[go: nahoru, domu]

Skip to content

Commit

Permalink
Change AllReduceSimplifier to handle trivial cross-partition all-redu…
Browse files Browse the repository at this point in the history
…ces.

This CL ensures that AllReduceSimplifier can simplify trivial all-reduces (an all-reduce where each subgroup is formed of a single participant) that are not necessarily cross replica (for example a cross partition all-reduce). We only simplify non cross replica all-reduce when the module is SPMD.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#14073 from apivovarov:select_compare_algsimp 6fe68d7319b272ff041b67e038359540cddda489
PiperOrigin-RevId: 644267887
  • Loading branch information
tensorflower-gardener committed Jun 26, 2024
1 parent 633e9cc commit cde7a2e
Show file tree
Hide file tree
Showing 55 changed files with 790 additions and 262 deletions.
93 changes: 55 additions & 38 deletions tensorflow/compiler/mlir/lite/quantization/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ package(
licenses = ["notice"],
)

exports_files(glob([
"testdata/*.bin",
]))

package_group(
name = "friends",
packages = [
Expand Down Expand Up @@ -123,52 +127,52 @@ tf_cc_test(
name = "quantize_model_test",
srcs = ["quantize_model_test.cc"],
args = [
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
"--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)",
],
data = [
"//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
"//tensorflow/lite/tools/optimize:testdata/argmax.bin",
"//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin",
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
"//tensorflow/lite/tools/optimize:testdata/fc.bin",
"//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",
"//tensorflow/lite/tools/optimize:testdata/gather_nd.bin",
"//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin",
"//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin",
"//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin",
"//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin",
"//tensorflow/lite/tools/optimize:testdata/maximum.bin",
"//tensorflow/lite/tools/optimize:testdata/minimum.bin",
"//tensorflow/lite/tools/optimize:testdata/mixed.bin",
"//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin",
"//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin",
"//tensorflow/lite/tools/optimize:testdata/pack.bin",
"//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
"//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin",
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
"//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin",
"//tensorflow/lite/tools/optimize:testdata/split.bin",
"//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin",
"//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin",
"//tensorflow/lite/tools/optimize:testdata/transpose.bin",
"//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin",
"//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin",
"//tensorflow/lite/tools/optimize:testdata/unpack.bin",
"//tensorflow/lite/tools/optimize:testdata/where.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin",
],
tags = [
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
deps = [
":quantize_model",
":test_util",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/lite/schema:schema_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/lite:framework",
"//tensorflow/lite/tools/optimize:test_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
Expand All @@ -181,13 +185,13 @@ tf_cc_test(
name = "quantize_weights_test",
srcs = ["quantize_weights_test.cc"],
args = [
"--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)",
"--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)",
],
data = [
"//tensorflow/lite/tools/optimize:testdata/custom_op.bin",
"//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin",
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin",
"//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin",
"//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin",
],
tags = [
# TODO(b/327796566): re-enable after the bug is fixed
Expand All @@ -200,15 +204,28 @@ tf_cc_test(
],
deps = [
":quantize_weights",
":test_util",
"//tensorflow/compiler/mlir/lite/schema:schema_fbs",
"//tensorflow/compiler/mlir/lite/schema:schema_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/lite:framework",
"//tensorflow/lite/tools/optimize:test_util",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
"@flatbuffers",
"@local_tsl//tsl/platform:logging",
],
)

cc_library(
name = "test_util",
testonly = 1,
srcs = ["test_util.cc"],
hdrs = ["test_util.h"],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/core/api",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
Loading

0 comments on commit cde7a2e

Please sign in to comment.