Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ experiments / python / SparseTransformer.py

# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

## @package SparseTransformer
# Module caffe2.experiments.python.SparseTransformer




from caffe2.python import workspace
import scipy.sparse


class NetDefNode():

    def __init__(self, name, optype, p=None, op=None):
        self.name = name
        self.optype = optype
        self.ops = {}
        self.prev = {}
        self.insertInput(p)
        self.visited = False
        self.op = op

    def insertInput(self, p):
        """
        Insert input of this op
        also maintain the output of previous op
        p: a node or a list of node
        """
        if isinstance(p, list):
            for i in p:
                self.prev[i.name] = i
                i.ops[self.name] = self
        elif isinstance(p, NetDefNode):
            self.prev[p.name] = p
            p.ops[self.name] = self

    def deleteInput(self, p):
        if isinstance(p, NetDefNode):
            del self.prev[p.name]
            del p.ops[self.name]


def maskNallocate(weight_name):
    """
    Combine mask and weights
    create wcsr, iw, jw, return their names
    """
    w = workspace.FetchBlob(weight_name)
    w_csr = scipy.sparse.csr_matrix(w)
    wcsr = w_csr.data
    iw = w_csr.indptr
    jw = w_csr.indices
    workspace.FeedBlob(weight_name + "wcsr", wcsr)
    workspace.FeedBlob(weight_name + "iw", iw)
    workspace.FeedBlob(weight_name + "jw", jw)
    return weight_name + "wcsr", weight_name + "iw", weight_name + "jw"


def transFCRelu(cur, id2node, name2id, ops, model):
    """
    Add trans before and after this FC_Prune->(Relu)->FC_Prune chain.
    """
    # 1. add trans before the start of this chain
    # assuming that cur is a FC_Prune, and it has only one input
    pre = cur.prev.itervalues().next()
    # Create a node /op and insert it.
    # TODO(wyiming): check whether it is correct here
    current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] + "_trans")
#     print model.net.Proto()
    trans_op = model.net.Proto().op[-1]
    trans_node = NetDefNode(trans_op.output[0], "Transpose", pre, trans_op)
    trans_node.visited = True
    pre_new = trans_node

    # 2. use while loop to visit the chain
    while True:
        # breakup with the parent
        cur.deleteInput(pre)
        if not (cur.optype == "FC_Prune" or cur.optype == "Relu"):
            print("Reaching the end of the chain")
            break
        if len(cur.ops) > 1:
            print("A FC/Relu giving more than 1 useful outputs")
        if cur.optype == "FC_Prune":
            op = cur.op
            wcsr, iw, jw = maskNallocate(op.input[1])
            bias_name = op.input[3]
            # TODO(wyiming): create a new Op here
            current_blob = model.FC_Sparse(current_blob,
                                           cur.op.output[0] + "_Sparse",
                                           wcsr, iw, jw, bias_name)
            sps_op = model.net.Proto().op[-1]
            sps_node = NetDefNode(cur.op.output[0] + "_Sparse",
                                  "FC_Sparse",
                                  pre_new, sps_op)
            sps_node.visited = True
            pre_new = sps_node
        if cur.optype == "Relu":
            op = cur.op
            current_blob = model.Relu(current_blob, current_blob)
            rel_op = model.net.Proto().op[-1]
            rel_node = NetDefNode(str(current_blob), "Relu",
                                  pre_new, rel_op)
            rel_node.visited = True
            pre_new = rel_node

        cur.visited = True
        pre = cur
        flag = False
        for _, temp in cur.ops.iteritems():
            if temp.optype == "Relu" or temp.optype == "FC_Prune":
                flag = True
                cur = temp
        if not flag:
            # assume that there is only 1 output that is not PrintOP
            cur = cur.ops.itervalues().next()
            cur.deleteInput(pre)
            print("No FC/RElu children")
            print(cur.op.type)
            break
    # 3. add trans after this chain like 1.
    current_blob = model.Transpose(current_blob, pre.op.output[0])
    trans_op = model.net.Proto().op[-1]
    trans_node = NetDefNode(str(current_blob), "Transpose", pre_new, trans_op)
    trans_node.visited = True
    cur.insertInput(trans_node)
    print(cur.prev)
    print(trans_node.ops)


def Prune2Sparse(cur, id2node, name2id, ops, model):
    # Assume that FC and Relu takes in only 1 input;
    # If not raise warning
    if not cur.visited and cur.optype == "FC_Prune":
        transFCRelu(cur, id2node, name2id, ops, model)

    cur.visited = True
    for name, n in cur.ops.iteritems():
        Prune2Sparse(n, id2node, name2id, ops, model)


def net2list(net_root):
    """
    Use topological order(BFS) to print the op of a net in a list
    """
    bfs_queue = []
    op_list = []
    cur = net_root
    for _, n in cur.ops.iteritems():
        bfs_queue.append(n)
    while bfs_queue:
        node = bfs_queue[0]
        bfs_queue = bfs_queue[1:]
        op_list.append(node.op)
        for _, n in node.ops.iteritems():
            bfs_queue.append(n)

    return op_list


def netbuilder(model):
    print("Welcome to model checker")
    proto = model.net.Proto()
    net_name2id = {}
    net_id2node = {}
    net_root = NetDefNode("net_root", "root", None)

    for op_id, op in enumerate(proto.op):
        if op.type == "Print":
            continue
        op_name = '%s/%s (op#%d)' % (op.name, op.type, op_id) \
                  if op.name else '%s (op#%d)' % (op.type, op_id)
        # print(op_name)
        op_node = NetDefNode(op_name, op.type, op=op)
        net_id2node[op_id] = op_node

        if_has_layer_input = False
        for input_name in op.input:
            if input_name not in net_name2id:
                # assume that un_occured name are non_layers
                # TODO: write a non-layer checker and log it
                continue
            op_node.insertInput(net_id2node[net_name2id[input_name]])
            if_has_layer_input = True

        if not if_has_layer_input:
            op_node.insertInput(net_root)

        for output_name in op.output:
            net_name2id[output_name] = op_id

    return net_root, net_name2id, net_id2node