## @package beam_search
# Module caffe2.python.models.seq2seq.beam_search
from collections import namedtuple
from caffe2.python import core
import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper
class BeamSearchForwardOnly(object):
"""
Class generalizing forward beam search for seq2seq models.
Also provides types to specify the recurrent structure of decoding:
StateConfig:
initial_value: blob providing value of state at first step_model
state_prev_link: LinkConfig describing how recurrent step receives
input from global state blob in each step
state_link: LinkConfig describing how step writes (produces new state)
to global state blob in each step
LinkConfig:
blob: blob connecting global state blob to step application
offset: offset from beginning of global blob for link in time dimension
window: width of global blob to read/write in time dimension
"""
LinkConfig = namedtuple('LinkConfig', ['blob', 'offset', 'window'])
StateConfig = namedtuple(
'StateConfig',
['initial_value', 'state_prev_link', 'state_link'],
)
def __init__(
self,
beam_size,
model,
eos_token_id,
go_token_id=seq2seq_util.GO_ID,
post_eos_penalty=None,
):
self.beam_size = beam_size
self.model = model
self.step_model = Seq2SeqModelHelper(
name='step_model',
param_model=self.model,
)
self.go_token_id = go_token_id
self.eos_token_id = eos_token_id
self.post_eos_penalty = post_eos_penalty
(
self.timestep,
self.scores_t_prev,
self.tokens_t_prev,
self.hypo_t_prev,
self.attention_t_prev,
) = self.step_model.net.AddExternalInputs(
'timestep',
'scores_t_prev',
'tokens_t_prev',
'hypo_t_prev',
'attention_t_prev',
)
tokens_t_prev_int32 = self.step_model.net.Cast(
self.tokens_t_prev,
'tokens_t_prev_int32',
to=core.DataType.INT32,
)
self.tokens_t_prev_int32_flattened, _ = self.step_model.net.Reshape(
[tokens_t_prev_int32],
[tokens_t_prev_int32, 'input_t_int32_old_shape'],
shape=[1, -1],
)
def get_step_model(self):
return self.step_model
def get_previous_tokens(self):
return self.tokens_t_prev_int32_flattened
def get_timestep(self):
return self.timestep
# TODO: make attentions a generic state
# data_dependencies is a list of blobs that the operator should wait for
# before beginning execution. This ensures that ops are run in the correct
# order when the RecurrentNetwork op is embedded in a DAGNet, for ex.
def apply(
self,
inputs,
length,
log_probs,
attentions,
state_configs,
data_dependencies,
word_rewards=None,
possible_translation_tokens=None,
go_token_id=None,
):
ZERO = self.model.param_init_net.ConstantFill(
[],
'ZERO',
shape=[1],
value=0,
dtype=core.DataType.INT32,
)
on_initial_step = self.step_model.net.EQ(
[ZERO, self.timestep],
'on_initial_step',
)
if self.post_eos_penalty is not None:
eos_token = self.model.param_init_net.ConstantFill(
[],
'eos_token',
shape=[self.beam_size],
value=self.eos_token_id,
dtype=core.DataType.INT32,
)
finished_penalty = self.model.param_init_net.ConstantFill(
[],
'finished_penalty',
shape=[1],
value=float(self.post_eos_penalty),
dtype=core.DataType.FLOAT,
)
ZERO_FLOAT = self.model.param_init_net.ConstantFill(
[],
'ZERO_FLOAT',
shape=[1],
value=0.0,
dtype=core.DataType.FLOAT,
)
finished_penalty = self.step_model.net.Conditional(
[on_initial_step, ZERO_FLOAT, finished_penalty],
'possible_finished_penalty',
)
tokens_t_flat = self.step_model.net.FlattenToVec(
self.tokens_t_prev,
'tokens_t_flat',
)
tokens_t_flat_int = self.step_model.net.Cast(
tokens_t_flat,
'tokens_t_flat_int',
to=core.DataType.INT32,
)
predecessor_is_eos = self.step_model.net.EQ(
[tokens_t_flat_int, eos_token],
'predecessor_is_eos',
)
predecessor_is_eos_float = self.step_model.net.Cast(
predecessor_is_eos,
'predecessor_is_eos_float',
to=core.DataType.FLOAT,
)
predecessor_is_eos_penalty = self.step_model.net.Mul(
[predecessor_is_eos_float, finished_penalty],
'predecessor_is_eos_penalty',
broadcast=1,
)
log_probs = self.step_model.net.Add(
[log_probs, predecessor_is_eos_penalty],
'log_probs_penalized',
broadcast=1,
axis=0,
)
# [beam_size, beam_size]
best_scores_per_hypo, best_tokens_per_hypo = self.step_model.net.TopK(
log_probs,
['best_scores_per_hypo', 'best_tokens_per_hypo_indices'],
k=self.beam_size,
)
if possible_translation_tokens:
# [beam_size, beam_size]
best_tokens_per_hypo = self.step_model.net.Gather(
[possible_translation_tokens, best_tokens_per_hypo],
['best_tokens_per_hypo']
)
# [beam_size]
scores_t_prev_squeezed, _ = self.step_model.net.Reshape(
self.scores_t_prev,
['scores_t_prev_squeezed', 'scores_t_prev_old_shape'],
shape=[self.beam_size],
)
# [beam_size, beam_size]
output_scores = self.step_model.net.Add(
[best_scores_per_hypo, scores_t_prev_squeezed],
'output_scores',
broadcast=1,
axis=0,
)
if word_rewards is not None:
# [beam_size, beam_size]
word_rewards_for_best_tokens_per_hypo = self.step_model.net.Gather(
[word_rewards, best_tokens_per_hypo],
'word_rewards_for_best_tokens_per_hypo',
)
# [beam_size, beam_size]
output_scores = self.step_model.net.Add(
[output_scores, word_rewards_for_best_tokens_per_hypo],
'output_scores',
)
# [beam_size * beam_size]
output_scores_flattened, _ = self.step_model.net.Reshape(
[output_scores],
[output_scores, 'output_scores_old_shape'],
shape=[-1],
)
MINUS_ONE_INT32 = self.model.param_init_net.ConstantFill(
[],
'MINUS_ONE_INT32',
value=-1,
shape=[1],
dtype=core.DataType.INT32,
)
BEAM_SIZE = self.model.param_init_net.ConstantFill(
[],
'beam_size',
shape=[1],
value=self.beam_size,
dtype=core.DataType.INT32,
)
# current_beam_size (predecessor states from previous step)
# is 1 on first step (so we just need beam_size scores),
# and beam_size subsequently (so we need all beam_size * beam_size
# scores)
slice_end = self.step_model.net.Conditional(
[on_initial_step, BEAM_SIZE, MINUS_ONE_INT32],
['slice_end'],
)
# [current_beam_size * beam_size]
output_scores_flattened_slice = self.step_model.net.Slice(
[output_scores_flattened, ZERO, slice_end],
'output_scores_flattened_slice',
)
# [1, current_beam_size * beam_size]
output_scores_flattened_slice, _ = self.step_model.net.Reshape(
output_scores_flattened_slice,
[
output_scores_flattened_slice,
'output_scores_flattened_slice_old_shape',
],
shape=[1, -1],
)
# [1, beam_size]
scores_t, best_indices = self.step_model.net.TopK(
output_scores_flattened_slice,
['scores_t', 'best_indices'],
k=self.beam_size,
)
BEAM_SIZE_64 = self.model.param_init_net.Cast(
BEAM_SIZE,
'BEAM_SIZE_64',
to=core.DataType.INT64,
)
# [1, beam_size]
hypo_t_int32 = self.step_model.net.Div(
[best_indices, BEAM_SIZE_64],
'hypo_t_int32',
broadcast=1,
)
hypo_t = self.step_model.net.Cast(
hypo_t_int32,
'hypo_t',
to=core.DataType.FLOAT,
)
# [beam_size, encoder_length, 1]
attention_t = self.step_model.net.Gather(
[attentions, hypo_t_int32],
'attention_t',
)
# [1, beam_size, encoder_length]
attention_t, _ = self.step_model.net.Reshape(
attention_t,
[attention_t, 'attention_t_old_shape'],
shape=[1, self.beam_size, -1],
)
# [beam_size * beam_size]
best_tokens_per_hypo_flatten, _ = self.step_model.net.Reshape(
best_tokens_per_hypo,
[
'best_tokens_per_hypo_flatten',
'best_tokens_per_hypo_old_shape',
],
shape=[-1],
)
tokens_t_int32 = self.step_model.net.Gather(
[best_tokens_per_hypo_flatten, best_indices],
'tokens_t_int32',
)
tokens_t = self.step_model.net.Cast(
tokens_t_int32,
'tokens_t',
to=core.DataType.FLOAT,
)
def choose_state_per_hypo(state_config):
state_flattened, _ = self.step_model.net.Reshape(
state_config.state_link.blob,
[
state_config.state_link.blob,
state_config.state_link.blob + '_old_shape',
],
shape=[self.beam_size, -1],
)
state_chosen_per_hypo = self.step_model.net.Gather(
[state_flattened, hypo_t_int32],
str(state_config.state_link.blob) + '_chosen_per_hypo',
)
return self.StateConfig(
initial_value=state_config.initial_value,
state_prev_link=state_config.state_prev_link,
state_link=self.LinkConfig(
blob=state_chosen_per_hypo,
offset=state_config.state_link.offset,
window=state_config.state_link.window,
)
)
state_configs = [choose_state_per_hypo(c) for c in state_configs]
initial_scores = self.model.param_init_net.ConstantFill(
[],
'initial_scores',
shape=[1],
value=0.0,
dtype=core.DataType.FLOAT,
)
if go_token_id:
initial_tokens = self.model.net.Copy(
[go_token_id],
'initial_tokens',
Loading ...