# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
## @package SparseTransformer
# Module caffe2.experiments.python.SparseTransformer
from caffe2.python import workspace
import scipy.sparse
class NetDefNode():
def __init__(self, name, optype, p=None, op=None):
self.name = name
self.optype = optype
self.ops = {}
self.prev = {}
self.insertInput(p)
self.visited = False
self.op = op
def insertInput(self, p):
"""
Insert input of this op
also maintain the output of previous op
p: a node or a list of node
"""
if isinstance(p, list):
for i in p:
self.prev[i.name] = i
i.ops[self.name] = self
elif isinstance(p, NetDefNode):
self.prev[p.name] = p
p.ops[self.name] = self
def deleteInput(self, p):
if isinstance(p, NetDefNode):
del self.prev[p.name]
del p.ops[self.name]
def maskNallocate(weight_name):
"""
Combine mask and weights
create wcsr, iw, jw, return their names
"""
w = workspace.FetchBlob(weight_name)
w_csr = scipy.sparse.csr_matrix(w)
wcsr = w_csr.data
iw = w_csr.indptr
jw = w_csr.indices
workspace.FeedBlob(weight_name + "wcsr", wcsr)
workspace.FeedBlob(weight_name + "iw", iw)
workspace.FeedBlob(weight_name + "jw", jw)
return weight_name + "wcsr", weight_name + "iw", weight_name + "jw"
def transFCRelu(cur, id2node, name2id, ops, model):
"""
Add trans before and after this FC_Prune->(Relu)->FC_Prune chain.
"""
# 1. add trans before the start of this chain
# assuming that cur is a FC_Prune, and it has only one input
pre = cur.prev.itervalues().next()
# Create a node /op and insert it.
# TODO(wyiming): check whether it is correct here
current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] + "_trans")
# print model.net.Proto()
trans_op = model.net.Proto().op[-1]
trans_node = NetDefNode(trans_op.output[0], "Transpose", pre, trans_op)
trans_node.visited = True
pre_new = trans_node
# 2. use while loop to visit the chain
while True:
# breakup with the parent
cur.deleteInput(pre)
if not (cur.optype == "FC_Prune" or cur.optype == "Relu"):
print("Reaching the end of the chain")
break
if len(cur.ops) > 1:
print("A FC/Relu giving more than 1 useful outputs")
if cur.optype == "FC_Prune":
op = cur.op
wcsr, iw, jw = maskNallocate(op.input[1])
bias_name = op.input[3]
# TODO(wyiming): create a new Op here
current_blob = model.FC_Sparse(current_blob,
cur.op.output[0] + "_Sparse",
wcsr, iw, jw, bias_name)
sps_op = model.net.Proto().op[-1]
sps_node = NetDefNode(cur.op.output[0] + "_Sparse",
"FC_Sparse",
pre_new, sps_op)
sps_node.visited = True
pre_new = sps_node
if cur.optype == "Relu":
op = cur.op
current_blob = model.Relu(current_blob, current_blob)
rel_op = model.net.Proto().op[-1]
rel_node = NetDefNode(str(current_blob), "Relu",
pre_new, rel_op)
rel_node.visited = True
pre_new = rel_node
cur.visited = True
pre = cur
flag = False
for _, temp in cur.ops.iteritems():
if temp.optype == "Relu" or temp.optype == "FC_Prune":
flag = True
cur = temp
if not flag:
# assume that there is only 1 output that is not PrintOP
cur = cur.ops.itervalues().next()
cur.deleteInput(pre)
print("No FC/RElu children")
print(cur.op.type)
break
# 3. add trans after this chain like 1.
current_blob = model.Transpose(current_blob, pre.op.output[0])
trans_op = model.net.Proto().op[-1]
trans_node = NetDefNode(str(current_blob), "Transpose", pre_new, trans_op)
trans_node.visited = True
cur.insertInput(trans_node)
print(cur.prev)
print(trans_node.ops)
def Prune2Sparse(cur, id2node, name2id, ops, model):
# Assume that FC and Relu takes in only 1 input;
# If not raise warning
if not cur.visited and cur.optype == "FC_Prune":
transFCRelu(cur, id2node, name2id, ops, model)
cur.visited = True
for name, n in cur.ops.iteritems():
Prune2Sparse(n, id2node, name2id, ops, model)
def net2list(net_root):
"""
Use topological order(BFS) to print the op of a net in a list
"""
bfs_queue = []
op_list = []
cur = net_root
for _, n in cur.ops.iteritems():
bfs_queue.append(n)
while bfs_queue:
node = bfs_queue[0]
bfs_queue = bfs_queue[1:]
op_list.append(node.op)
for _, n in node.ops.iteritems():
bfs_queue.append(n)
return op_list
def netbuilder(model):
print("Welcome to model checker")
proto = model.net.Proto()
net_name2id = {}
net_id2node = {}
net_root = NetDefNode("net_root", "root", None)
for op_id, op in enumerate(proto.op):
if op.type == "Print":
continue
op_name = '%s/%s (op#%d)' % (op.name, op.type, op_id) \
if op.name else '%s (op#%d)' % (op.type, op_id)
# print(op_name)
op_node = NetDefNode(op_name, op.type, op=op)
net_id2node[op_id] = op_node
if_has_layer_input = False
for input_name in op.input:
if input_name not in net_name2id:
# assume that un_occured name are non_layers
# TODO: write a non-layer checker and log it
continue
op_node.insertInput(net_id2node[net_name2id[input_name]])
if_has_layer_input = True
if not if_has_layer_input:
op_node.insertInput(net_root)
for output_name in op.output:
net_name2id[output_name] = op_id
return net_root, net_name2id, net_id2node