[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HloReachabilityMap does not update the BitVectors properly when there is a cycle in the HloComputation #56542

Open
iamohcy opened this issue Jun 23, 2022 · 2 comments
Assignees
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug

Comments

@iamohcy
Copy link
iamohcy commented Jun 23, 2022
Click to expand!

Issue Type

Bug

Source

source

Tensorflow Version

tf 2.9

Custom Code

Yes

OS Platform and Distribution

Linux Ubuntu 20.04

Mobile device

No response

Python version

3.8.10

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

In the HloReachabilityMap, the BitVectors are updated in topological order starting from the root node. However, if there is a cycle in the graph, e.g. A->B->C->A, with A being closer to the root, the following occurs:

  1. A gets updated (OR-ed) with BitVector of C
  2. At this point C's BitVector, has not been initialized, so this operation does nothing.

Basically, if there's a cycle whichever instruction in the cycle is processed earlier will not be updated with the actual reachability BitVectors of the later instructions.

Since the HloReachabilityMap is used in "InstructionFusion::MultiOutputFusionCreatesCycle" to detect cycles, so I believe it will no work properly in this use case.

A trivial fix would be to just perform the BitVector update twice in HloReachabilityMap::Build, i.e. just repeat the following code twice:

  for (const HloInstruction* hlo : all) {
    inputs.clear();
    add_dependencies(hlo);
    switch (hlo->opcode()) {
      case HloOpcode::kRecvDone: {
        auto it = channel_group.find(*hlo->channel_id());
        if (it != channel_group.end()) {
          for (HloInstruction* channel : it->second) {
            if (channel->opcode() == HloOpcode::kSend) {
              add_input(channel);
            }
          }
        }
        break;
      }
      case HloOpcode::kAllReduce:
      case HloOpcode::kReduceScatter: {
        auto channel_id = hlo->channel_id();
        if (channel_id) {
          auto it = channel_group.find(channel_id.value());
          if (it != channel_group.end()) {
            for (HloInstruction* all_reduce : it->second) {
              add_dependencies(all_reduce);
            }
          }
        }
        break;
      }
      default:
        break;
    }

    result->FastSetReachabilityToUnion(inputs, hlo);
  }

By running it twice, we can ensure that e.g. in the second loop when we load A's BitMap with C's, C's BitMap already denotes that it's reachable from A, B and C (itself)

Standalone code to reproduce the issue

TEST(CycleDetectionTestBase, Basic) {
  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
  auto builder = HloComputation::Builder("CycleDetection");
  auto constant1 = builder.AddInstruction(
      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
  auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
      r0f32, HloOpcode::kAdd, constant1, constant1));
  auto add2 = builder.AddInstruction(
      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, add1));
  auto add3 = builder.AddInstruction(
      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add2, add2));
  auto add4 = builder.AddInstruction(
      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add3, add3));
  // Create cycle
  add1->ReplaceOperandWith(0, add3);

  auto module = CreateNewVerifiedModule();
  auto computation =
      module->AddEntryComputation(builder.Build(/*root_instruction=*/add4));

  std::unique_ptr<HloReachabilityMap> reachability =
      HloReachabilityMap::Build(computation);

  EXPECT_TRUE(reachability->IsReachable(add3, add1));
  EXPECT_TRUE(reachability->IsReachable(add1, add3));
  EXPECT_TRUE(reachability->IsReachable(add3, add2));
  EXPECT_TRUE(reachability->IsReachable(add2, add3));
}

Relevant log output

No response

@google-ml-butler google-ml-butler bot added the type:bug Bug label Jun 23, 2022
@tilakrayal tilakrayal added comp:xla XLA TF 2.9 Issues found in the TF 2.9 release (or RCs) labels Jun 23, 2022
@tilakrayal
Copy link
Contributor

Hi @iamohcy,
To expedite the trouble-shooting process, could you please provide the complete code to reproduce the issue reported here. Thank you!

@tilakrayal tilakrayal added the stat:awaiting response Status - Awaiting response from author label Jun 23, 2022
@iamohcy
Copy link
Author
iamohcy commented Jun 24, 2022

Hi @tilkrayal, to reproduce this issue, you can add this test to "hlo_reachability_test.cc" and it should fail

TEST(CycleDetectionTestBase, Basic) {
  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
  auto builder = HloComputation::Builder("CycleDetection");
  auto constant1 = builder.AddInstruction(
      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
  auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
      r0f32, HloOpcode::kAdd, constant1, constant1));
  auto add2 = builder.AddInstruction(
      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, add1));
  auto add3 = builder.AddInstruction(
      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add2, add2));
  auto add4 = builder.AddInstruction(
      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add3, add3));
  // Create cycle
  add1->ReplaceOperandWith(0, add3);

  auto module = CreateNewVerifiedModule();
  auto computation =
      module->AddEntryComputation(builder.Build(/*root_instruction=*/add4));

  std::unique_ptr<HloReachabilityMap> reachability =
      HloReachabilityMap::Build(computation);

  EXPECT_TRUE(reachability->IsReachable(add3, add1));
  EXPECT_TRUE(reachability->IsReachable(add1, add3));
  EXPECT_TRUE(reachability->IsReachable(add3, add2));
  EXPECT_TRUE(reachability->IsReachable(add2, add3));
}

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Jun 24, 2022
@sachinprasadhs sachinprasadhs added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug
Projects
None yet
Development

No branches or pull requests

4 participants