[go: nahoru, domu]

Skip to content

Commit

Permalink
Add a new op XlaRngBitGenerator that wraps the HLO op RngBitGenertor.
Browse files Browse the repository at this point in the history
See https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator.

The use case for this op is for jax2tf, where we need to expose the plain
HLO op. There are already TensorFlow ops that use the same HLO op, but they
all seem to contain a certain amount of sugar on top, which makes it hard
to ensure that the TF op has the exact semantics of the HLO op.

PiperOrigin-RevId: 388242420
Change-Id: I55b35c36a087b53b20694aac66f13a1aa8e2f1d2
  • Loading branch information
gnecula authored and tensorflower-gardener committed Aug 2, 2021
1 parent 71e2833 commit e02c9ff
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/compilability_check_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ static auto const ops_triggering_xla_compilation =
"XlaReduce",
"XlaReduceWindow",
"XlaReplicaId",
"XlaRngBitGenerator",
"XlaScatter",
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
Expand All @@ -714,6 +715,8 @@ static auto const ops_triggering_xla_compilation =
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaVariadicReduceV2",
"XlaVariadicSort",
"XlaWhile"};

static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaReduceWindow",
"XlaRemoveDynamicDimensionSize",
"XlaReplicaId",
"XlaRngBitGenerator",
"XlaScatter",
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/jit/xla_ops_on_regular_devices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ namespace tensorflow {
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaRngBitGenerator") \
.HostMemory("algorithm") \
.HostMemory("shape") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -19481,6 +19481,31 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect, TF_NoConstantFold]>
// constant folded at the compile time.
}

def TF_XlaRngBitGeneratorOp : TF_Op<"XlaRngBitGenerator", [NoSideEffect]> {
let summary = "Stateless PRNG bit generator.";

let description = [{
Wraps the XLA RngBitGenerator operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
}];

let arguments = (ins
Arg<TF_Int32Tensor, [{The PRNG algorithm to use, one of
tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.}]>:$algorithm,
Arg<TF_Uint64Tensor, [{Initial state for the PRNG algorithm. For THREEFRY, it should be
a u64[2] and for PHILOX a u64[3].}]>:$initial_state,
Arg<TF_I32OrI64Tensor, [{The output shape of the generated data.}]>:$shape
);

let results = (outs
TF_Uint64Tensor:$output_key,
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
);

TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<2>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<1>;
}

def TF_XlaScatterOp : TF_Op<"XlaScatter", [NoSideEffect]> {
let summary = "Wraps the XLA Scatter operator documented at";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ void AddSupportedFunctionalOps(MLIRContext* context,
OperationName(TF::XlaReduceOp::getOperationName(), context));
supported_ops->insert(
OperationName(TF::XlaReduceWindowOp::getOperationName(), context));
supported_ops->insert(
OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context));
supported_ops->insert(
OperationName(TF::XlaScatterOp::getOperationName(), context));
supported_ops->insert(
Expand Down
56 changes: 56 additions & 0 deletions tensorflow/compiler/tests/xla_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import googletest


Expand Down Expand Up @@ -316,6 +317,26 @@ def pad_fn(x):
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
dtype=dtype))

@parameterized.parameters(stateless_random_ops.Algorithm.THREEFRY,
stateless_random_ops.Algorithm.PHILOX,
stateless_random_ops.Algorithm.AUTO_SELECT)
@test_util.disable_mlir_bridge('Not supported yet')
def testRngBitGeneratorIsDeterministic(self, algorithm):
dtype = np.uint32
key = np.array([1, 2], dtype=np.uint64)
shape = (10, 12)

def rng_fun_is_deterministic(k):
res1 = xla.rng_bit_generator(algorithm, k, shape, dtype=dtype)
res2 = xla.rng_bit_generator(algorithm, k, shape, dtype=dtype)
return (res1[0] - res2[0], res1[1] - res2[1])

self._assertOpOutputMatchesExpected(
rng_fun_is_deterministic,
args=(key,),
expected=(np.zeros(key.shape, dtype=key.dtype),
np.zeros(shape, dtype=dtype)))

@test_util.disable_mlir_bridge('Not supported yet')
def testReduce(self):
for dtype in set(self.numeric_types).intersection(
Expand Down Expand Up @@ -968,6 +989,41 @@ def assert_output_shapes(output, expected_shape):
'All inputs must have the same shape'):
reduce_with_shapes((None, 4, 5), (3, None, 5), (13, 4, 5))

@parameterized.parameters(stateless_random_ops.Algorithm.THREEFRY,
stateless_random_ops.Algorithm.PHILOX,
stateless_random_ops.Algorithm.AUTO_SELECT)
def testRngBitGenerator(self, algorithm):
dtype = np.uint64
initial_state = array_ops.placeholder(np.uint64, shape=(2,))
shape = (2, 3)
res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype)

self.assertEqual(res[0].shape, initial_state.shape)
self.assertEqual(res[1].shape, shape)

# The initial_state has unknown dimension size
initial_state = array_ops.placeholder(np.uint64, shape=(None,))
shape = (2, 3)
res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype)

self.assertEqual(res[0].shape.as_list(), initial_state.shape.as_list())
self.assertEqual(res[1].shape, shape)

# The initial_state has unknown rank
initial_state = array_ops.placeholder(np.uint64, shape=None)
shape = (2, 3)
res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype)

self.assertEqual(res[0].shape.as_list(), [None])
self.assertEqual(res[1].shape, shape)

# The output shape has unknown dimension
initial_state = array_ops.placeholder(np.uint64, shape=(None,))
shape = (None, 3)
with self.assertRaisesRegex(TypeError,
'Failed to convert object .* to Tensor'):
res = xla.rng_bit_generator(algorithm, initial_state, shape, dtype=dtype)


if __name__ == '__main__':
# This test is using Tensorflow sessions which are not compatible with eager
Expand Down
38 changes: 38 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,5 +575,43 @@ class GetAlgOp : public XlaOpKernel {

REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp);

class XlaRngBitGeneratorOp : public XlaOpKernel {
public:
explicit XlaRngBitGeneratorOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}

void Compile(XlaOpKernelContext* ctx) override {
xla::RandomAlgorithm algorithm;
OP_REQUIRES_OK(ctx,
AlgorithmFromInput(ctx, 0, device_type_string_, &algorithm));
xla::XlaOp initial_state = ctx->Input(1);

TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(2, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));

xla::XlaOp result =
xla::RngBitGenerator(algorithm, initial_state, xla_shape);
ctx->SetOutput(0, xla::GetTupleElement(result, 0));
ctx->SetOutput(1, xla::GetTupleElement(result, 1));
}

private:
DataType dtype_;
string device_type_string_;

TF_DISALLOW_COPY_AND_ASSIGN(XlaRngBitGeneratorOp);
};

REGISTER_XLA_OP(Name("XlaRngBitGenerator")
.CompileTimeConstantInput("algorithm")
.CompileTimeConstantInput("shape")
.TypeConstraint("dtype", {DT_UINT32, DT_UINT64}),
XlaRngBitGeneratorOp);

} // namespace
} // namespace tensorflow
33 changes: 33 additions & 0 deletions tensorflow/compiler/tf2xla/ops/xla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,39 @@ window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
)doc");

REGISTER_OP("XlaRngBitGenerator")
.Input("algorithm: int32")
.Input("initial_state: uint64")
.Input("shape: Tshape")
.Output("output_key: uint64")
.Output("output: dtype")
.Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
.Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle algorithm;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &algorithm));
shape_inference::ShapeHandle initial_state;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &initial_state));

c->set_output(0, initial_state);
shape_inference::ShapeHandle output;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &output));
c->set_output(1, output);
return Status::OK();
})
.Doc(R"doc(
Stateless PRNG bit generator.
Wraps the XLA RngBitGenerator operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
algorithm: The PRNG algorithm to use, one of
tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be
a u64[2] and for PHILOX a u64[3].
shape: The output shape of the generated data.
dtype: The type of the tensor.
)doc");

REGISTER_OP("XlaSelectAndScatter")
.Input("operand: T")
.Input("window_dimensions: Tindices")
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/compiler/tf2xla/python/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops.numpy_ops import np_utils

# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
Expand Down Expand Up @@ -377,6 +378,28 @@ def random_uniform(minval, maxval, dims, name=None):
dims, minval, maxval, dtype=minval.dtype, name=name)


def rng_bit_generator(algorithm, initial_state, shape, dtype):
"""Stateless PRNG bit generator.
Wraps the XLA RngBitGenerator operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
Args:
algorithm: The PRNG algorithm to use, one of
tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
initial_state: Initial state for the PRNG algorithm. For THREEFRY, it
should be a u64[2] and for PHILOX a u64[3].
shape: The output shape of the generated data.
dtype: The type of the tensor.
Returns:
a tuple with a new state and generated data of the given shape.
"""
alg_int = stateless_random_ops.convert_alg_to_int(algorithm)
return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape,
dtype=dtype)


recv = gen_xla_ops.xla_recv
reduce = gen_xla_ops.xla_reduce
variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2
Expand Down

0 comments on commit e02c9ff

Please sign in to comment.