[go: nahoru, domu]

Skip to content

Commit

Permalink
Ensuring that the Switch op used as a pivot is always placed on the C…
Browse files Browse the repository at this point in the history
…PU. For this we set a private attribute _PivotSwitch while creating this op and then make sure that the device overwriting logic in GraphPartition isn't executed for this op.

Note: Had to fix up control_flow_ops_py_test so that we don't expect a GPU graph when we don't get one. The reason is that now since we already know the switch_pred is going to be placed on CPU, the placer ensures that its input is placed on the CPU as well and we end up saving a copy. This means there is no GPU graph when we partition.
PiperOrigin-RevId: 338246477
Change-Id: I5641c9ae1b2d593a2996947bafe92b22cb63371d
  • Loading branch information
rohan100jain authored and tensorflower-gardener committed Oct 21, 2020
1 parent be24f6d commit 673b993
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tensorflow/core/common_runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2522,10 +2522,11 @@ tf_cc_test(
],
)

tf_cc_test(
tf_cc_test_gpu(
name = "lower_if_op_test",
size = "small",
srcs = ["lower_if_op_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":core_cpu",
":core_cpu_internal",
Expand Down
11 changes: 10 additions & 1 deletion tensorflow/core/common_runtime/lower_if_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,22 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder,
Status CondBuilder::CreatePivotNodes() {
// Construct the basic cond body (consisting of feeding in the predicate to
// create pivot nodes).

// This is a special pivot switch node for lowering. We mark this with a
// special _PivotSwitch attr on it as later on in the graph partitioner we
// do some special placement for Switch nodes and its necessary to distinguish
// between a "normal" Switch node and one of these pivot switches. We would
// like to place this node on the CPU always as the pred_ will be on the CPU
// as well (either a CPU op output or a GPU op with HostMemory annotation).
// TODO(b/171321391): Fix this for NUMA cases.
Node* switch_pred;
TF_RETURN_IF_ERROR(
SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
graph_->op_registry(), &debug_info_)
.Input(NodeOut(pred_))
.Input(NodeOut(pred_))
.Device(if_op_->requested_device()),
.Attr("_PivotSwitch", true)
.Device("/CPU:0"),
graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
Expand Down
109 changes: 109 additions & 0 deletions tensorflow/core/common_runtime/lower_if_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,115 @@ TEST(LowerIfOpTest, Simple) {
}
}

TEST(LowerIfOpTest, GPUPlacement) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));

// Add test functions for then and else branch.
FunctionDefLibrary f_lib_proto;
*(f_lib_proto.add_function()) = test::function::XTimesTwo();
*(f_lib_proto.add_function()) = test::function::XTimesFour();

// Construct simple conditional that switches on `pred` and operates only on
// single input `A`.
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32);
auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32);
Node* pred;
TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def())
.Input(x.node())
.Input(y.node())
.Device("/GPU:0")
.Finalize(root.graph(), &pred));
Node* written_if;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
TF_ASSERT_OK(
NodeBuilder("if", "If", &root.graph()->flib_def())
.Input(pred)
.Input(inputs)
.Attr("then_branch", FuncAttr("XTimesTwo"))
.Attr("else_branch", FuncAttr("XTimesFour"))
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
.Attr("Tout", {DT_INT32})
.Device("/GPU:0")
.Finalize(root.graph(), &written_if));
TF_ASSERT_OK(root.DoShapeInference(written_if));
TF_ASSERT_OK(root.ToGraph(graph.get()));

// The input graph has no switch or merge nodes.
int node_called_if_count = 0;
for (const auto* op : graph->op_nodes()) {
ASSERT_FALSE(op->IsSwitch());
ASSERT_FALSE(op->IsMerge());
if (op->name() == "if") {
++node_called_if_count;
}
}
ASSERT_EQ(node_called_if_count, 1);

TF_ASSERT_OK(Rewrite(&graph));

// Verify the resultant graph has switch and merge nodes, and a node called
// `if` (but not If nodes).
int switch_count = 0;
int merge_count = 0;
node_called_if_count = 0;
for (const auto* op : graph->op_nodes()) {
if (op->IsSwitch()) {
++switch_count;
}
if (op->IsMerge()) {
++merge_count;
}
ASSERT_NE(op->type_string(), "If");
if (op->name() == "if") {
++node_called_if_count;
}
}
// One switch for predicate and one for input (A).
ASSERT_EQ(switch_count, 2);
// One merge for the single output value of then and else, and one more merge
// to enforce then and else function call execution (`branch_executed` node).
ASSERT_EQ(merge_count, 2);
ASSERT_EQ(node_called_if_count, 1);

// Verify execution.
ClientSession session(root, SessionOptionsWithInlining());
{
RunMetadata metadata;
RunOptions options;
options.set_output_partition_graphs(true);
ClientSession::FeedType feeds;
feeds.emplace(Output(x.node()), Input::Initializer(5));
feeds.emplace(Output(y.node()), Input::Initializer(10));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {},
&out_tensors, &metadata));
GraphDef cpu_graph = metadata.partition_graphs(1);
int num_cpu_switch = 0;
for (const auto& node : cpu_graph.node()) {
if (node.op() == "Switch") {
++num_cpu_switch;
}
}
EXPECT_EQ(num_cpu_switch, 2);
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(x.node()), Input::Initializer(10));
feeds.emplace(Output(y.node()), Input::Initializer(5));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
}
}

TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) {
using ::tensorflow::test::function::GDef;
using ::tensorflow::test::function::NDef;
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/graph/graph_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,13 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
void OptimizeControlFlowColocation(Graph* graph) {
auto visit = [](Node* node) {
if (IsSwitch(node)) {
// Pivot Switch nodes (which are also of type Switch) are already placed
// on the CPU and colocated with its inputs that are also already on the
// CPU (or might be placed on GPU but in host memory).
if (HasNodeAttr(node->def(), "_PivotSwitch")) {
DCHECK(node->requested_device().find("CPU") != string::npos);
return;
}
for (const Edge* in_edge : node->in_edges()) {
if (in_edge->dst_input() == 0) {
// Colocate with the data input.
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/kernel_tests/control_flow_ops_py_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,8 @@ def _count_matching_switch_nodes_on_device(self, run_metadata, device_str,
g for g in run_metadata.partition_graphs
if device_str in g.node[0].device
]
if not device_graphs:
return 0
self.assertLen(device_graphs, 1)
switch_nodes = [
n for n in device_graphs[0].node
Expand Down Expand Up @@ -759,7 +761,6 @@ def true_fn():
options = config_pb2.RunOptions(output_partition_graphs=True)
sess.run(
r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
self.assertLen(run_metadata.partition_graphs, 2)
# Check that the Switch for `arg` gets placed on CPU.
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
Expand Down

0 comments on commit 673b993

Please sign in to comment.