Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ include / caffe2 / observers / profile_observer.h

/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <unordered_map>

#include "caffe2/core/common.h"
#include "caffe2/core/event.h"
#include "caffe2/core/net.h"
#include "caffe2/core/observer.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/timer.h"
#include "caffe2/observers/operator_attaching_net_observer.h"

namespace caffe2 {

/**
 * This observer displays a description of each operator executed in a network.
 * This includes input and tensors (name, size, type), arguments, and execution
 * time. This can be used to analyze different performance characteristics.
 * NOTE: Currently this observer only supports synchronized computation
 **/

class ProfileObserver;
class ProfileCounter {
 public:
  explicit ProfileCounter() {}

 protected:
  Timer timer_;
  float start_time_ = 0.0f;
  float run_time_ = 0.0f;
};

class TORCH_API ProfileOperatorObserver final
    : public ProfileCounter,
      public ObserverBase<OperatorBase> {
 public:
  explicit ProfileOperatorObserver(OperatorBase* subject) = delete;
  explicit ProfileOperatorObserver(
      OperatorBase* subject,
      ProfileObserver* netObserver)
      : ObserverBase<OperatorBase>(subject), netObserver_(netObserver) {
    if (subject) {
      net_position_ = subject->net_position();
    }
  }
  explicit ProfileOperatorObserver(
      OperatorBase* subject,
      ProfileObserver* netObserver,
      int net_position,
      int rnn_order)
      : ProfileOperatorObserver(subject, netObserver) {
    net_position_ = net_position;
    rnn_order_ = rnn_order;
  }

  std::unique_ptr<ObserverBase<OperatorBase>> rnnCopy(
      OperatorBase* subject,
      int rnn_order) const override;

  void Dump() const;

  virtual std::string getId() const {
    std::stringstream ss;
    ss << net_position_;
    if (rnn_order_ != OperatorBase::kNoNetPositionSet) {
      ss << "-" << rnn_order_;
    }
    return ss.str();
  }

 protected:
  ProfileObserver* netObserver_;
  int net_position_; // Needed because this is not visible in RNN Executor
  int rnn_order_ = OperatorBase::kNoNetPositionSet;

 private:
  void Start() override;
  void Stop() override;
};

class TORCH_API ProfileObserver final : public OperatorAttachingNetObserver<
                                             ProfileOperatorObserver,
                                             ProfileObserver> {
 public:
  explicit ProfileObserver(NetBase* subject)
      : OperatorAttachingNetObserver<ProfileOperatorObserver, ProfileObserver>(
            subject,
            this) {}

  void Start() override{};
  void Stop() override{};

 private:
  vector<const ProfileOperatorObserver*> operator_observers_;
};

} // namespace caffe2