# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils._pytree import tree_flatten, tree_unflatten
def tree_map_(fn_, pytree):
flat_args, _ = tree_flatten(pytree)
[fn_(arg) for arg in flat_args]
return pytree
class PlaceHolder():
def __repr__(self):
return '*'
def treespec_pprint(spec):
leafs = [PlaceHolder() for _ in range(spec.num_leaves)]
result = tree_unflatten(leafs, spec)
return repr(result)