HloReachabilityMap does not update the BitVectors properly when there is a cycle in the HloComputation #56542
Labels
comp:xla
XLA
stat:awaiting tensorflower
Status - Awaiting response from tensorflower
TF 2.9
Issues found in the TF 2.9 release (or RCs)
type:bug
Bug
Click to expand!
Issue Type
Bug
Source
source
Tensorflow Version
tf 2.9
Custom Code
Yes
OS Platform and Distribution
Linux Ubuntu 20.04
Mobile device
No response
Python version
3.8.10
Bazel version
No response
GCC/Compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current Behaviour?
In the HloReachabilityMap, the BitVectors are updated in topological order starting from the root node. However, if there is a cycle in the graph, e.g. A->B->C->A, with A being closer to the root, the following occurs:
Basically, if there's a cycle whichever instruction in the cycle is processed earlier will not be updated with the actual reachability BitVectors of the later instructions.
Since the HloReachabilityMap is used in "InstructionFusion::MultiOutputFusionCreatesCycle" to detect cycles, so I believe it will no work properly in this use case.
A trivial fix would be to just perform the BitVector update twice in HloReachabilityMap::Build, i.e. just repeat the following code twice:
By running it twice, we can ensure that e.g. in the second loop when we load A's BitMap with C's, C's BitMap already denotes that it's reachable from A, B and C (itself)
Standalone code to reproduce the issue
Relevant log output
No response
The text was updated successfully, but these errors were encountered: