[go: nahoru, domu]

Skip to content

Commit

Permalink
Preserve HloModuleConfig in HLO<->MHLO.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642016412
  • Loading branch information
GleasonK authored and tensorflower-gardener committed Jun 10, 2024
1 parent 705c22d commit 6312122
Show file tree
Hide file tree
Showing 14 changed files with 300 additions and 59 deletions.
12 changes: 12 additions & 0 deletions third_party/xla/xla/translate/hlo_to_mhlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ cc_library(
":custom_call_importer",
":hlo_utils",
":location_importer",
":module_config_importer",
"//xla:comparison_util",
"//xla:literal",
"//xla:protobuf_util",
Expand All @@ -82,6 +83,7 @@ cc_library(
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -164,6 +166,16 @@ cc_library(
],
)

cc_library(
name = "module_config_importer",
srcs = ["module_config_importer.cc"],
hdrs = ["module_config_importer.h"],
deps = [
"//xla/service:hlo_module_config",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "translate",
srcs = ["translate.cc"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/shape_util.h"
#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h"
#include "xla/translate/hlo_to_mhlo/module_config_importer.h"
#include "xla/xla.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -111,6 +112,7 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) {
"mhlo.is_dynamic",
mlir::BoolAttr::get(builder_.getContext(), hlo_module.is_dynamic()));
ImportFrontendAttributes(hlo_module, module);
ImportHloModuleConfig(hlo_module.config(), module);
module->setAttr("mhlo.use_auto_spmd_partitioning",
mlir::BoolAttr::get(builder_.getContext(),
hlo_module.use_auto_spmd_partitioning()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2019 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/translate/hlo_to_mhlo/module_config_importer.h"

#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "xla/service/hlo_module_config.h"

namespace xla {

namespace {

constexpr char kConfigNumPartitions[] = "mhlo.num_partitions";
constexpr char kConfigNumReplicas[] = "mhlo.num_replicas";

} // namespace

void ImportHloModuleConfig(const HloModuleConfig& config,
mlir::ModuleOp module) {
mlir::Builder builder(module.getContext());
if (config.num_partitions() != 1) {
module->setAttr(kConfigNumPartitions,
builder.getI32IntegerAttr(config.num_partitions()));
}
if (config.replica_count() != 1) {
module->setAttr(kConfigNumReplicas,
builder.getI32IntegerAttr(config.replica_count()));
}
}

} // namespace xla
30 changes: 30 additions & 0 deletions third_party/xla/xla/translate/hlo_to_mhlo/module_config_importer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_TRANSLATE_HLO_TO_MHLO_MODULE_CONFIG_IMPORTER_H_
#define XLA_TRANSLATE_HLO_TO_MHLO_MODULE_CONFIG_IMPORTER_H_

#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "xla/service/hlo_module_config.h"

namespace xla {
// Imports the HLO module config into the MLIR module as module attributes
// prefixed with `mhlo.`.
// TODO (b/345755258) Support roundtrip of all HLO module config fields.
void ImportHloModuleConfig(const xla::HloModuleConfig& config,
mlir::ModuleOp module);
} // namespace xla

#endif // XLA_TRANSLATE_HLO_TO_MHLO_MODULE_CONFIG_IMPORTER_H_
1 change: 1 addition & 0 deletions third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ lit_test_suite(
"layouts_and_names.hlo",
"location.hlo",
"module_attributes.hlo",
"module_config.hlo",
"send_recv.hlo",
"simple.hlo",
"spmd_module_sharding.hlo",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s

// CHECK: module @check_imported_configs attributes {{.*}} mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 2 : i32
HloModule check_imported_configs, replica_count=2, num_partitions=4

ENTRY %main.2 (Arg_0.1: f32[1]) -> f32[1] {
ROOT %Arg_0.1 = f32[1] parameter(0)
}
18 changes: 18 additions & 0 deletions third_party/xla/xla/translate/mhlo_to_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ cc_library(
],
)

cc_library(
name = "module_config_exporter",
srcs = ["module_config_exporter.cc"],
hdrs = ["module_config_exporter.h"],
deps = [
"//xla/service:hlo_module_config",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "stack_frame_index_builder",
srcs = ["stack_frame_index_builder.cc"],
Expand All @@ -82,6 +92,7 @@ cc_library(
":attribute_exporter",
":layout_util",
":location_exporter",
":module_config_exporter",
":operator_writer_inc",
":stack_frame_index_builder",
":type_to_shape",
Expand All @@ -104,10 +115,12 @@ cc_library(
"//xla/mlir/utils:type_util",
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service:hlo_module_config",
"//xla/service:hlo_parser",
"//xla/service:hlo_proto_cc",
"//xla/service/gpu:backend_configs_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
Expand Down Expand Up @@ -163,7 +176,12 @@ cc_library(
deps = [
":mlir_hlo_to_hlo",
":type_to_shape",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_proto_util",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
79 changes: 45 additions & 34 deletions third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_parser.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/translate/mhlo_to_hlo/attribute_exporter.h"
#include "xla/translate/mhlo_to_hlo/location_exporter.h"
#include "xla/translate/mhlo_to_hlo/module_config_exporter.h"
#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h"
#include "xla/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/types.h"
Expand Down Expand Up @@ -711,12 +713,9 @@ class ConvertToHloModule {
// single value.
explicit ConvertToHloModule(mlir::ModuleOp module,
xla::XlaBuilder& module_builder,
bool use_tuple_args, bool return_tuple,
MlirToHloConversionOptions options)
: module_(module),
module_builder_(module_builder),
use_tuple_args_(use_tuple_args),
return_tuple_(return_tuple),
options_(options) {}

// Perform the lowering to XLA. This function returns failure if an error was
Expand Down Expand Up @@ -820,15 +819,10 @@ class ConvertToHloModule {
// Map between function and lowered computation.
FunctionLoweringMap lowered_computation_;

// Whether the entry function should take a single tuple as input.
bool use_tuple_args_;

// Whether to always return a tuple.
bool return_tuple_;

// Unique suffix to give to the name of the next lowered region.
size_t region_id_ = 0;

// Conversion options
MlirToHloConversionOptions options_;
};

Expand Down Expand Up @@ -3178,7 +3172,8 @@ LogicalResult ConvertToHloModule::Lower(
unsigned num_return_values = inst->getNumOperands();
std::optional<xla::OpSharding> ret_tuple_sharding =
CreateTupleSharding(ret_shardings);
if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
if ((options_.return_tuple && is_entry_function) ||
num_return_values != 1) {
std::vector<xla::XlaOp> returns(num_return_values);
for (OpOperand& ret : inst->getOpOperands()) {
unsigned index = ret.getOperandNumber();
Expand Down Expand Up @@ -3293,7 +3288,7 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
auto buffer_donor =
f.getArgAttrOfType<mlir::BoolAttr>(i, "jax.buffer_donor");
if (buffer_donor) {
if (use_tuple_args_) {
if (options_.use_tuple_args) {
builder.AddBufferDonor(/*param_number=*/0, /*param_index=*/{i});
} else {
builder.AddBufferDonor(/*param_number=*/i, /*param_index=*/{});
Expand All @@ -3303,7 +3298,7 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
if (!aliasing_output) continue;
xla::ShapeIndex output_index;
if ((return_tuple_ && entry_function) || f.getNumResults() != 1) {
if ((options_.return_tuple && entry_function) || f.getNumResults() != 1) {
output_index = {aliasing_output.getInt()};
} else {
if (aliasing_output.getInt() != 0) {
Expand All @@ -3312,7 +3307,7 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
}
output_index = {};
}
if (use_tuple_args_) {
if (options_.use_tuple_args) {
builder.SetUpAlias(output_index, /*param_number=*/0,
/*param_index=*/{i});
} else {
Expand Down Expand Up @@ -3453,7 +3448,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(

// If using tuples as input, then there is only one input parameter that is a
// tuple.
if (is_entry_function && use_tuple_args_) {
if (is_entry_function && options_.use_tuple_args) {
llvm::SmallVector<xla::Shape, 4> arg_shapes;
std::vector<bool> leaf_replication;
if (failed(SetEntryTupleShapesAndLeafReplication(
Expand Down Expand Up @@ -3640,8 +3635,7 @@ absl::Status PrepareForExport(mlir::ModuleOp module) {
} // namespace

absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module,
xla::HloProto* hlo_proto, bool use_tuple_args,
bool return_tuple,
xla::HloProto* hlo_proto,
MlirToHloConversionOptions options) {
// To support the ongoing migration of XLA's compiler interface from MHLO
// to StableHLO, we've inserted this fallback to provide support for backends
Expand All @@ -3650,26 +3644,16 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module,
// supports not just MHLO, but also CHLO and StableHLO, but we will
// temporarily support StableHLO to MHLO lowering here as well to ensure
// a smooth migration.
// TODO(b/263811577): Remove this functionality once we have reasonable
// confidence that everyone has migrated from calling ConvertMlirHloToHlo
// directly.
bool hasStablehloOps = false;
module.walk([&](Operation* op) {
hasStablehloOps |= isa<stablehlo::StablehloDialect>(op->getDialect());
return hasStablehloOps ? WalkResult::interrupt() : WalkResult::advance();
});
if (hasStablehloOps) {
mlir::PassManager pm(module->getContext());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
if (failed(pm.run(module)))
return tsl::errors::Internal("Unable to convert StableHLO to MHLO");
mlir::PassManager pm(module->getContext());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
if (failed(pm.run(module))) {
return tsl::errors::Internal("Unable to convert StableHLO to MHLO");
}

TF_RETURN_IF_ERROR(PrepareForExport(module));
mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext());
xla::XlaBuilder module_builder("main");
ConvertToHloModule converter(module, module_builder, use_tuple_args,
return_tuple, options);
ConvertToHloModule converter(module, module_builder, options);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
StringRef module_name = module.getName() ? *module.getName() : "main";
Expand Down Expand Up @@ -3715,15 +3699,33 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module,
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<xla::HloModule>> ConvertMlirHloToHloModule(
mlir::ModuleOp module, MlirToHloConversionOptions options) {
xla::HloProto hlo_proto;
TF_RETURN_IF_ERROR(ConvertMlirHloToHlo(module, &hlo_proto, options));

// Create default config.
const xla::HloModuleProto& module_proto = hlo_proto.hlo_module();
TF_ASSIGN_OR_RETURN(xla::HloModuleConfig config,
xla::HloModule::CreateModuleConfigFromProto(
module_proto, xla::GetDebugOptionsFromFlags()));

// Modify config with values stored in MLIR module attributes
mhlo::ExportHloModuleConfig(config, module);

return xla::HloModule::CreateFromProto(module_proto, config);
}

absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
llvm::ArrayRef<xla::XlaOp> xla_params,
std::vector<xla::XlaOp>& returns,
MlirToHloConversionOptions options) {
auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
TF_RETURN_IF_ERROR(PrepareForExport(module));
ConvertToHloModule converter(module, builder,
/*use_tuple_args=*/false, /*return_tuple=*/false,
options);
// No tuple support in Builder converter API.
options.return_tuple = false;
options.use_tuple_args = false;
ConvertToHloModule converter(module, builder, options);

ConvertToHloModule::ValueLoweringMap lowering;
// xla_params should only include non-constant parameters the block arguments
Expand Down Expand Up @@ -3760,4 +3762,13 @@ absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
return absl::OkStatus();
}

absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module,
::xla::HloProto* hlo_proto,
bool use_tuple_args, bool return_tuple,
MlirToHloConversionOptions options) {
options.use_tuple_args = use_tuple_args;
options.return_tuple = return_tuple;
return ConvertMlirHloToHlo(module, hlo_proto, options);
}

} // namespace mlir
Loading

0 comments on commit 6312122

Please sign in to comment.