[go: nahoru, domu]

Skip to content

Commit

Permalink
random backoff if pretrained model fails to download from s3
Browse files Browse the repository at this point in the history
  • Loading branch information
cfregly committed Apr 3, 2020
1 parent e5b2ea9 commit db7f26e
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions 06_train/new_src_bert_tf2/tf_bert_reviews.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time
import random
import pandas as pd
from glob import glob
import argparse
Expand Down Expand Up @@ -129,13 +131,37 @@ def _decode_record(record, name_to_features):
current_host = args.current_host
num_gpus = args.num_gpus

tokenizer = None
model = None
config = None

# This is required when launching many instances at once... the urllib request seems to get denied periodically
successful_download = False
retries = 0
while (retries < 5 and not successful_download):
try:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased',
config=config)
config = BertConfig(num_labels=len(CLASSES))
successful_download = True
print('Sucessfully downloaded after {} retries.'.format(retries))
except:
retries = retries + 1
random_sleep = random.randint(1, 30)
print('Retry #{}. Sleeping for {} seconds'.format(retries, random_sleep))
time.sleep(random_sleep)

if not tokenizer or not model or not config:
print('Not properly initialized...')

pipe_mode_str = os.environ.get('SM_INPUT_DATA_CONFIG', '')
print('pipe_mode_str {}'.format(pipe_mode_str))
pipe_mode = (pipe_mode_str.find('Pipe') >= 0)
print('pipe_mode {}'.format(pipe_mode))

train_data_filenames = glob('{}/*.tfrecord'.format(train_data))
print(train_data_filenames)
print('train_data_filenames {}'.format(train_data_filenames))
train_dataset = file_based_input_dataset_builder(
channel='train',
input_filenames=train_data_filenames,
Expand All @@ -144,7 +170,7 @@ def _decode_record(record, name_to_features):
drop_remainder=False).map(select_data_and_label_from_record)

validation_data_filenames = glob('{}/*.tfrecord'.format(validation_data))
print(validation_data_filenames)
print('validation_data_filenames {}'.format(validation_data_filenames))
validation_dataset = file_based_input_dataset_builder(
channel='validation',
input_filenames=validation_data_filenames,
Expand All @@ -155,13 +181,7 @@ def _decode_record(record, name_to_features):
tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig(num_labels=len(CLASSES))
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased',
config=config)

optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)

if USE_AMP:
# loss scaling is currently required when using mixed precision
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
Expand All @@ -179,8 +199,8 @@ def _decode_record(record, name_to_features):
# shuffle=True,
epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
# validation_data=validation_dataset,
# validation_steps=VALIDATION_STEPS,
validation_data=validation_dataset,
validation_steps=VALIDATION_STEPS,
callbacks=[tensorboard_callback])

# Save the Model
Expand Down

0 comments on commit db7f26e

Please sign in to comment.