[go: nahoru, domu]

Skip to content

Commit

Permalink
Share StableHLO/MHLO pretty printers for ReduceOp and WhileOp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609037997
  • Loading branch information
GleasonK authored and tensorflower-gardener committed Feb 21, 2024
1 parent 264e6a8 commit 658fa2f
Show file tree
Hide file tree
Showing 8 changed files with 1,918 additions and 416 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-NEXT: %[[cmul:.*]] = mhlo.convert %[[mul]] : tensor<8x8x8x8xf32>
// CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[convert_init:.*]] = mhlo.convert %[[init]] : tensor<f32>
// CHECK: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[convert_init]]) across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK: %[[red1:.*]] = mhlo.reduce(%[[cmul]] init: %[[convert_init]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK: %[[scr2:.*]] = mhlo.convert %[[red1]] : tensor<8xf32>

// CHECK: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
Expand All @@ -575,7 +575,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK: %[[cgrad:.*]] = mhlo.convert %[[grad]] : tensor<8x8x8x8xf32>
// CHECK: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %[[convert_init2:.*]] = mhlo.convert %[[init2]] : tensor<f32>
// CHECK: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[convert_init2]]) across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK: %[[red2:.*]] = mhlo.reduce(%[[cgrad]] init: %[[convert_init2]]) applies mhlo.add across dimensions = [0, 1, 2] : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK: %[[offset_backprop:.*]] = mhlo.convert %[[red2]] : tensor<8xf32>

// CHECK: %[[x_backprop:.*]] = mhlo.convert %[[mul3]] : tensor<8x8x8x8xf32>
Expand Down
951 changes: 951 additions & 0 deletions third_party/stablehlo/temporary.patch

Large diffs are not rendered by default.

951 changes: 951 additions & 0 deletions third_party/xla/third_party/stablehlo/temporary.patch

Large diffs are not rendered by default.

373 changes: 8 additions & 365 deletions third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,6 @@ func.func @reduce_one_op_all_locs_same(%arg0: tensor<?x?xf32>, %arg1 : tensor<f3
func.return %0: tensor<?xf32>
}

// The test case is not eligible for pretty-printing reduce-op. The location of
// reduce-op is different.

// CHECK-LABEL: func @reduce_one_op_all_locs_not_same_1
// CHECK-NEXT: mhlo.reduce(%arg0 init: %arg1)
// CHECK-SAME: across dimensions = [1] {foo = "bar"}
// CHECK-SAME: : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK-NEXT: reducer(%arg[[x:.+]]: tensor<f32> loc("foo"), %arg[[y:.+]]: tensor<f32> loc("foo"))
// CHECK-NEXT: mhlo.add %arg[[x]], %arg[[y]] : tensor<f32> loc("foo")
// CHECK-NEXT: mhlo.return %{{[0-9]+}} : tensor<f32> loc("foo")
// CHECK-NEXT: loc("not_foo")

func.func @reduce_one_op_all_locs_not_same_1(%arg0: tensor<?x?xf32>, %arg1 : tensor<f32>) -> (tensor<?xf32>) {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32> loc("foo"), %arg3: tensor<f32> loc("foo")):
%1 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc("foo")
"mhlo.return"(%1) : (tensor<f32>) -> () loc("foo")
}) {dimensions = dense<[1]> : tensor<1xi64>, foo = "bar"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("not_foo")

func.return %0: tensor<?xf32>
}

// The test case is not eligible for pretty-printing reduce-op. The location of
// block-arguments are different.

// CHECK-LABEL: func @reduce_one_op_all_locs_not_same_2
// CHECK-NOT: applies

func.func @reduce_one_op_all_locs_not_same_2(%arg0: tensor<?x?xf32>, %arg1 : tensor<f32>) -> (tensor<?xf32>) {
%0 = "mhlo.reduce"(%arg0, %arg1) ({
^bb0(%arg2: tensor<f32> loc("foo"), %arg3: tensor<f32> loc("not_foo")):
%1 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc("foo")
"mhlo.return"(%1) : (tensor<f32>) -> () loc("foo")
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")

func.return %0: tensor<?xf32>
}


// The test case is not eligible for pretty-printing reduce-op. More than two
// block-arguments which are not perfectly forwarded to inner-op.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func.func @dot3(%arg0: tensor<4xf64, #SV>,
// CHECK-LABEL: func @sparse_reduce(
// CHECK-SAME: %[[A:.*]]: tensor<10xi64, #{{.*}}>) -> tensor<i64> {
// CHECK: %[[C:.*]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: %[[T:.*]] = mhlo.reduce(%[[A]] init: %[[C]]) across dimensions = [0] : (tensor<10xi64, #{{.*}}>) -> tensor<i64>
// CHECK: %[[T:.*]] = mhlo.reduce(%[[A]] init: %[[C]]) applies mhlo.add across dimensions = [0] : (tensor<10xi64, #{{.*}}>) -> tensor<i64>
// CHECK: return %[[T]] : tensor<i64>
func.func @sparse_reduce(%arg0: tensor<10xi64, #SV>) -> tensor<i64> {
%0 = mhlo.constant dense<0> : tensor<i64>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,23 +521,23 @@ func.func @reduce_verify_rettype(%arg0: tensor<?x?xf32>, %arg1 : tensor<f32>)
// -----

func.func @reduce_parsing_pretty_reduce_non_commutative(%arg0: tensor<?x?xf32> , %arg1: tensor<f32> ) -> tensor<?xf32> {
// expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}}
// expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}}
%0 = mhlo.reduce(%arg0 init: %arg1) applies mhlo.divide across dimensions = [1] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")
func.return %0 : tensor<?xf32>
}

// -----

func.func @reduce_parsing_pretty_reduce_wrong_dialect(%arg0: tensor<?x?xf32> , %arg1: tensor<f32> ) -> tensor<?xf32> {
// expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}}
// expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}}
%0 = mhlo.reduce(%arg0 init: %arg1) applies std.add across dimensions = [1] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")
func.return %0 : tensor<?xf32>
}

// -----

func.func @reduce_parsing_pretty_reduce_non_binary(%arg0: tensor<?x?xf32> , %arg1: tensor<f32> ) -> tensor<?xf32> {
// expected-error@+1 {{expected the inner-op to be a commutative binary-op from mhlo dialect, zero region, producing single result}}
// expected-error@+1 {{expected the inner-op to be a commutative binary-op that matching the reduce op dialect, with zero region, producing single result}}
%0 = mhlo.reduce(%arg0 init: %arg1) applies mhlo.reshape across dimensions = [1] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")
func.return %0 : tensor<?xf32>
}
8 changes: 2 additions & 6 deletions third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlotxt
Original file line number Diff line number Diff line change
Expand Up @@ -1142,16 +1142,12 @@ add {
// CHECK: mhlo.tuple %0#0, %0#1 {xla_shape = {{.*}}} : tuple<tensor<f32>, tensor<f32>>
%reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1

// CHECK: [[VAL2:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG2]])
// CHECK: mhlo.add{{.*}} : tensor<f32>
// CHECK: [[VAL2:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG2]]) applies mhlo.add across dimensions = [0, 1] : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
%reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
// CHECK: [[VAL3:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG1]])
// CHECK-
// CHECK: mhlo.add{{.*}} : tensor<4xf32>
%reduce.2 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.2
// CHECK: [[VAL4:%.*]] = mhlo.reduce([[VAL3]] init: [[ARG2]])
// CHECK-SAME: dimensions = [0]
// CHECK: mhlo.add{{.*}} : tensor<f32>
// CHECK: [[VAL4:%.*]] = mhlo.reduce([[VAL3]] init: [[ARG2]]) applies mhlo.add across dimensions = [0] : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
%reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3

// CHECK: %5 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor<f32>
Expand Down

0 comments on commit 658fa2f

Please sign in to comment.