From 0018db366526b9a28a7e3bae6ed53b4af9d23a89 Mon Sep 17 00:00:00 2001 From: redwrasse Date: Tue, 26 Mar 2024 20:42:28 -0700 Subject: [PATCH 001/124] add num_results + skip overflow check test --- tensorflow/python/ops/sobol_ops_test.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/sobol_ops_test.py b/tensorflow/python/ops/sobol_ops_test.py index a795cbdcf4f06e..c7517186b0e13f 100644 --- a/tensorflow/python/ops/sobol_ops_test.py +++ b/tensorflow/python/ops/sobol_ops_test.py @@ -140,7 +140,7 @@ def test_non_scalar_input(self): skip=constant_op.constant([1]))) @test_util.run_in_graph_and_eager_modes - def testDimNumResultsOverflow(self): + def test_dim_num_results_overflow(self): with self.assertRaisesRegex( (ValueError, errors.InvalidArgumentError), r'num_results\*dim must be less than 2147483647'): @@ -148,6 +148,15 @@ def testDimNumResultsOverflow(self): gen_math_ops.sobol_sample( dim=2560, num_results=16384000, skip=0, dtype=dtypes.float32)) + @test_util.run_in_graph_and_eager_modes + def test_num_results_skip_overflow(self): + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), + r'num_results\+skip must be less than 2147483647'): + self.evaluate( + gen_math_ops.sobol_sample( + dim=1, num_results=1, skip=2147483647, dtype=dtypes.float32)) + if __name__ == '__main__': googletest.main() From d6e5325809bac99b37bb8d8f13c66d694cecdf34 Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Thu, 28 Mar 2024 00:41:41 +0000 Subject: [PATCH 002/124] shlo_ref: Move template specializations GCC gives compilation errors that these template specializations are in a non-namespace scope. This PR moves the template specializations out of the struct scope and into a namespace scope. --- tensorflow/lite/experimental/shlo/ops/cbrt.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/cbrt_test.cc | 18 +++++++++--------- tensorflow/lite/experimental/shlo/ops/ceil.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/ceil_test.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/cosine.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/cosine_test.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/exponential.cc | 18 +++++++++--------- .../shlo/ops/exponential_minus_one.cc | 18 +++++++++--------- .../shlo/ops/exponential_minus_one_test.cc | 18 +++++++++--------- .../experimental/shlo/ops/exponential_test.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/floor.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/floor_test.cc | 18 +++++++++--------- tensorflow/lite/experimental/shlo/ops/log.cc | 18 +++++++++--------- .../experimental/shlo/ops/log_plus_one.cc | 18 +++++++++--------- .../shlo/ops/log_plus_one_test.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/log_test.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/logistic.cc | 18 +++++++++--------- .../experimental/shlo/ops/logistic_test.cc | 18 +++++++++--------- tensorflow/lite/experimental/shlo/ops/not.cc | 9 +++++---- .../lite/experimental/shlo/ops/not_test.cc | 9 +++++---- tensorflow/lite/experimental/shlo/ops/sign.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/sign_test.cc | 18 +++++++++--------- tensorflow/lite/experimental/shlo/ops/sine.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/sine_test.cc | 18 +++++++++--------- tensorflow/lite/experimental/shlo/ops/sqrt.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/sqrt_test.cc | 19 ++++++++++--------- tensorflow/lite/experimental/shlo/ops/tanh.cc | 18 +++++++++--------- .../lite/experimental/shlo/ops/tanh_test.cc | 18 +++++++++--------- 28 files changed, 245 insertions(+), 242 deletions(-) diff --git a/tensorflow/lite/experimental/shlo/ops/cbrt.cc b/tensorflow/lite/experimental/shlo/ops/cbrt.cc index 2e50c92c2e5998..076e5175436a5e 100644 --- a/tensorflow/lite/experimental/shlo/ops/cbrt.cc +++ b/tensorflow/lite/experimental/shlo/ops/cbrt.cc @@ -32,17 +32,17 @@ struct Cbrt { T operator()(T v) const { return std::cbrt(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Cbrt::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Cbrt::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} CbrtOp Create(CbrtOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc b/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc index 687e3cb7debb15..b62edcfebbfb7c 100644 --- a/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/cbrt_test.cc @@ -48,17 +48,17 @@ struct Cbrt { T operator()(T v) const { return std::cbrt(v); } +} cbrt_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Cbrt::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} cbrt_ref; +template <> +BF16 Cbrt::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Cbrt, UnaryElementwiseOpShapePropagationTest, CbrtOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/ceil.cc b/tensorflow/lite/experimental/shlo/ops/ceil.cc index a6b501131db5f9..95f96a38bafc4a 100644 --- a/tensorflow/lite/experimental/shlo/ops/ceil.cc +++ b/tensorflow/lite/experimental/shlo/ops/ceil.cc @@ -33,17 +33,17 @@ struct Ceil { T operator()(T v) const { return std::ceil(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Ceil::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Ceil::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} CeilOp Create(CeilOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/ceil_test.cc b/tensorflow/lite/experimental/shlo/ops/ceil_test.cc index 4059b19bcca63c..2e3e6288c14f4f 100644 --- a/tensorflow/lite/experimental/shlo/ops/ceil_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/ceil_test.cc @@ -48,17 +48,17 @@ struct Ceil { T operator()(T v) const { return std::ceil(v); } +} ceil_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Ceil::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} ceil_ref; +template <> +BF16 Ceil::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Ceil, UnaryElementwiseOpShapePropagationTest, CeilOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/cosine.cc b/tensorflow/lite/experimental/shlo/ops/cosine.cc index 8b757f9709ef18..5bd347836af75e 100644 --- a/tensorflow/lite/experimental/shlo/ops/cosine.cc +++ b/tensorflow/lite/experimental/shlo/ops/cosine.cc @@ -32,17 +32,17 @@ struct Cosine { T operator()(T v) const { return std::cos(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Cosine::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Cosine::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} CosineOp Create(CosineOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/cosine_test.cc b/tensorflow/lite/experimental/shlo/ops/cosine_test.cc index 41fce8a264dd57..9f858c9d6cc666 100644 --- a/tensorflow/lite/experimental/shlo/ops/cosine_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/cosine_test.cc @@ -48,17 +48,17 @@ struct Cosine { T operator()(T v) const { return std::cos(v); } +} cosine_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Cosine::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} cosine_ref; +template <> +BF16 Cosine::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Cosine, UnaryElementwiseOpShapePropagationTest, CosineOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/exponential.cc b/tensorflow/lite/experimental/shlo/ops/exponential.cc index 8c8994ccc8b296..979ddef45889ff 100644 --- a/tensorflow/lite/experimental/shlo/ops/exponential.cc +++ b/tensorflow/lite/experimental/shlo/ops/exponential.cc @@ -32,17 +32,17 @@ struct Exponential { T operator()(T v) const { return std::exp(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Exponential::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Exponential::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} ExponentialOp Create(ExponentialOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/exponential_minus_one.cc b/tensorflow/lite/experimental/shlo/ops/exponential_minus_one.cc index 57de5eb188df0f..a5bcab04280ba1 100644 --- a/tensorflow/lite/experimental/shlo/ops/exponential_minus_one.cc +++ b/tensorflow/lite/experimental/shlo/ops/exponential_minus_one.cc @@ -32,17 +32,17 @@ struct ExponentialMinusOne { T operator()(T v) const { return std::expm1(v); } +}; - template <> - F16 operator()(F16 v) const { - return F16(operator()(static_cast(v))); - } +template <> +F16 ExponentialMinusOne::operator()(F16 v) const { + return F16(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return BF16(operator()(static_cast(v))); - } -}; +template <> +BF16 ExponentialMinusOne::operator()(BF16 v) const { + return BF16(operator()(static_cast(v))); +} ExponentialMinusOneOp Create(ExponentialMinusOneOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/exponential_minus_one_test.cc b/tensorflow/lite/experimental/shlo/ops/exponential_minus_one_test.cc index b791350e49a304..0fbe259ecefafc 100644 --- a/tensorflow/lite/experimental/shlo/ops/exponential_minus_one_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/exponential_minus_one_test.cc @@ -48,17 +48,17 @@ struct ExponentialMinusOne { T operator()(T v) const { return std::expm1(v); } +} exponential_minus_one_ref; - template <> - F16 operator()(F16 v) const { - return F16(operator()(static_cast(v))); - } +template <> +F16 ExponentialMinusOne::operator()(F16 v) const { + return F16(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return BF16(operator()(static_cast(v))); - } -} exponential_minus_one_ref; +template <> +BF16 ExponentialMinusOne::operator()(BF16 v) const { + return BF16(operator()(static_cast(v))); +} INSTANTIATE_TYPED_TEST_SUITE_P(ExponentialMinusOne, UnaryElementwiseOpShapePropagationTest, diff --git a/tensorflow/lite/experimental/shlo/ops/exponential_test.cc b/tensorflow/lite/experimental/shlo/ops/exponential_test.cc index 12a180a5a60826..f8cab0a7afc137 100644 --- a/tensorflow/lite/experimental/shlo/ops/exponential_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/exponential_test.cc @@ -48,17 +48,17 @@ struct Exponential { T operator()(T v) const { return std::exp(v); } +} exponential_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Exponential::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} exponential_ref; +template <> +BF16 Exponential::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Exponential, UnaryElementwiseOpShapePropagationTest, diff --git a/tensorflow/lite/experimental/shlo/ops/floor.cc b/tensorflow/lite/experimental/shlo/ops/floor.cc index 1a3c9e8efbb07f..7ef86a3cb53e93 100644 --- a/tensorflow/lite/experimental/shlo/ops/floor.cc +++ b/tensorflow/lite/experimental/shlo/ops/floor.cc @@ -32,17 +32,17 @@ struct Floor { T operator()(T v) const { return std::floor(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Floor::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Floor::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} FloorOp Create(FloorOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/floor_test.cc b/tensorflow/lite/experimental/shlo/ops/floor_test.cc index 08ad8a8de670c5..bf0e19f0c10aeb 100644 --- a/tensorflow/lite/experimental/shlo/ops/floor_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/floor_test.cc @@ -48,17 +48,17 @@ struct Floor { T operator()(T v) const { return std::floor(v); } +} floor_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Floor::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} floor_ref; +template <> +BF16 Floor::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Floor, UnaryElementwiseOpShapePropagationTest, FloorOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/log.cc b/tensorflow/lite/experimental/shlo/ops/log.cc index 1beca617bc5880..9f3f68ae8e7fdf 100644 --- a/tensorflow/lite/experimental/shlo/ops/log.cc +++ b/tensorflow/lite/experimental/shlo/ops/log.cc @@ -32,17 +32,17 @@ struct Log { T operator()(T v) const { return std::log(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Log::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Log::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} LogOp Create(LogOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/log_plus_one.cc b/tensorflow/lite/experimental/shlo/ops/log_plus_one.cc index 66d3a7a11cc5b8..f80ebcce54d9fa 100644 --- a/tensorflow/lite/experimental/shlo/ops/log_plus_one.cc +++ b/tensorflow/lite/experimental/shlo/ops/log_plus_one.cc @@ -32,17 +32,17 @@ struct LogPlusOne { T operator()(T v) const { return std::log1p(v); } +}; - template <> - F16 operator()(F16 v) const { - return F16(operator()(static_cast(v))); - } +template <> +F16 LogPlusOne::operator()(F16 v) const { + return F16(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return BF16(operator()(static_cast(v))); - } -}; +template <> +BF16 LogPlusOne::operator()(BF16 v) const { + return BF16(operator()(static_cast(v))); +} LogPlusOneOp Create(LogPlusOneOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/log_plus_one_test.cc b/tensorflow/lite/experimental/shlo/ops/log_plus_one_test.cc index f636b72cb23e3f..9e303cc812d310 100644 --- a/tensorflow/lite/experimental/shlo/ops/log_plus_one_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/log_plus_one_test.cc @@ -48,17 +48,17 @@ struct LogPlusOne { T operator()(T v) const { return std::log1p(v); } +} log_plus_one_ref; - template <> - F16 operator()(F16 v) const { - return F16(operator()(static_cast(v))); - } +template <> +F16 LogPlusOne::operator()(F16 v) const { + return F16(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return BF16(operator()(static_cast(v))); - } -} log_plus_one_ref; +template <> +BF16 LogPlusOne::operator()(BF16 v) const { + return BF16(operator()(static_cast(v))); +} INSTANTIATE_TYPED_TEST_SUITE_P(LogPlusOne, UnaryElementwiseOpShapePropagationTest, diff --git a/tensorflow/lite/experimental/shlo/ops/log_test.cc b/tensorflow/lite/experimental/shlo/ops/log_test.cc index 2d5c45239bd0c5..5f2c59f147ee4e 100644 --- a/tensorflow/lite/experimental/shlo/ops/log_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/log_test.cc @@ -48,17 +48,17 @@ struct Log { T operator()(T v) const { return std::log(v); } +} log_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Log::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} log_ref; +template <> +BF16 Log::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Log, UnaryElementwiseOpShapePropagationTest, LogOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/logistic.cc b/tensorflow/lite/experimental/shlo/ops/logistic.cc index 0174e4113e2c29..3953cfe3441810 100644 --- a/tensorflow/lite/experimental/shlo/ops/logistic.cc +++ b/tensorflow/lite/experimental/shlo/ops/logistic.cc @@ -33,17 +33,17 @@ struct Logistic { constexpr T one = static_cast(1); return one / (one + std::exp(-v)); } +}; - template <> - F16 operator()(F16 v) const { - return F16(operator()(static_cast(v))); - } +template <> +F16 Logistic::operator()(F16 v) const { + return F16(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return BF16(operator()(static_cast(v))); - } -}; +template <> +BF16 Logistic::operator()(BF16 v) const { + return BF16(operator()(static_cast(v))); +} LogisticOp Create(LogisticOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/logistic_test.cc b/tensorflow/lite/experimental/shlo/ops/logistic_test.cc index e11df7372a40c0..3d8014e33d133a 100644 --- a/tensorflow/lite/experimental/shlo/ops/logistic_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/logistic_test.cc @@ -49,17 +49,17 @@ struct Logistic { constexpr T one = static_cast(1); return one / (one + std::exp(-v)); } +} logistic_ref; - template <> - F16 operator()(F16 v) const { - return F16(operator()(static_cast(v))); - } +template <> +F16 Logistic::operator()(F16 v) const { + return F16(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return BF16(operator()(static_cast(v))); - } -} logistic_ref; +template <> +BF16 Logistic::operator()(BF16 v) const { + return BF16(operator()(static_cast(v))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Logistic, UnaryElementwiseOpShapePropagationTest, LogisticOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/not.cc b/tensorflow/lite/experimental/shlo/ops/not.cc index 5029afc7260f41..b8d613309b9010 100644 --- a/tensorflow/lite/experimental/shlo/ops/not.cc +++ b/tensorflow/lite/experimental/shlo/ops/not.cc @@ -28,12 +28,13 @@ struct Not { T operator()(T v) const { return ~v; } - template <> - bool operator()(bool v) const { - return !v; - } }; +template <> +bool Not::operator()(bool v) const { + return !v; +} + NotOp Create(NotOp::Attributes) { return {}; } absl::Status Prepare(NotOp& op, const Tensor& input, Tensor& output) { diff --git a/tensorflow/lite/experimental/shlo/ops/not_test.cc b/tensorflow/lite/experimental/shlo/ops/not_test.cc index cc719a1badbf3f..6a036867895f84 100644 --- a/tensorflow/lite/experimental/shlo/ops/not_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/not_test.cc @@ -48,12 +48,13 @@ struct Not { T operator()(T v) const { return ~v; } - template <> - bool operator()(bool v) const { - return !v; - } } not_ref; +template <> +bool Not::operator()(bool v) const { + return !v; +} + INSTANTIATE_TYPED_TEST_SUITE_P(Not, UnaryElementwiseOpShapePropagationTest, NotOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/sign.cc b/tensorflow/lite/experimental/shlo/ops/sign.cc index 6703650dc0c18c..197b87ba6bb3bc 100644 --- a/tensorflow/lite/experimental/shlo/ops/sign.cc +++ b/tensorflow/lite/experimental/shlo/ops/sign.cc @@ -32,17 +32,17 @@ struct Sign { constexpr T zero = static_cast(0); return v < zero ? -one : (v > zero ? one : v); } +}; - template <> - F16 operator()(F16 v) const { - return static_cast(operator()(static_cast(v))); - } +template <> +F16 Sign::operator()(F16 v) const { + return static_cast(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return static_cast(operator()(static_cast(v))); - } -}; +template <> +BF16 Sign::operator()(BF16 v) const { + return static_cast(operator()(static_cast(v))); +} SignOp Create(SignOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/sign_test.cc b/tensorflow/lite/experimental/shlo/ops/sign_test.cc index e786c185b9840d..67ec823482036a 100644 --- a/tensorflow/lite/experimental/shlo/ops/sign_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/sign_test.cc @@ -49,17 +49,17 @@ struct Sign { constexpr T zero = static_cast(0); return v < zero ? -one : (v > zero ? one : v); } +} sign_ref; - template <> - F16 operator()(F16 v) const { - return static_cast(operator()(static_cast(v))); - } +template <> +F16 Sign::operator()(F16 v) const { + return static_cast(operator()(static_cast(v))); +} - template <> - BF16 operator()(BF16 v) const { - return static_cast(operator()(static_cast(v))); - } -} sign_ref; +template <> +BF16 Sign::operator()(BF16 v) const { + return static_cast(operator()(static_cast(v))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Sign, UnaryElementwiseOpShapePropagationTest, SignOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/sine.cc b/tensorflow/lite/experimental/shlo/ops/sine.cc index 1d1228f18f0804..e69d98e07c1517 100644 --- a/tensorflow/lite/experimental/shlo/ops/sine.cc +++ b/tensorflow/lite/experimental/shlo/ops/sine.cc @@ -33,17 +33,17 @@ struct Sine { T operator()(T v) const { return std::sin(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Sine::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Sine::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} SineOp Create(SineOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/sine_test.cc b/tensorflow/lite/experimental/shlo/ops/sine_test.cc index c82f2c570bb858..fa16dee3b1d27f 100644 --- a/tensorflow/lite/experimental/shlo/ops/sine_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/sine_test.cc @@ -48,17 +48,17 @@ struct Sine { T operator()(T v) const { return std::sin(v); } +} sine_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Sine::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} sine_ref; +template <> +BF16 Sine::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Sine, UnaryElementwiseOpShapePropagationTest, SineOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/sqrt.cc b/tensorflow/lite/experimental/shlo/ops/sqrt.cc index e13ff7a3025aa1..f34841fc9cb874 100644 --- a/tensorflow/lite/experimental/shlo/ops/sqrt.cc +++ b/tensorflow/lite/experimental/shlo/ops/sqrt.cc @@ -32,17 +32,17 @@ struct Sqrt { T operator()(T v) const { return std::sqrt(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Sqrt::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Sqrt::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} SqrtOp Create(SqrtOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/sqrt_test.cc b/tensorflow/lite/experimental/shlo/ops/sqrt_test.cc index 161a6c6882d4a4..937c871b93745f 100644 --- a/tensorflow/lite/experimental/shlo/ops/sqrt_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/sqrt_test.cc @@ -49,17 +49,18 @@ struct Sqrt { return std::sqrt(v); } - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } - - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } } sqrt_ref; +template <> +F16 Sqrt::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} + +template <> +BF16 Sqrt::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} + INSTANTIATE_TYPED_TEST_SUITE_P(Sqrt, UnaryElementwiseOpShapePropagationTest, SqrtOp, TestParamNames); diff --git a/tensorflow/lite/experimental/shlo/ops/tanh.cc b/tensorflow/lite/experimental/shlo/ops/tanh.cc index 3ba4c17a88dba6..d2518f6ba81b5f 100644 --- a/tensorflow/lite/experimental/shlo/ops/tanh.cc +++ b/tensorflow/lite/experimental/shlo/ops/tanh.cc @@ -33,17 +33,17 @@ struct Tanh { T operator()(T v) const { return std::tanh(v); } +}; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Tanh::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -}; +template <> +BF16 Tanh::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} TanhOp Create(TanhOp::Attributes) { return {}; } diff --git a/tensorflow/lite/experimental/shlo/ops/tanh_test.cc b/tensorflow/lite/experimental/shlo/ops/tanh_test.cc index d57e52d7318235..0343a087b65a62 100644 --- a/tensorflow/lite/experimental/shlo/ops/tanh_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/tanh_test.cc @@ -48,17 +48,17 @@ struct Tanh { T operator()(T v) const { return std::tanh(v); } +} tanh_ref; - template <> - F16 operator()(F16 val) const { - return F16(operator()(static_cast(val))); - } +template <> +F16 Tanh::operator()(F16 val) const { + return F16(operator()(static_cast(val))); +} - template <> - BF16 operator()(BF16 val) const { - return BF16(operator()(static_cast(val))); - } -} tanh_ref; +template <> +BF16 Tanh::operator()(BF16 val) const { + return BF16(operator()(static_cast(val))); +} INSTANTIATE_TYPED_TEST_SUITE_P(Tanh, UnaryElementwiseOpShapePropagationTest, TanhOp, TestParamNames); From d6475ffcc84bf3f94d671e90f933cc3f9b89b19c Mon Sep 17 00:00:00 2001 From: Vlad Sytchenko Date: Wed, 27 Mar 2024 22:16:12 -0700 Subject: [PATCH 003/124] If two calls to "ProductOfElementaryHouseholderReflectors" are in the same module with the same shape of the first operand, but different shapes of tau, one will end up expanded to the wrong computation. Append tau shape to the computation name when its there to avoid conflicts. PiperOrigin-RevId: 619788209 --- third_party/xla/xla/client/lib/qr_test.cc | 35 ++++++++++++++++++++++ third_party/xla/xla/service/qr_expander.cc | 5 +++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/client/lib/qr_test.cc b/third_party/xla/xla/client/lib/qr_test.cc index 1da9cab60a02d0..a21932b3e797e3 100644 --- a/third_party/xla/xla/client/lib/qr_test.cc +++ b/third_party/xla/xla/client/lib/qr_test.cc @@ -145,4 +145,39 @@ XLA_TEST_F(QrTest, SubnormalComplex) { xla::ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_F(QrTest, DuplicateHouseholderExpansion) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a0_vals({ + {0, 1, 1}, + {1, 0, 1}, + {1, 1, 0}, + }); + xla::Array2D a1_vals({ + {1, 0}, + {0, 1}, + {1, 0}, + }); + + // Verifies that different computations are created to generate HouseHolder + // transformations with identical QR shapes, but different tau shapes. + // The first QR decomposition should generate a ([3,3], [3]) computation, + // the second should generate a ([3,3], [2]) computation. Mismatch will result + // in compilation failure. + + xla::XlaOp a0, q0, r0; + auto a0_data = CreateR2Parameter(a0_vals, 0, "a0", &builder, &a0); + xla::QrExplicit(a0, /*full_matrices=*/true, q0, r0); + + xla::XlaOp a1, q1, r1; + auto a1_data = CreateR2Parameter(a1_vals, 1, "a1", &builder, &a1); + xla::QrExplicit(a1, /*full_matrices=*/true, q1, r1); + + // Verifies that the decomposition composes back to the original matrix. + xla::BatchDot(q1, r1, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a1_vals, {a0_data.get(), a1_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + } // namespace diff --git a/third_party/xla/xla/service/qr_expander.cc b/third_party/xla/xla/service/qr_expander.cc index cc9b10b0905fec..e817b66b61d2c8 100644 --- a/third_party/xla/xla/service/qr_expander.cc +++ b/third_party/xla/xla/service/qr_expander.cc @@ -509,9 +509,12 @@ bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) { absl::StatusOr QrExpander::ExpandInstruction( HloInstruction* instruction) { - const std::string name = + std::string name = absl::StrFormat("xla.%s_%s", instruction->custom_call_target(), instruction->operand(0)->shape().ToString()); + if (instruction->custom_call_target() == kHouseholderProductCustomCallName) { + name += "_" + instruction->operand(1)->shape().ToString(); + } HloModule* module = instruction->GetModule(); From 573fc841781f593f40ac75e8f8de204c1162f8f9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 Mar 2024 22:42:07 -0700 Subject: [PATCH 004/124] Automated Code Change PiperOrigin-RevId: 619794438 --- tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD | 4 ---- .../c/experimental/ops/gen/cpp/renderers/cpp_config.cc | 2 ++ .../ops/gen/cpp/renderers/cpp_file_renderer.cc | 3 ++- .../experimental/ops/gen/cpp/renderers/guard_renderer.cc | 5 ++++- .../c/experimental/ops/gen/cpp/renderers/guard_renderer.h | 1 + .../ops/gen/cpp/renderers/include_renderer.cc | 5 ++++- .../experimental/ops/gen/cpp/renderers/include_renderer.h | 1 + .../ops/gen/cpp/renderers/namespace_renderer.cc | 4 +++- .../ops/gen/cpp/renderers/namespace_renderer.h | 1 + .../ops/gen/cpp/renderers/op_comment_renderer.cc | 2 ++ .../ops/gen/cpp/renderers/op_comment_renderer.h | 1 + .../ops/gen/cpp/renderers/op_implementation_renderer.cc | 1 + .../ops/gen/cpp/renderers/op_implementation_renderer.h | 1 + .../c/experimental/ops/gen/cpp/renderers/op_renderer.cc | 8 ++++++++ .../c/experimental/ops/gen/cpp/renderers/op_renderer.h | 2 ++ .../c/experimental/ops/gen/cpp/renderers/renderer.cc | 3 +++ .../c/experimental/ops/gen/cpp/renderers/renderer.h | 1 + .../c/experimental/ops/gen/cpp/renderers/renderer_test.cc | 4 ++++ 18 files changed, 41 insertions(+), 8 deletions(-) diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD index 7589ea2d2f24a2..c13bc899f2d016 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD @@ -20,13 +20,9 @@ cc_library( deps = [ "//tensorflow/c/experimental/ops/gen/common", "//tensorflow/c/experimental/ops/gen/cpp/views", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:op_gen_lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:str_util", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc index 36c25c92760872..1fc16e093c011d 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc index 44f23ae0fb6aed..71132cfc3bf8b2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.h" -#include "tensorflow/c/experimental/ops/gen/common/view_util.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc index 8bfd5a334c565d..7a4275b532eda7 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h" #include "tensorflow/c/experimental/ops/gen/common/case_format.h" -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h index cfe2a99acfddce..a45fe89a7a011c 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_GUARD_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc index 5242d6f1baf255..38f31209f6da24 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h" -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h index b98547079f3ac7..e43715a62e45b0 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_INCLUDE_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc index 5547ca22df7ab0..db28ab303ae5c6 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h" -#include "absl/strings/str_split.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h index a54fc5878a0ad4..fd8ccf9531ef51 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc index e5afb7b6d63393..5d11bcada6e8c0 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h index 1d85c4c9fd7940..9131cc945349af 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_COMMENT_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc index e2184fcc7f834f..804e0585f88cca 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/common/view_util.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h index 9237eb9410bad7..98c3b0d75524aa 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_IMPLEMENTATION_RENDERER_H_ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc index 41db2ced426b47..c58e67782dfc34 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc @@ -16,7 +16,15 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_argument_view.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h index c29fb35b5b6b7c..3360e14e672e3a 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc index 0e6ee460512d2d..41d1dea64b3689 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h index b0a95baefa7676..b6168b196b35b2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_H_ +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc index 2674e5f156d9d5..eff654c5938160 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc @@ -14,8 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { From f29b6b1517cab1f1e9da22413840bb28ba0a8136 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 27 Mar 2024 22:51:58 -0700 Subject: [PATCH 005/124] Automated Code Change PiperOrigin-RevId: 619796552 --- third_party/xla/xla/backends/profiler/cpu/BUILD | 1 - third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/third_party/xla/xla/backends/profiler/cpu/BUILD b/third_party/xla/xla/backends/profiler/cpu/BUILD index 5c55e71f8406b3..6cfd1f1a51815b 100644 --- a/third_party/xla/xla/backends/profiler/cpu/BUILD +++ b/third_party/xla/xla/backends/profiler/cpu/BUILD @@ -126,7 +126,6 @@ xla_cc_test( srcs = ["host_tracer_test.cc"], deps = [ ":host_tracer_impl", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc index 2d71ca5286ad42..881f46e50837ff 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" From 3b7619292d3cbe9cf124870f9dc5810d8d96aa5e Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Wed, 27 Mar 2024 23:18:47 -0700 Subject: [PATCH 006/124] Add basic support for folding constant RTVars This allows turning dynamic HLO operands into a index expressions when the HLO yields some constant value. As a first step this is adding support for the HLO ops `constant` and `iota`. PiperOrigin-RevId: 619803271 --- third_party/xla/xla/service/gpu/fusions/BUILD | 4 + .../xla/service/gpu/fusions/fusion_emitter.cc | 3 +- .../in_place_dynamic_update_slice_mlir.cc | 3 +- .../xla/xla/service/gpu/fusions/loop.cc | 2 +- .../xla/xla/service/gpu/fusions/loop_mlir.cc | 3 +- .../xla/xla/service/gpu/fusions/mlir/BUILD | 1 + .../gpu/fusions/mlir/simplify_affine.cc | 3 +- .../xla/xla/service/gpu/fusions/scatter.cc | 3 +- .../xla/service/gpu/fusions/scatter_mlir.cc | 5 +- .../xla/xla/service/gpu/fusions/transpose.cc | 4 +- .../xla/service/gpu/fusions/transpose_mlir.cc | 6 +- .../xla/xla/service/gpu/ir_emitter_triton.cc | 2 +- third_party/xla/xla/service/gpu/model/BUILD | 5 +- .../service/gpu/model/coalescing_analysis.cc | 8 +- .../service/gpu/model/indexing_analysis.cc | 24 ++- .../xla/service/gpu/model/indexing_analysis.h | 5 + .../gpu/model/indexing_analysis_test.cc | 2 +- .../xla/xla/service/gpu/model/indexing_map.cc | 76 ++++++++- .../xla/xla/service/gpu/model/indexing_map.h | 12 +- .../service/gpu/model/indexing_map_test.cc | 148 ++++++++++++++++-- 20 files changed, 276 insertions(+), 43 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index c54387cf6ff642..87995e26bf2ec4 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -66,6 +66,7 @@ cc_library( "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -348,6 +349,7 @@ cc_library( "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -393,6 +395,7 @@ cc_library( "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -518,6 +521,7 @@ cc_library( "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:fused_ir_emitter", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index 5cf2481dcb5322..e652532fd0464c 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -49,6 +49,7 @@ limitations under the License. #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/target_util.h" @@ -196,7 +197,7 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap( } else { indexing_map.AddConstraint(linear_index, Interval{0, num_elements - 1}); } - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); return indexing_map; } diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc index 2334a5b74c892a..eccdcfceee8a8e 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/xla_data.pb.h" @@ -111,7 +112,7 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( auto indexing = *ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/kDUSUpdateIndex, mlir_context); - indexing.Simplify(); + indexing.Simplify(GetIndexingMapForInstruction); indexing.RemoveUnusedSymbols(); int num_inputs = fusion.fused_instructions_computation()->num_parameters(); diff --git a/third_party/xla/xla/service/gpu/fusions/loop.cc b/third_party/xla/xla/service/gpu/fusions/loop.cc index b2ef86d5a916b6..e417f96923f4e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop.cc @@ -239,7 +239,7 @@ std::optional LoopFusion::ComputeThreadIdToInputIndexing( CHECK_EQ(output_to_input_indexing_set.size(), 1); IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); - thread_id_to_input_indexing_map.Simplify(); + thread_id_to_input_indexing_map.Simplify(GetIndexingMapForInstruction); return thread_id_to_input_indexing_map; } diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc index 63b54cadb2bdc5..bf41d50930ea95 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" @@ -87,7 +88,7 @@ std::optional MlirLoopFusion::ComputeThreadIdToInputIndexing( CHECK_EQ(output_to_input_indexing_set.size(), 1); IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); - thread_id_to_input_indexing_map.Simplify(); + thread_id_to_input_indexing_map.Simplify(GetIndexingMapForInstruction); return thread_id_to_input_indexing_map; } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index 4b9298fda6e6b5..3e475b1a9cdbeb 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -285,6 +285,7 @@ cc_library( "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_map", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc index 585bc4b5cf6420..e4020c97a57e7d 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -116,7 +117,7 @@ struct RewriteAffineApply IndexingMap map(op.getAffineMap(), dim_ranges, symbol_ranges, /*rt_vars=*/{}); - map.Simplify(); + map.Simplify(GetIndexingMapForInstruction); auto expr = map.GetAffineMap().getResult(0); RangeEvaluator range_evaluator(map.GetDimensionBounds(), diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.cc b/third_party/xla/xla/service/gpu/fusions/scatter.cc index 55488cc0df5e0c..0625f9efd4653b 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_nested.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" @@ -275,7 +276,7 @@ std::optional ScatterFusion::ComputeThreadIdToInputIndexing( RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}), /*rt_vars=*/{}}; auto scatter_indices_map = scatter_update_map * updates_to_indices_map; - scatter_indices_map.Simplify(); + scatter_indices_map.Simplify(GetIndexingMapForInstruction); return scatter_indices_map; } return scatter_update_map; diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index 932ec9cfccf1a1..6d8969a5251f31 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" @@ -123,7 +124,7 @@ std::optional MlirScatterFusion::ComputeThreadIdToInputIndexing( RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}), /*rt_vars=*/{}}; auto scatter_indices_map = scatter_update_map * updates_to_indices_map; - scatter_indices_map.Simplify(); + scatter_indices_map.Simplify(GetIndexingMapForInstruction); return scatter_indices_map; } return scatter_update_map; @@ -190,7 +191,7 @@ absl::Status MlirScatterFusion::EmitEntryFunction( /*root_index=*/0, /*hero_operand_index=*/kScatterUpdateIndex, mlir_context) .value(); - thread_id_to_update_map.Simplify(); + thread_id_to_update_map.Simplify(GetIndexingMapForInstruction); thread_id_to_update_map.RemoveUnusedSymbols(); const auto& root_computation = computations.FindPartitionedComputation( diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.cc b/third_party/xla/xla/service/gpu/fusions/transpose.cc index 99f113cbafbea7..ca7b3f7ff79228 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose.cc @@ -306,7 +306,7 @@ std::optional TransposeFusion::ComputeThreadIdToOutputIndexing( tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(), permuted_tiled_shape.dimensions()), GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx)); - map.Simplify(); + map.Simplify(GetIndexingMapForInstruction); return map; } @@ -318,7 +318,7 @@ std::optional TransposeFusion::ComputeThreadIdToInputIndexing( auto map = ComposeIndexingMaps( GetIndexingMapForTiling(tiling_, ctx), GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx)); - map.Simplify(); + map.Simplify(GetIndexingMapForInstruction); return map; } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index c654c2e4ec8b99..9e2e5be2ea564b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -155,7 +155,7 @@ std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(), permuted_tiled_shape.dimensions()), GetBitcastMap(permuted_tiled_shape, hero.shape(), mlir_context)); - map.Simplify(); + map.Simplify(GetIndexingMapForInstruction); return map; } @@ -165,7 +165,7 @@ IndexingMap MlirTransposeFusion::ComputeThreadIdToInputIndexing( GetIndexingMapForTiling(tiling_, mlir_context), GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), mlir_context)); - map.Simplify(); + map.Simplify(GetIndexingMapForInstruction); return map; } @@ -207,7 +207,7 @@ IndexingMap GetSharedMemoryWriteIndexingMap( thread_id_indexing.GetRangeVars(), thread_id_indexing.GetRTVars(), thread_id_indexing.GetConstraints()}; - shmem_write_indexing.Simplify(); + shmem_write_indexing.Simplify(GetIndexingMapForInstruction); return shmem_write_indexing; } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 1fcbb8ce3c4e2c..98bf1ac6519286 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -2442,7 +2442,7 @@ absl::Status EmitTiledSoftMax(mlir::OpBuilder builder, IndexingMap program_id_to_input_tile_indexing = ComposeIndexingMaps( program_id_to_output_tile_indexing, tiled_hlo_instruction.indexing_map); - program_id_to_input_tile_indexing.Simplify(); + program_id_to_input_tile_indexing.Simplify(GetIndexingMapForInstruction); // Manually compute pointer offset to avoid materialized fully parallel // dimensions in the tile. Current codegen tried to avoid size-1 dims. diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 560ebe466a3777..45084afa659674 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -424,7 +424,6 @@ cc_library( deps = [ ":affine_map_printer", "//xla/hlo/ir:hlo", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -438,8 +437,12 @@ xla_cc_test( srcs = ["indexing_map_test.cc"], deps = [ ":affine_map_printer", + ":indexing_analysis", ":indexing_map", ":indexing_test_utils", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 283569d551a6b7..e7ad4bd6c01a93 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -234,7 +234,7 @@ bool IsCoalesced(const IndexingMap& thread_id_to_input_indexing_map, /*rt_vars=*/{}}; IndexingMap thread_x_to_linearized_input = thread_x_first_32_elements * thread_id_to_input_indexing_map; - thread_x_to_linearized_input.Simplify(); + thread_x_to_linearized_input.Simplify(GetIndexingMapForInstruction); thread_x_to_linearized_input.RemoveUnusedSymbols(); return EstimateCoalescingViaMemoryTransactionsCount( FindContiguousIntervals(thread_x_to_linearized_input), element_type); @@ -300,7 +300,8 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( IndexingMap operand_logical_to_linearized_physical_shape = operand_logical_to_physical_map * operand_physical_to_linearized_shape; - operand_logical_to_linearized_physical_shape.Simplify(); + operand_logical_to_linearized_physical_shape.Simplify( + GetIndexingMapForInstruction); for (const IndexingMap& operand_indexing_map : operand_indexing_maps_it->second) { @@ -316,7 +317,8 @@ std::optional GetThreadIdToInputMemoryLayoutsMaps( IndexingMap thread_id_to_linearized_physical_input_map = *thread_id_to_hero_operand_map * logical_output_to_linearized_physical_input_map; - thread_id_to_linearized_physical_input_map.Simplify(); + thread_id_to_linearized_physical_input_map.Simplify( + GetIndexingMapForInstruction); result[operand].insert(thread_id_to_linearized_physical_input_map); } } diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index 79c6a04db159dd..b0deed5f9f8952 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -625,7 +625,7 @@ IndexingMap ComposeIndexingMapsForWindow( // Composed indexing. IndexingMap result = ComposeIndexingMaps(input_indexing_no_padding, padded_input_indexing); - result.Simplify(); + result.Simplify(GetIndexingMapForInstruction); result.RemoveUnusedSymbols(); return result; } @@ -939,7 +939,7 @@ HloInstructionIndexing ComputeOutputToInputReshapeOpIndexing( IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( ComputeReshapeIndexingMap(input, output, mlir_context), output.dimensions(), {}); - reshape_indexing_map.Simplify(); + reshape_indexing_map.Simplify(GetIndexingMapForInstruction); return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } HloInstructionIndexing ComputeInputToOutputReshapeOpIndexing( @@ -950,7 +950,7 @@ HloInstructionIndexing ComputeInputToOutputReshapeOpIndexing( IndexingMap reshape_indexing_map = IndexingMap::FromTensorSizes( ComputeReshapeIndexingMap(output, input, mlir_context), input.dimensions(), {}); - reshape_indexing_map.Simplify(); + reshape_indexing_map.Simplify(GetIndexingMapForInstruction); return HloInstructionIndexing::FromIndexingMaps({reshape_indexing_map}); } @@ -1065,7 +1065,7 @@ HloInstructionIndexing ComputeOutputToInputBitcastOpIndexing( const HloInstruction* bitcast, MLIRContext* mlir_context) { auto bitcast_map = GetBitcastMap(bitcast->shape(), bitcast->operand(0)->shape(), mlir_context); - bitcast_map.Simplify(); + bitcast_map.Simplify(GetIndexingMapForInstruction); return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); } @@ -1073,7 +1073,7 @@ HloInstructionIndexing ComputeInputToOutputBitcastOpIndexing( const HloInstruction* bitcast, MLIRContext* mlir_context) { auto bitcast_map = GetBitcastMap(bitcast->operand(0)->shape(), bitcast->shape(), mlir_context); - bitcast_map.Simplify(); + bitcast_map.Simplify(GetIndexingMapForInstruction); return HloInstructionIndexing::FromIndexingMaps({bitcast_map}); } @@ -1233,7 +1233,7 @@ bool HloInstructionIndexing::Simplify() { to_remove.push_back(map); if (map.IsUndefined()) { to_add.push_back(map); - } else if (map.Simplify()) { + } else if (map.Simplify(GetIndexingMapForInstruction)) { map.RemoveUnusedSymbols(); } else { to_remove.pop_back(); @@ -1348,7 +1348,7 @@ GroupedByOpIndexingMap ComputeGroupedOutputToInputIndexing( for (const IndexingMap& producer_map : producer_operand_indexing) { for (const IndexingMap& consumer_map : consumer_indexing_maps_copy) { auto composed_map = ComposeIndexingMaps(consumer_map, producer_map); - composed_map.Simplify(); + composed_map.Simplify(GetIndexingMapForInstruction); composed_map.RemoveUnusedSymbols(); grouped_indexing_maps[&producer_operand_adaptor.instruction()].insert( composed_map); @@ -1497,11 +1497,19 @@ IndexingMap ComputeEpilogueInputToOutputIndexing( auto user_indexing = ComputeInputToOutputIndexing( user, user->operand_index(instr), mlir_context); root_indexing = root_indexing * *user_indexing.indexing_maps[0].begin(); - root_indexing.Simplify(); + root_indexing.Simplify(GetIndexingMapForInstruction); instr = user; } return root_indexing; } +IndexingMap GetIndexingMapForInstruction(const HloInstruction* instr, + int64_t operand_idx, + mlir::MLIRContext* mlir_context) { + HloInstructionIndexing indexing = + ComputeOutputToInputIndexing(instr, operand_idx, mlir_context); + return *indexing.indexing_maps[0].begin(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index ed8c495c425899..22012ea472f887 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -180,6 +180,11 @@ llvm::SmallVector DelinearizeInBoundsIndex( mlir::AffineExpr linear, absl::Span sizes, absl::Span strides); +// Returns the output-to-input indexing map of the first output of `instr` +IndexingMap GetIndexingMapForInstruction(const HloInstruction* instr, + int64_t operand_idx, + mlir::MLIRContext* mlir_context); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index d6424b4aaaf9ce..1ae80e76fa0907 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -2546,7 +2546,7 @@ TEST_F(IndexingAnalysisTest, TilingIndexing) { /*tile_sizes=*/{8, 1, 4}, /*num_threads=*/{1, 4, 4}}; auto indexing_map = GetIndexingMapForTiling(tiling, &mlir_context_); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( (d3 floordiv 64) * 8 + s0, diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 5ad902f690bea1..129a5e67cf7c4e 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" #include +#include #include #include #include @@ -34,6 +35,9 @@ limitations under the License. #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/model/affine_map_printer.h" #include "tsl/platform/logging.h" // IWYU pragma: keep @@ -945,9 +949,11 @@ IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs) { // RangeEvaluator for every constraint. Note that we start with "expr" // simplification, because the ranges of constraints were already optimized once // when IndexingMap was constructed. -bool IndexingMap::Simplify() { +bool IndexingMap::Simplify(IndexingMapProvider indexing_map_provider) { if (IsUndefined()) return false; + bool rtvars_were_eliminated = ReplaceConstantRTVars(indexing_map_provider); + // Simplify constraints to shrink the lower/upper bounds of dims and symbols. bool constraints_were_simplified = false; while (true) { @@ -967,7 +973,8 @@ bool IndexingMap::Simplify() { if (affine_map_was_simplified) { affine_map_ = simplified_affine_map; } - return affine_map_was_simplified || constraints_were_simplified; + return affine_map_was_simplified || constraints_were_simplified || + rtvars_were_eliminated; } bool IndexingMap::SimplifyConstraintExprs() { @@ -1330,5 +1337,70 @@ bool IndexingMap::RescaleSymbols() { return !to_delete.empty(); } +static std::optional FoldsIntoConstantIndexingExpression( + const HloInstruction* instr, const mlir::AffineMap& affine_map, + MLIRContext* mlir_context, + IndexingMap::IndexingMapProvider indexing_map_provider) { + if (auto constant_expr = DynCast(instr)) { + if (affine_map.isConstant()) { + const auto idx = affine_map.getConstantResults(); + return getAffineConstantExpr( + constant_expr->literal().GetIntegralAsS64(idx).value(), mlir_context); + } + return std::nullopt; + } + + if (auto iota_expr = DynCast(instr)) { + auto iota_dimension = iota_expr->iota_dimension(); + CHECK(iota_dimension < affine_map.getNumResults()); + return affine_map.getResults()[iota_dimension]; + } + + return std::nullopt; +} + +bool IndexingMap::ReplaceConstantRTVars( + IndexingMap::IndexingMapProvider indexing_map_provider) { + if (rt_vars_.empty()) return false; + + std::vector to_delete; + + for (const auto& [index, rt_var] : llvm::enumerate(rt_vars_)) { + auto folded_expr = FoldsIntoConstantIndexingExpression( + rt_var.hlo, rt_var.map, GetMLIRContext(), indexing_map_provider); + if (!folded_expr.has_value()) continue; + + auto symbol_index = range_vars_.size() + index; + affine_map_ = affine_map_.replace( + {{mlir::getAffineSymbolExpr(symbol_index, GetMLIRContext()), + folded_expr.value()}}); + + llvm::DenseMap replacements; + + for (const auto& [constraint, interval] : constraints_) { + auto modified_constraint = constraint.replace( + mlir::getAffineSymbolExpr(symbol_index, GetMLIRContext()), + folded_expr.value()); + + if (constraint == modified_constraint) continue; + replacements[constraint] = modified_constraint; + } + + for (const auto& [old_expr, new_expr] : replacements) { + auto interval = constraints_.at(old_expr); + constraints_.erase(old_expr); + constraints_[new_expr] = interval; + } + + to_delete.emplace_back(index); + } + + for (auto index : llvm::reverse(to_delete)) { + rt_vars_.erase(rt_vars_.begin() + index); + } + + return !to_delete.empty(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 0419c05557364c..bfc8abf30bdd39 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/affine_map_printer.h" @@ -244,8 +245,13 @@ class IndexingMap { void Print(std::ostream& out, const AffineMapPrinter& printer) const; + // TODO(hebecker): Rearrange code structure so that we can call + // `ComputeInputToOutputIndexing` from `:indexing_analysis` directly. + using IndexingMapProvider = llvm::function_ref; + // Returns true if the map was simplified. - bool Simplify(); + bool Simplify(IndexingMapProvider indexing_map_provider); // Return MLIRContext. mlir::MLIRContext* GetMLIRContext() const; @@ -333,6 +339,10 @@ class IndexingMap { // Merges "mod" constraints for the same AffineExpr. void MergeModConstraints(); + // Replace RTVars that yield constants by indexing expressions. + // Returns true if a replacement was performed, otherwise false. + bool ReplaceConstantRTVars(IndexingMapProvider indexing_map_provider); + mlir::AffineMap affine_map_; std::vector dim_vars_; std::vector range_vars_; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 6118265cea60ce..084b535708d332 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" +#include #include #include #include @@ -24,8 +25,12 @@ limitations under the License. #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal_util.h" #include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -163,7 +168,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { s0 mod 3 in [1, 1] s2 mod 4 in [0, 0] )")); - composed.Simplify(); + composed.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0) domain: @@ -381,7 +386,7 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { Interval{0, 0}); indexing_map.AddConstraint(ParseAffineExpr("s1 mod 5", &mlir_context_), Interval{1, 1}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( (d0)[s0, s1] -> (d0, s1, s0) @@ -399,7 +404,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { IndexingMap indexing_map = IndexingMap(ParseAffineMap("(d0) -> (d0)", &mlir_context_), {DimVar{{5, 5}}}, /*range_vars=*/{}, /*rt_vars=*/{}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0) -> (5) domain: @@ -412,7 +417,7 @@ TEST_F(IndexingMapTest, auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {8, 16}, {}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: @@ -429,7 +434,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {9, 9, 9}, {}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0, d1, d2) @@ -448,7 +453,7 @@ TEST_F(IndexingMapTest, IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1, d2) -> (d0 * 2 + (d1 + d2 floordiv 4) floordiv 2, (d1 * 4 + d2) mod 8) @@ -465,7 +470,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { "d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {8, 9}, {}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( (d0, d1) -> (d0, d1) domain: @@ -479,7 +484,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0 * 128) domain: s0 in [0, 127] @@ -492,7 +497,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715) domain: s0 in [0, 127] @@ -505,7 +510,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { "14)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0] -> (s0) domain: @@ -519,7 +524,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_DivGcdGreater1) { "floordiv 3) * 768 + ((s0 * 128 + s1) floordiv 192) * 768)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128, 4}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2) domain: @@ -535,7 +540,7 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { "20000)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {}, {872, 4, 128, 896}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1, s2, s3] -> ( s1 + (s0 * 458752 + s2 * 4 + s3 * 512) mod 20000 @@ -555,7 +560,7 @@ TEST_F(IndexingMapTest, "* 2) floordiv 4)"; IndexingMap indexing_map = IndexingMap::FromTensorSizes( ParseAffineMap(serialized_map, &mlir_context_), {}, {2, 128}); - indexing_map.Simplify(); + indexing_map.Simplify(GetIndexingMapForInstruction); EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( ()[s0, s1] -> ( s0 * 4 + s1 floordiv 32 @@ -704,6 +709,123 @@ TEST(IntervalComparisionTest, Comparisons) { EXPECT_EQ(point != 16, true); } +TEST_F(IndexingMapTest, ReplaceConstantRTVars_ScalarConstant) { + // auto zero_dim_map = AffineMap::get(&mlir_context_); + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)); + + IndexingMap indexing_map(ParseAffineMap("()[s0] -> (s0)", &mlir_context_), + /*dimensions=*/{}, + /*range_vars=*/{}, + {RTVar{Interval{42, 42}, constant.get(), + AffineMap::get(0, 0, {}, &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + () -> (42) + domain: + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_StaticIndexIntoTensorConstant) { + // auto zero_dim_map = AffineMap::get(&mlir_context_); + auto constant = HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3, 4}, {11, 12, 13, 14}})); + + IndexingMap indexing_map( + ParseAffineMap("()[s0] -> (s0)", &mlir_context_), + /*dimensions=*/{}, + /*range_vars=*/{}, + {RTVar{Interval{1, 14}, constant.get(), + ParseAffineMap("() -> (1,2)", &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + () -> (13) + domain: + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_NonFoldableTensor) { + // auto zero_dim_map = AffineMap::get(&mlir_context_); + auto constant = HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 2, 3, 4}, {11, 12, 13, 14}})); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (s0)", &mlir_context_), + /*dimensions=*/{}, + /*range_vars=*/{}, + {RTVar{Interval{1, 14}, constant.get(), + ParseAffineMap("(d0) -> (1, d0)", &mlir_context_)}}); + + EXPECT_FALSE(indexing_map.Simplify(GetIndexingMapForInstruction)); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {10, 10}), 0); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 255}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 9}, iota.get(), + ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, d0) + domain: + d0 in [0, 255] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {10, 10}), 1); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 255}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 9}, iota.get(), + ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, 7) + domain: + d0 in [0, 255] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {10, 10}), 0); + + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 255}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 9}, iota.get(), + ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); + indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), + Interval{0, 0}); + + EXPECT_TRUE(indexing_map.Simplify(GetIndexingMapForInstruction)); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, d0) + domain: + d0 in [0, 254] + d0 mod 2 in [0, 0] + )")); +} + } // namespace } // namespace gpu } // namespace xla From e4e2605ca259f89c6eec453a1fd22167da351b3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 00:29:28 -0700 Subject: [PATCH 007/124] Automated Code Change PiperOrigin-RevId: 619820412 --- tensorflow/lite/tools/BUILD | 2 ++ tensorflow/lite/tools/tool_params_test.cc | 1 - tensorflow/lite/tools/utils.cc | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index d7608ec188363e..8c60f8ad012bd8 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -409,6 +409,8 @@ cc_library( hdrs = ["utils.h"], deps = [ ":logging", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels:kernel_util", ], diff --git a/tensorflow/lite/tools/tool_params_test.cc b/tensorflow/lite/tools/tool_params_test.cc index 248db53b0a4d42..e34c40b3cba143 100644 --- a/tensorflow/lite/tools/tool_params_test.cc +++ b/tensorflow/lite/tools/tool_params_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/lite/tools/tool_params.h" -#include #include namespace tflite { diff --git a/tensorflow/lite/tools/utils.cc b/tensorflow/lite/tools/utils.cc index 846f76471f2ce1..12396ed7c3ce05 100644 --- a/tensorflow/lite/tools/utils.cc +++ b/tensorflow/lite/tools/utils.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/tools/logging.h" From c5621b856d913d5692b59e2861f37244dcad00d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 02:02:15 -0700 Subject: [PATCH 008/124] Update GraphDef version to 1815. PiperOrigin-RevId: 619842796 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 8d59a74287919d..df98002eacc475 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1814 // Updated: 2024/3/27 +#define TF_GRAPH_DEF_VERSION 1815 // Updated: 2024/3/28 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From e34abc2477bd1862f38073689f2d12c145dfc66e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 02:03:41 -0700 Subject: [PATCH 009/124] compat: Update forward compatibility horizon to 2024-03-28 PiperOrigin-RevId: 619843215 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index a96d4514f68cda..d3a357f833f2ef 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 27) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 28) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 9e45913593a9470976671f1710eff7eeed36f44a Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Thu, 28 Mar 2024 02:48:19 -0700 Subject: [PATCH 010/124] Disable Batch Mat Mul delegate test until flaky behaviour is fixed. PiperOrigin-RevId: 619853923 --- .../xnnpack/batch_matrix_multiply_test.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc b/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc index 4d7af0f07e73ce..30642a3df2a517 100644 --- a/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc +++ b/tensorflow/lite/delegates/xnnpack/batch_matrix_multiply_test.cc @@ -25,7 +25,7 @@ limitations under the License. namespace tflite { namespace xnnpack { -class BatchMatrixMultiplyTest : public testing::Test { +class DISABLED_BatchMatrixMultiplyTest : public testing::Test { public: // std::unique_ptr auto get_delegate(int num_threads = 1) { @@ -52,7 +52,7 @@ class BatchMatrixMultiplyTest : public testing::Test { std::mt19937 rng_ = std::mt19937(random_device_()); }; -TEST_F(BatchMatrixMultiplyTest, 3D) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, 3D) { const auto batch = shape_rng(); const auto height = shape_rng(); const auto input1_channels = channels_rng(); @@ -65,7 +65,7 @@ TEST_F(BatchMatrixMultiplyTest, 3D) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, BroadcastOne3D) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, BroadcastOne3D) { const auto batch = shape_rng(); const auto height = shape_rng(); const auto input1_channels = channels_rng(); @@ -83,7 +83,7 @@ TEST_F(BatchMatrixMultiplyTest, BroadcastOne3D) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, BroadcastImplicit3D) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, BroadcastImplicit3D) { const auto batch = shape_rng(); const auto height = shape_rng(); const auto input1_channels = channels_rng(); @@ -101,7 +101,7 @@ TEST_F(BatchMatrixMultiplyTest, BroadcastImplicit3D) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, 4D) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, 4D) { const auto outer_batch = shape_rng(); const auto inner_batch = shape_rng(); const auto height = shape_rng(); @@ -115,7 +115,7 @@ TEST_F(BatchMatrixMultiplyTest, 4D) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, BroadcastOne4D) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, BroadcastOne4D) { const auto outer_batch = shape_rng(); const auto inner_batch = shape_rng(); const auto height = shape_rng(); @@ -149,7 +149,7 @@ TEST_F(BatchMatrixMultiplyTest, BroadcastOne4D) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, BroadcastImplicit4D) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, BroadcastImplicit4D) { const auto outer_batch = shape_rng(); const auto inner_batch = shape_rng(); const auto height = shape_rng(); @@ -175,7 +175,7 @@ TEST_F(BatchMatrixMultiplyTest, BroadcastImplicit4D) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, 4D_AdjY) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, 4D_AdjY) { const auto outer_batch = shape_rng(); const auto inner_batch = shape_rng(); const auto height = shape_rng(); @@ -190,7 +190,7 @@ TEST_F(BatchMatrixMultiplyTest, 4D_AdjY) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, MultiThreading) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, MultiThreading) { const auto batch = shape_rng(); const auto height = shape_rng(); const auto input1_channels = channels_rng(); @@ -203,7 +203,7 @@ TEST_F(BatchMatrixMultiplyTest, MultiThreading) { .Test(xnnpack_delegate.get()); } -TEST_F(BatchMatrixMultiplyTest, WeightsCache) { +TEST_F(DISABLED_BatchMatrixMultiplyTest, WeightsCache) { TfLiteXNNPackDelegateOptions delegate_options = TfLiteXNNPackDelegateOptionsDefault(); std::unique_ptr Date: Thu, 28 Mar 2024 03:44:05 -0700 Subject: [PATCH 011/124] [xla:gpu] Sort sliced operand paths on the fly PiperOrigin-RevId: 619867884 --- .../address_computation_fusion_rewriter.cc | 36 +++----------- ...ddress_computation_fusion_rewriter_test.cc | 48 +++++++++---------- 2 files changed, 30 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index 44181f9206d366..6be6b4cc0be087 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -133,8 +133,7 @@ bool IsAlignedSlice(const Shape& src_shape, const Shape& dst_shape, UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, bool dynamic) { - UseDefDataflowPaths sliced_operand_paths = { - const_cast(instr)}; + UseDefDataflowPaths sliced_operand_paths; auto fusion = HloFusionAdaptor::ForComputation(instr->parent()); // This set is used to avoid duplicates in the matched results. It contains @@ -191,12 +190,14 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, // still need to add instructions encountered in the sliced operand path // during the latest traversal. sliced_operand_paths.insert(sliced_operand_paths.end(), - maybe_sliced_operand_path.begin(), - maybe_sliced_operand_path.end()); + maybe_sliced_operand_path.rbegin(), + maybe_sliced_operand_path.rend()); processed_instrs.insert(maybe_sliced_operand_path.begin(), maybe_sliced_operand_path.end()); } } + + sliced_operand_paths.push_back(const_cast(instr)); return sliced_operand_paths; } @@ -278,30 +279,6 @@ absl::InlinedVector GetPatternCaptures( return captures; } -UseDefDataflowPaths GetSortedMatches( - absl::Span matches) { - UseDefDataflowPaths sorted_matches; - InstructionSet matched_instrs(matches.begin(), matches.end()); - InstructionSet processed_instrs; - // Topologically sort `matches` - for (auto it = matches.rbegin(); it != matches.rend(); ++it) { - if (processed_instrs.contains(*it)) continue; - for (auto* operand : (*it)->operands()) { - if (!matched_instrs.contains(operand)) { - continue; - } - if (!processed_instrs.contains(operand)) { - sorted_matches.emplace_back(operand); - processed_instrs.insert(operand); - } - } - sorted_matches.emplace_back(*it); - processed_instrs.insert(*it); - } - - return sorted_matches; -} - Status CreateRootTuple(HloInstruction* hero, HloComputation::Builder& builder, DefUseDataflowPaths sliced_user_paths, absl::flat_hash_map AddressComputationFusionRewriter::Run( absl::c_copy(sliced_user_path, std::back_inserter(matches)); auto captures = GetPatternCaptures(matches); - auto sorted_operand_matches = GetSortedMatches(operand_matches); TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, - CreateFusionBody(module, sorted_operand_matches, + CreateFusionBody(module, operand_matches, sliced_user_paths, captures)); TF_ASSIGN_OR_RETURN(HloInstruction * fusion, diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc index cf63e3978758c2..ec358a7b522e9f 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc @@ -324,10 +324,10 @@ TEST_F(AddressComputationFusionRewriterTest, ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) - %p1 = f16[2,8,8]{2,1,0} parameter(1) + %p1 = f16[4,8,8]{2,1,0} parameter(1) %slice.13 = f16[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} %bitcast.41 = f16[8,8]{1,0} bitcast(%slice.13) - %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %slice.14 = f16[1,8,8]{2,1,0} slice(%p1), slice={[2:3], [0:8], [0:8]} %bitcast.42 = f16[8,8]{1,0} bitcast(%slice.14) %custom-call.1 = f16[8,8]{1,0} custom-call(%bitcast.41, %bitcast.42), @@ -355,26 +355,26 @@ TEST_F(AddressComputationFusionRewriterTest, const char* expected = R"( ; CHECK: %address-computation {{.*}} { - ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(0) - ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) - ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[8,8]{1,0} parameter(1) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[2:3], [0:8], [0:8]} ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) ; CHECK: ROOT [[CC:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[B1]]), ; CHECK: custom_call_target="__cublas$gemm" ; CHECK: } ; CHECK: ENTRY %main{{.*}} { - ; CHECK-DAG: [[P0:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) - ; CHECK-DAG: [[S0:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P0]]), slice={[1:2], [0:8], [0:8]} - ; CHECK-DAG: [[B0:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S0]]) - ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(1) - ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[B0]], [[P1]]) + ; CHECK-DAG: [[P1:%[^ ]+]] = f16[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: [[S1:%[^ ]+]] = f16[1,8,8]{2,1,0} slice([[P1]]), slice={[1:2], [0:8], [0:8]} + ; CHECK-DAG: [[B1:%[^ ]+]] = f16[8,8]{1,0} bitcast([[S1]]) + ; CHECK-DAG: [[P0:%[^ ]+]] = f16[4,8,8]{2,1,0} parameter(1) + ; CHECK: [[FUSION:%[^ ]+]] = f16[8,8]{1,0} fusion([[P0]], [[B1]]) ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", ; CHECK: "custom_fusion_config":{"name":"address_computation"} ; CHECK: } - ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[B0]]) + ; CHECK: ROOT {{.*}} = f16[8,8]{1,0} add([[FUSION]], [[B1]]) ; CHECK: } )"; @@ -848,11 +848,11 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandAliasingOutput) { const char* expected = R"( ; CHECK: %address-computation {{.*}} { - ; CHECK-DAG: [[P0:%[^ ]+]] = f32[100,100]{1,0} parameter(0) + ; CHECK-DAG: [[P2:%[^ ]+]] = f32[100,100]{1,0} parameter(2) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[100,100]{1,0} parameter(1) - ; CHECK-DAG: [[P2:%[^ ]+]] = f32[200,100]{1,0} parameter(2) - ; CHECK-DAG: [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P2]]), slice={[16:116], [0:100]} - ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P0]], [[S1]], [[P1]]), + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[200,100]{1,0} parameter(0) + ; CHECK-DAG: [[S1:%[^ ]+]] = f32[100,100]{1,0} slice([[P0]]), slice={[16:116], [0:100]} + ; CHECK: [[CC:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) custom-call([[P1]], [[S1]], [[P2]]), ; CHECK: custom_call_target="__cublas$gemm" ; CHECK: } @@ -862,7 +862,7 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandAliasingOutput) { ; CHECK: [[GTE1:%[^ ]+]] = f32[100,100]{1,0} get-tuple-element([[P]]), index=1 ; CHECK: [[CONCAT:%[^ ]+]] = f32[200,100]{1,0} concatenate([[GTE0]], [[GTE1]]), dimensions={0} ; CHECK: [[S:%[^ ]+]] = f32[100,100]{1,0} slice([[CONCAT]]), slice={[99:199], [0:100]} - ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[GTE0]], [[S]], [[CONCAT]]) + ; CHECK: ROOT [[FUSION:%[^ ]+]] = (f32[100,100]{1,0}, s8[120000]{0}) fusion([[CONCAT]], [[GTE0]], [[S]]) ; CHECK: kind=kCustom, calls=%address-computation, ; CHECK: backend_config={ ; CHECK: "kind":"__custom_fusion", @@ -1108,12 +1108,12 @@ TEST_F(AddressComputationFusionRewriterTest, TupleSliceCustomCallLegacy) { const char* expected = R"( ; CHECK: %address-computation {{.*}} { - ; CHECK-DAG: [[P2:%[^ ]+]] = f32[8,8]{1,0} parameter(2) - ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P2]]), slice={[0:4], [0:8]} + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0) + ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]} ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1) ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]]) - ; CHECK-DAG: [[P0:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(0) - ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[T0]], [[P0]]), + ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2) + ; CHECK: ROOT [[CC:%[^ ]+]] = f32[128]{0} custom-call([[T0]], [[P2]]), ; CHECK: custom_call_target="Callback_Void" ; CHECK: } @@ -1184,12 +1184,12 @@ TEST_F(AddressComputationFusionRewriterTest, TupledOutputCustomCallLegacy) { const char* expected = R"( ; CHECK: %address-computation {{.*}} { - ; CHECK-DAG: [[P0:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(0) + ; CHECK-DAG: [[P2:%[^ ]+]] = (f32[1024]{0}, f32[8]{0}) parameter(2) ; CHECK-DAG: [[P1:%[^ ]+]] = f32[256]{0} parameter(1) - ; CHECK-DAG: [[P2:%[^ ]+]] = f32[8,8]{1,0} parameter(2) - ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P2]]), slice={[0:4], [0:8]} + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[8,8]{1,0} parameter(0) + ; CHECK-DAG: [[S0:%[^ ]+]] = f32[4,8]{1,0} slice([[P0]]), slice={[0:4], [0:8]} ; CHECK-DAG: [[T0:%[^ ]+]] = (f32[4,8]{1,0}, f32[256]{0}) tuple([[S0]], [[P1]]) - ; CHECK: [[CC:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) custom-call([[T0]], [[P0]]), + ; CHECK: [[CC:%[^ ]+]] = (f32[8]{0}, (f32[128]{0}, f32[256]{0}), f32[1024]{0}, f32[4,8]{1,0}) custom-call([[T0]], [[P2]]), ; CHECK: custom_call_target="Callback_Void" ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[8]{0} get-tuple-element([[CC]]), index=0 ; CHECK-DAG: [[GTE1:%[^ ]+]] = (f32[128]{0}, f32[256]{0}) get-tuple-element([[CC]]), index=1 From 42883cb09d1a8155824ce4ed044794c0dffdd19f Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Thu, 28 Mar 2024 04:24:07 -0700 Subject: [PATCH 012/124] Run xla triton hopper tests on demand on TAP PiperOrigin-RevId: 619877948 --- third_party/xla/xla/service/gpu/BUILD | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b2483ae3b3c115..3ae6b8646d3822 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -600,6 +600,7 @@ xla_test( backends = [ "gpu_a100", "gpu_v100", + "gpu_h100", ], shard_count = 20, tags = ["nomac"], @@ -653,7 +654,10 @@ xla_test( backend_tags = {"gpu": [ "requires-gpu-sm70", ]}, - backends = ["gpu"], + backends = [ + "gpu", + "gpu_h100", + ], tags = [ "large", "no_oss", # requires-mem:16g tag doesn't work in open source From 462d6e01352ed769ada17d509a5d1fdda5e2ba51 Mon Sep 17 00:00:00 2001 From: Vladyslav Tsilytskyi Date: Thu, 28 Mar 2024 04:26:54 -0700 Subject: [PATCH 013/124] [stream_executor:host] Add LLVM kernel support in kernel_spec Related to https://github.com/openxla/xla/issues/7234 PiperOrigin-RevId: 619878628 --- third_party/xla/xla/stream_executor/kernel.h | 19 ++++++++++++ .../xla/xla/stream_executor/kernel_spec.cc | 20 ++++++++++++- .../xla/xla/stream_executor/kernel_spec.h | 30 +++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index a88cc59071d099..edf0e24b31a119 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -312,6 +312,14 @@ class TypedKernel { absl::string_view kernel_name, void *symbol); + // Creates a kernel which can be launched with `stream.ThenLaunch(...)` from + // an LLVM IR. + static absl::StatusOr Create(StreamExecutor *executor, + absl::string_view ir, + absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options); + TypedKernel() = default; Kernel &operator*() { return *kernel_; } @@ -757,6 +765,17 @@ inline absl::StatusOr> TypedKernel::Create( return TypedKernel::Create(executor, loader_spec); } +template +inline absl::StatusOr> TypedKernel::Create( + StreamExecutor *executor, absl::string_view ir, + absl::string_view entrypoint, absl::string_view kernel_name, + absl::Span options) { + MultiKernelLoaderSpec loader_spec(TypedKernel::kNumberOfParameters); + loader_spec.AddLlvmHostKernel(ir, entrypoint, kernel_name, options); + + return TypedKernel::Create(executor, loader_spec); +} + } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/kernel_spec.cc b/third_party/xla/xla/stream_executor/kernel_spec.cc index fe29fabe52643c..5f7077e991bbbc 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.cc +++ b/third_party/xla/xla/stream_executor/kernel_spec.cc @@ -59,6 +59,15 @@ CudaPtxInMemory::CudaPtxInMemory( } } +LlvmHostKernel::LlvmHostKernel(absl::string_view ir, + absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options) + : KernelLoaderSpec(std::move(kernel_name)), + ir_(ir), + entrypoint_(entrypoint), + options_(options.cbegin(), options.cend()) {} + const char *CudaPtxInMemory::default_text() const { if (ptx_by_compute_capability_.empty()) { return nullptr; @@ -84,7 +93,7 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddInProcessSymbol( void *symbol, absl::string_view kernel_name) { CHECK(in_process_symbol_ == nullptr); in_process_symbol_ = - std::make_unique(symbol, std::string(kernel_name)); + std::make_shared(symbol, std::string(kernel_name)); return this; } @@ -102,6 +111,15 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory( return this; } +MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddLlvmHostKernel( + absl::string_view ir, absl::string_view entrypoint, + absl::string_view kernel_name, absl::Span options) { + CHECK(llvm_host_kernel_ == nullptr); + llvm_host_kernel_ = + std::make_shared(ir, entrypoint, kernel_name, options); + return this; +} + MultiKernelLoaderSpec::MultiKernelLoaderSpec( size_t arity, KernelArgsPacking kernel_args_packing) : arity_(arity), kernel_args_packing_(std::move(kernel_args_packing)) {} diff --git a/third_party/xla/xla/stream_executor/kernel_spec.h b/third_party/xla/xla/stream_executor/kernel_spec.h index aa75e4a7b7454e..d50ac23713dc5e 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.h +++ b/third_party/xla/xla/stream_executor/kernel_spec.h @@ -175,6 +175,25 @@ class CudaCubinInMemory : public KernelLoaderSpec { void operator=(const CudaCubinInMemory &) = delete; }; +class LlvmHostKernel : public KernelLoaderSpec { + public: + LlvmHostKernel(absl::string_view ir, absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options); + + absl::string_view ir() const { return ir_; } + absl::string_view entrypoint() const { return entrypoint_; } + absl::Span options() const { return options_; } + + private: + std::string ir_; + std::string entrypoint_; + std::vector options_; + + LlvmHostKernel(const LlvmHostKernel &) = delete; + void operator=(const LlvmHostKernel &) = delete; +}; + // Describes how to load a kernel on any subset of a number of target platforms. class MultiKernelLoaderSpec { public: @@ -199,6 +218,7 @@ class MultiKernelLoaderSpec { return cuda_cubin_in_memory_ != nullptr; } bool has_cuda_ptx_in_memory() const { return cuda_ptx_in_memory_ != nullptr; } + bool has_llvm_host_kernel() const { return llvm_host_kernel_ != nullptr; } // Accessors for platform variant kernel load specifications. // Precondition: corresponding has_* is true. @@ -214,6 +234,10 @@ class MultiKernelLoaderSpec { CHECK(has_cuda_ptx_in_memory()); return *cuda_ptx_in_memory_; } + const LlvmHostKernel &llvm_host_kernel() const { + CHECK(has_llvm_host_kernel()); + return *llvm_host_kernel_; + } // Builder-pattern-like methods for use in initializing a // MultiKernelLoaderSpec. Each of these should be used at most once for a // single MultiKernelLoaderSpec object. See file comment for example usage. @@ -227,6 +251,10 @@ class MultiKernelLoaderSpec { absl::Span cubin_bytes, absl::string_view kernel_name); MultiKernelLoaderSpec *AddCudaPtxInMemory(absl::string_view ptx, absl::string_view kernel_name); + MultiKernelLoaderSpec *AddLlvmHostKernel(absl::string_view ir, + absl::string_view entrypoint, + absl::string_view kernel_name, + absl::Span options); const KernelArgsPacking &kernel_args_packing() const { return kernel_args_packing_; @@ -239,6 +267,8 @@ class MultiKernelLoaderSpec { cuda_cubin_in_memory_; // Binary CUDA program in memory. std::shared_ptr cuda_ptx_in_memory_; // PTX text that resides in memory. + std::shared_ptr + llvm_host_kernel_; // LLVM kernel for host execution. // Number of parameters that the kernel takes. (This is nicer to have in a // constexpr than having to determine it from the types via template From b23fc08a6028eb0541ff3f826b9cae50e55ab8de Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Thu, 28 Mar 2024 04:29:52 -0700 Subject: [PATCH 014/124] Avoid loading then saving reprensentative dataset again If the representative dataset is in QuantizationOptions, currently, we first load it then save it again. This is not efficient, just use it directly. PiperOrigin-RevId: 619879297 --- .../tensorflow/python/quantize_model.py | 132 ++++++++++-------- 1 file changed, 74 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 094f344221581f..e0eeca13d92f20 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -59,6 +59,8 @@ _QuantizationComponent = _QuantizationComponentSpec.QuantizationComponent _TensorType = _QuantizationComponentSpec.TensorType +_RepresentativeDatasetFile = quant_opts_pb2.RepresentativeDatasetFile + # Mapping of signature def key -> SignatureDef. _SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] @@ -99,6 +101,57 @@ def _serialize_signature_def_map( return signature_def_map_serialized +def _save_representative_dataset( + representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, + signature_def_map: _SignatureDefMap, +) -> Mapping[str, _RepresentativeDatasetFile]: + """Saves the representative dataset to temporary TFRecord files. + + Args: + representative_dataset: Representative dataset used for the calibration + step. Representative datasets should exist for each signature def key in + `signature_def_keys`. + signature_def_map: Signature def key -> SignatureDef mapping. + + Returns: + A map from signature key to the saved representative dataset file. + """ + if isinstance(representative_dataset, Mapping): + if set(signature_def_map.keys()) != set(representative_dataset.keys()): + raise ValueError( + 'The signature keys and the keys of representative dataset map ' + f'do not match. Signature keys: {set(signature_def_map.keys())}, ' + f'representative dataset map: {set(representative_dataset.keys())}.' + ) + representative_dataset_map = representative_dataset + elif len(signature_def_map.keys()) > 1: + raise ValueError( + 'Representative dataset is not a mapping (got: ' + f'{type(representative_dataset)}), but there is more than one ' + 'signature key provided. Please provide a map of ' + '{signature_key -> dataset} with more than one signature key.' + ) + else: + representative_dataset_map = { + list(signature_def_map.keys())[0]: representative_dataset, + } + + # Save the representative dataset to temporary TFRecord files. + path_map = {} + expected_input_key_map = {} + for signature_key, signature_def in signature_def_map.items(): + # Filepath is the second return value of mkstemp. + _, path_map[signature_key] = tempfile.mkstemp( + suffix='.tfrecord', prefix=signature_key + ) + expected_input_key_map[signature_key] = signature_def.inputs.keys() + + return repr_dataset.TfRecordRepresentativeDatasetSaver( + path_map=path_map, + expected_input_key_map=expected_input_key_map, + ).save(representative_dataset_map) + + def _run_static_range_qat( src_saved_model_path: str, dst_saved_model_path: str, @@ -133,7 +186,7 @@ def _run_static_range_ptq( src_saved_model_path: str, dst_saved_model_path: str, quant_opts: _QuantizationOptions, - representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, + representative_dataset: Mapping[str, _RepresentativeDatasetFile], signature_def_map: _SignatureDefMap, ) -> None: """Runs static-range Post-Training Quantization. @@ -147,9 +200,8 @@ def _run_static_range_ptq( src_saved_model_path: Path to the source SavedModel directory. dst_saved_model_path: Path to the destination SavedModel directory. quant_opts: Quantization options. - representative_dataset: Representative dataset used for the calibration - step. Representative datasets should exist for each signature def key in - `signature_def_keys`. + representative_dataset: A map from signature key to the saved representative + dataset file. signature_def_map: Signature def key -> SignatureDef mapping. Raises: @@ -159,48 +211,11 @@ def _run_static_range_ptq( signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) - if isinstance(representative_dataset, Mapping): - if set(signature_def_map.keys()) != set(representative_dataset.keys()): - raise ValueError( - 'The signature keys and the keys of representative dataset map ' - f'do not match. Signature keys: {set(signature_def_map.keys())}, ' - f'representative dataset map: {set(representative_dataset.keys())}.' - ) - representative_dataset_map = representative_dataset - elif len(signature_def_map.keys()) > 1: - raise ValueError( - 'Representative dataset is not a mapping (got: ' - f'{type(representative_dataset)}), but there is more than one ' - 'signature key provided. Please provide a map of ' - '{signature_key -> dataset} with more than one signature key.' - ) - else: - representative_dataset_map = { - list(signature_def_map.keys())[0]: representative_dataset, - } - - # Save the representative dataset to temporary TFRecord files. - # TODO: b/329552787 - If the representative dataset is in QuantizationOptions - # avoid loading then saving it again. - path_map = {} - expected_input_key_map = {} - for signature_key, signature_def in signature_def_map.items(): - # Filepath is the second return value of mkstemp. - _, path_map[signature_key] = tempfile.mkstemp( - suffix='.tfrecord', prefix=signature_key - ) - expected_input_key_map[signature_key] = signature_def.inputs.keys() - - dataset_file_map = repr_dataset.TfRecordRepresentativeDatasetSaver( - path_map=path_map, - expected_input_key_map=expected_input_key_map, - ).save(representative_dataset_map) - # `quantize_ptq_static_range` requires `RepresentativeDatasetFile`s to be # serialized. Serialize the values to match the type. dataset_file_map_serialized = { signature_key: dataset_file.SerializeToString() - for signature_key, dataset_file in dataset_file_map.items() + for signature_key, dataset_file in representative_dataset.items() } pywrap_quantize_model.quantize_ptq_static_range( src_saved_model_path, @@ -265,9 +280,24 @@ def _static_range_quantize( set(quantization_options.tags), ) + if ( + representative_dataset is not None + and quantization_options.representative_datasets + ): + raise ValueError( + 'Do not specify both the `representative_dataset` argument and' + ' the `representative_datasets` field in `QuantizationOptions`.' + ) + + saved_representative_dataset = quantization_options.representative_datasets + if representative_dataset is not None: + saved_representative_dataset = _save_representative_dataset( + representative_dataset, signature_def_map + ) + # Checks if the model is from QAT or method is METHOD_NO_QUANTIZE. if ( - representative_dataset is None + not saved_representative_dataset and not is_qat_saved_model_or_method_no_quantize ): raise ValueError( @@ -293,7 +323,7 @@ def _static_range_quantize( src_saved_model_path, dst_saved_model_path, quantization_options, - representative_dataset, + saved_representative_dataset, signature_def_map, ) @@ -859,20 +889,6 @@ def quantize( _populate_quantization_options_default_values(quantization_options) - if ( - representative_dataset is not None - and quantization_options.representative_datasets - ): - raise ValueError( - 'Do not specify both the `representative_dataset` argument and' - ' the `representative_datasets` field in `QuantizationOptions`.' - ) - - if quantization_options.representative_datasets: - representative_dataset = repr_dataset.TfRecordRepresentativeDatasetLoader( - quantization_options.representative_datasets - ).load() - method: _QuantizationMethod = quantization_options.quantization_method if ( method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8 From 2218d5f10bda1bd2d9c1eb8b1aa60b2877c3acbd Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Thu, 28 Mar 2024 05:28:47 -0700 Subject: [PATCH 015/124] Use indexing map symbol rescaling in the ReduceWindow emitter We current codegen the reduce window op as a loop nest with the inner loop iterating over the window and accumulating input values. If the op had a base dilation set we would still generate the same inner loop but the loop body now woudl have an additional condition that checks whether we are in bounds of the dilation. This change makes use of IndexingMap's symbol rescaling which results in the generation of an inner loop with fewer iterations. It also avoids the in-bounds check by only iterating over the tensor elements that actually need to be accumulated. PiperOrigin-RevId: 619893695 --- .../gpu/fusions/mlir/elemental_hlo_to_mlir.cc | 6 +-- .../mlir/elemental_hlo_to_mlir_test.cc | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 321b06c9c42a97..4e3e181f94c817 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -311,7 +311,8 @@ absl::StatusOr> EmitReduceWindow( MLIRContext* mlir_context = b.getContext(); HloInstructionIndexing indexing = ComputeOutputToInputIndexing(instr, 0, mlir_context); - const auto& indexing_map = *indexing.indexing_maps[0].begin(); + auto indexing_map = *indexing.indexing_maps[0].begin(); + indexing_map.RescaleSymbols(); auto reduce_window = DynCast(instr); CHECK(reduce_window != nullptr); @@ -1228,9 +1229,6 @@ void GetLoopBoundsFromIndexingMap(ImplicitLocOpBuilder& b, for (const Interval& bound : indexing_map.GetSymbolBounds()) { lbs->push_back(b.create(bound.lower)); ubs->push_back(b.create(bound.upper + 1)); - // Note that this is not optimal, when there are mod constraints on symbols, - // e.g. for reduce-window. In that case we have to extract loop steps from - // the mod constraints. steps->push_back(c1); } } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 3bb4e9fbd1f01d..1acdff315457e4 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -245,6 +245,47 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { )")); } +TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { + TF_EXPECT_OK(Run(R"( + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT sum = f32[] add(p0, p1) + } + + ENTRY main { + p0 = f32[42,12,8] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[19,12,8] reduce-window(p0, p1), window={ + size=8x1x1 + stride=4x1x1 + pad=0_0x0_0x0_0 + lhs_dilate=2x1x1 + }, + to_apply=add + })", + R"( + // CHECK: @main_r( + // CHECK-SAME: %[[ARG0:.*]]: tensor<42x12x8xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor + // CHECK-SAME: %[[X:arg[0-9]*]]: index {{[^}]*}}}, + // CHECK-SAME: %[[Y:arg[0-9]*]]: index {{[^}]*}}}, + // CHECK-SAME: %[[Z:arg[0-9]*]]: index {{[^}]*}}}) -> f32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + + // We have a window size of 8, but expect a loop from 0 to 4 + // due to the base dilation of 2 and the applied symbol rescaling: + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] + // CHECK: %[[K:.*]] = affine.apply affine_map<()[s0, s1] -> + // If symbol rescaling wasn't working we would have a + // `s0 floordiv ` in the map: + // CHECK-SAME: (s0 + s1 * 2)>()[%[[I]], %[[X]]] + // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] + )")); +} + TEST_F(ElementalHloToMlirTest, Concatenate) { TF_EXPECT_OK(Run(R"( ENTRY main { From 1713d5563df74c0b384159dfdd27f4410b74978a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Mar 2024 05:36:09 -0700 Subject: [PATCH 016/124] [XLA:Python] Refactor jit argument parsing code. * Put the logic to split static and dynamic arguments under ParseArguments, which returns a ArgumentSignature and a list of dynamic arguments. * Move the other argument parsing logic for jit and pmap into a separate CallSignature type. This simplifies the code, and prepares for adding another use case for the argument parsing code without the rest of the jit logic. Refactoring only, no functional changes intended. PiperOrigin-RevId: 619895559 --- third_party/xla/xla/python/BUILD | 4 +- third_party/xla/xla/python/jax_jit.cc | 140 +++++++++++---------- third_party/xla/xla/python/jax_jit.h | 161 +++++++++++++------------ third_party/xla/xla/python/pjit.cc | 109 +++++++++-------- third_party/xla/xla/python/pmap_lib.cc | 46 +++---- 5 files changed, 246 insertions(+), 214 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 9938364cfd94bc..91b78aa20c8822 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -537,8 +537,6 @@ cc_library( "@local_config_python//:python_headers", # build_cleaner: keep "//xla/pjrt:pjrt_client", "//xla/pjrt:status_casters", - "//xla/python/ifrt", - "@local_tsl//tsl/concurrency:ref_count", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -733,6 +731,7 @@ cc_library( ":types", # placeholder for index annotation deps "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -747,7 +746,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/pjrt:exceptions", "//xla/pjrt:pjrt_client", - "//xla/pjrt:pjrt_future", "//xla/pjrt:status_casters", "//xla/python/ifrt", "@local_tsl//tsl/concurrency:ref_count", diff --git a/third_party/xla/xla/python/jax_jit.cc b/third_party/xla/xla/python/jax_jit.cc index 726972f0d7fc6d..20ff684cd38ca4 100644 --- a/third_party/xla/xla/python/jax_jit.cc +++ b/third_party/xla/xla/python/jax_jit.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -138,13 +139,60 @@ bool FetchMemoriesFlag() { *global_state.enable_memories); } -std::string CallSignature::DebugString() const { +std::string ArgumentSignature::DebugString() const { auto py_object_formatter = [](std::string* out, const nb::object& o) { out->append(nb::cast(nb::str(o))); }; auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { out->append(d.ToString()); }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s]" + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature& other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object& a, const nb::object& b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; auto signature_formatter = [](std::string* out, const xla::PyArgSignature& s) { out->append(s.DebugString()); @@ -153,25 +201,20 @@ std::string CallSignature::DebugString() const { out->append(o ? "true" : "false"); }; return absl::StrFormat( - "static args (positional + keyword): %s\nstatic arg keyword names: %s\n" + "arg signature: %s\n" "dynamic arg signatures (positional + keyword): %s\n" "dynamic arg shardings: %s\n" "committed args: %s\n" - "dynamic arg keyword names: %s\n" - "dynamic arg treedefs: %s\n" "device: %s\n" "default_device: %s\n" "jax_enable_x64: %d\n" "jax_enable_memories: %d\n" "global_extra_jit_context: %s\n" "thread_local_extra_jit_context: %s\n", - absl::StrJoin(static_args, ",", py_object_formatter), - absl::StrJoin(static_arg_names, ",", py_object_formatter), + arg_signature.DebugString(), absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), absl::StrJoin(committed_args, ",", bool_formatter), - absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), - absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter), // new line device != nullptr ? device->DebugString() : "nullptr", OptionalDebugString(default_device), jax_enable_x64, jax_enable_memories, OptionalDebugString(global_extra_jit_context), @@ -179,14 +222,7 @@ std::string CallSignature::DebugString() const { } bool CallSignature::operator==(const CallSignature& other) const { - if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { - return false; - } - auto object_ptr_equality = [](nb::handle a, nb::handle b) { - return a.ptr() == b.ptr(); - }; - if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, - object_ptr_equality)) { + if (arg_signature != other.arg_signature) { return false; } if (dynamic_arg_signatures != other.dynamic_arg_signatures) { @@ -201,10 +237,6 @@ bool CallSignature::operator==(const CallSignature& other) const { if (jax_enable_memories != other.jax_enable_memories) { return false; } - if (!absl::c_equal(static_arg_names, other.static_arg_names, - object_ptr_equality)) { - return false; - } if (committed_args != other.committed_args) { return false; } @@ -212,21 +244,6 @@ bool CallSignature::operator==(const CallSignature& other) const { // `==` on py:objects is the Python `is`. We need equal. absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, ShardingEqual) && - absl::c_equal( - static_args, other.static_args, - [this](const nb::object& a, const nb::object& b) { - try { - return a.type().ptr() == b.type().ptr() && a.equal(b); - } catch (const nb::python_error& e) { - throw std::invalid_argument(absl::StrCat( - "static arguments should be comparable using __eq__." - "The following error was raised during a call to '", - function_name, "' when comparing two objects of types ", - nb::cast(nb::str(a.type())), " and ", - nb::cast(nb::str(b.type())), - ". The error was:\n", e.what())); - } - }) && (global_extra_jit_context.has_value() == other.global_extra_jit_context.has_value()) && (!global_extra_jit_context.has_value() || @@ -243,41 +260,37 @@ bool CallSignature::operator==(const CallSignature& other) const { // Filter out static arguments, flatten and concatenate other arguments (i.e. // dynamic positional and keyword arguments), filling `arguments` in place. -absl::Status ParseArguments(absl::Span positional_args, - absl::Span keyword_args, - nb::handle kwnames, - absl::Span static_argnums, - absl::Span static_argnames, - xla::PyTreeRegistry* pytree_registry, - ParsedArgumentsAsBuffers& arguments) { +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args) { tsl::profiler::TraceMe traceme("ParseArguments"); - arguments.flat_dynamic_args.reserve(positional_args.size() + - keyword_args.size()); + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); if (static_argnums.empty()) { - arguments.signature.dynamic_arg_treedefs.reserve(positional_args.size()); + signature.dynamic_arg_treedefs.reserve(positional_args.size()); // Positional arguments. for (int i = 0; i < positional_args.size(); ++i) { - arguments.signature.dynamic_arg_treedefs.emplace_back(pytree_registry); - xla::PyTreeDef& pytree_def = - arguments.signature.dynamic_arg_treedefs.back(); - pytree_def.Flatten(nb::handle(positional_args[i]), - arguments.flat_dynamic_args); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); } } else { - arguments.signature.dynamic_arg_treedefs.reserve(positional_args.size()); + signature.dynamic_arg_treedefs.reserve(positional_args.size()); // Positional arguments. for (int i = 0; i < positional_args.size(); ++i) { if (std::find(static_argnums.begin(), static_argnums.end(), i) == static_argnums.end()) { - arguments.signature.dynamic_arg_treedefs.emplace_back(pytree_registry); - xla::PyTreeDef& pytree_def = - arguments.signature.dynamic_arg_treedefs.back(); - pytree_def.Flatten(positional_args[i], arguments.flat_dynamic_args); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); } else { - arguments.signature.static_args.emplace_back( + signature.static_args.emplace_back( nb::borrow(positional_args[i])); } } @@ -313,21 +326,20 @@ absl::Status ParseArguments(absl::Span positional_args, return false; }; - arguments.signature.dynamic_arg_names.reserve(keyword_args.size()); + signature.dynamic_arg_names.reserve(keyword_args.size()); for (int i = 0; i < keyword_args.size(); ++i) { if (kwarg_is_static(kwargs[i].first)) { - arguments.signature.static_arg_names.push_back( + signature.static_arg_names.push_back( nb::steal(kwargs[i].first)); - arguments.signature.static_args.push_back( + signature.static_args.push_back( nb::borrow(kwargs[i].second)); } else { - arguments.signature.dynamic_arg_names.push_back( + signature.dynamic_arg_names.push_back( nb::steal(kwargs[i].first)); - arguments.signature.dynamic_arg_treedefs.emplace_back(pytree_registry); - xla::PyTreeDef& pytree_def = - arguments.signature.dynamic_arg_treedefs.back(); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), - arguments.flat_dynamic_args); + flat_dynamic_args); } } } diff --git a/third_party/xla/xla/python/jax_jit.h b/third_party/xla/xla/python/jax_jit.h index b16b111c4a5b70..c076ef9cbeabe1 100644 --- a/third_party/xla/xla/python/jax_jit.h +++ b/third_party/xla/xla/python/jax_jit.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include #include #include #include @@ -32,13 +33,11 @@ limitations under the License. #include "absl/types/span.h" #include "third_party/nanobind/include/nanobind/nanobind.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/python/ifrt/array.h" #include "xla/python/nb_helpers.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" #include "xla/python/pytree.h" #include "xla/python/sharding.h" -#include "tsl/concurrency/ref_count.h" #include "tsl/platform/logging.h" namespace jax { @@ -93,6 +92,88 @@ bool GetEnableX64(); std::optional GetDefaultDevice(); std::optional GetPostHook(); +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature& other) const; + bool operator!=(const ArgumentSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature& s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto& name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto& static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = xla::nb_hash(static_arg); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto& name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments +// pytree_registry: the registry to use to convert the arguments to pytrees +// arguments: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args); + // The signature of Python jitted function call, partitioned into: // - dynamic positional arguments (i.e. positional args which are not static) // - static positional arguments (i.e. the args associated to static_argnums) @@ -106,13 +187,8 @@ struct CallSignature { // Not part of the signature, but we need it for error messages. std::string_view function_name; - // A PyTreeDef for each dynamic argument, positional arguments first - // followed by keyword arguments. Keyword arguments are in the order given - // by dynamic_arg_names. - absl::InlinedVector dynamic_arg_treedefs; - // Dynamic keyword argument names. Interned, and sorted by the keyword - // name. - std::vector dynamic_arg_names; + ArgumentSignature arg_signature; + // Shape and dtype for both the dynamic positional arguments and the keyword // arguments (sorted by keyword name). absl::InlinedVector dynamic_arg_signatures; @@ -121,13 +197,6 @@ struct CallSignature { // jax.Array enabled. std::vector dynamic_arg_shardings; - // Static arguments. Contains the positional arguments sorted in argument - // order, followed by static keyword arguments in the order given by - // `static_arg_names`. - std::vector static_args; - // Static keyword argument names. Interned, and sorted by keyword name. - std::vector static_arg_names; - absl::InlinedVector committed_args; // For JIT, we need this in the key because computation follows the data, so @@ -155,8 +224,7 @@ struct CallSignature { template H AbslHashValue(H h, const CallSignature& s) { - h = H::combine(std::move(h), s.dynamic_arg_treedefs, - s.dynamic_arg_signatures); + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); DCHECK(s.dynamic_arg_shardings.empty() || s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); @@ -169,35 +237,7 @@ H AbslHashValue(H h, const CallSignature& s) { h = H::combine(std::move(h), ShardingHash(sharding.ptr())); } - for (const auto& name : s.dynamic_arg_names) { - h = H::combine(std::move(h), name.ptr()); - } - - h = H::combine(std::move(h), s.committed_args); - - h = H::combine(std::move(h), s.dynamic_arg_names.size()); - for (const auto& static_arg : s.static_args) { - ssize_t hash; - try { - hash = xla::nb_hash(static_arg); - } catch (const nanobind::python_error& e) { - if (!e.matches(PyExc_TypeError)) throw; - throw std::invalid_argument(absl::StrCat( - "Non-hashable static arguments are not supported. An error occurred " - "during a call to '", - s.function_name, "' while trying to hash an object of type ", - nanobind::cast(nanobind::str(static_arg.type())), - ", ", nanobind::cast(nanobind::str(static_arg)), - ". The error was:\n", e.what(), "\n")); - } - h = H::combine(std::move(h), hash); - } - h = H::combine(std::move(h), s.static_args.size()); - for (const auto& name : s.static_arg_names) { - h = H::combine(std::move(h), name.ptr()); - } - h = H::combine(std::move(h), s.static_arg_names.size()); - h = H::combine(std::move(h), s.device, s.jax_enable_x64); + h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); // We do not hash the extra_jit_context fields since calling Python hash // functions is expensive (~300ns) and we don't expect a large number of @@ -205,33 +245,6 @@ H AbslHashValue(H h, const CallSignature& s) { return h; } -// The resulting information of the parsing and conversion of the arguments. -struct ParsedArgumentsAsBuffers { - // The call signature will be filled during 2 steps: - // - `ParseArguments` will fill the static arguments and the pytree - // structures - // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`. - CallSignature signature; - // The concatenation of the dynamic positional arguments and the sorted - // keyword arguments. - absl::InlinedVector flat_dynamic_args; - std::vector keep_alive_objects; - - xla::ifrt::Client* ifrt_client; - // The following is only valid if the parsing succeeds. - std::vector> ifrt_arg_arrays; -}; - -// Filter out static arguments, flatten and concatenate other arguments (i.e. -// dynamic positional and keyword arguments), filling `arguments` in place. -absl::Status ParseArguments(absl::Span positional_args, - absl::Span keyword_args, - nanobind::handle kwnames, - absl::Span static_argnums, - absl::Span static_argnames, - xla::PyTreeRegistry* pytree_registry, - ParsedArgumentsAsBuffers& arguments); - // The function to call in `xla.cc` to add the bindings for this module. void BuildJaxjitSubmodule(nanobind::module_& m); diff --git a/third_party/xla/xla/python/pjit.cc b/third_party/xla/xla/python/pjit.cc index f3aa9b21a2ce54..c562eaa2eadd65 100644 --- a/third_party/xla/xla/python/pjit.cc +++ b/third_party/xla/xla/python/pjit.cc @@ -275,10 +275,11 @@ class PjitFunction { } private: - absl::Status UpdateArgsSignature(ParsedArgumentsAsBuffers& arguments); + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& call_signature); void PopulateCacheEntry(PjitCacheEntry& cache_entry, - const CallSignature& signature, const nb::tuple& out_and_fastpath_data); std::string function_name_; @@ -352,31 +353,32 @@ PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } void CallShardArgFallback( nb::handle arg, nb::handle sharding, const nb::callable& fallback, std::vector>& num_args_arrays, - ParsedArgumentsAsBuffers& arguments) { + std::vector& keep_alive_objects) { tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); auto py_array_or_bufs = fallback(arg, sharding); auto py_array = nb::cast(py_array_or_bufs); num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); - arguments.keep_alive_objects.push_back(std::move(py_array_or_bufs)); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); } // Prepares the input PjRtBuffers from the python arguments. This is equivalent // to shard_args() in pxla.py but for only a few supported cases. absl::StatusOr>> PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, - ParsedArgumentsAsBuffers& arguments, - const std::vector& kept_args, + absl::Span flat_dynamic_args, + bool enable_x64, const std::vector& kept_args, const std::vector& in_shardings, - const nb::callable& shard_arg_fallback) { + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { const auto& addressable_devices = executable.ifrt_loaded_executable()->addressable_devices(); - int num_args = arguments.flat_dynamic_args.size(); + int num_args = flat_dynamic_args.size(); std::vector> num_args_arrays; num_args_arrays.reserve(num_args); xla::DevicePutOptions options; - options.squash_64bit_types = !arguments.signature.jax_enable_x64; + options.squash_64bit_types = !enable_x64; options.allow_zero_copy = true; xla::PjRtDevice* data_device = nullptr; if (executable.ifrt_loaded_executable()->num_devices() == 1) { @@ -390,7 +392,7 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, int dce_index = dce_i; ++dce_i; - const nb::object& arg = arguments.flat_dynamic_args[i]; + const nb::object& arg = flat_dynamic_args[i]; auto transfer_guard_formatter = [] { return std::string(""); }; @@ -405,13 +407,13 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, num_args_arrays.push_back(std::move(on_device.ifrt_array)); if (on_device.owning_pybuffer) { - arguments.keep_alive_objects.push_back( - std::move(on_device.owning_pybuffer)); + keep_alive_objects.push_back(std::move(on_device.owning_pybuffer)); } continue; } else { CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, arguments); + shard_arg_fallback, num_args_arrays, + keep_alive_objects); continue; } } @@ -428,20 +430,22 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, arguments); + shard_arg_fallback, num_args_arrays, + keep_alive_objects); continue; } if (py_array.num_shards() != addressable_devices.size()) { CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, arguments); + shard_arg_fallback, num_args_arrays, + keep_alive_objects); continue; } xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); // PyArray inputs should have already been checked in // `xla::PyArgSignatureOfValue()` called by - // `PjitFunction::UpdateArgsSignature()`. + // `PjitFunction::ComputeCallSignature()`. DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; if (sharding_num_devices == 1 && @@ -460,7 +464,7 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, num_args_arrays.push_back(tsl::FormRef(ifrt_array)); } - arguments.keep_alive_objects.push_back(arg); + keep_alive_objects.push_back(arg); } return num_args_arrays; @@ -471,7 +475,6 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, size_t nargs, PyObject* kwnames) { tsl::profiler::TraceMe traceme( [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); - ParsedArgumentsAsBuffers arguments; // Make sure we trigger a garbage collection on JIT function calls. Otherwise // code like @@ -516,9 +519,13 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, absl::Span positional_args(args, num_positional_args); absl::Span keyword_args(args + num_positional_args, num_keyword_args); - auto status = - ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, - static_argnames_, pytree_registry_.get(), arguments); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); if (!status.ok()) { VLOG(2) << "ParseArguments failed: " << status; return fallback_to_cache_miss(); @@ -528,7 +535,7 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it // will fallback to python. For jit, numpy arrays and scalars are also // allowed, which we will check later. - for (const auto& arg : arguments.flat_dynamic_args) { + for (const auto& arg : flat_dynamic_args) { if (arg.type().ptr() != xla::PyArray::type().ptr()) { continue; } @@ -552,17 +559,17 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, } } - status = UpdateArgsSignature(arguments); + status = ComputeCallSignature(flat_dynamic_args, call_signature); if (!status.ok()) { - VLOG(2) << "UpdateArgsSignature failed: " << status; + VLOG(2) << "ComputeCallSignature failed: " << status; return fallback_to_cache_miss(); } - VLOG(2) << "CallSignature:\n" << arguments.signature.DebugString(); + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); bool inserted = false; std::shared_ptr cache_entry = executables_->GetOrCreateIfAbsent( - arguments.signature, [this, &inserted](const CallSignature& unused) { + call_signature, [this, &inserted](const CallSignature& unused) { inserted = true; return std::make_shared(pytree_registry_.get()); }); @@ -573,7 +580,7 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, if (inserted) { nb::object out_and_fastpath_data; nb::tuple out_tuple; - VLOG(2) << "Cache miss for " << arguments.signature.DebugString(); + VLOG(2) << "Cache miss for " << call_signature.DebugString(); try { // Calls Python and may release the GIL. May also throw if // compilation/tracing fails. @@ -583,7 +590,7 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, } out_tuple = nb::cast(out_and_fastpath_data); - PopulateCacheEntry(*cache_entry, arguments.signature, out_tuple); + PopulateCacheEntry(*cache_entry, out_tuple); } catch (const std::exception& e) { VLOG(2) << "cache miss fail: " << e.what(); cache_entry->fall_back_to_python = true; @@ -599,7 +606,7 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, } else { if (cache_entry->thread_id == std::this_thread::get_id()) { auto error_string = absl::StrCat("Recursively calling jit: ", - arguments.signature.DebugString()); + call_signature.DebugString()); PyErr_SetString(PyExc_RecursionError, error_string.c_str()); throw nb::python_error(); } @@ -617,8 +624,9 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, // A vector of [num_inputs]. auto num_args_arrays = PrepareIfrtInputs( - *cache_entry->executable, arguments, cache_entry->kept_var_bitvec, - cache_entry->in_shardings, shard_arg_fallback_); + *cache_entry->executable, flat_dynamic_args, + call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, + cache_entry->in_shardings, shard_arg_fallback_, keep_alive_objects); if (!num_args_arrays.ok()) { VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); @@ -684,49 +692,48 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, return out; } -absl::Status PjitFunction::UpdateArgsSignature( - ParsedArgumentsAsBuffers& arguments) { - arguments.signature.function_name = function_name_; +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, CallSignature& signature) { + signature.function_name = function_name_; // Get dynamic argument signatures. JitState& global_state = jax::GlobalJitState(); JitState& tls = jax::ThreadLocalJitState(); bool jax_enable_x64 = GetEnableX64(); - arguments.signature.default_device = GetDefaultDevice(); - arguments.signature.jax_enable_x64 = jax_enable_x64; - arguments.signature.jax_enable_memories = GetEnableMemories(); + signature.default_device = GetDefaultDevice(); + signature.jax_enable_x64 = jax_enable_x64; + signature.jax_enable_memories = GetEnableMemories(); - auto& dynamic_arg_signatures = arguments.signature.dynamic_arg_signatures; - dynamic_arg_signatures.reserve(arguments.flat_dynamic_args.size()); - auto& dynamic_arg_shardings = arguments.signature.dynamic_arg_shardings; - dynamic_arg_shardings.reserve(arguments.flat_dynamic_args.size()); + auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); - for (nb::handle arg : arguments.flat_dynamic_args) { - TF_ASSIGN_OR_RETURN(auto signature, + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, xla::PyArgSignatureOfValue(arg, jax_enable_x64)); - arguments.signature.dynamic_arg_signatures.push_back(std::move(signature)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); // It should be already checked previously in the entry point of // PjitFunction::Call(). if (arg.type().ptr() == xla::PyArray::type().ptr()) { auto py_array = nb::borrow(arg); - arguments.signature.dynamic_arg_shardings.push_back(py_array.sharding()); - arguments.signature.committed_args.push_back(py_array.committed()); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + signature.committed_args.push_back(py_array.committed()); } else { - arguments.signature.dynamic_arg_shardings.push_back(nb::none()); - arguments.signature.committed_args.push_back(false); + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.committed_args.push_back(false); } } - arguments.signature.thread_local_extra_jit_context = tls.extra_jit_context; - arguments.signature.global_extra_jit_context = global_state.extra_jit_context; + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; return absl::OkStatus(); } void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, - const CallSignature& signature, const nb::tuple& out_and_fastpath_data) { DCHECK_EQ(out_and_fastpath_data.size(), 2); diff --git a/third_party/xla/xla/python/pmap_lib.cc b/third_party/xla/xla/python/pmap_lib.cc index 9caaa46d0a44a6..61b16dedccbd1a 100644 --- a/third_party/xla/xla/python/pmap_lib.cc +++ b/third_party/xla/xla/python/pmap_lib.cc @@ -30,6 +30,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -330,7 +331,7 @@ class PmapFunction { } const std::vector& static_argnums() const { return static_argnums_; } - // nanobind::object typed subclass for PmapFunction objects. + // nb::object typed subclass for PmapFunction objects. class pyobject : public nb::object { public: NB_OBJECT(pyobject, nb::object, "PmapFunction", @@ -363,27 +364,28 @@ class PmapFunction { // // It deals with the arguments signatures and also of the global and // thread-local jit context. - absl::Status UpdateArgsSignature(ParsedArgumentsAsBuffers& arguments) { - arguments.signature.function_name = function_name_; + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& signature) { + signature.function_name = function_name_; // Get dynamic argument signatures. JitState& global_state = jax::GlobalJitState(); JitState& tls = jax::ThreadLocalJitState(); const bool jax_enable_x64 = GetEnableX64(); - arguments.signature.jax_enable_x64 = jax_enable_x64; - for (nb::handle arg : arguments.flat_dynamic_args) { + signature.jax_enable_x64 = jax_enable_x64; + for (nb::handle arg : flat_dynamic_args) { auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64); if (!signature_or_error.ok()) { VLOG(2) << "PyArgSignatureOfValue failed: " << signature_or_error.status(); return signature_or_error.status(); } - arguments.signature.dynamic_arg_signatures.push_back( + signature.dynamic_arg_signatures.push_back( std::move(signature_or_error).value()); } - arguments.signature.thread_local_extra_jit_context = tls.extra_jit_context; - arguments.signature.global_extra_jit_context = - global_state.extra_jit_context; + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; return absl::Status(); } @@ -404,7 +406,6 @@ class PmapFunction { private: // Mutates `cache_entry` in place. void PopulateCacheEntry(PmapCacheEntry& cache_entry, - const CallSignature& signature, const nb::tuple& out_and_fastpath_data); bool always_fallback_to_python_ = false; @@ -428,7 +429,6 @@ class PmapFunction { }; void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, - const CallSignature& signature, const nb::tuple& out_and_fastpath_data) { CHECK_EQ(out_and_fastpath_data.size(), 2); if (out_and_fastpath_data[1].is_none()) { @@ -551,16 +551,19 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, absl::Span positional_args(args, num_positional_args); absl::Span keyword_args(args + num_positional_args, num_keyword_args); - ParsedArgumentsAsBuffers arguments; + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; absl::Status status = ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, - /*static_argnames=*/{}, pytree_registry_.get(), arguments); + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); if (!status.ok()) { VLOG(2) << "ParseArguments failed: " << status; return fallback_to_cache_miss(); } - status = UpdateArgsSignature(arguments); + status = ComputeCallSignature(flat_dynamic_args, call_signature); if (!status.ok()) { return fallback_to_cache_miss(); } @@ -570,7 +573,7 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, it; bool inserted; std::tie(it, inserted) = executables_.try_emplace( - arguments.signature, std::unique_ptr()); + call_signature, std::unique_ptr()); if (inserted) { it->second = std::make_unique(pytree_registry_.get()); } @@ -582,7 +585,7 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, if (inserted) { nb::object out_and_fastpath_data; nb::tuple out_tuple; - VLOG(2) << "Cache miss for " << arguments.signature.DebugString(); + VLOG(2) << "Cache miss for " << call_signature.DebugString(); try { // Calls Python and may release the GIL. May also throw if // compilation/tracing fails. @@ -592,7 +595,7 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, } out_tuple = nb::cast(out_and_fastpath_data); - PopulateCacheEntry(cache_entry, arguments.signature, out_tuple); + PopulateCacheEntry(cache_entry, out_tuple); } catch (const std::exception& e) { cache_entry.fall_back_to_python = true; cache_entry.compilation_complete.Notify(); @@ -618,20 +621,19 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, // 1. Parse arguments. std::vector& input_devices = cache_entry.devices; std::vector& input_specs = cache_entry.input_specs; - const int num_args = arguments.flat_dynamic_args.size(); + const int num_args = flat_dynamic_args.size(); // We need [num_args] for the `Execute` call below. std::vector> num_args_arrays(num_args); for (int i = 0; i < num_args; ++i) { TF_ASSIGN_OR_RETURN( ShardArgResult sharded_arg, - ShardArg(arguments.flat_dynamic_args[i].ptr(), input_devices, - input_specs[i], cache_entry.py_devices, - python_shard_arg_fallback_)); + ShardArg(flat_dynamic_args[i].ptr(), input_devices, input_specs[i], + cache_entry.py_devices, python_shard_arg_fallback_)); num_args_arrays[i] = std::move(sharded_arg.ifrt_array); if (sharded_arg.owning_sda) { - arguments.keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); } } From 3f12e20fe4979527b1fdd85ec570d3334f9cd46a Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Thu, 28 Mar 2024 05:49:17 -0700 Subject: [PATCH 017/124] Add support for folding constant unary non-compute ops in RTVar optimization So far we can only optimize away RTVars that refer to a `constant` or `iota` HLO op. This change is extending that to unary ops without compute (given the operand is also constant). That includes: - Bitcast - Broadcast - Reshape - Reverse - Slice - Transpose PiperOrigin-RevId: 619898399 --- .../xla/xla/service/gpu/model/indexing_map.cc | 82 ++++++++++++----- .../service/gpu/model/indexing_map_test.cc | 88 +++++++++++++++++++ 2 files changed, 148 insertions(+), 22 deletions(-) diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 129a5e67cf7c4e..a740243bc2e795 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/types/span.h" @@ -38,6 +39,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/model/affine_map_printer.h" #include "tsl/platform/logging.h" // IWYU pragma: keep @@ -1337,26 +1339,50 @@ bool IndexingMap::RescaleSymbols() { return !to_delete.empty(); } -static std::optional FoldsIntoConstantIndexingExpression( - const HloInstruction* instr, const mlir::AffineMap& affine_map, - MLIRContext* mlir_context, +// Returns either: +// 1. an AffineExpr if the RTVar folds entirely into a constant expression +// 2. an updated RTVar if some partial optimization was possible +// 3. an unchanged RTVar if no optimization was possible +static std::variant OptimizeRTVar( + RTVar rt_var, MLIRContext* mlir_context, IndexingMap::IndexingMapProvider indexing_map_provider) { - if (auto constant_expr = DynCast(instr)) { - if (affine_map.isConstant()) { - const auto idx = affine_map.getConstantResults(); - return getAffineConstantExpr( - constant_expr->literal().GetIntegralAsS64(idx).value(), mlir_context); + while (true) { + if (auto constant_expr = DynCast(rt_var.hlo)) { + if (rt_var.map.isConstant()) { + const auto idx = rt_var.map.getConstantResults(); + return getAffineConstantExpr( + constant_expr->literal().GetIntegralAsS64(idx).value(), + mlir_context); + } + return rt_var; } - return std::nullopt; - } - if (auto iota_expr = DynCast(instr)) { - auto iota_dimension = iota_expr->iota_dimension(); - CHECK(iota_dimension < affine_map.getNumResults()); - return affine_map.getResults()[iota_dimension]; - } + if (auto iota_expr = DynCast(rt_var.hlo)) { + auto iota_dimension = iota_expr->iota_dimension(); + CHECK(iota_dimension < rt_var.map.getNumResults()); + return rt_var.map.getResults()[iota_dimension]; + } - return std::nullopt; + auto is_indexing_transformation = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kBroadcast || + instr->opcode() == HloOpcode::kReshape || + instr->opcode() == HloOpcode::kReverse || + instr->opcode() == HloOpcode::kSlice || + instr->opcode() == HloOpcode::kTranspose; + }; + + if (is_indexing_transformation(rt_var.hlo)) { + auto instr_indexing_map = + indexing_map_provider(rt_var.hlo, 0, mlir_context); + + rt_var.hlo = rt_var.hlo->operand(0); + rt_var.map = instr_indexing_map.GetAffineMap().compose(rt_var.map); + continue; + } + + return rt_var; + } } bool IndexingMap::ReplaceConstantRTVars( @@ -1365,22 +1391,34 @@ bool IndexingMap::ReplaceConstantRTVars( std::vector to_delete; - for (const auto& [index, rt_var] : llvm::enumerate(rt_vars_)) { - auto folded_expr = FoldsIntoConstantIndexingExpression( - rt_var.hlo, rt_var.map, GetMLIRContext(), indexing_map_provider); - if (!folded_expr.has_value()) continue; + for (auto index = 0; index < rt_vars_.size(); ++index) { + auto& rt_var = rt_vars_[index]; + auto result = + OptimizeRTVar(rt_var, GetMLIRContext(), indexing_map_provider); + + // If we got an RTVar back, then we just replace it and move on. + if (std::holds_alternative(result)) { + rt_var = std::get(std::move(result)); + continue; + } + + // But if we received an AffineExpr we can eliminate the RTVar from + // all expressions in the indexing map. + auto folded_expr = std::get(std::move(result)); + // range_vars and rt_vars share the symbol space, with the rt_vars coming + // after the range_vars. auto symbol_index = range_vars_.size() + index; affine_map_ = affine_map_.replace( {{mlir::getAffineSymbolExpr(symbol_index, GetMLIRContext()), - folded_expr.value()}}); + folded_expr}}); llvm::DenseMap replacements; for (const auto& [constraint, interval] : constraints_) { auto modified_constraint = constraint.replace( mlir::getAffineSymbolExpr(symbol_index, GetMLIRContext()), - folded_expr.value()); + folded_expr); if (constraint == modified_constraint) continue; replacements[constraint] = modified_constraint; diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index 084b535708d332..2b250db7871099 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -826,6 +826,94 @@ TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { )")); } +TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {12}), 0); + auto transpose = HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PrimitiveType::S64, {32, 12}), iota.get(), {1}); + + // (d0, 11): d0 maps into the broadcasted dimension, so it doesn't matter + // and 11 maps to 11 in iota. + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 31}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 11}, transpose.get(), + ParseAffineMap("(d0) -> (d0, 11)", &mlir_context_)}}); + + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, 11) + domain: + d0 in [0, 31] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { + auto iota = HloInstruction::CreateIota( + ShapeUtil::MakeShape(PrimitiveType::S64, {12}), 0); + auto reverse = HloInstruction::CreateReverse( + ShapeUtil::MakeShape(PrimitiveType::S64, {12}), iota.get(), {0}); + auto reshape = HloInstruction::CreateReshape( + ShapeUtil::MakeShape(PrimitiveType::S64, {3, 4}), reverse.get()); + auto broadcast = HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PrimitiveType::S64, {36, 3, 4}), reshape.get(), + {1, 2}); + + // - Iota: [0, 1, ,,,, 11] + // - Reverse: [11, 10, ..., 0] + // - Reshape: [[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]] + // - Coordinates: (d0 floordiv 12, 3) + // - y-coordinate=3 means we index into [8, 4, 0] + // - x-coordinate=(d0 floordiv 12) means our constant looks like this: + // [8, ..., 8, 4, ..., 4, 0, ..., 0] + // - Hence our final expression: (d0 floordiv 12) * -4 + 8 + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 35}}, + /*range_vars=*/{}, + {RTVar{ + Interval{0, 11}, broadcast.get(), + ParseAffineMap("(d0) -> (d0, d0 floordiv 12, 3)", &mlir_context_)}}); + + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0) -> (d0, (d0 floordiv 12) * -4 + 8) + domain: + d0 in [0, 35] + )")); +} + +TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { + auto iota = HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 7, 25, 1, 7, 25, 1, 7, 25, 1, 7, 25})); + auto broadcast = HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PrimitiveType::S64, {24, 12}), iota.get(), {1}); + + // (d0, d0 floordiv 2): d0 maps into the broadcasted dimension, so it can't be + // removed, but d0 floordiv 2 doesn't yield an affine expression so we need to + // keep the RTVar, but can optimize it by removing the broadcast. + IndexingMap indexing_map( + ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), + /*dimensions=*/{{0, 23}}, + /*range_vars=*/{}, + {RTVar{Interval{0, 512}, broadcast.get(), + ParseAffineMap("(d0) -> (d0, d0 floordiv 2)", &mlir_context_)}}); + + indexing_map.Simplify(GetIndexingMapForInstruction); + + EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + (d0)[s0] -> (d0, s0) + domain: + d0 in [0, 23] + s0 in [0, 512] + hlo: %constant = s64[12]{0} constant({...}) + (d0) -> (d0 floordiv 2) + )")); +} + } // namespace } // namespace gpu } // namespace xla From 5ceb83858e590102a69b10ce9682ab86bf427c63 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 28 Mar 2024 05:55:20 -0700 Subject: [PATCH 018/124] [XLA:GPU] Deprecate Triton codegen before Ampere. Unfortunately, upstream Triton has decided to drop support for NVIDIA GPUs below Ampere, so we bump the GPU version requirements for using Triton. PiperOrigin-RevId: 619899728 --- third_party/xla/xla/service/gpu/BUILD | 26 ++- .../xla/xla/service/gpu/gemm_fusion.cc | 14 +- .../service/gpu/gemm_fusion_autotuner_test.cc | 42 ---- .../xla/xla/service/gpu/gemm_fusion_test.cc | 19 +- .../xla/xla/service/gpu/gpu_compiler.cc | 4 +- .../xla/xla/service/gpu/ir_emitter_triton.cc | 32 +-- .../ir_emitter_triton_parametrized_test.cc | 172 +-------------- .../xla/service/gpu/ir_emitter_triton_test.cc | 205 +++++------------- .../xla/service/gpu/nvptx_compiler_test.cc | 21 +- .../service/gpu/softmax_rewriter_triton.cc | 7 + .../gpu/softmax_rewriter_triton_test.cc | 75 +++---- third_party/xla/xla/service/gpu/tests/BUILD | 12 +- .../gpu/tests/gpu_triton_custom_call_test.cc | 97 ++++++++- .../xla/xla/service/gpu/triton_support.cc | 16 +- .../service/gpu/triton_tiling_propagation.cc | 5 - 15 files changed, 285 insertions(+), 462 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 3ae6b8646d3822..2e2d4b943ad100 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -599,7 +599,6 @@ xla_test( srcs = if_cuda_is_configured(["ir_emitter_triton_test.cc"]), backends = [ "gpu_a100", - "gpu_v100", "gpu_h100", ], shard_count = 20, @@ -636,7 +635,6 @@ xla_test( "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -652,7 +650,7 @@ xla_test( name = "ir_emitter_triton_large_test", srcs = if_cuda_is_configured(["ir_emitter_triton_large_test.cc"]), backend_tags = {"gpu": [ - "requires-gpu-sm70", + "requires-gpu-sm80", ]}, backends = [ "gpu", @@ -680,7 +678,6 @@ xla_test( srcs = if_cuda_is_configured(["ir_emitter_triton_parametrized_test.cc"]), backends = [ "gpu_a100", - "gpu_v100", ], shard_count = 10, tags = ["nomac"], @@ -768,7 +765,7 @@ xla_test( name = "gemm_fusion_autotuner_test", srcs = if_cuda_is_configured(["gemm_fusion_autotuner_test.cc"]), backend_tags = {"gpu": [ - "requires-gpu-sm70", + "requires-gpu-sm80", ]}, backends = [ "gpu", @@ -1618,8 +1615,10 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], ) @@ -3965,26 +3964,27 @@ cc_library( ]), ) -xla_cc_test( +xla_test( name = "nvptx_compiler_test", srcs = if_gpu_is_configured([ "nvptx_compiler_test.cc", ]), + backends = [ + "gpu_v100", + "gpu_a100", + ], tags = [ "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. - "requires-gpu-sm70", ], deps = [ - ":gpu_compiler", ":nvptx_compiler_impl", - "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:buffer_assignment", - "//xla/service:gpu_plugin", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status:statusor", @@ -5917,13 +5917,15 @@ xla_cc_test( ], ) -xla_cc_test( +xla_test( name = "determinism_test", srcs = ["determinism_test.cc"], + backends = [ + "gpu_a100", + ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - tags = tf_gpu_tests_tags(), deps = [ ":autotuner_util", "//xla:literal", diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc index 2bceba577ff8d5..7518fa51269533 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion.cc @@ -794,7 +794,7 @@ bool IsSupportedByTriton(PrecisionConfig::Algorithm algorithm, switch (algorithm) { case PrecisionConfig::ALG_DOT_TF32_TF32_F32: if (cuda_compute_capability) { - return cuda_compute_capability->IsAtLeastAmpere(); + return true; } return false; case PrecisionConfig::ALG_DOT_BF16_BF16_F32: @@ -802,7 +802,7 @@ bool IsSupportedByTriton(PrecisionConfig::Algorithm algorithm, case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: if (cuda_compute_capability) { - return cuda_compute_capability->IsAtLeastAmpere(); + return true; } if (rocm_compute_capability) { return rocm_compute_capability->has_bf16_dtype_support(); @@ -852,8 +852,7 @@ FusionDecision CanTritonHandleGEMM( return true; case BF16: if (cuda_compute_capability) { - return cuda_compute_capability->IsAtLeast( - stream_executor::CudaComputeCapability::AMPERE); + return true; } if (rocm_compute_capability) { return rocm_compute_capability->has_bf16_dtype_support(); @@ -908,6 +907,13 @@ bool ShouldTritonHandleGEMM(HloDotInstruction& dot, absl::StatusOr GemmFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + auto cuda_compute_capability = + std::get_if(&gpu_version_); + if (!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } + bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc index d4557f4737f70d..f8def9455b2c1f 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc @@ -230,28 +230,6 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; -TEST_F(GemmFusionAutotunerTest, VoltaUsesNoMoreThanTwoStages) { - std::unique_ptr module = ParseAndReturnVerifiedModule(R"( -ENTRY e { - p0 = f32[1024,1024] parameter(0) - p1 = f32[1024,1024] parameter(1) - ROOT r = f32[1024,1024] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})") - .value(); - const se::CudaComputeCapability compute_capability{ - se::CudaComputeCapability::VOLTA, /*minor=*/0}; - TF_ASSERT_OK_AND_ASSIGN( - const std::vector configs, - GetPossibleMatmulAutotuneConfigs( - *Cast( - module->entry_computation()->root_instruction()), - compute_capability, GetDebugOptionsForTest())); - EXPECT_FALSE(std::any_of( - configs.begin(), configs.end(), - [](const TritonGemmConfig& config) { return config.num_stages > 2; })); -} - TEST_F(GemmFusionAutotunerTest, AmpereUsesMoreThanTwoStages) { std::unique_ptr module = ParseAndReturnVerifiedModule(R"( ENTRY e { @@ -366,10 +344,6 @@ ENTRY e { } TEST_F(GemmFusionAutotunerTest, SelectsSplitK) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } // Shapes with K >> M, N have to force split-K configurations. const std::string kHloText = R"( HloModule t @@ -395,10 +369,6 @@ ENTRY e { } TEST_F(GemmFusionAutotunerTestWithMorePreciseReduction, SelectsSplitK) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } // Shapes with K >> M, N have to force split-K configurations. constexpr absl::string_view kHloText = R"( HloModule t @@ -468,10 +438,6 @@ ENTRY %e { backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Not enough shared memory to run big tiles before Ampere."; - } auto module = ParseAndReturnVerifiedModule(kHloText).value(); EXPECT_THAT( backend().compiler()->RunBackend(std::move(module), @@ -507,10 +473,6 @@ ENTRY %e { backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Not enough shared memory to run big tiles before Ampere."; - } auto module = ParseAndReturnVerifiedModule(kHloText).value(); HloModuleConfig config = module->config(); DebugOptions debug_options = config.debug_options(); @@ -594,10 +556,6 @@ ENTRY e { } TEST_F(GemmFusionAutotunerTest, AutotuneCuDnnFusion) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "cuDNN fusion autotuning is not tested before Ampere."; - } const std::string kHlo = R"( fusion1 { p0 = f32[3,28,32] parameter(0) diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc index a4db45eb78d08e..e5986c9968b5ea 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -36,6 +37,7 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -194,7 +196,7 @@ ENTRY e { lhs_contracting_dims={0}, rhs_contracting_dims={0} })")); - const se::CudaComputeCapability cc{se::CudaComputeCapability::VOLTA, 0}; + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; EXPECT_TRUE(CublasRequiresPadding( *xla::Cast( module->entry_computation()->root_instruction()), @@ -215,7 +217,7 @@ ENTRY e { ROOT t = tuple(d, s1) })")); - const se::CudaComputeCapability cc{se::CudaComputeCapability::VOLTA, 0}; + const se::CudaComputeCapability cc{se::CudaComputeCapability::AMPERE, 0}; EXPECT_TRUE(GemmFusion(cc).Run(module.get()).value()); } @@ -759,7 +761,7 @@ e { m::Parameter(), m::Parameter())))); } -TEST_F(GemmFusionLevel2Test, FusionLevelIsLimitedOnVolta) { +TEST_F(GemmFusionLevel2Test, GemmFusionBailsOutPreAmpere) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( ENTRY e { @@ -770,12 +772,13 @@ ENTRY e { ROOT dot = f32[2,2] dot(p0e, p1c), lhs_contracting_dims={1}, rhs_contracting_dims={0} })")); - EXPECT_TRUE( + EXPECT_THAT( GemmFusion(se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}) - .Run(module.get()) - .value()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch((m::Fusion(m::Exp(), m::Parameter())))); + .Run(module.get()), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); } TEST_F(GemmFusionLevel2Test, ParameterUsedElementwiseTwiceIsFused) { diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 8ec5319dd8b10b..c6f0490777370a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1402,7 +1402,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // and may rewrite quantized FP8 GEMMs as higher-precision GEMMs. pipeline.AddPass(gpu_version, /*f8_rewrite=*/true); if (debug_options.xla_gpu_enable_triton_gemm() && cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) { + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { pipeline.AddPass(gpu_version); } // Rewrite non-FP8 GEMMs. @@ -1424,7 +1424,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // harder. if (debug_options.xla_gpu_enable_triton_softmax_fusion() && cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) { + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { pipeline.AddPass>(simplifier_options); pipeline.AddPass(gpu_version); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 98bf1ac6519286..184880632da09e 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -351,8 +351,7 @@ Value Compare(ImplicitLocOpBuilder& b, ValueRange values, Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, ValueRange values) { - if (mlir::getElementTypeOrSelf(values[0]).isa() && - device_info.cuda_compute_capability().IsAtLeastAmpere()) { + if (mlir::getElementTypeOrSelf(values[0]).isa()) { return b.create(values); } // logic: isNaN(lhs) || (!isNan(rhs) && lhs >= rhs) ? lhs : rhs @@ -373,8 +372,7 @@ Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, ValueRange values) { - if (mlir::getElementTypeOrSelf(values[0]).isa() && - device_info.cuda_compute_capability().IsAtLeastAmpere()) { + if (mlir::getElementTypeOrSelf(values[0]).isa()) { return b.create(values); } // logic: isNaN(lhs) || (!isNan(rhs) && lhs <= rhs) ? lhs : rhs @@ -969,10 +967,9 @@ absl::Status CreateTritonPipeline( pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); pm.addPass(mlir::createCSEPass()); - if (cc.IsAtLeastAmpere()) { - pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, - config.num_ctas, ccAsInt)); - } + pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, + config.num_ctas, ccAsInt)); + if (!cc.IsAtLeastHopper()) { pm.addPass(mt::gpu::createPrefetchPass()); } @@ -1907,9 +1904,7 @@ bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, if (algorithm == PrecisionConfig::ALG_UNSET) { const HloModule* hlo_module = dot_instr->GetModule(); Type f32 = builder.getF32Type(); - // BF16 datatype is not supported before Ampere. - return device_info.cuda_compute_capability().IsAtLeastAmpere() && - hlo_module->config() + return hlo_module->config() .debug_options() .xla_gpu_enable_bf16_6way_gemm() && dot_input_lhs.getType().cast().getElementType() == f32 && @@ -1929,9 +1924,7 @@ bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr, if (algorithm == PrecisionConfig::ALG_UNSET) { const HloModule* hlo_module = dot_instr->GetModule(); Type f32 = builder.getF32Type(); - // BF16 datatype is not supported before Ampere. - return device_info.cuda_compute_capability().IsAtLeastAmpere() && - hlo_module->config() + return hlo_module->config() .debug_options() .xla_gpu_enable_bf16_3way_gemm() && dot_input_lhs.getType().cast().getElementType() == f32 && @@ -2186,7 +2179,6 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, Value accumulator_next; if (Is6xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, device_info)) { - CHECK(device_info.cuda_compute_capability().IsAtLeastAmpere()); absl::StatusOr accumulator_next_or = Emit6xBfloat16MatMul( b, dot_input_lhs, dot_input_rhs, iter_args.back()); TF_CHECK_OK(accumulator_next_or.status()); @@ -2753,6 +2745,11 @@ absl::StatusOr TritonWrapper( const se::DeviceDescription& device_info, const TritonGemmConfig& config, llvm::Module* llvm_module, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) { + if (!cc.IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } + auto debug_options = GetDebugOptionsFromFlags(); if (debug_options.xla_gpu_enable_triton_hopper()) { // Set environment variables for consumption by Triton. @@ -2782,6 +2779,11 @@ absl::StatusOr CompileTritonToLLVM( const se::DeviceDescription& device_info, const TritonGemmConfig& config, mlir::ModuleOp triton_module, llvm::Module* llvm_module, mlir::MLIRContext& mlir_context) { + if (!cc.IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } + bool should_verify = (hlo_config.debug_options().xla_gpu_llvm_verification_level() >= 1); #ifndef NDEBUG diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc index 8e01b864b893ee..3be360ae554379 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_parametrized_test.cc @@ -71,11 +71,6 @@ class MixedTypeTest : public GpuCodegenTest, TEST_P(MixedTypeTest, MixedTypeDotProducesCorrectResult) { MixTypeParams params = GetParam(); - if ((params.lhs_ty == BF16 || params.rhs_ty == BF16) && - !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string hlo_string_template = R"( HloModule m @@ -799,10 +794,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitExactSoftmax) { if (data_type == F16) { GTEST_SKIP() << "Exponential op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } const std::string hlo_text_template = R"( @@ -887,14 +878,7 @@ ENTRY main { const std::string hlo_text = absl::Substitute( hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - std::string hlo_ref_template; - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - hlo_ref_template = R"( -; CHECK-NOT: triton -)"; - } else { - hlo_ref_template = R"( + std::string hlo_ref_template = R"( ; CHECK: ENTRY ; CHECK: %[[P0:.*]] = $0[127,125]{1,0} parameter(0) ; CHECK: ROOT @@ -902,7 +886,6 @@ ENTRY main { ; CHECK-SAME: kind=kCustom ; CHECK-SAME: __triton_softmax )"; - } const std::string hlo_ref = absl::Substitute( hlo_ref_template, primitive_util::LowercasePrimitiveTypeName(data_type)); @@ -927,12 +910,6 @@ ENTRY main { TEST_P(TritonSoftmaxTest, CanFuseAndEmitSoftmaxDiamondWithSmallRows) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - constexpr absl::string_view kHloTextTemplate = R"( HloModule softmax min_computation { @@ -969,12 +946,6 @@ ENTRY main { } TEST_F(TritonSoftmaxTest, CanFuseAndEmitDiamondWithBF16Converts) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text = R"( HloModule softmax max_computation { @@ -1016,10 +987,6 @@ TEST_P( if (data_type == F16) { GTEST_SKIP() << "Exponential op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } const std::string hlo_text_template = R"( @@ -1089,12 +1056,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithMultipleBroadcastDimensions) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1151,12 +1112,7 @@ TEST_P(TritonSoftmaxTest, if (data_type == F16) { GTEST_SKIP() << "Exponential op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1220,12 +1176,6 @@ TEST_P( CanFuseAndEmitTwoDiamondsWithSecondDiamondProducerEqualToFirstDiamondRoot) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1289,12 +1239,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithTrailingUnaryElementwiseAtTheRoot) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1349,12 +1293,6 @@ ENTRY main { TEST_P(TritonSoftmaxTest, CanFuseAndEmitDiamondWithUnaryElementwisePrefix) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1410,12 +1348,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitSoftmaxDiamondWithLastDimensionBitcastAfterReduce) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1469,17 +1401,10 @@ ENTRY main { ErrorSpec(/*aabs=*/tolerance, /*arel=*/tolerance))); } -TEST_P( - TritonSoftmaxTest, - CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectlyForAmpereAndVoltaComputeCapability) { // NOLINT(whitespace/line_length) +TEST_P(TritonSoftmaxTest, + CanFuseAndEmitConvertInvolvingBF16InputIntoSoftmaxDiamondCorrectly) { PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule softmax max_computation { @@ -1499,8 +1424,7 @@ ENTRY main { const std::string hlo_text = absl::Substitute( hlo_text_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { - const std::string hlo_ref = R"( + const std::string hlo_ref = R"( ; CHECK: ENTRY ; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) ; CHECK: ROOT @@ -1509,23 +1433,7 @@ ENTRY main { ; CHECK-SAME: __triton_softmax )"; - MatchOptimizedHlo(hlo_text, hlo_ref); - } else { - const std::string hlo_ref_template = R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = bf16[127,125]{1,0} parameter(0) -; CHECK: %[[CONVERT:.*]] = $0[127,125]{1,0} convert(%[[P0]]) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[CONVERT]]) -; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton_softmax -)"; - - const std::string hlo_ref = - absl::Substitute(hlo_ref_template, - primitive_util::LowercasePrimitiveTypeName(data_type)); - MatchOptimizedHlo(hlo_text, hlo_ref); - } + MatchOptimizedHlo(hlo_text, hlo_ref); float tolerance; switch (data_type) { @@ -1550,12 +1458,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseProducerIntoDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -1612,12 +1514,6 @@ TEST_P( CanFuseAndEmitIntermediateBinaryElementwiseWithinDiamondWhenBothOperandsAreTheSame) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -1674,12 +1570,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhenBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds max_computation { @@ -1745,12 +1635,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseConsumerWhereBothOperandsAreTheSameIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -1813,12 +1697,6 @@ TEST_P( CanFuseAndEmitTwoBinaryElementwiseWhereBothOperandsAreTheSameBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds max_computation { @@ -1919,10 +1797,6 @@ TEST_P(TritonSoftmaxTest, CanFuseAndEmitRMSNormDiamond) { if (data_type == F16) { GTEST_SKIP() << "rsqrt op does not support F16."; - } else if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; } const std::string hlo_text_template = R"( @@ -1989,12 +1863,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds add_computation { @@ -2058,12 +1926,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhereTheSecondOperandIsASplatConstantBetweenDiamonds) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamonds add_computation { @@ -2127,12 +1989,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseWhereTheFirstOperandIsASplatConstantWithinDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond max_computation { @@ -2192,12 +2048,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseConsumerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond add_computation { @@ -2256,12 +2106,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseProducerWhereTheFirstOperandIsASplatConstantIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule fusible_diamond add_computation { @@ -2321,12 +2165,6 @@ TEST_P( CanFuseAndEmitBinaryElementwiseOperationWhereOneOperandIsASharedSplatProducerIntoDiamond) { // NOLINT(whitespace/line_length) PrimitiveType data_type = GetParam(); - if (data_type == BF16 && !GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << R"(No BF16 before Ampere. Pre-Ampere BF16 behavior is tested - in CanFuseAndEmitFirstSoftmaxDiamond, and in SoftmaxRewriterTritonTest.)"; - } - const std::string hlo_text_template = R"( HloModule nonfusible_diamond max_computation { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 78e277f224b316..9a2b5a4e8b1d25 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" @@ -1289,17 +1288,10 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, ParseAndReturnVerifiedModule(kHloText)); - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( CHECK: mma )"); - } else { - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( -CHECK: fma -)"); - } } TEST_F(TritonGemmTest, FailIfTooMuchShmem) { @@ -2021,12 +2013,6 @@ ENTRY e { class TritonGemmLevel2Test : public TritonGemmTest { public: - void SetUp() override { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Triton fusion on pre-Ampere GPUs is limited."; - } - } DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_fusion_level(2); @@ -3061,11 +3047,6 @@ ENTRY e { } TEST_F(CompareTest, BF16TransposedLHS) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } - const char* hlo_text_ref = R"( HloModule r @@ -3106,11 +3087,6 @@ ENTRY e { } TEST_F(CompareTest, UsingOptinSharedMemoryOnAmpereProducesSameResult) { - // On pre-Ampere GPUs the test would use a different amount of shared memory. - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "This test is for Ampere+ GPUs."; - } const se::DeviceDescription dev_info = backend().default_stream_executor()->GetDeviceDescription(); constexpr int kBytesOfSharedMemoryTested = 64 * 1024; @@ -3276,10 +3252,6 @@ ENTRY e { } TEST_F(CompareTest, S8BF16) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* hlo_text_ref = R"( HloModule r @@ -3327,10 +3299,6 @@ ENTRY e { } TEST_F(CompareTest, SplitK) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string hlo_text_ref = R"( HloModule t, is_scheduled=true @@ -3404,10 +3372,6 @@ ENTRY e { } TEST_F(CompareTest, SplitKBatch) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloTextRef = R"( HloModule m, is_scheduled=true @@ -3470,10 +3434,6 @@ ENTRY e { } TEST_F(CompareTest, SplitKNontrivialBitcast) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloTextRef = R"( HloModule module, is_scheduled=true @@ -4048,10 +4008,6 @@ ENTRY e { } TEST_F(CompareTest, PredToBF16ConversionWorks) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloTextTest = R"( HloModule m, is_scheduled=true @@ -4172,10 +4128,6 @@ class TritonGemmContractionDims : public TritonGemmTest { }; TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -4198,10 +4150,6 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -4224,10 +4172,6 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -4251,10 +4195,6 @@ ENTRY e { } TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const std::string kHloText = R"( HloModule m @@ -4318,9 +4258,6 @@ class Triton6xBF16GemmTestWithFlag : public TritonFilecheckTest { }; TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmWhenBothInputsAreF32) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4364,9 +4301,6 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4409,9 +4343,6 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4442,9 +4373,6 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, m } TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleInfinity) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4489,9 +4417,6 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, m } TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleNaN) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4548,9 +4473,6 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, m // x_lo: 5.17201445e+33 // The result of x*x would be NaN instead of positive infinity. TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForInputsWithLargeExponent) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4594,42 +4516,6 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, m ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); } -TEST_F(Triton6xBF16GemmTestWithFlag, ShouldNotEmit6xBF16GemmForPreAmpere) { - if (GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "6xBF16Gemm should be emitted post-Ampere."; - } - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(kHloText)); - - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( -CHECK-NOT: mma -CHECK: selp.f32 -CHECK: st.shared{{(\.v[24])?}}.f32 -CHECK: ld.shared{{(\.v[24])?}}.f32 -CHECK: fma.rn.f32 -CHECK: st.shared{{(\.v[24])?}}.f32 -)"); -} TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { const char* kHloText = R"( @@ -4645,18 +4531,13 @@ ENTRY e { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, ParseAndReturnVerifiedModule(kHloText)); - if (GetCudaComputeCapability().IsAtLeastAmpere()) { - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 )"); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, - /*arel=*/1e-6})); - } else { - EXPECT_THAT(CompileToExecutable(std::move(verified_module)), - tsl::testing::StatusIs(absl::StatusCode::kUnimplemented)); - } + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); } // In these tests, we depend on "algorithm" annotations for selecting the 3XBF16 @@ -4702,9 +4583,6 @@ class Triton3xBF16GemmTestWithFlag : public TritonFilecheckTest { }; TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4748,9 +4626,6 @@ CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf } TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4822,9 +4697,6 @@ CHECK-NOT: tt.dot } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4855,9 +4727,6 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, m } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleInfinity) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4902,9 +4771,6 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, m } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleNaN) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -4951,9 +4817,6 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} {allowTF32 = false, m } TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForInputsWithLargeExponent) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "No BF16 before Ampere."; - } const char* kHloText = R"( HloModule t @@ -5011,19 +4874,57 @@ ENTRY e { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, ParseAndReturnVerifiedModule(kHloText)); - if (GetCudaComputeCapability().IsAtLeastAmpere()) { - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 )"); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, - /*arel=*/1e-5})); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} - } else { - EXPECT_THAT(CompileToExecutable(std::move(verified_module)), - tsl::testing::StatusIs(absl::StatusCode::kUnimplemented)); - } +using TritonEmitterTest = TritonGemmTest; + +TEST_F(TritonEmitterTest, EmitterFailsIfComputeCapabilityIsBelowAmpere) { + const std::string kHloText = R"( +HloModule module, is_scheduled=true + +triton_gemm_dot { + p0 = f32[10,20] parameter(0) + p1 = f32[20,30] parameter(1) + ROOT dot = f32[10,30] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = f32[10,20] parameter(0) + p1 = f32[20,30] parameter(1) + ROOT r = f32[10,30] fusion(p0, p1), + kind=kCustom, calls=triton_gemm_dot +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloText)); + const HloComputation* triton_dot_computation = + hlo_module->entry_computation() + ->root_instruction() + ->fused_instructions_computation(); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*triton_dot_computation), + "test_fn", triton_dot_computation, + se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, + /*minor=*/0}, + dev_info, TritonGemmConfig{}, &llvm_module, &EmitMatMul, + mlir_context), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); } } // namespace diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc index e6ab31b84033cf..f247582d38a5a3 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/service/backend.h" #include "xla/service/buffer_assignment.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla.pb.h" @@ -132,7 +133,7 @@ ENTRY entry { TEST_F(NVPTXCompilerTestTriton, DotDimensionAreSortedBeforePaddingForCublasEnablingTritonFusion) { - MatchOptimizedHlo(R"( + const absl::string_view hlo_string = R"( ENTRY e { p0 = f16[11,22,33,44] parameter(0) p1 = s8[11,22,33,44] parameter(1) @@ -140,13 +141,25 @@ ENTRY e { ROOT d = f16[11,22,44,44] dot(p0, p1c), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} -})", - R"( +})"; + + se::CudaComputeCapability cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + + if (cc.IsAtLeastAmpere()) { + MatchOptimizedHlo(hlo_string, R"( ; CHECK: ENTRY ; CHECK-NEXT: parameter ; CHECK-NEXT: parameter ; CHECK-NEXT: __triton_gemm - )"); + )"); + } else { + MatchOptimizedHlo(hlo_string, R"( +; CHECK-NOT: triton + )"); + } } TEST_F(NVPTXCompilerTest, RemovesUnnecessaryCopyInPostSchedulingPipelines) { diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc index df771c4df85d32..902a10ed935b71 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc @@ -670,6 +670,13 @@ absl::Status SoftmaxRewriterTriton::FuseDiamondChain( absl::StatusOr SoftmaxRewriterTriton::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + auto cuda_compute_capability = + std::get_if(&gpu_version_); + if (!cuda_compute_capability || !cuda_compute_capability->IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } + std::vector diamond_chains = FindAllFusibleDiamondChains(*module, execution_threads); diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc index 0683f405e30f6e..ef0bc0f6f7d7f6 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -1,8 +1,11 @@ /* Copyright 2023 The OpenXLA Authors. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. + You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -1073,9 +1076,8 @@ ENTRY main { GmockMatch(m::Fusion(m::Parameter()))); } -TEST_P( - SoftmaxRewriterTritonTest, - CanOnlyFuseConvertInvolvingBF16InputIntoSoftmaxDiamondWithAtLeastAmpereComputeCapability) { // NOLINT(whitespace/line_length) +TEST_P(SoftmaxRewriterTritonTest, + CanFuseConvertInvolvingBF16InputIntoSoftmaxDiamond) { PrimitiveType data_type = GetParam(); const std::string hlo_string_template = R"( HloModule softmax @@ -1091,52 +1093,51 @@ ENTRY main { reduce = $0[127]{0} reduce(param_0_$0, constant_neg_inf), dimensions={1}, to_apply=max_computation broadcast = $0[127,125]{1,0} broadcast(reduce), dimensions={0} ROOT subtract = $0[127,125]{1,0} subtract(param_0_$0, broadcast) -} -)"; +})"; const std::string hlo_string = absl::Substitute(hlo_string_template, primitive_util::LowercasePrimitiveTypeName(data_type)); - auto ampere_module = ParseAndReturnVerifiedModule(hlo_string).value(); - auto volta_module = ampere_module->Clone(); + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - // Ampere EXPECT_TRUE( SoftmaxRewriterTritonMatchAndRewrite( se::CudaComputeCapability{se::CudaComputeCapability::AMPERE, 0}, - ampere_module.get()) + module.get()) .value()); - EXPECT_TRUE(verifier().Run(ampere_module.get()).status().ok()); - VLOG(2) << ampere_module->ToString(); - EXPECT_THAT(ampere_module->entry_computation()->root_instruction(), + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + VLOG(2) << module->ToString(); + EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter()))); +} - // Volta (pre-Ampere) - VLOG(2) << volta_module->ToString(); +TEST_F(SoftmaxRewriterTritonTest, RewriterBailsOutOnPreAmpereGpu) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = bf16[127,125]{1,0} parameter(0) + param_0_f32 = f32[127,125]{1,0} convert(param_0) + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0_f32, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f32[127,125]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f32[127,125]{1,0} subtract(param_0_f32, broadcast) +})"; - switch (data_type) { - case F32: - case F16: - EXPECT_TRUE( - SoftmaxRewriterTritonMatchAndRewrite( - se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}, - volta_module.get()) - .value()); - EXPECT_TRUE(verifier().Run(volta_module.get()).status().ok()); - EXPECT_THAT(volta_module->entry_computation()->root_instruction(), - GmockMatch(m::Fusion(m::Convert(m::Parameter())))); - break; - case BF16: - // When bf16 is used, no fusion is possible on Volta. - EXPECT_FALSE( - SoftmaxRewriterTritonMatchAndRewrite( - se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}, - volta_module.get()) - .value()); - break; - default: - ABSL_UNREACHABLE(); - } + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + + EXPECT_THAT( + SoftmaxRewriterTriton( + se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}) + .Run(module.get()), + tsl::testing::StatusIs( + tsl::error::FAILED_PRECONDITION, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); } TEST_P(SoftmaxRewriterTritonTest, DoesNotFuseConvertWithC64DataType) { diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 783661f929211d..87617a29b8c289 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -442,21 +442,25 @@ xla_cc_test( ], ) -xla_cc_test( +xla_test( name = "gpu_triton_custom_call_test", srcs = ["gpu_triton_custom_call_test.cc"], - tags = tf_cuda_tests_tags(), + backends = [ + "gpu_a100", + "gpu_v100", + ], deps = [ ":gpu_codegen_test", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc index 439292ca2eab92..7dc6fda816609a 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include +#include #include +#include "absl/status/status.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -29,7 +31,9 @@ limitations under the License. #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/status_matchers.h" namespace xla { namespace gpu { @@ -37,10 +41,22 @@ namespace gpu { using ::mlir::ArrayRef; using ::mlir::NamedAttribute; -using GpuIrEmitterUnnestedTest = GpuCodegenTest; +class GpuIrEmitterUnnestedTest : public GpuCodegenTest { + public: + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } +}; TEST_F(GpuIrEmitterUnnestedTest, EmitTritonCustomCallWithCorrectLoweringAndWithoutNoaliasOrAlignment) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; + } + // Tests that the lowering of a Triton custom call produces the correct LLVM // IR, and that the arguments do not specify noalias or alignment attributes. @@ -139,5 +155,84 @@ TEST_F(GpuIrEmitterUnnestedTest, /*match_optimized_ir=*/false); } +TEST_F(GpuIrEmitterUnnestedTest, CanNotEmitTritonCustomCallOnPreAmpereGpu) { + if (GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "Running on Ampere or more recent GPU, skipping."; + } + + HloComputation::Builder computation_builder(TestName()); + mlir::MLIRContext context_; + mlir::Builder builder(&context_); + + // Create parameters and custom call in the computation builder. + Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + HloInstruction* param_0 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "arg_0")); + + HloInstruction* param_1 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); + + // Create the backend_config for the triton custom call. + const std::string kMLIRText = R"( + module { + tt.func public @add_one(%arg0: !tt.ptr {tt.divisibility = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 32 : i32}, %arg3: !tt.ptr {tt.divisibility = 32 : i32}) { + %0 = tt.get_program_id x : i32 + %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + %cst = arith.constant 1.000000e+00 : f32 + %3 = arith.addf %1, %cst : f32 + %4 = tt.load %arg2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : f32 + %5 = tt.load %arg3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 + tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : f32 + tt.return + } + } + )"; + + NamedAttribute name = + builder.getNamedAttr("name", builder.getStringAttr("add_one")); + NamedAttribute ir = + builder.getNamedAttr("ir", builder.getStringAttr(kMLIRText)); + NamedAttribute num_stages = + builder.getNamedAttr("num_stages", builder.getI32IntegerAttr(3)); + NamedAttribute num_warps = + builder.getNamedAttr("num_warps", builder.getI32IntegerAttr(4)); + NamedAttribute grid_x = + builder.getNamedAttr("grid_x", builder.getI32IntegerAttr(1)); + NamedAttribute grid_y = + builder.getNamedAttr("grid_y", builder.getI32IntegerAttr(1)); + NamedAttribute grid_z = + builder.getNamedAttr("grid_z", builder.getI32IntegerAttr(1)); + NamedAttribute debug = + builder.getNamedAttr("debug", builder.getBoolAttr(false)); + + std::vector attributes = { + name, ir, num_stages, num_warps, grid_x, grid_y, grid_z, debug}; + ArrayRef attributesRef(attributes); + mlir::DictionaryAttr backend_config = + mlir::DictionaryAttr::get(&context_, attributesRef); + + // Parse the backend_config into a string. + std::string backend_config_str; + llvm::raw_string_ostream(backend_config_str) << backend_config; + + computation_builder.AddInstruction(HloInstruction::CreateCustomCall( + tuple_shape, {param_0, param_1}, "__gpu$xla.gpu.triton", + backend_config_str)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(computation_builder.Build()); + + EXPECT_THAT( + CompileToExecutable(std::move(module), /*run_optimization_passes=*/false), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::StrEq( + "Triton support is only enabled for Ampere GPUs and up."))); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_support.cc b/third_party/xla/xla/service/gpu/triton_support.cc index a4c837c182f13b..3e2b15222320d1 100644 --- a/third_party/xla/xla/service/gpu/triton_support.cc +++ b/third_party/xla/xla/service/gpu/triton_support.cc @@ -61,15 +61,13 @@ bool IsTritonSupportedDataType(PrimitiveType type, case F32: return true; case BF16: - return std::visit( - VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeast( - stream_executor::CudaComputeCapability::AMPERE); - }, - [](const se::RocmComputeCapability& cc) { - return cc.has_bf16_dtype_support(); - }}, - gpu_version); + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return true; + }, + [](const se::RocmComputeCapability& cc) { + return cc.has_bf16_dtype_support(); + }}, + gpu_version); default: return false; } diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 9d54797f27f849..89b54ed4b4a4fa 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -1086,11 +1086,6 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( int fusion_level = hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level(); // TODO(ROCm): Check fusion level for ROCm. - if (std::holds_alternative(gpu_version) && - !std::get(gpu_version) - .IsAtLeast(se::CudaComputeCapability::AMPERE)) { - fusion_level = std::min(fusion_level, 1); - } if (transform_direction == TransformDirection::kOutputToInput) { if (fusion_level < 2) { if (hlo.opcode() == HloOpcode::kConvert) { From 34a4054ac11a9161fd791c29ca9f3da3127319fe Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Thu, 28 Mar 2024 05:57:52 -0700 Subject: [PATCH 019/124] #shlo_ref Use _Float16 for F16 with GCC The GCC OSS build was failing to use `_Float16` for the `shlo_ref::F16` type alias despite `_Float16`'s availability with GCC. This was because the `has_keyword` macro was always returning false on GCC, as it was implemented using `__is_identifier`, which is Clang only. _Float16 should be available on both GCC & Clang, so we should just use those unless std::float16_t is available. This also removes has_keyword.h, as it is no longer used and ensures that `BF16` and `F16` map to `std::bfloat16_t` and `std::float16_t` when those are available. PiperOrigin-RevId: 619900330 --- tensorflow/lite/experimental/shlo/BUILD | 7 ---- tensorflow/lite/experimental/shlo/bf16.h | 2 +- tensorflow/lite/experimental/shlo/f16.h | 8 ++--- .../lite/experimental/shlo/has_keyword.h | 32 ------------------- tensorflow/lite/experimental/shlo/ops/BUILD | 2 +- .../experimental/shlo/ops/is_finite_test.cc | 12 ++++--- 6 files changed, 11 insertions(+), 52 deletions(-) delete mode 100644 tensorflow/lite/experimental/shlo/has_keyword.h diff --git a/tensorflow/lite/experimental/shlo/BUILD b/tensorflow/lite/experimental/shlo/BUILD index 02838f1449002d..a274bab28d0f4a 100644 --- a/tensorflow/lite/experimental/shlo/BUILD +++ b/tensorflow/lite/experimental/shlo/BUILD @@ -86,7 +86,6 @@ cc_library( name = "bf16", hdrs = ["bf16.h"], deps = [ - ":has_keyword", "@com_google_absl//absl/base", "@com_google_absl//absl/log:absl_check", ], @@ -106,12 +105,6 @@ cc_test( cc_library( name = "f16", hdrs = ["f16.h"], - deps = [":has_keyword"], -) - -cc_library( - name = "has_keyword", - hdrs = ["has_keyword.h"], ) cc_library( diff --git a/tensorflow/lite/experimental/shlo/bf16.h b/tensorflow/lite/experimental/shlo/bf16.h index 33b614d1888435..3f228fb161af14 100644 --- a/tensorflow/lite/experimental/shlo/bf16.h +++ b/tensorflow/lite/experimental/shlo/bf16.h @@ -19,7 +19,7 @@ limitations under the License. #if defined(__STDCPP_BFLOAT16_T__) #include namespace shlo_ref { -using BF16 = bfloat16_t; +using BF16 = ::std::bfloat16_t; } // namespace shlo_ref #else diff --git a/tensorflow/lite/experimental/shlo/f16.h b/tensorflow/lite/experimental/shlo/f16.h index 2496b31b84dc9f..f18170cb052682 100644 --- a/tensorflow/lite/experimental/shlo/f16.h +++ b/tensorflow/lite/experimental/shlo/f16.h @@ -16,21 +16,17 @@ limitations under the License. #ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_F16_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_F16_H_ -#include "tensorflow/lite/experimental/shlo/has_keyword.h" - #if defined(__STDCPP_FLOAT16_T__) #include namespace shlo_ref { -using F16 = float16_t; +using F16 = ::std::float16_t; } // namespace shlo_ref -#elif __has_keyword(_Float16) +#else namespace shlo_ref { using F16 = _Float16; } // namespace shlo_ref -#else -#error Type F16 is not available #endif #endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_F16_H_ diff --git a/tensorflow/lite/experimental/shlo/has_keyword.h b/tensorflow/lite/experimental/shlo/has_keyword.h deleted file mode 100644 index 548c86eec4de36..00000000000000 --- a/tensorflow/lite/experimental/shlo/has_keyword.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_HAS_KEYWORD_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_HAS_KEYWORD_H_ - -// CAUTION: __is_identifier behaves opposite how you would expect! -// '__is_identifier' returns '0' if '__x' is a reserved identifier provided by -// the compiler and '1' otherwise. borrowed from LLVM __config header under -// Apache license 2. -// (https://www.mend.io/blog/top-10-apache-license-questions-answered/) - -#ifndef __is_identifier // Optional of course. -#define __is_identifier(x) 1 // Compatibility with non-clang compilers. -#endif - -// More sensible macro for keyword detection -#define __has_keyword(__x) !(__is_identifier(__x)) - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_HAS_KEYWORD_H_ diff --git a/tensorflow/lite/experimental/shlo/ops/BUILD b/tensorflow/lite/experimental/shlo/ops/BUILD index b4e3004ff12f8a..8a924dfd80aae3 100644 --- a/tensorflow/lite/experimental/shlo/ops/BUILD +++ b/tensorflow/lite/experimental/shlo/ops/BUILD @@ -35,10 +35,10 @@ cc_test( srcs = ["is_finite_test.cc"], linkopts = shlo_ref_linkopts(), deps = [ - ":benchmark_util", ":is_finite", "//tensorflow/lite/experimental/shlo:bf16", "//tensorflow/lite/experimental/shlo:data_type", + "//tensorflow/lite/experimental/shlo:f16", "//tensorflow/lite/experimental/shlo:shape", "//tensorflow/lite/experimental/shlo:status_matcher", "//tensorflow/lite/experimental/shlo:tensor", diff --git a/tensorflow/lite/experimental/shlo/ops/is_finite_test.cc b/tensorflow/lite/experimental/shlo/ops/is_finite_test.cc index 0c78f5264a1849..be5fbdbcf1817b 100644 --- a/tensorflow/lite/experimental/shlo/ops/is_finite_test.cc +++ b/tensorflow/lite/experimental/shlo/ops/is_finite_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/lite/experimental/shlo/bf16.h" #include "tensorflow/lite/experimental/shlo/data_type.h" +#include "tensorflow/lite/experimental/shlo/f16.h" #include "tensorflow/lite/experimental/shlo/shape.h" #include "tensorflow/lite/experimental/shlo/status_matcher.h" #include "tensorflow/lite/experimental/shlo/tensor.h" @@ -65,11 +66,12 @@ INSTANTIATE_TEST_SUITE_P( BF16{-1.0f}, BF16{0.0f}, BF16{1.0f}}), TensorWithData::Create( Shape{{7}}, {false, false, false, false, true, true, true})}, - Params{TensorWithData::Create( - Shape{{7}}, - {+NAN, -NAN, -INFINITY, +INFINITY, -1.0f, 0.0f, 1.0f}), - TensorWithData::Create( - Shape{{7}}, {false, false, false, false, true, true, true})}, + Params{ + TensorWithData::Create( + Shape{{7}}, {F16{+NAN}, F16{-NAN}, F16{-INFINITY}, + F16{+INFINITY}, F16{-1.0f}, F16{0.0f}, F16{1.0f}}), + TensorWithData::Create( + Shape{{7}}, {false, false, false, false, true, true, true})}, Params{ TensorWithData::Create( Shape{{7}}, From 5b5e2fbd3705ea460d861745cc37940e88c8fe2a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 06:25:00 -0700 Subject: [PATCH 020/124] Treat entry computation parameter that has host memory space as MoveToHost annotation in host_offload_legalize. PiperOrigin-RevId: 619907383 --- third_party/xla/xla/service/hlo_verifier.cc | 10 ++-- .../xla/xla/service/host_offload_legalize.cc | 45 ++++++++++---- .../xla/service/host_offload_legalize_test.cc | 58 ++++++++++++++++++- .../xla/xla/service/layout_assignment.cc | 21 +++++-- 4 files changed, 113 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 497b74734920af..9b526c3bae41f0 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -1160,11 +1160,12 @@ Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { (ShapeUtil::ArrayDataSize(output_shape) == ShapeUtil::ArrayDataSize(operand_shape)))) { return Internal( - "Bitcast cannot have different shape sizes of output (%d) and " + "%s: Bitcast cannot have different shape sizes of output (%d) and " "operand " "(%d) (%s) (%s)", - opts_.shape_size(output_shape), opts_.shape_size(operand_shape), - output_shape.ToString(true), operand_shape.ToString(true)); + bitcast->ToString(), opts_.shape_size(output_shape), + opts_.shape_size(operand_shape), output_shape.ToString(true), + operand_shape.ToString(true)); } } return OkStatus(); @@ -1975,7 +1976,8 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { if (!ShapesSame(parameter->shape(), layout.parameter_shape(i), Shape::Equal() .IgnoreTilesInLayout() - .IgnoreTailPaddingAlignmentInElements())) { + .IgnoreTailPaddingAlignmentInElements() + .IgnoreMemorySpaceInLayout())) { return Internal( "Shape of the entry computation parameter %d is %s should be " "compatible to the one specified in module's entry computation " diff --git a/third_party/xla/xla/service/host_offload_legalize.cc b/third_party/xla/xla/service/host_offload_legalize.cc index a56a5aad1ddc4c..958ec67718c51e 100644 --- a/third_party/xla/xla/service/host_offload_legalize.cc +++ b/third_party/xla/xla/service/host_offload_legalize.cc @@ -47,7 +47,7 @@ constexpr std::array kUsersOpcodes = {HloOpcode::kSlice, HloOpcode::kDynamicSlice}; // Find an annotation moving up. Meant to find an annotation from a DUS operand. -HloInstruction* FindAnnotationToUpdate(HloInstruction* instr) { +HloInstruction* FindToHostAnnotationToUpdate(HloInstruction* instr) { while (!instr->IsCustomCall( host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { if ((instr->opcode() != HloOpcode::kBitcast && @@ -61,7 +61,8 @@ HloInstruction* FindAnnotationToUpdate(HloInstruction* instr) { return instr; } -// Find an annotation moving up. Meant to find an annotation from a DUS operand. +// Find an annotation moving up. Meant to find an annotation from a DUS +// instruction. HloInstruction* FindToDeviceAnnotationToUpdate(HloInstruction* instr) { while (!instr->IsCustomCall( host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { @@ -330,6 +331,11 @@ absl::StatusOr ProcessAnnotationForCopyMovement( HloInstruction* instruction, const CallGraph* call_graph, absl::flat_hash_set& processed_annotations, std::vector& to_remove) { + auto is_entry_computation_parameter = [](HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kParameter && + instruction->parent()->IsEntryComputation(); + }; + HloInstruction* starting_instr = FindDUSFromAnnotation(instruction->users()[0]); // If it's the pure copy case reset instruction. @@ -343,7 +349,8 @@ absl::StatusOr ProcessAnnotationForCopyMovement( // to update (required in case there are multiple insertions in the buffer). processed_annotations.insert(current_value.first); if (!current_value.first->IsCustomCall( - host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + host_memory_offload_annotations::kMoveToHostCustomCallTarget) && + !is_entry_computation_parameter(current_value.first)) { CHECK_EQ(current_value.first->opcode(), HloOpcode::kDynamicUpdateSlice); while (true) { VLOG(10) << "Current value before: " << current_value.first->ToString(); @@ -361,7 +368,7 @@ absl::StatusOr ProcessAnnotationForCopyMovement( HloInstruction* annotation = current_value.first; if (annotation->opcode() == HloOpcode::kDynamicUpdateSlice) { HloInstruction* real_annotation = - FindAnnotationToUpdate(annotation->mutable_operand(1)); + FindToHostAnnotationToUpdate(annotation->mutable_operand(1)); // Check if this dynamic-update-slice doesn't have an annotation // attached. if (!real_annotation->IsCustomCall( @@ -473,8 +480,11 @@ absl::StatusOr ProcessAnnotationForCopyMovement( } update_shape_layout(std::make_pair(new_annotation, -1), copy_to_move.first); + Shape new_copy_shape = new_annotation->shape(); + *new_copy_shape.mutable_layout() = + copy_to_move.first->shape().layout(); HloInstruction* new_copy = instruction.first->AddInstruction( - copy_to_move.first->CloneWithNewOperands(new_annotation->shape(), + copy_to_move.first->CloneWithNewOperands(new_copy_shape, {new_annotation})); std::vector users = instruction.first->users(); for (auto* use : users) { @@ -495,8 +505,8 @@ absl::StatusOr ProcessAnnotationForCopyMovement( // Move the annotation first just before dynamic-update-slice to avoid // shape changes. if (instruction.first->opcode() == HloOpcode::kDynamicUpdateSlice) { - HloInstruction* annotation = - FindAnnotationToUpdate(instruction.first->mutable_operand(1)); + HloInstruction* annotation = FindToHostAnnotationToUpdate( + instruction.first->mutable_operand(1)); if (annotation == nullptr) { CHECK(false); return false; @@ -535,7 +545,7 @@ absl::StatusOr FixupInterveningCopies( std::vector annotations_to_remove; bool changed = false; for (HloInstruction* instruction : copy_to_host_annotations) { - if (processed_annotations.count(instruction)) { + if (processed_annotations.contains(instruction)) { continue; } TF_ASSIGN_OR_RETURN(bool changed_annotation_for_copy_movement, @@ -576,11 +586,22 @@ absl::StatusOr HostOffloadLegalize::Run( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kCustomCall) { - continue; + if (instruction->opcode() == HloOpcode::kParameter && + instruction->parent()->IsEntryComputation()) { + Shape param_shape = + module->entry_computation_layout() + .parameter_layout(instruction->parameter_number()) + .shape(); + // TODO(mingyao): Add support for tuple parameter. + if (param_shape.has_layout() && + param_shape.layout().memory_space() == kHostMemorySpaceColor) { + copy_to_host_annotations.push_back(instruction); + continue; + } } - if (instruction->custom_call_target() == - host_memory_offload_annotations::kMoveToHostCustomCallTarget) { + + if (instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { copy_to_host_annotations.push_back(instruction); } } diff --git a/third_party/xla/xla/service/host_offload_legalize_test.cc b/third_party/xla/xla/service/host_offload_legalize_test.cc index a1abd7e0a188b9..f9929648a3f12f 100644 --- a/third_party/xla/xla/service/host_offload_legalize_test.cc +++ b/third_party/xla/xla/service/host_offload_legalize_test.cc @@ -78,7 +78,7 @@ class HostOffloadLegalizeTest : public HloTestBase { TEST_F(HostOffloadLegalizeTest, NoCopyWithOptBarrierMoreElaborate) { const std::string& hlo_string = R"( -HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1})->f32[16,256]{0,1}} +HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1})->f32[16,256]{1,0}} ENTRY main.24 { Arg_0.1 = f32[16,256]{0,1} parameter(0) @@ -120,6 +120,62 @@ ENTRY main.24 { HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.18"); EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1})); + EXPECT_EQ(custom_call->users()[0]->shape().layout(), + LayoutUtil::MakeLayout({1, 0})); +} + +TEST_F(HostOffloadLegalizeTest, XposeCopyOnParameterStreaming) { + const std::string& hlo_string = R"( +HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1},f32[16,256]{0,1:T(8,128)S(5)})->f32[16,256]{1,0}} + +ENTRY main.24 { + Arg_0.1 = f32[16,256]{0,1} parameter(0) + Arg_0.2 = f32[16,256]{0,1:T(8,128)} parameter(1) + cp0 = f32[16,256]{1,0} copy(Arg_0.2) + cosine.4 = f32[16,256]{0,1} cosine(Arg_0.1) + custom-call.5 = f32[16,256]{0,1} custom-call(cosine.4), custom_call_target="MoveToHost" + sine.3 = f32[16,256]{0,1} sine(Arg_0.1) + cosine.7 = f32[16,256]{0,1} cosine(sine.3) + custom-call.8 = f32[16,256]{0,1} custom-call(cosine.7), custom_call_target="MoveToHost" + constant.2 = f32[] constant(1) + cp1 = f32[16,256]{1,0} copy(custom-call.8) + tuple.11 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{1,0}, f32[]) tuple(custom-call.5, cp1, cp0, constant.2) + opt-barrier.12 = (f32[16,256]{0,1}, f32[16,256]{1,0}, f32[16,256]{1,0}, f32[]) opt-barrier(tuple.11) + get-tuple-element.16 = f32[] get-tuple-element(opt-barrier.12), index=3 + broadcast.20 = f32[16,256]{0,1} broadcast(get-tuple-element.16), dimensions={} + get-tuple-element.15 = f32[16,256]{1,0} get-tuple-element(opt-barrier.12), index=2 + custom-call.19 = f32[16,256]{1,0} custom-call(get-tuple-element.15), custom_call_target="MoveToDevice" + multiply.21 = f32[16,256]{0,1} multiply(broadcast.20, custom-call.19) + cp2 = f32[16,256]{1,0} copy(multiply.21) + get-tuple-element.14 = f32[16,256]{1,0} get-tuple-element(opt-barrier.12), index=1 + custom-call.18 = f32[16,256]{1,0} custom-call(get-tuple-element.14), custom_call_target="MoveToDevice" + multiply.22 = f32[16,256]{1,0} multiply(cp2, custom-call.18) + get-tuple-element.13 = f32[16,256]{0,1} get-tuple-element(opt-barrier.12), index=0 + custom-call.17 = f32[16,256]{0,1} custom-call(get-tuple-element.13), custom_call_target="MoveToDevice" + cp3 = f32[16,256]{1,0} copy(custom-call.17) + ROOT multiply.23 = f32[16,256]{1,0} multiply(multiply.22, cp3) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + + EXPECT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.18"); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + EXPECT_EQ(custom_call->shape().layout(), LayoutUtil::MakeLayout({0, 1})); + EXPECT_EQ(custom_call->users()[0]->shape().layout(), + LayoutUtil::MakeLayout({1, 0})); + + custom_call = FindInstruction(module.get(), "custom-call.19"); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + EXPECT_EQ(custom_call->shape().layout(), + LayoutUtil::MakeLayout({0, 1}, {}, {}, {}, {Tile{{8, 128}}})); + EXPECT_EQ(custom_call->users()[0]->shape().layout(), + LayoutUtil::MakeLayout({1, 0})); } TEST_F(HostOffloadLegalizeTest, LlmActivationHostMemoryMultipleConsumers) { diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 82624c414b8168..8ca2296875e62b 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -721,11 +721,23 @@ Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.LayoutIsSet()) { // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. - TF_RETURN_IF_ERROR( - SetInstructionLayout(parameter_layout.shape(), instruction)); + Shape param_shape = parameter_layout.shape(); + // Clear out memory space in layout. Host offloader will do the + // analysis later. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + ¶m_shape, [](Shape* subshape, const ShapeIndex& index) { + if (!subshape->has_layout() || !subshape->IsArray()) { + return OkStatus(); + } + subshape->mutable_layout()->set_memory_space( + Layout::kDefaultMemorySpace); + return OkStatus(); + })); + + TF_RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction)); if (reverse_computation_order_) { TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers( - instruction, parameter_layout.shape(), this)); + instruction, param_shape, this)); } } } @@ -2537,7 +2549,8 @@ Status LayoutAssignment::PropagateComputationLayouts( } const auto& computed_subshape = ShapeUtil::GetSubshape( computed_computation_layout.parameter_shape(i), shape_index); - if (subshape.layout() != computed_subshape.layout()) { + if (!Layout::Equal().IgnoreMemorySpace()( + subshape.layout(), computed_subshape.layout())) { return Internal( "Assigned parameter shape %s does not match layout of " "computation shape: %s", From 41b8006c12180815d30134f24d87b4dc786ffc73 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Thu, 28 Mar 2024 06:49:50 -0700 Subject: [PATCH 021/124] Integrate LLVM at llvm/llvm-project@feebcd65fb7e Updates LLVM usage to match [feebcd65fb7e](https://github.com/llvm/llvm-project/commit/feebcd65fb7e) PiperOrigin-RevId: 619913390 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 83a5ccb8ec502b..ad8f149f602ed5 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "3cf169ca160eaf5464503fbd93d73ee1d8597936" - LLVM_SHA256 = "b63cac687df1bc98e3eb0289f3be6824fcb1b106d0720b5c083417918d1029fd" + LLVM_COMMIT = "feebcd65fb7e0534f5219e05432a05e45aa8cd2a" + LLVM_SHA256 = "39b2b0c5f5fefb54866a0e9738f1617d79049dbac3b5cdecb7b1f785a57bb669" tf_http_archive( name = name, From 3c2fb510507ea4913692a17c5d27c5b4853b5e91 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Mar 2024 07:13:53 -0700 Subject: [PATCH 022/124] Add Linux ppc64le support to CUDA stubs. Update implib.so to include a commit with ppc64le support. May fix https://github.com/google/jax/issues/19992, although I don't have such a machine so it's untested. PiperOrigin-RevId: 619919522 --- third_party/implib_so/workspace.bzl | 6 +++--- third_party/xla/third_party/implib_so/workspace.bzl | 6 +++--- .../xla/third_party/tsl/third_party/implib_so/workspace.bzl | 6 +++--- third_party/xla/xla/tsl/cuda/stub.bzl | 1 + 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/third_party/implib_so/workspace.bzl b/third_party/implib_so/workspace.bzl index 01dad3b169f402..37f36cc135fd6d 100644 --- a/third_party/implib_so/workspace.bzl +++ b/third_party/implib_so/workspace.bzl @@ -6,8 +6,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "implib_so", - strip_prefix = "Implib.so-5fb84c2a750434b9df1da67d67b749eb929598f1", - sha256 = "10de0a616df24849f2a883747784c115f209708960e44556f5ce384de6f103e8", - urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/5fb84c2a750434b9df1da67d67b749eb929598f1.tar.gz"), + strip_prefix = "Implib.so-2cce6cab8ff2c15f9da858ea0b68646a8d62aef2", + sha256 = "4ef3089969d57a5b60bb41b8212c478eaa15c56941f86d4bf5e7f98a3afd24e8", + urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/2cce6cab8ff2c15f9da858ea0b68646a8d62aef2.tar.gz"), build_file = "//third_party/implib_so:implib_so.BUILD", ) diff --git a/third_party/xla/third_party/implib_so/workspace.bzl b/third_party/xla/third_party/implib_so/workspace.bzl index 01dad3b169f402..37f36cc135fd6d 100644 --- a/third_party/xla/third_party/implib_so/workspace.bzl +++ b/third_party/xla/third_party/implib_so/workspace.bzl @@ -6,8 +6,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "implib_so", - strip_prefix = "Implib.so-5fb84c2a750434b9df1da67d67b749eb929598f1", - sha256 = "10de0a616df24849f2a883747784c115f209708960e44556f5ce384de6f103e8", - urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/5fb84c2a750434b9df1da67d67b749eb929598f1.tar.gz"), + strip_prefix = "Implib.so-2cce6cab8ff2c15f9da858ea0b68646a8d62aef2", + sha256 = "4ef3089969d57a5b60bb41b8212c478eaa15c56941f86d4bf5e7f98a3afd24e8", + urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/2cce6cab8ff2c15f9da858ea0b68646a8d62aef2.tar.gz"), build_file = "//third_party/implib_so:implib_so.BUILD", ) diff --git a/third_party/xla/third_party/tsl/third_party/implib_so/workspace.bzl b/third_party/xla/third_party/tsl/third_party/implib_so/workspace.bzl index 01dad3b169f402..37f36cc135fd6d 100644 --- a/third_party/xla/third_party/tsl/third_party/implib_so/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/implib_so/workspace.bzl @@ -6,8 +6,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "implib_so", - strip_prefix = "Implib.so-5fb84c2a750434b9df1da67d67b749eb929598f1", - sha256 = "10de0a616df24849f2a883747784c115f209708960e44556f5ce384de6f103e8", - urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/5fb84c2a750434b9df1da67d67b749eb929598f1.tar.gz"), + strip_prefix = "Implib.so-2cce6cab8ff2c15f9da858ea0b68646a8d62aef2", + sha256 = "4ef3089969d57a5b60bb41b8212c478eaa15c56941f86d4bf5e7f98a3afd24e8", + urls = tf_mirror_urls("https://github.com/yugr/Implib.so/archive/2cce6cab8ff2c15f9da858ea0b68646a8d62aef2.tar.gz"), build_file = "//third_party/implib_so:implib_so.BUILD", ) diff --git a/third_party/xla/xla/tsl/cuda/stub.bzl b/third_party/xla/xla/tsl/cuda/stub.bzl index d5e644dc13c97c..1aaa52746d69b8 100644 --- a/third_party/xla/xla/tsl/cuda/stub.bzl +++ b/third_party/xla/xla/tsl/cuda/stub.bzl @@ -21,6 +21,7 @@ def cuda_stub(name, srcs): cmd = select({ "@local_tsl//tsl:linux_aarch64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target aarch64", "@local_tsl//tsl:linux_x86_64": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target x86_64", + "@local_tsl//tsl:linux_ppc64le": "$(location //third_party/implib_so:make_stub) $< --outdir $(RULEDIR) --target powerpc64le", "//conditions:default": "NOT_IMPLEMENTED_FOR_THIS_PLATFORM_OR_ARCHITECTURE", }), ) From 8b1b801437cc28690a2bf843919692636c45821b Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Thu, 28 Mar 2024 07:54:27 -0700 Subject: [PATCH 023/124] Replace usage of `PermuteShape` to `Permute`. Replaces the duplicate implementation of array permutation. As a corollary, replaces permutation values to existing definitions in `attrs_and_constraints.h`. PiperOrigin-RevId: 619930275 --- .../common/attrs_and_constraints.h | 6 +++ .../passes/nchw_convolution_to_nhwc.cc | 46 ++++--------------- 2 files changed, 14 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index faf56159b39cc5..852902e229a9fc 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -49,6 +49,12 @@ inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; // permutation of `kNchwToNhwcPermutation`. inline constexpr std::array kNchwToNhwcPermutation = {0, 2, 3, 1}; +// Permutation from the OIHW (== (output features, input features, height, +// width)) tensor format to HWIO. This is commonly used to transpose convolution +// weights represented as OIHW format to HWIO, which is more desirable for +// certain downstream optimization passes (e.g. XLA). +inline constexpr std::array kOihwToHwioPermutation = {2, 3, 1, 0}; + // Returns true if the value has static shape. bool HasStaticShape(Value value); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 5ba80df30a9f2d..521f701598fb0a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -28,6 +28,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" namespace mlir::quant::stablehlo { @@ -72,20 +73,20 @@ class RewriteNchwConvolutionToNhwc // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] Value input = op->getOperand(0); const TensorType new_input_tensor_type = GetTransposedTensorType( - input.getType().cast(), kActivationPermutation); + input.getType().cast(), kNchwToNhwcPermutation); auto input_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_input_tensor_type, /*operand=*/input, - rewriter.getDenseI64ArrayAttr(kActivationPermutation)); + rewriter.getDenseI64ArrayAttr(kNchwToNhwcPermutation)); // Transpose the filter tensor: [o, i, 0, 1] => [0, 1, i, o] Value filter = op->getOperand(1); const TensorType new_filter_tensor_type = GetTransposedTensorType( - filter.getType().cast(), kFilterPermutation); + filter.getType().cast(), kOihwToHwioPermutation); auto filter_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_filter_tensor_type, /*operand=*/filter, - rewriter.getDenseI64ArrayAttr(kFilterPermutation)); + rewriter.getDenseI64ArrayAttr(kOihwToHwioPermutation)); // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] const auto new_dimension_nums = rewriter.getAttr( @@ -99,7 +100,7 @@ class RewriteNchwConvolutionToNhwc // Determine the shape of the output tensor: [b, f, 0, 1] => [b, 0, 1, f] auto output_tensor_type = op->getResult(0).getType().cast(); const TensorType new_conv_output_tensor_type = - GetTransposedTensorType(output_tensor_type, kOutputPermutation); + GetTransposedTensorType(output_tensor_type, kNchwToNhwcPermutation); // window_strides, padding, lhs_dilation, rhs_dilation, window_reversal are // reused without modification because the ordering of spatial dimensions @@ -125,31 +126,12 @@ class RewriteNchwConvolutionToNhwc auto output_transpose_op = rewriter.create( new_convolution_op.getLoc(), /*resultType0=*/output_tensor_type, /*operand=*/new_convolution_op, - rewriter.getDenseI64ArrayAttr(kOutputReversePermutation)); + rewriter.getDenseI64ArrayAttr(kNhwcToNchwPermutation)); rewriter.replaceAllUsesWith(op, output_transpose_op); } private: - // Permutation to transpose the input tensor from [b, f, 0, 1] to - // [b, 0, 1, f]. - static constexpr std::array kActivationPermutation = {0, 2, 3, 1}; - - // Permutation to transpose the filter tensor from [o, i, 0, 1] to - // [0, 1, i, o]. - static constexpr std::array kFilterPermutation = {2, 3, 1, 0}; - - // Permutation to transpose the output tensor from [b, f, 0, 1] to - // [b, 0, 1, f]. This is used to determine the shape of the new - // `ConvolutionOp`'s output tensor. - static constexpr std::array kOutputPermutation = {0, 2, 3, 1}; - - // Permutation to transpose the output tensor from [b, 0, 1, f] to - // [b, f, 0, 1]. This is used to revert the new output tensor of - // `ConvolutionOp` with a `TransposeOp`. - static constexpr std::array kOutputReversePermutation = {0, 3, 1, - 2}; - // Matches input dimensions corresponding to: [b, f, 0, 1]. bool MatchInputDimensionNumbers( const ConvDimensionNumbersAttr dimension_numbers) const { @@ -183,21 +165,9 @@ class RewriteNchwConvolutionToNhwc TensorType GetTransposedTensorType( const TensorType type, const ArrayRef permutation) const { const SmallVector after_shape = - PermuteShape(type.getShape(), permutation); + Permute(type.getShape(), permutation); return type.cloneWith(after_shape, type.getElementType()); } - - // Permutes the shape according to the permutation. The size of `shape` and - // `permutation` should be equal. - SmallVector PermuteShape(const ArrayRef shape, - const ArrayRef permutation) const { - const int64_t size = shape.size(); - SmallVector after_shape(size); - for (int i = 0; i < size; ++i) { - after_shape[i] = shape[permutation[i]]; - } - return after_shape; - } }; } // namespace From 4aa077841d27eb902054d6a9baad1345e78605eb Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Thu, 28 Mar 2024 08:07:34 -0700 Subject: [PATCH 024/124] Remove legacy post-training quantization functions PiperOrigin-RevId: 619933983 --- .../tensorflow/python/quantize_model.cc | 68 ------------------- .../tensorflow/python/quantize_model.h | 13 ---- 2 files changed, 81 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 8d230b97ca772b..89467d30944ca9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -298,74 +298,6 @@ absl::StatusOr QuantizeQatModel( *function_aliases); } -absl::StatusOr QuantizePtqModelPreCalibration( - absl::string_view saved_model_path, - const std::vector &signature_keys, - const std::unordered_set &tags, - const QuantizationOptions &quantization_options) { - std::unique_ptr context = - CreateMlirContextForQuantization(); - - absl::StatusOr> - function_aliases = GetFunctionAliases(saved_model_path, tags); - if (!function_aliases.ok()) { - return absl::InternalError(absl::StrCat( - "Failed to get function alias: ", function_aliases.status().message())); - } - - const bool is_stablehlo = quantization_options.op_set() == OpSet::STABLEHLO; - absl::StatusOr> module = - ImportAndPreprocessSavedModel( - saved_model_path, signature_keys, tags, context.get(), - /*is_inliner_run=*/true, - /*run_tf_to_stablehlo=*/is_stablehlo, - /*deserialize_xla_call_module=*/false, *function_aliases); - if (!module.status().ok()) { - return absl::InternalError( - absl::StrCat("Failed to import and preprocess SavedModel: ", - module.status().message())); - } - mlir::OwningOpRef module_ref = std::move(module).value(); - - return QuantizePtqModelPreCalibrationImpl( - *module_ref, context.get(), quantization_options, *function_aliases); -} - -absl::StatusOr QuantizePtqModelPostCalibration( - absl::string_view saved_model_path, - const std::vector &signature_keys, - const std::unordered_set &tags, - const QuantizationOptions &quantization_options) { - std::unique_ptr context = - CreateMlirContextForQuantization(); - - absl::StatusOr> - function_aliases = GetFunctionAliases(saved_model_path, tags); - if (!function_aliases.ok()) { - return absl::InternalError(absl::StrCat( - "Failed to get function alias: ", function_aliases.status().message())); - } - - // Freezing is required again since variables might have been produced during - // the pre-calibration step. `is_inliner_run = false` to prevent the functions - // lifted for quantization from being inlined. - absl::StatusOr> module = - ImportAndPreprocessSavedModel( - saved_model_path, signature_keys, tags, context.get(), - /*is_inliner_run=*/false, - /*run_tf_to_stablehlo=*/false, - /*deserialize_xla_call_module=*/false, *function_aliases); - if (!module.status().ok()) { - return absl::InternalError( - absl::StrCat("Failed to import and preprocess SavedModel: ", - module.status().message())); - } - mlir::OwningOpRef module_ref = std::move(module).value(); - - return QuantizePtqModelPostCalibrationImpl( - *module_ref, context.get(), quantization_options, *function_aliases); -} - absl::StatusOr QuantizeDynamicRangePtq( absl::string_view saved_model_path, const std::vector &signature_keys, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index ec7df2929660d5..a54e988c043aa3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -70,19 +70,6 @@ absl::StatusOr QuantizeStaticRangePtq( const absl::flat_hash_map& representative_dataset_file_map_serialized); -// Legacy versions of static-range quantization. -absl::StatusOr QuantizePtqModelPreCalibration( - absl::string_view saved_model_path, - const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationOptions& quantization_options); - -absl::StatusOr QuantizePtqModelPostCalibration( - absl::string_view saved_model_path, - const std::vector& signature_keys, - const std::unordered_set& tags, - const QuantizationOptions& quantization_options); - } // namespace quantization } // namespace tensorflow From c046dde407558f688b8ec1bbb1d380f38eb0a9e0 Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Thu, 28 Mar 2024 08:12:15 -0700 Subject: [PATCH 025/124] #tf-data-service Reduce severity of missing default transfer server log line. PiperOrigin-RevId: 619935370 --- .../core/data/service/client/data_service_client.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 472562e750bb73..7bba141740292a 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -400,12 +400,12 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { return CreateAlternativeWorkerClientWithGrpcFallback(*transfer_server, task_info); } - LOG(INFO) << "Failed to find transfer server for default data transfer " - "protocol '" - << default_protocol << "' for worker '" - << task_info.worker_address() - << "'; falling back to grpc. Original error: " - << transfer_server.status(); + VLOG(1) << "Failed to find transfer server for default data transfer " + "protocol '" + << default_protocol << "' for worker '" + << task_info.worker_address() + << "'; falling back to grpc. Original error: " + << transfer_server.status(); metrics::RecordTFDataServiceDataTransferProtocolFallback( default_protocol, error::Code::NOT_FOUND, "Failed to find transfer server for default protocol"); From 27fc08fb84fc0d76f3ebf8e924b3e4c6b59876c4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 28 Mar 2024 08:28:00 -0700 Subject: [PATCH 026/124] [xla:gpu] Added a test checking that Triton kernels compiled via XLA do not dedup arguments PiperOrigin-RevId: 619940030 --- third_party/xla/xla/service/gpu/tests/BUILD | 2 + .../gpu/tests/gpu_triton_custom_call_test.cc | 177 ++++++++++-------- 2 files changed, 97 insertions(+), 82 deletions(-) diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 87617a29b8c289..b9e7764047d8c4 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -454,8 +454,10 @@ xla_test( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc index 7dc6fda816609a..52351018c743bb 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc @@ -21,17 +21,20 @@ limitations under the License. #include #include #include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "tsl/platform/status_matchers.h" @@ -41,39 +44,13 @@ namespace gpu { using ::mlir::ArrayRef; using ::mlir::NamedAttribute; -class GpuIrEmitterUnnestedTest : public GpuCodegenTest { - public: - se::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - } -}; +namespace { -TEST_F(GpuIrEmitterUnnestedTest, - EmitTritonCustomCallWithCorrectLoweringAndWithoutNoaliasOrAlignment) { - if (!GetCudaComputeCapability().IsAtLeastAmpere()) { - GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; - } - - // Tests that the lowering of a Triton custom call produces the correct LLVM - // IR, and that the arguments do not specify noalias or alignment attributes. - - HloComputation::Builder computation_builder(TestName()); +std::unique_ptr CreateAddTritonCustomCall( + Shape tuple_shape, HloInstruction* param_0, HloInstruction* param_1) { mlir::MLIRContext context_; mlir::Builder builder(&context_); - // Create parameters and custom call in the computation builder. - Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); - Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - HloInstruction* param_0 = computation_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "arg_0")); - - HloInstruction* param_1 = computation_builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); - // Create the backend_config for the triton custom call. const std::string kMLIRText = R"( module { @@ -119,9 +96,46 @@ TEST_F(GpuIrEmitterUnnestedTest, std::string backend_config_str; llvm::raw_string_ostream(backend_config_str) << backend_config; - computation_builder.AddInstruction(HloInstruction::CreateCustomCall( - tuple_shape, {param_0, param_1}, "__gpu$xla.gpu.triton", - backend_config_str)); + return HloInstruction::CreateCustomCall(tuple_shape, {param_0, param_1}, + "__gpu$xla.gpu.triton", + backend_config_str); +} + +} // namespace + +class GpuIrEmitterUnnestedTest : public GpuCodegenTest { + public: + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } +}; + +TEST_F(GpuIrEmitterUnnestedTest, + EmitTritonCustomCallWithCorrectLoweringAndWithoutNoaliasOrAlignment) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; + } + + // Tests that the lowering of a Triton custom call produces the correct LLVM + // IR, and that the arguments do not specify noalias or alignment attributes. + + HloComputation::Builder computation_builder(TestName()); + + // Create parameters and custom call in the computation builder. + Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + HloInstruction* param_0 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "arg_0")); + + HloInstruction* param_1 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); + + computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, param_0, param_1)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(computation_builder.Build()); @@ -161,8 +175,6 @@ TEST_F(GpuIrEmitterUnnestedTest, CanNotEmitTritonCustomCallOnPreAmpereGpu) { } HloComputation::Builder computation_builder(TestName()); - mlir::MLIRContext context_; - mlir::Builder builder(&context_); // Create parameters and custom call in the computation builder. Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); @@ -174,54 +186,8 @@ TEST_F(GpuIrEmitterUnnestedTest, CanNotEmitTritonCustomCallOnPreAmpereGpu) { HloInstruction* param_1 = computation_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); - // Create the backend_config for the triton custom call. - const std::string kMLIRText = R"( - module { - tt.func public @add_one(%arg0: !tt.ptr {tt.divisibility = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 32 : i32}, %arg3: !tt.ptr {tt.divisibility = 32 : i32}) { - %0 = tt.get_program_id x : i32 - %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 - %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 - %cst = arith.constant 1.000000e+00 : f32 - %3 = arith.addf %1, %cst : f32 - %4 = tt.load %arg2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 - tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : f32 - %5 = tt.load %arg3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32 - tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : f32 - tt.return - } - } - )"; - - NamedAttribute name = - builder.getNamedAttr("name", builder.getStringAttr("add_one")); - NamedAttribute ir = - builder.getNamedAttr("ir", builder.getStringAttr(kMLIRText)); - NamedAttribute num_stages = - builder.getNamedAttr("num_stages", builder.getI32IntegerAttr(3)); - NamedAttribute num_warps = - builder.getNamedAttr("num_warps", builder.getI32IntegerAttr(4)); - NamedAttribute grid_x = - builder.getNamedAttr("grid_x", builder.getI32IntegerAttr(1)); - NamedAttribute grid_y = - builder.getNamedAttr("grid_y", builder.getI32IntegerAttr(1)); - NamedAttribute grid_z = - builder.getNamedAttr("grid_z", builder.getI32IntegerAttr(1)); - NamedAttribute debug = - builder.getNamedAttr("debug", builder.getBoolAttr(false)); - - std::vector attributes = { - name, ir, num_stages, num_warps, grid_x, grid_y, grid_z, debug}; - ArrayRef attributesRef(attributes); - mlir::DictionaryAttr backend_config = - mlir::DictionaryAttr::get(&context_, attributesRef); - - // Parse the backend_config into a string. - std::string backend_config_str; - llvm::raw_string_ostream(backend_config_str) << backend_config; - - computation_builder.AddInstruction(HloInstruction::CreateCustomCall( - tuple_shape, {param_0, param_1}, "__gpu$xla.gpu.triton", - backend_config_str)); + computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, param_0, param_1)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(computation_builder.Build()); @@ -234,5 +200,52 @@ TEST_F(GpuIrEmitterUnnestedTest, CanNotEmitTritonCustomCallOnPreAmpereGpu) { "Triton support is only enabled for Ampere GPUs and up."))); } +class TritonCustomCallTest : public HloTestBase {}; + +TEST_F(TritonCustomCallTest, NoArgumentDeduplication) { + if (auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + !cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; + } + + // Tests that no argument deduplication is done for Triton kernels. + // + // Triton kernels are compiled on the first call and re-used for all the + // following calls. So, if we are unlucky, we could end up calling the + // compiled kernel with fewer arguments than it expects in the presence + // of argument deduplication. + // + // For example, + // + // * The first call is f(x, y). The arguments are distinct, no deduplication + // is done at compilation time and the compiled kernel expects two + // arguments. + // * The second call is f(x, x). The arguments are deduplicated and we + // call the previously compiled kernel with just x, causing a crash. + + HloComputation::Builder computation_builder(TestName()); + + Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + HloInstruction* param_0 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "arg_0")); + + HloInstruction* param_1 = computation_builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); + + auto* instr_0 = computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, param_0, param_1)); + computation_builder.AddInstruction( + CreateAddTritonCustomCall(tuple_shape, instr_0, instr_0)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(computation_builder.Build()); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); +} + } // namespace gpu } // namespace xla From ab23fef288dd6b5af64fbb575e194c7bedbc06cc Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Thu, 28 Mar 2024 08:28:42 -0700 Subject: [PATCH 027/124] #shlo_ref Add `maximum` op. PiperOrigin-RevId: 619940246 --- tensorflow/lite/experimental/shlo/ops/BUILD | 29 ++++ .../lite/experimental/shlo/ops/maximum.cc | 66 ++++++++ .../lite/experimental/shlo/ops/maximum.h | 36 +++++ .../experimental/shlo/ops/maximum_test.cc | 151 ++++++++++++++++++ 4 files changed, 282 insertions(+) create mode 100644 tensorflow/lite/experimental/shlo/ops/maximum.cc create mode 100644 tensorflow/lite/experimental/shlo/ops/maximum.h create mode 100644 tensorflow/lite/experimental/shlo/ops/maximum_test.cc diff --git a/tensorflow/lite/experimental/shlo/ops/BUILD b/tensorflow/lite/experimental/shlo/ops/BUILD index 8a924dfd80aae3..8624708052c3d2 100644 --- a/tensorflow/lite/experimental/shlo/ops/BUILD +++ b/tensorflow/lite/experimental/shlo/ops/BUILD @@ -963,3 +963,32 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "maximum", + srcs = ["maximum.cc"], + hdrs = ["maximum.h"], + deps = [ + ":binary_elementwise", + ":util", + "//tensorflow/lite/experimental/shlo:dispatch", + "//tensorflow/lite/experimental/shlo:tensor", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "maximum_test", + srcs = ["maximum_test.cc"], + deps = [ + ":binary_elementwise_test_util", + ":maximum", + ":test_util", + "//tensorflow/lite/experimental/shlo:quantize", + "//tensorflow/lite/experimental/shlo:quantized_tensor_element_type", + "//tensorflow/lite/experimental/shlo:shape", + "//tensorflow/lite/experimental/shlo:status_matcher", + "//tensorflow/lite/experimental/shlo:tensor", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/shlo/ops/maximum.cc b/tensorflow/lite/experimental/shlo/ops/maximum.cc new file mode 100644 index 00000000000000..9239dadaac3f92 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/ops/maximum.cc @@ -0,0 +1,66 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/ops/maximum.h" + +#include "absl/status/status.h" +#include "tensorflow/lite/experimental/shlo/dispatch.h" +#include "tensorflow/lite/experimental/shlo/ops/binary_elementwise.h" +#include "tensorflow/lite/experimental/shlo/ops/util.h" +#include "tensorflow/lite/experimental/shlo/tensor.h" + +namespace shlo_ref { + +struct Maximum { + template + constexpr auto operator()(const T a, const T b) { + return a > b ? a : b; + } +}; + +MaximumOp Create(MaximumOp::Attributes) { return {}; } + +absl::Status Prepare(MaximumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output) { + SHLO_REF_RETURN_ON_ERROR(Propagate(lhs.shape(), rhs.shape(), output.shape())); + SHLO_REF_RETURN_ON_ERROR( + CheckSupportedTypes(CheckCtx("maximum"), lhs, IsBoolTensor, IsIntTensor, + IsFloatTensor, IsQuantizedPerTensorTensor)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("maximum"), lhs, output)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("maximum"), rhs, output)); + return absl::OkStatus(); +} + +absl::Status Evaluate(MaximumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output) { + Maximum maximum; + if (IsBoolTensor(lhs) || IsIntTensor(lhs) || IsFloatTensor(lhs)) { + // Note: all the arithmetic types share the same implementation. + DISPATCH_BOOL_INT_FLOAT(detail::EvaluateNoQuantization, + lhs.tensor_element_type(), maximum, lhs, rhs, + output); + } else if (IsQuantizedPerTensorTensor(lhs)) { + DISPATCH_QUANTIZED(detail::DequantizeOpQuantizePerTensor, + lhs.quantized_tensor_element_type().StorageType(), + lhs.quantized_tensor_element_type().ExpressedType(), + maximum, lhs, rhs, output) + } + return absl::FailedPreconditionError( + "stablehlo.maximum: Unsupported tensor type."); +} + +} // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/maximum.h b/tensorflow/lite/experimental/shlo/ops/maximum.h new file mode 100644 index 00000000000000..1ce0be360542f9 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/ops/maximum.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_OPS_MAXIMUM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_OPS_MAXIMUM_H_ + +#include "absl/status/status.h" +#include "tensorflow/lite/experimental/shlo/tensor.h" + +namespace shlo_ref { + +struct MaximumOp { + struct Attributes {}; +}; + +MaximumOp Create(MaximumOp::Attributes); +absl::Status Prepare(MaximumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output); +absl::Status Evaluate(MaximumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output); + +} // namespace shlo_ref + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_OPS_MAXIMUM_H_ diff --git a/tensorflow/lite/experimental/shlo/ops/maximum_test.cc b/tensorflow/lite/experimental/shlo/ops/maximum_test.cc new file mode 100644 index 00000000000000..0422331324daf9 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/ops/maximum_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/ops/maximum.h" + +#include + +#include +#include +#include "tensorflow/lite/experimental/shlo/ops/binary_elementwise_test_util.h" +#include "tensorflow/lite/experimental/shlo/ops/test_util.h" +#include "tensorflow/lite/experimental/shlo/quantize.h" +#include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" +#include "tensorflow/lite/experimental/shlo/shape.h" +#include "tensorflow/lite/experimental/shlo/status_matcher.h" +#include "tensorflow/lite/experimental/shlo/tensor.h" + +using testing::FloatEq; +using testing::Pointwise; + +namespace shlo_ref { + +template <> +struct ParamName { + static std::string Get() { return "Maximum"; } +}; + +struct Maximum { + template + constexpr auto operator()(const T a, const T b) { + return a > b ? a : b; + } +}; + +namespace { + +INSTANTIATE_TYPED_TEST_SUITE_P(Maximum, BinaryElementwiseOpShapePropagationTest, + MaximumOp, TestParamNames); + +using MaximumBaselineContraintTypes = BinaryElementwiseBaselineConstraintTypes< + MaximumOp, ConcatTypes>; + +INSTANTIATE_TYPED_TEST_SUITE_P( + Maximum, BinaryElementwiseSameBaselineElementTypeConstraintTest, + MaximumBaselineContraintTypes, TestParamNames); + +using UnsupportedTypes = + WithOpTypes>; + +INSTANTIATE_TYPED_TEST_SUITE_P(Maximum, BinaryElementwiseUnsupportedTypeTest, + UnsupportedTypes, TestParamNames); + +using SupportedTypes = ConcatTypes; + +template +struct MaximumTest : ::testing::Test {}; + +TYPED_TEST_SUITE(MaximumTest, SupportedTypes, TestParamNames); + +TYPED_TEST(MaximumTest, ArithmeticTestTypesTensorsWork) { + using StorageT = typename TypeParam::StorageT; + + const Shape shape({2, 3, 4}); + Vector lhs_data = + RandomBuffer(shape, /*min=*/-50, /*max=*/50); + Vector rhs_data = + RandomBuffer(shape, /*min=*/1, /*max=*/5); + Vector output_data(shape.NumElements()); + Tensor lhs_tensor{ + .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, + .data = lhs_data.data()}; + Tensor rhs_tensor{ + .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, + .data = rhs_data.data()}; + Tensor output_tensor{ + .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, + .data = output_data.data()}; + + Vector expected_data(shape.NumElements()); + absl::c_transform(lhs_data, rhs_data, expected_data.begin(), Maximum()); + + auto op = Create(MaximumOp::Attributes{}); + ASSERT_OK(Prepare(op, lhs_tensor, rhs_tensor, output_tensor)); + ASSERT_OK(Evaluate(op, lhs_tensor, rhs_tensor, output_tensor)); + EXPECT_THAT(output_data, Pointwise(FloatEq(), expected_data)); +} + +template +struct QuantizedMaximumTest : ::testing::Test {}; + +TYPED_TEST_SUITE(QuantizedMaximumTest, QuantizedTestTypes, TestParamNames); + +TYPED_TEST(QuantizedMaximumTest, PerTensorWorks) { + using StorageT = typename TypeParam::StorageT; + using ExpressedT = typename TypeParam::ExpressedT; + + const Shape shape({2, 3, 4}); + const ExpressedT scale = static_cast(1.5); + const StorageT zero_point = static_cast(2); + Vector lhs_data = + RandomBuffer(shape, /*min=*/-50, /*max=*/50); + Vector rhs_data = RandomBuffer( + shape, /*min=*/zero_point + 1, /*max=*/zero_point + 5); + Vector output_data(shape.NumElements()); + const QuantizedTensorElementType tensor_type = + QuantizedTensorElementType::PerTensor(scale, + zero_point); + Tensor lhs_tensor{ + .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, + .data = lhs_data.data()}; + Tensor rhs_tensor{ + .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, + .data = rhs_data.data()}; + Tensor output_tensor{ + .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, + .data = output_data.data()}; + + Vector expected_data(shape.NumElements()); + absl::c_transform( + lhs_data, rhs_data, expected_data.begin(), + [zero_point, scale](auto lhs, auto rhs) { + const ExpressedT dequantized_lhs = Dequantize(lhs, zero_point, scale); + const ExpressedT dequantized_rhs = Dequantize(rhs, zero_point, scale); + const ExpressedT dequantized_res = + Maximum()(dequantized_lhs, dequantized_rhs); + return Quantize( + dequantized_res, zero_point, static_cast(1.) / scale); + }); + + auto op = Create(MaximumOp::Attributes{}); + ASSERT_OK(Prepare(op, lhs_tensor, rhs_tensor, output_tensor)); + ASSERT_OK(Evaluate(op, lhs_tensor, rhs_tensor, output_tensor)); + EXPECT_THAT(output_data, Pointwise(FloatEq(), expected_data)); +} +} // namespace +} // namespace shlo_ref From c063f06f7fb3d87858d34d2811dc978276f3c987 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Thu, 28 Mar 2024 08:58:37 -0700 Subject: [PATCH 028/124] #shlo_ref Add `minimum` op. PiperOrigin-RevId: 619948992 --- tensorflow/lite/experimental/shlo/ops/BUILD | 29 ++++ .../lite/experimental/shlo/ops/minimum.cc | 66 ++++++++ .../lite/experimental/shlo/ops/minimum.h | 36 +++++ .../experimental/shlo/ops/minimum_test.cc | 151 ++++++++++++++++++ 4 files changed, 282 insertions(+) create mode 100644 tensorflow/lite/experimental/shlo/ops/minimum.cc create mode 100644 tensorflow/lite/experimental/shlo/ops/minimum.h create mode 100644 tensorflow/lite/experimental/shlo/ops/minimum_test.cc diff --git a/tensorflow/lite/experimental/shlo/ops/BUILD b/tensorflow/lite/experimental/shlo/ops/BUILD index 8624708052c3d2..f839714f3ecaf2 100644 --- a/tensorflow/lite/experimental/shlo/ops/BUILD +++ b/tensorflow/lite/experimental/shlo/ops/BUILD @@ -992,3 +992,32 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "minimum", + srcs = ["minimum.cc"], + hdrs = ["minimum.h"], + deps = [ + ":binary_elementwise", + ":util", + "//tensorflow/lite/experimental/shlo:dispatch", + "//tensorflow/lite/experimental/shlo:tensor", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "minimum_test", + srcs = ["minimum_test.cc"], + deps = [ + ":binary_elementwise_test_util", + ":minimum", + ":test_util", + "//tensorflow/lite/experimental/shlo:quantize", + "//tensorflow/lite/experimental/shlo:quantized_tensor_element_type", + "//tensorflow/lite/experimental/shlo:shape", + "//tensorflow/lite/experimental/shlo:status_matcher", + "//tensorflow/lite/experimental/shlo:tensor", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/shlo/ops/minimum.cc b/tensorflow/lite/experimental/shlo/ops/minimum.cc new file mode 100644 index 00000000000000..c583d7afb6b147 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/ops/minimum.cc @@ -0,0 +1,66 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/ops/minimum.h" + +#include "absl/status/status.h" +#include "tensorflow/lite/experimental/shlo/dispatch.h" +#include "tensorflow/lite/experimental/shlo/ops/binary_elementwise.h" +#include "tensorflow/lite/experimental/shlo/ops/util.h" +#include "tensorflow/lite/experimental/shlo/tensor.h" + +namespace shlo_ref { + +struct Minimum { + template + constexpr auto operator()(const T a, const T b) { + return a < b ? a : b; + } +}; + +MinimumOp Create(MinimumOp::Attributes) { return {}; } + +absl::Status Prepare(MinimumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output) { + SHLO_REF_RETURN_ON_ERROR(Propagate(lhs.shape(), rhs.shape(), output.shape())); + SHLO_REF_RETURN_ON_ERROR( + CheckSupportedTypes(CheckCtx("minimum"), lhs, IsBoolTensor, IsIntTensor, + IsFloatTensor, IsQuantizedPerTensorTensor)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("minimum"), lhs, output)); + SHLO_REF_RETURN_ON_ERROR( + CheckSameBaselineType(CheckCtx("minimum"), rhs, output)); + return absl::OkStatus(); +} + +absl::Status Evaluate(MinimumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output) { + Minimum minimum; + if (IsBoolTensor(lhs) || IsIntTensor(lhs) || IsFloatTensor(lhs)) { + // Note: all the arithmetic types share the same implementation. + DISPATCH_BOOL_INT_FLOAT(detail::EvaluateNoQuantization, + lhs.tensor_element_type(), minimum, lhs, rhs, + output); + } else if (IsQuantizedPerTensorTensor(lhs)) { + DISPATCH_QUANTIZED(detail::DequantizeOpQuantizePerTensor, + lhs.quantized_tensor_element_type().StorageType(), + lhs.quantized_tensor_element_type().ExpressedType(), + minimum, lhs, rhs, output) + } + return absl::FailedPreconditionError( + "stablehlo.minimum: Unsupported tensor type."); +} + +} // namespace shlo_ref diff --git a/tensorflow/lite/experimental/shlo/ops/minimum.h b/tensorflow/lite/experimental/shlo/ops/minimum.h new file mode 100644 index 00000000000000..5fc2205566de9c --- /dev/null +++ b/tensorflow/lite/experimental/shlo/ops/minimum.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SHLO_OPS_MINIMUM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SHLO_OPS_MINIMUM_H_ + +#include "absl/status/status.h" +#include "tensorflow/lite/experimental/shlo/tensor.h" + +namespace shlo_ref { + +struct MinimumOp { + struct Attributes {}; +}; + +MinimumOp Create(MinimumOp::Attributes); +absl::Status Prepare(MinimumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output); +absl::Status Evaluate(MinimumOp& op, const Tensor& lhs, const Tensor& rhs, + Tensor& output); + +} // namespace shlo_ref + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SHLO_OPS_MINIMUM_H_ diff --git a/tensorflow/lite/experimental/shlo/ops/minimum_test.cc b/tensorflow/lite/experimental/shlo/ops/minimum_test.cc new file mode 100644 index 00000000000000..586c096f1cbab0 --- /dev/null +++ b/tensorflow/lite/experimental/shlo/ops/minimum_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/shlo/ops/minimum.h" + +#include + +#include +#include +#include "tensorflow/lite/experimental/shlo/ops/binary_elementwise_test_util.h" +#include "tensorflow/lite/experimental/shlo/ops/test_util.h" +#include "tensorflow/lite/experimental/shlo/quantize.h" +#include "tensorflow/lite/experimental/shlo/quantized_tensor_element_type.h" +#include "tensorflow/lite/experimental/shlo/shape.h" +#include "tensorflow/lite/experimental/shlo/status_matcher.h" +#include "tensorflow/lite/experimental/shlo/tensor.h" + +using testing::FloatEq; +using testing::Pointwise; + +namespace shlo_ref { + +template <> +struct ParamName { + static std::string Get() { return "Minimum"; } +}; + +struct Minimum { + template + constexpr auto operator()(const T a, const T b) { + return a < b ? a : b; + } +}; + +namespace { + +INSTANTIATE_TYPED_TEST_SUITE_P(Minimum, BinaryElementwiseOpShapePropagationTest, + MinimumOp, TestParamNames); + +using MinimumBaselineContraintTypes = BinaryElementwiseBaselineConstraintTypes< + MinimumOp, ConcatTypes>; + +INSTANTIATE_TYPED_TEST_SUITE_P( + Minimum, BinaryElementwiseSameBaselineElementTypeConstraintTest, + MinimumBaselineContraintTypes, TestParamNames); + +using UnsupportedTypes = + WithOpTypes>; + +INSTANTIATE_TYPED_TEST_SUITE_P(Minimum, BinaryElementwiseUnsupportedTypeTest, + UnsupportedTypes, TestParamNames); + +using SupportedTypes = ConcatTypes; + +template +struct MinimumTest : ::testing::Test {}; + +TYPED_TEST_SUITE(MinimumTest, SupportedTypes, TestParamNames); + +TYPED_TEST(MinimumTest, ArithmeticTestTypesTensorsWork) { + using StorageT = typename TypeParam::StorageT; + + const Shape shape({2, 3, 4}); + Vector lhs_data = + RandomBuffer(shape, /*min=*/-50, /*max=*/50); + Vector rhs_data = + RandomBuffer(shape, /*min=*/1, /*max=*/5); + Vector output_data(shape.NumElements()); + Tensor lhs_tensor{ + .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, + .data = lhs_data.data()}; + Tensor rhs_tensor{ + .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, + .data = rhs_data.data()}; + Tensor output_tensor{ + .type = TensorType{.shape = shape, .element_type = TypeParam::kStorage}, + .data = output_data.data()}; + + Vector expected_data(shape.NumElements()); + absl::c_transform(lhs_data, rhs_data, expected_data.begin(), Minimum()); + + auto op = Create(MinimumOp::Attributes{}); + ASSERT_OK(Prepare(op, lhs_tensor, rhs_tensor, output_tensor)); + ASSERT_OK(Evaluate(op, lhs_tensor, rhs_tensor, output_tensor)); + EXPECT_THAT(output_data, Pointwise(FloatEq(), expected_data)); +} + +template +struct QuantizedMinimumTest : ::testing::Test {}; + +TYPED_TEST_SUITE(QuantizedMinimumTest, QuantizedTestTypes, TestParamNames); + +TYPED_TEST(QuantizedMinimumTest, PerTensorWorks) { + using StorageT = typename TypeParam::StorageT; + using ExpressedT = typename TypeParam::ExpressedT; + + const Shape shape({2, 3, 4}); + const ExpressedT scale = static_cast(1.5); + const StorageT zero_point = static_cast(2); + Vector lhs_data = + RandomBuffer(shape, /*min=*/-50, /*max=*/50); + Vector rhs_data = RandomBuffer( + shape, /*min=*/zero_point + 1, /*max=*/zero_point + 5); + Vector output_data(shape.NumElements()); + const QuantizedTensorElementType tensor_type = + QuantizedTensorElementType::PerTensor(scale, + zero_point); + Tensor lhs_tensor{ + .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, + .data = lhs_data.data()}; + Tensor rhs_tensor{ + .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, + .data = rhs_data.data()}; + Tensor output_tensor{ + .type = QuantizedTensorType{.shape = shape, .element_type = tensor_type}, + .data = output_data.data()}; + + Vector expected_data(shape.NumElements()); + absl::c_transform( + lhs_data, rhs_data, expected_data.begin(), + [zero_point, scale](auto lhs, auto rhs) { + const ExpressedT dequantized_lhs = Dequantize(lhs, zero_point, scale); + const ExpressedT dequantized_rhs = Dequantize(rhs, zero_point, scale); + const ExpressedT dequantized_res = + Minimum()(dequantized_lhs, dequantized_rhs); + return Quantize( + dequantized_res, zero_point, static_cast(1.) / scale); + }); + + auto op = Create(MinimumOp::Attributes{}); + ASSERT_OK(Prepare(op, lhs_tensor, rhs_tensor, output_tensor)); + ASSERT_OK(Evaluate(op, lhs_tensor, rhs_tensor, output_tensor)); + EXPECT_THAT(output_data, Pointwise(FloatEq(), expected_data)); +} +} // namespace +} // namespace shlo_ref From 48bb29f982ccab0d800043fb83f7118f8173a9fb Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Thu, 28 Mar 2024 09:06:31 -0700 Subject: [PATCH 029/124] [xla:gpu] Address computation thunk should not be inside command buffer Address computation is not compatible with command buffer since it requires a host sync. PiperOrigin-RevId: 619951741 --- third_party/xla/xla/service/gpu/command_buffer_scheduling.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc index 97e28bdd356761..a9ed89cbf91ea8 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc @@ -167,6 +167,9 @@ static bool IsCommand(const HloInstruction* hlo, &custom_call_adaptor->instruction()); return IsCommand(custom_call, config); } + if (custom_config.name() == "dynamic_address_computation") { + return false; + } return config.enabled_commands.contains(DebugOptions::FUSION); } From 2953811e87b70627c70647aea23ed6e986f1c647 Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Thu, 28 Mar 2024 09:14:06 -0700 Subject: [PATCH 030/124] PR #10344: [XLA:CPU] Enable BMM+Mul+Add fusion Imported from GitHub PR https://github.com/openxla/xla/pull/10344 This PR enables the BatchMatMul + Mul + Add fusion and adds a simple test Copybara import of the project: -- 2a202c9d171064bb3005e0753af41e26f2e1baf3 by Kanvi Khanna : Enable BMM+Mul+Add fusion -- 584e59861f5f91fae8222b9281e501dbeb270b94 by Kanvi Khanna : Address review comments Merging this change closes #10344 PiperOrigin-RevId: 619954276 --- .../xla/service/cpu/onednn_matmul_rewriter.cc | 4 +-- .../xla/xla/tests/onednn_matmul_test.cc | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc index 0bbd677325150b..d73975b3eac654 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc @@ -434,7 +434,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast, OneDnnMatmulInstr(&dot)) .WithOneUser(), - m::Op(&addend_intermediate).WithOneUser()); + m::Op(&addend_intermediate)); if (Match(instr, pattern)) { if (!IsSupportedType(dot->shape().element_type())) return OkStatus(); @@ -587,7 +587,7 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { .WithOneUser() .WithOpcode(HloOpcode::kCustomCall) .WithCustomCallTarget({"__onednn$matmul"}), - m::Broadcast(m::Constant(&constant)).WithOneUser()); + m::Broadcast(m::Constant(&constant))); if (Match(instr, pattern)) { std::vector new_operands; diff --git a/third_party/xla/xla/tests/onednn_matmul_test.cc b/third_party/xla/xla/tests/onednn_matmul_test.cc index d8488bc2ca4fb6..8b100b6141bee2 100644 --- a/third_party/xla/xla/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/tests/onednn_matmul_test.cc @@ -618,6 +618,38 @@ TEST_F(MatmulTest, TestTransposeBNoRewriteF32) { )"); } +TEST_F(MatmulTest, SimpleTestF32WithMulAndAddFusion) { + const char* matmul_module_str = R"( + ENTRY matmul.mul.add.test.f32 { + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f32[32,32,30,40] parameter(1), parameter_replication={false} + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + const.0 = f32[] constant(0.044715) + bcast.0 = f32[32,32,40,40] broadcast(const.0), dimensions={} + mul.0 = f32[32,32,40,40] multiply(dot.7,bcast.0) + const.1 = f32[] constant(0.65) + bcast.1 = f32[32,32,40,40] broadcast(const.1), dimensions={} + add.0 = f32[32,32,40,40] add(mul.0, bcast.1) + const.2 = f32[] constant(0.65) + bcast.2 = f32[32,32,40,40] broadcast(const.2), dimensions={} + add.1 = f32[32,32,40,40] add(bcast.2, bcast.1) + tuple.12 = (f32[32,32,40,40]) tuple(add.0) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: custom_call_target="__onednn$matmul", + ; CHECK: backend_config={ + ; CHECK-DAG: "outer_dimension_partitions":[], + ; CHECK-DAG: "onednn_matmul_config":{ + ; CHECK-DAG: "fused_ops":["LINEAR","BINARY_ADD"] + ; CHECK-DAG: } + ; CHECK: } + )"); +} + } // namespace cpu } // namespace xla From 0722946a072cf0542134235e1c29b1de7826311f Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Thu, 28 Mar 2024 09:16:17 -0700 Subject: [PATCH 031/124] [xla:gpu][NFC] Refactor and rename in address_computation_fusion_rewriter for clarity PiperOrigin-RevId: 619954961 --- .../address_computation_fusion_rewriter.cc | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index 6be6b4cc0be087..89a47d7cd9cd05 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -314,7 +314,7 @@ Status CreateRootTuple(HloInstruction* hero, HloComputation::Builder& builder, } absl::StatusOr CreateFusionBody( - HloModule* module, absl::Span operand_matches, + HloModule* module, absl::Span sliced_operand_paths, DefUseDataflowPaths sliced_user_paths, absl::Span captures) { HloComputation::Builder builder("address-computation"); @@ -342,7 +342,7 @@ absl::StatusOr CreateFusionBody( // Instructions in the pattern are already topologically sorted, as we visited // them following use-def path, then reverse the list. HloInstruction* hero; - for (HloInstruction* instr : operand_matches) { + for (HloInstruction* instr : sliced_operand_paths) { instr_mapping[instr] = builder.AddInstruction( instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); hero = instr; @@ -441,36 +441,36 @@ absl::StatusOr AddressComputationFusionRewriter::Run( if (matches.empty()) return false; HloSchedule& schedule = module->schedule(); - for (auto& kv : matches) { - auto& [operand_matches, sliced_user_paths] = kv.second; - std::vector matches; - absl::c_copy(operand_matches, std::back_inserter(matches)); + for (auto& [hero, paths] : matches) { + auto& [sliced_operand_paths, sliced_user_paths] = paths; + std::vector matched_instrs; + absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs)); for (auto& sliced_user_path : sliced_user_paths) - absl::c_copy(sliced_user_path, std::back_inserter(matches)); + absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs)); - auto captures = GetPatternCaptures(matches); + auto captures = GetPatternCaptures(matched_instrs); TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, - CreateFusionBody(module, operand_matches, + CreateFusionBody(module, sliced_operand_paths, sliced_user_paths, captures)); TF_ASSIGN_OR_RETURN(HloInstruction * fusion, - CreateFusionInstruction(module, kv.first, captures, + CreateFusionInstruction(module, hero, captures, fusion_body, dynamic)); // As we are running after scheduling we have to keep it valid. - HloComputation* parent = kv.first->parent(); + HloComputation* parent = hero->parent(); // Update schedule to replace the custom call instruction with the fusion // instruction. // Removal of the rest of the instructions in the sequence is handled by // schedule update below. HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent); - sequence.replace_instruction(kv.first, fusion); + sequence.replace_instruction(hero, fusion); if (fusion->shape().IsTuple()) { TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape( - const_cast(kv.first), fusion)); + const_cast(hero), fusion)); for (auto& sliced_user_path : sliced_user_paths) { auto old_gte = Cast(sliced_user_path.front()); @@ -481,28 +481,29 @@ absl::StatusOr AddressComputationFusionRewriter::Run( parent->ReplaceInstruction(sliced_user_path.back(), gte)); } } else { - auto* old_instr = const_cast(kv.first); + auto* instr_to_be_replaced = const_cast(hero); if (sliced_user_paths.empty()) { // The only case where a tuple-shaped original hero op is fused into a // non-tuple-shaped fusion is there's only one element of the original // tuple being used. In that case, we need to replace that single // get-tuple-element (instead of the hero op) with the fusion // instruction. - if (kv.first->shape().IsTuple()) { - if (kv.first->user_count() != 1 || + if (hero->shape().IsTuple()) { + if (hero->user_count() != 1 || !DynCast( - kv.first->users().front())) { + hero->users().front())) { return absl::InternalError( "Expect a single get-tuple-element user of the original " "tuple-shaped hero op when address computation fusion does " "not return a tuple"); } - old_instr = kv.first->users().front(); + instr_to_be_replaced = hero->users().front(); } } else { - old_instr = sliced_user_paths.front().back(); + instr_to_be_replaced = sliced_user_paths.front().back(); } - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(old_instr, fusion)); + TF_RETURN_IF_ERROR( + parent->ReplaceInstruction(instr_to_be_replaced, fusion)); } } From 3a2cd8887ed96de6abbd26e46844b959aa42e481 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Thu, 28 Mar 2024 09:26:08 -0700 Subject: [PATCH 032/124] [XLA] Add functionality to ReduceScatterDecomposer to be selective PiperOrigin-RevId: 619957874 --- .../xla/xla/service/collective_opt_utils.cc | 34 ++++++++++++++++++- .../xla/xla/service/collective_opt_utils.h | 2 +- .../xla/service/reduce_scatter_decomposer.cc | 4 +++ .../xla/service/reduce_scatter_decomposer.h | 6 ++-- .../service/reduce_scatter_decomposer_test.cc | 32 +++++++++++++++-- 5 files changed, 71 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/collective_opt_utils.cc b/third_party/xla/xla/service/collective_opt_utils.cc index 8e7a6d874cfa8a..183173801077c5 100644 --- a/third_party/xla/xla/service/collective_opt_utils.cc +++ b/third_party/xla/xla/service/collective_opt_utils.cc @@ -267,13 +267,45 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, return true; } +ReduceScatterSpec SpecFromReduceScatterInstr(const HloInstruction* rs_instr, + int64_t num_partitions, + int64_t num_replicas, + bool is_constrain_layout, + bool use_global_device_ids, + bool is_cross_module) { + CHECK(rs_instr->opcode() == HloOpcode::kReduceScatter); + ReduceScatterSpec spec; + spec.split_dim = rs_instr->dimensions(0); + if (!is_cross_module) { + spec.sharded_replicas = num_replicas; + spec.group_size = rs_instr->replica_groups().empty() + ? num_replicas + : rs_instr->replica_groups()[0].replica_ids_size(); + } else if (use_global_device_ids) { + spec.sharded_replicas = num_replicas; + spec.sharded_partitions = num_partitions; + spec.group_size = rs_instr->replica_groups()[0].replica_ids_size(); + } else { + spec.sharded_partitions = num_partitions; + spec.group_size = num_partitions; + } + spec.original_split_dims = {spec.split_dim}; + spec.dynamic_slice = nullptr; + return spec; +} + } // namespace std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, + const HloAllReduceInstructionBase* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, HloPredicate match_partition_id, HloPredicate match_replica_id) { + if (ar->opcode() == HloOpcode::kReduceScatter) { + return SpecFromReduceScatterInstr( + ar, num_partitions, num_replicas, ar->constrain_layout(), + ar->use_global_device_ids(), ar->channel_id().has_value()); + } auto spec = MatchWithDynamicSlice( ar, num_partitions, num_replicas, allow_multiple_split_dims, allow_intervening_reshape, min_rank, match_partition_id, match_replica_id, diff --git a/third_party/xla/xla/service/collective_opt_utils.h b/third_party/xla/xla/service/collective_opt_utils.h index 11b65c1acc4160..7d044be3c34568 100644 --- a/third_party/xla/xla/service/collective_opt_utils.h +++ b/third_party/xla/xla/service/collective_opt_utils.h @@ -36,7 +36,7 @@ struct ReduceScatterSpec { // Matches the given all-reduce operation to a reduce-scatter pattern. std::optional MatchReduceScatter( - const HloAllReduceInstruction* ar, int64_t num_partitions, + const HloAllReduceInstructionBase* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, bool allow_intervening_reshape = false, int64_t min_rank = 1, HloPredicate match_partition_id = HloPredicateIsOp, diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer.cc b/third_party/xla/xla/service/reduce_scatter_decomposer.cc index 7210a2c12b4f30..da2fed224a53f5 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer.cc +++ b/third_party/xla/xla/service/reduce_scatter_decomposer.cc @@ -53,7 +53,11 @@ absl::StatusOr ReduceScatterDecomposer::Run( if (rs->channel_id()) { channel_id = next_channel_id++; } + if (should_decompose_ && !should_decompose_(rs)) { + continue; + } + VLOG(2) << "Decompose: " << rs->ToString(); // Create an all-reduce HloComputation *apply_clone = module->AddComputationAndUnifyNamesAndIds( rs->to_apply()->Clone(), /*is_entry=*/false); diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer.h b/third_party/xla/xla/service/reduce_scatter_decomposer.h index 324d97d0e915e9..1ee1f603c09f28 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer.h +++ b/third_party/xla/xla/service/reduce_scatter_decomposer.h @@ -29,8 +29,9 @@ namespace xla { class ReduceScatterDecomposer : public HloModulePass { public: explicit ReduceScatterDecomposer( - std::function update_layout = nullptr) - : update_layout_(update_layout) {} + std::function update_layout = nullptr, + std::function should_decompose = nullptr) + : update_layout_(update_layout), should_decompose_(should_decompose) {} absl::string_view name() const override { return "reduce-scatter-decomposer"; } @@ -40,6 +41,7 @@ class ReduceScatterDecomposer : public HloModulePass { HloModule* module, const absl::flat_hash_set& execution_threads) override; std::function update_layout_; + std::function should_decompose_; }; } // namespace xla diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc b/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc index bfaa918930befb..d7f8360fbdc910 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc +++ b/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc @@ -41,13 +41,18 @@ class ReduceScatterDecomposerTest : public HloTestBase { absl::string_view hlo_module, PassAction action, CollectiveOpGroupMode mode = CollectiveOpGroupMode::kCrossReplica, int64_t shard_size = 0, int64_t shard_dimension = 0, - int64_t replica_count = 2) { + int64_t replica_count = 2, + std::function should_decompose = + [](const HloInstruction *) { return true; }) { const int64_t partition_count = 2; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnVerifiedModule(hlo_module, replica_count, partition_count)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - ReduceScatterDecomposer().Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ReduceScatterDecomposer(/*update_layout=*/nullptr, + /*should_decompose=*/should_decompose) + .Run(module.get())); if (action == PassAction::kNoChange) { ASSERT_FALSE(changed); return; @@ -222,5 +227,26 @@ ENTRY main { RunPass(hlo_string, PassAction::kNoChange); } +TEST_F(ReduceScatterDecomposerTest, NoChangeWithShouldDecompose) { + absl::string_view hlo_string = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +ENTRY main { + p0 = f32[4, 8] parameter(0) + ROOT rs = f32[4, 4] reduce-scatter(p0), replica_groups={{0,1}, {2,3}}, channel_id=1, dimensions={1}, to_apply=sum, use_global_device_ids=true +} +)"; + RunPass(hlo_string, PassAction::kNoChange, + CollectiveOpGroupMode::kCrossReplica, + /*shard_size=*/0, /*shard_dimension=*/0, + /*replica_count=*/2, [](const HloInstruction *) { return false; }); +} + } // namespace } // namespace xla From 6bbc1688e7c99a5067c12bdd2352f7509bb3b1cd Mon Sep 17 00:00:00 2001 From: Vladyslav Tsilytskyi Date: Thu, 28 Mar 2024 09:45:23 -0700 Subject: [PATCH 033/124] [stream_executor:host] Rename host_gpu_executor PiperOrigin-RevId: 619965098 --- third_party/xla/xla/stream_executor/host/BUILD | 9 ++++----- .../host/{host_gpu_executor.cc => host_executor.cc} | 2 +- .../host/{host_gpu_executor.h => host_executor.h} | 6 +++--- .../xla/xla/stream_executor/host/host_platform.cc | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) rename third_party/xla/xla/stream_executor/host/{host_gpu_executor.cc => host_executor.cc} (99%) rename third_party/xla/xla/stream_executor/host/{host_gpu_executor.h => host_executor.h} (97%) diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index a29fa215b5736c..04edca1c9e225a 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -40,7 +40,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":host_gpu_executor", + ":host_executor", ":host_platform_id", "//xla/stream_executor", "//xla/stream_executor:platform_manager", @@ -110,14 +110,13 @@ xla_cc_test( ], ) -# TODO(22689637): Rename this target. cc_library( - name = "host_gpu_executor", + name = "host_executor", srcs = [ - "host_gpu_executor.cc", + "host_executor.cc", ], hdrs = [ - "host_gpu_executor.h", + "host_executor.h", ], deps = [ ":host_stream", diff --git a/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc similarity index 99% rename from third_party/xla/xla/stream_executor/host/host_gpu_executor.cc rename to third_party/xla/xla/stream_executor/host/host_executor.cc index 13f8ec2e0bca9c..33d54a16f3fee1 100644 --- a/third_party/xla/xla/stream_executor/host/host_gpu_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -15,7 +15,7 @@ limitations under the License. // Implementation of HostExecutor class [of those methods not defined in the // class declaration]. -#include "xla/stream_executor/host/host_gpu_executor.h" +#include "xla/stream_executor/host/host_executor.h" #include #include diff --git a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h similarity index 97% rename from third_party/xla/xla/stream_executor/host/host_gpu_executor.h rename to third_party/xla/xla/stream_executor/host/host_executor.h index 0c86d7080c8755..6123e227591fa8 100644 --- a/third_party/xla/xla/stream_executor/host/host_gpu_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -16,8 +16,8 @@ limitations under the License. // Declares the HostExecutor class, which is a CPU-only implementation of // the StreamExecutor interface. For now, this is used for testing and to // examine the performance of host-based StreamExecutor code. -#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ -#define XLA_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ +#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ +#define XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ #include #include @@ -147,4 +147,4 @@ class HostExecutor : public internal::StreamExecutorInterface { } // namespace host } // namespace stream_executor -#endif // XLA_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ +#endif // XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ diff --git a/third_party/xla/xla/stream_executor/host/host_platform.cc b/third_party/xla/xla/stream_executor/host/host_platform.cc index 58771670d0b6a1..23112fbecd51aa 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.cc +++ b/third_party/xla/xla/stream_executor/host/host_platform.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/host/host_gpu_executor.h" +#include "xla/stream_executor/host/host_executor.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" From 5e27ca300ab463783f491fb3e67d6bfdb4ba9cd2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 09:50:49 -0700 Subject: [PATCH 034/124] Make the computation of the memory budget lower bound more efficient. Specifically: 1. Iterate over AliasAnalysis::buffers() rather than the live ranges as the latter has too many duplicates 2. Turn a Shape object into a const reference 3. Pull out a conditional to exit early when possible. PiperOrigin-RevId: 619967473 --- .../auto_sharding/auto_sharding.cc | 44 ++++++++++++------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 82bac59e3b2d7d..49370d832b393f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2650,7 +2650,7 @@ int64_t MemoryBudgetLowerBound( const HloAliasAnalysis& alias_analysis, const int64_t num_devices, const absl::flat_hash_map>& preserved_shardings) { - auto get_value_sharding = [](const HloValue* value) { + auto get_value_sharding = [](const HloValue* value) -> HloSharding { return !value->index().empty() ? value->instruction()->sharding().GetSubSharding( value->instruction()->shape(), value->index()) @@ -2664,9 +2664,8 @@ int64_t MemoryBudgetLowerBound( absl::flat_hash_map buffer_to_sharded_value_mapping; bool vlog_is_on_5 = VLOG_IS_ON(5); - for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { - for (const HloValue* value : liveness_set[time_idx]) { - const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(*value); + for (const HloBuffer& buffer : alias_analysis.buffers()) { + for (const HloValue* value : buffer.values()) { if (value->instruction()->has_sharding()) { if (vlog_is_on_5) { const HloSharding& this_value_sharding = get_value_sharding(value); @@ -2694,36 +2693,51 @@ int64_t MemoryBudgetLowerBound( } int64_t max_memory_usage = 0; + absl::flat_hash_map value_to_memory_size_mapping; for (LivenessIdx time_idx = 0; time_idx < liveness_set.size(); ++time_idx) { int64_t memory_usage = 0; for (const HloValue* value : liveness_set[time_idx]) { if (value->instruction()->shape().IsTuple() && value->index().empty()) { continue; } - Shape shape = - ShapeUtil::GetSubshape(value->instruction()->shape(), value->index()); - const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(*value); - auto iter = buffer_to_sharded_value_mapping.find(buffer.id()); + + auto iter1 = value_to_memory_size_mapping.find(value); + if (iter1 != value_to_memory_size_mapping.end()) { + memory_usage += iter1->second; + continue; + } + std::optional optional_sharding = std::nullopt; - if (iter != buffer_to_sharded_value_mapping.end()) { + const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(*value); + auto iter2 = buffer_to_sharded_value_mapping.find(buffer.id()); + if (iter2 != buffer_to_sharded_value_mapping.end()) { // The instructions here can have partial sharding annotations from // previous iterations with partial mesh shapes when // solve_nd_sharding_iteratively is true. To exclude these, we only // utilize those shardings which corresponding to the current device // mesh. - const HloSharding& value_sharding = get_value_sharding(iter->second); if (preserved_shardings.find(value->instruction()->name()) != - preserved_shardings.end() || - !value_sharding.IsTiled() || - value_sharding.TotalNumTiles() == num_devices) { - optional_sharding = value_sharding; + preserved_shardings.end()) { + optional_sharding = get_value_sharding(iter2->second); + } else { + const HloSharding& value_sharding = get_value_sharding(iter2->second); + if (!value_sharding.IsTiled() || + value_sharding.TotalNumTiles() == num_devices) { + optional_sharding = value_sharding; + } } } - memory_usage += + + const Shape& shape = + ShapeUtil::GetSubshape(value->instruction()->shape(), value->index()); + int64_t value_memory_usage = GetShardedInstructionSize(shape, num_devices, optional_sharding); + value_to_memory_size_mapping[value] = value_memory_usage; + memory_usage += value_memory_usage; } max_memory_usage = std::max(max_memory_usage, memory_usage); } + return max_memory_usage; } From 498618553949d8a5b33f7e7515b9fbd45cffdaaa Mon Sep 17 00:00:00 2001 From: akhilgoe <114951738+akhilgoe@users.noreply.github.com> Date: Thu, 28 Mar 2024 09:53:05 -0700 Subject: [PATCH 035/124] PR #10759: [XLA:CPU][oneDNN] Enable matrix-vector and vector-vector product Imported from GitHub PR https://github.com/openxla/xla/pull/10759 This PR relaxes conditions to rewrite dot operations of the form vector-matrix, matrix-vector, or vector-vector to oneDNN custom calls, provided the original problem meets the empirically determined multiply-accumulate threshold. In particular this PR: 1. Relaxes some constraints on Dot to oneDNN matmul custom call conversion 2. Reconfigures the dimensions of the operands and outputs of convertible dot operations. 3. Adds tests to verify rewrite and execution result Copybara import of the project: -- d24e5cd0b77d0734a0f33011ae03127f00d80e7d by Akhil Goel : Relax constraints for matmul rewrite -- 9582288fc5d06eb8f4641be32d6c41746f13dbce by Akhil Goel : Fix gemv test after merge -- 604d4fbd29a78a7d305092356176e31081ea4ff0 by Akhil Goel : Address review comments -- 4534fbd7e96a6e45ccc5a501b38b940aa8bd9d38 by Akhil Goel : Optional commit Merging this change closes #10759 PiperOrigin-RevId: 619968241 --- .../xla/xla/service/cpu/cpu_float_support.cc | 14 +- .../xla/service/cpu/onednn_matmul_rewriter.cc | 97 +++- .../xla/xla/tests/onednn_matmul_test.cc | 466 ++++++++++-------- 3 files changed, 344 insertions(+), 233 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_float_support.cc b/third_party/xla/xla/service/cpu/cpu_float_support.cc index dd5c6c5b9d5049..0bb4dd8e875a75 100644 --- a/third_party/xla/xla/service/cpu/cpu_float_support.cc +++ b/third_party/xla/xla/service/cpu/cpu_float_support.cc @@ -27,7 +27,7 @@ bool CpuFloatSupport::IsSupported(const HloInstruction& hlo) const { // oneDNN rewritable ops case HloOpcode::kDot: return LowPrecisionType() == BF16 && - OneDnnMatMulRewriter::ShouldRewrite(&hlo) && DotSupported(hlo); + OneDnnMatMulRewriter::ShouldRewrite(&hlo); // Collective ops. case HloOpcode::kAllGather: case HloOpcode::kAllReduce: @@ -59,18 +59,6 @@ bool CpuFloatSupport::IsSupported(const HloInstruction& hlo) const { } } -bool CpuFloatSupport::DotSupported(const HloInstruction& hlo) const { - bool supported = true; - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - if (lhs_shape.rank() == rhs_shape.rank() && lhs_shape.rank() == 2) { - // If first dim size is 1, it may be removed by a later pass which makes it - // unsupported case. - supported &= lhs_shape.dimensions(0) != 1; - } - return supported; -} - } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc index d73975b3eac654..e08cd7b6c7118a 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc @@ -333,12 +333,15 @@ bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) { ShapeUtil::IsZeroElementArray(output_shape)) { return false; } - // OneDNN only supports 2 <= rank <= kOneDnnMaxNDims. - if (lhs_shape.rank() != rhs_shape.rank() || - rhs_shape.rank() != output_shape.rank() || lhs_shape.rank() < 2 || - lhs_shape.rank() > kOneDnnMaxNDims) { + // OneDNN only supports rank <= kOneDnnMaxNDims and singular non-contracting + // dimensions. We should not rewrite if any of these conditions are violated. + if (lhs_shape.rank() <= 0 || lhs_shape.rank() > kOneDnnMaxNDims || + rhs_shape.rank() <= 0 || rhs_shape.rank() > kOneDnnMaxNDims || + output_shape.rank() > std::min({lhs_shape.rank(), rhs_shape.rank(), + static_cast(kOneDnnMaxNDims)})) { return false; } + // Layout should be row-major, contraction dimensions captures transpose // scenarios in last two dimensions. if (!IsRowMajor(lhs_shape) || !IsRowMajor(rhs_shape) || @@ -362,7 +365,7 @@ bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) { auto num_flops = xla::HloCostAnalysis::GetDotFlops(lhs_shape, output_shape, dot_dim_numbers); auto rank = output_shape.rank(); - auto flops_threshold = (rank == 2) ? (1 << 24) : (1 << 19); + auto flops_threshold = (rank <= 2) ? (1 << 24) : (1 << 19); return (num_flops >= flops_threshold); } @@ -375,10 +378,11 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { auto pattern = m::Op(&dot_instr).WithOpcode(HloOpcode::kDot); if (!Match(instr, pattern)) return OkStatus(); - auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); - TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot_dim_numbers)); - + TF_RETURN_IF_ERROR( + ValidateDotDimensionNumbers(dot_instr->dot_dimension_numbers())); if (!OneDnnMatMulRewriter::ShouldRewrite(dot_instr)) return OkStatus(); + TF_ASSIGN_OR_RETURN(dot_instr, ReconfigureDotDimensions(dot_instr)); + auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); const Shape& lhs_shape = dot_instr->operand(0)->shape(); const Shape& rhs_shape = dot_instr->operand(1)->shape(); const Shape& output_shape = dot_instr->shape(); @@ -630,6 +634,83 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { return ReplaceWithNewInstruction(activation, std::move(output)); } + + // This function changes dot instruction for supported matrix + // multiplication scenarios. In particular, it changes the shape + // of lhs, rhs and result arrays. + // - lhs configuration scenario + // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim] + // result: [batch_dims,feature_dim] to [batch_dims,1,feature_dim] + // + // - rhs configuration scenario + // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1] + // result: [batch_dims,feature_dim] to [batch_dims,feature_dim, 1] + // + // - both lhs and rhs configuration scenario + // lhs: [batch_dims,contracting_dim] to [batch_dims,1,contracting_dim] + // rhs: [batch_dims,contracting_dim] to [batch_dims,contracting_dim,1] + // result: [batch_dims] to [batch_dims,1,1] + StatusOr ReconfigureDotDimensions( + HloInstruction* dot_instr) { + HloInstruction* lhs = dot_instr->mutable_operand(0); + HloInstruction* rhs = dot_instr->mutable_operand(1); + DotDimensionNumbers dim_numbers = dot_instr->dot_dimension_numbers(); + + auto lhs_batch_dims = dim_numbers.lhs_batch_dimensions(); + auto lhs_contraction_dims = dim_numbers.lhs_contracting_dimensions(); + bool is_lhs_vector = lhs->shape().rank() == + (lhs_batch_dims.size() + lhs_contraction_dims.size()); + + auto rhs_batch_dims = dim_numbers.rhs_batch_dimensions(); + auto rhs_contraction_dims = dim_numbers.rhs_contracting_dimensions(); + bool is_rhs_vector = rhs->shape().rank() == + (rhs_batch_dims.size() + rhs_contraction_dims.size()); + + if (!is_lhs_vector && !is_rhs_vector) return dot_instr; + + std::vector adjusted_lhs_dims(lhs->shape().dimensions().begin(), + lhs->shape().dimensions().end()); + std::vector adjusted_rhs_dims(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + std::vector adjusted_dot_dims( + dot_instr->shape().dimensions().begin(), + dot_instr->shape().dimensions().end()); + + if (is_lhs_vector) { + auto lhs_it = adjusted_lhs_dims.begin() + lhs_batch_dims.size(); + adjusted_lhs_dims.insert(lhs_it, 1, 1); + auto result_it = adjusted_dot_dims.begin() + lhs_batch_dims.size(); + adjusted_dot_dims.insert(result_it, 1, 1); + auto lhs_contraction_dim = + dot_instr->dot_dimension_numbers().lhs_contracting_dimensions(0); + dim_numbers.set_lhs_contracting_dimensions(0, lhs_contraction_dim + 1); + lhs = lhs->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::MakeShape(lhs->shape().element_type(), adjusted_lhs_dims), + lhs)); + } + + if (is_rhs_vector) { + auto it = adjusted_rhs_dims.end(); + adjusted_rhs_dims.insert(it, 1, 1); + auto result_it = adjusted_dot_dims.end(); + adjusted_dot_dims.insert(result_it, 1, 1); + rhs = rhs->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::MakeShape(rhs->shape().element_type(), adjusted_rhs_dims), + rhs)); + } + + HloInstruction* adjusted_dot = + dot_instr->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(dot_instr->shape().element_type(), + adjusted_dot_dims), + lhs, rhs, dim_numbers, dot_instr->precision_config())); + + HloInstruction* replacement_instr = adjusted_dot->AddInstruction( + HloInstruction::CreateBitcast(dot_instr->shape(), adjusted_dot)); + + TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr)); + return adjusted_dot; + } }; class OneDnnMatMulReorderVisitor : public DfsHloRewriteVisitor { diff --git a/third_party/xla/xla/tests/onednn_matmul_test.cc b/third_party/xla/xla/tests/onednn_matmul_test.cc index 8b100b6141bee2..64befeab768c25 100644 --- a/third_party/xla/xla/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/tests/onednn_matmul_test.cc @@ -67,12 +67,12 @@ class MatmulTest : public HloTestBase { TEST_F(MatmulTest, SimpleTestF32) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[32,8,128,64]{3,2,1,0},f32[32,8,64,128]{3,2,1,0})->f32[32,8,128,128]{3,2,1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.f32 { - arg.0 = f32[32,8,128,64]{3,2,1,0} parameter(0), parameter_replication={false} - arg.1 = f32[32,8,64,128]{3,2,1,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = f32[32,8,128,128]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + arg.0 = f32[32,8,128,64] parameter(0), parameter_replication={false} + arg.1 = f32[32,8,64,128] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f32[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -87,12 +87,12 @@ TEST_F(MatmulTest, SimpleTestBF16) { } const char* matmul_module_str = R"( - HloModule matmul.test.bf16, entry_computation_layout={(bf16[32,8,128,64]{3,2,1,0},bf16[32,8,64,128]{3,2,1,0})->bf16[32,8,128,128]{3,2,1,0}} + HloModule matmul.test.bf16 ENTRY matmul.test.bf16 { - arg.0 = bf16[32,8,128,64]{3,2,1,0} parameter(0), parameter_replication={false} - arg.1 = bf16[32,8,64,128]{3,2,1,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = bf16[32,8,128,128]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + arg.0 = bf16[32,8,128,64] parameter(0), parameter_replication={false} + arg.1 = bf16[32,8,64,128] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = bf16[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-4})); @@ -105,11 +105,12 @@ TEST_F(MatmulTest, SimpleTestF16) { } const char* matmul_module_str = R"( - HloModule matmul.test.f16, entry_computation_layout={(f16[32,8,128,64]{3,2,1,0},f16[32,8,64,128]{3,2,1,0})->f16[32,8,128,128]{3,2,1,0}} + HloModule matmul.test.f16 + ENTRY matmul.test.f16 { - arg.0 = f16[32,8,128,64]{3,2,1,0} parameter(0), parameter_replication={false} - arg.1 = f16[32,8,64,128]{3,2,1,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = f16[32,8,128,128]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + arg.0 = f16[32,8,128,64] parameter(0), parameter_replication={false} + arg.1 = f16[32,8,64,128] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f16[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-4})); @@ -118,12 +119,12 @@ TEST_F(MatmulTest, SimpleTestF16) { TEST_F(MatmulTest, SimpleTestF32TransposeB) { const char* matmul_module_str = R"( - HloModule matmul.test.1, entry_computation_layout={(f32[32,8,128,64]{3,1,2,0},f32[32,8,128,64]{3,1,2,0})->f32[32,8,128,128]{3,2,1,0}} + HloModule matmul.test.1 ENTRY matmul.test.1 { arg.0 = f32[32,8,128,64]{3,1,2,0} parameter(0), parameter_replication={false} arg.1 = f32[32,8,128,64]{3,1,2,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = f32[32,8,128,128]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + ROOT onednn.matmul.0 = f32[32,8,128,128] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -132,21 +133,21 @@ TEST_F(MatmulTest, SimpleTestF32TransposeB) { TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion1) { const char* matmul_module_str = R"( - HloModule matmul.biasadd.test.f32, entry_computation_layout={(f32[32,32,40,30]{3,2,1,0})->f32[32,32,40,40]{3,2,1,0}} - + HloModule matmul.biasadd.test.f32 + ENTRY matmul.biasadd.test.f32 { - arg0.1 = f32[32,32,40,30]{3,2,1,0} parameter(0), parameter_replication={false} - reshape.2 = f32[32,32,40,30]{3,2,1,0} reshape(arg0.1) + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + reshape.2 = f32[32,32,40,30] reshape(arg0.1) constant.3 = f32[] constant(1) - broadcast.4 = f32[32,32,30,40]{3,2,1,0} broadcast(constant.3), dimensions={} - dot.7 = f32[32,32,40,40]{3,2,1,0} dot(reshape.2, broadcast.4), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broadcast.4 = f32[32,32,30,40] broadcast(constant.3), dimensions={} + dot.7 = f32[32,32,40,40] dot(reshape.2, broadcast.4), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} constant.5 = f32[] constant(15) - broadcast.6 = f32[40]{0} broadcast(constant.5), dimensions={} - broadcast.9 = f32[32,32,40,40]{3,2,1,0} broadcast(broadcast.6), dimensions={3} - add.10 = f32[32,32,40,40]{3,2,1,0} add(dot.7, broadcast.9) - reshape.11 = f32[32,32,40,40]{3,2,1,0} reshape(add.10) - tuple.12 = (f32[32,32,40,40]{3,2,1,0}) tuple(reshape.11) - ROOT get-tuple-element.13 = f32[32,32,40,40]{3,2,1,0} get-tuple-element(tuple.12), index=0 + broadcast.6 = f32[40] broadcast(constant.5), dimensions={} + broadcast.9 = f32[32,32,40,40] broadcast(broadcast.6), dimensions={3} + add.10 = f32[32,32,40,40] add(dot.7, broadcast.9) + reshape.11 = f32[32,32,40,40] reshape(add.10) + tuple.12 = (f32[32,32,40,40]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -155,21 +156,21 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion1) { TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion2) { const char* matmul_module_str = R"( - HloModule matmul.biasadd.test.f32, entry_computation_layout={(f32[400,300]{1,0})->f32[400,1,400]{2,1,0}} + HloModule matmul.biasadd.test.f32 ENTRY matmul.biasadd.test.f32 { - arg0.1 = f32[400,300]{1,0} parameter(0), parameter_replication={false} - reshape.2 = f32[400,300]{1,0} reshape(arg0.1) + arg0.1 = f32[400,300] parameter(0), parameter_replication={false} + reshape.2 = f32[400,300] reshape(arg0.1) constant.3 = f32[] constant(1) - broadcast.4 = f32[300,400]{1,0} broadcast(constant.3), dimensions={} - dot.7 = f32[400,400]{1,0} dot(reshape.2, broadcast.4), lhs_batch_dims={}, lhs_contracting_dims={1}, rhs_batch_dims={}, rhs_contracting_dims={0} - reshape.1 = f32[400,1,400]{2,1,0} reshape(dot.7) + broadcast.4 = f32[300,400] broadcast(constant.3), dimensions={} + dot.7 = f32[400,400] dot(reshape.2, broadcast.4), lhs_batch_dims={}, lhs_contracting_dims={1}, rhs_batch_dims={}, rhs_contracting_dims={0} + reshape.1 = f32[400,1,400] reshape(dot.7) constant.5 = f32[] constant(15) - broadcast.6 = f32[400]{0} broadcast(constant.5), dimensions={} - broadcast.9 = f32[400,1,400]{2,1,0} broadcast(broadcast.6), dimensions={2} - add.10 = f32[400,1,400]{2,1,0} add(reshape.1, broadcast.9) - tuple.12 = (f32[400,1,400]{2,1,0}) tuple(add.10) - ROOT get-tuple-element.13 = f32[400,1,400]{2,1,0} get-tuple-element(tuple.12), index=0 + broadcast.6 = f32[400] broadcast(constant.5), dimensions={} + broadcast.9 = f32[400,1,400] broadcast(broadcast.6), dimensions={2} + add.10 = f32[400,1,400] add(reshape.1, broadcast.9) + tuple.12 = (f32[400,1,400]) tuple(add.10) + ROOT get-tuple-element.13 = f32[400,1,400] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -178,17 +179,17 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion2) { TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter1) { const char* matmul_module_str = R"( - HloModule matmul.biasadd.test.f32, entry_computation_layout={(f32[32,32,40,30]{3,2,1,0}, f32[32,32,30,40]{3,2,1,0}, f32[32,32,40,40]{3,2,1,0})->f32[32,32,40,40]{3,2,1,0}} - + HloModule matmul.biasadd.test.f32 + ENTRY matmul.biasadd.test.f32 { - arg0.1 = f32[32,32,40,30]{3,2,1,0} parameter(0), parameter_replication={false} - arg0.2 = f32[32,32,30,40]{3,2,1,0} parameter(1), parameter_replication={false} - arg0.3 = f32[32,32,40,40]{3,2,1,0} parameter(2), parameter_replication={false} - dot.7 = f32[32,32,40,40]{3,2,1,0} dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - add.10 = f32[32,32,40,40]{3,2,1,0} add(dot.7, arg0.3) - reshape.11 = f32[32,32,40,40]{3,2,1,0} reshape(add.10) - tuple.12 = (f32[32,32,40,40]{3,2,1,0}) tuple(reshape.11) - ROOT get-tuple-element.13 = f32[32,32,40,40]{3,2,1,0} get-tuple-element(tuple.12), index=0 + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f32[32,32,30,40] parameter(1), parameter_replication={false} + arg0.3 = f32[32,32,40,40] parameter(2), parameter_replication={false} + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + add.10 = f32[32,32,40,40] add(dot.7, arg0.3) + reshape.11 = f32[32,32,40,40] reshape(add.10) + tuple.12 = (f32[32,32,40,40]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -197,18 +198,18 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter1) { TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2) { const char* matmul_module_str = R"( - HloModule matmul.biasadd.test.f32, entry_computation_layout={(f32[32,32,40,30]{3,2,1,0}, f32[32,32,30,40]{3,2,1,0}, f32[40]{0})->f32[32,32,40,40]{3,2,1,0}} - + HloModule matmul.biasadd.test.f32 + ENTRY matmul.biasadd.test.f32 { - arg0.1 = f32[32,32,40,30]{3,2,1,0} parameter(0), parameter_replication={false} - arg0.2 = f32[32,32,30,40]{3,2,1,0} parameter(1), parameter_replication={false} + arg0.1 = f32[32,32,40,30] parameter(0), parameter_replication={false} + arg0.2 = f32[32,32,30,40] parameter(1), parameter_replication={false} arg0.3 = f32[40]{0} parameter(2), parameter_replication={false} - dot.7 = f32[32,32,40,40]{3,2,1,0} dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - broad.1 = f32[32,32,40,40]{3,2,1,0} broadcast(arg0.3), dimensions={3} - add.10 = f32[32,32,40,40]{3,2,1,0} add(dot.7, broad.1) - reshape.11 = f32[32,32,40,40]{3,2,1,0} reshape(add.10) - tuple.12 = (f32[32,32,40,40]{3,2,1,0}) tuple(reshape.11) - ROOT get-tuple-element.13 = f32[32,32,40,40]{3,2,1,0} get-tuple-element(tuple.12), index=0 + dot.7 = f32[32,32,40,40] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broad.1 = f32[32,32,40,40] broadcast(arg0.3), dimensions={3} + add.10 = f32[32,32,40,40] add(dot.7, broad.1) + reshape.11 = f32[32,32,40,40] reshape(add.10) + tuple.12 = (f32[32,32,40,40]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[32,32,40,40] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -217,18 +218,18 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2) { TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2D) { const char* matmul_module_str = R"( - HloModule matmul.biasadd.test.f32, entry_computation_layout={(f32[2,2,400,30]{3,2,1,0}, f32[2,2,30,400]{3,2,1,0}, f32[2,400]{1,0})->f32[2,2,400,400]{3,2,1,0}} - + HloModule matmul.biasadd.test.f32 + ENTRY matmul.biasadd.test.f32 { - arg0.1 = f32[2,2,400,30]{3,2,1,0} parameter(0), parameter_replication={false} - arg0.2 = f32[2,2,30,400]{3,2,1,0} parameter(1), parameter_replication={false} - arg0.3 = f32[2,400]{1,0} parameter(2), parameter_replication={false} - dot.7 = f32[2,2,400,400]{3,2,1,0} dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - broad.1 = f32[2,2,400,400]{3,2,1,0} broadcast(arg0.3), dimensions={0,3} - add.10 = f32[2,2,400,400]{3,2,1,0} add(dot.7, broad.1) - reshape.11 = f32[2,2,400,400]{3,2,1,0} reshape(add.10) - tuple.12 = (f32[2,2,400,400]{3,2,1,0}) tuple(reshape.11) - ROOT get-tuple-element.13 = f32[2,2,400,400]{3,2,1,0} get-tuple-element(tuple.12), index=0 + arg0.1 = f32[2,2,400,30] parameter(0), parameter_replication={false} + arg0.2 = f32[2,2,30,400] parameter(1), parameter_replication={false} + arg0.3 = f32[2,400] parameter(2), parameter_replication={false} + dot.7 = f32[2,2,400,400] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broad.1 = f32[2,2,400,400] broadcast(arg0.3), dimensions={0,3} + add.10 = f32[2,2,400,400] add(dot.7, broad.1) + reshape.11 = f32[2,2,400,400] reshape(add.10) + tuple.12 = (f32[2,2,400,400]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[2,2,400,400] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -237,18 +238,18 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2D) { TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2D1B) { const char* matmul_module_str = R"( - HloModule matmul.biasadd.test.f32, entry_computation_layout={(f32[1,2,400,30]{3,2,1,0}, f32[1,2,30,400]{3,2,1,0}, f32[1,400]{1,0})->f32[1,2,400,400]{3,2,1,0}} - + HloModule matmul.biasadd.test.f32 + ENTRY matmul.biasadd.test.f32 { - arg0.1 = f32[1,2,400,30]{3,2,1,0} parameter(0), parameter_replication={false} - arg0.2 = f32[1,2,30,400]{3,2,1,0} parameter(1), parameter_replication={false} - arg0.3 = f32[1,400]{1,0} parameter(2), parameter_replication={false} - dot.7 = f32[1,2,400,400]{3,2,1,0} dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - broad.1 = f32[1,2,400,400]{3,2,1,0} broadcast(arg0.3), dimensions={0,3} - add.10 = f32[1,2,400,400]{3,2,1,0} add(dot.7, broad.1) - reshape.11 = f32[1,2,400,400]{3,2,1,0} reshape(add.10) - tuple.12 = (f32[1,2,400,400]{3,2,1,0}) tuple(reshape.11) - ROOT get-tuple-element.13 = f32[1,2,400,400]{3,2,1,0} get-tuple-element(tuple.12), index=0 + arg0.1 = f32[1,2,400,30] parameter(0), parameter_replication={false} + arg0.2 = f32[1,2,30,400] parameter(1), parameter_replication={false} + arg0.3 = f32[1,400] parameter(2), parameter_replication={false} + dot.7 = f32[1,2,400,400] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + broad.1 = f32[1,2,400,400] broadcast(arg0.3), dimensions={0,3} + add.10 = f32[1,2,400,400] add(dot.7, broad.1) + reshape.11 = f32[1,2,400,400] reshape(add.10) + tuple.12 = (f32[1,2,400,400]) tuple(reshape.11) + ROOT get-tuple-element.13 = f32[1,2,400,400] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -257,18 +258,18 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter2D1B) { TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter3) { const char* matmul_module_str = R"( - HloModule matmul.biasadd.test.f32, entry_computation_layout={(f32[16,128,768]{2,1,0}, f32[768,768]{1,0}, f32[768]{0})->f32[16,128,768]{2,1,0}} - + HloModule matmul.biasadd.test.f32 + ENTRY matmul.biasadd.test.f32 { - arg0.1 = f32[16,128,768]{2,1,0} parameter(0), sharding={replicated} - arg0.2 = f32[768,768]{1,0} parameter(1), sharding={replicated} - dot.84 = f32[16,128,768]{2,1,0} dot(arg0.1, arg0.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} + arg0.1 = f32[16,128,768] parameter(0), sharding={replicated} + arg0.2 = f32[768,768] parameter(1), sharding={replicated} + dot.84 = f32[16,128,768] dot(arg0.1, arg0.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} arg0.3 = f32[768]{0} parameter(2), sharding={replicated} - reshape.85 = f32[1,1,768]{2,1,0} reshape(arg0.3) - broadcast.86 = f32[1,1,768]{2,1,0} broadcast(reshape.85), dimensions={0,1,2} + reshape.85 = f32[1,1,768] reshape(arg0.3) + broadcast.86 = f32[1,1,768] broadcast(reshape.85), dimensions={0,1,2} reshape.87 = f32[768]{0} reshape(broadcast.86) - broadcast.88 = f32[16,128,768]{2,1,0} broadcast(reshape.87), dimensions={2} - ROOT add.89 = f32[16,128,768]{2,1,0} add(dot.84, broadcast.88) + broadcast.88 = f32[16,128,768] broadcast(reshape.87), dimensions={2} + ROOT add.89 = f32[16,128,768] add(dot.84, broadcast.88) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -277,8 +278,8 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAsParameter3) { TEST_F(MatmulTest, SimpleTestF32TransposeBWithBiasAddFusion) { const char* matmul_module_str = R"( - HloModule matmul.test.1, entry_computation_layout={(f32[32,8,4,16]{3,1,2,0},f32[32,8,16,16]{3,1,2,0})->f32[32,8,4,16]{3,2,1,0}} - + HloModule matmul.test.1 + ENTRY matmul.test.1 { arg.0 = f32[32,8,4,16]{3,1,2,0} parameter(0), parameter_replication={false} arg.1 = f32[32,8,16,16]{3,1,2,0} parameter(1), parameter_replication={false} @@ -298,19 +299,19 @@ TEST_F(MatmulTest, SimpleTestF32TransposeBWithBiasAddFusion) { TEST_F(MatmulTest, F32BiasAddFusionNonCompatibleBias) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[12288,2]{1,0},f32[2,1024]{1,0})->f32[32,384,1024]{2,1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.1 { - arg.0 = f32[12288,2]{1,0} parameter(0), parameter_replication={false} - arg.1 = f32[2,1024]{1,0} parameter(1), parameter_replication={false} - dot.0 = f32[12288,1024]{1,0} dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - reshape.0 = f32[32,384,1024]{2,1,0} reshape(dot.0) - constant.0 = f32[1,384,1024]{2,1,0} constant(15) - reshape.1 = f32[384,1024]{1,0} reshape(constant.0) - broadcast.0 = f32[32,384,1024]{2,1,0} broadcast(reshape.1), dimensions={1,2} - add.0 = f32[32,384,1024]{2,1,0} add(reshape.0, broadcast.0) - tuple.0 = (f32[32,384,1024]{2,1,0}) tuple(add.0) - ROOT get-tuple-element.0 = f32[32,384,1024]{2,1,0} get-tuple-element(tuple.0), index=0 + arg.0 = f32[12288,2] parameter(0), parameter_replication={false} + arg.1 = f32[2,1024] parameter(1), parameter_replication={false} + dot.0 = f32[12288,1024] dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.0 = f32[32,384,1024] reshape(dot.0) + constant.0 = f32[1,384,1024] constant(15) + reshape.1 = f32[384,1024] reshape(constant.0) + broadcast.0 = f32[32,384,1024] broadcast(reshape.1), dimensions={1,2} + add.0 = f32[32,384,1024] add(reshape.0, broadcast.0) + tuple.0 = (f32[32,384,1024]) tuple(add.0) + ROOT get-tuple-element.0 = f32[32,384,1024] get-tuple-element(tuple.0), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -319,29 +320,29 @@ TEST_F(MatmulTest, F32BiasAddFusionNonCompatibleBias) { TEST_F(MatmulTest, ApproxGELUTestF32) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[32,32,4,16]{3,2,1,0},f32[32,32,16,32]{3,2,1,0})->f32[32,32,4,32]{3,2,1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.f32 { - arg.0 = f32[32,32,4,16]{3,2,1,0} parameter(0), parameter_replication={false} - arg.1 = f32[32,32,16,32]{3,2,1,0} parameter(1), parameter_replication={false} - onednn.matmul.0 = f32[32,32,4,32]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - mul.0 = f32[32,32,4,32]{3,2,1,0} multiply(onednn.matmul.0, onednn.matmul.0) - mul.1 = f32[32,32,4,32]{3,2,1,0} multiply(onednn.matmul.0, mul.0) + arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} + arg.1 = f32[32,32,16,32] parameter(1), parameter_replication={false} + onednn.matmul.0 = f32[32,32,4,32] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + mul.0 = f32[32,32,4,32] multiply(onednn.matmul.0, onednn.matmul.0) + mul.1 = f32[32,32,4,32] multiply(onednn.matmul.0, mul.0) const.0 = f32[] constant(0.044715) - bcast.0 = f32[32,32,4,32]{3,2,1,0} broadcast(const.0), dimensions={} - mul.2 = f32[32,32,4,32]{3,2,1,0} multiply(mul.1, bcast.0) - add.0 = f32[32,32,4,32]{3,2,1,0} add(onednn.matmul.0, mul.2) + bcast.0 = f32[32,32,4,32] broadcast(const.0), dimensions={} + mul.2 = f32[32,32,4,32] multiply(mul.1, bcast.0) + add.0 = f32[32,32,4,32] add(onednn.matmul.0, mul.2) const.1 = f32[] constant(0.797884583) - bcast.1 = f32[32,32,4,32]{3,2,1,0} broadcast(const.1), dimensions={} - mul.3 = f32[32,32,4,32]{3,2,1,0} multiply(add.0, bcast.1) - tanh = f32[32,32,4,32]{3,2,1,0} tanh(mul.3) + bcast.1 = f32[32,32,4,32] broadcast(const.1), dimensions={} + mul.3 = f32[32,32,4,32] multiply(add.0, bcast.1) + tanh = f32[32,32,4,32] tanh(mul.3) const.2 = f32[] constant(1) - bcast.2 = f32[32,32,4,32]{3,2,1,0} broadcast(const.2), dimensions={} - add.2 = f32[32,32,4,32]{3,2,1,0} add(tanh, bcast.2) + bcast.2 = f32[32,32,4,32] broadcast(const.2), dimensions={} + add.2 = f32[32,32,4,32] add(tanh, bcast.2) const.3 = f32[] constant(0.5) - bcast.3 = f32[32,32,4,32]{3,2,1,0} broadcast(const.3), dimensions={} - mul.4 = f32[32,32,4,32]{3,2,1,0} multiply(add.2, bcast.3) - ROOT out = f32[32,32,4,32]{3,2,1,0} multiply(onednn.matmul.0, mul.4) + bcast.3 = f32[32,32,4,32] broadcast(const.3), dimensions={} + mul.4 = f32[32,32,4,32] multiply(add.2, bcast.3) + ROOT out = f32[32,32,4,32] multiply(onednn.matmul.0, mul.4) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -361,35 +362,35 @@ TEST_F(MatmulTest, ApproxGELUTestF32) { // batch=32; seq_len=32; hidden_size=64; intermediate_size=256 TEST_F(MatmulTest, BiasAndApproxGELUTestF32) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[32,32,64]{2,1,0}, f32[64,256]{1,0}, f32[256]{0})->f32[32,32,256]{2,1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.f32 { - Arg_5.6 = f32[32,32,64]{2,1,0} parameter(0), sharding={replicated} - Arg_7.8 = f32[64,256]{1,0} parameter(1), sharding={replicated} - dot.232 = f32[32,32,256]{2,1,0} dot(Arg_5.6, Arg_7.8), lhs_contracting_dims={2}, rhs_contracting_dims={0} - Arg_6.7 = f32[256]{0} parameter(2), sharding={replicated} - reshape.233 = f32[1,1,256]{2,1,0} reshape(Arg_6.7) - broadcast.234 = f32[1,1,256]{2,1,0} broadcast(reshape.233), dimensions={0,1,2} - reshape.235 = f32[256]{0} reshape(broadcast.234) - broadcast.236 = f32[32,32,256]{2,1,0} broadcast(reshape.235), dimensions={2} - add.237 = f32[32,32,256]{2,1,0} add(dot.232, broadcast.236) - multiply.238 = f32[32,32,256]{2,1,0} multiply(add.237, add.237) - multiply.239 = f32[32,32,256]{2,1,0} multiply(add.237, multiply.238) + Arg_5.6 = f32[32,32,64] parameter(0), sharding={replicated} + Arg_7.8 = f32[64,256] parameter(1), sharding={replicated} + dot.232 = f32[32,32,256] dot(Arg_5.6, Arg_7.8), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_6.7 = f32[256] parameter(2), sharding={replicated} + reshape.233 = f32[1,1,256] reshape(Arg_6.7) + broadcast.234 = f32[1,1,256] broadcast(reshape.233), dimensions={0,1,2} + reshape.235 = f32[256] reshape(broadcast.234) + broadcast.236 = f32[32,32,256] broadcast(reshape.235), dimensions={2} + add.237 = f32[32,32,256] add(dot.232, broadcast.236) + multiply.238 = f32[32,32,256] multiply(add.237, add.237) + multiply.239 = f32[32,32,256] multiply(add.237, multiply.238) constant.20 = f32[] constant(0.044715) - broadcast.21 = f32[32,32,256]{2,1,0} broadcast(constant.20), dimensions={} - multiply.240 = f32[32,32,256]{2,1,0} multiply(multiply.239, broadcast.21) - add.241 = f32[32,32,256]{2,1,0} add(add.237, multiply.240) + broadcast.21 = f32[32,32,256] broadcast(constant.20), dimensions={} + multiply.240 = f32[32,32,256] multiply(multiply.239, broadcast.21) + add.241 = f32[32,32,256] add(add.237, multiply.240) constant.18 = f32[] constant(0.797884583) - broadcast.19 = f32[32,32,256]{2,1,0} broadcast(constant.18), dimensions={} - multiply.242 = f32[32,32,256]{2,1,0} multiply(add.241, broadcast.19) - tanh.243 = f32[32,32,256]{2,1,0} tanh(multiply.242) + broadcast.19 = f32[32,32,256] broadcast(constant.18), dimensions={} + multiply.242 = f32[32,32,256] multiply(add.241, broadcast.19) + tanh.243 = f32[32,32,256] tanh(multiply.242) constant.16 = f32[] constant(1) - broadcast.17 = f32[32,32,256]{2,1,0} broadcast(constant.16), dimensions={} - add.244 = f32[32,32,256]{2,1,0} add(tanh.243, broadcast.17) + broadcast.17 = f32[32,32,256] broadcast(constant.16), dimensions={} + add.244 = f32[32,32,256] add(tanh.243, broadcast.17) constant.14 = f32[] constant(0.5) - broadcast.15 = f32[32,32,256]{2,1,0} broadcast(constant.14), dimensions={} - multiply.245 = f32[32,32,256]{2,1,0} multiply(add.244, broadcast.15) - ROOT out = f32[32,32,256]{2,1,0} multiply(add.237, multiply.245) + broadcast.15 = f32[32,32,256] broadcast(constant.14), dimensions={} + multiply.245 = f32[32,32,256] multiply(add.244, broadcast.15) + ROOT out = f32[32,32,256] multiply(add.237, multiply.245) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -407,20 +408,20 @@ TEST_F(MatmulTest, BiasAndApproxGELUTestF32) { TEST_F(MatmulTest, ReLUTestF32) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[32,32,4,16]{3,2,1,0},f32[32,32,16,32]{3,2,1,0})->f32[32,32,4,32]{3,2,1,0}} + HloModule matmul.test.f32 relu.1 { - Arg_0.3 = f32[32,32,4,32]{3,2,1,0} parameter(0) + Arg_0.3 = f32[32,32,4,32] parameter(0) constant.4 = f32[] constant(0) - broadcast.5 = f32[32,32,4,32]{3,2,1,0} broadcast(constant.4), dimensions={} - ROOT maximum.6 = f32[32,32,4,32]{3,2,1,0} maximum(Arg_0.3, broadcast.5) + broadcast.5 = f32[32,32,4,32] broadcast(constant.4), dimensions={} + ROOT maximum.6 = f32[32,32,4,32] maximum(Arg_0.3, broadcast.5) } ENTRY matmul.test.f32 { - arg.0 = f32[32,32,4,16]{3,2,1,0} parameter(0), parameter_replication={false} - arg.1 = f32[32,32,16,32]{3,2,1,0} parameter(1), parameter_replication={false} - onednn.matmul.0 = f32[32,32,4,32]{3,2,1,0} dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - ROOT call.7 = f32[32,32,4,32]{3,2,1,0} call(onednn.matmul.0), to_apply=relu.1 + arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} + arg.1 = f32[32,32,16,32] parameter(1), parameter_replication={false} + onednn.matmul.0 = f32[32,32,4,32] dot(arg.0, arg.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + ROOT call.7 = f32[32,32,4,32] call(onednn.matmul.0), to_apply=relu.1 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -442,20 +443,21 @@ TEST_F(MatmulTest, SimpleBiasTestBF16_PARAM_F32) { } const char* matmul_module_str = R"( - HloModule jit_apply, entry_computation_layout={(f32[3072]{0}, f32[768,3072]{1,0}, f32[16,128,768]{2,1,0})->bf16[16,128,3072]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + HloModule jit_apply + ENTRY matmul.test.bf16 { - Arg_2.3 = f32[16,128,768]{2,1,0} parameter(2), sharding={replicated} - convert.4 = bf16[16,128,768]{2,1,0} convert(Arg_2.3) - Arg_1.2 = f32[768,3072]{1,0} parameter(1), sharding={replicated} - convert.5 = bf16[768,3072]{1,0} convert(Arg_1.2) - dot.7 = bf16[16,128,3072]{2,1,0} dot(convert.4, convert.5), lhs_contracting_dims={2}, rhs_contracting_dims={0} - Arg_0.1 = f32[3072]{0} parameter(0), sharding={replicated} - convert.6 = bf16[3072]{0} convert(Arg_0.1) - reshape.8 = bf16[1,1,3072]{2,1,0} reshape(convert.6) - broadcast.9 = bf16[1,1,3072]{2,1,0} broadcast(reshape.8), dimensions={0,1,2} - reshape.10 = bf16[3072]{0} reshape(broadcast.9) - broadcast.11 = bf16[16,128,3072]{2,1,0} broadcast(reshape.10), dimensions={2} - ROOT add.12 = bf16[16,128,3072]{2,1,0} add(dot.7, broadcast.11) + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = bf16[16,128,768] convert(Arg_2.3) + Arg_1.2 = f32[768,3072] parameter(1), sharding={replicated} + convert.5 = bf16[768,3072] convert(Arg_1.2) + dot.7 = bf16[16,128,3072] dot(convert.4, convert.5), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = f32[3072] parameter(0), sharding={replicated} + convert.6 = bf16[3072] convert(Arg_0.1) + reshape.8 = bf16[1,1,3072] reshape(convert.6) + broadcast.9 = bf16[1,1,3072] broadcast(reshape.8), dimensions={0,1,2} + reshape.10 = bf16[3072] reshape(broadcast.9) + broadcast.11 = bf16[16,128,3072] broadcast(reshape.10), dimensions={2} + ROOT add.12 = bf16[16,128,3072] add(dot.7, broadcast.11) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); @@ -468,18 +470,19 @@ TEST_F(MatmulTest, SimpleBiasTestBF16_PARAM_BF16) { } const char* matmul_module_str = R"( - HloModule jit_apply, entry_computation_layout={(bf16[3072]{0}, bf16[768,3072]{1,0}, f32[16,128,768]{2,1,0})->bf16[16,128,3072]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + HloModule jit_apply + ENTRY matmul.test.bf16 { - Arg_2.3 = f32[16,128,768]{2,1,0} parameter(2), sharding={replicated} - convert.4 = bf16[16,128,768]{2,1,0} convert(Arg_2.3) - Arg_1.2 = bf16[768,3072]{1,0} parameter(1), sharding={replicated} - dot.5 = bf16[16,128,3072]{2,1,0} dot(convert.4, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} - Arg_0.1 = bf16[3072]{0} parameter(0), sharding={replicated} - reshape.6 = bf16[1,1,3072]{2,1,0} reshape(Arg_0.1) - broadcast.7 = bf16[1,1,3072]{2,1,0} broadcast(reshape.6), dimensions={0,1,2} - reshape.8 = bf16[3072]{0} reshape(broadcast.7) - broadcast.9 = bf16[16,128,3072]{2,1,0} broadcast(reshape.8), dimensions={2} - ROOT add.10 = bf16[16,128,3072]{2,1,0} add(dot.5, broadcast.9) + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = bf16[16,128,768] convert(Arg_2.3) + Arg_1.2 = bf16[768,3072] parameter(1), sharding={replicated} + dot.5 = bf16[16,128,3072] dot(convert.4, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = bf16[3072] parameter(0), sharding={replicated} + reshape.6 = bf16[1,1,3072] reshape(Arg_0.1) + broadcast.7 = bf16[1,1,3072] broadcast(reshape.6), dimensions={0,1,2} + reshape.8 = bf16[3072] reshape(broadcast.7) + broadcast.9 = bf16[16,128,3072] broadcast(reshape.8), dimensions={2} + ROOT add.10 = bf16[16,128,3072] add(dot.5, broadcast.9) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); @@ -488,14 +491,15 @@ TEST_F(MatmulTest, SimpleBiasTestBF16_PARAM_BF16) { TEST_F(MatmulTest, DivisionByConstantWithEltwiseLinearF32) { const char* matmul_module_str = R"( - HloModule matmul.divide.test.1, entry_computation_layout={(f32[16,128,768]{2,1,0}, f32[768,12,64]{2,1,0})->f32[16,128,12,64]{3,2,1,0}} + HloModule matmul.divide.test.1 + ENTRY matmul.divide.test.f32 { - Arg_4.5 = f32[16,128,768]{2,1,0} parameter(0), sharding={replicated} - Arg_2.3 = f32[768,12,64]{2,1,0} parameter(1), sharding={replicated} - onednn.matmul.0 = f32[16,128,12,64]{3,2,1,0} dot(Arg_4.5, Arg_2.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_4.5 = f32[16,128,768] parameter(0), sharding={replicated} + Arg_2.3 = f32[768,12,64] parameter(1), sharding={replicated} + onednn.matmul.0 = f32[16,128,12,64] dot(Arg_4.5, Arg_2.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} constant.8 = f32[] constant(8) - broadcast.9 = f32[16,128,12,64]{3,2,1,0} broadcast(constant.8), dimensions={} - ROOT divide.16 = f32[16,128,12,64]{3,2,1,0} divide(onednn.matmul.0, broadcast.9) + broadcast.9 = f32[16,128,12,64] broadcast(constant.8), dimensions={} + ROOT divide.16 = f32[16,128,12,64] divide(onednn.matmul.0, broadcast.9) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec(1e-4, 1e-4))); @@ -517,20 +521,21 @@ TEST_F(MatmulTest, SimpleBiasTestF16_PARAM_F32) { } const char* matmul_module_str = R"( - HloModule jit_apply, entry_computation_layout={(f32[3072]{0}, f32[768,3072]{1,0}, f32[16,128,768]{2,1,0})->f16[16,128,3072]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + HloModule jit_apply + ENTRY matmul.test.f16 { - Arg_2.3 = f32[16,128,768]{2,1,0} parameter(2), sharding={replicated} - convert.4 = f16[16,128,768]{2,1,0} convert(Arg_2.3) - Arg_1.2 = f32[768,3072]{1,0} parameter(1), sharding={replicated} - convert.5 = f16[768,3072]{1,0} convert(Arg_1.2) - dot.7 = f16[16,128,3072]{2,1,0} dot(convert.4, convert.5), lhs_contracting_dims={2}, rhs_contracting_dims={0} - Arg_0.1 = f32[3072]{0} parameter(0), sharding={replicated} - convert.6 = f16[3072]{0} convert(Arg_0.1) - reshape.8 = f16[1,1,3072]{2,1,0} reshape(convert.6) - broadcast.9 = f16[1,1,3072]{2,1,0} broadcast(reshape.8), dimensions={0,1,2} - reshape.10 = f16[3072]{0} reshape(broadcast.9) - broadcast.11 = f16[16,128,3072]{2,1,0} broadcast(reshape.10), dimensions={2} - ROOT add.12 = f16[16,128,3072]{2,1,0} add(dot.7, broadcast.11) + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = f16[16,128,768] convert(Arg_2.3) + Arg_1.2 = f32[768,3072] parameter(1), sharding={replicated} + convert.5 = f16[768,3072] convert(Arg_1.2) + dot.7 = f16[16,128,3072] dot(convert.4, convert.5), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = f32[3072] parameter(0), sharding={replicated} + convert.6 = f16[3072] convert(Arg_0.1) + reshape.8 = f16[1,1,3072] reshape(convert.6) + broadcast.9 = f16[1,1,3072] broadcast(reshape.8), dimensions={0,1,2} + reshape.10 = f16[3072] reshape(broadcast.9) + broadcast.11 = f16[16,128,3072] broadcast(reshape.10), dimensions={2} + ROOT add.12 = f16[16,128,3072] add(dot.7, broadcast.11) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); @@ -542,18 +547,19 @@ TEST_F(MatmulTest, SimpleBiasTestF16_PARAM_F16) { GTEST_SKIP() << "CPU does not support F16."; } const char* matmul_module_str = R"( - HloModule jit_apply, entry_computation_layout={(f16[3072]{0}, f16[768,3072]{1,0}, f32[16,128,768]{2,1,0})->f16[16,128,3072]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true} + HloModule jit_apply + ENTRY matmul.test.f16 { - Arg_2.3 = f32[16,128,768]{2,1,0} parameter(2), sharding={replicated} - convert.4 = f16[16,128,768]{2,1,0} convert(Arg_2.3) - Arg_1.2 = f16[768,3072]{1,0} parameter(1), sharding={replicated} - dot.5 = f16[16,128,3072]{2,1,0} dot(convert.4, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} - Arg_0.1 = f16[3072]{0} parameter(0), sharding={replicated} - reshape.6 = f16[1,1,3072]{2,1,0} reshape(Arg_0.1) - broadcast.7 = f16[1,1,3072]{2,1,0} broadcast(reshape.6), dimensions={0,1,2} - reshape.8 = f16[3072]{0} reshape(broadcast.7) - broadcast.9 = f16[16,128,3072]{2,1,0} broadcast(reshape.8), dimensions={2} - ROOT add.10 = f16[16,128,3072]{2,1,0} add(dot.5, broadcast.9) + Arg_2.3 = f32[16,128,768] parameter(2), sharding={replicated} + convert.4 = f16[16,128,768] convert(Arg_2.3) + Arg_1.2 = f16[768,3072] parameter(1), sharding={replicated} + dot.5 = f16[16,128,3072] dot(convert.4, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_0.1 = f16[3072] parameter(0), sharding={replicated} + reshape.6 = f16[1,1,3072] reshape(Arg_0.1) + broadcast.7 = f16[1,1,3072] broadcast(reshape.6), dimensions={0,1,2} + reshape.8 = f16[3072] reshape(broadcast.7) + broadcast.9 = f16[16,128,3072] broadcast(reshape.8), dimensions={2} + ROOT add.10 = f16[16,128,3072] add(dot.5, broadcast.9) })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-2, 1e-2})); @@ -562,12 +568,12 @@ TEST_F(MatmulTest, SimpleBiasTestF16_PARAM_F16) { TEST_F(MatmulTest, TestF32NonConstantWeights) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[64,256,16]{2,1,0},f32[16,32]{1,0})->f32[64,256,32]{2,1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.f32 { - arg.0 = f32[64,256,16]{2,1,0} parameter(0), parameter_replication={false} - arg.1 = f32[16,32]{1,0} parameter(1), parameter_replication={false} - ROOT onednn.matmul.0 = f32[64,256,32]{2,1,0} dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + arg.0 = f32[64,256,16] parameter(0), parameter_replication={false} + arg.1 = f32[16,32] parameter(1), parameter_replication={false} + ROOT onednn.matmul.0 = f32[64,256,32] dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -581,13 +587,13 @@ TEST_F(MatmulTest, TestF32NonConstantWeights) { TEST_F(MatmulTest, TestF32ConstantWeights) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[64,256,16]{2,1,0})->f32[64,256,32]{2,1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.f32 { - arg.0 = f32[64,256,16]{2,1,0} parameter(0), parameter_replication={false} + arg.0 = f32[64,256,16] parameter(0), parameter_replication={false} constant = f32[] constant(1) - arg.1 = f32[16,32]{1,0} broadcast(constant), dimensions={} - ROOT onednn.matmul.0 = f32[64,256,32]{2,1,0} dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + arg.1 = f32[16,32] broadcast(constant), dimensions={} + ROOT onednn.matmul.0 = f32[64,256,32] dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -599,9 +605,45 @@ TEST_F(MatmulTest, TestF32ConstantWeights) { )"); } +TEST_F(MatmulTest, SimpleTestBF16Gemv1) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.test.bf16 + + ENTRY matmul.test.bf16 { + arg.0 = bf16[1000,10000] parameter(0) + arg.1 = bf16[10000] parameter(1) + ROOT onednn.matmul.0 = bf16[1000] dot(arg.0, arg.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{2e-2, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); +} + +TEST_F(MatmulTest, SimpleTestBF16Gemv2) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; + } + + const char* matmul_module_str = R"( + HloModule matmul.test.bf16 + + ENTRY matmul.test.bf16 { + arg.0 = bf16[100,300,300] parameter(0) + arg.1 = bf16[300] parameter(1) + ROOT onednn.matmul.0 = bf16[100,300] dot(arg.0, arg.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} + })"; + + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{2e-2, 1e-4})); + MatchOptimizedHlo(matmul_module_str, matmul_rewrite_str_); +} + TEST_F(MatmulTest, TestTransposeBNoRewriteF32) { const char* matmul_module_str = R"( - HloModule matmul.test.f32, entry_computation_layout={(f32[384,1024]{1,0},f32[2,1024]{1,0})->f32[384,2]{1,0}} + HloModule matmul.test.f32 ENTRY matmul.test.f32 { arg.0 = f32[384,1024]{1,0} parameter(0), parameter_replication={false} From 1593877b000b738a15b15ccf46b036c0ca43a47c Mon Sep 17 00:00:00 2001 From: RJ Ascani Date: Thu, 28 Mar 2024 10:07:47 -0700 Subject: [PATCH 036/124] #shlo_ref: Fix adb typo PiperOrigin-RevId: 619974234 --- tensorflow/lite/experimental/shlo/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/experimental/shlo/README.md b/tensorflow/lite/experimental/shlo/README.md index e4b5e83ce6630c..4cc46da7b13429 100644 --- a/tensorflow/lite/experimental/shlo/README.md +++ b/tensorflow/lite/experimental/shlo/README.md @@ -199,7 +199,7 @@ using ADB. ```sh adb push shlo/ops/op_name_test /data/local/tmp -ash shell /data/local/tmp/op_name_test +adb shell /data/local/tmp/op_name_test ``` #### iOS From 69847fc42297d5310270fb4f5af2927bb13c2579 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Thu, 28 Mar 2024 10:29:45 -0700 Subject: [PATCH 037/124] [XLA:GPU] Skip matrix-matrix multiplication in GemvRewriter PiperOrigin-RevId: 619983009 --- third_party/xla/xla/service/gpu/gemv_rewriter.cc | 5 +++++ .../xla/xla/service/gpu/gemv_rewriter_test.cc | 14 ++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/gemv_rewriter.cc index 67ffd2b81db172..21e5f477e4b059 100644 --- a/third_party/xla/xla/service/gpu/gemv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/gemv_rewriter.cc @@ -75,6 +75,11 @@ class GemvRewriterVisitor : public DfsHloRewriteVisitor { dim_numbers.rhs_batch_dimensions_size() + dim_numbers.rhs_contracting_dimensions_size() + 1; + // Skip matrix-matrix multiplication. + if (lhs_has_non_contracting_dim && rhs_has_non_contracting_dim) { + return absl::OkStatus(); + } + // Skip vector-vector multiplication. if (!lhs_has_non_contracting_dim && !rhs_has_non_contracting_dim) { return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc index 46aee0aab3fb88..2a8b8103e0a94e 100644 --- a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc @@ -111,6 +111,20 @@ TEST_F(GemvRewriterTest, DotNotRewriteVectorVectorMultiplication) { RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt); } +TEST_F(GemvRewriterTest, DotNotRewriteMatrixMatrixMultiplication) { + const char* hlo = R"( + HloModule m + + ENTRY e { + p0 = f32[5,7] parameter(0) + p1 = f32[7,32] parameter(1) + ROOT d = f32[5,32] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"; + + RunAndFilecheckHloRewrite(hlo, GemvRewriter(), /*expected=*/std::nullopt); +} + TEST_F(GemvRewriterTest, DoNotRewriteDotsWithNonNormalizedLayout) { const char* hlo = R"( HloModule m From ddbbd8993263a6176cb42806b07124a76f6392dd Mon Sep 17 00:00:00 2001 From: Peter Gavin Date: Thu, 28 Mar 2024 10:36:44 -0700 Subject: [PATCH 038/124] [xla] remove BorrowingLiteral constructor for LiteralProto PiperOrigin-RevId: 619985533 --- third_party/xla/xla/literal.cc | 103 ---------------------------- third_party/xla/xla/literal.h | 16 ----- third_party/xla/xla/literal_test.cc | 42 ++++++------ 3 files changed, 20 insertions(+), 141 deletions(-) diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index d5364cb848e652..935e6092ac07c8 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -2709,107 +2709,4 @@ BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, } } -BorrowingLiteral::BorrowingLiteral(const LiteralProto& proto) - : LiteralBase(), shape_(std::make_unique(proto.shape())) { - root_piece_ = Piece(); - root_piece_.set_subshape(shape_.get()); - - if (shape().IsArray()) { - absl::Span data; - switch (shape_->element_type()) { -#define BORROWING_LITERAL_CAST_DATA_(FIELD) \ - absl::Span(reinterpret_cast(proto.FIELD().data()), \ - proto.FIELD().size() * sizeof *proto.FIELD().data()) - case PRED: - data = BORROWING_LITERAL_CAST_DATA_(preds); - break; - case S4: - data = proto.s4s(); - break; - case S8: - data = proto.s8s(); - break; - case S16: - data = proto.s16s(); - break; - case S32: - data = BORROWING_LITERAL_CAST_DATA_(s32s); - break; - case S64: - data = BORROWING_LITERAL_CAST_DATA_(s64s); - break; - case U4: - data = proto.u4s(); - break; - case U8: - data = proto.u8s(); - break; - case U16: - data = proto.u16s(); - break; - case U32: - data = BORROWING_LITERAL_CAST_DATA_(u32s); - break; - case U64: - data = BORROWING_LITERAL_CAST_DATA_(u64s); - break; - case F16: - data = proto.f16s(); - break; - case F32: - data = BORROWING_LITERAL_CAST_DATA_(f32s); - break; - case BF16: - data = proto.bf16s(); - break; - case F64: - data = BORROWING_LITERAL_CAST_DATA_(f64s); - break; - case F8E5M2: - data = proto.f8e5m2s(); - break; - case F8E4M3FN: - data = proto.f8e4m3fns(); - break; - case F8E4M3B11FNUZ: - data = proto.f8e4m3b11fnuzs(); - break; - case F8E5M2FNUZ: - data = proto.f8e5m2fnuzs(); - break; - case F8E4M3FNUZ: - data = proto.f8e4m3fnuzs(); - break; - case C64: - data = BORROWING_LITERAL_CAST_DATA_(c64s); - break; - case C128: - data = BORROWING_LITERAL_CAST_DATA_(c128s); - break; -#undef BORROWING_LITERAL_CAST_DATA_ - default: - LOG(FATAL) << "Invalid element type for array: " << shape(); - } - CHECK_EQ(data.size(), ShapeUtil::ByteSizeOfElements(*shape_)); - root_piece_.set_buffer(const_cast(data.data())); - } else if (shape_->IsTuple()) { - CHECK_EQ(shape().tuple_shapes_size(), proto.tuple_literals_size()); - BuildPieceSubtree(*shape_, &root_piece_); - for (int i = 0; i < shape_->tuple_shapes_size(); ++i) { - BorrowingLiteral child(proto.tuple_literals(i)); - child.root_piece_.ForEachMutableSubpiece( - [&](const ShapeIndex& child_index, Piece* child_piece) { - if (!child_piece->subshape().IsArray()) { - return; - } - ShapeIndex index = {i}; - index.insert(index.end(), child_index.begin(), child_index.end()); - root_piece_.child(index).set_buffer(child_piece->buffer()); - }); - } - } else { - LOG(FATAL) << "Invalid shape: " << *shape_; - } -} - } // namespace xla diff --git a/third_party/xla/xla/literal.h b/third_party/xla/xla/literal.h index fea38633994652..2ebe0c2d727174 100644 --- a/third_party/xla/xla/literal.h +++ b/third_party/xla/xla/literal.h @@ -912,18 +912,6 @@ class LiteralBase { return tuple_rep->children[index]; } - Piece& child(ShapeIndexView index) { - return const_cast(const_cast(this)->child(index)); - } - const Piece& child(ShapeIndexView index) const { - const Piece* result = this; - while (!index.empty()) { - result = &result->child(index.front()); - index.remove_prefix(1); - } - return *result; - } - // Adds a child piece to this piece's children. void emplace_back(Piece child_piece) { auto* tuple_rep = GetTupleRep(); @@ -1590,10 +1578,6 @@ class BorrowingLiteral : public LiteralBase { const Shape& shape); // TODO(b/79707221): adding constructors for nested tuples as well. - // Construct a BorrowingLiteral from a LiteralProto. The proto must not be - // modified during the lifetime of the BorrowingLiteral. - explicit BorrowingLiteral(const LiteralProto& proto); - private: // Recursively builds the subtree for the given piece and sets the subshapes // of the given piece with the given shape. diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index e61ba78d95e2ba..24c12bb92d6d84 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -2197,32 +2197,30 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal}); - auto test_proto = [](const Literal& literal) { - LiteralProto proto = literal.ToProto(); - EXPECT_EQ(literal, Literal::CreateFromProto(proto).value()); - EXPECT_EQ(literal, BorrowingLiteral(proto)); + auto to_from_proto = [](const Literal& literal) -> Literal { + return Literal::CreateFromProto(literal.ToProto()).value(); }; - test_proto(one_f32); - test_proto(vector_int8); - test_proto(vector_uint8); - test_proto(vector_c64); - test_proto(vector_c128); - test_proto(vector_bfloat16); - test_proto(vector_f8e5m2); - test_proto(vector_f8e4m3); - test_proto(vector_f8e4m3b11); - test_proto(vector_f8e5m2fnuz); - test_proto(vector_f8e4m3fnuz); - test_proto(matrix_pred); - test_proto(vector_s4); - test_proto(vector_u4); - test_proto(tuple); - test_proto(nested_tuple); - test_proto(nil_literal); + EXPECT_EQ(one_f32, to_from_proto(one_f32)); + EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); + EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); + EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); + EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); + EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); + EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); + EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); + EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); + EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); + EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); + EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); + EXPECT_EQ(tuple, to_from_proto(tuple)); + EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple)); + EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); EXPECT_NE(one_f32, two_f32); - EXPECT_NE(one_f32, Literal::CreateFromProto(two_f32.ToProto()).value()); + EXPECT_NE(one_f32, to_from_proto(two_f32)); } TEST_F(LiteralUtilTest, InvalidProtoNoValues) { From 354e90c4d0619e9bf8a0b149790df682356dfb95 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Mar 2024 10:49:07 -0700 Subject: [PATCH 039/124] [PJRT:GPU] Fix implementation of .compute_capability on devices to remove the need for an ifdef. PiperOrigin-RevId: 619989520 --- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index c57fe404badb38..9dec918f993274 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/thread_annotations.h" @@ -38,8 +39,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -994,15 +995,16 @@ Status BuildDistributedDevices( } std::string MakeComputeCapabilityString(const se::DeviceDescription* desc) { - std::string compute_capability; -#if GOOGLE_CUDA - se::CudaComputeCapability cc = desc->cuda_compute_capability(); - compute_capability = - std::to_string(cc.major) + "." + std::to_string(cc.minor); -#else // GOOGLE_CUDA - compute_capability = desc->rocm_compute_capability().gfx_version(); -#endif // GOOGLE_CUDA - return compute_capability; + se::GpuComputeCapability cc = desc->gpu_compute_capability(); + if (std::holds_alternative(cc)) { + auto nvcc = std::get(cc); + return absl::StrCat(nvcc.major, ".", nvcc.minor); + } else if (std::holds_alternative(cc)) { + auto rocmcc = std::get(cc); + return rocmcc.gfx_version(); + } else { + return "unknown"; + } } StreamExecutorGpuDevice::StreamExecutorGpuDevice( From 8b681670751c6b2dcc4c49e25a1eb0409a664450 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 28 Mar 2024 10:59:48 -0700 Subject: [PATCH 040/124] [stream_executor:host] Add missing externs to C API header PiperOrigin-RevId: 619992965 --- .../xla/xla/stream_executor/host/host_kernel_c_api.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h b/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h index 6564bb49e58f22..6768706abc2800 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel_c_api.h @@ -23,6 +23,10 @@ limitations under the License. // StreamExecutor Host Kernel API //===----------------------------------------------------------------------===// +#ifdef __cplusplus +extern "C" { +#endif + // StreamExecutor host kernel API is an integration point between a codegen // backend and a runtime. XLA:CPU backend compiles fusion regions to native // functions (via LLVM backend) that are compatible with a kernel API (and ABI), @@ -77,4 +81,8 @@ typedef struct SE_HOST_KernelError SE_HOST_KernelError; typedef SE_HOST_KernelError* SE_HOST_Kernel( const SE_HOST_KernelCallFrame* call_frame); +#ifdef __cplusplus +} +#endif + #endif // XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_C_API_H_ From 22e22f7a9af07ff2fc4a0c00de61909af1df9825 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 11:13:15 -0700 Subject: [PATCH 041/124] Logs a fatal message if Auto Sharding times out (since we no longer rely on GSPMD as a backup). PiperOrigin-RevId: 619997588 --- .../hlo/experimental/auto_sharding/auto_sharding.cc | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 49370d832b393f..ef7420bb384f4b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -4101,16 +4101,7 @@ absl::StatusOr AutoSharding::Run( absl::StatusOr module_is_changed; if (skip_auto_sharding) { - VLOG(1) << "Solver timed out. Will now rely on sharding propagation to " - "perform sharding."; - if (!ModuleHasUserShardings(module)) { - LOG(WARNING) - << "The auto-sharding solver has timed out without a solution. " - "Further, as the input module does not contain any sharding " - "annotations, we cannot rely on sharding propagation to perform " - "heuristic-guided sharding. The module therefore may not be " - "sharded leading to low performance."; - } + LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; module_is_changed = false; } else { std::string trying_to_find; From 81d3c5c07a6b4f9488640924eda33a20b54f2336 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 11:23:34 -0700 Subject: [PATCH 042/124] [XLA] add a setter for HloInputOutputAliasConfig and HloBufferDonorConfig This is needed to move those configs from an old HloModule to a new one. PiperOrigin-RevId: 620001061 --- third_party/xla/xla/hlo/ir/hlo_module.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index 6311929ac533a5..0f6d449127e68d 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -505,6 +505,9 @@ class HloModule { const HloInputOutputAliasConfig& input_output_alias_config() const { return input_output_alias_config_; } + void set_input_output_alias_config(HloInputOutputAliasConfig config) { + input_output_alias_config_ = std::move(config); + } // buffer_donor_config_ indicates the set of input buffer donors that are // expected from the module. @@ -512,6 +515,9 @@ class HloModule { const HloBufferDonorConfig& buffer_donor_config() const { return buffer_donor_config_; } + void set_buffer_donor_config(HloBufferDonorConfig config) { + buffer_donor_config_ = std::move(config); + } // Returns an id that is unique to this module across all modules created over // the lifetime of this process. From 5742cd596308506cc5dc1e7ceff102cd0b11bd7e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 11:34:59 -0700 Subject: [PATCH 043/124] Change the default gpu loop value to 1 PiperOrigin-RevId: 620005388 --- tensorflow/lite/delegates/gpu/api.h | 4 +++- tensorflow/lite/delegates/gpu/cl/api.cc | 4 ++++ tensorflow/lite/delegates/gpu/delegate_options.cc | 2 +- tensorflow/lite/delegates/gpu/delegate_options.h | 2 ++ tensorflow/lite/tools/benchmark/benchmark_model.cc | 10 +++++----- .../tools/benchmark/benchmark_performance_options.cc | 4 ++-- .../lite/tools/delegates/default_execution_provider.cc | 2 +- 7 files changed, 18 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/api.h b/tensorflow/lite/delegates/gpu/api.h index 7b0d8312e2b154..535f077bf0f8b7 100644 --- a/tensorflow/lite/delegates/gpu/api.h +++ b/tensorflow/lite/delegates/gpu/api.h @@ -372,7 +372,9 @@ struct InferenceOptions { InferencePriority priority3 = InferencePriority::AUTO; #ifdef TFLITE_GPU_ENABLE_INVOKE_LOOP - int gpu_invoke_loop_times = -1; + // Number of times to invoke the inference in GPU delegate, to collect more + // accurate latency result. Default as 1, which is the original behavior. + int gpu_invoke_loop_times = 1; #endif }; diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index 08cf817da685b9..dc9a8d3a6ca649 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -545,6 +545,10 @@ class InferenceRunnerImpl : public CLInferenceRunner { #ifdef TFLITE_GPU_ENABLE_INVOKE_LOOP // TODO(b/328511338): Remove code enabled by TFLITE_GPU_ENABLE_INVOKE_LOOP // when Async API solution is ready to replace it. + if (gpu_invoke_loop_times_ <= 0) { + return absl::InvalidArgumentError( + "gpu_invoke_loop_times must be positive"); + } for (int i = 0; i < gpu_invoke_loop_times_; i++) { RETURN_IF_ERROR(RunWithoutExternalBufferCopy()); } diff --git a/tensorflow/lite/delegates/gpu/delegate_options.cc b/tensorflow/lite/delegates/gpu/delegate_options.cc index 7b7059df37e4d1..e596045e2a9d47 100644 --- a/tensorflow/lite/delegates/gpu/delegate_options.cc +++ b/tensorflow/lite/delegates/gpu/delegate_options.cc @@ -35,7 +35,7 @@ TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default() { options.last_delegate_node_index = std::numeric_limits::max(); #endif #ifdef TFLITE_GPU_ENABLE_INVOKE_LOOP - options.gpu_invoke_loop_times = -1; + options.gpu_invoke_loop_times = 1; #endif return options; } diff --git a/tensorflow/lite/delegates/gpu/delegate_options.h b/tensorflow/lite/delegates/gpu/delegate_options.h index 98f1347bb7febe..b52d45c8abd5b1 100644 --- a/tensorflow/lite/delegates/gpu/delegate_options.h +++ b/tensorflow/lite/delegates/gpu/delegate_options.h @@ -144,6 +144,8 @@ typedef struct { int last_delegate_node_index; #endif #ifdef TFLITE_GPU_ENABLE_INVOKE_LOOP + // Number of times to invoke the inference in GPU delegate, to collect more + // accurate latency result. Default as 1, which is the original behavior. int gpu_invoke_loop_times; #endif } TfLiteGpuDelegateOptionsV2; diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc index 9fd9aac8508f90..77b20d1b9b1db8 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc @@ -73,7 +73,7 @@ BenchmarkParams BenchmarkModel::DefaultParams() { BenchmarkParam::Create(false)); params.AddParam("memory_footprint_check_interval_ms", BenchmarkParam::Create(kMemoryCheckIntervalMs)); - params.AddParam("gpu_invoke_loop_times", BenchmarkParam::Create(-1)); + params.AddParam("gpu_invoke_loop_times", BenchmarkParam::Create(1)); return params; } @@ -207,10 +207,10 @@ void BenchmarkModel::LogParams() { LOG_BENCHMARK_PARAM(int32_t, "memory_footprint_check_interval_ms", "Memory footprint check interval (ms)", verbose); #ifdef TFLITE_GPU_ENABLE_INVOKE_LOOP - LOG_BENCHMARK_PARAM( - int32_t, "gpu_invoke_loop_times", - "Number of GPU delegate invoke loop iterations to divide latency by", - verbose); + LOG_BENCHMARK_PARAM(int32_t, "gpu_invoke_loop_times", + "Number of GPU delegate invoke loop iterations. Latency " + "will be divided by it.", + verbose); #endif } diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc index ecce8cb1572046..1645c7ec50fd4c 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc @@ -149,7 +149,7 @@ BenchmarkParams BenchmarkPerformanceOptions::DefaultParams() { BenchmarkParam::Create(-1.0f)); params.AddParam("random_shuffle_benchmark_runs", BenchmarkParam::Create(true)); - params.AddParam("gpu_invoke_loop_times", BenchmarkParam::Create(-1)); + params.AddParam("gpu_invoke_loop_times", BenchmarkParam::Create(1)); return params; } @@ -250,7 +250,7 @@ void BenchmarkPerformanceOptions::ResetPerformanceOptions() { single_option_run_params_->Set("num_threads", 1); single_option_run_params_->Set("use_gpu", false); #ifdef TFLITE_GPU_ENABLE_INVOKE_LOOP - single_option_run_params_->Set("gpu_invoke_loop_times", -1); + single_option_run_params_->Set("gpu_invoke_loop_times", 1); single_option_run_params_->Set("require_full_delegation", false); #endif #if defined(__ANDROID__) diff --git a/tensorflow/lite/tools/delegates/default_execution_provider.cc b/tensorflow/lite/tools/delegates/default_execution_provider.cc index 93dc242a8e2bfd..38e8fa56632784 100644 --- a/tensorflow/lite/tools/delegates/default_execution_provider.cc +++ b/tensorflow/lite/tools/delegates/default_execution_provider.cc @@ -81,7 +81,7 @@ std::vector DefaultExecutionProvider::CreateFlags( CreateFlag( "gpu_invoke_loop_times", params, "Number of GPU delegate invoke loop iterations. Used only when " - "TFLITE_GPU_ENABLE_INVOKE_LOOP is defined. Default is -1."), + "TFLITE_GPU_ENABLE_INVOKE_LOOP is defined. Default is 1."), CreateFlag( "delegate_serialize_dir", params, "Directory to be used by delegates for serializing any model data. " From ce8999191a1d1c61a9190c37e42b4ada29431a2d Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 28 Mar 2024 11:39:01 -0700 Subject: [PATCH 044/124] NFC: Add `@llvm-project//mlir:TransformUtils` dependency as preparation for https://github.com/llvm/llvm-project/pull/86819. PiperOrigin-RevId: 620006777 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../compiler/mlir/lite/experimental/tac/BUILD | 2 + tensorflow/compiler/mlir/lite/stablehlo/BUILD | 10 +++ .../mlir/quantization/stablehlo/BUILD | 1 + .../mlir/quantization/tensorflow/cc/BUILD | 1 + .../compiler/mlir/tensorflow/transforms/BUILD | 2 + .../compiler/mlir/tf2xla/transforms/BUILD | 1 + tensorflow/compiler/mlir/tfrt/BUILD | 4 + tensorflow/compiler/mlir/tfrt/ir/BUILD | 2 + tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD | 1 + .../compiler/mlir/tools/kernel_gen/BUILD | 1 + .../mlir/tools/kernel_gen/transforms/BUILD | 5 ++ tensorflow/compiler/mlir/tosa/BUILD | 4 + tensorflow/compiler/tf2xla/BUILD | 1 + .../core/transforms/constant_folding/BUILD | 1 + tensorflow/core/transforms/remapper/BUILD | 1 + tensorflow/dtensor/mlir/BUILD | 1 + third_party/triton/cl619443019.patch | 76 +++++++++++++++++++ third_party/triton/workspace.bzl | 1 + .../xla/third_party/triton/cl619443019.patch | 76 +++++++++++++++++++ .../xla/third_party/triton/workspace.bzl | 1 + .../xla/xla/mlir/runtime/transforms/BUILD | 3 + third_party/xla/xla/mlir_hlo/BUILD | 9 +++ 23 files changed, 205 insertions(+) create mode 100644 third_party/triton/cl619443019.patch create mode 100644 third_party/xla/third_party/triton/cl619443019.patch diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index d684d1f3a5c27c..f535d0d1aaea2e 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -742,6 +742,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:status", "@local_xla//xla:statusor", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index 75228b1bc607bc..1c5a0703d0a58a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -199,6 +199,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -268,6 +269,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index b6abb996f47837..bd83f16de105f8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -86,6 +86,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -109,6 +110,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", ], @@ -138,6 +140,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", @@ -315,6 +318,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -339,6 +343,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -367,6 +372,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -419,6 +425,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", @@ -514,6 +521,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", ], @@ -537,6 +545,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], @@ -676,6 +685,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", "@stablehlo//:broadcast_utils", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index bf1d249bf05cb5..4998c87f70febe 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -334,6 +334,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/mlir_hlo", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 62a6f27c8ad5f1..23ce2105634854 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -183,6 +183,7 @@ tf_cc_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index d44bd428bd9456..c84871fd564156 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -60,6 +60,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -337,6 +338,7 @@ cc_library( "@llvm-project//mlir:MLProgramDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 1250206ef80cea..b76b52c9fd774a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -287,6 +287,7 @@ cc_library( "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:shape_util", "@local_xla//xla:side_effect_util", diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 2fa322d47c5c58..21cdf1203a3554 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -145,6 +145,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", @@ -165,6 +166,7 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", @@ -180,6 +182,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -254,6 +257,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index 92a1dc3d2757cb..88baa91f6de0d0 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -52,8 +52,10 @@ cc_library( ":tfrt_fallback_opdefs", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:compiler_tfrt_op_interfaces", diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index cf64a37c2a696c..ce69fa85189423 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -167,6 +167,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_side_effects", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Transforms", "@tf_runtime//:compiler_tfrt_op_interfaces", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 59dabbafe2de88..86e2e269e4d329 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -89,6 +89,7 @@ cc_library( "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:all_passes", # fixdeps: keep diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index f94d1adf2b0766..c4abb6420d9b38 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -39,6 +39,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -57,6 +58,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:chlo_ops", ], @@ -73,6 +75,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -148,6 +151,7 @@ cc_library( "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", @@ -216,6 +220,7 @@ cc_library( "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo:transforms_passes", ], diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index e25d2229c605c8..a7d9610a472308 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -102,6 +102,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/framework/fixedpoint", ], @@ -157,6 +158,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -219,6 +221,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -248,6 +251,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 6ee27ba27a8345..16315f71d9652c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -303,6 +303,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@ducc//:fft_wrapper", "@eigen_archive//:eigen3", + "@llvm-project//mlir:TransformUtils", "@local_xla//xla:empty", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/platform:bfloat16", diff --git a/tensorflow/core/transforms/constant_folding/BUILD b/tensorflow/core/transforms/constant_folding/BUILD index e64e9d868f2677..ba1bc56b1749d2 100644 --- a/tensorflow/core/transforms/constant_folding/BUILD +++ b/tensorflow/core/transforms/constant_folding/BUILD @@ -30,6 +30,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/core/transforms/remapper/BUILD b/tensorflow/core/transforms/remapper/BUILD index 0c348a93f6d723..38b09b8bf601cb 100644 --- a/tensorflow/core/transforms/remapper/BUILD +++ b/tensorflow/core/transforms/remapper/BUILD @@ -48,6 +48,7 @@ cc_library( "@llvm-project//mlir:PDLInterpDialect", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index 063d78221644bb..18a1057808e981 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -267,6 +267,7 @@ cc_library( "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:status", "@local_xla//xla:xla_data_proto_cc", diff --git a/third_party/triton/cl619443019.patch b/third_party/triton/cl619443019.patch new file mode 100644 index 00000000000000..95ce54b6e4d6aa --- /dev/null +++ b/third_party/triton/cl619443019.patch @@ -0,0 +1,76 @@ +==== triton/BUILD#44 - /google/src/cloud/csigg/mlir_transform_utils/triton/BUILD ==== +# action=edit type=text +--- triton/BUILD 2024-03-22 08:02:38.000000000 -0700 ++++ triton/BUILD 2024-03-27 01:34:43.000000000 -0700 +@@ -620,6 +620,7 @@ + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", +@@ -628,6 +629,7 @@ + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + ) +@@ -646,6 +648,7 @@ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +@@ -729,6 +732,7 @@ + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + ) +@@ -780,6 +784,7 @@ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + ) +==== triton/test/BUILD#18 - /google/src/cloud/csigg/mlir_transform_utils/triton/test/BUILD ==== +# action=edit type=text +--- triton/test/BUILD 2024-03-11 11:42:57.000000000 -0700 ++++ triton/test/BUILD 2024-03-27 01:32:04.000000000 -0700 +@@ -53,6 +53,7 @@ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", +==== triton/third_party/nvidia/BUILD#3 - /google/src/cloud/csigg/mlir_transform_utils/triton/third_party/nvidia/BUILD ==== +# action=edit type=text +--- triton/third_party/nvidia/BUILD 2024-03-11 11:42:57.000000000 -0700 ++++ triton/third_party/nvidia/BUILD 2024-03-27 01:32:46.000000000 -0700 +@@ -66,6 +66,7 @@ + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +@@ -113,6 +114,7 @@ + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index ce8d828c2fc64b..2773b250ac8554 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -17,5 +17,6 @@ def repo(): "//third_party/triton:cl607293980.patch", # long standing :( "//third_party/triton:cl617812302.patch", "//third_party/triton:cl619146327.patch", + "//third_party/triton:cl619443019.patch", ], ) diff --git a/third_party/xla/third_party/triton/cl619443019.patch b/third_party/xla/third_party/triton/cl619443019.patch new file mode 100644 index 00000000000000..95ce54b6e4d6aa --- /dev/null +++ b/third_party/xla/third_party/triton/cl619443019.patch @@ -0,0 +1,76 @@ +==== triton/BUILD#44 - /google/src/cloud/csigg/mlir_transform_utils/triton/BUILD ==== +# action=edit type=text +--- triton/BUILD 2024-03-22 08:02:38.000000000 -0700 ++++ triton/BUILD 2024-03-27 01:34:43.000000000 -0700 +@@ -620,6 +620,7 @@ + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", +@@ -628,6 +629,7 @@ + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + ) +@@ -646,6 +648,7 @@ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +@@ -729,6 +732,7 @@ + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + ) +@@ -780,6 +784,7 @@ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + ) +==== triton/test/BUILD#18 - /google/src/cloud/csigg/mlir_transform_utils/triton/test/BUILD ==== +# action=edit type=text +--- triton/test/BUILD 2024-03-11 11:42:57.000000000 -0700 ++++ triton/test/BUILD 2024-03-27 01:32:04.000000000 -0700 +@@ -53,6 +53,7 @@ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", +==== triton/third_party/nvidia/BUILD#3 - /google/src/cloud/csigg/mlir_transform_utils/triton/third_party/nvidia/BUILD ==== +# action=edit type=text +--- triton/third_party/nvidia/BUILD 2024-03-11 11:42:57.000000000 -0700 ++++ triton/third_party/nvidia/BUILD 2024-03-27 01:32:46.000000000 -0700 +@@ -66,6 +66,7 @@ + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +@@ -113,6 +114,7 @@ + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", ++ "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index ce8d828c2fc64b..2773b250ac8554 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -17,5 +17,6 @@ def repo(): "//third_party/triton:cl607293980.patch", # long standing :( "//third_party/triton:cl617812302.patch", "//third_party/triton:cl619146327.patch", + "//third_party/triton:cl619443019.patch", ], ) diff --git a/third_party/xla/xla/mlir/runtime/transforms/BUILD b/third_party/xla/xla/mlir/runtime/transforms/BUILD index 2de0bcb1c4f878..fa29f9ec740de2 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/BUILD +++ b/third_party/xla/xla/mlir/runtime/transforms/BUILD @@ -80,6 +80,7 @@ cc_library( deps = [ "//xla/mlir/runtime/ir:rt", "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -93,6 +94,7 @@ xla_cc_test( "//xla/mlir/runtime/ir:rt", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -170,6 +172,7 @@ cc_library( deps = [ ":custom_call_encoding", "//xla/runtime:type_id", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 1344327b89e7ec..0727dd7bc7f0fe 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -306,6 +306,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -446,6 +447,7 @@ cc_library( "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", @@ -619,6 +621,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:base", "@stablehlo//:chlo_ops", @@ -638,6 +641,7 @@ cc_library( "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", ], @@ -791,6 +795,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:chlo_ops", ], @@ -878,6 +883,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) @@ -909,6 +915,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_ops_inc_gen", @@ -928,6 +935,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_ops_inc_gen", @@ -1115,6 +1123,7 @@ cc_library( "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:ShapeToStandard", "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", From 5e726bca2cad4ea7a4209bf7c73d668cc31078d8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 12:33:20 -0700 Subject: [PATCH 045/124] Add Batch Norm to the integration test. Remove all the rng_seed parameters. PiperOrigin-RevId: 620022797 --- .../integration_test/quantize_model_test.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index 6db660e23c20ef..bc25b9a858440f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -64,7 +64,6 @@ class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): ([10, 1, 1024], [10, 1024, 3]), ([2, 3, 1, 1024], [2, 3, 1024, 3]), ), - 'rng_seed': (1230, 1231, 1232, 1233), }]) ) @test_util.run_in_graph_and_eager_modes @@ -73,7 +72,6 @@ def test_matmul_ptq_model( bias_fn: Optional[ops.Operation], activation_fn: Optional[ops.Operation], dim_sizes: Sequence[int], - rng_seed: int, ): lhs_dim_size, rhs_dim_size = dim_sizes input_shape = (*lhs_dim_size,) @@ -87,7 +85,7 @@ def test_matmul_ptq_model( activation_fn, ) - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -164,14 +162,12 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'slice', 'transpose', ), - 'rng_seed': (0, 11, 222, 3333), }]) ) @test_util.run_in_graph_and_eager_modes def test_matmul_and_same_scale_ptq_model( self, same_scale_op: str, - rng_seed: int, ): input_shape = (2, 3, 1, 1024) filter_shape = (2, 3, 1024, 3) @@ -184,7 +180,7 @@ def test_matmul_and_same_scale_ptq_model( same_scale_op, ) - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -249,7 +245,6 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # TODO: b/326242075 - Support other same-scale ops. ), 'dim_sizes': (([None, 1024], [1024, 3]),), - 'rng_seed': (0, 11, 222, 3333), }]) ) @test_util.run_in_graph_and_eager_modes @@ -257,7 +252,6 @@ def test_matmul_and_same_scale_ptq_model_dynamic( self, same_scale_op: str, dim_sizes: Sequence[int], - rng_seed: int, ): input_dim_size, filter_dim_size = dim_sizes input_shape = (*input_dim_size,) @@ -271,7 +265,7 @@ def test_matmul_and_same_scale_ptq_model_dynamic( same_scale_op, ) - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( np.float32 @@ -339,7 +333,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: nn_ops.relu, nn_ops.relu6, ), - 'has_batch_norm': (False,), + 'has_batch_norm': (False, True), 'input_shape_dynamic': ( False, True, @@ -348,7 +342,6 @@ def data_gen() -> repr_dataset.RepresentativeDataset: False, True, ), - 'rng_seed': (10, 11, 12, 13), }]) ) @test_util.run_in_graph_and_eager_modes @@ -359,7 +352,6 @@ def test_conv_ptq_model( has_batch_norm: bool, input_shape_dynamic: bool, enable_per_channel_quantized_weight: bool, - rng_seed: int, dilations: Sequence[int] = None, ): input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) @@ -375,9 +367,18 @@ def test_conv_ptq_model( strides, dilations, ) + # TODO(b/331809306): investigate why these tests fail. + # skip these test cases. + if ( + bias_fn is None + and has_batch_norm + and input_shape_dynamic + and enable_per_channel_quantized_weight + ): + return # Generate model input data. - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) static_input_shape = [dim if dim is not None else 2 for dim in input_shape] input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( @@ -450,13 +451,11 @@ def data_gen() -> repr_dataset.RepresentativeDataset: 'abc,cde->abde', 'abc,dce->abde', ), - 'rng_seed': (82, 82732, 4444, 14), }]) ) def test_einsum_ptq_model( self, equation: str, - rng_seed: int, ): _, y_shape, bias_shape, x_signature, y_signature = ( self._prepare_sample_einsum_datashapes(equation, use_bias=True) @@ -472,7 +471,7 @@ def test_einsum_ptq_model( ) # Generate model input data. - rng = np.random.default_rng(rng_seed) + rng = np.random.default_rng(seed=42) input_data = ops.convert_to_tensor( rng.uniform(low=0.0, high=1.0, size=x_signature).astype('f4') ) From 47bdff097e6c04cc4ec82209a4366a4f1f9be4b6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Mar 2024 12:45:29 -0700 Subject: [PATCH 046/124] [XLA:Python] Add pytype_srcs and pytype_deps attributes to pytype_strict_library. These attributes are ignored, because we don't use pytype in OSS builds at the moment. While we're here, remove a copybara transformation that strips off a pytype_srcs attribute: we can just leave the attribute alone and it won't do any harm. PiperOrigin-RevId: 620026406 --- third_party/xla/xla/python/BUILD | 1 + third_party/xla/xla/pytype.default.bzl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 91b78aa20c8822..e1f044aead00b3 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -40,6 +40,7 @@ package_group( pytype_strict_library( name = "xla_client", srcs = ["xla_client.py"], + pytype_srcs = ["xla_client.pyi"], srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ diff --git a/third_party/xla/xla/pytype.default.bzl b/third_party/xla/xla/pytype.default.bzl index 05143e8a715181..b63011cc1b8e48 100644 --- a/third_party/xla/xla/pytype.default.bzl +++ b/third_party/xla/xla/pytype.default.bzl @@ -10,5 +10,6 @@ def pytype_strict_binary(name, **kwargs): native.py_binary(name = name, **kwargs) # Placeholder to use until bazel supports pytype_strict_library. -def pytype_strict_library(name, **kwargs): +def pytype_strict_library(name, pytype_deps = [], pytype_srcs = [], **kwargs): + _ = (pytype_deps, pytype_srcs) # @unused native.py_library(name = name, **kwargs) From 89591b6bbc322f690e0e6add26b1e603bb21120e Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Thu, 28 Mar 2024 13:27:01 -0700 Subject: [PATCH 047/124] Move `tsl/util` to `xla/tsl/util` PiperOrigin-RevId: 620039887 --- tensorflow/core/BUILD | 4 +- .../colocate_predecessor_trees_pass.cc | 2 +- tensorflow/core/common_runtime/eager/BUILD | 6 +- .../core/common_runtime/eager/context.cc | 2 +- .../core/common_runtime/eager/execute.cc | 2 +- .../core/common_runtime/eager/execute_node.cc | 2 +- .../common_runtime/gpu/gpu_process_state.cc | 2 +- .../process_function_library_runtime.cc | 2 +- tensorflow/core/data/BUILD | 2 +- tensorflow/core/data/dataset_utils_test.cc | 2 +- tensorflow/core/framework/BUILD | 2 +- tensorflow/core/framework/tensor.cc | 2 +- tensorflow/core/kernels/BUILD | 2 +- tensorflow/core/kernels/image/BUILD | 4 +- .../core/kernels/image/decode_image_op.cc | 2 +- tensorflow/core/kernels/logging_ops_test.cc | 2 +- .../core/kernels/numeric_options_utils.h | 2 +- tensorflow/core/kernels/scatter_nd_util.h | 2 +- tensorflow/core/profiler/convert/BUILD | 2 +- .../op_stats_to_input_pipeline_analysis.cc | 2 +- tensorflow/core/profiler/utils/BUILD | 2 +- .../core/profiler/utils/derived_timeline.cc | 2 +- tensorflow/core/util/BUILD | 32 +++---- tensorflow/core/util/command_line_flags.h | 2 +- tensorflow/core/util/determinism.h | 2 +- tensorflow/core/util/device_name_utils.h | 2 +- tensorflow/core/util/env_var.h | 2 +- tensorflow/core/util/mkl_util.h | 2 +- tensorflow/core/util/proto/BUILD | 2 +- tensorflow/core/util/proto/proto_utils.h | 2 +- tensorflow/core/util/reporter.h | 2 +- .../core/util/stat_summarizer_options.h | 2 +- tensorflow/core/util/stats_calculator.h | 2 +- tensorflow/core/util/tensor_bundle/BUILD | 4 +- .../core/util/tensor_bundle/byte_swap_array.h | 2 +- .../core/util/tensor_bundle/tensor_bundle.cc | 2 +- tensorflow/core/util/use_cudnn.h | 2 +- tensorflow/core/util/work_sharder.cc | 2 +- tensorflow/dtensor/cc/BUILD | 4 +- tensorflow/dtensor/cc/dtensor_device.cc | 2 +- tensorflow/dtensor/cc/dtensor_utils.cc | 2 +- tensorflow/dtensor/mlir/utils/BUILD | 2 +- .../dtensor/mlir/utils/collective_lowering.cc | 2 +- tensorflow/lite/CMakeLists.txt | 2 + .../lite/examples/label_image/CMakeLists.txt | 2 +- tensorflow/lite/testing/BUILD | 2 +- .../testing/generated_examples_zip_test.cc | 2 +- .../lite/tools/benchmark/CMakeLists.txt | 2 +- .../lite/tools/benchmark/experimental/c/BUILD | 2 +- .../experimental/c/benchmark_c_api.cc | 2 +- tensorflow/lite/tools/evaluation/stages/BUILD | 6 +- .../stages/image_preprocessing_stage.h | 2 +- .../stages/inference_profiler_stage.h | 2 +- .../stages/tflite_inference_stage.h | 2 +- tensorflow/python/framework/BUILD | 2 +- tensorflow/python/framework/offset_counter.cc | 2 +- .../python/framework/python_op_gen_main.cc | 2 +- tensorflow/python/util/BUILD | 2 +- .../python/util/stat_summarizer_wrapper.cc | 4 +- .../inference_interface/jni/run_stats_jni.h | 2 +- .../distributed_runtime/coordination/BUILD | 2 +- .../coordination/coordination_service.cc | 2 +- .../tsl/tsl/distributed_runtime/rpc/BUILD | 6 +- .../distributed_runtime/rpc/grpc_channel.cc | 2 +- .../rpc/grpc_channel_test.cc | 2 +- .../tsl/distributed_runtime/rpc/grpc_state.h | 2 +- .../xla/third_party/tsl/tsl/framework/BUILD | 4 +- .../tsl/tsl/framework/device_id_utils.h | 2 +- .../tsl/tsl/framework/device_id_utils_test.cc | 2 +- .../third_party/tsl/tsl/platform/cloud/BUILD | 2 +- .../tsl/platform/cloud/curl_http_request.cc | 2 +- .../third_party/tsl/tsl/profiler/lib/BUILD | 2 +- .../tsl/tsl/profiler/lib/profiler_lock.cc | 2 +- .../third_party/tsl/tsl/profiler/utils/BUILD | 2 +- .../tsl/tsl/profiler/utils/xplane_utils.cc | 2 +- .../xla/third_party/tsl/tsl/protobuf/BUILD | 2 +- third_party/xla/xla/BUILD | 9 +- .../xla/xla/backends/profiler/gpu/BUILD | 4 +- .../xla/backends/profiler/gpu/cupti_utils.cc | 2 +- .../profiler/gpu/device_tracer_cuda.cc | 2 +- .../profiler/gpu/device_tracer_rocm.cc | 2 +- .../backends/profiler/gpu/rocm_collector.cc | 2 +- third_party/xla/xla/debug_options_flags.cc | 2 +- third_party/xla/xla/debug_options_flags.h | 2 +- third_party/xla/xla/literal.cc | 2 +- third_party/xla/xla/parse_flags_from_env.cc | 2 +- third_party/xla/xla/parse_flags_from_env.h | 2 +- .../xla/xla/parse_flags_from_env_test.cc | 2 +- third_party/xla/xla/pjrt/gpu/BUILD | 4 +- third_party/xla/xla/pjrt/gpu/gpu_helpers.cc | 2 +- third_party/xla/xla/service/BUILD | 4 +- third_party/xla/xla/service/cpu/BUILD | 8 +- .../xla/xla/service/cpu/onednn_layer_norm.cc | 2 +- .../xla/xla/service/cpu/onednn_matmul.cc | 2 +- .../xla/service/cpu/onednn_matmul_rewriter.cc | 2 +- .../xla/xla/service/cpu/onednn_softmax.cc | 2 +- third_party/xla/xla/service/gpu/BUILD | 14 +-- .../xla/service/gpu/conv_algorithm_picker.cc | 2 +- .../xla/service/gpu/gemm_algorithm_picker.cc | 2 +- .../xla/service/gpu/gemm_fusion_autotuner.cc | 2 +- .../xla/service/gpu/llvm_gpu_backend/BUILD | 2 +- .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 2 +- third_party/xla/xla/service/gpu/model/BUILD | 2 +- .../service/gpu/model/hlo_op_profiler_run.cc | 2 +- .../xla/xla/service/gpu/nvptx_compiler.cc | 2 +- .../xla/service/gpu/stream_executor_util.cc | 4 +- .../service/gpu/stream_executor_util_test.cc | 2 +- .../service/gpu_compilation_environment.cc | 2 +- .../xla/xla/service/xla_compile_main.cc | 3 +- third_party/xla/xla/stream_executor/BUILD | 2 +- .../xla/xla/stream_executor/cuda/BUILD | 2 +- .../xla/xla/stream_executor/cuda/cuda_dnn.cc | 2 +- third_party/xla/xla/stream_executor/gpu/BUILD | 4 +- .../gpu/gpu_cudamallocasync_allocator.cc | 2 +- .../xla/xla/stream_executor/rocm/BUILD | 6 +- .../xla/xla/stream_executor/rocm/rocm_blas.cc | 2 +- .../xla/xla/stream_executor/rocm/rocm_dnn.cc | 4 +- .../stream_executor/stream_executor_pimpl.cc | 2 +- third_party/xla/xla/tools/BUILD | 14 +-- .../tools/extract_collective_operations.cc | 2 +- .../xla/tools/hex_floats_to_packed_literal.cc | 2 +- third_party/xla/xla/tools/hlo_bisect/BUILD | 2 +- .../xla/xla/tools/hlo_bisect/hlo_bisect.cc | 2 +- third_party/xla/xla/tools/hlo_expand.cc | 2 +- third_party/xla/xla/tools/hlo_expand.h | 2 +- third_party/xla/xla/tools/hlo_expand_main.cc | 2 +- third_party/xla/xla/tools/hlo_opt/BUILD | 2 +- third_party/xla/xla/tools/hlo_opt/opt_main.cc | 2 +- .../xla/xla/tools/hlo_proto_to_json.cc | 2 +- .../xla/xla/tools/interactive_graphviz.cc | 2 +- .../xla/xla/tools/multihost_hlo_runner/BUILD | 2 +- .../multihost_hlo_runner/hlo_runner_main.cc | 2 +- .../xla/xla/tools/run_hlo_module_main.cc | 2 +- .../{third_party/tsl => xla}/tsl/util/BUILD | 94 +++++++++---------- .../tsl => xla}/tsl/util/byte_swap_array.cc | 2 +- .../tsl => xla}/tsl/util/byte_swap_array.h | 6 +- .../tsl/util/command_line_flags.cc | 2 +- .../tsl => xla}/tsl/util/command_line_flags.h | 6 +- .../tsl => xla}/tsl/util/determinism.cc | 4 +- .../tsl => xla}/tsl/util/determinism.h | 6 +- .../tsl/util/determinism_test_util.h | 8 +- .../tsl => xla}/tsl/util/device_name_utils.cc | 2 +- .../tsl => xla}/tsl/util/device_name_utils.h | 6 +- .../tsl/util/device_name_utils_test.cc | 2 +- .../tsl => xla}/tsl/util/env_var.cc | 2 +- .../tsl => xla}/tsl/util/env_var.h | 6 +- .../tsl => xla}/tsl/util/onednn_threadpool.h | 6 +- .../tsl => xla}/tsl/util/proto/BUILD | 0 .../tsl => xla}/tsl/util/proto/proto_utils.h | 6 +- .../tsl => xla}/tsl/util/reporter.cc | 2 +- .../tsl => xla}/tsl/util/reporter.h | 6 +- .../tsl/util/stat_summarizer_options.h | 6 +- .../tsl => xla}/tsl/util/stats_calculator.cc | 2 +- .../tsl => xla}/tsl/util/stats_calculator.h | 8 +- .../tsl/util/stats_calculator_test.cc | 2 +- .../tsl => xla}/tsl/util/use_cudnn.cc | 4 +- .../tsl => xla}/tsl/util/use_cudnn.h | 6 +- third_party/xla/xla/xla.bzl | 2 +- 158 files changed, 286 insertions(+), 284 deletions(-) rename third_party/xla/{third_party/tsl => xla}/tsl/util/BUILD (77%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/byte_swap_array.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/byte_swap_array.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/command_line_flags.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/command_line_flags.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/determinism.cc (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/determinism.h (86%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/determinism_test_util.h (84%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/device_name_utils.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/device_name_utils.h (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/device_name_utils_test.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/env_var.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/env_var.h (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/onednn_threadpool.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/proto/BUILD (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/proto/proto_utils.h (90%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/reporter.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/reporter.h (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/stat_summarizer_options.h (88%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/stats_calculator.cc (99%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/stats_calculator.h (96%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/stats_calculator_test.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/use_cudnn.cc (98%) rename third_party/xla/{third_party/tsl => xla}/tsl/util/use_cudnn.h (92%) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 97c79bec1faccb..656c02e1214ac6 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1708,8 +1708,8 @@ tf_cuda_library( "@local_tsl//tsl/framework:cancellation", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/util:command_line_flags", - "@local_tsl//tsl/util:device_name_utils", + "@local_xla//xla/tsl/util:command_line_flags", + "@local_xla//xla/tsl/util:device_name_utils", ] + if_cuda([ "@local_config_cuda//cuda:cudnn_header", ]) + if_static( diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc index 2b2debbd90458a..31a101a421a284 100644 --- a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc +++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/device_name_utils.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" -#include "tsl/util/device_name_utils.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 8161b5cd49aac3..cb4d241e639f44 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -124,7 +124,7 @@ tf_cuda_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/util:env_var", + "@local_xla//xla/tsl/util:env_var", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", @@ -643,7 +643,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/util:env_var", + "@local_xla//xla/tsl/util:env_var", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", @@ -813,7 +813,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/util:env_var", + "@local_xla//xla/tsl/util:env_var", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 594ea4bf045388..0677e45b4c83a6 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor_internal.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" #include "tensorflow/core/common_runtime/colocation_graph.h" @@ -63,7 +64,6 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tsl/platform/refcount.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h" #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 7aff350fa65ac4..57af63ddb05e3d 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -59,6 +59,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/compiler/jit/defs.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/common_runtime/colocation_graph.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -81,7 +82,6 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h" diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index ebedabf3eef3ee..02e032e604e1de 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/eager/execute_node.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/lib/core/errors.h" -#include "tsl/util/env_var.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc index b3d85d8d792d7e..9dc7530fc1e5d9 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/common_runtime/device/device_host_allocator.h" #include "tensorflow/core/common_runtime/device_id_utils.h" #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tsl/platform/mutex.h" #include "tsl/platform/strcat.h" #include "tsl/platform/types.h" -#include "tsl/util/env_var.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index a404cb014aef35..b1d491fa2fcd87 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/common_runtime/build_graph_options.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function.h" @@ -67,7 +68,6 @@ limitations under the License. #include "tensorflow/core/util/reffed_status_callback.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" #endif // IS_MOBILE_PLATFORM diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 0816237fb0e1d6..c5191a995b60a2 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -181,7 +181,7 @@ tf_cc_test( "//tensorflow/core/platform:str_util", "@com_google_absl//absl/container:flat_hash_set", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/util:determinism_test_util", + "@local_xla//xla/tsl/util:determinism_test_util", ], ) diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc index 853f5e6c5c0bfa..e581f6e3cbe3e8 100644 --- a/tensorflow/core/data/dataset_utils_test.cc +++ b/tensorflow/core/data/dataset_utils_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "xla/tsl/util/determinism_test_util.h" #include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/data/dataset_test_base.h" #include "tensorflow/core/data/serialization_utils.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/work_sharder.h" #include "tsl/platform/status_matchers.h" -#include "tsl/util/determinism_test_util.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index b5f24419245868..15db744c0b2201 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -875,7 +875,7 @@ tf_cuda_library( "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", "@local_tsl//tsl/framework:device_type", - "@local_tsl//tsl/util:byte_swap_array", + "@local_xla//xla/tsl/util:byte_swap_array", ], alwayslink = 1, ) diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 010395a9a2c4bd..d2b0cd3efa0461 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -38,6 +38,7 @@ limitations under the License. #include #include "absl/strings/escaping.h" +#include "xla/tsl/util/byte_swap_array.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/resource_handle.h" @@ -65,7 +66,6 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" -#include "tsl/util/byte_swap_array.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 41546d050012f7..e2c67628646a18 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3433,7 +3433,7 @@ tf_cc_tests( "//tensorflow/core:testlib", "//tensorflow/core/platform:status_matchers", "@com_google_absl//absl/strings", - "@local_tsl//tsl/util:determinism_test_util", + "@local_xla//xla/tsl/util:determinism_test_util", ], ) diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD index d26a532e1a8315..38c741f9844df0 100644 --- a/tensorflow/core/kernels/image/BUILD +++ b/tensorflow/core/kernels/image/BUILD @@ -206,7 +206,7 @@ tf_kernel_library( prefix = "decode_image_op", deps = IMAGE_DEPS + [ "@com_google_absl//absl/strings", - "@local_tsl//tsl/util:byte_swap_array", + "@local_xla//xla/tsl/util:byte_swap_array", ], ) @@ -455,7 +455,7 @@ cc_library( "//tensorflow/core/platform:byte_order", "//tensorflow/core/platform:errors", "@com_google_absl//absl/strings", - "@local_tsl//tsl/util:byte_swap_array", + "@local_xla//xla/tsl/util:byte_swap_array", ], alwayslink = 1, ) diff --git a/tensorflow/core/kernels/image/decode_image_op.cc b/tensorflow/core/kernels/image/decode_image_op.cc index 2ca9f67e17aca1..afb653191e3e8a 100644 --- a/tensorflow/core/kernels/image/decode_image_op.cc +++ b/tensorflow/core/kernels/image/decode_image_op.cc @@ -24,6 +24,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "absl/strings/match.h" +#include "xla/tsl/util/byte_swap_array.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/tstring.h" -#include "tsl/util/byte_swap_array.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc index 8ba44782a194c0..fdb85fda2d70a0 100644 --- a/tensorflow/core/kernels/logging_ops_test.cc +++ b/tensorflow/core/kernels/logging_ops_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "xla/tsl/util/determinism_test_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/status_matchers.h" -#include "tsl/util/determinism_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/numeric_options_utils.h b/tensorflow/core/kernels/numeric_options_utils.h index d9ea889b443a87..eb1d50ae7f80bc 100644 --- a/tensorflow/core/kernels/numeric_options_utils.h +++ b/tensorflow/core/kernels/numeric_options_utils.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_NUMERIC_OPTIONS_UTILS_H_ #include "xla/stream_executor/numeric_options.h" +#include "xla/tsl/util/determinism.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "tsl/util/determinism.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/scatter_nd_util.h b/tensorflow/core/kernels/scatter_nd_util.h index ae78a6abad3a0d..f0530048ef699a 100644 --- a/tensorflow/core/kernels/scatter_nd_util.h +++ b/tensorflow/core/kernels/scatter_nd_util.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_ #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_ +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tsl/util/env_var.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 1b6781bdd8e012..40d967302cc34c 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -242,7 +242,7 @@ cc_library( "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/utils:format_utils", "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index 120e22bab5d64d..268908b3f1cf28 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -46,7 +47,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tsl/profiler/utils/format_utils.h" #include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/util/stats_calculator.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index b3b363375f1bf4..a3f7d16a834cef 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -239,7 +239,7 @@ cc_library( "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_tsl//tsl/profiler/utils:timespan", "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index 3e28bac0c766cf..981b6a0c54e8de 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/gpu_event_stats.h" @@ -43,7 +44,6 @@ limitations under the License. #include "tsl/profiler/utils/tf_xplane_visitor.h" #include "tsl/profiler/utils/timespan.h" #include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/util/stats_calculator.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 6f8dce9bf2107d..c2b913a17a6a8f 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -66,7 +66,7 @@ filegroup( "padding.h", "tensor_format.cc", "tensor_format.h", - "@local_tsl//tsl/util:mobile_srcs_no_runtime", + "@local_xla//xla/tsl/util:mobile_srcs_no_runtime", ], ) @@ -132,7 +132,7 @@ filegroup( "work_sharder.h", "//tensorflow/core/config:mobile_srcs_only_runtime", "//tensorflow/core/util/quantization:mobile_srcs_only_runtime", - "@local_tsl//tsl/util:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/util:mobile_srcs_only_runtime", ], ) @@ -189,7 +189,7 @@ filegroup( "util.h", "work_sharder.h", "xla_config_registry.h", - "@local_tsl//tsl/util:framework_internal_private_hdrs", + "@local_xla//xla/tsl/util:framework_internal_private_hdrs", ], ) @@ -231,7 +231,7 @@ filegroup( "util.cc", "work_sharder.cc", "xla_config_registry.cc", - "@local_tsl//tsl/util:framework_internal_impl_srcs", + "@local_xla//xla/tsl/util:framework_internal_impl_srcs", ], ) @@ -240,7 +240,7 @@ filegroup( srcs = [ "env_var.h", "use_cudnn.h", - "@local_tsl//tsl/util:lib_internal_public_hdrs", + "@local_xla//xla/tsl/util:lib_internal_public_hdrs", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -286,7 +286,7 @@ filegroup( testonly = 1, srcs = [ "reporter.h", - "@local_tsl//tsl/util:test_hdrs", + "@local_xla//xla/tsl/util:test_hdrs", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -297,7 +297,7 @@ filegroup( "mkl_heuristics.h", "mkl_util.h", "onednn_env_vars.h", - "@local_tsl//tsl/util:onednn_util_hdrs", + "@local_xla//xla/tsl/util:onednn_util_hdrs", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -316,7 +316,7 @@ filegroup( testonly = 1, srcs = [ "reporter.h", - "@local_tsl//tsl/util:android_test_hdrs", + "@local_xla//xla/tsl/util:android_test_hdrs", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -326,7 +326,7 @@ filegroup( testonly = 1, srcs = [ ":android_test_hdrs", - "@local_tsl//tsl/util:android_test_srcs", + "@local_xla//xla/tsl/util:android_test_srcs", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -367,7 +367,7 @@ filegroup( "use_cudnn.h", "util.h", "work_sharder.h", - "@local_tsl//tsl/util:framework_srcs", + "@local_xla//xla/tsl/util:framework_srcs", ], ) @@ -422,7 +422,7 @@ cc_library( "//tensorflow:internal", ], deps = [ - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -510,7 +510,7 @@ cc_library( "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:str_util", "//tensorflow/core/platform:types", - "@local_tsl//tsl/util:reporter", + "@local_xla//xla/tsl/util:reporter", ], ) @@ -641,7 +641,7 @@ cc_library( "//tensorflow/core/platform:strcat", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/util:env_var", + "@local_xla//xla/tsl/util:env_var", ], ) @@ -666,7 +666,7 @@ cc_library( deps = [ ":env_var", "//tensorflow/core/platform:mutex", - "@local_tsl//tsl/util:determinism", + "@local_xla//xla/tsl/util:determinism", ], alwayslink = 1, ) @@ -675,7 +675,7 @@ filegroup( name = "determinism_hdr", srcs = [ "determinism.h", - "@local_tsl//tsl/util:determinism_hdr", + "@local_xla//xla/tsl/util:determinism_hdr", ], compatible_with = get_compatible_with_portable(), visibility = ["//tensorflow:__subpackages__"], @@ -688,7 +688,7 @@ cc_library( # TODO(b/298501506): narrow this in a way that won't break TFRT visibility = ["//visibility:public"], deps = [ - "@local_tsl//tsl/util:determinism_hdr_lib", + "@local_xla//xla/tsl/util:determinism_hdr_lib", ], ) diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index cc8ca1b8f119b4..ebc58f7ee476ab 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "xla/tsl/util/command_line_flags.h" #include "tensorflow/core/platform/types.h" -#include "tsl/util/command_line_flags.h" namespace tensorflow { using tsl::Flag; // NOLINT diff --git a/tensorflow/core/util/determinism.h b/tensorflow/core/util/determinism.h index e42fb71d42b0bc..136534ea828570 100644 --- a/tensorflow/core/util/determinism.h +++ b/tensorflow/core/util/determinism.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_DETERMINISM_H_ #define TENSORFLOW_CORE_UTIL_DETERMINISM_H_ -#include "tsl/util/determinism.h" +#include "xla/tsl/util/determinism.h" namespace tensorflow { diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h index 20b1f21786b2f0..28b5b0f1b6e764 100644 --- a/tensorflow/core/util/device_name_utils.h +++ b/tensorflow/core/util/device_name_utils.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_ #define TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_ -#include "tsl/util/device_name_utils.h" +#include "xla/tsl/util/device_name_utils.h" namespace tensorflow { // NOLINTBEGIN(misc-unused-using-decls) diff --git a/tensorflow/core/util/env_var.h b/tensorflow/core/util/env_var.h index fac0e2373ad145..faad61533d648a 100644 --- a/tensorflow/core/util/env_var.h +++ b/tensorflow/core/util/env_var.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_ENV_VAR_H_ #define TENSORFLOW_CORE_UTIL_ENV_VAR_H_ +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" -#include "tsl/util/env_var.h" namespace tensorflow { diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 7d6b1d76d78b2a..a6164cc0264518 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -42,7 +42,7 @@ limitations under the License. #if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif -#include "tsl/util/onednn_threadpool.h" +#include "xla/tsl/util/onednn_threadpool.h" using dnnl::engine; using dnnl::memory; diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD index be93c135384c59..1d84749d6bf523 100644 --- a/tensorflow/core/util/proto/BUILD +++ b/tensorflow/core/util/proto/BUILD @@ -72,7 +72,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:platform_base", "@com_google_absl//absl/strings", - "@local_tsl//tsl/util/proto:proto_utils", + "@local_xla//xla/tsl/util/proto:proto_utils", ], ) diff --git a/tensorflow/core/util/proto/proto_utils.h b/tensorflow/core/util/proto/proto_utils.h index 43f8c918299c65..f0347a84cbe429 100644 --- a/tensorflow/core/util/proto/proto_utils.h +++ b/tensorflow/core/util/proto/proto_utils.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_ #include "absl/strings/string_view.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/protobuf.h" -#include "tsl/util/proto/proto_utils.h" namespace tensorflow { namespace proto_utils { diff --git a/tensorflow/core/util/reporter.h b/tensorflow/core/util/reporter.h index f36b7c72bfb275..2db7a6f827dc22 100644 --- a/tensorflow/core/util/reporter.h +++ b/tensorflow/core/util/reporter.h @@ -21,11 +21,11 @@ limitations under the License. #include #include +#include "xla/tsl/util/reporter.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" -#include "tsl/util/reporter.h" namespace tensorflow { diff --git a/tensorflow/core/util/stat_summarizer_options.h b/tensorflow/core/util/stat_summarizer_options.h index daf3d3b3a9c2c9..71f9bf372454f7 100644 --- a/tensorflow/core/util/stat_summarizer_options.h +++ b/tensorflow/core/util/stat_summarizer_options.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_OPTIONS_H_ #define TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_OPTIONS_H_ -#include "tsl/util/stat_summarizer_options.h" +#include "xla/tsl/util/stat_summarizer_options.h" namespace tensorflow { using tsl::StatSummarizerOptions; diff --git a/tensorflow/core/util/stats_calculator.h b/tensorflow/core/util/stats_calculator.h index 90605aaa9da6d7..20c997ced374a7 100644 --- a/tensorflow/core/util/stats_calculator.h +++ b/tensorflow/core/util/stats_calculator.h @@ -26,8 +26,8 @@ limitations under the License. #include #include +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/util/stat_summarizer_options.h" -#include "tsl/util/stats_calculator.h" namespace tensorflow { diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD index 76a0e41669d42e..4ca9b222fb114b 100644 --- a/tensorflow/core/util/tensor_bundle/BUILD +++ b/tensorflow/core/util/tensor_bundle/BUILD @@ -62,7 +62,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/lib/io:buffered_file", - "@local_tsl//tsl/util:byte_swap_array", + "@local_xla//xla/tsl/util:byte_swap_array", ], ) @@ -86,7 +86,7 @@ cc_library( "//tensorflow/core/platform:byte_order", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", - "@local_tsl//tsl/util:byte_swap_array", + "@local_xla//xla/tsl/util:byte_swap_array", ], ) diff --git a/tensorflow/core/util/tensor_bundle/byte_swap_array.h b/tensorflow/core/util/tensor_bundle/byte_swap_array.h index ed3b6e1445eabb..97315b12917744 100644 --- a/tensorflow/core/util/tensor_bundle/byte_swap_array.h +++ b/tensorflow/core/util/tensor_bundle/byte_swap_array.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_BYTE_SWAP_ARRAY_H_ #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_BYTE_SWAP_ARRAY_H_ +#include "xla/tsl/util/byte_swap_array.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -#include "tsl/util/byte_swap_array.h" namespace tensorflow { diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index f2db274c497c8a..8f58c0bdbcb6d9 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/util/byte_swap_array.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -54,7 +55,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_bundle/naming.h" #include "tensorflow/core/util/tensor_slice_util.h" #include "tsl/lib/io/buffered_file.h" -#include "tsl/util/byte_swap_array.h" #ifdef PLATFORM_WINDOWS #undef DeleteFile diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h index ac9f918b494b30..ba13b74016ce7e 100644 --- a/tensorflow/core/util/use_cudnn.h +++ b/tensorflow/core/util/use_cudnn.h @@ -20,7 +20,7 @@ limitations under the License. #include -#include "tsl/util/use_cudnn.h" +#include "xla/tsl/util/use_cudnn.h" namespace tensorflow { diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc index 6f039c85b948ff..0d48191d69b9e6 100644 --- a/tensorflow/core/util/work_sharder.cc +++ b/tensorflow/core/util/work_sharder.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/logging.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/util/env_var.h" namespace tensorflow { namespace { diff --git a/tensorflow/dtensor/cc/BUILD b/tensorflow/dtensor/cc/BUILD index 7d5070507fd14f..3d28e474d680ea 100644 --- a/tensorflow/dtensor/cc/BUILD +++ b/tensorflow/dtensor/cc/BUILD @@ -43,7 +43,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "@com_google_absl//absl/strings", - "@local_tsl//tsl/util:env_var", + "@local_xla//xla/tsl/util:env_var", ], ) @@ -373,11 +373,11 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/util:env_var", "@local_xla//xla:status_macros", "@local_xla//xla/stream_executor/tpu:c_api_decl", "@local_xla//xla/stream_executor/tpu:tpu_platform_interface", "@local_xla//xla/stream_executor/tpu:tpu_topology_external", + "@local_xla//xla/tsl/util:env_var", ] + tf_dtensor_tpu_dependencies(), ) diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index cc3a92ec8d1c97..d169ba4c8595b6 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -54,6 +54,7 @@ limitations under the License. #include "xla/stream_executor/tpu/c_api_decl.h" #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "xla/stream_executor/tpu/tpu_topology.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -93,7 +94,6 @@ limitations under the License. #include "tensorflow/dtensor/proto/layout.pb.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" using tensorflow::EagerExecutor; diff --git a/tensorflow/dtensor/cc/dtensor_utils.cc b/tensorflow/dtensor/cc/dtensor_utils.cc index dc5f1f827befaa..f5261a719a82e6 100644 --- a/tensorflow/dtensor/cc/dtensor_utils.cc +++ b/tensorflow/dtensor/cc/dtensor_utils.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/util/env_var.h" namespace tensorflow { namespace dtensor { diff --git a/tensorflow/dtensor/mlir/utils/BUILD b/tensorflow/dtensor/mlir/utils/BUILD index f53ca53420b72a..56620400ec950e 100644 --- a/tensorflow/dtensor/mlir/utils/BUILD +++ b/tensorflow/dtensor/mlir/utils/BUILD @@ -51,7 +51,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/util:env_var", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/tsl/util:env_var", ], ) diff --git a/tensorflow/dtensor/mlir/utils/collective_lowering.cc b/tensorflow/dtensor/mlir/utils/collective_lowering.cc index 07e545f5851a83..5a12d4de95dcc0 100644 --- a/tensorflow/dtensor/mlir/utils/collective_lowering.cc +++ b/tensorflow/dtensor/mlir/utils/collective_lowering.cc @@ -43,6 +43,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" +#include "xla/tsl/util/env_var.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/dtensor/cc/constants.h" @@ -59,7 +60,6 @@ limitations under the License. #include "tensorflow/dtensor/mlir/layout_parsing.h" #include "tensorflow/dtensor/mlir/spmd_expander_common.h" #include "tensorflow/dtensor/mlir/value_utils.h" -#include "tsl/util/env_var.h" namespace tensorflow { namespace dtensor { diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index a61f537365a86e..7fa2ede1f4066d 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -50,6 +50,7 @@ if(NOT TENSORFLOW_SOURCE_DIR) endif() set(TF_SOURCE_DIR "${TENSORFLOW_SOURCE_DIR}/tensorflow") set(TSL_SOURCE_DIR "${TENSORFLOW_SOURCE_DIR}/third_party/xla/third_party/tsl") +set(XLA_SOURCE_DIR "${TENSORFLOW_SOURCE_DIR}/third_party/xla/") set(TFLITE_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}") set(CMAKE_MODULE_PATH "${TFLITE_SOURCE_DIR}/tools/cmake/modules" @@ -161,6 +162,7 @@ find_package(ruy REQUIRED) # Include TSL, which is in tensorflow/third_party include_directories( ${TSL_SOURCE_DIR} + ${XLA_SOURCE_DIR} ) # Download necessary dependencies. # Download pthreadpool source package if it doesn't exist. diff --git a/tensorflow/lite/examples/label_image/CMakeLists.txt b/tensorflow/lite/examples/label_image/CMakeLists.txt index 08044b1675beb3..9874801f34fa31 100644 --- a/tensorflow/lite/examples/label_image/CMakeLists.txt +++ b/tensorflow/lite/examples/label_image/CMakeLists.txt @@ -21,7 +21,7 @@ populate_source_vars("${TFLITE_SOURCE_DIR}/examples/label_image" FILTER "_test\\.cc$" ) list(APPEND TFLITE_LABEL_IMAGE_SRCS - ${TSL_SOURCE_DIR}/tsl/util/stats_calculator.cc + ${XLA_SOURCE_DIR}/xla/tsl/util/stats_calculator.cc ${TFLITE_SOURCE_DIR}/profiling/memory_info.cc ${TFLITE_SOURCE_DIR}/profiling/profile_summarizer.cc ${TFLITE_SOURCE_DIR}/profiling/profile_summary_formatter.cc diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index 863ca7c6b72170..1ffa19e4e7b65d 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -106,7 +106,7 @@ _test_size_override = { "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:subprocess", - "@local_tsl//tsl/util:command_line_flags", + "@local_xla//xla/tsl/util:command_line_flags", ], "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib", diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc index 43c0c0851d5e10..35e6737137a6ac 100644 --- a/tensorflow/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/lite/testing/generated_examples_zip_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" #include "re2/re2.h" +#include "xla/tsl/util/command_line_flags.h" #include "tensorflow/lite/string_type.h" #include "tensorflow/lite/testing/parse_testdata.h" #include "tensorflow/lite/testing/tflite_driver.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/subprocess.h" -#include "tsl/util/command_line_flags.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/tools/benchmark/CMakeLists.txt b/tensorflow/lite/tools/benchmark/CMakeLists.txt index 6702c1f9d352c6..fc2a1be282f985 100644 --- a/tensorflow/lite/tools/benchmark/CMakeLists.txt +++ b/tensorflow/lite/tools/benchmark/CMakeLists.txt @@ -21,7 +21,7 @@ populate_source_vars("${TFLITE_SOURCE_DIR}/tools/benchmark" FILTER "(_test|_plus_flex_main|_performance_options.*)\\.cc$" ) list(APPEND TFLITE_BENCHMARK_SRCS - ${TSL_SOURCE_DIR}/tsl/util/stats_calculator.cc + ${XLA_SOURCE_DIR}/xla/tsl/util/stats_calculator.cc ${TFLITE_SOURCE_DIR}/kernels/internal/utils/sparsity_format_converter.cc ${TFLITE_SOURCE_DIR}/profiling/memory_info.cc ${TFLITE_SOURCE_DIR}/profiling/memory_usage_monitor.cc diff --git a/tensorflow/lite/tools/benchmark/experimental/c/BUILD b/tensorflow/lite/tools/benchmark/experimental/c/BUILD index edefa5e6e35158..2a7a93671c265d 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/c/BUILD @@ -31,6 +31,6 @@ cc_library( deps = [ "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/tools/benchmark:benchmark_tflite_model_lib", - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) diff --git a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc index d2cc7f65ceb9e0..2f07561e42c416 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc +++ b/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" -#include "tsl/util/stats_calculator.h" extern "C" { diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index 2ebe3a67bd729e..b2eb3a2ee0cdde 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -49,7 +49,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:preprocessing_steps_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_jpeg_internal", @@ -118,7 +118,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "@com_google_absl//absl/base:core_headers", - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -172,7 +172,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "@FP16", - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) diff --git a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h index 5de1c1bcf96288..1e7d36e098fb5c 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h @@ -21,12 +21,12 @@ limitations under the License. #include #include +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/proto/preprocessing_steps.pb.h" -#include "tsl/util/stats_calculator.h" namespace tflite { namespace evaluation { diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h index f81a7c3b2a1f3d..5f1bee82d33a35 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h @@ -21,11 +21,11 @@ limitations under the License. #include #include +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h" -#include "tsl/util/stats_calculator.h" namespace tflite { namespace evaluation { diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h index b9451c2ee17f99..e6f2436f739b12 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h @@ -19,6 +19,7 @@ limitations under the License. #include +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/core/kernels/register.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/evaluation_stage.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" -#include "tsl/util/stats_calculator.h" namespace tflite { namespace evaluation { diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index a1dc47c492de7b..fd25d20a55ffef 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -203,7 +203,7 @@ tf_cc_binary( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/util:command_line_flags", + "@local_xla//xla/tsl/util:command_line_flags", ], ) diff --git a/tensorflow/python/framework/offset_counter.cc b/tensorflow/python/framework/offset_counter.cc index 4dbae6a231a7cb..09a6facfb5c6b7 100644 --- a/tensorflow/python/framework/offset_counter.cc +++ b/tensorflow/python/framework/offset_counter.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/util/command_line_flags.h" #include "tensorflow/python/framework/offset_counter_helper.h" #include "tensorflow/python/framework/op_reg_offset.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/init_main.h" #include "tsl/platform/strcat.h" #include "tsl/platform/types.h" -#include "tsl/util/command_line_flags.h" inline constexpr absl::string_view kUsage = "offset_counter reads C++ source codes, scans for the location of where " diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index 35314d604ffde4..fc6426e9e438bf 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/util/command_line_flags.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_gen_lib.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/python/framework/python_op_gen.h" #include "tsl/platform/errors.h" #include "tsl/platform/str_util.h" -#include "tsl/util/command_line_flags.h" namespace tensorflow { namespace { diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index e5c0b1b6d29025..01b7664e1c30de 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -162,7 +162,7 @@ tf_python_pybind_extension( "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", "@eigen_archive//:eigen3", - "@local_tsl//tsl/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:stats_calculator_portable", "@pybind11", ], ) diff --git a/tensorflow/python/util/stat_summarizer_wrapper.cc b/tensorflow/python/util/stat_summarizer_wrapper.cc index e6d00ff355b829..47120b21a24ee9 100644 --- a/tensorflow/python/util/stat_summarizer_wrapper.cc +++ b/tensorflow/python/util/stat_summarizer_wrapper.cc @@ -18,11 +18,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 +#include "xla/tsl/util/stat_summarizer_options.h" +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/util/stat_summarizer.h" -#include "tsl/util/stat_summarizer_options.h" -#include "tsl/util/stats_calculator.h" namespace py = pybind11; diff --git a/tensorflow/tools/android/inference_interface/jni/run_stats_jni.h b/tensorflow/tools/android/inference_interface/jni/run_stats_jni.h index 8bbd38692b13c4..7020f63ad449e4 100644 --- a/tensorflow/tools/android/inference_interface/jni/run_stats_jni.h +++ b/tensorflow/tools/android/inference_interface/jni/run_stats_jni.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/util/stat_summarizer.h" -#include "tsl/util/stats_calculator.h" #ifdef __cplusplus extern "C" { diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD index c5686096428dba..a95081f56bd160 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/BUILD @@ -78,13 +78,13 @@ tsl_gpu_library( "//tsl/platform:thread_annotations", "//tsl/protobuf:coordination_config_proto_cc", "//tsl/protobuf:coordination_service_proto_cc", - "//tsl/util:device_name_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@local_xla//xla/tsl/util:device_name_utils", ], alwayslink = 1, ) diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc index 8bb30c31cd798c..45dfb972a131f5 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/coordination/coordination_service.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" +#include "xla/tsl/util/device_name_utils.h" #include "tsl/distributed_runtime/call_options.h" #include "tsl/distributed_runtime/coordination/coordination_client.h" #include "tsl/distributed_runtime/coordination/coordination_service_error_util.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tsl/platform/thread_annotations.h" #include "tsl/protobuf/coordination_config.pb.h" #include "tsl/protobuf/coordination_service.pb.h" -#include "tsl/util/device_name_utils.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD index cc0a37aa4d0009..205029043be73f 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/BUILD @@ -95,8 +95,8 @@ cc_library( "//tsl/platform:thread_annotations", "//tsl/platform:types", "//tsl/protobuf:rpc_options_proto_cc", - "//tsl/util:device_name_utils", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/util:device_name_utils", ] + tsl_grpc_cc_dependencies(), ) @@ -114,7 +114,7 @@ tsl_cc_test( "//tsl/platform:test", "//tsl/platform:test_main", "//tsl/protobuf:rpc_options_proto_cc_impl", - "//tsl/util:device_name_utils", + "@local_xla//xla/tsl/util:device_name_utils", ], ) @@ -129,8 +129,8 @@ cc_library( "//tsl/platform:errors", "//tsl/platform:status", "//tsl/platform:strcat", - "//tsl/util:env_var", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/util:env_var", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc index 492c984e12f13a..ba12449f03bf2f 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/escaping.h" #include "absl/strings/match.h" #include "absl/strings/str_split.h" +#include "xla/tsl/util/device_name_utils.h" #include "tsl/distributed_runtime/rpc/grpc_channel_common.h" #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" #include "tsl/protobuf/rpc_options.pb.h" -#include "tsl/util/device_name_utils.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc index 6b2d330cb1d57a..adc0df2b89ddef 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_channel_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include "xla/tsl/util/device_name_utils.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" #include "tsl/protobuf/rpc_options.pb.h" -#include "tsl/util/device_name_utils.h" namespace tsl { #define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace diff --git a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h index 893e1b0192f694..21d8f2df5099e3 100644 --- a/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h +++ b/third_party/xla/third_party/tsl/tsl/distributed_runtime/rpc/grpc_state.h @@ -23,6 +23,7 @@ limitations under the License. #include "grpcpp/generic/generic_stub.h" #include "grpcpp/grpcpp.h" #include "absl/status/status.h" +#include "xla/tsl/util/env_var.h" #include "tsl/distributed_runtime/call_options.h" #include "tsl/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "tsl/distributed_runtime/rpc/grpc_util.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/strcat.h" #include "tsl/platform/threadpool.h" -#include "tsl/util/env_var.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/framework/BUILD b/third_party/xla/third_party/tsl/tsl/framework/BUILD index cfa12cab82e00e..bca669a616f0ec 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/BUILD +++ b/third_party/xla/third_party/tsl/tsl/framework/BUILD @@ -273,10 +273,10 @@ cc_library( "//tsl/platform:status", "//tsl/platform:statusor", "//tsl/platform:str_util", - "//tsl/util:device_name_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/util:device_name_utils", ], ) @@ -465,7 +465,7 @@ tsl_cc_test( "//tsl/platform:status_matchers", "//tsl/platform:test_main", "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/util:device_name_utils", + "@local_xla//xla/tsl/util:device_name_utils", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/framework/device_id_utils.h b/third_party/xla/third_party/tsl/tsl/framework/device_id_utils.h index c2479aded5fe0a..e814e68c8530a8 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/device_id_utils.h +++ b/third_party/xla/third_party/tsl/tsl/framework/device_id_utils.h @@ -20,11 +20,11 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/util/device_name_utils.h" #include "tsl/framework/device_id.h" #include "tsl/framework/device_type.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/device_name_utils.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/framework/device_id_utils_test.cc b/third_party/xla/third_party/tsl/tsl/framework/device_id_utils_test.cc index ddf7cdd479935b..21e574f95c1b2c 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/device_id_utils_test.cc +++ b/third_party/xla/third_party/tsl/tsl/framework/device_id_utils_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "xla/tsl/util/device_name_utils.h" #include "tsl/framework/device_id_manager.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" -#include "tsl/util/device_name_utils.h" namespace tsl { namespace { diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD index c9067c74c8e526..21e13663a4b4cf 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD @@ -217,8 +217,8 @@ cc_library( "//tsl/platform:str_util", "//tsl/platform:stringpiece", "//tsl/platform:types", - "//tsl/util:env_var", "@curl", + "@local_xla//xla/tsl/util:env_var", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc index a7e6a65e37335d..c41f967c04b055 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc @@ -17,13 +17,13 @@ limitations under the License. #include +#include "xla/tsl/util/env_var.h" #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/scanner.h" #include "tsl/platform/str_util.h" #include "tsl/platform/types.h" -#include "tsl/util/env_var.h" #define CHECK_CURL_OK(expr) CHECK_EQ(expr, CURLE_OK) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD index e6f8b25a83809b..84f6a1439711aa 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD @@ -149,8 +149,8 @@ cc_library( "//tsl/platform:errors", "//tsl/platform:macros", "//tsl/platform:statusor", - "//tsl/util:env_var", "@com_google_absl//absl/strings:string_view", + "@local_xla//xla/tsl/util:env_var", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc index e99db5ae366969..325713117a333a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.cc @@ -16,10 +16,10 @@ limitations under the License. #include +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD index 24d0417b13c652..41d669aedbd103 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD @@ -213,12 +213,12 @@ cc_library( "//tsl/platform:types", "//tsl/profiler/lib:context_types", "//tsl/profiler/protobuf:xplane_proto_cc", - "//tsl/util:stats_calculator_portable", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc index 333b93743ae64f..88c7e30b76eee5 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/stats_calculator.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/profiler/utils/xplane_builder.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_visitor.h" -#include "tsl/util/stats_calculator.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD index db185588785a46..557d9fcf2208e6 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD @@ -105,7 +105,7 @@ tf_proto_library( make_default_target_header_only = True, visibility = internal_visibility([ "//tensorflow/core:__subpackages__", - "//tsl/util:__pkg__", + "@local_xla//xla/tsl/util:__pkg__", ]), ) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index dd5858b633f826..fc4d67a6fe564f 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -604,6 +604,7 @@ cc_library( ":types", ":util", ":xla_data_proto_cc", + "//xla/tsl/util:byte_swap_array", "@com_google_absl//absl/base", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", @@ -622,7 +623,6 @@ cc_library( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util:byte_swap_array", ], ) @@ -1073,12 +1073,12 @@ cc_library( deps = [ ":types", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/util:command_line_flags", ], ) @@ -1088,12 +1088,12 @@ xla_cc_test( deps = [ ":parse_flags_from_env", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:subprocess", "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/util:command_line_flags", ], ) @@ -1111,6 +1111,7 @@ cc_library( ":parse_flags_from_env", ":xla_proto_cc", "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", @@ -1119,7 +1120,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/util:command_line_flags", ], ) @@ -1255,4 +1255,5 @@ cc_library( alias( name = "bazel_issue_21519", actual = ":empty", + visibility = ["//visibility:public"], ) diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 0d259e573afdec..5ea81b05c3276d 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -42,6 +42,7 @@ tsl_gpu_library( ], deps = [ ":cupti_utils", + "//xla/tsl/util:env_var", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -52,7 +53,6 @@ tsl_gpu_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/util:env_var", ], alwayslink = 1, ) @@ -329,7 +329,7 @@ tsl_gpu_library( "@com_google_absl//absl/memory", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:stringpiece", - "@local_tsl//tsl/util:env_var", + "//xla/tsl/util:env_var", ], visibility = ["//visibility:public"], alwayslink = 1, diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc index ee9a542485a48c..a4198811286bed 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_utils.cc @@ -17,9 +17,9 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_error_manager.h" #include "xla/backends/profiler/gpu/cupti_interface.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/logging.h" #include "tsl/platform/stringpiece.h" -#include "tsl/util/env_var.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc index 70530221b68123..d7bb2524b66762 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/thread_annotations.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/time_utils.h" -#include "tsl/util/env_var.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc index 6ce64f81b68657..81eb2d192ea09a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "xla/backends/profiler/gpu/rocm_collector.h" #include "xla/backends/profiler/gpu/rocm_tracer.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/abi.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tsl/profiler/utils/xplane_builder.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/util/env_var.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc index f28b91b5dd9e27..41b21c486eb340 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "xla/stream_executor/rocm/roctracer_wrapper.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/abi.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/profiler/utils/xplane_builder.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/util/env_var.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index f4f9f865e4d889..8b29c823b1d921 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -36,9 +36,9 @@ limitations under the License. #include "xla/debug_options_parsers.h" #include "xla/parse_flags_from_env.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/debug_options_flags.h b/third_party/xla/xla/debug_options_flags.h index 15cc0fb8b448f5..4bc8420441af1e 100644 --- a/third_party/xla/xla/debug_options_flags.h +++ b/third_party/xla/xla/debug_options_flags.h @@ -19,9 +19,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" #include "tsl/platform/logging.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 935e6092ac07c8..1f3768c1cc1a2b 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -49,6 +49,7 @@ limitations under the License. #include "xla/status.h" #include "xla/status_macros.h" #include "xla/statusor.h" +#include "xla/tsl/util/byte_swap_array.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -59,7 +60,6 @@ limitations under the License. #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/byte_swap_array.h" namespace xla { namespace { diff --git a/third_party/xla/xla/parse_flags_from_env.cc b/third_party/xla/xla/parse_flags_from_env.cc index 84ca13de4487cb..0f58671ebff7df 100644 --- a/third_party/xla/xla/parse_flags_from_env.cc +++ b/third_party/xla/xla/parse_flags_from_env.cc @@ -32,8 +32,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/logging.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/parse_flags_from_env.h b/third_party/xla/xla/parse_flags_from_env.h index e73a38421faf97..01d476f22fa3dc 100644 --- a/third_party/xla/xla/parse_flags_from_env.h +++ b/third_party/xla/xla/parse_flags_from_env.h @@ -51,8 +51,8 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/parse_flags_from_env_test.cc b/third_party/xla/xla/parse_flags_from_env_test.cc index 01e82889130a29..f00cb309c12a96 100644 --- a/third_party/xla/xla/parse_flags_from_env_test.cc +++ b/third_party/xla/xla/parse_flags_from_env_test.cc @@ -24,11 +24,11 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/subprocess.h" #include "tsl/platform/test.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 2620ae166bb7b9..d6f5bfe7af8fcf 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -26,10 +26,10 @@ cc_library( "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/tsl/util:env_var", "@com_google_absl//absl/types:span", "@local_tsl//tsl/framework:bfc_allocator", "@local_tsl//tsl/framework:device_id_impl", - "@local_tsl//tsl/util:env_var", ], ) @@ -86,6 +86,7 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor/integrations:device_mem_allocator", "//xla/stream_executor/integrations:tf_allocator_adapter", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -111,7 +112,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:connected_traceme", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/util:env_var", ] + if_cuda_or_rocm([ ":nccl_id_store", "//xla/service/gpu:gpu_compiler", diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc index efd5fbf8c17f60..c9c6fe4da4dcef 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc @@ -27,9 +27,9 @@ limitations under the License. #include "xla/statusor.h" #include "xla/stream_executor/integrations/device_host_allocator.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "tsl/framework/device_id.h" -#include "tsl/util/env_var.h" namespace xla { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 6a72a4431f47c0..79c7d1051897b9 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -7459,12 +7459,12 @@ xla_cc_binary( ":cpu_plugin", "//xla:status", "//xla/tools:xla_compile_lib", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/util:command_line_flags", ] + if_cuda_is_configured([ "//xla/service/gpu:executable_proto_cc", "//xla/service/gpu:gpu_compiler", @@ -7654,10 +7654,10 @@ cc_library( "//xla:statusor", "//xla:util", "//xla:xla_proto_cc", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util:command_line_flags", ], ) diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 04d5a6bff995ab..a69aafc5c113f3 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1602,7 +1602,7 @@ cc_library( srcs = ["onednn_matmul.cc"], hdrs = [ "onednn_matmul.h", - "@local_tsl//tsl/util:onednn_util_hdrs", + "//xla/tsl/util:onednn_util_hdrs", ], copts = runtime_copts() + tsl_copts(), visibility = ["//visibility:public"], @@ -1627,7 +1627,7 @@ cc_library( srcs = ["onednn_layer_norm.cc"], hdrs = [ "onednn_layer_norm.h", - "@local_tsl//tsl/util:onednn_util_hdrs", + "//xla/tsl/util:onednn_util_hdrs", ], copts = runtime_copts() + tsl_copts(), visibility = ["//visibility:public"], @@ -1650,7 +1650,7 @@ cc_library( srcs = ["onednn_softmax.cc"], hdrs = [ "onednn_softmax.h", - "@local_tsl//tsl/util:onednn_util_hdrs", + "//xla/tsl/util:onednn_util_hdrs", ], copts = runtime_copts() + tsl_copts(), visibility = ["//visibility:public"], @@ -1680,7 +1680,7 @@ cc_library( hdrs = [ "onednn_matmul.h", "onednn_matmul_rewriter.h", - "@local_tsl//tsl/util:onednn_util_hdrs", + "//xla/tsl/util:onednn_util_hdrs", ], copts = tsl_copts(), deps = [ diff --git a/third_party/xla/xla/service/cpu/onednn_layer_norm.cc b/third_party/xla/xla/service/cpu/onednn_layer_norm.cc index 1d42f0290ee839..d2109a1bc2f956 100644 --- a/third_party/xla/xla/service/cpu/onednn_layer_norm.cc +++ b/third_party/xla/xla/service/cpu/onednn_layer_norm.cc @@ -30,7 +30,7 @@ limitations under the License. #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" -#include "tsl/util/onednn_threadpool.h" +#include "xla/tsl/util/onednn_threadpool.h" #include "unsupported/Eigen/CXX11/Tensor" namespace xla { diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index 3686827198df7a..4c01c732a96da9 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -35,8 +35,8 @@ limitations under the License. #include "xla/service/cpu/runtime_lightweight_check.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/util/onednn_threadpool.h" #include "tsl/platform/logging.h" -#include "tsl/util/onednn_threadpool.h" namespace xla { namespace cpu { diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc index e08cd7b6c7118a..c9b9e9b4a04a41 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc @@ -32,8 +32,8 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/status_macros.h" +#include "xla/tsl/util/onednn_threadpool.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/util/onednn_threadpool.h" namespace xla { namespace cpu { diff --git a/third_party/xla/xla/service/cpu/onednn_softmax.cc b/third_party/xla/xla/service/cpu/onednn_softmax.cc index 18efb700eb3fde..5af6de54078596 100644 --- a/third_party/xla/xla/service/cpu/onednn_softmax.cc +++ b/third_party/xla/xla/service/cpu/onednn_softmax.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/runtime_lightweight_check.h" -#include "tsl/util/onednn_threadpool.h" +#include "xla/tsl/util/onednn_threadpool.h" #include "unsupported/Eigen/CXX11/Tensor" namespace xla { diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2e2d4b943ad100..1b8117cf702373 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -757,7 +757,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util/proto:proto_utils", + "//xla/tsl/util/proto:proto_utils", ]), ) @@ -1786,6 +1786,7 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/gpu:redzone_allocator", + "//xla/tsl/util/proto:proto_utils", "//xla:util", "//xla:autotuning_proto_cc", "//xla:shape_util", @@ -1793,7 +1794,6 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util/proto:proto_utils", ]), ) @@ -2133,7 +2133,7 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util/proto:proto_utils", + "//xla/tsl/util/proto:proto_utils", ]), ) @@ -3960,7 +3960,7 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/util:env_var", + "//xla/tsl/util:env_var", ]), ) @@ -4433,6 +4433,8 @@ cc_library( "//xla/service:hlo_module_config", "//xla/stream_executor", "//xla/stream_executor:launch_dim", + "//xla/tsl/util:env_var", + "//xla/tsl/util/proto:proto_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -4448,8 +4450,6 @@ cc_library( "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util:env_var", - "@local_tsl//tsl/util/proto:proto_utils", ], ) @@ -4460,10 +4460,10 @@ xla_cc_test( ":stream_executor_util", "//xla:autotuning_proto_cc", "//xla/service:hlo_module_config", + "//xla/tsl/util/proto:proto_utils", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/util/proto:proto_utils", ], ) diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc index 54bde3ac33e147..d7fa93b47e5332 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc @@ -64,13 +64,13 @@ limitations under the License. #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" #include "tsl/platform/statusor.h" -#include "tsl/util/proto/proto_utils.h" #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) #include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc index 2efb705ab3926e..515a05f2ffb9d8 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc @@ -48,11 +48,11 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/scratch_allocator.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -#include "tsl/util/proto/proto_utils.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/buffer_comparator.h" diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc index f5d934455f818c..58615453d7ea65 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc @@ -80,6 +80,7 @@ limitations under the License. #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/tools/hlo_decomposer.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -89,7 +90,6 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" -#include "tsl/util/proto/proto_utils.h" // Log levels used in this file: // VLOG(1): Overview diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index 5731422b1c87eb..d1b47ba1bbe021 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -38,6 +38,7 @@ cc_library( "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/stream_executor:device_description", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -66,7 +67,6 @@ cc_library( "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:rocm_rocdl_path", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/util:env_var", ] + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", "@llvm-project//llvm:AMDGPUCodeGen", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 3b87df5b729004..6c6a6c20fbe27f 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -61,6 +61,7 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/util/env_var.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/cuda_libdevice_path.h" @@ -70,7 +71,6 @@ limitations under the License. #include "tsl/platform/random.h" #include "tsl/platform/rocm_rocdl_path.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/util/env_var.h" #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 45084afa659674..fe2c14ab9ac2c8 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -744,6 +744,7 @@ cc_library( "//xla/service:hlo_runner", "//xla/service:platform_util", "//xla/stream_executor:device_description", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -751,7 +752,6 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/util:command_line_flags", ], ) for sm in [ diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc index 38479bbc982a38..b71dc91505dbd6 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc @@ -28,12 +28,12 @@ limitations under the License. #include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" -#include "tsl/util/command_line_flags.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 4a29483ce03091..8620bbbb8c0dbb 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -100,6 +100,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" @@ -109,7 +110,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/util/env_var.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index b6ae3f29182859..4827fc251ce087 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -56,12 +56,12 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/util/env_var.h" +#include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" -#include "tsl/util/proto/proto_utils.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/stream_executor_util_test.cc b/third_party/xla/xla/service/gpu/stream_executor_util_test.cc index 558de0ec3604ba..cb3be24a6ceaa6 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util_test.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/autotuning.pb.h" #include "xla/service/hlo_module_config.h" -#include "tsl/util/proto/proto_utils.h" +#include "xla/tsl/util/proto/proto_utils.h" namespace xla::gpu { namespace { diff --git a/third_party/xla/xla/service/gpu_compilation_environment.cc b/third_party/xla/xla/service/gpu_compilation_environment.cc index d6551239db2a99..d598c02df3d5f6 100644 --- a/third_party/xla/xla/service/gpu_compilation_environment.cc +++ b/third_party/xla/xla/service/gpu_compilation_environment.cc @@ -25,11 +25,11 @@ limitations under the License. #include "xla/service/compilation_environments.h" #include "xla/status.h" #include "xla/statusor.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/service/xla_compile_main.cc b/third_party/xla/xla/service/xla_compile_main.cc index 929ae0ec2888d1..b7f97fd800f8af 100644 --- a/third_party/xla/xla/service/xla_compile_main.cc +++ b/third_party/xla/xla/service/xla_compile_main.cc @@ -22,10 +22,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/status.h" #include "xla/tools/xla_compile_lib.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/types.h" -#include "tsl/util/command_line_flags.h" - namespace xla { namespace xla_compile { diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 57163b1845b107..194a0182bdab22 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -566,6 +566,7 @@ cc_library( ":platform", ":stream_executor_headers", ":stream_executor_internal", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", @@ -582,7 +583,6 @@ cc_library( "@local_tsl//tsl/platform:stacktrace", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util:env_var", ], ) diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 4c851dece33b95..1403c0dca73efe 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -396,6 +396,7 @@ cuda_only_cc_library( "//xla/stream_executor/gpu:gpu_timer_header", "//xla/stream_executor/platform", "//xla/tsl/cuda:cudnn", + "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -420,7 +421,6 @@ cuda_only_cc_library( "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_tsl//tsl/platform:tensor_float_32_utils", "@local_tsl//tsl/protobuf:dnn_proto_cc", - "@local_tsl//tsl/util:env_var", ], alwayslink = True, ) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 3514df43929e7a..1f8ce927570e47 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -69,13 +69,13 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" #include "tsl/protobuf/dnn.pb.h" -#include "tsl/util/env_var.h" // clang-format off #include "third_party/gpus/cuda/include/library_types.h" diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index f90bf78da0cc23..cfe260ce180fb3 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -445,6 +445,7 @@ gpu_only_cc_library( "//xla/stream_executor/cuda:ptx_compiler", "//xla/stream_executor/cuda:ptx_compiler_support", "//xla/stream_executor/platform", + "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -469,7 +470,6 @@ gpu_only_cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:subprocess", - "@local_tsl//tsl/util:env_var", ] + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_asm_compiler", "//xla/stream_executor/cuda:cuda_driver", @@ -591,6 +591,7 @@ tsl_gpu_library( deps = [ ":gpu_init_impl", "//xla/stream_executor:stream_executor_headers", + "//xla/tsl/util:env_var", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -598,7 +599,6 @@ tsl_gpu_library( "@local_tsl//tsl/framework:device_id", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:mutex", - "@local_tsl//tsl/util:env_var", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc index a97ebdd56b3e77..c7a7ad403e3aa7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc @@ -32,11 +32,11 @@ limitations under the License. #include "absl/strings/str_join.h" #include "xla/stream_executor/gpu/gpu_init.h" // IWYU pragma: keep #include "xla/stream_executor/stream_executor.h" // IWYU pragma: keep +#include "xla/tsl/util/env_var.h" // IWYU pragma: keep #include "tsl/framework/allocator.h" #include "tsl/framework/device_id.h" #include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" -#include "tsl/util/env_var.h" // IWYU pragma: keep namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 49265995ff5beb..c2ad1f56fb9974 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -229,7 +229,7 @@ cc_library( "//xla/stream_executor/platform", "//xla/stream_executor/platform:dso_loader", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/util:determinism_for_kernels", + "//xla/tsl/util:determinism_for_kernels", ]), alwayslink = True, ) @@ -348,8 +348,8 @@ cc_library( "@com_google_absl//absl/types:span", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/util:env_var", - "@local_tsl//tsl/util:determinism_for_kernels", + "//xla/tsl/util:env_var", + "//xla/tsl/util:determinism_for_kernels", ]), alwayslink = True, ) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc index eae41d7d584851..5de18b8557094c 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc @@ -40,8 +40,8 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/determinism.h" #include "tsl/platform/logging.h" -#include "tsl/util/determinism.h" using tsl::OpDeterminismRequired; namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc index d3030a0ad7a063..28521b319d54d0 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc @@ -43,12 +43,12 @@ limitations under the License. #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/util/determinism.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" -#include "tsl/util/determinism.h" -#include "tsl/util/env_var.h" namespace { diff --git a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc index e7ff3bb831df8a..12d2d0a9b33c25 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_pimpl.cc @@ -54,12 +54,12 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_internal.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/env_var.h" namespace stream_executor { namespace { diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 601c27f26f0b09..4b11cb0362ab5c 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -55,6 +55,7 @@ xla_cc_binary( srcs = ["hex_floats_to_packed_literal.cc"], deps = [ "//xla:types", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@local_tsl//tsl/lib/io:buffered_inputstream", @@ -63,7 +64,6 @@ xla_cc_binary( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/util:command_line_flags", ], ) @@ -228,11 +228,11 @@ xla_cc_binary( "//xla:statusor", "//xla:util", "//xla/service:hlo_proto_cc", + "//xla/tsl/util:command_line_flags", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/util:command_line_flags", ], ) @@ -297,13 +297,13 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_pass_pipeline", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/util:command_line_flags", ], ) @@ -323,7 +323,7 @@ cc_library( "//xla/service:sharding_propagation", "//xla/service:triangular_solve_expander", "//xla/service/spmd:stateful_rng_spmd_partitioner", - "@local_tsl//tsl/util:command_line_flags", + "//xla/tsl/util:command_line_flags", ], ) @@ -397,6 +397,7 @@ xla_cc_binary( "//xla/service:hlo_runner", "//xla/service:local_service", "//xla/service:platform_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -405,7 +406,6 @@ xla_cc_binary( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:subprocess", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/util:command_line_flags", ] + if_cuda_or_rocm([ "//xla/service:gpu_plugin", ]) + if_cuda([ @@ -562,13 +562,13 @@ xla_cc_binary( "//xla/service:hlo_runner", "//xla/service:interpreter_plugin", "//xla/service:platform_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/util:command_line_flags", ] + if_cuda_or_rocm([ "//xla/service:gpu_plugin", ]) + if_cuda([ @@ -663,6 +663,7 @@ xla_cc_binary( "//xla:status", "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -672,7 +673,6 @@ xla_cc_binary( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util:command_line_flags", ], ) diff --git a/third_party/xla/xla/tools/extract_collective_operations.cc b/third_party/xla/xla/tools/extract_collective_operations.cc index cc579b7bd445fc..1a15fa8fbc1fe6 100644 --- a/third_party/xla/xla/tools/extract_collective_operations.cc +++ b/third_party/xla/xla/tools/extract_collective_operations.cc @@ -29,12 +29,12 @@ limitations under the License. #include "xla/status.h" #include "xla/tools/hlo_decomposer.h" #include "xla/tools/hlo_module_loader.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/path.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/command_line_flags.h" namespace { const char* const kUsage = R"( diff --git a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc index a845dabe02b3a1..c4d591ba34928a 100644 --- a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc +++ b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" #include "tsl/lib/io/buffered_inputstream.h" #include "tsl/lib/io/random_inputstream.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" -#include "tsl/util/command_line_flags.h" using std::string; diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD index f0dce33a38533a..732384eca9dada 100644 --- a/third_party/xla/xla/tools/hlo_bisect/BUILD +++ b/third_party/xla/xla/tools/hlo_bisect/BUILD @@ -32,8 +32,8 @@ xla_cc_binary( "//xla/service:cpu_plugin", "//xla/service:gpu_plugin", "//xla/service:interpreter_plugin", + "//xla/tsl/util:command_line_flags", "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/util:command_line_flags", ] + if_cuda(["//xla/stream_executor/cuda:cublas_plugin"]), ) diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect.cc index fda3fb25968ac2..73b018323f34c3 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "xla/tools/hlo_bisect/hlo_bisect_utils.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" -#include "tsl/util/command_line_flags.h" const char* const kUsage = R"( Given an HloModule that manifests an XLA bug, either crashes the compiler or diff --git a/third_party/xla/xla/tools/hlo_expand.cc b/third_party/xla/xla/tools/hlo_expand.cc index 70aed16de04941..cd568564339d27 100644 --- a/third_party/xla/xla/tools/hlo_expand.cc +++ b/third_party/xla/xla/tools/hlo_expand.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include "xla/service/spmd/stateful_rng_spmd_partitioner.h" #include "xla/service/triangular_solve_expander.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla_data.pb.h" -#include "tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/tools/hlo_expand.h b/third_party/xla/xla/tools/hlo_expand.h index a80f05b5e789a6..5c1818e91a81da 100644 --- a/third_party/xla/xla/tools/hlo_expand.h +++ b/third_party/xla/xla/tools/hlo_expand.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "xla/service/hlo_pass_pipeline.h" -#include "tsl/util/command_line_flags.h" +#include "xla/tsl/util/command_line_flags.h" namespace xla { diff --git a/third_party/xla/xla/tools/hlo_expand_main.cc b/third_party/xla/xla/tools/hlo_expand_main.cc index 4e34c9f6c37fed..60a83c3b55837c 100644 --- a/third_party/xla/xla/tools/hlo_expand_main.cc +++ b/third_party/xla/xla/tools/hlo_expand_main.cc @@ -25,11 +25,11 @@ limitations under the License. #include "xla/service/hlo_pass_pipeline.h" #include "xla/tools/hlo_expand.h" #include "xla/tools/hlo_module_loader.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/path.h" -#include "tsl/util/command_line_flags.h" namespace { diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index cf847c5ad47d3b..1487bc0a484d84 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -117,6 +117,7 @@ cc_library( "//xla/service:platform_util", "//xla/tools:hlo_module_loader", "//xla/tools:run_hlo_module_lib", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -128,7 +129,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/util:command_line_flags", ] + if_gpu_is_configured([ ":gpu_opt", ]) + if_cuda_is_configured([ diff --git a/third_party/xla/xla/tools/hlo_opt/opt_main.cc b/third_party/xla/xla/tools/hlo_opt/opt_main.cc index a0f803c6751633..ef24a0a4da2405 100644 --- a/third_party/xla/xla/tools/hlo_opt/opt_main.cc +++ b/third_party/xla/xla/tools/hlo_opt/opt_main.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/tools/hlo_module_loader.h" #include "xla/tools/hlo_opt/opt_lib.h" #include "xla/tools/run_hlo_module.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/init_main.h" @@ -48,7 +49,6 @@ limitations under the License. #include "tsl/platform/path.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/command_line_flags.h" namespace { const char* const kUsage = R"( diff --git a/third_party/xla/xla/tools/hlo_proto_to_json.cc b/third_party/xla/xla/tools/hlo_proto_to_json.cc index ef16a966edcd7f..cdadd1e3ce1d79 100644 --- a/third_party/xla/xla/tools/hlo_proto_to_json.cc +++ b/third_party/xla/xla/tools/hlo_proto_to_json.cc @@ -30,12 +30,12 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/statusor.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" -#include "tsl/util/command_line_flags.h" using std::string; diff --git a/third_party/xla/xla/tools/interactive_graphviz.cc b/third_party/xla/xla/tools/interactive_graphviz.cc index 6bd32a96427f8d..10d68162d54627 100644 --- a/third_party/xla/xla/tools/interactive_graphviz.cc +++ b/third_party/xla/xla/tools/interactive_graphviz.cc @@ -46,12 +46,12 @@ limitations under the License. #include "xla/service/local_service.h" #include "xla/service/platform_util.h" #include "xla/tools/hlo_extractor.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/subprocess.h" #include "tsl/protobuf/error_codes.pb.h" -#include "tsl/util/command_line_flags.h" #if defined(PLATFORM_GOOGLE) #include "util/readline/readline.h" #endif diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index fe7b8b8bf03482..1a2ac425ed166c 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -40,13 +40,13 @@ xla_cc_binary( "//xla:statusor", "//xla/pjrt:pjrt_client", "//xla/service:cpu_plugin", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/util:command_line_flags", ] + if_cuda_or_rocm([ "//xla/service:gpu_plugin", ]) + if_cuda([ diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index 6a005380ad62c4..b0666ff491fcba 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -28,11 +28,11 @@ limitations under the License. #include "xla/statusor.h" #include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h" #include "xla/tools/multihost_hlo_runner/hlo_runner_flags.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/util/command_line_flags.h" namespace { const char* const kUsage = R"( diff --git a/third_party/xla/xla/tools/run_hlo_module_main.cc b/third_party/xla/xla/tools/run_hlo_module_main.cc index 6f4005a2da5d9c..c19130df2c6561 100644 --- a/third_party/xla/xla/tools/run_hlo_module_main.cc +++ b/third_party/xla/xla/tools/run_hlo_module_main.cc @@ -26,11 +26,11 @@ limitations under the License. #include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" #include "xla/tools/run_hlo_module.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" -#include "tsl/util/command_line_flags.h" namespace { const char* const kUsage = R"( diff --git a/third_party/xla/third_party/tsl/tsl/util/BUILD b/third_party/xla/xla/tsl/util/BUILD similarity index 77% rename from third_party/xla/third_party/tsl/tsl/util/BUILD rename to third_party/xla/xla/tsl/util/BUILD index 45432219bdcc4d..c96aa52c81f1f7 100644 --- a/third_party/xla/third_party/tsl/tsl/util/BUILD +++ b/third_party/xla/xla/tsl/util/BUILD @@ -5,24 +5,24 @@ # to other TF components outside of TSL. load( - "@local_tsl//tsl/platform:rules_cc.bzl", - "cc_library", -) -load( - "//tsl:tsl.bzl", + "@local_tsl//tsl:tsl.bzl", "check_deps", "internal_visibility", "tsl_copts", ) -load("//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( - "//tsl/platform:build_config.bzl", + "@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test", ) load( - "//tsl/platform:build_config_root.bzl", + "@local_tsl//tsl/platform:build_config_root.bzl", "if_static", ) +load( + "@local_tsl//tsl/platform:rules_cc.bzl", + "cc_library", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -127,9 +127,9 @@ cc_library( srcs = ["byte_swap_array.cc"], hdrs = ["byte_swap_array.h"], deps = [ - "//tsl/platform:byte_order", - "//tsl/platform:errors", - "//tsl/platform:status", + "@local_tsl//tsl/platform:byte_order", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", ], ) @@ -151,8 +151,8 @@ cc_library( visibility = internal_visibility(["//tensorflow:__subpackages__"]), deps = [ ":env_var", - "//tsl/platform:mutex", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:mutex", ], alwayslink = 1, ) @@ -200,14 +200,14 @@ cc_library( srcs = ["env_var.cc"], hdrs = ["env_var.h"], deps = [ - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:numbers", - "//tsl/platform:status", - "//tsl/platform:str_util", - "//tsl/platform:strcat", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", ], ) @@ -217,17 +217,17 @@ cc_library( hdrs = ["reporter.h"], visibility = internal_visibility([ "//tensorflow/core:__subpackages__", - "//tsl:__subpackages__", + "@local_tsl//tsl:__subpackages__", ]), deps = [ - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:macros", - "//tsl/platform:mutex", - "//tsl/platform:str_util", - "//tsl/platform:types", - "//tsl/protobuf:test_log_proto_cc", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/protobuf:test_log_proto_cc", ], ) @@ -242,7 +242,7 @@ cc_library( ], copts = tsl_copts(), visibility = internal_visibility([ - "//tsl:internal", + "@local_tsl//tsl:internal", ]), ) @@ -251,8 +251,8 @@ tsl_cc_test( srcs = ["stats_calculator_test.cc"], deps = [ ":stats_calculator_portable", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -261,9 +261,9 @@ cc_library( srcs = ["device_name_utils.cc"], hdrs = ["device_name_utils.h"], deps = [ - "//tsl/platform:errors", - "//tsl/platform:status", - "//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", ], ) @@ -273,12 +273,12 @@ tsl_cc_test( srcs = ["device_name_utils_test.cc"], deps = [ ":device_name_utils", - "//tsl/lib/core:status_test_util", - "//tsl/platform:errors", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", ], ) @@ -287,12 +287,12 @@ cc_library( srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], deps = [ - "//tsl/platform:logging", - "//tsl/platform:str_util", - "//tsl/platform:stringpiece", - "//tsl/platform:stringprintf", - "//tsl/platform:types", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:stringprintf", + "@local_tsl//tsl/platform:types", ], ) @@ -311,7 +311,7 @@ filegroup( "onednn_threadpool.h", ], visibility = internal_visibility([ - "@local_xla//xla:__subpackages__", + "//xla:__subpackages__", "//tensorflow/core:__pkg__", "//tensorflow/core/framework:__pkg__", "//tensorflow/core/util:__pkg__", diff --git a/third_party/xla/third_party/tsl/tsl/util/byte_swap_array.cc b/third_party/xla/xla/tsl/util/byte_swap_array.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/util/byte_swap_array.cc rename to third_party/xla/xla/tsl/util/byte_swap_array.cc index e77e4bab8defc0..3b21798f0caf41 100644 --- a/third_party/xla/third_party/tsl/tsl/util/byte_swap_array.cc +++ b/third_party/xla/xla/tsl/util/byte_swap_array.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/byte_swap_array.h" +#include "xla/tsl/util/byte_swap_array.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/util/byte_swap_array.h b/third_party/xla/xla/tsl/util/byte_swap_array.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/util/byte_swap_array.h rename to third_party/xla/xla/tsl/util/byte_swap_array.h index ad7e34efcd51f7..88c87afd2696e7 100644 --- a/third_party/xla/third_party/tsl/tsl/util/byte_swap_array.h +++ b/third_party/xla/xla/tsl/util/byte_swap_array.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_BYTE_SWAP_ARRAY_H_ -#define TENSORFLOW_TSL_UTIL_BYTE_SWAP_ARRAY_H_ +#ifndef XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ +#define XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ #include "tsl/platform/byte_order.h" #include "tsl/platform/errors.h" @@ -101,4 +101,4 @@ Status ByteSwapArray(char *array, size_t bytes_per_elem, int array_len); } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_BYTE_SWAP_ARRAY_H_ +#endif // XLA_TSL_UTIL_BYTE_SWAP_ARRAY_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc b/third_party/xla/xla/tsl/util/command_line_flags.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc rename to third_party/xla/xla/tsl/util/command_line_flags.cc index 5e316e9ae9fc6a..f5a97a50eb1980 100644 --- a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.cc +++ b/third_party/xla/xla/tsl/util/command_line_flags.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/command_line_flags.h" +#include "xla/tsl/util/command_line_flags.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h b/third_party/xla/xla/tsl/util/command_line_flags.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/util/command_line_flags.h rename to third_party/xla/xla/tsl/util/command_line_flags.h index 2710de5753cd01..d4b3efd662a94d 100644 --- a/third_party/xla/third_party/tsl/tsl/util/command_line_flags.h +++ b/third_party/xla/xla/tsl/util/command_line_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ -#define TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ +#ifndef XLA_TSL_UTIL_COMMAND_LINE_FLAGS_H_ +#define XLA_TSL_UTIL_COMMAND_LINE_FLAGS_H_ #include #include @@ -145,4 +145,4 @@ class Flags { } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ +#endif // XLA_TSL_UTIL_COMMAND_LINE_FLAGS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/determinism.cc b/third_party/xla/xla/tsl/util/determinism.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/util/determinism.cc rename to third_party/xla/xla/tsl/util/determinism.cc index b9a5abd9af40d1..6089cc96458dc1 100644 --- a/third_party/xla/third_party/tsl/tsl/util/determinism.cc +++ b/third_party/xla/xla/tsl/util/determinism.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/determinism.h" +#include "xla/tsl/util/determinism.h" #include "absl/strings/string_view.h" +#include "xla/tsl/util/env_var.h" #include "tsl/platform/mutex.h" -#include "tsl/util/env_var.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/util/determinism.h b/third_party/xla/xla/tsl/util/determinism.h similarity index 86% rename from third_party/xla/third_party/tsl/tsl/util/determinism.h rename to third_party/xla/xla/tsl/util/determinism.h index fff5b195845a39..2f1861ed60a23b 100644 --- a/third_party/xla/third_party/tsl/tsl/util/determinism.h +++ b/third_party/xla/xla/tsl/util/determinism.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_DETERMINISM_H_ -#define TENSORFLOW_TSL_UTIL_DETERMINISM_H_ +#ifndef XLA_TSL_UTIL_DETERMINISM_H_ +#define XLA_TSL_UTIL_DETERMINISM_H_ namespace tsl { @@ -24,4 +24,4 @@ void EnableOpDeterminism(bool enabled); } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_DETERMINISM_H_ +#endif // XLA_TSL_UTIL_DETERMINISM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/determinism_test_util.h b/third_party/xla/xla/tsl/util/determinism_test_util.h similarity index 84% rename from third_party/xla/third_party/tsl/tsl/util/determinism_test_util.h rename to third_party/xla/xla/tsl/util/determinism_test_util.h index e458dc9cdacc50..34b4552bb62d6a 100644 --- a/third_party/xla/third_party/tsl/tsl/util/determinism_test_util.h +++ b/third_party/xla/xla/tsl/util/determinism_test_util.h @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ -#define TENSORFLOW_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ +#ifndef XLA_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ +#define XLA_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ -#include "tsl/util/determinism.h" +#include "xla/tsl/util/determinism.h" namespace tsl { namespace test { @@ -35,4 +35,4 @@ class DeterministicOpsScope { } // namespace test } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ +#endif // XLA_TSL_UTIL_DETERMINISM_TEST_UTIL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/device_name_utils.cc b/third_party/xla/xla/tsl/util/device_name_utils.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/util/device_name_utils.cc rename to third_party/xla/xla/tsl/util/device_name_utils.cc index 0920532c62eddb..180e3336666bca 100644 --- a/third_party/xla/third_party/tsl/tsl/util/device_name_utils.cc +++ b/third_party/xla/xla/tsl/util/device_name_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/device_name_utils.h" +#include "xla/tsl/util/device_name_utils.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/util/device_name_utils.h b/third_party/xla/xla/tsl/util/device_name_utils.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/util/device_name_utils.h rename to third_party/xla/xla/tsl/util/device_name_utils.h index 162af1c55b4b47..82b5fa3b1aec2e 100644 --- a/third_party/xla/third_party/tsl/tsl/util/device_name_utils.h +++ b/third_party/xla/xla/tsl/util/device_name_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_DEVICE_NAME_UTILS_H_ -#define TENSORFLOW_TSL_UTIL_DEVICE_NAME_UTILS_H_ +#ifndef XLA_TSL_UTIL_DEVICE_NAME_UTILS_H_ +#define XLA_TSL_UTIL_DEVICE_NAME_UTILS_H_ #include @@ -291,4 +291,4 @@ std::ostream& operator<<(std::ostream& os, } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_DEVICE_NAME_UTILS_H_ +#endif // XLA_TSL_UTIL_DEVICE_NAME_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/device_name_utils_test.cc b/third_party/xla/xla/tsl/util/device_name_utils_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/util/device_name_utils_test.cc rename to third_party/xla/xla/tsl/util/device_name_utils_test.cc index dce1fc5807604f..03aa5fca5899b9 100644 --- a/third_party/xla/third_party/tsl/tsl/util/device_name_utils_test.cc +++ b/third_party/xla/xla/tsl/util/device_name_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/device_name_utils.h" +#include "xla/tsl/util/device_name_utils.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/util/env_var.cc b/third_party/xla/xla/tsl/util/env_var.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/util/env_var.cc rename to third_party/xla/xla/tsl/util/env_var.cc index e7d818445c7def..564617aa082889 100644 --- a/third_party/xla/third_party/tsl/tsl/util/env_var.cc +++ b/third_party/xla/xla/tsl/util/env_var.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/env_var.h" +#include "xla/tsl/util/env_var.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/util/env_var.h b/third_party/xla/xla/tsl/util/env_var.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/util/env_var.h rename to third_party/xla/xla/tsl/util/env_var.h index 9c6925c57f643b..69c0bff2a1658c 100644 --- a/third_party/xla/third_party/tsl/tsl/util/env_var.h +++ b/third_party/xla/xla/tsl/util/env_var.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_ENV_VAR_H_ -#define TENSORFLOW_TSL_UTIL_ENV_VAR_H_ +#ifndef XLA_TSL_UTIL_ENV_VAR_H_ +#define XLA_TSL_UTIL_ENV_VAR_H_ #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" @@ -53,4 +53,4 @@ Status ReadStringsFromEnvVar(StringPiece env_var_name, StringPiece default_val, } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_ENV_VAR_H_ +#endif // XLA_TSL_UTIL_ENV_VAR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/onednn_threadpool.h b/third_party/xla/xla/tsl/util/onednn_threadpool.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/util/onednn_threadpool.h rename to third_party/xla/xla/tsl/util/onednn_threadpool.h index 7d8a093ae89fa6..0c81806352f863 100644 --- a/third_party/xla/third_party/tsl/tsl/util/onednn_threadpool.h +++ b/third_party/xla/xla/tsl/util/onednn_threadpool.h @@ -14,8 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ -#define TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#ifndef XLA_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#define XLA_TSL_UTIL_ONEDNN_THREADPOOL_H_ #ifdef INTEL_MKL #include @@ -190,4 +190,4 @@ class OneDnnThreadPool { } // namespace tsl #endif // INTEL_MKL -#endif // TENSORFLOW_TSL_UTIL_ONEDNN_THREADPOOL_H_ +#endif // XLA_TSL_UTIL_ONEDNN_THREADPOOL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/proto/BUILD b/third_party/xla/xla/tsl/util/proto/BUILD similarity index 100% rename from third_party/xla/third_party/tsl/tsl/util/proto/BUILD rename to third_party/xla/xla/tsl/util/proto/BUILD diff --git a/third_party/xla/third_party/tsl/tsl/util/proto/proto_utils.h b/third_party/xla/xla/tsl/util/proto/proto_utils.h similarity index 90% rename from third_party/xla/third_party/tsl/tsl/util/proto/proto_utils.h rename to third_party/xla/xla/tsl/util/proto/proto_utils.h index 9a1dee8eed5224..2762f4df0e8af1 100644 --- a/third_party/xla/third_party/tsl/tsl/util/proto/proto_utils.h +++ b/third_party/xla/xla/tsl/util/proto/proto_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_PROTO_PROTO_UTILS_H_ -#define TENSORFLOW_TSL_UTIL_PROTO_PROTO_UTILS_H_ +#ifndef XLA_TSL_UTIL_PROTO_PROTO_UTILS_H_ +#define XLA_TSL_UTIL_PROTO_PROTO_UTILS_H_ #include "google/protobuf/duration.pb.h" #include "absl/time/time.h" @@ -39,4 +39,4 @@ inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { } // namespace proto_utils } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_PROTO_PROTO_UTILS_H_ +#endif // XLA_TSL_UTIL_PROTO_PROTO_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/reporter.cc b/third_party/xla/xla/tsl/util/reporter.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/util/reporter.cc rename to third_party/xla/xla/tsl/util/reporter.cc index 41501bc68e8ced..c8ee2f2f87c4ea 100644 --- a/third_party/xla/third_party/tsl/tsl/util/reporter.cc +++ b/third_party/xla/xla/tsl/util/reporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/reporter.h" +#include "xla/tsl/util/reporter.h" #include "tsl/platform/errors.h" #include "tsl/platform/mutex.h" diff --git a/third_party/xla/third_party/tsl/tsl/util/reporter.h b/third_party/xla/xla/tsl/util/reporter.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/util/reporter.h rename to third_party/xla/xla/tsl/util/reporter.h index d020e94fae1276..cf1e2b2c274b25 100644 --- a/third_party/xla/third_party/tsl/tsl/util/reporter.h +++ b/third_party/xla/xla/tsl/util/reporter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_REPORTER_H_ -#define TENSORFLOW_TSL_UTIL_REPORTER_H_ +#ifndef XLA_TSL_UTIL_REPORTER_H_ +#define XLA_TSL_UTIL_REPORTER_H_ #include #include @@ -131,4 +131,4 @@ class TestReporter { } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_REPORTER_H_ +#endif // XLA_TSL_UTIL_REPORTER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/stat_summarizer_options.h b/third_party/xla/xla/tsl/util/stat_summarizer_options.h similarity index 88% rename from third_party/xla/third_party/tsl/tsl/util/stat_summarizer_options.h rename to third_party/xla/xla/tsl/util/stat_summarizer_options.h index e07de6e8d5d9d1..c3ed6ffd7e48bf 100644 --- a/third_party/xla/third_party/tsl/tsl/util/stat_summarizer_options.h +++ b/third_party/xla/xla/tsl/util/stat_summarizer_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ -#define TENSORFLOW_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ +#ifndef XLA_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ +#define XLA_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ namespace tsl { // Used to control the output of the statistics summarizer; struct StatSummarizerOptions { @@ -41,4 +41,4 @@ struct StatSummarizerOptions { }; } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ +#endif // XLA_TSL_UTIL_STAT_SUMMARIZER_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/stats_calculator.cc b/third_party/xla/xla/tsl/util/stats_calculator.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/util/stats_calculator.cc rename to third_party/xla/xla/tsl/util/stats_calculator.cc index 99ab1e3e7c6bc5..cdfa46c94417c3 100644 --- a/third_party/xla/third_party/tsl/tsl/util/stats_calculator.cc +++ b/third_party/xla/xla/tsl/util/stats_calculator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/stats_calculator.h" +#include "xla/tsl/util/stats_calculator.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/util/stats_calculator.h b/third_party/xla/xla/tsl/util/stats_calculator.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/util/stats_calculator.h rename to third_party/xla/xla/tsl/util/stats_calculator.h index 5c23f432971c23..84045fb6ceece2 100644 --- a/third_party/xla/third_party/tsl/tsl/util/stats_calculator.h +++ b/third_party/xla/xla/tsl/util/stats_calculator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_UTIL_STATS_CALCULATOR_H_ -#define TENSORFLOW_TSL_UTIL_STATS_CALCULATOR_H_ +#ifndef XLA_TSL_UTIL_STATS_CALCULATOR_H_ +#define XLA_TSL_UTIL_STATS_CALCULATOR_H_ #include @@ -26,7 +26,7 @@ limitations under the License. #include #include -#include "tsl/util/stat_summarizer_options.h" +#include "xla/tsl/util/stat_summarizer_options.h" namespace tsl { @@ -198,4 +198,4 @@ class StatsCalculator { } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_STATS_CALCULATOR_H_ +#endif // XLA_TSL_UTIL_STATS_CALCULATOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/util/stats_calculator_test.cc b/third_party/xla/xla/tsl/util/stats_calculator_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/util/stats_calculator_test.cc rename to third_party/xla/xla/tsl/util/stats_calculator_test.cc index 9093701e4478c9..d58186630598f0 100644 --- a/third_party/xla/third_party/tsl/tsl/util/stats_calculator_test.cc +++ b/third_party/xla/xla/tsl/util/stats_calculator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/stats_calculator.h" +#include "xla/tsl/util/stats_calculator.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/util/use_cudnn.cc b/third_party/xla/xla/tsl/util/use_cudnn.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/util/use_cudnn.cc rename to third_party/xla/xla/tsl/util/use_cudnn.cc index 3156a319b73b3d..a3e1b4d25d2667 100644 --- a/third_party/xla/third_party/tsl/tsl/util/use_cudnn.cc +++ b/third_party/xla/xla/tsl/util/use_cudnn.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/util/use_cudnn.h" +#include "xla/tsl/util/use_cudnn.h" #include +#include "xla/tsl/util/env_var.h" #include "tsl/platform/str_util.h" #include "tsl/platform/stringpiece.h" -#include "tsl/util/env_var.h" #if GOOGLE_CUDA #include "third_party/gpus/cudnn/cudnn.h" diff --git a/third_party/xla/third_party/tsl/tsl/util/use_cudnn.h b/third_party/xla/xla/tsl/util/use_cudnn.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/util/use_cudnn.h rename to third_party/xla/xla/tsl/util/use_cudnn.h index 738e727e4c7808..41c29b256f7be0 100644 --- a/third_party/xla/third_party/tsl/tsl/util/use_cudnn.h +++ b/third_party/xla/xla/tsl/util/use_cudnn.h @@ -15,8 +15,8 @@ limitations under the License. // The utility to check Cudnn dependency and set Cudnn-related flags. -#ifndef TENSORFLOW_TSL_UTIL_USE_CUDNN_H_ -#define TENSORFLOW_TSL_UTIL_USE_CUDNN_H_ +#ifndef XLA_TSL_UTIL_USE_CUDNN_H_ +#define XLA_TSL_UTIL_USE_CUDNN_H_ #include @@ -40,4 +40,4 @@ bool ShouldCudnnGroupedConvolutionBeUsed(const int32_t filter_rows, const int32_t out_depth); } // namespace tsl -#endif // TENSORFLOW_TSL_UTIL_USE_CUDNN_H_ +#endif // XLA_TSL_UTIL_USE_CUDNN_H_ diff --git a/third_party/xla/xla/xla.bzl b/third_party/xla/xla/xla.bzl index 260fa5c1731310..71b67f8e74ce62 100644 --- a/third_party/xla/xla/xla.bzl +++ b/third_party/xla/xla/xla.bzl @@ -65,7 +65,7 @@ _XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_static(extra_deps = [], otherwise = [ Label("//xla/stream_executor/gpu:gpu_stream"), Label("//xla/stream_executor/rocm:all_runtime"), Label("//xla/stream_executor/rocm:stream_executor_rocm"), - "@local_tsl//tsl/util:determinism", + "//xla/tsl/util:determinism", ]) def xla_cc_binary(deps = [], copts = tsl_copts(), **kwargs): From ab3e910fd68f5dd51c3a63450a46f660295f1e73 Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Thu, 28 Mar 2024 13:46:25 -0700 Subject: [PATCH 048/124] Disable Compiler ir test for h100 GPU targets PiperOrigin-RevId: 620045586 --- tensorflow/python/eager/polymorphic_function/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/eager/polymorphic_function/BUILD b/tensorflow/python/eager/polymorphic_function/BUILD index 663ac4582ba3d9..2492bd4bd06124 100644 --- a/tensorflow/python/eager/polymorphic_function/BUILD +++ b/tensorflow/python/eager/polymorphic_function/BUILD @@ -639,6 +639,7 @@ tf_xla_py_strict_test( disabled_backends = [ "cpu_ondemand", "gpu_a100", + "gpu_h100", ], enable_mlir_bridge = True, python_version = "PY3", From 2457ae2ef3013935b124b5a85f1af0271afa236e Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Thu, 28 Mar 2024 13:46:34 -0700 Subject: [PATCH 049/124] Do not wrap lifted function in `TF.CustomAggregatorOp` with improper `quantization_method`. PiperOrigin-RevId: 620045624 --- .../compiler/mlir/quantization/stablehlo/BUILD | 2 +- .../compiler/mlir/quantization/stablehlo/cc/BUILD | 1 + .../stablehlo/cc/pre_calibration_test.cc | 8 +++++++- .../testing/test_pre_calibration_component.cc | 8 +++++++- .../components/pre_calibration_component.mlir | 8 ++++---- .../compiler/mlir/quantization/tensorflow/BUILD | 1 + .../passes/insert_custom_aggregation_ops.cc | 15 +++++++++++---- 7 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 4998c87f70febe..3b53b3c74bb7cb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -521,6 +521,7 @@ cc_library( ":quantization_config_proto_cc", ":stablehlo_test_passes_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:post_calibration", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pre_calibration", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -531,7 +532,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 7c7b57451a5f4a..77629c7719bf44 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -299,6 +299,7 @@ tf_cc_test( name = "pre_calibration_test", srcs = ["pre_calibration_test.cc"], deps = [ + ":config", ":pre_calibration", "//tensorflow/compiler/mlir/quantization/common:test_base", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc index c17c39d8783ba8..3d4d2295455a5c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" @@ -34,6 +35,8 @@ limitations under the License. namespace mlir::quant::stablehlo { namespace { +using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; using ::testing::Contains; using ::testing::SizeIs; @@ -92,8 +95,11 @@ TEST_F(PreCalibrationComponentTest, )mlir"); ASSERT_TRUE(module_op); + QuantizationConfig quantization_config{}; + quantization_config.mutable_static_range_ptq_preset(); + quantization_config = ExpandPresets(PopulateDefaults(quantization_config)); absl::StatusOr pre_calibration_result = - component.Run(*module_op, QuantizationConfig()); + component.Run(*module_op, quantization_config); EXPECT_THAT(pre_calibration_result, IsOk()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc index 06b53035c80c7a..0c41771a5c43b0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_pre_calibration_component.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep #include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" @@ -34,6 +35,8 @@ namespace mlir::quant::stablehlo::testing { namespace { +using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::PopulateDefaults; using ::stablehlo::quantization::QuantizationConfig; class TestPreCalibrationComponentPass @@ -52,7 +55,10 @@ void TestPreCalibrationComponentPass::runOnOperation() { // Simply runs the PreCalibrationComponent with a default configuration. PreCalibrationComponent component(&ctx); - if (!component.Run(module_op, QuantizationConfig::default_instance()).ok()) { + QuantizationConfig quantization_config{}; + quantization_config.mutable_static_range_ptq_preset(); + quantization_config = ExpandPresets(PopulateDefaults(quantization_config)); + if (!component.Run(module_op, quantization_config).ok()) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir index 6a5b58a7ba7b64..1fe56cde49601d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/components/pre_calibration_component.mlir @@ -8,10 +8,10 @@ func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @main(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } @@ -28,10 +28,10 @@ func.func @serving_default(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { } // CHECK: @serving_default(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32> // CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> -// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x4xf32>) -> tensor<1x4xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) // CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1" -// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {calibration_method = 0 : i32, {{.*}}} : (tensor<1x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32> // CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32> // CHECK: } // CHECK: } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 60099ccb0ea075..be0792ab76aff3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -448,6 +448,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_googlesource_code_re2//:re2", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index e518826d7e6d12..56b9d7393aacfd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -32,6 +33,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" @@ -45,6 +47,7 @@ namespace quant { namespace { using ::stablehlo::quantization::CalibrationOptions; +using ::stablehlo::quantization::Method; constexpr StringRef kQuantTraitAttrName = "_tfl_quant_trait"; @@ -199,7 +202,7 @@ class AddCustomAggregationOp : public RewritePattern { // The CustomAggregatorOp is only added after quantizable values. SmallVector quantizable_values; - if (isCallToLiftedFunction(op)) { + if (IsCallToQuantizableLiftedFunction(op)) { // Quantize inputs of quantizable composite functions. for (Value input : op->getOperands()) { Type element_type = getElementTypeOrSelf(input.getType()); @@ -226,7 +229,7 @@ class AddCustomAggregationOp : public RewritePattern { // Quantize output of fully quantizable composite functions. for (Value input : op->getOperands()) { auto defining_op = input.getDefiningOp(); - if (!isCallToLiftedFunction(defining_op)) { + if (!IsCallToQuantizableLiftedFunction(defining_op)) { continue; } @@ -282,9 +285,13 @@ class AddCustomAggregationOp : public RewritePattern { CalibrationOptions calib_opts_; // Whether the op is a call op to lifted composite function. - bool isCallToLiftedFunction(Operation *op) const { + bool IsCallToQuantizableLiftedFunction(Operation *op) const { if (!op) return false; - if (isa(op)) return true; + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_static_range_ptq()) return true; + } TF::PartitionedCallOp call_op = dyn_cast_or_null(op); return call_op && call_op->hasAttrOfType(kQuantTraitAttrName) && From 0d87e299d7ed225aa29439a0072f69a2c20d46f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 13:53:49 -0700 Subject: [PATCH 050/124] Fix a use-after-free issue, change a non-nullable pointer argument to a reference and get rid of an unused function argument. PiperOrigin-RevId: 620047666 --- .../experimental/auto_sharding/auto_sharding.cc | 14 +++++++------- .../auto_sharding/auto_sharding_util.cc | 16 +++++++--------- .../auto_sharding/auto_sharding_util.h | 6 ++---- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index ef7420bb384f4b..a9caf4daa6308e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2179,7 +2179,7 @@ Status SetHloShardingPostProcessing( const HloInstructionSequence& sequence, const StrategyMap& strategy_map, const CostGraph& cost_graph, absl::Span s_val, const ClusterEnvironment& cluster_env, const bool crash_at_error, - absl::flat_hash_map>* + absl::flat_hash_map>& preserve_shardings) { const std::vector& instructions = sequence.instructions(); const Array& device_mesh = cluster_env.device_mesh_; @@ -2269,8 +2269,8 @@ Status SetHloShardingPostProcessing( // In the analysis itself, we use replicated strategies as a stand-in for // the (expected) maximal sharding annotations that send-done ops usually // have. Here we restore these maximal shardings if present. - auto preserved_sharding_iter = preserve_shardings->find(inst->name()); - if (preserved_sharding_iter != preserve_shardings->end()) { + auto preserved_sharding_iter = preserve_shardings.find(inst->name()); + if (preserved_sharding_iter != preserve_shardings.end()) { const auto& preserved_sharding = preserved_sharding_iter->second; if (preserved_sharding.size() > 1) { std::vector tuple_elements_shape( @@ -2295,8 +2295,8 @@ Status SetHloShardingPostProcessing( // In the analysis itself, we use replicated strategies as a stand-in for // the (expected) maximal sharding annotations that send ops usually // have. Here we restore these maximal shardings if present. - auto preserved_sharding_iter = preserve_shardings->find(inst->name()); - if (preserved_sharding_iter != preserve_shardings->end()) { + auto preserved_sharding_iter = preserve_shardings.find(inst->name()); + if (preserved_sharding_iter != preserve_shardings.end()) { const auto& preserved_sharding = preserved_sharding_iter->second; if (preserved_sharding.size() > 1) { inst->set_sharding( @@ -2365,7 +2365,7 @@ Status SetHloShardingPostProcessing( } } FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( - inst, dst_shardings, device_mesh, preserve_shardings); + inst, dst_shardings, device_mesh); break; } @@ -3846,7 +3846,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( if (!SetHloShardingPostProcessing( sequence, strategy_map, cost_graph, s_val, cluster_env, /* crash_at_error */ !option_.try_multiple_mesh_shapes, - &preserve_shardings) + preserve_shardings) .ok()) { return AutoShardingResult::kModuleUnchanged; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 827fb83881231d..0cb674711c9b66 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1351,9 +1351,7 @@ HloInstruction* ReshardTensor(HloInstruction* tensor, void FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_shardings, - const Array& device_mesh, - absl::flat_hash_map>* - preserve_shardings) { + const Array& device_mesh) { size_t tuple_size = inst->shape().tuple_shapes_size(); auto current_sharding = inst->sharding(); @@ -1414,7 +1412,7 @@ void FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( void FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, const Array& device_mesh, - absl::flat_hash_map>* + absl::flat_hash_map>& preserve_shardings) { HloInstruction* operand = inst->mutable_operand(0); auto input_tuple_sharding = operand->sharding(); @@ -1444,11 +1442,11 @@ void FixMixedMeshShapeReshardingGetTupleElement( TF_CHECK_OK(inst->ReplaceUseWith(user, replace_with)); } - CHECK_NE(preserve_shardings, nullptr); - if (preserve_shardings->contains(inst->name())) { - (*preserve_shardings)[replace_with->name()] = - std::vector(preserve_shardings->at(inst->name())); - preserve_shardings->erase(inst->name()); + auto iter = preserve_shardings.find(inst->name()); + if (iter != preserve_shardings.end()) { + preserve_shardings[replace_with->name()] = + std::vector(iter->second); + preserve_shardings.erase(inst->name()); } } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index e293adddd19ca6..4a129bb159076e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -465,15 +465,13 @@ Shape ComputeIntermediateShape(const HloSharding& src_sharding, void FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, const Array& device_mesh, - absl::flat_hash_map>* + absl::flat_hash_map>& preserve_shardings); void FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_sharding, - const Array& device_mesh, - absl::flat_hash_map>* - preserve_shardings); + const Array& device_mesh); void FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const HloSharding& dst_sharding, From d2f599eb551e2d2596588a79a4a4a69c104c8287 Mon Sep 17 00:00:00 2001 From: Wren Romano Date: Thu, 28 Mar 2024 14:18:17 -0700 Subject: [PATCH 051/124] [XLA:Python] Factors the ":logging" library out from ":xla_extension". This is a prospective change for https://github.com/openxla/xla/pull/10966. In particular, this will help fix an OSS build problem: "tensorflow/xla/linux/cpu/build_cpu" not being able to find the `InitializeAbslLogging` function. PiperOrigin-RevId: 620055000 --- third_party/xla/xla/python/BUILD | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index e1f044aead00b3..b77b4eefc7f919 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1140,13 +1140,18 @@ cc_library( ]), ) +cc_library( + name = "logging", + srcs = ["logging.cc"], + hdrs = ["logging.h"], + deps = [ + "@com_google_absl//absl/log:initialize", + ], +) + tsl_pybind_extension( name = "xla_extension", - srcs = [ - "logging.cc", - "logging.h", - "xla.cc", - ], + srcs = ["xla.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -1177,6 +1182,7 @@ tsl_pybind_extension( ":custom_call_sharding", ":dlpack", ":jax_jit", + ":logging", ":mlir", ":nb_absl_flat_hash_map", ":nb_absl_span", From 97fd9c1093140beb4178ff095363383b293222d3 Mon Sep 17 00:00:00 2001 From: Swachhand Lokhande Date: Thu, 28 Mar 2024 14:26:39 -0700 Subject: [PATCH 052/124] Add support for tiled sharding with replicate_on_last_tile_dim in TpuRewrite pass PiperOrigin-RevId: 620057438 --- tensorflow/compiler/mlir/tensorflow/BUILD | 3 + .../mlir/tensorflow/tests/tpu_rewrite.mlir | 134 ++++++++++++++++++ .../tensorflow/utils/xla_sharding_util.cc | 105 ++++++++++---- .../mlir/tensorflow/utils/xla_sharding_util.h | 5 + .../distributed_tpu_rewrite_pass.cc | 32 +---- 5 files changed, 227 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 2a0a2222d9aa04..79203111b03ea4 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1375,6 +1375,9 @@ cc_library( deps = [ ":tensorflow", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 55b68e5de2fb5f..db28242944434e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -2196,6 +2196,73 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- +// Tests inputs to TPUComputation that are tiled in multiple dimensions with +// replicate_on_last_tile_dim set. + +// The following OpSharding is used for TPU computation inputs in below test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// replicate_on_last_tile_dim: true +// Serialized string: +// "\08\03\1A\03\02\01\02\22\04\00\01\02\030\01" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @multi_dimension_tiled_input_replicate_last_dim + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x10xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<*xi32>) + func.func @multi_dimension_tiled_input_replicate_last_dim(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"() + // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_0_DIM]], %[[RI_0]]) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#0, %[[COMPILE]]#1) + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#0, %[[RI_1]], %[[COMPILE]]#2) + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#1, %[[COMPILE]]#3) + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#1, %[[COMPILE]]#4) + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\1A\03\02\01\02\22\04\00\01\02\030\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + func.return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + // Tests that tiled output with multiple dimension sharding works properly. // The following OpSharding is used for TPU computation outputs in below test: @@ -2278,6 +2345,73 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- +// Tests that tiled output with multiple dimension sharding works properly with +// replicate_on_last_tile_dim set. + +// The following OpSharding is used for TPU computation outputs in below test: +// output 0 +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 2 +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// tile_assignment_devices: 2 +// tile_assignment_devices: 3 +// replicate_on_last_tile_dim: true +// Serialized string: +// "\08\03\1A\03\02\01\02\22\04\00\01\02\030\01" +// +// output 1 +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\00" + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @multi_dimension_tiled_output_replicate_last_dim + func.func @multi_dimension_tiled_output_replicate_last_dim(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<*xi32> + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:5 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK: tf_device.return %[[EXECUTE_1_OUTPUT]] + // CHECK: %[[LAUNCH_2_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_2_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_2_OUTPUT]] + // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"( + // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] + // CHECK: %[[CONST_CONCAT_DIM:.*]] = "tf.Const"() + // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#3 + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\03\02\01\02\22\04\00\01\02\030\01", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> + } + func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> + } + func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1> + func.return %4, %3 : tensor<*xi32>, tensor<*xi1> + } +} + +// ----- + // Tests inputs device assignment order is well preserved for tiled input sharding. // The following OpSharding is used for TPU computation inputs in below test: diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 58adaa41349b14..1d3df520549da7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -15,10 +15,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" +#include +#include #include #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -152,16 +157,21 @@ mlir::LogicalResult HandleTileShardedInputs( // are created such that input data is sharded in row major order. // Split nodes at ith depth from the original input node represent nodes // that split the input data at i-th dimension. - const auto& dimension_splits = input_sharding.tile_assignment_dimensions(); - for (const auto& num_splits_and_index : llvm::enumerate(dimension_splits)) { - const int num_splits = num_splits_and_index.value(); - const int dimension_index = num_splits_and_index.index(); - if (num_splits == 1) continue; + auto dimension_to_splits_map = + GetDimensionIndicesAndNumSplitsFromSharding(input_sharding); + if (!dimension_to_splits_map.ok()) { + LOG(ERROR) << dimension_to_splits_map.status(); + return mlir::failure(); + } + + for (const auto& dimension_and_num_splits : *dimension_to_splits_map) { + const int dimension = dimension_and_num_splits.first; + const int num_splits = dimension_and_num_splits.second; // Creates root split op. if (split_ops_for_tiled_input.empty()) { mlir::TF::SplitOp root_split_op; - auto result = CreateSplitOp(num_splits, dimension_index, location, + auto result = CreateSplitOp(num_splits, dimension, location, original_source, builder, &root_split_op); if (mlir::failed(result)) return mlir::failure(); @@ -176,7 +186,7 @@ mlir::LogicalResult HandleTileShardedInputs( for (auto parent_split_output_value : split_op.getResults()) { mlir::TF::SplitOp child_split_op; auto result = - CreateSplitOp(num_splits, dimension_index, location, + CreateSplitOp(num_splits, dimension, location, parent_split_output_value, builder, &child_split_op); if (mlir::failed(result)) return mlir::failure(); @@ -188,12 +198,21 @@ mlir::LogicalResult HandleTileShardedInputs( } // `split_ops_for_tiled_input` now includes final split nodes - // from which sharded data will be fed into TPUExcute ops -- sorted by + // from which sharded data will be fed into TPUExecute ops -- sorted by // row major order. + tiled_inputs->clear(); tiled_inputs->reserve(input_sharding.tile_assignment_devices_size()); - for (auto split_op : split_ops_for_tiled_input) - tiled_inputs->append(split_op.getResults().begin(), - split_op.getResults().end()); + for (auto split_op : split_ops_for_tiled_input) { + for (auto split_op_output : split_op.getResults()) { + int64_t repeat_count = + input_sharding.replicate_on_last_tile_dim() + ? *input_sharding.tile_assignment_dimensions().rbegin() + : 1; + for (int64_t i = 0; i < repeat_count; ++i) { + tiled_inputs->push_back(split_op_output); + } + } + } return mlir::success(); } @@ -205,6 +224,29 @@ bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { } // namespace +absl::StatusOr> GetDimensionIndicesAndNumSplitsFromSharding( + const xla::OpSharding& sharding) { + int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size(); + if (sharding.replicate_on_last_tile_dim()) { + tensor_tile_rank--; + } + + std::map dimension_to_splits_map; + for (int dim_index = 0; dim_index < tensor_tile_rank; ++dim_index) { + if (sharding.tile_assignment_dimensions(dim_index) > 1) { + dimension_to_splits_map.emplace( + dim_index, sharding.tile_assignment_dimensions(dim_index)); + } + } + + if (dimension_to_splits_map.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Arg has unnecessary tiled sharding: ", sharding.DebugString())); + } + + return dimension_to_splits_map; +} + int GetDimsFromXLAShardingTiled(const xla::OpSharding& xla_sharding) { return xla_sharding.tile_assignment_dimensions_size() - (xla_sharding.replicate_on_last_tile_dim() ? 1 : 0) - @@ -478,15 +520,25 @@ mlir::LogicalResult GetTileShardedOutputsToMerge( const xla::OpSharding& sharding = output_sharding_config[cluster_func_output_index]; outputs_to_merge->reserve(sharding.tile_assignment_devices_size()); - for (const auto logical_device_id : sharding.tile_assignment_devices()) { + for (const auto& core_id_and_index : + llvm::enumerate(sharding.tile_assignment_devices())) { + auto core_id = core_id_and_index.value(); + auto tile_index = core_id_and_index.index(); + + int last_tile_dim_size = *sharding.tile_assignment_dimensions().rbegin(); + if (sharding.replicate_on_last_tile_dim() && + tile_index % last_tile_dim_size != 0) { + continue; + } + int region_output_index; - auto status = LookupClusterToCoreIndex( - location, cluster_to_core_index, logical_device_id, - cluster_func_output_index, ®ion_output_index); + auto status = LookupClusterToCoreIndex(location, cluster_to_core_index, + core_id, cluster_func_output_index, + ®ion_output_index); if (failed(status)) return mlir::failure(); const auto output_from_logical_device = - new_parallel_execute.GetRegionOutputs( - cluster_idx + logical_device_id)[region_output_index]; + new_parallel_execute.GetRegionOutputs(cluster_idx + + core_id)[region_output_index]; outputs_to_merge->emplace_back(output_from_logical_device); } @@ -518,12 +570,18 @@ mlir::LogicalResult HandleTileShardedOutputs( // devices to a single replica output. const xla::OpSharding& sharding = output_sharding_config[cluster_func_output_index]; - int concat_dimension = sharding.tile_assignment_dimensions_size() - 1; - for (auto num_splits : llvm::reverse(sharding.tile_assignment_dimensions())) { - if (num_splits == 1) { - --concat_dimension; - continue; - } + + auto dimension_to_splits_map = + GetDimensionIndicesAndNumSplitsFromSharding(sharding); + if (!dimension_to_splits_map.ok()) { + LOG(ERROR) << dimension_to_splits_map.status(); + return mlir::failure(); + } + + for (auto it = dimension_to_splits_map->rbegin(); + it != dimension_to_splits_map->rend(); ++it) { + int concat_dimension = it->first; + int num_splits = it->second; llvm::SmallVector new_outputs; new_outputs.reserve(num_splits); @@ -539,7 +597,6 @@ mlir::LogicalResult HandleTileShardedOutputs( } std::swap(new_outputs, outputs_to_merge); - --concat_dimension; } assert(outputs_to_merge.size() == 1); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 6295be3776416e..ab22eb978214ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ +#include #include +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -122,6 +124,9 @@ bool IsSplitSharding(const xla::OpSharding& sharding); // REPLICATED type and replicated OTHER type. bool IsReplicatedSharding(const xla::OpSharding& sharding); +// Returns a map of dimension indices and number of splits for tiled sharding. +absl::StatusOr> GetDimensionIndicesAndNumSplitsFromSharding( + const xla::OpSharding& sharding); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index d13eb1dbdfc901..9008591c2800db 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -650,28 +650,6 @@ Status GetStepMarkerLocation(const Node& replicate_node, return absl::OkStatus(); } -// Extracts a map of dimension and number of splits for tiled input from xla -// sharding attribute. -Status GetDimensionIndicesAndNumSplitsFromSharding( - const xla::OpSharding& sharding, std::map* split_dimension_map) { - int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size(); - if (sharding.replicate_on_last_tile_dim()) { - tensor_tile_rank--; - } - for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) { - if (sharding.tile_assignment_dimensions(dim_index) > 1) { - split_dimension_map->emplace( - dim_index, sharding.tile_assignment_dimensions(dim_index)); - } - } - - if (split_dimension_map->empty()) { - return absl::InvalidArgumentError(absl::StrCat( - "Arg has unnecessary tiled sharding: ", sharding.DebugString())); - } - return absl::OkStatus(); -} - // Updates contents of the function with `function_name` in function library // definition `flib_def` to `new_graph`. This is required when graph // transformation happens inside a function call body. @@ -861,9 +839,8 @@ StatusOr CreateOrGetSplitNodesForInputSharding( } // Maps input dimension and number of splits with which the // dimension sharded. - std::map split_dimension_map; - TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding( - sharding, &split_dimension_map)); + TF_ASSIGN_OR_RETURN(auto split_dimension_map, + GetDimensionIndicesAndNumSplitsFromSharding(sharding)); TF_RET_CHECK(!split_dimension_map.empty()) << "Unnecessary sharding attribute found."; @@ -1280,9 +1257,8 @@ StatusOr CreateConcatNodesForRetval( const PartialTensorShape& inferred_shape, int replica_id, const std::vector& orig_inputs, Graph* graph, absl::string_view device) { - std::map split_dimension_map; - TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding( - sharding, &split_dimension_map)); + TF_ASSIGN_OR_RETURN(auto split_dimension_map, + GetDimensionIndicesAndNumSplitsFromSharding(sharding)); std::vector inputs_to_sharded_retval = orig_inputs; bool has_paddings = false; From d952cc21e0ad6601f46ffc295072ad32d0e298d0 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Thu, 28 Mar 2024 14:35:50 -0700 Subject: [PATCH 053/124] Add a field to StreamZ metric /tensorflow/core/tf_mlir_bridge_first_phase_count Add a filed for the type of TF2XLA Phase 1 Bridge , i.e. Replicated Bridge vs. Non-replicated Bridge. PiperOrigin-RevId: 620060074 --- .../tensorflow/transforms/host_runtime/BUILD | 3 +- .../lower_cluster_to_runtime_ops.cc | 14 +++++-- .../lower_cluster_to_runtime_ops_test.cc | 9 +++-- .../mlir/tensorflow/utils/attribute_utils.h | 12 ++++++ tensorflow/compiler/mlir/tf2xla/api/v1/BUILD | 3 +- .../compiler/mlir/tf2xla/api/v1/cluster_tf.cc | 15 +++---- .../mlir/tf2xla/api/v1/cluster_tf_test.cc | 15 +++++-- tensorflow/compiler/mlir/tf2xla/api/v2/BUILD | 6 +-- .../compiler/mlir/tf2xla/api/v2/cluster_tf.cc | 20 ++++++---- .../mlir/tf2xla/api/v2/cluster_tf_test.cc | 27 +++++++++---- tensorflow/compiler/tf2xla/BUILD | 2 +- .../compiler/tf2xla/mlir_bridge_pass.cc | 39 ++++++++++++++----- tensorflow/core/framework/metrics.cc | 12 +++--- tensorflow/core/framework/metrics.h | 6 ++- 14 files changed, 127 insertions(+), 56 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index 5d500453d17fe0..15f339eccd2f93 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -26,10 +26,10 @@ cc_library( deps = [ ":runtime_passes", "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", @@ -62,6 +62,7 @@ tf_cc_test( ":lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir:register_common_dialects", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc index 6f46766a3250fa..713e9080f2e03b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -121,6 +122,7 @@ void CreateNonTPULowerClusterToRuntimeOpsPassPipeline( // TODO(b/306728216): Move this out of the Bridge component and into a Host // runtime component. tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, + std::string bridge_type, tsl::DeviceType device_type, absl::Status status) { if (status.ok()) { @@ -129,11 +131,12 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, VLOG(2) << error_prefix << " " << status; tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - device_type.type_string(), /*bridge_version=*/"v2", + bridge_type, + /*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV2, + device_type.type_string(), /*fallback_enabled=*/false, /*result=*/"failure"); - constexpr char kBridgeComponent[] = "TFXLABridge"; std::string bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_TPU_BRIDGE"; tsl::OkOrSetErrorCounterPayload( @@ -144,7 +147,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE"; } - tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent, + tsl::error_logging::Log(mlir::TF::kBridgeComponent, bridge_subcomponent, status.ToString()) .IgnoreError(); @@ -194,10 +197,13 @@ absl::Status RunLowerClusterToRuntimeOpsPassPipeline( module, llvm::StringRef(), &runtime_lowering); } + std::string bridge_type = xla_device_type == DeviceType(DEVICE_TPU_XLA_JIT) + ? mlir::TF::kMlirPh1BridgeCounterReplicated + : mlir::TF::kMlirPh1BridgeCounterNonReplicated; auto result_status = diag_handler.ConsumeStatus(); TF_RETURN_IF_ERROR( RecordIfErrorStatus(/*error_prefix=*/"lower_cluster_to_runtime", - xla_device_type, result_status)); + bridge_type, xla_device_type, result_status)); return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc index 3e3e8db504f1da..1f0cf146203de2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/env.h" @@ -167,9 +168,11 @@ TEST_F(LowerClusterToRuntimeOpsTest, ErrorsWithBadCluster) { *mlir_module_, DeviceType(DEVICE_TPU_XLA_JIT)) .ok()); - EXPECT_EQ(compilation_status.Delta("XLA_TPU_JIT", "v2", "fallback_disabled", - "failure"), - 1); + EXPECT_EQ( + compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, "XLA_TPU_JIT", + "fallback_disabled", "failure"), + 1); } TEST_F(LowerClusterToRuntimeOpsTest, DumpsPipelinePasses) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index b50135c9bdfac3..5a99806d4295f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -121,6 +121,18 @@ inline constexpr llvm::StringRef kDynamicArgIndexAttr = "_dynamic_arg_index"; inline constexpr llvm::StringRef kParallelExecAnnotation = "_parallel_execution_ids"; +// Logging + +// Name of component for error logging. This name is fixed and required to +// enable logging. +inline const char kBridgeComponent[] = "TFXLABridge"; +inline const char kMlirPh1BridgeCounterReplicated[] = "replicated"; +inline const char kMlirPh1BridgeCounterNonReplicated[] = "nonreplicated"; +inline const char kMlirPh1BridgeCounterV1[] = "v1"; +inline const char kMlirPh1BridgeCounterV2[] = "v2"; +inline const char kMlirPh1BridgeCounterTpu[] = "tpu"; +inline const char kMlirPh1BridgeCounterNonTpu[] = "cpu/gpu"; + // Copies attributes that satisfy the given predicate from `from` to `to`. template void CopyAttributes(Operation *from, Operation *to, Predicate P) { diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 3b55c2f954a22b..53a65bd3ae3662 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -194,6 +194,7 @@ cc_library( ], deps = [ ":tf_dialect_to_executor", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -232,6 +233,7 @@ tf_cc_test( deps = [ ":cluster_tf", "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/platform:resource_loader", @@ -241,7 +243,6 @@ tf_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/lib/monitoring:test_utils", "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc index 09209d8673524c..38c11ec857f072 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -60,10 +61,6 @@ using mlir::func::FuncOp; namespace { -// Name of component for error logging. This name is fixed and required to -// enable logging. -constexpr char kBridgeComponent[] = "TFXLABridge"; - void CreateReplicatedBridgePipelineV1(OpPassManager &pm) { pm.addPass(mlir::tf2xla::internal::CreateInferenceMetricsPass()); @@ -152,10 +149,12 @@ tensorflow::Status RecordStatusIfError(const std::string error_prefix, } tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - /*device_type=*/"tpu", /*bridge_version=*/"v1", + /*bridge_type=*/mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, /*fallback_enabled=*/is_in_fallback_enabled_mode, /*result=*/"failure"); - tsl::error_logging::Log(kBridgeComponent, + tsl::error_logging::Log(mlir::TF::kBridgeComponent, "TFXLA_PHASE_ONE_MLIR_TPU_V1_COMPAT_BRIDGE", status.ToString()) .IgnoreError(); @@ -221,7 +220,9 @@ tensorflow::Status RunSessionTf2xlaClusteringBridge( RunClusteringPipelineOnSubmodule(module, is_in_fallback_enabled_mode)); tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - /*device_type=*/"tpu", /*bridge_version=*/"v1", + /*bridge_type=*/mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, /*n_fallback_enabled*/ is_in_fallback_enabled_mode, /*result=*/"success"); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc index 44eafb25f579c8..e674989d2174ba 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" #include "tsl/lib/core/status_test_util.h" @@ -84,8 +85,11 @@ TEST_F(SessionClusterTensorflowDialectTest, ClustersTf) { TF_EXPECT_OK( RunSessionTf2xlaClusteringBridge(*mlir_module_, /*is_in_fallback_enabled_mode=*/false)); - EXPECT_EQ( - compilation_status.Delta("tpu", "v1", "fallback_disabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV1, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "success"), + 1); } TEST_F(SessionClusterTensorflowDialectTest, FailsWithMultipleSubmodules) { @@ -98,8 +102,11 @@ TEST_F(SessionClusterTensorflowDialectTest, FailsWithMultipleSubmodules) { /*is_in_fallback_enabled_mode=*/false) .ok()); - EXPECT_EQ( - compilation_status.Delta("tpu", "v1", "fallback_disabled", "failure"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV1, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "failure"), + 1); } } // namespace diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index a92239e8dbba69..545203ad20ea23 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -119,12 +119,11 @@ cc_library( ], deps = [ ":device_type_proto_cc", - ":tf_dialect_to_executor", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", - "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes", "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks", "//tensorflow/core:framework", @@ -133,7 +132,6 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:stacktrace", "//tensorflow/core/platform:status", - "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", @@ -143,7 +141,6 @@ cc_library( "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:error_logging", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", ], ) @@ -159,6 +156,7 @@ tf_cc_test( ":cluster_tf", "//tensorflow/compiler/mlir:register_common_dialects", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/platform:resource_loader", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index 23480374032aaa..41df5eb0750459 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h" @@ -52,8 +53,6 @@ using mlir::OpPassManager; using mlir::PassManager; using mlir::func::FuncOp; -constexpr char kBridgeComponent[] = "TFXLABridge"; - // Run the TF XLA Bridge based on the input pipeline, which can be either TPU // bridge pipeline or non TPU bridge pipeline. tensorflow::Status RunTFXLABridge( @@ -114,6 +113,7 @@ tensorflow::Status RunTFXLABridge( tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, bool fallback_enabled, + std::string bridge_type, std::string device_type, absl::Status status) { if (status.ok()) { @@ -122,7 +122,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, VLOG(2) << error_prefix << " " << status; tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - device_type, /*bridge_version=*/"v2", + /*bridge_type*/ bridge_type, /*bridge_version=*/"v2", device_type, /*fallback_enabled=*/fallback_enabled, /*result=*/"failure"); @@ -135,7 +135,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE"; } - tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent, + tsl::error_logging::Log(mlir::TF::kBridgeComponent, bridge_subcomponent, status.ToString()) .IgnoreError(); @@ -162,8 +162,9 @@ void CreateReplicatedClusteringPipelineV2(OpPassManager &pm) { tensorflow::Status RunFunctionTf2xlaClusteringBridge( ModuleOp module, bool is_supported_by_replicated_brige, bool is_in_fallback_enabled_mode, llvm::StringRef module_name) { - std::string device_type_filter = - is_supported_by_replicated_brige ? "tpu" : "cpu/gpu"; + std::string device_type = is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterTpu + : mlir::TF::kMlirPh1BridgeCounterNonTpu; VLOG(2) << (is_supported_by_replicated_brige ? "Replicated" : "NonReplicated") @@ -186,14 +187,17 @@ tensorflow::Status RunFunctionTf2xlaClusteringBridge( }, module_name, /*dump_prefix=*/"tf_xla_bridge_v2_nonreplicated"); + std::string bridge_type = is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterReplicated + : mlir::TF::kMlirPh1BridgeCounterNonReplicated; // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. TF_RETURN_IF_ERROR(RecordIfErrorStatus( /*error_prefix=*/"clustering_v2", is_in_fallback_enabled_mode, - device_type_filter, clustering_status)); + bridge_type, device_type, clustering_status)); // TODO(b/317798386): add is_supported_by_replicated_brige as a filter. tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( - device_type_filter, /*bridge_version=*/"v2", + bridge_type, /*bridge_version=*/"v2", device_type, /*fallback_enabled=*/is_in_fallback_enabled_mode, /*result=*/"success"); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc index c4a96702533c49..a5f64a91cd8cb4 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" #include "tsl/lib/core/status_test_util.h" @@ -94,8 +95,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfReplicatedBridge) { FuncOp main = mlir_module_->lookupSymbol("main"); ASSERT_TRUE(main); - EXPECT_EQ( - compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "success"), + 1); } TEST_F(FunctionClusterTensorflowDialectTest, @@ -118,8 +122,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, }); EXPECT_TRUE(has_cluster_op); - EXPECT_EQ( - compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_disabled", "success"), + 1); } TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) { @@ -135,7 +142,10 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) { ASSERT_TRUE(main); EXPECT_EQ( - compilation_status.Delta("cpu/gpu", "v2", "fallback_disabled", "success"), + compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterNonReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterNonTpu, + "fallback_disabled", "success"), 1); } @@ -148,8 +158,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, LogsFallbackMode) { *mlir_module_, /*is_supported_by_replicated_brige*/ true, /*is_in_fallback_enabled_mode=*/true)); - EXPECT_EQ( - compilation_status.Delta("tpu", "v2", "fallback_enabled", "success"), 1); + EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated, + mlir::TF::kMlirPh1BridgeCounterV2, + mlir::TF::kMlirPh1BridgeCounterTpu, + "fallback_enabled", "success"), + 1); } } // namespace diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 16315f71d9652c..01e85cc7c6cfc7 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1172,11 +1172,11 @@ cc_library( hdrs = ["mlir_bridge_pass.h"], visibility = [":internal"], deps = [ - ":tf2xla_defs", ":xla_op_registry", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:device_util", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 0402508fe92f56..c24654c894b34f 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -29,13 +29,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h" #include "tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h" -#include "tensorflow/compiler/tf2xla/tf2xla_defs.h" +// #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device.h" @@ -162,16 +163,28 @@ MlirOptimizationPassState GetPassStateImpl( << " Bridge, disabled by user. " "The fallback will evaluate."; metrics::UpdateTfMlirBridgeFirstPhaseCounter( - is_supported_by_replicated_brige ? "tpu" : "cpu/gpu", "v2", true, - "disabled_by_user"); + /*bridge_type*/ is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterReplicated + : mlir::TF::kMlirPh1BridgeCounterNonReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV2, + /*device_type*/ + is_supported_by_replicated_brige + ? mlir::TF::kMlirPh1BridgeCounterTpu + : mlir::TF::kMlirPh1BridgeCounterNonTpu, + /*fallback_enabled*/ true, + /*result*/ "disabled_by_user"); return MlirOptimizationPassState::Disabled; } case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: // Graph analysis only runs on TPU graph. VLOG(1) << "Skipping MLIR TPU Bridge, disabled because the " "graph has unsupported features. The fallback will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, - "invalid_graph"); + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + /*bridge_type*/ mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV2, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, + /*fallback_enabled*/ true, + /*result*/ "invalid_graph"); // We set `uses_uninitialized_resource_args` to false here because the // first phase of the bridge is not affected by uninitialized resource // args. @@ -305,16 +318,24 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( VLOG(1) << "Skipping MLIR Replicated Bridge V1 Compat, MLIR Replicated " "bridge disabled " "by user. Fallback will evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v1", true, - "disabled_by_user"); + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + /*bridge_type*/ mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, + /*fallback_enabled*/ true, + /*result*/ "disabled_by_user"); return MlirOptimizationPassState::Disabled; case MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis: VLOG(1) << "Skipping MLIR Replicated Bridge V1 Compat, MLIR Replicated " "bridge disabled " "because graph has unsupported features. Old bridge will " "evaluate."; - metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v1", true, - "invalid_graph"); + metrics::UpdateTfMlirBridgeFirstPhaseCounter( + /*bridge_type*/ mlir::TF::kMlirPh1BridgeCounterReplicated, + /*bridge_version*/ mlir::TF::kMlirPh1BridgeCounterV1, + /*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu, + /*fallback_enabled*/ true, + /*result*/ "invalid_graph"); // We set `uses_uninitialized_resource_args` to false here because the // first phase of the bridge is not affected by uninitialized resource // args. diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index 902b25bd12356d..863f0c209513ac 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -443,10 +443,10 @@ auto* eager_client_error_counter = tsl::monitoring::Counter<2>::New( "Count the errors in eager client as a central place.", "error_source", "error_type"); -auto* mlir_bridge_first_phase_counter = tsl::monitoring::Counter<4>::New( +auto* mlir_bridge_first_phase_counter = tsl::monitoring::Counter<5>::New( "/tensorflow/core/tf_mlir_bridge_first_phase_count", - "Tracks processing state in first phase of mlir bridge", "device", - "version", "fallback", "result"); + "Tracks processing state in first phase of mlir bridge", "bridge", + "version", "device", "fallback", "result"); auto* mlir_second_phase_count = tensorflow::monitoring::Counter<1>::New( "/tensorflow/core/tf2xla/api/v2/phase2_compilation_status" /*metric_name*/, @@ -948,14 +948,16 @@ void TestDelta::Reset() { last_value_ = cell_->value(); } int64 TestDelta::Get() { return cell_->value() - last_value_; } -void UpdateTfMlirBridgeFirstPhaseCounter(const std::string& device_type, +void UpdateTfMlirBridgeFirstPhaseCounter(const std::string& bridge_type, const std::string& bridge_version, + const std::string& device_type, bool fallback_enabled, const std::string& result) { std::string fallback_status = fallback_enabled ? "fallback_enabled" : "fallback_disabled"; mlir_bridge_first_phase_counter - ->GetCell(device_type, bridge_version, fallback_status, result) + ->GetCell(bridge_type, bridge_version, device_type, fallback_status, + result) ->IncrementBy(1); } diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h index 955a6461bd7c36..1a6ba8a88bf890 100644 --- a/tensorflow/core/framework/metrics.h +++ b/tensorflow/core/framework/metrics.h @@ -341,12 +341,14 @@ int64_t GetFunctionGraphOptimizationCacheLoadCount( // Records the activity of the first phase of the mlir bridge using the // tf_metadata.tf_mlir_bridge_first_phase_count metric. -// device_type: tpu, cpu, gpu, etc. +// bridge_type: replicated, nonreplicated, etc. // bridge_version: v1 compat, v2, etc. +// device_type: tpu, cpu, gpu, etc. // fallback_enabled: true if fallback will happen, false if not // result: outcome of bridge (success, failure, disabled, invalid_graph, etc.) -void UpdateTfMlirBridgeFirstPhaseCounter(const std::string& device_type, +void UpdateTfMlirBridgeFirstPhaseCounter(const std::string& bridge_type, const std::string& bridge_version, + const std::string& device_type, bool fallback_enabled, const std::string& result); From 3107f4090909cd323051c7926bfbee8a2b1bcd55 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 28 Mar 2024 14:40:01 -0700 Subject: [PATCH 054/124] [MLIR Exporter] Move created `NodeDef` and `AttrValue` protocol buffers into the exported `Graph`. PiperOrigin-RevId: 620061282 --- .../mlir/tensorflow/translate/export_graphdef.cc | 6 +++--- .../compiler/mlir/tensorflow/utils/export_utils.cc | 14 ++++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 6042ae37ee8fa2..550ff9fee330a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -421,7 +421,7 @@ Status Exporter::AddInstructionNode(Operation* inst) { inst, name, /*ignore_unregistered_attrs=*/false)); UseOriginalFunctionNames(*node_def); - TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def)); + TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); DCHECK(node != nullptr); nodes_[inst] = node; return OkStatus(); @@ -436,7 +436,7 @@ bool IsEntryFunctionArg(BlockArgument arg) { Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index, llvm::StringRef name) { TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name)); - TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def)); + TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); args_[arg] = node; return OkStatus(); } @@ -455,7 +455,7 @@ Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch, GetReturnNode(function, operand_and_idx.value(), operand_and_idx.index(), names.empty() ? "" : names[operand_and_idx.index()])); - TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(*node_def)); + TF_ASSIGN_OR_RETURN(Node * node, graph_->AddNode(std::move(*node_def))); return_nodes.push_back(node); } return OkStatus(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index f9dc740cee1aae..f01a3f0e09d19b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace { @@ -395,12 +396,12 @@ Status ConvertAttributes( if (auto symbol_ref = attr.dyn_cast()) { TF_RETURN_IF_ERROR( ConvertAttribute(symbol_ref.cast(), &value)); - func_call_attrs[string(name)] = value; + func_call_attrs[string(name)] = std::move(value); continue; } if (auto func_attr = attr.dyn_cast()) { TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); - func_call_attrs[string(name)] = value; + func_call_attrs[string(name)] = std::move(value); continue; } if (attr.isa()) { @@ -434,13 +435,14 @@ Status ConvertAttributes( TF_RET_CHECK(name_tokens.size() <= 2); auto it = func_call_attrs.find(name_tokens[0]); if (it == func_call_attrs.end()) { - (*values)[string(name)] = value; + (*values)[string(name)] = std::move(value); } else { - (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value; + (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = + std::move(value); } } - for (const auto& it : func_call_attrs) { - (*values)[it.first] = it.second; + for (auto& it : func_call_attrs) { + (*values)[it.first] = std::move(it.second); } return OkStatus(); } From 86e08a6048d564c37c7be6a51826c4ac9f581c76 Mon Sep 17 00:00:00 2001 From: Wren Romano Date: Thu, 28 Mar 2024 14:50:54 -0700 Subject: [PATCH 055/124] [XLA:Python] Add python function to convert `xla::LiteralProto` into a tuple-tree of `numpy.ndarray`. This is intended for internal debugging use. It cannot be used on OSS because the relevant protobufs are not part of the public API. (Though it must not break the OSS build, naturally.) PiperOrigin-RevId: 620064326 --- third_party/xla/xla/python/tools/BUILD | 93 +++++++++ third_party/xla/xla/python/tools/__init__.py | 0 third_party/xla/xla/python/tools/_types.cc | 128 ++++++++++++ third_party/xla/xla/python/tools/_types.pyi | 23 +++ third_party/xla/xla/python/tools/types.py | 51 +++++ .../xla/xla/python/tools/types_test.py | 183 ++++++++++++++++++ 6 files changed, 478 insertions(+) create mode 100644 third_party/xla/xla/python/tools/BUILD create mode 100644 third_party/xla/xla/python/tools/__init__.py create mode 100644 third_party/xla/xla/python/tools/_types.cc create mode 100644 third_party/xla/xla/python/tools/_types.pyi create mode 100644 third_party/xla/xla/python/tools/types.py create mode 100644 third_party/xla/xla/python/tools/types_test.py diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD new file mode 100644 index 00000000000000..ac45d9d3f10c66 --- /dev/null +++ b/third_party/xla/xla/python/tools/BUILD @@ -0,0 +1,93 @@ +load("@local_tsl//tsl:tsl.default.bzl", "tsl_pybind_extension") + +# NOTE: We can't use `pytype_pybind_extension` nor `pytype_strict_contrib_test` +# because the OSS versions of these files do not include ports of those rules. +# We must instead use `tsl_pybind_extension` and `py_strict_test`. +load("//xla:pytype.default.bzl", "pytype_strict_library") +load("//xla:strict.default.bzl", "py_strict_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +exports_files([ + "__init__.py", + "types.py", + "_types.pyi", +]) + +# NOTE: This wrapper library is necessary in order to capture the Python +# dependencies of our extension (namely `ml_dtypes`). Although the +# underlying `pybind_extension` rule has a `py_deps` argument for capturing +# such dependencies directly, the `tsl_pybind_extension` rule doesn't expose +# that `py_deps` argument for us to use. +# +# NOTE: On the OSS side, the `pytype_strict_library` rule is changed into +# the non-typed rule, which in turn causes an error about the `pytype_srcs` +# field. The "..:xla_client" target gets around this by adding a custom +# copybara rule; but in lieu of adding yet another custom rule to maintain, +# we just use the generic copybara mechanism for commenting the field out +# on the OSS side. +# TODO(wrengr,phawkins): Once cl/619904840 lands, we can remove the +# pragma and the preceding commentary. +pytype_strict_library( + name = "types", + srcs = ["types.py"], + # copybara:uncomment pytype_srcs = ["_types.pyi"], + srcs_version = "PY3", + # Cannot build this on OSS because the ":xla_data_proto_py_pb2" + # dependency isn't part of the public API. + tags = ["no_oss"], + visibility = ["//visibility:public"], + deps = [ + ":_types", # buildcleaner: keep + "//third_party/py/numpy", + "//xla:xla_data_proto_py_pb2", + "@ml_dtypes", + ], +) + +# NOTE: Copybara detects the `tsl_pybind_extension` rule and automatically +# injects the "@com_google_protobuf//:protobuf_python" python dependency +# required by "@pybind11_protobuf//pybind11_protobuf:native_proto_caster". +tsl_pybind_extension( + name = "_types", + srcs = ["_types.cc"], + pytype_deps = ["//third_party/py/numpy"], + pytype_srcs = ["_types.pyi"], + # Users should depend on ":types" instead. + visibility = ["//visibility:private"], + deps = [ + "//third_party/nanobind", + "//xla:literal", + "//xla:xla_data_proto_cc", + "//xla/python:logging", + "//xla/python:types", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/python/lib/core:numpy", + "@pybind11", + "@pybind11_abseil//pybind11_abseil:status_casters", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], +) + +py_strict_test( + name = "types_test", + size = "small", + srcs = ["types_test.py"], + python_version = "PY3", + srcs_version = "PY3", + # Cannot build this on OSS because the ":xla_data_proto_py_pb2" + # dependency isn't part of the public API. + tags = ["no_oss"], + deps = [ + ":types", + #internal proto upb dep + "//third_party/py/numpy", + "//xla:xla_data_proto_py_pb2", + "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/third_party/xla/xla/python/tools/__init__.py b/third_party/xla/xla/python/tools/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/xla/xla/python/tools/_types.cc b/third_party/xla/xla/python/tools/_types.cc new file mode 100644 index 00000000000000..f18e360c13c7fb --- /dev/null +++ b/third_party/xla/xla/python/tools/_types.cc @@ -0,0 +1,128 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "third_party/nanobind/include/nanobind/nanobind.h" +#include "third_party/nanobind/include/nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "pybind11/detail/common.h" // from @pybind11 +#include "pybind11/numpy.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil +#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf +#include "xla/literal.h" +#include "xla/python/logging.h" +#include "xla/python/types.h" +#include "xla/xla_data.pb.h" +// NOTE: The tsl-numpy header forbids importing the actual NumPy arrayobject.h +// header before tsl-numpy (whereas, importing pybind11-numpy before tsl-numpy +// is fine); however, tsl-numpy does reexport NumPy's arrayobject.h header. +// Since one of the TF headers above already includes tsl-numpy, therefore +// we must include it down here rather than including actual NumPy directly. +#include "tsl/python/lib/core/numpy.h" + +namespace py = ::pybind11; +namespace nb = ::nanobind; + +namespace { +absl::StatusOr MakeNdarray(const xla::LiteralProto& proto) { + auto m_lit = xla::Literal::CreateFromProto(proto); + if (!m_lit.ok()) { + // NOTE: The OSS version of XLA is still using an old version of + // Abseil (LTS branch, Aug 2023, Patch 1) which does not have the + // `AbslStringify` interface for implicitly converting `absl::Status` + // into the `absl::AlphaNum` required by `absl::StrCat`. Therefore we + // inline the latest definition of the `AbslStringify` overload. + throw py::value_error(absl::StrCat( + "Cannot `xla::Literal::CreateFromProto`: ", + m_lit.status().ToString(absl::StatusToStringMode::kWithEverything))); + } + + // Move (not copy) the literal onto the heap, for sharing with Python. + auto lit = std::make_shared(std::move(m_lit).value()); + + TF_ASSIGN_OR_RETURN(auto nbobj, xla::LiteralToPython(std::move(lit))); + + // Convert `nb::object` into `py::object`. + return py::reinterpret_steal(nbobj.release().ptr()); +} +} // namespace + +// NOTE: It seems insurmountable to get "native_proto_caster.h" to work +// with nanobind modules; therefore, we define our extension as a pybind11 +// module so that we can use `pybind11::module_::def`. +PYBIND11_MODULE(_types, py_m) { + // Initialize ABSL logging because code within XLA uses it. + // (As per `xla::Init` in "xla.cc"; though we don't need it ourselves.) +#ifndef PLATFORM_GOOGLE + xla::InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // Normally this would happen at the start of NB_MODULE, but since + // this is a pybind11 module we have to do this ourselves. + // (As per `xla::Init` in "xla.cc".) + nb::detail::init(NB_DOMAIN_STR); + + // Import implicit conversions from Python protobuf objects to C++ + // protobuf objects. + pybind11_protobuf::ImportNativeProtoCasters(); + + // Import implicit conversions from `absl::StatusOr` to Python exceptions. + // (The code for performing conversions is easy enough to port to nanobind; + // albeit, the conversion calls themselves have to be made explicit, + // since `nb::detail::type_caster` disallows raising exceptions.) + py::google::ImportStatusModule(); + + // Import the 'ml_dtypes' module; which is implicitly required by + // `xla::LiteralToPython`. + // NOTE: If the `tsl_pybind_extension` build rule allowed us to specify + // this as a py_dep, then importing the module here would mean that + // client Python code need not import the hidden dependency themselves. + // However, since `tsl_pybind_extension` does not allow specifying py_deps, + // if client rules do not themselves declare the dependency then this will + // generate a `ModuleNotFoundError` / `ImportError` exception. Hence why + // we define the "types.py" wrapper library to encapsulate the dependency. + py::module_::import("ml_dtypes"); + + // Ensure that tsl-numpy initializes datastructures of the actual-NumPy + // implementation, and does whatever else tsl-numpy needs. + tsl::ImportNumpy(); + + // Declare that C++ can `nb::cast` from `std::shared_ptr` + // to `nb::object`; which is implicitly required by `xla::LiteralToPython`. + // (FWIW: This also enables using `nb::type()` to get + // the Python-type-object associated with the C++ class.) + // + // NOTE: This does *not* mean that C++ can `py::cast` from `xla::Literal` + // to `py::object`. It's unclear whether we can simultaneously provide + // both nanobind and pybind11 bindings (if we wanted the latter). + nb::module_ nb_m = nb::cast(nb::borrow(py_m.ptr())); + nb::class_(nb_m, "Literal") + .def("__repr__", &xla::Literal::ToString); + + // We do not define `py_m.doc()` here, since it wouldn't be inherited + // by the "types.py" wrapper library. See there for the python docstring. + + // LINT.IfChange + py_m.def("make_ndarray", &MakeNdarray, py::arg("proto").none(false), + py::pos_only(), R"pbdoc( + Converts `tensorflow.compiler.xla.xla_data_pb2.LiteralProto` + into an `xla::Literal` and then converts that literal into a tree + of tuples with leaves being `numpy.ndarray` views of array-shaped + sub-literals. + )pbdoc"); + // LINT.ThenChange(_types.pyi) +} diff --git a/third_party/xla/xla/python/tools/_types.pyi b/third_party/xla/xla/python/tools/_types.pyi new file mode 100644 index 00000000000000..1ca5071367a0cd --- /dev/null +++ b/third_party/xla/xla/python/tools/_types.pyi @@ -0,0 +1,23 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Union +import numpy as np +from xla import xla_data_pb2 + +# LINT.IfChange +NdarrayTree = Union[np.ndarray, tuple['NdarrayTree', ...]] +def make_ndarray(proto: xla_data_pb2.LiteralProto, /) -> NdarrayTree: ... +# LINT.ThenChange(types.py, _types.cc) diff --git a/third_party/xla/xla/python/tools/types.py b/third_party/xla/xla/python/tools/types.py new file mode 100644 index 00000000000000..d13ee2241ed479 --- /dev/null +++ b/third_party/xla/xla/python/tools/types.py @@ -0,0 +1,51 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tensorflow.compiler.xla.python.tools.types. + +This module provides Python bindings for various functions in +'tensorflow/compiler/xla/python/types.h'. It is primarily intended +to assist internal users in debugging things; and is not considered +part of the public API for OpenXLA. + +NOTE: This module *does* depend on Python protocol buffers; so beware! +The XLA Python bindings are currently packaged both as part of jaxlib and +as part of TensorFlow. Therefore, since we use protocol buffers here, +importing both jaxlib and TensorFlow may fail with duplicate protocol +buffer message definitions. +""" + +from typing import Union +# NOTE: `ml_dtypes` is implicitly required by `xla::LiteralToPython`. +# The entire goal of this wrapper library is to capture this dependency, +# so that client code need not be aware of it. +import ml_dtypes # pylint: disable=unused-import +import numpy +# NOTE: These protos are not part of TensorFlow's public API, therefore +# we cannot abide by [g-direct-tensorflow-import]. +# pylint: disable=g-direct-tensorflow-import,unused-import +from local_xla.xla import xla_data_pb2 +# pylint: enable=g-direct-tensorflow-import,unused-import + +# NOTE: `import as ` is required for names to be exported. +# See PEP 484 & +# pylint: disable=g-importing-member,useless-import-alias,unused-import +# LINT.IfChange +from ._types import ( + make_ndarray as make_ndarray, +) +# TODO(wrengr): We can't import the `NdarrayTree` defined in the pyi file. +# So re-defining it here for now. +NdarrayTree = Union[numpy.ndarray, tuple['NdarrayTree', ...]] +# LINT.ThenChange(_types.pyi) diff --git a/third_party/xla/xla/python/tools/types_test.py b/third_party/xla/xla/python/tools/types_test.py new file mode 100644 index 00000000000000..e056e05be24f35 --- /dev/null +++ b/third_party/xla/xla/python/tools/types_test.py @@ -0,0 +1,183 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import itertools +import math +import re +from typing import List, NamedTuple + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +# NOTE: These protos are not part of the public API, therefore we cannot +# abide by [g-direct-tensorflow-import]. +# pylint: disable=g-direct-tensorflow-import +from local_xla.xla import xla_data_pb2 +from xla.python.tools import types +# pylint: enable=g-direct-tensorflow-import + + +class MakeNdarrayInvalidTest(absltest.TestCase): + """Tests for invalid/unsupported arguments to `make_ndarray`.""" + + def setUp(self): + super().setUp() + self.assert_cannot_create_from_proto = self.assertRaisesRegex( + ValueError, re.escape('Cannot `xla::Literal::CreateFromProto`') + ) + + # NOTE: The `Literal(const Shape&, bool, ArrayValueState)` ctor does + # a CHECK forbidding `element_size_in_bits` from being specified; + # so we can't test anything about custom sizes here. + + def testMissingLayout(self): + # NOTE: `CreateFromProto` requires explicit `shape.layout.minor_to_major`. + # Though in principle it could use a default ctor instead, like we + # do in `make_named_parameter` below`. + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=[1, 2, 3], + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + def testMissingMinorToMajor(self): + # NOTE: `CreateFromProto` requires explicit `shape.layout.minor_to_major`. + # Though in principle it could use a default ctor instead, like we + # do in `make_named_parameter` below`. + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=[1, 2, 3], + layout=xla_data_pb2.LayoutProto(), + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + def testInvalidPrimitiveType(self): + # NOTE: The `is_dynamic_dimension` field isn't required by + # `CreateFromProto`; however, the `Shape(const ShapeProto&)` ctor + # will log warnings if we leave it unspecified. + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.PRIMITIVE_TYPE_INVALID, + dimensions=[1, 2, 3], + is_dynamic_dimension=[False, False, False], + layout=xla_data_pb2.LayoutProto( + minor_to_major=[0, 1, 2], + ), + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + def testHasDimLevelTypes(self): + # NOTE: `CreateFromProto` forbids `dim_level_types` (even if all-dense). + pb = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=[1, 2, 3], + is_dynamic_dimension=[False, False, False], + layout=xla_data_pb2.LayoutProto( + dim_level_types=[ + xla_data_pb2.DimLevelType.DIM_DENSE, + xla_data_pb2.DimLevelType.DIM_DENSE, + xla_data_pb2.DimLevelType.DIM_DENSE, + ], + minor_to_major=[0, 1, 2], + ), + ) + ) + with self.assert_cannot_create_from_proto: + types.make_ndarray(pb) + + +class MakeNdarrayValidTestParameter(NamedTuple): + testcase_name: str + proto: xla_data_pb2.LiteralProto + arr: np.ndarray + + +def make_named_parameter( + testcase_name: str, + dimensions: List[int], + data: List[float], +) -> MakeNdarrayValidTestParameter: + """Helper function to construct parameters for `MakeNdarrayValidTest`.""" + assert math.prod(dimensions) == len(data) + nd = len(dimensions) + proto = xla_data_pb2.LiteralProto( + shape=xla_data_pb2.ShapeProto( + element_type=xla_data_pb2.PrimitiveType.F64, + dimensions=dimensions, + is_dynamic_dimension=itertools.repeat(False, nd), + layout=xla_data_pb2.LayoutProto( + minor_to_major=range(nd), + ), + ), + f64s=data, + ) + arr = types.make_ndarray(proto) + return MakeNdarrayValidTestParameter(testcase_name, proto, arr) + + +@parameterized.named_parameters( + make_named_parameter('A', [2, 3], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + make_named_parameter('B', [1, 2, 3], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + make_named_parameter('C', [2, 3], [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]), + make_named_parameter('D', [3, 2], [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]), +) +class MakeNdarrayValidTest(parameterized.TestCase): + """Correctness tests for valid arguments to `make_ndarray`.""" + + def testHasCorrectDtype(self, proto, arr): + """Test that the result has the right dtype.""" + # Silence [unused-argument] warning. + del proto + # TODO(wrengr): Add pybind for `xla::PrimitiveTypeToDtype`, + # so that we can avoid hard-coding the expected np.dtype. + # Alternatively, we could use `xla_client.dtype_to_etype` (ideally + # after refactoring that into a small library, so we need not pull in + # all the rest of xla_client). + self.assertEqual(np.float64, arr.dtype) + + def testHasCorrectRank(self, proto, arr): + """Test that the result has the right rank.""" + self.assertLen(proto.shape.dimensions, arr.ndim) + + def testHasCorrectShape(self, proto, arr): + """Test that the result has the same/right shape.""" + self.assertTupleEqual(tuple(proto.shape.dimensions), arr.shape) + + def testHasCorrectData(self, proto, arr): + """Test that the result has the same/right data.""" + # TODO(wrengr): Figure out a way to abstract away the name of the + # proto field containing the data; so that we can test multiple types. + self.assertSequenceAlmostEqual(proto.f64s, list(np.nditer(arr))) + + # TODO(wrengr): Add tests for: + # * dynamic dimension sizes. + # * non-trivial `minor_to_major`. + # * problematic types {PRED,F16,C64,C128} are all handled correctly. + # * BF16 is handled correctly. + # * tuples are handled correctly + + +if __name__ == '__main__': + absltest.main() From 2d48a6d1eba05a768e3a8a09326c370518722086 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Thu, 28 Mar 2024 14:51:37 -0700 Subject: [PATCH 056/124] Reduce memory usage of convert_control_to_data_outputs. PiperOrigin-RevId: 620064529 --- .../tensorflow/analysis/resource_dataflow.cc | 71 ++++++++++++++----- .../tensorflow/analysis/resource_dataflow.h | 17 ++++- .../convert_control_to_data_outputs.cc | 6 +- 3 files changed, 71 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc index d0a05e45617cf6..5ceda80490f688 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project #include "mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -55,9 +54,6 @@ ResourceConstructingOps ResourceConstructingOps::EntryState(Value value) { tf_saved_model::GlobalTensorOp>(func, barg.getArgNumber(), symbol_table); ResourceConstructingOps result(global_tensor); - if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { - result.is_on_composite_device = true; - } return result; } } else if (auto vh = dyn_cast(value.getDefiningOp())) { @@ -75,17 +71,47 @@ ResourceConstructingOps ResourceConstructingOps::join( ResourceConstructingOps ret; ret.ops.insert(lhs.ops.begin(), lhs.ops.end()); ret.ops.insert(rhs.ops.begin(), rhs.ops.end()); - ret.is_on_composite_device = - lhs.is_on_composite_device || rhs.is_on_composite_device; return ret; } void ResourceConstructingOps::print(raw_ostream &os) const { llvm::interleaveComma(ops, os << "["); + os << "]"; +} + +IsComposite::IsComposite(Operation *op) {} + +IsComposite IsComposite::EntryState(MLIRContext *context) { + return IsComposite(); +} + +IsComposite IsComposite::EntryState(Value value) { + IsComposite result; + if (auto barg = value.dyn_cast()) { + if (func::FuncOp func = + dyn_cast(barg.getOwner()->getParentOp())) { + if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { + result.is_on_composite_device = true; + } + return result; + } + } + return result; +} + +IsComposite IsComposite::join(const IsComposite &lhs, const IsComposite &rhs) { + IsComposite ret; + ret.is_on_composite_device = + lhs.is_on_composite_device || rhs.is_on_composite_device; + return ret; +} + +void IsComposite::print(raw_ostream &os) const { if (is_on_composite_device) { - os << " COMPOSITE"; + os << "COMPOSITE"; + } else { + os << "NOT_COMPOSITE"; } - os << "]"; } class ResourceDataflowAnalysis @@ -94,23 +120,32 @@ class ResourceDataflowAnalysis using TensorflowDataflowAnalysis< ResourceConstructingOps>::TensorflowDataflowAnalysis; void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override; + ArrayRef results) override { + if (ForwardThroughTFOperation(op, operands, results)) return; + setAllToEntryStates(results); + } ~ResourceDataflowAnalysis() override = default; }; -void ResourceDataflowAnalysis::visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) { - LLVM_DEBUG(llvm::dbgs() << "ResAn: Visiting operation: " << *op << "\n"); - - if (ForwardThroughTFOperation(op, operands, results)) return; - - setAllToEntryStates(results); -} +class IsCompositeDataflowAnalysis + : public TensorflowDataflowAnalysis { + public: + using TensorflowDataflowAnalysis::TensorflowDataflowAnalysis; + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override { + if (ForwardThroughTFOperation(op, operands, results)) return; + setAllToEntryStates(results); + } + ~IsCompositeDataflowAnalysis() override = default; +}; void LoadResourceDataflowAnalysis(DataFlowSolver &solver) { solver.load(); } +void LoadIsCompositeDataflowAnalysis(DataFlowSolver &solver) { + solver.load(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h index 9015b9dc739634..0cf3611af1d20c 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h @@ -46,8 +46,7 @@ struct ResourceConstructingOps { static ResourceConstructingOps EntryState(MLIRContext *context); static ResourceConstructingOps EntryState(Value value); bool operator==(const ResourceConstructingOps &rhs) const { - return ops == rhs.ops && - is_on_composite_device == rhs.is_on_composite_device; + return ops == rhs.ops; } static ResourceConstructingOps join(const ResourceConstructingOps &lhs, @@ -57,13 +56,27 @@ struct ResourceConstructingOps { // The operation(s) which created the resource value. // IR constructs (i.e., GlobalTensorOp) are not const-correct. mutable DenseSet ops; +}; + +struct IsComposite { + explicit IsComposite(Operation *op = nullptr); + static IsComposite EntryState(MLIRContext *context); + static IsComposite EntryState(Value value); + bool operator==(const IsComposite &rhs) const { + return is_on_composite_device == rhs.is_on_composite_device; + } + + static IsComposite join(const IsComposite &lhs, const IsComposite &rhs); + void print(raw_ostream &os) const; bool is_on_composite_device = false; }; typedef dataflow::Lattice ResourceDataflowState; +typedef dataflow::Lattice IsCompositeDataflowState; void LoadResourceDataflowAnalysis(DataFlowSolver &solver); +void LoadIsCompositeDataflowAnalysis(DataFlowSolver &solver); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index f265ac68fa5f27..4de43317677f63 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -145,7 +145,7 @@ bool OnlyOperatesOnCompositeDevices( continue; } auto lattice = - solver.lookupState(arg.get())->getValue(); + solver.lookupState(arg.get())->getValue(); bool is_read = read_array.contains(arg.getOperandNumber()); bool is_update = update_array.contains(arg.getOperandNumber()); // We want the resource operands that are on composite devices to be the @@ -214,7 +214,7 @@ void CollectChainResources( // device-specific (see below). bool resource_is_on_composite_device = false; for (Value value : alias_analysis.GetValuesForResourceId(resource_id)) { - auto lattice = solver.lookupState(value); + auto lattice = solver.lookupState(value); if (lattice) { resource_is_on_composite_device |= lattice->getValue().is_on_composite_device; @@ -604,7 +604,7 @@ void ConvertControlToDataOutputsPass::runOnOperation() { DataFlowSolver solver; solver.load(); solver.load(); - TF::LoadResourceDataflowAnalysis(solver); + TF::LoadIsCompositeDataflowAnalysis(solver); if (failed(solver.initializeAndRun(module))) return signalPassFailure(); // This pass assumes that all functions are suitable for export i.e., each From 9926a407d0929e6444cb63cafe9a934eae9c17e7 Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Thu, 28 Mar 2024 15:01:23 -0700 Subject: [PATCH 057/124] Add quantization/legalization for `stablehlo.add` and respective pipeline changes. * Added `enable_full_int_quantization` in `StaticRangePtqPreset` to determine full int quantization. This value will be `false` by default, meaning only compute-heavy ops will be quantized unless specified. * Added tests for the above config change. * Follow up tests will include e2e python tests. PiperOrigin-RevId: 620067140 --- .../uniform-quantized-stablehlo-to-tfl.mlir | 172 +++++++++++------- ...uniform_quantized_stablehlo_to_tfl_pass.cc | 164 ++++++++--------- .../mlir/quantization/stablehlo/cc/config.cc | 8 +- .../quantization/stablehlo/cc/config_test.cc | 22 ++- .../stablehlo/cc/pass_pipeline.cc | 3 + ...t_quantizable_spots_as_functions_simple.td | 8 + .../quantization/stablehlo/passes/passes.td | 8 + .../stablehlo/passes/quantization_patterns.cc | 6 + .../stablehlo/passes/quantization_patterns.h | 5 + .../quantization/stablehlo/passes/quantize.cc | 7 + .../passes/quantize_composite_functions.cc | 5 +- .../stablehlo/passes/testing/passes.h | 7 +- .../stablehlo/passes/testing/passes.td | 4 +- ...ts_as_functions_with_quantization_specs.cc | 11 ++ .../integration_test/quantize_model_test.py | 4 +- .../stablehlo/quantization_config.proto | 5 +- ..._as_functions_with_quantization_specs.mlir | 63 +++++-- .../quantize_composite_functions_all_ops.mlir | 46 +++++ tensorflow/lite/python/lite.py | 1 + 19 files changed, 363 insertions(+), 186 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 05d00443ebcca6..5653dfeb9f2b8f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -764,7 +764,7 @@ func.func @conv_with_bias_same_padding_srq_depthwise(%arg0: tensor<1x4x5x3x!quan // ----- -// Tests that a quantized stablehlo.transpose is converted to tfl.transpose. +// Tests that a quantized `stablehlo.transpose` is converted to `tfl.transpose`. func.func @transpose( %arg0: tensor<2x3x4x!quant.uniform> @@ -781,19 +781,19 @@ func.func @transpose( // ----- -// Tests that a float stablehlo.transpose is not converted to tfl.transpose. +// Tests that a float `stablehlo.transpose` is not converted to `tfl.transpose`. -func.func @float_transpose(%arg0: tensor<2x3x4xf32>) -> tensor<4x3x2xf32> { +func.func @transpose_float(%arg0: tensor<2x3x4xf32>) -> tensor<4x3x2xf32> { %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x3x4xf32>) -> tensor<4x3x2xf32> return %0 : tensor<4x3x2xf32> } -// CHECK-LABEL: float_transpose +// CHECK-LABEL: transpose_float // CHECK-NOT: tfl.transpose // CHECK: stablehlo.transpose // ----- -// Tests that a quantized stablehlo.reshape is converted to tfl.reshape. +// Tests that a quantized `stablehlo.reshape` is converted to `tfl.reshape`. func.func @reshape( %arg0: tensor<2x3x4x!quant.uniform> @@ -810,19 +810,19 @@ func.func @reshape( // ----- -// Tests that a float stablehlo.reshape is not converted to tfl.reshape. +// Tests that a float `stablehlo.reshape` is not converted to `tfl.reshape`. -func.func @float_reshape(%arg0: tensor<2x3x4xf32>) -> tensor<6x4xf32> { +func.func @reshape_float(%arg0: tensor<2x3x4xf32>) -> tensor<6x4xf32> { %0 = stablehlo.reshape %arg0 : (tensor<2x3x4xf32>) -> tensor<6x4xf32> return %0 : tensor<6x4xf32> } -// CHECK-LABEL: float_reshape +// CHECK-LABEL: reshape_float // CHECK-NOT: tfl.reshape // CHECK: stablehlo.reshape // ----- -// Tests that a quantized stablehlo.select is converted to tfl.select_v2. +// Tests that a quantized `stablehlo.select` is converted to `tfl.select_v2`. func.func @select( %arg0: tensor<1x3xi1>, @@ -844,19 +844,20 @@ func.func @select( // ----- -// Tests that a float stablehlo.select is not converted to tfl.select_v2. +// Tests that a float `stablehlo.select` is not converted to `tfl.select_v2`. -func.func @float_select(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> { +func.func @select_float(%arg0: tensor<1x3xi1>, %arg1: tensor<1x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> { %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> return %0 : tensor<1x3xf32> } -// CHECK-LABEL: float_select +// CHECK-LABEL: select_float // CHECK-NOT: tfl.select_v2 // CHECK: stablehlo.select // ----- -// Tests that a quantized stablehlo.concatenate is converted to tfl.concatenation. +// Tests that a quantized `stablehlo.concatenate` is converted to +// `tfl.concatenation`. func.func @concatenate( %arg0: tensor<3x2x!quant.uniform>, @@ -876,20 +877,21 @@ func.func @concatenate( // ----- -// Tests that a float stablehlo.concatenate is not converted to tfl.concatenation. +// Tests that a float `stablehlo.concatenate` is not converted to +// `tfl.concatenation`. -func.func @float_concatenate(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x2xf32> { +func.func @concatenate_float(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<4x2xf32> { %0 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> return %0 : tensor<4x2xf32> } -// CHECK-LABEL: float_concatenate +// CHECK-LABEL: concatenate_float // CHECK-NOT: tfl.concatenation // CHECK: stablehlo.concatenate // ----- -// Tests that a quantized stablehlo.pad without interior padding is converted to -// tfl.padv2. +// Tests that a quantized `stablehlo.pad` without interior padding is +// converted to `tfl.padv2`. func.func @pad_without_interior_padding( %arg0: tensor<2x3x!quant.uniform>, @@ -911,8 +913,8 @@ func.func @pad_without_interior_padding( // ----- -// Tests that a quantized stablehlo.pad with interior padding is converted to -// tfl.dilate and tfl.padv2. +// Tests that a quantized `stablehlo.pad` with interior padding is converted to +// `tfl.dilate` and `tfl.padv2`. func.func @pad_with_interior_padding( %arg0: tensor<2x3x!quant.uniform>, @@ -937,20 +939,20 @@ func.func @pad_with_interior_padding( // ----- -// Tests that a float stablehlo.pad is not converted to tfl.padv2. +// Tests that a float `stablehlo.pad` is not converted to `tfl.padv2`. -func.func @float_pad(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<4x5xf32> { +func.func @pad_float(%arg0: tensor<2x3xf32>, %arg1: tensor) -> tensor<4x5xf32> { %0 = stablehlo.pad %arg0, %arg1, low = [0, 1], high = [2, 1], interior = [0, 0] : (tensor<2x3xf32>, tensor) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> } -// CHECK-LABEL: float_pad +// CHECK-LABEL: pad_float // CHECK-NOT: tfl.padv2 // CHECK: stablehlo.pad // ----- -// Tests that a quantized stablehlo.slice is converted to tfl.slice when stride -// is 1. +// Tests that a quantized `stablehlo.slice` is converted to +// `tfl.slice` when stride is 1. func.func @slice( %arg0: tensor<3x4x!quant.uniform> @@ -973,8 +975,8 @@ func.func @slice( // ----- -// Tests that a quantized stablehlo.slice is converted to tfl.strided_slice when -// stride is not 1. +// Tests that a quantized `stablehlo.slice` is converted to `tfl.strided_slice` +// when stride is not 1. func.func @strided_slice( %arg0: tensor<3x6x!quant.uniform> @@ -1001,9 +1003,9 @@ func.func @strided_slice( // ----- -// Tests that a float stablehlo.slice is not converted to tfl.slice. +// Tests that a float `stablehlo.slice` is not converted to `tfl.slice`. -func.func @float_slice(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { +func.func @slice_float(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { %0 = "stablehlo.slice"(%arg0) { start_indices = array, limit_indices = array, @@ -1011,15 +1013,15 @@ func.func @float_slice(%arg0: tensor<3x4xf32>) -> tensor<2x2xf32> { } : (tensor<3x4xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } -// CHECK-LABEL: float_slice +// CHECK-LABEL: slice_float // CHECK-NOT: tfl.slice // CHECK-NOT: tfl.strided_slice // CHECK: stablehlo.slice // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.broadcast_to. +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// `tfl.broadcast_to`. func.func @broadcast_in_dim( %arg0: tensor<1x2x!quant.uniform> @@ -1038,8 +1040,8 @@ func.func @broadcast_in_dim( // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.transpose and tfl.broadcast_to when broadcast_dimensions is not in +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// `tfl.transpose` and `tfl.broadcast_to` when `broadcast_dimensions` is not in // ascending order. func.func @broadcast_in_dim_with_transpose( @@ -1062,8 +1064,8 @@ func.func @broadcast_in_dim_with_transpose( // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.expand_dims and tfl.broadcast_to when input rank is smaller than output +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// tfl.expand_dims and `tfl.broadcast_to` when input rank is smaller than output // rank. func.func @broadcast_in_dim_with_expand( @@ -1086,9 +1088,10 @@ func.func @broadcast_in_dim_with_expand( // ----- -// Tests that a quantized stablehlo.broadcast_in_dim is converted to -// tfl.transpose, tfl.expand_dims and tfl.broadcast_to when broadcast_dimensions -// is not in ascending order and input rank is smaller than output rank. +// Tests that a quantized `stablehlo.broadcast_in_dim` is converted to +// `tfl.transpose`, `tfl.expand_dims` and `tfl.broadcast_to` when +// `broadcast_dimensions` is not in ascending order and input rank is smaller +// than output rank. func.func @broadcast_in_dim_with_transpose_and_expand( %arg0: tensor<2x3x4x!quant.uniform> @@ -1112,15 +1115,16 @@ func.func @broadcast_in_dim_with_transpose_and_expand( // ----- -// Tests that a float stablehlo.broadcast_in_dim is not converted to tfl.broadcast_to. +// Tests that a float `stablehlo.broadcast_in_dim` is not converted to +// `tfl.broadcast_to`. -func.func @float_broadcast_in_dim(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { +func.func @broadcast_in_dim_float(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { %0 = "stablehlo.broadcast_in_dim"(%arg0) { broadcast_dimensions = array } : (tensor<1x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> } -// CHECK-LABEL: float_broadcast_in_dim +// CHECK-LABEL: broadcast_in_dim_float // CHECK-NOT: tfl.broadcast_to // CHECK-NOT: tfl.transpose // CHECK-NOT: tfl.expand_dims @@ -1128,8 +1132,8 @@ func.func @float_broadcast_in_dim(%arg0: tensor<1x2xf32>) -> tensor<3x2xf32> { // ----- -// Test that a quantized stablehlo.reduce_window with max is converted to -// tfl.max_pool_2d. +// Tests that a quantized `stablehlo.reduce_window` with max is converted to +// `tfl.max_pool_2d`. func.func @reduce_window_with_max( %arg0: tensor<2x9x10x3x!quant.uniform>, @@ -1153,8 +1157,8 @@ func.func @reduce_window_with_max( // ----- -// Test that a quantized stablehlo.reduce_window with max whose rank is not 4 -// is not converted to tfl.max_pool_2d. +// Tests that a quantized `stablehlo.reduce_window `with max whose rank is not 4 +// is not converted to `tfl.max_pool_2d`. func.func @reduce_window_not_4d( %arg0: tensor<3x2x9x10x3x!quant.uniform>, @@ -1174,8 +1178,8 @@ func.func @reduce_window_not_4d( // ----- -// Test that a quantized stablehlo.reduce_window with max that takes multiple -// inputs is not converted to tfl.max_pool_2d. +// Tests that a quantized `stablehlo.reduce_window` with max that takes multiple +// inputs is not converted to `tfl.max_pool_2d`. func.func @reduce_window_not_binary( %arg0: tensor<3x2x9x10x3x!quant.uniform>, @@ -1198,10 +1202,10 @@ func.func @reduce_window_not_binary( // ----- -// Test that a float stablehlo.reduce_window with max is not converted to -// tfl.max_pool_2d. +// Tests that a float `stablehlo.reduce_window` with max is not converted to +// `tfl.max_pool_2d`. -func.func @float_reduce_window_with_max( +func.func @reduce_window_with_max_float( %arg0: tensor<2x9x10x3xf32>, %arg1: tensor ) -> tensor<2x4x3x3xf32> { @@ -1213,13 +1217,14 @@ func.func @float_reduce_window_with_max( return %0 : tensor<2x4x3x3xf32> } -// CHECK-LABEL: float_reduce_window_with_max +// CHECK-LABEL: reduce_window_with_max_float // CHECK: stablehlo.reduce_window // CHECK-NOT: tfl.max_pool_2d // ----- -// Test that a quantized stablehlo.dynamic_reshape is converted to tfl.reshape. +// Tests that a quantized `stablehlo.dynamic_reshape` is converted to +// `tfl.reshape`. func.func @dynamic_reshape( %arg0: tensor>, @@ -1240,20 +1245,21 @@ func.func @dynamic_reshape( // ----- -// Test that a float stablehlo.dynamic_reshape is not converted to tfl.reshape. +// Tests that a float `stablehlo.dynamic_reshape` is not converted to +// `tfl.reshape`. -func.func @float_dynamic_reshape(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { +func.func @dynamic_reshape_float(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor { %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor, tensor<2xi32>) -> tensor return %0 : tensor } -// CHECK-LABEL: func @float_dynamic_reshape +// CHECK-LABEL: func @dynamic_reshape_float // CHECK: stablehlo.dynamic_reshape // CHECK-NOT: tfl.reshape // ----- -// Test that a quantized stablehlo.gather is converted to tfl.gather_nd. +// Tests that a quantized `stablehlo.gather` is converted to tfl.gather_nd. func.func @gather( %arg0: tensor<3x4x2x2x!quant.uniform>, @@ -1282,8 +1288,8 @@ func.func @gather( // ----- -// Test that a quantized stablehlo.gather with unsorted start_index_map is not -// converted to tfl.gather_nd (condition 1 is not satisfied). +// Tests that a quantized `stablehlo.gather` with unsorted start_index_map is +// not converted to `tfl.gather_nd` (condition 1 is not satisfied). func.func @gather_start_index_map_not_sorted( %arg0: tensor<3x4x2x2x!quant.uniform>, @@ -1311,7 +1317,7 @@ func.func @gather_start_index_map_not_sorted( // ----- -// Test that a quantized stablehlo.gather is not converted to tfl.gather_nd +// Tests that a quantized `stablehlo.gather` is not converted to tfl.gather_nd // when index_vector_dim is not the last dimension of start_indices (condition 2 // is not satisfied). @@ -1341,7 +1347,7 @@ func.func @gather_start_index_vector_dim_not_at_last( // ----- -// Test that a quantized stablehlo.gather is not converted to tfl.gather_nd +// Tests that a quantized `stablehlo.gather` is not converted to tfl.gather_nd // when offset_dims are not the last dimensions of the output (condition 3 is // not satisfied). @@ -1371,7 +1377,7 @@ func.func @gather_offset_dims_not_at_last( // ----- -// Test that a quantized stablehlo.gather is not converted to tfl.gather_nd +// Tests that a quantized `stablehlo.gather` is not converted to tfl.gather_nd // when shape of slice is not same with shape of offset (condition 4 is not // satisfied). @@ -1401,9 +1407,9 @@ func.func @gather_different_slice_and_offset( // ----- -// Test that a float stablehlo.gather is not converted to tfl.gather_nd. +// Tests that a float `stablehlo.gather` is not converted to `tfl.gather_nd`. -func.func @float_gather(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { +func.func @gather_float(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { %0 = "stablehlo.gather"(%arg0, %arg1) { dimension_numbers = #stablehlo.gather< offset_dims = [2, 3], @@ -1416,14 +1422,14 @@ func.func @float_gather(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2x3x2xi64>) -> return %0 : tensor<2x3x2x2xf32> } -// CHECK-LABEL: func @float_gather +// CHECK-LABEL: func @gather_float // CHECK: stablehlo.gather // CHECK-NOT: tfl.gather_nd // CHECK-NOT: tfl.gather // ----- -// Test that a quantized stablehlo.dynamic_slice is converted to tfl.slice. +// Tests that a quantized `stablehlo.dynamic_slice` is converted to `tfl.slice`. // CHECK-LABEL: func @dynamic_slice // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4x!quant.uniform>, %[[ARG1:.+]]: tensor, %[[ARG2:.+]]: tensor @@ -1457,18 +1463,46 @@ func.func @dynamic_slice( // ----- -// Test that a float stablehlo.dynamic_slice is not converted to tfl.slice. +// Tests that a float `stablehlo.dynamic_slice` is not converted to `tfl.slice`. -func.func @float_dynamic_slice(%arg0: tensor<4x4xf32>, %arg1: tensor, %arg2: tensor) -> tensor<2x1xf32> { +func.func @dynamic_slice_float(%arg0: tensor<4x4xf32>, %arg1: tensor, %arg2: tensor) -> tensor<2x1xf32> { %0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) { slice_sizes = array } : (tensor<4x4xf32>, tensor, tensor) -> tensor<2x1xf32> return %0 : tensor<2x1xf32> } -// CHECK-LABEL: func @float_dynamic_slice +// CHECK-LABEL: func @dynamic_slice_float // CHECK: stablehlo.dynamic_slice // CHECK-NOT: tfl.bitcast // CHECK-NOT: tfl.minimum // CHECK-NOT: tfl.maximum // CHECK-NOT: tfl.slice + +// ----- + +// Tests that `stablehlo.add` with both operands int8 UniformQuantizedType is +// properly converted into `tfl.add`. + +func.func @add(%arg0: tensor<1x3x!quant.uniform>, %arg1: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.add %arg0, %arg1 : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// CHECK-LABEL: func @add +// CHECK: %[[ADD:.+]] = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[ADD]] + +// ----- + +// Tests that `stablehlo.add` with int32 UniformQuantizedPerAxisTypes is +// not converted. + +func.func @add_i32(%arg0: tensor<1x3x!quant.uniform>, %arg1: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> { + %0 = stablehlo.add %arg0, %arg1 : (tensor<1x3x!quant.uniform>, tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// CHECK-LABEL: func @add_i32 +// CHECK: stablehlo.add +// CHECK-NOT: tfl.add diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index c124f33d6f55eb..f9417d6da30274 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -62,6 +62,7 @@ using ::mlir::quant::CreateI32F32UniformQuantizedType; using ::mlir::quant::CreateI8F32UniformQuantizedPerAxisType; using ::mlir::quant::CreateI8F32UniformQuantizedType; using ::mlir::quant::FindUserOfType; +using ::mlir::quant::GetElementType; using ::mlir::quant::IsI32F32UniformQuantizedPerAxisType; using ::mlir::quant::IsI32F32UniformQuantizedType; using ::mlir::quant::IsI8F32UniformQuantizedPerAxisType; @@ -142,10 +143,7 @@ TFL::QConstOp CreateTransposedTflConstOpForFilter( Type new_filter_quantized_type; if (is_per_channel) { - auto filter_quantized_type = filter_constant_op.getResult() - .getType() - .cast() - .getElementType() + auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) .cast(); new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( filter_constant_op->getLoc(), *rewriter.getContext(), @@ -153,10 +151,7 @@ TFL::QConstOp CreateTransposedTflConstOpForFilter( filter_quantized_type.getZeroPoints(), /*quantization_dimension=*/0, /*narrow_range=*/true); } else { - auto filter_quantized_type = filter_constant_op.getResult() - .getType() - .cast() - .getElementType() + auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) .cast(); new_filter_quantized_type = CreateI8F32UniformQuantizedType( filter_constant_op->getLoc(), *rewriter.getContext(), @@ -224,9 +219,7 @@ TFL::QConstOp CreateTflConstOpForDummyBias( Type bias_quantized_type; if (is_per_channel) { const auto filter_quantized_element_type = - filter_const_op.getResult() - .getType() - .getElementType() + GetElementType(filter_const_op.getResult()) .cast(); // The storage type is i32 for bias, which is the precision used for @@ -238,9 +231,7 @@ TFL::QConstOp CreateTflConstOpForDummyBias( /*quantization_dimension=*/0); } else { const auto filter_quantized_element_type = - filter_const_op.getResult() - .getType() - .getElementType() + GetElementType(filter_const_op.getResult()) .cast(); // The storage type is i32 for bias, which is the precision used for @@ -277,8 +268,9 @@ arith::ConstantOp CreateI32ShapeConstantOp(const TensorType op_type, } // Returns the desired qi8 per-tensor quantized output type for a given gemm op. -Type GetOutputType(Operation* op, MLIRContext& ctx, const bool has_i32_output, - const bool fuse_bias_constant) { +Type GetQuantizedOutputType(Operation* op, PatternRewriter& rewriter, + const bool has_i32_output, + const bool fuse_bias_constant) { Operation* uniform_quantize_op; if (!has_i32_output) return op->getResult(0).getType(); if (fuse_bias_constant) { @@ -289,17 +281,15 @@ Type GetOutputType(Operation* op, MLIRContext& ctx, const bool has_i32_output, } // StableHLO Quantizer outputs an i32 type. Rewrite to i8 type result // to meet TFLite op requirement. - auto result_quantized_type = uniform_quantize_op->getResult(0) - .getType() - .cast() - .getElementType() + auto result_quantized_type = GetElementType(uniform_quantize_op->getResult(0)) .cast(); auto new_result_quantized_type = CreateI8F32UniformQuantizedType( - uniform_quantize_op->getLoc(), ctx, result_quantized_type.getScale(), - result_quantized_type.getZeroPoint()); + uniform_quantize_op->getLoc(), *rewriter.getContext(), + result_quantized_type.getScale(), result_quantized_type.getZeroPoint()); // Omit any bias and requantize ops as `tfl.{gemm_op}` outputs a // fused `qi8` type. - FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); + rewriter.replaceAllUsesWith(uniform_quantize_op->getResult(0), + op->getResult(0)); return op->getResult(0).getType().cast().clone( new_result_quantized_type); } @@ -315,8 +305,7 @@ class RewriteUniformQuantizeOp // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/quantize.cc#L105). LogicalResult match(stablehlo::UniformQuantizeOp op) const override { - const Type input_element_type = - op.getOperand().getType().cast().getElementType(); + const Type input_element_type = GetElementType(op.getOperand()); if (!(input_element_type.isa() || IsI32F32UniformQuantizedType(input_element_type) || IsI32F32UniformQuantizedPerAxisType(input_element_type))) { @@ -328,10 +317,7 @@ class RewriteUniformQuantizeOp // Output type of `UniformQuantizeOp` is guaranteed to be a quantized // tensor with integer storage type. - const auto output_storage_type = op.getResult() - .getType() - .cast() - .getElementType() + const auto output_storage_type = GetElementType(op.getResult()) .cast() .getStorageType() .cast(); @@ -363,10 +349,7 @@ class RewriteUniformDequantizeOp // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/dequantize.cc#L52). LogicalResult match(stablehlo::UniformDequantizeOp op) const override { - const auto input_storage_type = op.getOperand() - .getType() - .cast() - .getElementType() + const auto input_storage_type = GetElementType(op.getOperand()) .cast() .getStorageType() .cast(); @@ -377,11 +360,8 @@ class RewriteUniformDequantizeOp } // Output type is guaranteed to be a float tensor for a valid StableHLO. - const auto output_element_type = op.getResult() - .getType() - .cast() - .getElementType() - .cast(); + const auto output_element_type = + GetElementType(op.getResult()).cast(); if (!output_element_type.isa()) { LLVM_DEBUG(llvm::dbgs() << "Uniform dequantize op's output element type " "should be f32. Got: " @@ -448,8 +428,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp op.getDotDimensionNumbers(); const bool is_batch_matmul = !dot_dimension_nums.getLhsBatchingDimensions().empty(); - const Type elem_type = - op.getResult().getType().cast().getElementType(); + const Type elem_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(elem_type) || IsI32F32UniformQuantizedPerAxisType(elem_type); @@ -479,8 +458,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const override { - const Type output_type = - op.getResult().getType().cast().getElementType(); + const Type output_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(output_type) || IsI32F32UniformQuantizedPerAxisType(output_type); @@ -656,8 +634,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp static LogicalResult MatchOutput(const Value output, const bool has_i32_output, const bool is_batch_matmul) { - const Type output_element_type = - output.getType().cast().getElementType(); + const Type output_element_type = GetElementType(output); if (has_i32_output) { if (is_batch_matmul && !IsI32F32UniformQuantizedType(output_element_type)) { @@ -760,11 +737,8 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp TFL::QConstOp filter_constant_op = CreateTflConstOpForFilter( rhs_value.getDefiningOp(), rewriter, /*is_per_channel=*/true); - const double input_scale = lhs_value.getType() - .cast() - .getElementType() - .cast() - .getScale(); + const double input_scale = + GetElementType(lhs_value).cast().getScale(); TFL::QConstOp bias_tfl_op; bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; @@ -800,16 +774,10 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp Operation* add_op = FindUserOfType(op); uniform_quantize_op = FindUserOfType(add_op); const auto filter_quantized_type = - op->getOperand(1) - .getType() - .cast() - .getElementType() + GetElementType(op->getOperand(1)) .cast(); const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/op->getOperand(0) - .getType() - .cast() - .getElementType() + /*input_scale=*/GetElementType(op->getOperand(0)) .cast() .getScale(), /*filter_scales=*/filter_quantized_type.getScales()); @@ -821,10 +789,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp const auto bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( op->getLoc(), *op->getContext(), std::move(bias_scales), - op->getResult(0) - .getType() - .cast() - .getElementType() + GetElementType(op->getResult(0)) .cast() .getZeroPoints(), /*quantization_dimension=*/0); @@ -841,11 +806,9 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp uniform_quantize_op = FindUserOfType(op); } - const auto result_quantized_type = uniform_quantize_op->getResult(0) - .getType() - .cast() - .getElementType() - .cast(); + const auto result_quantized_type = + GetElementType(uniform_quantize_op->getResult(0)) + .cast(); const auto new_result_quantized_type = CreateI8F32UniformQuantizedType( uniform_quantize_op->getLoc(), *rewriter.getContext(), result_quantized_type.getScale(), @@ -856,8 +819,8 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // fused `qi8` type. FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); } else { - output_type = GetOutputType(op, *rewriter.getContext(), has_i32_output, - fuse_bias_constant); + output_type = GetQuantizedOutputType(op, rewriter, has_i32_output, + fuse_bias_constant); } return output_type; } @@ -898,8 +861,8 @@ class RewriteQuantizedConvolutionOp public: using OpRewritePattern::OpRewritePattern; LogicalResult match(stablehlo::ConvolutionOp op) const override { - const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( - op.getResult().getType().cast().getElementType()); + const bool has_i32_output = + IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); const bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; stablehlo::ConvDimensionNumbersAttr dimension_numbers = @@ -965,8 +928,8 @@ class RewriteQuantizedConvolutionOp void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const override { - const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType( - op.getResult().getType().cast().getElementType()); + const bool has_i32_output = + IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); stablehlo::ConvDimensionNumbersAttr dimension_numbers = op.getDimensionNumbers(); @@ -993,8 +956,8 @@ class RewriteQuantizedConvolutionOp input_value = pad_op.getResult(); } - const Type output_type = GetOutputType(op, *rewriter.getContext(), - has_i32_output, fuse_bias_constant); + const Type output_type = GetQuantizedOutputType( + op, rewriter, has_i32_output, fuse_bias_constant); const auto [stride_h, stride_w] = GetStrides(op); const auto [dilation_h_factor, dilation_w_factor] = GetDilationFactors(op); if (is_depthwise) { @@ -1110,8 +1073,7 @@ class RewriteQuantizedConvolutionOp } static LogicalResult MatchOutput(Value output) { - const Type output_element_type = - output.getType().cast().getElementType(); + const Type output_element_type = GetElementType(output); if (!IsI32F32UniformQuantizedPerAxisType(output_element_type) && !IsI8F32UniformQuantizedType(output_element_type)) { LLVM_DEBUG( @@ -1397,10 +1359,7 @@ class RewriteQuantizedConvolutionOp Value filter_value = op.getOperand(1); Operation* filter_op = filter_value.getDefiningOp(); auto filter_uniform_quantized_type = - filter_value.getType() - .cast() - .getElementType() - .cast(); + GetElementType(filter_value).cast(); auto filter_constant_value_attr = cast( cast(filter_value.getDefiningOp()).getValue()); const DenseIntElementsAttr new_filter_value_attr = @@ -1440,10 +1399,7 @@ class RewriteQuantizedConvolutionOp const SmallVector bias_shape, const bool has_i32_output, const bool fuse_bias_constant) const { const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/op.getOperand(0) - .getType() - .cast() - .getElementType() + /*input_scale=*/GetElementType(op.getOperand(0)) .cast() .getScale(), /*filter_scales=*/new_filter_quantized_type.getScales()); @@ -2108,6 +2064,44 @@ class RewriteQuantizedDynamicSliceOp } }; +class RewriteQuantizedAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::AddOp op) const override { + return success(IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) && + IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))); + } + + void rewrite(stablehlo::AddOp op, PatternRewriter& rewriter) const override { + TFL::QConstOp lhs_qconst_op; + TFL::QConstOp rhs_qconst_op; + + auto GetBroadcastedConstOp = [&](Value operand) -> TFL::QConstOp { + if (auto broadcast_op = dyn_cast_or_null( + operand.getDefiningOp())) { + auto stablehlo_const_op = dyn_cast_or_null( + broadcast_op.getOperand().getDefiningOp()); + auto const_uniform_quantized_type = + stablehlo_const_op.getResult().getType().cast(); + return rewriter.create( + op.getLoc(), TypeAttr::get(const_uniform_quantized_type), + cast(stablehlo_const_op.getValue())); + } + return nullptr; + }; + + lhs_qconst_op = GetBroadcastedConstOp(op.getLhs()); + rhs_qconst_op = GetBroadcastedConstOp(op.getRhs()); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), + lhs_qconst_op ? lhs_qconst_op : op.getOperand(0), + rhs_qconst_op ? rhs_qconst_op : op.getOperand(1), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + } +}; + void UniformQuantizedStableHloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); @@ -2121,7 +2115,7 @@ void UniformQuantizedStableHloToTflPass::runOnOperation() { RewriteQuantizedGatherOp, RewriteQuantizedPadOp, RewriteQuantizedReduceWindowOpWithMax, RewriteQuantizedReshapeOp, RewriteQuantizedSelectOp, RewriteQuantizedSliceOp, - RewriteQuantizedTransposeOp>(&ctx); + RewriteQuantizedTransposeOp, RewriteQuantizedAddOp>(&ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc index ccf2ddf768b88b..0f9932d053cb4d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc @@ -98,10 +98,11 @@ void PopulateDefaultCalibrationOptions(QuantizationConfig& quant_config) { // {matcher {function_name {regex: ".*"}} // {method {static_range_ptq {}}} // } -QuantizationSpec GetDefaultStaticRangePtqSpec() { +QuantizationSpec GetDefaultStaticRangePtqSpec(StaticRangePtqPreset preset) { QuantizationSpec spec{}; // Default for all ops. - spec.mutable_matcher()->mutable_function_name()->set_regex(".*"); + spec.mutable_matcher()->mutable_function_name()->set_regex( + preset.enable_full_int_quantization() ? ".*" : "^.*(conv|dot|gather).*"); spec.mutable_method()->mutable_static_range_ptq(); return spec; @@ -161,7 +162,8 @@ void ExpandStaticRangePtqPreset(const StaticRangePtqPreset& preset, // expansion from `StaticRangePtqPreset` gets populated first and then // user-provided explicit `QuantizationSpec`s will be appended. QuantizationSpecs new_specs{}; - *new_specs.add_specs() = GetDefaultStaticRangePtqSpec(); + *new_specs.add_specs() = + GetDefaultStaticRangePtqSpec(/*preset=*/config.static_range_ptq_preset()); *new_specs.add_specs() = GetStaticRangePtqSpecForConvolution(); // Append user-provided specs to override existing specs. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc index 70d23808f6df97..e3f2bfde3d10c3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc @@ -147,10 +147,12 @@ TEST(ExpandPresetsTest, ExpandUnspecifiedPreset) { EXPECT_FALSE(new_config.has_pipeline_config()); } -TEST(ExpandPresetsTest, ExpandStaticRangePtqPreset) { +TEST(ExpandPresetsTest, ExpandStaticRangePtqEnableFullIntquantization) { QuantizationConfig config{}; RepresentativeDatasetConfig& preset_dataset_config = *config.mutable_static_range_ptq_preset()->add_representative_datasets(); + config.mutable_static_range_ptq_preset()->set_enable_full_int_quantization( + true); preset_dataset_config.mutable_tf_record()->set_path("/test/path"); const QuantizationConfig new_config = ExpandPresets(config); @@ -185,6 +187,21 @@ TEST(ExpandPresetsTest, ExpandStaticRangePtqPreset) { StrEq("/test/path")); } +TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetDefault) { + QuantizationConfig config{}; + RepresentativeDatasetConfig& preset_dataset_config = + *config.mutable_static_range_ptq_preset()->add_representative_datasets(); + preset_dataset_config.mutable_tf_record()->set_path("/test/path"); + + const QuantizationConfig new_config = ExpandPresets(config); + ASSERT_THAT(new_config.specs().specs(), SizeIs(2)); + + const QuantizationSpec& spec = new_config.specs().specs(0); + EXPECT_THAT(spec.matcher().function_name().regex(), + StrEq("^.*(conv|dot|gather).*")); + EXPECT_TRUE(spec.method().has_static_range_ptq()); +} + TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetWithTopLevelRepresentativeDataset) { // Test the scenario where both @@ -216,7 +233,8 @@ TEST(ExpandPresetsTest, TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetThenAppendExplicitSpecs) { QuantizationConfig config{}; - config.mutable_static_range_ptq_preset(); + config.mutable_static_range_ptq_preset()->set_enable_full_int_quantization( + true); QuantizationSpec& user_provided_spec = *config.mutable_specs()->add_specs(); user_provided_spec.mutable_matcher()->mutable_function_name()->set_regex( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index c871ab3ac1adc2..ebe950c58142f6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -58,8 +58,11 @@ void AddPostCalibrationPasses( OpPassManager& pm, const PipelineConfig& pipeline_config, const StaticRangePtqPreset& static_range_ptq_preset) { QuantizeCompositeFunctionsPassOptions options; + // TODO: b/331120943 - Use QuantizationConfig instead of preset flags. options.enable_per_channel_quantized_weight_ = static_range_ptq_preset.enable_per_channel_quantized_weight(); + options.enable_full_int_quantization_ = + static_range_ptq_preset.enable_full_int_quantization(); // For debugging purposes. options.mlir_dump_file_name_ = "quantize_composite_functions"; options.enable_weight_only_ = false; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td index 07598356cce7d3..eaa8a9092f41f2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td @@ -67,3 +67,11 @@ def LiftGather : Pat< (NamedAttr<"slice_sizes"> $slice_sizes), (NamedAttr<"indices_are_sorted"> (DefaultOrNullAttr $indices_are_sorted)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $operand)], [], (addBenefit 1)>; + +def LiftAdd : Pat< + (StableHLO_AddOp:$res + $lhs, $rhs), + (LiftAsTFXlaCallModule<"composite_add_fn"> + (ArgumentList $lhs, $rhs), + (ResultList $res)), + [(IsNotInLiftedFunc $res), (IsNotInStableHloOpRegion $res)], [], (addBenefit 1)>; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index e69f4d02b4ba84..80847e8283652e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -60,6 +60,10 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, + Option<"enable_full_int_quantization_", + "enable-full-int-quantization", + "bool", /*default=*/"false", + "Whether to enable full int quantization, including non compute-heavy ops.">, Option<"mlir_dump_file_name_", "mlir-dump-file-name", "std::optional", /*default=*/"std::nullopt", "MLIR dump file name.">, @@ -102,6 +106,10 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { "enable-per-channel-quantized-weight", "bool", /*default=*/"true", "Whether to enable per-channel quantized weights.">, + Option<"enable_full_int_quantization_", + "enable-full-int-quantization", + "bool", /*default=*/"false", + "Whether to apply full int quantization, including non compute-heavy ops.">, Option<"enable_weight_only_", "enable-weight-only", "bool", /*default=*/"false", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 3b53bef99ba179..f2b78caeb3cb44 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -954,6 +954,12 @@ void PopulateComputeHeavyPatterns( patterns.add(ctx); } +void PopulateAllQuantizablePatterns(MLIRContext& ctx, + RewritePatternSet& patterns) { + patterns.add>>( + ctx, /*enable_per_channel_quantized_weight=*/false); +} + void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, RewritePatternSet& patterns) { patterns.add< diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 7e30fb54966077..9aa33ee0316ee1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -254,6 +254,11 @@ class StableHloQuantizationPattern : public OpRewritePattern { void PopulateComputeHeavyPatterns(MLIRContext& ctx, RewritePatternSet& patterns, bool enable_per_channel_quantized_weight); +// Populates conversion patterns for all quantizable ops, including +// ops that are not compute-heavy and data movement ops. +void PopulateAllQuantizablePatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + // Populates pattern weight-only quantization. void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, RewritePatternSet& patterns); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index a0749a4f3d3caa..4d6f4b3fe86832 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -96,9 +96,11 @@ class QuantizePass : public impl::QuantizePassBase { using impl::QuantizePassBase::QuantizePassBase; explicit QuantizePass(const bool enable_per_channel_quantized_weight, + const bool enable_full_int_quantization, const bool enable_weight_only, const QuantizationSpecs& quant_specs) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + enable_full_int_quantization_ = enable_full_int_quantization; enable_weight_only_ = enable_weight_only; } @@ -120,6 +122,11 @@ void QuantizePass::runOnOperation() { PopulateComputeHeavyPatterns(ctx, patterns, enable_per_channel_quantized_weight_); + // Quantize all quantizable ops, including ops that are not compute-heavy. + if (enable_full_int_quantization_) { + PopulateAllQuantizablePatterns(ctx, patterns); + } + if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) { // There are cases where no rewrites happen even if a pattern matches, // causing this to result in a convergence failure. Consider this as a diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index 3583ff4cb4c08d..f3cf92dde359d1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -55,8 +55,9 @@ class QuantizeCompositeFunctionsPass explicit QuantizeCompositeFunctionsPass( const bool enable_per_channel_quantized_weight, - const bool enable_weight_only) { + const bool enable_weight_only, const bool enable_full_int_quantization) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + enable_full_int_quantization_ = enable_full_int_quantization; enable_weight_only_ = enable_weight_only; } @@ -89,6 +90,8 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { QuantizePassOptions quantize_options; quantize_options.enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight_; + quantize_options.enable_full_int_quantization_ = + enable_full_int_quantization_; quantize_options.enable_weight_only_ = enable_weight_only_; // QuantizePass modifies FuncOps referenced outside of its given scope // and therefore requires a module-level context. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h index 7ba129d1c7a40d..a8a59d1cd3b46b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h @@ -23,9 +23,10 @@ namespace mlir::quant::stablehlo::testing { // `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. The pass // option argument is specified in line comments for each enum value. enum class TestQuantizationSpecs { - kEmpty, // empty - kDisableAllDotGeneral, // disable-all-dot-general - kStaticRangePtqToAll, // static-range-ptq-to-all + kEmpty, // empty + kDisableAllDotGeneral, // disable-all-dot-general + kStaticRangePtqToAll, // static-range-ptq-to-all + kStaticRangePtqToComputeHeavy, // static-range-ptq-to-compute-heavy }; // Adds generated pass default constructors or options definitions. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td index c2be397d764d58..ee525f2deead04 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td @@ -80,7 +80,9 @@ def TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass : clEnumValN(mlir::quant::stablehlo::testing::TestQuantizationSpecs::kDisableAllDotGeneral, "disable-all-dot-general", "Disables all dot_general ops by matching lifted function names"), clEnumValN(mlir::quant::stablehlo::testing::TestQuantizationSpecs::kStaticRangePtqToAll, - "static-range-ptq-to-all", "Applies `StaticRangePtq` to all quantizable units.") + "static-range-ptq-to-all", "Applies `StaticRangePtq` to all quantizable units."), + clEnumValN(mlir::quant::stablehlo::testing::TestQuantizationSpecs::kStaticRangePtqToComputeHeavy, + "static-range-ptq-to-compute-heavy", "Applies `StaticRangePtq` to only compute heavy units.") )}]> ]; let dependentDialects = [ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc index 25920c986e4d1d..062fbdddd4150d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_lift_quantizable_spots_as_functions_with_quantization_specs.cc @@ -62,6 +62,15 @@ constexpr absl::string_view kSpecsStaticRangePtqToAll = method { static_range_ptq {} } }])pb"; +// Configure `QuantizationSpecs` to apply `StaticRangePtq` to compute heavy +// units. +constexpr absl::string_view kSpecsStaticRangePtqToComputeHeavy = + R"pb(specs + [ { + matcher { function_name { regex: "^.*(conv|dot|gather).*" } } + method { static_range_ptq {} } + }])pb"; + class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass : public impl:: TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< @@ -88,6 +97,8 @@ absl::string_view GetQuantizationSpecsTextProto( return kSpecsDisableAllDotGeneral; case TestQuantizationSpecs::kStaticRangePtqToAll: return kSpecsStaticRangePtqToAll; + case TestQuantizationSpecs::kStaticRangePtqToComputeHeavy: + return kSpecsStaticRangePtqToComputeHeavy; } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index bc25b9a858440f..a76ca4e75ac764 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -442,7 +442,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: testing.get_size_ratio( self._output_saved_model_path, self._input_saved_model_path ), - 0.6, + 0.61, ) @parameterized.parameters( @@ -931,7 +931,7 @@ def data_gen() -> repr_dataset.RepresentativeDataset: testing.get_size_ratio( self._output_saved_model_path, self._input_saved_model_path ), - 0.4, + 0.46, ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto index 93d98d9067ef9c..efdceebd6c2008 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto @@ -53,7 +53,7 @@ message RepresentativeDatasetConfig { // channel dimension, which assumes the weight tensor is in NHWC format. // * Applies static-range PTQ for all other ops. // -// Next ID: 3 +// Next ID: 4 message StaticRangePtqPreset { // Configures representative dataset. Each item corresponds to a // representative dataset used to calibrate a function. @@ -72,6 +72,9 @@ message StaticRangePtqPreset { // // Default value: true bool enable_per_channel_quantized_weight = 2 [deprecated = true]; + + // Whether to quantize all quantizable ops or only compute-heavy ops. + bool enable_full_int_quantization = 3; } // Applies int8 per-tensor weight-only quantization for all dot_general op. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir index c8bffa8be6b6b4..69bf09104c814d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/lift_quantizable_spots_as_functions_with_quantization_specs.mlir @@ -1,9 +1,5 @@ // RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=disable-all-dot-general" \ // RUN: -split-input-file | FileCheck %s --check-prefix=DISABLE-ALL-DOT-GENERAL -// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=empty" \ -// RUN: -split-input-file | FileCheck %s --check-prefix=EMPTY -// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=static-range-ptq-to-all" \ -// RUN: -split-input-file | FileCheck %s --check-prefix=STATIC-RANGE-PTQ-TO-ALL // Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp // contains attributes required for quantization, including the @@ -16,8 +12,8 @@ func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { return %1 : tensor<1x1x64xf32> } -// DISABLE-ALL-DOT-GENERAL: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> -// DISABLE-ALL-DOT-GENERAL: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// DISABLE-ALL-DOT-GENERAL: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// DISABLE-ALL-DOT-GENERAL: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) // Check that the `_quantization_method` attribute contains the quantization // method in textproto format. The dot_general op quantization is explicitly @@ -27,17 +23,20 @@ func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { // DISABLE-ALL-DOT-GENERAL-SAME: _quantization_method = "no_quantization { }" // DISABLE-ALL-DOT-GENERAL-SAME: _tfl_quant_trait = "fully_quantizable" -// DISABLE-ALL-DOT-GENERAL: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// DISABLE-ALL-DOT-GENERAL: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32> // DISABLE-ALL-DOT-GENERAL: } // DISABLE-ALL-DOT-GENERAL-LABEL: private @composite_dot_general_fn_1 // DISABLE-ALL-DOT-GENERAL-SAME: tf_quant.composite_function -// DISABLE-ALL-DOT-GENERAL: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 -// DISABLE-ALL-DOT-GENERAL: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> +// DISABLE-ALL-DOT-GENERAL: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1 +// DISABLE-ALL-DOT-GENERAL: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32> // DISABLE-ALL-DOT-GENERAL: } // ----- +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=empty" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=EMPTY + // Tests that `composite_dot_general_fn_1` and its corresponding XlaCallModuleOp // contains attributes required for quantization. `_quantization_method` is not // set, as it is implicitly disabled. @@ -49,8 +48,8 @@ func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { return %1 : tensor<1x1x64xf32> } -// EMPTY: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> -// EMPTY: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// EMPTY: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// EMPTY: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) // Check that the `_quantization_method` attribute doesn't contain the // quantization method, implying "no_quantization". @@ -59,17 +58,20 @@ func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { // EMPTY-NOT: _quantization_method // EMPTY-SAME: _tfl_quant_trait = "fully_quantizable" -// EMPTY: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// EMPTY: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32> // EMPTY: } // EMPTY-LABEL: private @composite_dot_general_fn_1 // EMPTY-SAME: tf_quant.composite_function -// EMPTY: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 -// EMPTY: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> +// EMPTY: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1 +// EMPTY: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32> // EMPTY: } // ----- +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=static-range-ptq-to-all" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=STATIC-RANGE-PTQ-TO-ALL + // STATIC-RANGE-PTQ-TO-ALL: @main func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> @@ -80,8 +82,8 @@ func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { // contains attributes required for quantization, including the // `_quantization_method` attribute that contains textpb of `Method`. -// STATIC-RANGE-PTQ-TO-ALL: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> -// STATIC-RANGE-PTQ-TO-ALL: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// STATIC-RANGE-PTQ-TO-ALL: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// STATIC-RANGE-PTQ-TO-ALL: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) // Check that the `_quantization_method` attribute contains the quantization // method in textproto format, enabling static-range PTQ. @@ -90,11 +92,34 @@ func.func @main(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { // STATIC-RANGE-PTQ-TO-ALL-SAME: _quantization_method = "static_range_ptq { }" // STATIC-RANGE-PTQ-TO-ALL-SAME: _tfl_quant_trait = "fully_quantizable" -// STATIC-RANGE-PTQ-TO-ALL: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// STATIC-RANGE-PTQ-TO-ALL: return %[[XLA_CALL_MODULE:.+]] : tensor<1x1x64xf32> // STATIC-RANGE-PTQ-TO-ALL: } // STATIC-RANGE-PTQ-TO-ALL-LABEL: private @composite_dot_general_fn_1 // STATIC-RANGE-PTQ-TO-ALL-SAME: tf_quant.composite_function -// STATIC-RANGE-PTQ-TO-ALL: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 -// STATIC-RANGE-PTQ-TO-ALL: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> +// STATIC-RANGE-PTQ-TO-ALL: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1 +// STATIC-RANGE-PTQ-TO-ALL: return %[[DOT_GENERAL:.+]] : tensor<1x1x64xf32> // STATIC-RANGE-PTQ-TO-ALL: } + +// ----- + +// RUN: stablehlo-quant-opt %s -stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs="quantization-specs=static-range-ptq-to-compute-heavy" \ +// RUN: -split-input-file | FileCheck %s --check-prefix=STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY + +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: @main +func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 = stablehlo.add %arg0, %arg0 : tensor<1x2xf32> + return %0 : tensor<1x2xf32> +} +// Tests that `composite_add_fn_1` does not quantize when quantizing +// only compute-heavy ops. + +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: %[[CONST:.+]] = stablehlo.constant dense<2.000000e+00> +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%arg0, %arg0) + +// Check that the `_quantization_method` attribute contains the quantization +// method in textproto format, enabling static-range PTQ. +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _entry_function = @composite_add_fn_1 +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _original_entry_function +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY-NOT: _quantization_method +// STATIC-RANGE-PTQ-TO-COMPUTE-HEAVY: _tfl_quant_trait = "fully_quantizable" diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir new file mode 100644 index 00000000000000..72851d92b64b75 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_all_ops.mlir @@ -0,0 +1,46 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -stablehlo-quantize-composite-functions=enable-full-int-quantization=true | FileCheck --check-prefix=CHECK-FULL-INT %s + +// Tests that a basic `stablehlo.add` and a fused `stablehlo.dot_general` +// are properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK-FULL-INT: func.func private @quantize_add_fn(%[[ARG:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_add_fn(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst_0 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %cst_1 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _original_entry_function = "composite_add_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %3 = "quantfork.stats"(%2) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %4 = "tf.XlaCallModule"(%3, %cst_1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantfork.stats"(%4) {layerStats = dense<[5.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %5 : tensor<1x3xf32> + } +// CHECK-FULL-INT: %[[CONST:.+]] = stablehlo.constant() {value = dense<127> : tensor<1x2xi8>} : () -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[CONST_0:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> +// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[CALL:.+]] = call @quantized_add_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[UNIFORM_DEQUANTIZE]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-FULL-INT: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-FULL-INT: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK-FULL-INT: func.func private @quantized_add_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + return %0 : tensor<1x2xf32> + } +// CHECK-FULL-INT: %[[ADD:.+]] = stablehlo.add %arg0, %arg1 : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK-FULL-INT: return %[[ADD]] : tensor<1x2x!quant.uniform> + +// CHECK-FULL-INT: func.func private @quantized_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK-FULL-INT: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1,{{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-FULL-INT: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-FULL-INT: return %[[UNIFORM_QUANTIZE]] : tensor<1x3x!quant.uniform> +} diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 849e854798368b..9b91a640bc7923 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -868,6 +868,7 @@ def _get_base_converter_args(self): ) ], enable_per_channel_quantized_weight=True, + enable_full_int_quantization=True, ), # For ODML use cases, uniform quantized types should be left intact. pipeline_config=qc.PipelineConfig( From a7d81b539f9834e4033a178730676e4b0e48c485 Mon Sep 17 00:00:00 2001 From: Wren Romano Date: Thu, 28 Mar 2024 15:04:52 -0700 Subject: [PATCH 058/124] [XLA:Python] Adding `xla::PrimitiveType` <-> `numpy.dtype` conversions to the library for internal debugging tools. PiperOrigin-RevId: 620068167 --- third_party/xla/xla/python/tools/BUILD | 1 + third_party/xla/xla/python/tools/_types.cc | 33 ++++++++++++++++++- third_party/xla/xla/python/tools/_types.pyi | 2 ++ third_party/xla/xla/python/tools/types.py | 4 ++- .../xla/xla/python/tools/types_test.py | 14 ++++---- 5 files changed, 44 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD index ac45d9d3f10c66..2338525fe744d3 100644 --- a/third_party/xla/xla/python/tools/BUILD +++ b/third_party/xla/xla/python/tools/BUILD @@ -63,6 +63,7 @@ tsl_pybind_extension( "//xla:literal", "//xla:xla_data_proto_cc", "//xla/python:logging", + "//xla/python:nb_numpy", "//xla/python:types", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/python/tools/_types.cc b/third_party/xla/xla/python/tools/_types.cc index f18e360c13c7fb..320404637a0462 100644 --- a/third_party/xla/xla/python/tools/_types.cc +++ b/third_party/xla/xla/python/tools/_types.cc @@ -25,6 +25,7 @@ limitations under the License. #include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf #include "xla/literal.h" #include "xla/python/logging.h" +#include "xla/python/nb_numpy.h" #include "xla/python/types.h" #include "xla/xla_data.pb.h" // NOTE: The tsl-numpy header forbids importing the actual NumPy arrayobject.h @@ -59,6 +60,19 @@ absl::StatusOr MakeNdarray(const xla::LiteralProto& proto) { // Convert `nb::object` into `py::object`. return py::reinterpret_steal(nbobj.release().ptr()); } + +// Partial reversion of cl/617156835, until we can get the proto-casters +// (and hence the extension) switched over to nanobind. +// TODO(wrengr): Or can we mix `{py,nb}::module_::def` calls?? +absl::StatusOr DtypeToEtype(const py::dtype& py_d) { + auto nb_d = nb::borrow(py_d.ptr()); + return xla::DtypeToPrimitiveType(nb_d); +} + +absl::StatusOr EtypeToDtype(xla::PrimitiveType p) { + TF_ASSIGN_OR_RETURN(xla::nb_dtype nb_d, xla::PrimitiveTypeToNbDtype(p)); + return py::reinterpret_steal(nb_d.release().ptr()); +} } // namespace // NOTE: It seems insurmountable to get "native_proto_caster.h" to work @@ -98,7 +112,8 @@ PYBIND11_MODULE(_types, py_m) { py::module_::import("ml_dtypes"); // Ensure that tsl-numpy initializes datastructures of the actual-NumPy - // implementation, and does whatever else tsl-numpy needs. + // implementation, and does whatever else tsl-numpy needs. This is + // also necessary for using the `xla::nb_dtype` type. tsl::ImportNumpy(); // Declare that C++ can `nb::cast` from `std::shared_ptr` @@ -124,5 +139,21 @@ PYBIND11_MODULE(_types, py_m) { of tuples with leaves being `numpy.ndarray` views of array-shaped sub-literals. )pbdoc"); + + // This method name is based on `xla_client.dtype_to_etype`. + // NOTE: `xla_client` uses a Python class wrapping the protobuf-enum, + // rather than using the protobuf-enum directly. See the module docstring + // in "types.py" for more explanation on why. + py_m.def("dtype_to_etype", &DtypeToEtype, py::arg("dtype").none(false), + py::pos_only(), R"pbdoc( + Converts `numpy.dtype` into + `tensorflow.compiler.xla.xla_data_pb2.PrimitiveType`. + )pbdoc"); + + py_m.def("etype_to_dtype", &EtypeToDtype, py::arg("ptype").none(false), + py::pos_only(), R"pbdoc( + Converts `tensorflow.compiler.xla.xla_data_pb2.PrimitiveType` into + `numpy.dtype`. + )pbdoc"); // LINT.ThenChange(_types.pyi) } diff --git a/third_party/xla/xla/python/tools/_types.pyi b/third_party/xla/xla/python/tools/_types.pyi index 1ca5071367a0cd..f355656f05b674 100644 --- a/third_party/xla/xla/python/tools/_types.pyi +++ b/third_party/xla/xla/python/tools/_types.pyi @@ -20,4 +20,6 @@ from xla import xla_data_pb2 # LINT.IfChange NdarrayTree = Union[np.ndarray, tuple['NdarrayTree', ...]] def make_ndarray(proto: xla_data_pb2.LiteralProto, /) -> NdarrayTree: ... +def dtype_to_etype(dtype: np.dtype, /) -> xla_data_pb2.PrimitiveType: ... +def etype_to_dtype(ptype: xla_data_pb2.PrimitiveType, /) -> np.dtype: ... # LINT.ThenChange(types.py, _types.cc) diff --git a/third_party/xla/xla/python/tools/types.py b/third_party/xla/xla/python/tools/types.py index d13ee2241ed479..189758f1e749c8 100644 --- a/third_party/xla/xla/python/tools/types.py +++ b/third_party/xla/xla/python/tools/types.py @@ -40,10 +40,12 @@ # NOTE: `import as ` is required for names to be exported. # See PEP 484 & -# pylint: disable=g-importing-member,useless-import-alias,unused-import +# pylint: disable=g-importing-member,useless-import-alias,unused-import,g-multiple-import # LINT.IfChange from ._types import ( make_ndarray as make_ndarray, + dtype_to_etype as dtype_to_etype, + etype_to_dtype as etype_to_dtype, ) # TODO(wrengr): We can't import the `NdarrayTree` defined in the pyi file. # So re-defining it here for now. diff --git a/third_party/xla/xla/python/tools/types_test.py b/third_party/xla/xla/python/tools/types_test.py index e056e05be24f35..a6cdb1d0f76b13 100644 --- a/third_party/xla/xla/python/tools/types_test.py +++ b/third_party/xla/xla/python/tools/types_test.py @@ -148,14 +148,12 @@ class MakeNdarrayValidTest(parameterized.TestCase): def testHasCorrectDtype(self, proto, arr): """Test that the result has the right dtype.""" - # Silence [unused-argument] warning. - del proto - # TODO(wrengr): Add pybind for `xla::PrimitiveTypeToDtype`, - # so that we can avoid hard-coding the expected np.dtype. - # Alternatively, we could use `xla_client.dtype_to_etype` (ideally - # after refactoring that into a small library, so we need not pull in - # all the rest of xla_client). - self.assertEqual(np.float64, arr.dtype) + e = proto.shape.element_type + d = arr.dtype + with self.subTest(msg='etype_to_dtype'): + self.assertEqual(types.etype_to_dtype(e), d) + with self.subTest(msg='dtype_to_etype'): + self.assertEqual(e, types.dtype_to_etype(d)) def testHasCorrectRank(self, proto, arr): """Test that the result has the right rank.""" From a9db13b49666b6461249bd6eb82fdc02fb5bb16c Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Thu, 28 Mar 2024 15:09:11 -0700 Subject: [PATCH 059/124] Integrate StableHLO at openxla/stablehlo@271e8634 PiperOrigin-RevId: 620069321 --- third_party/stablehlo/workspace.bzl | 4 ++-- third_party/xla/third_party/stablehlo/workspace.bzl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index b79bde1851c6c6..ca2f3a937f73ad 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "f4459e76553770ecc94f23de29984c7859ad9f05" - STABLEHLO_SHA256 = "00e2bcd62db577297a0a9b6f9203a9f2f58bd40bfa2574908ffa883ad7f60fd5" + STABLEHLO_COMMIT = "271e8634de184fbfafd677d3876170feb6d08c97" + STABLEHLO_SHA256 = "06db84c751bd4a980dc76249e02f10e119175fceba3eebed008da122cb480bab" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index b79bde1851c6c6..ca2f3a937f73ad 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "f4459e76553770ecc94f23de29984c7859ad9f05" - STABLEHLO_SHA256 = "00e2bcd62db577297a0a9b6f9203a9f2f58bd40bfa2574908ffa883ad7f60fd5" + STABLEHLO_COMMIT = "271e8634de184fbfafd677d3876170feb6d08c97" + STABLEHLO_SHA256 = "06db84c751bd4a980dc76249e02f10e119175fceba3eebed008da122cb480bab" # LINT.ThenChange(Google-internal path) tf_http_archive( From 0adcfc22f441b53a581ce14b1fc48298bb35705e Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Thu, 28 Mar 2024 15:22:06 -0700 Subject: [PATCH 060/124] Move sparsecore passes under transforms/sparsecore. PiperOrigin-RevId: 620072718 --- tensorflow/compiler/mlir/BUILD | 2 + .../compiler/mlir/tensorflow/transforms/BUILD | 3 - .../tensorflow/transforms/host_runtime/BUILD | 1 + .../lower_cluster_to_runtime_ops.cc | 1 + .../mlir/tensorflow/transforms/passes.h | 7 - .../tensorflow/transforms/sparsecore/BUILD | 123 ++++++++++++++++++ .../{ => sparsecore}/embedding_pipelining.cc | 6 +- .../{ => sparsecore}/embedding_program_key.cc | 3 +- .../{ => sparsecore}/embedding_sequencing.cc | 8 +- .../transforms/sparsecore/sparsecore_passes.h | 50 +++++++ .../sparsecore/sparsecore_passes.td | 83 ++++++++++++ .../mlir/tensorflow/transforms/tf_passes.td | 67 ---------- .../compiler/mlir/tf2xla/internal/BUILD | 1 + .../internal/clustering_bridge_passes.cc | 1 + tensorflow/compiler/mlir/tf_mlir_opt_main.cc | 5 +- 15 files changed, 277 insertions(+), 84 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD rename tensorflow/compiler/mlir/tensorflow/transforms/{ => sparsecore}/embedding_pipelining.cc (99%) rename tensorflow/compiler/mlir/tensorflow/transforms/{ => sparsecore}/embedding_program_key.cc (99%) rename tensorflow/compiler/mlir/tensorflow/transforms/{ => sparsecore}/embedding_sequencing.cc (98%) create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.td diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index b30f08a1bfe1b4..d0286e5acff9ce 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -59,6 +59,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", # buildcleaner:keep "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:runtime_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", "//tensorflow/compiler/mlir/tf2xla/internal/passes:mlir_to_graph_passes", @@ -69,6 +70,7 @@ cc_library( "//tensorflow/compiler/mlir/tosa:tfl_passes", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir/framework/ir:xla_framework", "@local_xla//xla/mlir/framework/transforms:passes", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index c84871fd564156..3d1cf1bd58fa38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -460,9 +460,6 @@ cc_library( "device_index_selector.cc", "drop_while_shape_invariant.cc", "einsum.cc", - "embedding_pipelining.cc", - "embedding_program_key.cc", - "embedding_sequencing.cc", "executor_island_coarsening.cc", "executor_tpuv1_inline_tpu_island.cc", "executor_tpuv1_island_coarsening.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index 15f339eccd2f93..f8e75d9032f3e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -31,6 +31,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:error_payloads", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc index 713e9080f2e03b..a239c7304a0ae0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 9c475f1f9f5281..da89e77cb0862c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -446,13 +446,6 @@ std::unique_ptr> CreateReplicateToIslandPass( std::unique_ptr> CreateReplicaIDToDeviceOrdinalPass(); -// Creates a pass that adds pipelining to a graph that contains device -// accelerated embeddings. The EmbeddingSequencingPass is a temporary fallback -// while developing full pipelining capabilities. -std::unique_ptr> CreateEmbeddingSequencingPass(); -std::unique_ptr> CreateEmbeddingPipeliningPass(); -std::unique_ptr> CreateEmbeddingProgramKeyPass(); - // Creates a pass that creates `tf_executor.island` from a single // `tf_device.parallel_execute` island. std::unique_ptr> CreateParallelExecuteToIslandsPass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD new file mode 100644 index 00000000000000..bff95d357c885f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD @@ -0,0 +1,123 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", + "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:__pkg__", + "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", + "//tensorflow/compiler/mlir/tf2xla/internal:__pkg__", + ], + licenses = ["notice"], +) + +gentbl_cc_library( + name = "sparsecore_passes_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=SparseCore", + ], + "sparsecore_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "sparsecore_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +cc_library( + name = "sparsecore_passes", + hdrs = [ + "sparsecore_passes.h", + ], + textual_hdrs = [ + "sparsecore_passes.h.inc", + ], + deps = [ + ":embedding_pipelining", + ":embedding_program_key", + ":embedding_sequencing", + ":sparsecore_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "embedding_pipelining", + srcs = ["embedding_pipelining.cc"], + hdrs = [ + "sparsecore_passes.h", + ], + deps = [ + ":sparsecore_passes_inc_gen", + "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "embedding_sequencing", + srcs = ["embedding_sequencing.cc"], + hdrs = [ + "sparsecore_passes.h", + ], + deps = [ + ":sparsecore_passes_inc_gen", + "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "embedding_program_key", + srcs = ["embedding_program_key.cc"], + hdrs = [ + "sparsecore_passes.h", + ], + deps = [ + ":sparsecore_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla/mlir_hlo", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc similarity index 99% rename from tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index ee334b3f032155..0c450126e4e090 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -157,7 +157,7 @@ return selected_results #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define GEN_PASS_DEF_EMBEDDINGPIPELININGPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining"; static constexpr char kEmbeddingPipeliningInlineAttr[] = @@ -1289,7 +1289,7 @@ LogicalResult StartStep0(OpBuilder& builder, Location& loc, func::FuncOp orig_parent_func = callers.backward->getParentOfType(); - std::vector operands = loop_operands_nm0; + const std::vector& operands = loop_operands_nm0; // Input types will be the same as the original loop body. std::vector input_types = GetValueTypes(operands); @@ -1373,7 +1373,7 @@ LogicalResult StartStep1(OpBuilder& builder, Location& loc, func::FuncOp orig_parent_func = callers.backward->getParentOfType(); - std::vector operands = loop_operands_1; + const std::vector& operands = loop_operands_1; // Input types will be the same as the original loop body. std::vector input_types = GetValueTypes(operands); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_program_key.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc similarity index 99% rename from tensorflow/compiler/mlir/tensorflow/transforms/embedding_program_key.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc index a5575ef156ddb9..3e41762feb16c2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_program_key.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { @@ -42,7 +41,7 @@ constexpr char kMiniBatchSplitsAttr[] = "mini_batch_splits"; constexpr char kMiniBatchCsrAttr[] = "mini_batch_in_csr"; #define GEN_PASS_DEF_EMBEDDINGPROGRAMKEYPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" struct EmbeddingProgramKeyPass : public impl::EmbeddingProgramKeyPassBase { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc similarity index 98% rename from tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc index a77dd6f498a144..7ed29a3ed58cc3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc @@ -32,6 +32,8 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Casting.h" @@ -40,6 +42,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project @@ -47,17 +50,20 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define GEN_PASS_DEF_EMBEDDINGSEQUENCINGPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining"; static constexpr char kEmbeddingForward[] = "forward"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h new file mode 100644 index 00000000000000..8944745dd3fff9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TFDevice { + +// For architectures that support accelerated embedding lookups, this pass will +// rewrite the graph to use pipelining for better device utilization. +std::unique_ptr> CreateEmbeddingSequencingPass(); + +// This is a strictly sequential and formally correct fallback option for the +// embedding pipelining pass intended for debugging during pipelining +// development. +std::unique_ptr> CreateEmbeddingPipeliningPass(); + +// Passes in the program key to embedding ops, by moving the embedding ops +// after the _TPUCompileMlir op. +std::unique_ptr> CreateEmbeddingProgramKeyPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_EMBEDDINGSEQUENCINGPASS +#define GEN_PASS_DECL_EMBEDDINGPIPELININGPASS +#define GEN_PASS_DECL_EMBEDDINGPROGRAMKEYPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" + +} // namespace TFDevice +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.td new file mode 100644 index 00000000000000..a9c5981393df6c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.td @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def EmbeddingPipeliningPass : Pass<"tf-embedding-pipelining", "mlir::ModuleOp"> { + let summary = "Rewrite graph for embedding pipelining"; + let constructor = "TFDevice::CreateEmbeddingPipeliningPass()"; + let description = [{ + For architectures that support accelerated embedding lookups, this pass will + rewrite the graph to use pipelining for better device utilization. + }]; +} + +def EmbeddingSequencingPass : Pass<"tf-embedding-sequencing", "mlir::ModuleOp"> { + let summary = "Rewrite graph for sequential execution of embeddings"; + let constructor = "TFDevice::CreateEmbeddingSequencingPass()"; + let description = [{ + This is a strictly sequential and formally correct fallback option for the + embedding pipelining pass intended for debugging during pipelining + development. + }]; +} + +def EmbeddingProgramKeyPass : Pass<"tf-embedding-program-key", "mlir::func::FuncOp"> { + let summary = "Sets the program key for embedding ops."; + let constructor = "TFDevice::CreateEmbeddingProgramKeyPass()"; + let description = [{ + Passes in the program key to embedding ops. Will move the embedding ops + after a _TPUCompileMlir op if there is no predecessor _TPUCompileMlir op. + Both the embedding op and compile op are assumed to be wrapped in separate + tf_device.launch() ops. This is because the embedding op is head outside + compiled and the compile op is wrapped in launch to execute on host + during TPURewritePass. + + For example, the tf.OpA with the `mini_batch_splits` attribute will be + moved after _TPUCompileMlir and the first input will use the + _TPUCompileMlir program output: + + ```mlir + "tf_device.launch"() ({ + %cst_0 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + "tf.OpA"(%cst_0) { mini_batch_splits = ""} : (tensor<1x!tf_type.string>) -> () + tf_device.return + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () + %0:2 = "tf_device.launch"() ({ + %compilation_status, %program = "tf._TPUCompileMlir"() { metadata = "...", mlir_module = "..." } : () -> (tensor, tensor<3x!tf_type.string>) + tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) + ``` + + becomes: + + ```mlir + %0:2 = "tf_device.launch"() ({ + %compilation_status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor<3x!tf_type.string>) + tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) + "tf_device.launch"() ({ + %cst = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + "tf.OpA"(%0#1) {mini_batch_splits = ""} : (tensor<3x!tf_type.string>) -> () + tf_device.return + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () + ``` + }]; + + let dependentDialects = [ + "mhlo::MhloDialect", + "tf_device::TensorFlowDeviceDialect" + ]; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index b00e70eb73c4cc..6b53cae7099688 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -329,73 +329,6 @@ def ReplicaIDToDeviceOrdinalPass : Pass<"tf-replica-id-to-device-ordinal", "mlir }]; } -def EmbeddingPipeliningPass : Pass<"tf-embedding-pipelining", "mlir::ModuleOp"> { - let summary = "Rewrite graph for embedding pipelining"; - let constructor = "TFDevice::CreateEmbeddingPipeliningPass()"; - let description = [{ - For architectures that support accelerated embedding lookups, this pass will - rewrite the graph to use pipelining for better device utilization. - }]; -} - -def EmbeddingProgramKeyPass : Pass<"tf-embedding-program-key", "mlir::func::FuncOp"> { - let summary = "Sets the program key for embedding ops."; - let constructor = "TFDevice::CreateEmbeddingProgramKeyPass()"; - let description = [{ - Passes in the program key to embedding ops. Will move the embedding ops - after a _TPUCompileMlir op if there is no predecessor _TPUCompileMlir op. - Both the embedding op and compile op are assumed to be wrapped in separate - tf_device.launch() ops. This is because the embedding op is head outside - compiled and the compile op is wrapped in launch to execute on host - during TPURewritePass. - - For example, the tf.OpA with the `mini_batch_splits` attribute will be - moved after _TPUCompileMlir and the first input will use the - _TPUCompileMlir program output: - - ```mlir - "tf_device.launch"() ({ - %cst_0 = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> - "tf.OpA"(%cst_0) { mini_batch_splits = ""} : (tensor<1x!tf_type.string>) -> () - tf_device.return - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () - %0:2 = "tf_device.launch"() ({ - %compilation_status, %program = "tf._TPUCompileMlir"() { metadata = "...", mlir_module = "..." } : () -> (tensor, tensor<3x!tf_type.string>) - tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) - ``` - - becomes: - - ```mlir - %0:2 = "tf_device.launch"() ({ - %compilation_status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor<3x!tf_type.string>) - tf_device.return %compilation_status, %program : tensor, tensor<3x!tf_type.string> - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf_type.string>) - "tf_device.launch"() ({ - %cst = "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> - "tf.OpA"(%0#1) {mini_batch_splits = ""} : (tensor<3x!tf_type.string>) -> () - tf_device.return - }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () - ``` - }]; - - let dependentDialects = [ - "mhlo::MhloDialect", - "tf_device::TensorFlowDeviceDialect" - ]; -} - -def EmbeddingSequencingPass : Pass<"tf-embedding-sequencing", "mlir::ModuleOp"> { - let summary = "Rewrite graph for sequential execution of embeddings"; - let constructor = "TFDevice::CreateEmbeddingSequencingPass()"; - let description = [{ - This is a strictly sequential and formally correct fallback option for the - embedding pipelining pass intended for debugging during pipelining - development. - }]; -} - def ConvertReadonlyReferenceVariablesToResourceVariablesPass : Pass<"tf-readonly-references-to-resources", "mlir::func::FuncOp"> { let summary = "Convert readonly reference variables to resource variables."; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD index 246481c5cab7db..7e937d2ce49f8b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -187,6 +187,7 @@ cc_library( "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", diff --git a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc index 603d928daf9032..e289934b69fbe0 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 2c49198be7bad8..1ce45fe7345c11 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow//compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" @@ -35,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" -#include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir/framework/transforms/passes.h" #include "xla/mlir_hlo/lhlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -69,6 +70,8 @@ int main(int argc, char **argv) { tensorflow::RegisterGraphOptimizationPasses(); tensorflow::RegisterMlProgramPasses(); mlir::TFTPU::registerRuntimeLoweringPasses(); + mlir::TFDevice::registerSparseCorePasses(); + tensorflow::tfrt_compiler::RegisterTPULowerClusterToRuntimeOpsPassPipeline(); tensorflow::tfrt_compiler:: RegisterNonTPULowerClusterToRuntimeOpsPassPipeline(); From a3db0abb7adc8101d90d2646edb2c7700b1f0c3b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 15:41:03 -0700 Subject: [PATCH 061/124] Internal cleanup of BUILD/.bzl files PiperOrigin-RevId: 620077983 --- tensorflow/lite/acceleration/configuration/BUILD | 2 ++ tensorflow/lite/experimental/acceleration/configuration/BUILD | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/lite/acceleration/configuration/BUILD b/tensorflow/lite/acceleration/configuration/BUILD index f78221caf8ee2c..a8e13de931c3b0 100644 --- a/tensorflow/lite/acceleration/configuration/BUILD +++ b/tensorflow/lite/acceleration/configuration/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +# buildifier: disable=out-of-order-load + load("@flatbuffers//:build_defs.bzl", "DEFAULT_FLATC_ARGS", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatc_path") # copybara:comment_begin(oss-only) diff --git a/tensorflow/lite/experimental/acceleration/configuration/BUILD b/tensorflow/lite/experimental/acceleration/configuration/BUILD index 79932ae571be7e..f4f92bcbe6860b 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/BUILD +++ b/tensorflow/lite/experimental/acceleration/configuration/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +# buildifier: disable=out-of-order-load + load("@flatbuffers//:build_defs.bzl", "DEFAULT_FLATC_ARGS", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library") # copybara:comment_begin(oss-only) From 8ae67af5aeb33f15b9f8323a594c5c76bbfa94d0 Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Thu, 28 Mar 2024 16:06:00 -0700 Subject: [PATCH 062/124] Delete populateRankSpecialization*Patterns functions These were used by KernelGen but are no longer needed. PiperOrigin-RevId: 620084345 --- third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h | 7 ------- 1 file changed, 7 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h index 14e3add6f814fe..c40a087ab52cbd 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -166,13 +166,6 @@ void populateGroupReductionDimensionsPatterns(MLIRContext *context, RewritePatternSet *patterns, bool preferColumnsReductions); -/// Populate rank specialization clustering and lowering patterns. -void populateRankSpecializationClusterPatterns(MLIRContext *context, - RewritePatternSet *patterns); -void populateRankSpecializationToSCFPatterns(MLIRContext *context, - RewritePatternSet *patterns, - int64_t maxTargetRank); - /// Populate sparse tensor specific rewriting patterns. void populateSparseRewritingPatterns(RewritePatternSet *patterns, MLIRContext *ctx); From 90ab9c1d4030fabb2e9af171ab43dbe0ecec97ef Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Thu, 28 Mar 2024 16:58:40 -0700 Subject: [PATCH 063/124] Add a Resource for KV Cache buffer storage PiperOrigin-RevId: 620097541 --- tensorflow/lite/experimental/resource/BUILD | 26 +++++++++ .../experimental/resource/cache_buffer.cc | 56 +++++++++++++++++++ .../lite/experimental/resource/cache_buffer.h | 51 +++++++++++++++++ .../resource/cache_buffer_test.cc | 47 ++++++++++++++++ .../experimental/resource/resource_variable.h | 2 +- 5 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/experimental/resource/cache_buffer.cc create mode 100644 tensorflow/lite/experimental/resource/cache_buffer.h create mode 100644 tensorflow/lite/experimental/resource/cache_buffer_test.cc diff --git a/tensorflow/lite/experimental/resource/BUILD b/tensorflow/lite/experimental/resource/BUILD index 45ed5395dbc9b8..bed57f7489d25e 100644 --- a/tensorflow/lite/experimental/resource/BUILD +++ b/tensorflow/lite/experimental/resource/BUILD @@ -6,6 +6,32 @@ package( licenses = ["notice"], ) +cc_library( + name = "cache_buffer", + srcs = ["cache_buffer.cc"], + hdrs = [ + "cache_buffer.h", + "//tensorflow/lite/core/c:common.h", + ], + deps = [ + ":resource", + "//tensorflow/lite/core/c:c_api_types", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:compatibility", + ], +) + +cc_test( + name = "cache_buffer_test", + srcs = ["cache_buffer_test.cc"], + deps = [ + ":cache_buffer", + "//tensorflow/lite/c:common", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "resource", srcs = [ diff --git a/tensorflow/lite/experimental/resource/cache_buffer.cc b/tensorflow/lite/experimental/resource/cache_buffer.cc new file mode 100644 index 00000000000000..0e221589b4cc64 --- /dev/null +++ b/tensorflow/lite/experimental/resource/cache_buffer.cc @@ -0,0 +1,56 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/resource/cache_buffer.h" + +#include +#include + +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace resource { + +constexpr char kCacheBufferTensorName[] = "CacheBuffer"; + +TfLiteStatus CacheBuffer::Initialize(const TfLiteIntArray &shape, + const TfLiteType &type) { + // Set basic parameters. + tensor_.name = kCacheBufferTensorName; + tensor_.allocation_type = kTfLiteDynamic; + tensor_.type = type; + + // Set the shape and allocate the memory. + tensor_.dims = TfLiteIntArrayCopy(&shape); + const size_t num_bytes = TfLiteTypeGetSize(type) * NumElements(&tensor_); + TfLiteTensorRealloc(num_bytes, &tensor_); + + memset(tensor_.data.raw, 0, tensor_.bytes); + is_initialized_ = true; + return kTfLiteOk; +} + +size_t CacheBuffer::GetNumEntries() const { return num_entries_; } + +void CacheBuffer::SetNumEntries(size_t count) { + TFLITE_DCHECK(count <= tensor_.dims->data[2]); + num_entries_ = count; +} + +} // namespace resource +} // namespace tflite diff --git a/tensorflow/lite/experimental/resource/cache_buffer.h b/tensorflow/lite/experimental/resource/cache_buffer.h new file mode 100644 index 00000000000000..1e500fab07c269 --- /dev/null +++ b/tensorflow/lite/experimental/resource/cache_buffer.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_CACHE_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_CACHE_BUFFER_H_ + +#include +#include + +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/experimental/resource/resource_variable.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace resource { + +/// WARNING: Experimental interface, subject to change. +// A Cache Buffer class. Useful for keeping the keys and values of a +// transformer block attention mechanism in autoregressive decode. +// Ops can access this buffer and add tensors to it. It also keeps track of the +// number of used entries in the cache. +class CacheBuffer : public ResourceVariable { + public: + CacheBuffer() = default; + CacheBuffer(const CacheBuffer &) = delete; + CacheBuffer &operator=(const CacheBuffer &) = delete; + // Initialize tensor of a certain shape using the provided type. + TfLiteStatus Initialize(const TfLiteIntArray &shape, const TfLiteType &type); + size_t GetNumEntries() const; + void SetNumEntries(size_t count); + + private: + // The number of entries currently used in the buffer; + size_t num_entries_ = 0; +}; + +} // namespace resource +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_CACHE_BUFFER_H_ diff --git a/tensorflow/lite/experimental/resource/cache_buffer_test.cc b/tensorflow/lite/experimental/resource/cache_buffer_test.cc new file mode 100644 index 00000000000000..6b54f6c787138d --- /dev/null +++ b/tensorflow/lite/experimental/resource/cache_buffer_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/resource/cache_buffer.h" + +#include +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace resource { + +TEST(CacheBufferTest, Initialize) { + TfLiteIntArray* shape = TfLiteIntArrayCreate(4); + shape->data[0] = 1; + shape->data[1] = 3; + shape->data[2] = 5; + shape->data[3] = 7; + + TfLiteType type = kTfLiteFloat32; + CacheBuffer cache_buffer; + cache_buffer.Initialize(*shape, type); + + EXPECT_EQ(cache_buffer.GetTensor()->type, type); + EXPECT_EQ(cache_buffer.GetTensor()->dims->size, 4); + EXPECT_EQ(cache_buffer.GetTensor()->dims->data[0], 1); + EXPECT_EQ(cache_buffer.GetTensor()->dims->data[1], 3); + EXPECT_EQ(cache_buffer.GetTensor()->bytes, 420); + ASSERT_NE(cache_buffer.GetTensor()->data.raw, nullptr); + EXPECT_EQ(cache_buffer.GetNumEntries(), 0); + cache_buffer.SetNumEntries(3); + EXPECT_EQ(cache_buffer.GetNumEntries(), 3); + TfLiteIntArrayFree(shape); +} + +} // namespace resource +} // namespace tflite diff --git a/tensorflow/lite/experimental/resource/resource_variable.h b/tensorflow/lite/experimental/resource/resource_variable.h index 3f34082c85f553..881aaa1f0b3a19 100644 --- a/tensorflow/lite/experimental/resource/resource_variable.h +++ b/tensorflow/lite/experimental/resource/resource_variable.h @@ -50,7 +50,7 @@ class ResourceVariable : public ResourceBase { return is_initialized_ ? tensor_.bytes : 0; } - private: + protected: // The tensor (and its buffer stored in `tensor_.data` is fully owned by // the `ResourceVariable` object. TfLiteTensor tensor_; From 7fadd476bd7915ac0b855c18fd747b813c681975 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 17:44:19 -0700 Subject: [PATCH 064/124] Correctly handle output streaming case where the MoveToHost annotation is the entry computation root PiperOrigin-RevId: 620107928 --- third_party/xla/xla/service/BUILD | 8 +- third_party/xla/xla/service/hlo_verifier.cc | 3 +- .../xla/xla/service/host_offload_legalize.cc | 12 ++- third_party/xla/xla/service/host_offloader.cc | 30 +++--- .../xla/xla/service/host_offloader_test.cc | 102 +++++++++++++++--- 5 files changed, 118 insertions(+), 37 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 79c7d1051897b9..1db01dd5c90dea 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5882,17 +5882,13 @@ cc_library( srcs = ["host_offload_legalize.cc"], hdrs = ["host_offload_legalize.h"], deps = [ + ":call_graph", ":hlo_alias_analysis", - ":hlo_buffer", ":hlo_pass", ":hlo_value", ":host_memory_offload_annotations_hdr", - ":host_offloader", - ":pattern_matcher", - "//xla:literal_util", "//xla:shape_util", "//xla:status", - "//xla:statusor", "//xla:util", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", @@ -5900,7 +5896,9 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 9b526c3bae41f0..751756ccdab359 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -1952,7 +1952,8 @@ Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { result_layout.shape(), Shape::Equal() .IgnoreTilesInLayout() - .IgnoreTailPaddingAlignmentInElements())) { + .IgnoreTailPaddingAlignmentInElements() + .IgnoreMemorySpaceInLayout())) { return Internal( "Shape of the root instruction of entry computation (%s) should be " "compatible to one specified in module's entry computation layout (%s)", diff --git a/third_party/xla/xla/service/host_offload_legalize.cc b/third_party/xla/xla/service/host_offload_legalize.cc index 958ec67718c51e..e80e0ef32b5a50 100644 --- a/third_party/xla/xla/service/host_offload_legalize.cc +++ b/third_party/xla/xla/service/host_offload_legalize.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include +#include #include #include @@ -26,15 +26,18 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" #include "xla/service/hlo_value.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/statusor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -336,8 +339,11 @@ absl::StatusOr ProcessAnnotationForCopyMovement( instruction->parent()->IsEntryComputation(); }; + if (instruction->IsRoot()) { + return false; + } HloInstruction* starting_instr = - FindDUSFromAnnotation(instruction->users()[0]); + FindDUSFromAnnotation(instruction->users().at(0)); // If it's the pure copy case reset instruction. if (starting_instr->opcode() != HloOpcode::kDynamicUpdateSlice) { starting_instr = instruction; diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/service/host_offloader.cc index 9058a9aa48c515..d17277a9141c6b 100644 --- a/third_party/xla/xla/service/host_offloader.cc +++ b/third_party/xla/xla/service/host_offloader.cc @@ -222,28 +222,34 @@ Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) { // Save a pointer to this custom call for when we want to remove it later. custom_calls_to_remove_.emplace(custom_call); - // We expect that the DUS is the only user of this custom call. - if (custom_call->user_count() != 1) { + // We expect that either the custom call is the root or the DUS is the only + // user of this custom call. + if (!custom_call->IsRoot() && custom_call->user_count() != 1) { return FailedPrecondition( - "Expecting custom call %s to only have 1 user; it has %d users: [%s]", + "Expecting custom call %s to either be the root or only have 1 user; " + "it is not the root and has %d users: [%s]", custom_call->name(), custom_call->user_count(), absl::StrJoin(custom_call->users(), ", ", [](std::string* out, const HloInstruction* user) { out->append(user->name()); })); } - HloInstruction* op_being_annotated = custom_call->users()[0]; - // Skip past any bitcasts. - while (op_being_annotated->opcode() == HloOpcode::kBitcast) { - VLOG(1) << "Skipping bitcast " << op_being_annotated->ToString(); - op_being_annotated = op_being_annotated->users()[0]; + HloInstruction* consumer = nullptr; + if (!custom_call->IsRoot()) { + consumer = custom_call->users().at(0); + // Skip past any bitcasts. + while (consumer != nullptr && consumer->opcode() == HloOpcode::kBitcast) { + VLOG(1) << "Skipping bitcast " << consumer->ToString(); + consumer = consumer->users().at(0); + } } - if (op_being_annotated->opcode() == HloOpcode::kDynamicUpdateSlice) { - TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithDus(op_being_annotated)); - } else if (op_being_annotated->opcode() == HloOpcode::kCopy) { - TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithCopy(op_being_annotated)); + if (consumer != nullptr && + consumer->opcode() == HloOpcode::kDynamicUpdateSlice) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithDus(consumer)); + } else if (consumer != nullptr && consumer->opcode() == HloOpcode::kCopy) { + TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithCopy(consumer)); } else { TF_ASSIGN_OR_RETURN(bool did_output_streaming, TryOutputStreaming(custom_call)); diff --git a/third_party/xla/xla/service/host_offloader_test.cc b/third_party/xla/xla/service/host_offloader_test.cc index 6b367fe53a2f54..162bbb4630de45 100644 --- a/third_party/xla/xla/service/host_offloader_test.cc +++ b/third_party/xla/xla/service/host_offloader_test.cc @@ -1856,22 +1856,22 @@ ENTRY main { TEST_F(HostOffloaderTest, OutputStreaming) { const std::string& hlo_string = R"( -HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})} - -ENTRY main { - param_0 = s32[2,1]{1,0} parameter(0) - param_1 = s32[2,1]{1,0} parameter(1) - constant_2 = s32[] constant(2) - constant_4 = s32[] constant(4) - broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={} - multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0) - multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0) - broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={} - multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1) - custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost" - ROOT tuple = (s32[2,1]{1,0}, s32[2,1]{1,0}) tuple(custom_call, multiply_1) -} -)"; + HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})} + + ENTRY main { + param_0 = s32[2,1]{1,0} parameter(0) + param_1 = s32[2,1]{1,0} parameter(1) + constant_2 = s32[] constant(2) + constant_4 = s32[] constant(4) + broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={} + multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0) + multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0) + broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={} + multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1) + custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost" + ROOT tuple = (s32[2,1]{1,0}, s32[2,1]{1,0}) tuple(custom_call, multiply_1) + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); @@ -1935,6 +1935,76 @@ ENTRY main { EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); } +TEST_F(HostOffloaderTest, OutputStreamingCustomCallRoot) { + const std::string& hlo_string = R"( + HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->s32[2,1]{1,0:T(2,128)S(5)}} + + ENTRY main { + param_0 = s32[2,1]{1,0} parameter(0) + param_1 = s32[2,1]{1,0} parameter(1) + constant_2 = s32[] constant(2) + constant_4 = s32[] constant(4) + broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={} + multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0) + multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0) + broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={} + multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1) + ROOT custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + + EXPECT_TRUE(changed); + + // Look for the following pattern: + // constant + // | + // param1 broadcast param0 + // \ / / + // multiply / + // \ / + // \ / + // multiply constant + // | | + // | ---+---broadcast + // | / + // multiply + // | + // copy + HloInstruction* param_1; + HloInstruction* broadcast_0; + HloInstruction* multiply_0; + HloInstruction* param_0; + HloInstruction* multiply_1; + HloInstruction* broadcast_1; + HloInstruction* multiply_2; + HloInstruction* copy; + auto multiplyPattern = + m::Multiply(&multiply_1, + m::Multiply(&multiply_0, m::Parameter(¶m_1), + m::Broadcast(&broadcast_0, m::ConstantScalar(2))), + m::Parameter(¶m_0)); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Copy( + ©, m::Multiply(&multiply_2, multiplyPattern, + m::Broadcast(&broadcast_1, + m::ConstantScalar(4)))))); + TestShapeHasMemorySpace(param_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(param_0->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(broadcast_1->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace); + TestShapeHasMemorySpace(copy->shape(), kHostMemorySpaceColor); + + EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get())); +} + } // namespace } // namespace xla From 56e98e590e7ebf6d5b470328d6cba39eff98b1cf Mon Sep 17 00:00:00 2001 From: Siqiao Wu Date: Thu, 28 Mar 2024 18:04:29 -0700 Subject: [PATCH 065/124] Remove the duplicate device assignment in ifrt_serving_executable. PiperOrigin-RevId: 620111648 --- .../mlir/tfrt/transforms/ifrt/tf2hlo.cc | 1 + .../core/tfrt/ifrt/ifrt_serving_executable.cc | 35 +++++-------------- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index 65b5078495ea04..a0b01ba1ffc3f7 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -128,6 +128,7 @@ absl::StatusOr GetCompileMetadata( // Create a default device assignment if one is not given by the model. if (!metadata.has_device_assignment()) { + // TODO(b/316068010): integrate core selection. TF_ASSIGN_OR_RETURN( auto device_assignment, ifrt_client.GetDefaultDeviceAssignment( diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index 868f49ad070ac4..bb9176c3b3de6f 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -87,37 +87,21 @@ absl::StatusOr> BuildDtypeAndShape( } absl::StatusOr GetXlaDeviceAssignment( - const xla::ifrt::Client& ifrt_client, const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) { - int num_replicas = compile_metadata.num_replicas(); - int num_partitions = compile_metadata.num_cores_per_replica(); - - VLOG(2) << " Number of replcas is " << num_replicas - << " and num_partitions is " << num_partitions; - - if (num_replicas > 1) { - return absl::UnimplementedError( - absl::StrCat("Only support single replica, but replica number is ", - num_replicas, " and num_partitions is ", num_partitions)); - } - - if (compile_metadata.has_device_assignment()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr da, - xla::DeviceAssignment::Deserialize( - compile_metadata.device_assignment())); - - return *std::move(da); - } else { - // TODO(b/316068010): integrate core selection. - return ifrt_client.GetDefaultDeviceAssignment(num_replicas, num_partitions); + if (!compile_metadata.has_device_assignment()) { + return absl::InternalError("No device assignment found."); } + TF_ASSIGN_OR_RETURN( + std::unique_ptr da, + xla::DeviceAssignment::Deserialize(compile_metadata.device_assignment())); + return *da; } absl::StatusOr> GetAssignedDevices( const xla::ifrt::Client& ifrt_client, const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) { TF_ASSIGN_OR_RETURN(auto device_assignment, - GetXlaDeviceAssignment(ifrt_client, compile_metadata)); + GetXlaDeviceAssignment(compile_metadata)); const int num_devices = device_assignment.replica_count() * device_assignment.computation_count(); @@ -173,9 +157,8 @@ IfrtServingExecutable::CreateExecutableSynchronously( num_replicas, " and num_partitions is ", num_partitions)); } - TF_ASSIGN_OR_RETURN( - xla::DeviceAssignment da, - GetXlaDeviceAssignment(*ifrt_client_, tf2hlo_result.compile_metadata)); + TF_ASSIGN_OR_RETURN(xla::DeviceAssignment da, + GetXlaDeviceAssignment(tf2hlo_result.compile_metadata)); VLOG(2) << "Device assignment :" << da.ToString(); From 41bfc5d180756de01ed1d181428c65cc20468102 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Thu, 28 Mar 2024 18:06:04 -0700 Subject: [PATCH 066/124] Set release_base for all release platforms PiperOrigin-RevId: 620111958 --- .bazelrc | 12 ++++++++++-- third_party/xla/.bazelrc | 12 ++++++++++-- third_party/xla/third_party/tsl/.bazelrc | 12 ++++++++++-- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/.bazelrc b/.bazelrc index c630c1350bbd77..d8990ac5c12cc5 100644 --- a/.bazelrc +++ b/.bazelrc @@ -597,8 +597,12 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Target the AVX instruction set build:release_linux_base --config=avx_linux + # Enable support for all targets build:release_base --config=cpu_cross @@ -679,12 +683,14 @@ build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gc build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain build:release_cpu_macos --config=avx_linux -test:release_cpu_macos --config=release_base # Base build configs for macOS build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Ensure release_base is set on mac +build:release_macos_base --config=release_base + # Build configs for macOS x86 build:release_macos_x86 --config=release_macos_base # Build with the AVX instruction set when on macOS x86 @@ -714,10 +720,12 @@ test:release_macos_x86 --config=release_macos_base # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base +# Ensure release_base is set on windows +build:release_cpu_windows --config=release_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true -test:release_cpu_windows --config=release_base # Exclude TFRT integration for anything but Linux. build:android --config=no_tfrt diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index c630c1350bbd77..d8990ac5c12cc5 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -597,8 +597,12 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Target the AVX instruction set build:release_linux_base --config=avx_linux + # Enable support for all targets build:release_base --config=cpu_cross @@ -679,12 +683,14 @@ build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gc build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain build:release_cpu_macos --config=avx_linux -test:release_cpu_macos --config=release_base # Base build configs for macOS build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Ensure release_base is set on mac +build:release_macos_base --config=release_base + # Build configs for macOS x86 build:release_macos_x86 --config=release_macos_base # Build with the AVX instruction set when on macOS x86 @@ -714,10 +720,12 @@ test:release_macos_x86 --config=release_macos_base # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base +# Ensure release_base is set on windows +build:release_cpu_windows --config=release_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true -test:release_cpu_windows --config=release_base # Exclude TFRT integration for anything but Linux. build:android --config=no_tfrt diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index c630c1350bbd77..d8990ac5c12cc5 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -597,8 +597,12 @@ try-import %workspace%/.bazelrc.user # Build TensorFlow v2. test:release_base --test_size_filters=small,medium +# Ensure release_base is set on linux +build:release_linux_base --config=release_base + # Target the AVX instruction set build:release_linux_base --config=avx_linux + # Enable support for all targets build:release_base --config=cpu_cross @@ -679,12 +683,14 @@ build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gc build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain build:release_cpu_macos --config=avx_linux -test:release_cpu_macos --config=release_base # Base build configs for macOS build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer build:release_macos_base --define=no_nccl_support=true --output_filter=^$ +# Ensure release_base is set on mac +build:release_macos_base --config=release_base + # Build configs for macOS x86 build:release_macos_x86 --config=release_macos_base # Build with the AVX instruction set when on macOS x86 @@ -714,10 +720,12 @@ test:release_macos_x86 --config=release_macos_base # Test configs for macOS Arm64 test:release_macos_arm64 --config=release_macos_base +# Ensure release_base is set on windows +build:release_cpu_windows --config=release_base + # TODO(kanglan): Update windows configs after b/289091160 is fixed build:release_cpu_windows --config=avx_win build:release_cpu_windows --define=no_tensorflow_py_deps=true -test:release_cpu_windows --config=release_base # Exclude TFRT integration for anything but Linux. build:android --config=no_tfrt From a6bbcb71721040569f8aa737aabc78d5df5050de Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Thu, 28 Mar 2024 18:11:34 -0700 Subject: [PATCH 067/124] Add missing 'END' to definition of XlaSplitND. PiperOrigin-RevId: 620112924 --- tensorflow/core/api_def/base_api/api_def_XlaSplitND.pbtxt | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/api_def/base_api/api_def_XlaSplitND.pbtxt b/tensorflow/core/api_def/base_api/api_def_XlaSplitND.pbtxt index c7deacb4fec21d..a31cfa5c4d85ec 100644 --- a/tensorflow/core/api_def/base_api/api_def_XlaSplitND.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_XlaSplitND.pbtxt @@ -5,6 +5,7 @@ op { name: "input" description: < Date: Thu, 28 Mar 2024 18:50:52 -0700 Subject: [PATCH 068/124] Enable weight-only quantization of stablehlo.convolution PiperOrigin-RevId: 620121421 --- .../quantization/stablehlo/passes/passes.td | 4 +- .../stablehlo/passes/quantization_patterns.cc | 111 ++++++++---------- .../quantization/stablehlo/passes/quantize.cc | 14 ++- .../integration_test/quantize_model_test.py | 93 ++++++++++++++- .../passes/quantize/quantize_weight_only.mlir | 36 +++++- ...ntize_composite_functions_weight_only.mlir | 30 +++++ 6 files changed, 217 insertions(+), 71 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 80847e8283652e..63f6f822dbebdf 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -70,7 +70,7 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function Option<"enable_weight_only_", "enable-weight-only", "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for dot_general op.">, + "Whether to produce weight-only quantized op for convolution and dot_general op.">, ]; let dependentDialects = [ "mlir::arith::ArithDialect", @@ -113,7 +113,7 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { Option<"enable_weight_only_", "enable-weight-only", "bool", /*default=*/"false", - "Whether to produce weight-only quantized op for dot_general op.">, + "Whether to produce weight-only quantized op for convolution and dot_general op.">, ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index f2b78caeb3cb44..10b15f1132fe62 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -410,9 +410,11 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeDotGeneralOpPattern( - const bool enable_per_channel_quantized_weight) + const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight) {} + enable_per_channel_quantized_weight), + enable_weight_only_(enable_weight_only) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); @@ -420,6 +422,7 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { + if (enable_weight_only_) return; DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); const bool should_quantize_per_channel = enable_per_channel_quantized_weight_ && @@ -432,15 +435,20 @@ class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; + // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform + // weight-only quantization. + const bool enable_weight_only_; }; // Quantizes the entry function's body containing a `ConvolutionOp`. class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeConvolutionOpPattern( - const bool enable_per_channel_quantized_weight) + const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) : enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight) {} + enable_per_channel_quantized_weight), + enable_weight_only_(enable_weight_only) {} LogicalResult match(func::FuncOp entry_func_op) const override { return MatchGemmStyleOp(entry_func_op); @@ -448,6 +456,7 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override { + if (enable_weight_only_) return; RewriteGemmStyleOp( entry_func_op, rewriter, enable_per_channel_quantized_weight_ && @@ -475,13 +484,17 @@ class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; + // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform + // weight-only quantization. + const bool enable_weight_only_; }; template class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { public: explicit QuantizeSingularOpPattern( - const bool enable_per_channel_quantized_weight) {} + const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) {} LogicalResult match(func::FuncOp entry_func_op) const override { const auto op_iterator_range = entry_func_op.getOps(); @@ -569,10 +582,12 @@ template { public: explicit XlaCallModuleOpToCallOp( - MLIRContext& ctx, const bool enable_per_channel_quantized_weight) + MLIRContext& ctx, const bool enable_per_channel_quantized_weight, + const bool enable_weight_only) : OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( - enable_per_channel_quantized_weight) {} + enable_per_channel_quantized_weight), + enable_weight_only_(enable_weight_only) {} LogicalResult match(TF::XlaCallModuleOp op) const override { ModuleOp module_op = op->getParentOfType(); @@ -581,13 +596,19 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { // Ignore unquantized ops. if (!IsQuantizedXlaCallModuleOp(op)) return failure(); + // For weight-only quantization, op should be hybrid quantized. + if (enable_weight_only_ && !IsHybridQuantizedOp(op)) { + return failure(); + } + func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); if (!entry_func_op) { op->emitError("Failed to find a valid entry function."); return failure(); } - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, + enable_weight_only_) .match(entry_func_op); } @@ -601,7 +622,8 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_, + enable_weight_only_), quantization_method); } @@ -609,6 +631,9 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { [[deprecated( "Do not rely on this field for per-channel quantization. Use `Method` " "instead.")]] const bool enable_per_channel_quantized_weight_; + // TODO: b/331510853 - Deprecate boolean flag and use `Method` to perform + // weight-only quantization. + const bool enable_weight_only_; }; // Quantizes op with regions such as stablehlo.reduce_window op. @@ -883,72 +908,32 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { return false; } -class QuantizeWeightOnlyDotGeneralPattern - : public EntryFuncBodyQuantizationPattern { +template +class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { public: - explicit QuantizeWeightOnlyDotGeneralPattern() = default; + explicit QuantizeWeightOnlyOpPattern( + const bool enable_per_channel_quantized_weight) {} LogicalResult match(func::FuncOp entry_func_op) const override { - return MatchGemmStyleOp(entry_func_op); + return MatchGemmStyleOp(entry_func_op); } void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, PatternRewriter& rewriter) const override {} }; -template >> -class WeightOnlyXlaCallModuleOpToCallOp - : public OpRewritePattern { - public: - explicit WeightOnlyXlaCallModuleOpToCallOp( - MLIRContext& ctx, const bool enable_per_channel_quantized_weight) - : OpRewritePattern(&ctx) {}; - - LogicalResult match(TF::XlaCallModuleOp op) const override { - ModuleOp module_op = op->getParentOfType(); - SymbolTable symbol_table(module_op); - - // Ignore unquantized ops. - if (!IsHybridQuantizedOp(op) || !IsOpQuantizableStableHlo(op)) { - return failure(); - } - - func::FuncOp entry_func_op = GetEntryFuncOp(op, symbol_table); - if (!entry_func_op) { - op->emitError("Failed to find a valid entry function."); - return failure(); - } - - return FuncBodyRewritePatternT().match(entry_func_op); - } - - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { - // TODO: b/331145946 - Each quantization method should be valid - // (GetQuantizationMethodOrDefault swallows invalid method attribute). Check - // the validity in `match()`. Use accessors to achieve this. - const Method quantization_method = - GetQuantizationMethodOrDefault(xla_call_module_op); - - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, - FuncBodyRewritePatternT(), quantization_method); - } -}; - // Compute heavy patterns should be quantized for both server and ODML targets. void PopulateComputeHeavyPatterns( MLIRContext& ctx, RewritePatternSet& patterns, const bool enable_per_channel_quantized_weight) { patterns.add>( - ctx, enable_per_channel_quantized_weight); + ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); patterns.add>( - ctx, enable_per_channel_quantized_weight); + ctx, enable_per_channel_quantized_weight, /*enable_weight_only=*/false); // TODO: b/307620772 - Per-channel quantization for gather. patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false); + ctx, /*enable_per_channel_quantized_weight=*/false, + /*enable_weight_only=*/false); // Populate pattern for quantization of ops with regions such as // `stablehlo.reduce_window` op. patterns.add(ctx); @@ -957,14 +942,16 @@ void PopulateComputeHeavyPatterns( void PopulateAllQuantizablePatterns(MLIRContext& ctx, RewritePatternSet& patterns) { patterns.add>>( - ctx, /*enable_per_channel_quantized_weight=*/false); + ctx, /*enable_per_channel_quantized_weight=*/false, + /*enable_weight_only=*/false); } void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx, RewritePatternSet& patterns) { - patterns.add< - WeightOnlyXlaCallModuleOpToCallOp>( - ctx, /*enable_per_channel_quantized_weight=*/false); + patterns.add, + XlaCallModuleOpToCallOp>( + ctx, /*enable_per_channel_quantized_weight*/ false, + /*enable_weight_only=*/true); } } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 4d6f4b3fe86832..8bb2bd33564481 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -77,6 +77,14 @@ struct StableHloQuantizationReverse quantfork::QuantizeCastOp>(ctx) {} }; +bool IsHybridQuantizableOp(Operation& op) { + auto call_op = cast(op); + if (call_op == nullptr) return false; + StringRef entry_function_name = GetEntryFunctionName(call_op); + return entry_function_name.contains("conv") || + entry_function_name.contains("dot_general"); +} + // Quantization rewrite pattern using DQ as the root op. struct StableHloQuantizationWeightOnly : public StableHloQuantizationBase { @@ -84,8 +92,7 @@ struct StableHloQuantizationWeightOnly : StableHloQuantizationBase(ctx) {} static bool AllowWeightOnlyQuantization(Operation& op) { - auto call_op = cast(op); - return call_op && GetEntryFunctionName(call_op).contains("dot_general"); + return IsHybridQuantizableOp(op); } }; @@ -97,8 +104,7 @@ class QuantizePass : public impl::QuantizePassBase { explicit QuantizePass(const bool enable_per_channel_quantized_weight, const bool enable_full_int_quantization, - const bool enable_weight_only, - const QuantizationSpecs& quant_specs) { + const bool enable_weight_only) { enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; enable_full_int_quantization_ = enable_full_int_quantization; enable_weight_only_ = enable_weight_only; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py index a76ca4e75ac764..80a2c560ef865b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/quantize_model_test.py @@ -1014,12 +1014,101 @@ def test_matmul_weight_only_model( # dequantization. self.assertTrue(re.search('stablehlo.subtract', module_str)) self.assertTrue(re.search('stablehlo.multiply', module_str)) + # Tests that the output graph contains float dot_general. self.assertTrue( re.search('stablehlo.dot_general.*xf32>.*xf32>.*xf32>', module_str) ) + + # Due to other meta data, the compression is not exactly 1/4. + self.assertLess( + testing.get_size_ratio( + self._output_saved_model_path, self._input_saved_model_path + ), + 0.3, + ) + + @parameterized.parameters( + testing.parameter_combinations([{ + 'bias_fn': ( + None, + nn_ops.bias_add, + ), + 'activation_fn': ( + None, + nn_ops.relu, + nn_ops.relu6, + ), + 'has_batch_norm': (False,), + 'input_shape_dynamic': ( + False, + True, + ), + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_conv_weight_only_model( + self, + bias_fn: Optional[ops.Operation], + activation_fn: Optional[ops.Operation], + has_batch_norm: bool, + input_shape_dynamic: bool, + dilations: Sequence[int] = None, + ): + input_shape = (None, 3, 4, 3) if input_shape_dynamic else (1, 3, 4, 3) + filter_shape = (2, 3, 3, 2) + strides = (1, 1, 1, 1) + model = self._create_conv2d_model( + input_shape, + filter_shape, + self._input_saved_model_path, + bias_fn, + activation_fn, + has_batch_norm, + strides, + dilations, + ) + + rng = np.random.default_rng(1234) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + input_data = ops.convert_to_tensor( + rng.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + np.float32 + ) + ) + + config = qc.QuantizationConfig( + weight_only_preset=qc.WeightOnlyPreset(), + tf_saved_model=qc.TfSavedModelConfig(tags=[tag_constants.SERVING]), + ) + quantization.quantize_saved_model( + self._input_saved_model_path, + self._output_saved_model_path, + config, + ) + + expected_outputs = model.conv2d(input_data) + + root = load.load(self._output_saved_model_path) + self.assertCountEqual(root.signatures.keys(), {'serving_default'}) + + new_outputs = root.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # Tests that the quantized graph outputs similar values. The rtol and atol + # values are arbitrary. + self.assertAllClose(new_outputs, expected_outputs, rtol=0.03, atol=0.2) + + module_str = self._extract_first_xla_call_module_op( + self._output_saved_model_path + ) + + # Tests that the output graph contains subtract and multiply for + # dequantization. + self.assertTrue(re.search('stablehlo.subtract', module_str)) + self.assertTrue(re.search('stablehlo.multiply', module_str)) # Tests that the output graph contains float dot_general. self.assertTrue( - re.search('stablehlo.dot_general.*xf32>.*xf32>.*xf32>', module_str) + re.search('stablehlo.convolution.*xf32>.*xf32>.*xf32>', module_str) ) # Due to other meta data, the compression is not exactly 1/4. @@ -1027,7 +1116,7 @@ def test_matmul_weight_only_model( testing.get_size_ratio( self._output_saved_model_path, self._input_saved_model_path ), - 0.3, + 0.35, ) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir index f9a6aaea3a500f..6db474de676ccc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/quantize_weight_only.mlir @@ -1,6 +1,7 @@ // RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-quantize=enable-weight-only=true | FileCheck %s -// Test that hybrid quantized op is produced when q/dq pair only exists for weight. +// Test that hybrid quantized dot_general is produced when q/dq pair only exists +// for weight. module attributes {tf_saved_model.semantics} { func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { @@ -29,3 +30,36 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] // CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: return %[[DOT]] + +// ----- + +// Test that hybrid quantized convolution is produced when q/dq pair only exists +// for weight. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %0 = "quantfork.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantfork.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[Q]]) : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir index c14ff0e36340b3..dce15fe07760e2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions_weight_only.mlir @@ -28,3 +28,33 @@ module attributes {tf_saved_model.semantics} { // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] // CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> // CHECK: return %[[DOT]] + +// ----- + +// Test that hybrid quantized convolution op is produced when enable-weight-only +// is set to true. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] From e84bf191f278a6a52e9acca0c0deda4037bd3867 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Thu, 28 Mar 2024 18:51:38 -0700 Subject: [PATCH 069/124] [XLA] Respect min_rank for reduce scatter version of MatchReduceScatter. We need to honor it. PiperOrigin-RevId: 620121620 --- .../xla/xla/service/collective_opt_utils.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/service/collective_opt_utils.cc b/third_party/xla/xla/service/collective_opt_utils.cc index 183173801077c5..cbc7a4c8867bd4 100644 --- a/third_party/xla/xla/service/collective_opt_utils.cc +++ b/third_party/xla/xla/service/collective_opt_utils.cc @@ -267,12 +267,13 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, return true; } -ReduceScatterSpec SpecFromReduceScatterInstr(const HloInstruction* rs_instr, - int64_t num_partitions, - int64_t num_replicas, - bool is_constrain_layout, - bool use_global_device_ids, - bool is_cross_module) { +std::optional SpecFromReduceScatterInstr( + const HloInstruction* rs_instr, int64_t num_partitions, + int64_t num_replicas, int64_t min_rank, bool is_constrain_layout, + bool use_global_device_ids, bool is_cross_module) { + if (rs_instr->shape().rank() < min_rank) { + return std::nullopt; + } CHECK(rs_instr->opcode() == HloOpcode::kReduceScatter); ReduceScatterSpec spec; spec.split_dim = rs_instr->dimensions(0); @@ -303,7 +304,7 @@ std::optional MatchReduceScatter( HloPredicate match_partition_id, HloPredicate match_replica_id) { if (ar->opcode() == HloOpcode::kReduceScatter) { return SpecFromReduceScatterInstr( - ar, num_partitions, num_replicas, ar->constrain_layout(), + ar, num_partitions, num_replicas, min_rank, ar->constrain_layout(), ar->use_global_device_ids(), ar->channel_id().has_value()); } auto spec = MatchWithDynamicSlice( From 0a9c74775046f815cbd30563cb5bed847774e16f Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Thu, 28 Mar 2024 19:04:05 -0700 Subject: [PATCH 070/124] Integrate LLVM at llvm/llvm-project@aa2c14de1adc Updates LLVM usage to match [aa2c14de1adc](https://github.com/llvm/llvm-project/commit/aa2c14de1adc) PiperOrigin-RevId: 620124069 --- .../compiler/mlir/lite/experimental/tac/BUILD | 1 + .../compiler/mlir/lite/quantization/ir/BUILD | 1 + .../mlir/quantization/common/ir/BUILD | 1 + tensorflow/compiler/mlir/tensorflow/BUILD | 4 + .../mlir/tensorflow/ir/host_runtime/BUILD | 1 + tensorflow/compiler/mlir/tfrt/ir/BUILD | 2 + .../compiler/mlir/tools/kernel_gen/ir/BUILD | 1 + third_party/llvm/generated.patch | 206 ++++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- third_party/xla/xla/mlir_hlo/BUILD | 2 + third_party/xla/xla/pjrt/BUILD | 1 + third_party/xla/xla/python/BUILD | 1 + .../xla/xla/service/gpu/fusions/mlir/ir/BUILD | 1 + .../xla/xla/service/gpu/gpu_compiler.cc | 4 +- 14 files changed, 226 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index 1c5a0703d0a58a..248a55c7fe17e1 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -82,6 +82,7 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index ffac0779313307..a6d6c61444548e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -92,6 +92,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", diff --git a/tensorflow/compiler/mlir/quantization/common/ir/BUILD b/tensorflow/compiler/mlir/quantization/common/ir/BUILD index c1429a27368d51..615f54f70d2373 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/ir/BUILD @@ -70,6 +70,7 @@ cc_library( deps = [ ":QuantOpsIncGen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:QuantOps", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 79203111b03ea4..26d5e4d52b41d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -356,6 +356,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", @@ -404,6 +405,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", @@ -453,6 +455,7 @@ cc_library( "//tensorflow/core/common_runtime:lower_function_call_inline_policy", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", @@ -556,6 +559,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfacesIncGen", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ControlFlowInterfaces", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD index ddef04d4185e1d..ccf7b0b547ab90 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD @@ -73,6 +73,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/framework:resource_handle", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index 88baa91f6de0d0..68e9624e118453 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -51,6 +51,7 @@ cc_library( ":tfrt_fallback_common", ":tfrt_fallback_opdefs", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:SideEffectInterfaces", @@ -256,6 +257,7 @@ cc_library( ":tfrt_fallback_opdefs", ":tfrt_gpu_opdefs_inc_gen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 29335382ec41a7..42d679c35d0173 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -86,6 +86,7 @@ cc_library( "@com_google_absl//absl/status", "@llvm-project//mlir:AllocationOpInterface", "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..229971a2e9ad47 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,207 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/lib/APINotes/APINotesWriter.cpp b/clang/lib/APINotes/APINotesWriter.cpp +--- a/clang/lib/APINotes/APINotesWriter.cpp ++++ b/clang/lib/APINotes/APINotesWriter.cpp +@@ -441,7 +441,7 @@ + std::sort(VI.begin(), VI.end(), + [](const std::pair &LHS, + const std::pair &RHS) -> bool { +- assert(LHS.first != RHS.first && ++ assert((&LHS == &RHS || LHS.first != RHS.first) && + "two entries for the same version"); + return LHS.first < RHS.first; + }); +diff -ruN --strip-trailing-cr a/clang/test/APINotes/module-cache.m b/clang/test/APINotes/module-cache.m +--- a/clang/test/APINotes/module-cache.m ++++ b/clang/test/APINotes/module-cache.m +@@ -27,6 +27,7 @@ + // RUN: FileCheck -check-prefix=CHECK-ONE-ERROR %s < %t/before.log + + // Change the API notes file, after the module has rebuilt once. ++// RUN: chmod u+w %t/APINotes/SomeOtherKit.apinotes + // RUN: echo ' - Selector: "methodA"' >> %t/APINotes/SomeOtherKit.apinotes + // RUN: echo ' MethodKind: Instance' >> %t/APINotes/SomeOtherKit.apinotes + // RUN: echo ' Availability: none' >> %t/APINotes/SomeOtherKit.apinotes +diff -ruN --strip-trailing-cr a/lld/test/ELF/lto/libcall-archive.ll b/lld/test/ELF/lto/libcall-archive.ll +--- a/lld/test/ELF/lto/libcall-archive.ll ++++ b/lld/test/ELF/lto/libcall-archive.ll +@@ -4,8 +4,8 @@ + ; RUN: llvm-as -o %t2.o %S/Inputs/libcall-archive.ll + ; RUN: llvm-mc -filetype=obj -triple=x86_64-unknown-linux -o %t3.o %S/Inputs/libcall-archive.s + ; RUN: llvm-ar rcs %t.a %t2.o %t3.o +-; RUN: ld.lld --why-extract=why.txt -o %t %t.o %t.a +-; RUN: FileCheck %s --input-file=why.txt --check-prefix=CHECK-WHY ++; RUN: ld.lld --why-extract=%t.why.txt -o %t %t.o %t.a ++; RUN: FileCheck %s --input-file=%t.why.txt --check-prefix=CHECK-WHY + ; RUN: llvm-nm %t | FileCheck %s + ; RUN: ld.lld -o %t2 %t.o --start-lib %t2.o %t3.o --end-lib + ; RUN: llvm-nm %t2 | FileCheck %s +diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/Verifier.h b/llvm/include/llvm/IR/Verifier.h +--- a/llvm/include/llvm/IR/Verifier.h ++++ b/llvm/include/llvm/IR/Verifier.h +@@ -77,7 +77,6 @@ + /// Visit an instruction and return true if it is valid, return false if an + /// invalid TBAA is attached. + bool visitTBAAMetadata(Instruction &I, const MDNode *MD); +- bool visitTBAAStructMetadata(Instruction &I, const MDNode *MD); + }; + + /// Check a function for errors, useful for use when debugging a +diff -ruN --strip-trailing-cr a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp +--- a/llvm/lib/IR/Verifier.cpp ++++ b/llvm/lib/IR/Verifier.cpp +@@ -5096,9 +5096,6 @@ + if (MDNode *TBAA = I.getMetadata(LLVMContext::MD_tbaa)) + TBAAVerifyHelper.visitTBAAMetadata(I, TBAA); + +- if (MDNode *TBAA = I.getMetadata(LLVMContext::MD_tbaa_struct)) +- TBAAVerifyHelper.visitTBAAStructMetadata(I, TBAA); +- + if (MDNode *MD = I.getMetadata(LLVMContext::MD_noalias)) + visitAliasScopeListMetadata(MD); + if (MDNode *MD = I.getMetadata(LLVMContext::MD_alias_scope)) +@@ -7422,35 +7419,6 @@ + return true; + } + +-bool TBAAVerifier::visitTBAAStructMetadata(Instruction &I, const MDNode *MD) { +- CheckTBAA(MD->getNumOperands() % 3 == 0, +- "tbaa.struct operands must occur in groups of three", &I, MD); +- +- // Each group of three operands must consist of two integers and a +- // tbaa node. Moreover, the regions described by the offset and size +- // operands must be non-overlapping. +- std::optional NextFree; +- for (unsigned int Idx = 0; Idx < MD->getNumOperands(); Idx += 3) { +- auto *OffsetCI = +- mdconst::dyn_extract_or_null(MD->getOperand(Idx)); +- CheckTBAA(OffsetCI, "Offset must be a constant integer", &I, MD); +- +- auto *SizeCI = +- mdconst::dyn_extract_or_null(MD->getOperand(Idx + 1)); +- CheckTBAA(SizeCI, "Size must be a constant integer", &I, MD); +- +- MDNode *TBAA = dyn_cast_or_null(MD->getOperand(Idx + 2)); +- CheckTBAA(TBAA, "TBAA tag missing", &I, MD); +- visitTBAAMetadata(I, TBAA); +- +- bool NonOverlapping = !NextFree || NextFree->ule(OffsetCI->getValue()); +- CheckTBAA(NonOverlapping, "Overlapping tbaa.struct regions", &I, MD); +- +- NextFree = OffsetCI->getValue() + SizeCI->getValue(); +- } +- return true; +-} +- + char VerifierLegacyPass::ID = 0; + INITIALIZE_PASS(VerifierLegacyPass, "verify", "Module Verifier", false, false) + +diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/arm64-abi_align.ll b/llvm/test/CodeGen/AArch64/arm64-abi_align.ll +--- a/llvm/test/CodeGen/AArch64/arm64-abi_align.ll ++++ b/llvm/test/CodeGen/AArch64/arm64-abi_align.ll +@@ -518,6 +518,4 @@ + !1 = !{!"omnipotent char", !2} + !2 = !{!"Simple C/C++ TBAA"} + !3 = !{!"short", !1} +-!4 = !{i64 0, i64 4, !5, i64 4, i64 2, !6, i64 8, i64 4, !5, i64 12, i64 2, !6, i64 16, i64 4, !5, i64 20, i64 2, !6} +-!5 = !{!0, !0, i64 0} +-!6 = !{!3, !3, i64 0} ++!4 = !{i64 0, i64 4, !0, i64 4, i64 2, !3, i64 8, i64 4, !0, i64 12, i64 2, !3, i64 16, i64 4, !0, i64 20, i64 2, !3} +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll +--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll ++++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll +@@ -141,4 +141,4 @@ + !5 = distinct !{!5, !"some domain"} + !6 = !{!7} + !7 = distinct !{!7, !5, !"some scope 2"} +-!8 = !{i64 0, i64 8, !0} ++!8 = !{i64 0, i64 8, null} +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll b/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll +--- a/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll ++++ b/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll +@@ -75,7 +75,7 @@ + !1 = !{!"omnipotent char", !0} + !2 = !{!5, !5, i64 0} + !3 = !{i64 0, i64 4, !2} +-!4 = !{i64 0, i64 8, !2} ++!4 = !{i64 0, i64 8, null} + !5 = !{!"float", !0} + !6 = !{i64 0, i64 4, !2, i64 4, i64 4, !2} + !7 = !{i64 0, i64 2, !2, i64 4, i64 6, !2} +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll b/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll +--- a/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll ++++ b/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll +@@ -836,6 +836,5 @@ + !2 = !{ !"set2", !0 } + !3 = !{ !3, !{!"llvm.loop.parallel_accesses", !13} } + !4 = !{ float 4.0 } +-!5 = !{ i64 0, i64 8, !6 } +-!6 = !{ !1, !1, i64 0 } ++!5 = !{ i64 0, i64 8, null } + !13 = distinct !{} +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/Scalarizer/basic.ll b/llvm/test/Transforms/Scalarizer/basic.ll +--- a/llvm/test/Transforms/Scalarizer/basic.ll ++++ b/llvm/test/Transforms/Scalarizer/basic.ll +@@ -870,6 +870,5 @@ + !2 = !{ !"set2", !0 } + !3 = !{ !3, !{!"llvm.loop.parallel_accesses", !13} } + !4 = !{ float 4.0 } +-!5 = !{ i64 0, i64 8, !6 } +-!6 = !{ !1, !1, i64 0 } ++!5 = !{ i64 0, i64 8, null } + !13 = distinct !{} +diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SROA/tbaa-struct3.ll b/llvm/test/Transforms/SROA/tbaa-struct3.ll +--- a/llvm/test/Transforms/SROA/tbaa-struct3.ll ++++ b/llvm/test/Transforms/SROA/tbaa-struct3.ll +@@ -539,7 +539,7 @@ + !6 = !{!5, !5, i64 0} + !7 = !{i64 0, i64 8, !6, i64 8, i64 4, !1} + !8 = !{i64 0, i64 4, !1, i64 4, i64 8, !6} +-!9 = !{i64 0, i64 8, !6, i64 8, i64 8, !1} ++!9 = !{i64 0, i64 8, !6, i64 4, i64 8, !1} + !10 = !{i64 0, i64 2, !1, i64 2, i64 2, !1} + !11 = !{i64 0, i64 1, !1, i64 1, i64 3, !1} + !12 = !{i64 0, i64 2, !1, i64 2, i64 6, !1} +diff -ruN --strip-trailing-cr a/llvm/test/Verifier/tbaa-struct.ll b/llvm/test/Verifier/tbaa-struct.ll +--- a/llvm/test/Verifier/tbaa-struct.ll ++++ b/llvm/test/Verifier/tbaa-struct.ll +@@ -1,36 +1,28 @@ +-; RUN: not llvm-as < %s 2>&1 | FileCheck %s ++; RUN: llvm-as < %s 2>&1 ++ ++; FIXME: The verifer should reject the invalid !tbaa.struct nodes below. + + define void @test_overlapping_regions(ptr %a1) { +-; CHECK: Overlapping tbaa.struct regions +-; CHECK-NEXT: %ld = load i8, ptr %a1, align 1, !tbaa.struct !0 + %ld = load i8, ptr %a1, align 1, !tbaa.struct !0 + ret void + } + + define void @test_size_not_integer(ptr %a1) { +-; CHECK: Size must be a constant integer +-; CHECK-NEXT: store i8 1, ptr %a1, align 1, !tbaa.struct !5 + store i8 1, ptr %a1, align 1, !tbaa.struct !5 + ret void + } + + define void @test_offset_not_integer(ptr %a1, ptr %a2) { +-; CHECK: Offset must be a constant integer +-; CHECK-NEXT: tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !6 + tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !6 + ret void + } + + define void @test_tbaa_missing(ptr %a1, ptr %a2) { +-; CHECK: TBAA tag missing +-; CHECK-NEXT: tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !7 + tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !7 + ret void + } + + define void @test_tbaa_invalid(ptr %a1) { +-; CHECK: Old-style TBAA is no longer allowed, use struct-path TBAA instead +-; CHECK-NEXT: store i8 1, ptr %a1, align 1, !tbaa.struct !8 + store i8 1, ptr %a1, align 1, !tbaa.struct !8 + ret void + } diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index ad8f149f602ed5..1a7d56b5764590 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "feebcd65fb7e0534f5219e05432a05e45aa8cd2a" - LLVM_SHA256 = "39b2b0c5f5fefb54866a0e9738f1617d79049dbac3b5cdecb7b1f785a57bb669" + LLVM_COMMIT = "aa2c14de1adcd265bf0c0fb44f97b5d6c1c38710" + LLVM_SHA256 = "50d2c7cd5355ec04a75991f2a4e2c89a3876b46fc1b71cd9fa3245f212d55da0" tf_http_archive( name = name, diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 0727dd7bc7f0fe..dfed55f77e5fd3 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -440,6 +440,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:Dialect", @@ -1104,6 +1105,7 @@ cc_library( "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:GPUCommonTransforms", "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:GPUTransforms", diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 7d316928c72772..fe75981da57781 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -570,6 +570,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index b77b4eefc7f919..2f022afb5c51f1 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -907,6 +907,7 @@ cc_library( "//xla/mlir/utils:error_util", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD index d81257a2a05382..e82df33e965214 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD @@ -58,6 +58,7 @@ cc_library( deps = [ ":xla_gpu_ops_inc_gen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index c6f0490777370a..c9c7e3523170c4 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1695,12 +1695,12 @@ absl::StatusOr> GpuCompiler::AssignBuffers( using OutputInfoMap = absl::flat_hash_map; -static void NullDiagnosticHandler(const llvm::DiagnosticInfo& diag_info, +static void NullDiagnosticHandler(const llvm::DiagnosticInfo* diag_info, void* context) { std::string error_string; llvm::raw_string_ostream string_printer(error_string); llvm::DiagnosticPrinterRawOStream diagnostic_printer(string_printer); - diag_info.print(diagnostic_printer); + diag_info->print(diagnostic_printer); VLOG(5) << error_string; } From 8330127f0aa8f9f117ddaa4134c5598ed9ea3e61 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 19:46:52 -0700 Subject: [PATCH 071/124] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 620131121 --- tensorflow/go/op/wrappers.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 4f3a7d7af782be..8a825b310ca47c 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -59743,17 +59743,13 @@ func XlaSplitNDPaddings(value []int64) XlaSplitNDAttr { // // Arguments: // -// input: Input tensor to split across all dimensions. -// } -// out_arg { -// name: "outputs" -// description: < Date: Thu, 28 Mar 2024 21:35:01 -0700 Subject: [PATCH 072/124] Automated Code Change PiperOrigin-RevId: 620149815 --- .../xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h index e6e5bfed1493cc..a4b01ae517f650 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/connected_traceme.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_CONNECTED_TRACEME_H_ #define TENSORFLOW_TSL_PROFILER_LIB_CONNECTED_TRACEME_H_ +#include #include #include @@ -79,7 +80,7 @@ class TraceMeProducer : public TraceMe { template explicit TraceMeProducer(NameT&& name, ContextType context_type = ContextType::kGeneric, - absl::optional context_id = absl::nullopt, + std::optional context_id = std::nullopt, int level = 2) : TraceMe(std::forward(name), level), context_id_(context_id.has_value() ? context_id.value() From 903f741b18ef0fcf004d8677f03c893bb66ee2f2 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Thu, 28 Mar 2024 22:26:52 -0700 Subject: [PATCH 073/124] [xla:gpu][NFC] Use absl::Span more consistenly PiperOrigin-RevId: 620156928 --- .../address_computation_fusion_rewriter.cc | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index 89a47d7cd9cd05..631ce71a4407d1 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -59,13 +59,20 @@ namespace { // A dataflow path flowing from a definition to a user. using DefUseDataflowPath = absl::InlinedVector; + // All dataflow paths flowing from a definition to all users. Each user will // have a separate entry in the vector. using DefUseDataflowPaths = absl::InlinedVector; + // A dataflow path flowing from a user to a definition. using UseDefDataflowPath = absl::InlinedVector; + // All dataflow paths flowing from a user to all definitions of its operands. using UseDefDataflowPaths = absl::InlinedVector; + +using DataflowPathView = absl::Span; +using DataflowPathsView = absl::Span; + using InstructionSet = absl::flat_hash_set; bool IsNoOp(const HloInstruction* hlo) { @@ -262,7 +269,7 @@ DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { } absl::InlinedVector GetPatternCaptures( - absl::Span matches) { + DataflowPathView matches) { absl::InlinedVector captures; InstructionSet matched_instrs(matches.begin(), matches.end()); @@ -280,7 +287,7 @@ absl::InlinedVector GetPatternCaptures( } Status CreateRootTuple(HloInstruction* hero, HloComputation::Builder& builder, - DefUseDataflowPaths sliced_user_paths, + DataflowPathsView sliced_user_paths, absl::flat_hash_map& instr_mapping) { unsigned tuple_size = hero->shape().tuple_shapes_size(); @@ -314,9 +321,8 @@ Status CreateRootTuple(HloInstruction* hero, HloComputation::Builder& builder, } absl::StatusOr CreateFusionBody( - HloModule* module, absl::Span sliced_operand_paths, - DefUseDataflowPaths sliced_user_paths, - absl::Span captures) { + HloModule* module, DataflowPathView sliced_operand_paths, + DataflowPathsView sliced_user_paths, DataflowPathView captures) { HloComputation::Builder builder("address-computation"); // A mapping from original instructions to instructions in the fusion body. @@ -366,9 +372,8 @@ absl::StatusOr CreateFusionBody( } absl::StatusOr CreateFusionInstruction( - HloModule* module, HloInstruction* orig, - absl::Span captures, HloComputation* body, - bool dynamic) { + HloModule* module, HloInstruction* orig, DataflowPathView captures, + HloComputation* body, bool dynamic) { HloComputation* parent = orig->parent(); // Add a fusion operation calling outlined fusion computation. @@ -446,14 +451,21 @@ absl::StatusOr AddressComputationFusionRewriter::Run( std::vector matched_instrs; absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs)); - for (auto& sliced_user_path : sliced_user_paths) + std::vector sliced_user_paths_view; + for (auto& sliced_user_path : sliced_user_paths) { absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs)); + DataflowPathView sliced_user_path_view{&sliced_user_path.front(), + sliced_user_path.size()}; + sliced_user_paths_view.push_back(std::move(sliced_user_path_view)); + } auto captures = GetPatternCaptures(matched_instrs); - TF_ASSIGN_OR_RETURN(HloComputation * fusion_body, - CreateFusionBody(module, sliced_operand_paths, - sliced_user_paths, captures)); + TF_ASSIGN_OR_RETURN( + HloComputation * fusion_body, + CreateFusionBody(module, sliced_operand_paths, + DataflowPathsView(sliced_user_paths_view), + captures)); TF_ASSIGN_OR_RETURN(HloInstruction * fusion, CreateFusionInstruction(module, hero, captures, From 3d9b57da50c0e04454f35ef511a1542f68884af7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 28 Mar 2024 22:29:12 -0700 Subject: [PATCH 074/124] [xla][gpu] Extracting triton codegen requirements for hlo instructions This CL extracts current triton codegen requirements for each hlo instruction into a single function to clean the codes in the triton fusion passes. PiperOrigin-RevId: 620157253 --- third_party/xla/xla/service/gpu/BUILD | 42 + .../xla/service/gpu/cublas_pad_for_gemms.cc | 3 +- .../xla/xla/service/gpu/fusions/triton.cc | 1 + .../xla/xla/service/gpu/gemm_fusion.cc | 116 +-- third_party/xla/xla/service/gpu/gemm_fusion.h | 4 - .../xla/xla/service/gpu/ir_emitter_triton.cc | 27 +- .../service/gpu/softmax_rewriter_triton.cc | 67 +- .../gpu/softmax_rewriter_triton_test.cc | 9 +- .../xla/xla/service/gpu/triton_support.cc | 218 ++++ .../xla/xla/service/gpu/triton_support.h | 6 + .../xla/service/gpu/triton_support_test.cc | 940 ++++++++++++++++++ 11 files changed, 1246 insertions(+), 187 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/triton_support_test.cc diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 1b8117cf702373..b8bd5348f53318 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1477,8 +1477,49 @@ cc_library( ":variant_visitor", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@local_tsl//tsl/platform:tensor_float_32_utils", + ], +) + +xla_test( + name = "triton_support_test", + srcs = if_cuda_is_configured(["triton_support_test.cc"]), + backends = [ + "gpu_a100", + ], + shard_count = 10, + tags = ["nomac"], + deps = [ + ":gpu_device_info_for_tests", + ":gpu_float_support", + ":ir_emission_utils", + ":ir_emitter_triton", + ":matmul_utils", + ":triton_fusion_analysis", + ":triton_support", + "//xla:error_spec", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:float_normalization", + "//xla/service:hlo_pass_pipeline", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", ], ) @@ -3069,6 +3110,7 @@ cc_library( deps = [ ":gemm_fusion", ":ir_emission_utils", + ":triton_support", "//xla:literal_util", "//xla:shape_util", "//xla:statusor", diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc index 42a0b55b8bc1ac..050f219d12b6c8 100644 --- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc +++ b/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/gpu/gemm_fusion.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/triton_support.h" #include "xla/shape.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" @@ -179,7 +180,7 @@ static std::vector GetRelevantDots( ->config() .debug_options() .xla_gpu_enable_triton_gemm() && - CanTritonHandleGEMM(*dot, gpu_compute_capability) && + IsTritonSupportedInstruction(*dot, gpu_compute_capability) && ShouldTritonHandleGEMM(*dot, gpu_compute_capability))) { gemms.push_back(dot); } diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index 2b9565d87bbee9..2fc6d15898da6c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -153,6 +153,7 @@ absl::StatusOr TritonFusion::Emit( triton_config.set_split_k(1); triton_config.set_num_stages(1); triton_config.set_num_warps(2); + triton_config.set_num_ctas(1); } TF_ASSIGN_OR_RETURN( TritonGemmConfig config, diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc index 7518fa51269533..05e758a73f3d47 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion.cc @@ -618,10 +618,11 @@ absl::StatusOr CreateDotFusion( std::vector& fusion_inputs, HloInstruction** fusion_output_ptr) { VLOG(5) << dot.ToString(); - if (FusionDecision can_handle = CanTritonHandleGEMM(dot, gpu_version); - !can_handle) { - VLOG(3) << can_handle.Explain(); - return can_handle; + if (CodegenDecision is_supported = + IsTritonSupportedInstruction(dot, gpu_version); + !is_supported) { + VLOG(3) << is_supported.Explain(); + return is_supported; } // Verify sparse dot constraints. @@ -785,116 +786,9 @@ absl::StatusOr RunOnComputation( return visitor.changed(); } -bool IsSupportedByTriton(PrecisionConfig::Algorithm algorithm, - const se::GpuComputeCapability& gpu_version) { - auto cuda_compute_capability = - std::get_if(&gpu_version); - auto rocm_compute_capability = - std::get_if(&gpu_version); - switch (algorithm) { - case PrecisionConfig::ALG_DOT_TF32_TF32_F32: - if (cuda_compute_capability) { - return true; - } - return false; - case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: - if (cuda_compute_capability) { - return true; - } - if (rocm_compute_capability) { - return rocm_compute_capability->has_bf16_dtype_support(); - } - return false; - - // TODO(b/326579472): Fix the support of this algorithm and maybe allow it - // here. - case PrecisionConfig::ALG_DOT_F16_F16_F32: - // Slow to compile: - case PrecisionConfig::ALG_DOT_F32_F32_F32: - default: - return false; - } -} } // namespace -FusionDecision CanTritonHandleGEMM( - const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { - auto cuda_compute_capability = - std::get_if(&gpu_version); - auto rocm_compute_capability = - std::get_if(&gpu_version); - - if (!cuda_compute_capability && !rocm_compute_capability) { - return "Non CUDA or ROCM device."; - } - - if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { - if (!tsl::tensor_float_32_execution_enabled() || - absl::c_any_of(dot.precision_config().operand_precision(), - [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Non-default precision."; - } - } else { - if (!IsSupportedByTriton(dot.precision_config().algorithm(), - *cuda_compute_capability)) { - return "Unsupported algorithm on the current device(s)."; - } - } - - auto supported_output_type = [&](const PrimitiveType t) { - switch (t) { - case F16: - case F32: - return true; - case BF16: - if (cuda_compute_capability) { - return true; - } - if (rocm_compute_capability) { - return rocm_compute_capability->has_bf16_dtype_support(); - } - return false; - default: - return false; - } - }; - - // TODO(b/266862493): Support more output types. - if (!supported_output_type(dot.shape().element_type())) { - return "Unsupported output data type."; - } - - if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), - gpu_version) || - !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), - gpu_version)) { - return "Unsupported input data type."; - } - - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - - // TODO(b/269580541): support multiple batch dimensions. - if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; - } - - // Cases where lhs or rhs have no non-contracting dims are not handled. - if (dim_numbers.lhs_batch_dimensions().size() + - dim_numbers.lhs_contracting_dimensions().size() == - dot.operand(0)->shape().rank() || - dim_numbers.rhs_batch_dimensions().size() + - dim_numbers.rhs_contracting_dimensions().size() == - dot.operand(1)->shape().rank()) { - return "No non-contracting dimensions."; - } - - return FusionDecision{}; -} - bool ShouldTritonHandleGEMM(HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { std::vector fusion_inputs; diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.h b/third_party/xla/xla/service/gpu/gemm_fusion.h index 1ddf1bd850fc95..1138ad28a36a5f 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.h +++ b/third_party/xla/xla/service/gpu/gemm_fusion.h @@ -30,10 +30,6 @@ limitations under the License. namespace xla { namespace gpu { -// Filters GEMMs which can be handled using Triton. -FusionDecision CanTritonHandleGEMM(const HloDotInstruction&, - const se::GpuComputeCapability&); - // Filters GEMMs which are better to handle using Triton. bool ShouldTritonHandleGEMM(HloDotInstruction&, const se::GpuComputeCapability&); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index 184880632da09e..427ac4dbd6b4fa 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -415,9 +415,11 @@ Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) { return b.create(ptr.getType(), ptr, offset); } -Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloInstruction& hlo, ValueRange inputs) { +absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloInstruction& hlo, + ValueRange inputs) { if (mlir::getElementTypeOrSelf(inputs[0]).isF32() || mlir::getElementTypeOrSelf(inputs[0]).isF64()) { auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode()); @@ -489,7 +491,8 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path, mlir::mhlo::ComparisonDirection::NE), inputs[1], inputs[2]); default: - LOG(FATAL) << "Unsupported operation " << hlo.ToString(); + return absl::InvalidArgumentError( + absl::StrCat("Unsupported elementwise operation ", hlo.ToString())); } } @@ -901,7 +904,8 @@ absl::StatusOr EmitScope( for (const HloInstruction* operand : hlo->operands()) { operands.push_back(values[operand]); } - result = EmitElementwise(b, libdevice_path, device_info, *hlo, operands); + TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, + device_info, *hlo, operands)); } else if (hlo->opcode() == HloOpcode::kTuple) { TF_RET_CHECK(hlo->IsRoot()) << hlo->ToString(); } else if (hlo->opcode() == HloOpcode::kBitcast || @@ -919,7 +923,8 @@ absl::StatusOr EmitScope( EmitNestedFusion(b, libdevice_path, device_info, *fusion_instruction, values)); } else { - LOG(FATAL) << hlo->ToString(); + return absl::InvalidArgumentError( + absl::StrCat("Unsupported operation ", hlo->ToString())); } TF_RET_CHECK(values.insert({hlo, result}).second) << hlo->ToString(); VLOG(8) << "Emitted " << hlo->ToString(HloPrintOptions::ShortParsable()); @@ -1191,11 +1196,11 @@ struct MatMulLaunchConfig { matmul_dims.out_lhs_noncontracting_dim_idx = dot.shape().rank() - 2; auto* root = dot.parent()->root_instruction(); - matmul_dims.n = analysis - .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, - matmul_dims.out_rhs_noncontracting_dim_idx) - ->at(0) - .count; + auto iter_spec = + analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, + matmul_dims.out_rhs_noncontracting_dim_idx); + TF_RET_CHECK(iter_spec != nullptr); + matmul_dims.n = iter_spec->at(0).count; // Contracting dimension length. if (config.split_k > 1 && dot.operand(0)->operand(0)->opcode() == HloOpcode::kPad) { diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc index 902a10ed935b71..2db88bc59379a1 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc @@ -59,38 +59,6 @@ bool HasDefaultLayout(const Shape& shape) { LayoutUtil::IsMonotonicWithDim0Major(shape.layout()); } -bool IsTritonSupportedInstruction(const HloInstruction* instr, - const se::GpuComputeCapability& gpu_version) { - if (!instr->shape().IsArray()) { - return false; - } - - if (!IsTritonSupportedDataType(instr->shape().element_type(), gpu_version)) { - return false; - } - - for (const HloInstruction* operand : instr->operands()) { - if (!IsTritonSupportedDataType(operand->shape().element_type(), - gpu_version)) { - return false; - } - } - - // TODO(bchetioui): expand with non-trivial instructions. - if (instr->IsElementwise()) { - return IsTritonSupportedElementwise(instr->opcode(), - instr->shape().element_type()); - } - - switch (instr->opcode()) { - case HloOpcode::kBitcast: - case HloOpcode::kParameter: - return true; - default: - return false; - } -} - // Returns true if a trivially connected producer of 'consumer' with opcode // 'opcode' exists. If such an instruction is found, the value of 'producer' is // set to it. The definition of "trivial" operations is as given in @@ -268,7 +236,7 @@ bool IsTriviallyFusible(HloInstruction* instr, } if (instr->IsElementwise() && instr->operand_count() == 1) { - return IsTritonSupportedInstruction(instr, gpu_version); + return static_cast(IsTritonSupportedInstruction(*instr, gpu_version)); } // Elementwise binary ops are trivially fusible if the operands are the same, @@ -280,7 +248,8 @@ bool IsTriviallyFusible(HloInstruction* instr, // Elementwise binary ops should be fused if both operands are the same and // if the operand is triton supported. if (operand_0 == operand_1) { - return IsTritonSupportedInstruction(instr, gpu_version); + return static_cast( + IsTritonSupportedInstruction(*instr, gpu_version)); } // For simplicity we only fuse elementwise binary ops with splat operands @@ -291,7 +260,8 @@ bool IsTriviallyFusible(HloInstruction* instr, IsSupportedBroadcastOfParameter(*operand_0)) ^ (IsBroadcastOfScalarConstant(*operand_1) || IsSupportedBroadcastOfParameter(*operand_1))) { - return IsTritonSupportedInstruction(instr, gpu_version); + return static_cast( + IsTritonSupportedInstruction(*instr, gpu_version)); } } @@ -337,14 +307,6 @@ bool IsTriviallyConnectedProducerOf( return false; } -bool IsTritonSupportedComputation(const HloComputation* computation, - const se::GpuComputeCapability& gpu_version) { - return absl::c_all_of( - computation->instructions(), [&](const HloInstruction* instr) { - return IsTritonSupportedInstruction(instr, gpu_version); - }); -} - // Finds the first non-fusible producer of a diamond. This instruction is either // 1. the direct producer of the diamond, if that producer is used more than // twice and/or is not otherwise trivially fusible @@ -447,7 +409,7 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( return "Root is not elementwise binary."; } - if (!IsTritonSupportedInstruction(instr, gpu_version_)) { + if (!IsTritonSupportedInstruction(*instr, gpu_version_)) { return "Root is not supported for Triton instruction."; } @@ -471,13 +433,11 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( return "Broadcast or reduce have non-default layouts."; } - if (!(reduce->operand_count() == 2 && - reduce->operand(1)->opcode() == HloOpcode::kConstant)) { - return "Reduce has a non-constant second operand and/or is variadic."; - } - - if (!(IsTritonSupportedComputation(reduce->to_apply(), gpu_version_))) { - return "Unsupported reduction by Triton."; + if (CodegenDecision is_supported = + IsTritonSupportedInstruction(*reduce, gpu_version_); + !is_supported) { + VLOG(3) << is_supported.Explain(); + return is_supported; } if (!HasOneUse(broadcast) || !HasOneUse(reduce)) { @@ -486,11 +446,6 @@ SoftmaxRewriterTriton::MatchesTritonCompatibleClosedReductionDiamond( producer = reduce->mutable_operand(0); - if (reduce->dimensions().size() != 1 || - reduce->dimensions(0) != producer->shape().rank() - 1) { - return "Reduction is not a row-reduction of a single operand."; - } - if (absl::c_linear_search(broadcast->dimensions(), broadcast->shape().rank() - 1)) { return "Broadcast is not along the reduction dimension."; diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc index ef0bc0f6f7d7f6..74e800f9a815cc 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc @@ -1770,10 +1770,11 @@ ENTRY main { if (std::holds_alternative(decision)) { std::string actual_decision = std::get(decision).Explain(); - EXPECT_THAT(actual_decision, - AnyOf(HasSubstr("Root is not elementwise binary"), - HasSubstr("Reduce has a non-constant second operand " - "and/or is variadic"))); + EXPECT_THAT( + actual_decision, + AnyOf(HasSubstr("Root is not elementwise binary"), + HasSubstr("Reduction init value should be a constant or a " + "convert of a constant."))); unmatched++; } else { matched++; diff --git a/third_party/xla/xla/service/gpu/triton_support.cc b/third_party/xla/xla/service/gpu/triton_support.cc index 3e2b15222320d1..66631dbd19ad3d 100644 --- a/third_party/xla/xla/service/gpu/triton_support.cc +++ b/third_party/xla/xla/service/gpu/triton_support.cc @@ -20,11 +20,15 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/tensor_float_32_utils.h" namespace xla { namespace gpu { @@ -50,6 +54,7 @@ bool IsDistributiveOverAddition(const HloInstruction& hlo) { // BF16 is supported in a sense that all operations on it are implemented // through F32 and converts have to be inserted into the HLO graph, but // they can be missing during fusion. +// TODO(b/266862493): Support more data types (F8, F64, etc.). bool IsTritonSupportedDataType(PrimitiveType type, const se::GpuComputeCapability& gpu_version) { switch (type) { @@ -130,5 +135,218 @@ bool IsTritonSupportedElementwise(HloOpcode opcode, opcode); } +CodegenDecision CanTritonHandleElementwise( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + + for (const HloInstruction* operand : instr.operands()) { + if (!IsTritonSupportedDataType(operand->shape().element_type(), + gpu_version)) { + return "Unsupported input data type."; + } + } + + if (instr.opcode() == HloOpcode::kConstant) { + return CodegenDecision{}; + } else if (!IsTritonSupportedElementwise( + instr.opcode(), instr.operand(0)->shape().element_type())) { + return "Unsupported elementwise operation."; + } + return CodegenDecision{}; +} + +bool IsDotAlgorithmSupportedByTriton( + PrecisionConfig::Algorithm algorithm, + const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + switch (algorithm) { + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + if (cuda_compute_capability) { + return true; + } + return false; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + if (cuda_compute_capability) { + return true; + } + if (rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; + + // TODO(b/326579472): Fix the support of this algorithm and maybe allow it + // here. + case PrecisionConfig::ALG_DOT_F16_F16_F32: + // TODO(b/311331155): Triton F32 is about 3x slower than Triton TF32 and is + // slow to compile. Disable it for now. + case PrecisionConfig::ALG_DOT_F32_F32_F32: + default: + return false; + } +} + +// Filters GEMMs which can be handled using Triton. +CodegenDecision CanTritonHandleGEMM( + const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + + CHECK(cuda_compute_capability || rocm_compute_capability); + + if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { + if (!tsl::tensor_float_32_execution_enabled() || + absl::c_any_of(dot.precision_config().operand_precision(), + [](int x) { return x != PrecisionConfig::DEFAULT; })) { + return "Having non-default operand precisions or TensorFloat-32 disabled " + "for Dot op with unset algorithm."; + } + } else { + if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), + gpu_version)) { + return "Unsupported algorithm on the current device(s)."; + } + } + + auto supported_output_type = [&](const PrimitiveType t) { + switch (t) { + case F16: + case F32: + return true; + case BF16: + if (cuda_compute_capability) { + return true; + } + if (rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; + default: + return false; + } + }; + + // TODO(b/266862493): Support more output types. + if (!supported_output_type(dot.shape().element_type())) { + return "Unsupported output data type for Dot op."; + } + + if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), + gpu_version) || + !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), + gpu_version)) { + return "Unsupported input data type for Dot op."; + } + + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + + // TODO(b/269580541): support multiple batch dimensions. + if (dim_numbers.lhs_batch_dimensions().size() > 1) { + return "Multiple batch dimensions."; + } + + // Cases where lhs or rhs have no non-contracting dims are not handled. + if (dim_numbers.lhs_batch_dimensions().size() + + dim_numbers.lhs_contracting_dimensions().size() == + dot.operand(0)->shape().rank() || + dim_numbers.rhs_batch_dimensions().size() + + dim_numbers.rhs_contracting_dimensions().size() == + dot.operand(1)->shape().rank()) { + return "No non-contracting dimensions."; + } + + return CodegenDecision{}; +} + +// Filters Reduces which can be handled using Triton. +CodegenDecision CanTritonHandleReduce( + const HloReduceInstruction& reduce, + const se::GpuComputeCapability& gpu_version) { + if (!IsTritonSupportedDataType(reduce.shape().element_type(), gpu_version)) { + return "Unsupported output data type for Reduce op."; + } + + for (const HloInstruction* operand : reduce.operands()) { + if (!IsTritonSupportedDataType(operand->shape().element_type(), + gpu_version)) { + return "Unsupported input data type for Reduce op."; + } + } + + bool is_triton_supported_reduction_computation = [&]() { + return absl::c_all_of( + reduce.to_apply()->instructions(), [&](const HloInstruction* instr) { + return IsTritonSupportedInstruction(*instr, gpu_version); + }); + }(); + if (!is_triton_supported_reduction_computation) { + return "Unsupported reduction computation by Triton."; + } + + if (reduce.dimensions().size() == 1 && + reduce.dimensions().front() == reduce.operand(0)->shape().rank() - 1 && + reduce.operand_count() == 2) { + const HloInstruction* operand = reduce.operand(1); + // We assume that the reduction init value was input as a constant, or in + // the case of a data type affected by float normalization, a convert of a + // constant. + if (operand->opcode() == HloOpcode::kConvert) { + if (operand->operand(0)->opcode() == HloOpcode::kConstant && + operand->operand(0)->shape().element_type() == BF16 && + operand->shape().element_type() == F32) { + return CodegenDecision{}; + } + } else if (operand->opcode() == HloOpcode::kConstant) { + return CodegenDecision{}; + } + return "Reduction init value should be a constant or a convert of a " + "constant."; + } + return "Reduction is not a row-reduction of a single operand."; +} + +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (instr.IsElementwise()) { + return CanTritonHandleElementwise(instr, gpu_version); + } + + switch (instr.opcode()) { + case HloOpcode::kDot: { + return CanTritonHandleGEMM(*Cast(&instr), gpu_version); + } + case HloOpcode::kReduce: { + return CanTritonHandleReduce(*Cast(&instr), + gpu_version); + } + case HloOpcode::kTuple: { + if (instr.IsRoot()) { + return CodegenDecision{}; + } + return "Only supports root tuples."; + } + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: + case HloOpcode::kReshape: + case HloOpcode::kPad: + case HloOpcode::kConcatenate: + case HloOpcode::kParameter: + case HloOpcode::kBroadcast: + return CodegenDecision{}; + default: + break; + } + return "Unsupported opcode."; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_support.h b/third_party/xla/xla/service/gpu/triton_support.h index 02f6da34089f89..072c9ab948ec00 100644 --- a/third_party/xla/xla/service/gpu/triton_support.h +++ b/third_party/xla/xla/service/gpu/triton_support.h @@ -22,11 +22,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla { namespace gpu { +using CodegenDecision = FusionDecision; // Tells if f(a+b) == f(a) + f(b). bool IsDistributiveOverAddition(const HloInstruction& hlo); @@ -46,6 +48,10 @@ bool IsTritonSupportedDataType(PrimitiveType, const se::GpuComputeCapability&); // Checks elementwise operation against all supported by Triton GEMM codegen. bool IsTritonSupportedElementwise(HloOpcode, PrimitiveType); +// Checks instruction against requirements of triton emitter. +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/triton_support_test.cc b/third_party/xla/xla/service/gpu/triton_support_test.cc new file mode 100644 index 00000000000000..e3ad43b2f0f783 --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_support_test.cc @@ -0,0 +1,940 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/triton_support.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "llvm/IR/Module.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/primitive_util.h" +#include "xla/service/float_normalization.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/gpu_float_support.h" +#include "xla/service/gpu/ir_emitter_triton.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/hlo_pass_pipeline.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +class TritonSupportTest : public GpuCodegenTest { + public: + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + absl::StatusOr ApplyFloatNormalization(HloModule* module) { + const GpuFloatSupport bf16_support(GetCudaComputeCapability(), BF16); + HloPassPipeline pipeline("hlo float normalization"); + pipeline.AddPass(&bf16_support); + return pipeline.Run(module); + } + + float getTolerance(PrimitiveType data_type) { + float tolerance; + switch (data_type) { + case F64: + case F32: + tolerance = 1e-6; + break; + case F16: + tolerance = 2e-4; + break; + case BF16: + tolerance = 2e-2; + break; + case PRED: + case S8: + tolerance = 3e-2; + break; + case S16: + tolerance = 3e-3; + break; + case S32: + tolerance = 3e-3; + break; + default: + ABSL_UNREACHABLE(); + } + return tolerance; + } + + protected: + llvm::LLVMContext llvm_ctx_; + llvm::Module llvm_module_{"module", llvm_ctx_}; + mlir::MLIRContext mlir_context_; + TritonGemmConfig config_{16, 32, 512, 1, 4, 8}; +}; + +class TritonSupportTestWithParam : public TritonSupportTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +std::string TestParamsToString( + const ::testing::TestParamInfo>& + data) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = data.param; + return absl::StrCat( + primitive_util::LowercasePrimitiveTypeName(data_type), "_", + absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}})); +} + +using UnaryElementwiseTest = TritonSupportTestWithParam; + +// TODO(b/331636835): updates elementwise op tests to directly emit single op +// instead of relying on triton gemm kernel. +TEST_P(UnaryElementwiseTest, IsTritonSupportedExecutesCorrectlyForUnary) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[15,33]{1,0} parameter(0) + parameter_1 = $0[33,68]{1,0} parameter(1) + unary = $0[33,68]{1,0} $1(parameter_1) + convert = f32[33,68]{1,0} convert(unary) + ROOT dot = f32[15,68]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[15,33]{1,0} parameter(0) + parameter_1 = $0[33,68]{1,0} parameter(1) + ROOT triton_gemm = f32[15,68]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + // TODO(b/331632717): update the check to use SymbolicTileAnalysis to avoid + // tiling failures and check triton emitter fails gracefully. + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + UnaryElementwiseTestSuite, UnaryElementwiseTest, + ::testing::Combine(::testing::Values(S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kConvert, HloOpcode::kAbs, + HloOpcode::kNegate)), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P( + UnaryPREDTestSuite, UnaryElementwiseTest, + ::testing::Combine(::testing::Values(PRED), + ::testing::Values(HloOpcode::kConvert, HloOpcode::kNot)), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P( + UnaryMathTestSuite, UnaryElementwiseTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kCos, HloOpcode::kExp, + HloOpcode::kExpm1, HloOpcode::kLog, + HloOpcode::kLog1p, HloOpcode::kRsqrt, + HloOpcode::kSin, HloOpcode::kSqrt, + HloOpcode::kCbrt, HloOpcode::kTan, + HloOpcode::kTanh, HloOpcode::kErf)), + TestParamsToString); + +using BinaryElementwiseTest = TritonSupportTestWithParam; + +TEST_P(BinaryElementwiseTest, IsTritonSupportedExecutesCorrectlyForBinaryE) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + binary = $0[11,63]{1,0} $1(parameter_1, parameter_2) + convert = f32[11,63]{1,0} convert(binary) + ROOT dot = f32[92,63]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + ROOT triton_gemm = f32[92,63]{1,0} fusion(parameter_0, parameter_1, parameter_2), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + ::testing::AnyOf( + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr( + "std::holds_alternative")), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements")))); + } +} + +INSTANTIATE_TEST_SUITE_P( + BinaryElementwiseTestSuite, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kAdd, HloOpcode::kMultiply, + HloOpcode::kMaximum, + HloOpcode::kMinimum, + HloOpcode::kSubtract)), + TestParamsToString); + +INSTANTIATE_TEST_SUITE_P(BinaryPREDTestSuite, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(PRED), + ::testing::Values(HloOpcode::kAnd, + HloOpcode::kOr, + HloOpcode::kXor)), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P( + BinaryMathTestSuite, BinaryElementwiseTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kAtan2, HloOpcode::kDivide, + HloOpcode::kPower)), + TestParamsToString); + +using CompareTest = TritonSupportTestWithParam; + +TEST_P(CompareTest, IsTritonSupportedExecutesCorrectlyForCompare) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + compare = pred[11,63]{1,0} $1(parameter_1, parameter_2), direction=GE + convert = f32[11,63]{1,0} convert(compare) + ROOT dot = f32[92,63]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + parameter_2 = $0[11,63]{1,0} parameter(2) + ROOT triton_gemm = f32[92,63]{1,0} fusion(parameter_0, parameter_1, parameter_2), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + EXPECT_THAT( + TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("std::holds_alternative"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + CompareTestSuite, CompareTest, + ::testing::Combine(::testing::Values(PRED, S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kCompare)), + TestParamsToString); + +using TernaryElementwiseTest = TritonSupportTestWithParam; + +TEST_P(TernaryElementwiseTest, IsTritonSupportedExecutesCorrectlyForTernary) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = f32[92,13]{1,0} parameter(0) + parameter_1 = $0[13,63]{1,0} parameter(1) + parameter_2 = $0[13,63]{1,0} parameter(2) + parameter_3 = pred[13,63]{1,0} parameter(3) + ternary = $0[13,63]{1,0} $1(parameter_3, parameter_1, parameter_2) + convert = f32[13,63]{1,0} convert(ternary) + ROOT dot = f32[92,63]{1,0} dot(parameter_0, convert), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + operand_precision={HIGH, HIGH} +} + +ENTRY e { + parameter_0 = f32[92,13]{1,0} parameter(0) + parameter_1 = $0[13,63]{1,0} parameter(1) + parameter_2 = $0[13,63]{1,0} parameter(2) + parameter_3 = pred[13,63]{1,0} parameter(3) + ROOT triton_gemm = f32[92,63]{1,0} fusion(parameter_0, parameter_1, parameter_2, parameter_3), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + EXPECT_THAT( + TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("std::holds_alternative"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + TernaryElementwiseTestSuite, TernaryElementwiseTest, + ::testing::Combine(::testing::Values(PRED, S8, S16, S32, F16, F32, BF16), + ::testing::Values(HloOpcode::kSelect)), + TestParamsToString); + +using DotTest = TritonSupportTestWithParam; + +TEST_P(DotTest, IsTritonSupportedExecutesCorrectlyForDot) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +triton_gemm___computation { + parameter_0 = $0[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + ROOT dot = $0[92,63]{1,0} $1(parameter_0, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + parameter_0 = $0[92,11]{1,0} parameter(0) + parameter_1 = $0[11,63]{1,0} parameter(1) + ROOT triton_gemm = $0[92,63]{1,0} fusion(parameter_0, parameter_1), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + const HloComputation* computation = + module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4})); + } else { + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, + config_, &llvm_module_, &EmitMatMul, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("Failed to compile Triton kernel"))); + } +} + +INSTANTIATE_TEST_SUITE_P(DotTestTestSuite, DotTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kDot)), + TestParamsToString); + +TEST_F(TritonSupportTest, UnsupportedDotOutputTypeFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +triton_gemm___computation { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = f32[11,63]{1,0} parameter(1) + ROOT dot = pred[92,63]{1,0} dot(parameter_0, parameter_1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + parameter_0 = f32[92,11]{1,0} parameter(0) + parameter_1 = f32[11,63]{1,0} parameter(1) + ROOT triton_gemm = pred[92,63]{1,0} fusion(parameter_0, parameter_1), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Unsupported output data type for Dot op.")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitMatMul, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("pm.run(triton_module.get()).succeeded()"))); +} + +TEST_F(TritonSupportTest, + UnsupportedDotWithMultipleBatchDimensionsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +triton_gemm___computation { + parameter_0 = f32[2,2,2,2]{3,2,1,0} parameter(0) + parameter_1 = f32[2,2,2,2]{3,2,1,0} parameter(1) + ROOT dot = f32[2,2,2,2]{3,2,1,0} dot(parameter_0, parameter_1), + lhs_contracting_dims={3}, lhs_batch_dims={1,0}, rhs_contracting_dims={2}, + rhs_batch_dims={1,0} +} + +ENTRY e { + parameter_0 = f32[2,2,2,2]{3,2,1,0} parameter(0) + parameter_1 = f32[2,2,2,2]{3,2,1,0} parameter(1) + ROOT triton_gemm = f32[2,2,2,2]{3,2,1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Multiple batch dimensions")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitMatMul, mlir_context_), + tsl::testing::StatusIs(absl::StatusCode::kInternal, + ::testing::HasSubstr("num_batch_dims <= 1"))); +} + +TEST_F(TritonSupportTest, + UnsupportedDotWithNoNonContractingDimensionsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +triton_gemm___computation { + parameter_0 = f32[2]{0} parameter(0) + parameter_1 = f32[2]{0} parameter(1) + ROOT dot = f32[] dot(parameter_0, parameter_1), + lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY e { + parameter_0 = f32[2]{0} parameter(0) + parameter_1 = f32[2]{0} parameter(1) + ROOT triton_gemm = f32[] fusion(parameter_0, parameter_1), kind=kCustom, + calls=triton_gemm___computation, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_gemm___computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("No non-contracting dimensions.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("non_contracting_dims.size() == 1"))); +} + +using ReduceConstTest = TritonSupportTestWithParam; +TEST_P(ReduceConstTest, + IsTritonSupportedExecutesCorrectlyForReduceWithConstInit) { + PrimitiveType data_type; + HloOpcode opcode; + std::tie(data_type, opcode) = GetParam(); + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE) && + data_type == BF16) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + + const std::string kHloTestTemplate = R"( +HloModule t +add { + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT add = $0[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = $0[125,127]{1,0} parameter(0) + multiply_0 = $0[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = $0[] constant(0) + reduce = $0[125]{0} $1(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast = $0[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = $0[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = $0[125,127]{1,0} parameter(0) + ROOT triton_softmax = $0[125,127]{1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + const std::string hlo_test = absl::Substitute( + kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type), + HloOpcodeString(opcode)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_test)); + + const HloComputation* computation = + module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, opcode); + if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) { + float tolerance = getTolerance(data_type); + EXPECT_OK(ApplyFloatNormalization(module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance})); + } else { + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, + config_, &llvm_module_, &EmitSoftMax, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("Failed to compile Triton kernel"))); + } +} + +INSTANTIATE_TEST_SUITE_P( + ReduceConstTestSuite, ReduceConstTest, + ::testing::Combine(::testing::Values(F16, F32, BF16), + ::testing::Values(HloOpcode::kReduce)), + TestParamsToString); + +TEST_F(TritonSupportTest, + SupportedReduceWithConvertConstantIsCodegenedSuccessfullyWithTriton) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "No BF16 before Ampere."; + } + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = bf16[] constant(0) + convert_0 = f32[] convert(constant_0) + reduce = f32[125]{0} reduce(multiply_0, convert_0), dimensions={1}, to_apply=add + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0), kind=kCustom, + calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_TRUE(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .CanFuse()); + EXPECT_OK(ApplyFloatNormalization(hlo_module.get())); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(hlo_module), ErrorSpec{/*aabs=*/2e-4, /*arel=*/2e-4})); +} + +TEST_F( + TritonSupportTest, + UnsupportedReduceWithMoreThanOneReduceDimensionsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + multiply_0 = f32[2,125,127]{2,1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce = f32[2]{0} reduce(multiply_0, constant_0), dimensions={1,2}, to_apply=add + broadcast = f32[2,125,127]{2,1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[2,125,127]{2,1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + ROOT triton_softmax = f32[2,125,127]{2,1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr( + "Reduction is not a row-reduction of a single operand.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReduceWithNoneLastReduceDimensionFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + multiply_0 = f32[2,125,127]{2,1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce = f32[2,127]{1,0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add + broadcast = f32[2,125,127]{2,1,0} broadcast(reduce), dimensions={0,2} + ROOT multiply = f32[2,125,127]{2,1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[2,125,127]{2,1,0} parameter(0) + ROOT triton_softmax = f32[2,125,127]{2,1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr( + "Reduction is not a row-reduction of a single operand.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReduceWithMoreThanOneOperandsFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_2 = f32[] parameter(1) + Arg_1 = f32[] parameter(2) + Arg_3 = f32[] parameter(3) + add_0 = f32[] add(Arg_0, Arg_2) + add_1 = f32[] add(Arg_1, Arg_3) + ROOT pair = (f32[], f32[]) tuple(add_0, add_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127] parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + tuple_0 = (f32[125]{0}, f32[125]{0}) reduce(multiply_0, multiply_0, constant_0, constant_0), dimensions={1}, to_apply=add + reduce = f32[125]{0} get-tuple-element(tuple_0), index=0 + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + EXPECT_THAT( + IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Unsupported output data type for Reduce op.")); + EXPECT_THAT(TritonFusionAnalysis::Execute(*computation), + tsl::testing::StatusIs( + absl::StatusCode::kFailedPrecondition, + ::testing::HasSubstr( + "Can not propagate dim orders and requirements"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReduceWithNonConstReduceValueFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + init = f32[] parameter(1) + reduce = f32[125]{0} reduce(multiply_0, init), dimensions={1}, to_apply=add + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + parameter_1 = f32[] parameter(1) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0, parameter_1), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT(IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Reduction init value should be a constant " + "or a convert of a constant.")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitSoftMax, mlir_context_), + tsl::testing::StatusIs( + absl::StatusCode::kInternal, + ::testing::HasSubstr("operand->opcode() == HloOpcode::kConstant"))); +} + +TEST_F(TritonSupportTest, + UnsupportedReductionComputationFailsGracefullyWithTriton) { + const std::string kHloTest = R"( +HloModule t +custom_call { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT custom_call = f32[] custom-call(Arg_0, Arg_1), custom_call_target="foo" +} + +triton_softmax_computation { + parameter_0 = f32[125,127]{1,0} parameter(0) + multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0) + constant_0 = f32[] constant(0) + reduce = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=custom_call + broadcast = f32[125,127]{1,0} broadcast(reduce), dimensions={0} + ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast) +} + +ENTRY main { + parameter_0 = f32[125,127]{1,0} parameter(0) + ROOT triton_softmax = f32[125,127]{1,0} fusion(parameter_0), + kind=kCustom, calls=triton_softmax_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_softmax"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloTest)); + + const HloComputation* computation = + hlo_module->GetComputationWithName("triton_softmax_computation"); + ASSERT_TRUE(computation != nullptr); + const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode( + *computation, HloOpcode::kReduce); + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(GetCudaComputeCapability()); + EXPECT_THAT( + IsTritonSupportedInstruction(*instr, GetCudaComputeCapability()) + .Explain(), + ::testing::HasSubstr("Unsupported reduction computation by Triton.")); + EXPECT_THAT( + TritonWrapper(*TritonFusionAnalysis::Execute(*computation), "test_fn", + computation, GetCudaComputeCapability(), dev_info, config_, + &llvm_module_, &EmitSoftMax, mlir_context_), + tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr("Unsupported operation"))); +} +} // namespace +} // namespace gpu +} // namespace xla From a829ac043c3531ad514febe0a7dfa003621a6d7f Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 29 Mar 2024 00:54:46 -0700 Subject: [PATCH 075/124] [xla:gpu] Create fake buffer allocations for embedded thunk This is required in cases where embedded thunk arguments share the same buffer (i.e. they are located at different offsets of the same buffer) PiperOrigin-RevId: 620179451 --- .../xla/xla/service/gpu/fusions/custom.cc | 32 ++- .../gpu/runtime/address_computation_thunk.cc | 12 +- .../gpu/runtime/address_computation_thunk.h | 2 + .../runtime/address_computation_thunk_test.cc | 244 +++++++++++++----- 4 files changed, 205 insertions(+), 85 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 147fb0e89e401b..753d610360c691 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -295,6 +295,12 @@ absl::StatusOr EmitDynamicSlicedGemm( int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf( custom_call.shape().IsArray() ? custom_call.shape() : custom_call.shape().tuple_shapes(0)); + + // Handling cases where multiple operands share the same buffer, with + // different offset by creating new fake allocations so each operand will have + // a different buffer index. The slices can thus always start at offset 0. + // AddressComputationThunk will take care of the offset adjustment. + std::vector> fake_allocations(4); if (fusion.shape().IsArray()) { TF_ASSIGN_OR_RETURN(output, get_original_result_slice(&custom_call, /*index=*/{})); @@ -312,8 +318,10 @@ absl::StatusOr EmitDynamicSlicedGemm( &fusion, /*index=*/{1})); slice_instr = nullptr; collect_slice_info(); - slice_workspace_fake = - BufferAllocation::Slice(workspace->allocation(), 0, workspace->size()); + fake_allocations[3] = std::make_unique( + /*index=*/3, workspace->size(), /*color=*/0); + slice_workspace_fake = BufferAllocation::Slice(fake_allocations[3].get(), 0, + workspace->size()); } if (absl::c_all_of(offset_buffer_indices, [&](auto offset_slices) { @@ -331,20 +339,23 @@ absl::StatusOr EmitDynamicSlicedGemm( GemmConfig config, GemmConfig::For(static_cast(&custom_call))); - // TODO(vuson): handle cases where LHS and RHS share the same buffer, with - // different offset. In such cases, the fake slices need to contain the - // correct offset instead of default value 0. int64_t lhs_byte_size = ShapeUtil::ByteSizeOf(custom_call.operand(0)->shape()); - BufferAllocation::Slice slice_lhs_fake(lhs_slice.allocation(), 0, + fake_allocations[0] = std::make_unique( + /*index=*/0, lhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_lhs_fake(fake_allocations[0].get(), 0, lhs_byte_size); int64_t rhs_byte_size = ShapeUtil::ByteSizeOf(custom_call.operand(1)->shape()); - BufferAllocation::Slice slice_rhs_fake(rhs_slice.allocation(), 0, + fake_allocations[1] = std::make_unique( + /*index=*/1, rhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_rhs_fake(fake_allocations[1].get(), 0, rhs_byte_size); - BufferAllocation::Slice slice_out_fake(output.allocation(), 0, + fake_allocations[2] = std::make_unique( + /*index=*/2, out_fake_byte_size, /*color=*/0); + BufferAllocation::Slice slice_out_fake(fake_allocations[2].get(), 0, out_fake_byte_size); ThunkSequence seq; seq.emplace_back(std::make_unique( @@ -358,7 +369,8 @@ absl::StatusOr EmitDynamicSlicedGemm( auto thunk = std::make_unique( Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), std::make_unique(std::move(seq)), arguments, - offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes); + std::move(fake_allocations), offset_buffer_indices, orig_shapes, + sliced_shapes, offset_byte_sizes); FusionEmissionResult result; result.thunks.push_back(std::move(thunk)); @@ -602,6 +614,8 @@ absl::StatusOr AddressComputationFusion::Emit( absl::StatusOr DynamicAddressComputationFusion::Emit( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { + // std::cerr << "TYB \n" + // << fusion.fused_instructions_computation()->ToString() << '\n'; const HloFusionAdaptor& adaptor = analysis_.fusion(); auto maybe_custom_call_adaptor = HloFindIf( adaptor.GetRoots(), adaptor, diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc index 705aede8672f5f..15efc3e89b4b5b 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc @@ -46,6 +46,7 @@ namespace gpu { AddressComputationThunk::AddressComputationThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, std::vector> arguments, + std::vector> fake_allocations, std::vector>> offset_buffer_indices, std::vector> orig_shapes, @@ -55,6 +56,7 @@ AddressComputationThunk::AddressComputationThunk( embedded_thunk_(std::make_unique( ThunkInfo(thunk_info.op), std::move(*embedded_thunk))), embedded_thunk_arguments_(std::move(arguments)), + fake_allocations_(std::move(fake_allocations)), offset_buffer_indices_(std::move(offset_buffer_indices)), orig_shapes_(std::move(orig_shapes)), sliced_shapes_(std::move(sliced_shapes)), @@ -113,8 +115,8 @@ absl::Status AddressComputationThunk::ExecuteOnStream( const ExecuteParams& params) { auto& stream = *params.stream; const BufferAllocations& orig_allocations = *params.buffer_allocations; - std::vector new_buffers(orig_allocations.size(), - se::DeviceMemoryBase()); + std::vector new_buffers( + embedded_thunk_arguments_.size(), se::DeviceMemoryBase()); // Get memory allocation for copying offsets from device. int64_t* offsets_base = [&] { @@ -136,10 +138,9 @@ absl::Status AddressComputationThunk::ExecuteOnStream( // `argument_slice` within `orig_allocations` se::DeviceMemoryBase orig_argument = orig_allocations.GetDeviceAddress(*argument_slice); - auto buffer_idx = argument_slice->index(); if (offset_slice == std::nullopt) { - new_buffers[buffer_idx] = orig_argument; + new_buffers[argument_idx] = orig_argument; continue; } @@ -185,7 +186,8 @@ absl::Status AddressComputationThunk::ExecuteOnStream( new_offset += start * stride; } - new_buffers[buffer_idx] = orig_argument.GetByteSlice(new_offset, new_size); + new_buffers[argument_idx] = + orig_argument.GetByteSlice(new_offset, new_size); } // Safe to create a local BufferAllocations here since buffers are only slices diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h index 05765f215a9e14..e1c0b30d9953aa 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h @@ -45,6 +45,7 @@ class AddressComputationThunk : public Thunk { AddressComputationThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, std::vector> arguments, + std::vector> fake_allocations_, std::vector>> offset_buffer_indices, std::vector> orig_shapes, @@ -63,6 +64,7 @@ class AddressComputationThunk : public Thunk { std::unique_ptr embedded_thunk_; std::vector> embedded_thunk_arguments_; + std::vector> fake_allocations_; std::vector>> offset_buffer_indices_; std::vector> orig_shapes_; diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index 030e45a154ed7c..8e43b77cc6c04a 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -83,17 +83,30 @@ TEST(AddressComputationThunkTest, SlicedGemm) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); - BufferAllocation alloc_rhs(/*index=*/1, rhs_length, /*color=*/0); - BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs(fake_allocations.back().get(), 0, + rhs_length); - BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); - BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); - BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); - BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, /*color=*/0); @@ -105,9 +118,6 @@ TEST(AddressComputationThunkTest, SlicedGemm) { BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, offset_length); - BufferAllocation alloc_lhs_fake(/*index=*/0, rhs_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, rhs_length); - // Preparing config for GEMM thunk. auto config = GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -130,6 +140,7 @@ TEST(AddressComputationThunkTest, SlicedGemm) { Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, std::nullopt, std::nullopt}, @@ -212,17 +223,33 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + slice_length); + BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); BufferAllocation alloc_rhs(/*index=*/1, rhs_length, /*color=*/0); BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); - BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); - BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); - BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); - BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, /*color=*/0); @@ -244,12 +271,6 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { BufferAllocation::Slice slice_rhs_offset_1(&alloc_rhs_offset_1, 0, offset_length); - BufferAllocation alloc_lhs_fake(/*index=*/0, slice_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, slice_length); - - BufferAllocation alloc_rhs_fake(/*index=*/1, slice_length, /*color=*/0); - BufferAllocation::Slice slice_rhs_fake(&alloc_rhs_fake, 0, slice_length); - // Preparing config for GEMM thunk. auto config = GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}), {}, {1}, @@ -274,6 +295,7 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), {lhs_offsets, rhs_offsets, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), std::nullopt, @@ -362,17 +384,33 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + slice_length); + BufferAllocation alloc_lhs(/*index=*/0, length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, length); BufferAllocation alloc_rhs(/*index=*/1, length, /*color=*/0); BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, length); - BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); - BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); - BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); - BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, /*color=*/0); @@ -394,12 +432,6 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { BufferAllocation::Slice slice_rhs_offset_1(&alloc_rhs_offset_1, 0, offset_length); - BufferAllocation alloc_lhs_fake(/*index=*/0, slice_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(&alloc_lhs, 0, slice_length); - - BufferAllocation alloc_rhs_fake(/*index=*/1, slice_length, /*color=*/0); - BufferAllocation::Slice slice_rhs_fake(&alloc_rhs, 0, slice_length); - // Preparing config for GEMM thunk. auto config = GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -424,6 +456,7 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), {lhs_offsets, rhs_offsets, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1}), std::nullopt, @@ -542,11 +575,21 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(2); + + // Fake slices for embedded thunk creation. + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0, + slice_length); + BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0); BufferAllocation::Slice slice_src(&alloc_src, 0, src_length); - BufferAllocation alloc_dst(/*index=*/1, dst_length, /*color=*/0); - BufferAllocation::Slice slice_dst(&alloc_dst, 0, dst_length); + fake_allocations.push_back( + std::make_unique(/*index=*/1, dst_length, /*color=*/0)); + BufferAllocation::Slice slice_dst(fake_allocations.back().get(), 0, + dst_length); BufferAllocation alloc_offset_0(/*index=*/2, offset_length, /*color=*/0); BufferAllocation::Slice slice_offset_0(&alloc_offset_0, 0, offset_length); @@ -560,10 +603,6 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { BufferAllocation alloc_offset_3(/*index=*/5, offset_length, /*color=*/0); BufferAllocation::Slice slice_offset_3(&alloc_offset_3, 0, offset_length); - // Fake slices for embedded thunk creation. - BufferAllocation alloc_src_fake(/*index=*/0, slice_length, /*color=*/0); - BufferAllocation::Slice slice_src_fake(&alloc_src_fake, 0, slice_length); - // Preparing custom call thunk: setting up call target and operands + results // buffers. auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); @@ -589,7 +628,7 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_src, slice_dst}, - {slice_offsets, std::nullopt}, + std::move(fake_allocations), {slice_offsets, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8}), std::nullopt}, // Make sure to pass a dst shape with the same rank as src shape (i.e. // original slice result and not bitcasted one) @@ -672,6 +711,19 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(2); + + // Fake slices for embedded thunk creation. + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_dst_fake(fake_allocations.back().get(), 0, + slice_length); + BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0); BufferAllocation::Slice slice_src(&alloc_src, 0, src_length); @@ -710,13 +762,6 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { BufferAllocation::Slice slice_dst_offset_3(&alloc_dst_offset_3, 0, offset_length); - // Fake slices for embedded thunk creation. - BufferAllocation alloc_src_fake(/*index=*/0, slice_length, /*color=*/0); - BufferAllocation::Slice slice_src_fake(&alloc_src_fake, 0, slice_length); - - BufferAllocation alloc_dst_fake(/*index=*/1, slice_length, /*color=*/0); - BufferAllocation::Slice slice_dst_fake(&alloc_dst_fake, 0, slice_length); - // Preparing custom call thunk: setting up call target and operands + results // buffers. auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); @@ -746,7 +791,7 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_src, slice_dst}, - {slice_src_offsets, slice_dst_offsets}, + std::move(fake_allocations), {slice_src_offsets, slice_dst_offsets}, {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2}), ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}, // Make sure to pass a dst shape with the same rank as src shape (i.e. @@ -849,6 +894,28 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0, + 1024 * 1024); + BufferAllocation alloc_lhs(/*index=*/1, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); @@ -871,9 +938,6 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) { BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, offset_length); - BufferAllocation alloc_lhs_fake(/*index=*/1, rhs_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, rhs_length); - // Preparing config for GEMM thunk. auto config = GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -886,8 +950,8 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) { // Creating embedded GEMM thunk. ThunkSequence seq; seq.emplace_back(std::make_unique( - Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs, - slice_out, slice_workspace, /*deterministic=*/true)); + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. std::vector lhs_offsets{slice_lhs_offset_0, @@ -896,6 +960,7 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) { Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, std::nullopt, std::nullopt}, @@ -977,6 +1042,28 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0, + 1024 * 1024); + BufferAllocation alloc_lhs(/*index=*/7, lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); @@ -999,9 +1086,6 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) { BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, offset_length); - BufferAllocation alloc_lhs_fake(/*index=*/7, rhs_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, rhs_length); - // Preparing config for GEMM thunk. auto config = GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -1014,8 +1098,8 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) { // Creating embedded GEMM thunk. ThunkSequence seq; seq.emplace_back(std::make_unique( - Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs, - slice_out, slice_workspace, /*deterministic=*/true)); + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); // Wrapping address computation thunk around the GEMM thunk. std::vector lhs_offsets{slice_lhs_offset_0, @@ -1024,6 +1108,7 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) { Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, std::nullopt, std::nullopt}, @@ -1106,17 +1191,30 @@ TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + BufferAllocation alloc_lhs(/*index=*/0, 3 * lhs_length, /*color=*/0); BufferAllocation::Slice slice_lhs(&alloc_lhs, lhs_length, lhs_length); - BufferAllocation alloc_rhs(/*index=*/1, rhs_length, /*color=*/0); - BufferAllocation::Slice slice_rhs(&alloc_rhs, 0, rhs_length); + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs(fake_allocations.back().get(), 0, + rhs_length); - BufferAllocation alloc_out(/*index=*/2, out_length, /*color=*/0); - BufferAllocation::Slice slice_out(&alloc_out, 0, out_length); + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0, + out_length); - BufferAllocation alloc_workspace(/*index=*/3, 1024 * 1024, /*color=*/0); - BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0, + 1024 * 1024); BufferAllocation alloc_lhs_offset_0(/*index=*/4, offset_length, /*color=*/0); @@ -1128,9 +1226,6 @@ TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) { BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, offset_length); - BufferAllocation alloc_lhs_fake(/*index=*/0, rhs_length, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(&alloc_lhs_fake, 0, rhs_length); - // Preparing config for GEMM thunk. auto config = GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -1153,6 +1248,7 @@ TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) { Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, std::nullopt, std::nullopt}, @@ -1244,6 +1340,19 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) { // Prepare embedded and address computation thunks. // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(2); + + // Fake slices for embedded thunk creation. + fake_allocations.push_back(std::make_unique( + /*index=*/0, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0, + slice_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/1, slice_length, /*color=*/0)); + BufferAllocation::Slice slice_dst_fake(fake_allocations.back().get(), 0, + slice_length); + BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0); BufferAllocation::Slice slice_src(&alloc_src, 0, src_length); @@ -1282,13 +1391,6 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) { BufferAllocation::Slice slice_dst_offset_3(&alloc_dst_offset_3, 0, offset_length); - // Fake slices for embedded thunk creation. - BufferAllocation alloc_src_fake(/*index=*/0, slice_length, /*color=*/0); - BufferAllocation::Slice slice_src_fake(&alloc_src_fake, 0, slice_length); - - BufferAllocation alloc_dst_fake(/*index=*/1, slice_length, /*color=*/0); - BufferAllocation::Slice slice_dst_fake(&alloc_dst_fake, 0, slice_length); - // Preparing custom call thunk: setting up call target and operands + results // buffers. auto registration = xla::ffi::FindHandler("__xla_test$$memcpy", PLATFORM); @@ -1318,7 +1420,7 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) { AddressComputationThunk thunk( Thunk::ThunkInfo(nullptr), std::make_unique(std::move(seq)), {slice_src, slice_dst}, - {slice_src_offsets, slice_dst_offsets}, + std::move(fake_allocations), {slice_src_offsets, slice_dst_offsets}, {ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2}), ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})}, // Make sure to pass a dst shape with the same rank as src shape (i.e. From 06ef8a60031342ed6ba72a572aa509837a97be8b Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Fri, 29 Mar 2024 01:11:55 -0700 Subject: [PATCH 076/124] Convert quantized stablehlo.constant to tfl.pseudo_qconst PiperOrigin-RevId: 620182445 --- .../uniform-quantized-stablehlo-to-tfl.mlir | 29 +++++++++++++++++++ ...uniform_quantized_stablehlo_to_tfl_pass.cc | 23 +++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 5653dfeb9f2b8f..76699b8c860c23 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -1506,3 +1506,32 @@ func.func @add_i32(%arg0: tensor<1x3x!quant.uniform tensor<1x2x4x5x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> + return %0 : tensor<1x2x4x5x!quant.uniform> +} + +// CHECK: %[[QCONST:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} +// CHECK-SAME: () -> tensor<1x2x4x5x!quant.uniform> +// CHECK: return %[[QCONST]] + +// ----- + +// Tests that a float `stablehlo.constant` is not converted into `tfl.qconst`. + +// CHECK-LABEL: func @float_constant +func.func @float_constant() -> tensor<1x2x4x5xf32> { + %0 = stablehlo.constant() {value = dense<1.0> : tensor<1x2x4x5xf32>} : () -> tensor<1x2x4x5xf32> + return %0 : tensor<1x2x4x5xf32> +} + +// CHECK: stablehlo.constant +// CHECK-NOT: tfl.pseudo_qconst +// CHECK-NOT: tfl.pseudo_const +// CHECK-NOT: arith.constant diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index f9417d6da30274..4148ef49f6604a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -2102,20 +2102,39 @@ class RewriteQuantizedAddOp : public OpRewritePattern { } }; +// Rewrites quantized `stablehlo.constant` to `tfl.pseudo_qconst`. +class RewriteQuantizedConstantOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(stablehlo::ConstantOp op) const override { + return success(IsQuantizedTensorType(op.getOutput().getType())); + } + + void rewrite(stablehlo::ConstantOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + op, /*qtype=*/TypeAttr::get(op.getOutput().getType()), + /*value=*/op.getValue()); + } +}; + void UniformQuantizedStableHloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); RewritePatternSet patterns(&ctx); patterns.add(&ctx); + RewriteQuantizedTransposeOp>(&ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " From f2a78c0e2beb95fcae7b6aa346802f36ddb62219 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 29 Mar 2024 01:26:57 -0700 Subject: [PATCH 077/124] [xla:gpu][NFC] Add AddressComputationThunk test with GEMM operands sharing the same buffer PiperOrigin-RevId: 620184639 --- .../runtime/address_computation_thunk_test.cc | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc index 8e43b77cc6c04a..ee6b4eee6b6164 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -1511,4 +1511,156 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) { ASSERT_EQ(out, ref); } +TEST(AddressComputationThunkTest, SlicedOperandsSameBufferGemm) { + se::StreamExecutor* executor = GpuExecutor(); + + se::Stream stream(executor); + TF_ASSERT_OK(stream.Initialize()); + + int64_t lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 3 * 1; + int64_t out_length = sizeof(float) * 1 * 1; + int64_t offset_length = sizeof(int64_t); + + // Step 1: + // Prepare embedded and address computation thunks. + + // Preparing buffer allocation slices for thunk creations. + std::vector> fake_allocations(4); + + fake_allocations.push_back( + std::make_unique(/*index=*/0, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/1, rhs_length, /*color=*/0)); + BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0, + rhs_length); + + fake_allocations.push_back( + std::make_unique(/*index=*/2, out_length, /*color=*/0)); + BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0, + out_length); + + fake_allocations.push_back(std::make_unique( + /*index=*/3, 1024 * 1024, /*color=*/0)); + BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0, + 1024 * 1024); + + BufferAllocation alloc(/*index=*/0, lhs_length + rhs_length + out_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc, 0, lhs_length); + BufferAllocation::Slice slice_rhs(&alloc, lhs_length, rhs_length); + BufferAllocation::Slice slice_out(&alloc, lhs_length + rhs_length, + out_length); + + BufferAllocation alloc_workspace(/*index=*/1, 1024 * 1024, /*color=*/0); + BufferAllocation::Slice slice_workspace(&alloc_workspace, 0, 1024 * 1024); + + BufferAllocation alloc_lhs_offset_0(/*index=*/2, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_0(&alloc_lhs_offset_0, 0, + offset_length); + + BufferAllocation alloc_lhs_offset_1(/*index=*/3, offset_length, + /*color=*/0); + BufferAllocation::Slice slice_lhs_offset_1(&alloc_lhs_offset_1, 0, + offset_length); + + // Preparing config for GEMM thunk. + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {1, 1}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Creating embedded GEMM thunk. + ThunkSequence seq; + seq.emplace_back(std::make_unique( + Thunk::ThunkInfo(nullptr), config.value(), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, /*deterministic=*/true)); + + // Wrapping address computation thunk around the GEMM thunk. + std::vector lhs_offsets{slice_lhs_offset_0, + slice_lhs_offset_1}; + AddressComputationThunk thunk( + Thunk::ThunkInfo(nullptr), + std::make_unique(std::move(seq)), + {slice_lhs, slice_rhs, slice_out, slice_workspace}, + std::move(fake_allocations), + {lhs_offsets, std::nullopt, std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}, + {ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt, + std::nullopt, std::nullopt}, + {sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}); + + // Step 2: + // Execute address computation thunk. + // + + // Preparing memory for thunk arguments. + // lhs = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + // + // The real `lhs` tensor will look more like this: + // lhs = [1.0, 2.0, 3.0, 4.0, + // 5.0, 6.0, 7.0, 8.0] + // The `lhs` slice that we want to use will be equivalent to this static + // slice op: + // f32[1,3]{1,0} slice(lhs), slice={[0:1], [1:4]} + se::DeviceMemory buffer = + executor->AllocateArray(lhs_length + rhs_length + out_length); + TF_ASSERT_OK(stream.MemZero(&buffer, lhs_length + rhs_length + out_length)); + + se::DeviceMemoryBase lhs = buffer.GetByteSlice(0, lhs_length); + std::vector lhs_arr{1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream.Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + // rhs = [1.0, + // 1.0, + // 1.0] + se::DeviceMemoryBase rhs = buffer.GetByteSlice(lhs_length, rhs_length); + std::vector rhs_arr(3, 1); + TF_ASSERT_OK(stream.Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemoryBase out = + buffer.GetByteSlice(lhs_length + rhs_length, out_length); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream.MemZero(&workspace, 1024 * 1024)); + + se::DeviceMemory lhs_offset_0 = executor->AllocateArray(1); + se::DeviceMemory lhs_offset_1 = executor->AllocateArray(1); + std::vector lhs_offset_arr{0, 1}; + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_0, &lhs_offset_arr[0], offset_length)); + TF_ASSERT_OK(stream.Memcpy(&lhs_offset_1, &lhs_offset_arr[1], offset_length)); + + // Preparing parameters for thunk execution. + ServiceExecutableRunOptions run_options; + BufferAllocations allocations({buffer, workspace, lhs_offset_0, lhs_offset_1}, + 0, executor->GetAllocator()); + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, &stream, &stream, {}, nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK( + thunk.Initialize({executor, source, &allocations, &stream, &stream})); + + // Executing address computation thunk. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + + // Copying `out` data back to host for verification. + std::vector dst(1, 0); + TF_ASSERT_OK(stream.Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({9})); +} + } // namespace xla::gpu From 17aaaa5ed5da664a51efa220a32c54fbf5adb985 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 02:02:17 -0700 Subject: [PATCH 078/124] compat: Update forward compatibility horizon to 2024-03-29 PiperOrigin-RevId: 620190038 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index d3a357f833f2ef..149839623ef935 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 28) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 29) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 30166624ca249a6bdba5487ab0374f124fba566a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 02:02:26 -0700 Subject: [PATCH 079/124] Update GraphDef version to 1816. PiperOrigin-RevId: 620190062 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index df98002eacc475..6352dcf15edd44 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1815 // Updated: 2024/3/28 +#define TF_GRAPH_DEF_VERSION 1816 // Updated: 2024/3/29 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 06f2799e4630d759841689f04a374dba89774983 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 29 Mar 2024 02:31:43 -0700 Subject: [PATCH 080/124] [xla:gpu][NFC] Use meaningful constexpr PiperOrigin-RevId: 620194665 --- .../xla/xla/service/gpu/fusions/custom.cc | 61 +++++++++++-------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 753d610360c691..8c2229506e35d5 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -70,6 +70,12 @@ namespace xla { namespace gpu { namespace { +constexpr unsigned kLHSOperandIndex = 0; +constexpr unsigned kRHSOperandIndex = 1; + +constexpr unsigned kGEMMOutputBufferIndex = 0; +constexpr unsigned kGEMMWorkspaceBufferIndex = 1; + absl::StatusOr> BuildCustomKernelThunkForFusion( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, CustomKernel custom_kernel) { @@ -144,12 +150,14 @@ absl::StatusOr EmitGemm( TF_ASSIGN_OR_RETURN( BufferAllocation::Slice lhs_slice, GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion, - *custom_call.operand(0), /*index=*/{})); + *custom_call.operand(kLHSOperandIndex), + /*index=*/{})); TF_ASSIGN_OR_RETURN( BufferAllocation::Slice rhs_slice, GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion, - *custom_call.operand(1), /*index=*/{})); + *custom_call.operand(kRHSOperandIndex), + /*index=*/{})); BufferAllocation::Slice output; std::optional workspace; @@ -161,10 +169,11 @@ absl::StatusOr EmitGemm( TF_ASSIGN_OR_RETURN(output, GetAllocationSlice(buffer_assignment, &fusion, {})); } else { - TF_ASSIGN_OR_RETURN(output, - GetAllocationSlice(buffer_assignment, &fusion, {0})); + TF_ASSIGN_OR_RETURN(output, GetAllocationSlice(buffer_assignment, &fusion, + {kGEMMOutputBufferIndex})); TF_ASSIGN_OR_RETURN(workspace, - GetAllocationSlice(buffer_assignment, &fusion, {1})); + GetAllocationSlice(buffer_assignment, &fusion, + {kGEMMWorkspaceBufferIndex})); } bool deterministic_ops = @@ -249,15 +258,15 @@ absl::StatusOr EmitDynamicSlicedGemm( slice_instr->index_operands().front()->shape().element_type())); }; - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice lhs_slice, - get_original_operand_slice(custom_call.operand(0), /*index=*/{})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice, + get_original_operand_slice( + custom_call.operand(kLHSOperandIndex), /*index=*/{})); collect_slice_info(); slice_instr = nullptr; - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice rhs_slice, - get_original_operand_slice(custom_call.operand(1), /*index=*/{})); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice, + get_original_operand_slice( + custom_call.operand(kRHSOperandIndex), /*index=*/{})); collect_slice_info(); slice_instr = nullptr; @@ -309,13 +318,15 @@ absl::StatusOr EmitDynamicSlicedGemm( slice_instr = nullptr; collect_slice_info(); } else { - TF_ASSIGN_OR_RETURN(output, - get_original_result_slice(&custom_call, /*index=*/{0})); + TF_ASSIGN_OR_RETURN( + output, get_original_result_slice(&custom_call, + /*index=*/{kGEMMOutputBufferIndex})); collect_slice_info(); // TODO(vuson): If we want to support slices of workspace, we'd need to // start `HloFindIf` with `get-tuple-element` with the right index. - TF_ASSIGN_OR_RETURN(workspace, GetAllocationSlice(buffer_assignment, - &fusion, /*index=*/{1})); + TF_ASSIGN_OR_RETURN( + workspace, GetAllocationSlice(buffer_assignment, &fusion, + /*index=*/{kGEMMWorkspaceBufferIndex})); slice_instr = nullptr; collect_slice_info(); fake_allocations[3] = std::make_unique( @@ -340,18 +351,18 @@ absl::StatusOr EmitDynamicSlicedGemm( GemmConfig::For(static_cast(&custom_call))); int64_t lhs_byte_size = - ShapeUtil::ByteSizeOf(custom_call.operand(0)->shape()); - fake_allocations[0] = std::make_unique( - /*index=*/0, lhs_byte_size, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(fake_allocations[0].get(), 0, - lhs_byte_size); + ShapeUtil::ByteSizeOf(custom_call.operand(kLHSOperandIndex)->shape()); + fake_allocations[kLHSOperandIndex] = std::make_unique( + /*index=*/kLHSOperandIndex, lhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_lhs_fake( + fake_allocations[kLHSOperandIndex].get(), 0, lhs_byte_size); int64_t rhs_byte_size = - ShapeUtil::ByteSizeOf(custom_call.operand(1)->shape()); - fake_allocations[1] = std::make_unique( - /*index=*/1, rhs_byte_size, /*color=*/0); - BufferAllocation::Slice slice_rhs_fake(fake_allocations[1].get(), 0, - rhs_byte_size); + ShapeUtil::ByteSizeOf(custom_call.operand(kRHSOperandIndex)->shape()); + fake_allocations[kRHSOperandIndex] = std::make_unique( + /*index=*/kRHSOperandIndex, rhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_rhs_fake( + fake_allocations[kRHSOperandIndex].get(), 0, rhs_byte_size); fake_allocations[2] = std::make_unique( /*index=*/2, out_fake_byte_size, /*color=*/0); From 147638c8d4391b3606e0ddcc40d14e6aa5a0f1fa Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Fri, 29 Mar 2024 02:52:05 -0700 Subject: [PATCH 081/124] Split hybrid quantized dot-like StableHLO ops into TFLite dequantize and float op Hybrid quantized op has semantics for weight-only quantization within StableHLO, so it should be splitted into dequantize and float op for legalization towards TFLite. PiperOrigin-RevId: 620197428 --- .../uniform-quantized-stablehlo-to-tfl.mlir | 45 +++++++++++++++++++ ...uniform_quantized_stablehlo_to_tfl_pass.cc | 35 ++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index 76699b8c860c23..7107f7dcb08a45 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -1535,3 +1535,48 @@ func.func @float_constant() -> tensor<1x2x4x5xf32> { // CHECK-NOT: tfl.pseudo_qconst // CHECK-NOT: tfl.pseudo_const // CHECK-NOT: arith.constant + +// ----- + +// Tests that a hybrid quantized dot_general is splitted into dequantize and float +// dot_general. + +// CHECK-LABEL: func @dot_general_hybrid +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x3x4xf32> +func.func @dot_general_hybrid(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x5xf32> { + %0 = stablehlo.constant() {value = dense<1> : tensor<1x2x4x5xi8>} : () -> tensor<1x2x4x5x!quant.uniform> + %1 = "stablehlo.dot_general"(%arg0, %0) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [3], + rhs_contracting_dimensions = [2]>, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x2x3x4xf32>, tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x3x5xf32> + return %1 : tensor<1x2x3x5xf32> +} + +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x2x4x5x!quant.uniform>, value = dense<1> : tensor<1x2x4x5xi8>} +// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<1x2x4x5x!quant.uniform>) -> tensor<1x2x4x5xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG0]], %[[DQ]], batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2], precision = [DEFAULT, DEFAULT] : (tensor<1x2x3x4xf32>, tensor<1x2x4x5xf32>) -> tensor<1x2x3x5xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Tests that a hybrid quantized convolution is splitted into dequantize and +// float convolution. + +// CHECK-LABEL: func @convolution_hybrid +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32> +func.func @convolution_hybrid(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2xf32> + return %1 : tensor<1x3x3x2xf32> +} + +// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x3x4x2x!quant.uniform>, value = dense<3> : tensor<3x3x4x2xi8>} +// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<3x3x4x2x!quant.uniform>) -> tensor<3x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]]) +// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} +// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<3x3x4x2xf32>) -> tensor<1x3x3x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index 4148ef49f6604a..8fed8f3f01ed54 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -2120,12 +2120,45 @@ class RewriteQuantizedConstantOp } }; +// Splits dot-like hybrid quantized StableHLO ops into `tfl.dequantize` and +// float StableHLO op. Legalization of float StableHLO op depends on existing +// passes for conversion of StableHLO -> MHLO -> TF -> TFL. +template +class RewriteHybridQuantizedDotLikeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(OpType op) const override { + if (op->getNumOperands() != 2 || op->getNumResults() != 1) { + return failure(); + } + // Lhs and result should not be quantized and rhs should be quantized. + return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && + IsQuantizedTensorType(op->getOperand(1).getType()) && + !IsQuantizedTensorType(op->getResult(0).getType())); + } + + void rewrite(OpType op, PatternRewriter& rewriter) const override { + Value rhs = op.getOperand(1); + Type lhs_element_type = + op.getOperand(0).getType().template cast().getElementType(); + Type dequantized_rhs_type = + quant::CloneTypeWithNewElementType(rhs.getType(), lhs_element_type); + auto dq = rewriter.create( + op->getLoc(), /*output=*/dequantized_rhs_type, + /*input=*/rhs); + rewriter.replaceAllUsesExcept(rhs, dq.getOutput(), dq); + } +}; + void UniformQuantizedStableHloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); RewritePatternSet patterns(&ctx); - patterns.add, + RewriteHybridQuantizedDotLikeOp, + RewriteUniformDequantizeOp, RewriteUniformQuantizeOp, RewriteQuantizedAddOp, RewriteQuantizedBroadcastInDimOp, RewriteQuantizedConcatenateOp, RewriteQuantizedConstantOp, RewriteQuantizedConvolutionOp, From bcf3638a87a7eb9c3d160546dce4b4242f5b0d79 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 09:10:25 -0700 Subject: [PATCH 082/124] Deduplicate inferred mesh shapes when try_multiple_mesh_shapes=true. PiperOrigin-RevId: 620258542 --- .../auto_sharding/auto_sharding_util.cc | 43 ++++++++----------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 0cb674711c9b66..baa827febc909a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -2199,29 +2199,6 @@ void EnumerateAllPossibleMeshShapesHelper( } } -std::vector> EnumerateAllPossibleMeshShapes( - const int64_t num_devices, int num_mesh_dims, bool symmetrical_mesh_dims) { - std::vector> result; - EnumerateAllPossibleMeshShapesHelper(num_devices, num_mesh_dims, {}, result); - - if (symmetrical_mesh_dims) { - absl::flat_hash_set> dedup_result; - for (const std::vector& mesh_shape : result) { - dedup_result.insert( - absl::btree_multiset(mesh_shape.begin(), mesh_shape.end())); - } - - result.clear(); - - for (const absl::btree_multiset& mesh_shape_set : dedup_result) { - result.push_back( - std::vector(mesh_shape_set.begin(), mesh_shape_set.end())); - } - } - - return result; -} - std::vector> InferMeshShapesToTry( const HloModule& module) { int64_t sharding_1d = -1; @@ -2280,10 +2257,24 @@ std::vector> InferOrEnumerateMeshShapesToTry( bool symmetrical_mesh_dims) { std::vector> mesh_shapes = InferMeshShapesToTry(module); if (mesh_shapes.empty()) { - mesh_shapes = spmd::EnumerateAllPossibleMeshShapes( - num_devices, num_mesh_dims, - /* symmetrical_mesh_dims */ symmetrical_mesh_dims); + EnumerateAllPossibleMeshShapesHelper(num_devices, num_mesh_dims, {}, + mesh_shapes); + } + if (symmetrical_mesh_dims) { + absl::flat_hash_set> dedup_result; + for (const std::vector& mesh_shape : mesh_shapes) { + dedup_result.insert( + absl::btree_multiset(mesh_shape.begin(), mesh_shape.end())); + } + + mesh_shapes.clear(); + + for (const absl::btree_multiset& mesh_shape_set : dedup_result) { + mesh_shapes.push_back( + std::vector(mesh_shape_set.begin(), mesh_shape_set.end())); + } } + return mesh_shapes; } From 370c98d1260e1ad5f883cb62fc5400d3031ac0ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 09:17:53 -0700 Subject: [PATCH 083/124] Delete the redundant compilation_cache_test PiperOrigin-RevId: 620259968 --- third_party/xla/xla/tests/BUILD | 23 --- .../xla/xla/tests/compilation_cache_test.cc | 171 ------------------ 2 files changed, 194 deletions(-) delete mode 100644 third_party/xla/xla/tests/compilation_cache_test.cc diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 32788698bf4203..fca608158771f3 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2181,29 +2181,6 @@ xla_test( ], ) -xla_test( - name = "compilation_cache_test", - srcs = ["compilation_cache_test.cc"], - deps = [ - ":client_library_test_base", - ":literal_test_util", - ":test_macros_header", - ":test_utils", - ":xla_internal_test_main", - "//xla:literal", - "//xla:shape_util", - "//xla:statusor", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/client:global_data", - "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:test", - ], -) - xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], diff --git a/third_party/xla/xla/tests/compilation_cache_test.cc b/third_party/xla/xla/tests/compilation_cache_test.cc deleted file mode 100644 index 057015d5233696..00000000000000 --- a/third_party/xla/xla/tests/compilation_cache_test.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright 2017 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/types/span.h" -#include "xla/client/global_data.h" -#include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/literal.h" -#include "xla/shape_util.h" -#include "xla/statusor.h" -#include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" -#include "xla/xla.pb.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace { - -class CompilationCacheTest : public ClientLibraryTestBase { - public: - void ExecuteComputationR0F32(const XlaComputation& computation, - absl::Span arguments, - float expected_result, bool expect_cache_hit) { - ExecutionProfile execution_profile; - Literal result = - client_ - ->ExecuteAndTransfer(computation, arguments, - /*execution_options=*/&execution_options_, - &execution_profile) - .value(); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR0(expected_result), result, error_spec_)); - EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); - } - - void ExecuteComputationR2F32( - const XlaComputation& computation, - absl::Span arguments, - std::initializer_list> expected_result, - bool expect_cache_hit) { - ExecutionProfile execution_profile; - auto data_handle = client_ - ->Execute(computation, arguments, - &execution_options_, &execution_profile) - .value(); - Literal result = client_->Transfer(*data_handle).value(); - EXPECT_TRUE(LiteralTestUtil::Near( - LiteralUtil::CreateR2(expected_result), result, error_spec_)); - EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); - } - - ErrorSpec error_spec_{0.0001}; -}; - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { - XlaBuilder builder(TestName()); - Neg(ConstantR0(&builder, 42.0)); - XlaComputation computation = builder.Build().value(); - - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); -} - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, - DISABLED_ComputationCalledWithDifferentParameters) { - std::unique_ptr data_42 = - client_->TransferToServer(LiteralUtil::CreateR0(42.0f)).value(); - std::unique_ptr data_123 = - client_->TransferToServer(LiteralUtil::CreateR0(123.0f)).value(); - std::unique_ptr data_456 = - client_->TransferToServer(LiteralUtil::CreateR0(456.0f)).value(); - - XlaBuilder builder(TestName()); - Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param")); - XlaComputation computation = builder.Build().value(); - - ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {data_123.get()}, -123.0, - /*expect_cache_hit=*/true); - ExecuteComputationR0F32(computation, {data_456.get()}, -456.0, - /*expect_cache_hit=*/true); - ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, - /*expect_cache_hit=*/true); -} - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { - XlaBuilder builder_neg(TestName() + "_neg"); - Neg(ConstantR0(&builder_neg, 42.0)); - XlaComputation computation_neg = builder_neg.Build().value(); - - XlaBuilder builder_exp(TestName() + "_exp"); - Exp(ConstantR0(&builder_exp, 1.0)); - XlaComputation computation_exp = builder_exp.Build().value(); - - XlaBuilder builder_add(TestName() + "_add"); - Add(ConstantR0(&builder_add, 2.0), - ConstantR0(&builder_add, 3.0)); - XlaComputation computation_add = builder_add.Build().value(); - - ExecuteComputationR0F32(computation_neg, {}, -42.0, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation_exp, {}, 2.7182817, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation_add, {}, 5.0, - /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation_neg, {}, -42.0, - /*expect_cache_hit=*/true); -} - -// TODO(b/74197823): Disabled because there is no cache in the new design. -XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { - // Create two GlobalData arrays with the same shape but different - // layouts. Use these arrays as parameters to a simple computation. If the - // layout of the array changes then computation should be recompiled (cache - // miss). - auto rowmaj_array = LiteralUtil::CreateR2WithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); - auto rowmaj_handle = client_->TransferToServer(rowmaj_array).value(); - - auto colmaj_array = LiteralUtil::CreateR2WithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); - auto colmaj_handle = client_->TransferToServer(colmaj_array).value(); - - XlaBuilder builder(TestName()); - Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); - XlaComputation computation = builder.Build().value(); - - ExecuteComputationR2F32(computation, {colmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/false); - ExecuteComputationR2F32(computation, {colmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/true); - ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/false); - ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/true); - ExecuteComputationR2F32(computation, {colmaj_handle.get()}, - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*expect_cache_hit=*/true); -} - -} // namespace -} // namespace xla From 210d15664debd6a0393a8c9698b217cd5b77e056 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 09:19:31 -0700 Subject: [PATCH 084/124] Restore GOOGLE_CUDA guard in scoped_annotation.h PiperOrigin-RevId: 620260337 --- .../xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h index bfd222a81184e3..c41a2a39a8dc3a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/scoped_annotation.h @@ -35,11 +35,13 @@ namespace tsl::profiler { // TraceCollector until PopAnnotation() is called. template void PushAnnotation(const T& generator) { +#if GOOGLE_CUDA if (auto domain = DefaultProfilerDomain(); TF_PREDICT_FALSE(domain != nullptr)) { RangePush(domain, generator()); return; } +#endif #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { @@ -60,11 +62,13 @@ inline void PopAnnotation() { // fail probably due to compiler in that presubmit config. std::atomic_thread_fence(std::memory_order_acquire); +#if GOOGLE_CUDA if (auto domain = DefaultProfilerDomain(); TF_PREDICT_FALSE(domain != nullptr)) { RangePop(domain); return; } +#endif #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(AnnotationStack::IsEnabled())) { From b36854adf7405d4fc95cd6e6fec0b9b4cc057707 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Fri, 29 Mar 2024 10:00:03 -0700 Subject: [PATCH 085/124] [odml] Remove MHLO from CHLO->StableHLO lowering. Migrate JAX random lowering to StableHLO PiperOrigin-RevId: 620269527 --- tensorflow/compiler/mlir/lite/BUILD | 2 ++ .../mlir/lite/tests/legalize_jax_random.mlir | 20 +++++++++---------- .../compiler/mlir/lite/tf_tfl_passes.cc | 15 ++++++++------ .../lite/transforms/legalize_jax_random.cc | 6 +++--- .../compiler/mlir/lite/transforms/passes.td | 2 +- .../lite/transforms/quantize_variables.cc | 1 + 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index f535d0d1aaea2e..c3826f1bfb935c 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -747,6 +747,7 @@ cc_library( "@local_xla//xla:status", "@local_xla//xla:statusor", "@local_xla//xla/mlir_hlo", + "@stablehlo//:stablehlo_ops", ], ) @@ -886,6 +887,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_xla//xla/mlir_hlo", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir b/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir index d7d77f2e77a97b..76f453d1d3a8aa 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize_jax_random.mlir @@ -3,31 +3,31 @@ // CHECK-LABEL: func @tfl_wrapped_jax_random_normal( // CHECK-SAME: %[[RNG:.*]]: tensor<2xui32>) -> tuple> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<[3, 4]> : tensor<2xi32> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<[3, 4]> : tensor<2xi32> // CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) {custom_code = "RandomStandardNormal", custom_option = #tfl} : (tensor<2xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]] : tuple> +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]] : tuple> // CHECK: return %[[VAL_2]] : tuple> // CHECK: } func.func @tfl_wrapped_jax_random_normal(%arg0: tensor<2xui32>) -> tuple> { // This is a fake jax random normal body. - %0 = mhlo.constant dense<0.0> : tensor<12xf32> - %1 = "mhlo.reshape"(%0) : (tensor<12xf32>) -> tensor<3x4xf32> - %2 = "mhlo.tuple"(%1) : (tensor<3x4xf32>) -> tuple> + %0 = stablehlo.constant dense<0.0> : tensor<12xf32> + %1 = "stablehlo.reshape"(%0) : (tensor<12xf32>) -> tensor<3x4xf32> + %2 = "stablehlo.tuple"(%1) : (tensor<3x4xf32>) -> tuple> func.return %2 : tuple> } // CHECK-LABEL: func @tfl_wrapped_jax_random_uniform( // CHECK-SAME: %[[RNG:.*]]: tensor<2xui32>) -> tuple> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<[1, 2]> : tensor<2xi32> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<[1, 2]> : tensor<2xi32> // CHECK: %[[VAL_1:.*]] = "tfl.custom"(%[[VAL_0]]) {custom_code = "RandomUniform", custom_option = #tfl} : (tensor<2xi32>) -> tensor<1x2xf32> -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_1]] : tuple> +// CHECK: %[[VAL_2:.*]] = stablehlo.tuple %[[VAL_1]] : tuple> // CHECK: return %[[VAL_2]] : tuple> // CHECK: } func.func @tfl_wrapped_jax_random_uniform(%arg0: tensor<2xui32>) -> tuple> { // This is a fake jax random uniform body. - %0 = mhlo.constant dense<0.0> : tensor<2xf32> - %1 = "mhlo.reshape"(%0) : (tensor<2xf32>) -> tensor<1x2xf32> - %2 = "mhlo.tuple"(%1) : (tensor<1x2xf32>) -> tuple> + %0 = stablehlo.constant dense<0.0> : tensor<2xf32> + %1 = "stablehlo.reshape"(%0) : (tensor<2xf32>) -> tensor<1x2xf32> + %2 = "stablehlo.tuple"(%1) : (tensor<1x2xf32>) -> tuple> func.return %2 : tuple> } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 72abf68f852fb5..f4aa97069655e8 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -144,14 +144,17 @@ void AddPreQuantizationStableHloToTfPasses( pass_manager.addPass( mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); - // Add CHLO to StableHLO Decompositions: - // This is needed since we are relying on XlaCallModule uses MHLO - // specific features like mhlo::ErfOp which aren't supported - // in StableHLO, but we have CHLO->StableHLO decompositions to legalize. + // Legalize MHLO to StableHLO should be moved closer to where it is needed + // There are some entry points that start with HLO->MHLO like + // jax_to_tfl_flatbuffer.cc which can likely be updated to emit StableHLO + // to be consistent with other entrypoints. pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + + // Decompose CHLO into StableHLO ops + // TODO(b/331843141): There are some CHLO's like TopK which we could instead + // lower to TFL ops. mlir::stablehlo::experimental::createChloLegalizeToStablehloPipeline( pass_manager); - pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); // The following two passes find specific uniform quantization patterns in // StableHLO and converts them to TFLite ops that accept or produce uniform @@ -168,7 +171,6 @@ void AddPreQuantizationStableHloToTfPasses( pass_manager.addNestedPass( mlir::odml::CreateUniformQuantizedStableHloToTflPass()); - pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); // Legalize jax random to tflite custom op. // The CreateLegalizeJaxRandom Pass has to stay at because we need to replace // the random function body before being inlined. @@ -176,6 +178,7 @@ void AddPreQuantizationStableHloToTfPasses( mlir::TFL::CreateLegalizeJaxRandomPass()); // Canonicalize, CSE etc. + pass_manager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pass_manager.addNestedPass( mlir::createCanonicalizerPass()); pass_manager.addNestedPass(mlir::createCSEPass()); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc index 72120f1502f021..e8bae6eb64280f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc @@ -47,10 +47,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { namespace TFL { @@ -99,7 +99,7 @@ void LegalizeJaxRandomPass::runOnOperation() { } auto result_shape_attr = builder.getI32TensorAttr(result_shape_i32); Value result_shape_tensor = - builder.create(result_shape_attr); + builder.create(result_shape_attr); auto custom_code = IsJaxRandomUniform(func) ? "RandomUniform" : "RandomStandardNormal"; @@ -112,7 +112,7 @@ void LegalizeJaxRandomPass::runOnOperation() { ValueRange(result_shape_tensor_vec), custom_code, attr) .getResult(0); - Value tulple_result = builder.create(random_result); + Value tulple_result = builder.create(random_result); builder.create(tulple_result); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 988ad189a6ec00..eefb109d2b966e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -108,7 +108,7 @@ def LegalizeHashTablesPass : Pass<"tfl-legalize-hashtables-tf", "mlir::ModuleOp" def LegalizeJaxRandomPass : Pass<"tfl-legalize-random", "mlir::func::FuncOp"> { let summary = "Replace jax.random.uniform/normal with tfl.custom."; let constructor = "CreateLegalizeJaxRandomPass()"; - let dependentDialects = ["TFL::TensorFlowLiteDialect"]; + let dependentDialects = ["TFL::TensorFlowLiteDialect", "stablehlo::StablehloDialect"]; } def LegalizeTFPass : Pass<"tfl-legalize-tf", "mlir::func::FuncOp"> { diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc index 33580d1ea95dbc..0d9db051ef27ff 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" From bc4b9d84ac43580f33b235141a9d4abf1a85be33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Fri, 29 Mar 2024 10:07:30 -0700 Subject: [PATCH 086/124] Enhanced zeta readability based on the article Changes based on the Hurwitz Zeta algorithm from the article linked in the comments. PiperOrigin-RevId: 620272234 --- third_party/stablehlo/temporary.patch | 1181 +++++++++++++++++ .../xla/third_party/stablehlo/temporary.patch | 1181 +++++++++++++++++ .../Dialect/chlo/chlo_legalize_to_mhlo.mlir | 834 ++++++------ 3 files changed, 2779 insertions(+), 417 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 70d9744d6e8ae1..94971c07102a21 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -2645,4 +2645,1185 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloTrivialDce.cpp +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +--- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ++++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +@@ -1283,153 +1283,153 @@ + func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = stablehlo.convert %[[X]] : (tensor) -> tensor + // CHECK: %[[TMP_1:.*]] = stablehlo.convert %[[Q]] : (tensor) -> tensor +- // CHECK: %[[TMP_2:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_3:.*]] = stablehlo.negate %[[TMP_0]] +- // CHECK: %[[TMP_4:.*]] = stablehlo.power %[[TMP_1]], %[[TMP_3]] +- // CHECK: %[[TMP_5:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_6:.*]] = stablehlo.add %[[TMP_1]], %[[TMP_5]] +- // CHECK: %[[TMP_7:.*]] = stablehlo.power %[[TMP_6]], %[[TMP_3]] +- // CHECK: %[[TMP_8:.*]] = stablehlo.add %[[TMP_4]], %[[TMP_7]] +- // CHECK: %[[TMP_9:.*]] = stablehlo.add %[[TMP_6]], %[[TMP_5]] +- // CHECK: %[[TMP_10:.*]] = stablehlo.power %[[TMP_9]], %[[TMP_3]] ++ // CHECK-DAG: %[[TMP_2:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_3:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_4:.*]] = stablehlo.negate %[[TMP_0]] ++ // CHECK: %[[TMP_5:.*]] = stablehlo.power %[[TMP_1]], %[[TMP_4]] ++ // CHECK: %[[TMP_6:.*]] = stablehlo.add %[[TMP_1]], %[[TMP_3]] ++ // CHECK: %[[TMP_7:.*]] = stablehlo.power %[[TMP_6]], %[[TMP_4]] ++ // CHECK: %[[TMP_8:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_7]] ++ // CHECK: %[[TMP_9:.*]] = stablehlo.add %[[TMP_6]], %[[TMP_3]] ++ // CHECK: %[[TMP_10:.*]] = stablehlo.power %[[TMP_9]], %[[TMP_4]] + // CHECK: %[[TMP_11:.*]] = stablehlo.add %[[TMP_8]], %[[TMP_10]] +- // CHECK: %[[TMP_12:.*]] = stablehlo.add %[[TMP_9]], %[[TMP_5]] +- // CHECK: %[[TMP_13:.*]] = stablehlo.power %[[TMP_12]], %[[TMP_3]] ++ // CHECK: %[[TMP_12:.*]] = stablehlo.add %[[TMP_9]], %[[TMP_3]] ++ // CHECK: %[[TMP_13:.*]] = stablehlo.power %[[TMP_12]], %[[TMP_4]] + // CHECK: %[[TMP_14:.*]] = stablehlo.add %[[TMP_11]], %[[TMP_13]] +- // CHECK: %[[TMP_15:.*]] = stablehlo.add %[[TMP_12]], %[[TMP_5]] +- // CHECK: %[[TMP_16:.*]] = stablehlo.power %[[TMP_15]], %[[TMP_3]] ++ // CHECK: %[[TMP_15:.*]] = stablehlo.add %[[TMP_12]], %[[TMP_3]] ++ // CHECK: %[[TMP_16:.*]] = stablehlo.power %[[TMP_15]], %[[TMP_4]] + // CHECK: %[[TMP_17:.*]] = stablehlo.add %[[TMP_14]], %[[TMP_16]] +- // CHECK: %[[TMP_18:.*]] = stablehlo.add %[[TMP_15]], %[[TMP_5]] +- // CHECK: %[[TMP_19:.*]] = stablehlo.power %[[TMP_18]], %[[TMP_3]] ++ // CHECK: %[[TMP_18:.*]] = stablehlo.add %[[TMP_15]], %[[TMP_3]] ++ // CHECK: %[[TMP_19:.*]] = stablehlo.power %[[TMP_18]], %[[TMP_4]] + // CHECK: %[[TMP_20:.*]] = stablehlo.add %[[TMP_17]], %[[TMP_19]] +- // CHECK: %[[TMP_21:.*]] = stablehlo.add %[[TMP_18]], %[[TMP_5]] +- // CHECK: %[[TMP_22:.*]] = stablehlo.power %[[TMP_21]], %[[TMP_3]] ++ // CHECK: %[[TMP_21:.*]] = stablehlo.add %[[TMP_18]], %[[TMP_3]] ++ // CHECK: %[[TMP_22:.*]] = stablehlo.power %[[TMP_21]], %[[TMP_4]] + // CHECK: %[[TMP_23:.*]] = stablehlo.add %[[TMP_20]], %[[TMP_22]] +- // CHECK: %[[TMP_24:.*]] = stablehlo.add %[[TMP_21]], %[[TMP_5]] +- // CHECK: %[[TMP_25:.*]] = stablehlo.power %[[TMP_24]], %[[TMP_3]] ++ // CHECK: %[[TMP_24:.*]] = stablehlo.add %[[TMP_21]], %[[TMP_3]] ++ // CHECK: %[[TMP_25:.*]] = stablehlo.power %[[TMP_24]], %[[TMP_4]] + // CHECK: %[[TMP_26:.*]] = stablehlo.add %[[TMP_23]], %[[TMP_25]] +- // CHECK: %[[TMP_27:.*]] = stablehlo.add %[[TMP_24]], %[[TMP_5]] +- // CHECK: %[[TMP_28:.*]] = stablehlo.power %[[TMP_27]], %[[TMP_3]] ++ // CHECK: %[[TMP_27:.*]] = stablehlo.add %[[TMP_24]], %[[TMP_3]] ++ // CHECK: %[[TMP_28:.*]] = stablehlo.power %[[TMP_27]], %[[TMP_4]] + // CHECK: %[[TMP_29:.*]] = stablehlo.add %[[TMP_26]], %[[TMP_28]] +- // CHECK: %[[TMP_30:.*]] = stablehlo.add %[[TMP_27]], %[[TMP_5]] +- // CHECK: %[[TMP_31:.*]] = stablehlo.power %[[TMP_30]], %[[TMP_3]] ++ // CHECK: %[[TMP_30:.*]] = stablehlo.add %[[TMP_27]], %[[TMP_3]] ++ // CHECK: %[[TMP_31:.*]] = stablehlo.power %[[TMP_30]], %[[TMP_4]] + // CHECK: %[[TMP_32:.*]] = stablehlo.add %[[TMP_29]], %[[TMP_31]] +- // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_5]] +- // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_3]] ++ // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_3]] ++ // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_4]] + // CHECK: %[[TMP_35:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_36:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] +- // CHECK: %[[TMP_37:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] +- // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_37]], %[[TMP_36]] +- // CHECK: %[[TMP_39:.*]] = stablehlo.add %[[TMP_32]], %[[TMP_38]] +- // CHECK: %[[TMP_40:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] +- // CHECK: %[[TMP_41:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_40]] +- // CHECK: %[[TMP_42:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_43:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_42]] +- // CHECK: %[[TMP_44:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_45:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_44]] +- // CHECK: %[[TMP_46:.*]] = stablehlo.multiply %[[TMP_43]], %[[TMP_45]] +- // CHECK: %[[TMP_47:.*]] = stablehlo.constant dense<-1.39544646E-19> +- // CHECK: %[[TMP_48:.*]] = stablehlo.add %[[TMP_2]], %[[TMP_47]] +- // CHECK: %[[TMP_49:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_48]] +- // CHECK: %[[TMP_50:.*]] = stablehlo.multiply %[[TMP_46]], %[[TMP_49]] +- // CHECK: %[[TMP_51:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_52:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_51]] +- // CHECK: %[[TMP_53:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_54:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_53]] +- // CHECK: %[[TMP_55:.*]] = stablehlo.multiply %[[TMP_52]], %[[TMP_54]] +- // CHECK: %[[TMP_56:.*]] = stablehlo.constant dense<5.50900303E-18> +- // CHECK: %[[TMP_57:.*]] = stablehlo.add %[[TMP_50]], %[[TMP_56]] +- // CHECK: %[[TMP_58:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_57]] +- // CHECK: %[[TMP_59:.*]] = stablehlo.multiply %[[TMP_55]], %[[TMP_58]] +- // CHECK: %[[TMP_60:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_61:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_60]] +- // CHECK: %[[TMP_62:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_63:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_62]] +- // CHECK: %[[TMP_64:.*]] = stablehlo.multiply %[[TMP_61]], %[[TMP_63]] +- // CHECK: %[[TMP_65:.*]] = stablehlo.constant dense<-2.17486866E-16> +- // CHECK: %[[TMP_66:.*]] = stablehlo.add %[[TMP_59]], %[[TMP_65]] +- // CHECK: %[[TMP_67:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_66]] +- // CHECK: %[[TMP_68:.*]] = stablehlo.multiply %[[TMP_64]], %[[TMP_67]] +- // CHECK: %[[TMP_69:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_70:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_69]] +- // CHECK: %[[TMP_71:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_72:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_71]] +- // CHECK: %[[TMP_73:.*]] = stablehlo.multiply %[[TMP_70]], %[[TMP_72]] +- // CHECK: %[[TMP_74:.*]] = stablehlo.constant dense<8.58606213E-15> +- // CHECK: %[[TMP_75:.*]] = stablehlo.add %[[TMP_68]], %[[TMP_74]] +- // CHECK: %[[TMP_76:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_75]] +- // CHECK: %[[TMP_77:.*]] = stablehlo.multiply %[[TMP_73]], %[[TMP_76]] +- // CHECK: %[[TMP_78:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_79:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_78]] +- // CHECK: %[[TMP_80:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_81:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_80]] +- // CHECK: %[[TMP_82:.*]] = stablehlo.multiply %[[TMP_79]], %[[TMP_81]] +- // CHECK: %[[TMP_83:.*]] = stablehlo.constant dense<-3.3896803E-13> +- // CHECK: %[[TMP_84:.*]] = stablehlo.add %[[TMP_77]], %[[TMP_83]] +- // CHECK: %[[TMP_85:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_84]] +- // CHECK: %[[TMP_86:.*]] = stablehlo.multiply %[[TMP_82]], %[[TMP_85]] +- // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_88:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_87]] +- // CHECK: %[[TMP_89:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_90:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_89]] +- // CHECK: %[[TMP_91:.*]] = stablehlo.multiply %[[TMP_88]], %[[TMP_90]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.constant dense<1.33825364E-11> +- // CHECK: %[[TMP_93:.*]] = stablehlo.add %[[TMP_86]], %[[TMP_92]] +- // CHECK: %[[TMP_94:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.multiply %[[TMP_91]], %[[TMP_94]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_96]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.multiply %[[TMP_97]], %[[TMP_99]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.constant dense<-5.28419031E-10> +- // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_95]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_102]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.multiply %[[TMP_100]], %[[TMP_103]] +- // CHECK: %[[TMP_105:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_105]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.multiply %[[TMP_106]], %[[TMP_108]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.constant dense<2.08767563E-8> +- // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_104]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_111]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.multiply %[[TMP_109]], %[[TMP_112]] +- // CHECK: %[[TMP_114:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_114]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.multiply %[[TMP_115]], %[[TMP_117]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.constant dense<-8.26719599E-7> +- // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_113]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_120]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.multiply %[[TMP_118]], %[[TMP_121]] +- // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_126:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_125]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_124]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.constant dense<3.30687835E-5> +- // CHECK: %[[TMP_129:.*]] = stablehlo.add %[[TMP_122]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_129]] +- // CHECK: %[[TMP_131:.*]] = stablehlo.multiply %[[TMP_127]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_134]] +- // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.constant dense<-0.00138888892> +- // CHECK: %[[TMP_138:.*]] = stablehlo.add %[[TMP_131]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_138]] +- // CHECK: %[[TMP_140:.*]] = stablehlo.multiply %[[TMP_136]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_33]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<0.0833333358> +- // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_143]], %[[TMP_140]] +- // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.add %[[TMP_141]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_39]], %[[TMP_147]] ++ // CHECK: %[[TMP_36:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] ++ // CHECK: %[[TMP_37:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] ++ // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_36]], %[[TMP_37]] ++ // CHECK: %[[TMP_39:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] ++ // CHECK: %[[TMP_40:.*]] = stablehlo.divide %[[TMP_3]], %[[TMP_39]] ++ // CHECK: %[[TMP_41:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_42:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_41]] ++ // CHECK: %[[TMP_43:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_44:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_43]] ++ // CHECK: %[[TMP_45:.*]] = stablehlo.multiply %[[TMP_42]], %[[TMP_44]] ++ // CHECK: %[[TMP_46:.*]] = stablehlo.constant dense<-1.39544646E-19> ++ // CHECK: %[[TMP_47:.*]] = stablehlo.add %[[TMP_2]], %[[TMP_46]] ++ // CHECK: %[[TMP_48:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_47]] ++ // CHECK: %[[TMP_49:.*]] = stablehlo.multiply %[[TMP_45]], %[[TMP_48]] ++ // CHECK: %[[TMP_50:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_51:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_50]] ++ // CHECK: %[[TMP_52:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_53:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_52]] ++ // CHECK: %[[TMP_54:.*]] = stablehlo.multiply %[[TMP_51]], %[[TMP_53]] ++ // CHECK: %[[TMP_55:.*]] = stablehlo.constant dense<5.50900303E-18> ++ // CHECK: %[[TMP_56:.*]] = stablehlo.add %[[TMP_49]], %[[TMP_55]] ++ // CHECK: %[[TMP_57:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_56]] ++ // CHECK: %[[TMP_58:.*]] = stablehlo.multiply %[[TMP_54]], %[[TMP_57]] ++ // CHECK: %[[TMP_59:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_60:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_59]] ++ // CHECK: %[[TMP_61:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_62:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_61]] ++ // CHECK: %[[TMP_63:.*]] = stablehlo.multiply %[[TMP_60]], %[[TMP_62]] ++ // CHECK: %[[TMP_64:.*]] = stablehlo.constant dense<-2.17486866E-16> ++ // CHECK: %[[TMP_65:.*]] = stablehlo.add %[[TMP_58]], %[[TMP_64]] ++ // CHECK: %[[TMP_66:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_65]] ++ // CHECK: %[[TMP_67:.*]] = stablehlo.multiply %[[TMP_63]], %[[TMP_66]] ++ // CHECK: %[[TMP_68:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_69:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_68]] ++ // CHECK: %[[TMP_70:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_71:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_70]] ++ // CHECK: %[[TMP_72:.*]] = stablehlo.multiply %[[TMP_69]], %[[TMP_71]] ++ // CHECK: %[[TMP_73:.*]] = stablehlo.constant dense<8.58606213E-15> ++ // CHECK: %[[TMP_74:.*]] = stablehlo.add %[[TMP_67]], %[[TMP_73]] ++ // CHECK: %[[TMP_75:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_74]] ++ // CHECK: %[[TMP_76:.*]] = stablehlo.multiply %[[TMP_72]], %[[TMP_75]] ++ // CHECK: %[[TMP_77:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_78:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_77]] ++ // CHECK: %[[TMP_79:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_80:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_79]] ++ // CHECK: %[[TMP_81:.*]] = stablehlo.multiply %[[TMP_78]], %[[TMP_80]] ++ // CHECK: %[[TMP_82:.*]] = stablehlo.constant dense<-3.3896803E-13> ++ // CHECK: %[[TMP_83:.*]] = stablehlo.add %[[TMP_76]], %[[TMP_82]] ++ // CHECK: %[[TMP_84:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_83]] ++ // CHECK: %[[TMP_85:.*]] = stablehlo.multiply %[[TMP_81]], %[[TMP_84]] ++ // CHECK: %[[TMP_86:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_87:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_86]] ++ // CHECK: %[[TMP_88:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_89:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_88]] ++ // CHECK: %[[TMP_90:.*]] = stablehlo.multiply %[[TMP_87]], %[[TMP_89]] ++ // CHECK: %[[TMP_91:.*]] = stablehlo.constant dense<1.33825364E-11> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.add %[[TMP_85]], %[[TMP_91]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.multiply %[[TMP_90]], %[[TMP_93]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_98:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_97]] ++ // CHECK: %[[TMP_99:.*]] = stablehlo.multiply %[[TMP_96]], %[[TMP_98]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.constant dense<-5.28419031E-10> ++ // CHECK: %[[TMP_101:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_100]] ++ // CHECK: %[[TMP_102:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_101]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.multiply %[[TMP_99]], %[[TMP_102]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_104]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_107:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_106]] ++ // CHECK: %[[TMP_108:.*]] = stablehlo.multiply %[[TMP_105]], %[[TMP_107]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.constant dense<2.08767563E-8> ++ // CHECK: %[[TMP_110:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_109]] ++ // CHECK: %[[TMP_111:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_110]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.multiply %[[TMP_108]], %[[TMP_111]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_113]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_116:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_115]] ++ // CHECK: %[[TMP_117:.*]] = stablehlo.multiply %[[TMP_114]], %[[TMP_116]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.constant dense<-8.26719599E-7> ++ // CHECK: %[[TMP_119:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_118]] ++ // CHECK: %[[TMP_120:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_119]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.multiply %[[TMP_117]], %[[TMP_120]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_123:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_122]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_125:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_124]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.multiply %[[TMP_123]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.constant dense<3.30687835E-5> ++ // CHECK: %[[TMP_128:.*]] = stablehlo.add %[[TMP_121]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_128]] ++ // CHECK: %[[TMP_130:.*]] = stablehlo.multiply %[[TMP_126]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_134:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_133]] ++ // CHECK: %[[TMP_135:.*]] = stablehlo.multiply %[[TMP_132]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.constant dense<-0.00138888892> ++ // CHECK: %[[TMP_137:.*]] = stablehlo.add %[[TMP_130]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_137]] ++ // CHECK: %[[TMP_139:.*]] = stablehlo.multiply %[[TMP_135]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_33]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.constant dense<0.0833333358> ++ // CHECK: %[[TMP_143:.*]] = stablehlo.add %[[TMP_142]], %[[TMP_139]] ++ // CHECK: %[[TMP_144:.*]] = stablehlo.multiply %[[TMP_141]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_140]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.add %[[TMP_32]], %[[TMP_38]] ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_146]] + // CHECK: %[[TMP_149:.*]] = stablehlo.abs %[[TMP_34]] + // CHECK: %[[TMP_150:.*]] = stablehlo.abs %[[TMP_32]] + // CHECK: %[[TMP_151:.*]] = stablehlo.constant dense<1.401300e-45> +@@ -1456,7 +1456,7 @@ + // CHECK: %[[TMP_172:.*]] = stablehlo.and %[[TMP_169]], %[[TMP_171]] : tensor + // CHECK: %[[TMP_173:.*]] = stablehlo.select %[[TMP_172]], %[[TMP_163]], %[[TMP_155]] + // CHECK: %[[TMP_174:.*]] = stablehlo.select %[[TMP_166]], %[[TMP_173]], %[[TMP_162]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.compare EQ, %[[TMP_0]], %[[TMP_5]], NOTYPE ++ // CHECK: %[[TMP_175:.*]] = stablehlo.compare EQ, %[[TMP_0]], %[[TMP_3]], NOTYPE + // CHECK: %[[TMP_176:.*]] = stablehlo.select %[[TMP_175]], %[[TMP_163]], %[[TMP_174]] + // CHECK: %[[TMP_177:.*]] = stablehlo.convert %[[TMP_176]] : (tensor) -> tensor + %0 = chlo.zeta %arg0, %arg1 : tensor, tensor -> tensor +@@ -1465,8 +1465,7 @@ + + // ----- + +- +-// CHECK-LABEL: @polygamma_f32 ++// CHECK: @polygamma_f32 + // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor { + // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> +@@ -1559,153 +1558,153 @@ + // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<0x7F800000> + // CHECK: %[[TMP_88:.*]] = stablehlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] + // CHECK: %[[TMP_89:.*]] = stablehlo.exponential %[[TMP_88]] +- // CHECK: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_91:.*]] = stablehlo.negate %[[TMP_5]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.power %[[ARG1]], %[[TMP_91]] +- // CHECK: %[[TMP_93:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_91]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_92]], %[[TMP_95]] +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_93]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_91]] ++ // CHECK-DAG: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_91:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.negate %[[TMP_5]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.power %[[ARG1]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_91]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_92]] ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_93]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_91]] ++ // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_92]] + // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_96]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_93]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_91]] ++ // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_92]] + // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_99]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_93]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_92]] + // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_102]], %[[TMP_104]] +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_93]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_92]] + // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_105]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_93]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_92]] + // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_108]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_93]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_92]] + // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_111]], %[[TMP_113]] +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_93]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_92]] + // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_114]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_93]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_92]] + // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_117]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_93]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_91]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] + // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] +- // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_125]], %[[TMP_124]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] +- // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] +- // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.39544646E-19> +- // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] +- // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] +- // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.50900303E-18> +- // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] +- // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] +- // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] +- // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.17486866E-16> +- // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] +- // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] +- // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] +- // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] +- // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] +- // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] +- // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.58606213E-15> +- // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] +- // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] +- // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] +- // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] +- // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] +- // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] +- // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896803E-13> +- // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] +- // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] +- // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] +- // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] +- // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] +- // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.33825364E-11> +- // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] +- // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] +- // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] +- // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] +- // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] +- // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] +- // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.28419031E-10> +- // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] +- // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] +- // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] +- // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] +- // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] +- // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] +- // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767563E-8> +- // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] +- // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] +- // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] +- // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] +- // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] +- // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] +- // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.26719599E-7> +- // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] +- // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] +- // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] +- // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] +- // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] +- // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] +- // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.30687835E-5> +- // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] +- // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] +- // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] +- // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] +- // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] +- // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] +- // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.00138888892> +- // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] +- // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_226]] +- // CHECK: %[[TMP_228:.*]] = stablehlo.multiply %[[TMP_224]], %[[TMP_227]] +- // CHECK: %[[TMP_229:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_230:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] +- // CHECK: %[[TMP_231:.*]] = stablehlo.constant dense<0.0833333358> +- // CHECK: %[[TMP_232:.*]] = stablehlo.add %[[TMP_231]], %[[TMP_228]] +- // CHECK: %[[TMP_233:.*]] = stablehlo.multiply %[[TMP_230]], %[[TMP_232]] +- // CHECK: %[[TMP_234:.*]] = stablehlo.add %[[TMP_229]], %[[TMP_233]] +- // CHECK: %[[TMP_235:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_234]] +- // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_127]], %[[TMP_235]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] ++ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] ++ // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_130:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] ++ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.39544646E-19> ++ // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] ++ // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_139:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] ++ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.50900303E-18> ++ // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_147]] ++ // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] ++ // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] ++ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.17486866E-16> ++ // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] ++ // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] ++ // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] ++ // CHECK: %[[TMP_156:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_157:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_156]] ++ // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] ++ // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] ++ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.58606213E-15> ++ // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] ++ // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] ++ // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] ++ // CHECK: %[[TMP_165:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_166:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_165]] ++ // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] ++ // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] ++ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896803E-13> ++ // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] ++ // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] ++ // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] ++ // CHECK: %[[TMP_174:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_175:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_174]] ++ // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] ++ // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] ++ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.33825364E-11> ++ // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] ++ // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] ++ // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] ++ // CHECK: %[[TMP_183:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_184:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_183]] ++ // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] ++ // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] ++ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.28419031E-10> ++ // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] ++ // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] ++ // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] ++ // CHECK: %[[TMP_192:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_193:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_192]] ++ // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] ++ // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] ++ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767563E-8> ++ // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] ++ // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] ++ // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] ++ // CHECK: %[[TMP_201:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_202:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_201]] ++ // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] ++ // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] ++ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.26719599E-7> ++ // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] ++ // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] ++ // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] ++ // CHECK: %[[TMP_210:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_211:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_210]] ++ // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] ++ // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] ++ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.30687835E-5> ++ // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] ++ // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] ++ // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] ++ // CHECK: %[[TMP_219:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_220:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_219]] ++ // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] ++ // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] ++ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.00138888892> ++ // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] ++ // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] ++ // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] ++ // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] ++ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.0833333358> ++ // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] ++ // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] ++ // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] ++ // CHECK: %[[TMP_234:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_233]] ++ // CHECK: %[[TMP_235:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] ++ // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] + // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] + // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] + // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<1.401300e-45> +@@ -1732,7 +1731,7 @@ + // CHECK: %[[TMP_260:.*]] = stablehlo.and %[[TMP_257]], %[[TMP_259]] + // CHECK: %[[TMP_261:.*]] = stablehlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] + // CHECK: %[[TMP_262:.*]] = stablehlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] +- // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE ++ // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_264:.*]] = stablehlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] + // CHECK: %[[TMP_265:.*]] = stablehlo.multiply %[[TMP_4]], %[[TMP_89]] + // CHECK: %[[TMP_266:.*]] = stablehlo.multiply %[[TMP_265]], %[[TMP_264]] +@@ -1853,8 +1852,7 @@ + + // ----- + +- +-// CHECK-LABEL: @polygamma_f64 ++// CHECK: @polygamma_f64 + // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor { + // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> +@@ -1947,153 +1945,153 @@ + // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<0x7FF0000000000000> + // CHECK: %[[TMP_88:.*]] = stablehlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] + // CHECK: %[[TMP_89:.*]] = stablehlo.exponential %[[TMP_88]] +- // CHECK: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_91:.*]] = stablehlo.negate %[[TMP_5]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.power %[[ARG1]], %[[TMP_91]] +- // CHECK: %[[TMP_93:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_91]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_92]], %[[TMP_95]] +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_93]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_91]] ++ // CHECK-DAG: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_91:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.negate %[[TMP_5]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.power %[[ARG1]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_91]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_92]] ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_93]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_91]] ++ // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_92]] + // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_96]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_93]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_91]] ++ // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_92]] + // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_99]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_93]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_92]] + // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_102]], %[[TMP_104]] +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_93]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_92]] + // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_105]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_93]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_92]] + // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_108]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_93]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_92]] + // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_111]], %[[TMP_113]] +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_93]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_92]] + // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_114]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_93]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_92]] + // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_117]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_93]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_91]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] + // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] +- // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_125]], %[[TMP_124]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] +- // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] +- // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.3954464685812522E-19> +- // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] +- // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] +- // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.5090028283602295E-18> +- // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] +- // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] +- // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] +- // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.1748686985580617E-16> +- // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] +- // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] +- // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] +- // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] +- // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] +- // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] +- // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.5860620562778452E-15> +- // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] +- // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] +- // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] +- // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] +- // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] +- // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] +- // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896802963225832E-13> +- // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] +- // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] +- // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] +- // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] +- // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] +- // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.3382536530684679E-11> +- // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] +- // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] +- // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] +- // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] +- // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] +- // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] +- // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.2841901386874932E-10> +- // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] +- // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] +- // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] +- // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] +- // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] +- // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] +- // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767569878681E-8> +- // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] +- // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] +- // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] +- // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] +- // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] +- // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] +- // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.2671957671957675E-7> +- // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] +- // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] +- // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] +- // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] +- // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] +- // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] +- // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.3068783068783071E-5> +- // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] +- // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] +- // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] +- // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] +- // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] +- // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] +- // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.0013888888888888889> +- // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] +- // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_226]] +- // CHECK: %[[TMP_228:.*]] = stablehlo.multiply %[[TMP_224]], %[[TMP_227]] +- // CHECK: %[[TMP_229:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_230:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] +- // CHECK: %[[TMP_231:.*]] = stablehlo.constant dense<0.083333333333333329> +- // CHECK: %[[TMP_232:.*]] = stablehlo.add %[[TMP_231]], %[[TMP_228]] +- // CHECK: %[[TMP_233:.*]] = stablehlo.multiply %[[TMP_230]], %[[TMP_232]] +- // CHECK: %[[TMP_234:.*]] = stablehlo.add %[[TMP_229]], %[[TMP_233]] +- // CHECK: %[[TMP_235:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_234]] +- // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_127]], %[[TMP_235]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] ++ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] ++ // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_130:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] ++ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.3954464685812522E-19> ++ // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] ++ // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_139:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] ++ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.5090028283602295E-18> ++ // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_147]] ++ // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] ++ // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] ++ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.1748686985580617E-16> ++ // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] ++ // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] ++ // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] ++ // CHECK: %[[TMP_156:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_157:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_156]] ++ // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] ++ // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] ++ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.5860620562778452E-15> ++ // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] ++ // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] ++ // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] ++ // CHECK: %[[TMP_165:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_166:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_165]] ++ // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] ++ // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] ++ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896802963225832E-13> ++ // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] ++ // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] ++ // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] ++ // CHECK: %[[TMP_174:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_175:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_174]] ++ // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] ++ // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] ++ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.3382536530684679E-11> ++ // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] ++ // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] ++ // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] ++ // CHECK: %[[TMP_183:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_184:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_183]] ++ // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] ++ // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] ++ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.2841901386874932E-10> ++ // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] ++ // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] ++ // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] ++ // CHECK: %[[TMP_192:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_193:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_192]] ++ // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] ++ // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] ++ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767569878681E-8> ++ // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] ++ // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] ++ // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] ++ // CHECK: %[[TMP_201:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_202:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_201]] ++ // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] ++ // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] ++ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.2671957671957675E-7> ++ // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] ++ // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] ++ // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] ++ // CHECK: %[[TMP_210:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_211:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_210]] ++ // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] ++ // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] ++ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.3068783068783071E-5> ++ // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] ++ // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] ++ // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] ++ // CHECK: %[[TMP_219:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_220:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_219]] ++ // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] ++ // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] ++ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.0013888888888888889> ++ // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] ++ // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] ++ // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] ++ // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] ++ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.083333333333333329> ++ // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] ++ // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] ++ // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] ++ // CHECK: %[[TMP_234:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_233]] ++ // CHECK: %[[TMP_235:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] ++ // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] + // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] + // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] + // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<4.940660e-324> +@@ -2120,7 +2118,7 @@ + // CHECK: %[[TMP_260:.*]] = stablehlo.and %[[TMP_257]], %[[TMP_259]] + // CHECK: %[[TMP_261:.*]] = stablehlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] + // CHECK: %[[TMP_262:.*]] = stablehlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] +- // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE ++ // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_264:.*]] = stablehlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] + // CHECK: %[[TMP_265:.*]] = stablehlo.multiply %[[TMP_4]], %[[TMP_89]] + // CHECK: %[[TMP_266:.*]] = stablehlo.multiply %[[TMP_265]], %[[TMP_264]] +diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp ++++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +@@ -1575,11 +1575,21 @@ + + static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, + ValueRange args) { +- // Code should match XLA's materializeZeta from chlo_legalize_to_hlo.cc ++ // Implementation ported from: ++ // https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917 ++ // Reference: Johansson, Fredrik. ++ // "Rigorous high-precision computation of the Hurwitz zeta function and its ++ // derivatives." Numerical Algorithms 69.2 (2015): 253-270. ++ // https://arxiv.org/abs/1309.2877 - formula (5) ++ // Notation is more or less kept as a reference to the whitepaper. + assert(args.size() == 2); + Value x = args[0]; + Value q = args[1]; +- static const std::array kZetaCoeffs{ ++ ++ static constexpr auto kTerms = 12; ++ static constexpr auto kIters = 9; ++ static constexpr auto kTwoTermsMinusOne = 2 * kTerms - 1; ++ static constexpr auto kZetaCoeffs = std::array{ + -7.1661652561756670113e18, + 1.8152105401943546773e17, + -4.5979787224074726105e15, +@@ -1596,131 +1606,134 @@ + + // For speed we'll always use 9 iterations for the initial series estimate, + // and a 12 term expansion for the Euler-Maclaurin formula. +- Value a = q; +- Value zero = getConstantLike(rewriter, loc, 0.0, a); +- Value negPower = zero; +- Value negX = rewriter.create(loc, x); +- Value initialSum = rewriter.create(loc, q, negX); +- Value one = getConstantLike(rewriter, loc, 1.0, a); +- for (int i = 0; i < 9; ++i) { +- a = rewriter.create(loc, a, one); +- negPower = rewriter.create(loc, a, negX); +- initialSum = +- rewriter.create(loc, initialSum, negPower); +- } +- +- a = rewriter.create(loc, a, one); +- negPower = rewriter.create(loc, a, negX); ++ Value zero = getConstantLike(rewriter, loc, 0.0, q); ++ Value one = getConstantLike(rewriter, loc, 1.0, q); ++ Value acc = q; ++ Value qNegPower = zero; ++ Value negX = rewriter.create(loc, x); ++ Value powerSum = rewriter.create(loc, q, negX); ++ for (int i = 0; i < kIters; ++i) { ++ acc = rewriter.create(loc, acc, one); ++ qNegPower = rewriter.create(loc, acc, negX); ++ powerSum = ++ rewriter.create(loc, powerSum, qNegPower); ++ } ++ acc = rewriter.create(loc, acc, one); ++ qNegPower = rewriter.create(loc, acc, negX); + Value oneLikeX = getConstantLike(rewriter, loc, 1.0, x); +- Value xMinusOne = +- rewriter.create(loc, x, oneLikeX); +- Value negPowerMulA = +- rewriter.create(loc, negPower, a); +- Value negPowerMulADivXMinusOne = +- rewriter.create(loc, negPowerMulA, xMinusOne); +- Value s = rewriter.create(loc, initialSum, +- negPowerMulADivXMinusOne); +- Value aInverseSquare = rewriter.create( +- loc, one, rewriter.create(loc, a, a)); +- +- Value hornerSum = zero; +- Value factor = one; ++ Value correctionEulerMaclaurin = rewriter.create( ++ loc, rewriter.create(loc, qNegPower, acc), ++ rewriter.create(loc, x, oneLikeX)); ++ ++ // Manual reciprocal of the square root as RsqrtOp produces different results ++ Value rsqrtAcc = rewriter.create( ++ loc, one, rewriter.create(loc, acc, acc)); ++ + // Use Horner's rule for this. + // Note this differs from Cephes which does a 'naive' polynomial evaluation. + // Using Horner's rule allows to avoid some NaN's and Infs from happening, + // resulting in more numerically stable code. +- for (int i = 0; i < 11; ++i) { +- Value factorLhs = rewriter.create( +- loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x)); +- Value factorRhs = rewriter.create( +- loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x)); +- factor = rewriter.create(loc, factorLhs, factorRhs); +- hornerSum = rewriter.create( +- loc, factor, +- rewriter.create( +- loc, aInverseSquare, +- rewriter.create( ++ Value hornerSum = zero; ++ Value hornerProduct = one; ++ ++ for (int i = 0; i < kTerms - 1; ++i) { ++ Value factorLhs = rewriter.create( ++ loc, x, ++ getConstantLike(rewriter, loc, kTwoTermsMinusOne - 1 - 2 * i, x)); ++ Value factorRhs = rewriter.create( ++ loc, x, ++ getConstantLike(rewriter, loc, kTwoTermsMinusOne - 2 - 2 * i, x)); ++ hornerProduct = ++ rewriter.create(loc, factorLhs, factorRhs); ++ hornerSum = rewriter.create( ++ loc, hornerProduct, ++ rewriter.create( ++ loc, rsqrtAcc, ++ rewriter.create( + loc, hornerSum, +- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); +- } +- Value zeroPointFiveLikeNegPower = +- getConstantLike(rewriter, loc, .5, negPower); +- Value xDivA = rewriter.create(loc, x, a); +- s = rewriter.create( +- loc, s, +- rewriter.create( +- loc, negPower, +- rewriter.create( +- loc, zeroPointFiveLikeNegPower, +- rewriter.create( +- loc, xDivA, +- rewriter.create( +- loc, +- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], a), +- hornerSum))))); ++ getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], acc)))); ++ } ++ Value zeroPointFiveLikeQNegPower = ++ getConstantLike(rewriter, loc, .5, qNegPower); ++ Value xDivAcc = rewriter.create(loc, x, acc); ++ Value bernoulliTailTerm = rewriter.create( ++ loc, qNegPower, ++ rewriter.create( ++ loc, zeroPointFiveLikeQNegPower, ++ rewriter.create( ++ loc, xDivAcc, ++ rewriter.create( ++ loc, ++ getConstantLike(rewriter, loc, 1. / kZetaCoeffs[kTerms - 1], ++ acc), ++ hornerSum)))); ++ Value accurateResult = rewriter.create( ++ loc, ++ rewriter.create(loc, powerSum, ++ correctionEulerMaclaurin), ++ bernoulliTailTerm); + + // Use the initial zeta sum without the correction term coming + // from Euler-Maclaurin if it is accurate enough. +- Value absNegPower = rewriter.create(loc, negPower); +- Value absInitialSum = +- rewriter.create(loc, initialSum); +- Value output = rewriter.create( ++ Value absQNegPower = rewriter.create(loc, qNegPower); ++ Value absPowerSum = rewriter.create(loc, powerSum); ++ Value output = rewriter.create( + loc, +- rewriter.create( +- loc, absNegPower, +- rewriter.create( +- loc, absInitialSum, +- getConstantLikeSmallestFiniteValue(rewriter, loc, a)), +- mlir::stablehlo::ComparisonDirection::LT), +- initialSum, s); ++ rewriter.create( ++ loc, absQNegPower, ++ rewriter.create( ++ loc, absPowerSum, ++ getConstantLikeSmallestFiniteValue(rewriter, loc, acc)), ++ ComparisonDirection::LT), ++ powerSum, accurateResult); + + // Function is not defined for x < 1. + Value nan = getConstantLike(rewriter, loc, + std::numeric_limits::quiet_NaN(), x); +- output = rewriter.create( ++ output = rewriter.create( + loc, +- rewriter.create( +- loc, x, oneLikeX, mlir::stablehlo::ComparisonDirection::LT), ++ rewriter.create( ++ loc, x, oneLikeX, ComparisonDirection::LT), + nan, output); + + // For q <= 0, x must be an integer. +- Value qLeZero = rewriter.create( +- loc, q, zero, mlir::stablehlo::ComparisonDirection::LE); +- Value xNotInt = rewriter.create( +- loc, x, rewriter.create(loc, x), +- mlir::stablehlo::ComparisonDirection::NE); ++ Value qLeZero = rewriter.create( ++ loc, q, zero, ComparisonDirection::LE); ++ Value xNotInt = rewriter.create( ++ loc, x, rewriter.create(loc, x), ++ ComparisonDirection::NE); + Value xDomainError = +- rewriter.create(loc, qLeZero, xNotInt); +- output = rewriter.create(loc, xDomainError, nan, ++ rewriter.create(loc, qLeZero, xNotInt); ++ output = rewriter.create(loc, xDomainError, nan, + output); + + // For all integer q <= 0, zeta has a pole. The limit is only defined as + // +inf if x is and even integer. + Value inf = getConstantLike(rewriter, loc, + std::numeric_limits::infinity(), x); +- Value qIsInt = rewriter.create( +- loc, q, rewriter.create(loc, q), +- mlir::stablehlo::ComparisonDirection::EQ); +- Value atPole = rewriter.create(loc, qLeZero, qIsInt); ++ Value qIsInt = rewriter.create( ++ loc, q, rewriter.create(loc, q), ++ ComparisonDirection::EQ); ++ Value atPole = rewriter.create(loc, qLeZero, qIsInt); + Value two = getConstantLike(rewriter, loc, 2.0, x); +- Value xIsInt = rewriter.create( +- loc, x, rewriter.create(loc, x), +- mlir::stablehlo::ComparisonDirection::EQ); +- Value xIsEven = rewriter.create( +- loc, rewriter.create(loc, x, two), zero, +- mlir::stablehlo::ComparisonDirection::EQ); ++ Value xIsInt = rewriter.create( ++ loc, x, rewriter.create(loc, x), ++ ComparisonDirection::EQ); ++ Value xIsEven = rewriter.create( ++ loc, rewriter.create(loc, x, two), zero, ++ ComparisonDirection::EQ); + Value xIsEvenInt = +- rewriter.create(loc, xIsInt, xIsEven); +- output = rewriter.create( ++ rewriter.create(loc, xIsInt, xIsEven); ++ output = rewriter.create( + loc, atPole, +- rewriter.create(loc, xIsEvenInt, inf, nan), ++ rewriter.create(loc, xIsEvenInt, inf, nan), + output); + + // For x = 1, this is the harmonic series and diverges. +- output = rewriter.create( ++ output = rewriter.create( + loc, +- rewriter.create( +- loc, x, one, mlir::stablehlo::ComparisonDirection::EQ), ++ rewriter.create( ++ loc, x, one, ComparisonDirection::EQ), + inf, output); + + return output; diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 70d9744d6e8ae1..94971c07102a21 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -2645,4 +2645,1185 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloTrivialDce.cpp +} // namespace experimental +} // namespace stablehlo +} // namespace mlir +diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +--- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ++++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +@@ -1283,153 +1283,153 @@ + func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = stablehlo.convert %[[X]] : (tensor) -> tensor + // CHECK: %[[TMP_1:.*]] = stablehlo.convert %[[Q]] : (tensor) -> tensor +- // CHECK: %[[TMP_2:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_3:.*]] = stablehlo.negate %[[TMP_0]] +- // CHECK: %[[TMP_4:.*]] = stablehlo.power %[[TMP_1]], %[[TMP_3]] +- // CHECK: %[[TMP_5:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_6:.*]] = stablehlo.add %[[TMP_1]], %[[TMP_5]] +- // CHECK: %[[TMP_7:.*]] = stablehlo.power %[[TMP_6]], %[[TMP_3]] +- // CHECK: %[[TMP_8:.*]] = stablehlo.add %[[TMP_4]], %[[TMP_7]] +- // CHECK: %[[TMP_9:.*]] = stablehlo.add %[[TMP_6]], %[[TMP_5]] +- // CHECK: %[[TMP_10:.*]] = stablehlo.power %[[TMP_9]], %[[TMP_3]] ++ // CHECK-DAG: %[[TMP_2:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_3:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_4:.*]] = stablehlo.negate %[[TMP_0]] ++ // CHECK: %[[TMP_5:.*]] = stablehlo.power %[[TMP_1]], %[[TMP_4]] ++ // CHECK: %[[TMP_6:.*]] = stablehlo.add %[[TMP_1]], %[[TMP_3]] ++ // CHECK: %[[TMP_7:.*]] = stablehlo.power %[[TMP_6]], %[[TMP_4]] ++ // CHECK: %[[TMP_8:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_7]] ++ // CHECK: %[[TMP_9:.*]] = stablehlo.add %[[TMP_6]], %[[TMP_3]] ++ // CHECK: %[[TMP_10:.*]] = stablehlo.power %[[TMP_9]], %[[TMP_4]] + // CHECK: %[[TMP_11:.*]] = stablehlo.add %[[TMP_8]], %[[TMP_10]] +- // CHECK: %[[TMP_12:.*]] = stablehlo.add %[[TMP_9]], %[[TMP_5]] +- // CHECK: %[[TMP_13:.*]] = stablehlo.power %[[TMP_12]], %[[TMP_3]] ++ // CHECK: %[[TMP_12:.*]] = stablehlo.add %[[TMP_9]], %[[TMP_3]] ++ // CHECK: %[[TMP_13:.*]] = stablehlo.power %[[TMP_12]], %[[TMP_4]] + // CHECK: %[[TMP_14:.*]] = stablehlo.add %[[TMP_11]], %[[TMP_13]] +- // CHECK: %[[TMP_15:.*]] = stablehlo.add %[[TMP_12]], %[[TMP_5]] +- // CHECK: %[[TMP_16:.*]] = stablehlo.power %[[TMP_15]], %[[TMP_3]] ++ // CHECK: %[[TMP_15:.*]] = stablehlo.add %[[TMP_12]], %[[TMP_3]] ++ // CHECK: %[[TMP_16:.*]] = stablehlo.power %[[TMP_15]], %[[TMP_4]] + // CHECK: %[[TMP_17:.*]] = stablehlo.add %[[TMP_14]], %[[TMP_16]] +- // CHECK: %[[TMP_18:.*]] = stablehlo.add %[[TMP_15]], %[[TMP_5]] +- // CHECK: %[[TMP_19:.*]] = stablehlo.power %[[TMP_18]], %[[TMP_3]] ++ // CHECK: %[[TMP_18:.*]] = stablehlo.add %[[TMP_15]], %[[TMP_3]] ++ // CHECK: %[[TMP_19:.*]] = stablehlo.power %[[TMP_18]], %[[TMP_4]] + // CHECK: %[[TMP_20:.*]] = stablehlo.add %[[TMP_17]], %[[TMP_19]] +- // CHECK: %[[TMP_21:.*]] = stablehlo.add %[[TMP_18]], %[[TMP_5]] +- // CHECK: %[[TMP_22:.*]] = stablehlo.power %[[TMP_21]], %[[TMP_3]] ++ // CHECK: %[[TMP_21:.*]] = stablehlo.add %[[TMP_18]], %[[TMP_3]] ++ // CHECK: %[[TMP_22:.*]] = stablehlo.power %[[TMP_21]], %[[TMP_4]] + // CHECK: %[[TMP_23:.*]] = stablehlo.add %[[TMP_20]], %[[TMP_22]] +- // CHECK: %[[TMP_24:.*]] = stablehlo.add %[[TMP_21]], %[[TMP_5]] +- // CHECK: %[[TMP_25:.*]] = stablehlo.power %[[TMP_24]], %[[TMP_3]] ++ // CHECK: %[[TMP_24:.*]] = stablehlo.add %[[TMP_21]], %[[TMP_3]] ++ // CHECK: %[[TMP_25:.*]] = stablehlo.power %[[TMP_24]], %[[TMP_4]] + // CHECK: %[[TMP_26:.*]] = stablehlo.add %[[TMP_23]], %[[TMP_25]] +- // CHECK: %[[TMP_27:.*]] = stablehlo.add %[[TMP_24]], %[[TMP_5]] +- // CHECK: %[[TMP_28:.*]] = stablehlo.power %[[TMP_27]], %[[TMP_3]] ++ // CHECK: %[[TMP_27:.*]] = stablehlo.add %[[TMP_24]], %[[TMP_3]] ++ // CHECK: %[[TMP_28:.*]] = stablehlo.power %[[TMP_27]], %[[TMP_4]] + // CHECK: %[[TMP_29:.*]] = stablehlo.add %[[TMP_26]], %[[TMP_28]] +- // CHECK: %[[TMP_30:.*]] = stablehlo.add %[[TMP_27]], %[[TMP_5]] +- // CHECK: %[[TMP_31:.*]] = stablehlo.power %[[TMP_30]], %[[TMP_3]] ++ // CHECK: %[[TMP_30:.*]] = stablehlo.add %[[TMP_27]], %[[TMP_3]] ++ // CHECK: %[[TMP_31:.*]] = stablehlo.power %[[TMP_30]], %[[TMP_4]] + // CHECK: %[[TMP_32:.*]] = stablehlo.add %[[TMP_29]], %[[TMP_31]] +- // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_5]] +- // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_3]] ++ // CHECK: %[[TMP_33:.*]] = stablehlo.add %[[TMP_30]], %[[TMP_3]] ++ // CHECK: %[[TMP_34:.*]] = stablehlo.power %[[TMP_33]], %[[TMP_4]] + // CHECK: %[[TMP_35:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_36:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] +- // CHECK: %[[TMP_37:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] +- // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_37]], %[[TMP_36]] +- // CHECK: %[[TMP_39:.*]] = stablehlo.add %[[TMP_32]], %[[TMP_38]] +- // CHECK: %[[TMP_40:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] +- // CHECK: %[[TMP_41:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_40]] +- // CHECK: %[[TMP_42:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_43:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_42]] +- // CHECK: %[[TMP_44:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_45:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_44]] +- // CHECK: %[[TMP_46:.*]] = stablehlo.multiply %[[TMP_43]], %[[TMP_45]] +- // CHECK: %[[TMP_47:.*]] = stablehlo.constant dense<-1.39544646E-19> +- // CHECK: %[[TMP_48:.*]] = stablehlo.add %[[TMP_2]], %[[TMP_47]] +- // CHECK: %[[TMP_49:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_48]] +- // CHECK: %[[TMP_50:.*]] = stablehlo.multiply %[[TMP_46]], %[[TMP_49]] +- // CHECK: %[[TMP_51:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_52:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_51]] +- // CHECK: %[[TMP_53:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_54:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_53]] +- // CHECK: %[[TMP_55:.*]] = stablehlo.multiply %[[TMP_52]], %[[TMP_54]] +- // CHECK: %[[TMP_56:.*]] = stablehlo.constant dense<5.50900303E-18> +- // CHECK: %[[TMP_57:.*]] = stablehlo.add %[[TMP_50]], %[[TMP_56]] +- // CHECK: %[[TMP_58:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_57]] +- // CHECK: %[[TMP_59:.*]] = stablehlo.multiply %[[TMP_55]], %[[TMP_58]] +- // CHECK: %[[TMP_60:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_61:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_60]] +- // CHECK: %[[TMP_62:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_63:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_62]] +- // CHECK: %[[TMP_64:.*]] = stablehlo.multiply %[[TMP_61]], %[[TMP_63]] +- // CHECK: %[[TMP_65:.*]] = stablehlo.constant dense<-2.17486866E-16> +- // CHECK: %[[TMP_66:.*]] = stablehlo.add %[[TMP_59]], %[[TMP_65]] +- // CHECK: %[[TMP_67:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_66]] +- // CHECK: %[[TMP_68:.*]] = stablehlo.multiply %[[TMP_64]], %[[TMP_67]] +- // CHECK: %[[TMP_69:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_70:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_69]] +- // CHECK: %[[TMP_71:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_72:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_71]] +- // CHECK: %[[TMP_73:.*]] = stablehlo.multiply %[[TMP_70]], %[[TMP_72]] +- // CHECK: %[[TMP_74:.*]] = stablehlo.constant dense<8.58606213E-15> +- // CHECK: %[[TMP_75:.*]] = stablehlo.add %[[TMP_68]], %[[TMP_74]] +- // CHECK: %[[TMP_76:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_75]] +- // CHECK: %[[TMP_77:.*]] = stablehlo.multiply %[[TMP_73]], %[[TMP_76]] +- // CHECK: %[[TMP_78:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_79:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_78]] +- // CHECK: %[[TMP_80:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_81:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_80]] +- // CHECK: %[[TMP_82:.*]] = stablehlo.multiply %[[TMP_79]], %[[TMP_81]] +- // CHECK: %[[TMP_83:.*]] = stablehlo.constant dense<-3.3896803E-13> +- // CHECK: %[[TMP_84:.*]] = stablehlo.add %[[TMP_77]], %[[TMP_83]] +- // CHECK: %[[TMP_85:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_84]] +- // CHECK: %[[TMP_86:.*]] = stablehlo.multiply %[[TMP_82]], %[[TMP_85]] +- // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_88:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_87]] +- // CHECK: %[[TMP_89:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_90:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_89]] +- // CHECK: %[[TMP_91:.*]] = stablehlo.multiply %[[TMP_88]], %[[TMP_90]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.constant dense<1.33825364E-11> +- // CHECK: %[[TMP_93:.*]] = stablehlo.add %[[TMP_86]], %[[TMP_92]] +- // CHECK: %[[TMP_94:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.multiply %[[TMP_91]], %[[TMP_94]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_96]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.multiply %[[TMP_97]], %[[TMP_99]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.constant dense<-5.28419031E-10> +- // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_95]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_102]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.multiply %[[TMP_100]], %[[TMP_103]] +- // CHECK: %[[TMP_105:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_105]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.multiply %[[TMP_106]], %[[TMP_108]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.constant dense<2.08767563E-8> +- // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_104]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_111]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.multiply %[[TMP_109]], %[[TMP_112]] +- // CHECK: %[[TMP_114:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_114]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.multiply %[[TMP_115]], %[[TMP_117]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.constant dense<-8.26719599E-7> +- // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_113]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_120]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.multiply %[[TMP_118]], %[[TMP_121]] +- // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_126:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_125]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_124]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.constant dense<3.30687835E-5> +- // CHECK: %[[TMP_129:.*]] = stablehlo.add %[[TMP_122]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_129]] +- // CHECK: %[[TMP_131:.*]] = stablehlo.multiply %[[TMP_127]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_134]] +- // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.constant dense<-0.00138888892> +- // CHECK: %[[TMP_138:.*]] = stablehlo.add %[[TMP_131]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.multiply %[[TMP_41]], %[[TMP_138]] +- // CHECK: %[[TMP_140:.*]] = stablehlo.multiply %[[TMP_136]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_33]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<0.0833333358> +- // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_143]], %[[TMP_140]] +- // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.add %[[TMP_141]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_39]], %[[TMP_147]] ++ // CHECK: %[[TMP_36:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_33]] ++ // CHECK: %[[TMP_37:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_35]] ++ // CHECK: %[[TMP_38:.*]] = stablehlo.divide %[[TMP_36]], %[[TMP_37]] ++ // CHECK: %[[TMP_39:.*]] = stablehlo.multiply %[[TMP_33]], %[[TMP_33]] ++ // CHECK: %[[TMP_40:.*]] = stablehlo.divide %[[TMP_3]], %[[TMP_39]] ++ // CHECK: %[[TMP_41:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_42:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_41]] ++ // CHECK: %[[TMP_43:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_44:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_43]] ++ // CHECK: %[[TMP_45:.*]] = stablehlo.multiply %[[TMP_42]], %[[TMP_44]] ++ // CHECK: %[[TMP_46:.*]] = stablehlo.constant dense<-1.39544646E-19> ++ // CHECK: %[[TMP_47:.*]] = stablehlo.add %[[TMP_2]], %[[TMP_46]] ++ // CHECK: %[[TMP_48:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_47]] ++ // CHECK: %[[TMP_49:.*]] = stablehlo.multiply %[[TMP_45]], %[[TMP_48]] ++ // CHECK: %[[TMP_50:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_51:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_50]] ++ // CHECK: %[[TMP_52:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_53:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_52]] ++ // CHECK: %[[TMP_54:.*]] = stablehlo.multiply %[[TMP_51]], %[[TMP_53]] ++ // CHECK: %[[TMP_55:.*]] = stablehlo.constant dense<5.50900303E-18> ++ // CHECK: %[[TMP_56:.*]] = stablehlo.add %[[TMP_49]], %[[TMP_55]] ++ // CHECK: %[[TMP_57:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_56]] ++ // CHECK: %[[TMP_58:.*]] = stablehlo.multiply %[[TMP_54]], %[[TMP_57]] ++ // CHECK: %[[TMP_59:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_60:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_59]] ++ // CHECK: %[[TMP_61:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_62:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_61]] ++ // CHECK: %[[TMP_63:.*]] = stablehlo.multiply %[[TMP_60]], %[[TMP_62]] ++ // CHECK: %[[TMP_64:.*]] = stablehlo.constant dense<-2.17486866E-16> ++ // CHECK: %[[TMP_65:.*]] = stablehlo.add %[[TMP_58]], %[[TMP_64]] ++ // CHECK: %[[TMP_66:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_65]] ++ // CHECK: %[[TMP_67:.*]] = stablehlo.multiply %[[TMP_63]], %[[TMP_66]] ++ // CHECK: %[[TMP_68:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_69:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_68]] ++ // CHECK: %[[TMP_70:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_71:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_70]] ++ // CHECK: %[[TMP_72:.*]] = stablehlo.multiply %[[TMP_69]], %[[TMP_71]] ++ // CHECK: %[[TMP_73:.*]] = stablehlo.constant dense<8.58606213E-15> ++ // CHECK: %[[TMP_74:.*]] = stablehlo.add %[[TMP_67]], %[[TMP_73]] ++ // CHECK: %[[TMP_75:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_74]] ++ // CHECK: %[[TMP_76:.*]] = stablehlo.multiply %[[TMP_72]], %[[TMP_75]] ++ // CHECK: %[[TMP_77:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_78:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_77]] ++ // CHECK: %[[TMP_79:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_80:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_79]] ++ // CHECK: %[[TMP_81:.*]] = stablehlo.multiply %[[TMP_78]], %[[TMP_80]] ++ // CHECK: %[[TMP_82:.*]] = stablehlo.constant dense<-3.3896803E-13> ++ // CHECK: %[[TMP_83:.*]] = stablehlo.add %[[TMP_76]], %[[TMP_82]] ++ // CHECK: %[[TMP_84:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_83]] ++ // CHECK: %[[TMP_85:.*]] = stablehlo.multiply %[[TMP_81]], %[[TMP_84]] ++ // CHECK: %[[TMP_86:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_87:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_86]] ++ // CHECK: %[[TMP_88:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_89:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_88]] ++ // CHECK: %[[TMP_90:.*]] = stablehlo.multiply %[[TMP_87]], %[[TMP_89]] ++ // CHECK: %[[TMP_91:.*]] = stablehlo.constant dense<1.33825364E-11> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.add %[[TMP_85]], %[[TMP_91]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.multiply %[[TMP_90]], %[[TMP_93]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_98:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_97]] ++ // CHECK: %[[TMP_99:.*]] = stablehlo.multiply %[[TMP_96]], %[[TMP_98]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.constant dense<-5.28419031E-10> ++ // CHECK: %[[TMP_101:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_100]] ++ // CHECK: %[[TMP_102:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_101]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.multiply %[[TMP_99]], %[[TMP_102]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_104]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_107:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_106]] ++ // CHECK: %[[TMP_108:.*]] = stablehlo.multiply %[[TMP_105]], %[[TMP_107]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.constant dense<2.08767563E-8> ++ // CHECK: %[[TMP_110:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_109]] ++ // CHECK: %[[TMP_111:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_110]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.multiply %[[TMP_108]], %[[TMP_111]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_113]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_116:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_115]] ++ // CHECK: %[[TMP_117:.*]] = stablehlo.multiply %[[TMP_114]], %[[TMP_116]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.constant dense<-8.26719599E-7> ++ // CHECK: %[[TMP_119:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_118]] ++ // CHECK: %[[TMP_120:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_119]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.multiply %[[TMP_117]], %[[TMP_120]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_123:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_122]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_125:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_124]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.multiply %[[TMP_123]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.constant dense<3.30687835E-5> ++ // CHECK: %[[TMP_128:.*]] = stablehlo.add %[[TMP_121]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_128]] ++ // CHECK: %[[TMP_130:.*]] = stablehlo.multiply %[[TMP_126]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_134:.*]] = stablehlo.add %[[TMP_0]], %[[TMP_133]] ++ // CHECK: %[[TMP_135:.*]] = stablehlo.multiply %[[TMP_132]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.constant dense<-0.00138888892> ++ // CHECK: %[[TMP_137:.*]] = stablehlo.add %[[TMP_130]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_40]], %[[TMP_137]] ++ // CHECK: %[[TMP_139:.*]] = stablehlo.multiply %[[TMP_135]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_33]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.constant dense<0.0833333358> ++ // CHECK: %[[TMP_143:.*]] = stablehlo.add %[[TMP_142]], %[[TMP_139]] ++ // CHECK: %[[TMP_144:.*]] = stablehlo.multiply %[[TMP_141]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_140]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_34]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.add %[[TMP_32]], %[[TMP_38]] ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_146]] + // CHECK: %[[TMP_149:.*]] = stablehlo.abs %[[TMP_34]] + // CHECK: %[[TMP_150:.*]] = stablehlo.abs %[[TMP_32]] + // CHECK: %[[TMP_151:.*]] = stablehlo.constant dense<1.401300e-45> +@@ -1456,7 +1456,7 @@ + // CHECK: %[[TMP_172:.*]] = stablehlo.and %[[TMP_169]], %[[TMP_171]] : tensor + // CHECK: %[[TMP_173:.*]] = stablehlo.select %[[TMP_172]], %[[TMP_163]], %[[TMP_155]] + // CHECK: %[[TMP_174:.*]] = stablehlo.select %[[TMP_166]], %[[TMP_173]], %[[TMP_162]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.compare EQ, %[[TMP_0]], %[[TMP_5]], NOTYPE ++ // CHECK: %[[TMP_175:.*]] = stablehlo.compare EQ, %[[TMP_0]], %[[TMP_3]], NOTYPE + // CHECK: %[[TMP_176:.*]] = stablehlo.select %[[TMP_175]], %[[TMP_163]], %[[TMP_174]] + // CHECK: %[[TMP_177:.*]] = stablehlo.convert %[[TMP_176]] : (tensor) -> tensor + %0 = chlo.zeta %arg0, %arg1 : tensor, tensor -> tensor +@@ -1465,8 +1465,7 @@ + + // ----- + +- +-// CHECK-LABEL: @polygamma_f32 ++// CHECK: @polygamma_f32 + // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor { + // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> +@@ -1559,153 +1558,153 @@ + // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<0x7F800000> + // CHECK: %[[TMP_88:.*]] = stablehlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] + // CHECK: %[[TMP_89:.*]] = stablehlo.exponential %[[TMP_88]] +- // CHECK: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_91:.*]] = stablehlo.negate %[[TMP_5]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.power %[[ARG1]], %[[TMP_91]] +- // CHECK: %[[TMP_93:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_91]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_92]], %[[TMP_95]] +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_93]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_91]] ++ // CHECK-DAG: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_91:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.negate %[[TMP_5]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.power %[[ARG1]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_91]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_92]] ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_93]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_91]] ++ // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_92]] + // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_96]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_93]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_91]] ++ // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_92]] + // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_99]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_93]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_92]] + // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_102]], %[[TMP_104]] +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_93]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_92]] + // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_105]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_93]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_92]] + // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_108]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_93]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_92]] + // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_111]], %[[TMP_113]] +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_93]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_92]] + // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_114]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_93]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_92]] + // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_117]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_93]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_91]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] + // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] +- // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_125]], %[[TMP_124]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] +- // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] +- // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.39544646E-19> +- // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] +- // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] +- // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.50900303E-18> +- // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] +- // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] +- // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] +- // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.17486866E-16> +- // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] +- // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] +- // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] +- // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] +- // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] +- // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] +- // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.58606213E-15> +- // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] +- // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] +- // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] +- // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] +- // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] +- // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] +- // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896803E-13> +- // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] +- // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] +- // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] +- // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] +- // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] +- // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.33825364E-11> +- // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] +- // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] +- // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] +- // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] +- // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] +- // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] +- // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.28419031E-10> +- // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] +- // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] +- // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] +- // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] +- // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] +- // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] +- // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767563E-8> +- // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] +- // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] +- // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] +- // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] +- // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] +- // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] +- // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.26719599E-7> +- // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] +- // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] +- // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] +- // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] +- // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] +- // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] +- // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.30687835E-5> +- // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] +- // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] +- // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] +- // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] +- // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] +- // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] +- // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.00138888892> +- // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] +- // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_226]] +- // CHECK: %[[TMP_228:.*]] = stablehlo.multiply %[[TMP_224]], %[[TMP_227]] +- // CHECK: %[[TMP_229:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_230:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] +- // CHECK: %[[TMP_231:.*]] = stablehlo.constant dense<0.0833333358> +- // CHECK: %[[TMP_232:.*]] = stablehlo.add %[[TMP_231]], %[[TMP_228]] +- // CHECK: %[[TMP_233:.*]] = stablehlo.multiply %[[TMP_230]], %[[TMP_232]] +- // CHECK: %[[TMP_234:.*]] = stablehlo.add %[[TMP_229]], %[[TMP_233]] +- // CHECK: %[[TMP_235:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_234]] +- // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_127]], %[[TMP_235]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] ++ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] ++ // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_130:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] ++ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.39544646E-19> ++ // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] ++ // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_139:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] ++ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.50900303E-18> ++ // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_147]] ++ // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] ++ // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] ++ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.17486866E-16> ++ // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] ++ // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] ++ // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] ++ // CHECK: %[[TMP_156:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_157:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_156]] ++ // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] ++ // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] ++ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.58606213E-15> ++ // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] ++ // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] ++ // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] ++ // CHECK: %[[TMP_165:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_166:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_165]] ++ // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] ++ // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] ++ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896803E-13> ++ // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] ++ // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] ++ // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] ++ // CHECK: %[[TMP_174:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_175:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_174]] ++ // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] ++ // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] ++ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.33825364E-11> ++ // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] ++ // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] ++ // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] ++ // CHECK: %[[TMP_183:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_184:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_183]] ++ // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] ++ // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] ++ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.28419031E-10> ++ // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] ++ // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] ++ // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] ++ // CHECK: %[[TMP_192:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_193:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_192]] ++ // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] ++ // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] ++ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767563E-8> ++ // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] ++ // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] ++ // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] ++ // CHECK: %[[TMP_201:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_202:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_201]] ++ // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] ++ // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] ++ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.26719599E-7> ++ // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] ++ // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] ++ // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] ++ // CHECK: %[[TMP_210:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_211:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_210]] ++ // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] ++ // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] ++ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.30687835E-5> ++ // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] ++ // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] ++ // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] ++ // CHECK: %[[TMP_219:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_220:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_219]] ++ // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] ++ // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] ++ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.00138888892> ++ // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] ++ // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] ++ // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] ++ // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] ++ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.0833333358> ++ // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] ++ // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] ++ // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] ++ // CHECK: %[[TMP_234:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_233]] ++ // CHECK: %[[TMP_235:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] ++ // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] + // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] + // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] + // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<1.401300e-45> +@@ -1732,7 +1731,7 @@ + // CHECK: %[[TMP_260:.*]] = stablehlo.and %[[TMP_257]], %[[TMP_259]] + // CHECK: %[[TMP_261:.*]] = stablehlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] + // CHECK: %[[TMP_262:.*]] = stablehlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] +- // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE ++ // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_264:.*]] = stablehlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] + // CHECK: %[[TMP_265:.*]] = stablehlo.multiply %[[TMP_4]], %[[TMP_89]] + // CHECK: %[[TMP_266:.*]] = stablehlo.multiply %[[TMP_265]], %[[TMP_264]] +@@ -1853,8 +1852,7 @@ + + // ----- + +- +-// CHECK-LABEL: @polygamma_f64 ++// CHECK: @polygamma_f64 + // CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) + func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor { + // CHECK-DAG: %[[TMP_0:.*]] = stablehlo.constant dense<1.000000e+00> +@@ -1947,153 +1945,153 @@ + // CHECK: %[[TMP_87:.*]] = stablehlo.constant dense<0x7FF0000000000000> + // CHECK: %[[TMP_88:.*]] = stablehlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] + // CHECK: %[[TMP_89:.*]] = stablehlo.exponential %[[TMP_88]] +- // CHECK: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> +- // CHECK: %[[TMP_91:.*]] = stablehlo.negate %[[TMP_5]] +- // CHECK: %[[TMP_92:.*]] = stablehlo.power %[[ARG1]], %[[TMP_91]] +- // CHECK: %[[TMP_93:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_93]] +- // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_91]] +- // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_92]], %[[TMP_95]] +- // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_93]] +- // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_91]] ++ // CHECK-DAG: %[[TMP_90:.*]] = stablehlo.constant dense<0.000000e+00> ++ // CHECK-DAG: %[[TMP_91:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_92:.*]] = stablehlo.negate %[[TMP_5]] ++ // CHECK: %[[TMP_93:.*]] = stablehlo.power %[[ARG1]], %[[TMP_92]] ++ // CHECK: %[[TMP_94:.*]] = stablehlo.add %[[ARG1]], %[[TMP_91]] ++ // CHECK: %[[TMP_95:.*]] = stablehlo.power %[[TMP_94]], %[[TMP_92]] ++ // CHECK: %[[TMP_96:.*]] = stablehlo.add %[[TMP_93]], %[[TMP_95]] ++ // CHECK: %[[TMP_97:.*]] = stablehlo.add %[[TMP_94]], %[[TMP_91]] ++ // CHECK: %[[TMP_98:.*]] = stablehlo.power %[[TMP_97]], %[[TMP_92]] + // CHECK: %[[TMP_99:.*]] = stablehlo.add %[[TMP_96]], %[[TMP_98]] +- // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_93]] +- // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_100:.*]] = stablehlo.add %[[TMP_97]], %[[TMP_91]] ++ // CHECK: %[[TMP_101:.*]] = stablehlo.power %[[TMP_100]], %[[TMP_92]] + // CHECK: %[[TMP_102:.*]] = stablehlo.add %[[TMP_99]], %[[TMP_101]] +- // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_93]] +- // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_103:.*]] = stablehlo.add %[[TMP_100]], %[[TMP_91]] ++ // CHECK: %[[TMP_104:.*]] = stablehlo.power %[[TMP_103]], %[[TMP_92]] + // CHECK: %[[TMP_105:.*]] = stablehlo.add %[[TMP_102]], %[[TMP_104]] +- // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_93]] +- // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_106:.*]] = stablehlo.add %[[TMP_103]], %[[TMP_91]] ++ // CHECK: %[[TMP_107:.*]] = stablehlo.power %[[TMP_106]], %[[TMP_92]] + // CHECK: %[[TMP_108:.*]] = stablehlo.add %[[TMP_105]], %[[TMP_107]] +- // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_93]] +- // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_109:.*]] = stablehlo.add %[[TMP_106]], %[[TMP_91]] ++ // CHECK: %[[TMP_110:.*]] = stablehlo.power %[[TMP_109]], %[[TMP_92]] + // CHECK: %[[TMP_111:.*]] = stablehlo.add %[[TMP_108]], %[[TMP_110]] +- // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_93]] +- // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_112:.*]] = stablehlo.add %[[TMP_109]], %[[TMP_91]] ++ // CHECK: %[[TMP_113:.*]] = stablehlo.power %[[TMP_112]], %[[TMP_92]] + // CHECK: %[[TMP_114:.*]] = stablehlo.add %[[TMP_111]], %[[TMP_113]] +- // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_93]] +- // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_115:.*]] = stablehlo.add %[[TMP_112]], %[[TMP_91]] ++ // CHECK: %[[TMP_116:.*]] = stablehlo.power %[[TMP_115]], %[[TMP_92]] + // CHECK: %[[TMP_117:.*]] = stablehlo.add %[[TMP_114]], %[[TMP_116]] +- // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_93]] +- // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_118:.*]] = stablehlo.add %[[TMP_115]], %[[TMP_91]] ++ // CHECK: %[[TMP_119:.*]] = stablehlo.power %[[TMP_118]], %[[TMP_92]] + // CHECK: %[[TMP_120:.*]] = stablehlo.add %[[TMP_117]], %[[TMP_119]] +- // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_93]] +- // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_91]] ++ // CHECK: %[[TMP_121:.*]] = stablehlo.add %[[TMP_118]], %[[TMP_91]] ++ // CHECK: %[[TMP_122:.*]] = stablehlo.power %[[TMP_121]], %[[TMP_92]] + // CHECK: %[[TMP_123:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_124:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] +- // CHECK: %[[TMP_125:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] +- // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_125]], %[[TMP_124]] +- // CHECK: %[[TMP_127:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] +- // CHECK: %[[TMP_128:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] +- // CHECK: %[[TMP_129:.*]] = stablehlo.divide %[[TMP_93]], %[[TMP_128]] +- // CHECK: %[[TMP_130:.*]] = stablehlo.constant dense<2.200000e+01> +- // CHECK: %[[TMP_131:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_130]] +- // CHECK: %[[TMP_132:.*]] = stablehlo.constant dense<2.100000e+01> +- // CHECK: %[[TMP_133:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_132]] +- // CHECK: %[[TMP_134:.*]] = stablehlo.multiply %[[TMP_131]], %[[TMP_133]] +- // CHECK: %[[TMP_135:.*]] = stablehlo.constant dense<-1.3954464685812522E-19> +- // CHECK: %[[TMP_136:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_135]] +- // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_136]] +- // CHECK: %[[TMP_138:.*]] = stablehlo.multiply %[[TMP_134]], %[[TMP_137]] +- // CHECK: %[[TMP_139:.*]] = stablehlo.constant dense<2.000000e+01> +- // CHECK: %[[TMP_140:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_139]] +- // CHECK: %[[TMP_141:.*]] = stablehlo.constant dense<1.900000e+01> +- // CHECK: %[[TMP_142:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_141]] +- // CHECK: %[[TMP_143:.*]] = stablehlo.multiply %[[TMP_140]], %[[TMP_142]] +- // CHECK: %[[TMP_144:.*]] = stablehlo.constant dense<5.5090028283602295E-18> +- // CHECK: %[[TMP_145:.*]] = stablehlo.add %[[TMP_138]], %[[TMP_144]] +- // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_145]] +- // CHECK: %[[TMP_147:.*]] = stablehlo.multiply %[[TMP_143]], %[[TMP_146]] +- // CHECK: %[[TMP_148:.*]] = stablehlo.constant dense<1.800000e+01> +- // CHECK: %[[TMP_149:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_148]] +- // CHECK: %[[TMP_150:.*]] = stablehlo.constant dense<1.700000e+01> +- // CHECK: %[[TMP_151:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_150]] +- // CHECK: %[[TMP_152:.*]] = stablehlo.multiply %[[TMP_149]], %[[TMP_151]] +- // CHECK: %[[TMP_153:.*]] = stablehlo.constant dense<-2.1748686985580617E-16> +- // CHECK: %[[TMP_154:.*]] = stablehlo.add %[[TMP_147]], %[[TMP_153]] +- // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_154]] +- // CHECK: %[[TMP_156:.*]] = stablehlo.multiply %[[TMP_152]], %[[TMP_155]] +- // CHECK: %[[TMP_157:.*]] = stablehlo.constant dense<1.600000e+01> +- // CHECK: %[[TMP_158:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_157]] +- // CHECK: %[[TMP_159:.*]] = stablehlo.constant dense<1.500000e+01> +- // CHECK: %[[TMP_160:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_159]] +- // CHECK: %[[TMP_161:.*]] = stablehlo.multiply %[[TMP_158]], %[[TMP_160]] +- // CHECK: %[[TMP_162:.*]] = stablehlo.constant dense<8.5860620562778452E-15> +- // CHECK: %[[TMP_163:.*]] = stablehlo.add %[[TMP_156]], %[[TMP_162]] +- // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_163]] +- // CHECK: %[[TMP_165:.*]] = stablehlo.multiply %[[TMP_161]], %[[TMP_164]] +- // CHECK: %[[TMP_166:.*]] = stablehlo.constant dense<1.400000e+01> +- // CHECK: %[[TMP_167:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_166]] +- // CHECK: %[[TMP_168:.*]] = stablehlo.constant dense<1.300000e+01> +- // CHECK: %[[TMP_169:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_168]] +- // CHECK: %[[TMP_170:.*]] = stablehlo.multiply %[[TMP_167]], %[[TMP_169]] +- // CHECK: %[[TMP_171:.*]] = stablehlo.constant dense<-3.3896802963225832E-13> +- // CHECK: %[[TMP_172:.*]] = stablehlo.add %[[TMP_165]], %[[TMP_171]] +- // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_172]] +- // CHECK: %[[TMP_174:.*]] = stablehlo.multiply %[[TMP_170]], %[[TMP_173]] +- // CHECK: %[[TMP_175:.*]] = stablehlo.constant dense<1.200000e+01> +- // CHECK: %[[TMP_176:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_175]] +- // CHECK: %[[TMP_177:.*]] = stablehlo.constant dense<1.100000e+01> +- // CHECK: %[[TMP_178:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_177]] +- // CHECK: %[[TMP_179:.*]] = stablehlo.multiply %[[TMP_176]], %[[TMP_178]] +- // CHECK: %[[TMP_180:.*]] = stablehlo.constant dense<1.3382536530684679E-11> +- // CHECK: %[[TMP_181:.*]] = stablehlo.add %[[TMP_174]], %[[TMP_180]] +- // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_181]] +- // CHECK: %[[TMP_183:.*]] = stablehlo.multiply %[[TMP_179]], %[[TMP_182]] +- // CHECK: %[[TMP_184:.*]] = stablehlo.constant dense<1.000000e+01> +- // CHECK: %[[TMP_185:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_184]] +- // CHECK: %[[TMP_186:.*]] = stablehlo.constant dense<9.000000e+00> +- // CHECK: %[[TMP_187:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_186]] +- // CHECK: %[[TMP_188:.*]] = stablehlo.multiply %[[TMP_185]], %[[TMP_187]] +- // CHECK: %[[TMP_189:.*]] = stablehlo.constant dense<-5.2841901386874932E-10> +- // CHECK: %[[TMP_190:.*]] = stablehlo.add %[[TMP_183]], %[[TMP_189]] +- // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_190]] +- // CHECK: %[[TMP_192:.*]] = stablehlo.multiply %[[TMP_188]], %[[TMP_191]] +- // CHECK: %[[TMP_193:.*]] = stablehlo.constant dense<8.000000e+00> +- // CHECK: %[[TMP_194:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_193]] +- // CHECK: %[[TMP_195:.*]] = stablehlo.constant dense<7.000000e+00> +- // CHECK: %[[TMP_196:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_195]] +- // CHECK: %[[TMP_197:.*]] = stablehlo.multiply %[[TMP_194]], %[[TMP_196]] +- // CHECK: %[[TMP_198:.*]] = stablehlo.constant dense<2.08767569878681E-8> +- // CHECK: %[[TMP_199:.*]] = stablehlo.add %[[TMP_192]], %[[TMP_198]] +- // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_199]] +- // CHECK: %[[TMP_201:.*]] = stablehlo.multiply %[[TMP_197]], %[[TMP_200]] +- // CHECK: %[[TMP_202:.*]] = stablehlo.constant dense<6.000000e+00> +- // CHECK: %[[TMP_203:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_202]] +- // CHECK: %[[TMP_204:.*]] = stablehlo.constant dense<5.000000e+00> +- // CHECK: %[[TMP_205:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_204]] +- // CHECK: %[[TMP_206:.*]] = stablehlo.multiply %[[TMP_203]], %[[TMP_205]] +- // CHECK: %[[TMP_207:.*]] = stablehlo.constant dense<-8.2671957671957675E-7> +- // CHECK: %[[TMP_208:.*]] = stablehlo.add %[[TMP_201]], %[[TMP_207]] +- // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_208]] +- // CHECK: %[[TMP_210:.*]] = stablehlo.multiply %[[TMP_206]], %[[TMP_209]] +- // CHECK: %[[TMP_211:.*]] = stablehlo.constant dense<4.000000e+00> +- // CHECK: %[[TMP_212:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_211]] +- // CHECK: %[[TMP_213:.*]] = stablehlo.constant dense<3.000000e+00> +- // CHECK: %[[TMP_214:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_213]] +- // CHECK: %[[TMP_215:.*]] = stablehlo.multiply %[[TMP_212]], %[[TMP_214]] +- // CHECK: %[[TMP_216:.*]] = stablehlo.constant dense<3.3068783068783071E-5> +- // CHECK: %[[TMP_217:.*]] = stablehlo.add %[[TMP_210]], %[[TMP_216]] +- // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_217]] +- // CHECK: %[[TMP_219:.*]] = stablehlo.multiply %[[TMP_215]], %[[TMP_218]] +- // CHECK: %[[TMP_220:.*]] = stablehlo.constant dense<2.000000e+00> +- // CHECK: %[[TMP_221:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_220]] +- // CHECK: %[[TMP_222:.*]] = stablehlo.constant dense<1.000000e+00> +- // CHECK: %[[TMP_223:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_222]] +- // CHECK: %[[TMP_224:.*]] = stablehlo.multiply %[[TMP_221]], %[[TMP_223]] +- // CHECK: %[[TMP_225:.*]] = stablehlo.constant dense<-0.0013888888888888889> +- // CHECK: %[[TMP_226:.*]] = stablehlo.add %[[TMP_219]], %[[TMP_225]] +- // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_129]], %[[TMP_226]] +- // CHECK: %[[TMP_228:.*]] = stablehlo.multiply %[[TMP_224]], %[[TMP_227]] +- // CHECK: %[[TMP_229:.*]] = stablehlo.constant dense<5.000000e-01> +- // CHECK: %[[TMP_230:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] +- // CHECK: %[[TMP_231:.*]] = stablehlo.constant dense<0.083333333333333329> +- // CHECK: %[[TMP_232:.*]] = stablehlo.add %[[TMP_231]], %[[TMP_228]] +- // CHECK: %[[TMP_233:.*]] = stablehlo.multiply %[[TMP_230]], %[[TMP_232]] +- // CHECK: %[[TMP_234:.*]] = stablehlo.add %[[TMP_229]], %[[TMP_233]] +- // CHECK: %[[TMP_235:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_234]] +- // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_127]], %[[TMP_235]] ++ // CHECK: %[[TMP_124:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_121]] ++ // CHECK: %[[TMP_125:.*]] = stablehlo.subtract %[[TMP_5]], %[[TMP_123]] ++ // CHECK: %[[TMP_126:.*]] = stablehlo.divide %[[TMP_124]], %[[TMP_125]] ++ // CHECK: %[[TMP_127:.*]] = stablehlo.multiply %[[TMP_121]], %[[TMP_121]] ++ // CHECK: %[[TMP_128:.*]] = stablehlo.divide %[[TMP_91]], %[[TMP_127]] ++ // CHECK: %[[TMP_129:.*]] = stablehlo.constant dense<2.200000e+01> ++ // CHECK: %[[TMP_130:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_129]] ++ // CHECK: %[[TMP_131:.*]] = stablehlo.constant dense<2.100000e+01> ++ // CHECK: %[[TMP_132:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_131]] ++ // CHECK: %[[TMP_133:.*]] = stablehlo.multiply %[[TMP_130]], %[[TMP_132]] ++ // CHECK: %[[TMP_134:.*]] = stablehlo.constant dense<-1.3954464685812522E-19> ++ // CHECK: %[[TMP_135:.*]] = stablehlo.add %[[TMP_90]], %[[TMP_134]] ++ // CHECK: %[[TMP_136:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_135]] ++ // CHECK: %[[TMP_137:.*]] = stablehlo.multiply %[[TMP_133]], %[[TMP_136]] ++ // CHECK: %[[TMP_138:.*]] = stablehlo.constant dense<2.000000e+01> ++ // CHECK: %[[TMP_139:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_138]] ++ // CHECK: %[[TMP_140:.*]] = stablehlo.constant dense<1.900000e+01> ++ // CHECK: %[[TMP_141:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_140]] ++ // CHECK: %[[TMP_142:.*]] = stablehlo.multiply %[[TMP_139]], %[[TMP_141]] ++ // CHECK: %[[TMP_143:.*]] = stablehlo.constant dense<5.5090028283602295E-18> ++ // CHECK: %[[TMP_144:.*]] = stablehlo.add %[[TMP_137]], %[[TMP_143]] ++ // CHECK: %[[TMP_145:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_144]] ++ // CHECK: %[[TMP_146:.*]] = stablehlo.multiply %[[TMP_142]], %[[TMP_145]] ++ // CHECK: %[[TMP_147:.*]] = stablehlo.constant dense<1.800000e+01> ++ // CHECK: %[[TMP_148:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_147]] ++ // CHECK: %[[TMP_149:.*]] = stablehlo.constant dense<1.700000e+01> ++ // CHECK: %[[TMP_150:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_149]] ++ // CHECK: %[[TMP_151:.*]] = stablehlo.multiply %[[TMP_148]], %[[TMP_150]] ++ // CHECK: %[[TMP_152:.*]] = stablehlo.constant dense<-2.1748686985580617E-16> ++ // CHECK: %[[TMP_153:.*]] = stablehlo.add %[[TMP_146]], %[[TMP_152]] ++ // CHECK: %[[TMP_154:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_153]] ++ // CHECK: %[[TMP_155:.*]] = stablehlo.multiply %[[TMP_151]], %[[TMP_154]] ++ // CHECK: %[[TMP_156:.*]] = stablehlo.constant dense<1.600000e+01> ++ // CHECK: %[[TMP_157:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_156]] ++ // CHECK: %[[TMP_158:.*]] = stablehlo.constant dense<1.500000e+01> ++ // CHECK: %[[TMP_159:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_158]] ++ // CHECK: %[[TMP_160:.*]] = stablehlo.multiply %[[TMP_157]], %[[TMP_159]] ++ // CHECK: %[[TMP_161:.*]] = stablehlo.constant dense<8.5860620562778452E-15> ++ // CHECK: %[[TMP_162:.*]] = stablehlo.add %[[TMP_155]], %[[TMP_161]] ++ // CHECK: %[[TMP_163:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_162]] ++ // CHECK: %[[TMP_164:.*]] = stablehlo.multiply %[[TMP_160]], %[[TMP_163]] ++ // CHECK: %[[TMP_165:.*]] = stablehlo.constant dense<1.400000e+01> ++ // CHECK: %[[TMP_166:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_165]] ++ // CHECK: %[[TMP_167:.*]] = stablehlo.constant dense<1.300000e+01> ++ // CHECK: %[[TMP_168:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_167]] ++ // CHECK: %[[TMP_169:.*]] = stablehlo.multiply %[[TMP_166]], %[[TMP_168]] ++ // CHECK: %[[TMP_170:.*]] = stablehlo.constant dense<-3.3896802963225832E-13> ++ // CHECK: %[[TMP_171:.*]] = stablehlo.add %[[TMP_164]], %[[TMP_170]] ++ // CHECK: %[[TMP_172:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_171]] ++ // CHECK: %[[TMP_173:.*]] = stablehlo.multiply %[[TMP_169]], %[[TMP_172]] ++ // CHECK: %[[TMP_174:.*]] = stablehlo.constant dense<1.200000e+01> ++ // CHECK: %[[TMP_175:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_174]] ++ // CHECK: %[[TMP_176:.*]] = stablehlo.constant dense<1.100000e+01> ++ // CHECK: %[[TMP_177:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_176]] ++ // CHECK: %[[TMP_178:.*]] = stablehlo.multiply %[[TMP_175]], %[[TMP_177]] ++ // CHECK: %[[TMP_179:.*]] = stablehlo.constant dense<1.3382536530684679E-11> ++ // CHECK: %[[TMP_180:.*]] = stablehlo.add %[[TMP_173]], %[[TMP_179]] ++ // CHECK: %[[TMP_181:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_180]] ++ // CHECK: %[[TMP_182:.*]] = stablehlo.multiply %[[TMP_178]], %[[TMP_181]] ++ // CHECK: %[[TMP_183:.*]] = stablehlo.constant dense<1.000000e+01> ++ // CHECK: %[[TMP_184:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_183]] ++ // CHECK: %[[TMP_185:.*]] = stablehlo.constant dense<9.000000e+00> ++ // CHECK: %[[TMP_186:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_185]] ++ // CHECK: %[[TMP_187:.*]] = stablehlo.multiply %[[TMP_184]], %[[TMP_186]] ++ // CHECK: %[[TMP_188:.*]] = stablehlo.constant dense<-5.2841901386874932E-10> ++ // CHECK: %[[TMP_189:.*]] = stablehlo.add %[[TMP_182]], %[[TMP_188]] ++ // CHECK: %[[TMP_190:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_189]] ++ // CHECK: %[[TMP_191:.*]] = stablehlo.multiply %[[TMP_187]], %[[TMP_190]] ++ // CHECK: %[[TMP_192:.*]] = stablehlo.constant dense<8.000000e+00> ++ // CHECK: %[[TMP_193:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_192]] ++ // CHECK: %[[TMP_194:.*]] = stablehlo.constant dense<7.000000e+00> ++ // CHECK: %[[TMP_195:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_194]] ++ // CHECK: %[[TMP_196:.*]] = stablehlo.multiply %[[TMP_193]], %[[TMP_195]] ++ // CHECK: %[[TMP_197:.*]] = stablehlo.constant dense<2.08767569878681E-8> ++ // CHECK: %[[TMP_198:.*]] = stablehlo.add %[[TMP_191]], %[[TMP_197]] ++ // CHECK: %[[TMP_199:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_198]] ++ // CHECK: %[[TMP_200:.*]] = stablehlo.multiply %[[TMP_196]], %[[TMP_199]] ++ // CHECK: %[[TMP_201:.*]] = stablehlo.constant dense<6.000000e+00> ++ // CHECK: %[[TMP_202:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_201]] ++ // CHECK: %[[TMP_203:.*]] = stablehlo.constant dense<5.000000e+00> ++ // CHECK: %[[TMP_204:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_203]] ++ // CHECK: %[[TMP_205:.*]] = stablehlo.multiply %[[TMP_202]], %[[TMP_204]] ++ // CHECK: %[[TMP_206:.*]] = stablehlo.constant dense<-8.2671957671957675E-7> ++ // CHECK: %[[TMP_207:.*]] = stablehlo.add %[[TMP_200]], %[[TMP_206]] ++ // CHECK: %[[TMP_208:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_207]] ++ // CHECK: %[[TMP_209:.*]] = stablehlo.multiply %[[TMP_205]], %[[TMP_208]] ++ // CHECK: %[[TMP_210:.*]] = stablehlo.constant dense<4.000000e+00> ++ // CHECK: %[[TMP_211:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_210]] ++ // CHECK: %[[TMP_212:.*]] = stablehlo.constant dense<3.000000e+00> ++ // CHECK: %[[TMP_213:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_212]] ++ // CHECK: %[[TMP_214:.*]] = stablehlo.multiply %[[TMP_211]], %[[TMP_213]] ++ // CHECK: %[[TMP_215:.*]] = stablehlo.constant dense<3.3068783068783071E-5> ++ // CHECK: %[[TMP_216:.*]] = stablehlo.add %[[TMP_209]], %[[TMP_215]] ++ // CHECK: %[[TMP_217:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_216]] ++ // CHECK: %[[TMP_218:.*]] = stablehlo.multiply %[[TMP_214]], %[[TMP_217]] ++ // CHECK: %[[TMP_219:.*]] = stablehlo.constant dense<2.000000e+00> ++ // CHECK: %[[TMP_220:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_219]] ++ // CHECK: %[[TMP_221:.*]] = stablehlo.constant dense<1.000000e+00> ++ // CHECK: %[[TMP_222:.*]] = stablehlo.add %[[TMP_5]], %[[TMP_221]] ++ // CHECK: %[[TMP_223:.*]] = stablehlo.multiply %[[TMP_220]], %[[TMP_222]] ++ // CHECK: %[[TMP_224:.*]] = stablehlo.constant dense<-0.0013888888888888889> ++ // CHECK: %[[TMP_225:.*]] = stablehlo.add %[[TMP_218]], %[[TMP_224]] ++ // CHECK: %[[TMP_226:.*]] = stablehlo.multiply %[[TMP_128]], %[[TMP_225]] ++ // CHECK: %[[TMP_227:.*]] = stablehlo.multiply %[[TMP_223]], %[[TMP_226]] ++ // CHECK: %[[TMP_228:.*]] = stablehlo.constant dense<5.000000e-01> ++ // CHECK: %[[TMP_229:.*]] = stablehlo.divide %[[TMP_5]], %[[TMP_121]] ++ // CHECK: %[[TMP_230:.*]] = stablehlo.constant dense<0.083333333333333329> ++ // CHECK: %[[TMP_231:.*]] = stablehlo.add %[[TMP_230]], %[[TMP_227]] ++ // CHECK: %[[TMP_232:.*]] = stablehlo.multiply %[[TMP_229]], %[[TMP_231]] ++ // CHECK: %[[TMP_233:.*]] = stablehlo.add %[[TMP_228]], %[[TMP_232]] ++ // CHECK: %[[TMP_234:.*]] = stablehlo.multiply %[[TMP_122]], %[[TMP_233]] ++ // CHECK: %[[TMP_235:.*]] = stablehlo.add %[[TMP_120]], %[[TMP_126]] ++ // CHECK: %[[TMP_236:.*]] = stablehlo.add %[[TMP_235]], %[[TMP_234]] + // CHECK: %[[TMP_237:.*]] = stablehlo.abs %[[TMP_122]] + // CHECK: %[[TMP_238:.*]] = stablehlo.abs %[[TMP_120]] + // CHECK: %[[TMP_239:.*]] = stablehlo.constant dense<4.940660e-324> +@@ -2120,7 +2118,7 @@ + // CHECK: %[[TMP_260:.*]] = stablehlo.and %[[TMP_257]], %[[TMP_259]] + // CHECK: %[[TMP_261:.*]] = stablehlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] + // CHECK: %[[TMP_262:.*]] = stablehlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] +- // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE ++ // CHECK: %[[TMP_263:.*]] = stablehlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE + // CHECK: %[[TMP_264:.*]] = stablehlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] + // CHECK: %[[TMP_265:.*]] = stablehlo.multiply %[[TMP_4]], %[[TMP_89]] + // CHECK: %[[TMP_266:.*]] = stablehlo.multiply %[[TMP_265]], %[[TMP_264]] +diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp ++++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +@@ -1575,11 +1575,21 @@ + + static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, + ValueRange args) { +- // Code should match XLA's materializeZeta from chlo_legalize_to_hlo.cc ++ // Implementation ported from: ++ // https://github.com/openxla/xla/blob/7a067a7b88d2ffb15b1dc5e3c06f701a15f0391d/xla/client/lib/math.cc#L1912-L1917 ++ // Reference: Johansson, Fredrik. ++ // "Rigorous high-precision computation of the Hurwitz zeta function and its ++ // derivatives." Numerical Algorithms 69.2 (2015): 253-270. ++ // https://arxiv.org/abs/1309.2877 - formula (5) ++ // Notation is more or less kept as a reference to the whitepaper. + assert(args.size() == 2); + Value x = args[0]; + Value q = args[1]; +- static const std::array kZetaCoeffs{ ++ ++ static constexpr auto kTerms = 12; ++ static constexpr auto kIters = 9; ++ static constexpr auto kTwoTermsMinusOne = 2 * kTerms - 1; ++ static constexpr auto kZetaCoeffs = std::array{ + -7.1661652561756670113e18, + 1.8152105401943546773e17, + -4.5979787224074726105e15, +@@ -1596,131 +1606,134 @@ + + // For speed we'll always use 9 iterations for the initial series estimate, + // and a 12 term expansion for the Euler-Maclaurin formula. +- Value a = q; +- Value zero = getConstantLike(rewriter, loc, 0.0, a); +- Value negPower = zero; +- Value negX = rewriter.create(loc, x); +- Value initialSum = rewriter.create(loc, q, negX); +- Value one = getConstantLike(rewriter, loc, 1.0, a); +- for (int i = 0; i < 9; ++i) { +- a = rewriter.create(loc, a, one); +- negPower = rewriter.create(loc, a, negX); +- initialSum = +- rewriter.create(loc, initialSum, negPower); +- } +- +- a = rewriter.create(loc, a, one); +- negPower = rewriter.create(loc, a, negX); ++ Value zero = getConstantLike(rewriter, loc, 0.0, q); ++ Value one = getConstantLike(rewriter, loc, 1.0, q); ++ Value acc = q; ++ Value qNegPower = zero; ++ Value negX = rewriter.create(loc, x); ++ Value powerSum = rewriter.create(loc, q, negX); ++ for (int i = 0; i < kIters; ++i) { ++ acc = rewriter.create(loc, acc, one); ++ qNegPower = rewriter.create(loc, acc, negX); ++ powerSum = ++ rewriter.create(loc, powerSum, qNegPower); ++ } ++ acc = rewriter.create(loc, acc, one); ++ qNegPower = rewriter.create(loc, acc, negX); + Value oneLikeX = getConstantLike(rewriter, loc, 1.0, x); +- Value xMinusOne = +- rewriter.create(loc, x, oneLikeX); +- Value negPowerMulA = +- rewriter.create(loc, negPower, a); +- Value negPowerMulADivXMinusOne = +- rewriter.create(loc, negPowerMulA, xMinusOne); +- Value s = rewriter.create(loc, initialSum, +- negPowerMulADivXMinusOne); +- Value aInverseSquare = rewriter.create( +- loc, one, rewriter.create(loc, a, a)); +- +- Value hornerSum = zero; +- Value factor = one; ++ Value correctionEulerMaclaurin = rewriter.create( ++ loc, rewriter.create(loc, qNegPower, acc), ++ rewriter.create(loc, x, oneLikeX)); ++ ++ // Manual reciprocal of the square root as RsqrtOp produces different results ++ Value rsqrtAcc = rewriter.create( ++ loc, one, rewriter.create(loc, acc, acc)); ++ + // Use Horner's rule for this. + // Note this differs from Cephes which does a 'naive' polynomial evaluation. + // Using Horner's rule allows to avoid some NaN's and Infs from happening, + // resulting in more numerically stable code. +- for (int i = 0; i < 11; ++i) { +- Value factorLhs = rewriter.create( +- loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x)); +- Value factorRhs = rewriter.create( +- loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x)); +- factor = rewriter.create(loc, factorLhs, factorRhs); +- hornerSum = rewriter.create( +- loc, factor, +- rewriter.create( +- loc, aInverseSquare, +- rewriter.create( ++ Value hornerSum = zero; ++ Value hornerProduct = one; ++ ++ for (int i = 0; i < kTerms - 1; ++i) { ++ Value factorLhs = rewriter.create( ++ loc, x, ++ getConstantLike(rewriter, loc, kTwoTermsMinusOne - 1 - 2 * i, x)); ++ Value factorRhs = rewriter.create( ++ loc, x, ++ getConstantLike(rewriter, loc, kTwoTermsMinusOne - 2 - 2 * i, x)); ++ hornerProduct = ++ rewriter.create(loc, factorLhs, factorRhs); ++ hornerSum = rewriter.create( ++ loc, hornerProduct, ++ rewriter.create( ++ loc, rsqrtAcc, ++ rewriter.create( + loc, hornerSum, +- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); +- } +- Value zeroPointFiveLikeNegPower = +- getConstantLike(rewriter, loc, .5, negPower); +- Value xDivA = rewriter.create(loc, x, a); +- s = rewriter.create( +- loc, s, +- rewriter.create( +- loc, negPower, +- rewriter.create( +- loc, zeroPointFiveLikeNegPower, +- rewriter.create( +- loc, xDivA, +- rewriter.create( +- loc, +- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], a), +- hornerSum))))); ++ getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], acc)))); ++ } ++ Value zeroPointFiveLikeQNegPower = ++ getConstantLike(rewriter, loc, .5, qNegPower); ++ Value xDivAcc = rewriter.create(loc, x, acc); ++ Value bernoulliTailTerm = rewriter.create( ++ loc, qNegPower, ++ rewriter.create( ++ loc, zeroPointFiveLikeQNegPower, ++ rewriter.create( ++ loc, xDivAcc, ++ rewriter.create( ++ loc, ++ getConstantLike(rewriter, loc, 1. / kZetaCoeffs[kTerms - 1], ++ acc), ++ hornerSum)))); ++ Value accurateResult = rewriter.create( ++ loc, ++ rewriter.create(loc, powerSum, ++ correctionEulerMaclaurin), ++ bernoulliTailTerm); + + // Use the initial zeta sum without the correction term coming + // from Euler-Maclaurin if it is accurate enough. +- Value absNegPower = rewriter.create(loc, negPower); +- Value absInitialSum = +- rewriter.create(loc, initialSum); +- Value output = rewriter.create( ++ Value absQNegPower = rewriter.create(loc, qNegPower); ++ Value absPowerSum = rewriter.create(loc, powerSum); ++ Value output = rewriter.create( + loc, +- rewriter.create( +- loc, absNegPower, +- rewriter.create( +- loc, absInitialSum, +- getConstantLikeSmallestFiniteValue(rewriter, loc, a)), +- mlir::stablehlo::ComparisonDirection::LT), +- initialSum, s); ++ rewriter.create( ++ loc, absQNegPower, ++ rewriter.create( ++ loc, absPowerSum, ++ getConstantLikeSmallestFiniteValue(rewriter, loc, acc)), ++ ComparisonDirection::LT), ++ powerSum, accurateResult); + + // Function is not defined for x < 1. + Value nan = getConstantLike(rewriter, loc, + std::numeric_limits::quiet_NaN(), x); +- output = rewriter.create( ++ output = rewriter.create( + loc, +- rewriter.create( +- loc, x, oneLikeX, mlir::stablehlo::ComparisonDirection::LT), ++ rewriter.create( ++ loc, x, oneLikeX, ComparisonDirection::LT), + nan, output); + + // For q <= 0, x must be an integer. +- Value qLeZero = rewriter.create( +- loc, q, zero, mlir::stablehlo::ComparisonDirection::LE); +- Value xNotInt = rewriter.create( +- loc, x, rewriter.create(loc, x), +- mlir::stablehlo::ComparisonDirection::NE); ++ Value qLeZero = rewriter.create( ++ loc, q, zero, ComparisonDirection::LE); ++ Value xNotInt = rewriter.create( ++ loc, x, rewriter.create(loc, x), ++ ComparisonDirection::NE); + Value xDomainError = +- rewriter.create(loc, qLeZero, xNotInt); +- output = rewriter.create(loc, xDomainError, nan, ++ rewriter.create(loc, qLeZero, xNotInt); ++ output = rewriter.create(loc, xDomainError, nan, + output); + + // For all integer q <= 0, zeta has a pole. The limit is only defined as + // +inf if x is and even integer. + Value inf = getConstantLike(rewriter, loc, + std::numeric_limits::infinity(), x); +- Value qIsInt = rewriter.create( +- loc, q, rewriter.create(loc, q), +- mlir::stablehlo::ComparisonDirection::EQ); +- Value atPole = rewriter.create(loc, qLeZero, qIsInt); ++ Value qIsInt = rewriter.create( ++ loc, q, rewriter.create(loc, q), ++ ComparisonDirection::EQ); ++ Value atPole = rewriter.create(loc, qLeZero, qIsInt); + Value two = getConstantLike(rewriter, loc, 2.0, x); +- Value xIsInt = rewriter.create( +- loc, x, rewriter.create(loc, x), +- mlir::stablehlo::ComparisonDirection::EQ); +- Value xIsEven = rewriter.create( +- loc, rewriter.create(loc, x, two), zero, +- mlir::stablehlo::ComparisonDirection::EQ); ++ Value xIsInt = rewriter.create( ++ loc, x, rewriter.create(loc, x), ++ ComparisonDirection::EQ); ++ Value xIsEven = rewriter.create( ++ loc, rewriter.create(loc, x, two), zero, ++ ComparisonDirection::EQ); + Value xIsEvenInt = +- rewriter.create(loc, xIsInt, xIsEven); +- output = rewriter.create( ++ rewriter.create(loc, xIsInt, xIsEven); ++ output = rewriter.create( + loc, atPole, +- rewriter.create(loc, xIsEvenInt, inf, nan), ++ rewriter.create(loc, xIsEvenInt, inf, nan), + output); + + // For x = 1, this is the harmonic series and diverges. +- output = rewriter.create( ++ output = rewriter.create( + loc, +- rewriter.create( +- loc, x, one, mlir::stablehlo::ComparisonDirection::EQ), ++ rewriter.create( ++ loc, x, one, ComparisonDirection::EQ), + inf, output); + + return output; diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 4d9835dc8b790c..27a2d6567f50d7 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -1109,153 +1109,153 @@ func.func @digamma_f16(%arg : tensor) -> tensor { func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_0:.*]] = mhlo.convert %[[X]] : (tensor) -> tensor // CHECK: %[[TMP_1:.*]] = mhlo.convert %[[Q]] : (tensor) -> tensor - // CHECK: %[[TMP_2:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_3:.*]] = mhlo.negate %[[TMP_0]] - // CHECK: %[[TMP_4:.*]] = mhlo.power %[[TMP_1]], %[[TMP_3]] - // CHECK: %[[TMP_5:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_5]] - // CHECK: %[[TMP_7:.*]] = mhlo.power %[[TMP_6]], %[[TMP_3]] - // CHECK: %[[TMP_8:.*]] = mhlo.add %[[TMP_4]], %[[TMP_7]] - // CHECK: %[[TMP_9:.*]] = mhlo.add %[[TMP_6]], %[[TMP_5]] - // CHECK: %[[TMP_10:.*]] = mhlo.power %[[TMP_9]], %[[TMP_3]] + // CHECK-DAG: %[[TMP_2:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK-DAG: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_4:.*]] = mhlo.negate %[[TMP_0]] + // CHECK: %[[TMP_5:.*]] = mhlo.power %[[TMP_1]], %[[TMP_4]] + // CHECK: %[[TMP_6:.*]] = mhlo.add %[[TMP_1]], %[[TMP_3]] + // CHECK: %[[TMP_7:.*]] = mhlo.power %[[TMP_6]], %[[TMP_4]] + // CHECK: %[[TMP_8:.*]] = mhlo.add %[[TMP_5]], %[[TMP_7]] + // CHECK: %[[TMP_9:.*]] = mhlo.add %[[TMP_6]], %[[TMP_3]] + // CHECK: %[[TMP_10:.*]] = mhlo.power %[[TMP_9]], %[[TMP_4]] // CHECK: %[[TMP_11:.*]] = mhlo.add %[[TMP_8]], %[[TMP_10]] - // CHECK: %[[TMP_12:.*]] = mhlo.add %[[TMP_9]], %[[TMP_5]] - // CHECK: %[[TMP_13:.*]] = mhlo.power %[[TMP_12]], %[[TMP_3]] + // CHECK: %[[TMP_12:.*]] = mhlo.add %[[TMP_9]], %[[TMP_3]] + // CHECK: %[[TMP_13:.*]] = mhlo.power %[[TMP_12]], %[[TMP_4]] // CHECK: %[[TMP_14:.*]] = mhlo.add %[[TMP_11]], %[[TMP_13]] - // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_5]] - // CHECK: %[[TMP_16:.*]] = mhlo.power %[[TMP_15]], %[[TMP_3]] + // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_3]] + // CHECK: %[[TMP_16:.*]] = mhlo.power %[[TMP_15]], %[[TMP_4]] // CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_14]], %[[TMP_16]] - // CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_15]], %[[TMP_5]] - // CHECK: %[[TMP_19:.*]] = mhlo.power %[[TMP_18]], %[[TMP_3]] + // CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_15]], %[[TMP_3]] + // CHECK: %[[TMP_19:.*]] = mhlo.power %[[TMP_18]], %[[TMP_4]] // CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_17]], %[[TMP_19]] - // CHECK: %[[TMP_21:.*]] = mhlo.add %[[TMP_18]], %[[TMP_5]] - // CHECK: %[[TMP_22:.*]] = mhlo.power %[[TMP_21]], %[[TMP_3]] + // CHECK: %[[TMP_21:.*]] = mhlo.add %[[TMP_18]], %[[TMP_3]] + // CHECK: %[[TMP_22:.*]] = mhlo.power %[[TMP_21]], %[[TMP_4]] // CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_20]], %[[TMP_22]] - // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_21]], %[[TMP_5]] - // CHECK: %[[TMP_25:.*]] = mhlo.power %[[TMP_24]], %[[TMP_3]] + // CHECK: %[[TMP_24:.*]] = mhlo.add %[[TMP_21]], %[[TMP_3]] + // CHECK: %[[TMP_25:.*]] = mhlo.power %[[TMP_24]], %[[TMP_4]] // CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_23]], %[[TMP_25]] - // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_24]], %[[TMP_5]] - // CHECK: %[[TMP_28:.*]] = mhlo.power %[[TMP_27]], %[[TMP_3]] + // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_24]], %[[TMP_3]] + // CHECK: %[[TMP_28:.*]] = mhlo.power %[[TMP_27]], %[[TMP_4]] // CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_26]], %[[TMP_28]] - // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_27]], %[[TMP_5]] - // CHECK: %[[TMP_31:.*]] = mhlo.power %[[TMP_30]], %[[TMP_3]] + // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_27]], %[[TMP_3]] + // CHECK: %[[TMP_31:.*]] = mhlo.power %[[TMP_30]], %[[TMP_4]] // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_29]], %[[TMP_31]] - // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_30]], %[[TMP_5]] - // CHECK: %[[TMP_34:.*]] = mhlo.power %[[TMP_33]], %[[TMP_3]] + // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_30]], %[[TMP_3]] + // CHECK: %[[TMP_34:.*]] = mhlo.power %[[TMP_33]], %[[TMP_4]] // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_36:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] - // CHECK: %[[TMP_37:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] - // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_37]], %[[TMP_36]] - // CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_32]], %[[TMP_38]] - // CHECK: %[[TMP_40:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_33]] - // CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_40]] - // CHECK: %[[TMP_42:.*]] = mhlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_43:.*]] = mhlo.add %[[TMP_0]], %[[TMP_42]] - // CHECK: %[[TMP_44:.*]] = mhlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_0]], %[[TMP_44]] - // CHECK: %[[TMP_46:.*]] = mhlo.multiply %[[TMP_43]], %[[TMP_45]] - // CHECK: %[[TMP_47:.*]] = mhlo.constant dense<-1.39544646E-19> - // CHECK: %[[TMP_48:.*]] = mhlo.add %[[TMP_2]], %[[TMP_47]] - // CHECK: %[[TMP_49:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_48]] - // CHECK: %[[TMP_50:.*]] = mhlo.multiply %[[TMP_46]], %[[TMP_49]] - // CHECK: %[[TMP_51:.*]] = mhlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_0]], %[[TMP_51]] - // CHECK: %[[TMP_53:.*]] = mhlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_54:.*]] = mhlo.add %[[TMP_0]], %[[TMP_53]] - // CHECK: %[[TMP_55:.*]] = mhlo.multiply %[[TMP_52]], %[[TMP_54]] - // CHECK: %[[TMP_56:.*]] = mhlo.constant dense<5.50900303E-18> - // CHECK: %[[TMP_57:.*]] = mhlo.add %[[TMP_50]], %[[TMP_56]] - // CHECK: %[[TMP_58:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_57]] - // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_55]], %[[TMP_58]] - // CHECK: %[[TMP_60:.*]] = mhlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_61:.*]] = mhlo.add %[[TMP_0]], %[[TMP_60]] - // CHECK: %[[TMP_62:.*]] = mhlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_0]], %[[TMP_62]] - // CHECK: %[[TMP_64:.*]] = mhlo.multiply %[[TMP_61]], %[[TMP_63]] - // CHECK: %[[TMP_65:.*]] = mhlo.constant dense<-2.17486866E-16> - // CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_59]], %[[TMP_65]] - // CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_66]] - // CHECK: %[[TMP_68:.*]] = mhlo.multiply %[[TMP_64]], %[[TMP_67]] - // CHECK: %[[TMP_69:.*]] = mhlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_70:.*]] = mhlo.add %[[TMP_0]], %[[TMP_69]] - // CHECK: %[[TMP_71:.*]] = mhlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_72:.*]] = mhlo.add %[[TMP_0]], %[[TMP_71]] - // CHECK: %[[TMP_73:.*]] = mhlo.multiply %[[TMP_70]], %[[TMP_72]] - // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<8.58606213E-15> - // CHECK: %[[TMP_75:.*]] = mhlo.add %[[TMP_68]], %[[TMP_74]] - // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_75]] - // CHECK: %[[TMP_77:.*]] = mhlo.multiply %[[TMP_73]], %[[TMP_76]] - // CHECK: %[[TMP_78:.*]] = mhlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_79:.*]] = mhlo.add %[[TMP_0]], %[[TMP_78]] - // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_0]], %[[TMP_80]] - // CHECK: %[[TMP_82:.*]] = mhlo.multiply %[[TMP_79]], %[[TMP_81]] - // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<-3.3896803E-13> - // CHECK: %[[TMP_84:.*]] = mhlo.add %[[TMP_77]], %[[TMP_83]] - // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_84]] - // CHECK: %[[TMP_86:.*]] = mhlo.multiply %[[TMP_82]], %[[TMP_85]] - // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_88:.*]] = mhlo.add %[[TMP_0]], %[[TMP_87]] - // CHECK: %[[TMP_89:.*]] = mhlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_90:.*]] = mhlo.add %[[TMP_0]], %[[TMP_89]] - // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]] - // CHECK: %[[TMP_92:.*]] = mhlo.constant dense<1.33825364E-11> - // CHECK: %[[TMP_93:.*]] = mhlo.add %[[TMP_86]], %[[TMP_92]] - // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_93]] - // CHECK: %[[TMP_95:.*]] = mhlo.multiply %[[TMP_91]], %[[TMP_94]] - // CHECK: %[[TMP_96:.*]] = mhlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_0]], %[[TMP_96]] - // CHECK: %[[TMP_98:.*]] = mhlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_0]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.multiply %[[TMP_97]], %[[TMP_99]] - // CHECK: %[[TMP_101:.*]] = mhlo.constant dense<-5.28419031E-10> - // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_95]], %[[TMP_101]] - // CHECK: %[[TMP_103:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_102]] - // CHECK: %[[TMP_104:.*]] = mhlo.multiply %[[TMP_100]], %[[TMP_103]] - // CHECK: %[[TMP_105:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_0]], %[[TMP_105]] - // CHECK: %[[TMP_107:.*]] = mhlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_108:.*]] = mhlo.add %[[TMP_0]], %[[TMP_107]] - // CHECK: %[[TMP_109:.*]] = mhlo.multiply %[[TMP_106]], %[[TMP_108]] - // CHECK: %[[TMP_110:.*]] = mhlo.constant dense<2.08767563E-8> - // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_104]], %[[TMP_110]] - // CHECK: %[[TMP_112:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_111]] - // CHECK: %[[TMP_113:.*]] = mhlo.multiply %[[TMP_109]], %[[TMP_112]] - // CHECK: %[[TMP_114:.*]] = mhlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_0]], %[[TMP_114]] - // CHECK: %[[TMP_116:.*]] = mhlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_117:.*]] = mhlo.add %[[TMP_0]], %[[TMP_116]] - // CHECK: %[[TMP_118:.*]] = mhlo.multiply %[[TMP_115]], %[[TMP_117]] - // CHECK: %[[TMP_119:.*]] = mhlo.constant dense<-8.26719599E-7> - // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_113]], %[[TMP_119]] - // CHECK: %[[TMP_121:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_120]] - // CHECK: %[[TMP_122:.*]] = mhlo.multiply %[[TMP_118]], %[[TMP_121]] - // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.add %[[TMP_0]], %[[TMP_123]] - // CHECK: %[[TMP_125:.*]] = mhlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_126:.*]] = mhlo.add %[[TMP_0]], %[[TMP_125]] - // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_124]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.constant dense<3.30687835E-5> - // CHECK: %[[TMP_129:.*]] = mhlo.add %[[TMP_122]], %[[TMP_128]] - // CHECK: %[[TMP_130:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_129]] - // CHECK: %[[TMP_131:.*]] = mhlo.multiply %[[TMP_127]], %[[TMP_130]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_133:.*]] = mhlo.add %[[TMP_0]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_135:.*]] = mhlo.add %[[TMP_0]], %[[TMP_134]] - // CHECK: %[[TMP_136:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.constant dense<-0.00138888892> - // CHECK: %[[TMP_138:.*]] = mhlo.add %[[TMP_131]], %[[TMP_137]] - // CHECK: %[[TMP_139:.*]] = mhlo.multiply %[[TMP_41]], %[[TMP_138]] - // CHECK: %[[TMP_140:.*]] = mhlo.multiply %[[TMP_136]], %[[TMP_139]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_142:.*]] = mhlo.divide %[[TMP_0]], %[[TMP_33]] - // CHECK: %[[TMP_143:.*]] = mhlo.constant dense<0.0833333358> - // CHECK: %[[TMP_144:.*]] = mhlo.add %[[TMP_143]], %[[TMP_140]] - // CHECK: %[[TMP_145:.*]] = mhlo.multiply %[[TMP_142]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = mhlo.add %[[TMP_141]], %[[TMP_145]] - // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_146]] - // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_39]], %[[TMP_147]] + // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] + // CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] + // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_36]], %[[TMP_37]] + // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_33]] + // CHECK: %[[TMP_40:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_39]] + // CHECK: %[[TMP_41:.*]] = mhlo.constant dense<2.200000e+01> + // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_0]], %[[TMP_41]] + // CHECK: %[[TMP_43:.*]] = mhlo.constant dense<2.100000e+01> + // CHECK: %[[TMP_44:.*]] = mhlo.add %[[TMP_0]], %[[TMP_43]] + // CHECK: %[[TMP_45:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_44]] + // CHECK: %[[TMP_46:.*]] = mhlo.constant dense<-1.39544646E-19> + // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_2]], %[[TMP_46]] + // CHECK: %[[TMP_48:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_47]] + // CHECK: %[[TMP_49:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_48]] + // CHECK: %[[TMP_50:.*]] = mhlo.constant dense<2.000000e+01> + // CHECK: %[[TMP_51:.*]] = mhlo.add %[[TMP_0]], %[[TMP_50]] + // CHECK: %[[TMP_52:.*]] = mhlo.constant dense<1.900000e+01> + // CHECK: %[[TMP_53:.*]] = mhlo.add %[[TMP_0]], %[[TMP_52]] + // CHECK: %[[TMP_54:.*]] = mhlo.multiply %[[TMP_51]], %[[TMP_53]] + // CHECK: %[[TMP_55:.*]] = mhlo.constant dense<5.50900303E-18> + // CHECK: %[[TMP_56:.*]] = mhlo.add %[[TMP_49]], %[[TMP_55]] + // CHECK: %[[TMP_57:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_56]] + // CHECK: %[[TMP_58:.*]] = mhlo.multiply %[[TMP_54]], %[[TMP_57]] + // CHECK: %[[TMP_59:.*]] = mhlo.constant dense<1.800000e+01> + // CHECK: %[[TMP_60:.*]] = mhlo.add %[[TMP_0]], %[[TMP_59]] + // CHECK: %[[TMP_61:.*]] = mhlo.constant dense<1.700000e+01> + // CHECK: %[[TMP_62:.*]] = mhlo.add %[[TMP_0]], %[[TMP_61]] + // CHECK: %[[TMP_63:.*]] = mhlo.multiply %[[TMP_60]], %[[TMP_62]] + // CHECK: %[[TMP_64:.*]] = mhlo.constant dense<-2.17486866E-16> + // CHECK: %[[TMP_65:.*]] = mhlo.add %[[TMP_58]], %[[TMP_64]] + // CHECK: %[[TMP_66:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_65]] + // CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_63]], %[[TMP_66]] + // CHECK: %[[TMP_68:.*]] = mhlo.constant dense<1.600000e+01> + // CHECK: %[[TMP_69:.*]] = mhlo.add %[[TMP_0]], %[[TMP_68]] + // CHECK: %[[TMP_70:.*]] = mhlo.constant dense<1.500000e+01> + // CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_0]], %[[TMP_70]] + // CHECK: %[[TMP_72:.*]] = mhlo.multiply %[[TMP_69]], %[[TMP_71]] + // CHECK: %[[TMP_73:.*]] = mhlo.constant dense<8.58606213E-15> + // CHECK: %[[TMP_74:.*]] = mhlo.add %[[TMP_67]], %[[TMP_73]] + // CHECK: %[[TMP_75:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_74]] + // CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[TMP_72]], %[[TMP_75]] + // CHECK: %[[TMP_77:.*]] = mhlo.constant dense<1.400000e+01> + // CHECK: %[[TMP_78:.*]] = mhlo.add %[[TMP_0]], %[[TMP_77]] + // CHECK: %[[TMP_79:.*]] = mhlo.constant dense<1.300000e+01> + // CHECK: %[[TMP_80:.*]] = mhlo.add %[[TMP_0]], %[[TMP_79]] + // CHECK: %[[TMP_81:.*]] = mhlo.multiply %[[TMP_78]], %[[TMP_80]] + // CHECK: %[[TMP_82:.*]] = mhlo.constant dense<-3.3896803E-13> + // CHECK: %[[TMP_83:.*]] = mhlo.add %[[TMP_76]], %[[TMP_82]] + // CHECK: %[[TMP_84:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_83]] + // CHECK: %[[TMP_85:.*]] = mhlo.multiply %[[TMP_81]], %[[TMP_84]] + // CHECK: %[[TMP_86:.*]] = mhlo.constant dense<1.200000e+01> + // CHECK: %[[TMP_87:.*]] = mhlo.add %[[TMP_0]], %[[TMP_86]] + // CHECK: %[[TMP_88:.*]] = mhlo.constant dense<1.100000e+01> + // CHECK: %[[TMP_89:.*]] = mhlo.add %[[TMP_0]], %[[TMP_88]] + // CHECK: %[[TMP_90:.*]] = mhlo.multiply %[[TMP_87]], %[[TMP_89]] + // CHECK: %[[TMP_91:.*]] = mhlo.constant dense<1.33825364E-11> + // CHECK: %[[TMP_92:.*]] = mhlo.add %[[TMP_85]], %[[TMP_91]] + // CHECK: %[[TMP_93:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.multiply %[[TMP_90]], %[[TMP_93]] + // CHECK: %[[TMP_95:.*]] = mhlo.constant dense<1.000000e+01> + // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_0]], %[[TMP_95]] + // CHECK: %[[TMP_97:.*]] = mhlo.constant dense<9.000000e+00> + // CHECK: %[[TMP_98:.*]] = mhlo.add %[[TMP_0]], %[[TMP_97]] + // CHECK: %[[TMP_99:.*]] = mhlo.multiply %[[TMP_96]], %[[TMP_98]] + // CHECK: %[[TMP_100:.*]] = mhlo.constant dense<-5.28419031E-10> + // CHECK: %[[TMP_101:.*]] = mhlo.add %[[TMP_94]], %[[TMP_100]] + // CHECK: %[[TMP_102:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_101]] + // CHECK: %[[TMP_103:.*]] = mhlo.multiply %[[TMP_99]], %[[TMP_102]] + // CHECK: %[[TMP_104:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_0]], %[[TMP_104]] + // CHECK: %[[TMP_106:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_107:.*]] = mhlo.add %[[TMP_0]], %[[TMP_106]] + // CHECK: %[[TMP_108:.*]] = mhlo.multiply %[[TMP_105]], %[[TMP_107]] + // CHECK: %[[TMP_109:.*]] = mhlo.constant dense<2.08767563E-8> + // CHECK: %[[TMP_110:.*]] = mhlo.add %[[TMP_103]], %[[TMP_109]] + // CHECK: %[[TMP_111:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_110]] + // CHECK: %[[TMP_112:.*]] = mhlo.multiply %[[TMP_108]], %[[TMP_111]] + // CHECK: %[[TMP_113:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_0]], %[[TMP_113]] + // CHECK: %[[TMP_115:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_116:.*]] = mhlo.add %[[TMP_0]], %[[TMP_115]] + // CHECK: %[[TMP_117:.*]] = mhlo.multiply %[[TMP_114]], %[[TMP_116]] + // CHECK: %[[TMP_118:.*]] = mhlo.constant dense<-8.26719599E-7> + // CHECK: %[[TMP_119:.*]] = mhlo.add %[[TMP_112]], %[[TMP_118]] + // CHECK: %[[TMP_120:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_119]] + // CHECK: %[[TMP_121:.*]] = mhlo.multiply %[[TMP_117]], %[[TMP_120]] + // CHECK: %[[TMP_122:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_123:.*]] = mhlo.add %[[TMP_0]], %[[TMP_122]] + // CHECK: %[[TMP_124:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_125:.*]] = mhlo.add %[[TMP_0]], %[[TMP_124]] + // CHECK: %[[TMP_126:.*]] = mhlo.multiply %[[TMP_123]], %[[TMP_125]] + // CHECK: %[[TMP_127:.*]] = mhlo.constant dense<3.30687835E-5> + // CHECK: %[[TMP_128:.*]] = mhlo.add %[[TMP_121]], %[[TMP_127]] + // CHECK: %[[TMP_129:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_128]] + // CHECK: %[[TMP_130:.*]] = mhlo.multiply %[[TMP_126]], %[[TMP_129]] + // CHECK: %[[TMP_131:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_132:.*]] = mhlo.add %[[TMP_0]], %[[TMP_131]] + // CHECK: %[[TMP_133:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_134:.*]] = mhlo.add %[[TMP_0]], %[[TMP_133]] + // CHECK: %[[TMP_135:.*]] = mhlo.multiply %[[TMP_132]], %[[TMP_134]] + // CHECK: %[[TMP_136:.*]] = mhlo.constant dense<-0.00138888892> + // CHECK: %[[TMP_137:.*]] = mhlo.add %[[TMP_130]], %[[TMP_136]] + // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_40]], %[[TMP_137]] + // CHECK: %[[TMP_139:.*]] = mhlo.multiply %[[TMP_135]], %[[TMP_138]] + // CHECK: %[[TMP_140:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_141:.*]] = mhlo.divide %[[TMP_0]], %[[TMP_33]] + // CHECK: %[[TMP_142:.*]] = mhlo.constant dense<0.0833333358> + // CHECK: %[[TMP_143:.*]] = mhlo.add %[[TMP_142]], %[[TMP_139]] + // CHECK: %[[TMP_144:.*]] = mhlo.multiply %[[TMP_141]], %[[TMP_143]] + // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_140]], %[[TMP_144]] + // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_145]] + // CHECK: %[[TMP_147:.*]] = mhlo.add %[[TMP_32]], %[[TMP_38]] + // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_147]], %[[TMP_146]] // CHECK: %[[TMP_149:.*]] = mhlo.abs %[[TMP_34]] // CHECK: %[[TMP_150:.*]] = mhlo.abs %[[TMP_32]] // CHECK: %[[TMP_151:.*]] = mhlo.constant dense<1.401300e-45> @@ -1282,7 +1282,7 @@ func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_172:.*]] = mhlo.and %[[TMP_169]], %[[TMP_171]] : tensor // CHECK: %[[TMP_173:.*]] = mhlo.select %[[TMP_172]], %[[TMP_163]], %[[TMP_155]] // CHECK: %[[TMP_174:.*]] = mhlo.select %[[TMP_166]], %[[TMP_173]], %[[TMP_162]] - // CHECK: %[[TMP_175:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_5]], NOTYPE + // CHECK: %[[TMP_175:.*]] = mhlo.compare EQ, %[[TMP_0]], %[[TMP_3]], NOTYPE // CHECK: %[[TMP_176:.*]] = mhlo.select %[[TMP_175]], %[[TMP_163]], %[[TMP_174]] // CHECK: %[[TMP_177:.*]] = mhlo.convert %[[TMP_176]] : (tensor) -> tensor %0 = chlo.zeta %arg0, %arg1 : tensor, tensor -> tensor @@ -1384,153 +1384,153 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<0x7F800000> // CHECK: %[[TMP_88:.*]] = mhlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] // CHECK: %[[TMP_89:.*]] = mhlo.exponential %[[TMP_88]] - // CHECK: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_91:.*]] = mhlo.negate %[[TMP_5]] - // CHECK: %[[TMP_92:.*]] = mhlo.power %[[ARG1]], %[[TMP_91]] - // CHECK: %[[TMP_93:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_93]] - // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_91]] - // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_92]], %[[TMP_95]] - // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_93]] - // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_91]] + // CHECK-DAG: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK-DAG: %[[TMP_91:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_92:.*]] = mhlo.negate %[[TMP_5]] + // CHECK: %[[TMP_93:.*]] = mhlo.power %[[ARG1]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_91]] + // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_92]] + // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_93]], %[[TMP_95]] + // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_91]] + // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_92]] // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_96]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_93]] - // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_91]] + // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_92]] // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_99]], %[[TMP_101]] - // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_93]] - // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_92]] // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_102]], %[[TMP_104]] - // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_93]] - // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_92]] // CHECK: %[[TMP_108:.*]] = mhlo.add %[[TMP_105]], %[[TMP_107]] - // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_93]] - // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_92]] // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_108]], %[[TMP_110]] - // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_93]] - // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_92]] // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_111]], %[[TMP_113]] - // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_93]] - // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_92]] // CHECK: %[[TMP_117:.*]] = mhlo.add %[[TMP_114]], %[[TMP_116]] - // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_93]] - // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_92]] // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_117]], %[[TMP_119]] - // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_93]] - // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_91]] + // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_125]], %[[TMP_124]] - // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_129:.*]] = mhlo.divide %[[TMP_93]], %[[TMP_128]] - // CHECK: %[[TMP_130:.*]] = mhlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_131:.*]] = mhlo.add %[[TMP_5]], %[[TMP_130]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_133:.*]] = mhlo.add %[[TMP_5]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_131]], %[[TMP_133]] - // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<-1.39544646E-19> - // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_90]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_136]] - // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_134]], %[[TMP_137]] - // CHECK: %[[TMP_139:.*]] = mhlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_140:.*]] = mhlo.add %[[TMP_5]], %[[TMP_139]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_142:.*]] = mhlo.add %[[TMP_5]], %[[TMP_141]] - // CHECK: %[[TMP_143:.*]] = mhlo.multiply %[[TMP_140]], %[[TMP_142]] - // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<5.50900303E-18> - // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_138]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_145]] - // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_143]], %[[TMP_146]] - // CHECK: %[[TMP_148:.*]] = mhlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_149:.*]] = mhlo.add %[[TMP_5]], %[[TMP_148]] - // CHECK: %[[TMP_150:.*]] = mhlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_151:.*]] = mhlo.add %[[TMP_5]], %[[TMP_150]] - // CHECK: %[[TMP_152:.*]] = mhlo.multiply %[[TMP_149]], %[[TMP_151]] - // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<-2.17486866E-16> - // CHECK: %[[TMP_154:.*]] = mhlo.add %[[TMP_147]], %[[TMP_153]] - // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_154]] - // CHECK: %[[TMP_156:.*]] = mhlo.multiply %[[TMP_152]], %[[TMP_155]] - // CHECK: %[[TMP_157:.*]] = mhlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_158:.*]] = mhlo.add %[[TMP_5]], %[[TMP_157]] - // CHECK: %[[TMP_159:.*]] = mhlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_160:.*]] = mhlo.add %[[TMP_5]], %[[TMP_159]] - // CHECK: %[[TMP_161:.*]] = mhlo.multiply %[[TMP_158]], %[[TMP_160]] - // CHECK: %[[TMP_162:.*]] = mhlo.constant dense<8.58606213E-15> - // CHECK: %[[TMP_163:.*]] = mhlo.add %[[TMP_156]], %[[TMP_162]] - // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_163]] - // CHECK: %[[TMP_165:.*]] = mhlo.multiply %[[TMP_161]], %[[TMP_164]] - // CHECK: %[[TMP_166:.*]] = mhlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_167:.*]] = mhlo.add %[[TMP_5]], %[[TMP_166]] - // CHECK: %[[TMP_168:.*]] = mhlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_169:.*]] = mhlo.add %[[TMP_5]], %[[TMP_168]] - // CHECK: %[[TMP_170:.*]] = mhlo.multiply %[[TMP_167]], %[[TMP_169]] - // CHECK: %[[TMP_171:.*]] = mhlo.constant dense<-3.3896803E-13> - // CHECK: %[[TMP_172:.*]] = mhlo.add %[[TMP_165]], %[[TMP_171]] - // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_172]] - // CHECK: %[[TMP_174:.*]] = mhlo.multiply %[[TMP_170]], %[[TMP_173]] - // CHECK: %[[TMP_175:.*]] = mhlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_176:.*]] = mhlo.add %[[TMP_5]], %[[TMP_175]] - // CHECK: %[[TMP_177:.*]] = mhlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_178:.*]] = mhlo.add %[[TMP_5]], %[[TMP_177]] - // CHECK: %[[TMP_179:.*]] = mhlo.multiply %[[TMP_176]], %[[TMP_178]] - // CHECK: %[[TMP_180:.*]] = mhlo.constant dense<1.33825364E-11> - // CHECK: %[[TMP_181:.*]] = mhlo.add %[[TMP_174]], %[[TMP_180]] - // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_181]] - // CHECK: %[[TMP_183:.*]] = mhlo.multiply %[[TMP_179]], %[[TMP_182]] - // CHECK: %[[TMP_184:.*]] = mhlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_185:.*]] = mhlo.add %[[TMP_5]], %[[TMP_184]] - // CHECK: %[[TMP_186:.*]] = mhlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_187:.*]] = mhlo.add %[[TMP_5]], %[[TMP_186]] - // CHECK: %[[TMP_188:.*]] = mhlo.multiply %[[TMP_185]], %[[TMP_187]] - // CHECK: %[[TMP_189:.*]] = mhlo.constant dense<-5.28419031E-10> - // CHECK: %[[TMP_190:.*]] = mhlo.add %[[TMP_183]], %[[TMP_189]] - // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_190]] - // CHECK: %[[TMP_192:.*]] = mhlo.multiply %[[TMP_188]], %[[TMP_191]] - // CHECK: %[[TMP_193:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_194:.*]] = mhlo.add %[[TMP_5]], %[[TMP_193]] - // CHECK: %[[TMP_195:.*]] = mhlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_196:.*]] = mhlo.add %[[TMP_5]], %[[TMP_195]] - // CHECK: %[[TMP_197:.*]] = mhlo.multiply %[[TMP_194]], %[[TMP_196]] - // CHECK: %[[TMP_198:.*]] = mhlo.constant dense<2.08767563E-8> - // CHECK: %[[TMP_199:.*]] = mhlo.add %[[TMP_192]], %[[TMP_198]] - // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_199]] - // CHECK: %[[TMP_201:.*]] = mhlo.multiply %[[TMP_197]], %[[TMP_200]] - // CHECK: %[[TMP_202:.*]] = mhlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_203:.*]] = mhlo.add %[[TMP_5]], %[[TMP_202]] - // CHECK: %[[TMP_204:.*]] = mhlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_205:.*]] = mhlo.add %[[TMP_5]], %[[TMP_204]] - // CHECK: %[[TMP_206:.*]] = mhlo.multiply %[[TMP_203]], %[[TMP_205]] - // CHECK: %[[TMP_207:.*]] = mhlo.constant dense<-8.26719599E-7> - // CHECK: %[[TMP_208:.*]] = mhlo.add %[[TMP_201]], %[[TMP_207]] - // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_208]] - // CHECK: %[[TMP_210:.*]] = mhlo.multiply %[[TMP_206]], %[[TMP_209]] - // CHECK: %[[TMP_211:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_212:.*]] = mhlo.add %[[TMP_5]], %[[TMP_211]] - // CHECK: %[[TMP_213:.*]] = mhlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_214:.*]] = mhlo.add %[[TMP_5]], %[[TMP_213]] - // CHECK: %[[TMP_215:.*]] = mhlo.multiply %[[TMP_212]], %[[TMP_214]] - // CHECK: %[[TMP_216:.*]] = mhlo.constant dense<3.30687835E-5> - // CHECK: %[[TMP_217:.*]] = mhlo.add %[[TMP_210]], %[[TMP_216]] - // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_217]] - // CHECK: %[[TMP_219:.*]] = mhlo.multiply %[[TMP_215]], %[[TMP_218]] - // CHECK: %[[TMP_220:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_221:.*]] = mhlo.add %[[TMP_5]], %[[TMP_220]] - // CHECK: %[[TMP_222:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_223:.*]] = mhlo.add %[[TMP_5]], %[[TMP_222]] - // CHECK: %[[TMP_224:.*]] = mhlo.multiply %[[TMP_221]], %[[TMP_223]] - // CHECK: %[[TMP_225:.*]] = mhlo.constant dense<-0.00138888892> - // CHECK: %[[TMP_226:.*]] = mhlo.add %[[TMP_219]], %[[TMP_225]] - // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_226]] - // CHECK: %[[TMP_228:.*]] = mhlo.multiply %[[TMP_224]], %[[TMP_227]] - // CHECK: %[[TMP_229:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_230:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] - // CHECK: %[[TMP_231:.*]] = mhlo.constant dense<0.0833333358> - // CHECK: %[[TMP_232:.*]] = mhlo.add %[[TMP_231]], %[[TMP_228]] - // CHECK: %[[TMP_233:.*]] = mhlo.multiply %[[TMP_230]], %[[TMP_232]] - // CHECK: %[[TMP_234:.*]] = mhlo.add %[[TMP_229]], %[[TMP_233]] - // CHECK: %[[TMP_235:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_234]] - // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_127]], %[[TMP_235]] + // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] + // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] + // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] + // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<2.200000e+01> + // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_5]], %[[TMP_129]] + // CHECK: %[[TMP_131:.*]] = mhlo.constant dense<2.100000e+01> + // CHECK: %[[TMP_132:.*]] = mhlo.add %[[TMP_5]], %[[TMP_131]] + // CHECK: %[[TMP_133:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_132]] + // CHECK: %[[TMP_134:.*]] = mhlo.constant dense<-1.39544646E-19> + // CHECK: %[[TMP_135:.*]] = mhlo.add %[[TMP_90]], %[[TMP_134]] + // CHECK: %[[TMP_136:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_135]] + // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_136]] + // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<2.000000e+01> + // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_5]], %[[TMP_138]] + // CHECK: %[[TMP_140:.*]] = mhlo.constant dense<1.900000e+01> + // CHECK: %[[TMP_141:.*]] = mhlo.add %[[TMP_5]], %[[TMP_140]] + // CHECK: %[[TMP_142:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_141]] + // CHECK: %[[TMP_143:.*]] = mhlo.constant dense<5.50900303E-18> + // CHECK: %[[TMP_144:.*]] = mhlo.add %[[TMP_137]], %[[TMP_143]] + // CHECK: %[[TMP_145:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_144]] + // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_142]], %[[TMP_145]] + // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<1.800000e+01> + // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_5]], %[[TMP_147]] + // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<1.700000e+01> + // CHECK: %[[TMP_150:.*]] = mhlo.add %[[TMP_5]], %[[TMP_149]] + // CHECK: %[[TMP_151:.*]] = mhlo.multiply %[[TMP_148]], %[[TMP_150]] + // CHECK: %[[TMP_152:.*]] = mhlo.constant dense<-2.17486866E-16> + // CHECK: %[[TMP_153:.*]] = mhlo.add %[[TMP_146]], %[[TMP_152]] + // CHECK: %[[TMP_154:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_153]] + // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_151]], %[[TMP_154]] + // CHECK: %[[TMP_156:.*]] = mhlo.constant dense<1.600000e+01> + // CHECK: %[[TMP_157:.*]] = mhlo.add %[[TMP_5]], %[[TMP_156]] + // CHECK: %[[TMP_158:.*]] = mhlo.constant dense<1.500000e+01> + // CHECK: %[[TMP_159:.*]] = mhlo.add %[[TMP_5]], %[[TMP_158]] + // CHECK: %[[TMP_160:.*]] = mhlo.multiply %[[TMP_157]], %[[TMP_159]] + // CHECK: %[[TMP_161:.*]] = mhlo.constant dense<8.58606213E-15> + // CHECK: %[[TMP_162:.*]] = mhlo.add %[[TMP_155]], %[[TMP_161]] + // CHECK: %[[TMP_163:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_162]] + // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_160]], %[[TMP_163]] + // CHECK: %[[TMP_165:.*]] = mhlo.constant dense<1.400000e+01> + // CHECK: %[[TMP_166:.*]] = mhlo.add %[[TMP_5]], %[[TMP_165]] + // CHECK: %[[TMP_167:.*]] = mhlo.constant dense<1.300000e+01> + // CHECK: %[[TMP_168:.*]] = mhlo.add %[[TMP_5]], %[[TMP_167]] + // CHECK: %[[TMP_169:.*]] = mhlo.multiply %[[TMP_166]], %[[TMP_168]] + // CHECK: %[[TMP_170:.*]] = mhlo.constant dense<-3.3896803E-13> + // CHECK: %[[TMP_171:.*]] = mhlo.add %[[TMP_164]], %[[TMP_170]] + // CHECK: %[[TMP_172:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_171]] + // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_169]], %[[TMP_172]] + // CHECK: %[[TMP_174:.*]] = mhlo.constant dense<1.200000e+01> + // CHECK: %[[TMP_175:.*]] = mhlo.add %[[TMP_5]], %[[TMP_174]] + // CHECK: %[[TMP_176:.*]] = mhlo.constant dense<1.100000e+01> + // CHECK: %[[TMP_177:.*]] = mhlo.add %[[TMP_5]], %[[TMP_176]] + // CHECK: %[[TMP_178:.*]] = mhlo.multiply %[[TMP_175]], %[[TMP_177]] + // CHECK: %[[TMP_179:.*]] = mhlo.constant dense<1.33825364E-11> + // CHECK: %[[TMP_180:.*]] = mhlo.add %[[TMP_173]], %[[TMP_179]] + // CHECK: %[[TMP_181:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_180]] + // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_178]], %[[TMP_181]] + // CHECK: %[[TMP_183:.*]] = mhlo.constant dense<1.000000e+01> + // CHECK: %[[TMP_184:.*]] = mhlo.add %[[TMP_5]], %[[TMP_183]] + // CHECK: %[[TMP_185:.*]] = mhlo.constant dense<9.000000e+00> + // CHECK: %[[TMP_186:.*]] = mhlo.add %[[TMP_5]], %[[TMP_185]] + // CHECK: %[[TMP_187:.*]] = mhlo.multiply %[[TMP_184]], %[[TMP_186]] + // CHECK: %[[TMP_188:.*]] = mhlo.constant dense<-5.28419031E-10> + // CHECK: %[[TMP_189:.*]] = mhlo.add %[[TMP_182]], %[[TMP_188]] + // CHECK: %[[TMP_190:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_189]] + // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_187]], %[[TMP_190]] + // CHECK: %[[TMP_192:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_193:.*]] = mhlo.add %[[TMP_5]], %[[TMP_192]] + // CHECK: %[[TMP_194:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_195:.*]] = mhlo.add %[[TMP_5]], %[[TMP_194]] + // CHECK: %[[TMP_196:.*]] = mhlo.multiply %[[TMP_193]], %[[TMP_195]] + // CHECK: %[[TMP_197:.*]] = mhlo.constant dense<2.08767563E-8> + // CHECK: %[[TMP_198:.*]] = mhlo.add %[[TMP_191]], %[[TMP_197]] + // CHECK: %[[TMP_199:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_198]] + // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_196]], %[[TMP_199]] + // CHECK: %[[TMP_201:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_202:.*]] = mhlo.add %[[TMP_5]], %[[TMP_201]] + // CHECK: %[[TMP_203:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_204:.*]] = mhlo.add %[[TMP_5]], %[[TMP_203]] + // CHECK: %[[TMP_205:.*]] = mhlo.multiply %[[TMP_202]], %[[TMP_204]] + // CHECK: %[[TMP_206:.*]] = mhlo.constant dense<-8.26719599E-7> + // CHECK: %[[TMP_207:.*]] = mhlo.add %[[TMP_200]], %[[TMP_206]] + // CHECK: %[[TMP_208:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_207]] + // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_205]], %[[TMP_208]] + // CHECK: %[[TMP_210:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_211:.*]] = mhlo.add %[[TMP_5]], %[[TMP_210]] + // CHECK: %[[TMP_212:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_213:.*]] = mhlo.add %[[TMP_5]], %[[TMP_212]] + // CHECK: %[[TMP_214:.*]] = mhlo.multiply %[[TMP_211]], %[[TMP_213]] + // CHECK: %[[TMP_215:.*]] = mhlo.constant dense<3.30687835E-5> + // CHECK: %[[TMP_216:.*]] = mhlo.add %[[TMP_209]], %[[TMP_215]] + // CHECK: %[[TMP_217:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_216]] + // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_214]], %[[TMP_217]] + // CHECK: %[[TMP_219:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_220:.*]] = mhlo.add %[[TMP_5]], %[[TMP_219]] + // CHECK: %[[TMP_221:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_222:.*]] = mhlo.add %[[TMP_5]], %[[TMP_221]] + // CHECK: %[[TMP_223:.*]] = mhlo.multiply %[[TMP_220]], %[[TMP_222]] + // CHECK: %[[TMP_224:.*]] = mhlo.constant dense<-0.00138888892> + // CHECK: %[[TMP_225:.*]] = mhlo.add %[[TMP_218]], %[[TMP_224]] + // CHECK: %[[TMP_226:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_225]] + // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_223]], %[[TMP_226]] + // CHECK: %[[TMP_228:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_229:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] + // CHECK: %[[TMP_230:.*]] = mhlo.constant dense<0.0833333358> + // CHECK: %[[TMP_231:.*]] = mhlo.add %[[TMP_230]], %[[TMP_227]] + // CHECK: %[[TMP_232:.*]] = mhlo.multiply %[[TMP_229]], %[[TMP_231]] + // CHECK: %[[TMP_233:.*]] = mhlo.add %[[TMP_228]], %[[TMP_232]] + // CHECK: %[[TMP_234:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_233]] + // CHECK: %[[TMP_235:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] + // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_235]], %[[TMP_234]] // CHECK: %[[TMP_237:.*]] = mhlo.abs %[[TMP_122]] // CHECK: %[[TMP_238:.*]] = mhlo.abs %[[TMP_120]] // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<1.401300e-45> @@ -1557,7 +1557,7 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_260:.*]] = mhlo.and %[[TMP_257]], %[[TMP_259]] // CHECK: %[[TMP_261:.*]] = mhlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] // CHECK: %[[TMP_262:.*]] = mhlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] - // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE + // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE // CHECK: %[[TMP_264:.*]] = mhlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] // CHECK: %[[TMP_265:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]] // CHECK: %[[TMP_266:.*]] = mhlo.multiply %[[TMP_265]], %[[TMP_264]] @@ -1771,153 +1771,153 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_87:.*]] = mhlo.constant dense<0x7FF0000000000000> // CHECK: %[[TMP_88:.*]] = mhlo.select %[[TMP_86]], %[[TMP_87]], %[[TMP_83]] // CHECK: %[[TMP_89:.*]] = mhlo.exponential %[[TMP_88]] - // CHECK: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> - // CHECK: %[[TMP_91:.*]] = mhlo.negate %[[TMP_5]] - // CHECK: %[[TMP_92:.*]] = mhlo.power %[[ARG1]], %[[TMP_91]] - // CHECK: %[[TMP_93:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_93]] - // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_91]] - // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_92]], %[[TMP_95]] - // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_93]] - // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_91]] + // CHECK-DAG: %[[TMP_90:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK-DAG: %[[TMP_91:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_92:.*]] = mhlo.negate %[[TMP_5]] + // CHECK: %[[TMP_93:.*]] = mhlo.power %[[ARG1]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.add %[[ARG1]], %[[TMP_91]] + // CHECK: %[[TMP_95:.*]] = mhlo.power %[[TMP_94]], %[[TMP_92]] + // CHECK: %[[TMP_96:.*]] = mhlo.add %[[TMP_93]], %[[TMP_95]] + // CHECK: %[[TMP_97:.*]] = mhlo.add %[[TMP_94]], %[[TMP_91]] + // CHECK: %[[TMP_98:.*]] = mhlo.power %[[TMP_97]], %[[TMP_92]] // CHECK: %[[TMP_99:.*]] = mhlo.add %[[TMP_96]], %[[TMP_98]] - // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_93]] - // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_100:.*]] = mhlo.add %[[TMP_97]], %[[TMP_91]] + // CHECK: %[[TMP_101:.*]] = mhlo.power %[[TMP_100]], %[[TMP_92]] // CHECK: %[[TMP_102:.*]] = mhlo.add %[[TMP_99]], %[[TMP_101]] - // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_93]] - // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_103:.*]] = mhlo.add %[[TMP_100]], %[[TMP_91]] + // CHECK: %[[TMP_104:.*]] = mhlo.power %[[TMP_103]], %[[TMP_92]] // CHECK: %[[TMP_105:.*]] = mhlo.add %[[TMP_102]], %[[TMP_104]] - // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_93]] - // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_106:.*]] = mhlo.add %[[TMP_103]], %[[TMP_91]] + // CHECK: %[[TMP_107:.*]] = mhlo.power %[[TMP_106]], %[[TMP_92]] // CHECK: %[[TMP_108:.*]] = mhlo.add %[[TMP_105]], %[[TMP_107]] - // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_93]] - // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_109:.*]] = mhlo.add %[[TMP_106]], %[[TMP_91]] + // CHECK: %[[TMP_110:.*]] = mhlo.power %[[TMP_109]], %[[TMP_92]] // CHECK: %[[TMP_111:.*]] = mhlo.add %[[TMP_108]], %[[TMP_110]] - // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_93]] - // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_112:.*]] = mhlo.add %[[TMP_109]], %[[TMP_91]] + // CHECK: %[[TMP_113:.*]] = mhlo.power %[[TMP_112]], %[[TMP_92]] // CHECK: %[[TMP_114:.*]] = mhlo.add %[[TMP_111]], %[[TMP_113]] - // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_93]] - // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_115:.*]] = mhlo.add %[[TMP_112]], %[[TMP_91]] + // CHECK: %[[TMP_116:.*]] = mhlo.power %[[TMP_115]], %[[TMP_92]] // CHECK: %[[TMP_117:.*]] = mhlo.add %[[TMP_114]], %[[TMP_116]] - // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_93]] - // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_118:.*]] = mhlo.add %[[TMP_115]], %[[TMP_91]] + // CHECK: %[[TMP_119:.*]] = mhlo.power %[[TMP_118]], %[[TMP_92]] // CHECK: %[[TMP_120:.*]] = mhlo.add %[[TMP_117]], %[[TMP_119]] - // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_93]] - // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_91]] + // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] + // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] - // CHECK: %[[TMP_125:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_125]], %[[TMP_124]] - // CHECK: %[[TMP_127:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] - // CHECK: %[[TMP_128:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] - // CHECK: %[[TMP_129:.*]] = mhlo.divide %[[TMP_93]], %[[TMP_128]] - // CHECK: %[[TMP_130:.*]] = mhlo.constant dense<2.200000e+01> - // CHECK: %[[TMP_131:.*]] = mhlo.add %[[TMP_5]], %[[TMP_130]] - // CHECK: %[[TMP_132:.*]] = mhlo.constant dense<2.100000e+01> - // CHECK: %[[TMP_133:.*]] = mhlo.add %[[TMP_5]], %[[TMP_132]] - // CHECK: %[[TMP_134:.*]] = mhlo.multiply %[[TMP_131]], %[[TMP_133]] - // CHECK: %[[TMP_135:.*]] = mhlo.constant dense<-1.3954464685812522E-19> - // CHECK: %[[TMP_136:.*]] = mhlo.add %[[TMP_90]], %[[TMP_135]] - // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_136]] - // CHECK: %[[TMP_138:.*]] = mhlo.multiply %[[TMP_134]], %[[TMP_137]] - // CHECK: %[[TMP_139:.*]] = mhlo.constant dense<2.000000e+01> - // CHECK: %[[TMP_140:.*]] = mhlo.add %[[TMP_5]], %[[TMP_139]] - // CHECK: %[[TMP_141:.*]] = mhlo.constant dense<1.900000e+01> - // CHECK: %[[TMP_142:.*]] = mhlo.add %[[TMP_5]], %[[TMP_141]] - // CHECK: %[[TMP_143:.*]] = mhlo.multiply %[[TMP_140]], %[[TMP_142]] - // CHECK: %[[TMP_144:.*]] = mhlo.constant dense<5.5090028283602295E-18> - // CHECK: %[[TMP_145:.*]] = mhlo.add %[[TMP_138]], %[[TMP_144]] - // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_145]] - // CHECK: %[[TMP_147:.*]] = mhlo.multiply %[[TMP_143]], %[[TMP_146]] - // CHECK: %[[TMP_148:.*]] = mhlo.constant dense<1.800000e+01> - // CHECK: %[[TMP_149:.*]] = mhlo.add %[[TMP_5]], %[[TMP_148]] - // CHECK: %[[TMP_150:.*]] = mhlo.constant dense<1.700000e+01> - // CHECK: %[[TMP_151:.*]] = mhlo.add %[[TMP_5]], %[[TMP_150]] - // CHECK: %[[TMP_152:.*]] = mhlo.multiply %[[TMP_149]], %[[TMP_151]] - // CHECK: %[[TMP_153:.*]] = mhlo.constant dense<-2.1748686985580617E-16> - // CHECK: %[[TMP_154:.*]] = mhlo.add %[[TMP_147]], %[[TMP_153]] - // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_154]] - // CHECK: %[[TMP_156:.*]] = mhlo.multiply %[[TMP_152]], %[[TMP_155]] - // CHECK: %[[TMP_157:.*]] = mhlo.constant dense<1.600000e+01> - // CHECK: %[[TMP_158:.*]] = mhlo.add %[[TMP_5]], %[[TMP_157]] - // CHECK: %[[TMP_159:.*]] = mhlo.constant dense<1.500000e+01> - // CHECK: %[[TMP_160:.*]] = mhlo.add %[[TMP_5]], %[[TMP_159]] - // CHECK: %[[TMP_161:.*]] = mhlo.multiply %[[TMP_158]], %[[TMP_160]] - // CHECK: %[[TMP_162:.*]] = mhlo.constant dense<8.5860620562778452E-15> - // CHECK: %[[TMP_163:.*]] = mhlo.add %[[TMP_156]], %[[TMP_162]] - // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_163]] - // CHECK: %[[TMP_165:.*]] = mhlo.multiply %[[TMP_161]], %[[TMP_164]] - // CHECK: %[[TMP_166:.*]] = mhlo.constant dense<1.400000e+01> - // CHECK: %[[TMP_167:.*]] = mhlo.add %[[TMP_5]], %[[TMP_166]] - // CHECK: %[[TMP_168:.*]] = mhlo.constant dense<1.300000e+01> - // CHECK: %[[TMP_169:.*]] = mhlo.add %[[TMP_5]], %[[TMP_168]] - // CHECK: %[[TMP_170:.*]] = mhlo.multiply %[[TMP_167]], %[[TMP_169]] - // CHECK: %[[TMP_171:.*]] = mhlo.constant dense<-3.3896802963225832E-13> - // CHECK: %[[TMP_172:.*]] = mhlo.add %[[TMP_165]], %[[TMP_171]] - // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_172]] - // CHECK: %[[TMP_174:.*]] = mhlo.multiply %[[TMP_170]], %[[TMP_173]] - // CHECK: %[[TMP_175:.*]] = mhlo.constant dense<1.200000e+01> - // CHECK: %[[TMP_176:.*]] = mhlo.add %[[TMP_5]], %[[TMP_175]] - // CHECK: %[[TMP_177:.*]] = mhlo.constant dense<1.100000e+01> - // CHECK: %[[TMP_178:.*]] = mhlo.add %[[TMP_5]], %[[TMP_177]] - // CHECK: %[[TMP_179:.*]] = mhlo.multiply %[[TMP_176]], %[[TMP_178]] - // CHECK: %[[TMP_180:.*]] = mhlo.constant dense<1.3382536530684679E-11> - // CHECK: %[[TMP_181:.*]] = mhlo.add %[[TMP_174]], %[[TMP_180]] - // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_181]] - // CHECK: %[[TMP_183:.*]] = mhlo.multiply %[[TMP_179]], %[[TMP_182]] - // CHECK: %[[TMP_184:.*]] = mhlo.constant dense<1.000000e+01> - // CHECK: %[[TMP_185:.*]] = mhlo.add %[[TMP_5]], %[[TMP_184]] - // CHECK: %[[TMP_186:.*]] = mhlo.constant dense<9.000000e+00> - // CHECK: %[[TMP_187:.*]] = mhlo.add %[[TMP_5]], %[[TMP_186]] - // CHECK: %[[TMP_188:.*]] = mhlo.multiply %[[TMP_185]], %[[TMP_187]] - // CHECK: %[[TMP_189:.*]] = mhlo.constant dense<-5.2841901386874932E-10> - // CHECK: %[[TMP_190:.*]] = mhlo.add %[[TMP_183]], %[[TMP_189]] - // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_190]] - // CHECK: %[[TMP_192:.*]] = mhlo.multiply %[[TMP_188]], %[[TMP_191]] - // CHECK: %[[TMP_193:.*]] = mhlo.constant dense<8.000000e+00> - // CHECK: %[[TMP_194:.*]] = mhlo.add %[[TMP_5]], %[[TMP_193]] - // CHECK: %[[TMP_195:.*]] = mhlo.constant dense<7.000000e+00> - // CHECK: %[[TMP_196:.*]] = mhlo.add %[[TMP_5]], %[[TMP_195]] - // CHECK: %[[TMP_197:.*]] = mhlo.multiply %[[TMP_194]], %[[TMP_196]] - // CHECK: %[[TMP_198:.*]] = mhlo.constant dense<2.08767569878681E-8> - // CHECK: %[[TMP_199:.*]] = mhlo.add %[[TMP_192]], %[[TMP_198]] - // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_199]] - // CHECK: %[[TMP_201:.*]] = mhlo.multiply %[[TMP_197]], %[[TMP_200]] - // CHECK: %[[TMP_202:.*]] = mhlo.constant dense<6.000000e+00> - // CHECK: %[[TMP_203:.*]] = mhlo.add %[[TMP_5]], %[[TMP_202]] - // CHECK: %[[TMP_204:.*]] = mhlo.constant dense<5.000000e+00> - // CHECK: %[[TMP_205:.*]] = mhlo.add %[[TMP_5]], %[[TMP_204]] - // CHECK: %[[TMP_206:.*]] = mhlo.multiply %[[TMP_203]], %[[TMP_205]] - // CHECK: %[[TMP_207:.*]] = mhlo.constant dense<-8.2671957671957675E-7> - // CHECK: %[[TMP_208:.*]] = mhlo.add %[[TMP_201]], %[[TMP_207]] - // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_208]] - // CHECK: %[[TMP_210:.*]] = mhlo.multiply %[[TMP_206]], %[[TMP_209]] - // CHECK: %[[TMP_211:.*]] = mhlo.constant dense<4.000000e+00> - // CHECK: %[[TMP_212:.*]] = mhlo.add %[[TMP_5]], %[[TMP_211]] - // CHECK: %[[TMP_213:.*]] = mhlo.constant dense<3.000000e+00> - // CHECK: %[[TMP_214:.*]] = mhlo.add %[[TMP_5]], %[[TMP_213]] - // CHECK: %[[TMP_215:.*]] = mhlo.multiply %[[TMP_212]], %[[TMP_214]] - // CHECK: %[[TMP_216:.*]] = mhlo.constant dense<3.3068783068783071E-5> - // CHECK: %[[TMP_217:.*]] = mhlo.add %[[TMP_210]], %[[TMP_216]] - // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_217]] - // CHECK: %[[TMP_219:.*]] = mhlo.multiply %[[TMP_215]], %[[TMP_218]] - // CHECK: %[[TMP_220:.*]] = mhlo.constant dense<2.000000e+00> - // CHECK: %[[TMP_221:.*]] = mhlo.add %[[TMP_5]], %[[TMP_220]] - // CHECK: %[[TMP_222:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_223:.*]] = mhlo.add %[[TMP_5]], %[[TMP_222]] - // CHECK: %[[TMP_224:.*]] = mhlo.multiply %[[TMP_221]], %[[TMP_223]] - // CHECK: %[[TMP_225:.*]] = mhlo.constant dense<-0.0013888888888888889> - // CHECK: %[[TMP_226:.*]] = mhlo.add %[[TMP_219]], %[[TMP_225]] - // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_129]], %[[TMP_226]] - // CHECK: %[[TMP_228:.*]] = mhlo.multiply %[[TMP_224]], %[[TMP_227]] - // CHECK: %[[TMP_229:.*]] = mhlo.constant dense<5.000000e-01> - // CHECK: %[[TMP_230:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] - // CHECK: %[[TMP_231:.*]] = mhlo.constant dense<0.083333333333333329> - // CHECK: %[[TMP_232:.*]] = mhlo.add %[[TMP_231]], %[[TMP_228]] - // CHECK: %[[TMP_233:.*]] = mhlo.multiply %[[TMP_230]], %[[TMP_232]] - // CHECK: %[[TMP_234:.*]] = mhlo.add %[[TMP_229]], %[[TMP_233]] - // CHECK: %[[TMP_235:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_234]] - // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_127]], %[[TMP_235]] + // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] + // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] + // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] + // CHECK: %[[TMP_129:.*]] = mhlo.constant dense<2.200000e+01> + // CHECK: %[[TMP_130:.*]] = mhlo.add %[[TMP_5]], %[[TMP_129]] + // CHECK: %[[TMP_131:.*]] = mhlo.constant dense<2.100000e+01> + // CHECK: %[[TMP_132:.*]] = mhlo.add %[[TMP_5]], %[[TMP_131]] + // CHECK: %[[TMP_133:.*]] = mhlo.multiply %[[TMP_130]], %[[TMP_132]] + // CHECK: %[[TMP_134:.*]] = mhlo.constant dense<-1.3954464685812522E-19> + // CHECK: %[[TMP_135:.*]] = mhlo.add %[[TMP_90]], %[[TMP_134]] + // CHECK: %[[TMP_136:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_135]] + // CHECK: %[[TMP_137:.*]] = mhlo.multiply %[[TMP_133]], %[[TMP_136]] + // CHECK: %[[TMP_138:.*]] = mhlo.constant dense<2.000000e+01> + // CHECK: %[[TMP_139:.*]] = mhlo.add %[[TMP_5]], %[[TMP_138]] + // CHECK: %[[TMP_140:.*]] = mhlo.constant dense<1.900000e+01> + // CHECK: %[[TMP_141:.*]] = mhlo.add %[[TMP_5]], %[[TMP_140]] + // CHECK: %[[TMP_142:.*]] = mhlo.multiply %[[TMP_139]], %[[TMP_141]] + // CHECK: %[[TMP_143:.*]] = mhlo.constant dense<5.5090028283602295E-18> + // CHECK: %[[TMP_144:.*]] = mhlo.add %[[TMP_137]], %[[TMP_143]] + // CHECK: %[[TMP_145:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_144]] + // CHECK: %[[TMP_146:.*]] = mhlo.multiply %[[TMP_142]], %[[TMP_145]] + // CHECK: %[[TMP_147:.*]] = mhlo.constant dense<1.800000e+01> + // CHECK: %[[TMP_148:.*]] = mhlo.add %[[TMP_5]], %[[TMP_147]] + // CHECK: %[[TMP_149:.*]] = mhlo.constant dense<1.700000e+01> + // CHECK: %[[TMP_150:.*]] = mhlo.add %[[TMP_5]], %[[TMP_149]] + // CHECK: %[[TMP_151:.*]] = mhlo.multiply %[[TMP_148]], %[[TMP_150]] + // CHECK: %[[TMP_152:.*]] = mhlo.constant dense<-2.1748686985580617E-16> + // CHECK: %[[TMP_153:.*]] = mhlo.add %[[TMP_146]], %[[TMP_152]] + // CHECK: %[[TMP_154:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_153]] + // CHECK: %[[TMP_155:.*]] = mhlo.multiply %[[TMP_151]], %[[TMP_154]] + // CHECK: %[[TMP_156:.*]] = mhlo.constant dense<1.600000e+01> + // CHECK: %[[TMP_157:.*]] = mhlo.add %[[TMP_5]], %[[TMP_156]] + // CHECK: %[[TMP_158:.*]] = mhlo.constant dense<1.500000e+01> + // CHECK: %[[TMP_159:.*]] = mhlo.add %[[TMP_5]], %[[TMP_158]] + // CHECK: %[[TMP_160:.*]] = mhlo.multiply %[[TMP_157]], %[[TMP_159]] + // CHECK: %[[TMP_161:.*]] = mhlo.constant dense<8.5860620562778452E-15> + // CHECK: %[[TMP_162:.*]] = mhlo.add %[[TMP_155]], %[[TMP_161]] + // CHECK: %[[TMP_163:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_162]] + // CHECK: %[[TMP_164:.*]] = mhlo.multiply %[[TMP_160]], %[[TMP_163]] + // CHECK: %[[TMP_165:.*]] = mhlo.constant dense<1.400000e+01> + // CHECK: %[[TMP_166:.*]] = mhlo.add %[[TMP_5]], %[[TMP_165]] + // CHECK: %[[TMP_167:.*]] = mhlo.constant dense<1.300000e+01> + // CHECK: %[[TMP_168:.*]] = mhlo.add %[[TMP_5]], %[[TMP_167]] + // CHECK: %[[TMP_169:.*]] = mhlo.multiply %[[TMP_166]], %[[TMP_168]] + // CHECK: %[[TMP_170:.*]] = mhlo.constant dense<-3.3896802963225832E-13> + // CHECK: %[[TMP_171:.*]] = mhlo.add %[[TMP_164]], %[[TMP_170]] + // CHECK: %[[TMP_172:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_171]] + // CHECK: %[[TMP_173:.*]] = mhlo.multiply %[[TMP_169]], %[[TMP_172]] + // CHECK: %[[TMP_174:.*]] = mhlo.constant dense<1.200000e+01> + // CHECK: %[[TMP_175:.*]] = mhlo.add %[[TMP_5]], %[[TMP_174]] + // CHECK: %[[TMP_176:.*]] = mhlo.constant dense<1.100000e+01> + // CHECK: %[[TMP_177:.*]] = mhlo.add %[[TMP_5]], %[[TMP_176]] + // CHECK: %[[TMP_178:.*]] = mhlo.multiply %[[TMP_175]], %[[TMP_177]] + // CHECK: %[[TMP_179:.*]] = mhlo.constant dense<1.3382536530684679E-11> + // CHECK: %[[TMP_180:.*]] = mhlo.add %[[TMP_173]], %[[TMP_179]] + // CHECK: %[[TMP_181:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_180]] + // CHECK: %[[TMP_182:.*]] = mhlo.multiply %[[TMP_178]], %[[TMP_181]] + // CHECK: %[[TMP_183:.*]] = mhlo.constant dense<1.000000e+01> + // CHECK: %[[TMP_184:.*]] = mhlo.add %[[TMP_5]], %[[TMP_183]] + // CHECK: %[[TMP_185:.*]] = mhlo.constant dense<9.000000e+00> + // CHECK: %[[TMP_186:.*]] = mhlo.add %[[TMP_5]], %[[TMP_185]] + // CHECK: %[[TMP_187:.*]] = mhlo.multiply %[[TMP_184]], %[[TMP_186]] + // CHECK: %[[TMP_188:.*]] = mhlo.constant dense<-5.2841901386874932E-10> + // CHECK: %[[TMP_189:.*]] = mhlo.add %[[TMP_182]], %[[TMP_188]] + // CHECK: %[[TMP_190:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_189]] + // CHECK: %[[TMP_191:.*]] = mhlo.multiply %[[TMP_187]], %[[TMP_190]] + // CHECK: %[[TMP_192:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_193:.*]] = mhlo.add %[[TMP_5]], %[[TMP_192]] + // CHECK: %[[TMP_194:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_195:.*]] = mhlo.add %[[TMP_5]], %[[TMP_194]] + // CHECK: %[[TMP_196:.*]] = mhlo.multiply %[[TMP_193]], %[[TMP_195]] + // CHECK: %[[TMP_197:.*]] = mhlo.constant dense<2.08767569878681E-8> + // CHECK: %[[TMP_198:.*]] = mhlo.add %[[TMP_191]], %[[TMP_197]] + // CHECK: %[[TMP_199:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_198]] + // CHECK: %[[TMP_200:.*]] = mhlo.multiply %[[TMP_196]], %[[TMP_199]] + // CHECK: %[[TMP_201:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_202:.*]] = mhlo.add %[[TMP_5]], %[[TMP_201]] + // CHECK: %[[TMP_203:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_204:.*]] = mhlo.add %[[TMP_5]], %[[TMP_203]] + // CHECK: %[[TMP_205:.*]] = mhlo.multiply %[[TMP_202]], %[[TMP_204]] + // CHECK: %[[TMP_206:.*]] = mhlo.constant dense<-8.2671957671957675E-7> + // CHECK: %[[TMP_207:.*]] = mhlo.add %[[TMP_200]], %[[TMP_206]] + // CHECK: %[[TMP_208:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_207]] + // CHECK: %[[TMP_209:.*]] = mhlo.multiply %[[TMP_205]], %[[TMP_208]] + // CHECK: %[[TMP_210:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_211:.*]] = mhlo.add %[[TMP_5]], %[[TMP_210]] + // CHECK: %[[TMP_212:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_213:.*]] = mhlo.add %[[TMP_5]], %[[TMP_212]] + // CHECK: %[[TMP_214:.*]] = mhlo.multiply %[[TMP_211]], %[[TMP_213]] + // CHECK: %[[TMP_215:.*]] = mhlo.constant dense<3.3068783068783071E-5> + // CHECK: %[[TMP_216:.*]] = mhlo.add %[[TMP_209]], %[[TMP_215]] + // CHECK: %[[TMP_217:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_216]] + // CHECK: %[[TMP_218:.*]] = mhlo.multiply %[[TMP_214]], %[[TMP_217]] + // CHECK: %[[TMP_219:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_220:.*]] = mhlo.add %[[TMP_5]], %[[TMP_219]] + // CHECK: %[[TMP_221:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_222:.*]] = mhlo.add %[[TMP_5]], %[[TMP_221]] + // CHECK: %[[TMP_223:.*]] = mhlo.multiply %[[TMP_220]], %[[TMP_222]] + // CHECK: %[[TMP_224:.*]] = mhlo.constant dense<-0.0013888888888888889> + // CHECK: %[[TMP_225:.*]] = mhlo.add %[[TMP_218]], %[[TMP_224]] + // CHECK: %[[TMP_226:.*]] = mhlo.multiply %[[TMP_128]], %[[TMP_225]] + // CHECK: %[[TMP_227:.*]] = mhlo.multiply %[[TMP_223]], %[[TMP_226]] + // CHECK: %[[TMP_228:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_229:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_121]] + // CHECK: %[[TMP_230:.*]] = mhlo.constant dense<0.083333333333333329> + // CHECK: %[[TMP_231:.*]] = mhlo.add %[[TMP_230]], %[[TMP_227]] + // CHECK: %[[TMP_232:.*]] = mhlo.multiply %[[TMP_229]], %[[TMP_231]] + // CHECK: %[[TMP_233:.*]] = mhlo.add %[[TMP_228]], %[[TMP_232]] + // CHECK: %[[TMP_234:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_233]] + // CHECK: %[[TMP_235:.*]] = mhlo.add %[[TMP_120]], %[[TMP_126]] + // CHECK: %[[TMP_236:.*]] = mhlo.add %[[TMP_235]], %[[TMP_234]] // CHECK: %[[TMP_237:.*]] = mhlo.abs %[[TMP_122]] // CHECK: %[[TMP_238:.*]] = mhlo.abs %[[TMP_120]] // CHECK: %[[TMP_239:.*]] = mhlo.constant dense<4.940660e-324> @@ -1944,7 +1944,7 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_260:.*]] = mhlo.and %[[TMP_257]], %[[TMP_259]] // CHECK: %[[TMP_261:.*]] = mhlo.select %[[TMP_260]], %[[TMP_251]], %[[TMP_243]] // CHECK: %[[TMP_262:.*]] = mhlo.select %[[TMP_254]], %[[TMP_261]], %[[TMP_250]] - // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_93]], NOTYPE + // CHECK: %[[TMP_263:.*]] = mhlo.compare EQ, %[[TMP_5]], %[[TMP_91]], NOTYPE // CHECK: %[[TMP_264:.*]] = mhlo.select %[[TMP_263]], %[[TMP_251]], %[[TMP_262]] // CHECK: %[[TMP_265:.*]] = mhlo.multiply %[[TMP_4]], %[[TMP_89]] // CHECK: %[[TMP_266:.*]] = mhlo.multiply %[[TMP_265]], %[[TMP_264]] From 19efc8f0a30f727f6349fcbdf3b19fb2dc0316ae Mon Sep 17 00:00:00 2001 From: Majid Dadashi Date: Fri, 29 Mar 2024 10:11:36 -0700 Subject: [PATCH 087/124] [tflite] Add the composite lowering logic for hardswish This directly lowers an aten hardswish to a tflite hardswish PiperOrigin-RevId: 620273404 --- .../stablehlo/tests/composite-lowering.mlir | 34 +++++++++++++++++++ .../transforms/composite_lowering_patterns.td | 6 ++++ 2 files changed, 40 insertions(+) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir new file mode 100644 index 00000000000000..5924d0dce396c4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -0,0 +1,34 @@ +// RUN: odml-to-stablehlo-opt -composite-lowering -verify-diagnostics %s | FileCheck %s + +func.func @hardswish(%arg0: tensor<2xf32>) -> (tensor<*xf32>) { + %0 = mhlo.composite "aten.hardswish.default" %arg0 {decomposition = @XlaCallModule_aten.hardswish.default.impl_0} : (tensor<2xf32>) -> tensor<2xf32> + %1 = "tf.Identity"(%0) {device = ""} : (tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} +func.func private @XlaCallModule_aten.hardswish.default.impl_0(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = mhlo.constant dense<6.000000e+00> : tensor + %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %2 = mhlo.constant dense<3.40282347E+38> : tensor + %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %4 = mhlo.constant dense<3.000000e+00> : tensor + %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %6 = mhlo.constant dense<0.000000e+00> : tensor + %7 = "mhlo.broadcast_in_dim"(%6) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %8 = mhlo.constant dense<-3.40282347E+38> : tensor + %9 = "mhlo.broadcast_in_dim"(%8) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + %10 = mhlo.add %arg0, %5 : tensor<2xf32> + %11 = mhlo.clamp %7, %10, %3 : tensor<2xf32> + %12 = mhlo.clamp %9, %11, %1 : tensor<2xf32> + %13 = mhlo.multiply %arg0, %12 : tensor<2xf32> + %14 = mhlo.divide %13, %1 : tensor<2xf32> + return %14 : tensor<2xf32> +} + +// CHECK-LABEL: func.func @hardswish( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xf32>) -> tensor<*xf32> { +// CHECK: %[[VAL_1:.*]] = "tfl.hard_swish"(%[[VAL_0]]) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: %[[VAL_2:.*]] = "tf.Identity"(%[[VAL_1]]) {device = ""} : (tensor<2xf32>) -> tensor<*xf32> +// CHECK: %[[VAL_3:.*]] = "tf.Identity"(%[[VAL_2]]) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[VAL_3]] : tensor<*xf32> +// CHECK: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 74d8bb372c7d37..1b62b6fcc4aeae 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -20,3 +20,9 @@ include "mlir/Dialect/Func/IR/FuncOps.td" include "mhlo/IR/hlo_ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" + +def LegalizeHardSwishComposite: Pat< + (MHLO_CompositeOp:$old_value + (variadic $input), + ConstantStrAttr, $_, $_, $_), + (TFL_HardSwishOp $input)>; From 5c66d026b089c47b021a65663153a268f6598e03 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 29 Mar 2024 10:11:59 -0700 Subject: [PATCH 088/124] Rollback https://github.com/openxla/xla/commit/0ab2be0b5a575da3206d2c2f92b85e6346708405. There is an internal issue with running tests on H100s requiring the change to be rolled back. Reverts 42883cb09d1a8155824ce4ed044794c0dffdd19f PiperOrigin-RevId: 620273492 --- third_party/xla/xla/service/gpu/BUILD | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b8bd5348f53318..d723001b77dcf2 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -599,7 +599,6 @@ xla_test( srcs = if_cuda_is_configured(["ir_emitter_triton_test.cc"]), backends = [ "gpu_a100", - "gpu_h100", ], shard_count = 20, tags = ["nomac"], @@ -652,10 +651,7 @@ xla_test( backend_tags = {"gpu": [ "requires-gpu-sm80", ]}, - backends = [ - "gpu", - "gpu_h100", - ], + backends = ["gpu"], tags = [ "large", "no_oss", # requires-mem:16g tag doesn't work in open source From 1a9dbe83517f4697acfb640726c56ecc6df0a450 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 10:12:51 -0700 Subject: [PATCH 089/124] Modify the matrix class to keep track of both memory and communication resharding costs for a given edge as part of one matrix object. PiperOrigin-RevId: 620273768 --- .../auto_sharding/auto_sharding.cc | 13 +- .../auto_sharding/auto_sharding_cost_graph.cc | 142 ++++++------------ .../auto_sharding/auto_sharding_cost_graph.h | 45 ++++-- .../hlo/experimental/auto_sharding/matrix.h | 15 +- 4 files changed, 92 insertions(+), 123 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index a9caf4daa6308e..49d4807266603b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -1852,8 +1852,7 @@ AutoShardingSolverResult CallSolver( if (max_cost) { request.mutable_max_cost()->set_coeff(*max_cost); } - for (const auto& [edge, edge_cost] : cost_graph.edge_communication_costs_) { - const auto& edge_memory_cost = cost_graph.edge_memory_costs_.at(edge); + for (const auto& [edge, edge_cost] : cost_graph.edge_costs_) { AutoShardingSolverRequest_Pair raw_edge; raw_edge.set_first(edge.first); raw_edge.set_second(edge.second); @@ -1862,8 +1861,8 @@ AutoShardingSolverResult CallSolver( AutoShardingSolverRequest_Costs mij; for (NodeStrategyIdx i = 0; i < edge_cost.n_; i++) { for (NodeStrategyIdx j = 0; j < edge_cost.m_; j++) { - rij.add_costs(edge_cost(i, j)); - mij.add_costs(edge_memory_cost(i, j)); + rij.add_costs(edge_cost(i, j).communication_cost); + mij.add_costs(edge_cost(i, j).memory_cost); } } request.mutable_resharding_costs()->Add(std::move(rij)); @@ -1929,8 +1928,8 @@ AutoShardingSolverResult CallSolver( for (const auto& pair : alias_set) { const StrategyGroup* src_strategy_group = strategy_groups[pair.first]; const StrategyGroup* dst_strategy_group = strategy_groups[pair.second]; - Matrix raw_cost(src_strategy_group->strategies.size(), - dst_strategy_group->strategies.size()); + Matrix raw_cost(src_strategy_group->strategies.size(), + dst_strategy_group->strategies.size()); for (NodeStrategyIdx i = 0; i < src_strategy_group->strategies.size(); ++i) { for (NodeStrategyIdx j = 0; j < dst_strategy_group->strategies.size(); @@ -3782,7 +3781,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( std::vector> node_to_edges( strategy_groups.size()); spmd::EdgeIdx edge_idx = 0; - for (const auto& [edge, _] : cost_graph.edge_communication_costs_) { + for (const auto& [edge, _] : cost_graph.edge_costs_) { node_to_edges[edge.second].insert(edge_idx); ++edge_idx; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index 5db37cc868bc2f..1156e0b80c3027 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -51,22 +51,16 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, if (!in_nodes[i]->is_tuple) { NodeIdx src_idx = in_nodes[i]->node_idx; NodeIdx dst_idx = strategy_group->node_idx; - Matrix edge_communication_cost = - CreateEdgeCommunicationCost(src_idx, dst_idx, i, strategy_group); - Matrix edge_memory_cost = - CreateEdgeMemoryCost(src_idx, dst_idx, i, strategy_group); - AddEdgeCost(src_idx, dst_idx, edge_communication_cost, - edge_memory_cost); + EdgeReshardingCostMatrix edge_cost = + CreateEdgeCost(src_idx, dst_idx, i, strategy_group); + AddEdgeCost(src_idx, dst_idx, edge_cost); } else if (in_nodes[i]->is_tuple && in_nodes.size() > 1) { for (size_t l = 0; l < in_nodes[i]->childs.size(); ++l) { NodeIdx src_idx = in_nodes[i]->childs[l]->node_idx; NodeIdx dst_idx = strategy_group->node_idx; - Matrix edge_communication_cost = CreateEdgeCommunicationCost( - src_idx, dst_idx, i, strategy_group, true); - Matrix edge_memory_cost = - CreateEdgeMemoryCost(src_idx, dst_idx, i, strategy_group, true); - AddEdgeCost(src_idx, dst_idx, edge_communication_cost, - edge_memory_cost); + EdgeReshardingCostMatrix edge_cost = + CreateEdgeCost(src_idx, dst_idx, i, strategy_group, true); + AddEdgeCost(src_idx, dst_idx, edge_cost); } } else { CHECK_EQ(in_nodes.size(), 1) @@ -80,12 +74,9 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, // operands. If there is only one operand and it's a tuple, the // first index of communication_resharding_costs is for the tuple // element. - Matrix edge_communication_cost = CreateEdgeCommunicationCost( + EdgeReshardingCostMatrix edge_cost = CreateEdgeCost( src_idx, dst_idx, /*in_node_idx=*/l, strategy_group); - Matrix edge_memory_cost = CreateEdgeMemoryCost( - src_idx, dst_idx, /*in_node_idx=*/l, strategy_group); - AddEdgeCost(src_idx, dst_idx, edge_communication_cost, - edge_memory_cost); + AddEdgeCost(src_idx, dst_idx, edge_cost); } } } @@ -110,8 +101,8 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, NodeIdx src_idx = pair.first->node_idx; NodeIdx dst_idx = pair.second->node_idx; - Matrix edge_communication_cost(node_lens_[src_idx], node_lens_[dst_idx]); - Matrix edge_memory_cost(node_lens_[src_idx], node_lens_[dst_idx]); + EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], + node_lens_[dst_idx]); absl::flat_hash_map src_strategy_name_to_idx_map; for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { @@ -132,49 +123,21 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, CHECK_LE(std::abs(src_strategy.communication_cost - dst_strategy.communication_cost), 1e-6); - edge_communication_cost(it->second, i) = + edge_cost(it->second, i).communication_cost = -src_strategy.communication_cost; } } } - AddEdgeCost(src_idx, dst_idx, edge_communication_cost, edge_memory_cost); - } -} - -Matrix CostGraph::CreateEdgeCommunicationCost(const NodeIdx src_idx, - const NodeIdx dst_idx, - const size_t in_node_idx, - StrategyGroup* strategy_group, - const bool zero_cost) { - CHECK_LT(src_idx, node_lens_.size()); - CHECK_LT(dst_idx, node_lens_.size()); - Matrix edge_communication_cost(node_lens_[src_idx], node_lens_[dst_idx]); - for (NodeStrategyIdx k = 0; k < strategy_group->strategies.size(); ++k) { - const ShardingStrategy& strategy = strategy_group->strategies[k]; - size_t start_idx = 0; - if (strategy.communication_resharding_costs[in_node_idx].size() > - node_lens_[src_idx]) { - start_idx = strategy.communication_resharding_costs[in_node_idx].size() - - node_lens_[src_idx]; - } - for (size_t j = start_idx; - j < strategy.communication_resharding_costs[in_node_idx].size(); ++j) { - edge_communication_cost(j - start_idx, k) = - zero_cost ? 0 - : strategy.communication_resharding_costs[in_node_idx][j]; - } + AddEdgeCost(src_idx, dst_idx, edge_cost); } - return edge_communication_cost; } -Matrix CostGraph::CreateEdgeMemoryCost(const NodeIdx src_idx, - const NodeIdx dst_idx, - const size_t in_node_idx, - StrategyGroup* strategy_group, - const bool zero_cost) { +EdgeReshardingCostMatrix CostGraph::CreateEdgeCost( + const NodeIdx src_idx, const NodeIdx dst_idx, const size_t in_node_idx, + StrategyGroup* strategy_group, const bool zero_cost) { CHECK_LT(src_idx, node_lens_.size()); CHECK_LT(dst_idx, node_lens_.size()); - Matrix edge_communication_cost(node_lens_[src_idx], node_lens_[dst_idx]); + EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); for (NodeStrategyIdx k = 0; k < strategy_group->strategies.size(); ++k) { const ShardingStrategy& strategy = strategy_group->strategies[k]; size_t start_idx = 0; @@ -187,46 +150,43 @@ Matrix CostGraph::CreateEdgeMemoryCost(const NodeIdx src_idx, } for (size_t j = start_idx; j < strategy.memory_resharding_costs[in_node_idx].size(); ++j) { - edge_communication_cost(j - start_idx, k) = - zero_cost ? 0 : strategy.memory_resharding_costs[in_node_idx][j]; + double communication_cost = 0; + double memory_cost = 0; + if (!zero_cost) { + communication_cost = + strategy.communication_resharding_costs[in_node_idx][j]; + memory_cost = strategy.memory_resharding_costs[in_node_idx][j]; + } + edge_cost(j - start_idx, k) = + EdgeReshardingCost(communication_cost, memory_cost); } } - return edge_communication_cost; + return edge_cost; } -Matrix CostGraph::GetEdgeCommunicationCost(const NodeIdx i, const NodeIdx j) { +EdgeReshardingCostMatrix CostGraph::GetEdgeCost(const NodeIdx i, + const NodeIdx j) { if (i <= j) { - return edge_communication_costs_[{i, j}]; + return edge_costs_[{i, j}]; } - return edge_communication_costs_[{j, i}].Transpose(); + return edge_costs_[{j, i}].Transpose(); } -Matrix CostGraph::GetEdgeMemoryCost(const NodeIdx i, const NodeIdx j) { - if (i <= j) { - return edge_memory_costs_[{i, j}]; - } - return edge_memory_costs_[{j, i}].Transpose(); -} - -void CostGraph::AddEdgeCost(NodeIdx i, NodeIdx j, Matrix& cost, - Matrix& memory_cost) { +void CostGraph::AddEdgeCost(NodeIdx i, NodeIdx j, + EdgeReshardingCostMatrix& cost) { if (i > j) { std::swap(i, j); cost = cost.Transpose(); - memory_cost = memory_cost.Transpose(); } - if (edge_communication_costs_.contains({i, j})) { + if (edge_costs_.contains({i, j})) { CHECK(adjacency_[i].contains(j)); CHECK(adjacency_[j].contains(i)); - edge_communication_costs_[{i, j}] = - edge_communication_costs_[{i, j}] + cost; - edge_memory_costs_[{i, j}] = edge_memory_costs_[{i, j}] + memory_cost; + edge_costs_[{i, j}] = edge_costs_[{i, j}] + cost; } else { adjacency_[i].insert(j); adjacency_[j].insert(i); - edge_communication_costs_[{i, j}] = cost; - edge_memory_costs_[{i, j}] = memory_cost; + edge_costs_[{i, j}] = cost; } } @@ -237,13 +197,11 @@ void CostGraph::RemoveEdge(NodeIdx i, NodeIdx j) { CHECK(adjacency_[i].contains(j)); CHECK(adjacency_[j].contains(i)); - CHECK(edge_communication_costs_.contains({i, j})); - CHECK(edge_memory_costs_.contains({i, j})); + CHECK(edge_costs_.contains({i, j})); adjacency_[i].erase(j); adjacency_[j].erase(i); - edge_communication_costs_.erase({i, j}); - edge_memory_costs_.erase({i, j}); + edge_costs_.erase({i, j}); } void CostGraph::MergeNode(const NodeIdx src, const NodeIdx dst) { @@ -253,7 +211,7 @@ void CostGraph::MergeNode(const NodeIdx src, const NodeIdx dst) { CHECK(!merged_to_.contains(dst)); CHECK_NE(src, dst); - Matrix edge_communication_cost = GetEdgeCommunicationCost(dst, src); + EdgeReshardingCostMatrix edge_cost = GetEdgeCost(dst, src); std::vector reindexing(node_lens_[dst]); if (node_lens_[dst] == node_lens_[src]) { @@ -277,7 +235,7 @@ void CostGraph::MergeNode(const NodeIdx src, const NodeIdx dst) { // as the last strategy in BuildStrategyAndCost. keys.reserve(node_lens_[src]); for (NodeStrategyIdx j = 0; j < node_lens_[src]; ++j) { - keys.push_back({edge_communication_cost(i, j), -j}); + keys.push_back({edge_cost(i, j).communication_cost, -j}); } std::sort(arange.begin(), arange.end(), [&keys](int l, int r) { @@ -296,25 +254,19 @@ void CostGraph::MergeNode(const NodeIdx src, const NodeIdx dst) { for (const NodeIdx adj : adj_list) { if (adj == dst) { for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { - extra_node_costs_[dst][i] += edge_communication_cost(i, reindexing[i]); + extra_node_costs_[dst][i] += + edge_cost(i, reindexing[i]).communication_cost; } } else { - Matrix added_edge_communication_cost(node_lens_[dst], node_lens_[adj]); - Matrix added_edge_memory_cost(node_lens_[dst], node_lens_[adj]); - Matrix edge_communication_cost_src_adj = - GetEdgeCommunicationCost(src, adj); - Matrix edge_memory_cost_src_adj = GetEdgeMemoryCost(src, adj); - + EdgeReshardingCostMatrix added_edge_cost(node_lens_[dst], + node_lens_[adj]); + EdgeReshardingCostMatrix edge_cost_src_adj = GetEdgeCost(src, adj); for (NodeStrategyIdx i = 0; i < node_lens_[dst]; ++i) { for (NodeStrategyIdx k = 0; k < node_lens_[adj]; ++k) { - added_edge_communication_cost(i, k) = - edge_communication_cost_src_adj(reindexing[i], k); - added_edge_memory_cost(i, k) = - edge_memory_cost_src_adj(reindexing[i], k); + added_edge_cost(i, k) = edge_cost_src_adj(reindexing[i], k); } } - AddEdgeCost(dst, adj, added_edge_communication_cost, - added_edge_memory_cost); + AddEdgeCost(dst, adj, added_edge_cost); } } // Remove edges @@ -380,7 +332,7 @@ std::string CostGraph::ToString() const { } absl::StrAppend(&str, "\n"); - for (const auto& iter : edge_communication_costs_) { + for (const auto& iter : edge_costs_) { absl::StrAppend(&str, "Edge (", iter.first.first, ", ", iter.first.second, "):\n"); absl::StrAppend(&str, iter.second.ToString(), "\n"); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index 6ef20f764cbca5..08b0bd968b6d4c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/matrix.h" @@ -32,6 +33,28 @@ limitations under the License. namespace xla { namespace spmd { +struct EdgeReshardingCost { + double communication_cost = 0; + double memory_cost = 0; + + EdgeReshardingCost() : communication_cost(0), memory_cost(0) {} + + EdgeReshardingCost(double communication_cost_, double memory_cost_) + : communication_cost(communication_cost_), memory_cost(memory_cost_) {} + + EdgeReshardingCost operator+(const EdgeReshardingCost& other) const { + return EdgeReshardingCost(other.communication_cost + communication_cost, + other.memory_cost + memory_cost); + } + + std::string ToString() const { + return absl::StrCat("{communication_cost=", communication_cost, + ", memory_cost=", memory_cost, "}"); + } +}; + +using EdgeReshardingCostMatrix = Matrix; + // A graph data structure to simplify the edge cost graph. It merges nodes and // performs path compression. class CostGraph { @@ -39,20 +62,14 @@ class CostGraph { CostGraph(const StrategyGroups& strategy_groups, const AssociativeDotPairs& associative_dot_pairs); - Matrix CreateEdgeCommunicationCost(NodeIdx src_idx, NodeIdx dst_idx, - size_t in_node_idx, - StrategyGroup* strategy_group, - bool zero_cost = false); - - Matrix CreateEdgeMemoryCost(NodeIdx src_idx, NodeIdx dst_idx, - size_t in_node_idx, StrategyGroup* strategy_group, - bool zero_cost = false); - - Matrix GetEdgeCommunicationCost(NodeIdx i, NodeIdx j); + EdgeReshardingCostMatrix CreateEdgeCost(NodeIdx src_idx, NodeIdx dst_idx, + size_t in_node_idx, + StrategyGroup* strategy_group, + bool zero_cost = false); - Matrix GetEdgeMemoryCost(NodeIdx i, NodeIdx j); + EdgeReshardingCostMatrix GetEdgeCost(NodeIdx i, NodeIdx j); - void AddEdgeCost(NodeIdx i, NodeIdx j, Matrix& cost, Matrix& memory_cost); + void AddEdgeCost(NodeIdx i, NodeIdx j, EdgeReshardingCostMatrix& cost); void RemoveEdge(NodeIdx i, NodeIdx j); @@ -90,8 +107,8 @@ class CostGraph { std::vector> adjacency_; // The cost matrix between two nodes. - StableHashMap, Matrix> edge_communication_costs_; - StableHashMap, Matrix> edge_memory_costs_; + StableHashMap, EdgeReshardingCostMatrix> + edge_costs_; // The extra node costs introduced by merging nodes. std::vector> extra_node_costs_; // The reindexing vector of the node. diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/matrix.h b/third_party/xla/xla/hlo/experimental/auto_sharding/matrix.h index 40eb8d35887685..903973eea5a3a6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/matrix.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/matrix.h @@ -36,6 +36,7 @@ namespace spmd { // It can create a view for matrix transpose without copying the memory. // TODO (zhuohan): Inherit from Array2D and add Transpose and operator+ (See // tensorflow/compiler/xla/array2d.h;l=39) +template class Matrix { public: Matrix() : n_(0), m_(0), transpose_(false), data_(nullptr) {} @@ -44,11 +45,11 @@ class Matrix { this->n_ = n; this->m_ = m; transpose_ = false; - data_ = std::make_shared>(n * m, 0.0); + data_ = std::make_shared>(n * m, T()); } Matrix(size_t n, size_t m, bool transpose, - std::shared_ptr> data) { + std::shared_ptr> data) { this->n_ = n; this->m_ = m; this->transpose_ = transpose; @@ -57,7 +58,7 @@ class Matrix { Matrix Transpose() { return Matrix(m_, n_, !transpose_, data_); } - double operator()(size_t i, size_t j) const { + T operator()(size_t i, size_t j) const { size_t idx; if (transpose_) { idx = j * n_ + i; @@ -69,7 +70,7 @@ class Matrix { return (*data_)[idx]; } - double& operator()(size_t i, size_t j) { + T& operator()(size_t i, size_t j) { size_t idx; if (transpose_) { idx = j * n_ + i; @@ -81,7 +82,7 @@ class Matrix { return (*data_)[idx]; } - Matrix operator+(const Matrix& other) { + Matrix operator+(const Matrix& other) { CHECK_EQ(n_, other.n_); CHECK_EQ(m_, other.m_); Matrix ret = Matrix(n_, m_); @@ -98,7 +99,7 @@ class Matrix { for (size_t i = 0; i < n_; ++i) { for (size_t j = 0; j < m_; ++j) { - absl::StrAppend(&str, operator()(i, j), " "); + absl::StrAppend(&str, operator()(i, j).ToString(), " "); } absl::StrAppend(&str, "\n"); } @@ -109,7 +110,7 @@ class Matrix { size_t n_; size_t m_; bool transpose_; - std::shared_ptr> data_; + std::shared_ptr> data_; }; } // namespace spmd } // namespace xla From c9cffcd9ee77dba29078a04185e2e1f3865a6a78 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 29 Mar 2024 10:38:02 -0700 Subject: [PATCH 090/124] [xla:gpu][NFC] Simplify `collect_slice_info` PiperOrigin-RevId: 620281417 --- .../xla/xla/service/gpu/fusions/custom.cc | 52 ++++++++----------- .../gpu/runtime/address_computation_thunk.cc | 4 +- .../gpu/runtime/address_computation_thunk.h | 8 +-- 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 8c2229506e35d5..c8c0900bf0523b 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -199,10 +199,10 @@ absl::StatusOr EmitDynamicSlicedGemm( ir_emitter_context.buffer_assignment(); std::vector>> - offset_buffer_indices; - std::vector> orig_shapes; - std::vector> sliced_shapes; - std::vector> offset_byte_sizes; + offset_buffer_indices(4, std::nullopt); + std::vector> orig_shapes(4, std::nullopt); + std::vector> sliced_shapes(4, std::nullopt); + std::vector> offset_byte_sizes(4, std::nullopt); HloDynamicIndexInstruction* slice_instr = nullptr; auto get_original_operand_slice = @@ -231,12 +231,8 @@ absl::StatusOr EmitDynamicSlicedGemm( fusion.operand(param->parameter_number()), index); }; - auto collect_slice_info = [&]() { + auto collect_slice_info = [&](unsigned idx) { if (slice_instr == nullptr) { - offset_buffer_indices.push_back(std::nullopt); - orig_shapes.push_back(std::nullopt); - sliced_shapes.push_back(std::nullopt); - offset_byte_sizes.push_back(std::nullopt); return; } @@ -249,27 +245,29 @@ absl::StatusOr EmitDynamicSlicedGemm( /*index=*/{}) .value()); } - offset_buffer_indices.push_back(offset_slices); - orig_shapes.push_back(slice_instr->operand(0)->shape()); - sliced_shapes.push_back(DynCast(slice_instr) - ? slice_instr->shape() - : slice_instr->operand(1)->shape()); - offset_byte_sizes.push_back(ShapeUtil::ByteSizeOfPrimitiveType( - slice_instr->index_operands().front()->shape().element_type())); + offset_buffer_indices[idx] = std::move(offset_slices); + orig_shapes[idx] = slice_instr->operand(0)->shape(); + sliced_shapes[idx] = DynCast(slice_instr) + ? slice_instr->shape() + : slice_instr->operand(1)->shape(); + offset_byte_sizes[idx] = ShapeUtil::ByteSizeOfPrimitiveType( + slice_instr->index_operands().front()->shape().element_type()); + + // Reset `slice_instr` for the next call to `collect_slice_info()`. + slice_instr = nullptr; }; + unsigned argument_idx = 0; TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice, get_original_operand_slice( - custom_call.operand(kLHSOperandIndex), /*index=*/{})); - collect_slice_info(); + custom_call.operand(argument_idx), /*index=*/{})); + collect_slice_info(argument_idx++); - slice_instr = nullptr; TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice, get_original_operand_slice( - custom_call.operand(kRHSOperandIndex), /*index=*/{})); - collect_slice_info(); + custom_call.operand(argument_idx), /*index=*/{})); + collect_slice_info(argument_idx++); - slice_instr = nullptr; BufferAllocation::Slice output; std::optional workspace = std::nullopt; std::optional slice_workspace_fake = std::nullopt; @@ -313,22 +311,18 @@ absl::StatusOr EmitDynamicSlicedGemm( if (fusion.shape().IsArray()) { TF_ASSIGN_OR_RETURN(output, get_original_result_slice(&custom_call, /*index=*/{})); - collect_slice_info(); - // Collect slice info for std::nullopt workspace. - slice_instr = nullptr; - collect_slice_info(); + collect_slice_info(argument_idx); } else { TF_ASSIGN_OR_RETURN( output, get_original_result_slice(&custom_call, /*index=*/{kGEMMOutputBufferIndex})); - collect_slice_info(); + collect_slice_info(argument_idx++); // TODO(vuson): If we want to support slices of workspace, we'd need to // start `HloFindIf` with `get-tuple-element` with the right index. TF_ASSIGN_OR_RETURN( workspace, GetAllocationSlice(buffer_assignment, &fusion, /*index=*/{kGEMMWorkspaceBufferIndex})); - slice_instr = nullptr; - collect_slice_info(); + collect_slice_info(argument_idx); fake_allocations[3] = std::make_unique( /*index=*/3, workspace->size(), /*color=*/0); slice_workspace_fake = BufferAllocation::Slice(fake_allocations[3].get(), 0, diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc index 15efc3e89b4b5b..b24a4f2b7cc3b7 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.cc @@ -49,8 +49,8 @@ AddressComputationThunk::AddressComputationThunk( std::vector> fake_allocations, std::vector>> offset_buffer_indices, - std::vector> orig_shapes, - std::vector> sliced_shapes, + std::vector> orig_shapes, + std::vector> sliced_shapes, std::vector> offset_byte_sizes) : Thunk(Kind::kAddressComputation, thunk_info), embedded_thunk_(std::make_unique( diff --git a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h index e1c0b30d9953aa..8d36751b9d830d 100644 --- a/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/address_computation_thunk.h @@ -48,8 +48,8 @@ class AddressComputationThunk : public Thunk { std::vector> fake_allocations_, std::vector>> offset_buffer_indices, - std::vector> orig_shapes, - std::vector> sliced_shapes, + std::vector> orig_shapes, + std::vector> sliced_shapes, std::vector> offset_byte_sizes); AddressComputationThunk(const AddressComputationThunk&) = delete; @@ -67,8 +67,8 @@ class AddressComputationThunk : public Thunk { std::vector> fake_allocations_; std::vector>> offset_buffer_indices_; - std::vector> orig_shapes_; - std::vector> sliced_shapes_; + std::vector> orig_shapes_; + std::vector> sliced_shapes_; std::vector> offset_byte_sizes_; // Pinned host memory for transferring offset values from device to host. From 8012c68f2fbbb7fc5d6986e55c49976504fb7eb3 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 29 Mar 2024 10:49:09 -0700 Subject: [PATCH 091/124] Change include order in `ml_dtypes.cc` to prevent errors. Trying to prevent `error: "Using deprecated NumPy API, disable it with #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION"` PiperOrigin-RevId: 620284610 --- .../xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc b/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc index 662a29f84f9387..2815b25b24469f 100644 --- a/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc +++ b/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc @@ -17,13 +17,16 @@ limitations under the License. #include #include +// Must be included first to ensure `NPY_NO_DEPRECATED_API` is defined. +// clang-format off +#include "tsl/python/lib/core/numpy.h" // IWYU pragma: keep +// clang-format on #include "numpy/ndarraytypes.h" #include "absl/base/attributes.h" #include "absl/base/call_once.h" #include "pybind11/gil.h" // from @pybind11 #include "pybind11/numpy.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 -#include "tsl/python/lib/core/numpy.h" // IWYU pragma: keep namespace tsl { namespace ml_dtypes { From 1faaaceb5f5c27cdfbb5b9f254ad4c3d95bb6b39 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Fri, 29 Mar 2024 10:53:24 -0700 Subject: [PATCH 092/124] Integrate LLVM at llvm/llvm-project@80aa52d8c5a8 Updates LLVM usage to match [80aa52d8c5a8](https://github.com/llvm/llvm-project/commit/80aa52d8c5a8) PiperOrigin-RevId: 620285862 --- third_party/llvm/generated.patch | 206 ------------------------------- third_party/llvm/workspace.bzl | 4 +- 2 files changed, 2 insertions(+), 208 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 229971a2e9ad47..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,207 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/APINotes/APINotesWriter.cpp b/clang/lib/APINotes/APINotesWriter.cpp ---- a/clang/lib/APINotes/APINotesWriter.cpp -+++ b/clang/lib/APINotes/APINotesWriter.cpp -@@ -441,7 +441,7 @@ - std::sort(VI.begin(), VI.end(), - [](const std::pair &LHS, - const std::pair &RHS) -> bool { -- assert(LHS.first != RHS.first && -+ assert((&LHS == &RHS || LHS.first != RHS.first) && - "two entries for the same version"); - return LHS.first < RHS.first; - }); -diff -ruN --strip-trailing-cr a/clang/test/APINotes/module-cache.m b/clang/test/APINotes/module-cache.m ---- a/clang/test/APINotes/module-cache.m -+++ b/clang/test/APINotes/module-cache.m -@@ -27,6 +27,7 @@ - // RUN: FileCheck -check-prefix=CHECK-ONE-ERROR %s < %t/before.log - - // Change the API notes file, after the module has rebuilt once. -+// RUN: chmod u+w %t/APINotes/SomeOtherKit.apinotes - // RUN: echo ' - Selector: "methodA"' >> %t/APINotes/SomeOtherKit.apinotes - // RUN: echo ' MethodKind: Instance' >> %t/APINotes/SomeOtherKit.apinotes - // RUN: echo ' Availability: none' >> %t/APINotes/SomeOtherKit.apinotes -diff -ruN --strip-trailing-cr a/lld/test/ELF/lto/libcall-archive.ll b/lld/test/ELF/lto/libcall-archive.ll ---- a/lld/test/ELF/lto/libcall-archive.ll -+++ b/lld/test/ELF/lto/libcall-archive.ll -@@ -4,8 +4,8 @@ - ; RUN: llvm-as -o %t2.o %S/Inputs/libcall-archive.ll - ; RUN: llvm-mc -filetype=obj -triple=x86_64-unknown-linux -o %t3.o %S/Inputs/libcall-archive.s - ; RUN: llvm-ar rcs %t.a %t2.o %t3.o --; RUN: ld.lld --why-extract=why.txt -o %t %t.o %t.a --; RUN: FileCheck %s --input-file=why.txt --check-prefix=CHECK-WHY -+; RUN: ld.lld --why-extract=%t.why.txt -o %t %t.o %t.a -+; RUN: FileCheck %s --input-file=%t.why.txt --check-prefix=CHECK-WHY - ; RUN: llvm-nm %t | FileCheck %s - ; RUN: ld.lld -o %t2 %t.o --start-lib %t2.o %t3.o --end-lib - ; RUN: llvm-nm %t2 | FileCheck %s -diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/Verifier.h b/llvm/include/llvm/IR/Verifier.h ---- a/llvm/include/llvm/IR/Verifier.h -+++ b/llvm/include/llvm/IR/Verifier.h -@@ -77,7 +77,6 @@ - /// Visit an instruction and return true if it is valid, return false if an - /// invalid TBAA is attached. - bool visitTBAAMetadata(Instruction &I, const MDNode *MD); -- bool visitTBAAStructMetadata(Instruction &I, const MDNode *MD); - }; - - /// Check a function for errors, useful for use when debugging a -diff -ruN --strip-trailing-cr a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp ---- a/llvm/lib/IR/Verifier.cpp -+++ b/llvm/lib/IR/Verifier.cpp -@@ -5096,9 +5096,6 @@ - if (MDNode *TBAA = I.getMetadata(LLVMContext::MD_tbaa)) - TBAAVerifyHelper.visitTBAAMetadata(I, TBAA); - -- if (MDNode *TBAA = I.getMetadata(LLVMContext::MD_tbaa_struct)) -- TBAAVerifyHelper.visitTBAAStructMetadata(I, TBAA); -- - if (MDNode *MD = I.getMetadata(LLVMContext::MD_noalias)) - visitAliasScopeListMetadata(MD); - if (MDNode *MD = I.getMetadata(LLVMContext::MD_alias_scope)) -@@ -7422,35 +7419,6 @@ - return true; - } - --bool TBAAVerifier::visitTBAAStructMetadata(Instruction &I, const MDNode *MD) { -- CheckTBAA(MD->getNumOperands() % 3 == 0, -- "tbaa.struct operands must occur in groups of three", &I, MD); -- -- // Each group of three operands must consist of two integers and a -- // tbaa node. Moreover, the regions described by the offset and size -- // operands must be non-overlapping. -- std::optional NextFree; -- for (unsigned int Idx = 0; Idx < MD->getNumOperands(); Idx += 3) { -- auto *OffsetCI = -- mdconst::dyn_extract_or_null(MD->getOperand(Idx)); -- CheckTBAA(OffsetCI, "Offset must be a constant integer", &I, MD); -- -- auto *SizeCI = -- mdconst::dyn_extract_or_null(MD->getOperand(Idx + 1)); -- CheckTBAA(SizeCI, "Size must be a constant integer", &I, MD); -- -- MDNode *TBAA = dyn_cast_or_null(MD->getOperand(Idx + 2)); -- CheckTBAA(TBAA, "TBAA tag missing", &I, MD); -- visitTBAAMetadata(I, TBAA); -- -- bool NonOverlapping = !NextFree || NextFree->ule(OffsetCI->getValue()); -- CheckTBAA(NonOverlapping, "Overlapping tbaa.struct regions", &I, MD); -- -- NextFree = OffsetCI->getValue() + SizeCI->getValue(); -- } -- return true; --} -- - char VerifierLegacyPass::ID = 0; - INITIALIZE_PASS(VerifierLegacyPass, "verify", "Module Verifier", false, false) - -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AArch64/arm64-abi_align.ll b/llvm/test/CodeGen/AArch64/arm64-abi_align.ll ---- a/llvm/test/CodeGen/AArch64/arm64-abi_align.ll -+++ b/llvm/test/CodeGen/AArch64/arm64-abi_align.ll -@@ -518,6 +518,4 @@ - !1 = !{!"omnipotent char", !2} - !2 = !{!"Simple C/C++ TBAA"} - !3 = !{!"short", !1} --!4 = !{i64 0, i64 4, !5, i64 4, i64 2, !6, i64 8, i64 4, !5, i64 12, i64 2, !6, i64 16, i64 4, !5, i64 20, i64 2, !6} --!5 = !{!0, !0, i64 0} --!6 = !{!3, !3, i64 0} -+!4 = !{i64 0, i64 4, !0, i64 4, i64 2, !3, i64 8, i64 4, !0, i64 12, i64 2, !3, i64 16, i64 4, !0, i64 20, i64 2, !3} -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll ---- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll -+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/mem-intrinsics.ll -@@ -141,4 +141,4 @@ - !5 = distinct !{!5, !"some domain"} - !6 = !{!7} - !7 = distinct !{!7, !5, !"some scope 2"} --!8 = !{i64 0, i64 8, !0} -+!8 = !{i64 0, i64 8, null} -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll b/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll ---- a/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll -+++ b/llvm/test/Transforms/InstCombine/struct-assign-tbaa.ll -@@ -75,7 +75,7 @@ - !1 = !{!"omnipotent char", !0} - !2 = !{!5, !5, i64 0} - !3 = !{i64 0, i64 4, !2} --!4 = !{i64 0, i64 8, !2} -+!4 = !{i64 0, i64 8, null} - !5 = !{!"float", !0} - !6 = !{i64 0, i64 4, !2, i64 4, i64 4, !2} - !7 = !{i64 0, i64 2, !2, i64 4, i64 6, !2} -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll b/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll ---- a/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll -+++ b/llvm/test/Transforms/Scalarizer/basic-inseltpoison.ll -@@ -836,6 +836,5 @@ - !2 = !{ !"set2", !0 } - !3 = !{ !3, !{!"llvm.loop.parallel_accesses", !13} } - !4 = !{ float 4.0 } --!5 = !{ i64 0, i64 8, !6 } --!6 = !{ !1, !1, i64 0 } -+!5 = !{ i64 0, i64 8, null } - !13 = distinct !{} -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/Scalarizer/basic.ll b/llvm/test/Transforms/Scalarizer/basic.ll ---- a/llvm/test/Transforms/Scalarizer/basic.ll -+++ b/llvm/test/Transforms/Scalarizer/basic.ll -@@ -870,6 +870,5 @@ - !2 = !{ !"set2", !0 } - !3 = !{ !3, !{!"llvm.loop.parallel_accesses", !13} } - !4 = !{ float 4.0 } --!5 = !{ i64 0, i64 8, !6 } --!6 = !{ !1, !1, i64 0 } -+!5 = !{ i64 0, i64 8, null } - !13 = distinct !{} -diff -ruN --strip-trailing-cr a/llvm/test/Transforms/SROA/tbaa-struct3.ll b/llvm/test/Transforms/SROA/tbaa-struct3.ll ---- a/llvm/test/Transforms/SROA/tbaa-struct3.ll -+++ b/llvm/test/Transforms/SROA/tbaa-struct3.ll -@@ -539,7 +539,7 @@ - !6 = !{!5, !5, i64 0} - !7 = !{i64 0, i64 8, !6, i64 8, i64 4, !1} - !8 = !{i64 0, i64 4, !1, i64 4, i64 8, !6} --!9 = !{i64 0, i64 8, !6, i64 8, i64 8, !1} -+!9 = !{i64 0, i64 8, !6, i64 4, i64 8, !1} - !10 = !{i64 0, i64 2, !1, i64 2, i64 2, !1} - !11 = !{i64 0, i64 1, !1, i64 1, i64 3, !1} - !12 = !{i64 0, i64 2, !1, i64 2, i64 6, !1} -diff -ruN --strip-trailing-cr a/llvm/test/Verifier/tbaa-struct.ll b/llvm/test/Verifier/tbaa-struct.ll ---- a/llvm/test/Verifier/tbaa-struct.ll -+++ b/llvm/test/Verifier/tbaa-struct.ll -@@ -1,36 +1,28 @@ --; RUN: not llvm-as < %s 2>&1 | FileCheck %s -+; RUN: llvm-as < %s 2>&1 -+ -+; FIXME: The verifer should reject the invalid !tbaa.struct nodes below. - - define void @test_overlapping_regions(ptr %a1) { --; CHECK: Overlapping tbaa.struct regions --; CHECK-NEXT: %ld = load i8, ptr %a1, align 1, !tbaa.struct !0 - %ld = load i8, ptr %a1, align 1, !tbaa.struct !0 - ret void - } - - define void @test_size_not_integer(ptr %a1) { --; CHECK: Size must be a constant integer --; CHECK-NEXT: store i8 1, ptr %a1, align 1, !tbaa.struct !5 - store i8 1, ptr %a1, align 1, !tbaa.struct !5 - ret void - } - - define void @test_offset_not_integer(ptr %a1, ptr %a2) { --; CHECK: Offset must be a constant integer --; CHECK-NEXT: tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !6 - tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !6 - ret void - } - - define void @test_tbaa_missing(ptr %a1, ptr %a2) { --; CHECK: TBAA tag missing --; CHECK-NEXT: tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !7 - tail call void @llvm.memcpy.p0.p0.i64(ptr align 8 %a1, ptr align 8 %a2, i64 16, i1 false), !tbaa.struct !7 - ret void - } - - define void @test_tbaa_invalid(ptr %a1) { --; CHECK: Old-style TBAA is no longer allowed, use struct-path TBAA instead --; CHECK-NEXT: store i8 1, ptr %a1, align 1, !tbaa.struct !8 - store i8 1, ptr %a1, align 1, !tbaa.struct !8 - ret void - } diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 1a7d56b5764590..6ed4d29d211c15 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "aa2c14de1adcd265bf0c0fb44f97b5d6c1c38710" - LLVM_SHA256 = "50d2c7cd5355ec04a75991f2a4e2c89a3876b46fc1b71cd9fa3245f212d55da0" + LLVM_COMMIT = "80aa52d8c5a8a1c26b4114c60c2159c743d236d8" + LLVM_SHA256 = "b9079d7e8d72d7bb2453d908be1bd2bc4e5d62fd358ea9b7108f1c7d3b3c8585" tf_http_archive( name = name, From 8dc94a2ad54f5ea13d842d1eacbab9d3f04433ce Mon Sep 17 00:00:00 2001 From: Chris Minge Date: Fri, 29 Mar 2024 11:15:22 -0700 Subject: [PATCH 093/124] Add support for LocalDeviceManager in tfrt_session. PiperOrigin-RevId: 620292882 --- tensorflow/core/tfrt/tfrt_session/tfrt_session.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc index 2836105ca727e1..a664a14d58ffe1 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session.cc @@ -464,6 +464,10 @@ class TfrtSession : public tensorflow::Session { Status ListDevices(std::vector* response) override { return errors::Unimplemented("TfrtSession::ListDevices is Unimplemented."); } + Status LocalDeviceManager(const DeviceMgr** output) override { + *output = &graph_executor_->fallback_state().device_manager(); + return absl::OkStatus(); + } private: tfrt::HostContext* GetHostContext() { From 66ee739cec70a0488c0f4c7ef9ce99e5f14fe14f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 11:41:47 -0700 Subject: [PATCH 094/124] Add new methods in backend async kernel interface. Reflect to the design of go/async-io-coherency New methods will set / retrieve the attributes based on the buffers. PiperOrigin-RevId: 620300315 --- .../async/backend_async_kernel_interface.cc | 20 ++++++++++ .../async/backend_async_kernel_interface.h | 14 +++++++ .../backend_async_kernel_interface_test.cc | 4 ++ .../lite/async/testing/mock_async_kernel.h | 5 +++ .../lite/core/async/async_kernel_internal.h | 11 ++++++ .../lite/core/async/async_signature_runner.cc | 10 +++++ .../lite/core/async/async_signature_runner.h | 11 ++++++ tensorflow/lite/core/async/async_subgraph.cc | 10 +++++ tensorflow/lite/core/async/async_subgraph.h | 18 +++++++++ tensorflow/lite/core/async/c/async_kernel.cc | 18 +++++++++ tensorflow/lite/core/async/c/async_kernel.h | 38 +++++++++++++++++++ tensorflow/lite/delegates/gpu/delegate.cc | 16 ++++++++ 12 files changed, 175 insertions(+) diff --git a/tensorflow/lite/async/backend_async_kernel_interface.cc b/tensorflow/lite/async/backend_async_kernel_interface.cc index ef7bc0563018a2..a3e0234bbc1fd9 100644 --- a/tensorflow/lite/async/backend_async_kernel_interface.cc +++ b/tensorflow/lite/async/backend_async_kernel_interface.cc @@ -152,6 +152,22 @@ TfLiteStatus Finish(TfLiteAsyncKernel* async_kernel, ->Finish(context, task); } +TfLiteStatus SetBufferAttributes(TfLiteAsyncKernel* async_kernel, + const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs) { + return reinterpret_cast( + TfLiteAsyncKernelGetKernelData(async_kernel)) + ->SetBufferAttributes(buffer, attrs); +} + +TfLiteStatus GetBufferAttributes(TfLiteAsyncKernel* async_kernel, + const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs) { + return reinterpret_cast( + TfLiteAsyncKernelGetKernelData(async_kernel)) + ->GetBufferAttributes(buffer, attrs); +} + } // namespace internal BackendAsyncKernelInterface::BackendAsyncKernelInterface() { @@ -167,6 +183,10 @@ BackendAsyncKernelInterface::BackendAsyncKernelInterface() { TfLiteAsyncKernelSetReconcileRestrictions(kernel_, internal::ReconcileRestrictions); TfLiteAsyncKernelSetSetAttributes(kernel_, internal::SetAttributes); + TfLiteAsyncKernelSetSetBufferAttributes(kernel_, + internal::SetBufferAttributes); + TfLiteAsyncKernelSetGetBufferAttributes(kernel_, + internal::GetBufferAttributes); TfLiteAsyncKernelSetPrepare(kernel_, internal::Prepare); TfLiteAsyncKernelSetEval(kernel_, internal::Eval); TfLiteAsyncKernelSetWait(kernel_, internal::Wait); diff --git a/tensorflow/lite/async/backend_async_kernel_interface.h b/tensorflow/lite/async/backend_async_kernel_interface.h index 2849c229c61cc9..c8d94341c417ad 100644 --- a/tensorflow/lite/async/backend_async_kernel_interface.h +++ b/tensorflow/lite/async/backend_async_kernel_interface.h @@ -130,6 +130,20 @@ class BackendAsyncKernelInterface { TfLiteOpaqueNode* node, int tensor_index, const TfLiteAttributeMap* attrs) = 0; + // Set buffer's attributes. Backend will check if the buffer has been + // registered. And return TfLiteOk if the `attrs` for the `buffer` could be + // set in the corresponding async kernel. + virtual TfLiteStatus SetBufferAttributes(const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs) = 0; + + // Get buffer's attributes. Backend will check if the buffer has been + // registered. And return TfLiteOk if provided `attrs` for the `buffer` could + // be found in the registration pool in corresponding async kernel. If `attrs` + // is a non-empty map, it will be overwritten by the attributes of the + // `buffer`. + virtual TfLiteStatus GetBufferAttributes(const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs) = 0; + // Prepares the kernel using the information from Set[In|Out]putAttributes // call above. virtual TfLiteStatus Prepare(TfLiteOpaqueContext* context, diff --git a/tensorflow/lite/async/backend_async_kernel_interface_test.cc b/tensorflow/lite/async/backend_async_kernel_interface_test.cc index 2620cc707a1ad0..e4f68bb32167b6 100644 --- a/tensorflow/lite/async/backend_async_kernel_interface_test.cc +++ b/tensorflow/lite/async/backend_async_kernel_interface_test.cc @@ -35,6 +35,8 @@ TEST(BackendAsyncKernelInterfaceTest, BasicTest) { EXPECT_CALL(kernel, UnregisterBuffer(_, _)); EXPECT_CALL(kernel, ReconcileRestrictions(_, _, _, _, _, _)); EXPECT_CALL(kernel, SetAttributes(_, _, _, _)); + EXPECT_CALL(kernel, SetBufferAttributes(_, _)); + EXPECT_CALL(kernel, GetBufferAttributes(_, _)); EXPECT_CALL(kernel, Prepare(_, _)); EXPECT_CALL(kernel, Eval(_, _, _)); EXPECT_CALL(kernel, Wait(_, _)); @@ -49,6 +51,8 @@ TEST(BackendAsyncKernelInterfaceTest, BasicTest) { tflite_kernel->reconcile_restrictions(tflite_kernel, nullptr, nullptr, 0, nullptr, nullptr, nullptr); tflite_kernel->set_attributes(tflite_kernel, nullptr, nullptr, 0, nullptr); + tflite_kernel->set_buffer_attributes(tflite_kernel, nullptr, nullptr); + tflite_kernel->get_buffer_attributes(tflite_kernel, nullptr, nullptr); tflite_kernel->prepare(tflite_kernel, nullptr, nullptr); tflite_kernel->eval(tflite_kernel, nullptr, nullptr, nullptr); tflite_kernel->wait(tflite_kernel, nullptr, nullptr); diff --git a/tensorflow/lite/async/testing/mock_async_kernel.h b/tensorflow/lite/async/testing/mock_async_kernel.h index a3297f849b12f6..be31a2a71a843c 100644 --- a/tensorflow/lite/async/testing/mock_async_kernel.h +++ b/tensorflow/lite/async/testing/mock_async_kernel.h @@ -48,6 +48,11 @@ class MockAsyncKernel : public delegates::BackendAsyncKernelInterface { (TfLiteOpaqueContext*, TfLiteOpaqueNode*, int, const TfLiteAttributeMap*), (override)); + MOCK_METHOD(TfLiteStatus, SetBufferAttributes, + (const TfLiteBackendBuffer*, const TfLiteAttributeMap*), + (override)); + MOCK_METHOD(TfLiteStatus, GetBufferAttributes, + (const TfLiteBackendBuffer*, TfLiteAttributeMap*), (override)); MOCK_METHOD(TfLiteStatus, Prepare, (TfLiteOpaqueContext*, TfLiteOpaqueNode*), (override)); MOCK_METHOD(TfLiteStatus, Eval, diff --git a/tensorflow/lite/core/async/async_kernel_internal.h b/tensorflow/lite/core/async/async_kernel_internal.h index 2ce473a029e516..efc341be8b3f0c 100644 --- a/tensorflow/lite/core/async/async_kernel_internal.h +++ b/tensorflow/lite/core/async/async_kernel_internal.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_ASYNC_ASYNC_KERNEL_INTERNAL_H_ #define TENSORFLOW_LITE_CORE_ASYNC_ASYNC_KERNEL_INTERNAL_H_ +#include #include #include @@ -115,6 +116,16 @@ struct TfLiteAsyncKernel { TfLiteOpaqueNode* node, int tensor_index, const TfLiteAttributeMap* attrs) = nullptr; + // Set attributes to the buffer, backend kernel will validate the buffer. + TfLiteStatus (*set_buffer_attributes)( + TfLiteAsyncKernel* async_kernel, const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs) = nullptr; + + // Get attributes from the buffer, backend kernel will validate the buffer. + TfLiteStatus (*get_buffer_attributes)(TfLiteAsyncKernel* async_kernel, + const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs) = nullptr; + // Prepares the kernel using the information from Set[In|Out]putAttributes // call above. TfLiteStatus (*prepare)(TfLiteAsyncKernel* async_kernel, diff --git a/tensorflow/lite/core/async/async_signature_runner.cc b/tensorflow/lite/core/async/async_signature_runner.cc index d87f9f40683ab1..ad1ee15a601bdb 100644 --- a/tensorflow/lite/core/async/async_signature_runner.cc +++ b/tensorflow/lite/core/async/async_signature_runner.cc @@ -137,6 +137,16 @@ TfLiteStatus AsyncSignatureRunner::SetAttributes( return async_subgraph_->SetAttributes(tensor_index, attrs); } +TfLiteStatus AsyncSignatureRunner::SetBufferAttributes( + const TfLiteBackendBuffer* buffer, const TfLiteAttributeMap* attrs) { + return async_subgraph_->SetBufferAttributes(buffer, attrs); +} + +TfLiteStatus AsyncSignatureRunner::GetBufferAttributes( + const TfLiteBackendBuffer* buffer, TfLiteAttributeMap* attrs) { + return async_subgraph_->GetBufferAttributes(buffer, attrs); +} + TfLiteStatus AsyncSignatureRunner::PrepareBackends() { return async_subgraph_->Prepare(); } diff --git a/tensorflow/lite/core/async/async_signature_runner.h b/tensorflow/lite/core/async/async_signature_runner.h index b23a460debea29..d0a85a6c7eb038 100644 --- a/tensorflow/lite/core/async/async_signature_runner.h +++ b/tensorflow/lite/core/async/async_signature_runner.h @@ -122,6 +122,17 @@ class AsyncSignatureRunner { // Returns true if all backends accept the `attrs`. TfLiteStatus SetAttributes(int tensor_index, const TfLiteAttributeMap* attrs); + // Set the attributes of a specific buffer. Returns + // kTfLiteDelegateError if the buffer is not registered. + TfLiteStatus SetBufferAttributes(const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs); + + // Get the attributes from a specific buffer. Returns + // kTfLiteDelegateError if the buffer has not been found in the + // backends. + TfLiteStatus GetBufferAttributes(const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs); + // Prepares delegate backends for execution. // Must be called after calling `SetAttributes`. TfLiteStatus PrepareBackends(); diff --git a/tensorflow/lite/core/async/async_subgraph.cc b/tensorflow/lite/core/async/async_subgraph.cc index 7a575372049a07..11fcef091be9ff 100644 --- a/tensorflow/lite/core/async/async_subgraph.cc +++ b/tensorflow/lite/core/async/async_subgraph.cc @@ -179,6 +179,16 @@ TfLiteStatus AsyncSubgraph::SetAttributes(int tensor_index, opaque_node_, tensor_index, attrs); } +TfLiteStatus AsyncSubgraph::SetBufferAttributes( + const TfLiteBackendBuffer* buffer, const TfLiteAttributeMap* attrs) { + return (*async_kernel_->set_buffer_attributes)(async_kernel_, buffer, attrs); +} + +TfLiteStatus AsyncSubgraph::GetBufferAttributes( + const TfLiteBackendBuffer* buffer, TfLiteAttributeMap* attrs) { + return (*async_kernel_->get_buffer_attributes)(async_kernel_, buffer, attrs); +} + TfLiteStatus AsyncSubgraph::Prepare() { if (async_kernel() == nullptr) return kTfLiteError; return (*async_kernel_->prepare)(async_kernel_, opaque_context(), diff --git a/tensorflow/lite/core/async/async_subgraph.h b/tensorflow/lite/core/async/async_subgraph.h index edf87ecaae72d1..cf4f3c905ca381 100644 --- a/tensorflow/lite/core/async/async_subgraph.h +++ b/tensorflow/lite/core/async/async_subgraph.h @@ -109,6 +109,24 @@ class AsyncSubgraph { // Returns true if all backends accept the `attrs`. TfLiteStatus SetAttributes(int tensor_index, const TfLiteAttributeMap* attrs); + // Set the attributes for a specific buffer. `attrs` should be initialized + // before calling this function and could be constructed by calling + // TfLiteAttributeMapCreate(). The attributes will be sent to backend kernels + // and stored in the map with the buffer. `buffer` and `attrs` should not be + // nullptr. The buffer needs to be registered before calling this function. + TfLiteStatus SetBufferAttributes(const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs); + + // Get the attributes for a specific buffer. `attrs` should be initialized + // before calling this function and could be constructed by calling + // TfLiteAttributeMapCreate(). `attrs` will be used to store the attributes + // obtained from the backend kernel. If `attrs` is a non-empty map, it will be + // overwritten by the attributes of the buffer. `buffer` and `attrs` should + // not be nullptr. The buffer needs to be registered before calling this + // function. + TfLiteStatus GetBufferAttributes(const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs); + // Prepares delegate backends for execution. // Must be called after calling `SetAttributes`. TfLiteStatus Prepare(); diff --git a/tensorflow/lite/core/async/c/async_kernel.cc b/tensorflow/lite/core/async/c/async_kernel.cc index e220014954926c..08e6f0f2f8581f 100644 --- a/tensorflow/lite/core/async/c/async_kernel.cc +++ b/tensorflow/lite/core/async/c/async_kernel.cc @@ -100,6 +100,24 @@ void TfLiteAsyncKernelSetSetAttributes( async_kernel->set_attributes = set_attributes; } +void TfLiteAsyncKernelSetSetBufferAttributes( + TfLiteAsyncKernel* async_kernel, + TfLiteStatus (*set_buffer_attributes)(TfLiteAsyncKernel* async_kernel, + const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs)) { + if (!async_kernel) return; + async_kernel->set_buffer_attributes = set_buffer_attributes; +} + +void TfLiteAsyncKernelSetGetBufferAttributes( + TfLiteAsyncKernel* async_kernel, + TfLiteStatus (*get_buffer_attributes)(TfLiteAsyncKernel* async_kernel, + const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs)) { + if (!async_kernel) return; + async_kernel->get_buffer_attributes = get_buffer_attributes; +}; + void TfLiteAsyncKernelSetPrepare( TfLiteAsyncKernel* async_kernel, TfLiteStatus (*prepare)(TfLiteAsyncKernel* async_kernel, diff --git a/tensorflow/lite/core/async/c/async_kernel.h b/tensorflow/lite/core/async/c/async_kernel.h index 1b3c76acee6324..e53eca75c70f65 100644 --- a/tensorflow/lite/core/async/c/async_kernel.h +++ b/tensorflow/lite/core/async/c/async_kernel.h @@ -178,6 +178,44 @@ TFL_CAPI_EXPORT extern void TfLiteAsyncKernelSetSetAttributes( TfLiteOpaqueNode* node, int tensor_index, const TfLiteAttributeMap* attrs)); +/// Sets the callback for the backend to set buffer attributes. +/// +/// `set_buffer_attributes`: +/// Sets the attributes of the buffers. +/// Backend kernel will check if the provided buffer has been registered, and +/// update the map in the backend, so that the callers can retrieve specific +/// buffer's attributes. `attrs` should be initialized +/// before calling this function and could be constructed by calling +/// TfLiteAttributeMapCreate(). The attributes will be sent to backend kernels +/// and stored in the map with the buffer. `buffer` and `attrs` should not be +/// nullptr. The buffer needs to be registered before calling this +/// function. Returns kTfLiteOk if the buffer has been registered and +/// callers can successfully set the attributes for a buffer. +TFL_CAPI_EXPORT extern void TfLiteAsyncKernelSetSetBufferAttributes( + TfLiteAsyncKernel* async_kernel, + TfLiteStatus (*set_buffer_attributes)(TfLiteAsyncKernel* async_kernel, + const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs)); + +/// Sets the callback for the backend to get buffer attributes. +/// +/// `get_buffer_attributes`: +/// Gets the attributes of the buffers. +/// Backend kernel will check if the provided buffer has been registered, and +/// get the corresponding attributes from the map. `attrs` should be initialized +/// before calling this function and could be constructed by calling +/// TfLiteAttributeMapCreate(). `attrs` will be used to store the attributes +/// obtained from the backend kernel. If `attrs` is a non-empty map, it will be +/// overwritten by the attributes of the buffer. `buffer` and `attrs` should not +/// be nullptr. The buffer needs to be registered before calling this function. +/// Returns kTfLiteOk if the buffer has been registered and callers can +/// successfully get the attributes for a buffer. +TFL_CAPI_EXPORT extern void TfLiteAsyncKernelSetGetBufferAttributes( + TfLiteAsyncKernel* async_kernel, + TfLiteStatus (*get_buffer_attributes)(TfLiteAsyncKernel* async_kernel, + const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs)); + /// Sets the callback to prepare the kernels using the information from /// `set_attributes` calls. TFL_CAPI_EXPORT extern void TfLiteAsyncKernelSetPrepare( diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 409c7da45d5add..3389cf7948e0d6 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -809,6 +809,10 @@ class DelegateAsyncKernel : public BackendAsyncKernelInterface { TfLiteStatus SetAttributes(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node, int tensor_index, const TfLiteAttributeMap* attrs) override; + TfLiteStatus SetBufferAttributes(const TfLiteBackendBuffer* buffer, + const TfLiteAttributeMap* attrs) override; + TfLiteStatus GetBufferAttributes(const TfLiteBackendBuffer* buffer, + TfLiteAttributeMap* attrs) override; TfLiteStatus Prepare(TfLiteOpaqueContext* context, TfLiteOpaqueNode* node) override; @@ -1068,6 +1072,18 @@ TfLiteStatus DelegateAsyncKernel::SetAttributesImpl( return kTfLiteOk; } +TfLiteStatus DelegateAsyncKernel::SetBufferAttributes( + const TfLiteBackendBuffer* buffer, const TfLiteAttributeMap* attrs) { + // TODO(b/325338475): Implement the details for set attributes to buffer. + return kTfLiteDelegateError; +} + +TfLiteStatus DelegateAsyncKernel::GetBufferAttributes( + const TfLiteBackendBuffer* buffer, TfLiteAttributeMap* attrs) { + // TODO(b/325338475): Implement the details for get attributes from buffer. + return kTfLiteDelegateError; +} + TfLiteStatus DelegateAsyncKernel::Prepare(TfLiteOpaqueContext* opaque_context, TfLiteOpaqueNode* opaque_node) { // The following cast is safe only because this code is part of the From 8958c652df2d5db28f71d95b83fa119cdc03942a Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Fri, 29 Mar 2024 11:49:47 -0700 Subject: [PATCH 095/124] PR #7849: [XLA:CPU] Add support for cross-process collectives using mpi. Imported from GitHub PR https://github.com/openxla/xla/pull/7849 Mpi collectives as proposed in https://github.com/google/jax/issues/11182?notification_referrer_id=NT_kwDOAG8zGbIzODQ5MDcxMzM0OjcyODc1Nzc#issuecomment-1851591135. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani : add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani : add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani : add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani : fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani : fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani : bump xla_extension_version Merging this change closes #7849 PiperOrigin-RevId: 620302264 --- third_party/mpitrampoline/BUILD | 1 + third_party/mpitrampoline/gen.patch | 149 +++++++++ third_party/mpitrampoline/mpitrampoline.BUILD | 135 +++++++++ third_party/mpitrampoline/workspace.bzl | 18 ++ .../xla/third_party/mpitrampoline/BUILD | 1 + .../xla/third_party/mpitrampoline/gen.patch | 149 +++++++++ .../mpitrampoline/mpitrampoline.BUILD | 135 +++++++++ .../third_party/mpitrampoline/workspace.bzl | 18 ++ third_party/xla/workspace2.bzl | 2 + third_party/xla/xla/pjrt/cpu/BUILD | 32 ++ .../xla/xla/pjrt/cpu/mpi_collectives.cc | 283 ++++++++++++++++++ .../xla/xla/pjrt/cpu/mpi_collectives.h | 102 +++++++ third_party/xla/xla/python/BUILD | 6 + third_party/xla/xla/python/xla.cc | 22 ++ third_party/xla/xla/python/xla_client.py | 2 +- 15 files changed, 1054 insertions(+), 1 deletion(-) create mode 100644 third_party/mpitrampoline/BUILD create mode 100644 third_party/mpitrampoline/gen.patch create mode 100644 third_party/mpitrampoline/mpitrampoline.BUILD create mode 100644 third_party/mpitrampoline/workspace.bzl create mode 100644 third_party/xla/third_party/mpitrampoline/BUILD create mode 100644 third_party/xla/third_party/mpitrampoline/gen.patch create mode 100644 third_party/xla/third_party/mpitrampoline/mpitrampoline.BUILD create mode 100644 third_party/xla/third_party/mpitrampoline/workspace.bzl create mode 100644 third_party/xla/xla/pjrt/cpu/mpi_collectives.cc create mode 100644 third_party/xla/xla/pjrt/cpu/mpi_collectives.h diff --git a/third_party/mpitrampoline/BUILD b/third_party/mpitrampoline/BUILD new file mode 100644 index 00000000000000..3c413807167aeb --- /dev/null +++ b/third_party/mpitrampoline/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/mpitrampoline/gen.patch b/third_party/mpitrampoline/gen.patch new file mode 100644 index 00000000000000..35124db0abb1e3 --- /dev/null +++ b/third_party/mpitrampoline/gen.patch @@ -0,0 +1,149 @@ +diff --git a/gen/gen_decl.py b/gen/gen_decl.py +index 1005b95..696b4e0 100755 +--- a/gen/gen_decl.py ++++ b/gen/gen_decl.py +@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False +@@ -24,7 +24,7 @@ def wrap(line): + lines.append(line) + return "\n".join(lines) + +-with open("include/mpi_decl_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Declare C MPI constants\n") + file.write("\n") + for (tp, nm) in constants: +@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs)) + +-with open("include/mpi_decl_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Declare C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("include/mpi_decl_constants_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file: + file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl))) + file.write("\n") + +-with open("include/mpi_decl_functions_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_defn.py b/gen/gen_defn.py +index bf31f35..318222e 100755 +--- a/gen/gen_defn.py ++++ b/gen/gen_defn.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_defn_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Define C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Define C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("src/mpi_defn_constants_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file: + # Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes + file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_init.py b/gen/gen_init.py +index 4939261..0e52822 100755 +--- a/gen/gen_init.py ++++ b/gen/gen_init.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_init_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Initialize C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)} + file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Initialize C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file: + subs['anm{0}'.format(i)] = anm + file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_constants_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"} + file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: diff --git a/third_party/mpitrampoline/mpitrampoline.BUILD b/third_party/mpitrampoline/mpitrampoline.BUILD new file mode 100644 index 00000000000000..cf8e9c336e4e33 --- /dev/null +++ b/third_party/mpitrampoline/mpitrampoline.BUILD @@ -0,0 +1,135 @@ +# Description: +# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI + +load("@org_tensorflow//xla:strict.default.bzl", "py_strict_binary") +load("//third_party/bazel_skylib/rules:expand_template.bzl", "expand_template") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE.md"]) + +genrule( + name = "mpi_version", + srcs = [ + "CMakeLists.txt", + "include/mpi_version.h.in", + ], + outs = ["include/mpi_version.h"], + cmd = """ + PROJECT_VERSION=`cat $(location CMakeLists.txt) \ + | grep "MPItrampoline VERSION" | awk '{print $$NF}'` + PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1` + PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2` + PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3` + sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \ + -e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \ + -e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \ + -e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \ + $(location include/mpi_version.h.in) > $(location include/mpi_version.h) + """, +) + +expand_template( + name = "mpi_defaults", + out = "src/mpi_defaults.h", + substitutions = { + "@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "", + "@MPITRAMPOLINE_DEFAULT_LIB@": "", + "@MPITRAMPOLINE_DEFAULT_PRELOAD@": "", + "@MPITRAMPOLINE_DEFAULT_VERBOSE@": "", + }, + template = "src/mpi_defaults.h.in", +) + +py_strict_binary( + name = "gen_decl", + srcs = [ + "gen/gen_decl.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "decl", + outs = [ + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + ], + cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \ + $(location include/mpi_decl_functions_c.h)", + tools = [":gen_decl"], +) + +py_strict_binary( + name = "gen_defn", + srcs = [ + "gen/gen_defn.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "defn", + outs = [ + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + ], + cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \ + $(location include/mpi_defn_functions_c.h)", + tools = [":gen_defn"], +) + +py_strict_binary( + name = "gen_init", + srcs = [ + "gen/gen_init.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "init", + outs = [ + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + ], + cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \ + $(location include/mpi_init_functions_c.h)", + tools = [":gen_init"], +) + +cc_library( + name = "mpitrampoline", + srcs = [ + "src/mpi.c", + ], + hdrs = [ + "include/mpi.h", + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + "include/mpi_version.h", + "mpiabi/mpiabi.h", + "src/mpi_defaults.h", + ], + copts = [ + "-fexceptions", + ], + includes = [ + "include", + "mpiabi", + "src", + ], +) diff --git a/third_party/mpitrampoline/workspace.bzl b/third_party/mpitrampoline/workspace.bzl new file mode 100644 index 00000000000000..4748931ae6e368 --- /dev/null +++ b/third_party/mpitrampoline/workspace.bzl @@ -0,0 +1,18 @@ +"""Provides the repository macro to import mpitrampoline.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports mpitrampoline.""" + + MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed" + MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84" + + tf_http_archive( + name = "mpitrampoline", + sha256 = MPITRAMPOLINE_SHA256, + strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT), + urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)), + patch_file = ["//third_party/mpitrampoline:gen.patch"], + build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD", + ) diff --git a/third_party/xla/third_party/mpitrampoline/BUILD b/third_party/xla/third_party/mpitrampoline/BUILD new file mode 100644 index 00000000000000..3c413807167aeb --- /dev/null +++ b/third_party/xla/third_party/mpitrampoline/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/xla/third_party/mpitrampoline/gen.patch b/third_party/xla/third_party/mpitrampoline/gen.patch new file mode 100644 index 00000000000000..35124db0abb1e3 --- /dev/null +++ b/third_party/xla/third_party/mpitrampoline/gen.patch @@ -0,0 +1,149 @@ +diff --git a/gen/gen_decl.py b/gen/gen_decl.py +index 1005b95..696b4e0 100755 +--- a/gen/gen_decl.py ++++ b/gen/gen_decl.py +@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False +@@ -24,7 +24,7 @@ def wrap(line): + lines.append(line) + return "\n".join(lines) + +-with open("include/mpi_decl_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Declare C MPI constants\n") + file.write("\n") + for (tp, nm) in constants: +@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs)) + +-with open("include/mpi_decl_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Declare C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("include/mpi_decl_constants_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file: + file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl))) + file.write("\n") + +-with open("include/mpi_decl_functions_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_defn.py b/gen/gen_defn.py +index bf31f35..318222e 100755 +--- a/gen/gen_defn.py ++++ b/gen/gen_defn.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_defn_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Define C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Define C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("src/mpi_defn_constants_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file: + # Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes + file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_init.py b/gen/gen_init.py +index 4939261..0e52822 100755 +--- a/gen/gen_init.py ++++ b/gen/gen_init.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_init_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Initialize C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)} + file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Initialize C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file: + subs['anm{0}'.format(i)] = anm + file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_constants_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"} + file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: diff --git a/third_party/xla/third_party/mpitrampoline/mpitrampoline.BUILD b/third_party/xla/third_party/mpitrampoline/mpitrampoline.BUILD new file mode 100644 index 00000000000000..f46e39d762a159 --- /dev/null +++ b/third_party/xla/third_party/mpitrampoline/mpitrampoline.BUILD @@ -0,0 +1,135 @@ +# Description: +# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@local_xla//xla:strict.default.bzl", "py_strict_binary") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE.md"]) + +genrule( + name = "mpi_version", + srcs = [ + "CMakeLists.txt", + "include/mpi_version.h.in", + ], + outs = ["include/mpi_version.h"], + cmd = """ + PROJECT_VERSION=`cat $(location CMakeLists.txt) \ + | grep "MPItrampoline VERSION" | awk '{print $$NF}'` + PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1` + PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2` + PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3` + sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \ + -e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \ + -e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \ + -e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \ + $(location include/mpi_version.h.in) > $(location include/mpi_version.h) + """, +) + +expand_template( + name = "mpi_defaults", + out = "src/mpi_defaults.h", + substitutions = { + "@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "", + "@MPITRAMPOLINE_DEFAULT_LIB@": "", + "@MPITRAMPOLINE_DEFAULT_PRELOAD@": "", + "@MPITRAMPOLINE_DEFAULT_VERBOSE@": "", + }, + template = "src/mpi_defaults.h.in", +) + +py_strict_binary( + name = "gen_decl", + srcs = [ + "gen/gen_decl.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "decl", + outs = [ + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + ], + cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \ + $(location include/mpi_decl_functions_c.h)", + tools = [":gen_decl"], +) + +py_strict_binary( + name = "gen_defn", + srcs = [ + "gen/gen_defn.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "defn", + outs = [ + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + ], + cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \ + $(location include/mpi_defn_functions_c.h)", + tools = [":gen_defn"], +) + +py_strict_binary( + name = "gen_init", + srcs = [ + "gen/gen_init.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "init", + outs = [ + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + ], + cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \ + $(location include/mpi_init_functions_c.h)", + tools = [":gen_init"], +) + +cc_library( + name = "mpitrampoline", + srcs = [ + "src/mpi.c", + ], + hdrs = [ + "include/mpi.h", + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + "include/mpi_version.h", + "mpiabi/mpiabi.h", + "src/mpi_defaults.h", + ], + copts = [ + "-fexceptions", + ], + includes = [ + "include", + "mpiabi", + "src", + ], +) diff --git a/third_party/xla/third_party/mpitrampoline/workspace.bzl b/third_party/xla/third_party/mpitrampoline/workspace.bzl new file mode 100644 index 00000000000000..4748931ae6e368 --- /dev/null +++ b/third_party/xla/third_party/mpitrampoline/workspace.bzl @@ -0,0 +1,18 @@ +"""Provides the repository macro to import mpitrampoline.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports mpitrampoline.""" + + MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed" + MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84" + + tf_http_archive( + name = "mpitrampoline", + sha256 = MPITRAMPOLINE_SHA256, + strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT), + urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)), + patch_file = ["//third_party/mpitrampoline:gen.patch"], + build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD", + ) diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index e7c6c4be2500d1..9b9fd5e9265ed9 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -10,6 +10,7 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # Import third party repository rules. See go/tfbr-thirdparty. load("//third_party/dlpack:workspace.bzl", dlpack = "repo") load("//third_party/gloo:workspace.bzl", gloo = "repo") +load("//third_party/mpitrampoline:workspace.bzl", mpitrampoline = "repo") load("//third_party/nanobind:workspace.bzl", nanobind = "repo") load("//third_party/robin_map:workspace.bzl", robin_map = "repo") load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") @@ -19,6 +20,7 @@ def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ dlpack() gloo() + mpitrampoline() nanobind() robin_map() stablehlo() diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index a673cd5191e5ee..324d684611e22c 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -1,3 +1,4 @@ +load("@local_tsl//tsl:tsl.bzl", "if_oss") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") @@ -286,3 +287,34 @@ cc_library( "@local_tsl//tsl/platform:logging", ], ) + +cc_library( + name = "mpi_collectives", + srcs = if_oss(["mpi_collectives.cc"]), + hdrs = if_oss(["mpi_collectives.h"]), + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = if_oss([ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@mpitrampoline", + ]), +) diff --git a/third_party/xla/xla/pjrt/cpu/mpi_collectives.cc b/third_party/xla/xla/pjrt/cpu/mpi_collectives.cc new file mode 100644 index 00000000000000..d2c93fd75450f5 --- /dev/null +++ b/third_party/xla/xla/pjrt/cpu/mpi_collectives.cc @@ -0,0 +1,283 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/mpi_collectives.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mpi.h" // NOLINT +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla::cpu { + +absl::StatusOr PrimitiveTypeToMpiType( + PrimitiveType element_type) { + switch (element_type) { + case S8: + return MPI_INT8_T; + case U8: + case PRED: + return MPI_UINT8_T; + case S16: + return MPI_INT16_T; + case U16: + return MPI_UINT16_T; + case S32: + return MPI_INT32_T; + case U32: + return MPI_UINT32_T; + case S64: + return MPI_INT64_T; + case U64: + return MPI_UINT64_T; + case F32: + return MPI_FLOAT; + case F64: + return MPI_DOUBLE; + case C64: + return MPI_C_COMPLEX; + case C128: + return MPI_C_DOUBLE_COMPLEX; + default: + // For implementing the reduction of unsupported types + // see e.g. https://stackoverflow.com/a/29643391 + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported primitive type for reduction: ", + primitive_util::LowercasePrimitiveTypeName(element_type))); + } +} + +bool MpiTypeIsComplex(MPI_Datatype type) { + return type == MPI_C_COMPLEX || type == MPI_C_DOUBLE_COMPLEX; +} + +absl::StatusOr ReductionKindToMpiOp(ReductionKind reduction_kind, + MPI_Datatype type) { + switch (reduction_kind) { + case ReductionKind::SUM: + return MPI_SUM; + case ReductionKind::PRODUCT: + return MPI_PROD; + case ReductionKind::MIN: + if (!MpiTypeIsComplex(type)) { + return MPI_MIN; + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + case ReductionKind::MAX: + if (!MpiTypeIsComplex(type)) { + return MPI_MAX; + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Unknown reduction kind: ", reduction_kind)); + } +} + +static absl::Status MpiErrorToAbslStatus(int error) { + if (error != MPI_SUCCESS) { + char error_str[MPI_MAX_ERROR_STRING]; + int len; + MPI_Error_string(error, error_str, &len); + return absl::UnknownError(absl::StrCat("MPI error: ", error_str)); + } + return absl::OkStatus(); +} + +MpiCollectivesCommunicator::MpiCollectivesCommunicator(int color, int key) { + MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); + MPI_Comm_rank(comm_, &mpi_rank_); + MPI_Comm_size(comm_, &mpi_size_); +} + +MpiCollectivesCommunicator::~MpiCollectivesCommunicator() { + MPI_Comm_free(&comm_); +}; + +absl::Status MpiCollectivesCommunicator::AllReduce( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Allreduce(input_buffer, output_buffer, + num_elements, type, op, comm_)); +} + +absl::Status MpiCollectivesCommunicator::CollectivePermute( + const RendezvousKey& key, size_t num_bytes, std::optional source_rank, + absl::Span target_ranks, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + int tag = 0; // TODO come up with better tags. + + const int rank = mpi_rank_; + + std::vector requests; + + if (source_rank) { + if (*source_rank == rank) { + std::memcpy(output_buffer, input_buffer, num_bytes); + } else { + VLOG(1) << "recv at " << rank << " from " << *source_rank; + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Irecv(output_buffer, num_bytes, MPI_BYTE, *source_rank, tag, + comm_, &requests.back()))); + } + } else { + std::memset(output_buffer, 0, num_bytes); + } + + for (int target : target_ranks) { + if (target != rank) { + VLOG(1) << "send from " << rank << " to " << target; + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Isend(input_buffer, num_bytes, MPI_BYTE, target, tag, comm_, + &requests.back()))); + } + } + + for (auto& request : requests) { + TF_RETURN_IF_ERROR( + MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCollectivesCommunicator::AllToAll( + const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) { + // We can't use MPI_Alltoall directly because it assumes that the inputs and + // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. + + int tag = 0; // TODO use better tags. + const int rank = mpi_rank_; + const int size = mpi_size_; + TF_RET_CHECK(size == input_buffers.size()); + TF_RET_CHECK(size == output_buffers.size()); + + std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); + + for (int i = 1; i < size; i++) { + int send_rank = (rank + i) % size; + int recv_rank = (rank + size - i) % size; + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Sendrecv(input_buffers[send_rank], chunk_bytes, MPI_BYTE, send_rank, + tag, output_buffers[recv_rank], chunk_bytes, MPI_BYTE, + recv_rank, tag, comm_, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCollectivesCommunicator::AllGather(const RendezvousKey& key, + size_t chunk_bytes, + const void* input_buffer, + void* output_buffer, + absl::Duration timeout) { + return MpiErrorToAbslStatus(MPI_Allgather(input_buffer, chunk_bytes, MPI_BYTE, + output_buffer, chunk_bytes, + MPI_BYTE, comm_)); +} + +absl::Status MpiCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + const int size = mpi_size_; + std::vector recvcounts(size, chunk_elems); + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Reduce_scatter( + input_buffer, output_buffer, recvcounts.data(), type, op, comm_)); +} + +void MpiCollectives::Init() { + int provided; + MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &provided); + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_); + VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_; +} + +void MpiCollectives::Finalize() { + contexts_.clear(); + MPI_Finalize(); +} + +absl::StatusOr> +MpiCollectives::GetCommunicator(absl::Span global_devices, + int rank) { + int flag; + MPI_Is_thread_main(&flag); + if (!flag) { + return absl::UnknownError( + absl::StrCat("MPI: Communicator requested from a thread that is not " + "the one MPI was initialized from. Multiple " + "threads/devices per process are not yet supported.")); + } + + auto& context = contexts_[std::make_tuple( + std::vector(global_devices.begin(), global_devices.end()), + rank)]; + if (context) { + return context; + } + + int color; + int key = 0; + if (global_devices.size() > 0) { + color = static_cast(global_devices.at(0).value()); + key = rank; + } else { + color = MPI_UNDEFINED; + } + context = std::make_shared(color, key); + return context; +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/pjrt/cpu/mpi_collectives.h b/third_party/xla/xla/pjrt/cpu/mpi_collectives.h new file mode 100644 index 00000000000000..fdf6ec81b6dc6b --- /dev/null +++ b/third_party/xla/xla/pjrt/cpu/mpi_collectives.h @@ -0,0 +1,102 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_ +#define XLA_PJRT_CPU_MPI_COLLECTIVES_H_ + +#include +#include +#include +#include +#include + +#include "mpi.h" // NOLINT +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class MpiCollectivesCommunicator : public CollectivesCommunicator { + public: + explicit MpiCollectivesCommunicator(int color, int key); + ~MpiCollectivesCommunicator() override; + + absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, + std::optional source_rank, + absl::Span target_ranks, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, + absl::Duration timeout) override; + absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + + private: + MPI_Comm comm_; + int mpi_rank_; + int mpi_size_; +}; + +class MpiCollectives : public CollectivesInterface { + public: + /* + The user has to explicitly call Init() and Finalize() before and + after use. + For example, using the Python client, this can be achieved with: + + collectives = xla_client._xla.make_mpi_collectives() + collectives.Init() + atexit.register(collectives.Finalize) + */ + void Init(); + void Finalize(); + + absl::StatusOr> GetCommunicator( + absl::Span global_devices, int rank) override; + + private: + absl::Status ExchangeGlobalDeviceIds( + absl::Span global_devices, int rank); + + int mpi_world_rank_; + int mpi_world_size_; + absl::flat_hash_map, int>, + std::shared_ptr> + contexts_; +}; + +} // namespace xla::cpu + +#endif // XLA_PJRT_CPU_MPI_COLLECTIVES_H_ diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 2f022afb5c51f1..267e7284f05cb0 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -1261,6 +1261,12 @@ tsl_pybind_extension( "//xla/pjrt/cpu:gloo_collectives", "//xla/pjrt/cpu:gloo_kv_store", ], + }) + select({ + # mpitrampoline does not build on windows + "@local_tsl//tsl:windows": [], + "//conditions:default": [ + "//xla/pjrt/cpu:mpi_collectives", + ], }) + select({ ":gpu_enabled": [ "//xla/pjrt/gpu:se_gpu_pjrt_client", diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 2e4f390f0dfb93..1c9cf5bbcdadbb 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -67,6 +67,11 @@ limitations under the License. #include "xla/pjrt/cpu/gloo_collectives.h" #include "xla/pjrt/cpu/gloo_kv_store.h" #endif // __linux__ + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/pjrt/cpu/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/exceptions.h" @@ -270,6 +275,23 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, nb::arg("interface").none() = std::nullopt); +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m_nb, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m_nb.def("make_mpi_collectives", + []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m_nb.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + m_nb.def( "get_tfrt_cpu_client", [](bool asynchronous, diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index ef9e4e5291ab97..64eb6cd7d4e1dd 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 250 +_version = 251 # Version number for MLIR:Python components. mlir_api_version = 55 From f606d83b95f6517f3dca4ede243713d6f3af03aa Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Fri, 29 Mar 2024 12:40:12 -0700 Subject: [PATCH 096/124] Use `bytes` proto field type for string values of `xla::PjRtValueType` to `bytes` `xla::PjRtValueType` is defined in C++, where its `std::string` value can contain any string (not necessarily UTF-8). Protobuf verison 3 requires a `string` field to contain UTF-8, so it is more suitable to use `bytes` to express this value. (Note that the string value of `xla::PjRtValueType` would be often consumed by Python, where nanobind would convert `std::string` into Python `str` with UTF-8 decoding. However, this is what some users of `xla::PjRtValueType` choose to do; this is not sufficient enough to constrain the string to be UTF-8 only in C++ APIs.) This is a preemptive change; there is no known problem of using a `string` field previously. PiperOrigin-RevId: 620315110 --- third_party/xla/xla/python/ifrt_proxy/common/types.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.proto b/third_party/xla/xla/python/ifrt_proxy/common/types.proto index 7ef48aed10c54b..2de88772abe906 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/types.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.proto @@ -85,7 +85,7 @@ message Variant { } oneof value { - string string_value = 1; + bytes string_value = 1; sfixed64 int64_value = 2; Int64List int64_list = 3; float float_value = 4; From ea867d6ee69125ebae452f07f5c6b54d82267f14 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Fri, 29 Mar 2024 13:05:35 -0700 Subject: [PATCH 097/124] Move `tsl/python` to `xla/tsl/python` PiperOrigin-RevId: 620320903 --- .../lite/python/interpreter_wrapper/numpy.h | 3 ++- tensorflow/python/BUILD | 10 ++++---- tensorflow/python/client/BUILD | 6 ++--- tensorflow/python/client/tf_session_helper.h | 2 +- .../python/client/tf_session_wrapper.cc | 2 +- tensorflow/python/eager/BUILD | 2 +- tensorflow/python/eager/pywrap_tensor.cc | 2 +- tensorflow/python/eager/pywrap_tensor.h | 2 +- tensorflow/python/framework/BUILD | 6 ++--- tensorflow/python/lib/core/BUILD | 16 ++++++------- tensorflow/python/lib/core/ndarray_tensor.cc | 4 ++-- .../python/lib/core/ndarray_tensor_bridge.cc | 4 ++-- .../python/lib/core/ndarray_tensor_bridge.h | 2 +- tensorflow/python/lib/core/py_func.cc | 2 +- tensorflow/python/lib/core/py_seq_tensor.cc | 2 +- tensorflow/python/tfe_wrapper.cc | 2 +- .../tools/def_file_filter/symbols_pybind.txt | 4 ++-- .../tools/pip_package/build_pip_package.py | 23 ++++++++++++------- .../tools/def_file_filter/symbols_pybind.txt | 4 ++-- .../third_party/tsl/tsl/platform/ml_dtypes.h | 14 +++++------ third_party/xla/xla/python/BUILD | 10 ++++---- third_party/xla/xla/python/nb_numpy.cc | 2 +- third_party/xla/xla/python/nb_numpy.h | 2 +- third_party/xla/xla/python/pmap_lib.cc | 2 +- .../xla/xla/python/py_compile_only_client.cc | 2 +- third_party/xla/xla/python/py_values.cc | 2 +- third_party/xla/xla/python/tools/BUILD | 2 +- third_party/xla/xla/python/tools/_types.cc | 2 +- third_party/xla/xla/python/types.cc | 2 +- third_party/xla/xla/python/xla.cc | 2 +- .../tsl => xla}/tsl/python/lib/core/BUILD | 0 .../tsl/python/lib/core/ml_dtypes.cc | 4 ++-- .../tsl/python/lib/core/ml_dtypes.h | 6 ++--- .../tsl => xla}/tsl/python/lib/core/numpy.cc | 2 +- .../tsl => xla}/tsl/python/lib/core/numpy.h | 6 ++--- 35 files changed, 83 insertions(+), 75 deletions(-) rename third_party/xla/{third_party/tsl => xla}/tsl/python/lib/core/BUILD (100%) rename third_party/xla/{third_party/tsl => xla}/tsl/python/lib/core/ml_dtypes.cc (97%) rename third_party/xla/{third_party/tsl => xla}/tsl/python/lib/core/ml_dtypes.h (90%) rename third_party/xla/{third_party/tsl => xla}/tsl/python/lib/core/numpy.cc (95%) rename third_party/xla/{third_party/tsl => xla}/tsl/python/lib/core/numpy.h (91%) diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.h b/tensorflow/lite/python/interpreter_wrapper/numpy.h index acc3dbd9fdab3a..e04418c32df7f4 100644 --- a/tensorflow/lite/python/interpreter_wrapper/numpy.h +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.h @@ -40,7 +40,8 @@ limitations under the License. // translation unit boundaries. // // For more info see https://sourceforge.net/p/numpy/mailman/message/5700519 -// See also tensorflow/tsl/python/lib/core/numpy.h for a similar approach. +// See also tensorflow/compiler/xla/tsl/python/lib/core/numpy.h for a similar +// approach. #define PY_ARRAY_UNIQUE_SYMBOL _tflite_numpy_api #ifndef TFLITE_IMPORT_NUMPY #define NO_IMPORT_ARRAY diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 210fa3b1bff005..159e6db204f0cb 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -834,9 +834,9 @@ pywrap_tensorflow_macro( "@com_google_absl//absl/types:span", "@local_tsl//tsl/profiler/rpc:profiler_server_impl", "@local_tsl//tsl/profiler/rpc/client:profiler_client_impl", - "@local_tsl//tsl/python/lib/core:numpy", "@local_xla//xla/backends/profiler/cpu:python_tracer", "@local_xla//xla/stream_executor:stream_executor_impl", + "@local_xla//xla/tsl/python/lib/core:numpy", ] + select({ "//tensorflow/compiler/mlir/python:disable_mlir_config": [], "//conditions:default": [ @@ -951,9 +951,9 @@ filegroup( "@local_tsl//tsl/profiler/lib:profiler_session_impl", # profiler "@local_tsl//tsl/profiler/rpc:profiler_server_impl", # profiler "@local_tsl//tsl/profiler/rpc/client:profiler_client_impl", - "@local_tsl//tsl/python/lib/core:ml_dtypes_lib", # bfloat16, float8_e4m3fn, float8_e5m2 - "@local_tsl//tsl/python/lib/core:numpy", # checkpoint_reader "@local_xla//xla/stream_executor", # stat_summarizer + "@local_xla//xla/tsl/python/lib/core:ml_dtypes_lib", # bfloat16, float8_e4m3fn, float8_e5m2 + "@local_xla//xla/tsl/python/lib/core:numpy", # checkpoint_reader ] + select({ "//tensorflow/compiler/mlir/python:disable_mlir_config": [], "//conditions:default": [ @@ -1151,7 +1151,7 @@ cc_library( name = "unified_api_pywrap_required_headers", textual_hdrs = [ "//tensorflow/python/lib/core:basic_hdrs", - "@local_tsl//tsl/python/lib/core:basic_hdrs", + "@local_xla//xla/tsl/python/lib/core:basic_hdrs", "//tensorflow/c:headers", "//tensorflow/c:safe_ptr_hdr", "//tensorflow/c/eager:headers", @@ -1205,7 +1205,7 @@ tf_python_pybind_extension( "//tensorflow/python/util:util_hdr", "@local_tsl//tsl/distributed_runtime:pywrap_required_hdrs", "@local_tsl//tsl/distributed_runtime/coordination:pywrap_required_hdrs", - "@local_tsl//tsl/python/lib/core:numpy_hdr", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", ], dynamic_deps = [":_pywrap_tensorflow_internal.so"] + select({ "//tensorflow:macos": ["//tensorflow:libtensorflow_framework.%s.dylib" % VERSION], diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 15f5937a06b95a..64ff5af1017bea 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -45,7 +45,7 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", "@local_tsl//tsl/distributed_runtime:pywrap_required_hdrs", "@local_tsl//tsl/distributed_runtime/coordination:pywrap_required_hdrs", - "@local_tsl//tsl/python/lib/core:numpy_hdr", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", ], enable_stub_generation = True, pytype_srcs = [ @@ -68,7 +68,7 @@ tf_python_pybind_extension( "//third_party/python_runtime:headers", "@com_google_absl//absl/types:optional", "@eigen_archive//:eigen3", - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:numpy", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", @@ -259,7 +259,7 @@ tf_cuda_library( "//tensorflow/python/lib/core:safe_pyobject_ptr", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:numpy", ], alwayslink = 1, ) diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 19e04fd1c5b29a..8a9682307c4ff5 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -18,7 +18,7 @@ limitations under the License. // Must be included first // clang-format off -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT // clang-format on #include "tensorflow/c/c_api.h" diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index 0e7a5efe142b13..7492c91bbdb355 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/c/safe_ptr.h" #include "tensorflow/c/tf_buffer.h" #include "tensorflow/c/tf_datatype.h" +#include "xla/tsl/python/lib/core/numpy.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/python/lib/core/pybind11_status.h" #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" #include "tsl/platform/mutex.h" -#include "tsl/python/lib/core/numpy.h" namespace pybind11 { namespace detail { diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index af25ec35db0bdf..2c92adf4f1e300 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -80,7 +80,7 @@ cc_library( "@com_google_absl//absl/types:variant", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:numpy", "@pybind11", ], ) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 167511e80c65ef..6dc8541ef09592 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // Must be included first // clang-format off -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT // clang-format on #include "tensorflow/python/eager/pywrap_tensor.h" diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h index 53c5b66a93fcb9..bebf4e8558c463 100644 --- a/tensorflow/python/eager/pywrap_tensor.h +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -17,7 +17,7 @@ limitations under the License. // Must be included first // clang-format off -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT // clang-format on #include "tensorflow/c/eager/c_api.h" diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index fd25d20a55ffef..dccc222720757f 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -865,7 +865,7 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", "@local_tsl//tsl/distributed_runtime:pywrap_required_hdrs", "@local_tsl//tsl/distributed_runtime/coordination:pywrap_required_hdrs", - "@local_tsl//tsl/python/lib/core:numpy_hdr", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", ], enable_stub_generation = True, pytype_srcs = [ @@ -965,7 +965,7 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", "@local_tsl//tsl/distributed_runtime:pywrap_required_hdrs", "@local_tsl//tsl/distributed_runtime/coordination:pywrap_required_hdrs", - "@local_tsl//tsl/python/lib/core:numpy_hdr", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", ], enable_stub_generation = True, pytype_srcs = [ @@ -1109,7 +1109,7 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", "@local_tsl//tsl/distributed_runtime:pywrap_required_hdrs", "@local_tsl//tsl/distributed_runtime/coordination:pywrap_required_hdrs", - "@local_tsl//tsl/python/lib/core:numpy_hdr", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", ], enable_stub_generation = True, pytype_srcs = [ diff --git a/tensorflow/python/lib/core/BUILD b/tensorflow/python/lib/core/BUILD index a6904053c60b76..0ab9f789976e11 100644 --- a/tensorflow/python/lib/core/BUILD +++ b/tensorflow/python/lib/core/BUILD @@ -39,8 +39,8 @@ cc_library( "//tensorflow/c:tf_datatype_hdrs", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@local_tsl//tsl/python/lib/core:ml_dtypes_lib", - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:ml_dtypes_lib", + "@local_xla//xla/tsl/python/lib/core:numpy", ], ) @@ -171,7 +171,7 @@ cc_library( "//tensorflow/python/eager:pywrap_tfe_lib", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:numpy", ], alwayslink = 1, ) @@ -213,7 +213,7 @@ cc_library( "//tensorflow/c:headers", "//tensorflow/c:safe_ptr_hdr", "//tensorflow/c/eager:headers", - "@local_tsl//tsl/python/lib/core:numpy_hdr", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", ], features = [ "-parse_headers", @@ -228,7 +228,7 @@ cc_library( "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:numpy", ], ) @@ -249,8 +249,8 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/python/lib/core:safe_pyobject_ptr", - "@local_tsl//tsl/python/lib/core:ml_dtypes_lib", - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:ml_dtypes_lib", + "@local_xla//xla/tsl/python/lib/core:numpy", ], ) @@ -273,7 +273,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863 - "@local_tsl//tsl/python/lib/core:numpy", + "@local_xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index ea38b32266ab63..5096026cbdb1db 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -15,7 +15,7 @@ limitations under the License. // Must be included first // clang-format off #include "tensorflow/c/tf_datatype.h" -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT // clang-format on #include "tensorflow/python/lib/core/ndarray_tensor.h" @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/tf_tensor_internal.h" +#include "xla/tsl/python/lib/core/ml_dtypes.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/util/port.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" -#include "tsl/python/lib/core/ml_dtypes.h" namespace tensorflow { namespace { diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc index d45ca6ee0c67d6..c7fa135c82ad7c 100644 --- a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc +++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc @@ -16,7 +16,7 @@ limitations under the License. // clang-format off // Must be included first. #include "tensorflow/c/tf_datatype.h" -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" // clang-format on #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" @@ -24,10 +24,10 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "xla/tsl/python/lib/core/ml_dtypes.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/python/lib/core/py_util.h" -#include "tsl/python/lib/core/ml_dtypes.h" namespace tensorflow { diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.h b/tensorflow/python/lib/core/ndarray_tensor_bridge.h index ed2da4afc0c230..fe98f8818d46bb 100644 --- a/tensorflow/python/lib/core/ndarray_tensor_bridge.h +++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.h @@ -17,7 +17,7 @@ limitations under the License. // Must be included first // clang-format off -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT // clang-format on #include diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 9eac4c2c207e97..d1b7986c0a998e 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -17,7 +17,7 @@ limitations under the License. // Must be included first. #include "tensorflow/python/lib/core/py_func.h" -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" // clang-format: on #include diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index b572244c74b40f..aeac93b3711984 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // Must be included first // clang-format off -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT // clang-format on #include "tensorflow/python/lib/core/py_seq_tensor.h" diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index a308a621912917..846b8693c227b9 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "pybind11/attr.h" // from @pybind11 -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT // clang-format on #include "Python.h" diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 1a158e696f8590..78c42c6f454c5c 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -59,7 +59,7 @@ tensorflow::tfprof::SerializeToString [//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool] # graph_analyze tensorflow::grappler::graph_analyzer::GraphAnalyzerTool -[//external/local_tsl/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8 +[//external/local_xla/xla/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8 tsl::ml_dtypes::RegisterTypes tsl::ml_dtypes::GetBfloat16Dtype tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype @@ -307,7 +307,7 @@ tensorflow::AddWhileInputHack tensorflow::RecordMutation tensorflow::Graph::IsControlEdge -[//external/local_tsl/tsl/python/lib/core:numpy] # tf_session +[//external/local_xla/xla/tsl/python/lib/core:numpy] # tf_session tsl::ImportNumpy _tsl_numpy_api diff --git a/tensorflow/tools/pip_package/build_pip_package.py b/tensorflow/tools/pip_package/build_pip_package.py index 88f27b25cc5329..9ac5ad2d7b57c8 100644 --- a/tensorflow/tools/pip_package/build_pip_package.py +++ b/tensorflow/tools/pip_package/build_pip_package.py @@ -219,14 +219,21 @@ def patch_so(srcs_dir: str) -> None: srcs_dir: target directory with .so files to patch. """ to_patch = { - "tensorflow/python/_pywrap_tensorflow_internal.so": - "$ORIGIN/../../tensorflow/tsl/python/lib/core", - ("tensorflow/compiler/mlir/quantization/tensorflow/python/" - "pywrap_function_lib.so"): "$ORIGIN/../../../../../python", - ("tensorflow/compiler/mlir/quantization/tensorflow/python/" - "pywrap_quantize_model.so"): "$ORIGIN/../../../../../python", - ("tensorflow/compiler/mlir/quantization/tensorflow/calibrator/" - "pywrap_calibration.so"): "$ORIGIN/../../../../../python", + "tensorflow/python/_pywrap_tensorflow_internal.so": ( + "$ORIGIN/../../tensorflow/compiler/xla/tsl/python/lib/core" + ), + ( + "tensorflow/compiler/mlir/quantization/tensorflow/python/" + "pywrap_function_lib.so" + ): "$ORIGIN/../../../../../python", + ( + "tensorflow/compiler/mlir/quantization/tensorflow/python/" + "pywrap_quantize_model.so" + ): "$ORIGIN/../../../../../python", + ( + "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/" + "pywrap_calibration.so" + ): "$ORIGIN/../../../../../python", } for file, path in to_patch.items(): rpath = subprocess.check_output( diff --git a/third_party/xla/third_party/tsl/tools/def_file_filter/symbols_pybind.txt b/third_party/xla/third_party/tsl/tools/def_file_filter/symbols_pybind.txt index 1a158e696f8590..78c42c6f454c5c 100644 --- a/third_party/xla/third_party/tsl/tools/def_file_filter/symbols_pybind.txt +++ b/third_party/xla/third_party/tsl/tools/def_file_filter/symbols_pybind.txt @@ -59,7 +59,7 @@ tensorflow::tfprof::SerializeToString [//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool] # graph_analyze tensorflow::grappler::graph_analyzer::GraphAnalyzerTool -[//external/local_tsl/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8 +[//external/local_xla/xla/tsl/python/lib/core:ml_dtypes_lib] # bfloat16, float8 tsl::ml_dtypes::RegisterTypes tsl::ml_dtypes::GetBfloat16Dtype tsl::ml_dtypes::GetFloat8E4m3b11fnuzDtype @@ -307,7 +307,7 @@ tensorflow::AddWhileInputHack tensorflow::RecordMutation tensorflow::Graph::IsControlEdge -[//external/local_tsl/tsl/python/lib/core:numpy] # tf_session +[//external/local_xla/xla/tsl/python/lib/core:numpy] # tf_session tsl::ImportNumpy _tsl_numpy_api diff --git a/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h index c25efc2f865b70..504085af8518ee 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h @@ -20,16 +20,16 @@ limitations under the License. #include "ml_dtypes/include/int4.h" // from @ml_dtypes namespace tsl { -using float8_e4m3fn = ml_dtypes::float8_e4m3fn; -using float8_e4m3fnuz = ml_dtypes::float8_e4m3fnuz; -using float8_e4m3b11fnuz = ml_dtypes::float8_e4m3b11fnuz; +using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; +using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; +using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e4m3b11 = float8_e4m3b11fnuz; // Deprecated: old name for // backward-compatibility only. -using float8_e5m2 = ml_dtypes::float8_e5m2; -using float8_e5m2fnuz = ml_dtypes::float8_e5m2fnuz; +using float8_e5m2 = ::ml_dtypes::float8_e5m2; +using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; -using int4 = ml_dtypes::int4; -using uint4 = ml_dtypes::uint4; +using int4 = ::ml_dtypes::int4; +using uint4 = ::ml_dtypes::uint4; } // namespace tsl #endif // TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_ diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 267e7284f05cb0..4d0cda7ddd02fd 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -197,6 +197,7 @@ cc_library( "//xla/pjrt:exceptions", "//xla/python/ifrt", "//xla/python/pjrt_ifrt", + "//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", @@ -205,7 +206,6 @@ cc_library( "@local_config_python//:python_headers", # buildcleaner: keep "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/python/lib/core:numpy", ], ) @@ -377,6 +377,7 @@ cc_library( "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:platform_util", + "//xla/tsl/python/lib/core:numpy", "@local_tsl//tsl/concurrency:ref_count", "@local_tsl//tsl/framework:allocator", "@local_tsl//tsl/platform:casts", @@ -387,7 +388,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/python/lib/core:numpy", "@com_google_protobuf//:protobuf", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -749,11 +749,11 @@ cc_library( "//xla/pjrt:pjrt_client", "//xla/pjrt:status_casters", "//xla/python/ifrt", + "//xla/tsl/python/lib/core:numpy", "@local_tsl//tsl/concurrency:ref_count", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/python/lib/core:numpy", ], ) @@ -1246,12 +1246,12 @@ tsl_pybind_extension( "//xla/python/ifrt_proxy/client:py_module", "//xla/python/pjrt_ifrt", "//xla/service/cpu:collectives_interface", + "//xla/tsl/python/lib/core:numpy", "@local_tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform/cloud:gcs_file_system", - "@local_tsl//tsl/python/lib/core:numpy", ] + select({ # gloo transport only builds on linux "@local_tsl//tsl:macos": [], @@ -1356,9 +1356,9 @@ cc_library( features = ["-use_header_modules"], deps = [ "//third_party/nanobind", + "//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/types:span", "@local_config_python//:python_headers", - "@local_tsl//tsl/python/lib/core:numpy", ], ) diff --git a/third_party/xla/xla/python/nb_numpy.cc b/third_party/xla/xla/python/nb_numpy.cc index f6006cd94786e7..2210f67569a283 100644 --- a/third_party/xla/xla/python/nb_numpy.cc +++ b/third_party/xla/xla/python/nb_numpy.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/types/span.h" #include "third_party/nanobind/include/nanobind/nanobind.h" -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/third_party/xla/xla/python/nb_numpy.h b/third_party/xla/xla/python/nb_numpy.h index 64c6c55cc93d2e..23dc85f7ce900c 100644 --- a/third_party/xla/xla/python/nb_numpy.h +++ b/third_party/xla/xla/python/nb_numpy.h @@ -30,7 +30,7 @@ limitations under the License. #include "absl/types/span.h" #include "third_party/nanobind/include/nanobind/nanobind.h" -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" #if NPY_ABI_VERSION < 0x02000000 #define PyDataType_ELSIZE(descr) ((descr)->elsize) diff --git a/third_party/xla/xla/python/pmap_lib.cc b/third_party/xla/xla/python/pmap_lib.cc index 61b16dedccbd1a..1dbeb30d01ef8b 100644 --- a/third_party/xla/xla/python/pmap_lib.cc +++ b/third_party/xla/xla/python/pmap_lib.cc @@ -68,13 +68,13 @@ limitations under the License. #include "xla/python/traceback.h" #include "xla/python/types.h" #include "xla/status_macros.h" +#include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/concurrency/ref_count.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/python/lib/core/numpy.h" namespace jax { diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index a4297bce47bacc..a85758f921096e 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -60,11 +60,11 @@ limitations under the License. #include "xla/python/nb_class_ptr.h" #include "xla/python/py_client.h" #include "xla/service/computation_placer.h" +#include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" #include "tsl/concurrency/ref_count.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -#include "tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/third_party/xla/xla/python/py_values.cc b/third_party/xla/xla/python/py_values.cc index 396dc542b11c1c..7454dd4e63bf93 100644 --- a/third_party/xla/xla/python/py_values.cc +++ b/third_party/xla/xla/python/py_values.cc @@ -50,6 +50,7 @@ limitations under the License. #include "xla/python/sharding.h" #include "xla/python/types.h" #include "xla/shape.h" +#include "xla/tsl/python/lib/core/numpy.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -57,7 +58,6 @@ limitations under the License. #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/python/lib/core/numpy.h" namespace nb = nanobind; diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD index 2338525fe744d3..2b6ce41ba112b0 100644 --- a/third_party/xla/xla/python/tools/BUILD +++ b/third_party/xla/xla/python/tools/BUILD @@ -65,9 +65,9 @@ tsl_pybind_extension( "//xla/python:logging", "//xla/python:nb_numpy", "//xla/python:types", + "//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/python/lib/core:numpy", "@pybind11", "@pybind11_abseil//pybind11_abseil:status_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", diff --git a/third_party/xla/xla/python/tools/_types.cc b/third_party/xla/xla/python/tools/_types.cc index 320404637a0462..cd2a65382642c9 100644 --- a/third_party/xla/xla/python/tools/_types.cc +++ b/third_party/xla/xla/python/tools/_types.cc @@ -33,7 +33,7 @@ limitations under the License. // is fine); however, tsl-numpy does reexport NumPy's arrayobject.h header. // Since one of the TF headers above already includes tsl-numpy, therefore // we must include it down here rather than including actual NumPy directly. -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace py = ::pybind11; namespace nb = ::nanobind; diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index 3ca11ef9d9e836..5ecb181fb2592e 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -45,11 +45,11 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -#include "tsl/python/lib/core/numpy.h" namespace xla { diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 1c9cf5bbcdadbb..cde953367c498d 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -56,7 +56,7 @@ limitations under the License. #include "xla/python/ifrt_proxy/client/py_module.h" #include "xla/python/py_client.h" #include "xla/service/cpu/collectives_interface.h" -#include "tsl/python/lib/core/numpy.h" //NOLINT +#include "xla/tsl/python/lib/core/numpy.h" //NOLINT #ifdef XLA_PYTHON_ENABLE_GPU #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #endif // XLA_PYTHON_ENABLE_GPU diff --git a/third_party/xla/third_party/tsl/tsl/python/lib/core/BUILD b/third_party/xla/xla/tsl/python/lib/core/BUILD similarity index 100% rename from third_party/xla/third_party/tsl/tsl/python/lib/core/BUILD rename to third_party/xla/xla/tsl/python/lib/core/BUILD diff --git a/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc rename to third_party/xla/xla/tsl/python/lib/core/ml_dtypes.cc index 2815b25b24469f..d138c00bf9e6d5 100644 --- a/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.cc +++ b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/python/lib/core/ml_dtypes.h" +#include "xla/tsl/python/lib/core/ml_dtypes.h" #include #include // Must be included first to ensure `NPY_NO_DEPRECATED_API` is defined. // clang-format off -#include "tsl/python/lib/core/numpy.h" // IWYU pragma: keep +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep // clang-format on #include "numpy/ndarraytypes.h" #include "absl/base/attributes.h" diff --git a/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.h b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.h similarity index 90% rename from third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.h rename to third_party/xla/xla/tsl/python/lib/core/ml_dtypes.h index f2b93ebee41a21..bf9eab2200a76b 100644 --- a/third_party/xla/third_party/tsl/tsl/python/lib/core/ml_dtypes.h +++ b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ -#define TENSORFLOW_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ +#ifndef XLA_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ +#define XLA_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ // Registers all custom types from the python ml_dtypes package. // https://github.com/jax-ml/ml_dtypes @@ -47,4 +47,4 @@ inline int GetBfloat16TypeNum() { return GetNumpyDtypes().bfloat16; } } // namespace ml_dtypes } // namespace tsl -#endif // TENSORFLOW_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ +#endif // XLA_TSL_PYTHON_LIB_CORE_ML_DTYPES_H_ diff --git a/third_party/xla/third_party/tsl/tsl/python/lib/core/numpy.cc b/third_party/xla/xla/tsl/python/lib/core/numpy.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/python/lib/core/numpy.cc rename to third_party/xla/xla/tsl/python/lib/core/numpy.cc index 3013a1a7c68d46..3f54df1281c2d5 100644 --- a/third_party/xla/third_party/tsl/tsl/python/lib/core/numpy.cc +++ b/third_party/xla/xla/tsl/python/lib/core/numpy.cc @@ -17,7 +17,7 @@ limitations under the License. // ImportNumpy function to populate it. #define XLA_IMPORT_NUMPY -#include "tsl/python/lib/core/numpy.h" +#include "xla/tsl/python/lib/core/numpy.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/python/lib/core/numpy.h b/third_party/xla/xla/tsl/python/lib/core/numpy.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/python/lib/core/numpy.h rename to third_party/xla/xla/tsl/python/lib/core/numpy.h index ee4b920d0ebf5c..6a5a6a6486ccf7 100644 --- a/third_party/xla/third_party/tsl/tsl/python/lib/core/numpy.h +++ b/third_party/xla/xla/tsl/python/lib/core/numpy.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PYTHON_LIB_CORE_NUMPY_H_ -#define TENSORFLOW_TSL_PYTHON_LIB_CORE_NUMPY_H_ +#ifndef XLA_TSL_PYTHON_LIB_CORE_NUMPY_H_ +#define XLA_TSL_PYTHON_LIB_CORE_NUMPY_H_ #ifdef PyArray_Type #error "Numpy cannot be included before numpy.h." @@ -50,4 +50,4 @@ void ImportNumpy(); } // namespace tsl -#endif // TENSORFLOW_TSL_PYTHON_LIB_CORE_NUMPY_H_ +#endif // XLA_TSL_PYTHON_LIB_CORE_NUMPY_H_ From 5bbc00214f106863808539104be9fdc6e01af02d Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Fri, 29 Mar 2024 13:22:30 -0700 Subject: [PATCH 098/124] Reverts 3a2cd8887ed96de6abbd26e46844b959aa42e481 PiperOrigin-RevId: 620324878 --- .../xla/xla/service/collective_opt_utils.cc | 35 +------------------ .../xla/xla/service/collective_opt_utils.h | 2 +- .../xla/service/reduce_scatter_decomposer.cc | 4 --- .../xla/service/reduce_scatter_decomposer.h | 6 ++-- .../service/reduce_scatter_decomposer_test.cc | 32 ++--------------- 5 files changed, 7 insertions(+), 72 deletions(-) diff --git a/third_party/xla/xla/service/collective_opt_utils.cc b/third_party/xla/xla/service/collective_opt_utils.cc index cbc7a4c8867bd4..8e7a6d874cfa8a 100644 --- a/third_party/xla/xla/service/collective_opt_utils.cc +++ b/third_party/xla/xla/service/collective_opt_utils.cc @@ -267,46 +267,13 @@ bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size, return true; } -std::optional SpecFromReduceScatterInstr( - const HloInstruction* rs_instr, int64_t num_partitions, - int64_t num_replicas, int64_t min_rank, bool is_constrain_layout, - bool use_global_device_ids, bool is_cross_module) { - if (rs_instr->shape().rank() < min_rank) { - return std::nullopt; - } - CHECK(rs_instr->opcode() == HloOpcode::kReduceScatter); - ReduceScatterSpec spec; - spec.split_dim = rs_instr->dimensions(0); - if (!is_cross_module) { - spec.sharded_replicas = num_replicas; - spec.group_size = rs_instr->replica_groups().empty() - ? num_replicas - : rs_instr->replica_groups()[0].replica_ids_size(); - } else if (use_global_device_ids) { - spec.sharded_replicas = num_replicas; - spec.sharded_partitions = num_partitions; - spec.group_size = rs_instr->replica_groups()[0].replica_ids_size(); - } else { - spec.sharded_partitions = num_partitions; - spec.group_size = num_partitions; - } - spec.original_split_dims = {spec.split_dim}; - spec.dynamic_slice = nullptr; - return spec; -} - } // namespace std::optional MatchReduceScatter( - const HloAllReduceInstructionBase* ar, int64_t num_partitions, + const HloAllReduceInstruction* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, HloPredicate match_partition_id, HloPredicate match_replica_id) { - if (ar->opcode() == HloOpcode::kReduceScatter) { - return SpecFromReduceScatterInstr( - ar, num_partitions, num_replicas, min_rank, ar->constrain_layout(), - ar->use_global_device_ids(), ar->channel_id().has_value()); - } auto spec = MatchWithDynamicSlice( ar, num_partitions, num_replicas, allow_multiple_split_dims, allow_intervening_reshape, min_rank, match_partition_id, match_replica_id, diff --git a/third_party/xla/xla/service/collective_opt_utils.h b/third_party/xla/xla/service/collective_opt_utils.h index 7d044be3c34568..11b65c1acc4160 100644 --- a/third_party/xla/xla/service/collective_opt_utils.h +++ b/third_party/xla/xla/service/collective_opt_utils.h @@ -36,7 +36,7 @@ struct ReduceScatterSpec { // Matches the given all-reduce operation to a reduce-scatter pattern. std::optional MatchReduceScatter( - const HloAllReduceInstructionBase* ar, int64_t num_partitions, + const HloAllReduceInstruction* ar, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, bool allow_intervening_reshape = false, int64_t min_rank = 1, HloPredicate match_partition_id = HloPredicateIsOp, diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer.cc b/third_party/xla/xla/service/reduce_scatter_decomposer.cc index da2fed224a53f5..7210a2c12b4f30 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer.cc +++ b/third_party/xla/xla/service/reduce_scatter_decomposer.cc @@ -53,11 +53,7 @@ absl::StatusOr ReduceScatterDecomposer::Run( if (rs->channel_id()) { channel_id = next_channel_id++; } - if (should_decompose_ && !should_decompose_(rs)) { - continue; - } - VLOG(2) << "Decompose: " << rs->ToString(); // Create an all-reduce HloComputation *apply_clone = module->AddComputationAndUnifyNamesAndIds( rs->to_apply()->Clone(), /*is_entry=*/false); diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer.h b/third_party/xla/xla/service/reduce_scatter_decomposer.h index 1ee1f603c09f28..324d97d0e915e9 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer.h +++ b/third_party/xla/xla/service/reduce_scatter_decomposer.h @@ -29,9 +29,8 @@ namespace xla { class ReduceScatterDecomposer : public HloModulePass { public: explicit ReduceScatterDecomposer( - std::function update_layout = nullptr, - std::function should_decompose = nullptr) - : update_layout_(update_layout), should_decompose_(should_decompose) {} + std::function update_layout = nullptr) + : update_layout_(update_layout) {} absl::string_view name() const override { return "reduce-scatter-decomposer"; } @@ -41,7 +40,6 @@ class ReduceScatterDecomposer : public HloModulePass { HloModule* module, const absl::flat_hash_set& execution_threads) override; std::function update_layout_; - std::function should_decompose_; }; } // namespace xla diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc b/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc index d7f8360fbdc910..bfaa918930befb 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc +++ b/third_party/xla/xla/service/reduce_scatter_decomposer_test.cc @@ -41,18 +41,13 @@ class ReduceScatterDecomposerTest : public HloTestBase { absl::string_view hlo_module, PassAction action, CollectiveOpGroupMode mode = CollectiveOpGroupMode::kCrossReplica, int64_t shard_size = 0, int64_t shard_dimension = 0, - int64_t replica_count = 2, - std::function should_decompose = - [](const HloInstruction *) { return true; }) { + int64_t replica_count = 2) { const int64_t partition_count = 2; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndReturnVerifiedModule(hlo_module, replica_count, partition_count)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - ReduceScatterDecomposer(/*update_layout=*/nullptr, - /*should_decompose=*/should_decompose) - .Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ReduceScatterDecomposer().Run(module.get())); if (action == PassAction::kNoChange) { ASSERT_FALSE(changed); return; @@ -227,26 +222,5 @@ ENTRY main { RunPass(hlo_string, PassAction::kNoChange); } -TEST_F(ReduceScatterDecomposerTest, NoChangeWithShouldDecompose) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[4, 8] parameter(0) - ROOT rs = f32[4, 4] reduce-scatter(p0), replica_groups={{0,1}, {2,3}}, channel_id=1, dimensions={1}, to_apply=sum, use_global_device_ids=true -} -)"; - RunPass(hlo_string, PassAction::kNoChange, - CollectiveOpGroupMode::kCrossReplica, - /*shard_size=*/0, /*shard_dimension=*/0, - /*replica_count=*/2, [](const HloInstruction *) { return false; }); -} - } // namespace } // namespace xla From b25ab7111ed33efb300e60760dd6a03268826814 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Mar 2024 13:38:12 -0700 Subject: [PATCH 099/124] Implementation of get / retrieve buffer attributes methods in GPU async kernel. PiperOrigin-RevId: 620328996 --- tensorflow/lite/delegates/gpu/delegate.cc | 138 +++++++++++++++++----- 1 file changed, 110 insertions(+), 28 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index 3389cf7948e0d6..9fac6e598f1b1a 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -117,7 +117,7 @@ using tflite::delegates::utils::WriteSyncAttrs; namespace tflite { namespace gpu { namespace { - +// TODO(b/328628170): Add productive coverage to GPU delegate. using delegates::Serialization; using delegates::SerializationParams; using tflite::TFLITE_LOG_WARNING; @@ -880,6 +880,37 @@ class DelegateAsyncKernel : public BackendAsyncKernelInterface { return desc_ahwb; } + // Validate the attributes passed in, return kTfLiteOk if the attributes + // meet the requirements. Return the registered buffer attributes in + // `buffer_attrs`. + static TfLiteStatus CheckAttributes(const TfLiteAttributeMap* attrs, + BufferAttributes& buffer_attrs) { + // Validate buffer attributes. + TFLITE_RET_CHECK_STATUS( + TfLiteAttributeMapIsBufferAttributeMap(attrs), + "calling RegisterBuffer with invalid attribute map type"); + buffer_attrs = ReadBufferAttrs(attrs); + TFLITE_RET_CHECK_STATUS( + buffer_attrs.buffer_type.has_value(), + "calling RegisterBuffer with buffer resource type name unspecified"); + TFLITE_RET_CHECK_STATUS( + buffer_attrs.buffer_type.value() != BufferType::kUnknown, + "calling RegisterBuffer with unknown buffer resource type"); + size_t alignment = buffer_attrs.alignment.value_or(kRequiredByteAlignment); + TFLITE_RET_CHECK_STATUS( + alignment % kRequiredByteAlignment == 0, + "calling RegisterBuffer with non-zero buffer alignment"); + size_t padding = buffer_attrs.padding.value_or(kRequiredBytePadding); + TFLITE_RET_CHECK_STATUS( + padding % kRequiredBytePadding == 0, + "calling RegisterBuffer with non-zero buffer padding"); + size_t offset = buffer_attrs.offset.value_or(0); + TFLITE_RET_CHECK_STATUS(offset == 0, + "calling RegisterBuffer with non-zero offset"); + + return kTfLiteOk; + } + // For SupportedBufferTypes and SupportedSynchronizations const std::vector supported_buffer_types_ = { ::tflite::delegates::utils::kBufferTypeAHardwareBufferBlob}; @@ -901,6 +932,9 @@ class DelegateAsyncKernel : public BackendAsyncKernelInterface { absl::flat_hash_map buffer_by_handle_ ABSL_GUARDED_BY(eval_mutex_); + + absl::flat_hash_map attributes_by_buffer_ + ABSL_GUARDED_BY(eval_mutex_); std::vector output_sync_types_ ABSL_GUARDED_BY(eval_mutex_); }; @@ -1074,14 +1108,78 @@ TfLiteStatus DelegateAsyncKernel::SetAttributesImpl( TfLiteStatus DelegateAsyncKernel::SetBufferAttributes( const TfLiteBackendBuffer* buffer, const TfLiteAttributeMap* attrs) { - // TODO(b/325338475): Implement the details for set attributes to buffer. - return kTfLiteDelegateError; + TFLITE_ABORT_CHECK(buffer != nullptr, "Buffer is null"); + TFLITE_ABORT_CHECK(attrs != nullptr, "Attribute is null"); + + // We depend on the availability of AHardwareBuffer. + TFLITE_RET_CHECK_STATUS( + TFLITE_AHWB_AVAILABLE(), + "calling tflite::gpu::DelegateAsyncKernel::SetBufferAttributes on device " + "without AHardwareBuffer support"); + BufferAttributes buffer_attrs; + TFLITE_RET_CHECK_STATUS(CheckAttributes(attrs, buffer_attrs) == kTfLiteOk, + "SetBufferAttributes(): Failed to check attributes"); + + // Validate ahardwarebuffer. + auto* ahwb = + reinterpret_cast(TfLiteBackendBufferGetPtr(buffer)); + TFLITE_RET_CHECK_STATUS(ahwb != nullptr, + "calling SetBufferAttributes with nullptr buffer"); + UniquePtrAHardwareBuffer uptr_ahwb = Acquire(ahwb); + const AHardwareBuffer_Desc desc_ahwb = Describe(uptr_ahwb); + TFLITE_RET_CHECK_STATUS(desc_ahwb.format == AHARDWAREBUFFER_FORMAT_BLOB, + "calling SetBufferAttributes with an AHardwareBuffer " + "of format other than BLOB is not supported"); + size_t size = buffer_attrs.size.value_or(desc_ahwb.width); + TFLITE_RET_CHECK_STATUS( + size <= desc_ahwb.width, + "calling SetBufferAttributes with buffer size larger than the actual " + "AHardwareBuffer size"); + + absl::MutexLock eval_lock(&eval_mutex_); + if (attributes_by_buffer_.find(uptr_ahwb.get()) != + attributes_by_buffer_.end()) { + attributes_by_buffer_[uptr_ahwb.get()] = buffer_attrs; + } else { + TFLITE_LOG_PROD( + TFLITE_LOG_ERROR, + "SetBufferAttributes(): Unable to find the buffer in the map."); + } + return kTfLiteOk; } TfLiteStatus DelegateAsyncKernel::GetBufferAttributes( const TfLiteBackendBuffer* buffer, TfLiteAttributeMap* attrs) { - // TODO(b/325338475): Implement the details for get attributes from buffer. - return kTfLiteDelegateError; + TFLITE_ABORT_CHECK(buffer != nullptr, "Buffer is null"); + TFLITE_ABORT_CHECK(attrs != nullptr, "Attribute map is null"); + + // We depend on the availability of AHardwareBuffer. + TFLITE_RET_CHECK_STATUS( + TFLITE_AHWB_AVAILABLE(), + "calling tflite::gpu::DelegateAsyncKernel::GetBufferAttributes on device " + "without AHardwareBuffer support"); + TFLITE_RET_CHECK_STATUS( + TfLiteAttributeMapIsBufferAttributeMap(attrs), + "calling GetBufferAttributes with an invalid attribute map type"); + + // Validate ahardwarebuffer. + auto* ahwb = + reinterpret_cast(TfLiteBackendBufferGetPtr(buffer)); + TFLITE_RET_CHECK_STATUS(ahwb != nullptr, + "calling GetBufferAttributes with nullptr buffer"); + UniquePtrAHardwareBuffer uptr_ahwb = Acquire(ahwb); + const AHardwareBuffer_Desc desc_ahwb = Describe(uptr_ahwb); + TFLITE_RET_CHECK_STATUS(desc_ahwb.format == AHARDWAREBUFFER_FORMAT_BLOB, + "calling GetBufferAttributes with an AHardwareBuffer " + "of format other than " + "BLOB is not supported"); + + absl::MutexLock eval_lock(&eval_mutex_); + auto it = attributes_by_buffer_.find(uptr_ahwb.get()); + TFLITE_RET_CHECK_STATUS(it != attributes_by_buffer_.end(), + "Unable to find the buffer."); + WriteBufferAttrs(it->second, attrs); + return kTfLiteOk; } TfLiteStatus DelegateAsyncKernel::Prepare(TfLiteOpaqueContext* opaque_context, @@ -1141,34 +1239,14 @@ TfLiteStatus DelegateAsyncKernel::RegisterBufferImpl( TFLITE_ABORT_CHECK(buffer != nullptr, ""); // Crash OK TFLITE_ABORT_CHECK(attrs != nullptr, ""); // Crash OK TFLITE_ABORT_CHECK(handle != kTfLiteNullBufferHandle, ""); // Crash OK - - // Validate buffer attributes. - TFLITE_RET_CHECK_STATUS( - TfLiteAttributeMapIsBufferAttributeMap(attrs), - "calling RegisterBuffer with invalid attribute map type"); - auto buffer_attrs = ReadBufferAttrs(attrs); - TFLITE_RET_CHECK_STATUS( - buffer_attrs.buffer_type.has_value(), - "calling RegisterBuffer with buffer resource type name unspecified"); - TFLITE_RET_CHECK_STATUS( - buffer_attrs.buffer_type.value() != BufferType::kUnknown, - "calling RegisterBuffer with unknown buffer resource type"); - size_t alignment = buffer_attrs.alignment.value_or(kRequiredByteAlignment); - TFLITE_RET_CHECK_STATUS( - alignment % kRequiredByteAlignment == 0, - "calling RegisterBuffer with invalid buffer alignment"); - size_t padding = buffer_attrs.padding.value_or(kRequiredBytePadding); - TFLITE_RET_CHECK_STATUS(padding % kRequiredBytePadding == 0, - "calling RegisterBuffer with invalid buffer padding"); - size_t offset = buffer_attrs.offset.value_or(0); - TFLITE_RET_CHECK_STATUS(offset == 0, - "calling RegisterBuffer with non-zero offset"); - // We depend on the availability of AHardwareBuffer. TFLITE_RET_CHECK_STATUS( TFLITE_AHWB_AVAILABLE(), "calling tflite::gpu::DelegateAsyncKernel::RegisterBuffer on device " "without AHardwareBuffer support"); + BufferAttributes buffer_attrs; + TFLITE_RET_CHECK_STATUS(CheckAttributes(attrs, buffer_attrs) == kTfLiteOk, + "RegisterBufferImpl(): Failed to check attributes"); // Retrieve and validate the buffer. auto* ahwb = @@ -1193,6 +1271,10 @@ TfLiteStatus DelegateAsyncKernel::RegisterBufferImpl( buffer_by_handle_.try_emplace(handle, std::move(uptr_ahwb)); TFLITE_RET_CHECK_STATUS(did_something, "RegisterBuffer called with duplicate handle"); + + auto [iterator, check] = + attributes_by_buffer_.try_emplace(it->second.get(), buffer_attrs); + TFLITE_RET_CHECK_STATUS(check, "RegisterBuffer called with same buffer"); return kTfLiteOk; } From 7810e7330e76c452e56bd5c42fe23d536d9cef14 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 29 Mar 2024 14:03:53 -0700 Subject: [PATCH 100/124] [xla:gpu] Add a version of `HloPredicateIsOp` for `HloInstructionAdaptor` PiperOrigin-RevId: 620335202 --- third_party/xla/xla/service/gpu/hlo_traversal.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.h b/third_party/xla/xla/service/gpu/hlo_traversal.h index fa5bc0f81817fb..d77e669b48f162 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.h +++ b/third_party/xla/xla/service/gpu/hlo_traversal.h @@ -68,6 +68,11 @@ H AbslHashValue(H h, const HloInstructionAdaptor& m) { m.instruction_->unique_id()); } +template +bool IsOpcodeAnyOf(const HloInstructionAdaptor& adaptor) { + return (adaptor.opcode() == op) || ((adaptor.opcode() == rest) || ...); +} + class HloFusionAdaptor { public: virtual ~HloFusionAdaptor() = default; From d6b38af08cbb1cec8d158529a4df9fb6d8fa16a1 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Fri, 29 Mar 2024 15:37:59 -0700 Subject: [PATCH 101/124] [PJRT C API] Add a PJRT extension to register custom partitioner. - This extension has one C API which registers a custom partitioner with callbacks from the input. - Update xla_client.register_custom_call_partitioner to take an optional PJRT_Api* input. - Add xla_bridge.register_plugin_initialization_callbacks to register callbacks to be called with PJRT_Api* after plugins are discovered. PiperOrigin-RevId: 620357554 --- third_party/xla/xla/pjrt/c/BUILD | 11 ++ third_party/xla/xla/pjrt/c/CHANGELOG.md | 15 +- third_party/xla/xla/pjrt/c/pjrt_c_api.h | 3 +- .../pjrt_c_api_custom_partitioner_extension.h | 134 ++++++++++++++++++ .../xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 22 ++- third_party/xla/xla/python/BUILD | 6 +- .../xla/xla/python/custom_call_sharding.cc | 47 +++++- .../xla/python/custom_partition_callback.cc | 1 + .../xla/python/custom_partition_callback.h | 84 +---------- third_party/xla/xla/python/xla_client.py | 2 +- .../xla/xla/python/xla_extension/__init__.pyi | 1 + 11 files changed, 229 insertions(+), 97 deletions(-) create mode 100644 third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index e01564fb6a0ca9..bf9a3a85cb1b3c 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -53,6 +53,15 @@ cc_library( ], ) +cc_library( + name = "pjrt_c_api_custom_partitioner_extension_hdrs", + hdrs = ["pjrt_c_api_custom_partitioner_extension.h"], + visibility = ["//visibility:public"], + deps = [ + ":pjrt_c_api_hdrs", + ], +) + cc_library( name = "pjrt_c_api_wrapper_impl", srcs = ["pjrt_c_api_wrapper_impl.cc"], @@ -166,6 +175,7 @@ cc_library( hdrs = ["pjrt_c_api_gpu_internal.h"], visibility = ["//visibility:public"], deps = [ + ":pjrt_c_api_custom_partitioner_extension_hdrs", ":pjrt_c_api_gpu_extension_hdrs", ":pjrt_c_api_hdrs", ":pjrt_c_api_helpers", @@ -186,6 +196,7 @@ cc_library( "//xla/pjrt/gpu:gpu_helpers", "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/pjrt/gpu:se_gpu_pjrt_compiler", # To register GPU AOT compiler + "//xla/python:custom_partition_callback", "//xla/python:inspect_sharding", # To register "InspectSharding" custom partitioning handler. "//xla/service:compiler", "//xla/service:custom_call_target_registry", diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index ca223ce6007b1e..cb9cb750d81940 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,25 +1,28 @@ # PJRT C API changelog -## 0.46 +## 0.47 +* Added ``PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner``. + +## 0.46 (Feb 29, 2024) * Update outdated struct sizes from previous changes to ``PJRT_Device_AddressableMemories_Args`` and ``PJRT_ExecuteOptions``. -## 0.45 +## 0.45 (Feb 27, 2024) * Breaking changes * Added struct_size field to beginning of PJRT_Extension_Base. This is so forwards and backwards compatibility logic can be implemented with extension structs. -## 0.44 +## 0.44 (Feb 26, 2024) * Changed all ``void*`` extension fields to have type ``PJRT_Extension_Base*`` -## 0.43 +## 0.43 (Feb 24, 2024) * Added some new fields to PJRT_Executable_GetCompiledMemoryStats -## 0.42 +## 0.42 (Feb 13, 2024) * Renamed all ``priv`` fields to ``extension_start`` -## 0.41 +## 0.41 (Feb 13, 2024) * Renamed PJRT_Structure_Base to PJRT_Extension_Base * Renamed PJRT_Structure_Type to PJRT_Extension_Type (and similarly for enum fields) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index c39b7636983e39..da1934e64f2e64 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -41,6 +41,7 @@ extern "C" { typedef enum { PJRT_Extension_Type_Gpu_Custom_Call = 0, PJRT_Extension_Type_Profiler, + PJRT_Extension_Type_Custom_Partitioner, } PJRT_Extension_Type; // PJRT_Extension_Base contains a type and a pointer to next @@ -75,7 +76,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 46 +#define PJRT_API_MINOR 47 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h new file mode 100644 index 00000000000000..825734610b863c --- /dev/null +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h @@ -0,0 +1,134 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_ +#define XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_ + +#include +#include + +#include "xla/pjrt/c/pjrt_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define PJRT_API_CUSTOM_PARTITIONER_EXTENSION_VERSION 0 + +struct JAX_CustomCallPartitioner_string { + const char* data; + size_t size; +}; + +struct JAX_CustomCallPartitioner_aval { + JAX_CustomCallPartitioner_string shape; + bool has_sharding; + JAX_CustomCallPartitioner_string sharding; +}; + +// General callback information containing api versions, the result error +// message and the cleanup function to free any temporary memory that is backing +// the results. Arguments are always owned by the caller, and results are owned +// by the cleanup_fn. These should never be used directly. Args and results +// should be serialized via the PopulateArgs, ReadArgs, PopulateResults, +// ConsumeResults functions defined below. +struct JAX_CustomCallPartitioner_version_and_error { + int64_t api_version; + void* data; // out + // cleanup_fn cleans up any returned results. The caller must finish with all + // uses by the point the cleanup is called. + void (*cleanup_fn)(void* data); // out + bool has_error; + PJRT_Error_Code code; // out + JAX_CustomCallPartitioner_string error_msg; // out +}; + +struct JAX_CustomCallPartitioner_Partition_Args { + JAX_CustomCallPartitioner_version_and_error header; + + size_t num_args; + JAX_CustomCallPartitioner_aval* op_args; + JAX_CustomCallPartitioner_aval op_result; + JAX_CustomCallPartitioner_string backend_config; + + // out + JAX_CustomCallPartitioner_string mlir_module; + JAX_CustomCallPartitioner_string* args_sharding; + JAX_CustomCallPartitioner_string result_sharding; +}; + +struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args { + JAX_CustomCallPartitioner_version_and_error header; + + size_t num_args; + JAX_CustomCallPartitioner_aval* op_args; + JAX_CustomCallPartitioner_string result_shape; + JAX_CustomCallPartitioner_string backend_config; + + bool has_result_sharding; + JAX_CustomCallPartitioner_string result_sharding; +}; + +struct JAX_CustomCallPartitioner_PropagateUserSharding_Args { + JAX_CustomCallPartitioner_version_and_error header; + + JAX_CustomCallPartitioner_string backend_config; + + JAX_CustomCallPartitioner_string result_shape; + + JAX_CustomCallPartitioner_string result_sharding; // inout +}; + +struct JAX_CustomCallPartitioner_Callbacks { + int64_t version; + void* private_data; + void (*dtor)(JAX_CustomCallPartitioner_Callbacks* data); + void (*partition)(JAX_CustomCallPartitioner_Callbacks* data, + JAX_CustomCallPartitioner_Partition_Args* args); + void (*infer_sharding)( + JAX_CustomCallPartitioner_Callbacks* data, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); + void (*propagate_user_sharding)( + JAX_CustomCallPartitioner_Callbacks* data, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); + bool can_side_effecting_have_replicated_sharding; +}; + +struct PJRT_Register_Custom_Partitioner_Args { + size_t struct_size; + const char* name; // lifetime of the call. + size_t name_size; + JAX_CustomCallPartitioner_Callbacks* callbacks; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Register_Custom_Partitioner_Args, callbacks); + +// Registers a custom partitioner. +typedef PJRT_Error* PJRT_Register_Custom_Partitioner( + PJRT_Register_Custom_Partitioner_Args* args); + +typedef struct PJRT_Custom_Partitioner_Extension { + size_t struct_size; + PJRT_Extension_Type type; + PJRT_Extension_Base* next; + PJRT_Register_Custom_Partitioner* register_custom_partitioner; +} PJRT_Custom_Partitioner_Extension; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Custom_Partitioner_Extension, + register_custom_partitioner); + +#ifdef __cplusplus +} +#endif + +#endif // XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_ diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index c687c7760c8752..211fcc5d538c03 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_profiler_extension.h" @@ -42,6 +43,7 @@ limitations under the License. #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" +#include "xla/python/custom_partition_callback.h" #include "xla/service/compiler.h" #include "xla/service/custom_call_target_registry.h" #include "xla/stream_executor/device_description.h" @@ -198,6 +200,24 @@ PJRT_Profiler_Extension profiler_extension{ /*profiler_api=*/&profiler_api, }; +PJRT_Error* PJRT_Register_Custom_Partitioner( + PJRT_Register_Custom_Partitioner_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_Register_Custom_Partitioner_Args", + PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE, args->struct_size)); + std::string name(args->name, args->name_size); + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(args->callbacks)); + return nullptr; +} + +PJRT_Custom_Partitioner_Extension custom_partitioner{ + /*struct_size=*/PJRT_Gpu_Custom_Call_STRUCT_SIZE, + /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner, + /*next=*/reinterpret_cast(&profiler_extension), + /*register_custom_partitioner=*/PJRT_Register_Custom_Partitioner, +}; + PJRT_Error* PJRT_Gpu_Register_Custom_Call( PJRT_Gpu_Register_Custom_Call_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( @@ -228,7 +248,7 @@ const PJRT_Api* GetGpuPjrtApi() { static PJRT_Gpu_Custom_Call custom_call{ /*struct_size=*/PJRT_Gpu_Custom_Call_STRUCT_SIZE, /*type=*/PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call, - /*next=*/reinterpret_cast(&profiler_extension), + /*next=*/reinterpret_cast(&custom_partitioner), /*custom_call=*/PJRT_Gpu_Register_Custom_Call, }; static const PJRT_Api pjrt_api = diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 4d0cda7ddd02fd..660ef5ecb87adb 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -566,6 +566,7 @@ cc_library( "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/pjrt:mlir_to_hlo", + "//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_helpers", "//xla/service:call_inliner", @@ -599,11 +600,14 @@ cc_library( "@com_google_absl//absl/status:statusor", "//third_party/nanobind", "//xla:shape_util", + "//xla:status", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_sharding_util", "//xla/pjrt:status_casters", - "@local_tsl//tsl/platform:errors", + "//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "//xla/pjrt/c:pjrt_c_api_hdrs", + "//xla/pjrt/c:pjrt_c_api_helpers", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/python/custom_call_sharding.cc b/third_party/xla/xla/python/custom_call_sharding.cc index c9056f22465f05..599cb160a9d94d 100644 --- a/third_party/xla/xla/python/custom_call_sharding.cc +++ b/third_party/xla/xla/python/custom_call_sharding.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -33,9 +34,14 @@ limitations under the License. #include "third_party/nanobind/include/nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/custom_partition_callback.h" #include "xla/python/inspect_sharding.h" #include "xla/shape.h" +#include "xla/status.h" #include "xla/util.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -210,15 +216,45 @@ void BuildCustomCallShardingPybindAPI(nb::module_& m) { "register_custom_call_partitioner", [](std::string name, nb::object prop_user_sharding, nb::object partition, nb::object infer_sharding_from_operands, - bool can_side_effecting_have_replicated_sharding) { + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { auto* c_fns = (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, infer_sharding_from_operands)) ->callbacks(); c_fns->can_side_effecting_have_replicated_sharding = can_side_effecting_have_replicated_sharding; - RegisterCustomCallPartitioner( - name, jax::CreateCApiCustomCallPartitioner(c_fns)); + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (std::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error* error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); }, R"(Registers a partitioner for a custom-call operation. @@ -233,10 +269,13 @@ void BuildCustomCallShardingPybindAPI(nb::module_& m) { Takes operand sharding and returns the instruction sharding. can_side_effecting_have_replicated_sharding: Side effecting ops are not allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension )", nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), nb::arg("infer_sharding_from_operands"), - nb::arg("can_side_effecting_have_replicated_sharding") = false); + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); m.def("encode_inspect_sharding_callback", [](nb::object handler) -> nb::bytes { JAX_InspectSharding_Callback cb; diff --git a/third_party/xla/xla/python/custom_partition_callback.cc b/third_party/xla/xla/python/custom_partition_callback.cc index b37c6d03b4c729..d9bcb596bf5999 100644 --- a/third_party/xla/xla/python/custom_partition_callback.cc +++ b/third_party/xla/xla/python/custom_partition_callback.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/service/call_inliner.h" diff --git a/third_party/xla/xla/python/custom_partition_callback.h b/third_party/xla/xla/python/custom_partition_callback.h index d026f1f5cbb59b..33cc31e75fc9bf 100644 --- a/third_party/xla/xla/python/custom_partition_callback.h +++ b/third_party/xla/xla/python/custom_partition_callback.h @@ -24,91 +24,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" #include "xla/service/custom_call_sharding_helper.h" -extern "C" { - -struct JAX_CustomCallPartitioner_string { - const char* data; - size_t size; -}; - -struct JAX_CustomCallPartitioner_aval { - JAX_CustomCallPartitioner_string shape; - bool has_sharding; - JAX_CustomCallPartitioner_string sharding; -}; - -// General callback information containing api versions, the result error -// message and the cleanup function to free any temporary memory that is backing -// the results. Arguments are always owned by the caller, and results are owned -// by the cleanup_fn. These should never be used directly. Args and results -// should be serialized via the PopulateArgs, ReadArgs, PopulateResults, -// ConsumeResults functions defined below. -struct JAX_CustomCallPartitioner_version_and_error { - int64_t api_version; - void* data; // out - // cleanup_fn cleans up any returned results. The caller must finish with all - // uses by the point the cleanup is called. - void (*cleanup_fn)(void* data); // out - bool has_error; - PJRT_Error_Code code; // out - JAX_CustomCallPartitioner_string error_msg; // out -}; - -struct JAX_CustomCallPartitioner_Partition_Args { - JAX_CustomCallPartitioner_version_and_error header; - - size_t num_args; - JAX_CustomCallPartitioner_aval* op_args; - JAX_CustomCallPartitioner_aval op_result; - JAX_CustomCallPartitioner_string backend_config; - - // out - JAX_CustomCallPartitioner_string mlir_module; - JAX_CustomCallPartitioner_string* args_sharding; - JAX_CustomCallPartitioner_string result_sharding; -}; - -struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args { - JAX_CustomCallPartitioner_version_and_error header; - - size_t num_args; - JAX_CustomCallPartitioner_aval* op_args; - JAX_CustomCallPartitioner_string result_shape; - JAX_CustomCallPartitioner_string backend_config; - - bool has_result_sharding; - JAX_CustomCallPartitioner_string result_sharding; -}; - -struct JAX_CustomCallPartitioner_PropagateUserSharding_Args { - JAX_CustomCallPartitioner_version_and_error header; - - JAX_CustomCallPartitioner_string backend_config; - - JAX_CustomCallPartitioner_string result_shape; - - JAX_CustomCallPartitioner_string result_sharding; // inout -}; - -struct JAX_CustomCallPartitioner_Callbacks { - int64_t version; - void* private_data; - void (*dtor)(JAX_CustomCallPartitioner_Callbacks* data); - void (*partition)(JAX_CustomCallPartitioner_Callbacks* data, - JAX_CustomCallPartitioner_Partition_Args* args); - void (*infer_sharding)( - JAX_CustomCallPartitioner_Callbacks* data, - JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args); - void (*propagate_user_sharding)( - JAX_CustomCallPartitioner_Callbacks* data, - JAX_CustomCallPartitioner_PropagateUserSharding_Args* args); - bool can_side_effecting_have_replicated_sharding; -}; - -} // extern "C" - namespace jax { struct PartitionScratch { diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 64eb6cd7d4e1dd..a6b8b6e8dc4e58 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 251 +_version = 252 # Version number for MLIR:Python components. mlir_api_version = 55 diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index 8fe1300bd94c73..b8a02ae8f1f41e 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -267,6 +267,7 @@ def register_custom_call_partitioner( partition: Callable, infer_sharding_from_operands: Callable, can_side_effecting_have_replicated_sharding: bool, + c_api: Optional[Any], ) -> None: ... def encode_inspect_sharding_callback(handler: Any) -> bytes: ... From 30464e6cef4184f40333c60c906f574f72fb85f2 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Fri, 29 Mar 2024 16:22:52 -0700 Subject: [PATCH 102/124] In general, avoid the suffix on StatusOr. PiperOrigin-RevId: 620366682 --- third_party/xla/xla/service/BUILD | 3 +- .../xla/xla/service/shape_inference_test.cc | 1177 ++++++++--------- 2 files changed, 586 insertions(+), 594 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 1db01dd5c90dea..dc4a8ff4263651 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -596,7 +596,6 @@ xla_cc_test( ":hlo_parser", ":shape_inference", "//xla:shape_util", - "//xla:status", "//xla:statusor", "//xla:test", "//xla:test_helpers", @@ -609,7 +608,7 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 174e10f007037d..781852cd41d9dd 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -82,11 +83,11 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { const Shape& expected_inferred_shape, const Shape& arg, absl::Span dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape({&arg, &f32_}, dimensions_to_reduce, to_apply); - EXPECT_IS_OK(inferred_status.status()); - EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, *inferred_status)); + EXPECT_IS_OK(inferred_shape.status()); + EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, *inferred_shape)); } }; @@ -164,138 +165,138 @@ class UnboundedSelectOpShapeInferenceTest TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, *inferred_shape)); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { const Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_, tuple, tuple); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Expected array argument for select")); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { const Shape predarray = ShapeUtil::MakeShape(PRED, {64, 48}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, SelectBadShapes) { - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Operands to select must be the same shape")); - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("pred operand must have PRED")); - const absl::StatusOr inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, matrix_64_48_); - ASSERT_FALSE(inferred_status_error3.ok()); + ASSERT_FALSE(inferred_shape_error3.ok()); ASSERT_THAT( - inferred_status_error3.status().message(), + inferred_shape_error3.status().message(), HasSubstr("Operands to select and predicate must be the same shape")); // Tuples have a TUPLE element type and cannot be the pred of a select. - const absl::StatusOr inferred_status_error4 = + const absl::StatusOr inferred_shape_error4 = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); - ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_THAT(inferred_status_error4.status().message(), + ASSERT_FALSE(inferred_shape_error4.ok()); + ASSERT_THAT(inferred_shape_error4.status().message(), HasSubstr("Expected array argument for select pred")); } TEST_F(ShapeInferenceTest, ClampAllMatrix) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(f32_, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampMinScalar) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Clamp with incompatible shapes")); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, f32_, f32_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Clamp with incompatible shapes")); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, matrix_64_48_); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Clamp with incompatible shapes")); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, matrix_64_48_, f32_); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, *inferred_shape)); } TEST_F(ShapeInferenceTest, ClampBadShapes) { @@ -397,80 +398,80 @@ TEST_F(ShapeInferenceTest, ReduceWindowInHalf) { const Shape float_scalar = ShapeUtil::MakeShape(F32, {}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReduceWindowShape(matrix_shape, init_value_shape, window, to_apply); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), *inferred_status)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), *inferred_shape)); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) { - const absl::StatusOr inferred_status_ok = + const absl::StatusOr inferred_shape_ok = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_, window_, source_shape_, init_value_shape_, scatter_program_shape_); - ASSERT_IS_OK(inferred_status_ok.status()); - ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, *inferred_status_ok)); + ASSERT_IS_OK(inferred_shape_ok.status()); + ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, *inferred_shape_ok)); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { const Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6}); - const absl::StatusOr inferred_status_fail = + const absl::StatusOr inferred_shape_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_, window_, source_shape_fail, init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_); - const absl::StatusOr inferred_status_fail = + const absl::StatusOr inferred_shape_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); - const absl::StatusOr inferred_status_fail = + const absl::StatusOr inferred_shape_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); - const absl::StatusOr inferred_status_fail = + const absl::StatusOr inferred_shape_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_); - const absl::StatusOr inferred_status_fail = + const absl::StatusOr inferred_shape_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); - ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_THAT(inferred_status_fail.status().message(), + ASSERT_FALSE(inferred_shape_fail.ok()); + ASSERT_THAT(inferred_shape_fail.status().message(), HasSubstr("Select function's second parameter")); } @@ -575,14 +576,14 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), - *inferred_status)); + *inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { @@ -622,14 +623,14 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), - *inferred_status)); + *inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { @@ -669,14 +670,14 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), - *inferred_status)); + *inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { @@ -709,13 +710,13 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("each dimension exactly once")); } @@ -748,13 +749,13 @@ TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) { dim1->set_stride(1); dim0->set_window_dilation(3); dim1->set_window_dilation(2); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, window, dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("to be a multiple of batch group count")); } @@ -953,18 +954,18 @@ static const char* innermost_dimension_matches = static void Pass(const Shape& shape, FftType type, absl::Span length, const Shape& expected_shape) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferFftShape(shape, type, length); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(expected_shape, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(expected_shape, *inferred_shape)); } static void Fail(const Shape& shape, FftType type, absl::Span length, absl::string_view message) { - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferFftShape(shape, type, length); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr(std::string(message))); } @@ -1087,35 +1088,35 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) { TEST_F(ShapeInferenceTest, MapThatChangesElementType) { const Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferMapShape({&arg}, to_apply, {0}); - EXPECT_IS_OK(inferred_status.status()); + EXPECT_IS_OK(inferred_shape.status()); const Shape expected = ShapeUtil::MakeShape(S32, {20}); - EXPECT_TRUE(ShapeUtil::Equal(expected, *inferred_status)); + EXPECT_TRUE(ShapeUtil::Equal(expected, *inferred_shape)); } TEST_F(ShapeInferenceTest, Map) { - const absl::StatusOr inferred_status_r1f32 = + const absl::StatusOr inferred_shape_r1f32 = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); - EXPECT_IS_OK(inferred_status_r1f32.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_status_r1f32)); + EXPECT_IS_OK(inferred_shape_r1f32.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape_r1f32)); // It's OK to provide a single argument, as long as the applied arity matches // (this degenerates to a Map). - const absl::StatusOr inferred_status_r1f32_one = + const absl::StatusOr inferred_shape_r1f32_one = ShapeInference::InferMapShape( {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); - EXPECT_IS_OK(inferred_status_r1f32_one.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_status_r1f32_one)); + EXPECT_IS_OK(inferred_shape_r1f32_one.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape_r1f32_one)); - const absl::StatusOr inferred_status_r2s32 = + const absl::StatusOr inferred_shape_r2s32 = ShapeInference::InferMapShape( {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); - EXPECT_IS_OK(inferred_status_r2s32.status()); - EXPECT_TRUE(ShapeUtil::Equal(s32matrix_64_64_, *inferred_status_r2s32)); + EXPECT_IS_OK(inferred_shape_r2s32.status()); + EXPECT_TRUE(ShapeUtil::Equal(s32matrix_64_64_, *inferred_shape_r2s32)); const auto no_args_error = ShapeInference::InferMapShape( {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {}); @@ -1160,37 +1161,37 @@ TEST_F(ShapeInferenceTest, Map) { const Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferMapShape({&arg}, to_apply, {0}); - EXPECT_IS_OK(inferred_status.status()); - EXPECT_TRUE(ShapeUtil::Equal(arg, *inferred_status)); + EXPECT_IS_OK(inferred_shape.status()); + EXPECT_TRUE(ShapeUtil::Equal(arg, *inferred_shape)); - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("arity must match number of arguments")); - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("has to be a scalar")); - const absl::StatusOr inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("has to be a scalar")); - const absl::StatusOr inferred_status_error5 = + const absl::StatusOr inferred_shape_error5 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); - ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT(inferred_status_error5.status().message(), + ASSERT_FALSE(inferred_shape_error5.ok()); + ASSERT_THAT(inferred_shape_error5.status().message(), HasSubstr("parameter type has to match argument")); } @@ -1198,11 +1199,11 @@ TEST_F(ShapeInferenceTest, MapWithDifferentInputTypes) { const Shape arg0 = ShapeUtil::MakeShape(F32, {20}); const Shape arg1 = ShapeUtil::MakeShape(S32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_}, s32_); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferMapShape({&arg0, &arg1}, to_apply, {0}); - EXPECT_IS_OK(inferred_status.status()); + EXPECT_IS_OK(inferred_shape.status()); const Shape expected = ShapeUtil::MakeShape(S32, {20}); - EXPECT_TRUE(ShapeUtil::Equal(expected, *inferred_status)); + EXPECT_TRUE(ShapeUtil::Equal(expected, *inferred_shape)); } TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { @@ -1255,12 +1256,11 @@ TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape( - {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_IS_OK(inferred_status.status()); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_IS_OK(inferred_shape.status()); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), - *inferred_status)); + *inferred_shape)); } TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) { @@ -1279,15 +1279,15 @@ TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) { const Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding_values, {}, {})); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReduceWindowShape( absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); - VLOG(2) << inferred_status->ToString() << "\n"; - EXPECT_IS_OK(inferred_status.status()); + VLOG(2) << inferred_shape->ToString() << "\n"; + EXPECT_IS_OK(inferred_shape.status()); EXPECT_TRUE(ShapeUtil::Equal( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}), ShapeUtil::MakeShape(S32, {5, 2, 0})}), - *inferred_status)); + *inferred_shape)); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { @@ -1296,11 +1296,10 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape( - {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); } @@ -1309,12 +1308,11 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape( - {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr( "parameter shape differs from the result shape: s32[] vs f32[]")); } @@ -1322,10 +1320,10 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("must have at least 2 arguments, has 0")); } @@ -1345,11 +1343,11 @@ TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) { const Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding_values, {}, {})); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReduceWindowShape( absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); - EXPECT_FALSE(inferred_status.status().ok()); - EXPECT_THAT(inferred_status.status().message(), HasSubstr("f32[] vs s32[]")); + EXPECT_FALSE(inferred_shape.status().ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("f32[] vs s32[]")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { @@ -1357,12 +1355,11 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape( - {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); } @@ -1371,12 +1368,11 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape( - {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); } @@ -1385,11 +1381,10 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { const Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape( - {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("accumulator shape at index 0 differs from the " "init_value shape: s32[] vs f32[]")); } @@ -1397,108 +1392,106 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape({&arg_shape, &f32_}, - /*dimensions_to_reduce=*/{3, 4}, - to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{3, 4}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("out-of-bounds dimension")); } TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("take 2 parameters")); } TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("0-th parameter shape differs")); } TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape({&arg_shape, &f32_}, - /*dimensions_to_reduce=*/{0, 0}, - to_apply); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{0, 0}, to_apply); + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("Duplicate reduction dimension: 0")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1}); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), *inferred_status)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) { const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1}); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal( - ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), *inferred_status)); + ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) { const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4}); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), *inferred_status)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4}); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), *inferred_status)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferInvalidStride) { const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1}); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_status.status().code()); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_shape.status().code()); } TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { const Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1}); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_status.status().code()); + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_EQ(tsl::error::INVALID_ARGUMENT, inferred_shape.status().code()); } TEST_F(ShapeInferenceTest, InferSliceShapeRank1) { const Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1}); - ASSERT_TRUE(inferred_status.ok()); + ASSERT_TRUE(inferred_shape.ok()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2}), *inferred_status)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferConstIndexShape) { @@ -1529,21 +1522,21 @@ TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) { TEST_F(ShapeInferenceTest, InferPowShape) { const Shape ten_floats = ShapeUtil::MakeShape(F32, {10}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kPower, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(ten_floats, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(ten_floats, *inferred_shape)); } TEST_F(ShapeInferenceTest, InferCompareShape) { const Shape ten_floats = ShapeUtil::MakeShape(F32, {10}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, ten_floats, f32_, {}); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), *inferred_status)); + ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), *inferred_shape)); } TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) { @@ -1600,11 +1593,11 @@ TEST_F(ShapeInferenceTest, InferDynamicBroadcast) { // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1} const Shape operand_shape = ShapeUtil::MakeShape(F32, {15}, {true}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBroadcastShape(operand_shape, {15}); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), - *inferred_status); + *inferred_shape); } TEST_F(ShapeInferenceTest, BroadcastScalar) { @@ -1645,10 +1638,10 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: ok TEST_F(ShapeInferenceTest, ScalarDotVector) { DotDimensionNumbers dot_dnums; - const absl::StatusOr inferred_status = ShapeInference::InferDotOpShape( + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape( f32_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); - EXPECT_TRUE(inferred_status.ok()); - EXPECT_EQ(*inferred_status, vector_32_); + EXPECT_TRUE(inferred_shape.ok()); + EXPECT_EQ(*inferred_shape, vector_32_); } // 3D 2D: error @@ -1656,11 +1649,11 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - const absl::StatusOr inferred_status = ShapeInference::InferDotOpShape( + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - EXPECT_TRUE(inferred_status.ok()); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, + EXPECT_TRUE(inferred_shape.ok()); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, ShapeUtil::MakeShape(F32, {32, 32, 64}))); } @@ -1669,15 +1662,15 @@ TEST_F(ShapeInferenceTest, VectorDotVector) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(f32_, *inferred_status)); - const absl::StatusOr inferred_status_mismatch = + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape)); + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // matrix vector -> vector @@ -1685,15 +1678,15 @@ TEST_F(ShapeInferenceTest, MatrixDotVector) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status, vector_32_)); - const absl::StatusOr inferred_status_mismatch = + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, vector_32_)); + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // vector matrix -> vector @@ -1701,15 +1694,15 @@ TEST_F(ShapeInferenceTest, VectorDotMatrix) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status, vector_64_)); - const absl::StatusOr inferred_status_mismatch = + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, vector_64_)); + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // matrix matrix -> matrix @@ -1717,17 +1710,17 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - const absl::StatusOr inferred_status_match = + const absl::StatusOr inferred_shape_match = ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status_match, matrix_32_48_)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status_match) + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, matrix_32_48_)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape_match) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - const absl::StatusOr inferred_status_mismatch = + const absl::StatusOr inferred_shape_mismatch = ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } // BatchMatMul with two batch dimensions and one contracting dimension. @@ -1745,12 +1738,12 @@ TEST_F(ShapeInferenceTest, DotGeneral) { dot_dnums.add_rhs_batch_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - const absl::StatusOr inferred_status_match = + const absl::StatusOr inferred_shape_match = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status_match, output_shape)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status_match) + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, output_shape)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape_match) << " expected: " << ShapeUtil::HumanString(output_shape); } @@ -1767,11 +1760,11 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { dot_dnums.add_rhs_contracting_dimensions(1); dot_dnums.add_rhs_batch_dimensions(0); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Must specify the same number of contracting " "dimensions for lhs and rhs.")); } @@ -1790,34 +1783,34 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { dot_dnums.add_rhs_contracting_dimensions(2); dot_dnums.add_rhs_batch_dimensions(0); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - EXPECT_TRUE(inferred_status.ok()); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, output_shape)); + EXPECT_TRUE(inferred_shape.ok()); + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, output_shape)); } TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) { const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); const Shape val_shape = ShapeUtil::MakeShape(S32, {1}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSetDimensionSizeShape(arg_shape, val_shape, /*dimension=*/0); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("value has to be S32 scalar")); } TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) { const Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); const Shape val_shape = ShapeUtil::MakeShape(U32, {}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferSetDimensionSizeShape(arg_shape, val_shape, /*dimension=*/0); - EXPECT_FALSE(inferred_status.ok()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_FALSE(inferred_shape.ok()); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("value has to be S32 scalar")); } @@ -1833,11 +1826,11 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) { dot_dnums.add_rhs_contracting_dimensions(1); dot_dnums.add_rhs_batch_dimensions(0); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Batch dimension sizes are not compatible")); } @@ -1853,11 +1846,11 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) { dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_TRUE(inferred_status.ok()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status, + ASSERT_TRUE(inferred_shape.ok()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, ShapeUtil::MakeShape(F32, {2, 11, 14}))); } @@ -1873,11 +1866,11 @@ TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("A dimension number is out of range")); } @@ -1893,11 +1886,11 @@ TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_FALSE(inferred_shape.ok()); + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("A dimension number is not unique")); } @@ -2073,13 +2066,13 @@ TEST_F(ShapeInferenceTest, DotWithIncorrectSparseDimensionSizeRatio) { sparsity_descriptor.set_dimension(1); std::vector sparsity = {sparsity_descriptor}; - const absl::StatusOr inferred_status = ShapeInference::InferDotOpShape( + const absl::StatusOr inferred_shape = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(F32, {10, 32}), ShapeUtil::MakeShape(F32, {32, 20}), dot_dnums, /*preferred_element_type=*/std::nullopt, absl::MakeSpan(sparsity)); - ASSERT_FALSE(inferred_status.ok()); + ASSERT_FALSE(inferred_shape.ok()); ASSERT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr("Sparse dimension size ratio doesn't match the descriptor")); } @@ -2109,23 +2102,23 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - absl::StatusOr inferred_status_match = + absl::StatusOr inferred_shape_match = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status_match, mat)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, mat)); - absl::StatusOr inferred_status_mismatch = + absl::StatusOr inferred_shape_mismatch = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); - inferred_status_match = + inferred_shape_match = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status_match, mat)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, mat)); - inferred_status_mismatch = + inferred_shape_mismatch = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); - ASSERT_FALSE(inferred_status_mismatch.ok()); + ASSERT_FALSE(inferred_shape_mismatch.ok()); } TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { @@ -2135,21 +2128,21 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4}); const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); - absl::StatusOr inferred_status_match = + absl::StatusOr inferred_shape_match = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, cube, matrix8_4, {1, 2}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status_match, cube)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, cube)); - inferred_status_match = ShapeInference::InferBinaryOpShape( + inferred_shape_match = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, cube, matrix16_4, {0, 2}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status_match, cube)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, cube)); - inferred_status_match = ShapeInference::InferBinaryOpShape( + inferred_shape_match = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, cube, matrix16_8, {0, 1}); - ASSERT_IS_OK(inferred_status_match.status()); - ASSERT_TRUE(ShapeUtil::Equal(*inferred_status_match, cube)); + ASSERT_IS_OK(inferred_shape_match.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape_match, cube)); } TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { @@ -2161,65 +2154,65 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Shapes must be equal rank")); // broadcast_dimension out of bounds for tensor's rank - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - const absl::StatusOr inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long - const absl::StatusOr inferred_status_error4 = + const absl::StatusOr inferred_shape_error4 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); - ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_THAT(inferred_status_error4.status().message(), + ASSERT_FALSE(inferred_shape_error4.ok()); + ASSERT_THAT(inferred_shape_error4.status().message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor - const absl::StatusOr inferred_status_error5 = + const absl::StatusOr inferred_shape_error5 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); - ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT(inferred_status_error5.status().message(), + ASSERT_FALSE(inferred_shape_error5.ok()); + ASSERT_THAT(inferred_shape_error5.status().message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order - const absl::StatusOr inferred_status_error6 = + const absl::StatusOr inferred_shape_error6 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); - ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_THAT(inferred_status_error6.status().message(), + ASSERT_FALSE(inferred_shape_error6.ok()); + ASSERT_THAT(inferred_shape_error6.status().message(), HasSubstr("dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. - const absl::StatusOr inferred_status_error7 = + const absl::StatusOr inferred_shape_error7 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); - ASSERT_FALSE(inferred_status_error7.ok()); - ASSERT_THAT(inferred_status_error7.status().message(), + ASSERT_FALSE(inferred_shape_error7.ok()); + ASSERT_THAT(inferred_shape_error7.status().message(), HasSubstr("dimensions order is wrong")); - const absl::StatusOr inferred_status_error8 = + const absl::StatusOr inferred_shape_error8 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); - ASSERT_FALSE(inferred_status_error8.ok()); - ASSERT_THAT(inferred_status_error8.status().message(), + ASSERT_FALSE(inferred_shape_error8.ok()); + ASSERT_THAT(inferred_shape_error8.status().message(), HasSubstr("dimensions order is wrong")); } @@ -2228,47 +2221,48 @@ TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) { const Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferWhileShape(cond, body, result_shape); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(result_shape, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(result_shape, *inferred_shape)); } // Tests for the while instruction with wrong shapes. TEST_F(ShapeInferenceTest, WhileWithBadShapes) { - const Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); - ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); - ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); + const Shape inferred_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); + ProgramShape cond = ShapeUtil::MakeProgramShape({inferred_shape}, pred_); + ProgramShape body = + ShapeUtil::MakeProgramShape({inferred_shape}, inferred_shape); const auto bad_shape_1 = - ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_); - const absl::StatusOr inferred_status_error1 = - ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ShapeUtil::MakeProgramShape({s32_, inferred_shape}, pred_); + const absl::StatusOr inferred_shape_error1 = + ShapeInference::InferWhileShape(bad_shape_1, body, inferred_shape); + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Condition must take 1 arguments")); const auto bad_shape_2 = - ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); - const absl::StatusOr inferred_status_error2 = - ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ShapeUtil::MakeProgramShape({s32_, inferred_shape}, inferred_shape); + const absl::StatusOr inferred_shape_error2 = + ShapeInference::InferWhileShape(cond, bad_shape_2, inferred_shape); + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("Body must take 1 arguments")); - const auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); - const absl::StatusOr inferred_status_error3 = - ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + const auto bad_shape_3 = ShapeUtil::MakeProgramShape({inferred_shape}, s32_); + const absl::StatusOr inferred_shape_error3 = + ShapeInference::InferWhileShape(bad_shape_3, body, inferred_shape); + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("Condition must return a boolean")); const auto bad_shape_4 = - ShapeUtil::MakeProgramShape({result_shape}, vector_32_); - const absl::StatusOr inferred_status_error4 = - ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); - ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_THAT(inferred_status_error4.status().message(), + ShapeUtil::MakeProgramShape({inferred_shape}, vector_32_); + const absl::StatusOr inferred_shape_error4 = + ShapeInference::InferWhileShape(cond, bad_shape_4, inferred_shape); + ASSERT_FALSE(inferred_shape_error4.ok()); + ASSERT_THAT(inferred_shape_error4.status().message(), HasSubstr("parameter of condition and body")); } @@ -2278,81 +2272,81 @@ TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) { ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false}); const auto dynamic_shape_2 = ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConcatOpShape({&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE(ShapeUtil::Equal( ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), - *inferred_status)); + *inferred_shape)); } // Tests for the concatenate instruction with proper shapes. TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) { - const absl::StatusOr inferred_status_1 = + const absl::StatusOr inferred_shape_1 = ShapeInference::InferConcatOpShape({&vector_32_, &vector_64_}, /*dimension=*/0); - ASSERT_IS_OK(inferred_status_1.status()); + ASSERT_IS_OK(inferred_shape_1.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), *inferred_status_1)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), *inferred_shape_1)); - const absl::StatusOr inferred_status_2 = + const absl::StatusOr inferred_shape_2 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0); - ASSERT_IS_OK(inferred_status_2.status()); + ASSERT_IS_OK(inferred_shape_2.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), *inferred_status_2)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), *inferred_shape_2)); - const absl::StatusOr inferred_status_3 = + const absl::StatusOr inferred_shape_3 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1); - ASSERT_IS_OK(inferred_status_3.status()); + ASSERT_IS_OK(inferred_shape_3.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), - *inferred_status_3)); + *inferred_shape_3)); } // Tests for the concatenate instruction with wrong shapes. TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferConcatOpShape({}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("Concatenate expects at least one argument")); - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("dimension out of bounds: -1")); - const absl::StatusOr inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("dimension out of bounds: 1")); const Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); - const absl::StatusOr inferred_status_error4 = + const absl::StatusOr inferred_shape_error4 = ShapeInference::InferConcatOpShape({&vector_32_, &tuple}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error4.ok()); + ASSERT_FALSE(inferred_shape_error4.ok()); ASSERT_THAT( - inferred_status_error4.status().message(), + inferred_shape_error4.status().message(), HasSubstr("Expected array argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); - const absl::StatusOr inferred_status_error5 = + const absl::StatusOr inferred_shape_error5 = ShapeInference::InferConcatOpShape({&vector_32_, &vector_s32}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT(inferred_status_error5.status().message(), + ASSERT_FALSE(inferred_shape_error5.ok()); + ASSERT_THAT(inferred_shape_error5.status().message(), HasSubstr("concatenate arrays with different element types")); - const absl::StatusOr inferred_status_error6 = + const absl::StatusOr inferred_shape_error6 = ShapeInference::InferConcatOpShape({&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); - ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_THAT(inferred_status_error6.status().message(), + ASSERT_FALSE(inferred_shape_error6.ok()); + ASSERT_THAT(inferred_shape_error6.status().message(), HasSubstr("concatenate arrays that differ in " "dimensions other than the one being " "concatenated")); @@ -2373,11 +2367,11 @@ TEST_F(ShapeInferenceTest, Pad) { dimension1->set_edge_padding_high(5); dimension1->set_interior_padding(0); - const absl::StatusOr inferred_status = ShapeInference::InferPadShape( + const absl::StatusOr inferred_shape = ShapeInference::InferPadShape( input_shape, padding_value_shape, padding_config); - ASSERT_IS_OK(inferred_status.status()); + ASSERT_IS_OK(inferred_shape.status()); ASSERT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), *inferred_status)); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), *inferred_shape)); dimension1->set_edge_padding_low(-20); dimension1->set_edge_padding_high(-10); @@ -2391,74 +2385,74 @@ TEST_F(ShapeInferenceTest, Pad) { TEST_F(ShapeInferenceTest, Reverse) { const Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReverseShape(input_shape, {0, 1}); - ASSERT_IS_OK(inferred_status.status()); - ASSERT_TRUE(ShapeUtil::Equal(input_shape, *inferred_status)); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(input_shape, *inferred_shape)); } TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { const Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); - const absl::StatusOr inferred_status_error0 = + const absl::StatusOr inferred_shape_error0 = ShapeInference::InferReverseShape(input_shape, {0, 2}); - ASSERT_FALSE(inferred_status_error0.ok()); - ASSERT_THAT(inferred_status_error0.status().message(), + ASSERT_FALSE(inferred_shape_error0.ok()); + ASSERT_THAT(inferred_shape_error0.status().message(), HasSubstr("out-of-bounds")); - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferReverseShape(input_shape, {0, -1}); - ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_THAT(inferred_status_error1.status().message(), + ASSERT_FALSE(inferred_shape_error1.ok()); + ASSERT_THAT(inferred_shape_error1.status().message(), HasSubstr("out-of-bounds")); - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferReverseShape(input_shape, {0, 0}); - ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_THAT(inferred_status_error2.status().message(), + ASSERT_FALSE(inferred_shape_error2.ok()); + ASSERT_THAT(inferred_shape_error2.status().message(), HasSubstr("duplicated")); const Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); - const absl::StatusOr inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferReverseShape(tuple_shape, {0}); - ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_THAT(inferred_status_error3.status().message(), + ASSERT_FALSE(inferred_shape_error3.ok()); + ASSERT_THAT(inferred_shape_error3.status().message(), HasSubstr("Expected array argument")); } TEST_F(ShapeInferenceTest, Call) { - const absl::StatusOr inferred_status0 = + const absl::StatusOr inferred_shape0 = ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_)); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_status0)); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape0)); - const absl::StatusOr inferred_status1 = ShapeInference::InferCallShape( + const absl::StatusOr inferred_shape1 = ShapeInference::InferCallShape( {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_}, ShapeUtil::MakeProgramShape( {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_)); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(s32matrix_64_64_, *inferred_status1)); + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(s32matrix_64_64_, *inferred_shape1)); - const absl::StatusOr inferred_status_error0 = + const absl::StatusOr inferred_shape_error0 = ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({f32_}, f32_)); - EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_THAT(inferred_status_error0.status().message(), + EXPECT_FALSE(inferred_shape_error0.ok()); + EXPECT_THAT(inferred_shape_error0.status().message(), HasSubstr("arity must match")); - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferCallShape({&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); - EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_THAT(inferred_status_error1.status().message(), + EXPECT_FALSE(inferred_shape_error1.ok()); + EXPECT_THAT(inferred_shape_error1.status().message(), HasSubstr("arity must match")); - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferCallShape({&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); - EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_THAT(inferred_status_error2.status().message(), + EXPECT_FALSE(inferred_shape_error2.ok()); + EXPECT_THAT(inferred_shape_error2.status().message(), HasSubstr("parameter must match argument")); } @@ -2481,140 +2475,140 @@ TEST_F(ShapeInferenceTest, Rank1Transpose) { } TEST_F(ShapeInferenceTest, ConditionalPred) { - const absl::StatusOr inferred_status0 = + const absl::StatusOr inferred_shape0 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_}); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_status0)); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape0)); - const absl::StatusOr inferred_status1 = + const absl::StatusOr inferred_shape1 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)}, {matrix_32_48_, vector_32_}); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_64_, *inferred_status1)); + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, *inferred_shape1)); const auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); - const absl::StatusOr inferred_status2 = + const absl::StatusOr inferred_shape2 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, {matrix_32_48_, tuple_f32_v32}); - EXPECT_IS_OK(inferred_status2.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_status2)); + EXPECT_IS_OK(inferred_shape2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape2)); - const absl::StatusOr inferred_status_error0 = + const absl::StatusOr inferred_shape_error0 = ShapeInference::InferConditionalShape( f32_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_THAT(inferred_status_error0.status().message(), + EXPECT_FALSE(inferred_shape_error0.ok()); + EXPECT_THAT(inferred_shape_error0.status().message(), HasSubstr("must be bool or int32_t")); - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); - EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_THAT(inferred_status_error1.status().message(), + EXPECT_FALSE(inferred_shape_error1.ok()); + EXPECT_THAT(inferred_shape_error1.status().message(), HasSubstr("branch computation 0 must take 1 argument")); - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_64_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_THAT(inferred_status_error2.status().message(), + EXPECT_FALSE(inferred_shape_error2.ok()); + EXPECT_THAT(inferred_shape_error2.status().message(), HasSubstr("branch operand 0 must match the shape of the only " "parameter of branch computation 0")); - const absl::StatusOr inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)}, {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})}); - EXPECT_FALSE(inferred_status_error3.ok()); - EXPECT_THAT(inferred_status_error3.status().message(), + EXPECT_FALSE(inferred_shape_error3.ok()); + EXPECT_THAT(inferred_shape_error3.status().message(), HasSubstr("branch computation 1 must take 1 argument")); - const absl::StatusOr inferred_status_error4 = + const absl::StatusOr inferred_shape_error4 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error4.ok()); - EXPECT_THAT(inferred_status_error4.status().message(), + EXPECT_FALSE(inferred_shape_error4.ok()); + EXPECT_THAT(inferred_shape_error4.status().message(), HasSubstr("branch operand 1 must match the shape of the only " "parameter of branch computation 1")); - const absl::StatusOr inferred_status_error5 = + const absl::StatusOr inferred_shape_error5 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, {vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error5.ok()); - EXPECT_THAT(inferred_status_error5.status().message(), + EXPECT_FALSE(inferred_shape_error5.ok()); + EXPECT_THAT(inferred_shape_error5.status().message(), HasSubstr("the result of branch 0 computation and branch 1 " "computation must have the same shape")); } TEST_F(ShapeInferenceTest, ConditionalIndexed) { const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); - const absl::StatusOr inferred_status0 = + const absl::StatusOr inferred_shape0 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_, vector_64_}); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_status0)); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, *inferred_shape0)); - const absl::StatusOr inferred_status1 = + const absl::StatusOr inferred_shape1 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), ShapeUtil::MakeProgramShape({vector_32_}, vector_64_), ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)}, {matrix_32_48_, vector_32_, matrix_32_48_}); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_64_, *inferred_status1)); + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, *inferred_shape1)); const auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); - const absl::StatusOr inferred_status2 = + const absl::StatusOr inferred_shape2 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, {tuple_f32_v32}); - EXPECT_IS_OK(inferred_status2.status()); - EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_status2)); + EXPECT_IS_OK(inferred_shape2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, *inferred_shape2)); - const absl::StatusOr inferred_status_error0 = + const absl::StatusOr inferred_shape_error0 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_THAT(inferred_status_error0.status().message(), + EXPECT_FALSE(inferred_shape_error0.ok()); + EXPECT_THAT(inferred_shape_error0.status().message(), HasSubstr("2 == branch_computations.size()")); - const absl::StatusOr inferred_status_error1 = + const absl::StatusOr inferred_shape_error1 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), @@ -2622,23 +2616,23 @@ TEST_F(ShapeInferenceTest, ConditionalIndexed) { ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); - EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_THAT(inferred_status_error1.status().message(), + EXPECT_FALSE(inferred_shape_error1.ok()); + EXPECT_THAT(inferred_shape_error1.status().message(), HasSubstr("branch computation 1 must take 1 argument")); - const absl::StatusOr inferred_status_error2 = + const absl::StatusOr inferred_shape_error2 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({r0s32}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, {r0s32, vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_THAT(inferred_status_error2.status().message(), + EXPECT_FALSE(inferred_shape_error2.ok()); + EXPECT_THAT(inferred_shape_error2.status().message(), HasSubstr("branch operand 2 must match the shape of the only " "parameter of branch computation 2")); - const absl::StatusOr inferred_status_error3 = + const absl::StatusOr inferred_shape_error3 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), @@ -2646,15 +2640,15 @@ TEST_F(ShapeInferenceTest, ConditionalIndexed) { ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, {vector_32_, vector_32_, vector_32_, vector_64_}); - EXPECT_FALSE(inferred_status_error3.ok()); - EXPECT_THAT(inferred_status_error3.status().message(), + EXPECT_FALSE(inferred_shape_error3.ok()); + EXPECT_THAT(inferred_shape_error3.status().message(), HasSubstr("the result of branch 0 computation and branch 3 " "computation must have the same shape")); - const absl::StatusOr inferred_status_error4 = + const absl::StatusOr inferred_shape_error4 = ShapeInference::InferConditionalShape(r0s32, {}, {}); - EXPECT_FALSE(inferred_status_error4.ok()); - EXPECT_THAT(inferred_status_error4.status().message(), + EXPECT_FALSE(inferred_shape_error4.ok()); + EXPECT_THAT(inferred_shape_error4.status().message(), HasSubstr("!branch_computations.empty()")); } @@ -2662,25 +2656,25 @@ TEST_F(ShapeInferenceTest, ConditionalDynamic) { const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); const Shape static_shape = ShapeUtil::MakeShape(S32, {4}, {false}); const Shape dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true}); - const absl::StatusOr inferred_status0 = + const absl::StatusOr inferred_shape0 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, static_shape), ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape), ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, {vector_32_, vector_64_, vector_64_}); - EXPECT_IS_OK(inferred_status0.status()); - EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, *inferred_status0)); + EXPECT_IS_OK(inferred_shape0.status()); + EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, *inferred_shape0)); - const absl::StatusOr inferred_status1 = + const absl::StatusOr inferred_shape1 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape), ShapeUtil::MakeProgramShape({vector_64_}, static_shape), ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, {vector_32_, vector_64_, vector_64_}); - EXPECT_IS_OK(inferred_status1.status()); - EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, *inferred_status1)); + EXPECT_IS_OK(inferred_shape1.status()); + EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, *inferred_shape1)); } TEST_F(ShapeInferenceTest, BadSlice) { @@ -4018,18 +4012,18 @@ TEST_P(UnboundedUnaryOpShapeInferenceTest, UnboundedUnaryOps) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4037,18 +4031,18 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAdd) { TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedAnd) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kAnd, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4056,17 +4050,17 @@ TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedAnd) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAtan2) { TF_ASSERT_OK_AND_ASSIGN(Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kAtan2, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4074,11 +4068,11 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedAtan2) { TEST_F(ShapeInferenceTest, UnboundedBitcastConvert) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]")); TF_ASSERT_OK_AND_ASSIGN( - const Shape inferred_status, + const Shape inferred_shape, ShapeInference::InferBitcastConvertShape(operand, PrimitiveType::F16)); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f16[?, 10, 2]")); - EXPECT_TRUE(ShapeUtil::Equal(inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } @@ -4091,13 +4085,13 @@ TEST_F(ShapeInferenceTest, UnboundedBatchNormGrad) { TF_ASSERT_OK_AND_ASSIGN(const Shape grad_scale, ParseShape("f32[?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape grad_offset, ParseShape("f32[?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape grad_output, ParseShape("f32[5, ?, 7]")); - TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferBatchNormGradShape( operand, scale, mean, variance, grad_output, 1)); const Shape expected_tuple_shape = ShapeUtil::MakeTupleShape({grad_operand, grad_scale, grad_offset}); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected_tuple_shape)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected_tuple_shape)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected_tuple_shape); } @@ -4107,12 +4101,12 @@ TEST_F(ShapeInferenceTest, UnboundedBatchNormInference) { TF_ASSERT_OK_AND_ASSIGN(const Shape offset, ParseShape("f32[5]")); TF_ASSERT_OK_AND_ASSIGN(const Shape mean, ParseShape("f32[5]")); TF_ASSERT_OK_AND_ASSIGN(const Shape variance, ParseShape("f32[5]")); - TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferBatchNormInferenceShape( operand, scale, offset, mean, variance, 1)); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?, 7]")); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } @@ -4126,28 +4120,28 @@ TEST_F(ShapeInferenceTest, UnboundedBatchNormTraining) { const Shape expected_tuple_shape = ShapeUtil::MakeTupleShape({output, batch_mean, batch_var}); TF_ASSERT_OK_AND_ASSIGN( - const Shape result_shape, + const Shape inferred_shape, ShapeInference::InferBatchNormTrainingShape(operand, scale, offset, 1)); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected_tuple_shape)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected_tuple_shape)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected_tuple_shape); } TEST_F(ShapeInferenceTest, UnboundedBroadcastUnsupportedOperand) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, <=2, ?]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBroadcastShape(operand, /*broadcast_sizes=*/{1}); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("is_unbounded_dynamic")); } TEST_F(ShapeInferenceTest, UnboundedBroadcastUnsupportedBroadcastSize) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, 4]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBroadcastShape( operand, /*broadcast_sizes=*/{Shape::kUnboundedSize}); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("Non-broadcast dimensions must not be dynamic.")); } @@ -4155,11 +4149,11 @@ TEST_F(ShapeInferenceTest, UnboundedBroadcastInDim) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 3, 4]")); TF_ASSERT_OK_AND_ASSIGN( - const Shape inferred_status, + const Shape inferred_shape, ShapeInference::InferBroadcastShape(operand, expected, /*broadcast_dimensions=*/{0, 2})); - EXPECT_TRUE(ShapeUtil::Equal(inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } @@ -4167,30 +4161,30 @@ TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimToBounded) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 3, <=4]")); TF_ASSERT_OK_AND_ASSIGN( - const Shape inferred_status, + const Shape inferred_shape, ShapeInference::InferBroadcastShape(operand, expected, /*broadcast_dimensions=*/{0, 2})); - EXPECT_TRUE(ShapeUtil::Equal(inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupportedOutput) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 3, ?]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBroadcastShape(operand, expected, /*broadcast_dimensions=*/{0, 2}); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("is_unbounded_dynamic")); } TEST_F(ShapeInferenceTest, UnboundedBroadcastInDimUnsupported) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[<=2, 4]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBroadcastShape( operand, /*broadcast_sizes=*/{2, Shape::kUnboundedSize, 4}); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("Non-broadcast dimensions must not be dynamic.")); } @@ -4198,15 +4192,15 @@ TEST_P(UnboundedClampOpShapeInferenceTest, UnboundedClamp) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1])); TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape(GetParam()[2])); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, lhs, rhs, ehs); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam()[3])); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { - EXPECT_EQ(inferred_status.status().message(), GetParam()[4]); + EXPECT_EQ(inferred_shape.status().message(), GetParam()[4]); } } @@ -4215,10 +4209,10 @@ TEST_F(ShapeInferenceTest, UnboundedClampWithTuple) { TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("(f32[?], f32[2])")); TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("(f32[2], f32[?])")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("(f32[?], f32[2])")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, lhs, rhs, ehs); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr( "Expected array argument for clamp min, but got (f32[2], f32[?]).")); } @@ -4226,18 +4220,18 @@ TEST_F(ShapeInferenceTest, UnboundedClampWithTuple) { TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4245,16 +4239,16 @@ TEST_P(UnboundedCompareOpShapeInferenceTest, UnboundedCompare) { TEST_P(UnboundedConcatenateOpShapeInferenceTest, UnboundedConcatenate) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape(GetParam()[1])); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConcatOpShape({&operand1, &operand2}, /*dimension=*/0); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam()[2])); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { - EXPECT_EQ(inferred_status.status().message(), GetParam()[3]); + EXPECT_EQ(inferred_shape.status().message(), GetParam()[3]); } } @@ -4263,10 +4257,10 @@ TEST_F(UnboundedConcatenateOpShapeInferenceTest, TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape("f32[2, 3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape operand3, ParseShape("f32[2, 4]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConcatOpShape({&operand1, &operand2, &operand3}, /*dimension=*/0); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("Mismatched dimension sizes 3 and 4 in dimension 1")); } @@ -4275,10 +4269,10 @@ TEST_F(UnboundedConcatenateOpShapeInferenceTest, TF_ASSERT_OK_AND_ASSIGN(const Shape operand1, ParseShape("f32[2, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape operand2, ParseShape("f32[2, <=3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape operand3, ParseShape("f32[2, <=4]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferConcatOpShape({&operand1, &operand2, &operand3}, /*dimension=*/0); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("Mismatched bound sizes 3 and 4 in dimension 1")); } @@ -4319,31 +4313,31 @@ TEST_F(ShapeInferenceTest, UnboundedConvolution) { /*window_dimensions=*/{2, 2}, /*window_strides=*/{1, 1}, Padding::kValid), /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); - TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferConvolveShape( lhs, rhs, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt)); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedDiv) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kDivide, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4358,11 +4352,11 @@ TEST_F(ShapeInferenceTest, UnboundedDot) { dnums.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN( - const Shape result_shape, + const Shape inferred_shape, ShapeInference::InferDotOpShape(lhs, rhs, dnums, /*preferred_element_type=*/std::nullopt)); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } @@ -4378,11 +4372,11 @@ TEST_F(ShapeInferenceTest, UnboundedDotGeneral) { dnums.add_rhs_contracting_dimensions(1); TF_ASSERT_OK_AND_ASSIGN( - const Shape result_shape, + const Shape inferred_shape, ShapeInference::InferDotOpShape(lhs, rhs, dnums, /*preferred_element_type=*/std::nullopt)); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } @@ -4400,30 +4394,30 @@ TEST_F(ShapeInferenceTest, UnboundedGather) { dimension_numbers.add_start_index_map(0); dimension_numbers.set_index_vector_dim(2); - TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferGatherShape( operand, start_indices, dimension_numbers, /*slice_sizes=*/{1, 2, 2})); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kMaximum, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4431,18 +4425,18 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMax) { TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4450,18 +4444,18 @@ TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedMul) { TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedOr) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kOr, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4480,28 +4474,28 @@ TEST_F(ShapeInferenceTest, UnboundedPad) { } TF_ASSERT_OK_AND_ASSIGN( - const Shape result_shape, + const Shape inferred_shape, ShapeInference::InferPadShape(operand, padding_value, padding_config)); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedPow) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kPower, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4515,13 +4509,13 @@ TEST_F(ShapeInferenceTest, UnboundedReduce) { {f32_, f32_, f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, f32_, f32_})); TF_ASSERT_OK_AND_ASSIGN( - const Shape result_shape, + const Shape inferred_shape, ShapeInference::InferReduceShape( {&input0, &input1, &input2, &f32_, &f32_, &f32_}, {1}, to_apply)); const Shape shape = ShapeUtil::MakeShape(F32, {7}); const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } @@ -4533,10 +4527,9 @@ TEST_F(ShapeInferenceTest, UnboundedReduceInvalidReduceDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, f32_, f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, f32_, f32_})); - const absl::StatusOr inferred_status = - ShapeInference::InferReduceShape( - {&input0, &input1, &input2, &f32_, &f32_, &f32_}, {1}, to_apply); - EXPECT_THAT(inferred_status.status().message(), + const absl::StatusOr inferred_shape = ShapeInference::InferReduceShape( + {&input0, &input1, &input2, &f32_, &f32_, &f32_}, {1}, to_apply); + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("All reduced tensors must have compatible dimension")); } @@ -4582,22 +4575,22 @@ TEST_F(ShapeInferenceTest, UnboundedReshape) { TEST_F(ShapeInferenceTest, UnboundedReshapeUnsupportedOutputShape) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[6]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReshapeShape( operand, /*dimensions=*/{0}, /*new_sizes=*/{Shape::kUnboundedSize, Shape::kUnboundedSize}, -1); EXPECT_THAT( - inferred_status.status().message(), + inferred_shape.status().message(), HasSubstr("Reshaping with unbounded result shape is not supported.")); } TEST_F(ShapeInferenceTest, UnboundedReshapeUnsupportedMixOfDynamism) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, <=3]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=3]")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferReshapeShape(operand, /*dimensions=*/{0}, /*new_sizes=*/{3}, -1); - ASSERT_THAT(inferred_status.status().message(), + ASSERT_THAT(inferred_shape.status().message(), HasSubstr("Reshape operand with bounded and unbounded dynamism " "not supported.")); } @@ -4606,15 +4599,15 @@ TEST_P(UnboundedSelectOpShapeInferenceTest, UnboundedSelect) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam()[0])); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam()[1])); TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape(GetParam()[2])); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, lhs, rhs, ehs); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam()[3])); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { - EXPECT_EQ(inferred_status.status().message(), GetParam()[4]); + EXPECT_EQ(inferred_shape.status().message(), GetParam()[4]); } } @@ -4623,9 +4616,9 @@ TEST_F(ShapeInferenceTest, UnboundedSelectWithTupleUnsupported) { TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape("(f32[?], f32[2])")); TF_ASSERT_OK_AND_ASSIGN(const Shape ehs, ParseShape("(f32[2], f32[?])")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("(f32[?], f32[2])")); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, lhs, rhs, ehs); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr("Expected array argument for select pred, but got " "(pred[2], pred[?]).")); } @@ -4634,30 +4627,30 @@ TEST_F(ShapeInferenceTest, UnboundedSlice) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, <=3, ?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[1, <=2, 3]")); TF_ASSERT_OK_AND_ASSIGN( - const Shape result_shape, + const Shape inferred_shape, ShapeInference::InferSliceShape(operand, /*starts=*/{0, 1, 2}, /*limits=*/{1, 3, 5}, /*strides=*/{1, 1, 1})); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } TEST_P(UnboundedBinaryOpShapeInferenceTest, UnboundedSub) { TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs)); TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs)); - const absl::StatusOr inferred_status = + const absl::StatusOr inferred_shape = ShapeInference::InferBinaryOpShape(HloOpcode::kSubtract, lhs, rhs, GetParam().broadcast_dimensions); - if (inferred_status.ok()) { + if (inferred_shape.ok()) { TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape(GetParam().expected)); - EXPECT_TRUE(ShapeUtil::Equal(*inferred_status, expected)) - << "inferred: " << ShapeUtil::HumanString(*inferred_status) + EXPECT_TRUE(ShapeUtil::Equal(*inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(*inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } else { ASSERT_TRUE(GetParam().error_message.has_value()); - EXPECT_THAT(inferred_status.status().message(), + EXPECT_THAT(inferred_shape.status().message(), HasSubstr(*GetParam().error_message)); } } @@ -4693,11 +4686,11 @@ TEST_F(ShapeInferenceTest, UnboundedTranspose) { ParseShape("f32[1, ?, 2, ?, <=2]{4,3,2,1,0}")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[<=2, 1, ?, 2, ?]{0,2,3,4,1}")); - TF_ASSERT_OK_AND_ASSIGN(const Shape result_shape, + TF_ASSERT_OK_AND_ASSIGN(const Shape inferred_shape, ShapeInference::InferTransposeShape( operand, /*dimensions=*/{4, 0, 3, 2, 1})); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } @@ -4705,10 +4698,10 @@ TEST_F(ShapeInferenceTest, UnboundedTransposeRank1) { TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?]")); TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?]")); TF_ASSERT_OK_AND_ASSIGN( - const Shape result_shape, + const Shape inferred_shape, ShapeInference::InferTransposeShape(operand, /*dimensions=*/{0})); - EXPECT_TRUE(ShapeUtil::Equal(result_shape, expected)) - << "inferred: " << ShapeUtil::HumanString(result_shape) + EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected)) + << "inferred: " << ShapeUtil::HumanString(inferred_shape) << " expected: " << ShapeUtil::HumanString(expected); } From 2c908271c07aadddbd7c6bfea7decc0be0026f65 Mon Sep 17 00:00:00 2001 From: Swachhand Lokhande Date: Fri, 29 Mar 2024 16:24:01 -0700 Subject: [PATCH 103/124] Add a contextmanager to temporarily disable XLA sharding support for ResourceVariables. PiperOrigin-RevId: 620366887 --- .../resource_variable_xla_sharding_test.py | 50 +++++++++++++++++++ tensorflow/python/eager/context.py | 19 +++++++ 2 files changed, 69 insertions(+) diff --git a/tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py b/tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py index ef7192a4f45807..cab7c9810063fa 100644 --- a/tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py +++ b/tensorflow/python/compiler/xla/experimental/resource_variable_xla_sharding_test.py @@ -131,6 +131,56 @@ def tpu_fn(x): sharding_proto, ) + def test_disabling_xla_sharding_ops_temporarily(self): + w = variables.Variable( + initial_value=math_ops.range(8, dtype=dtypes.float32), + name='w', + ) + self.assertIsInstance(w, resource_variable_ops.BaseResourceVariable) + + context.enable_xla_sharding_for_resource_variables() + with context.temporarily_disable_xla_sharding_for_resource_variables(): + with self.assertRaisesRegex( + AttributeError, + '.*Tensor.op is undefined when eager execution is enabled.*', + ): + xla_sharding.split( + w, + split_dimension=0, + num_devices=8, + ) + + # xla_sharding_for_resource_variables is enabled again. Following line + # doesn't throw an error. + xla_sharding.split( + w, + split_dimension=0, + num_devices=8, + ) + + context.disable_xla_sharding_for_resource_variables() + with context.temporarily_disable_xla_sharding_for_resource_variables(): + with self.assertRaisesRegex( + AttributeError, + '.*Tensor.op is undefined when eager execution is enabled.*', + ): + xla_sharding.split( + w, + split_dimension=0, + num_devices=8, + ) + + # xla_sharding_for_resource_variables stays disabled. + with self.assertRaisesRegex( + AttributeError, + '.*Tensor.op is undefined when eager execution is enabled.*', + ): + xla_sharding.split( + w, + split_dimension=0, + num_devices=8, + ) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 0cb1074d251f04..e400fbaa117209 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -150,6 +150,25 @@ def xla_sharding_for_resource_variables_enabled(): return _XLA_SHARDING_FOR_RESOURCE_VARIABLES +@contextlib.contextmanager +def temporarily_disable_xla_sharding_for_resource_variables(): + """Temporarily disables XLA sharding for resource variables. + + Should be a no-op if it is already disabled. + + Yields: + None. + """ + previously_enabled = xla_sharding_for_resource_variables_enabled() + + try: + disable_xla_sharding_for_resource_variables() + yield + finally: + if previously_enabled: + enable_xla_sharding_for_resource_variables() + + # Expose it as internally public APIs for Keras use cases in b/171080602. tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled) From 27bd888ea71430ba64bd8f1ab16cc201cf5c36cb Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 29 Mar 2024 17:57:07 -0700 Subject: [PATCH 104/124] [xla:gpu] AddressComputationFusionRewriter should run before other fusions PiperOrigin-RevId: 620383409 --- .../address_computation_fusion_rewriter.cc | 13 - ...ddress_computation_fusion_rewriter_test.cc | 231 ++++++------------ .../xla/xla/service/gpu/gpu_compiler.cc | 21 +- .../xla/xla/service/gpu/gpu_compiler_test.cc | 13 +- 4 files changed, 99 insertions(+), 179 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index 631ce71a4407d1..afb429b1942ab3 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -404,8 +404,6 @@ absl::StatusOr CreateFusionInstruction( absl::StatusOr AddressComputationFusionRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - if (!module->has_schedule()) return Internal("module is not scheduled"); - auto process_slices = [&](bool dynamic) -> absl::StatusOr { absl::flat_hash_map> @@ -445,7 +443,6 @@ absl::StatusOr AddressComputationFusionRewriter::Run( if (matches.empty()) return false; - HloSchedule& schedule = module->schedule(); for (auto& [hero, paths] : matches) { auto& [sliced_operand_paths, sliced_user_paths] = paths; std::vector matched_instrs; @@ -471,15 +468,7 @@ absl::StatusOr AddressComputationFusionRewriter::Run( CreateFusionInstruction(module, hero, captures, fusion_body, dynamic)); - // As we are running after scheduling we have to keep it valid. HloComputation* parent = hero->parent(); - // Update schedule to replace the custom call instruction with the fusion - // instruction. - // Removal of the rest of the instructions in the sequence is handled by - // schedule update below. - HloInstructionSequence& sequence = schedule.GetOrCreateSequence(parent); - sequence.replace_instruction(hero, fusion); - if (fusion->shape().IsTuple()) { TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape( const_cast(hero), fusion)); @@ -519,8 +508,6 @@ absl::StatusOr AddressComputationFusionRewriter::Run( } } - TF_RETURN_IF_ERROR(module->schedule().Update()); - return true; }; diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc index ec358a7b522e9f..4d14024115a621 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter_test.cc @@ -51,7 +51,7 @@ class AddressComputationFusionRewriterTest : public HloTestBase {}; TEST_F(AddressComputationFusionRewriterTest, SimpleGemm) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -107,15 +107,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemm) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWithWorkspace) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -175,15 +172,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWithWorkspace) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWorkspaceIgnored) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -245,15 +239,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmWorkspaceIgnored) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotRoot) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -311,16 +302,13 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotRoot) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandHasMultipleUsers) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -380,16 +368,13 @@ TEST_F(AddressComputationFusionRewriterTest, auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandsHaveMultipleUsers) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -448,7 +433,7 @@ TEST_F(AddressComputationFusionRewriterTest, TEST_F(AddressComputationFusionRewriterTest, SimpleGemmSlicingNotParameter) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[4,8,8]{2,1,0} parameter(0) @@ -510,15 +495,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmSlicingNotParameter) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotContiguousSlice) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -557,7 +539,7 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNotContiguousSlice) { TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -600,7 +582,7 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmNonNoOpInSliceChain) { TEST_F(AddressComputationFusionRewriterTest, SimpleGemmDuplicateOperand) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main { %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) @@ -674,15 +656,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmDuplicateOperand) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(1) @@ -740,15 +719,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder2) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -806,15 +782,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmReverseOperandOrder2) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandAliasingOutput) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = (f32[100,100]{1,0}, f32[100,100]{1,0}) parameter(0) @@ -873,15 +846,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandAliasingOutput) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandsFromSameSlice) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -934,10 +904,7 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandsFromSameSlice) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src, @@ -977,12 +944,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCall) { hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(hlo.get(), [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - })); - TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); const char* expected = R"( ; CHECK: %address-computation {{.*}} { @@ -1006,12 +973,8 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCall) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo->ToString(), - AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); } void Callback_Void(se::gpu::GpuStreamHandle stream, void** buffers, @@ -1035,12 +998,12 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCallLegacy) { hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(hlo.get(), [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - })); - TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); const char* expected = R"( ; CHECK: %address-computation {{.*}} { @@ -1063,12 +1026,8 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleCustomCallLegacy) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo->ToString(), - AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); } TEST_F(AddressComputationFusionRewriterTest, TupleSliceCustomCallLegacy) { @@ -1099,12 +1058,12 @@ TEST_F(AddressComputationFusionRewriterTest, TupleSliceCustomCallLegacy) { hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(hlo.get(), [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - })); - TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); const char* expected = R"( ; CHECK: %address-computation {{.*}} { @@ -1128,12 +1087,8 @@ TEST_F(AddressComputationFusionRewriterTest, TupleSliceCustomCallLegacy) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo->ToString(), - AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); } TEST_F(AddressComputationFusionRewriterTest, TupledOutputCustomCallLegacy) { @@ -1175,12 +1130,12 @@ TEST_F(AddressComputationFusionRewriterTest, TupledOutputCustomCallLegacy) { hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(hlo.get(), [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - })); - TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); const char* expected = R"( ; CHECK: %address-computation {{.*}} { @@ -1216,12 +1171,8 @@ TEST_F(AddressComputationFusionRewriterTest, TupledOutputCustomCallLegacy) { )"; auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - RunAndFilecheckHloRewrite(hlo->ToString(), - AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + RunAndFilecheckHloRewrite( + hlo->ToString(), AddressComputationFusionRewriter(PLATFORM), expected); } TEST_F(AddressComputationFusionRewriterTest, UnalignedSlice) { @@ -1240,12 +1191,12 @@ TEST_F(AddressComputationFusionRewriterTest, UnalignedSlice) { hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(hlo.get(), [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - })); - TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); + // TF_ASSERT_OK_AND_ASSIGN( + // HloSchedule schedule, + // ScheduleModule(hlo.get(), [](const BufferValue& buffer) { + // return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + // })); + // TF_CHECK_OK(hlo->set_schedule(std::move(schedule))); auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo->ToString(), @@ -1255,7 +1206,7 @@ TEST_F(AddressComputationFusionRewriterTest, UnalignedSlice) { TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemm) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY main.9 { p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -1315,15 +1266,12 @@ TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemm) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmWithWorkspace) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY main.9 { p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -1388,16 +1336,13 @@ TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmWithWorkspace) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmWorkspaceIgnored) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY main.9 { p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -1463,15 +1408,12 @@ TEST_F(AddressComputationFusionRewriterTest, auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmNotRoot) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY main.9 { p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -1533,15 +1475,12 @@ TEST_F(AddressComputationFusionRewriterTest, DynamicSimpleGemmNotRoot) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemm) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY main.9 { p0 = f16[1,8,8]{2,1,0} parameter(0) @@ -1600,15 +1539,12 @@ TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemm) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmNotRoot) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY main.9 { p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -1676,15 +1612,12 @@ TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmNotRoot) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmWithWorkspace) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY main.9 { p0 = f16[2,8,8]{2,1,0} parameter(0) @@ -1762,15 +1695,12 @@ TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmWithWorkspace) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) { const char* hlo = R"( - HloModule test, is_scheduled=true + HloModule test ENTRY %main.9 { %p0 = f16[8,8]{1,0} parameter(0) @@ -1833,10 +1763,7 @@ TEST_F(AddressComputationFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) { auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); RunAndFilecheckHloRewrite(hlo, AddressComputationFusionRewriter(PLATFORM), - expected, [](HloModule* module) { - EXPECT_TRUE(module->has_schedule()); - TF_CHECK_OK(module->schedule().Verify()); - }); + expected); } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index c9c7e3523170c4..63d1082da8ef9e 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -1286,6 +1286,17 @@ absl::Status GpuCompiler::OptimizeHloModule( TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment( hlo_module, stream_exec, options, gpu_target_config, thread_pool.get())); + // This is a "low effort, high impact" fusion that should be run first. + if (hlo_module->config() + .debug_options() + .xla_gpu_enable_address_computation_fusion()) { + HloPassPipeline pipeline("address-computation"); + TF_ASSIGN_OR_RETURN(se::Platform * platform, + se::PlatformManager::PlatformWithId(PlatformId())); + pipeline.AddPass(platform->Name()); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + TF_RETURN_IF_ERROR(RunFusionPasses(hlo_module, gpu_target_config, thread_pool.get(), ShapeSizeBytesFunction())); @@ -2218,16 +2229,6 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( } } - if (module->config() - .debug_options() - .xla_gpu_enable_address_computation_fusion()) { - HloPassPipeline pipeline("address-computation"); - TF_ASSIGN_OR_RETURN(se::Platform * platform, - se::PlatformManager::PlatformWithId(PlatformId())); - pipeline.AddPass(platform->Name()); - TF_RETURN_IF_ERROR(pipeline.Run(module).status()); - } - { HloPassPipeline pipeline("fusion-wrapper"); pipeline.AddPass(); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 02511c1f99e6bc..b168f69308eb55 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -349,16 +349,21 @@ ENTRY main { )"; HloModuleConfig config; - DebugOptions debug_options = GetDebugOptionsForTest(); - config.set_debug_options(GetDebugOptionsForTest()); + DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); + triton_enabled_debug_options.set_xla_gpu_enable_address_computation_fusion( + false); + config.set_debug_options(triton_enabled_debug_options); config.set_replica_count(1); config.set_num_partitions(1); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string, config)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_enabled_module, GetOptimizedModule(std::move(module))); - debug_options.set_xla_gpu_enable_triton_gemm(false); - config.set_debug_options(debug_options); + DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); + triton_disabled_debug_options.set_xla_gpu_enable_address_computation_fusion( + false); + triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); + config.set_debug_options(triton_disabled_debug_options); TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnVerifiedModule(hlo_string, config)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_disabled_module, From ce7a13230b6a169b66b6c7a63abddd37b24849fe Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Fri, 29 Mar 2024 18:41:05 -0700 Subject: [PATCH 105/124] [xla:gpu] Unify GEMM emission for (Dynamic)AddressComputationFusion emitters PiperOrigin-RevId: 620391616 --- .../address_computation_fusion_test.cc | 16 +- .../xla/xla/service/gpu/fusions/custom.cc | 255 +++++++++--------- 2 files changed, 133 insertions(+), 138 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc index 3a20611f1fde4b..341cf154394a87 100644 --- a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc @@ -159,7 +159,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmSimple) { %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} %p1 = bf16[2,8,8]{2,1,0} parameter(1), sharding={replicated} ROOT %fusion.2 = bf16[8,8]{1,0} fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, - backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), @@ -241,7 +241,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmWithWorkspace) { %p0 = f16[2,8,8]{2,1,0} parameter(0), sharding={replicated} %p1 = f16[2,8,8]{2,1,0} parameter(1), sharding={replicated} ROOT %fusion.2 = (f16[8,8]{1,0}, s8[256]{0}) fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, - backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), @@ -320,7 +320,7 @@ TEST_F(AddressComputationFusionTest, ContiguousSlice) { %p0 = bf16[2,8,8]{2,1,0} parameter(0), sharding={replicated} %p1 = bf16[8,8,10,8]{3,2,1,0} parameter(1), sharding={replicated} ROOT %fusion.2 = bf16[4,8]{1,0} fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, - backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), @@ -399,7 +399,7 @@ TEST_F(AddressComputationFusionTest, ContiguousSliceNonDefaultLayout) { %p0 = bf16[2,8,8]{1,2,0} parameter(0), sharding={replicated} %p1 = bf16[8,8,10,8]{1,2,3,0} parameter(1), sharding={replicated} ROOT %fusion.2 = bf16[4,8]{1,0} fusion(%p0, %p1), kind=kCustom, calls=%fused_computation, - backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"}}} + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), @@ -529,7 +529,7 @@ TEST_F(AddressComputationFusionTest, OperandIsSlicedGetTupleElement) { calls=%address-computation, backend_config={ "fusion_backend_config":{ - "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} } } })"; @@ -615,7 +615,7 @@ TEST_F(AddressComputationFusionTest, ReversedOperandOrder) { calls=%address-computation, backend_config={ "fusion_backend_config":{ - "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} } } })"; @@ -746,7 +746,7 @@ TEST_F(AddressComputationFusionTest, SingleOperandComputation) { calls=%address-computation, backend_config={ "fusion_backend_config":{ - "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} } } })"; @@ -837,7 +837,7 @@ TEST_F(AddressComputationFusionTest, SlicedOperandAliasingOutput) { output_to_operand_aliasing={{0}: (1, {})}, backend_config={ "fusion_backend_config":{ - "kind":"__custom_fusion","custom_fusion_config":{"name":"address_computation"} + "kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} } } })"; diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index c8c0900bf0523b..48d9ff6fc9cc6a 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -147,75 +147,24 @@ absl::StatusOr EmitGemm( const BufferAssignment& buffer_assignment = ir_emitter_context.buffer_assignment(); - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice lhs_slice, - GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion, - *custom_call.operand(kLHSOperandIndex), - /*index=*/{})); - - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice rhs_slice, - GetSliceWithUpdatedOffsetAndSize(buffer_assignment, adaptor, fusion, - *custom_call.operand(kRHSOperandIndex), - /*index=*/{})); - - BufferAllocation::Slice output; - std::optional workspace; - - // Result of a legacy cuBLAS custom call can be a tuple if we explicitly - // allocate workspace buffer in HLO. If result is an array, it means that - // workspace is not available, and cuBLAS will allocate its own workspace. - if (custom_call.shape().IsArray()) { - TF_ASSIGN_OR_RETURN(output, - GetAllocationSlice(buffer_assignment, &fusion, {})); - } else { - TF_ASSIGN_OR_RETURN(output, GetAllocationSlice(buffer_assignment, &fusion, - {kGEMMOutputBufferIndex})); - TF_ASSIGN_OR_RETURN(workspace, - GetAllocationSlice(buffer_assignment, &fusion, - {kGEMMWorkspaceBufferIndex})); - } - - bool deterministic_ops = - ir_emitter_context.debug_options().xla_gpu_deterministic_ops(); - - TF_ASSIGN_OR_RETURN( - GemmConfig config, - GemmConfig::For(static_cast(&custom_call))); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), std::move(config), - lhs_slice, rhs_slice, output, workspace, deterministic_ops); - - FusionEmissionResult result; - result.thunks.push_back(std::move(thunk)); - return result; -} - -absl::StatusOr EmitDynamicSlicedGemm( - IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, - const HloFusionInstruction& fusion, - const HloCustomCallInstruction& custom_call) { - const BufferAssignment& buffer_assignment = - ir_emitter_context.buffer_assignment(); - std::vector>> offset_buffer_indices(4, std::nullopt); std::vector> orig_shapes(4, std::nullopt); std::vector> sliced_shapes(4, std::nullopt); std::vector> offset_byte_sizes(4, std::nullopt); - HloDynamicIndexInstruction* slice_instr = nullptr; + std::vector slice_instrs(4, nullptr); auto get_original_operand_slice = - [&](const HloInstruction* start, - const ShapeIndex& index) -> absl::StatusOr { - auto* param = DynCast(start); - auto slice_adaptor = HloFindIf( - {HloInstructionAdaptor(*start)}, adaptor, - [](auto node) { return node.opcode() == HloOpcode::kDynamicSlice; }); + [&](const HloInstruction* start, const ShapeIndex& index, + unsigned param_idx) -> absl::StatusOr { + auto slice_adaptor = + HloFindIf({HloInstructionAdaptor(*start)}, adaptor, [](auto node) { + return IsOpcodeAnyOf( + node); + }); if (slice_adaptor.has_value()) { - slice_instr = const_cast( - static_cast( - &slice_adaptor->instruction())); + auto* slice_instr = + const_cast(&slice_adaptor->instruction()); if (!IsContiguousSlice(slice_instr->operand(0)->shape(), slice_instr->shape())) { @@ -224,14 +173,49 @@ absl::StatusOr EmitDynamicSlicedGemm( "currently"); } - param = Cast(slice_instr->operand(0)); + slice_instrs[param_idx] = slice_instr; + + const auto* param = + Cast(slice_instr->operand(0)); + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice orig_slice, + GetAllocationSlice(buffer_assignment, + fusion.operand(param->parameter_number()), index)); + + if (auto* static_slice = DynCast(slice_instr)) { + // Update static slices. + const Shape& src_shape = static_slice->operand(0)->shape(); + const Shape& dst_shape = static_slice->shape(); + int64_t size = ShapeUtil::ByteSizeOf(dst_shape); + + // Given this slice + // f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}), + // slice={[1:2], [4:8], [0:8]} + // + // The offset of the slice should be: + // slice_starts(0) * 8 * 8 * sizeof(f16) + + // slice_starts(1) * 8 * sizeof(f16) + int64_t offset = orig_slice.offset(); + for (auto [start, stride] : + llvm::zip(static_slice->slice_starts(), + *ShapeUtil::ByteStrides(src_shape))) { + offset += start * stride; + } + + return BufferAllocation::Slice(orig_slice.allocation(), offset, size); + } + + return orig_slice; } + const auto* param = DynCast(start); return GetAllocationSlice(buffer_assignment, fusion.operand(param->parameter_number()), index); }; auto collect_slice_info = [&](unsigned idx) { + auto* slice_instr = + DynCastOrNull(slice_instrs[idx]); if (slice_instr == nullptr) { return; } @@ -252,29 +236,26 @@ absl::StatusOr EmitDynamicSlicedGemm( : slice_instr->operand(1)->shape(); offset_byte_sizes[idx] = ShapeUtil::ByteSizeOfPrimitiveType( slice_instr->index_operands().front()->shape().element_type()); - - // Reset `slice_instr` for the next call to `collect_slice_info()`. - slice_instr = nullptr; }; - unsigned argument_idx = 0; + unsigned param_idx = 0; TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice, - get_original_operand_slice( - custom_call.operand(argument_idx), /*index=*/{})); - collect_slice_info(argument_idx++); + get_original_operand_slice(custom_call.operand(param_idx), + /*index=*/{}, param_idx)); + collect_slice_info(param_idx++); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice, - get_original_operand_slice( - custom_call.operand(argument_idx), /*index=*/{})); - collect_slice_info(argument_idx++); + get_original_operand_slice(custom_call.operand(param_idx), + /*index=*/{}, param_idx)); + collect_slice_info(param_idx++); BufferAllocation::Slice output; std::optional workspace = std::nullopt; std::optional slice_workspace_fake = std::nullopt; auto get_original_result_slice = - [&](const HloInstruction* start, - const ShapeIndex& index) -> absl::StatusOr { + [&](const HloInstruction* start, const ShapeIndex& index, + unsigned param_idx) -> absl::StatusOr { auto slice_adaptor = HloFindIf( {HloInstructionAdaptor(*start)}, adaptor, [](auto node) { @@ -282,9 +263,9 @@ absl::StatusOr EmitDynamicSlicedGemm( }, false); if (slice_adaptor.has_value()) { - slice_instr = const_cast( - static_cast( - &slice_adaptor->instruction())); + auto* slice_instr = + const_cast(&slice_adaptor->instruction()); + slice_instrs[param_idx] = slice_instr; if (!IsContiguousSlice(slice_instr->shape(), Cast(slice_instr) @@ -299,44 +280,42 @@ absl::StatusOr EmitDynamicSlicedGemm( return GetAllocationSlice(buffer_assignment, &fusion, index); }; - int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf( - custom_call.shape().IsArray() ? custom_call.shape() - : custom_call.shape().tuple_shapes(0)); - // Handling cases where multiple operands share the same buffer, with // different offset by creating new fake allocations so each operand will have // a different buffer index. The slices can thus always start at offset 0. // AddressComputationThunk will take care of the offset adjustment. std::vector> fake_allocations(4); if (fusion.shape().IsArray()) { - TF_ASSIGN_OR_RETURN(output, - get_original_result_slice(&custom_call, /*index=*/{})); - collect_slice_info(argument_idx); + TF_ASSIGN_OR_RETURN(output, get_original_result_slice( + &custom_call, /*index=*/{}, param_idx)); + collect_slice_info(param_idx); } else { TF_ASSIGN_OR_RETURN( - output, get_original_result_slice(&custom_call, - /*index=*/{kGEMMOutputBufferIndex})); - collect_slice_info(argument_idx++); + output, + get_original_result_slice( + &custom_call, /*index=*/{kGEMMOutputBufferIndex}, param_idx)); + collect_slice_info(param_idx++); + // TODO(vuson): If we want to support slices of workspace, we'd need to // start `HloFindIf` with `get-tuple-element` with the right index. TF_ASSIGN_OR_RETURN( workspace, GetAllocationSlice(buffer_assignment, &fusion, /*index=*/{kGEMMWorkspaceBufferIndex})); - collect_slice_info(argument_idx); - fake_allocations[3] = std::make_unique( - /*index=*/3, workspace->size(), /*color=*/0); - slice_workspace_fake = BufferAllocation::Slice(fake_allocations[3].get(), 0, - workspace->size()); + collect_slice_info(param_idx); + fake_allocations[param_idx] = std::make_unique( + /*index=*/param_idx, workspace->size(), /*color=*/0); + slice_workspace_fake = BufferAllocation::Slice( + fake_allocations[param_idx].get(), 0, workspace->size()); } - if (absl::c_all_of(offset_buffer_indices, [&](auto offset_slices) { - return offset_slices == std::nullopt; - })) + if (absl::c_all_of(slice_instrs, [&](auto slice_instr) { + return slice_instr == nullptr; + })) { return absl::InternalError( "DynamicAddressComputationFusion expects at least one sliced " "operand/result"); + } - // Creating embedded GEMM thunk. bool deterministic_ops = ir_emitter_context.debug_options().xla_gpu_deterministic_ops(); @@ -344,38 +323,55 @@ absl::StatusOr EmitDynamicSlicedGemm( GemmConfig config, GemmConfig::For(static_cast(&custom_call))); - int64_t lhs_byte_size = - ShapeUtil::ByteSizeOf(custom_call.operand(kLHSOperandIndex)->shape()); - fake_allocations[kLHSOperandIndex] = std::make_unique( - /*index=*/kLHSOperandIndex, lhs_byte_size, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake( - fake_allocations[kLHSOperandIndex].get(), 0, lhs_byte_size); - - int64_t rhs_byte_size = - ShapeUtil::ByteSizeOf(custom_call.operand(kRHSOperandIndex)->shape()); - fake_allocations[kRHSOperandIndex] = std::make_unique( - /*index=*/kRHSOperandIndex, rhs_byte_size, /*color=*/0); - BufferAllocation::Slice slice_rhs_fake( - fake_allocations[kRHSOperandIndex].get(), 0, rhs_byte_size); - - fake_allocations[2] = std::make_unique( - /*index=*/2, out_fake_byte_size, /*color=*/0); - BufferAllocation::Slice slice_out_fake(fake_allocations[2].get(), 0, - out_fake_byte_size); - ThunkSequence seq; - seq.emplace_back(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), std::move(config), - slice_lhs_fake, slice_rhs_fake, slice_out_fake, slice_workspace_fake, - deterministic_ops)); - - std::vector> arguments{ - lhs_slice, rhs_slice, output, workspace}; - - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), - std::make_unique(std::move(seq)), arguments, - std::move(fake_allocations), offset_buffer_indices, orig_shapes, - sliced_shapes, offset_byte_sizes); + std::unique_ptr thunk; + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&custom_call); + + if (absl::c_any_of(slice_instrs, [&](auto slice_instr) { + return DynCastOrNull(slice_instr) != + nullptr; + })) { + // Creating embedded GEMM thunk. + unsigned arg_idx = 0; + int64_t lhs_byte_size = + ShapeUtil::ByteSizeOf(custom_call.operand(arg_idx)->shape()); + fake_allocations[arg_idx] = std::make_unique( + /*index=*/arg_idx, lhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_lhs_fake(fake_allocations[arg_idx].get(), 0, + lhs_byte_size); + + arg_idx++; + int64_t rhs_byte_size = + ShapeUtil::ByteSizeOf(custom_call.operand(arg_idx)->shape()); + fake_allocations[arg_idx] = std::make_unique( + /*index=*/arg_idx, rhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_rhs_fake(fake_allocations[arg_idx].get(), 0, + rhs_byte_size); + + arg_idx++; + int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf( + custom_call.shape().IsArray() ? custom_call.shape() + : custom_call.shape().tuple_shapes(0)); + fake_allocations[arg_idx] = std::make_unique( + /*index=*/arg_idx, out_fake_byte_size, /*color=*/0); + BufferAllocation::Slice slice_out_fake(fake_allocations[arg_idx].get(), 0, + out_fake_byte_size); + ThunkSequence seq; + seq.emplace_back(std::make_unique( + thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake, + slice_out_fake, slice_workspace_fake, deterministic_ops)); + + std::vector> arguments{ + lhs_slice, rhs_slice, output, workspace}; + + thunk = std::make_unique( + thunk_info, std::make_unique(std::move(seq)), arguments, + std::move(fake_allocations), offset_buffer_indices, orig_shapes, + sliced_shapes, offset_byte_sizes); + } else { + thunk = std::make_unique(thunk_info, std::move(config), + lhs_slice, rhs_slice, output, workspace, + deterministic_ops); + } FusionEmissionResult result; result.thunks.push_back(std::move(thunk)); @@ -633,8 +629,7 @@ absl::StatusOr DynamicAddressComputationFusion::Emit( const auto& custom_call = *static_cast( &maybe_custom_call_adaptor->instruction()); if (IsLegacyCublasMatmul(custom_call)) { - return EmitDynamicSlicedGemm(ir_emitter_context, adaptor, fusion, - custom_call); + return EmitGemm(ir_emitter_context, adaptor, fusion, custom_call); } return absl::UnimplementedError(absl::StrCat( From 7e9d002edd604b73cd4d4ee3437fbad3e8a6b969 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 30 Mar 2024 01:42:11 -0700 Subject: [PATCH 106/124] [xla:gpu][NFC] Remove unused constexprs PiperOrigin-RevId: 620446803 --- third_party/xla/xla/service/gpu/fusions/custom.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 48d9ff6fc9cc6a..2397cd0bafbc99 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -70,9 +70,6 @@ namespace xla { namespace gpu { namespace { -constexpr unsigned kLHSOperandIndex = 0; -constexpr unsigned kRHSOperandIndex = 1; - constexpr unsigned kGEMMOutputBufferIndex = 0; constexpr unsigned kGEMMWorkspaceBufferIndex = 1; @@ -615,8 +612,6 @@ absl::StatusOr AddressComputationFusion::Emit( absl::StatusOr DynamicAddressComputationFusion::Emit( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { - // std::cerr << "TYB \n" - // << fusion.fused_instructions_computation()->ToString() << '\n'; const HloFusionAdaptor& adaptor = analysis_.fusion(); auto maybe_custom_call_adaptor = HloFindIf( adaptor.GetRoots(), adaptor, From c96e6c560c2d89f6f5965146b15c1840a1b1a617 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 30 Mar 2024 02:02:02 -0700 Subject: [PATCH 107/124] Update GraphDef version to 1817. PiperOrigin-RevId: 620449190 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 6352dcf15edd44..562fb1fe7d5136 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1816 // Updated: 2024/3/29 +#define TF_GRAPH_DEF_VERSION 1817 // Updated: 2024/3/30 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From c65e20552379c3d0e4f03391256546f2e729581b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 30 Mar 2024 02:02:03 -0700 Subject: [PATCH 108/124] compat: Update forward compatibility horizon to 2024-03-30 PiperOrigin-RevId: 620449193 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 149839623ef935..5de1cb41d21693 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 29) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 30) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From fd091d499a90bd73e9eedcd648bfe35229e97ce3 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 30 Mar 2024 02:39:15 -0700 Subject: [PATCH 109/124] [xla:gpu] Unify static and dynamic slice cases for AddressComputationFusionRewriter PiperOrigin-RevId: 620453556 --- .../address_computation_fusion_rewriter.cc | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index afb429b1942ab3..03317a8f09e166 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -138,8 +138,7 @@ bool IsAlignedSlice(const Shape& src_shape, const Shape& dst_shape, return true; } -UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, - bool dynamic) { +UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { UseDefDataflowPaths sliced_operand_paths; auto fusion = HloFusionAdaptor::ForComputation(instr->parent()); @@ -164,34 +163,31 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr, auto maybe_slice_adaptor = HloFindIf({HloInstructionAdaptor(*operand)}, *fusion, [&](auto node) { const HloInstruction* cur = &node.instruction(); + // If the node is a match that has been processed, stop the traversal. if (processed_instrs.contains(cur)) return true; + maybe_sliced_operand_path.push_back(const_cast(cur)); - if (dynamic) { - if (const auto slice_instr = - DynCast(cur)) { - if (IsAlignedSlice(slice_instr->operand(0)->shape(), - slice_instr->shape(), nullptr)) { - slice_found = true; - return slice_found; - } - } - } else { - if (const auto slice_instr = DynCast(cur)) { - if (IsAlignedSlice(slice_instr->operand(0)->shape(), - slice_instr->shape(), slice_instr)) { - slice_found = true; - return slice_found; - } + + if (IsOpcodeAnyOf( + node)) { + if (IsAlignedSlice(cur->operand(0)->shape(), cur->shape(), + DynCast(cur))) { + slice_found = true; + return slice_found; } } + // TODO(vuson): lift the first restriction by considering fusing other // uses of the operand to reuse the address computation. Only worth it // if other uses are also custom calls though. return cur->user_count() > 1 || !IsNoOp(cur); }); + if (maybe_slice_adaptor == std::nullopt) continue; + const auto& maybe_slice_instr = maybe_slice_adaptor->instruction(); + if (slice_found || processed_instrs.contains(&maybe_slice_instr)) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced operand path @@ -415,11 +411,11 @@ absl::StatusOr AddressComputationFusionRewriter::Run( for (HloInstruction* instr : computation->instructions()) { if (IsLegacyCublasMatmul(*instr) || (!dynamic && IsCustomCall(instr, platform_name_))) { - auto sliced_operand_paths = GetSlicedOperandPaths(instr, dynamic); + UseDefDataflowPaths sliced_operand_paths = + GetSlicedOperandPaths(instr); bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; - DefUseDataflowPaths sliced_user_paths{}; - if (dynamic) sliced_user_paths = GetSlicedUserPaths(instr); + DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); bool has_sliced_user_paths = absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) { return !sliced_user_path.empty(); @@ -464,9 +460,14 @@ absl::StatusOr AddressComputationFusionRewriter::Run( DataflowPathsView(sliced_user_paths_view), captures)); - TF_ASSIGN_OR_RETURN(HloInstruction * fusion, - CreateFusionInstruction(module, hero, captures, - fusion_body, dynamic)); + bool has_dynamic_slices = + absl::c_any_of(matched_instrs, [&](auto* instr) { + return DynCast(instr) != nullptr; + }); + TF_ASSIGN_OR_RETURN( + HloInstruction * fusion, + CreateFusionInstruction(module, hero, captures, fusion_body, + has_dynamic_slices)); HloComputation* parent = hero->parent(); if (fusion->shape().IsTuple()) { From 92b03bde2ec875d863d67fb5b6f9001f3bc7ab92 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 30 Mar 2024 12:09:23 -0700 Subject: [PATCH 110/124] [xla:gpu][NFC] Make lambdas static functions for better reusability PiperOrigin-RevId: 620511999 --- third_party/xla/xla/service/gpu/fusions/BUILD | 1 + .../xla/xla/service/gpu/fusions/custom.cc | 322 ++++++++++-------- 2 files changed, 179 insertions(+), 144 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 87995e26bf2ec4..dd386e6d354270 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -147,6 +147,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 2397cd0bafbc99..999f7d967da5ad 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/AsmParser/AsmParser.h" // from @llvm-project @@ -137,6 +138,130 @@ absl::StatusOr GetSliceWithUpdatedOffsetAndSize( return BufferAllocation::Slice(orig_slice.allocation(), offset, size); } +absl::StatusOr GetOperandSlice( + const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor, + const HloInstruction& fusion_instr, const HloInstruction& start_instr, + std::vector& slice_instrs, const ShapeIndex& shape_idx, + unsigned arg_idx) { + auto slice_adaptor = + HloFindIf({HloInstructionAdaptor(start_instr)}, adaptor, [](auto node) { + return IsOpcodeAnyOf(node); + }); + if (slice_adaptor.has_value()) { + auto* slice_instr = + const_cast(&slice_adaptor->instruction()); + + if (!IsContiguousSlice(slice_instr->operand(0)->shape(), + slice_instr->shape())) { + return absl::InternalError( + "DynamicAddressComputationFusion only handles contiguous slices " + "currently"); + } + + slice_instrs[arg_idx] = slice_instr; + + const auto* param = Cast(slice_instr->operand(0)); + TF_ASSIGN_OR_RETURN( + BufferAllocation::Slice orig_slice, + GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + shape_idx)); + + if (auto* static_slice = DynCast(slice_instr)) { + // Update static slices. + const Shape& src_shape = static_slice->operand(0)->shape(); + const Shape& dst_shape = static_slice->shape(); + int64_t size = ShapeUtil::ByteSizeOf(dst_shape); + + // Given this slice + // f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}), + // slice={[1:2], [4:8], [0:8]} + // + // The offset of the slice should be: + // slice_starts(0) * 8 * 8 * sizeof(f16) + + // slice_starts(1) * 8 * sizeof(f16) + int64_t offset = orig_slice.offset(); + for (auto [start, stride] : + llvm::zip(static_slice->slice_starts(), + *ShapeUtil::ByteStrides(src_shape))) { + offset += start * stride; + } + + return BufferAllocation::Slice(orig_slice.allocation(), offset, size); + } + + return orig_slice; + } + + const auto* param = DynCast(&start_instr); + return GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + shape_idx); +} + +absl::Status CollectSliceInfo( + const BufferAssignment& buffer_assignment, + const HloInstruction& fusion_instr, + absl::Span slice_instrs, + std::vector>>& + offset_buffer_indices, + std::vector>& orig_shapes, + std::vector>& sliced_shapes, + std::vector>& offset_byte_sizes, unsigned arg_idx) { + auto* slice_instr = + DynCastOrNull(slice_instrs[arg_idx]); + if (slice_instr == nullptr) { + return absl::OkStatus(); + } + + std::vector offset_slices; + for (auto idx_op : slice_instr->index_operands()) { + const auto* param = Cast(idx_op); + TF_ASSIGN_OR_RETURN( + auto offset_slice, + GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + /*index=*/{})); + offset_slices.push_back(offset_slice); + } + offset_buffer_indices[arg_idx] = std::move(offset_slices); + orig_shapes[arg_idx] = slice_instr->operand(0)->shape(); + sliced_shapes[arg_idx] = DynCast(slice_instr) + ? slice_instr->shape() + : slice_instr->operand(1)->shape(); + offset_byte_sizes[arg_idx] = ShapeUtil::ByteSizeOfPrimitiveType( + slice_instr->index_operands().front()->shape().element_type()); + + return absl::OkStatus(); +} + +absl::StatusOr GetResultSlice( + const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor, + const HloInstruction& fusion_instr, const HloInstruction& start_instr, + std::vector& slice_instrs, const ShapeIndex& shape_idx, + unsigned arg_idx) { + auto slice_adaptor = HloFindIf( + {HloInstructionAdaptor(start_instr)}, adaptor, + [](auto node) { return node.opcode() == HloOpcode::kDynamicUpdateSlice; }, + false); + if (slice_adaptor.has_value()) { + auto* slice_instr = + const_cast(&slice_adaptor->instruction()); + slice_instrs[arg_idx] = slice_instr; + + if (!IsContiguousSlice(slice_instr->shape(), + Cast(slice_instr) + ->update() + ->shape())) { + return absl::InternalError( + "DynamicAddressComputationFusion only handles contiguous slices " + "currently"); + } + } + + return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx); +} + absl::StatusOr EmitGemm( IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, const HloFusionInstruction& fusion, @@ -151,158 +276,67 @@ absl::StatusOr EmitGemm( std::vector> offset_byte_sizes(4, std::nullopt); std::vector slice_instrs(4, nullptr); - auto get_original_operand_slice = - [&](const HloInstruction* start, const ShapeIndex& index, - unsigned param_idx) -> absl::StatusOr { - auto slice_adaptor = - HloFindIf({HloInstructionAdaptor(*start)}, adaptor, [](auto node) { - return IsOpcodeAnyOf( - node); - }); - if (slice_adaptor.has_value()) { - auto* slice_instr = - const_cast(&slice_adaptor->instruction()); - - if (!IsContiguousSlice(slice_instr->operand(0)->shape(), - slice_instr->shape())) { - return absl::InternalError( - "DynamicAddressComputationFusion only handles contiguous slices " - "currently"); - } - - slice_instrs[param_idx] = slice_instr; - - const auto* param = - Cast(slice_instr->operand(0)); - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice orig_slice, - GetAllocationSlice(buffer_assignment, - fusion.operand(param->parameter_number()), index)); - - if (auto* static_slice = DynCast(slice_instr)) { - // Update static slices. - const Shape& src_shape = static_slice->operand(0)->shape(); - const Shape& dst_shape = static_slice->shape(); - int64_t size = ShapeUtil::ByteSizeOf(dst_shape); - - // Given this slice - // f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}), - // slice={[1:2], [4:8], [0:8]} - // - // The offset of the slice should be: - // slice_starts(0) * 8 * 8 * sizeof(f16) + - // slice_starts(1) * 8 * sizeof(f16) - int64_t offset = orig_slice.offset(); - for (auto [start, stride] : - llvm::zip(static_slice->slice_starts(), - *ShapeUtil::ByteStrides(src_shape))) { - offset += start * stride; - } - - return BufferAllocation::Slice(orig_slice.allocation(), offset, size); - } - - return orig_slice; - } - const auto* param = DynCast(start); - return GetAllocationSlice(buffer_assignment, - fusion.operand(param->parameter_number()), index); - }; - - auto collect_slice_info = [&](unsigned idx) { - auto* slice_instr = - DynCastOrNull(slice_instrs[idx]); - if (slice_instr == nullptr) { - return; - } - - std::vector offset_slices; - for (auto idx_op : slice_instr->index_operands()) { - const auto* param = Cast(idx_op); - offset_slices.push_back( - GetAllocationSlice(buffer_assignment, - fusion.operand(param->parameter_number()), - /*index=*/{}) - .value()); - } - offset_buffer_indices[idx] = std::move(offset_slices); - orig_shapes[idx] = slice_instr->operand(0)->shape(); - sliced_shapes[idx] = DynCast(slice_instr) - ? slice_instr->shape() - : slice_instr->operand(1)->shape(); - offset_byte_sizes[idx] = ShapeUtil::ByteSizeOfPrimitiveType( - slice_instr->index_operands().front()->shape().element_type()); - }; - - unsigned param_idx = 0; + unsigned arg_idx = 0; TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_slice, - get_original_operand_slice(custom_call.operand(param_idx), - /*index=*/{}, param_idx)); - collect_slice_info(param_idx++); + GetOperandSlice(buffer_assignment, adaptor, fusion, + *custom_call.operand(arg_idx), + slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx++)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_slice, - get_original_operand_slice(custom_call.operand(param_idx), - /*index=*/{}, param_idx)); - collect_slice_info(param_idx++); + GetOperandSlice(buffer_assignment, adaptor, fusion, + *custom_call.operand(arg_idx), + slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx++)); BufferAllocation::Slice output; std::optional workspace = std::nullopt; std::optional slice_workspace_fake = std::nullopt; - auto get_original_result_slice = - [&](const HloInstruction* start, const ShapeIndex& index, - unsigned param_idx) -> absl::StatusOr { - auto slice_adaptor = HloFindIf( - {HloInstructionAdaptor(*start)}, adaptor, - [](auto node) { - return node.opcode() == HloOpcode::kDynamicUpdateSlice; - }, - false); - if (slice_adaptor.has_value()) { - auto* slice_instr = - const_cast(&slice_adaptor->instruction()); - slice_instrs[param_idx] = slice_instr; - - if (!IsContiguousSlice(slice_instr->shape(), - Cast(slice_instr) - ->update() - ->shape())) { - return absl::InternalError( - "DynamicAddressComputationFusion only handles contiguous slices " - "currently"); - } - } - - return GetAllocationSlice(buffer_assignment, &fusion, index); - }; - // Handling cases where multiple operands share the same buffer, with // different offset by creating new fake allocations so each operand will have // a different buffer index. The slices can thus always start at offset 0. // AddressComputationThunk will take care of the offset adjustment. std::vector> fake_allocations(4); if (fusion.shape().IsArray()) { - TF_ASSIGN_OR_RETURN(output, get_original_result_slice( - &custom_call, /*index=*/{}, param_idx)); - collect_slice_info(param_idx); + TF_ASSIGN_OR_RETURN( + output, GetResultSlice(buffer_assignment, adaptor, fusion, custom_call, + slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx)); } else { TF_ASSIGN_OR_RETURN( output, - get_original_result_slice( - &custom_call, /*index=*/{kGEMMOutputBufferIndex}, param_idx)); - collect_slice_info(param_idx++); + GetResultSlice(buffer_assignment, adaptor, fusion, custom_call, + slice_instrs, /*shape_idx=*/{kGEMMOutputBufferIndex}, + arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx++)); // TODO(vuson): If we want to support slices of workspace, we'd need to // start `HloFindIf` with `get-tuple-element` with the right index. TF_ASSIGN_OR_RETURN( workspace, GetAllocationSlice(buffer_assignment, &fusion, /*index=*/{kGEMMWorkspaceBufferIndex})); - collect_slice_info(param_idx); - fake_allocations[param_idx] = std::make_unique( - /*index=*/param_idx, workspace->size(), /*color=*/0); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, absl::Span(slice_instrs), + offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes, + arg_idx)); + fake_allocations[arg_idx] = std::make_unique( + /*index=*/arg_idx, workspace->size(), /*color=*/0); slice_workspace_fake = BufferAllocation::Slice( - fake_allocations[param_idx].get(), 0, workspace->size()); + fake_allocations[arg_idx].get(), 0, workspace->size()); } if (absl::c_all_of(slice_instrs, [&](auto slice_instr) { @@ -328,30 +362,30 @@ absl::StatusOr EmitGemm( nullptr; })) { // Creating embedded GEMM thunk. - unsigned arg_idx = 0; + unsigned fake_arg_idx = 0; int64_t lhs_byte_size = - ShapeUtil::ByteSizeOf(custom_call.operand(arg_idx)->shape()); - fake_allocations[arg_idx] = std::make_unique( - /*index=*/arg_idx, lhs_byte_size, /*color=*/0); - BufferAllocation::Slice slice_lhs_fake(fake_allocations[arg_idx].get(), 0, - lhs_byte_size); + ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape()); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, lhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_lhs_fake(fake_allocations[fake_arg_idx].get(), + 0, lhs_byte_size); - arg_idx++; + fake_arg_idx++; int64_t rhs_byte_size = - ShapeUtil::ByteSizeOf(custom_call.operand(arg_idx)->shape()); - fake_allocations[arg_idx] = std::make_unique( - /*index=*/arg_idx, rhs_byte_size, /*color=*/0); - BufferAllocation::Slice slice_rhs_fake(fake_allocations[arg_idx].get(), 0, - rhs_byte_size); + ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape()); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, rhs_byte_size, /*color=*/0); + BufferAllocation::Slice slice_rhs_fake(fake_allocations[fake_arg_idx].get(), + 0, rhs_byte_size); - arg_idx++; + fake_arg_idx++; int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf( custom_call.shape().IsArray() ? custom_call.shape() : custom_call.shape().tuple_shapes(0)); - fake_allocations[arg_idx] = std::make_unique( - /*index=*/arg_idx, out_fake_byte_size, /*color=*/0); - BufferAllocation::Slice slice_out_fake(fake_allocations[arg_idx].get(), 0, - out_fake_byte_size); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, out_fake_byte_size, /*color=*/0); + BufferAllocation::Slice slice_out_fake(fake_allocations[fake_arg_idx].get(), + 0, out_fake_byte_size); ThunkSequence seq; seq.emplace_back(std::make_unique( thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake, From b2e993e84f8c7d03281ddf78dd8817e468bad83b Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 30 Mar 2024 15:05:27 -0700 Subject: [PATCH 111/124] [xla:gpu][NFC] Explicitly rewrite AddressComputationFusion in custom call tests AddressComputationFusionRewriter is now part of `RunHloPasses`, we need to explicitly call it to transform the HLO in order to keep tests meaningful. PiperOrigin-RevId: 620529073 --- third_party/xla/xla/service/gpu/fusions/BUILD | 1 + .../address_computation_fusion_test.cc | 28 +++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index dd386e6d354270..2fbbcd17d47da1 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -173,6 +173,7 @@ xla_test( "//xla/service:custom_call_target_registry", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/service/gpu:address_computation_fusion_rewriter", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:gpu_types_header", diff --git a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc index 341cf154394a87..b3832da27130cf 100644 --- a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/address_computation_fusion_rewriter.h" #include "xla/service/hlo_module_config.h" #include "xla/service/service_executable_run_options.h" #include "xla/shape.h" @@ -887,14 +888,14 @@ TEST_F(AddressComputationFusionTest, CustomCallSimple) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); - hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), - error_spec, - /*run_hlo_passes=*/false)); + error_spec, /*run_hlo_passes=*/false)); } static absl::Status SubBuffers(se::Stream* stream, ffi::BufferBase src0, @@ -993,9 +994,12 @@ TEST_F(AddressComputationFusionTest, CustomCallWithTuple) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), - error_spec, - /*run_hlo_passes=*/false)); + error_spec, /*run_hlo_passes=*/false)); } static absl::Status NoOp(se::Stream* stream, ffi::BufferBase operand) { @@ -1039,6 +1043,10 @@ TEST_F(AddressComputationFusionTest, NilTuple) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), error_spec, /*run_hlo_passes=*/false)); @@ -1079,6 +1087,10 @@ TEST_F(AddressComputationFusionTest, CustomCallLegacyAPI) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), error_spec, /*run_hlo_passes=*/false)); @@ -1113,6 +1125,10 @@ TEST_F(AddressComputationFusionTest, NilTupleLegacyAPI) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), error_spec, /*run_hlo_passes=*/false)); From e6bb7838413df021df74b9465d63720e1380d2f1 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sat, 30 Mar 2024 20:55:31 -0700 Subject: [PATCH 112/124] [xla:gpu][NFC] Use the same helpers to get slices for GEMM and generic custom call emissions PiperOrigin-RevId: 620567579 --- .../xla/xla/service/gpu/fusions/custom.cc | 88 ++++++------------- 1 file changed, 28 insertions(+), 60 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 999f7d967da5ad..10eb3942924003 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -85,59 +85,6 @@ absl::StatusOr> BuildCustomKernelThunkForFusion( &fusion, std::move(custom_kernel), std::move(kernel_arguments.args())); } -absl::StatusOr GetSliceWithUpdatedOffsetAndSize( - const BufferAssignment& buffer_assignment, const HloFusionAdaptor& fusion, - const HloInstruction& fusion_instr, const HloInstruction& start, - const ShapeIndex& index) { - if (const auto* param = DynCast(&start)) { - return GetAllocationSlice(buffer_assignment, - fusion_instr.operand(param->parameter_number()), - index); - } - - auto slice_adaptor = - HloFindIf({HloInstructionAdaptor(start)}, fusion, - [](auto node) { return node.opcode() == HloOpcode::kSlice; }); - if (!slice_adaptor.has_value()) { - return absl::InternalError( - "AddressComputationFusion expects at least one sliced operand"); - } - - const auto& slice_instr = - *static_cast(&slice_adaptor->instruction()); - - if (!IsContiguousSlice(slice_instr)) { - return absl::InternalError( - "AddressComputationFusion only handles contiguous slices currently"); - } - - const Shape& src_shape = slice_instr.operand(0)->shape(); - const Shape& dst_shape = slice_instr.shape(); - int64_t size = ShapeUtil::ByteSizeOf(dst_shape); - - const auto* param = Cast(slice_instr.operand(0)); - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice orig_slice, - GetAllocationSlice(buffer_assignment, - fusion_instr.operand(param->parameter_number()), - index)); - - // Given this slice - // f16[1,4,8]{2,1,0} slice(f16[2,8,8]{2,1,0}), - // slice={[1:2], [4:8], [0:8]} - // - // The offset of the slice should be: - // slice_starts(0) * 8 * 8 * sizeof(f16) + - // slice_starts(1) * 8 * sizeof(f16) - int64_t offset = orig_slice.offset(); - for (auto [start, stride] : llvm::zip(slice_instr.slice_starts(), - *ShapeUtil::ByteStrides(src_shape))) { - offset += start * stride; - } - - return BufferAllocation::Slice(orig_slice.allocation(), offset, size); -} - absl::StatusOr GetOperandSlice( const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor, const HloInstruction& fusion_instr, const HloInstruction& start_instr, @@ -343,7 +290,7 @@ absl::StatusOr EmitGemm( return slice_instr == nullptr; })) { return absl::InternalError( - "DynamicAddressComputationFusion expects at least one sliced " + "AddressComputationFusion expects at least one sliced " "operand/result"); } @@ -441,21 +388,31 @@ absl::StatusOr EmitCustomCall( using Slices = std::vector>; + int64_t num_args = ShapeUtil::GetLeafCount(custom_call.shape()); + absl::c_for_each(custom_call.operands(), [&](auto* operand) { + num_args += ShapeUtil::GetLeafCount(operand->shape()); + }); + + std::vector slice_instrs(num_args, nullptr); + Slices operands; - // TODO(vuson): add test with custom call with token-typed operands + unsigned arg_idx = 0; + // TODO(vuson): add test for custom call with token-typed operands for (auto* operand : custom_call.operands()) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { if (subshape.IsToken()) { + arg_idx++; operands.push_back(std::nullopt); return absl::OkStatus(); } if (!subshape.IsArray()) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto slice, GetSliceWithUpdatedOffsetAndSize( - buffer_assignment, adaptor, - fusion, *operand, index)); + TF_ASSIGN_OR_RETURN( + auto slice, + GetOperandSlice(buffer_assignment, adaptor, fusion, *operand, + slice_instrs, /*shape_idx=*/index, arg_idx++)); operands.push_back(CustomCallThunk::Slice{slice, subshape}); return absl::OkStatus(); })); @@ -463,8 +420,9 @@ absl::StatusOr EmitCustomCall( Slices results; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - fusion.shape(), [&](const Shape& subshape, const ShapeIndex& index) { + custom_call.shape(), [&](const Shape& subshape, const ShapeIndex& index) { if (subshape.IsToken()) { + arg_idx++; results.push_back(std::nullopt); return absl::OkStatus(); } @@ -472,11 +430,21 @@ absl::StatusOr EmitCustomCall( return absl::OkStatus(); } TF_ASSIGN_OR_RETURN( - auto slice, GetAllocationSlice(buffer_assignment, &fusion, index)); + auto slice, + GetResultSlice(buffer_assignment, adaptor, fusion, custom_call, + slice_instrs, /*shape_idx=*/index, arg_idx++)); results.push_back(CustomCallThunk::Slice{slice, subshape}); return absl::OkStatus(); })); + if (absl::c_all_of(slice_instrs, [&](auto slice_instr) { + return slice_instr == nullptr; + })) { + return absl::InternalError( + "AddressComputationFusion expects at least one sliced " + "operand/result"); + } + // For legacy custom calls we convert all API versions into the latest // status-returning one and pass backend config as an opaque string. CustomCallThunk::CustomCallTarget custom_call_target; From b4a779f77fab3d8ef8caab1f12cf3b7c4509dabf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 30 Mar 2024 20:58:08 -0700 Subject: [PATCH 113/124] Enable SparseCore threads in TpuLayoutAssignment PiperOrigin-RevId: 620567851 --- .../xla/xla/hlo/utils/hlo_sharding_util.cc | 41 +++++++++++++++++- .../xla/xla/hlo/utils/hlo_sharding_util.h | 7 +++ .../xla/hlo/utils/hlo_sharding_util_test.cc | 31 +++++++++++++ .../xla/xla/service/layout_assignment.cc | 43 ++++++++++--------- .../xla/xla/service/layout_assignment.h | 37 +++++++++++++++- 5 files changed, 135 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 0661c620c4d5f5..0177abc48f506f 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -3063,13 +3063,52 @@ Shape UntileLeafShape(const HloSharding& sharding, const Shape& shape) { if (sharding.IsTileMaximal() || sharding.IsManual() || sharding.IsUnknown()) { return shape; } + if (!shape.IsArray()) { + return shape; + } Shape result_shape = shape; - for (int64_t i = 0; i < sharding.TiledDataRank(); ++i) { + // sharding.TiledDataRank() == i < shape.rank() is not always true? + for (int64_t i = 0; i < sharding.TiledDataRank() && i < shape.rank(); ++i) { result_shape.set_dimensions( i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); } return result_shape; } +Shape TileShape(const HloSharding& sharding, const Shape& shape) { + if (!sharding.IsTuple()) { + return TileLeafShape(sharding, shape); + } + Shape result_shape = shape; + ShapeUtil::ForEachMutableSubshape( + &result_shape, + [&shape, &sharding](Shape* subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(shape, index)) { + return; + } + const HloSharding& subshape_sharding = + sharding.GetSubSharding(shape, index); + *subshape = TileLeafShape(subshape_sharding, *subshape); + }); + + return result_shape; +} + +Shape TileLeafShape(const HloSharding& sharding, const Shape& shape) { + if (sharding.IsTileMaximal() || sharding.IsManual() || sharding.IsUnknown()) { + return shape; + } + if (!shape.IsArray()) { + return shape; + } + Shape result_shape = shape; + for (int64_t i = 0; i < sharding.TiledDataRank() && i < shape.rank(); ++i) { + CHECK_EQ(shape.dimensions(i) % sharding.tile_assignment().dim(i), 0); + result_shape.set_dimensions( + i, shape.dimensions(i) / sharding.tile_assignment().dim(i)); + } + return result_shape; +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index bbf074c408a4a7..8671ebfe2554f2 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -479,6 +479,13 @@ Shape UntileShape(const HloSharding& sharding, const Shape& shape); // REQUIRES: !sharding.IsTuple() Shape UntileLeafShape(const HloSharding& sharding, const Shape& shape); +// Returns the tiled shape. +Shape TileShape(const HloSharding& sharding, const Shape& shape); + +// Returns the tiled shape. +// REQUIRES: !sharding.IsTuple() +Shape TileLeafShape(const HloSharding& sharding, const Shape& shape); + } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc index bcf9d1690cda3c..ad042361a7bf27 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc @@ -823,6 +823,37 @@ TEST(HloShardingUtilTest, IsSortOperandShardingMovableSortDimUnsharded) { iota.set_sharding(HloSharding::IotaTile({1, 2})); EXPECT_FALSE(IsSortOperandShardingMovable(&iota, 0)); } + +TEST(HloShardingUtilTest, TileShape) { + HloSharding sharding = HloSharding::Tile(TileAssignment({4, 1})); + Shape shape_0 = ShapeUtil::MakeShape(F32, {80, 128}); + auto tile_shape_0 = hlo_sharding_util::TileShape(sharding, shape_0); + auto expected_shape_0 = ShapeUtil::MakeShape(F32, {20, 128}); + EXPECT_EQ(tile_shape_0, expected_shape_0); + Shape shape_1 = ShapeUtil::MakeShape(F32, {40, 128}); + auto tile_shape_1 = hlo_sharding_util::TileShape(sharding, shape_1); + auto expected_shape_1 = ShapeUtil::MakeShape(F32, {10, 128}); + EXPECT_EQ(tile_shape_1, expected_shape_1); + const Shape tuple = ShapeUtil::MakeTupleShape({tile_shape_0, tile_shape_1}); + EXPECT_EQ(hlo_sharding_util::TileShape(sharding, tuple), + ShapeUtil::MakeTupleShape({expected_shape_0, expected_shape_1})); +} + +TEST(HloShardingUtilTest, UntileShape) { + HloSharding sharding = HloSharding::Tile(TileAssignment({4, 1})); + Shape shape_0 = ShapeUtil::MakeShape(F32, {80, 128}); + auto tile_shape_0 = hlo_sharding_util::UntileShape(sharding, shape_0); + auto expected_shape_0 = ShapeUtil::MakeShape(F32, {320, 128}); + EXPECT_EQ(tile_shape_0, expected_shape_0); + Shape shape_1 = ShapeUtil::MakeShape(F32, {40, 128}); + auto tile_shape_1 = hlo_sharding_util::UntileShape(sharding, shape_1); + auto expected_shape_1 = ShapeUtil::MakeShape(F32, {160, 128}); + EXPECT_EQ(tile_shape_1, expected_shape_1); + const Shape tuple = ShapeUtil::MakeTupleShape({tile_shape_0, tile_shape_1}); + EXPECT_EQ(hlo_sharding_util::UntileShape(sharding, tuple), + ShapeUtil::MakeTupleShape({expected_shape_0, expected_shape_1})); +} + } // namespace } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 8ca2296875e62b..a6a3b9ebf94480 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -806,14 +806,17 @@ Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout& called_computation_layout = FindOrDie(computation_layouts_, instruction->to_apply()) ->computation_layout(); - TF_RETURN_IF_ERROR(SetInstructionLayout( - called_computation_layout.result_layout().shape(), instruction)); + auto result_shape = UnShardedShape( + instruction, called_computation_layout.result_layout().shape(), -1); + TF_RETURN_IF_ERROR(SetInstructionLayout(result_shape, instruction)); TF_RET_CHECK(instruction->operand_count() == called_computation_layout.parameter_count()); for (int64_t i = 0; i < instruction->operand_count(); ++i) { - TF_RETURN_IF_ERROR(SetOperandLayout( - called_computation_layout.parameter_layout(i).shape(), instruction, - i, /*mandatory=*/true, /*dfs=*/true)); + auto operand_shape = UnShardedShape( + instruction, called_computation_layout.parameter_layout(i).shape(), + i); + TF_RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, i, + /*mandatory=*/true, /*dfs=*/true)); } } else if (instruction->opcode() == HloOpcode::kWhile && computation_layouts_.find(instruction->while_body()) != @@ -963,22 +966,6 @@ bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout()); } -// The operands of a call must match the layouts of parameters in the -// ComputationLayout, and the call instruction itself must match the result -// layout in the ComputationLayout. -Status CheckCallLayout(HloInstruction* call, - const ComputationLayout& computation_layout) { - HloComputation* computation = call->to_apply(); - TF_RET_CHECK(computation->num_parameters() == call->operand_count()); - for (int64_t i = 0; i < computation->num_parameters(); ++i) { - TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( - call->operand(i)->shape(), /*minor_to_major_only=*/true)); - } - TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape( - call->shape(), /*minor_to_major_only=*/true)); - return OkStatus(); -} - // Operands of layout-constrained custom calls must match the expected // constrained layouts. Status CheckCustomCallLayout(HloInstruction* instruction) { @@ -1126,6 +1113,20 @@ Status CheckBroadcastLayout(HloInstruction* broadcast) { } // namespace +Status LayoutAssignment::CheckCallLayout( + HloInstruction* call, const ComputationLayout& computation_layout) { + HloComputation* computation = call->to_apply(); + TF_RET_CHECK(computation->num_parameters() == call->operand_count()); + for (int64_t i = 0; i < computation->num_parameters(); ++i) { + TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape( + ShardedShape(call, call->operand(i)->shape(), i), + /*minor_to_major_only=*/true)); + } + TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape( + ShardedShape(call, call->shape(), -1), /*minor_to_major_only=*/true)); + return OkStatus(); +} + absl::StatusOr LayoutAssignment::CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction) { TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); diff --git a/third_party/xla/xla/service/layout_assignment.h b/third_party/xla/xla/service/layout_assignment.h index b855b315acd75f..22586493917105 100644 --- a/third_party/xla/xla/service/layout_assignment.h +++ b/third_party/xla/xla/service/layout_assignment.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ #define XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ +#include #include #include #include @@ -28,19 +29,26 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/container/node_hash_map.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/map_util.h" #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/logical_buffer.h" #include "xla/service/tuple_points_to_analysis.h" +#include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/status.h" #include "xla/statusor.h" #include "xla/types.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status.h" @@ -117,7 +125,7 @@ class OperandLayoutConstraint : public LayoutConstraint { const ShapeLayout& shape_layout() const { return shape_layout_[0]; } const HloInstruction* instruction() const { return instruction_; } - const int64_t operand_no() const { return operand_no_; } + int64_t operand_no() const { return operand_no_; } const HloInstruction* operand() const { return instruction_->operand(operand_no_); } @@ -196,7 +204,7 @@ class ComputationLayoutConstraint : public LayoutConstraint { class ChannelLayoutConstraints { public: // Construct an empty constraint set. - ChannelLayoutConstraints() {} + ChannelLayoutConstraints() = default; // Returns true if channel_id has a layout constraint. bool IsChannelConstrained(int64_t channel_id) const { @@ -516,6 +524,31 @@ class LayoutAssignment : public HloModulePass { virtual bool InstructionCanChangeLayoutInstance( const HloInstruction* instruction); + // The shapes in caller can be different from the shapes in callee. For + // example, a shape (1024, 128) of an array can be distributed to four threads + // so the shape for each thread is (256, 128). When verifying the callee's + // shapes based on the caller, we should use this function to compute the + // expected shape. The param_id should be the parameter id of the shape or -1 + // for the result output or unknown. + virtual Shape ShardedShape(const HloInstruction* call, const Shape& shape, + int param_id) { + return shape; + } + // When verifying the caller's shapes based on the callee, we should use this + // function to compute the expected shape. + // The param_id should be the parameter id of the shape or -1 for the result + // output or unknown. + virtual Shape UnShardedShape(const HloInstruction* call, const Shape& shape, + int param_id) { + return shape; + } + + // The operands of a call must match the layouts of parameters in the + // ComputationLayout, and the call instruction itself must match the result + // layout in the ComputationLayout. + Status CheckCallLayout(HloInstruction* call, + const ComputationLayout& computation_layout); + private: // Initializes the layout assignment object for a new Run() call. Status Init(HloModule* module); From 83f815b5c0bf5d718361902b058e22d98fb454a1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 31 Mar 2024 02:02:36 -0700 Subject: [PATCH 114/124] compat: Update forward compatibility horizon to 2024-03-31 PiperOrigin-RevId: 620606935 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 5de1cb41d21693..8ca0d0f2d923c6 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 30) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 31) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From c1781fcc12bbf272385076ef1cea20f9f1a822d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 31 Mar 2024 02:03:10 -0700 Subject: [PATCH 115/124] Update GraphDef version to 1818. PiperOrigin-RevId: 620607049 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 562fb1fe7d5136..a71104785c2b9d 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1817 // Updated: 2024/3/30 +#define TF_GRAPH_DEF_VERSION 1818 // Updated: 2024/3/31 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From e8cd9fd702cedf5b67824d4912d01325a9bda3c1 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sun, 31 Mar 2024 12:52:56 -0700 Subject: [PATCH 116/124] [xla:gpu] Generic custom call emission for DynamicAddressComputationFusion emitter PiperOrigin-RevId: 620681261 --- .../address_computation_fusion_rewriter.cc | 7 +- .../address_computation_fusion_test.cc | 318 ++++++++++++++++-- .../xla/xla/service/gpu/fusions/custom.cc | 165 +++++++-- 3 files changed, 432 insertions(+), 58 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index 03317a8f09e166..db5ece796ffe79 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -160,6 +160,11 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { if (aliased_operands.contains(instr->operand_index(operand))) continue; UseDefDataflowPath maybe_sliced_operand_path; bool slice_found = false; + // TODO: currently HloFindIf exits upon encountering the first node that + // matches. This works well if each operand only has 1 data flow (i.e. only + // flows through unary op). We might want to keep finding until the queue is + // empty: if the operand is a tuple, it might have different data flows + // (i.e. 1 for each element). auto maybe_slice_adaptor = HloFindIf({HloInstructionAdaptor(*operand)}, *fusion, [&](auto node) { const HloInstruction* cur = &node.instruction(); @@ -410,7 +415,7 @@ absl::StatusOr AddressComputationFusionRewriter::Run( if (computation->IsFusionComputation()) continue; for (HloInstruction* instr : computation->instructions()) { if (IsLegacyCublasMatmul(*instr) || - (!dynamic && IsCustomCall(instr, platform_name_))) { + (IsCustomCall(instr, platform_name_))) { UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; diff --git a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc index b3832da27130cf..07733a13aa04e4 100644 --- a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc @@ -901,20 +901,27 @@ TEST_F(AddressComputationFusionTest, CustomCallSimple) { static absl::Status SubBuffers(se::Stream* stream, ffi::BufferBase src0, ffi::BufferBase src1, ffi::BufferBase src2, ffi::BufferBase src3, ffi::BufferBase src4, - ffi::BufferBase dst0, ffi::BufferBase dst1, - ffi::BufferBase dst2, ffi::BufferBase dst3, - ffi::BufferBase dst4) { + ffi::BufferBase src5, ffi::BufferBase src6, + ffi::BufferBase src7, ffi::BufferBase dst0, + ffi::BufferBase dst1, ffi::BufferBase dst2, + ffi::BufferBase dst3, ffi::BufferBase dst4, + ffi::BufferBase dst5, ffi::BufferBase dst6) { // src0: param 0 at tuple index {0}, shape f32[128] // src1: param 0 at tuple index {1}, shape f32[256] // src2: param 1 at tuple index {0}, shape f32[1024] // src3: param 1 at tuple index {1}, shape f32[8] // src4: param 2, shape f32[4,8] + // src5: param 3 at tuple index {0, 0}, shape f32[32] + // src6: param 3 at tuple index {0, 1}, shape f32[64] + // src7: param 3 at tuple index {1}, shape f32[3,128] // // dst0: result at tuple index {0}, shape f32[8] // dst1: result at tuple index {1, 0}, shape f32[128] // dst2: result at tuple index {1, 1}, shape f32[256] // dst3: result at tuple index {2}, shape f32[1024] // dst4: result at tuple index {3}, shape f32[4,8] + // dst5: result at tuple index {4}, shape f32[3,128] + // dst6: result at tuple index {5}, shape f32[96] TF_RETURN_IF_ERROR( stream->MemcpyD2D(&dst0.data, src3.data, 8 * sizeof(float))); @@ -926,6 +933,13 @@ static absl::Status SubBuffers(se::Stream* stream, ffi::BufferBase src0, stream->MemcpyD2D(&dst3.data, src2.data, 1024 * sizeof(float))); TF_RETURN_IF_ERROR( stream->MemcpyD2D(&dst4.data, src4.data, 4 * 8 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst5.data, src7.data, 3 * 128 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst6.data, src6.data, 64 * sizeof(float))); + stream_executor::DeviceMemoryBase slice = + dst6.data.GetByteSlice(64 * sizeof(float), 32 * sizeof(float)); + TF_RETURN_IF_ERROR(stream->MemcpyD2D(&slice, src6.data, 32 * sizeof(float))); return absl::OkStatus(); } @@ -937,52 +951,71 @@ XLA_FFI_DEFINE_HANDLER(kSubBuffers, SubBuffers, .Arg() // src2 .Arg() // src3 .Arg() // src4 + .Arg() // src5 + .Arg() // src6 + .Arg() // src7 .Arg() // dst0 .Arg() // dst1 .Arg() // dst2 .Arg() // dst3 .Arg() // dst4 + .Arg() // dst5 + .Arg() // dst6 ); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$subbuffers", PLATFORM, kSubBuffers); TEST_F(AddressComputationFusionTest, CustomCallWithTuple) { XlaBuilder b(TestName()); - CustomCall(&b, "__xla_test$$subbuffers", /*operands=*/ - { - Tuple(&b, - { - Broadcast(ConstantR0WithType(&b, F32, 1), {128}), - Broadcast(ConstantR0WithType(&b, F32, 2), {256}), - }), - Tuple(&b, - { - Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), - Broadcast(ConstantR0WithType(&b, F32, 4), {8}), - }), - Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), - {0, 0}, {4, 8}, {1, 1}), - }, - ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {8}), - ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {128}), - ShapeUtil::MakeShape(F32, {256}), - }), - ShapeUtil::MakeShape(F32, {1024}), - ShapeUtil::MakeShape(F32, {4, 8}), - }), - /*opaque=*/"", - /*has_side_effect=*/false, - /*output_operand_aliasing=*/{}, /*literal=*/nullptr, - /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, - /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + CustomCall( + &b, "__xla_test$$subbuffers", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0}, + {4, 8}, {1, 1}), + Tuple(&b, + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 6), {32}), + Broadcast(ConstantR0WithType(&b, F32, 7), {64}), + }), + Slice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), + "p0"), + {1, 0}, {4, 128}, {1, 1}), + }), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeShape(F32, {3, 128}), + ShapeUtil::MakeShape(F32, {32 + 64}), + }), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); xla::HloModuleConfig hlo_config( xla::ProgramShape(computation.proto().host_program_shape()), - /*ignore_layouts=*/false); + /*ignore_layouts=*/true); DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_address_computation_fusion(false); hlo_config.set_debug_options(debug_options); @@ -2422,6 +2455,225 @@ TEST_F(AddressComputationFusionTest, CublasGemmDUSOffsetOOB) { /*run_hlo_passes=*/false)); } +TEST_F(AddressComputationFusionTest, DynamicCustomCallSimple) { + XlaBuilder b(TestName()); + CustomCall( + &b, "__xla_test$$memcpy", + /*operands=*/ + {DynamicSlice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"), + {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "start0"), + Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start1")}, + {2, 128})}, + ShapeUtil::MakeShape(F32, {2, 128}), /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, DynamicCustomCallWithTuple) { + XlaBuilder b(TestName()); + CustomCall( + &b, "__xla_test$$subbuffers", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0}, + {4, 8}, {1, 1}), + Tuple(&b, + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 6), {32}), + Broadcast(ConstantR0WithType(&b, F32, 7), {64}), + }), + DynamicSlice( + Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), + "p0"), + {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), + "start0"), + Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), + "start1")}, + {3, 128}), + }), + }, + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeShape(F32, {3, 128}), + ShapeUtil::MakeShape(F32, {32 + 64}), + }), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/true); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, /*run_hlo_passes=*/false)); +} + +static absl::Status SubBuffers2(se::Stream* stream, ffi::BufferBase src0, + ffi::BufferBase src1, ffi::BufferBase src2, + ffi::BufferBase src3, ffi::BufferBase src4, + ffi::BufferBase src5, ffi::BufferBase src6, + ffi::BufferBase src7, ffi::BufferBase dst0, + ffi::BufferBase dst1, ffi::BufferBase dst2, + ffi::BufferBase dst3, ffi::BufferBase dst4) { + // src0: param 0 at tuple index {0}, shape f32[128] + // src1: param 0 at tuple index {1}, shape f32[256] + // src2: param 1 at tuple index {0}, shape f32[1024] + // src3: param 1 at tuple index {1}, shape f32[8] + // src4: param 2, shape f32[4,8] + // + // dst0: result at tuple index {0}, shape f32[8] + // dst1: result at tuple index {1, 0}, shape f32[128] + // dst2: result at tuple index {1, 1}, shape f32[256] + // dst3: result at tuple index {2}, shape f32[1024] + // dst4: result at tuple index {3}, shape f32[4,8] + + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst0.data, src3.data, 8 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst1.data, src0.data, 128 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst2.data, src1.data, 256 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst3.data, src2.data, 1024 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst4.data, src4.data, 4 * 8 * sizeof(float))); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kSubBuffers2, SubBuffers2, + ffi::Ffi::Bind() + .Ctx() + .Arg() // src0 + .Arg() // src1 + .Arg() // src2 + .Arg() // src3 + .Arg() // src4 + .Arg() // src5 + .Arg() // src6 + .Arg() // src7 + .Arg() // dst0 + .Arg() // dst1 + .Arg() // dst2 + .Arg() // dst3 + .Arg() // dst4 +); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$subbuffers2", + PLATFORM, kSubBuffers2); + +TEST_F(AddressComputationFusionTest, Test) { + XlaBuilder b(TestName()); + CustomCall( + &b, "Callback_Void", /*operands=*/ + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 1), {128}), + Broadcast(ConstantR0WithType(&b, F32, 2), {256}), + }), + Tuple(&b, + { + Slice(Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), + {512}, {512 + 256}, {1}), + Broadcast(ConstantR0WithType(&b, F32, 4), {8}), + }), + Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0}, + {4, 8}, {1, 1}), + Tuple(&b, + { + Tuple(&b, + { + Broadcast(ConstantR0WithType(&b, F32, 6), {32}), + Broadcast(ConstantR0WithType(&b, F32, 7), {64}), + }), + }), + }, + ShapeUtil::MakeNil(), + // ShapeUtil::MakeShape(F32, {128}), + /*opaque=*/""); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 10eb3942924003..43d6c5d59b8e19 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -90,8 +90,29 @@ absl::StatusOr GetOperandSlice( const HloInstruction& fusion_instr, const HloInstruction& start_instr, std::vector& slice_instrs, const ShapeIndex& shape_idx, unsigned arg_idx) { + if (const auto* param = DynCast(&start_instr)) { + return GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + shape_idx); + } + + // Walk through ShapeIndex to find the real starting point. + auto* start = const_cast(&start_instr); + for (auto idx : shape_idx) { + CHECK(start->shape().IsTuple()); + start = const_cast(start->operand(idx)); + } + + if (const auto* param = DynCast(start)) { + // At this point we've walked through all `shape_idx`, `index` should be + // empty. + return GetAllocationSlice(buffer_assignment, + fusion_instr.operand(param->parameter_number()), + /*index*/ {}); + } + auto slice_adaptor = - HloFindIf({HloInstructionAdaptor(start_instr)}, adaptor, [](auto node) { + HloFindIf({HloInstructionAdaptor(*start)}, adaptor, [](auto node) { return IsOpcodeAnyOf(node); }); if (slice_adaptor.has_value()) { @@ -108,11 +129,13 @@ absl::StatusOr GetOperandSlice( slice_instrs[arg_idx] = slice_instr; const auto* param = Cast(slice_instr->operand(0)); + // At this point we've walked through all `shape_idx`, `index` should be + // empty. TF_ASSIGN_OR_RETURN( BufferAllocation::Slice orig_slice, GetAllocationSlice(buffer_assignment, fusion_instr.operand(param->parameter_number()), - shape_idx)); + /*index*/ {})); if (auto* static_slice = DynCast(slice_instr)) { // Update static slices. @@ -140,10 +163,7 @@ absl::StatusOr GetOperandSlice( return orig_slice; } - const auto* param = DynCast(&start_instr); - return GetAllocationSlice(buffer_assignment, - fusion_instr.operand(param->parameter_number()), - shape_idx); + return absl::InternalError("WTF"); } absl::Status CollectSliceInfo( @@ -190,7 +210,7 @@ absl::StatusOr GetResultSlice( auto slice_adaptor = HloFindIf( {HloInstructionAdaptor(start_instr)}, adaptor, [](auto node) { return node.opcode() == HloOpcode::kDynamicUpdateSlice; }, - false); + /*visit_operands=*/false); if (slice_adaptor.has_value()) { auto* slice_instr = const_cast(&slice_adaptor->instruction()); @@ -342,9 +362,10 @@ absl::StatusOr EmitGemm( lhs_slice, rhs_slice, output, workspace}; thunk = std::make_unique( - thunk_info, std::make_unique(std::move(seq)), arguments, - std::move(fake_allocations), offset_buffer_indices, orig_shapes, - sliced_shapes, offset_byte_sizes); + thunk_info, std::make_unique(std::move(seq)), + std::move(arguments), std::move(fake_allocations), + std::move(offset_buffer_indices), std::move(orig_shapes), + std::move(sliced_shapes), std::move(offset_byte_sizes)); } else { thunk = std::make_unique(thunk_info, std::move(config), lhs_slice, rhs_slice, output, workspace, @@ -393,11 +414,19 @@ absl::StatusOr EmitCustomCall( num_args += ShapeUtil::GetLeafCount(operand->shape()); }); + std::vector>> + offset_buffer_indices(num_args, std::nullopt); + std::vector> orig_shapes(num_args, std::nullopt); + std::vector> sliced_shapes(num_args, std::nullopt); + std::vector> offset_byte_sizes(num_args, + std::nullopt); + std::vector slice_instrs(num_args, nullptr); + std::vector> arguments; - Slices operands; unsigned arg_idx = 0; // TODO(vuson): add test for custom call with token-typed operands + Slices operands; for (auto* operand : custom_call.operands()) { TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { @@ -412,8 +441,14 @@ absl::StatusOr EmitCustomCall( TF_ASSIGN_OR_RETURN( auto slice, GetOperandSlice(buffer_assignment, adaptor, fusion, *operand, - slice_instrs, /*shape_idx=*/index, arg_idx++)); + slice_instrs, /*shape_idx=*/index, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, + absl::Span(slice_instrs), offset_buffer_indices, + orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++)); + operands.push_back(CustomCallThunk::Slice{slice, subshape}); + arguments.push_back(slice); return absl::OkStatus(); })); } @@ -432,8 +467,14 @@ absl::StatusOr EmitCustomCall( TF_ASSIGN_OR_RETURN( auto slice, GetResultSlice(buffer_assignment, adaptor, fusion, custom_call, - slice_instrs, /*shape_idx=*/index, arg_idx++)); + slice_instrs, /*shape_idx=*/index, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion, + absl::Span(slice_instrs), offset_buffer_indices, + orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++)); + results.push_back(CustomCallThunk::Slice{slice, subshape}); + arguments.push_back(slice); return absl::OkStatus(); })); @@ -517,23 +558,101 @@ absl::StatusOr EmitCustomCall( custom_call.api_version()); } - auto ffi_thunk = [&] { + std::unique_ptr thunk; + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(&custom_call); + + auto ffi_thunk = [&](Slices ops, Slices res) { auto& called_computations = custom_call.called_computations(); return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), - registration->handler, std::move(operands), std::move(results), + thunk_info, registration->handler, std::move(ops), std::move(res), std::move(attributes), called_computations.empty() ? nullptr : called_computations[0]); }; - auto legacy_thunk = [&] { + auto legacy_thunk = [&](Slices ops, Slices res) { return std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(&custom_call), - std::move(custom_call_target), std::move(operands), std::move(results), - std::move(opaque)); + thunk_info, std::move(custom_call_target), std::move(ops), + std::move(res), std::move(opaque)); }; + + std::vector> fake_allocations(num_args); + if (absl::c_any_of(slice_instrs, [&](auto slice_instr) { + return DynCastOrNull(slice_instr) != + nullptr; + })) { + // Creating embedded custom call thunk. + unsigned fake_arg_idx = 0; + + Slices fake_operands; + for (auto* operand : custom_call.operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + fake_arg_idx++; + fake_operands.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + + int64_t operand_byte_size = ShapeUtil::ByteSizeOf(subshape); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, operand_byte_size, /*color=*/0); + BufferAllocation::Slice fake_slice( + fake_allocations[fake_arg_idx].get(), 0, operand_byte_size); + + fake_arg_idx++; + fake_operands.push_back( + CustomCallThunk::Slice{fake_slice, subshape}); + return absl::OkStatus(); + })); + } + + Slices fake_results; + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + custom_call.shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsToken()) { + fake_arg_idx++; + fake_results.push_back(std::nullopt); + return absl::OkStatus(); + } + if (!subshape.IsArray()) { + return absl::OkStatus(); + } + + int64_t result_byte_size = ShapeUtil::ByteSizeOf(subshape); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, result_byte_size, /*color=*/0); + BufferAllocation::Slice fake_slice( + fake_allocations[fake_arg_idx].get(), 0, result_byte_size); + + fake_arg_idx++; + fake_results.push_back(CustomCallThunk::Slice{fake_slice, subshape}); + return absl::OkStatus(); + })); + + ThunkSequence seq; + seq.emplace_back( + found_ffi_handler + ? ffi_thunk(std::move(fake_operands), std::move(fake_results)) + : legacy_thunk(std::move(fake_operands), std::move(fake_results))); + + thunk = std::make_unique( + thunk_info, std::make_unique(std::move(seq)), + std::move(arguments), std::move(fake_allocations), + std::move(offset_buffer_indices), std::move(orig_shapes), + std::move(sliced_shapes), std::move(offset_byte_sizes)); + } else { + thunk = found_ffi_handler + ? ffi_thunk(std::move(operands), std::move(results)) + : legacy_thunk(std::move(operands), std::move(results)); + } + FusionEmissionResult result; - result.thunks.push_back(found_ffi_handler ? ffi_thunk() : legacy_thunk()); + result.thunks.push_back(std::move(thunk)); return result; } @@ -629,9 +748,7 @@ absl::StatusOr DynamicAddressComputationFusion::Emit( return EmitGemm(ir_emitter_context, adaptor, fusion, custom_call); } - return absl::UnimplementedError(absl::StrCat( - "No emission for DynamicAddressComputationFusion of custom call ", - custom_call.custom_call_target())); + return EmitCustomCall(ir_emitter_context, adaptor, fusion, custom_call); } } // namespace gpu From 0cc13d257ed9b2be535e6e898e2e4a2c467f51a5 Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Sun, 31 Mar 2024 14:34:02 -0700 Subject: [PATCH 117/124] [xla:gpu] DUS support for generic custom call emission in DynamicAddressComputationFusion emitter PiperOrigin-RevId: 620691743 --- .../address_computation_fusion_test.cc | 137 +++++++++++++++--- .../xla/xla/service/gpu/fusions/custom.cc | 20 ++- 2 files changed, 132 insertions(+), 25 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc index 07733a13aa04e4..abe64600d848e9 100644 --- a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc @@ -1089,7 +1089,7 @@ void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers, const char* /*opaque*/, size_t /*opaque_len*/) { void* src = buffers[0]; void* dst = buffers[1]; - auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128, + auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 3 * 128, gpuMemcpyDeviceToDevice, stream); ASSERT_EQ(err, gpuSuccess); } @@ -1100,9 +1100,9 @@ TEST_F(AddressComputationFusionTest, CustomCallLegacyAPI) { XlaBuilder b(TestName()); CustomCall(&b, "Callback_Memcpy", /*operands=*/ - {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0}, - {128}, {1})}, - ShapeUtil::MakeShape(F32, {128}), /*opaque=*/""); + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {512}), {128}, + {4 * 128}, {1})}, + ShapeUtil::MakeShape(F32, {3 * 128}), /*opaque=*/""); ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); @@ -2570,20 +2570,25 @@ static absl::Status SubBuffers2(se::Stream* stream, ffi::BufferBase src0, ffi::BufferBase src1, ffi::BufferBase src2, ffi::BufferBase src3, ffi::BufferBase src4, ffi::BufferBase src5, ffi::BufferBase src6, - ffi::BufferBase src7, ffi::BufferBase dst0, - ffi::BufferBase dst1, ffi::BufferBase dst2, - ffi::BufferBase dst3, ffi::BufferBase dst4) { + ffi::BufferBase dst0, ffi::BufferBase dst1, + ffi::BufferBase dst2, ffi::BufferBase dst3, + ffi::BufferBase dst4, ffi::BufferBase dst5, + ffi::BufferBase dst6) { // src0: param 0 at tuple index {0}, shape f32[128] // src1: param 0 at tuple index {1}, shape f32[256] // src2: param 1 at tuple index {0}, shape f32[1024] // src3: param 1 at tuple index {1}, shape f32[8] // src4: param 2, shape f32[4,8] + // src5: param 3 at tuple index {0, 0}, shape f32[3,128] + // src6: param 3 at tuple index {0, 1}, shape f32[5,128] // // dst0: result at tuple index {0}, shape f32[8] // dst1: result at tuple index {1, 0}, shape f32[128] // dst2: result at tuple index {1, 1}, shape f32[256] // dst3: result at tuple index {2}, shape f32[1024] // dst4: result at tuple index {3}, shape f32[4,8] + // dst5: result at tuple index {4, 0}, shape f32[5,128] + // dst6: result at tuple index {4, 1}, shape f32[3,128] TF_RETURN_IF_ERROR( stream->MemcpyD2D(&dst0.data, src3.data, 8 * sizeof(float))); @@ -2595,6 +2600,10 @@ static absl::Status SubBuffers2(se::Stream* stream, ffi::BufferBase src0, stream->MemcpyD2D(&dst3.data, src2.data, 1024 * sizeof(float))); TF_RETURN_IF_ERROR( stream->MemcpyD2D(&dst4.data, src4.data, 4 * 8 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst5.data, src6.data, 5 * 128 * sizeof(float))); + TF_RETURN_IF_ERROR( + stream->MemcpyD2D(&dst6.data, src5.data, 3 * 128 * sizeof(float))); return absl::OkStatus(); } @@ -2608,20 +2617,64 @@ XLA_FFI_DEFINE_HANDLER(kSubBuffers2, SubBuffers2, .Arg() // src4 .Arg() // src5 .Arg() // src6 - .Arg() // src7 .Arg() // dst0 .Arg() // dst1 .Arg() // dst2 .Arg() // dst3 .Arg() // dst4 + .Arg() // dst5 + .Arg() // dst6 ); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$subbuffers2", PLATFORM, kSubBuffers2); -TEST_F(AddressComputationFusionTest, Test) { +TEST_F(AddressComputationFusionTest, CustomCallDUS) { XlaBuilder b(TestName()); - CustomCall( - &b, "Callback_Void", /*operands=*/ + auto custom_call = + CustomCall(&b, "Callback_Memcpy", + /*operands=*/ + {Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {10, 128}), + {2, 0}, {5, 128}, {1, 1})}, + ShapeUtil::MakeShape(F32, {3, 128}), /*opaque=*/""); + + DynamicUpdateSlice( + Broadcast(ConstantR0WithType(&b, F32, 92.0), {10, 128}), custom_call, + {ConstantR0WithType(&b, S32, 4), ConstantR0WithType(&b, S32, 0)}); + + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); + xla::HloModuleConfig hlo_config( + xla::ProgramShape(computation.proto().host_program_shape()), + /*ignore_layouts=*/false); + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_address_computation_fusion(false); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + debug_options.set_xla_gpu_enable_address_computation_fusion(true); + hlo_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( + computation.proto(), hlo_config)); + + AddressComputationFusionRewriter pass(PLATFORM); + TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); + EXPECT_TRUE(changed); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), + error_spec, + /*run_hlo_passes=*/false)); +} + +TEST_F(AddressComputationFusionTest, CustomCallDUSTuple) { + XlaBuilder b(TestName()); + auto big_buffer1 = + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10, 128}), "p0"); + auto big_buffer2 = + Parameter(&b, 1, ShapeUtil::MakeShape(F32, {10, 256}), "p1"); + auto custom_call = CustomCall( + &b, "__xla_test$$subbuffers2", /*operands=*/ { Tuple(&b, { @@ -2630,24 +2683,60 @@ TEST_F(AddressComputationFusionTest, Test) { }), Tuple(&b, { - Slice(Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), - {512}, {512 + 256}, {1}), + Broadcast(ConstantR0WithType(&b, F32, 3), {1024}), Broadcast(ConstantR0WithType(&b, F32, 4), {8}), }), Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0}, {4, 8}, {1, 1}), - Tuple(&b, - { - Tuple(&b, - { - Broadcast(ConstantR0WithType(&b, F32, 6), {32}), - Broadcast(ConstantR0WithType(&b, F32, 7), {64}), - }), - }), + Tuple( + &b, + { + Tuple( + &b, + { + Broadcast(ConstantR0WithType(&b, F32, 6), {3, 128}), + DynamicSlice(Broadcast(ConstantR0WithType(&b, F32, 7), + {8, 128}), + {ConstantR0WithType(&b, S32, 2), + ConstantR0WithType(&b, S32, 0)}, + {5, 128}), + }), + }), }, - ShapeUtil::MakeNil(), - // ShapeUtil::MakeShape(F32, {128}), - /*opaque=*/""); + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {128}), + ShapeUtil::MakeShape(F32, {256}), + }), + ShapeUtil::MakeShape(F32, {1024}), + ShapeUtil::MakeShape(F32, {4, 8}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {5, 128}), + ShapeUtil::MakeShape(F32, {3, 128}), + }), + }), + /*opaque=*/"", + /*has_side_effect=*/false, + /*output_operand_aliasing=*/{}, /*literal=*/nullptr, + /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI); + auto tuple_gte = GetTupleElement(custom_call, 4); + auto dus1 = DynamicUpdateSlice( + big_buffer1, GetTupleElement(tuple_gte, 0), + {ConstantR0WithType(&b, S32, 2), ConstantR0WithType(&b, S32, 0)}); + auto dus2 = DynamicUpdateSlice( + big_buffer1, GetTupleElement(tuple_gte, 1), + {ConstantR0WithType(&b, S32, 7), ConstantR0WithType(&b, S32, 0)}); + auto dus3 = DynamicUpdateSlice( + big_buffer2, + xla::internal::XlaBuilderFriend::BuildBitcast( + &b, GetTupleElement(custom_call, 2), + ShapeUtil::MakeShape(F32, {4, 256})), + {Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start0"), + Parameter(&b, 3, ShapeUtil::MakeShape(S32, {}), "start1")}); + Tuple(&b, {dus1, dus2, dus3}); + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 43d6c5d59b8e19..99b1e9a0918669 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -207,8 +207,26 @@ absl::StatusOr GetResultSlice( const HloInstruction& fusion_instr, const HloInstruction& start_instr, std::vector& slice_instrs, const ShapeIndex& shape_idx, unsigned arg_idx) { + auto* start = const_cast(&start_instr); + // Walk through ShapeIndex to find the real "user" (i.e. not get-tuple-element + // user). Otherwise one sliced element will mark all buffers of all other + // elements "sliced" too. + if (start->shape().IsTuple()) { + for (auto idx : shape_idx) { + std::vector gte_users( + start->shape().tuple_shapes_size(), nullptr); + for (auto* user : start->users()) + if (auto* gte = DynCast(user)) + gte_users[gte->tuple_index()] = gte; + + start = static_cast(gte_users[idx]); + if (start == nullptr) + return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx); + } + } + auto slice_adaptor = HloFindIf( - {HloInstructionAdaptor(start_instr)}, adaptor, + {HloInstructionAdaptor(*start)}, adaptor, [](auto node) { return node.opcode() == HloOpcode::kDynamicUpdateSlice; }, /*visit_operands=*/false); if (slice_adaptor.has_value()) { From 96817f4184908aef93634cf3ba61b1bc75c27b89 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 1 Apr 2024 02:02:09 -0700 Subject: [PATCH 118/124] Update GraphDef version to 1819. PiperOrigin-RevId: 620789686 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index a71104785c2b9d..bb798e7845959e 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1818 // Updated: 2024/3/31 +#define TF_GRAPH_DEF_VERSION 1819 // Updated: 2024/4/1 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 2686c02b1d10c44a27fd17cf69f259bdc153d682 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 1 Apr 2024 02:02:35 -0700 Subject: [PATCH 119/124] compat: Update forward compatibility horizon to 2024-04-01 PiperOrigin-RevId: 620789797 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 8ca0d0f2d923c6..7f59445d85966d 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 3, 31) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 4, 1) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 869892f6e59da06300488380aaf3df1f4c5af651 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 1 Apr 2024 08:19:49 -0700 Subject: [PATCH 120/124] [PJRT:CPU] Fix thread-pool stack sizes to 2MiB. The default thread pool size is too small on Mac OS. An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite. Fixes https://github.com/google/jax/issues/20428 PiperOrigin-RevId: 620853544 --- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 6c0bcfc141e5a9..63f930a3c037af 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -363,14 +363,23 @@ absl::StatusOr> GetTfrtCpuClient( std::move(options.collectives), num_threads)); } +static tsl::ThreadOptions GetThreadOptions() { + tsl::ThreadOptions thread_options; + // On Mac OS the default stack size is 512KiB, which is too small for some + // BLAS and LAPACK functions (https://github.com/google/jax/issues/20428). + thread_options.stack_size = 2 * 1024 * 1024; + return thread_options; +} + TfrtCpuClient::TfrtCpuClient( int process_index, std::vector> devices, std::shared_ptr collectives, size_t num_threads) : process_index_(process_index), owned_devices_(std::move(devices)), computation_placer_(std::make_unique()), - pjrt_client_thread_pool_(new tsl::thread::ThreadPool( - tsl::Env::Default(), "XLATfrtCpuClient", num_threads)), + pjrt_client_thread_pool_( + new tsl::thread::ThreadPool(tsl::Env::Default(), GetThreadOptions(), + "XLATfrtCpuClient", num_threads)), async_work_runner_(std::make_unique( pjrt_client_thread_pool_.get())), eigen_intraop_pool_(new tsl::thread::ThreadPool( From 61cac5cec1e3c31d795decb56b57c82584aefeb9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 1 Apr 2024 09:23:01 -0700 Subject: [PATCH 121/124] [PJRT:CPU] Replace references to pjrt/tfrt_cpu_pjrt_client with pjrt/cpu/cpu_client.h. The two are aliases and the former is a forwarding header pointing to the latter. Cleanup only, no functional changes. PiperOrigin-RevId: 620867527 --- tensorflow/c/experimental/next_pluggable_device/BUILD | 2 +- .../next_pluggable_device/tensor_pjrt_buffer_util_test.cc | 2 +- third_party/xla/xla/pjrt/cpu/BUILD | 6 ++---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 03c83a4e8f99e0..3d92b7ad3d2992 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -94,9 +94,9 @@ tf_cc_test( "@local_xla//xla:shape_util", "@local_xla//xla/pjrt:pjrt_api", "@local_xla//xla/pjrt:pjrt_c_api_client", - "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", "@local_xla//xla/pjrt/c:pjrt_c_api_cpu", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", "@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl", + "@local_xla//xla/pjrt/cpu:cpu_client", ], ) diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc index c72f0cfafa6ead..7f45fd91a1baea 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_cpu.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/types.h" diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 324d684611e22c..603a8ef30e2dcc 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -1,4 +1,4 @@ -load("@local_tsl//tsl:tsl.bzl", "if_oss") +load("@local_tsl//tsl:tsl.bzl", "if_oss", "internal_visibility") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") @@ -130,9 +130,7 @@ cc_library( name = "cpu_client", srcs = ["cpu_client.cc"], hdrs = ["cpu_client.h"], - visibility = [ - "//xla:friends", - ], + visibility = internal_visibility(["//xla:friends"]), deps = [ ":abstract_tfrt_cpu_buffer", ":cpu_topology", From 007eb8a5b793ecbaa648b2e894df4f4c846a6156 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 1 Apr 2024 09:41:37 -0700 Subject: [PATCH 122/124] [xla:gpu] No need for dynamic/static mode in AddressComputationFusionRewriter PiperOrigin-RevId: 620871987 --- .../address_computation_fusion_rewriter.cc | 195 ++++++++---------- 1 file changed, 90 insertions(+), 105 deletions(-) diff --git a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc index db5ece796ffe79..5b92eb3423ebbb 100644 --- a/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/address_computation_fusion_rewriter.cc @@ -405,125 +405,110 @@ absl::StatusOr CreateFusionInstruction( absl::StatusOr AddressComputationFusionRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - auto process_slices = [&](bool dynamic) -> absl::StatusOr { - absl::flat_hash_map> - matches; - - // Collect all potential custom call matches in the non-fusion computations. - for (HloComputation* computation : module->computations()) { - if (computation->IsFusionComputation()) continue; - for (HloInstruction* instr : computation->instructions()) { - if (IsLegacyCublasMatmul(*instr) || - (IsCustomCall(instr, platform_name_))) { - UseDefDataflowPaths sliced_operand_paths = - GetSlicedOperandPaths(instr); - bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; - - DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); - bool has_sliced_user_paths = - absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) { - return !sliced_user_path.empty(); - }); - - if (absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) { - return DynCast( - sliced_user_path.back()) == nullptr; - })) { - return absl::InternalError( - "Expect sliced user path to end with a DUS."); - } + absl::flat_hash_map> + matches; + + // Collect all potential custom call matches in the non-fusion computations. + for (HloComputation* computation : module->computations()) { + if (computation->IsFusionComputation()) continue; + for (HloInstruction* instr : computation->instructions()) { + if (IsLegacyCublasMatmul(*instr) || + (IsCustomCall(instr, platform_name_))) { + UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); + bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; + + DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); + bool has_sliced_user_paths = absl::c_any_of( + sliced_user_paths, + [&](auto& sliced_user_path) { return !sliced_user_path.empty(); }); + + if (absl::c_any_of(sliced_user_paths, [&](auto& sliced_user_path) { + return DynCast( + sliced_user_path.back()) == nullptr; + })) { + return absl::InternalError( + "Expect sliced user path to end with a DUS."); + } - if (has_sliced_operand_paths || has_sliced_user_paths) { - matches[instr] = std::make_pair(std::move(sliced_operand_paths), - std::move(sliced_user_paths)); - } + if (has_sliced_operand_paths || has_sliced_user_paths) { + matches[instr] = std::make_pair(std::move(sliced_operand_paths), + std::move(sliced_user_paths)); } } } + } - if (matches.empty()) return false; + if (matches.empty()) return false; - for (auto& [hero, paths] : matches) { - auto& [sliced_operand_paths, sliced_user_paths] = paths; - std::vector matched_instrs; - absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs)); + for (auto& [hero, paths] : matches) { + auto& [sliced_operand_paths, sliced_user_paths] = paths; + std::vector matched_instrs; + absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs)); - std::vector sliced_user_paths_view; + std::vector sliced_user_paths_view; + for (auto& sliced_user_path : sliced_user_paths) { + absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs)); + DataflowPathView sliced_user_path_view{&sliced_user_path.front(), + sliced_user_path.size()}; + sliced_user_paths_view.push_back(std::move(sliced_user_path_view)); + } + + auto captures = GetPatternCaptures(matched_instrs); + + TF_ASSIGN_OR_RETURN( + HloComputation * fusion_body, + CreateFusionBody(module, sliced_operand_paths, + DataflowPathsView(sliced_user_paths_view), captures)); + + bool has_dynamic_slices = absl::c_any_of(matched_instrs, [&](auto* instr) { + return DynCast(instr) != nullptr; + }); + TF_ASSIGN_OR_RETURN( + HloInstruction * fusion, + CreateFusionInstruction(module, hero, captures, fusion_body, + has_dynamic_slices)); + + HloComputation* parent = hero->parent(); + if (fusion->shape().IsTuple()) { + TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape( + const_cast(hero), fusion)); for (auto& sliced_user_path : sliced_user_paths) { - absl::c_copy(sliced_user_path, std::back_inserter(matched_instrs)); - DataflowPathView sliced_user_path_view{&sliced_user_path.front(), - sliced_user_path.size()}; - sliced_user_paths_view.push_back(std::move(sliced_user_path_view)); + auto old_gte = + Cast(sliced_user_path.front()); + HloInstruction* gte = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + fusion, old_gte->tuple_index())); + TF_RETURN_IF_ERROR( + parent->ReplaceInstruction(sliced_user_path.back(), gte)); } - - auto captures = GetPatternCaptures(matched_instrs); - - TF_ASSIGN_OR_RETURN( - HloComputation * fusion_body, - CreateFusionBody(module, sliced_operand_paths, - DataflowPathsView(sliced_user_paths_view), - captures)); - - bool has_dynamic_slices = - absl::c_any_of(matched_instrs, [&](auto* instr) { - return DynCast(instr) != nullptr; - }); - TF_ASSIGN_OR_RETURN( - HloInstruction * fusion, - CreateFusionInstruction(module, hero, captures, fusion_body, - has_dynamic_slices)); - - HloComputation* parent = hero->parent(); - if (fusion->shape().IsTuple()) { - TF_RETURN_IF_ERROR(parent->ReplaceInstructionWithDifferentShape( - const_cast(hero), fusion)); - for (auto& sliced_user_path : sliced_user_paths) { - auto old_gte = - Cast(sliced_user_path.front()); - HloInstruction* gte = - parent->AddInstruction(HloInstruction::CreateGetTupleElement( - fusion, old_gte->tuple_index())); - TF_RETURN_IF_ERROR( - parent->ReplaceInstruction(sliced_user_path.back(), gte)); - } - } else { - auto* instr_to_be_replaced = const_cast(hero); - if (sliced_user_paths.empty()) { - // The only case where a tuple-shaped original hero op is fused into a - // non-tuple-shaped fusion is there's only one element of the original - // tuple being used. In that case, we need to replace that single - // get-tuple-element (instead of the hero op) with the fusion - // instruction. - if (hero->shape().IsTuple()) { - if (hero->user_count() != 1 || - !DynCast( - hero->users().front())) { - return absl::InternalError( - "Expect a single get-tuple-element user of the original " - "tuple-shaped hero op when address computation fusion does " - "not return a tuple"); - } - instr_to_be_replaced = hero->users().front(); + } else { + auto* instr_to_be_replaced = const_cast(hero); + if (sliced_user_paths.empty()) { + // The only case where a tuple-shaped original hero op is fused into a + // non-tuple-shaped fusion is there's only one element of the original + // tuple being used. In that case, we need to replace that single + // get-tuple-element (instead of the hero op) with the fusion + // instruction. + if (hero->shape().IsTuple()) { + if (hero->user_count() != 1 || + !DynCast(hero->users().front())) { + return absl::InternalError( + "Expect a single get-tuple-element user of the original " + "tuple-shaped hero op when address computation fusion does " + "not return a tuple"); } - } else { - instr_to_be_replaced = sliced_user_paths.front().back(); + instr_to_be_replaced = hero->users().front(); } - TF_RETURN_IF_ERROR( - parent->ReplaceInstruction(instr_to_be_replaced, fusion)); + } else { + instr_to_be_replaced = sliced_user_paths.front().back(); } + TF_RETURN_IF_ERROR( + parent->ReplaceInstruction(instr_to_be_replaced, fusion)); } + } - return true; - }; - - // TODO(vuson): unify dynamic_address_computation and address_computation - TF_ASSIGN_OR_RETURN(bool processed_pattern_with_static_slices, - process_slices(false)); - TF_ASSIGN_OR_RETURN(bool processed_pattern_with_dynamic_slices, - process_slices(true)); - return processed_pattern_with_static_slices || - processed_pattern_with_dynamic_slices; + return true; } } // namespace gpu From e7efc3a618162ce1093007467b75da2139aa622d Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 1 Apr 2024 10:57:48 -0700 Subject: [PATCH 123/124] [xla:gpu][NFC] No need for custom HloModuleConfigs in address_computation_fusion_test Since the rewriter is now in RunHloPasses, these configs do not do anything. PiperOrigin-RevId: 620894763 --- .../address_computation_fusion_test.cc | 81 +++++++------------ 1 file changed, 27 insertions(+), 54 deletions(-) diff --git a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc index abe64600d848e9..2374f5abe6e2fd 100644 --- a/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/address_computation_fusion_test.cc @@ -70,18 +70,8 @@ namespace { class AddressComputationFusionTest : public HloTestBase { public: - HloModuleConfig GetRefModuleConfig() { + HloModuleConfig GetModuleConfigWithoutCommandBuffer() { DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); - debug_options.clear_xla_gpu_enable_command_buffer(); - HloModuleConfig config; - config.set_debug_options(debug_options); - return config; - } - - HloModuleConfig GetOptModuleConfig() { - DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); debug_options.clear_xla_gpu_enable_command_buffer(); HloModuleConfig config; config.set_debug_options(debug_options); @@ -163,8 +153,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmSimple) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -245,8 +234,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -324,8 +312,7 @@ TEST_F(AddressComputationFusionTest, ContiguousSlice) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -403,8 +390,7 @@ TEST_F(AddressComputationFusionTest, ContiguousSliceNonDefaultLayout) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -535,8 +521,7 @@ TEST_F(AddressComputationFusionTest, OperandIsSlicedGetTupleElement) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -621,8 +606,7 @@ TEST_F(AddressComputationFusionTest, ReversedOperandOrder) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -752,8 +736,7 @@ TEST_F(AddressComputationFusionTest, SingleOperandComputation) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -843,8 +826,7 @@ TEST_F(AddressComputationFusionTest, SlicedOperandAliasingOutput) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1247,8 +1229,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmDynamic) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1335,8 +1316,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmDynamicWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1426,8 +1406,7 @@ TEST_F(AddressComputationFusionTest, DynamicContiguousSlice) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1518,8 +1497,7 @@ TEST_F(AddressComputationFusionTest, DynamicContiguousSliceNonDefaultLayout) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1653,8 +1631,7 @@ TEST_F(AddressComputationFusionTest, DynamicOperandIsSlicedGetTupleElement) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1745,8 +1722,7 @@ TEST_F(AddressComputationFusionTest, DynamicReversedOperandOrder) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1879,8 +1855,7 @@ TEST_F(AddressComputationFusionTest, DynamicSingleOperandComputation) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -1978,8 +1953,7 @@ TEST_F(AddressComputationFusionTest, DynamicSlicedOperandAliasingOutput) { } })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -2070,9 +2044,12 @@ TEST_F(AddressComputationFusionTest, CublasGemmDUS) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, - /*run_hlo_passes=*/false)); + // The GEMM custom call does not have a workspace, shouldn't be run in command + // buffer. + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithoutCommandBuffer(), + GetModuleConfigWithoutCommandBuffer(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(AddressComputationFusionTest, CublasGemmDUSWithWorkspace) { @@ -2168,8 +2145,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmDUSWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -2254,8 +2230,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmDUSWorkspaceIgnored) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -2352,8 +2327,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmDUSOffsetS32NotConstant) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } @@ -2450,8 +2424,7 @@ TEST_F(AddressComputationFusionTest, CublasGemmDUSOffsetOOB) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, GetRefModuleConfig(), - GetOptModuleConfig(), error_spec, + EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, /*run_hlo_passes=*/false)); } From d02479b2b7da0b2348a810ea183c80217860e4be Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Wed, 27 Mar 2024 22:14:12 -0700 Subject: [PATCH 124/124] Make usages of Eigen::array compatible with std::array. Eigen::array is no longer necessary, so will be deprecated/removed, and replaced with `std::array`. The main difference is the constructor - currently Eigen::array allows `array(a, b, c, ...)` construction, whereas `std::array` requires an initializer list. We also need to remove any direct access to the `Eigen::array::values` internal parameter, in favor of regular index access. PiperOrigin-RevId: 619787691 --- .../core/distributed_runtime/master_test.cc | 6 +-- .../core/kernels/gather_nd_op_gpu.cu.cc | 11 +++-- .../core/kernels/image/adjust_contrast_op.cc | 40 +++++++++++-------- ...arameterized_truncated_normal_op_gpu.cu.cc | 2 - tensorflow/core/kernels/random_binomial_op.cc | 2 - .../kernels/sparse_tensor_dense_matmul_op.cc | 3 +- .../eigen_spatial_convolutions_test.cc | 4 -- 7 files changed, 35 insertions(+), 33 deletions(-) diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc index 5c2f17e31f819d..1e9e5545183191 100644 --- a/tensorflow/core/distributed_runtime/master_test.cc +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "grpcpp/grpcpp.h" - +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" @@ -389,8 +389,8 @@ TEST_F(MasterTest, EigenProblem) { TF_CHECK_OK(CreateSession(def, &handle, &initial_version)); // Temps supporting the computation of the convergence condition. - const Eigen::array sum_along_dim(0); - const Eigen::array matrix_transpose({1, 0}); + const Eigen::array sum_along_dim{0}; + const Eigen::array matrix_transpose{1, 0}; Tensor x(DT_FLOAT, TensorShape({2, 1})); Tensor y(DT_FLOAT, TensorShape({2, 1})); Eigen::Tensor y_square_sum; diff --git a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc index 227cd311c244ee..c26f5bdf492e39 100644 --- a/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/gather_nd_op_gpu.cu.cc @@ -39,11 +39,14 @@ __global__ void GatherSliceOpKernel( const auto indices_i = indices + IXDIM * loc; bool out_of_bounds = false; Index offset = 0; + // Avoid empty std::array access, which fails to compile on GPU. + if constexpr (IXDIM > 0) { #pragma unroll - for (int j = 0; j < IXDIM; ++j) { - const Index index_j = ldg(indices_i + j); - out_of_bounds |= !FastBoundsCheck(index_j, batch_indices[j]); - offset += batch_strides[j] * index_j; + for (int j = 0; j < IXDIM; ++j) { + const Index index_j = ldg(indices_i + j); + out_of_bounds |= !FastBoundsCheck(index_j, batch_indices[j]); + offset += batch_strides[j] * index_j; + } } // TODO(ebrevdo): // This is the only part that depends on the offset. The part diff --git a/tensorflow/core/kernels/image/adjust_contrast_op.cc b/tensorflow/core/kernels/image/adjust_contrast_op.cc index 7cef95b9479022..df8650ebfed515 100644 --- a/tensorflow/core/kernels/image/adjust_contrast_op.cc +++ b/tensorflow/core/kernels/image/adjust_contrast_op.cc @@ -248,6 +248,7 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { TTypes::Tensor mean_flat(&mean(0, 0), mean.size()); TTypes::Tensor summation_scratch(&scratch(0, 0, 0), scratch.size()); + using Eigen::DenseIndex; typedef Eigen::array Index; const int64_t plane_size = image_size * channels; // Since the number of channels in the early layers is often small, a @@ -255,10 +256,10 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { // This algorithm repeatedly folds each image plane by half, until // only one set of channels remains. for (int64_t i = 0; i < batch; i++) { - auto input_plane = - input_flat.slice(Index(i * plane_size), Index(plane_size)); - auto summation_plane = - summation_scratch.slice(Index(i * plane_size), Index(plane_size)); + auto input_plane = input_flat.slice(Index{DenseIndex(i * plane_size)}, + Index{DenseIndex(plane_size)}); + auto summation_plane = summation_scratch.slice( + Index{DenseIndex(i * plane_size)}, Index{DenseIndex(plane_size)}); int64_t remaining_size = image_size; int round = 0; // Sum the input(i, :, k) into mean(i, k). Repeatedly splits the input @@ -289,26 +290,29 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { if (round == 0) { // In the first round, sum the left side and right side of the input // array into the summation area. - summation_plane.slice(Index(0), Index(right_size * channels)) = - input_plane.slice(Index(left_size * channels), - Index(right_size * channels)) + - input_plane.slice(Index(0), Index(right_size * channels)); + summation_plane.slice(Index{0}, + Index{DenseIndex(right_size * channels)}) = + input_plane.slice(Index{DenseIndex(left_size * channels)}, + Index{DenseIndex(right_size * channels)}) + + input_plane.slice(Index{0}, + Index{DenseIndex(right_size * channels)}); if (left_size > right_size) { DCHECK_EQ(left_size - right_size, 1); // Copy over the remaining column if the remaining_size is odd. // This also handles the case where image_size == 1. - summation_plane.slice(Index(right_size * channels), - Index(channels)) = - input_plane.slice(Index(right_size * channels), - Index(channels)); + summation_plane.slice(Index{DenseIndex(right_size * channels)}, + Index{DenseIndex(channels)}) = + input_plane.slice(Index{DenseIndex(right_size * channels)}, + Index{DenseIndex(channels)}); } } else { // For all the remaining rounds, add the second half of the inputs // into the first half of the inputs. With the flat structure and // large size, this utilizes vectorization between components. - summation_plane.slice(Index(0), Index(right_size * channels)) += - summation_plane.slice(Index(left_size * channels), - Index(right_size * channels)); + summation_plane.slice(Index{0}, + Index{DenseIndex(right_size * channels)}) += + summation_plane.slice(Index{DenseIndex(left_size * channels)}, + Index{DenseIndex(right_size * channels)}); } remaining_size = left_size; round++; @@ -316,9 +320,11 @@ class AdjustContrastOpv2 : public AdjustContrastOpV2Base { const float mean_scaling = 1.0f / image_size; // The first channels elements in summation_plane now holds the summation. // Scale it with image_size and copy over to the means. - auto mean_plane = mean_flat.slice(Index(i * channels), Index(channels)); + auto mean_plane = mean_flat.slice(Index{DenseIndex(i * channels)}, + Index{DenseIndex(channels)}); mean_plane = - summation_plane.slice(Index(0), Index(channels)) * mean_scaling; + summation_plane.slice(Index{0}, Index{DenseIndex(channels)}) * + mean_scaling; } } diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc index b826564437c0a1..e7b76653dc329e 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc @@ -132,8 +132,6 @@ __global__ void __launch_bounds__(1024) (normMax >= T(0.))) || ((normMax > kStdDevsInsideBoundsToUseRandnSampler) && (normMin <= T(0.)))) { - Eigen::array n; - int numIterations = 0; while (numIterations < kMaxIterations) { const auto randn = normal_dist(&gen); diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc index 8fceaf70c0dbbb..98118b78eb5b58 100644 --- a/tensorflow/core/kernels/random_binomial_op.cc +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -187,8 +187,6 @@ struct RandomBinomialFunctor { &gen, &output](int64_t start_output, int64_t limit_output) { // Vectorized intermediate calculations for uniform rejection sampling. // We always generate at most 4 samples. - Eigen::array z; - Eigen::array g; const bool should_bcast = bcast.IsBroadcastingRequired(); const auto& counts_batch_indices = bcast.x_batch_indices(); const auto& probs_batch_indices = bcast.y_batch_indices(); diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 04aff711362552..cb80fa34230a20 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -310,7 +311,7 @@ Status SparseTensorDenseMatMulImpl( if (ADJ_B) { // Perform transpose and conjugation on B once, since we chip out B's // columns in the nnz loop. - Eigen::array shuffle(1, 0); // preserve dimension order + Eigen::array shuffle{1, 0}; // preserve dimension order Eigen::Tensor col_major_conj_b = b.swap_layout().shuffle(shuffle).conjugate(); LOOP_NNZ(col_major_conj_b); diff --git a/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc b/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc index 48e0379bce1466..9f589c549901e9 100644 --- a/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc +++ b/third_party/xla/third_party/tsl/tsl/framework/convolution/eigen_spatial_convolutions_test.cc @@ -1036,10 +1036,6 @@ static void PackLhsHelper(::testing::benchmark::State& state, reshape_dims[0] = filter_count; reshape_dims[1] = input_depth * filter_rows * filter_cols; - // We are going to contract along the 'in_depth * filter_rows * filter_cols`. - nocontract_t nocontract_dim = {0}; - contract_t contract_dim = {1}; - // These values computed using the algorithm in TensorContraction.h, with // 'nocontract_dim' and 'contract_dim' values specified above. nocontract_t nocontract_strides = {1};