[go: nahoru, domu]

Skip to content

Commit

Permalink
[XLA:GPU] Remove the emit_param_load_fn callback in EmitTiledScope.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 612955251
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed May 28, 2024
1 parent 04a2e22 commit 7ac685e
Show file tree
Hide file tree
Showing 22 changed files with 352 additions and 436 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ cc_library(
":stablehlo_util",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib",
"//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf",
"//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/IR/register.h"
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ std::string TestDataPath() {
}

static constexpr char kCompilationStreamz[] =
"/tensorflow/core/tf_mlir_bridge_first_phase_count";
"/tensorflow/core/tf_mlir_bridge_first_phase_v2_count";

class LowerClusterToRuntimeOpsTest : public ::testing::Test {
public:
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using ::mlir::OwningOpRef;
using ::tensorflow::monitoring::testing::CellReader;

static constexpr char kCompilationStreamz[] =
"/tensorflow/core/tf_mlir_bridge_first_phase_count";
"/tensorflow/core/tf_mlir_bridge_first_phase_v2_count";

std::string TestDataPath() {
return tensorflow::GetDataDependencyFilepath(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::string TestDataPath() {
}

static constexpr char kCompilationStreamz[] =
"/tensorflow/core/tf_mlir_bridge_first_phase_count";
"/tensorflow/core/tf_mlir_bridge_first_phase_v2_count";

class FunctionClusterTensorflowDialectTest : public ::testing::Test {
public:
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/tf2xla/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ cc_library(
"legalize_tf_with_tf2xla.cc",
],
hdrs = [
"legalize_tf_with_tf2xla_passes.h",
"passes.h",
],
deps = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZE_TF_WITH_TF2XLA_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZE_TF_WITH_TF2XLA_PASSES_H_

#include <memory>
#include <optional>

#include "llvm/ADT/StringRef.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir {

namespace func {
class FuncOp;
}
class ModuleOp;
class Operation;
template <typename T>
class OperationPass;
class Pass;

namespace mhlo {

/// Converter to be used along with the fallback Tf2Xla patterns below.
class Tf2XlaTypeConverter : public TypeConverter {
public:
Tf2XlaTypeConverter();
};

/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
/// `prefer_tf2xla` means an op will be included iff it is not in
/// `MlirLegalizedUnderPreferTf2XlaSet`. `!prefer_tf2xla` mean an op will be
/// included if there is no native MLIR legalization for the op.
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
RewritePatternSet& patterns,
MLIRContext* ctx,
Tf2XlaTypeConverter& converter,
bool prefer_tf2xla = false);


} // namespace mhlo
} // namespace mlir

#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZE_TF_WITH_TF2XLA_PASSES_H_
17 changes: 0 additions & 17 deletions tensorflow/compiler/mlir/tf2xla/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFPass(
std::optional<StringRef> tf2xla_fallback_device_type = std::nullopt,
bool prefer_tf2xla = false);

/// Converter to be used along with the fallback Tf2Xla patterns below.
class Tf2XlaTypeConverter : public TypeConverter {
public:
Tf2XlaTypeConverter();
};

/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
/// `prefer_tf2xla` means an op will be included iff it is not in
/// `MlirLegalizedUnderPreferTf2XlaSet`. `!prefer_tf2xla` mean an op will be
/// included if there is no native MLIR legalization for the op.
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
RewritePatternSet& patterns,
MLIRContext* ctx,
Tf2XlaTypeConverter& converter,
bool prefer_tf2xla = false);

/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern
/// list.
void PopulateLegalizeTfPatterns(MLIRContext* context,
Expand Down Expand Up @@ -129,7 +113,6 @@ CreateInfeedsOpsXlaAdjustLayoutPass();

#define GEN_PASS_REGISTRATION
#define GEN_PASS_DECL_LEGALIZETFCOMMUNICATIONPASS
#define GEN_PASS_DECL_LEGALIZETFWITHTF2XLA
#include "tensorflow/compiler/mlir/tf2xla/transforms/tf_xla_passes.h.inc"
} // namespace mhlo
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/framework/metrics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ auto* eager_client_error_counter = tsl::monitoring::Counter<2>::New(
"error_type");

auto* mlir_bridge_first_phase_counter = tsl::monitoring::Counter<5>::New(
"/tensorflow/core/tf_mlir_bridge_first_phase_count",
"/tensorflow/core/tf_mlir_bridge_first_phase_v2_count",
"Tracks processing state in first phase of mlir bridge", "bridge",
"version", "device", "fallback", "result");

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/framework/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ int64_t GetFunctionGraphOptimizationCacheLoadCount(
GraphOptimizationSource source);

// Records the activity of the first phase of the mlir bridge using the
// tf_metadata.tf_mlir_bridge_first_phase_count metric.
// tf_metadata.tf_mlir_bridge_first_phase_v2_count metric.
// bridge_type: replicated, nonreplicated, etc.
// bridge_version: v1 compat, v2, etc.
// device_type: tpu, cpu, gpu, etc.
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 266
_version = 267

# Version number for MLIR:Python components.
mlir_api_version = 56
Expand Down
36 changes: 0 additions & 36 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -638,47 +638,12 @@ cc_library(
],
)

cc_library(
name = "host_kernel_emitter",
srcs = ["host_kernel_emitter.cc"],
hdrs = ["host_kernel_emitter.h"],
deps = [
"//xla:shape_util",
"//xla/service/llvm_ir:ir_array",
"//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:logging",
],
)

xla_cc_test(
name = "host_kernel_emitter_test",
srcs = ["host_kernel_emitter_test.cc"],
deps = [
":host_kernel_emitter",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/service/llvm_ir:llvm_util",
"//xla/tests:filecheck",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
],
)

cc_library(
name = "ir_emitter2",
srcs = ["ir_emitter2.cc"],
hdrs = ["ir_emitter2.h"],
deps = [
":elemental_math_emitter",
":host_kernel_emitter",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/service:elemental_ir_emitter",
Expand All @@ -700,7 +665,6 @@ xla_cc_test(
name = "ir_emitter2_test",
srcs = ["ir_emitter2_test.cc"],
deps = [
":host_kernel_emitter",
":ir_emitter2",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
Expand Down
Loading

0 comments on commit 7ac685e

Please sign in to comment.