[go: nahoru, domu]

Skip to content

Commit

Permalink
[odml] Change StableHLO->VHLO pass to use partial conversion instead …
Browse files Browse the repository at this point in the history
…of greedy rewriter to avoid CSE of arith constants.

PiperOrigin-RevId: 611605374
  • Loading branch information
GleasonK authored and tensorflower-gardener committed Feb 29, 2024
1 parent d5cea8e commit b4f292d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,19 @@ func.func @op_with_region_mixed_shlo_tfl_shlo(%arg0: tensor<7x5xf32>, %arg1 : te
}) {dimensions = array<i64: 0>} : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32>
func.return %0: tensor<5xf32>
}

// -----

// There are cases where ODML converter relies on constants not being folded or
// CSE'ed. This test ensures that StableHLO<->ODML conversion does not fold.

// CHECK-LABEL: mixed_no_constant_folding
func.func @mixed_no_constant_folding() -> (tensor<f32>) {
// CHECK: %[[CST0:.+]] = arith.constant dense<0.000000e+00>
// CHECK-NEXT: %[[CST1:.+]] = arith.constant dense<0.000000e+00>
// CHECK-NEXT: "vhlo.add_v1"(%[[CST0]], %[[CST1]]) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<f32>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.add %cst_0, %cst_1 : tensor<f32>
return %0 : tensor<f32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "stablehlo/dialect/VhloTypes.h" // from @stablehlo
#include "stablehlo/transforms/Passes.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"

#define DEBUG_TYPE "compat-passes"

Expand Down Expand Up @@ -255,9 +256,12 @@ LogicalResult ApplyVhloToStablehloPatterns(ModuleOp module) {
}

LogicalResult ApplyUnrealizedCastCanonicalization(ModuleOp module) {
RewritePatternSet patterns(module->getContext());
MLIRContext *context = module->getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addIllegalOp<UnrealizedConversionCastOp>();
populateReconcileUnrealizedCastsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return module->emitError("Failed to fold unrealized cast");
}
return success();
Expand Down

0 comments on commit b4f292d

Please sign in to comment.