diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 5f7f62fc99606e..493f64baac8f8d 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -705,6 +705,7 @@ static auto const ops_triggering_xla_compilation = "XlaReduce", "XlaReduceWindow", "XlaReplicaId", + "XlaRngBitGenerator", "XlaScatter", "XlaSelectAndScatter", "XlaSelfAdjointEig", @@ -714,6 +715,8 @@ static auto const ops_triggering_xla_compilation = "XlaSpmdFullToShardShape", "XlaSpmdShardToFullShape", "XlaSvd", + "XlaVariadicReduceV2", + "XlaVariadicSort", "XlaWhile"}; static bool NodeCanTriggerXlaCompilation(const NodeDef& node) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index f5c7c30e2918db..547ad0b10277ed 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2083,6 +2083,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaReduceWindow", "XlaRemoveDynamicDimensionSize", "XlaReplicaId", + "XlaRngBitGenerator", "XlaScatter", "XlaSelectAndScatter", "XlaSelfAdjointEig", diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc index 0c08dd9210e53a..0dce2d0f7a4ab9 100644 --- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc +++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc @@ -97,6 +97,11 @@ namespace tensorflow { XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \ XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaRngBitGenerator") \ + .HostMemory("algorithm") \ + .HostMemory("shape") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \ XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 6122044f4ca4d3..c6eb9e83c24ea6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -19481,6 +19481,31 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect, TF_NoConstantFold]> // constant folded at the compile time. } +def TF_XlaRngBitGeneratorOp : TF_Op<"XlaRngBitGenerator", [NoSideEffect]> { + let summary = "Stateless PRNG bit generator."; + + let description = [{ +Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + }]; + + let arguments = (ins + Arg:$algorithm, + Arg:$initial_state, + Arg:$shape + ); + + let results = (outs + TF_Uint64Tensor:$output_key, + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<1>; +} + def TF_XlaScatterOp : TF_Op<"XlaScatter", [NoSideEffect]> { let summary = "Wraps the XLA Scatter operator documented at"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 08f7273e3b7f07..1172a5799eda1e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -120,6 +120,8 @@ void AddSupportedFunctionalOps(MLIRContext* context, OperationName(TF::XlaReduceOp::getOperationName(), context)); supported_ops->insert( OperationName(TF::XlaReduceWindowOp::getOperationName(), context)); + supported_ops->insert( + OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context)); supported_ops->insert( OperationName(TF::XlaScatterOp::getOperationName(), context)); supported_ops->insert( diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index fb832e1a17d4f4..a3404b19cdf991 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import googletest @@ -316,6 +317,26 @@ def pad_fn(x): [[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]], dtype=dtype)) + @parameterized.parameters(stateless_random_ops.Algorithm.THREEFRY, + stateless_random_ops.Algorithm.PHILOX, + stateless_random_ops.Algorithm.AUTO_SELECT) + @test_util.disable_mlir_bridge('Not supported yet') + def testRngBitGeneratorIsDeterministic(self, algorithm): + dtype = np.uint32 + key = np.array([1, 2], dtype=np.uint64) + shape = (10, 12) + + def rng_fun_is_deterministic(k): + res1 = xla.rng_bit_generator(algorithm, k, shape, dtype=dtype) + res2 = xla.rng_bit_generator(algorithm, k, shape, dtype=dtype) + return (res1[0] - res2[0], res1[1] - res2[1]) + + self._assertOpOutputMatchesExpected( + rng_fun_is_deterministic, + args=(key,), + expected=(np.zeros(key.shape, dtype=key.dtype), + np.zeros(shape, dtype=dtype))) + @test_util.disable_mlir_bridge('Not supported yet') def testReduce(self): for dtype in set(self.numeric_types).intersection( @@ -968,6 +989,41 @@ def assert_output_shapes(output, expected_shape): 'All inputs must have the same shape'): reduce_with_shapes((None, 4, 5), (3, None, 5), (13, 4, 5)) + @parameterized.parameters(stateless_random_ops.Algorithm.THREEFRY, + stateless_random_ops.Algorithm.PHILOX, + stateless_random_ops.Algorithm.AUTO_SELECT) + def testRngBitGenerator(self, algorithm): + dtype = np.uint64 + initial_state = array_ops.placeholder(np.uint64, shape=(2,)) + shape = (2, 3) + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + + self.assertEqual(res[0].shape, initial_state.shape) + self.assertEqual(res[1].shape, shape) + + # The initial_state has unknown dimension size + initial_state = array_ops.placeholder(np.uint64, shape=(None,)) + shape = (2, 3) + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + + self.assertEqual(res[0].shape.as_list(), initial_state.shape.as_list()) + self.assertEqual(res[1].shape, shape) + + # The initial_state has unknown rank + initial_state = array_ops.placeholder(np.uint64, shape=None) + shape = (2, 3) + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + + self.assertEqual(res[0].shape.as_list(), [None]) + self.assertEqual(res[1].shape, shape) + + # The output shape has unknown dimension + initial_state = array_ops.placeholder(np.uint64, shape=(None,)) + shape = (None, 3) + with self.assertRaisesRegex(TypeError, + 'Failed to convert object .* to Tensor'): + res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype) + if __name__ == '__main__': # This test is using Tensorflow sessions which are not compatible with eager diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index dabc19a3628831..53e5b1c21365c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -575,5 +575,43 @@ class GetAlgOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp); +class XlaRngBitGeneratorOp : public XlaOpKernel { + public: + explicit XlaRngBitGeneratorOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx), + device_type_string_(ctx->device_type().type_string()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::RandomAlgorithm algorithm; + OP_REQUIRES_OK(ctx, + AlgorithmFromInput(ctx, 0, device_type_string_, &algorithm)); + xla::XlaOp initial_state = ctx->Input(1); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(2, &shape)); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + + xla::XlaOp result = + xla::RngBitGenerator(algorithm, initial_state, xla_shape); + ctx->SetOutput(0, xla::GetTupleElement(result, 0)); + ctx->SetOutput(1, xla::GetTupleElement(result, 1)); + } + + private: + DataType dtype_; + string device_type_string_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaRngBitGeneratorOp); +}; + +REGISTER_XLA_OP(Name("XlaRngBitGenerator") + .CompileTimeConstantInput("algorithm") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_UINT32, DT_UINT64}), + XlaRngBitGeneratorOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 42c45d6864253f..ff4cc5c9dd8675 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -807,6 +807,39 @@ window_strides: the inter-window strides padding: the padding to apply at the start and end of each input dimensions )doc"); +REGISTER_OP("XlaRngBitGenerator") + .Input("algorithm: int32") + .Input("initial_state: uint64") + .Input("shape: Tshape") + .Output("output_key: uint64") + .Output("output: dtype") + .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64") + .Attr("Tshape: {int32, int64} = DT_INT32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle algorithm; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm)); + shape_inference::ShapeHandle initial_state; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state)); + + c->set_output(0, initial_state); + shape_inference::ShapeHandle output; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output)); + c->set_output(1, output); + return Status::OK(); + }) + .Doc(R"doc( +Stateless PRNG bit generator. +Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + +algorithm: The PRNG algorithm to use, one of + tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. +initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be + a u64[2] and for PHILOX a u64[3]. +shape: The output shape of the generated data. +dtype: The type of the tensor. +)doc"); + REGISTER_OP("XlaSelectAndScatter") .Input("operand: T") .Input("window_dimensions: Tindices") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 1c136406cb6cfd..f9ee76c7d1652f 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -40,6 +40,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops.numpy_ops import np_utils # TODO(phawkins): provide wrappers for all XLA operators. Currently the missing @@ -377,6 +378,28 @@ def random_uniform(minval, maxval, dims, name=None): dims, minval, maxval, dtype=minval.dtype, name=name) +def rng_bit_generator(algorithm, initial_state, shape, dtype): + """Stateless PRNG bit generator. + + Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + + Args: + algorithm: The PRNG algorithm to use, one of + tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. + initial_state: Initial state for the PRNG algorithm. For THREEFRY, it + should be a u64[2] and for PHILOX a u64[3]. + shape: The output shape of the generated data. + dtype: The type of the tensor. + + Returns: + a tuple with a new state and generated data of the given shape. + """ + alg_int = stateless_random_ops.convert_alg_to_int(algorithm) + return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape, + dtype=dtype) + + recv = gen_xla_ops.xla_recv reduce = gen_xla_ops.xla_reduce variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2