From 64f4a59d5e39b60d67047d5e0b82de0cbcc6c2df Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Mon, 6 Apr 2020 15:55:40 -0700 Subject: [PATCH] Add multi worker mirrored strategy combinations to enable easier testing again The previous was rolled back because it breaks windows builds. The change adds has_chief and num_workers to NamedDistribution. If there're more than one workers (chief+workers), the distribute flavor of combinations library will run the test in multiple processes to simulate a multi worker setup. Users are required to call combinations.main() instead of test.main(). Note that it's the same as running the test concurrently in multiple processes. You're expected to program your test in the same as if you're writing a multi client program. There's no way to get the return value from all processes. If you need that, use multi_process_runner directly. This is very slow at this moment. Do not use it for a large number of tests. PiperOrigin-RevId: 305136390 Change-Id: I814323cccaee5b8b7eed3b4dfef2971ab9d09cb4 --- .../saved_model/integration_tests/BUILD | 1 + .../integration_tests/saved_model_test.py | 7 +- tensorflow/python/distribute/BUILD | 159 ++++++----- tensorflow/python/distribute/combinations.py | 258 +++++++++++++++--- .../python/distribute/combinations_test.py | 151 ++++++++++ .../distribute/strategy_combinations.py | 13 + 6 files changed, 481 insertions(+), 108 deletions(-) create mode 100644 tensorflow/python/distribute/combinations_test.py diff --git a/tensorflow/examples/saved_model/integration_tests/BUILD b/tensorflow/examples/saved_model/integration_tests/BUILD index 0e55a0af437755..4f55cfa3042d1f 100644 --- a/tensorflow/examples/saved_model/integration_tests/BUILD +++ b/tensorflow/examples/saved_model/integration_tests/BUILD @@ -64,6 +64,7 @@ cuda_py_test( ":distribution_strategy_utils", ":integration_scripts", "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_combinations", "//tensorflow/python/distribute:combinations", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py index d97b93418affdb..6333e55999e32c 100644 --- a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py +++ b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py @@ -25,7 +25,8 @@ from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts -from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import combinations as distribute_combinations +from tensorflow.python.framework import combinations class SavedModelTest(scripts.TestCase, parameterized.TestCase): @@ -89,8 +90,8 @@ def test_text_embedding_in_dataset(self): retrain_flag_value=["true", "false"], regularization_loss_multiplier=[None, 2], # Test for b/134528831. )), - test_combinations=(combinations.NamedGPUCombination(), - combinations.NamedTPUCombination())) + test_combinations=(distribute_combinations.NamedGPUCombination(), + distribute_combinations.NamedTPUCombination())) @combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS) def test_mnist_cnn(self, use_keras_save_api, named_strategy, diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index d0a70f1829403c..55d5463527beb8 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -315,19 +315,19 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ + ":cross_device_ops", ":input_lib", ":mirrored_run", + ":multi_worker_util", ":numpy_dataset", + ":reduce_util", + ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/distribute:values", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", ], @@ -369,18 +369,18 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ + ":cross_device_ops", + ":cross_device_utils", + ":input_lib", ":mirrored_strategy", + ":multi_worker_util", + ":numpy_dataset", + ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:training", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:cross_device_utils", - "//tensorflow/python/distribute:input_lib", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:numpy_dataset", - "//tensorflow/python/distribute:values", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", ], @@ -445,16 +445,16 @@ cuda_py_test( ], deps = [ ":collective_all_reduce_strategy", + ":combinations", ":multi_process_runner", + ":multi_worker_test_base", ":reduce_util", + ":strategy_combinations", "//tensorflow/python:array_ops", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:multi_worker_test_base", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//third_party/py/numpy", @@ -707,9 +707,26 @@ py_library( srcs = ["combinations.py"], srcs_version = "PY2AND3", deps = [ + ":multi_process_runner", + ":multi_worker_test_base", "//tensorflow/python:framework_combinations", "//tensorflow/python:framework_test_combinations_lib", + "//tensorflow/python:platform", + "//tensorflow/python:util", "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_test( + name = "combinations_test", + srcs = ["combinations_test.py"], + deps = [ + ":combinations", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_combinations", + "//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py", + "@absl_py//absl/testing:parameterized", ], ) @@ -719,19 +736,24 @@ py_library( srcs_version = "PY2AND3", deps = [ ":central_storage_strategy", + ":collective_all_reduce_strategy", ":combinations", ":distribute_lib", ":mirrored_strategy", + ":multi_process_runner", + ":multi_worker_test_base", ":one_device_strategy", ":tpu_strategy", "//tensorflow/python:config", - "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:tf2", "//tensorflow/python:training", - "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python:util", + "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/python/eager:context", "//tensorflow/python/eager:remote", "//tensorflow/python/keras/optimizer_v2", - "@absl_py//absl/testing:parameterized", + "//tensorflow/python/tpu:tpu_lib", ], ) @@ -743,9 +765,10 @@ py_test( ":combinations", ":reduce_util", ":strategy_combinations", + "//tensorflow/python:client_testlib", "//tensorflow/python:config", "//tensorflow/python:constant_op", - "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", ], ) @@ -787,13 +810,13 @@ cuda_py_test( "multi_and_single_gpu", ], deps = [ + ":combinations", + ":strategy_combinations", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", ], ) @@ -805,10 +828,10 @@ distribute_py_test( "multi_and_single_gpu", ], deps = [ + ":combinations", + ":strategy_combinations", ":tpu_strategy", "//tensorflow/compiler/tests:xla_test", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", "//tensorflow/python/keras", "//tensorflow/python/training/tracking:util", @@ -827,19 +850,19 @@ distribute_py_test( ], deps = [ ":collective_all_reduce_strategy", + ":combinations", + ":input_lib", ":mirrored_strategy", + ":multi_worker_test_base", ":reduce_util", + ":strategy_combinations", + ":values", "//tensorflow/python:control_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:input_lib", - "//tensorflow/python/distribute:multi_worker_test_base", - "//tensorflow/python/distribute:strategy_combinations", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -852,14 +875,14 @@ cuda_py_test( name = "cross_device_utils_test", srcs = ["cross_device_utils_test.py"], deps = [ + ":combinations", + ":cross_device_utils", + ":strategy_combinations", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:cross_device_utils", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "@absl_py//absl/testing:parameterized", @@ -877,16 +900,16 @@ cuda_py_test( deps = [ ":collective_all_reduce_strategy", ":collective_util", + ":combinations", + ":cross_device_ops", ":mirrored_strategy", + ":multi_worker_test_base", + ":strategy_combinations", + ":values", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:multi_worker_test_base", - "//tensorflow/python/distribute:strategy_combinations", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "@absl_py//absl/testing:parameterized", @@ -898,9 +921,9 @@ cuda_py_test( srcs = ["one_device_strategy_test.py"], grpc_enabled = True, deps = [ + ":combinations", + ":strategy_combinations", ":strategy_test_lib", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", ], ) @@ -936,6 +959,9 @@ py_library( srcs = ["strategy_test_lib.py"], srcs_version = "PY2AND3", deps = [ + ":distribute_lib", + ":reduce_util", + ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -949,9 +975,6 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -973,8 +996,13 @@ distribute_py_test( "no_oss", # Target too big to run serially reliably. ], deps = [ + ":combinations", + ":device_util", + ":distribute_lib", ":mirrored_strategy", ":parameter_server_strategy", + ":strategy_combinations", + ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", @@ -983,11 +1011,6 @@ distribute_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:device_util", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:strategy_combinations", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/saved_model/model_utils:mode_keys", @@ -1000,13 +1023,13 @@ distribute_py_test( srcs = ["moving_averages_test.py"], main = "moving_averages_test.py", deps = [ + ":combinations", + ":strategy_combinations", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:training", "//tensorflow/python:variables", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", "@absl_py//absl/testing:parameterized", ], @@ -1020,10 +1043,10 @@ distribute_py_test( "multi_and_single_gpu", ], deps = [ + ":combinations", + ":strategy_combinations", "//tensorflow/python:errors", "//tensorflow/python:variables", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", "@absl_py//absl/testing:parameterized", ], @@ -1041,11 +1064,11 @@ distribute_py_test( "no_oss", # Target too big to run serially reliably. ], deps = [ + ":combinations", + ":strategy_combinations", "//tensorflow/python:errors", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", "//tensorflow/python/keras", "@absl_py//absl/testing:parameterized", @@ -1060,11 +1083,11 @@ distribute_py_test( "multi_and_single_gpu", ], deps = [ + ":combinations", + ":strategy_combinations", "//tensorflow/python:errors", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", "//tensorflow/python/keras", "@absl_py//absl/testing:parameterized", @@ -1125,8 +1148,10 @@ distribute_py_test( "notap", # TODO(b/139815303): enable after this is fixed. ], deps = [ + ":combinations", ":mirrored_strategy", ":single_loss_example", + ":strategy_combinations", "//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_v2_toggles", "//tensorflow/python:framework_ops", @@ -1134,8 +1159,6 @@ distribute_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/keras/layers", @@ -1178,11 +1201,11 @@ distribute_py_test( "multi_and_single_gpu", ], deps = [ + ":combinations", ":single_loss_example", + ":strategy_combinations", "//tensorflow/python:framework_test_lib", "//tensorflow/python:variables", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//third_party/py/numpy", @@ -1198,13 +1221,13 @@ cuda_py_test( "multi_and_single_gpu", ], deps = [ + ":combinations", + ":strategy_combinations", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", ], ) @@ -1213,9 +1236,11 @@ cuda_py_test( srcs = ["remote_mirrored_strategy_eager_test.py"], deps = [ ":combinations", + ":distribute_lib", ":mirrored_strategy", ":multi_worker_test_base", ":strategy_test_lib", + ":values", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_test_lib", @@ -1223,8 +1248,6 @@ cuda_py_test( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -1271,15 +1294,15 @@ cuda_py_test( deps = [ ":collective_all_reduce_strategy", ":combinations", + ":distribute_lib", ":strategy_combinations", + ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_test_lib", "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/keras/layers", @@ -1429,13 +1452,13 @@ distribute_py_test( "noguitar", # b/140755528 ], deps = [ + ":combinations", + ":strategy_combinations", "//tensorflow/python:keras_lib", "//tensorflow/python:platform_test", "//tensorflow/python:util", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", ], ) @@ -1450,6 +1473,7 @@ cuda_py_test( deps = [ ":collective_all_reduce_strategy", ":combinations", + ":cross_device_utils", ":multi_worker_test_base", ":strategy_combinations", ":strategy_test_lib", @@ -1463,7 +1487,6 @@ cuda_py_test( "//tensorflow/python:init_ops", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras/layers", @@ -1486,9 +1509,11 @@ cuda_py_test( ":central_storage_strategy", ":combinations", ":multi_worker_test_base", + ":multi_worker_util", ":parameter_server_strategy", ":strategy_combinations", ":strategy_test_lib", + ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1501,8 +1526,6 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras/layers", @@ -1557,6 +1580,6 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python/distribute:multi_worker_util", + ":multi_worker_util", ], ) diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index 5f6779911c43ac..e9a45f0cc106e3 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -22,20 +22,27 @@ from __future__ import division from __future__ import print_function -import functools import sys +import types +import unittest +import six + +from tensorflow.python.distribute import multi_process_runner +from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import context from tensorflow.python.framework import combinations as framework_combinations -from tensorflow.python.framework import test_combinations +from tensorflow.python.framework import test_combinations as combinations_lib from tensorflow.python.platform import flags +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect FLAGS = flags.FLAGS # TODO(rchao): Rename `distribution` parameter to `strategy` or # `distribute_strategy` in all tests. -class DistributionParameter(test_combinations.ParameterModifier): +class DistributionParameter(combinations_lib.ParameterModifier): """Transforms arguments of type `NamedDistribution`. Convert all arguments of type `NamedDistribution` to the value of their @@ -51,7 +58,32 @@ def modified_arguments(self, kwargs, requested_parameters): return distribution_arguments -class NamedGPUCombination(test_combinations.TestCombination): +class ClusterParameters(combinations_lib.ParameterModifier): + """Adds cluster parameters if a `NamedDistribution` has it. + + It needs to be before DistributionParameter. + """ + + def modified_arguments(self, kwargs, requested_parameters): + strategy = None + for _, v in kwargs.items(): + if isinstance(v, NamedDistribution): + if strategy is not None and _num_total_workers(v.has_chief, + v.num_workers) > 1: + raise ValueError("Only support one NamedDistribution for multi worker" + "tests.") + strategy = v + # Always set cluster parameters if they're requested. So that generate() + # works when there's no startegy in the combinations. + update = {} + if "has_chief" in requested_parameters: + update["has_chief"] = strategy.has_chief if strategy else False + if "num_workers" in requested_parameters: + update["num_workers"] = strategy.num_workers if strategy else 1 + return update + + +class NamedGPUCombination(combinations_lib.TestCombination): """Enable tests to request GPU hardware and skip non-GPU combinations. This class expects test_combinations to be generated with `NamedDistribution` @@ -90,18 +122,20 @@ def should_execute_combination(self, kwargs): return (True, None) def parameter_modifiers(self): - return [test_combinations.OptionalParameter("required_gpus")] + return [combinations_lib.OptionalParameter("required_gpus")] class GPUCombination(NamedGPUCombination): """NamedGPUCombination that passes `tf.distribute.Strategy` to the tests.""" def parameter_modifiers(self): - return [DistributionParameter() - ] + NamedGPUCombination.parameter_modifiers(self) + return [ + ClusterParameters(), + DistributionParameter(), + ] + NamedGPUCombination.parameter_modifiers(self) -class NamedTPUCombination(test_combinations.TestCombination): +class NamedTPUCombination(combinations_lib.TestCombination): """Allow to request TPU hardware and skip non-TPU combinations. This class expects test_combinations to be generated with `NamedDistribution` @@ -158,9 +192,9 @@ def should_execute_combination(self, kwargs): def parameter_modifiers(self): return [ - test_combinations.OptionalParameter("required_tpus"), - test_combinations.OptionalParameter("required_tpu"), - test_combinations.OptionalParameter("use_cloud_tpu"), + combinations_lib.OptionalParameter("required_tpus"), + combinations_lib.OptionalParameter("required_tpu"), + combinations_lib.OptionalParameter("use_cloud_tpu"), ] @@ -168,38 +202,47 @@ class TPUCombination(NamedTPUCombination): """NamedTPUCombination that passes `tf.distribute.Strategy` to the tests.""" def parameter_modifiers(self): - return [DistributionParameter() - ] + NamedTPUCombination.parameter_modifiers(self) + return [ + ClusterParameters(), + DistributionParameter(), + ] + NamedTPUCombination.parameter_modifiers(self) class NamedDistribution(object): """Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" - def __init__(self, name, distribution_fn, required_gpus=None, - required_tpu=False, use_cloud_tpu=False): + def __init__(self, + name, + distribution_fn, + required_gpus=None, + required_tpu=False, + use_cloud_tpu=False, + has_chief=False, + num_workers=1): + """Initialize NamedDistribution. + + Args: + name: Name that will be a part of the name of the test case. + distribution_fn: A callable that creates a `tf.distribute.Strategy`. + required_gpus: The number of GPUs that the strategy requires. + required_tpu: Whether the strategy requires TPU. + use_cloud_tpu: Whether the strategy requires cloud TPU. + has_chief: Whether the strategy requires a chief worker. + num_workers: The number of workers that the strategy requires. + """ object.__init__(self) self._name = name self._distribution_fn = distribution_fn - self._required_gpus = required_gpus - self._required_tpu = required_tpu - self._use_cloud_tpu = use_cloud_tpu + self.required_gpus = required_gpus + self.required_tpu = required_tpu + self.use_cloud_tpu = use_cloud_tpu + self.has_chief = has_chief + self.num_workers = num_workers @property def strategy(self): return self._distribution_fn() - @property - def required_gpus(self): - return self._required_gpus - - @property - def required_tpu(self): - return self._required_tpu - - @property - def use_cloud_tpu(self): - return self._use_cloud_tpu - def __repr__(self): return self._name @@ -212,11 +255,152 @@ def concat(*combined): return result -_defaults = framework_combinations.generate.keywords["test_combinations"] +def generate(combinations, test_combinations=()): + # pylint: disable=g-doc-args,g-doc-return-or-yield + """Distributed adapter of `framework.combinations_lib.generate`. + + All tests with distributed strategy should use this one instead of + `framework.test_combinations.generate`. This function has support of strategy + combinations, GPU/TPU and multi worker support. + + See `framework.test_combinations_lib.generate` for usage. + """ + # pylint: enable=g-doc-args,g-doc-return-or-yield + default_combinations = ( + framework_combinations.EagerGraphCombination(), + framework_combinations.TFVersionCombination(), + GPUCombination(), + TPUCombination(), + ) + # We apply our own decoration to handle multi worker tests before applying + # framework.test_combinations.generate. The order is important since we need + # framework.test_combinations.generate to apply all parameter modifiers first. + combination_decorator = combinations_lib.generate( + combinations, test_combinations=default_combinations + test_combinations) + + def decorator(test_method_or_class): + if isinstance(test_method_or_class, type): + # If it's a test class. + class_object = test_method_or_class + # Decorate each test method with _multi_worker_test. + for name, test_method in six.iteritems(class_object.__dict__.copy()): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + isinstance(test_method, types.FunctionType)): + setattr(class_object, name, _multi_worker_test(test_method)) + return combination_decorator(class_object) + else: + return combination_decorator(_multi_worker_test(test_method_or_class)) + + return decorator + + +combine = combinations_lib.combine +times = combinations_lib.times +NamedObject = combinations_lib.NamedObject + + +def main(): + """Tests must call this main().""" + return multi_process_runner.test_main() + + +# Identifies whether we're in the main process or worker processes. +# `_multi_worker_test` decoration behaves differently in the main processs and +# the worker processes. See the documentation of _multi_worker_test for detail. +_running_in_worker = False + + +def _test_runner(test_id): + """Executes the test with the given test_id. + + This is a simple wrapper around TestRunner to be used with + multi_process_runner. Similar to test.main(), but it executes only one test + specified by test_id and returns whether the test succeeds. If the test fails, + the function prints failures and errors to stdout. + + Args: + test_id: TestCase.id() + + Returns: + A boolean indicates whether the test succeeds. + """ + global _running_in_worker + # No need to restore the value of _running_in_worker since it should always be + # True in worker processes. + _running_in_worker = True + test = unittest.defaultTestLoader.loadTestsFromName(test_id) + runner = unittest.TextTestRunner() + result = runner.run(test) + # Print failures and errors to stdout and multi_process_runner will collect + # them and stream back to the main process. + for _, msg in result.failures + result.errors: + print(msg) + # Return expected failures as failures, so that the main process can get + # them and fail as expected. + if result.expectedFailures: + return False + return result.wasSuccessful() + + +def _multi_worker_test(test_method): + """Decorate test_method so that it runs in each worker. + + We use `multi_process_runner` to simulate multiple workers. Since we run the + this function in the main process and all worker processes, this decoration + behaves differently in the main process and worker procssses. In the main + process, it spawns subprocesses and runs the test on each of them; in a worker + process, it executes test in the same way as a normal test, e.g. + setUp()/tearDown() are called before/after the test. + + Args: + test_method: a function which must be a test method. + + Returns: + Decorated `test_method`. Note that the decorated function has additional + arguments. + """ -generate = functools.partial( - framework_combinations.generate, - test_combinations=_defaults + (GPUCombination(), TPUCombination())) -combine = framework_combinations.combine -times = framework_combinations.times -NamedObject = framework_combinations.NamedObject + def decorator(self, has_chief, num_workers, **kwargs): + if _num_total_workers(has_chief, num_workers) == 1 or _running_in_worker: + # We're in worker process or the test is for single worker. Either case we + # execute the test method directly instead of spawning subprocesses. + test_method(self, **kwargs) + return + + # We're in the main process. We spawn subprocesses and run the *test* on + # each of them. Note that we're not directly executing test_method passed to + # _multi_worker_test, because we need setUp()/tearDown() to be called and + # all the decorations on the test method. The conceptual call stack is: + # [main process]test.main() + # [main process]test_runner.run(test) + # [main process]wrapper by combinations.generate() + # [main process]_multi_worker_test.decorator() + # # A sub process goes through the same code path as the main + # # process. + # [sub process]_test_runner() + # [sub process]test_runner.run(test) + # [sub process]wrapper by combinations.generate() + # [sub process]_multi_worker_test.decorator() + # # _running_in_worker is True + # [sub process]test_method() + test_id = self.id() + cluster_spec = multi_worker_test_base.create_cluster_spec( + has_chief=has_chief, num_workers=num_workers, num_ps=0, has_eval=False) + result = multi_process_runner.run( + _test_runner, cluster_spec, args=(test_id,)) + for was_successful in result.return_value: + if not was_successful: + raise AssertionError("some worker failed, see logs for details") + + argspec = tf_inspect.getfullargspec(test_method) + decorator_args = (argspec.args or []) + ["has_chief", "num_workers"] + decorator_argspec = argspec._replace(args=decorator_args) + return tf_decorator.make_decorator( + test_method, decorator, decorator_argspec=decorator_argspec) + + +def _num_total_workers(has_chief, num_workers): + """Returns the number of workers including the chief.""" + if has_chief: + return num_workers + 1 + return num_workers diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py new file mode 100644 index 00000000000000..7033e3e3b337a9 --- /dev/null +++ b/tensorflow/python/distribute/combinations_test.py @@ -0,0 +1,151 @@ +# Lint as: python3 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.distribute.combinations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from absl.testing import parameterized + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver +from tensorflow.python.framework import combinations as framework_combinations +from tensorflow.python.platform import test + + +class ClusterParametersTest(test.TestCase, parameterized.TestCase): + # For this test we need to use `framework.test_combinations` because our + # `generate` eats the cluster parameters. + # + # Note that we don't have a standalone combination for ClusterParameters, so + # we should use GPUCombination which contains it. + + @framework_combinations.generate( + framework_combinations.combine(distribution=[ + combinations.NamedDistribution( + "HasClusterParams", lambda: None, has_chief=True, num_workers=2), + ]), + test_combinations=(combinations.GPUCombination(),)) + def testClusterParams(self, distribution, has_chief, num_workers): + self.assertTrue(has_chief) + self.assertEqual(num_workers, 2) + + @framework_combinations.generate( + framework_combinations.combine(distribution=[ + combinations.NamedDistribution("NoClusterParams", lambda: None), + ]), + test_combinations=(combinations.GPUCombination(),)) + def testClusterParamsHasDefault(self, distribution, has_chief, num_workers): + self.assertFalse(has_chief) + self.assertEqual(num_workers, 1) + + @framework_combinations.generate( + framework_combinations.combine(v=1), + test_combinations=(combinations.GPUCombination(),)) + def testClusterParamsNoStrategy(self, v, has_chief, num_workers): + self.assertFalse(has_chief) + self.assertEqual(num_workers, 1) + + @framework_combinations.generate( + framework_combinations.combine(distribution=[ + combinations.NamedDistribution( + "WithClusterParams", lambda: None, has_chief=True, num_workers=2), + combinations.NamedDistribution("WithoutClusterParams", lambda: None), + ]), + test_combinations=(combinations.GPUCombination(),)) + def testClusterParamsAreOptional(self, distribution): + # If combinations library doesn't raise an exception, the test is passed. + pass + + @framework_combinations.generate( + framework_combinations.combine( + ds1=combinations.NamedDistribution( + "Strategy1", lambda: None, has_chief=True, num_workers=0), + ds2=combinations.NamedDistribution( + "Strategy2", lambda: None, has_chief=False, num_workers=1), + ds3=combinations.NamedDistribution( + "Strategy3", lambda: None, has_chief=True, num_workers=0), + ), + test_combinations=(combinations.GPUCombination(),)) + def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3): + # If combinations library doesn't raise an exception, the test is passed. + pass + + +# unittest.expectedFailure doesn't work with parameterized test methods, so we +# have to decorate the class instead. +@unittest.expectedFailure +class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase): + + @framework_combinations.generate( + framework_combinations.combine( + ds1=combinations.NamedDistribution( + "Strategy1", lambda: None, has_chief=True, num_workers=2), + ds2=combinations.NamedDistribution( + "Strategy2", lambda: None, has_chief=True, num_workers=2), + ), + test_combinations=(combinations.GPUCombination(),)) + def testMultipleDistributionMultiWorker(self, ds1, ds2): + # combinations library should raise an exception. + pass + + +# Tests that we *actually* run the test method in multiple workers instead of +# just passing silently. More importantly, it verifies that the test can fail. +# Note that unittest.expectedFailure doesn't work with parameterized test +# methods, so we have to decorate the class instead. +@unittest.expectedFailure +class CombinationsExpectedFailureTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine(distribution=[ + combinations.NamedDistribution( + "OneChiefOneWorker", lambda: None, has_chief=True, num_workers=1), + combinations.NamedDistribution( + "TwoWorkers", lambda: None, has_chief=False, num_workers=2), + ])) + def testMultiWorkerCanFail(self, distribution): + resolver = tfconfig_cluster_resolver.TFConfigClusterResolver() + # This should fail. + self.assertIsNone(resolver.task_id) + + +# Tests that we *actually* run the test method in multiple workers instead of +# just passing silently. More importantly, it verifies that the test can fail. +# Note that unittest.expectedFailure doesn't work with parameterized test +# methods, so we have to decorate the class instead. +@unittest.expectedFailure +@combinations.generate( + combinations.combine(distribution=[ + combinations.NamedDistribution( + "OneChiefOneWorker", lambda: None, has_chief=True, num_workers=1), + combinations.NamedDistribution( + "TwoWorkers", lambda: None, has_chief=False, num_workers=2), + ])) +class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase, + parameterized.TestCase): + + def test(self, distribution): + resolver = tfconfig_cluster_resolver.TFConfigClusterResolver() + # This should fail. + self.assertIsNone(resolver.task_id) + + +if __name__ == "__main__": + combinations.main() diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index 72e4324d0b665f..8d721698d5c6cb 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -20,6 +20,7 @@ from tensorflow.python import tf2 from tensorflow.python.distribute import central_storage_strategy +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy as mirrored_lib @@ -148,6 +149,18 @@ def _create_tpu_strategy(): lambda: central_storage_strategy.CentralStorageStrategy( ["/gpu:0", "/cpu:0"]), required_gpus=1) +multi_worker_mirrored_two_workers = combinations.NamedDistribution( + "MultiWorkerMirrroedTwoWorkers", + collective_all_reduce_strategy.CollectiveAllReduceStrategy, + has_chief=False, + num_workers=2, +) +multi_worker_mirrored_one_chief_one_worker = combinations.NamedDistribution( + "MultiWorkerMirrroedOneChiefOneWorker", + collective_all_reduce_strategy.CollectiveAllReduceStrategy, + has_chief=True, + num_workers=1, +) gradient_descent_optimizer_v1_fn = combinations.NamedObject( "GradientDescentV1",