From 673b993983f37f332ff70cdb642305f69089337d Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 21 Oct 2020 05:44:08 -0700 Subject: [PATCH] Ensuring that the Switch op used as a pivot is always placed on the CPU. For this we set a private attribute _PivotSwitch while creating this op and then make sure that the device overwriting logic in GraphPartition isn't executed for this op. Note: Had to fix up control_flow_ops_py_test so that we don't expect a GPU graph when we don't get one. The reason is that now since we already know the switch_pred is going to be placed on CPU, the placer ensures that its input is placed on the CPU as well and we end up saving a copy. This means there is no GPU graph when we partition. PiperOrigin-RevId: 338246477 Change-Id: I5641c9ae1b2d593a2996947bafe92b22cb63371d --- tensorflow/core/common_runtime/BUILD | 3 +- tensorflow/core/common_runtime/lower_if_op.cc | 11 +- .../core/common_runtime/lower_if_op_test.cc | 109 ++++++++++++++++++ tensorflow/core/graph/graph_partition.cc | 7 ++ .../kernel_tests/control_flow_ops_py_test.py | 3 +- 5 files changed, 130 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index cf053b0af51c10..fcbf0c52905225 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -2522,10 +2522,11 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cc_test_gpu( name = "lower_if_op_test", size = "small", srcs = ["lower_if_op_test.cc"], + tags = tf_cuda_tests_tags(), deps = [ ":core_cpu", ":core_cpu_internal", diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index ff010ad8a63298..2a0e5d35de5849 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -148,13 +148,22 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder, Status CondBuilder::CreatePivotNodes() { // Construct the basic cond body (consisting of feeding in the predicate to // create pivot nodes). + + // This is a special pivot switch node for lowering. We mark this with a + // special _PivotSwitch attr on it as later on in the graph partitioner we + // do some special placement for Switch nodes and its necessary to distinguish + // between a "normal" Switch node and one of these pivot switches. We would + // like to place this node on the CPU always as the pred_ will be on the CPU + // as well (either a CPU op output or a GPU op with HostMemory annotation). + // TODO(b/171321391): Fix this for NUMA cases. Node* switch_pred; TF_RETURN_IF_ERROR( SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry(), &debug_info_) .Input(NodeOut(pred_)) .Input(NodeOut(pred_)) - .Device(if_op_->requested_device()), + .Attr("_PivotSwitch", true) + .Device("/CPU:0"), graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc index cf7d35409bb078..b0304cfe29360a 100644 --- a/tensorflow/core/common_runtime/lower_if_op_test.cc +++ b/tensorflow/core/common_runtime/lower_if_op_test.cc @@ -147,6 +147,115 @@ TEST(LowerIfOpTest, Simple) { } } +TEST(LowerIfOpTest, GPUPlacement) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + // Add test functions for then and else branch. + FunctionDefLibrary f_lib_proto; + *(f_lib_proto.add_function()) = test::function::XTimesTwo(); + *(f_lib_proto.add_function()) = test::function::XTimesFour(); + + // Construct simple conditional that switches on `pred` and operates only on + // single input `A`. + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto)); + auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32); + auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32); + auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32); + Node* pred; + TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def()) + .Input(x.node()) + .Input(y.node()) + .Device("/GPU:0") + .Finalize(root.graph(), &pred)); + Node* written_if; + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + TF_ASSERT_OK( + NodeBuilder("if", "If", &root.graph()->flib_def()) + .Input(pred) + .Input(inputs) + .Attr("then_branch", FuncAttr("XTimesTwo")) + .Attr("else_branch", FuncAttr("XTimesFour")) + .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) + .Attr("Tout", {DT_INT32}) + .Device("/GPU:0") + .Finalize(root.graph(), &written_if)); + TF_ASSERT_OK(root.DoShapeInference(written_if)); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + // The input graph has no switch or merge nodes. + int node_called_if_count = 0; + for (const auto* op : graph->op_nodes()) { + ASSERT_FALSE(op->IsSwitch()); + ASSERT_FALSE(op->IsMerge()); + if (op->name() == "if") { + ++node_called_if_count; + } + } + ASSERT_EQ(node_called_if_count, 1); + + TF_ASSERT_OK(Rewrite(&graph)); + + // Verify the resultant graph has switch and merge nodes, and a node called + // `if` (but not If nodes). + int switch_count = 0; + int merge_count = 0; + node_called_if_count = 0; + for (const auto* op : graph->op_nodes()) { + if (op->IsSwitch()) { + ++switch_count; + } + if (op->IsMerge()) { + ++merge_count; + } + ASSERT_NE(op->type_string(), "If"); + if (op->name() == "if") { + ++node_called_if_count; + } + } + // One switch for predicate and one for input (A). + ASSERT_EQ(switch_count, 2); + // One merge for the single output value of then and else, and one more merge + // to enforce then and else function call execution (`branch_executed` node). + ASSERT_EQ(merge_count, 2); + ASSERT_EQ(node_called_if_count, 1); + + // Verify execution. + ClientSession session(root, SessionOptionsWithInlining()); + { + RunMetadata metadata; + RunOptions options; + options.set_output_partition_graphs(true); + ClientSession::FeedType feeds; + feeds.emplace(Output(x.node()), Input::Initializer(5)); + feeds.emplace(Output(y.node()), Input::Initializer(10)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {}, + &out_tensors, &metadata)); + GraphDef cpu_graph = metadata.partition_graphs(1); + int num_cpu_switch = 0; + for (const auto& node : cpu_graph.node()) { + if (node.op() == "Switch") { + ++num_cpu_switch; + } + } + EXPECT_EQ(num_cpu_switch, 2); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 40); + } + { + ClientSession::FeedType feeds; + feeds.emplace(Output(x.node()), Input::Initializer(10)); + feeds.emplace(Output(y.node()), Input::Initializer(5)); + feeds.emplace(Output(a.node()), Input::Initializer(10)); + std::vector out_tensors; + TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors)); + EXPECT_EQ(out_tensors.size(), 1); + EXPECT_EQ(out_tensors[0].scalar()(), 20); + } +} + TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) { using ::tensorflow::test::function::GDef; using ::tensorflow::test::function::NDef; diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index bf57e263441bd0..7680bcacba5a8e 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -371,6 +371,13 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, void OptimizeControlFlowColocation(Graph* graph) { auto visit = [](Node* node) { if (IsSwitch(node)) { + // Pivot Switch nodes (which are also of type Switch) are already placed + // on the CPU and colocated with its inputs that are also already on the + // CPU (or might be placed on GPU but in host memory). + if (HasNodeAttr(node->def(), "_PivotSwitch")) { + DCHECK(node->requested_device().find("CPU") != string::npos); + return; + } for (const Edge* in_edge : node->in_edges()) { if (in_edge->dst_input() == 0) { // Colocate with the data input. diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 54bbd2b2e9ec95..532dac1d85a992 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -730,6 +730,8 @@ def _count_matching_switch_nodes_on_device(self, run_metadata, device_str, g for g in run_metadata.partition_graphs if device_str in g.node[0].device ] + if not device_graphs: + return 0 self.assertLen(device_graphs, 1) switch_nodes = [ n for n in device_graphs[0].node @@ -759,7 +761,6 @@ def true_fn(): options = config_pb2.RunOptions(output_partition_graphs=True) sess.run( r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) - self.assertLen(run_metadata.partition_graphs, 2) # Check that the Switch for `arg` gets placed on CPU. self.assertEqual( self._count_matching_switch_nodes_on_device(run_metadata, "CPU",