[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[determinism] Add GPU excepts, CPU d9m, and tests to crop_and_resize #48905

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions tensorflow/core/kernels/image/crop_and_resize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/determinism.h"
#include "tensorflow/core/util/work_sharder.h"

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand Down Expand Up @@ -189,7 +190,7 @@ class CropAndResizeOp : public AsyncOpKernel {

if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
errors::Internal("Failed to launch CropAndResizeKernel."));
}
};

Expand Down Expand Up @@ -407,6 +408,14 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {
context, grads.dim_size(3) == depth,
errors::InvalidArgument("image_size and grads are incompatible"), done);

if (std::is_same<Device, GPUDevice>::value) {
OP_REQUIRES_ASYNC(
context, !OpDeterminismRequired(), errors::Unimplemented(
"Deterministic GPU implementation of CropAndResizeBackpropImage"
" not available."),
done);
}

// Allocate output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
Expand All @@ -426,7 +435,7 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {

if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropImage kernel."));
"Failed to launch CropAndResizeBackpropImage kernel."));
}
};

Expand Down Expand Up @@ -545,8 +554,13 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {

const DeviceBase::CpuWorkerThreads& worker_threads =
*(context->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers, num_boxes,
cost_per_box, CropAndResizeBackImgPerBox);

// Sharding introduces nondeterminism when the gradients associated with
// more than two crops backprop into the same element in the source image.
int max_threads = OpDeterminismRequired() ? 1 : worker_threads.num_threads;

Shard(max_threads, worker_threads.workers, num_boxes, cost_per_box,
CropAndResizeBackImgPerBox);

return true;
}
Expand Down Expand Up @@ -610,6 +624,14 @@ class CropAndResizeGradBoxesOp : public AsyncOpKernel {
errors::InvalidArgument("boxes and grads have incompatible shape"),
done);

if (std::is_same<Device, GPUDevice>::value) {
OP_REQUIRES_ASYNC(
context, !OpDeterminismRequired(), errors::Unimplemented(
"Deterministic GPU implementation of CropAndResizeBackpropBoxes"
" not available."),
done);
}

// Allocate output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
Expand All @@ -628,7 +650,7 @@ class CropAndResizeGradBoxesOp : public AsyncOpKernel {
box_index.tensor<int32, 1>(), output->tensor<float, 2>());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropBoxes kernel."));
"Failed to launch CropAndResizeBackpropBoxes kernel."));
}
};

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ filegroup(
"bcast.h",
"command_line_flags.h",
"debug_events_writer.h",
"determinism.h",
"device_name_utils.h",
"dump_graph.h",
"einsum_op_util.h",
Expand Down Expand Up @@ -214,6 +215,7 @@ filegroup(
"bcast.cc",
"command_line_flags.cc",
"debug_events_writer.cc",
"determinism.cc",
"device_name_utils.cc",
"dump_graph.cc",
"equal_graph_def.cc",
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/core/util/determinism.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/util/determinism.h"
#include "tensorflow/core/util/env_var.h"

namespace tensorflow {

bool OpDeterminismRequired() {
static bool op_determinism_required = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
return deterministic_ops;
}();
return op_determinism_required;
}

} // namespace tensorflow
27 changes: 27 additions & 0 deletions tensorflow/core/util/determinism.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_UTIL_DETERMINISM_H_
#define TENSORFLOW_CORE_UTIL_DETERMINISM_H_

namespace tensorflow {

// TODO(duncanriach): use this to replace existing instances of
// RequireDeterminism in various core/kernels.
bool OpDeterminismRequired();

} // namespace tensorflow

#endif // TENSORFLOW_CORE_UTIL_DETERMINISM_H_
1 change: 1 addition & 0 deletions tensorflow/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2925,6 +2925,7 @@ py_library(
srcs_version = "PY3",
deps = [
":client_testlib",
":errors",
":gradients",
":image_ops",
"//tensorflow/python/framework:for_generated_wrappers",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ def testExceptionThrowing(self):


class SparseSoftmaxCrossEntropyWithLogitsDeterministicTest(test.TestCase):
"""Test that SparseSoftmaxCrossEntropyWithLogits operates reproducibly."""
"""Test that SparseSoftmaxCrossEntropyWithLogits operates reproducibly.

Note that the deterministic functionality currently tested by this class is
always activated (not enabled by TF_DETERMINISTIC_OPS), so this class does not
currently need to inherit from a base op test class (to ensure that the op
still functions correctly when determinism is enabled).
"""

def _randomInts(self, shape, high, dtype):
return constant_op.constant(
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/python/kernel_tests/xent_op_deterministic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def testExceptionThrowing(self):


class SoftmaxCrossEntropyWithLogitsDeterministicTest(test.TestCase):
"""Test that SoftmaxCrossEntropyWithLogits operates reproducibly."""
"""Test that SoftmaxCrossEntropyWithLogits operates reproducibly.

Note that the deterministic functionality currently tested by this class is
always activated (not enabled by TF_DETERMINISTIC_OPS), so this class does not
currently need to inherit from a base op test class (to ensure that the op
still functions correctly when determinism is enabled).
"""

def _randomFloats(self, shape, dtype, normalized_rows=False):
a = (2 * np.random.random_sample(shape) - 1).astype(dtype)
Expand Down
Loading