Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / hsm_util.py

## @package hsm_util
# Module caffe2.python.hsm_util





from caffe2.proto import hsm_pb2

'''
    Hierarchical softmax utility methods that can be used to:
    1) create TreeProto structure given list of word_ids or NodeProtos
    2) create HierarchyProto structure using the user-inputted TreeProto
'''


def create_node_with_words(words, name='node'):
    node = hsm_pb2.NodeProto()
    node.name = name
    for word in words:
        node.word_ids.append(word)
    return node


def create_node_with_nodes(nodes, name='node'):
    node = hsm_pb2.NodeProto()
    node.name = name
    for child_node in nodes:
        new_child_node = node.children.add()
        new_child_node.MergeFrom(child_node)
    return node


def create_hierarchy(tree_proto):
    max_index = 0

    def create_path(path, word):
        path_proto = hsm_pb2.PathProto()
        path_proto.word_id = word
        for entry in path:
            new_path_node = path_proto.path_nodes.add()
            new_path_node.index = entry[0]
            new_path_node.length = entry[1]
            new_path_node.target = entry[2]
        return path_proto

    def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
        node_proto.offset = max_index
        path.append([max_index,
                    len(node_proto.word_ids) + len(node_proto.children), 0])
        max_index += len(node_proto.word_ids) + len(node_proto.children)
        if hierarchy_proto.size < max_index:
            hierarchy_proto.size = max_index
        for target, node in enumerate(node_proto.children):
            path[-1][2] = target
            max_index = recursive_path_builder(node, path, hierarchy_proto,
                                               max_index)
        for target, word in enumerate(node_proto.word_ids):
            path[-1][2] = target + len(node_proto.children)
            path_entry = create_path(path, word)
            new_path_entry = hierarchy_proto.paths.add()
            new_path_entry.MergeFrom(path_entry)
        del path[-1]
        return max_index

    node = tree_proto.root_node
    hierarchy_proto = hsm_pb2.HierarchyProto()
    path = []
    max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
    return hierarchy_proto