[go: nahoru, domu]

Skip to content

Commit

Permalink
Use ReZero instead of LayerNorms.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 496754847
  • Loading branch information
Genomics team in Google Health authored and Copybara-Service committed Dec 20, 2022
1 parent 09d3660 commit 01a4e1c
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 16 deletions.
23 changes: 17 additions & 6 deletions deepconsensus/models/encoder_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@ def __init__(self, layer: tf.keras.layers.Layer,
self.postprocess_dropout = params["layer_postprocess_dropout"]

def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]):
# Create normalization layer
self.layer_norm = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
if self.params["rezero"]:
# Variable used in ReZero (paper: https://arxiv.org/abs/2003.04887).
alpha_init = tf.zeros_initializer()
self.alpha = tf.Variable(
initial_value=alpha_init(shape=()), trainable=True)
else:
self.layer_norm = tf.keras.layers.LayerNormalization(
epsilon=1e-6, dtype="float32")
super(PrePostProcessingWrapper, self).build(input_shape)

def get_config(self) -> Dict[str, Any]:
Expand All @@ -63,10 +68,12 @@ def get_config(self) -> Dict[str, Any]:

def call(self, x: tf.Tensor, *args, **kwargs) -> Dict[str, tf.Tensor]:
"""Calls wrapped layer with same parameters."""
# Preprocessing: apply layer normalization
training = kwargs["training"]

y = self.layer_norm(x)
if self.params["rezero"]:
y = x
else:
y = self.layer_norm(x)

# Get layer output.
layer_output = self.layer(y, *args, **kwargs)
Expand All @@ -75,7 +82,11 @@ def call(self, x: tf.Tensor, *args, **kwargs) -> Dict[str, tf.Tensor]:
# Postprocessing: apply dropout and residual connection
if training:
y = tf.nn.dropout(y, rate=self.postprocess_dropout)
layer_output["main_output"] = x + y
if self.params["rezero"]:
# Apply ReZero.
layer_output["main_output"] = x + self.alpha * y
else:
layer_output["main_output"] = x + y
return layer_output


Expand Down
2 changes: 2 additions & 0 deletions deepconsensus/models/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def _set_base_transformer_hparams(params):
# tuning.
params.num_heads = 2
params.layer_norm = False
# Whether to use ReZero instead of LayerNorms.
params.rezero = True
params.condense_transformer_input = False
params.transformer_model_size = 'base'

Expand Down
Binary file modified deepconsensus/testdata/human_1m/tf_examples/eval/eval.tfrecord.gz
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"zmw_trimmed_insertions_bp": 9421,
"n_zmw_inference": 10,
"n_zmw_pass": 10,
"example_width_bucket_100": 1593,
"n_examples_skip_large_windows_keep": 1593,
"n_examples_inference": 1593,
"n_examples": 1593,
"max_passes": "20",
Expand All @@ -17,5 +19,5 @@
"truth_bed": "None",
"truth_split": "None",
"ins_trim": "5",
"version": "1.0.0"
"version": "1.1.0"
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"zmw_trimmed_insertions_bp": 9421,
"n_zmw_train": 7,
"n_zmw_pass": 9,
"example_width_bucket_100": 1551,
"n_examples_skip_large_windows_keep": 1507,
"n_examples_adjusted_label": 305,
"n_examples_train": 1239,
"n_examples": 1507,
Expand All @@ -24,5 +26,5 @@
"truth_bed": "testdata/human_1m/truth.bed",
"truth_split": "testdata/human_1m/truth_split.tsv",
"ins_trim": "5",
"version": "1.0.0"
"version": "1.1.0"
}
Binary file modified deepconsensus/testdata/human_1m/tf_examples/test/test.tfrecord.gz
Binary file not shown.
Binary file not shown.
Binary file modified deepconsensus/testdata/model/checkpoint-1.index
Binary file not shown.
Binary file modified deepconsensus/testdata/model/checkpoint-2.index
Binary file not shown.
16 changes: 8 additions & 8 deletions deepconsensus/testdata/model/checkpoint_metrics.tsv
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
checkpoint_name group name value
/tmp/deepconsensus/model/20221022174103/checkpoint-1 eval eval/loss 181.65109252929688
/tmp/deepconsensus/model/20221022174103/checkpoint-1 eval eval/per_example_accuracy 0.0
/tmp/deepconsensus/model/20221022174103/checkpoint-1 eval eval/per_batch_alignment_identity 0.2653157711029053
/tmp/deepconsensus/model/20221022174103/checkpoint-1 eval eval/yield_over_ccs 0.0
/tmp/deepconsensus/model/20221022174103/checkpoint-2 eval eval/loss 0.0
/tmp/deepconsensus/model/20221022174103/checkpoint-2 eval eval/per_example_accuracy 0.0
/tmp/deepconsensus/model/20221022174103/checkpoint-2 eval eval/per_batch_alignment_identity 0.0
/tmp/deepconsensus/model/20221022174103/checkpoint-2 eval eval/yield_over_ccs 0.0
/tmp/deepconsensus/model/20221218190123/checkpoint-1 eval eval/loss 187.8463134765625
/tmp/deepconsensus/model/20221218190123/checkpoint-1 eval eval/per_example_accuracy 0.0
/tmp/deepconsensus/model/20221218190123/checkpoint-1 eval eval/per_batch_alignment_identity 0.26910263299942017
/tmp/deepconsensus/model/20221218190123/checkpoint-1 eval eval/yield_over_ccs 0.0
/tmp/deepconsensus/model/20221218190123/checkpoint-2 eval eval/loss 0.0
/tmp/deepconsensus/model/20221218190123/checkpoint-2 eval eval/per_example_accuracy 0.0
/tmp/deepconsensus/model/20221218190123/checkpoint-2 eval eval/per_batch_alignment_identity 0.0
/tmp/deepconsensus/model/20221218190123/checkpoint-2 eval eval/yield_over_ccs 0.0
1 change: 1 addition & 0 deletions deepconsensus/testdata/model/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"pw_hidden_size": 8,
"relu_dropout": 0.1,
"remove_label_gaps": false,
"rezero": true,
"seed": 1,
"sn_hidden_size": 8,
"static_batch": false,
Expand Down

0 comments on commit 01a4e1c

Please sign in to comment.