Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / cnn.py

## @package cnn
# Module caffe2.python.cnn





from caffe2.python import brew, workspace
from caffe2.python.model_helper import ModelHelper
from caffe2.proto import caffe2_pb2
import logging


class CNNModelHelper(ModelHelper):
    """A helper model so we can write CNN models more easily, without having to
    manually define parameter initializations and operators separately.
    """

    def __init__(self, order="NCHW", name=None,
                 use_cudnn=True, cudnn_exhaustive_search=False,
                 ws_nbytes_limit=None, init_params=True,
                 skip_sparse_optim=False,
                 param_model=None):
        logging.warning(
            "[====DEPRECATE WARNING====]: you are creating an "
            "object from CNNModelHelper class which will be deprecated soon. "
            "Please use ModelHelper object with brew module. For more "
            "information, please refer to caffe2.ai and python/brew.py, "
            "python/brew_test.py for more information."
        )

        cnn_arg_scope = {
            'order': order,
            'use_cudnn': use_cudnn,
            'cudnn_exhaustive_search': cudnn_exhaustive_search,
        }
        if ws_nbytes_limit:
            cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
        super(CNNModelHelper, self).__init__(
            skip_sparse_optim=skip_sparse_optim,
            name="CNN" if name is None else name,
            init_params=init_params,
            param_model=param_model,
            arg_scope=cnn_arg_scope,
        )

        self.order = order
        self.use_cudnn = use_cudnn
        self.cudnn_exhaustive_search = cudnn_exhaustive_search
        self.ws_nbytes_limit = ws_nbytes_limit
        if self.order != "NHWC" and self.order != "NCHW":
            raise ValueError(
                "Cannot understand the CNN storage order %s." % self.order
            )

    def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
        return brew.image_input(
            self,
            blob_in,
            blob_out,
            order=self.order,
            use_gpu_transform=use_gpu_transform,
            **kwargs
        )

    def VideoInput(self, blob_in, blob_out, **kwargs):
        return brew.video_input(
            self,
            blob_in,
            blob_out,
            **kwargs
        )

    def PadImage(self, blob_in, blob_out, **kwargs):
        # TODO(wyiming): remove this dummy helper later
        self.net.PadImage(blob_in, blob_out, **kwargs)

    def ConvNd(self, *args, **kwargs):
        return brew.conv_nd(
            self,
            *args,
            use_cudnn=self.use_cudnn,
            order=self.order,
            cudnn_exhaustive_search=self.cudnn_exhaustive_search,
            ws_nbytes_limit=self.ws_nbytes_limit,
            **kwargs
        )

    def Conv(self, *args, **kwargs):
        return brew.conv(
            self,
            *args,
            use_cudnn=self.use_cudnn,
            order=self.order,
            cudnn_exhaustive_search=self.cudnn_exhaustive_search,
            ws_nbytes_limit=self.ws_nbytes_limit,
            **kwargs
        )

    def ConvTranspose(self, *args, **kwargs):
        return brew.conv_transpose(
            self,
            *args,
            use_cudnn=self.use_cudnn,
            order=self.order,
            cudnn_exhaustive_search=self.cudnn_exhaustive_search,
            ws_nbytes_limit=self.ws_nbytes_limit,
            **kwargs
        )

    def GroupConv(self, *args, **kwargs):
        return brew.group_conv(
            self,
            *args,
            use_cudnn=self.use_cudnn,
            order=self.order,
            cudnn_exhaustive_search=self.cudnn_exhaustive_search,
            ws_nbytes_limit=self.ws_nbytes_limit,
            **kwargs
        )

    def GroupConv_Deprecated(self, *args, **kwargs):
        return brew.group_conv_deprecated(
            self,
            *args,
            use_cudnn=self.use_cudnn,
            order=self.order,
            cudnn_exhaustive_search=self.cudnn_exhaustive_search,
            ws_nbytes_limit=self.ws_nbytes_limit,
            **kwargs
        )

    def FC(self, *args, **kwargs):
        return brew.fc(self, *args, **kwargs)

    def PackedFC(self, *args, **kwargs):
        return brew.packed_fc(self, *args, **kwargs)

    def FC_Prune(self, *args, **kwargs):
        return brew.fc_prune(self, *args, **kwargs)

    def FC_Decomp(self, *args, **kwargs):
        return brew.fc_decomp(self, *args, **kwargs)

    def FC_Sparse(self, *args, **kwargs):
        return brew.fc_sparse(self, *args, **kwargs)

    def Dropout(self, *args, **kwargs):
        return brew.dropout(
            self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
        )

    def LRN(self, *args, **kwargs):
        return brew.lrn(
            self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
        )

    def Softmax(self, *args, **kwargs):
        return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)

    def SpatialBN(self, *args, **kwargs):
        return brew.spatial_bn(self, *args, order=self.order, **kwargs)

    def SpatialGN(self, *args, **kwargs):
        return brew.spatial_gn(self, *args, order=self.order, **kwargs)

    def InstanceNorm(self, *args, **kwargs):
        return brew.instance_norm(self, *args, order=self.order, **kwargs)

    def Relu(self, *args, **kwargs):
        return brew.relu(
            self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
        )

    def PRelu(self, *args, **kwargs):
        return brew.prelu(self, *args, **kwargs)

    def Concat(self, *args, **kwargs):
        return brew.concat(self, *args, order=self.order, **kwargs)

    def DepthConcat(self, *args, **kwargs):
        """The old depth concat function - we should move to use concat."""
        print("DepthConcat is deprecated. use Concat instead.")
        return self.Concat(*args, **kwargs)

    def Sum(self, *args, **kwargs):
        return brew.sum(self, *args, **kwargs)

    def Transpose(self, *args, **kwargs):
        return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)

    def Iter(self, *args, **kwargs):
        return brew.iter(self, *args, **kwargs)

    def Accuracy(self, *args, **kwargs):
        return brew.accuracy(self, *args, **kwargs)

    def MaxPool(self, *args, **kwargs):
        return brew.max_pool(
            self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
        )

    def MaxPoolWithIndex(self, *args, **kwargs):
        return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)

    def AveragePool(self, *args, **kwargs):
        return brew.average_pool(
            self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
        )

    @property
    def XavierInit(self):
        return ('XavierFill', {})

    def ConstantInit(self, value):
        return ('ConstantFill', dict(value=value))

    @property
    def MSRAInit(self):
        return ('MSRAFill', {})

    @property
    def ZeroInit(self):
        return ('ConstantFill', {})

    def AddWeightDecay(self, weight_decay):
        return brew.add_weight_decay(self, weight_decay)

    @property
    def CPU(self):
        device_option = caffe2_pb2.DeviceOption()
        device_option.device_type = caffe2_pb2.CPU
        return device_option

    @property
    def GPU(self, gpu_id=0):
        device_option = caffe2_pb2.DeviceOption()
        device_option.device_type = workspace.GpuDeviceType
        device_option.device_id = gpu_id
        return device_option