diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/gemm_fusion.cc index 9d93c2b2a8aa63..3b3c90890735d8 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/gemm_fusion.cc @@ -798,13 +798,14 @@ bool ShouldTritonHandleGEMM(HloDotInstruction& dot, absl::StatusOr GemmFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability(gpu_version_)); + TF_RETURN_IF_ERROR( + EnsureTritonSupportsComputeCapability(compute_capability_)); bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { TF_ASSIGN_OR_RETURN(bool result, - RunOnComputation(computation, gpu_version_)); + RunOnComputation(computation, compute_capability_)); changed |= result; } return changed; diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.h b/third_party/xla/xla/service/gpu/gemm_fusion.h index 1138ad28a36a5f..c858b43822f194 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.h +++ b/third_party/xla/xla/service/gpu/gemm_fusion.h @@ -38,8 +38,8 @@ bool ShouldTritonHandleGEMM(HloDotInstruction&, // that target Triton-based matmul emitter. class GemmFusion : public HloModulePass { public: - explicit GemmFusion(const se::GpuComputeCapability& gpu_version) - : gpu_version_(gpu_version) {} + explicit GemmFusion(const se::GpuComputeCapability& compute_capability) + : compute_capability_(compute_capability) {} absl::string_view name() const override { return "triton-gemm-rewriter"; } using HloPassInterface::Run; @@ -48,7 +48,7 @@ class GemmFusion : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - se::GpuComputeCapability gpu_version_; + se::GpuComputeCapability compute_capability_; }; } // namespace gpu