[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix all-reduce bytes_transmitted for ALL_GATHER collectives
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630413612
  • Loading branch information
cliveverghese authored and tensorflower-gardener committed May 3, 2024
1 parent 9b67088 commit d1697a8
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ int DcnTracker::GetReplicaGroupSize(const std::string& rendezvous_name,
return rendezvous_to_replica_group_size_map_[rendezvous_name];
}

// ComputeTransmittedDataSize is called with the buffer_size for recv-done.
uint64_t DcnTracker::ComputeTransmittedDataSize(
const int64_t buffer_size, const int group_size,
const int64_t recv_buffer_size, const int group_size,
const std::string& transfer_type) {
uint64_t transmitted_bytes = 0;
if (group_size == 0) {
Expand All @@ -251,17 +252,20 @@ uint64_t DcnTracker::ComputeTransmittedDataSize(
}

if (transfer_type == "ONE_TO_ONE") {
transmitted_bytes = group_size * buffer_size;
transmitted_bytes = group_size * recv_buffer_size;
} else if (transfer_type == "ALL_GATHER") {
transmitted_bytes = (group_size - 1) * buffer_size;
transmitted_bytes =
SafeDivide((group_size - 1) * recv_buffer_size, group_size);
} else if (transfer_type == "ALL_REDUCE") {
// Since the reduced buffer now has to be sent back to the replicas,
// the total bytes transmitted over the network is 2x the shape of the op.
transmitted_bytes =
2 * SafeDivide(group_size - 1, group_size) * buffer_size;
} else if (transfer_type == "ALL_TO_ALL" ||
transfer_type == "REDUCE_SCATTER") {
transmitted_bytes = SafeDivide(group_size - 1, group_size) * buffer_size;
2 * SafeDivide(group_size - 1, group_size) * recv_buffer_size;
} else if (transfer_type == "ALL_TO_ALL") {
transmitted_bytes =
SafeDivide(group_size - 1, group_size) * recv_buffer_size;
} else if (transfer_type == "REDUCE_SCATTER") {
transmitted_bytes = recv_buffer_size * (group_size - 1);
} else {
LOG(ERROR) << "Unsupported transfer type: " << transfer_type;
}
Expand Down

0 comments on commit d1697a8

Please sign in to comment.