[go: nahoru, domu]

Skip to content

Commit

Permalink
PR #47936: Add GPU support of AsString + StringToHashBucket
Browse files Browse the repository at this point in the history
Imported from GitHub PR #47936

This PR adds the GPU support for patterns of AsString + StringToHashBucket ops. We propose a new fused TensorToHashBucket op and use it via grappler remapper.

@nluehr @benbarsdell
Copybara import of the project:

--
164fa47 by Kaixi Hou <kaixih@nvidia.com>:

Add TensorToHashBucket operation

--
7ce5a3f by Kaixi Hou <kaixih@nvidia.com>:

Cleanup codes

--
3f613de by Kaixi Hou <kaixih@nvidia.com>:

Minor changes

--
13a3142 by Kaixi Hou <kaixih@nvidia.com>:

Set the thread number

--
e7f45d3 by Kaixi Hou <kaixih@nvidia.com>:

Change functions and buffer size

--
e84a362 by Kaixi Hou <kaixih@nvidia.com>:

Update the max integer digits

--
ca82c20 by Kaixi Hou <kaixih@nvidia.com>:

Remove unintended tabs

--
fadf828 by Kaixi Hou <kaixih@nvidia.com>:

Remove unintended tabs

PiperOrigin-RevId: 375722953
Change-Id: Ib9fcc4b481304570a9735911660d6aac11d81287
  • Loading branch information
jurahul authored and tensorflower-gardener committed May 25, 2021
1 parent 0b9fdf8 commit b6f0f31
Show file tree
Hide file tree
Showing 12 changed files with 833 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tensorflow/core/grappler/op_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ bool IsAssign(const NodeDef& node) {

bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }

bool IsAsString(const NodeDef& node) { return node.op() == "AsString"; }

bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }

bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
Expand Down Expand Up @@ -576,6 +578,10 @@ bool IsStridedSliceGrad(const NodeDef& node) {
return node.op() == "StridedSliceGrad";
}

bool IsStringToHashBucketFast(const NodeDef& node) {
return node.op() == "StringToHashBucketFast";
}

bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }

bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ bool IsArgMax(const NodeDef& node);
bool IsArgMin(const NodeDef& node);
bool IsAssert(const NodeDef& node);
bool IsAssign(const NodeDef& node);
bool IsAsString(const NodeDef& node);
bool IsAtan2(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsBetainc(const NodeDef& node);
Expand Down Expand Up @@ -187,6 +188,7 @@ bool IsStatefulPartitionedCall(const NodeDef& node);
bool IsStopGradient(const NodeDef& node);
bool IsStridedSlice(const NodeDef& node);
bool IsStridedSliceGrad(const NodeDef& node);
bool IsStringToHashBucketFast(const NodeDef& node);
bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
Expand Down
122 changes: 121 additions & 1 deletion tensorflow/core/grappler/optimizers/remapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,13 @@ constexpr char kFusedConv2D[] = "_FusedConv2D";
constexpr char kFusedMatMul[] = "_FusedMatMul";
constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative";
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";

constexpr char kTensorToHashBucket[] = "_TensorToHashBucketFast";
constexpr char kDataFormat[] = "data_format";
constexpr char kIsTraining[] = "is_training";

constexpr char kWidth[] = "width";
constexpr char kFill[] = "fill";

constexpr int kMissingIndex = -1;

struct RemapperContext {
Expand Down Expand Up @@ -111,6 +114,18 @@ struct FusedBatchNormEx {
int invalidated = kMissingIndex;
};

// TensorToHashBucket that can be replaced with AsString + StringToHashBucket.
// We also include the fanin node of AsString ("pre_as_string") to determine the
// device.
struct TensorToHashBucket {
TensorToHashBucket() = default;
explicit TensorToHashBucket(int op1, int op2, int op3)
: pre_as_string(op1), as_string(op2), string_to_hash_bucket(op3) {}

int pre_as_string = kMissingIndex;
int as_string = kMissingIndex;
int string_to_hash_bucket = kMissingIndex;
};
// Contraction node followed by a BiasAdd.
struct ContractionWithBiasAdd {
ContractionWithBiasAdd() = default;
Expand Down Expand Up @@ -970,6 +985,65 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
return false;
}

bool FindTensorToHashBucket(const RemapperContext& ctx, int node_index,
TensorToHashBucket* matched) {
// Root of the pattern must be a StringToHashBucketFast.
const auto* node_view = ctx.graph_view.GetNode(node_index);
const auto* node_def = node_view->node();

if (!IsStringToHashBucketFast(*node_def) ||
HasControlFaninOrFanout(*node_view)) {
return false;
}

// Input to the StringToHashBucketFast must be AsString.
if (node_view->NumRegularFanins() < 1) return false;

const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
const auto* as_string_node_view = regular_fanin_0.node_view();
const auto* as_string_node_def = as_string_node_view->node();
bool is_as_string = IsAsString(*as_string_node_def);

if (!is_as_string || HasControlFaninOrFanout(*as_string_node_view) ||
!HasAtMostOneFanoutAtPort0(*as_string_node_view) ||
IsInPreserveSet(ctx, as_string_node_def))
return false;

// DataType of AsString must be int8/16/32/64 and width/fill attrs must be
// default values.
if (!HasDataType(as_string_node_def, DT_INT8) &&
!HasDataType(as_string_node_def, DT_INT16) &&
!HasDataType(as_string_node_def, DT_INT32) &&
!HasDataType(as_string_node_def, DT_INT64)) {
return false;
}

int width;
if (!GetNodeAttr(*as_string_node_def, kWidth, &width).ok() || width != -1) {
return false;
}

string fill;
if (!GetNodeAttr(*as_string_node_def, kFill, &fill).ok() || !fill.empty()) {
return false;
}

// An input to the AsString must exist to determine the device.
if (as_string_node_view->NumRegularFanins() < 1) return false;

const auto& fanin_0 = as_string_node_view->GetRegularFanin(0);
const auto* pre_node_view = fanin_0.node_view();

// We successfully found a AsString + StringToHashBucketFast pattern.
const TensorToHashBucket pattern{pre_node_view->node_index(),
as_string_node_view->node_index(),
node_index};

*matched = pattern;

return true;
}

void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
const NodeDef* activation = nullptr) {
DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
Expand Down Expand Up @@ -1644,6 +1718,44 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
return mutation->Apply();
}

Status AddTensorToHashBucketNode(RemapperContext* ctx,
const TensorToHashBucket& matched,
std::vector<bool>* invalidated_nodes,
std::vector<bool>* nodes_to_delete) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& pre_as_string = graph->node(matched.pre_as_string);
const NodeDef& as_string = graph->node(matched.as_string);
const NodeDef& string_to_hash_bucket =
graph->node(matched.string_to_hash_bucket);
VLOG(2) << "Fuse AsString with StringToHashBucketFast:"
<< " as_string=" << as_string.name()
<< " string_to_hash_bucket=" << string_to_hash_bucket.name()
<< " on device=" << pre_as_string.device();

NodeDef fused_op;
fused_op.set_name(string_to_hash_bucket.name());
fused_op.set_device(pre_as_string.device());
fused_op.add_input(as_string.input(0)); // 0: input
fused_op.set_op(kTensorToHashBucket);

auto* attr = fused_op.mutable_attr();
auto& src_attr0 = as_string.attr();
auto& src_attr1 = string_to_hash_bucket.attr();
(*attr)["T"] = src_attr0.at("T");
(*attr)["num_buckets"] = src_attr1.at("num_buckets");

utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_op), &status);
TF_RETURN_IF_ERROR(status);
TF_RETURN_IF_ERROR(mutation->Apply());

(*invalidated_nodes)[matched.string_to_hash_bucket] = true;
(*nodes_to_delete)[matched.as_string] = true;

return Status::OK();
}

#ifdef INTEL_MKL
bool IsConv2DOrMatMul(const NodeDef& node) {
return IsConv2D(node) || IsMatMul(node);
Expand Down Expand Up @@ -1925,6 +2037,14 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}

TensorToHashBucket tensor_to_hash_bucket;
if (allow_non_differentiable_rewrites &&
FindTensorToHashBucket(ctx, i, &tensor_to_hash_bucket)) {
TF_RETURN_IF_ERROR(AddTensorToHashBucketNode(
&ctx, tensor_to_hash_bucket, &invalidated_nodes, &nodes_to_delete));
continue;
}

// During inference, most of the inputs to FusedBatchNorm are constant, and
// we can therefore replace the op with a much cheaper set of primitives.
FusedBatchNorm fused_batch_norm;
Expand Down
68 changes: 68 additions & 0 deletions tensorflow/core/grappler/optimizers/remapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,74 @@ TEST_F(RemapperTest, FuseConv2DWithBias) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}

class RemapperTensorToHashBucketTest : public RemapperTest {
public:
template <DataType DTYPE>
void RunTest() {
using ::tensorflow::ops::Placeholder;

tensorflow::Scope s = tensorflow::Scope::NewRootScope();

auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape);

int num_buckets = 100;
auto to_string = ops::AsString(s.WithOpName("to_string"), input);
auto to_bucket = ops::StringToHashBucketFast(s.WithOpName("to_bucket"),
to_string, num_buckets);
auto fetch = ops::Identity(s.WithOpName("fetch"), to_bucket);

auto input_t = GenerateRandomTensor<DTYPE>({8, 32, 32, 3});

GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));

// For CPU tests, we place all nodes on CPU. For GPU tests, we place the
// "input" node on GPU to determine the fused op to be on GPU.
const string input_device =
GetNumAvailableGPUs() > 0 ? "/device:GPU:0" : "/device:CPU:0";
for (int i = 0; i < item.graph.node_size(); ++i) {
if (item.graph.node(i).name() == "input") {
item.graph.mutable_node(i)->set_device(input_device);
} else {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
}

Remapper optimizer(RewriterConfig::ON);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));

int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "to_bucket") {
EXPECT_EQ(node.op(), "_TensorToHashBucketFast");
ASSERT_GE(node.input_size(), 1);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.attr().at("num_buckets").i(), num_buckets);
found++;
}
}
EXPECT_EQ(found, 1);

auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorEqual<int64>(tensors[0], tensors_expected[0]);
}
};

TEST_F(RemapperTensorToHashBucketTest, I8) { RunTest<DT_INT8>(); }

TEST_F(RemapperTensorToHashBucketTest, I16) { RunTest<DT_INT16>(); }

TEST_F(RemapperTensorToHashBucketTest, I32) { RunTest<DT_INT32>(); }

TEST_F(RemapperTensorToHashBucketTest, I64) { RunTest<DT_INT64>(); }

class RemapperFuseMatMulWithBiasTest : public RemapperTest {
public:
template <DataType DTYPE>
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ load(
"if_cuda_or_rocm",
"if_mobile",
"if_not_windows",
"if_oss",
"tf_cc_binary",
"tf_cc_shared_object",
"tf_cc_test",
Expand All @@ -20,6 +21,10 @@ load(
"if_mlir_generated_experimental_kernels_enabled",
"if_mlir_generated_gpu_kernels_enabled",
)
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_fingerprint_deps",
)

# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
Expand Down Expand Up @@ -5030,6 +5035,7 @@ cc_library(
":string_to_hash_bucket_op",
":string_upper_op",
":substr_op",
":tensor_to_hash_bucket_op",
":unicode_ops",
":unicode_script_op",
":unsorted_segment_join_op",
Expand Down Expand Up @@ -5070,6 +5076,15 @@ tf_kernel_library(
deps = STRING_DEPS,
)

tf_kernel_library(
name = "tensor_to_hash_bucket_op",
prefix = "tensor_to_hash_bucket_op",
deps = STRING_DEPS + if_oss(
if_cuda(["@farmhash_gpu_archive//:farmhash_gpu"]),
tf_fingerprint_deps(),
),
)

tf_kernel_library(
name = "reduce_join_op",
prefix = "reduce_join_op",
Expand Down
Loading

0 comments on commit b6f0f31

Please sign in to comment.