Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tensordict / _td.py
Size: Mime:
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import numbers
import os
import weakref
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor, wait
from copy import copy
from numbers import Number
from pathlib import Path
from textwrap import indent
from typing import Any, Callable, Dict, Iterable, Iterator, List, Sequence, Tuple, Type
from warnings import warn

import numpy as np
import orjson as json
import torch

from tensordict.base import (
    _ACCEPTED_CLASSES,
    _default_is_leaf,
    _expand_to_match_shape,
    _is_leaf_nontensor,
    _is_tensor_collection,
    _load_metadata,
    _NESTED_TENSORS_AS_LISTS,
    _register_tensor_class,
    BEST_ATTEMPT_INPLACE,
    CompatibleType,
    is_tensor_collection,
    NO_DEFAULT,
    T,
    TensorDictBase,
)
from tensordict.memmap import MemoryMappedTensor
from tensordict.utils import (
    _add_batch_dim_pre_hook,
    _as_context_manager,
    _BatchedUninitializedBuffer,
    _BatchedUninitializedParameter,
    _check_inbuild,
    _clone_value,
    _get_item,
    _get_leaf_tensordict,
    _get_shape_from_args,
    _getitem_batch_size,
    _index_preserve_data_ptr,
    _is_shared,
    _is_tensorclass,
    _KEY_ERROR,
    _LOCK_ERROR,
    _NON_STR_KEY_ERR,
    _NON_STR_KEY_TUPLE_ERR,
    _parse_to,
    _prune_selected_keys,
    _set_item,
    _set_max_batch_size,
    _shape,
    _STRDTYPE2DTYPE,
    _StringKeys,
    _StringOnlyDict,
    _sub_index,
    _unravel_key_to_tuple,
    _zip_strict,
    Buffer,
    cache,
    convert_ellipsis_to_idx,
    DeviceType,
    expand_as_right,
    IndexType,
    is_non_tensor,
    is_tensorclass,
    KeyedJaggedTensor,
    lock_blocked,
    NestedKey,
    unravel_key,
    unravel_key_list,
)
from torch import Tensor
from torch._dynamo import graph_break
from torch.jit._shape_functions import infer_size_impl
from torch.nn.parameter import UninitializedTensorMixin
from torch.utils._pytree import tree_map

try:
    from functorch import dim as ftdim

    _has_funcdim = True
except ImportError:
    from tensordict.utils import _ftdim_mock as ftdim

    _has_funcdim = False
try:
    from torch.compiler import is_dynamo_compiling
except ImportError:  # torch 2.0
    from torch._dynamo import is_compiling as is_dynamo_compiling

_register_tensor_class(ftdim.Tensor)

__base__setattr__ = torch.nn.Module.__setattr__

_has_mps = torch.backends.mps.is_available()
_has_cuda = torch.cuda.is_available()
_has_functorch = False
try:
    try:
        from torch._C._functorch import (  # @manual=fbcode//caffe2:torch
            _add_batch_dim,
            _remove_batch_dim,
            is_batchedtensor,
        )
    except ImportError:
        from functorch._C import is_batchedtensor  # @manual=fbcode//functorch:_C

    _has_functorch = True
except ImportError:
    _has_functorch = False

    def is_batchedtensor(tensor: Tensor) -> bool:
        """Placeholder for the functorch function."""
        return False


class TensorDict(TensorDictBase):
    """A batched dictionary of tensors.

    TensorDict is a tensor container where all tensors are stored in a
    key-value pair fashion and where each element shares the same first ``N``
    leading dimensions shape, where is an arbitrary number with ``N >= 0``.

    Additionally, if the tensordict has a specified device, then each element
    must share that device.

    TensorDict instances support many regular tensor operations with the notable
    exception of algebraic operations:

    - operations on shape: when a shape operation is called (indexing,
      reshape, view, expand, transpose, permute,
      unsqueeze, squeeze, masking etc), the operations is done as if it
      was executed on a tensor of the same shape as the batch size then
      expended to the right, e.g.:

        >>> td = TensorDict({'a': torch.zeros(3, 4, 5)}, batch_size=[3, 4])
        >>> # returns a TensorDict of batch size [3, 4, 1]:
        >>> td_unsqueeze = td.unsqueeze(-1)
        >>> # returns a TensorDict of batch size [12]
        >>> td_view = td.view(-1)
        >>> # returns a tensor of batch size [12, 4]
        >>> a_view = td.view(-1).get("a")

    - casting operations: a TensorDict can be cast on a different device using

        >>> td_cpu = td.to("cpu")
        >>> dictionary = td.to_dict()

      A call of the `.to()` method with a dtype will return an error.

    - Cloning (:meth:`~TensorDictBase.clone`), contiguous (:meth:`~TensorDictBase.contiguous`);

    - Reading: `td.get(key)`, `td.get_at(key, index)`

    - Content modification: :obj:`td.set(key, value)`, :obj:`td.set_(key, value)`,
      :obj:`td.update(td_or_dict)`, :obj:`td.update_(td_or_dict)`, :obj:`td.fill_(key,
      value)`, :obj:`td.rename_key_(old_name, new_name)`, etc.

    - Operations on multiple tensordicts: `torch.cat(tensordict_list, dim)`,
      `torch.stack(tensordict_list, dim)`, `td1 == td2`, `td.apply(lambda x+y, other_td)` etc.

    Args:
        source (TensorDict or Dict[NestedKey, Union[Tensor, TensorDictBase]]): a
            data source. If empty, the tensordict can be populated subsequently.
            A ``TensorDict`` can also be built via a sequence of keyword arguments,
            as it is the case for ``dict(...)``.
        batch_size (iterable of int, optional): a batch size for the
            tensordict. The batch size can be modified subsequently as long
            as it is compatible with its content.
            If not batch-size is provided, an empty batch-size is assumed (it
            is not inferred automatically from the data). To automatically set
            the batch-size, refer to :meth:`~.auto_batch_size_`.
        device (torch.device or compatible type, optional): a device for the
            TensorDict. If provided, all tensors will be stored on that device.
            If not, tensors on different devices are allowed.
        names (lsit of str, optional): the names of the dimensions of the
            tensordict. If provided, its length must match the one of the
            ``batch_size``. Defaults to ``None`` (no dimension name, or ``None``
            for every dimension).
        non_blocking (bool, optional): if ``True`` and a device is passed, the tensordict
            is delivered without synchronization. This is the fastest option but is only
            safe when casting from cpu to cuda (otherwise a synchronization call must be
            implemented by the user).
            If ``False`` is passed, every tensor movement will be done synchronously.
            If ``None`` (default), the device casting will be done asynchronously but
            a synchronization will be executed after creation if required. This option
            should generally be faster than ``False`` and potentially slower than ``True``.
        lock (bool, optional): if ``True``, the resulting tensordict will be
            locked.

    Examples:
        >>> import torch
        >>> from tensordict import TensorDict
        >>> source = {'random': torch.randn(3, 4),
        ...     'zeros': torch.zeros(3, 4, 5)}
        >>> batch_size = [3]
        >>> td = TensorDict(source, batch_size=batch_size)
        >>> print(td.shape)  # equivalent to td.batch_size
        torch.Size([3])
        >>> td_unqueeze = td.unsqueeze(-1)
        >>> print(td_unqueeze.get("zeros").shape)
        torch.Size([3, 1, 4, 5])
        >>> print(td_unqueeze[0].shape)
        torch.Size([1])
        >>> print(td_unqueeze.view(-1).shape)
        torch.Size([3])
        >>> print((td.clone()==td).all())
        True

    """

    _td_dim_names = None
    _is_shared = False
    _is_memmap = False
    _has_exclusive_keys = False

    def __init__(
        self,
        source: T | dict[str, CompatibleType] = None,
        batch_size: Sequence[int] | torch.Size | int | None = None,
        device: DeviceType | None = None,
        names: Sequence[str] | None = None,
        non_blocking: bool = None,
        lock: bool = False,
        **kwargs,
    ) -> None:
        if (source is not None) and kwargs:
            raise ValueError(
                "Either a dictionary or a sequence of kwargs must be provided, not both."
            )
        source = source if not kwargs else kwargs
        if names and is_dynamo_compiling():
            graph_break()
        has_device = False
        sub_non_blocking = False
        if device is not None:
            has_device = True
            if non_blocking is None:
                sub_non_blocking = True
                non_blocking = False
            else:
                sub_non_blocking = non_blocking
            device = torch.device(device) if device is not None else None
            if _has_mps:
                # With MPS, an explicit sync is required
                sub_non_blocking = True
        self._device = device

        self._tensordict = _StringOnlyDict()
        if source is None:
            source = {}
        if not isinstance(source, (TensorDictBase, dict)):
            raise ValueError(
                "A TensorDict source is expected to be a TensorDictBase "
                f"sub-type or a dictionary, found type(source)={type(source)}."
            )
        self._batch_size = self._parse_batch_size(source, batch_size)
        # TODO: this breaks when stacking tensorclasses with dynamo
        if not is_dynamo_compiling():
            self.names = names

        for key, value in source.items():
            self.set(key, value, non_blocking=sub_non_blocking)
        if not non_blocking and sub_non_blocking and has_device:
            self._sync_all()
        if lock:
            self.lock_()

    @classmethod
    def _new_unsafe(
        cls,
        source: T | dict[str, CompatibleType] = None,
        batch_size: Sequence[int] | torch.Size | int | None = None,
        device: DeviceType | None = None,
        names: Sequence[str] | None = None,
        non_blocking: bool = None,
        lock: bool = False,
        nested: bool = True,
    ) -> TensorDict:
        if is_dynamo_compiling():
            return TensorDict(
                source,
                batch_size=batch_size,
                device=device,
                names=names,
                non_blocking=non_blocking,
                lock=lock,
            )
        self = cls.__new__(cls)
        sub_non_blocking = False
        if device is not None:
            if non_blocking is None:
                sub_non_blocking = True
                non_blocking = False
            else:
                sub_non_blocking = non_blocking
            device = torch.device(device) if device is not None else None
            if _has_mps:
                # With MPS, an explicit sync is required
                sub_non_blocking = True
        self._device = device
        self._tensordict = _tensordict = _StringOnlyDict()
        self._batch_size = batch_size
        if source:  # faster than calling items
            for key, value in source.items():
                if nested and isinstance(value, dict):
                    value = TensorDict._new_unsafe(
                        source=value,
                        batch_size=self._batch_size,
                        device=self._device,
                        non_blocking=sub_non_blocking,
                    )
                _tensordict[key] = value
        # assert names is None or len(names) == self.batch_dims, (names, batch_size)
        # assert (names is None) or (not all(name is None for name in names))
        self._td_dim_names = names
        if lock:
            self.lock_()
        return self

    @classmethod
    def from_module(
        cls,
        module: torch.nn.Module,
        as_module: bool = False,
        lock: bool = False,
        use_state_dict: bool = False,
        filter_empty: bool = True,
    ):
        result = cls._from_module(
            module=module,
            as_module=as_module,
            use_state_dict=use_state_dict,
            filter_empty=filter_empty,
        )
        if result is None:
            result = TensorDict._new_unsafe({}, batch_size=torch.Size(()))
        if lock:
            result.lock_()
        return result

    @classmethod
    def _from_module(
        cls,
        module: torch.nn.Module,
        as_module: bool = False,
        use_state_dict: bool = False,
        prefix="",
        filter_empty: bool = True,
    ):
        from tensordict.nn import TensorDictParams

        if isinstance(module, TensorDictParams):
            return module
        destination = {}
        if use_state_dict:
            keep_vars = False
            # do we need this feature atm?
            local_metadata = {}
            # if hasattr(destination, "_metadata"):
            #     destination._metadata[prefix[:-1]] = local_metadata
            for hook in module._state_dict_pre_hooks.values():
                hook(module, prefix, keep_vars)
            module._save_to_state_dict(destination, "", keep_vars)
        else:
            for name, param in module._parameters.items():
                if param is None:
                    continue
                destination[name] = param
            for name, buffer in module._buffers.items():
                if buffer is None:
                    continue
                destination[name] = buffer

        if use_state_dict:
            for hook in module._state_dict_hooks.values():
                hook_result = hook(module, destination, prefix, local_metadata)
                if hook_result is not None:
                    destination = hook_result
        if not filter_empty or destination:
            destination_set = True
            destination = TensorDict._new_unsafe(destination, batch_size=torch.Size(()))
        else:
            destination_set = False
        for name, submodule in module._modules.items():
            if submodule is not None:
                subtd = cls._from_module(
                    module=submodule,
                    as_module=False,
                    use_state_dict=use_state_dict,
                    prefix=prefix + name + ".",
                    filter_empty=filter_empty,
                )
                if subtd is not None:
                    if not destination_set:
                        destination = TensorDict._new_unsafe(batch_size=torch.Size(()))
                        destination_set = True
                    destination._set_str(
                        name, subtd, validated=True, inplace=False, non_blocking=False
                    )
        if not destination_set:
            return

        if as_module:
            from tensordict.nn.params import TensorDictParams

            return TensorDictParams(destination, no_convert=True)
        return destination

    def is_empty(self):

        for item in self._tensordict.values():
            # we need to check if item is empty
            if _is_tensor_collection(type(item)):
                if not item.is_empty():
                    return False

                if is_non_tensor(item):
                    return False
            else:
                return False
        return True

    def _to_module(
        self,
        module,
        *,
        inplace: bool | None = None,
        return_swap: bool = True,
        swap_dest=None,
        memo=None,
        use_state_dict: bool = False,
        non_blocking: bool = False,
    ):
        is_dynamo = is_dynamo_compiling()
        if is_dynamo:
            _check_inbuild()

        if not use_state_dict and isinstance(module, TensorDictBase):
            if return_swap:
                swap = module.copy()
                module._param_td = getattr(self, "_param_td", self)
                return swap
            else:
                module.update(self)
                return

        hooks = memo["hooks"]
        if return_swap:
            _swap = {}
            if not is_dynamo:
                memo[weakref.ref(module)] = _swap

        if use_state_dict:
            if inplace is not None:
                raise RuntimeError(
                    "inplace argument cannot be passed when use_state_dict=True."
                )
            # execute module's pre-hooks
            state_dict = self.flatten_keys(".")
            prefix = ""
            strict = True
            local_metadata = {}
            missing_keys = []
            unexpected_keys = []
            error_msgs = []
            for hook in module._load_state_dict_pre_hooks.values():
                hook(
                    state_dict,
                    prefix,
                    local_metadata,
                    strict,
                    missing_keys,
                    unexpected_keys,
                    error_msgs,
                )

            def convert_type(x, y):
                if isinstance(y, torch.nn.Parameter):
                    return torch.nn.Parameter(x)
                if isinstance(y, Buffer):
                    return Buffer(x)
                return x

            input = state_dict.unflatten_keys(".")._fast_apply(
                convert_type, self, propagate_lock=True
            )
        else:
            input = self
            inplace = bool(inplace)

        # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
        if type(module).__setattr__ is __base__setattr__:
            __dict__ = module.__dict__
            _parameters = __dict__["_parameters"]
            _buffers = __dict__["_buffers"]
        else:
            __dict__ = None

        for key, value in input.items():
            if isinstance(value, (Tensor, ftdim.Tensor)):
                # For Dynamo, we use regular set/delattr as we're not
                #  much afraid by overhead (and dynamo doesn't like those
                #  hacks we're doing).
                if __dict__ is not None:
                    # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict
                    local_out = _set_tensor_dict(
                        __dict__,
                        _parameters,
                        _buffers,
                        hooks,
                        module,
                        key,
                        value,
                        inplace,
                    )
                else:
                    if return_swap:
                        local_out = getattr(module, key)
                    if not inplace:
                        # use specialized __setattr__ if needed
                        delattr(module, key)
                        setattr(module, key, value)
                    else:
                        new_val = local_out
                        if return_swap:
                            local_out = local_out.clone()
                        new_val.data.copy_(value.data, non_blocking=non_blocking)
            else:
                if __dict__ is not None:
                    child = __dict__["_modules"][key]
                else:
                    child = module._modules.get(key)

                if not is_dynamo:
                    local_out = memo.get(weakref.ref(child), NO_DEFAULT)

                if is_dynamo or local_out is NO_DEFAULT:
                    local_out = value._to_module(
                        child,
                        inplace=inplace,
                        return_swap=return_swap,
                        swap_dest={},  # we'll be calling update later
                        memo=memo,
                        use_state_dict=use_state_dict,
                        non_blocking=non_blocking,
                    )

            if return_swap:
                _swap[key] = local_out

        if return_swap:
            if isinstance(swap_dest, dict):
                return _swap
            elif swap_dest is not None:

                def _quick_set(swap_dict, swap_td):
                    for key, val in swap_dict.items():
                        if isinstance(val, dict):
                            _quick_set(val, swap_td._get_str(key, default=NO_DEFAULT))
                        elif swap_td._get_str(key, None) is not val:
                            swap_td._set_str(
                                key,
                                val,
                                inplace=False,
                                validated=True,
                                non_blocking=non_blocking,
                            )

                _quick_set(_swap, swap_dest)
                return swap_dest
            else:
                return TensorDict._new_unsafe(_swap, batch_size=[])

    def __ne__(self, other: object) -> T | bool:
        if _is_tensorclass(other):
            return other != self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                raise KeyError(
                    f"keys in {self} and {other} mismatch, got {keys1} and {keys2}"
                )
            d = {}
            for key, item1 in self.items():
                d[key] = item1 != other.get(key)
            return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value != other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return True

    def __xor__(self, other: object) -> T | bool:
        if _is_tensorclass(other):
            return other ^ self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                raise KeyError(
                    f"keys in {self} and {other} mismatch, got {keys1} and {keys2}"
                )
            d = {}
            for key, item1 in self.items():
                d[key] = item1 ^ other.get(key)
            return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value ^ other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return True

    def __or__(self, other: object) -> T | bool:
        if _is_tensorclass(other):
            return other | self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                raise KeyError(
                    f"keys in {self} and {other} mismatch, got {keys1} and {keys2}"
                )
            d = {}
            for key, item1 in self.items():
                d[key] = item1 | other.get(key)
            return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value | other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return False

    def __eq__(self, other: object) -> T | bool:
        if is_tensorclass(other):
            return other == self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                keys1 = sorted(
                    keys1,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                keys2 = sorted(
                    keys2,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
            d = {}
            for key, item1 in self.items():
                d[key] = item1 == other.get(key)
            return TensorDict(source=d, batch_size=self.batch_size, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value == other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return False

    def __ge__(self, other: object) -> T | bool:
        if is_tensorclass(other):
            return other <= self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                keys1 = sorted(
                    keys1,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                keys2 = sorted(
                    keys2,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
            d = {}
            for key, item1 in self.items():
                d[key] = item1 >= other.get(key)
            return TensorDict(source=d, batch_size=self.batch_size, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value >= other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return False

    def __gt__(self, other: object) -> T | bool:
        if is_tensorclass(other):
            return other < self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                keys1 = sorted(
                    keys1,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                keys2 = sorted(
                    keys2,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
            d = {}
            for key, item1 in self.items():
                d[key] = item1 > other.get(key)
            return TensorDict(source=d, batch_size=self.batch_size, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value > other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return False

    def __le__(self, other: object) -> T | bool:
        if is_tensorclass(other):
            return other >= self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                keys1 = sorted(
                    keys1,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                keys2 = sorted(
                    keys2,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
            d = {}
            for key, item1 in self.items():
                d[key] = item1 <= other.get(key)
            return TensorDict(source=d, batch_size=self.batch_size, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value <= other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return False

    def __lt__(self, other: object) -> T | bool:
        if is_tensorclass(other):
            return other > self
        if isinstance(other, (dict,)):
            other = self.from_dict_instance(other)
        if _is_tensor_collection(type(other)):
            keys1 = set(self.keys())
            keys2 = set(other.keys())
            if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
                keys1 = sorted(
                    keys1,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                keys2 = sorted(
                    keys2,
                    key=lambda key: "".join(key) if isinstance(key, tuple) else key,
                )
                raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
            d = {}
            for key, item1 in self.items():
                d[key] = item1 < other.get(key)
            return TensorDict(source=d, batch_size=self.batch_size, device=self.device)
        if isinstance(other, (numbers.Number, Tensor)):
            return TensorDict(
                {key: value < other for key, value in self.items()},
                self.batch_size,
                device=self.device,
            )
        return False

    def __setitem__(
        self,
        index: IndexType,
        value: T | dict | numbers.Number | CompatibleType,
    ) -> None:
        istuple = isinstance(index, tuple)
        if istuple or isinstance(index, str):
            # try:
            index_unravel = _unravel_key_to_tuple(index)
            if index_unravel:
                self._set_tuple(
                    index_unravel,
                    value,
                    inplace=(
                        BEST_ATTEMPT_INPLACE
                        if isinstance(self, _SubTensorDict)
                        else False
                    ),
                    validated=False,
                    non_blocking=False,
                )
                return

        # we must use any and because using Ellipsis in index can break with some indices
        if index is Ellipsis or (
            isinstance(index, tuple) and any(idx is Ellipsis for idx in index)
        ):
            index = convert_ellipsis_to_idx(index, self.batch_size)

        if isinstance(value, (TensorDictBase, dict)):
            indexed_bs = _getitem_batch_size(self.batch_size, index)
            if isinstance(value, dict):
                value = self.from_dict_instance(value, batch_size=indexed_bs)
                # value = self.empty(recurse=True)[index].update(value)
            if value.batch_size != indexed_bs:
                if value.shape == indexed_bs[-len(value.shape) :]:
                    # try to expand on the left (broadcasting)
                    value = value.expand(indexed_bs)
                else:
                    try:
                        # copy and change batch_size if can't be expanded
                        value = value.copy()
                        value.batch_size = indexed_bs
                    except RuntimeError as err:
                        raise RuntimeError(
                            f"indexed destination TensorDict batch size is {indexed_bs} "
                            f"(batch_size = {self.batch_size}, index={index}), "
                            f"which differs from the source batch size {value.batch_size}"
                        ) from err

            keys = set(self.keys())
            subtd = None
            for value_key, item in value.items():
                if value_key in keys:
                    self._set_at_str(
                        value_key, item, index, validated=False, non_blocking=False
                    )
                else:
                    if subtd is None:
                        subtd = self._get_sub_tensordict(index)
                    subtd.set(value_key, item, inplace=True, non_blocking=False)
        else:
            for key in self.keys():
                self.set_at_(key, value, index)

    def all(self, dim: int = None) -> bool | TensorDictBase:
        if dim is not None and (dim >= self.batch_dims or dim < -self.batch_dims):
            raise RuntimeError(
                "dim must be greater than or equal to -tensordict.batch_dims and "
                "smaller than tensordict.batch_dims"
            )
        if dim is not None:
            if dim < 0:
                dim = self.batch_dims + dim

            names = None
            if self._has_names():
                names = copy(self.names)
                names = [name for i, name in enumerate(names) if i != dim]

            return TensorDict(
                source={key: value.all(dim=dim) for key, value in self.items()},
                batch_size=[b for i, b in enumerate(self.batch_size) if i != dim],
                device=self.device,
                names=names,
            )
        return all(value.all() for value in self.values())

    def any(self, dim: int = None) -> bool | TensorDictBase:
        if dim is not None and (dim >= self.batch_dims or dim < -self.batch_dims):
            raise RuntimeError(
                "dim must be greater than or equal to -tensordict.batch_dims and "
                "smaller than tensordict.batch_dims"
            )
        if dim is not None:
            if dim < 0:
                dim = self.batch_dims + dim

            names = None
            if self._has_names():
                names = copy(self.names)
                names = [name for i, name in enumerate(names) if i != dim]

            return TensorDict(
                source={key: value.any(dim=dim) for key, value in self.items()},
                batch_size=[b for i, b in enumerate(self.batch_size) if i != dim],
                device=self.device,
                names=names,
            )
        return any([value.any() for value in self.values()])

    def _cast_reduction(
        self,
        *,
        reduction_name,
        dim=NO_DEFAULT,
        keepdim=NO_DEFAULT,
        tuple_ok=True,
        further_reduce: bool,
        **kwargs,
    ):
        if further_reduce:
            # It is not very memory-efficient to do this, but it's the easiest to cover all use cases
            if dim is NO_DEFAULT:
                agglomerate = [
                    val.contiguous().flatten()
                    for val in self._values_list(
                        True, True, is_leaf=_NESTED_TENSORS_AS_LISTS
                    )
                ]
                agglomerate = torch.cat(agglomerate, dim=0)
                return getattr(torch, reduction_name)(agglomerate)
            else:
                agglomerate = list(
                    self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
                )
                agglomerate = torch.cat(agglomerate, dim=0)
                return getattr(torch, reduction_name)(
                    agglomerate, keepdim=keepdim, dim=dim
                )

        def proc_dim(dim, tuple_ok=True):
            if dim is None:
                return dim
            if isinstance(dim, tuple):
                if tuple_ok:
                    return tuple(_d for d in dim for _d in proc_dim(d, tuple_ok=False))
                return dim
            if dim >= self.batch_dims or dim < -self.batch_dims:
                raise RuntimeError(
                    "dim must be greater than or equal to -tensordict.batch_dims and "
                    "smaller than tensordict.batch_dims"
                )
            if dim < 0:
                return (self.batch_dims + dim,)
            return (dim,)

        if dim is not NO_DEFAULT:
            dim = proc_dim(dim, tuple_ok=tuple_ok)
            if not tuple_ok:
                dim = dim[0]
        if dim is not NO_DEFAULT or keepdim:
            names = None
            if self._has_names():
                names = copy(self.names)
                if not keepdim and isinstance(dim, tuple):
                    names = [name for i, name in enumerate(names) if i not in dim]
                else:
                    names = [name for i, name in enumerate(names) if i != dim]
            if dim is not NO_DEFAULT:
                kwargs["dim"] = dim
            if keepdim is not NO_DEFAULT:
                kwargs["keepdim"] = keepdim

            def reduction(val):
                result = getattr(val, reduction_name)(
                    **kwargs,
                )
                return result

            if dim not in (None, NO_DEFAULT):
                if not keepdim:
                    if isinstance(dim, tuple):
                        batch_size = [
                            b for i, b in enumerate(self.batch_size) if i not in dim
                        ]
                    else:
                        batch_size = [
                            b for i, b in enumerate(self.batch_size) if i != dim
                        ]
                else:
                    if isinstance(dim, tuple):
                        batch_size = [
                            b if i not in dim else 1
                            for i, b in enumerate(self.batch_size)
                        ]
                    else:
                        batch_size = [
                            b if i != dim else 1 for i, b in enumerate(self.batch_size)
                        ]

            else:
                batch_size = [1 for b in self.batch_size]

            return self._fast_apply(
                reduction,
                call_on_nested=True,
                batch_size=torch.Size(batch_size),
                device=self.device,
                names=names,
            )

        def reduction(val):
            return getattr(val, reduction_name)(**kwargs)

        return self._fast_apply(
            reduction,
            call_on_nested=True,
            batch_size=torch.Size([]),
            device=self.device,
            names=None,
        )

    def _multithread_apply_flat(
        self,
        fn: Callable,
        *others: T,
        call_on_nested: bool = False,
        default: Any = NO_DEFAULT,
        named: bool = False,
        nested_keys: bool = False,
        prefix: tuple = (),
        is_leaf: Callable = None,
        executor: ThreadPoolExecutor,
        futures: List[Future],
        local_futures: List,
    ) -> None:
        if is_leaf is None:
            is_leaf = _default_is_leaf
        for key, item in self.items():
            if (
                not call_on_nested
                and not is_leaf(type(item))
                # and not is_non_tensor(item)
            ):
                if default is not NO_DEFAULT:
                    _others = [_other._get_str(key, default=None) for _other in others]
                    _others = [
                        self.empty(recurse=True) if _other is None else _other
                        for _other in _others
                    ]
                else:
                    _others = [
                        _other._get_str(key, default=NO_DEFAULT) for _other in others
                    ]
                local_futures.append([])
                item._multithread_apply_flat(
                    fn,
                    *_others,
                    named=named,
                    nested_keys=nested_keys,
                    prefix=prefix + (key,),
                    is_leaf=is_leaf,
                    executor=executor,
                    futures=futures,
                    local_futures=local_futures[-1],
                )
            else:
                _others = [_other._get_str(key, default=default) for _other in others]
                if named:
                    if nested_keys:
                        future = executor.submit(
                            fn, prefix + (key,) if prefix != () else key, item, *_others
                        )
                    else:
                        future = executor.submit(fn, key, item, *_others)
                else:
                    future = executor.submit(fn, item, *_others)
                futures.append(future)
                local_futures.append(future)

    def _multithread_rebuild(
        self,
        *,
        batch_size: Sequence[int] | None = None,
        device: torch.device | None = NO_DEFAULT,
        names: Sequence[str] | None = NO_DEFAULT,
        inplace: bool = False,
        checked: bool = False,
        out: TensorDictBase | None = None,
        filter_empty: bool = False,
        executor: ThreadPoolExecutor,
        futures: List[Future],
        local_futures: List,
        subs_results: Dict[Future, Any] | None = None,
        multithread_set: bool = False,  # Experimental
        **constructor_kwargs,
    ) -> None:
        if constructor_kwargs:
            raise RuntimeError(
                f"constructor_kwargs not supported for class {type(self)}."
            )
        # Rebuilds a tensordict from the futures of its leaves
        if inplace:
            result = self
            is_locked = result.is_locked
        elif out is not None:
            result = out
            if out.is_locked:
                raise RuntimeError(_LOCK_ERROR)
            is_locked = False
            if batch_size is not None and batch_size != out.batch_size:
                raise RuntimeError(
                    "batch_size and out.batch_size must be equal when both are provided."
                )
            if device is not NO_DEFAULT and device != out.device:
                raise RuntimeError(
                    "device and out.device must be equal when both are provided."
                )
        else:

            def make_result(names=names, batch_size=batch_size):
                if names is NO_DEFAULT:
                    if batch_size is not None:
                        # erase names
                        names = None
                    elif batch_size is None:
                        names = self.names if self._has_names() else None
                return self.empty(batch_size=batch_size, device=device, names=names)

            result = make_result()
            is_locked = False

        any_set = set()

        if isinstance(result, _SubTensorDict):

            def setter(
                item_trsf,
                key,
                inplace=inplace,
                result=result,
            ):
                set_item = item_trsf is not None
                any_set.add(set_item)
                if not set_item:
                    return
                result.set(key, item_trsf, inplace=inplace)

        elif isinstance(result, TensorDict) and checked and (inplace is not True):

            def setter(
                item_trsf,
                key,
                result=result,
            ):
                set_item = item_trsf is not None
                any_set.add(set_item)
                if not set_item:
                    return
                result._tensordict[key] = item_trsf

        else:

            local_inplace = BEST_ATTEMPT_INPLACE if inplace else False

            def setter(
                item_trsf,
                key,
                result=result,
                checked=checked,
            ):
                set_item = item_trsf is not None
                any_set.add(set_item)
                if not set_item:
                    return

                result._set_str(
                    key,
                    item_trsf,
                    inplace=local_inplace,
                    validated=checked,
                    non_blocking=False,
                )

        for i, (key, local_future) in enumerate(
            _zip_strict(self.keys(), local_futures)
        ):

            if isinstance(local_future, list):
                # We can't make this a future as it could cause deadlocks:
                #  If we put a future over the root and this triggers another
                #  call on the leaves, the root will occupy a spot in the execution queue
                #  and wait for completion, potentially preventing the leaf of
                #  getting in the execution queue at all.
                td = self._get_str(key, default=None)
                item_trsf = td._multithread_rebuild(
                    batch_size=batch_size,
                    device=device,
                    names=names,
                    inplace=inplace,
                    checked=checked,
                    out=out,
                    filter_empty=filter_empty,
                    executor=executor,
                    futures=futures,
                    local_futures=local_future,
                    subs_results=subs_results,
                    multithread_set=multithread_set,
                    **constructor_kwargs,
                )
                if multithread_set:
                    local_future = executor.submit(setter, item_trsf=item_trsf, key=key)
                    local_futures[i] = local_future
                    futures.append(local_future)
                else:
                    setter(item_trsf=item_trsf, key=key)
            else:
                if multithread_set:
                    if subs_results is not None:
                        local_result = subs_results[local_future]
                    else:
                        # TODO: check if add_done_callback can safely be used here
                        #  The issue is that it does not raises an exception encountered during the
                        #  execution, resulting in UBs.
                        local_result = local_future.result()
                    local_future = executor.submit(
                        setter, item_trsf=local_result, key=key
                    )
                    futures.append(local_future)
                    local_futures[i] = local_future
                else:
                    local_result = local_future.result()
                    setter(item_trsf=local_result, key=key)

        if multithread_set:
            wait(local_futures)
        any_set = True in any_set or is_non_tensor(self)

        if filter_empty and not any_set:
            return
        elif not filter_empty and not inplace and is_locked:
            result.lock_()
        return result

    def _apply_nest(
        self,
        fn: Callable,
        *others: T,
        batch_size: Sequence[int] | None = None,
        device: torch.device | None = NO_DEFAULT,
        names: Sequence[str] | None = NO_DEFAULT,
        inplace: bool = False,
        checked: bool = False,
        call_on_nested: bool = False,
        default: Any = NO_DEFAULT,
        named: bool = False,
        nested_keys: bool = False,
        prefix: tuple = (),
        filter_empty: bool | None = None,
        is_leaf: Callable = None,
        out: TensorDictBase | None = None,
        **constructor_kwargs,
    ) -> T | None:
        if inplace:
            result = self
            is_locked = result.is_locked
        elif out is not None:
            result = out
            if out.is_locked:
                raise RuntimeError(_LOCK_ERROR)
            is_locked = False
            if batch_size is not None and batch_size != out.batch_size:
                raise RuntimeError(
                    "batch_size and out.batch_size must be equal when both are provided."
                )
            if device is not NO_DEFAULT and device != out.device:
                raise RuntimeError(
                    "device and out.device must be equal when both are provided."
                )
        else:

            def make_result(names=names, batch_size=batch_size):
                if names is NO_DEFAULT:
                    if batch_size is not None:
                        # erase names
                        names = None
                    else:
                        names = self.names if self._has_names() else None
                return self.empty(batch_size=batch_size, device=device, names=names)

            result = None
            is_locked = False

        any_set = False
        if is_leaf is None:
            is_leaf = _default_is_leaf

        for key, item in self.items():
            if (
                not call_on_nested
                and not is_leaf(type(item))
                # and not is_non_tensor(item)
            ):
                if default is not NO_DEFAULT:
                    _others = [_other._get_str(key, default=None) for _other in others]
                    _others = [
                        self.empty(recurse=True) if _other is None else _other
                        for _other in _others
                    ]
                else:
                    _others = [
                        _other._get_str(key, default=NO_DEFAULT) for _other in others
                    ]

                item_trsf = item._apply_nest(
                    fn,
                    *_others,
                    inplace=inplace,
                    batch_size=batch_size,
                    device=device,
                    checked=checked,
                    named=named,
                    nested_keys=nested_keys,
                    default=default,
                    prefix=prefix + (key,),
                    filter_empty=filter_empty,
                    is_leaf=is_leaf,
                    out=out._get_str(key, default=None) if out is not None else None,
                    **constructor_kwargs,
                )
            else:
                _others = [_other._get_str(key, default=default) for _other in others]
                if named:
                    if nested_keys:
                        item_trsf = fn(
                            prefix + (key,) if prefix != () else key, item, *_others
                        )
                    else:
                        item_trsf = fn(key, item, *_others)
                else:
                    item_trsf = fn(item, *_others)
            if item_trsf is not None:
                if not any_set:
                    if result is None:
                        result = make_result()
                    any_set = True
                if isinstance(self, _SubTensorDict):
                    result.set(key, item_trsf, inplace=inplace)
                else:
                    result._set_str(
                        key,
                        item_trsf,
                        inplace=BEST_ATTEMPT_INPLACE if inplace else False,
                        validated=checked,
                        non_blocking=False,
                    )

        if filter_empty and not any_set:
            return
        elif filter_empty is None and not any_set and not self.is_empty():
            # we raise the deprecation warning only if the tensordict wasn't already empty.
            # After we introduce the new behaviour, we will have to consider what happens
            # to empty tensordicts by default: will they disappear or stay?
            warn(
                "Your resulting tensordict has no leaves but you did not specify filter_empty=True. "
                "This now returns None (filter_empty=True). "
                "To silence this warning, set filter_empty to the desired value in your call to `apply`. "
                "This warning will be removed in v0.6.",
                category=DeprecationWarning,
            )
            return
        if result is None:
            result = make_result()

        if not inplace and is_locked:
            result.lock_()
        return result

    # Functorch compatibility
    @cache  # noqa: B019
    def _add_batch_dim(self, *, in_dim, vmap_level):
        td = self

        def _add_batch_dim_wrapper(key, value):
            if is_tensor_collection(value):
                return value._add_batch_dim(in_dim=in_dim, vmap_level=vmap_level)

            if isinstance(
                value, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer)
            ):
                value.in_dim = in_dim
                value.vmap_level = vmap_level
                return value
            return _add_batch_dim(value, in_dim, vmap_level)

        out = TensorDict._new_unsafe(
            {key: _add_batch_dim_wrapper(key, value) for key, value in td.items()},
            batch_size=torch.Size(
                [b for i, b in enumerate(td.batch_size) if i != in_dim]
            ),
            names=(
                [name for i, name in enumerate(td.names) if i != in_dim]
                if self._has_names()
                else None
            ),
            lock=self.is_locked,
        )
        return out

    @cache  # noqa: B019
    def _remove_batch_dim(self, vmap_level, batch_size, out_dim):
        new_batch_size = list(self.batch_size)
        new_batch_size.insert(out_dim, batch_size)
        new_names = list(self.names)
        new_names.insert(out_dim, None)
        out = TensorDict(
            {
                key: (
                    value._remove_batch_dim(
                        vmap_level=vmap_level, batch_size=batch_size, out_dim=out_dim
                    )
                    if is_tensor_collection(value)
                    else _remove_batch_dim(value, vmap_level, batch_size, out_dim)
                )
                for key, value in self.items()
            },
            batch_size=new_batch_size,
            names=new_names,
            lock=self.is_locked,
        )
        return out

    def _convert_to_tensordict(self, dict_value: dict[str, Any]) -> T:
        return TensorDict(
            dict_value,
            batch_size=self.batch_size,
            device=self.device,
        )

    def _index_tensordict(
        self,
        index: IndexType,
        new_batch_size: torch.Size | None = None,
        names: List[str] | None = None,
    ) -> T:
        batch_size = self.batch_size
        batch_dims = len(batch_size)

        def _check_for_invalid_index(index):
            if batch_size:
                return
            if index is None:
                return
            if (
                isinstance(index, torch.Tensor)
                and index.dtype == torch.bool
                and not index.ndim
            ):
                return
            if isinstance(index, tuple):
                if len(index) == 1:
                    return _check_for_invalid_index(index[0])
                elif all(idx is None for idx in index):
                    return
            raise RuntimeError(
                f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}."
            )

        _check_for_invalid_index(index)

        if new_batch_size is not None:
            batch_size = new_batch_size
        else:
            batch_size = _getitem_batch_size(batch_size, index)

        if names is None:
            names = self._get_names_idx(index)

        source = {}
        for key, item in self.items():
            if isinstance(item, TensorDict):
                # this is the simplest case, we can pre-compute the batch size easily
                new_batch_size = batch_size + item.batch_size[batch_dims:]
                source[key] = item._index_tensordict(
                    index, new_batch_size=new_batch_size
                )
            else:
                source[key] = _get_item(item, index)
        result = TensorDict._new_unsafe(
            source=source,
            batch_size=batch_size,
            device=self.device,
            names=names,
            # lock=self.is_locked,
        )
        if self._is_memmap and _index_preserve_data_ptr(index):
            result._is_memmap = True
            result.lock_()
        elif self._is_shared and _index_preserve_data_ptr(index):
            result._is_shared = True
            result.lock_()
        return result

    def expand(self, *args, **kwargs) -> T:
        tensordict_dims = self.batch_dims
        shape = _get_shape_from_args(*args, **kwargs)

        # new shape dim check
        if len(shape) < len(self.shape):
            raise RuntimeError(
                f"the number of sizes provided ({len(shape)}) must be greater or equal to the number of "
                f"dimensions in the TensorDict ({tensordict_dims})"
            )

        # new shape compatibility check
        for old_dim, new_dim in zip(self.batch_size, shape[-tensordict_dims:]):
            if old_dim != 1 and new_dim != old_dim:
                raise RuntimeError(
                    "Incompatible expanded shape: The expanded shape length at non-singleton dimension should be same "
                    f"as the original length. target_shape = {shape}, existing_shape = {self.batch_size}"
                )

        if self._has_names():
            names = [None] * (len(shape) - tensordict_dims) + self.names
        else:
            names = None

        def _expand(tensor):
            tensor_shape = tensor.shape
            tensor_dims = len(tensor_shape)
            last_n_dims = tensor_dims - tensordict_dims
            if last_n_dims > 0:
                new_shape = (*shape, *tensor_shape[-last_n_dims:])
            else:
                new_shape = shape
            return tensor.expand(new_shape)

        return self._fast_apply(
            _expand,
            batch_size=shape,
            call_on_nested=True,
            names=names,
            propagate_lock=True,
        )

    def _unbind(self, dim: int):
        batch_size = torch.Size([s for i, s in enumerate(self.batch_size) if i != dim])
        names = None
        if self._has_names():
            names = copy(self.names)
            names = [name for i, name in enumerate(names) if i != dim]
            # We could use any() but dynamo doesn't like generators
            for name in names:
                if name is not None:
                    break
            else:
                names = None
        device = self.device

        is_shared = self._is_shared
        is_memmap = self._is_memmap

        def empty(
            batch_size=batch_size,
            names=names,
            device=device,
            is_shared=is_shared,
            is_memmap=is_memmap,
        ):
            result = TensorDict._new_unsafe(
                {}, batch_size=batch_size, names=names, device=device
            )
            result._is_shared = is_shared
            result._is_memmap = is_memmap
            return result

        tds = tuple(empty() for _ in range(self.batch_size[dim]))

        def unbind(key, val, tds=tds):
            unbound = (
                val.unbind(dim)
                if not isinstance(val, TensorDictBase)
                # tensorclass is also unbound using plain unbind
                else val._unbind(dim)
            )
            for td, _val in _zip_strict(tds, unbound):
                td._set_str(
                    key, _val, validated=True, inplace=False, non_blocking=False
                )

        for key, val in self.items():
            unbind(key, val)
        return tds

    def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]:
        # we must use slices to keep the storage of the tensors
        WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints"
        batch_size = self.batch_size
        batch_sizes = []
        batch_dims = len(batch_size)
        if dim < 0:
            dim = len(batch_size) + dim
        if dim >= batch_dims or dim < 0:
            raise IndexError(
                f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})"
            )
        max_size = batch_size[dim]
        if isinstance(split_size, int):
            idx0 = 0
            idx1 = min(max_size, split_size)
            split_sizes = [slice(idx0, idx1)]
            batch_sizes.append(
                torch.Size(
                    tuple(
                        d if i != dim else idx1 - idx0 for i, d in enumerate(batch_size)
                    )
                )
            )
            while idx1 < max_size:
                idx0 = idx1
                idx1 = min(max_size, idx1 + split_size)
                split_sizes.append(slice(idx0, idx1))
                batch_sizes.append(
                    torch.Size(
                        tuple(
                            d if i != dim else idx1 - idx0
                            for i, d in enumerate(batch_size)
                        )
                    )
                )
        elif isinstance(split_size, (list, tuple)):
            if len(split_size) == 0:
                raise RuntimeError("Insufficient number of elements in split_size.")
            try:
                idx0 = 0
                idx1 = split_size[0]
                split_sizes = [slice(idx0, idx1)]
                batch_sizes.append(
                    torch.Size(
                        tuple(
                            d if i != dim else idx1 - idx0
                            for i, d in enumerate(batch_size)
                        )
                    )
                )
                for idx in split_size[1:]:
                    idx0 = idx1
                    idx1 = min(max_size, idx1 + idx)
                    split_sizes.append(slice(idx0, idx1))
                    batch_sizes.append(
                        torch.Size(
                            tuple(
                                d if i != dim else idx1 - idx0
                                for i, d in enumerate(batch_size)
                            )
                        )
                    )
            except TypeError:
                raise TypeError(WRONG_TYPE)

            if idx1 < batch_size[dim]:
                raise RuntimeError(
                    f"Split method expects split_size to sum exactly to {self.batch_size[dim]} (tensor's size at dimension {dim}), but got split_size={split_size}"
                )
        else:
            raise TypeError(WRONG_TYPE)
        index = (slice(None),) * dim
        names = self.names if self._has_names() else None
        return tuple(
            self._index_tensordict(index + (ss,), new_batch_size=bs, names=names)
            for ss, bs in _zip_strict(split_sizes, batch_sizes)
        )

    def masked_select(self, mask: Tensor) -> T:
        d = {}
        mask_expand = mask
        while mask_expand.ndimension() > self.batch_dims:
            mndim = mask_expand.ndimension()
            mask_expand = mask_expand.squeeze(-1)
            if mndim == mask_expand.ndimension():  # no more squeeze
                break
        for key, value in self.items():
            d[key] = value[mask_expand]
        dim = int(mask.sum().item())
        other_dim = self.shape[mask.ndim :]
        return TensorDict(
            device=self.device, source=d, batch_size=torch.Size([dim, *other_dim])
        )

    def _view(
        self,
        *args,
        **kwargs,
    ) -> T:
        shape = _get_shape_from_args(*args, **kwargs)
        if any(dim < 0 for dim in shape):
            shape = infer_size_impl(shape, self.numel())
        if torch.Size(shape) == self.shape:
            return self
        batch_dims = self.batch_dims

        def _view(tensor):
            return tensor.view((*shape, *tensor.shape[batch_dims:]))

        result = self._fast_apply(
            _view, batch_size=shape, call_on_nested=True, propagate_lock=True
        )
        self._maybe_set_shared_attributes(result)
        return result

    def reshape(
        self,
        *args,
        **kwargs,
    ) -> T:
        shape = _get_shape_from_args(*args, **kwargs)
        if any(dim < 0 for dim in shape):
            shape = infer_size_impl(shape, self.numel())
            shape = torch.Size(shape)
        if torch.Size(shape) == self.shape:
            return self
        batch_dims = self.batch_dims

        def _reshape(tensor):
            return tensor.reshape((*shape, *tensor.shape[batch_dims:]))

        return self._fast_apply(
            _reshape,
            batch_size=shape,
            call_on_nested=True,
            propagate_lock=True,
        )

    def _transpose(self, dim0, dim1):
        def _transpose(tensor):
            return tensor.transpose(dim0, dim1)

        batch_size = list(self.batch_size)
        v0 = batch_size[dim0]
        v1 = batch_size[dim1]
        batch_size[dim1] = v0
        batch_size[dim0] = v1
        if self._has_names():
            names = self.names
            names = [
                names[dim0] if i == dim1 else names[dim1] if i == dim0 else names[i]
                for i in range(self.ndim)
            ]
        else:
            names = None
        result = self._fast_apply(
            _transpose,
            batch_size=torch.Size(batch_size),
            call_on_nested=True,
            names=names,
            propagate_lock=True,
        )
        self._maybe_set_shared_attributes(result)
        return result

    def _permute(self, *args, **kwargs):
        dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs)
        dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list]
        if any(dim < 0 or dim >= self.ndim for dim in dims_list):
            raise ValueError(
                "Received an permutation order incompatible with the tensordict shape."
            )
        # note: to allow this to work recursively, we must allow permutation order with fewer elements than dims,
        # as long as this list is complete.
        if not np.array_equal(sorted(dims_list), range(len(dims_list))):
            raise ValueError(
                f"Cannot compute the permutation, got dims={dims_list} but expected a permutation of {list(range(len(dims_list)))}."
            )
        if not len(dims_list) and not self.batch_dims:
            return self
        if np.array_equal(dims_list, range(len(dims_list))):
            return self

        def _permute(tensor):
            return tensor.permute(*dims_list, *range(len(dims_list), tensor.ndim))

        batch_size = self.batch_size
        batch_size = [batch_size[p] for p in dims_list] + list(
            batch_size[len(dims_list) :]
        )
        if self._has_names():
            names = self.names
            names = [names[i] for i in dims_list]
        else:
            names = None
        result = self._fast_apply(
            _permute,
            batch_size=batch_size,
            call_on_nested=True,
            names=names,
            propagate_lock=True,
        )
        self._maybe_set_shared_attributes(result)
        return result

    def _squeeze(self, dim=None):
        batch_size = self.batch_size
        if dim is None:
            names = copy(self.names) if self._has_names() else None
            if names is not None:
                batch_size, names = _zip_strict(
                    *[
                        (size, name)
                        for size, name in _zip_strict(batch_size, names)
                        if size != 1
                    ]
                )
            else:
                batch_size = [size for size in batch_size if size != 1]
            batch_size = torch.Size(batch_size)
            if batch_size == self.batch_size:
                return self

            # we only want to squeeze dimensions lower than the batch dim, and view
            # is the perfect op for this
            def _squeeze(tensor):
                return tensor.view(*batch_size, *tensor.shape[self.batch_dims :])

            return self._fast_apply(
                _squeeze,
                batch_size=batch_size,
                names=names,
                inplace=False,
                call_on_nested=True,
                propagate_lock=True,
            )
        # make the dim positive
        if dim < 0:
            newdim = self.batch_dims + dim
        else:
            newdim = dim

        if (newdim >= self.batch_dims) or (newdim < 0):
            raise RuntimeError(
                f"squeezing is allowed for dims comprised between "
                f"`-td.batch_dims` and `td.batch_dims - 1` only. Got "
                f"dim={dim} with a batch size of {self.batch_size}."
            )
        if batch_size[dim] != 1:
            return self
        batch_size = list(batch_size)
        batch_size.pop(dim)
        batch_size = list(batch_size)
        names = copy(self.names) if self._has_names() else None
        if names:
            names.pop(dim)

        result = self._fast_apply(
            lambda x: x.squeeze(newdim),
            batch_size=batch_size,
            names=names,
            inplace=False,
            call_on_nested=True,
            propagate_lock=True,
        )
        self._maybe_set_shared_attributes(result)
        return result

    def _unsqueeze(self, dim):
        # make the dim positive
        if dim < 0:
            newdim = self.batch_dims + dim + 1
        else:
            newdim = dim

        if (newdim > self.batch_dims) or (newdim < 0):
            raise RuntimeError(
                f"unsqueezing is allowed for dims comprised between "
                f"`-td.batch_dims - 1` and `td.batch_dims` only. Got "
                f"dim={dim} with a batch size of {self.batch_size}."
            )
        batch_size = list(self.batch_size)
        batch_size.insert(newdim, 1)
        batch_size = torch.Size(batch_size)

        names = copy(self.names) if self._has_names() else None
        if names:
            names.insert(newdim, None)

        def _unsqueeze(tensor):
            return tensor.unsqueeze(newdim)

        result = self._fast_apply(
            _unsqueeze,
            batch_size=batch_size,
            names=names,
            inplace=False,
            call_on_nested=True,
            propagate_lock=True,
        )
        self._maybe_set_shared_attributes(result)
        return result

    @classmethod
    def from_dict(
        cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None
    ):
        if batch_dims is not None and batch_size is not None:
            raise ValueError(
                "Cannot pass both batch_size and batch_dims to `from_dict`."
            )

        batch_size_set = torch.Size(()) if batch_size is None else batch_size
        input_dict = copy(input_dict)
        for key, value in list(input_dict.items()):
            if isinstance(value, (dict,)):
                # we don't know if another tensor of smaller size is coming
                # so we can't be sure that the batch-size will still be valid later
                input_dict[key] = TensorDict.from_dict(
                    value, batch_size=[], device=device, batch_dims=None
                )
        # regular __init__ breaks because a tensor may have the same batch-size as the tensordict
        out = cls(
            input_dict,
            batch_size=batch_size_set,
            device=device,
            names=names,
        )
        if batch_size is None:
            _set_max_batch_size(out, batch_dims)
        else:
            out.batch_size = batch_size
        return out

    @classmethod
    def _from_dict_validated(
        cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None
    ):
        return cls._new_unsafe(
            input_dict,
            batch_size=torch.Size(batch_size),
            device=torch.device(device) if device is not None else device,
            names=names if any(name is not None for name in names) else None,
        )

    def from_dict_instance(
        self, input_dict, batch_size=None, device=None, batch_dims=None, names=None
    ):
        if batch_dims is not None and batch_size is not None:
            raise ValueError(
                "Cannot pass both batch_size and batch_dims to `from_dict`."
            )
        from tensordict import TensorDict

        batch_size_set = torch.Size(()) if batch_size is None else batch_size
        input_dict = copy(input_dict)
        for key, value in list(input_dict.items()):
            if isinstance(value, (dict,)):
                cur_value = self.get(key, None)
                if cur_value is not None:
                    input_dict[key] = cur_value.from_dict_instance(
                        value, batch_size=[], device=device, batch_dims=None
                    )
                    continue
                # we don't know if another tensor of smaller size is coming
                # so we can't be sure that the batch-size will still be valid later
                input_dict[key] = TensorDict.from_dict(
                    value, batch_size=[], device=device, batch_dims=None
                )
        out = TensorDict.from_dict(
            input_dict,
            batch_size=batch_size_set,
            device=device,
            names=names,
        )
        if batch_size is None:
            _set_max_batch_size(out, batch_dims)
        else:
            out.batch_size = batch_size
        return out

    @staticmethod
    def _parse_batch_size(
        source: T | dict,
        batch_size: Sequence[int] | torch.Size | int | None = None,
    ) -> torch.Size:
        try:
            return torch.Size(batch_size)
        except Exception:
            if batch_size is None:
                return torch.Size([])
            elif isinstance(batch_size, Number):
                return torch.Size([batch_size])
            elif isinstance(source, TensorDictBase):
                return source.batch_size
            raise ValueError(
                "batch size was not specified when creating the TensorDict "
                "instance and it could not be retrieved from source."
            )

    @property
    def batch_dims(self) -> int:
        return len(self.batch_size)

    @batch_dims.setter
    def batch_dims(self, value: int) -> None:
        raise RuntimeError(
            f"Setting batch dims on {type(self).__name__} instances is " f"not allowed."
        )

    def _has_names(self):
        return self._td_dim_names is not None

    def _erase_names(self):
        self._td_dim_names = None

    @property
    def names(self):
        names = self._td_dim_names
        if names is None:
            return [None for _ in range(self.batch_dims)]
        # assert len(names) == self.batch_dims, (names, self.batch_dims)
        return names

    @names.setter
    def names(self, value):
        if is_dynamo_compiling():
            if value is not None:
                graph_break()
            else:
                # We have already made sure that the tensordict was not named
                return

        # we don't run checks on types for efficiency purposes
        if value is None:
            self._rename_subtds(value)
            self._erase_names()
            return
        value = list(value)
        # Faster but incompatible with dynamo
        # num_none = sum(v is None for v in value)
        num_none = 0
        for v in value:
            num_none += v is None
        if num_none == self.batch_dims:
            self.names = None
            return
        if num_none:
            num_none -= 1
        if len(set(value)) != len(value) - num_none:
            raise ValueError(f"Some dimension names are non-unique: {value}.")
        if len(value) != self.batch_dims:
            raise ValueError(
                "the length of the dimension names must equate the tensordict batch_dims attribute. "
                f"Got {value} for batch_dims {self.batch_dims}."
            )
        self._rename_subtds(value)
        self._td_dim_names = list(value)

    def _rename_subtds(self, names):
        if names is None:
            for item in self._tensordict.values():
                if _is_tensor_collection(type(item)):
                    item._erase_names()
            return
        for item in self._tensordict.values():
            if _is_tensor_collection(type(item)):
                item_names = item.names
                td_names = list(names) + item_names[len(names) :]
                item.rename_(*td_names)

    @property
    def device(self) -> torch.device | None:
        """Device of the tensordict.

        Returns `None` if device hasn't been provided in the constructor or set via `tensordict.to(device)`.

        """
        return self._device

    @device.setter
    def device(self, value: DeviceType) -> None:
        raise RuntimeError(
            "device cannot be set using tensordict.device = device, "
            "because device cannot be updated in-place. To update device, use "
            "tensordict.to(new_device), which will return a new tensordict "
            "on the new device."
        )

    @property
    def batch_size(self) -> torch.Size:
        return self._batch_size

    @batch_size.setter
    def batch_size(self, new_size: torch.Size) -> None:
        self._batch_size_setter(new_size)

    def _change_batch_size(self, new_size: torch.Size) -> None:
        self._batch_size = new_size

    # Checks
    def _check_is_shared(self) -> bool:
        share_list = [_is_shared(value) for value in self.values()]
        if any(share_list) and not all(share_list):
            shared_str = ", ".join(
                [f"{key}: {_is_shared(value)}" for key, value in self.items()]
            )
            raise RuntimeError(
                f"tensors must be either all shared or not, but mixed "
                f"features is not allowed. "
                f"Found: {shared_str}"
            )
        return all(share_list) and len(share_list) > 0

    def _check_device(self) -> None:
        devices = {value.device for value in self.values()}
        if self.device is not None and len(devices) >= 1 and devices != {self.device}:
            raise RuntimeError(
                f"TensorDict.device is {self._device}, but elements have "
                f"device values {devices}. If TensorDict.device is set then "
                "all elements must share that device."
            )

    @lock_blocked
    def popitem(self) -> Tuple[NestedKey, CompatibleType]:
        return self._tensordict.popitem()

    def _set_str(
        self,
        key: NestedKey,
        value: dict[str, CompatibleType] | CompatibleType,
        *,
        inplace: bool,
        validated: bool,
        ignore_lock: bool = False,
        non_blocking: bool = False,
    ) -> T:
        if inplace is not False:
            best_attempt = inplace is BEST_ATTEMPT_INPLACE
            inplace = self._convert_inplace(inplace, key)
        if not validated:
            value = self._validate_value(value, check_shape=True)
        if not inplace:
            if self._is_locked and not ignore_lock:
                raise RuntimeError(_LOCK_ERROR)
            self._tensordict[key] = value
        else:
            try:
                dest = self._get_str(key, default=NO_DEFAULT)
                if best_attempt and _is_tensor_collection(type(dest)):
                    dest.update(value, inplace=True, non_blocking=non_blocking)
                else:
                    if dest is not value:
                        try:
                            dest.copy_(value, non_blocking=non_blocking)
                        except RuntimeError:
                            # if we're updating a param and the storages match, nothing needs to be done
                            if not (
                                isinstance(dest, torch.Tensor)
                                and dest.data.untyped_storage().data_ptr()
                                == value.data.untyped_storage().data_ptr()
                            ):
                                raise
            except KeyError as err:
                raise err
            except Exception as err:
                raise ValueError(
                    f"Failed to update '{key}' in tensordict {self}"
                ) from err
        return self

    def _set_dict(
        self,
        d: dict[str, CompatibleType],
        *,
        validated: bool,
    ):
        if not validated:
            raise RuntimeError("Not Implemented for non-validated inputs")
        self._tensordict = d

    def _set_tuple(
        self,
        key: NestedKey,
        value: dict[str, CompatibleType] | CompatibleType,
        *,
        inplace: bool,
        validated: bool,
        non_blocking: bool = False,
    ) -> T:
        if len(key) == 1:
            return self._set_str(
                key[0],
                value,
                inplace=inplace,
                validated=validated,
                non_blocking=non_blocking,
            )
        td = self._get_str(key[0], None)
        if td is None:
            td = self._create_nested_str(key[0])
            inplace = False
        elif not _is_tensor_collection(type(td)):
            raise KeyError(
                f"The entry {key[0]} is already present in tensordict {self}."
            )
        td._set_tuple(
            key[1:],
            value,
            inplace=inplace,
            validated=validated,
            non_blocking=non_blocking,
        )
        return self

    _SHARED_INPLACE_ERROR = (
        "You're attempting to update a leaf in-place with a shared "
        "tensordict, but the new value does not match the previous. "
        "If you're using NonTensorData, see the class documentation "
        "to see how to properly pre-allocate memory in shared contexts."
    )

    def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
        if not validated:
            value = self._validate_value(value, check_shape=False)
            validated = True
        tensor_in = self._get_str(key, NO_DEFAULT)

        if is_non_tensor(value) and not (self._is_shared or self._is_memmap):
            dest = tensor_in
            is_diff = dest[idx].tolist() != value.tolist()
            if is_diff:
                dest_val = dest.maybe_to_stack()
                dest_val[idx] = value
                if dest_val is not dest:
                    self._set_str(
                        key,
                        dest_val,
                        validated=True,
                        inplace=False,
                        ignore_lock=True,
                    )
            return

        if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple):
            warn(
                "Multiple indexing can lead to unexpected behaviours when "
                "setting items, for instance `td[idx1][idx2] = other` may "
                "not write to the desired location if idx1 is a list/tensor."
            )
            tensor_in = _sub_index(tensor_in, idx)
            tensor_in.copy_(value, non_blocking=non_blocking)
        else:
            tensor_out = _set_item(
                tensor_in, idx, value, validated=validated, non_blocking=non_blocking
            )
            if tensor_in is not tensor_out:
                if self._is_shared or self._is_memmap:
                    raise RuntimeError(self._SHARED_INPLACE_ERROR)
                # this happens only when a NonTensorData becomes a NonTensorStack
                # so it is legitimate (there is no in-place modification of a tensor
                # that was expected to happen but didn't).
                # For this reason we can ignore the locked attribute of the td.
                self._set_str(
                    key,
                    tensor_out,
                    validated=True,
                    inplace=False,
                    ignore_lock=True,
                    non_blocking=non_blocking,
                )

        return self

    def _set_at_tuple(self, key, value, idx, *, validated, non_blocking: bool):
        if len(key) == 1:
            return self._set_at_str(
                key[0], value, idx, validated=validated, non_blocking=non_blocking
            )
        if key[0] not in self.keys():
            # this won't work
            raise KeyError(f"key {key} not found in set_at_ with tensordict {self}.")
        else:
            td = self._get_str(key[0], NO_DEFAULT)
        td._set_at_tuple(
            key[1:], value, idx, validated=validated, non_blocking=non_blocking
        )
        return self

    @lock_blocked
    def del_(self, key: NestedKey) -> T:
        key = _unravel_key_to_tuple(key)
        if len(key) > 1:
            td, subkey = _get_leaf_tensordict(self, key)
            td.del_(subkey)
            return self

        del self._tensordict[key[0]]
        return self

    @lock_blocked
    def rename_key_(
        self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
    ) -> T:
        # these checks are not perfect, tuples that are not tuples of strings or empty
        # tuples could go through but (1) it will raise an error anyway and (2)
        # those checks are expensive when repeated often.
        if old_key == new_key:
            return self
        if not isinstance(old_key, (str, tuple)):
            raise TypeError(
                f"Expected old_name to be a string or a tuple of strings but found {type(old_key)}"
            )
        if not isinstance(new_key, (str, tuple)):
            raise TypeError(
                f"Expected new_name to be a string or a tuple of strings but found {type(new_key)}"
            )
        if safe and (new_key in self.keys(include_nested=True)):
            raise KeyError(f"key {new_key} already present in TensorDict.")

        if isinstance(new_key, str):
            self._set_str(
                new_key,
                self.get(old_key),
                inplace=False,
                validated=True,
                non_blocking=False,
            )
        else:
            self._set_tuple(
                new_key,
                self.get(old_key),
                inplace=False,
                validated=True,
                non_blocking=False,
            )
        self.del_(old_key)
        return self

    def _stack_onto_(self, list_item: list[CompatibleType], dim: int) -> TensorDict:
        # if not isinstance(key, str):
        #     raise ValueError("_stack_onto_ expects string keys.")
        for key in self.keys():
            vals = [item._get_str(key, None) for item in list_item]
            if all(v is None for v in vals):
                continue
            dest = self._get_str(key, NO_DEFAULT)
            new_dest = torch.stack(
                vals,
                dim=dim,
                out=dest,
            )
            if new_dest is not dest:
                # This can happen with non-tensor data
                self._set_str(key, new_dest, inplace=False, validated=True)
        return self

    def entry_class(self, key: NestedKey) -> type:
        return type(self.get(key))

    def _stack_onto_at_(
        self,
        list_item: list[CompatibleType],
        dim: int,
        idx: IndexType,
    ) -> TensorDict:
        if not isinstance(idx, tuple):
            idx = (idx,)
        idx = convert_ellipsis_to_idx(idx, self.batch_size)
        for key in self.keys():
            vals = [td._get_str(key, NO_DEFAULT) for td in list_item]
            if all(v is None for v in vals):
                continue
            v = self._get_str(key, NO_DEFAULT)
            v_idx = v[idx]
            if v.data_ptr() != v_idx.data_ptr():
                raise IndexError(
                    f"Index {idx} is incompatible with stack(..., out=data) as the storages of the indexed tensors differ."
                )
            torch.stack(vals, dim=dim, out=v_idx)
            # raise ValueError(
            #     f"Cannot stack onto an indexed tensor with index {idx} "
            #     f"as its storage differs."
            # )
        return self

    def _get_str(self, key, default):
        first_key = key
        out = self._tensordict.get(first_key, None)
        if out is None:
            return self._default_get(first_key, default)
        return out

    def _get_tuple(self, key, default):
        first = self._get_str(key[0], default)
        if len(key) == 1 or first is default:
            return first
        try:
            return first._get_tuple(key[1:], default=default)
        except AttributeError as err:
            if "has no attribute" in str(err):
                raise ValueError(
                    f"Expected a TensorDictBase instance but got {type(first)} instead"
                    f" for key '{key[1:]}' in tensordict:\n{self}."
                )

    def share_memory_(self) -> T:
        if self.is_memmap():
            raise RuntimeError(
                "memmap and shared memory are mutually exclusive features."
            )
        if self.device is not None and self.device.type == "cuda":
            # cuda tensors are shared by default
            return self
        for value in self.values():
            if (
                isinstance(value, Tensor)
                and value.device.type == "cpu"
                or _is_tensor_collection(type(value))
            ):
                value.share_memory_()
        self._is_shared = True
        self.lock_()
        return self

    def detach_(self) -> T:
        for value in self.values():
            value.detach_()
        return self

    def _memmap_(
        self,
        *,
        prefix: str | None,
        copy_existing: bool,
        executor,
        futures,
        inplace,
        like,
        share_non_tensor,
    ) -> T:

        if prefix is not None:
            prefix = Path(prefix)
            if not prefix.exists():
                os.makedirs(prefix, exist_ok=True)
            metadata = {}
        if inplace and self._is_shared:
            raise RuntimeError(
                "memmap and shared memory are mutually exclusive features."
            )
        dest = self if inplace else self.empty(device=torch.device("cpu"))

        # We must set these attributes before memmapping because we need the metadata
        # to match the tensordict content.
        if inplace:
            self._is_memmap = True
            self._is_shared = False  # since they are mutually exclusive
            self._device = torch.device("cpu")
        else:
            dest._is_memmap = True
            dest._is_shared = False  # since they are mutually exclusive

        for key, value in self.items():
            type_value = type(value)
            if _is_tensor_collection(type_value):
                dest._tensordict[key] = value._memmap_(
                    prefix=prefix / key if prefix is not None else None,
                    copy_existing=copy_existing,
                    executor=executor,
                    futures=futures,
                    inplace=inplace,
                    like=like,
                    share_non_tensor=share_non_tensor,
                )
                if prefix is not None:
                    _update_metadata(
                        metadata=metadata, key=key, value=value, is_collection=True
                    )
                continue
            else:

                if executor is None:
                    _populate_memmap(
                        dest=dest,
                        value=value,
                        key=key,
                        copy_existing=copy_existing,
                        prefix=prefix,
                        like=like,
                    )
                else:
                    futures.append(
                        executor.submit(
                            _populate_memmap,
                            dest=dest,
                            value=value,
                            key=key,
                            copy_existing=copy_existing,
                            prefix=prefix,
                            like=like,
                        )
                    )
                if prefix is not None:
                    _update_metadata(
                        metadata=metadata, key=key, value=value, is_collection=False
                    )

        if prefix is not None:
            if executor is None:
                _save_metadata(
                    dest,
                    prefix,
                    metadata=metadata,
                )
            else:
                futures.append(executor.submit(_save_metadata, dest, prefix, metadata))
        dest._is_locked = True
        dest._memmap_prefix = prefix
        return dest

    @classmethod
    def _load_memmap(
        cls,
        prefix: str,
        metadata: dict,
        device: torch.device | None = None,
        out=None,
    ) -> T:
        if metadata["device"] == "None":
            metadata["device"] = None
        else:
            metadata["device"] = torch.device(metadata["device"])
        metadata["shape"] = torch.Size(metadata["shape"])

        if out is None:
            result = cls(
                {},
                batch_size=metadata.pop("shape"),
                device=metadata.pop("device") if device is None else device,
            )
        else:
            result = out

        paths = set()
        for key, entry_metadata in metadata.items():
            if not isinstance(entry_metadata, dict):
                # there can be other metadata
                continue
            type_value = entry_metadata.get("type", None)
            if type_value is not None:
                paths.add(key)
                continue
            dtype = entry_metadata.get("dtype", None)
            shape = entry_metadata.get("shape", None)
            if (
                not (prefix / f"{key}.memmap").exists()
                or dtype is None
                or shape is None
            ):
                # invalid dict means
                continue
            try:
                # this was absent in earlier versions of pytorch
                is_fake = torch._guards.active_fake_mode()
            except AttributeError:
                # Let's just make sure that the private function is just not gone
                if torch.__version__ > "2.3.0":
                    raise
                is_fake = False
            if (device is None or device != torch.device("meta")) and not is_fake:
                if entry_metadata.get("is_nested", False):
                    # The shape is the shape of the shape, get the shape from it
                    shape = MemoryMappedTensor.from_filename(
                        (prefix / f"{key}.memmap").with_suffix(".shape.memmap"),
                        shape=shape,
                        dtype=torch.long,
                    )
                else:
                    shape = torch.Size(shape)
                tensor = MemoryMappedTensor.from_filename(
                    dtype=_STRDTYPE2DTYPE[dtype],
                    shape=shape,
                    filename=str(prefix / f"{key}.memmap"),
                )
                if device is not None:
                    tensor = tensor.to(device, non_blocking=True)
            else:
                tensor = torch.zeros(
                    torch.Size(shape),
                    device=device,
                    dtype=_STRDTYPE2DTYPE[dtype],
                )
            result._set_str(
                key,
                tensor,
                validated=True,
                inplace=False,
                non_blocking=False,
            )
        # iterate over folders and load them
        for path in prefix.iterdir():
            if path.is_dir() and path.parts[-1] in paths:
                key = path.parts[-1]  # path.parts[len(prefix.parts) :]
                existing_elt = result._get_str(key, default=None)
                if existing_elt is not None:
                    existing_elt.load_memmap_(path)
                else:
                    result._set_str(
                        key,
                        TensorDict.load_memmap(path, device=device, non_blocking=True),
                        inplace=False,
                        validated=False,
                    )
        result._memmap_prefix = prefix
        return result

    def _make_memmap_subtd(self, key):
        """Creates a sub-tensordict given a tuple key."""
        result = self
        for key_str in key:
            result_tmp = result._get_str(key_str, default=None)
            if result_tmp is None:
                result_tmp = result.empty()
                if result._memmap_prefix is not None:
                    result_tmp.memmap_(prefix=result._memmap_prefix / key_str)
                    metadata = _load_metadata(result._memmap_prefix)
                    _update_metadata(
                        metadata=metadata,
                        key=key_str,
                        value=result_tmp,
                        is_collection=True,
                    )
                    _save_metadata(
                        result, prefix=result._memmap_prefix, metadata=metadata
                    )
                result._tensordict[key_str] = result_tmp
            result = result_tmp
        return result

    def make_memmap(
        self,
        key: NestedKey,
        shape: torch.Size | torch.Tensor,
        *,
        dtype: torch.dtype | None = None,
    ) -> MemoryMappedTensor:
        if not self.is_memmap():
            raise RuntimeError(
                "Can only make a memmap tensor within a memory-mapped tensordict."
            )

        key = unravel_key(key)
        if isinstance(key, tuple):
            last_node = self._make_memmap_subtd(key[:-1])
            last_key = key[-1]
        else:
            last_node = self
            last_key = key
        if last_key in last_node.keys():
            raise RuntimeError(
                f"The key {last_key} already exists within the target tensordict. Delete that entry before "
                f"overwriting it."
            )
        if dtype is None:
            dtype = torch.get_default_dtype()
        if last_node._memmap_prefix is not None:
            metadata = _load_metadata(last_node._memmap_prefix)
            memmap_tensor = _populate_empty(
                key=last_key,
                dest=last_node,
                prefix=last_node._memmap_prefix,
                shape=shape,
                dtype=dtype,
            )
            _update_metadata(
                metadata=metadata,
                key=last_key,
                value=memmap_tensor,
                is_collection=False,
            )
            _save_metadata(
                last_node, prefix=last_node._memmap_prefix, metadata=metadata
            )
        else:
            memmap_tensor = MemoryMappedTensor.empty(shape=shape, dtype=dtype)

        last_node._set_str(
            last_key, memmap_tensor, validated=False, inplace=False, ignore_lock=True
        )

        return memmap_tensor

    def make_memmap_from_storage(
        self,
        key: NestedKey,
        storage: torch.UntypedStorage,
        shape: torch.Size | torch.Tensor,
        *,
        dtype: torch.dtype | None = None,
    ) -> MemoryMappedTensor:
        if not self.is_memmap():
            raise RuntimeError(
                "Can only make a memmap tensor within a memory-mapped tensordict."
            )

        key = unravel_key(key)
        if isinstance(key, tuple):
            last_node = self._make_memmap_subtd(key[:-1])
            last_key = key[-1]
        else:
            last_node = self
            last_key = key
        if last_key in last_node.keys():
            raise RuntimeError(
                f"The key {last_key} already exists within the target tensordict. Delete that entry before "
                f"overwriting it."
            )
        if dtype is None:
            dtype = torch.get_default_dtype()

        if last_node._memmap_prefix is not None:
            metadata = _load_metadata(last_node._memmap_prefix)
            memmap_tensor = _populate_storage(
                key=last_key,
                dest=last_node,
                prefix=last_node._memmap_prefix,
                storage=storage,
                shape=shape,
                dtype=dtype,
            )
            _update_metadata(
                metadata=metadata,
                key=last_key,
                value=memmap_tensor,
                is_collection=False,
            )
            _save_metadata(
                last_node, prefix=last_node._memmap_prefix, metadata=metadata
            )
        else:
            memmap_tensor = MemoryMappedTensor.from_storage(
                storage=storage, shape=shape, dtype=dtype
            )

        last_node._set_str(
            last_key, memmap_tensor, validated=False, inplace=False, ignore_lock=True
        )

        return memmap_tensor

    def make_memmap_from_tensor(
        self, key: NestedKey, tensor: torch.Tensor, *, copy_data: bool = True
    ) -> MemoryMappedTensor:
        if not self.is_memmap():
            raise RuntimeError(
                "Can only make a memmap tensor within a memory-mapped tensordict."
            )

        key = unravel_key(key)
        if isinstance(key, tuple):
            last_node = self._make_memmap_subtd(key[:-1])
            last_key = key[-1]
        else:
            last_node = self
            last_key = key
        if last_key in last_node.keys():
            raise RuntimeError(
                f"The key {last_key} already exists within the target tensordict. Delete that entry before "
                f"overwriting it."
            )

        if last_node._memmap_prefix is not None:
            metadata = _load_metadata(last_node._memmap_prefix)
            memmap_tensor = _populate_memmap(
                dest=last_node,
                value=tensor,
                key=last_key,
                copy_existing=True,
                prefix=last_node._memmap_prefix,
                like=not copy_data,
            )
            _update_metadata(
                metadata=metadata,
                key=last_key,
                value=memmap_tensor,
                is_collection=False,
            )
            _save_metadata(
                last_node, prefix=last_node._memmap_prefix, metadata=metadata
            )
        else:
            memmap_tensor = MemoryMappedTensor.from_tensor(tensor)

        last_node._set_str(
            last_key, memmap_tensor, validated=False, inplace=False, ignore_lock=True
        )

        return memmap_tensor

    def where(self, condition, other, *, out=None, pad=None):
        if _is_tensor_collection(type(other)):

            def func(tensor, _other, key):
                if tensor is None:
                    if pad is not None:
                        tensor = _other
                        _other = torch.tensor(pad, dtype=_other.dtype)
                    else:
                        raise KeyError(
                            f"Key {key} not found and no pad value provided."
                        )
                    cond = expand_as_right(~condition, tensor)
                elif _other is None:
                    if pad is not None:
                        _other = torch.tensor(pad, dtype=tensor.dtype)
                    else:
                        raise KeyError(
                            f"Key {key} not found and no pad value provided."
                        )
                    cond = expand_as_right(condition, tensor)
                else:
                    cond = expand_as_right(condition, tensor)
                return torch.where(
                    condition=cond,
                    input=tensor,
                    other=_other,
                )

            result = self.empty() if out is None else out
            other_keys = set(other.keys())
            # we turn into a list because out could be = to self!
            for key in list(self.keys()):
                tensor = self._get_str(key, default=NO_DEFAULT)
                _other = other._get_str(key, default=None)
                if _is_tensor_collection(type(tensor)):
                    _out = None if out is None else out._get_str(key, None)
                    if _other is None:
                        _other = tensor.empty()
                    val = tensor.where(
                        condition=condition, other=_other, out=_out, pad=pad
                    )
                else:
                    val = func(tensor, _other, key)
                result._set_str(
                    key, val, inplace=False, validated=True, non_blocking=False
                )
                other_keys.discard(key)
            for key in other_keys:
                tensor = None
                _other = other._get_str(key, default=NO_DEFAULT)
                if _is_tensor_collection(type(_other)):
                    try:
                        tensor = _other.empty()
                    except NotImplementedError:
                        # H5 tensordicts do not support select()
                        tensor = _other.to_tensordict().empty()
                    val = _other.where(
                        condition=~condition, other=tensor, out=None, pad=pad
                    )
                else:
                    val = func(tensor, _other, key)
                result._set_str(
                    key, val, inplace=False, validated=True, non_blocking=False
                )
            return result
        else:
            if out is None:

                def func(tensor):
                    return torch.where(
                        condition=expand_as_right(condition, tensor),
                        input=tensor,
                        other=other,
                    )

                return self._fast_apply(func, propagate_lock=True)
            else:

                def func(tensor, _out):
                    return torch.where(
                        condition=expand_as_right(condition, tensor),
                        input=tensor,
                        other=other,
                        out=_out,
                    )

                return self._fast_apply(func, out, propagate_lock=True)

    def masked_fill_(self, mask: Tensor, value: float | int | bool) -> T:
        for item in self.values():
            mask_expand = expand_as_right(mask, item)
            item.masked_fill_(mask_expand, value)
        return self

    def masked_fill(self, mask: Tensor, value: float | bool) -> T:
        td_copy = self.clone()
        return td_copy.masked_fill_(mask, value)

    def is_contiguous(self) -> bool:
        return all([value.is_contiguous() for _, value in self.items()])

    def _clone(self, recurse: bool = True) -> T:
        result = TensorDict._new_unsafe(
            source={key: _clone_value(value, recurse) for key, value in self.items()},
            batch_size=self.batch_size,
            device=self.device,
            names=copy(self._td_dim_names) if self._has_names() else None,
        )
        # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too
        # This may be undesirable, not sure if this should be the default behaviour
        # (one usually does a copy to modify it).
        # if not recurse:
        #     self._maybe_set_shared_attributes(result)
        return result

    def contiguous(self) -> T:
        source = {key: value.contiguous() for key, value in self.items()}
        batch_size = self.batch_size
        device = self.device
        out = TensorDict._new_unsafe(
            source=source,
            batch_size=batch_size,
            device=device,
            names=self.names if self._has_names() else None,
        )
        return out

    def empty(
        self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=NO_DEFAULT
    ) -> T:
        if not recurse:
            return TensorDict._new_unsafe(
                device=self._device if device is NO_DEFAULT else device,
                batch_size=(
                    self._batch_size if batch_size is None else torch.Size(batch_size)
                ),
                source={},
                names=(
                    (self.names if self._has_names() else None)
                    if names is NO_DEFAULT
                    else names
                ),
            )
        return super().empty(recurse=recurse)

    def _select(
        self,
        *keys: NestedKey,
        inplace: bool = False,
        strict: bool = True,
        set_shared: bool = True,
    ) -> T:
        if inplace and self.is_locked:
            raise RuntimeError(_LOCK_ERROR)

        source = {}
        if len(keys):
            keys_to_select = None
            for key in keys:
                if isinstance(key, str):
                    subkey = []
                else:
                    key, subkey = key[0], key[1:]

                val = self._get_str(key, default=None if not strict else NO_DEFAULT)
                if val is None:
                    continue
                source[key] = val
                if len(subkey):
                    if keys_to_select is None:
                        # delay creation of defaultdict
                        keys_to_select = defaultdict(list)
                    keys_to_select[key].append(subkey)

            if keys_to_select is not None:
                for key, val in keys_to_select.items():
                    source[key] = source[key]._select(
                        *val, strict=strict, inplace=inplace, set_shared=set_shared
                    )

        result = TensorDict._new_unsafe(
            device=self.device,
            batch_size=self.batch_size,
            source=source,
            # names=self.names if self._has_names() else None,
            names=self._td_dim_names,
        )
        if inplace:
            self._tensordict = result._tensordict
            return self
        # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too
        # This may be undesirable, not sure if this should be the default behaviour
        # (one usually does a copy to modify it).
        # if set_shared:
        #     self._maybe_set_shared_attributes(result)
        return result

    def _exclude(
        self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
    ) -> T:
        # faster than Base.exclude
        if not len(keys):
            return self.copy() if not inplace else self
        if not inplace:
            _tensordict = copy(self._tensordict)
        else:
            _tensordict = self._tensordict
        keys_to_exclude = None
        for key in keys:
            key = unravel_key(key)
            if isinstance(key, str):
                _tensordict.pop(key, None)
            else:
                if keys_to_exclude is None:
                    # delay creation of defaultdict
                    keys_to_exclude = defaultdict(list)
                if key[0] in self._tensordict:
                    keys_to_exclude[key[0]].append(key[1:])
        if keys_to_exclude is not None:
            for key, cur_keys in keys_to_exclude.items():
                val = _tensordict.get(key, None)
                if val is not None:
                    val = val._exclude(
                        *cur_keys, inplace=inplace, set_shared=set_shared
                    )
                    if not inplace:
                        _tensordict[key] = val
        if inplace:
            return self
        result = TensorDict._new_unsafe(
            _tensordict,
            batch_size=self.batch_size,
            device=self.device,
            names=self.names if self._has_names() else None,
        )
        # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too
        # This may be undesirable, not sure if this should be the default behaviour
        # (one usually does a copy to modify it).
        # if set_shared:
        #     self._maybe_set_shared_attributes(result)
        return result

    def keys(
        self,
        include_nested: bool = False,
        leaves_only: bool = False,
        is_leaf: Callable[[Type], bool] | None = None,
    ) -> _TensorDictKeysView:
        if not include_nested and not leaves_only and is_leaf is None:
            return _StringKeys(self._tensordict.keys())
        else:
            return self._nested_keys(
                include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
            )

    @cache  # noqa: B019
    def _nested_keys(
        self,
        include_nested: bool = False,
        leaves_only: bool = False,
        is_leaf: Callable[[Type], bool] | None = None,
    ) -> _TensorDictKeysView:
        return _TensorDictKeysView(
            self,
            include_nested=include_nested,
            leaves_only=leaves_only,
            is_leaf=is_leaf,
        )

    # some custom methods for efficiency
    def items(
        self,
        include_nested: bool = False,
        leaves_only: bool = False,
        is_leaf: Callable[[Type], bool] | None = None,
    ) -> Iterator[tuple[str, CompatibleType]]:
        if not include_nested and not leaves_only:
            return self._tensordict.items()
        elif include_nested and leaves_only:
            is_leaf = _default_is_leaf if is_leaf is None else is_leaf
            result = []
            if is_dynamo_compiling():

                def fast_iter():
                    for key, val in self._tensordict.items():
                        if not is_leaf(type(val)):
                            for _key, _val in val.items(
                                include_nested=include_nested,
                                leaves_only=leaves_only,
                                is_leaf=is_leaf,
                            ):
                                result.append(
                                    (
                                        (
                                            key,
                                            *(
                                                (_key,)
                                                if isinstance(_key, str)
                                                else _key
                                            ),
                                        ),
                                        _val,
                                    )
                                )
                        else:
                            result.append((key, val))
                    return result

            else:
                # dynamo doesn't like generators
                def fast_iter():
                    for key, val in self._tensordict.items():
                        if not is_leaf(type(val)):
                            yield from (
                                (
                                    (
                                        key,
                                        *((_key,) if isinstance(_key, str) else _key),
                                    ),
                                    _val,
                                )
                                for _key, _val in val.items(
                                    include_nested=include_nested,
                                    leaves_only=leaves_only,
                                    is_leaf=is_leaf,
                                )
                            )
                        else:
                            yield (key, val)

            return fast_iter()
        else:
            return super().items(
                include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
            )

    def values(
        self,
        include_nested: bool = False,
        leaves_only: bool = False,
        is_leaf: Callable[[Type], bool] | None = None,
    ) -> Iterator[tuple[str, CompatibleType]]:
        if not include_nested and not leaves_only:
            return self._tensordict.values()
        else:
            return TensorDictBase.values(
                self,
                include_nested=include_nested,
                leaves_only=leaves_only,
                is_leaf=is_leaf,
            )


class _SubTensorDict(TensorDictBase):
    """A TensorDict that only sees an index of the stored tensors."""

    _lazy = True
    _inplace_set = True
    _safe = False

    def __init__(
        self,
        source: T,
        idx: IndexType,
        batch_size: Sequence[int] | None = None,
    ) -> None:
        if not _is_tensor_collection(type(source)):
            raise TypeError(
                f"Expected source to be a subclass of TensorDictBase, "
                f"got {type(source)}"
            )
        self._source = source
        idx = (
            (idx,)
            if not isinstance(
                idx,
                (
                    tuple,
                    list,
                ),
            )
            else tuple(idx)
        )
        if any(item is Ellipsis for item in idx):
            idx = convert_ellipsis_to_idx(idx, self._source.batch_size)
        self._batch_size = _getitem_batch_size(self._source.batch_size, idx)
        self.idx = idx

        if batch_size is not None and batch_size != self.batch_size:
            raise RuntimeError("batch_size does not match self.batch_size.")

    # These attributes should never be set
    @property
    def _is_shared(self):
        return self._source._is_shared

    @property
    def _is_memmap(self):
        return self._source._is_memmap

    @staticmethod
    def _convert_ellipsis(idx, shape):
        if any(_idx is Ellipsis for _idx in idx):
            new_idx = []
            cursor = -1
            for _idx in idx:
                if _idx is Ellipsis:
                    if cursor == len(idx) - 1:
                        # then we can just skip
                        continue
                    n_upcoming = len(idx) - cursor - 1
                    while cursor < len(shape) - n_upcoming:
                        cursor += 1
                        new_idx.append(slice(None))
                else:
                    new_idx.append(_idx)
            return tuple(new_idx)
        return idx

    @property
    def batch_size(self) -> torch.Size:
        return self._batch_size

    @batch_size.setter
    def batch_size(self, new_size: torch.Size) -> None:
        self._batch_size_setter(new_size)

    @property
    def names(self):
        names = self._source._get_names_idx(self.idx)
        if names is None:
            return [None] * self.batch_dims
        return names

    @names.setter
    def names(self, value):
        raise RuntimeError(
            "Names of a subtensordict cannot be modified. Instantiate it as a TensorDict first."
        )

    def _has_names(self):
        return self._source._has_names()

    def _erase_names(self):
        raise RuntimeError(
            "Cannot erase names of a _SubTensorDict. Erase source TensorDict's names instead."
        )

    def _rename_subtds(self, names):
        for key in self.keys():
            if _is_tensor_collection(self.entry_class(key)):
                raise RuntimeError("Cannot rename nested sub-tensordict dimensions.")

    @property
    def device(self) -> None | torch.device:
        return self._source.device

    @device.setter
    def device(self, value: DeviceType) -> None:
        self._source.device = value

    def _preallocate(self, key: NestedKey, value: CompatibleType) -> T:
        return self._source.set(key, value)

    def _convert_inplace(self, inplace, key):
        has_key = key in self.keys()
        if inplace is not False:
            if inplace is True and not has_key:  # inplace could be None
                raise KeyError(
                    _KEY_ERROR.format(key, type(self).__name__, sorted(self.keys()))
                )
            inplace = has_key
        if not inplace and has_key:
            raise RuntimeError(
                "Calling `_SubTensorDict.set(key, value, inplace=False)` is "
                "prohibited for existing tensors. Consider calling "
                "_SubTensorDict.set_(...) or cloning your tensordict first."
            )
        elif not inplace and self.is_locked:
            raise RuntimeError(_LOCK_ERROR)
        return inplace

    from_dict_instance = TensorDict.from_dict_instance

    def _set_str(
        self,
        key: NestedKey,
        value: dict[str, CompatibleType] | CompatibleType,
        *,
        inplace: bool,
        validated: bool,
        ignore_lock: bool = False,
        non_blocking: bool = False,
    ) -> T:
        inplace = self._convert_inplace(inplace, key)
        # it is assumed that if inplace=False then the key doesn't exist. This is
        # checked in set method, but not here. responsibility lies with the caller
        # so that this method can have minimal overhead from runtime checks
        parent = self._source
        if not validated:
            value = self._validate_value(value, check_shape=True)
            validated = True
        if not inplace:
            if _is_tensor_collection(type(value)):
                # value has the shape of subtd[idx], so we want an expanded
                #  version value_expand such that value_expand[idx] has the
                #  shape of value
                value_expand = _expand_to_match_shape(
                    parent.batch_size,
                    value,
                    self.batch_dims,
                    self.device,
                    index=self.idx,
                )
                for _key, _tensor in value.items():
                    value_expand._set_str(
                        _key,
                        _expand_to_match_shape(
                            parent.batch_size,
                            _tensor,
                            self.batch_dims,
                            self.device,
                            index=self.idx,
                        ),
                        inplace=inplace,
                        validated=validated,
                        ignore_lock=ignore_lock,
                        non_blocking=non_blocking,
                    )
            else:
                value_expand = torch.zeros(
                    (
                        *parent.batch_size,
                        *_shape(value)[self.batch_dims :],
                    ),
                    dtype=value.dtype,
                    device=self.device,
                )
                if self._is_shared:
                    value_expand.share_memory_()
                elif self._is_memmap:
                    value_expand = MemoryMappedTensor.from_tensor(value_expand)
            parent._set_str(
                key,
                value_expand,
                inplace=False,
                validated=validated,
                ignore_lock=ignore_lock,
                non_blocking=non_blocking,
            )

        parent._set_at_str(
            key, value, self.idx, validated=validated, non_blocking=non_blocking
        )
        return self

    def _set_tuple(
        self,
        key: NestedKey,
        value: dict[str, CompatibleType] | CompatibleType,
        *,
        inplace: bool,
        validated: bool,
        non_blocking: bool = False,
    ) -> T:
        if len(key) == 1:
            return self._set_str(
                key[0],
                value,
                inplace=inplace,
                validated=validated,
                non_blocking=non_blocking,
            )
        parent = self._source
        td = parent._get_str(key[0], None)
        if td is None:
            td = parent.select()
            parent._set_str(
                key[0], td, inplace=False, validated=True, non_blocking=non_blocking
            )
        _SubTensorDict(td, self.idx)._set_tuple(
            key[1:],
            value,
            inplace=inplace,
            validated=validated,
            non_blocking=non_blocking,
        )
        return self

    def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
        tensor_in = self._get_str(key, NO_DEFAULT)
        if not validated:
            value = self._validate_value(value, check_shape=False)
            validated = True
        if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple):
            warn(
                "Multiple indexing can lead to unexpected behaviours when "
                "setting items, for instance `td[idx1][idx2] = other` may "
                "not write to the desired location if idx1 is a list/tensor."
            )
            tensor_in = _sub_index(tensor_in, idx)
            tensor_in.copy_(value)
            tensor_out = tensor_in
        else:
            tensor_out = _set_item(
                tensor_in, idx, value, validated=validated, non_blocking=non_blocking
            )
        # make sure that the value is updated
        self._source._set_at_str(
            key, tensor_out, self.idx, validated=validated, non_blocking=non_blocking
        )
        return self

    def _set_at_tuple(self, key, value, idx, *, validated, non_blocking: bool):
        if len(key) == 1:
            return self._set_at_str(
                key[0], value, idx, validated=validated, non_blocking=non_blocking
            )
        if key[0] not in self.keys():
            # this won't work
            raise KeyError(f"key {key} not found in set_at_ with tensordict {self}.")
        else:
            td = self._get_str(key[0], NO_DEFAULT)
        td._set_at_tuple(
            key[1:], value, idx, validated=validated, non_blocking=non_blocking
        )
        return self

    # @cache  # noqa: B019
    def keys(
        self,
        include_nested: bool = False,
        leaves_only: bool = False,
        is_leaf: Callable[[Type], bool] | None = None,
    ) -> _TensorDictKeysView:
        return self._source.keys(
            include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
        )

    def entry_class(self, key: NestedKey) -> type:
        source_type = type(self._source.get(key))
        if _is_tensor_collection(source_type):
            return type(self)
        return source_type

    def _stack_onto_(self, list_item: list[CompatibleType], dim: int) -> _SubTensorDict:
        self._source._stack_onto_at_(list_item, dim=dim, idx=self.idx)
        return self

    def to(self, *args, **kwargs: Any) -> T:
        (
            device,
            dtype,
            non_blocking,
            convert_to_format,
            batch_size,
            pin_memory,
            num_threads,
        ) = _parse_to(*args, **kwargs)
        result = self

        if device is not None and dtype is None and device == self.device:
            return result
        return self.to_tensordict().to(*args, **kwargs)

    def _change_batch_size(self, new_size: torch.Size) -> None:
        self._batch_size = new_size

    def get(
        self,
        key: NestedKey,
        default: Tensor | str | None = NO_DEFAULT,
    ) -> CompatibleType:
        return self._source.get_at(key, self.idx, default=default)

    def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
        out = super()._get_non_tensor(key, default=default)

        if isinstance(out, _SubTensorDict) and is_non_tensor(out._source):
            return out._source
        return out

    def _get_str(self, key, default):
        if key in self.keys() and _is_tensor_collection(self.entry_class(key)):
            data = self._source._get_str(key, NO_DEFAULT)
            if is_non_tensor(data):
                return data[self.idx]
            return _SubTensorDict(data, self.idx)
        return self._source._get_at_str(key, self.idx, default=default)

    def _get_tuple(self, key, default):
        return self._source._get_at_tuple(key, self.idx, default=default)

    def update(
        self,
        input_dict_or_td: dict[str, CompatibleType] | TensorDictBase,
        clone: bool = False,
        inplace: bool = False,
        *,
        non_blocking: bool = False,
        keys_to_update: Sequence[NestedKey] | None = None,
        is_leaf: Callable[[Type], bool] | None = None,
        **kwargs,
    ) -> _SubTensorDict:
        if input_dict_or_td is self:
            # no op
            return self
        if is_leaf is None:
            is_leaf = _is_leaf_nontensor

        if getattr(self._source, "_has_exclusive_keys", False):
            raise RuntimeError(
                "Cannot use _SubTensorDict.update with a LazyStackedTensorDict that has exclusive keys."
            )
        if keys_to_update is not None:
            if len(keys_to_update) == 0:
                return self
            keys_to_update = unravel_key_list(keys_to_update)
        keys = set(self.keys(False))
        for key, value in input_dict_or_td.items():
            key = _unravel_key_to_tuple(key)
            firstkey, subkey = key[0], key[1:]
            if keys_to_update and not any(
                firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
                for ktu in keys_to_update
            ):
                continue
            if clone and hasattr(value, "clone"):
                value = value.clone()
            elif clone:
                value = tree_map(torch.clone, value)
            # the key must be a string by now. Let's check if it is present
            if firstkey in keys:
                target_class = self.entry_class(firstkey)
                if _is_tensor_collection(target_class):
                    target = self._source.get(firstkey)._get_sub_tensordict(self.idx)
                    if len(subkey):
                        sub_keys_to_update = _prune_selected_keys(
                            keys_to_update, firstkey
                        )
                        target.update(
                            {subkey: value},
                            inplace=False,
                            keys_to_update=sub_keys_to_update,
                            non_blocking=non_blocking,
                            is_leaf=is_leaf,
                        )
                        continue
                    elif isinstance(value, dict) or _is_tensor_collection(type(value)):
                        sub_keys_to_update = _prune_selected_keys(
                            keys_to_update, firstkey
                        )
                        target.update(
                            value,
                            keys_to_update=sub_keys_to_update,
                            non_blocking=non_blocking,
                        )
                        continue
                    raise ValueError(
                        f"Tried to replace a tensordict with an incompatible object of type {type(value)}"
                    )
                else:
                    self._set_tuple(
                        key,
                        value,
                        inplace=True,
                        validated=False,
                        non_blocking=non_blocking,
                    )
            else:
                self._set_tuple(
                    key,
                    value,
                    inplace=BEST_ATTEMPT_INPLACE if inplace else False,
                    validated=False,
                    non_blocking=non_blocking,
                )
        return self

    def update_(
        self,
        input_dict: dict[str, CompatibleType] | TensorDictBase,
        clone: bool = False,
        *,
        non_blocking: bool = False,
        keys_to_update: Sequence[NestedKey] | None = None,
    ) -> _SubTensorDict:
        return self.update_at_(
            input_dict,
            idx=self.idx,
            discard_idx_attr=True,
            clone=clone,
            keys_to_update=keys_to_update,
            non_blocking=non_blocking,
        )

    def update_at_(
        self,
        input_dict: dict[str, CompatibleType] | TensorDictBase,
        idx: IndexType,
        *,
        discard_idx_attr: bool = False,
        clone: bool = False,
        non_blocking: bool = False,
        keys_to_update: Sequence[NestedKey] | None = None,
    ) -> _SubTensorDict:
        if keys_to_update is not None:
            if len(keys_to_update) == 0:
                return self
            keys_to_update = unravel_key_list(keys_to_update)
        for key, value in input_dict.items():
            key = _unravel_key_to_tuple(key)
            firstkey, _ = key[0], key[1:]
            if keys_to_update and not any(
                firstkey == ktu if isinstance(ktu, str) else firstkey == ktu[0]
                for ktu in keys_to_update
            ):
                continue
            if not isinstance(value, tuple(_ACCEPTED_CLASSES)):
                raise TypeError(
                    f"Expected value to be one of types {_ACCEPTED_CLASSES} "
                    f"but got {type(value)}"
                )
            if clone:
                value = value.clone()
            if discard_idx_attr:
                self._source._set_at_tuple(
                    key,
                    value,
                    idx,
                    non_blocking=non_blocking,
                    validated=False,
                )
            else:
                self._set_at_tuple(
                    key, value, idx, validated=False, non_blocking=non_blocking
                )
        return self

    def get_parent_tensordict(self) -> T:
        if not isinstance(self._source, TensorDictBase):
            raise TypeError(
                f"_SubTensorDict was initialized with a source of type"
                f" {type(self._source).__name__}, "
                "parent tensordict not accessible"
            )
        return self._source

    @lock_blocked
    def del_(self, key: NestedKey) -> T:
        self._source = self._source.del_(key)
        return self

    @lock_blocked
    def popitem(self) -> Tuple[NestedKey, CompatibleType]:
        raise NotImplementedError(
            f"popitem not implemented for class {type(self).__name__}."
        )

    def _clone(self, recurse: bool = True) -> _SubTensorDict:
        """Clones the _SubTensorDict.

        Args:
            recurse (bool, optional): if ``True`` (default), a regular
                :class:`~.tensordict.TensorDict` instance will be created from the :class:`~.tensordict._SubTensorDict`.
                Otherwise, another :class:`~.tensordict._SubTensorDict` with identical content
                will be returned.

        Examples:
            >>> data = TensorDict({"a": torch.arange(4).reshape(2, 2,)}, batch_size=[2, 2])
            >>> sub_data = data._get_sub_tensordict([0,])
            >>> print(sub_data)
            _SubTensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([2]),
                device=None,
                is_shared=False)
            >>> # the data of both subtensordict is the same
            >>> print(data.get("a").data_ptr(), sub_data.get("a").data_ptr())
            140183705558208 140183705558208
            >>> sub_data_clone = sub_data.clone(recurse=True)
            >>> print(sub_data_clone)
            TensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([2]),
                device=None,
                is_shared=False)
            >>. print(sub_data.get("a").data_ptr())
            140183705558208
            >>> sub_data_clone = sub_data.clone(recurse=False)
            >>> print(sub_data_clone)
            _SubTensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([2]),
                device=None,
                is_shared=False)
            >>> print(sub_data.get("a").data_ptr())
            140183705558208
        """
        if not recurse:
            return _SubTensorDict(
                source=self._source._clone(recurse=False), idx=self.idx
            )
        return self.to_tensordict()

    def is_contiguous(self) -> bool:
        return all(value.is_contiguous() for value in self.values())

    def contiguous(self) -> T:
        return TensorDict._new_unsafe(
            batch_size=self.batch_size,
            source={key: value.contiguous() for key, value in self.items()},
            device=self.device,
            names=self.names if self._has_names() else None,
        )

    def _select(
        self,
        *keys: NestedKey,
        inplace: bool = False,
        strict: bool = True,
        set_shared: bool = True,
    ) -> T:
        if inplace:
            raise RuntimeError("Cannot call select inplace on a lazy tensordict.")
        return self.to_tensordict()._select(
            *keys, inplace=False, strict=strict, set_shared=set_shared
        )

    def _exclude(
        self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
    ) -> T:
        if inplace:
            raise RuntimeError("Cannot call exclude inplace on a lazy tensordict.")
        return self.to_tensordict()._exclude(
            *keys, inplace=False, set_shared=set_shared
        )

    def expand(self, *args: int, inplace: bool = False) -> T:
        if len(args) == 1 and isinstance(args[0], Sequence):
            shape = tuple(args[0])
        else:
            shape = args
        return self._fast_apply(
            lambda x: x.expand((*shape, *x.shape[self.ndim :])),
            batch_size=shape,
            propagate_lock=True,
        )

    @classmethod
    def from_dict(
        cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None
    ):
        raise NotImplementedError(f"from_dict not implemented for {cls.__name__}.")

    def is_shared(self) -> bool:
        return self._source.is_shared()

    def is_memmap(self) -> bool:
        return self._source.is_memmap()

    def rename_key_(
        self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
    ) -> _SubTensorDict:
        self._source.rename_key_(old_key, new_key, safe=safe)
        return self

    def pin_memory(self, *args, **kwargs) -> T:
        raise RuntimeError(
            f"Cannot pin memory of a {type(self).__name__}. Call to_tensordict() before making this call."
        )

    def detach_(self) -> T:
        raise RuntimeError("Detaching a sub-tensordict in-place cannot be done.")

    def where(self, condition, other, *, out=None, pad=None):
        return self.to_tensordict().where(
            condition=condition, other=other, out=out, pad=pad
        )

    def masked_fill_(self, mask: Tensor, value: float | bool) -> T:
        for key, item in self.items():
            self.set_(key, torch.full_like(item, value))
        return self

    def masked_fill(self, mask: Tensor, value: float | bool) -> T:
        td_copy = self.clone()
        return td_copy.masked_fill_(mask, value)

    def memmap_(
        self,
        prefix: str | None = None,
        copy_existing: bool = False,
        num_threads: int = 0,
    ) -> T:
        raise RuntimeError(
            "Converting a sub-tensordict values to memmap cannot be done."
        )

    def _memmap_(
        self,
        *,
        prefix: str | None,
        copy_existing: bool,
        executor,
        futures,
        inplace,
        like,
        share_non_tensor,
    ) -> T:
        if prefix is not None:

            def save_metadata(prefix=prefix, self=self):
                prefix = Path(prefix)
                if not prefix.exists():
                    os.makedirs(prefix, exist_ok=True)
                with open(prefix / "meta.json", "wb") as f:
                    f.write(
                        json.dumps(
                            {
                                "_type": str(type(self)),
                                "index": _index_to_str(self.idx),
                            }
                        )
                    )

            if executor is None:
                save_metadata()
            else:
                futures.append(executor.submit(save_metadata))

        _source = self._source._memmap_(
            prefix=prefix / "_source" if prefix is not None else None,
            copy_existing=copy_existing,
            executor=executor,
            futures=futures,
            inplace=inplace,
            like=like,
            share_non_tensor=share_non_tensor,
        )
        if not inplace:
            result = _SubTensorDict(_source, idx=self.idx)
        else:
            result = self
        return result

    @classmethod
    def _load_memmap(
        cls, prefix: Path, metadata: dict, device: torch.device | None = None
    ):
        index = metadata["index"]
        return _SubTensorDict(
            TensorDict.load_memmap(prefix / "_source", device=device),
            _str_to_index(index),
        )

    def make_memmap(
        self,
        key: NestedKey,
        shape: torch.Size | torch.Tensor,
        *,
        dtype: torch.dtype | None = None,
    ) -> MemoryMappedTensor:
        raise RuntimeError(
            "Making a memory-mapped tensor after instantiation isn't currently allowed for _SubTensorDict."
            "If this feature is required, open an issue on GitHub to trigger a discussion on the topic!"
        )

    def make_memmap_from_storage(
        self,
        key: NestedKey,
        storage: torch.UntypedStorage,
        shape: torch.Size | torch.Tensor,
        *,
        dtype: torch.dtype | None = None,
    ) -> MemoryMappedTensor:
        raise RuntimeError(
            "Making a memory-mapped tensor after instantiation isn't currently allowed for _SubTensorDict."
            "If this feature is required, open an issue on GitHub to trigger a discussion on the topic!"
        )

    def make_memmap_from_tensor(
        self, key: NestedKey, tensor: torch.Tensor, *, copy_data: bool = True
    ) -> MemoryMappedTensor:
        raise RuntimeError(
            "Making a memory-mapped tensor after instantiation isn't currently allowed for _SubTensorDict."
            "If this feature is required, open an issue on GitHub to trigger a discussion on the topic!"
        )

    def share_memory_(self) -> T:
        raise RuntimeError(
            "Casting a sub-tensordict values to shared memory cannot be done."
        )

    @property
    def is_locked(self) -> bool:
        return self._source.is_locked

    @is_locked.setter
    def is_locked(self, value) -> bool:
        if value:
            self.lock_()
        else:
            self.unlock_()

    @_as_context_manager("is_locked")
    def lock_(self) -> T:
        # we can't lock sub-tensordicts because that would mean that the
        # parent tensordict cannot be modified either.
        if not self.is_locked:
            raise RuntimeError(
                "Cannot lock a _SubTensorDict. Lock the parent tensordict instead."
            )
        return self

    @_as_context_manager("is_locked")
    def unlock_(self) -> T:
        if self.is_locked:
            raise RuntimeError(
                "Cannot unlock a _SubTensorDict. Unlock the parent tensordict instead."
            )
        return self

    def _remove_lock(self, lock_id):
        raise RuntimeError(
            "Cannot unlock a _SubTensorDict. Unlock the parent tensordict instead."
        )

    def _propagate_lock(self, lock_ids=None, *, is_compiling):
        raise RuntimeError(
            "Cannot lock a _SubTensorDict. Lock the parent tensordict instead."
        )

    def __del__(self):
        pass

    def _create_nested_str(self, key):
        # this may fail with a sub-sub tensordict
        out = self._source.empty()
        self._source._set_str(
            key, out, inplace=False, validated=True, non_blocking=False
        )
        # the id of out changes
        return self._get_str(key, default=NO_DEFAULT)

    def _cast_reduction(
        self,
        *,
        reduction_name,
        dim=NO_DEFAULT,
        keepdim=NO_DEFAULT,
        tuple_ok=True,
        **kwargs,
    ):
        try:
            td = self.to_tensordict()
        except Exception:
            raise RuntimeError(
                f"{reduction_name} requires this object to be cast to a regular TensorDict. "
                f"If you need {type(self)} to support {reduction_name}, help us by filing an issue"
                f" on github!"
            )
        return td._cast_reduction(
            reduction_name=reduction_name,
            dim=dim,
            keepdim=keepdim,
            tuple_ok=tuple_ok,
            **kwargs,
        )

    # TODO: check these implementations
    __eq__ = TensorDict.__eq__
    __ne__ = TensorDict.__ne__
    __ge__ = TensorDict.__ge__
    __gt__ = TensorDict.__gt__
    __le__ = TensorDict.__le__
    __lt__ = TensorDict.__lt__
    __setitem__ = TensorDict.__setitem__
    __xor__ = TensorDict.__xor__
    __or__ = TensorDict.__or__
    _check_device = TensorDict._check_device
    _check_is_shared = TensorDict._check_is_shared
    all = TensorDict.all
    any = TensorDict.any
    masked_select = TensorDict.masked_select
    memmap_like = TensorDict.memmap_like
    reshape = TensorDict.reshape
    split = TensorDict.split
    _to_module = TensorDict._to_module
    _unbind = TensorDict._unbind

    def _view(self, *args, **kwargs):
        raise RuntimeError(
            "Cannot call `view` on a sub-tensordict. Call `reshape` instead."
        )

    def _transpose(self, dim0, dim1):
        raise RuntimeError(
            "Cannot call `transpose` on a sub-tensordict. Make it dense before calling this method by calling `to_tensordict`."
        )

    def _permute(
        self,
        *args,
        **kwargs,
    ):
        raise RuntimeError(
            "Cannot call `permute` on a sub-tensordict. Make it dense before calling this method by calling `to_tensordict`."
        )

    def _squeeze(self, dim=None):
        raise RuntimeError(
            "Cannot call `squeeze` on a sub-tensordict. Make it dense before calling this method by calling `to_tensordict`."
        )

    def _unsqueeze(self, dim):
        raise RuntimeError(
            "Cannot call `unsqueeze` on a sub-tensordict. Make it dense before calling this method by calling `to_tensordict`."
        )

    _add_batch_dim = TensorDict._add_batch_dim

    _apply_nest = TensorDict._apply_nest
    _multithread_apply_flat = TensorDict._multithread_apply_flat
    _multithread_rebuild = TensorDict._multithread_rebuild
    _convert_to_tensordict = TensorDict._convert_to_tensordict

    _get_names_idx = TensorDict._get_names_idx

    def _index_tensordict(self, index, new_batch_size=None, names=None):
        # we ignore the names and new_batch_size which are only provided for
        # efficiency purposes
        return self._get_sub_tensordict(index)

    def _remove_batch_dim(self, *args, **kwargs):
        raise NotImplementedError


###########################
# Keys utils


class _TensorDictKeysView:
    """A Key view for TensorDictBase instance.

    _TensorDictKeysView is returned when accessing tensordict.keys() and holds a
    reference to the original TensorDict. This class enables us to support nested keys
    when performing membership checks and when iterating over keys.

    Examples:
        >>> import torch
        >>> from tensordict import TensorDict

        >>> td = TensorDict(
        >>>     {"a": TensorDict({"b": torch.rand(1, 2)}, [1, 2]), "c": torch.rand(1)},
        >>>     [1],
        >>> )

        >>> assert "a" in td.keys()
        >>> assert ("a",) in td.keys()
        >>> assert ("a", "b") in td.keys()
        >>> assert ("a", "c") not in td.keys()

        >>> assert set(td.keys()) == {("a", "b"), "c"}
    """

    def __init__(
        self,
        tensordict: T,
        include_nested: bool,
        leaves_only: bool,
        is_leaf: Callable[[Type], bool] = None,
    ) -> None:
        self.tensordict = tensordict
        self.include_nested = include_nested
        self.leaves_only = leaves_only
        if is_leaf is None:
            is_leaf = _default_is_leaf
        self.is_leaf = is_leaf

    def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]:
        if not self.include_nested:
            if self.leaves_only:
                for key in self._keys():
                    target_class = self.tensordict.entry_class(key)
                    if _is_tensor_collection(target_class):
                        continue
                    yield key
            else:
                yield from self._keys()
        else:
            yield from (
                key if len(key) > 1 else key[0]
                for key in self._iter_helper(self.tensordict)
            )

    def _iter_helper(
        self, tensordict: T, prefix: str | None = None
    ) -> Iterable[str] | Iterable[tuple[str, ...]]:
        for key, value in self._items(tensordict):
            full_key = self._combine_keys(prefix, key)
            cls = type(value)
            while cls is list:
                # For lazy stacks
                value = value[0]
                cls = type(value)
            is_leaf = self.is_leaf(cls)
            if self.include_nested and not is_leaf:
                yield from self._iter_helper(value, prefix=full_key)
            if not self.leaves_only or is_leaf:
                yield full_key

    def _combine_keys(self, prefix: tuple | None, key: NestedKey) -> tuple:
        if prefix is not None:
            return prefix + (key,)
        return (key,)

    def __len__(self) -> int:
        return sum(1 for _ in self)

    def _items(
        self, tensordict: TensorDictBase | None = None
    ) -> Iterable[tuple[NestedKey, CompatibleType]]:
        if tensordict is None:
            tensordict = self.tensordict
        if isinstance(tensordict, TensorDict) or is_tensorclass(tensordict):
            return tensordict._tensordict.items()
        from tensordict.nn import TensorDictParams

        if isinstance(tensordict, TensorDictParams):
            return tensordict._param_td.items()
        if isinstance(tensordict, KeyedJaggedTensor):
            return tuple((key, tensordict[key]) for key in tensordict.keys())
        from tensordict._lazy import (
            _CustomOpTensorDict,
            _iter_items_lazystack,
            LazyStackedTensorDict,
        )

        if isinstance(tensordict, LazyStackedTensorDict):
            return _iter_items_lazystack(tensordict, return_none_for_het_values=True)
        if isinstance(tensordict, _CustomOpTensorDict):
            # it's possible that a TensorDict contains a nested LazyStackedTensorDict,
            # or _CustomOpTensorDict, so as we iterate through the contents we need to
            # be careful to not rely on tensordict._tensordict existing.
            return (
                (key, tensordict._get_str(key, NO_DEFAULT))
                for key in tensordict._source.keys()
            )
        raise NotImplementedError(type(tensordict))

    def _keys(self) -> _TensorDictKeysView:
        return self.tensordict._tensordict.keys()

    def __contains__(self, key: NestedKey) -> bool:
        key = _unravel_key_to_tuple(key)
        if not key:
            raise TypeError(_NON_STR_KEY_ERR)

        if isinstance(key, str):
            if key in self._keys():
                if self.leaves_only:
                    # TODO: make this faster for LazyStacked without compromising regular
                    return not _is_tensor_collection(
                        type(self.tensordict._get_str(key))
                    )
                return True
            return False
        else:
            # thanks to _unravel_key_to_tuple we know the key is a tuple
            if len(key) == 1:
                return key[0] in self._keys()
            elif self.include_nested:
                item_root = self.tensordict._get_str(key[0], default=None)
                if item_root is not None:
                    entry_type = type(item_root)
                    if issubclass(entry_type, Tensor):
                        return False
                    elif entry_type is KeyedJaggedTensor:
                        if len(key) > 2:
                            return False
                        return key[1] in item_root.keys()
                    # TODO: make this faster for LazyStacked without compromising regular
                    _is_tensordict = _is_tensor_collection(entry_type)
                    if _is_tensordict:
                        # # this will call _unravel_key_to_tuple many times
                        # return key[1:] in self.tensordict._get_str(key[0], NO_DEFAULT).keys(include_nested=self.include_nested)
                        # this won't call _unravel_key_to_tuple but requires to get the default which can be suboptimal
                        if len(key) >= 3:
                            leaf_td = item_root._get_tuple(key[1:-1], None)
                            if leaf_td is None or (
                                not _is_tensor_collection(type(leaf_td))
                                and not isinstance(leaf_td, KeyedJaggedTensor)
                            ):
                                return False
                        else:
                            leaf_td = item_root
                        return key[-1] in leaf_td.keys()
                return False
            # this is reached whenever there is more than one key but include_nested is False
            if all(isinstance(subkey, str) for subkey in key):
                raise TypeError(_NON_STR_KEY_TUPLE_ERR)

    def __repr__(self):
        include_nested = f"include_nested={self.include_nested}"
        leaves_only = f"leaves_only={self.leaves_only}"
        return f"{type(self).__name__}({list(self)},\n{indent(include_nested, 4 * ' ')},\n{indent(leaves_only, 4 * ' ')})"


def _set_tensor_dict(  # noqa: F811
    __dict__,
    _parameters,
    _buffers,
    hooks,
    module: torch.nn.Module,
    name: str,
    tensor: torch.Tensor,
    inplace: bool,
) -> None:
    """Simplified version of torch.nn.utils._named_member_accessor."""
    was_buffer = False
    out = _parameters.pop(name, None)  # type: ignore[assignment]
    if out is None:
        out = _buffers.pop(name, None)
        was_buffer = out is not None
    if out is None:
        out = __dict__.pop(name)
    if inplace:
        # swap tensor and out after updating out
        out_tmp = out.clone()
        out.data.copy_(tensor.data)
        tensor = out
        out = out_tmp

    if isinstance(tensor, torch.nn.Parameter):
        for hook in hooks:
            output = hook(module, name, tensor)
            if output is not None:
                tensor = output
        _parameters[name] = tensor

        if isinstance(tensor, UninitializedTensorMixin):
            module.register_forward_pre_hook(
                _add_batch_dim_pre_hook(), with_kwargs=True
            )

    elif was_buffer and isinstance(tensor, torch.Tensor):
        _buffers[name] = tensor
    else:
        __dict__[name] = tensor
    return out


def _index_to_str(index):
    if isinstance(index, tuple):
        return tuple(_index_to_str(elt) for elt in index)
    if isinstance(index, slice):
        return ("slice", {"start": index.start, "stop": index.stop, "step": index.step})
    if isinstance(index, range):
        return ("range", {"start": index.start, "stop": index.stop, "step": index.step})
    if isinstance(index, Tensor):
        return ("tensor", index.tolist(), str(index.device))
    return index


def _str_to_index(index):
    if isinstance(index, tuple):
        if not len(index):
            return index
        if index[0] == "slice":
            index = index[1]
            return slice(index["start"], index["stop"], index["step"])
        if index[0] == "range":
            index = index[1]
            return range(index["start"], index["stop"], index["step"])
        if index[0] == "tensor":
            index, device = index[1:]
            return torch.tensor(index, device=device)
        return tuple(_index_to_str(elt) for elt in index)
    return index


_register_tensor_class(TensorDict)
_register_tensor_class(_SubTensorDict)


def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None):
    """Saves the metadata of a memmap tensordict on disk."""
    filepath = prefix / "meta.json"
    if metadata is None:
        metadata = {}
    metadata.update(
        {
            "shape": list(data.shape),
            "device": str(data.device),
            "_type": str(type(data)),
        }
    )
    with open(filepath, "wb") as json_metadata:
        json_metadata.write(json.dumps(metadata))


# user did specify location and memmap is in wrong place, so we copy
def _populate_memmap(*, dest, value, key, copy_existing, prefix, like):
    filename = None if prefix is None else str(prefix / f"{key}.memmap")
    if value.is_nested:
        shape = value._nested_tensor_size()
        # Make the shape a memmap tensor too
        if prefix is not None:
            shape_filename = Path(filename)
            shape_filename = shape_filename.with_suffix(".shape.memmap")
            MemoryMappedTensor.from_tensor(
                shape,
                filename=shape_filename,
                copy_existing=copy_existing,
                existsok=True,
                copy_data=True,
            )
    else:
        shape = None
    memmap_tensor = MemoryMappedTensor.from_tensor(
        value.data if value.requires_grad else value,
        filename=filename,
        copy_existing=copy_existing,
        existsok=True,
        copy_data=not like,
        shape=shape,
    )
    dest._tensordict[key] = memmap_tensor
    return memmap_tensor


def _populate_empty(
    *,
    dest,
    key,
    shape,
    dtype,
    prefix,
):
    filename = None if prefix is None else str(prefix / f"{key}.memmap")
    if isinstance(shape, torch.Tensor):
        # Make the shape a memmap tensor too
        if prefix is not None:
            shape_filename = Path(filename)
            shape_filename = shape_filename.with_suffix(".shape.memmap")
            MemoryMappedTensor.from_tensor(
                shape,
                filename=shape_filename,
                existsok=True,
                copy_data=True,
            )
    memmap_tensor = MemoryMappedTensor.empty(
        shape=shape,
        dtype=dtype,
        filename=filename,
        existsok=True,
    )
    dest._tensordict[key] = memmap_tensor
    return memmap_tensor


def _populate_storage(
    *,
    dest,
    key,
    shape,
    dtype,
    prefix,
    storage,
):
    filename = None if prefix is None else str(prefix / f"{key}.memmap")
    if isinstance(shape, torch.Tensor):
        # Make the shape a memmap tensor too
        if prefix is not None:
            shape_filename = Path(filename)
            shape_filename = shape_filename.with_suffix(".shape.memmap")
            MemoryMappedTensor.from_tensor(
                shape,
                filename=shape_filename,
                existsok=True,
                copy_data=True,
            )
    memmap_tensor = MemoryMappedTensor.from_storage(
        storage=storage,
        shape=shape,
        dtype=dtype,
        filename=filename,
    )
    dest._tensordict[key] = memmap_tensor
    return memmap_tensor


def _update_metadata(*, metadata, key, value, is_collection):
    if not is_collection:
        metadata[key] = {
            "device": str(value.device),
            "shape": (
                list(value.shape)
                if not value.is_nested
                else list(value._nested_tensor_size().shape)
            ),
            "dtype": str(value.dtype),
            "is_nested": value.is_nested,
        }
    else:
        metadata[key] = {
            "type": type(value).__name__,
        }