Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / core / observer.h

#pragma once

#include <memory>
#include <unordered_set>

#include "caffe2/core/logging.h"

namespace caffe2 {

/**
 *  Use this to implement a Observer using the Observer Pattern template.
 */

template <class T>
class ObserverBase {
 public:
  explicit ObserverBase(T* subject) : subject_(subject) {}

  virtual void Start() {}
  virtual void Stop() {}

  virtual std::string debugInfo() {
    return "Not implemented.";
  }

  virtual ~ObserverBase() noexcept {};

  T* subject() const {
    return subject_;
  }

  virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order)
      const {
    return nullptr;
  };

 protected:
  T* subject_;
};

/**
 *  Inherit to make your class observable.
 */
template <class T>
class Observable {
 public:
  Observable() = default;

  Observable(Observable&&) = default;
  Observable& operator =(Observable&&) = default;

  virtual ~Observable() = default;

  C10_DISABLE_COPY_AND_ASSIGN(Observable);

  using Observer = ObserverBase<T>;

  /* Returns a reference to the observer after addition. */
  const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
    CAFFE_ENFORCE(observer, "Couldn't attach a null observer.");
    std::unordered_set<const Observer*> observers;
    for (auto& ob : observers_list_) {
      observers.insert(ob.get());
    }

    const auto* observer_ptr = observer.get();
    if (observers.count(observer_ptr)) {
      return observer_ptr;
    }
    observers_list_.push_back(std::move(observer));
    UpdateCache();

    return observer_ptr;
  }

  /**
   * Returns a unique_ptr to the removed observer. If not found, return a
   * nullptr
   */
  std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) {
    for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) {
      if (it->get() == observer_ptr) {
        auto res = std::move(*it);
        observers_list_.erase(it);
        UpdateCache();
        return res;
      }
    }
    return nullptr;
  }

  virtual size_t NumObservers() {
    return num_observers_;
  }

 private:
  inline static void StartObserver(Observer* observer) {
    try {
      observer->Start();
    } catch (const std::exception& e) {
      LOG(ERROR) << "Exception from observer: " << e.what();
    } catch (...) {
      LOG(ERROR) << "Exception from observer: unknown";
    }
  }

  inline static void StopObserver(Observer* observer) {
    try {
      observer->Stop();
    } catch (const std::exception& e) {
      LOG(ERROR) << "Exception from observer: " << e.what();
    } catch (...) {
      LOG(ERROR) << "Exception from observer: unknown";
    }
  }

  void UpdateCache() {
    num_observers_ = observers_list_.size();
    if (num_observers_ != 1) {
      // we cannot take advantage of the cache
      return;
    }
    observer_cache_ = observers_list_[0].get();
  }

 public:
  void StartAllObservers() {
    // do not access observers_list_ unless necessary
    if (num_observers_ == 0) {
      return;
    } else if (num_observers_ == 1) {
      StartObserver(observer_cache_);
    } else {
      for (auto& observer : observers_list_) {
        StartObserver(observer.get());
      }
    }
  }

  void StopAllObservers() {
    // do not access observers_list_ unless necessary
    if (num_observers_ == 0) {
      return;
    } else if (num_observers_ == 1) {
      StopObserver(observer_cache_);
    } else {
      for (auto& observer : observers_list_) {
        StopObserver(observer.get());
      }
    }
  }

 private:
  // an on-stack cache for fast iteration;
  // ideally, inside StartAllObservers and StopAllObservers,
  // we should never access observers_list_
  Observer* observer_cache_;
  size_t num_observers_ = 0;

 protected:
  std::vector<std::unique_ptr<Observer>> observers_list_;
};

} // namespace caffe2