[go: nahoru, domu]

Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
cfregly committed Apr 4, 2020
1 parent 96cf2c1 commit 0ada93e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 84 deletions.
112 changes: 58 additions & 54 deletions 06_train/04_Train_Reviews_BERT_TensorFlow2_ScriptMode.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 293,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": 294,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 295,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand All @@ -53,13 +53,13 @@
},
{
"cell_type": "code",
"execution_count": 296,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"prefix_train = '{}/output/bert-train2'.format(scikit_processing_job_s3_output_prefix)\n",
"prefix_validation = '{}/output/bert-validation2'.format(scikit_processing_job_s3_output_prefix)\n",
"prefix_test = '{}/output/bert-test2'.format(scikit_processing_job_s3_output_prefix)\n",
"prefix_train = '{}/output/bert-train-10'.format(scikit_processing_job_s3_output_prefix)\n",
"prefix_validation = '{}/output/bert-validation-10'.format(scikit_processing_job_s3_output_prefix)\n",
"prefix_test = '{}/output/bert-test-10'.format(scikit_processing_job_s3_output_prefix)\n",
"\n",
"path_train = './{}'.format(prefix_train)\n",
"path_validation = './{}'.format(prefix_validation)\n",
Expand All @@ -72,23 +72,23 @@
},
{
"cell_type": "code",
"execution_count": 297,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-train2', 'S3DataDistributionType': 'ShardedByS3Key'}}}\n",
"{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-validation2', 'S3DataDistributionType': 'ShardedByS3Key'}}}\n",
"{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-test2', 'S3DataDistributionType': 'ShardedByS3Key'}}}\n"
"{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-train-10', 'S3DataDistributionType': 'FullyReplicated'}}}\n",
"{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-validation-10', 'S3DataDistributionType': 'FullyReplicated'}}}\n",
"{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-835319576252/sagemaker-scikit-learn-2020-03-30-03-34-18-188/output/bert-test-10', 'S3DataDistributionType': 'FullyReplicated'}}}\n"
]
}
],
"source": [
"s3_input_train_data = sagemaker.s3_input(s3_data=train_s3_uri, distribution='ShardedByS3Key') \n",
"s3_input_validation_data = sagemaker.s3_input(s3_data=validation_s3_uri, distribution='ShardedByS3Key')\n",
"s3_input_test_data = sagemaker.s3_input(s3_data=test_s3_uri, distribution='ShardedByS3Key')\n",
"s3_input_train_data = sagemaker.s3_input(s3_data=train_s3_uri) #, distribution='ShardedByS3Key') \n",
"s3_input_validation_data = sagemaker.s3_input(s3_data=validation_s3_uri) #, distribution='ShardedByS3Key')\n",
"s3_input_test_data = sagemaker.s3_input(s3_data=test_s3_uri) #, distribution='ShardedByS3Key')\n",
"\n",
"print(s3_input_train_data.config)\n",
"print(s3_input_validation_data.config)\n",
Expand All @@ -97,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 298,
"execution_count": 18,
"metadata": {
"scrolled": false
},
Expand Down Expand Up @@ -128,11 +128,11 @@
"from transformers.configuration_bert import BertConfig\r\n",
"\r\n",
"MAX_SEQ_LENGTH = 128\r\n",
"BATCH_SIZE = 256\r\n",
"BATCH_SIZE = 128 \r\n",
"EVAL_BATCH_SIZE=BATCH_SIZE * 2\r\n",
"EPOCHS = 5\r\n",
"STEPS_PER_EPOCH = 1000\r\n",
"VALIDATION_STEPS = 1000\r\n",
"EPOCHS = 2 \r\n",
"STEPS_PER_EPOCH = 100\r\n",
"VALIDATION_STEPS = 100\r\n",
"CLASSES = [1, 2, 3, 4, 5]\r\n",
"# XLA is an optimization compiler for tensorflow\r\n",
"USE_XLA = True \r\n",
Expand Down Expand Up @@ -169,7 +169,7 @@
" print('***** Using input_filenames {}'.format(input_filenames))\r\n",
" dataset = tf.data.TFRecordDataset(input_filenames)\r\n",
"\r\n",
" dataset = dataset.repeat(EPOCHS)\r\n",
" dataset = dataset.repeat(EPOCHS * STEPS_PER_EPOCH)\r\n",
" dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\r\n",
"\r\n",
" name_to_features = {\r\n",
Expand Down Expand Up @@ -244,18 +244,18 @@
" num_gpus = args.num_gpus\r\n",
"\r\n",
" tokenizer = None\r\n",
" config = None\r\n",
" model = None\r\n",
" config = None \r\n",
"\r\n",
" # This is required when launching many instances at once... the urllib request seems to get denied periodically\r\n",
" successful_download = False\r\n",
" retries = 0\r\n",
" while (retries < 5 and not successful_download):\r\n",
" try:\r\n",
" tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\r\n",
" config = BertConfig(num_labels=len(CLASSES))\r\n",
" model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', \r\n",
" config=config)\r\n",
" config = BertConfig(num_labels=len(CLASSES))\r\n",
" successful_download = True\r\n",
" print('Sucessfully downloaded after {} retries.'.format(retries))\r\n",
" except:\r\n",
Expand Down Expand Up @@ -301,42 +301,45 @@
" loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\r\n",
" metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\r\n",
" model.compile(optimizer=optimizer, loss=loss, metrics=[metric])\r\n",
" model.layers[0].trainable = False\r\n",
" model.layers[0].trainable=False\r\n",
" model.summary()\r\n",
"\r\n",
" log_dir = './tensorboard/classification/'\r\n",
" tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)\r\n",
"\r\n",
" history = model.fit(train_dataset,\r\n",
"# shuffle=True,\r\n",
" shuffle=True,\r\n",
" epochs=EPOCHS,\r\n",
" steps_per_epoch=STEPS_PER_EPOCH,\r\n",
" validation_data=validation_dataset,\r\n",
" validation_steps=VALIDATION_STEPS,\r\n",
" callbacks=[tensorboard_callback])\r\n",
"\r\n",
" print('Trained model {}'.format(model))\r\n",
"\r\n",
" # Save the Model\r\n",
" model.save_pretrained(model_dir)\r\n",
"\r\n",
" loaded_model = TFBertForSequenceClassification.from_pretrained(model_dir,\r\n",
" id2label={\r\n",
" 0: 1,\r\n",
" 1: 2,\r\n",
" 2: 3,\r\n",
" 3: 4,\r\n",
" 4: 5\r\n",
" },\r\n",
" label2id={\r\n",
" 1: 0,\r\n",
" 2: 1,\r\n",
" 3: 2,\r\n",
" 4: 3,\r\n",
" 5: 4\r\n",
" })\r\n",
"\r\n",
" tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\r\n",
"\r\n",
" inference_pipeline = TextClassificationPipeline(model=loaded_model, \r\n",
"# loaded_model = TFBertForSequenceClassification.from_pretrained(model_dir,\r\n",
"# id2label={\r\n",
"# 0: 1,\r\n",
"# 1: 2,\r\n",
"# 2: 3,\r\n",
"# 3: 4,\r\n",
"# 4: 5\r\n",
"# },\r\n",
"# label2id={\r\n",
"# 1: 0,\r\n",
"# 2: 1,\r\n",
"# 3: 2,\r\n",
"# 4: 3,\r\n",
"# 5: 4\r\n",
"# })\r\n",
"\r\n",
"# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\r\n",
"\r\n",
"# inference_pipeline = TextClassificationPipeline(model=loaded_model, \r\n",
" inference_pipeline = TextClassificationPipeline(model=model, \r\n",
" tokenizer=tokenizer,\r\n",
" framework='tf',\r\n",
" device=-1) # -1 is CPU, >= 0 is GPU\r\n",
Expand All @@ -354,7 +357,7 @@
},
{
"cell_type": "code",
"execution_count": 299,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -374,7 +377,8 @@
"# 'model_name': 'bert-base-cased'},\n",
" distributions={'parameter_server': {'enabled': True}},\n",
" enable_cloudwatch_metrics=True,\n",
" input_mode='Pipe')"
"# input_mode='Pipe'\n",
" )"
]
},
{
Expand All @@ -386,7 +390,7 @@
},
{
"cell_type": "code",
"execution_count": 300,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -397,14 +401,14 @@
},
{
"cell_type": "code",
"execution_count": 301,
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training_job_name: tensorflow-training-2020-04-03-06-33-19-845\n"
"training_job_name: tensorflow-training-2020-04-03-22-08-48-700\n"
]
}
],
Expand All @@ -415,13 +419,13 @@
},
{
"cell_type": "code",
"execution_count": 302,
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region=us-east-1#/jobs/tensorflow-training-2020-04-03-06-33-19-845\">Training Job</a> After About 5 Minutes</b>"
"<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region=us-east-1#/jobs/tensorflow-training-2020-04-03-22-08-48-700\">Training Job</a> After About 5 Minutes</b>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -439,13 +443,13 @@
},
{
"cell_type": "code",
"execution_count": 303,
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<b>Review <a href=\"https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/TrainingJobs;prefix=tensorflow-training-2020-04-03-06-33-19-845;streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>"
"<b>Review <a href=\"https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/aws/sagemaker/TrainingJobs;prefix=tensorflow-training-2020-04-03-22-08-48-700;streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -463,13 +467,13 @@
},
{
"cell_type": "code",
"execution_count": 304,
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<b>Review <a href=\"https://s3.console.aws.amazon.com/s3/buckets/sagemaker-us-east-1-835319576252/models/tf2-bert/tensorflow-training-2020-04-03-06-33-19-845/?region=us-east-1&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>"
"<b>Review <a href=\"https://s3.console.aws.amazon.com/s3/buckets/sagemaker-us-east-1-835319576252/models/tf2-bert/tensorflow-training-2020-04-03-22-08-48-700/?region=us-east-1&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand Down
3 changes: 2 additions & 1 deletion 06_train/src_bert_tf2/test-local.sh
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
SM_INPUT_DATA_CONFIG={\"train\":{\"TrainingInputMode\":\"Pipe\"}} SM_CURRENT_HOST=blah SM_NUM_GPUS=0 SM_HOSTS={\"hosts\":\"blah\"} SM_CHANNEL_TRAIN=data/train SM_CHANNEL_VALIDATION=data/validation SM_MODEL_DIR=. python tf_bert_reviews.py # TODO: --model-type=bert --model-name=bert-base-uncased
#SM_INPUT_DATA_CONFIG={\"train\":{\"TrainingInputMode\":\"Pipe\"}}
SM_CURRENT_HOST=blah SM_NUM_GPUS=0 SM_HOSTS={\"hosts\":\"blah\"} SM_CHANNEL_TRAIN=data/train SM_CHANNEL_VALIDATION=data/validation SM_MODEL_DIR=. python tf_bert_reviews.py # TODO: --model-type=bert --model-name=bert-base-uncased
Loading

0 comments on commit 0ada93e

Please sign in to comment.