Repository URL to install this package:
|
Version:
2.1.2+cpu ▾
|
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils._pytree import tree_flatten, tree_unflatten
def tree_map_(fn_, pytree):
flat_args, _ = tree_flatten(pytree)
[fn_(arg) for arg in flat_args]
return pytree
class PlaceHolder:
def __repr__(self):
return '*'
def treespec_pprint(spec):
leafs = [PlaceHolder() for _ in range(spec.num_leaves)]
result = tree_unflatten(leafs, spec)
return repr(result)