[go: nahoru, domu]

Skip to content

Commit

Permalink
Enable access to intermediate model outputs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 463203521
  • Loading branch information
anastasiyabl authored and Copybara-Service committed Jul 25, 2022
1 parent 9fe0ce4 commit 0609176
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 61 deletions.
42 changes: 25 additions & 17 deletions deepconsensus/models/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Implementation of multiheaded attention and self-attention layers."""
import math

from typing import Any, Dict, Optional, Union, Iterable
import tensorflow as tf


class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""

def __init__(self, hidden_size, num_heads, attention_dropout):
def __init__(self, hidden_size: int, num_heads: int,
attention_dropout: float):
"""Initialize Attention.
Args:
Expand All @@ -52,7 +53,7 @@ def __init__(self, hidden_size, num_heads, attention_dropout):
self.num_heads = num_heads
self.attention_dropout = attention_dropout

def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]):
"""Builds the layer."""
# Layers for linearly projecting the queries, keys, and values.
size_per_head = self.hidden_size // self.num_heads
Expand Down Expand Up @@ -91,20 +92,20 @@ def _glorot_initializer(fan_in, fan_out):
name="output_transform")
super(Attention, self).build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
return {
"hidden_size": self.hidden_size,
"num_heads": self.num_heads,
"attention_dropout": self.attention_dropout,
}

def call(self,
query_input,
source_input,
bias,
training,
cache=None,
decode_loop_step=None):
query_input: tf.Tensor,
source_input: tf.Tensor,
bias: tf.Tensor,
training: bool,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_loop_step: Optional[int] = None) -> Dict[str, tf.Tensor]:
"""Apply attention mechanism to query_input and source_input.
Args:
Expand All @@ -124,7 +125,12 @@ def call(self,
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_query, hidden_size]
Dictionary with the following (key:value) pairs:
"main_output": Attention layer output with shape [batch_size,
length_query, hidden_size]. Used as input to the feed_forward_network.
"attention scores": Attention map weights (after softmax) with shape
[batch_size, num_heads, length_query, length_query] - auxiliary output.
"""
# Linearly project the query, key and value using different learned
# projections. Splitting heads is automatically done during the linear
Expand Down Expand Up @@ -173,17 +179,19 @@ def call(self,
# Run the outputs through another linear projection layer. Recombining heads
# is automatically done --> [batch_size, length, hidden_size]
attention_output = self.output_dense_layer(attention_output)
return attention_output

layer_output = dict(main_output=attention_output, attention_scores=weights)
return layer_output


class SelfAttention(Attention):
"""Multiheaded self-attention layer."""

def call(self,
query_input,
bias,
training,
cache=None,
decode_loop_step=None):
query_input: tf.Tensor,
bias: tf.Tensor,
training: bool,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_loop_step: Optional[int] = None) -> Dict[str, tf.Tensor]:
return super(SelfAttention, self).call(query_input, query_input, bias,
training, cache, decode_loop_step)
59 changes: 43 additions & 16 deletions deepconsensus/models/encoder_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
Transformer model code source: https://github.com/tensorflow/tensor2tensor
"""

from typing import Any, Dict, Union, Iterable

import ml_collections
import tensorflow as tf

from deepconsensus.models import attention_layer
Expand All @@ -40,37 +43,40 @@
class PrePostProcessingWrapper(tf.keras.layers.Layer):
"""Wrapper class that applies layer pre-processing and post-processing."""

def __init__(self, layer, params):
def __init__(self, layer: tf.keras.layers.Layer,
params: ml_collections.ConfigDict):
super(PrePostProcessingWrapper, self).__init__()
self.layer = layer
self.params = params
self.postprocess_dropout = params["layer_postprocess_dropout"]

def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]):
# Create normalization layer
self.layer_norm = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(PrePostProcessingWrapper, self).build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
return {
"params": self.params,
}

def call(self, x, *args, **kwargs):
def call(self, x: tf.Tensor, *args, **kwargs) -> Dict[str, tf.Tensor]:
"""Calls wrapped layer with same parameters."""
# Preprocessing: apply layer normalization
training = kwargs["training"]

y = self.layer_norm(x)

# Get layer output
y = self.layer(y, *args, **kwargs)
# Get layer output.
layer_output = self.layer(y, *args, **kwargs)
y = layer_output["main_output"]

# Postprocessing: apply dropout and residual connection
if training:
y = tf.nn.dropout(y, rate=self.postprocess_dropout)
return x + y
layer_output["main_output"] = x + y
return layer_output


class EncoderStack(tf.keras.layers.Layer):
Expand All @@ -82,12 +88,12 @@ class EncoderStack(tf.keras.layers.Layer):
2. Feedforward network (which is 2 fully-connected layers)
"""

def __init__(self, params):
def __init__(self, params: ml_collections.ConfigDict):
super(EncoderStack, self).__init__()
self.params = params
self.layers = []

def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]):
"""Builds the encoder stack."""
params = self.params
for _ in range(params["num_hidden_layers"]):
Expand All @@ -108,12 +114,13 @@ def build(self, input_shape):
epsilon=1e-6, dtype="float32")
super(EncoderStack, self).build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
return {
"params": self.params,
}

def call(self, encoder_inputs, attention_bias, inputs_padding, training):
def call(self, encoder_inputs: tf.Tensor, attention_bias: tf.Tensor,
inputs_padding: tf.Tensor, training: bool) -> Dict[str, tf.Tensor]:
"""Return the output of the encoder layer stacks.
Args:
Expand All @@ -125,20 +132,40 @@ def call(self, encoder_inputs, attention_bias, inputs_padding, training):
training: boolean, whether in training mode or not.
Returns:
Output of encoder layer stack.
float32 tensor with shape [batch_size, input_length, hidden_size]
Dictionary with the following (key:value) pairs:
"self_attention_layer_{n}": Attention layer output for every layer in
the encoder stack with shape [batch_size, input_length, hidden_size].
"attention_scores_{n}" : Attention map for every layer in the
encoder stack with shape [batch_size, num_heads, input_length,
input_length].
"ffn_layer_{n}": Feedforward network output for every layer in the
encoder stack with shape [batch_size, input_length, hidden_size].
"final_output": Final output of the entire encoder stack after
normalization with shape [batch_size, input_length, hidden_size]. Used
as input to the fully-connected layer which outputs logits.
"""
outputs_dict = dict()
for n, layer in enumerate(self.layers):
# Run inputs through the sublayers.
self_attention_layer = layer[0]
feed_forward_network = layer[1]

with tf.name_scope("layer_%d" % n):
with tf.name_scope("self_attention"):
encoder_inputs = self_attention_layer(
layer_outputs = self_attention_layer(
encoder_inputs, attention_bias, training=training)
encoder_inputs = layer_outputs["main_output"]
# Add attention layer outputs and attention map scores to outputs.
outputs_dict[f"self_attention_layer_{n}"] = encoder_inputs
outputs_dict[f"attention_scores_{n}"] = layer_outputs[
"attention_scores"]
with tf.name_scope("ffn"):
encoder_inputs = feed_forward_network(
layer_outputs = feed_forward_network(
encoder_inputs, training=training)
encoder_inputs = layer_outputs["main_output"]
# Add output of the feedforward network to outputs.
outputs_dict[f"ffn_layer_{n}"] = encoder_inputs

return self.output_normalization(encoder_inputs)
# Add normalized final output of the entire encoder stack to outputs.
outputs_dict["final_output"] = self.output_normalization(encoder_inputs)
return outputs_dict
17 changes: 9 additions & 8 deletions deepconsensus/models/ffn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Implementation of fully connected network."""

from typing import Any, Dict, Union, Iterable
import tensorflow as tf


class FeedForwardNetwork(tf.keras.layers.Layer):
"""Fully connected feedforward network."""

def __init__(self, hidden_size, filter_size, relu_dropout):
def __init__(self, hidden_size: int, filter_size: int, relu_dropout: float):
"""Initialize FeedForwardNetwork.
Args:
Expand All @@ -46,7 +47,7 @@ def __init__(self, hidden_size, filter_size, relu_dropout):
self.filter_size = filter_size
self.relu_dropout = relu_dropout

def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]):
self.filter_dense_layer = tf.keras.layers.Dense(
self.filter_size,
use_bias=True,
Expand All @@ -56,29 +57,29 @@ def build(self, input_shape):
self.hidden_size, use_bias=True, name="output_layer")
super(FeedForwardNetwork, self).build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
return {
"hidden_size": self.hidden_size,
"filter_size": self.filter_size,
"relu_dropout": self.relu_dropout,
}

def call(self, x, training):
def call(self, x: tf.Tensor, training: bool) -> Dict[str, tf.Tensor]:
"""Return outputs of the feedforward network.
Args:
x: tensor with shape [batch_size, length, hidden_size]
training: boolean, whether in training mode or not.
Returns:
Output of the feedforward network.
tensor with shape [batch_size, length, hidden_size]
Dictionary with the following (key:value) pairs:
"main_output": Output of the feedforward network with shape [batch_size,
length, hidden_size]. Used as input to the next encoder layer.
"""
# Retrieve dynamically known shapes

output = self.filter_dense_layer(x)
if training:
output = tf.nn.dropout(output, rate=self.relu_dropout)
output = self.output_dense_layer(output)

return output
return dict(main_output=output)
16 changes: 12 additions & 4 deletions deepconsensus/models/model_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,14 @@ def train_step(inputs):
"""Training StepFn."""
features, labels = inputs
# Get logits from the teacher model.
teacher_logits = teacher_model.get_logits(features, training=False)
teacher_intermediate_outputs_dict = teacher_model.get_intermediate_outputs(
features, training=False)
teacher_logits = teacher_intermediate_outputs_dict['logits']

with tf.GradientTape() as tape:
student_logits = model.get_logits(features, training=True)
student_intermediate_outputs_dict = model.get_intermediate_outputs(
features, training=True)
student_logits = student_intermediate_outputs_dict['logits']
student_preds = tf.nn.softmax(student_logits)
train_losses_dict = compute_loss(labels, student_preds, student_logits,
teacher_logits)
Expand All @@ -240,9 +244,13 @@ def eval_step(inputs):
"""Eval StepFn."""
features, labels = inputs
# Get logits from the teacher model.
teacher_logits = teacher_model.get_logits(features, training=False)
teacher_intermediate_outputs_dict = teacher_model.get_intermediate_outputs(
features, training=False)
teacher_logits = teacher_intermediate_outputs_dict['logits']

student_logits = model.get_logits(features, training=False)
student_intermediate_outputs_dict = model.get_intermediate_outputs(
features, training=False)
student_logits = student_intermediate_outputs_dict['logits']
student_preds = tf.nn.softmax(student_logits)
eval_losses_dict = compute_loss(labels, student_preds, student_logits,
teacher_logits)
Expand Down
Loading

0 comments on commit 0609176

Please sign in to comment.