[go: nahoru, domu]

Skip to content

Commit

Permalink
Update python formatting.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 508555829
  • Loading branch information
danielecook authored and Copybara-Service committed Feb 10, 2023
1 parent 8b87555 commit b5546cd
Show file tree
Hide file tree
Showing 38 changed files with 2,549 additions and 1,493 deletions.
375 changes: 246 additions & 129 deletions deepconsensus/inference/quick_inference.py

Large diffs are not rendered by default.

76 changes: 45 additions & 31 deletions deepconsensus/models/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""

def __init__(self,
hidden_size: int,
num_heads: int,
attention_dropout: float,
attn_win_size: Optional[int] = None):
def __init__(
self,
hidden_size: int,
num_heads: int,
attention_dropout: float,
attn_win_size: Optional[int] = None,
):
"""Initialize Attention.
Args:
Expand All @@ -51,7 +53,8 @@ def __init__(self,
if hidden_size % num_heads:
raise ValueError(
"Hidden size ({}) must be divisible by the number of heads ({})."
.format(hidden_size, num_heads))
.format(hidden_size, num_heads)
)

super(Attention, self).__init__()
self.hidden_size = hidden_size
Expand All @@ -77,35 +80,40 @@ def _glorot_initializer(fan_in, fan_out):
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=query_initializer,
bias_axes=None,
name="query")
name="query",
)
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=key_initializer,
bias_axes=None,
name="key")
name="key",
)
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=value_initializer,
bias_axes=None,
name="value")
name="value",
)

output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
self.output_dense_layer = tf.keras.layers.experimental.EinsumDense(
"BTNH,NHE->BTE",
output_shape=(None, self.hidden_size),
kernel_initializer=output_initializer,
bias_axes=None,
name="output_transform")
name="output_transform",
)

# input_shape = [batch_size, input_length, hidden_size]
max_length = input_shape.as_list()[1]

if self.attn_win_size:
self.attn_mask = tf.ones([1, 1, max_length, max_length])
self.attn_mask = tf.linalg.band_part(self.attn_mask, self.attn_win_size,
self.attn_win_size)
self.attn_mask = tf.linalg.band_part(
self.attn_mask, self.attn_win_size, self.attn_win_size
)
# attn_mask will contain True values in the band and False values outside.
self.attn_mask = self.attn_mask > 0.0
else:
Expand All @@ -121,13 +129,15 @@ def get_config(self) -> Dict[str, Any]:
"attn_win_size": self.attn_win_size,
}

def call(self,
query_input: tf.Tensor,
source_input: tf.Tensor,
bias: tf.Tensor,
training: bool,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_loop_step: Optional[int] = None) -> Dict[str, tf.Tensor]:
def call(
self,
query_input: tf.Tensor,
source_input: tf.Tensor,
bias: tf.Tensor,
training: bool,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_loop_step: Optional[int] = None,
) -> Dict[str, tf.Tensor]:
"""Apply attention mechanism to query_input and source_input.
Args:
Expand All @@ -152,7 +162,6 @@ def call(self,
length_query, hidden_size]. Used as input to the feed_forward_network.
"attention scores": Attention map weights (after softmax) with shape
[batch_size, num_heads, length_query, length_query] - auxiliary output.
"""
# Linearly project the query, key and value using different learned
# projections. Splitting heads is automatically done during the linear
Expand All @@ -167,12 +176,14 @@ def call(self,
cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1, 1])
[1, cache_k_shape[1], 1, 1],
)
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1, 1])
[1, cache_v_shape[1], 1, 1],
)
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
Expand All @@ -184,7 +195,7 @@ def call(self,

# Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads)
depth = self.hidden_size // self.num_heads
query *= depth**-0.5

# Calculate dot product attention
Expand Down Expand Up @@ -213,11 +224,14 @@ def call(self,
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""

def call(self,
query_input: tf.Tensor,
bias: tf.Tensor,
training: bool,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_loop_step: Optional[int] = None) -> Dict[str, tf.Tensor]:
return super(SelfAttention, self).call(query_input, query_input, bias,
training, cache, decode_loop_step)
def call(
self,
query_input: tf.Tensor,
bias: tf.Tensor,
training: bool,
cache: Optional[Dict[str, tf.Tensor]] = None,
decode_loop_step: Optional[int] = None,
) -> Dict[str, tf.Tensor]:
return super(SelfAttention, self).call(
query_input, query_input, bias, training, cache, decode_loop_step
)
16 changes: 12 additions & 4 deletions deepconsensus/models/convert_to_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Example command:
convert_to_saved_model --checkpoint=/path/to/checkpoint --output=/tmp/output
"""

import os
Expand All @@ -47,8 +48,13 @@

# Model checkpoint:
_CHECKPOINT = flags.DEFINE_string(
'checkpoint', None, 'Path to checkpoint directory + prefix. '
'For example: <path/to/model>/checkpoint-50.')
'checkpoint',
None,
(
'Path to checkpoint directory + prefix. '
'For example: <path/to/model>/checkpoint-50.'
),
)


def register_required_flags():
Expand Down Expand Up @@ -81,7 +87,8 @@ def initialize_model(checkpoint_path: str) -> Optional[tf.keras.Model]:
input_shape = (1, params.total_rows, params.max_length, params.num_channels)
model_utils.print_model_summary(model, input_shape)
checkpoint.restore(
checkpoint_path).expect_partial().assert_existing_objects_matched()
checkpoint_path
).expect_partial().assert_existing_objects_matched()

logging.info('Finished initialize_model.')
return model
Expand All @@ -94,7 +101,8 @@ def main(_):
# Copy over the params.json. At this point, we know params.json exists.
json_path = os.path.join(os.path.dirname(_CHECKPOINT.value), 'params.json')
gfile.Copy(
json_path, os.path.join(_OUTPUT.value, 'params.json'), overwrite=True)
json_path, os.path.join(_OUTPUT.value, 'params.json'), overwrite=True
)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit b5546cd

Please sign in to comment.