#pragma once
#include <ATen/core/ivalue.h>
#include <c10/util/hash.h>
namespace c10 {
namespace detail {
inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const {
if (lhs.isTensor() && rhs.isTensor()) {
// for tensors, we compare only by identity (following how it's done in Python).
return lhs.is(rhs);
// Otherwise, we first compare by identity for efficiency, then by value (see:
// [container equality])
return _fastEqualsForContainer(lhs, rhs);
template<class T> decltype(auto) getTypePtr();
std::string toString(const Type& type);
namespace impl {
template<class Key, class Value>
Dict<Key, Value> toTypedDict(GenericDict dict) {
TORCH_INTERNAL_ASSERT(*getTypePtr<Key>() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Key types mismatch.");
TORCH_INTERNAL_ASSERT(*getTypePtr<Value>() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr<Key>()), ", ", toString(*getTypePtr<Value>()), ">. Value types mismatch.");
return Dict<Key, Value>(std::move(dict.impl_));
template<class Key, class Value>
GenericDict toGenericDict(Dict<Key, Value> dict) {
return GenericDict(std::move(dict.impl_));
namespace detail {
inline size_t DictKeyHash::operator()(const IValue& ivalue) const {
if (ivalue.isInt()) {
return std::hash<int64_t>()(ivalue.toInt());
} else if (ivalue.isString()) {
return std::hash<c10::string_view>()(ivalue.toStringView());
} else if (ivalue.isDouble()) {
return std::hash<double>()(ivalue.toDouble());
} else if (ivalue.isComplexDouble()) {
return c10::hash<c10::complex<double>>()(ivalue.toComplexDouble());
} else if (ivalue.isBool()) {
return std::hash<bool>()(ivalue.toBool());
} else if (ivalue.isTensor()) {
return std::hash<TensorImpl*>()(ivalue.toTensor().unsafeGetTensorImpl());
} else if (ivalue.isDevice()) {
return std::hash<Device>()(ivalue.toDevice());
} else {
throw std::runtime_error(
"Can't hash IValues with tag '" + ivalue.tagKind() + "'");
inline intrusive_ptr<DictImpl> DictImpl::copy() const {
return make_intrusive<DictImpl>(dict, elementTypes);
template<class Key, class Value>
Dict<Key, Value>::Dict()
detail::DictImpl::DictElementTypes{getTypePtr<Key>(), getTypePtr<Value>()})) {
static_assert(!std::is_same<Key, IValue>::value, "This constructor is not valid for Dict<IValue, _>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
static_assert(!std::is_same<Value, IValue>::value, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead.");
template<class Key, class Value>
Dict<Key, Value>::Dict(TypePtr keyType, TypePtr valueType)
: Dict(make_intrusive<detail::DictImpl>(
detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) {
static_assert(std::is_same<Key, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
static_assert(std::is_same<Value, IValue>::value, "This constructor is only valid for c10::impl::GenericDict.");
template<class Key, class Value>
Dict<Key, Value>::Dict(c10::intrusive_ptr<detail::DictImpl>&& impl): impl_(std::move(impl)) {}
template<class Key, class Value>
Dict<Key, Value> Dict<Key, Value>::copy() const {
return Dict<Key, Value>(impl_->copy());
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::begin() const {
return iterator{impl_->dict.begin()};
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::end() const {
return iterator{impl_->dict.end()};
template<class Key, class Value>
bool Dict<Key, Value>::empty() const {
return impl_->dict.empty();
template<class Key, class Value>
typename Dict<Key, Value>::size_type Dict<Key, Value>::size() const {
return impl_->dict.size();
template<class Key, class Value>
void Dict<Key, Value>::clear() const {
template<class Key, class Value>
template<class Key_, class Value_>
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert(Key_&& key, Value_&& value) const {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert");
auto inserted = impl_->dict.insert(std::pair<IValue, IValue>{
return {iterator{inserted.first}, inserted.second};
template<class Key, class Value>
template<class Key_, class Value_>
std::pair<typename Dict<Key, Value>::iterator, bool> Dict<Key, Value>::insert_or_assign(Key_&& key, Value_&& value) const {
static_assert(std::is_constructible<Key, Key_>::value, "Wrong type for the key argument of Dict::insert_or_assign");
static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of Dict::insert_or_assign");
auto inserted = impl_->dict.insert_or_assign(
return {iterator{inserted.first}, inserted.second};
template<class Key, class Value>
void Dict<Key, Value>::erase(iterator iter) const {
template<class Key, class Value>
C10_NODISCARD size_t Dict<Key, Value>::erase(const Key& key) const {
return impl_->dict.erase(key);
template<class Key, class Value>
Value Dict<Key, Value>::at(const Key& key) const {
return impl_->dict.at(key).template to<Value>();
template<class Key, class Value>
typename Dict<Key, Value>::iterator Dict<Key, Value>::find(const Key& key) const {
return iterator{impl_->dict.find(key)};
template<class Key, class Value>
bool Dict<Key, Value>::contains(const Key& key) const {
return end() != find(key);
template<class Key, class Value>
void Dict<Key, Value>::reserve(size_type count) const {
template<class Key, class Value>
TypePtr Dict<Key, Value>::keyType() const {
return impl_->elementTypes.keyType;
template<class Key, class Value>
TypePtr Dict<Key, Value>::valueType() const {
return impl_->elementTypes.valueType;
template <class Key, class Value>
void Dict<Key, Value>::unsafeSetKeyType(TypePtr t) {
impl_->elementTypes.keyType = std::move(t);
template <class Key, class Value>
void Dict<Key, Value>::unsafeSetValueType(TypePtr t) {
impl_->elementTypes.valueType = std::move(t);
template <class Key_, class Value_>
bool operator==(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
// Dicts with the same identity trivially compare equal.
if (lhs.impl_ == rhs.impl_) {
return true;
// Otherwise compare the values
return *lhs.impl_ == *rhs.impl_;
template <class Key_, class Value_>
bool operator!=(const Dict<Key_, Value_>& lhs, const Dict<Key_, Value_>& rhs) {
return !(lhs == rhs);
template <class Key, class Value>
bool Dict<Key, Value>::is(const Dict& rhs) const {
return this->impl_ == rhs.impl_;