diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index c53835bf83a88d..4da1564a334ad1 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -102,6 +102,23 @@ cc_library( ], ) +cc_library( + name = "attribute_map", + srcs = ["attribute_map.cc"], + hdrs = ["attribute_map.h"], + deps = [ + ":call_frame", + "//xla:xla_data_proto_cc_impl", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + ], +) + xla_cc_test( name = "ffi_test", srcs = ["ffi_test.cc"], diff --git a/third_party/xla/xla/ffi/attribute_map.cc b/third_party/xla/xla/ffi/attribute_map.cc new file mode 100644 index 00000000000000..9443a8c010d542 --- /dev/null +++ b/third_party/xla/xla/ffi/attribute_map.cc @@ -0,0 +1,130 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ffi/attribute_map.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/ffi/call_frame.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" + +using FlatAttribute = xla::ffi::CallFrameBuilder::FlatAttribute; +using FlatAttributesMap = xla::ffi::CallFrameBuilder::FlatAttributesMap; + +namespace xla::ffi { + +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict) { + FlatAttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + + auto boolean = [&](mlir::BoolAttr boolean) { + attributes[name] = static_cast(boolean.getValue()); + return absl::OkStatus(); + }; + + auto integer = [&](mlir::IntegerAttr integer) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 1: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 8: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 16: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 32: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + }; + + auto fp = [&](mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(fp.getValue().convertToFloat()); + return absl::OkStatus(); + case 64: + attributes[name] = + static_cast(fp.getValue().convertToDouble()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } + }; + + auto arr = [&](mlir::DenseArrayAttr arr) { + if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } else if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } else if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } else if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } else if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } else if (auto dense = mlir::dyn_cast(arr)) { + attributes[name] = dense.asArrayRef().vec(); + return absl::OkStatus(); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported array element type for attribute: ", name)); + } + }; + + auto str = [&](mlir::StringAttr str) { + attributes[name] = str.getValue().str(); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(kv.getValue()) + .Case(boolean) + .Case(integer) + .Case(fp) + .Case(arr) + .Case(str) + .Default([&](mlir::Attribute) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported attribute type for attribute: ", name)); + })); + } + return attributes; +} +} // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/attribute_map.h b/third_party/xla/xla/ffi/attribute_map.h new file mode 100644 index 00000000000000..d6c37b31c5522b --- /dev/null +++ b/third_party/xla/xla/ffi/attribute_map.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_FFI_ATTRIBUTE_MAP_H_ +#define XLA_FFI_ATTRIBUTE_MAP_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "xla/ffi/call_frame.h" + +namespace xla::ffi { + +// Converts MLIR dictionary attribute attached to a custom call operation to a +// custom call handler attributes that are forwarded to the FFI handler. +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict); + +} // namespace xla::ffi + +#endif // XLA_FFI_ATTRIBUTE_MAP_H_ diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 125af6dbce9170..1d0defbdcced10 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1184,6 +1184,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/ffi:attribute_map", "//xla/ffi:call_frame", "//xla/ffi:ffi_api", "//xla/service:custom_call_status_public_headers", diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc index 7a7963593f592a..7d3271db600634 100644 --- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/ffi/attribute_map.h" #include "xla/ffi/call_frame.h" #include "xla/ffi/ffi_api.h" #include "xla/primitive_util.h" @@ -46,95 +47,6 @@ namespace ffi = xla::ffi; namespace { -using Attribute = ffi::CallFrameBuilder::FlatAttribute; -using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; - -// TODO(heinsaar): This BuildAttributesMap() is originally an identical -// copy-paste of the same function in custom_call_thunk.cc -// May make sense to have one in a common place & reuse. -absl::StatusOr BuildAttributesMap(mlir::DictionaryAttr dict) { - AttributesMap attributes; - for (auto& kv : dict) { - std::string_view name = kv.getName().strref(); - - auto boolean = [&](mlir::BoolAttr boolean) { - attributes[name] = static_cast(boolean.getValue()); - return absl::OkStatus(); - }; - - auto integer = [&](mlir::IntegerAttr integer) { - const bool is_unsigned = integer.getType().isUnsignedInteger(); - if (is_unsigned) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 8: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 16: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 32: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", - name)); - } - } else { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 8: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 16: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 32: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", - name)); - } - } - }; - - auto fp = [&](mlir::FloatAttr fp) { - switch (fp.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(fp.getValue().convertToFloat()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported float attribute bit width for attribute: ", name)); - } - }; - - auto str = [&](mlir::StringAttr str) { - attributes[name] = str.getValue().str(); - return absl::OkStatus(); - }; - - TF_RETURN_IF_ERROR( - llvm::TypeSwitch(kv.getValue()) - .Case(boolean) - .Case(integer) - .Case(fp) - .Case(str) - .Default([&](mlir::Attribute) { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported attribute type for attribute: ", name)); - })); - } - return attributes; -} - absl::Span DecodeDims(int64_t* encoded_dims_data) { // Annotate memory coming from jit compiled function as initialized to // suppress false positives from msan sanitizer. @@ -230,7 +142,7 @@ inline absl::Status BuildAndCallFfi( // and build an MLIR compatible map of attributes out of it. mlir::Attribute attr = mlir::parseAttribute(backend_config, &mlir_context); if (auto dict = attr.dyn_cast_or_null()) { - TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict)); } else { return absl::InternalError( "Unsupported backend config. Expected a string parsable into " diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 5faf098331bb14..e665640e8ce5eb 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -347,6 +347,7 @@ cc_library( "//xla:statusor", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 02752d4bfffe04..2e0a1202157486 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -126,6 +126,7 @@ cc_library( "//xla:status", "//xla:statusor", "//xla:util", + "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 5d5cc5d3898e08..e6dc11a46cb0be 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/ffi/attribute_map.h" #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -563,7 +564,7 @@ absl::StatusOr EmitCustomCall( mlir::Attribute attr = mlir::parseAttribute( backend_config_str, ir_emitter_context.mlir_context()); if (auto dict = mlir::dyn_cast_or_null(attr)) { - TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict)); break; } return absl::InternalError( diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 308a7c479822f5..85e39381bac1d9 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -72,6 +72,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project #include "xla/ffi/api/c_api.h" +#include "xla/ffi/attribute_map.h" #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -1405,7 +1406,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk( mlir::Attribute attr = mlir::parseAttribute( backend_config_str, ir_emitter_context_->mlir_context()); if (auto dict = mlir::dyn_cast_or_null(attr)) { - TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + TF_ASSIGN_OR_RETURN(attributes, xla::ffi::BuildAttributesMap(dict)); break; } return absl::InternalError( diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc index 9c25e984252519..b110384ce35563 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc @@ -167,99 +167,5 @@ absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { return handler_ ? ExecuteFfiHandler(params) : ExecuteCustomCall(params); } -absl::StatusOr BuildAttributesMap( - mlir::DictionaryAttr dict) { - CustomCallThunk::AttributesMap attributes; - for (auto& kv : dict) { - std::string_view name = kv.getName().strref(); - - auto boolean = [&](mlir::BoolAttr boolean) { - attributes[name] = static_cast(boolean.getValue()); - return absl::OkStatus(); - }; - - auto integer = [&](mlir::IntegerAttr integer) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 1: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 8: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 16: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 32: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", name)); - } - }; - - auto fp = [&](mlir::FloatAttr fp) { - switch (fp.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(fp.getValue().convertToFloat()); - return absl::OkStatus(); - case 64: - attributes[name] = - static_cast(fp.getValue().convertToDouble()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported float attribute bit width for attribute: ", name)); - } - }; - - auto arr = [&](mlir::DenseArrayAttr arr) { - if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported array element type for attribute: ", name)); - } - }; - - auto str = [&](mlir::StringAttr str) { - attributes[name] = str.getValue().str(); - return absl::OkStatus(); - }; - - TF_RETURN_IF_ERROR( - llvm::TypeSwitch(kv.getValue()) - .Case(boolean) - .Case(integer) - .Case(fp) - .Case(arr) - .Case(str) - .Default([&](mlir::Attribute) { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported attribute type for attribute: ", name)); - })); - } - return attributes; -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h index 02679d2e0d21ff..2d797ecea01a6c 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h @@ -120,11 +120,6 @@ class CustomCallThunk : public Thunk { const HloComputation* called_computation_ = nullptr; }; -// Converts MLIR dictionary attribute attached to a custom call operation to a -// custom call thunk attributes that are forwarded to the FFI handler. -absl::StatusOr BuildAttributesMap( - mlir::DictionaryAttr dict); - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 2bffc0a1316e5f..58600c91d25852 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -1783,6 +1783,7 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:constants", "//xla/ffi", + "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", "//xla/service:custom_call_status",