diff --git a/deepconsensus/models/attention_layer.py b/deepconsensus/models/attention_layer.py index 61d8ad1..a57cee4 100644 --- a/deepconsensus/models/attention_layer.py +++ b/deepconsensus/models/attention_layer.py @@ -27,14 +27,15 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Implementation of multiheaded attention and self-attention layers.""" import math - +from typing import Any, Dict, Optional, Union, Iterable import tensorflow as tf class Attention(tf.keras.layers.Layer): """Multi-headed attention layer.""" - def __init__(self, hidden_size, num_heads, attention_dropout): + def __init__(self, hidden_size: int, num_heads: int, + attention_dropout: float): """Initialize Attention. Args: @@ -52,7 +53,7 @@ def __init__(self, hidden_size, num_heads, attention_dropout): self.num_heads = num_heads self.attention_dropout = attention_dropout - def build(self, input_shape): + def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]): """Builds the layer.""" # Layers for linearly projecting the queries, keys, and values. size_per_head = self.hidden_size // self.num_heads @@ -91,7 +92,7 @@ def _glorot_initializer(fan_in, fan_out): name="output_transform") super(Attention, self).build(input_shape) - def get_config(self): + def get_config(self) -> Dict[str, Any]: return { "hidden_size": self.hidden_size, "num_heads": self.num_heads, @@ -99,12 +100,12 @@ def get_config(self): } def call(self, - query_input, - source_input, - bias, - training, - cache=None, - decode_loop_step=None): + query_input: tf.Tensor, + source_input: tf.Tensor, + bias: tf.Tensor, + training: bool, + cache: Optional[Dict[str, tf.Tensor]] = None, + decode_loop_step: Optional[int] = None) -> Dict[str, tf.Tensor]: """Apply attention mechanism to query_input and source_input. Args: @@ -124,7 +125,12 @@ def call(self, for autoregressive inference on TPU. Returns: - Attention layer output with shape [batch_size, length_query, hidden_size] + Dictionary with the following (key:value) pairs: + "main_output": Attention layer output with shape [batch_size, + length_query, hidden_size]. Used as input to the feed_forward_network. + "attention scores": Attention map weights (after softmax) with shape + [batch_size, num_heads, length_query, length_query] - auxiliary output. + """ # Linearly project the query, key and value using different learned # projections. Splitting heads is automatically done during the linear @@ -173,17 +179,19 @@ def call(self, # Run the outputs through another linear projection layer. Recombining heads # is automatically done --> [batch_size, length, hidden_size] attention_output = self.output_dense_layer(attention_output) - return attention_output + + layer_output = dict(main_output=attention_output, attention_scores=weights) + return layer_output class SelfAttention(Attention): """Multiheaded self-attention layer.""" def call(self, - query_input, - bias, - training, - cache=None, - decode_loop_step=None): + query_input: tf.Tensor, + bias: tf.Tensor, + training: bool, + cache: Optional[Dict[str, tf.Tensor]] = None, + decode_loop_step: Optional[int] = None) -> Dict[str, tf.Tensor]: return super(SelfAttention, self).call(query_input, query_input, bias, training, cache, decode_loop_step) diff --git a/deepconsensus/models/encoder_stack.py b/deepconsensus/models/encoder_stack.py index 48df78a..9c27186 100644 --- a/deepconsensus/models/encoder_stack.py +++ b/deepconsensus/models/encoder_stack.py @@ -31,6 +31,9 @@ Transformer model code source: https://github.com/tensorflow/tensor2tensor """ +from typing import Any, Dict, Union, Iterable + +import ml_collections import tensorflow as tf from deepconsensus.models import attention_layer @@ -40,37 +43,40 @@ class PrePostProcessingWrapper(tf.keras.layers.Layer): """Wrapper class that applies layer pre-processing and post-processing.""" - def __init__(self, layer, params): + def __init__(self, layer: tf.keras.layers.Layer, + params: ml_collections.ConfigDict): super(PrePostProcessingWrapper, self).__init__() self.layer = layer self.params = params self.postprocess_dropout = params["layer_postprocess_dropout"] - def build(self, input_shape): + def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]): # Create normalization layer self.layer_norm = tf.keras.layers.LayerNormalization( epsilon=1e-6, dtype="float32") super(PrePostProcessingWrapper, self).build(input_shape) - def get_config(self): + def get_config(self) -> Dict[str, Any]: return { "params": self.params, } - def call(self, x, *args, **kwargs): + def call(self, x: tf.Tensor, *args, **kwargs) -> Dict[str, tf.Tensor]: """Calls wrapped layer with same parameters.""" # Preprocessing: apply layer normalization training = kwargs["training"] y = self.layer_norm(x) - # Get layer output - y = self.layer(y, *args, **kwargs) + # Get layer output. + layer_output = self.layer(y, *args, **kwargs) + y = layer_output["main_output"] # Postprocessing: apply dropout and residual connection if training: y = tf.nn.dropout(y, rate=self.postprocess_dropout) - return x + y + layer_output["main_output"] = x + y + return layer_output class EncoderStack(tf.keras.layers.Layer): @@ -82,12 +88,12 @@ class EncoderStack(tf.keras.layers.Layer): 2. Feedforward network (which is 2 fully-connected layers) """ - def __init__(self, params): + def __init__(self, params: ml_collections.ConfigDict): super(EncoderStack, self).__init__() self.params = params self.layers = [] - def build(self, input_shape): + def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]): """Builds the encoder stack.""" params = self.params for _ in range(params["num_hidden_layers"]): @@ -108,12 +114,13 @@ def build(self, input_shape): epsilon=1e-6, dtype="float32") super(EncoderStack, self).build(input_shape) - def get_config(self): + def get_config(self) -> Dict[str, Any]: return { "params": self.params, } - def call(self, encoder_inputs, attention_bias, inputs_padding, training): + def call(self, encoder_inputs: tf.Tensor, attention_bias: tf.Tensor, + inputs_padding: tf.Tensor, training: bool) -> Dict[str, tf.Tensor]: """Return the output of the encoder layer stacks. Args: @@ -125,9 +132,19 @@ def call(self, encoder_inputs, attention_bias, inputs_padding, training): training: boolean, whether in training mode or not. Returns: - Output of encoder layer stack. - float32 tensor with shape [batch_size, input_length, hidden_size] + Dictionary with the following (key:value) pairs: + "self_attention_layer_{n}": Attention layer output for every layer in + the encoder stack with shape [batch_size, input_length, hidden_size]. + "attention_scores_{n}" : Attention map for every layer in the + encoder stack with shape [batch_size, num_heads, input_length, + input_length]. + "ffn_layer_{n}": Feedforward network output for every layer in the + encoder stack with shape [batch_size, input_length, hidden_size]. + "final_output": Final output of the entire encoder stack after + normalization with shape [batch_size, input_length, hidden_size]. Used + as input to the fully-connected layer which outputs logits. """ + outputs_dict = dict() for n, layer in enumerate(self.layers): # Run inputs through the sublayers. self_attention_layer = layer[0] @@ -135,10 +152,20 @@ def call(self, encoder_inputs, attention_bias, inputs_padding, training): with tf.name_scope("layer_%d" % n): with tf.name_scope("self_attention"): - encoder_inputs = self_attention_layer( + layer_outputs = self_attention_layer( encoder_inputs, attention_bias, training=training) + encoder_inputs = layer_outputs["main_output"] + # Add attention layer outputs and attention map scores to outputs. + outputs_dict[f"self_attention_layer_{n}"] = encoder_inputs + outputs_dict[f"attention_scores_{n}"] = layer_outputs[ + "attention_scores"] with tf.name_scope("ffn"): - encoder_inputs = feed_forward_network( + layer_outputs = feed_forward_network( encoder_inputs, training=training) + encoder_inputs = layer_outputs["main_output"] + # Add output of the feedforward network to outputs. + outputs_dict[f"ffn_layer_{n}"] = encoder_inputs - return self.output_normalization(encoder_inputs) + # Add normalized final output of the entire encoder stack to outputs. + outputs_dict["final_output"] = self.output_normalization(encoder_inputs) + return outputs_dict diff --git a/deepconsensus/models/ffn_layer.py b/deepconsensus/models/ffn_layer.py index c6fbfbb..bf131a3 100644 --- a/deepconsensus/models/ffn_layer.py +++ b/deepconsensus/models/ffn_layer.py @@ -27,13 +27,14 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Implementation of fully connected network.""" +from typing import Any, Dict, Union, Iterable import tensorflow as tf class FeedForwardNetwork(tf.keras.layers.Layer): """Fully connected feedforward network.""" - def __init__(self, hidden_size, filter_size, relu_dropout): + def __init__(self, hidden_size: int, filter_size: int, relu_dropout: float): """Initialize FeedForwardNetwork. Args: @@ -46,7 +47,7 @@ def __init__(self, hidden_size, filter_size, relu_dropout): self.filter_size = filter_size self.relu_dropout = relu_dropout - def build(self, input_shape): + def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]): self.filter_dense_layer = tf.keras.layers.Dense( self.filter_size, use_bias=True, @@ -56,14 +57,14 @@ def build(self, input_shape): self.hidden_size, use_bias=True, name="output_layer") super(FeedForwardNetwork, self).build(input_shape) - def get_config(self): + def get_config(self) -> Dict[str, Any]: return { "hidden_size": self.hidden_size, "filter_size": self.filter_size, "relu_dropout": self.relu_dropout, } - def call(self, x, training): + def call(self, x: tf.Tensor, training: bool) -> Dict[str, tf.Tensor]: """Return outputs of the feedforward network. Args: @@ -71,8 +72,9 @@ def call(self, x, training): training: boolean, whether in training mode or not. Returns: - Output of the feedforward network. - tensor with shape [batch_size, length, hidden_size] + Dictionary with the following (key:value) pairs: + "main_output": Output of the feedforward network with shape [batch_size, + length, hidden_size]. Used as input to the next encoder layer. """ # Retrieve dynamically known shapes @@ -80,5 +82,4 @@ def call(self, x, training): if training: output = tf.nn.dropout(output, rate=self.relu_dropout) output = self.output_dense_layer(output) - - return output + return dict(main_output=output) diff --git a/deepconsensus/models/model_distillation.py b/deepconsensus/models/model_distillation.py index 69b5556..1d9330a 100644 --- a/deepconsensus/models/model_distillation.py +++ b/deepconsensus/models/model_distillation.py @@ -218,10 +218,14 @@ def train_step(inputs): """Training StepFn.""" features, labels = inputs # Get logits from the teacher model. - teacher_logits = teacher_model.get_logits(features, training=False) + teacher_intermediate_outputs_dict = teacher_model.get_intermediate_outputs( + features, training=False) + teacher_logits = teacher_intermediate_outputs_dict['logits'] with tf.GradientTape() as tape: - student_logits = model.get_logits(features, training=True) + student_intermediate_outputs_dict = model.get_intermediate_outputs( + features, training=True) + student_logits = student_intermediate_outputs_dict['logits'] student_preds = tf.nn.softmax(student_logits) train_losses_dict = compute_loss(labels, student_preds, student_logits, teacher_logits) @@ -240,9 +244,13 @@ def eval_step(inputs): """Eval StepFn.""" features, labels = inputs # Get logits from the teacher model. - teacher_logits = teacher_model.get_logits(features, training=False) + teacher_intermediate_outputs_dict = teacher_model.get_intermediate_outputs( + features, training=False) + teacher_logits = teacher_intermediate_outputs_dict['logits'] - student_logits = model.get_logits(features, training=False) + student_intermediate_outputs_dict = model.get_intermediate_outputs( + features, training=False) + student_logits = student_intermediate_outputs_dict['logits'] student_preds = tf.nn.softmax(student_logits) eval_losses_dict = compute_loss(labels, student_preds, student_logits, teacher_logits) diff --git a/deepconsensus/models/networks.py b/deepconsensus/models/networks.py index c1dc5ba..92496d9 100644 --- a/deepconsensus/models/networks.py +++ b/deepconsensus/models/networks.py @@ -208,12 +208,15 @@ def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor: each position in the sequence. """ with tf.name_scope('Transformer'): - logits = self.get_logits(inputs, training=training) + intermediate_outputs_dict = self.get_intermediate_outputs( + inputs, training=training) + logits = intermediate_outputs_dict['logits'] preds = self.softmax(logits) return preds - def get_logits(self, inputs: tf.Tensor, training: bool) -> tf.Tensor: - """Get logits of the model. + def get_intermediate_outputs(self, inputs: tf.Tensor, + training: bool) -> Dict[str, tf.Tensor]: + """Get intermediate outputs of the model. Args: inputs: tensor of shape (batch_size, hidden_size, input_length @@ -221,8 +224,19 @@ def get_logits(self, inputs: tf.Tensor, training: bool) -> tf.Tensor: training: boolean, whether in training mode or not. Returns: - Output logits over the vocabulary at each position in the sequence. The - output tensor is of shape (batch_size, length, vocab_size). + Dictionary with the following (key:value) pairs: + "self_attention_layer_{n}": Attention layer output for every layer in + the encoder stack with shape [batch_size, input_length, hidden_size]. + "attention_scores_{n}" : Attention map for every layer in the + encoder stack with shape [batch_size, num_heads, input_length, + input_length]. + "ffn_layer_{n}": Feedforward network output for every layer in the + encoder stack with shape [batch_size, input_length, hidden_size]. + "final_output": Final output of the entire encoder stack after + normalization with shape [batch_size, input_length, hidden_size]. Used + as input to the fully-connected layer which outputs logits. + "logits": Logits over the vocabulary at each position in the sequence + with shape [batch_size, input_length, vocab_size]. """ # Get rid of the channel dimension as we only have one channel. @@ -239,12 +253,13 @@ def get_logits(self, inputs: tf.Tensor, training: bool) -> tf.Tensor: all_zeros = tf.reduce_sum(tf.zeros_like(inputs), -1) attention_bias = tf.expand_dims(tf.expand_dims(all_zeros, 1), 1) - # Run inputs through the encoder. Encoder returns logits from dense layer. - encoder_outputs = self.encode(inputs, attention_bias, training) - return encoder_outputs + # Run inputs through the encoder. Encoder returns a dictionary of + # logits from dense layer as well as other intermediate model outputs. + intermediate_outputs_dict = self.encode(inputs, attention_bias, training) + return intermediate_outputs_dict def encode(self, inputs: tf.Tensor, attention_bias: tf.Tensor, - training: bool) -> tf.Tensor: + training: bool) -> Dict[str, tf.Tensor]: """Runs the input through Encoder stack and problem-specific layers.""" with tf.name_scope('encode'): @@ -287,14 +302,19 @@ def encode(self, inputs: tf.Tensor, attention_bias: tf.Tensor, encoder_inputs, rate=self.params['layer_postprocess_dropout']) # Pass inputs through the encoder. As mentioned above, `inputs_padding` is - # not actually used by EncoderStack.call. Encoder stack output has shape - # (batch_size, input_length, hidden_size). - encoder_outputs = self.encoder_stack( + # not actually used by EncoderStack.call. Encoder stack output is a + # dictionary containing final output of the encoder stack with shape + # (batch_size, input_length, hidden_size) as well as intermediate outputs + # of each of the attention and feed forward network layers in the stack. + encoder_outputs_dict = self.encoder_stack( encoder_inputs, attention_bias, inputs_padding, training=training) - # Pass through dense layer and output logits over vocab for each position. - encoder_outputs = self.fc1(encoder_outputs) - return encoder_outputs + # Pass the final output of the encoder stack through dense layer and + # output logits over vocab for each position. + encoder_outputs = self.fc1(encoder_outputs_dict['final_output']) + # Add logits to the outputs dictionary. + encoder_outputs_dict['logits'] = encoder_outputs + return encoder_outputs_dict def decode(self, encoder_outputs: tf.Tensor, attention_bias: tf.Tensor, training: bool) -> tf.Tensor: @@ -365,7 +385,7 @@ def __init__(self, bias_initializer='zeros') def encode(self, inputs: tf.Tensor, attention_bias: tf.Tensor, - training: bool) -> tf.Tensor: + training: bool) -> Dict[str, tf.Tensor]: """Runs the input through Encoder stack and problem-specific layers.""" # Input to embedding layer is [batch_size, length] and output will be