[go: nahoru, domu]

Skip to content

Commit

Permalink
PR #13639: Fix UAF in Norm Rewriter
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#13639

Resolves a use-after-free when matching and rewriting layer norm patterns. See #13606.
Copybara import of the project:

--
91ebf7b4a2ac90ebadce27d1a73e88fb4513aed4 by Philipp Hack <phack@nvidia.com>:

Resolves a use-after-free in the norm rewriter.

Merging this change closes #13639

PiperOrigin-RevId: 642548037
  • Loading branch information
philipphack authored and tensorflower-gardener committed Jun 12, 2024
1 parent 321eb70 commit b94a1d2
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1109,8 +1109,12 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
for (HloInstruction* user : norm_factor->users()) {
if (user->opcode() == HloOpcode::kDivide &&
user->operand_index(norm_factor) == 0) {
TF_RETURN_IF_ERROR(MatchNormFactor(user, custom_call, variance,
expectation, epsilon));
TF_ASSIGN_OR_RETURN(bool changed,
MatchNormFactor(user, custom_call, variance,
expectation, epsilon));
if (changed) {
break;
}
}
}
}
Expand All @@ -1122,11 +1126,11 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
// as the norm factor and its cube, (variance + epsilon)^-1/2 and (variance +
// epsilon)^-3/2. When identified in the graph, these quantities are fused
// into the layer norm Custom Call.
absl::Status MatchNormFactor(HloInstruction* instr,
HloInstruction* custom_call,
UniqueHloInstruction& variance,
UniqueHloInstruction& expectation,
UniqueHloInstruction& epsilon) {
absl::StatusOr<bool> MatchNormFactor(HloInstruction* instr,
HloInstruction* custom_call,
UniqueHloInstruction& variance,
UniqueHloInstruction& expectation,
UniqueHloInstruction& epsilon) {
HloInstruction* gte = custom_call->users()[0];
if (Match(instr,
m::Divide(
Expand All @@ -1138,21 +1142,21 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
// Verify the uniqueness of the operands.
if (!variance.Instr() || !epsilon.Instr()) {
VLOG(1) << "Layer norm operands not unique.";
return absl::OkStatus();
return false;
}

// Verify the element types.
if (!CompatibleElementType(instr) ||
!CompatibleElementType(expectation.Instr())) {
VLOG(1) << "Layer norm input types not compatible.";
return absl::OkStatus();
return false;
}

// Retrieve metadata of the forward layer norm.
auto norm_metadata = norm_metadata_.extract(custom_call);
if (!norm_metadata) {
VLOG(1) << "Unable to retrieve norm metadata of forward Custom Call.";
return absl::OkStatus();
return false;
}

// The shape of the expectation and norm factor return values of the
Expand Down Expand Up @@ -1241,7 +1245,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor {
<< "Expectation and norm factor fused into layer norm Custom Call.";
}

return absl::OkStatus();
return true;
}

// Matches and rewrites the backward graph of layer norm patterns into Custom
Expand Down

0 comments on commit b94a1d2

Please sign in to comment.