[go: nahoru, domu]

Skip to content

Commit

Permalink
Implement "bad_indcies_policy" for ScatterNd.
Browse files Browse the repository at this point in the history
For testing, we also introduced "ScatterNdTest" that verifies the existing "default" behavior for comparison.

PiperOrigin-RevId: 640429046
  • Loading branch information
tensorflower-gardener committed Jun 5, 2024
1 parent 4430cec commit b151b39
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 10 deletions.
93 changes: 83 additions & 10 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ limitations under the License.
// See docs in ../ops/state_ops.cc.
#define EIGEN_USE_THREADS

#include <string>
#include <type_traits>

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "absl/status/statusor.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
Expand All @@ -35,8 +40,11 @@ limitations under the License.
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bad_indices_policy.h"
#include "tensorflow/core/util/determinism.h"
#include "tensorflow/core/util/util.h"

Expand All @@ -45,6 +53,19 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

namespace {
constexpr char kBadIndicesPolicyAtrr[] = "bad_indices_policy";
} // namespace

namespace functor {

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate, BadIndicesPolicy bad_indices_policy);
} // namespace functor

// Returns true if the three tensors have valid number of elements
// If shape_input has 0 elements, then we need to have indices and updates with
// exactly 0 elements too, otherwise we should error. If indices has 0 elements
Expand All @@ -65,6 +86,19 @@ class ScatterNdOp : public OpKernel {
const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
std::string bad_indices_policy_str;
OP_REQUIRES_OK(c,
c->GetAttr(kBadIndicesPolicyAtrr, &bad_indices_policy_str));
absl::StatusOr<BadIndicesPolicy> bad_indices_policy =
BadIndicesPolicyFromString(bad_indices_policy_str);
OP_REQUIRES_OK(c, bad_indices_policy.status());
bad_indices_policy_ = *bad_indices_policy;
if constexpr (std::is_same<Device, GPUDevice>::value) {
OP_REQUIRES(
c, bad_indices_policy_ != BadIndicesPolicy::kError,
errors::InvalidArgument(
"ERROR bad_indices_policy is not supported on GPU devices."));
}
}

void Compute(OpKernelContext* c) override {
Expand Down Expand Up @@ -128,9 +162,13 @@ class ScatterNdOp : public OpKernel {
Tensor out;
OP_REQUIRES_OK(
c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
c, indices, updates, shape, &out, true /*allocate*/));
c, indices, updates, shape, &out, true /*allocate*/,
bad_indices_policy_));
c->set_output(0, out);
}

private:
BadIndicesPolicy bad_indices_policy_ = BadIndicesPolicy::kDefault;
};

template <typename Device, typename T, typename Index,
Expand Down Expand Up @@ -881,7 +919,8 @@ template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Tensor* out, bool allocate,
BadIndicesPolicy bad_indices_policy) {
int64_t slice_dim;
Index num_updates;
Index slice_size;
Expand Down Expand Up @@ -947,7 +986,11 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
slice_dim);
}
}
if (bad_i >= 0) {
const bool check_bad_indices =
((std::is_same<Device, CPUDevice>::value &&
bad_indices_policy == BadIndicesPolicy::kDefault) ||
bad_indices_policy == BadIndicesPolicy::kError);
if (check_bad_indices && bad_i >= 0) {
auto slice_shape = indices.shape();
slice_shape.RemoveLastDims(1);
return errors::InvalidArgument(
Expand All @@ -959,10 +1002,28 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
return absl::OkStatus();
}

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
return DoScatterNdImpl<Device, T, Index, Op>(
c, indices, updates, shape, out, allocate, BadIndicesPolicy::kDefault);
}

template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate);
Tensor* out, bool allocate,
BadIndicesPolicy bad_indices_policy);

template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate, BadIndicesPolicy::kDefault);
}

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

Expand All @@ -973,7 +1034,8 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Tensor* out, bool allocate,
BadIndicesPolicy bad_indices_policy) {
AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
alloc_attr.set_gpu_compatible(true);
Expand Down Expand Up @@ -1016,7 +1078,8 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,

TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));
c, host_indices, host_updates, shape, &host_out, /*allocate=*/false,
bad_indices_policy));

// Copy 'host_out' to device.
se::DeviceMemoryBase out_ptr(out->flat<T>().data(),
Expand All @@ -1036,7 +1099,7 @@ template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate) {
bool allocate, BadIndicesPolicy bad_indices_policy) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value &&
tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) {
Expand All @@ -1050,12 +1113,22 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
if constexpr (std::is_same<Device, GPUDevice>::value &&
std::is_integral<T>::value) {
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
allocate, bad_indices_policy);
} else {
return DoScatterNdImpl<Device, T, Index, Op>(c, indices, updates, shape,
out, allocate);
return DoScatterNdImpl<Device, T, Index, Op>(
c, indices, updates, shape, out, allocate, bad_indices_policy);
}
}

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate) {
return DoScatterNd<Device, T, Index, Op>(
c, indices, updates, shape, out, allocate, BadIndicesPolicy::kDefault);
}

} // namespace functor

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand Down
126 changes: 126 additions & 0 deletions tensorflow/core/kernels/scatter_nd_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"

Expand Down Expand Up @@ -240,6 +241,131 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
<< s;
}

class ScatterNdOpTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNd")
.Input(FakeInput(index_type))
.Input(FakeInput(variable_type))
.Input(FakeInput(DT_INT32))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};

TEST_F(ScatterNdOpTest, Simple_OneD) {
MakeOp(DT_FLOAT, DT_INT32);

// Feed and run
// Index: [[0], [4], [2]].
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
// Updates: [100, 101, 102].
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
// Shape: output tensor of 5x1 shape.
AddInputFromArray<int32>(TensorShape({2}), {5, 1});
TF_ASSERT_OK(RunOpKernel());

// Check the output.
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}

TEST_F(ScatterNdOpTest, Error_IndexOutOfRange) {
MakeOp(DT_FLOAT, DT_INT32);

// Feed and run
// Put the bad index in the middle to make sure the others are still updated.
// Index: [[0], [5], [2]].
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 5, 2});
// Updates: [100, 101, 102].
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
// Shape: output tensor of 5x1 shape.
AddInputFromArray<int32>(TensorShape({2}), {5, 1});
Status s = RunOpKernel();
// The valid index range is [0,5). Expect "5" to raise error.
EXPECT_TRUE(absl::StrContains(
s.ToString(), "indices[1] = [5] does not index into shape [5,1]"))
<< s;
}

class ScatterNdOpErrorOnBadIndicesTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNd")
.Input(FakeInput(index_type))
.Input(FakeInput(variable_type))
.Input(FakeInput(DT_INT32))
.Attr("bad_indices_policy", "ERROR")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};

TEST_F(ScatterNdOpErrorOnBadIndicesTest, Error_IndexOutOfRange) {
MakeOp(DT_FLOAT, DT_INT32);

// Feed and run
// Put the bad index in the middle to make sure the others are still updated.
// Index: [[0], [5], [2]].
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 5, 2});
// Updates: [100, 101, 102].
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
// Shape: output tensor of 5x1 shape.
AddInputFromArray<int32>(TensorShape({2}), {5, 1});
Status s = RunOpKernel();
// The valid index range is [0,5). Expect "5" to raise error.
EXPECT_TRUE(absl::StrContains(
s.ToString(), "indices[1] = [5] does not index into shape [5,1]"))
<< s;
}

class ScatterNdOpIgnoreBadIndicesTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNd")
.Input(FakeInput(index_type))
.Input(FakeInput(variable_type))
.Input(FakeInput(DT_INT32))
.Attr("bad_indices_policy", "IGNORE")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};

TEST_F(ScatterNdOpIgnoreBadIndicesTest, DropOutOfRangeIndices) {
MakeOp(DT_FLOAT, DT_INT32);

// Feed and run
// Put the bad index in the middle to make sure the others are still updated.
// Index: [[0], [5], [2]].
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 5, 2});
// Updates: [100, 101, 102].
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
// Shape: output tensor of 5x1 shape.
AddInputFromArray<int32>(TensorShape({2}), {5, 1});
TF_ASSERT_OK(RunOpKernel());

// Check the output.
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
// The valid index range is [0,5). Expect to drop index[1] of value "5" and
// update otuput[0] and output[2].
test::FillValues<float>(&expected, {100, 0, 102, 0, 0});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}

class ScatterNdOpConstructionTest : public OpsTestBase {};

TEST_F(ScatterNdOpConstructionTest, Error_BadIndicesPolicyInvalid) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNd")
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32))
.Attr("bad_indices_policy", "AN_UNRECOGNIZED_POLICY")
.Finalize(node_def()));
EXPECT_NE(InitOp(), OkStatus());
}

class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
public:
void TestBody() override {}
Expand Down

0 comments on commit b151b39

Please sign in to comment.