# @package adaptive_weight
# Module caffe2.fb.python.layers.adaptive_weight
import numpy as np
from caffe2.python import core, schema
from caffe2.python.layers.layers import ModelLayer
from caffe2.python.regularizer import BoundedGradientProjection, LogBarrier
"""
Implementation of adaptive weighting: https://arxiv.org/pdf/1705.07115.pdf
"""
class AdaptiveWeight(ModelLayer):
def __init__(
self,
model,
input_record,
name="adaptive_weight",
optimizer=None,
weights=None,
enable_diagnose=False,
estimation_method="log_std",
pos_optim_method="log_barrier",
reg_lambda=0.1,
**kwargs
):
super(AdaptiveWeight, self).__init__(model, name, input_record, **kwargs)
self.output_schema = schema.Scalar(
np.float32, self.get_next_blob_reference("adaptive_weight")
)
self.data = self.input_record.field_blobs()
self.num = len(self.data)
self.optimizer = optimizer
if weights is not None:
assert len(weights) == self.num
else:
weights = [1. / self.num for _ in range(self.num)]
assert min(weights) > 0, "initial weights must be positive"
self.weights = np.array(weights).astype(np.float32)
self.estimation_method = str(estimation_method).lower()
# used in positivity-constrained parameterization as when the estimation method
# is inv_var, with optimization method being either log barrier, or grad proj
self.pos_optim_method = str(pos_optim_method).lower()
self.reg_lambda = float(reg_lambda)
self.enable_diagnose = enable_diagnose
self.init_func = getattr(self, self.estimation_method + "_init")
self.weight_func = getattr(self, self.estimation_method + "_weight")
self.reg_func = getattr(self, self.estimation_method + "_reg")
self.init_func()
if self.enable_diagnose:
self.weight_i = [
self.get_next_blob_reference("adaptive_weight_%d" % i)
for i in range(self.num)
]
for i in range(self.num):
self.model.add_ad_hoc_plot_blob(self.weight_i[i])
def concat_data(self, net):
reshaped = [net.NextScopedBlob("reshaped_data_%d" % i) for i in range(self.num)]
# coerce shape for single real values
for i in range(self.num):
net.Reshape(
[self.data[i]],
[reshaped[i], net.NextScopedBlob("new_shape_%d" % i)],
shape=[1],
)
concated = net.NextScopedBlob("concated_data")
net.Concat(
reshaped, [concated, net.NextScopedBlob("concated_new_shape")], axis=0
)
return concated
def log_std_init(self):
"""
mu = 2 log sigma, sigma = standard variance
per task objective:
min 1 / 2 / e^mu X + mu / 2
"""
values = np.log(1. / 2. / self.weights)
initializer = (
"GivenTensorFill",
{"values": values, "dtype": core.DataType.FLOAT},
)
self.mu = self.create_param(
param_name="mu",
shape=[self.num],
initializer=initializer,
optimizer=self.optimizer,
)
def log_std_weight(self, x, net, weight):
"""
min 1 / 2 / e^mu X + mu / 2
"""
mu_neg = net.NextScopedBlob("mu_neg")
net.Negative(self.mu, mu_neg)
mu_neg_exp = net.NextScopedBlob("mu_neg_exp")
net.Exp(mu_neg, mu_neg_exp)
net.Scale(mu_neg_exp, weight, scale=0.5)
def log_std_reg(self, net, reg):
net.Scale(self.mu, reg, scale=0.5)
def inv_var_init(self):
"""
k = 1 / variance
per task objective:
min 1 / 2 * k X - 1 / 2 * log k
"""
values = 2. * self.weights
initializer = (
"GivenTensorFill",
{"values": values, "dtype": core.DataType.FLOAT},
)
if self.pos_optim_method == "log_barrier":
regularizer = LogBarrier(reg_lambda=self.reg_lambda)
elif self.pos_optim_method == "pos_grad_proj":
regularizer = BoundedGradientProjection(lb=0, left_open=True)
else:
raise TypeError(
"unknown positivity optimization method: {}".format(
self.pos_optim_method
)
)
self.k = self.create_param(
param_name="k",
shape=[self.num],
initializer=initializer,
optimizer=self.optimizer,
regularizer=regularizer,
)
def inv_var_weight(self, x, net, weight):
net.Scale(self.k, weight, scale=0.5)
def inv_var_reg(self, net, reg):
log_k = net.NextScopedBlob("log_k")
net.Log(self.k, log_k)
net.Scale(log_k, reg, scale=-0.5)
def _add_ops_impl(self, net, enable_diagnose):
x = self.concat_data(net)
weight = net.NextScopedBlob("weight")
reg = net.NextScopedBlob("reg")
weighted_x = net.NextScopedBlob("weighted_x")
weighted_x_add_reg = net.NextScopedBlob("weighted_x_add_reg")
self.weight_func(x, net, weight)
self.reg_func(net, reg)
net.Mul([weight, x], weighted_x)
net.Add([weighted_x, reg], weighted_x_add_reg)
net.SumElements(weighted_x_add_reg, self.output_schema())
if enable_diagnose:
for i in range(self.num):
net.Slice(weight, self.weight_i[i], starts=[i], ends=[i + 1])
def add_ops(self, net):
self._add_ops_impl(net, self.enable_diagnose)