From e02c9ff3f24175ad9a259c0ff0b05cc5b96158a3 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 2 Aug 2021 09:56:48 -0700 Subject: [PATCH] Add a new op XlaRngBitGenerator that wraps the HLO op RngBitGenertor. See https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. The use case for this op is for jax2tf, where we need to expose the plain HLO op. There are already TensorFlow ops that use the same HLO op, but they all seem to contain a certain amount of sugar on top, which makes it hard to ensure that the TF op has the exact semantics of the HLO op. PiperOrigin-RevId: 388242420 Change-Id: I55b35c36a087b53b20694aac66f13a1aa8e2f1d2 --- .../compiler/jit/compilability_check_util.cc | 3 + .../compiler/jit/mark_for_compilation_pass.cc | 1 + .../jit/xla_ops_on_regular_devices.cc | 5 ++ .../mlir/tensorflow/ir/tf_generated_ops.td | 25 +++++++++ .../mark_ops_for_outside_compilation.cc | 2 + tensorflow/compiler/tests/xla_ops_test.py | 56 +++++++++++++++++++ .../tf2xla/kernels/stateless_random_ops_v2.cc | 38 +++++++++++++ tensorflow/compiler/tf2xla/ops/xla_ops.cc | 33 +++++++++++ tensorflow/compiler/tf2xla/python/xla.py | 23 ++++++++ 9 files changed, 186 insertions(+) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 5f7f62fc99606e..493f64baac8f8d 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -705,6 +705,7 @@ static auto const ops_triggering_xla_compilation = "XlaReduce", "XlaReduceWindow", "XlaReplicaId", + "XlaRngBitGenerator", "XlaScatter", "XlaSelectAndScatter", "XlaSelfAdjointEig", @@ -714,6 +715,8 @@ static auto const ops_triggering_xla_compilation = "XlaSpmdFullToShardShape", "XlaSpmdShardToFullShape", "XlaSvd", + "XlaVariadicReduceV2", + "XlaVariadicSort", "XlaWhile"}; static bool NodeCanTriggerXlaCompilation(const NodeDef& node) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index f5c7c30e2918db..547ad0b10277ed 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2083,6 +2083,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaReduceWindow", "XlaRemoveDynamicDimensionSize", "XlaReplicaId", + "XlaRngBitGenerator", "XlaScatter", "XlaSelectAndScatter", "XlaSelfAdjointEig", diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc index 0c08dd9210e53a..0dce2d0f7a4ab9 100644 --- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc +++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc @@ -97,6 +97,11 @@ namespace tensorflow { XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \ XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaRngBitGenerator") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \ XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 6122044f4ca4d3..c6eb9e83c24ea6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -19481,6 +19481,31 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect, TF_NoConstantFold]> // constant folded at the compile time. } +def TF_XlaRngBitGeneratorOp : TF_Op<"XlaRngBitGenerator", [NoSideEffect]> { + let summary = "Stateless PRNG bit generator."; + + let description = [{ +Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + }]; + + let arguments = (ins + Arg:$algorithm, + Arg:$initial_state, + Arg:$shape + ); + + let results = (outs + TF_Uint64Tensor:$output_key, + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<1>; +} + def TF_XlaScatterOp : TF_Op<"XlaScatter", [NoSideEffect]> { let summary = "Wraps the XLA Scatter operator documented at"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 08f7273e3b7f07..1172a5799eda1e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -120,6 +120,8 @@ void AddSupportedFunctionalOps(MLIRContext* context, OperationName(TF::XlaReduceOp::getOperationName(), context)); supported_ops->insert( OperationName(TF::XlaReduceWindowOp::getOperationName(), context)); + supported_ops->insert( + OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context)); supported_ops->insert( OperationName(TF::XlaScatterOp::getOperationName(), context)); supported_ops->insert( diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index fb832e1a17d4f4..a3404b19cdf991 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import googletest @@ -316,6 +317,26 @@ def pad_fn(x): [[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]], dtype=dtype)) + @parameterized.parameters(stateless_random_ops.Algorithm.THREEFRY, + stateless_random_ops.Algorithm.PHILOX, + stateless_random_ops.Algorithm.AUTO_SELECT) + @test_util.disable_mlir_bridge('Not supported yet') + def testRngBitGeneratorIsDeterministic(self, algorithm): + dtype = np.uint32 + key = np.array([1, 2], dtype=np.uint64) + shape = (10, 12) + + def rng_fun_is_deterministic(k): + res1 = xla.rng_bit_generator(algorithm, k, shape, dtype=dtype) + res2 = xla.rng_bit_generator(algorithm, k, shape, dtype=dtype) + return (res1[0] - res2[0], res1[1] - res2[1]) + + self._assertOpOutputMatchesExpected( + rng_fun_is_deterministic, + args=(key,), + expected=(np.zeros(key.shape, dtype=key.dtype), + np.zeros(shape, dtype=dtype))) + @test_util.disable_mlir_bridge('Not supported yet') def testReduce(self): for dtype in set(self.numeric_types).intersection( @@ -968,6 +989,41 @@ def assert_output_shapes(output, expected_shape): 'All inputs must have the same shape'): reduce_with_shapes((None, 4, 5), (3, None, 5), (13, 4, 5)) + @parameterized.parameters(stateless_random_ops.Algorithm.THREEFRY, + stateless_random_ops.Algorithm.PHILOX, + stateless_random_ops.Algorithm.AUTO_SELECT) + def testRngBitGenerator(self, algorithm): + dtype = np.uint64 + initial_state = array_ops.placeholder(np.uint64, shape=(2,)) + shape = (2, 3) + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + + self.assertEqual(res[0].shape, initial_state.shape) + self.assertEqual(res[1].shape, shape) + + # The initial_state has unknown dimension size + initial_state = array_ops.placeholder(np.uint64, shape=(None,)) + shape = (2, 3) + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + + self.assertEqual(res[0].shape.as_list(), initial_state.shape.as_list()) + self.assertEqual(res[1].shape, shape) + + # The initial_state has unknown rank + initial_state = array_ops.placeholder(np.uint64, shape=None) + shape = (2, 3) + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + + self.assertEqual(res[0].shape.as_list(), [None]) + self.assertEqual(res[1].shape, shape) + + # The output shape has unknown dimension + initial_state = array_ops.placeholder(np.uint64, shape=(None,)) + shape = (None, 3) + with self.assertRaisesRegex(TypeError, + 'Failed to convert object .* to Tensor'): + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + if __name__ == '__main__': # This test is using Tensorflow sessions which are not compatible with eager diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index dabc19a3628831..53e5b1c21365c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -575,5 +575,43 @@ class GetAlgOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp); +class XlaRngBitGeneratorOp : public XlaOpKernel { + public: + explicit XlaRngBitGeneratorOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::RandomAlgorithm algorithm; + OP_REQUIRES_OK(ctx, + AlgorithmFromInput(ctx, 0, device_type_string_, &algorithm)); + xla::XlaOp initial_state = ctx->Input(1); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(2, &shape)); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + + xla::XlaOp result = + xla::RngBitGenerator(algorithm, initial_state, xla_shape); + ctx->SetOutput(0, xla::GetTupleElement(result, 0)); + ctx->SetOutput(1, xla::GetTupleElement(result, 1)); + } + + private: + DataType dtype_; + string device_type_string_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaRngBitGeneratorOp); +}; + +REGISTER_XLA_OP(Name("XlaRngBitGenerator") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_UINT32, DT_UINT64}), + XlaRngBitGeneratorOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 42c45d6864253f..ff4cc5c9dd8675 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -807,6 +807,39 @@ window_strides: the inter-window strides padding: the padding to apply at the start and end of each input dimensions )doc"); +REGISTER_OP("XlaRngBitGenerator") + .Input("algorithm: int32") + .Input("initial_state: uint64") + .Input("shape: Tshape") + .Output("output_key: uint64") + .Output("output: dtype") + .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64") + .Attr("Tshape: {int32, int64} = DT_INT32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle algorithm; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm)); + shape_inference::ShapeHandle initial_state; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state)); + + c->set_output(0, initial_state); + shape_inference::ShapeHandle output; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output)); + c->set_output(1, output); + return Status::OK(); + }) + .Doc(R"doc( +Stateless PRNG bit generator. +Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + +algorithm: The PRNG algorithm to use, one of + tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. +initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be + a u64[2] and for PHILOX a u64[3]. +shape: The output shape of the generated data. +dtype: The type of the tensor. +)doc"); + REGISTER_OP("XlaSelectAndScatter") .Input("operand: T") .Input("window_dimensions: Tindices") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 1c136406cb6cfd..f9ee76c7d1652f 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -40,6 +40,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops.numpy_ops import np_utils # TODO(phawkins): provide wrappers for all XLA operators. Currently the missing @@ -377,6 +378,28 @@ def random_uniform(minval, maxval, dims, name=None): dims, minval, maxval, dtype=minval.dtype, name=name) +def rng_bit_generator(algorithm, initial_state, shape, dtype): + """Stateless PRNG bit generator. + + Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + + Args: + algorithm: The PRNG algorithm to use, one of + tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. + initial_state: Initial state for the PRNG algorithm. For THREEFRY, it + should be a u64[2] and for PHILOX a u64[3]. + shape: The output shape of the generated data. + dtype: The type of the tensor. + + Returns: + a tuple with a new state and generated data of the given shape. + """ + alg_int = stateless_random_ops.convert_alg_to_int(algorithm) + return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape, + dtype=dtype) + + recv = gen_xla_ops.xla_recv reduce = gen_xla_ops.xla_reduce variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2