WebNN: Define transpose operation in mojo
The operation transpose permutes the dimensions of the input tensor
according to the given permutation argument.
This CL moves the validation steps of transpose to //components/,
implements CreateTransposeOperation() for creating transpose mojo
operation and adds operator validation on service side.
Bug: 1273291
Change-Id: Ib143cac3eb48f2db553847f2add935c461a17f9f
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4893811
Auto-Submit: Shiyi Zou <shiyi.zou@intel.com>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: Shiyi Zou <shiyi.zou@intel.com>
Cr-Commit-Position: refs/heads/main@{#1207460}
diff --git a/AUTHORS b/AUTHORS
index ac5660a2..a6a754c 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -1218,6 +1218,7 @@
Shirish S <shirish.s@amd.com>
Shiva Kumar <shiva.k1@samsung.com>
Shivakumar JM <shiva.jm@samsung.com>
+Shiyi Zou <shiyi.zou@intel.com>
Shobhit Goel <shobhit.goel@samsung.com>
Shouqun Liu <liushouqun@xiaomi.com>
Shouqun Liu <shouqun.liu@intel.com>
diff --git a/components/ml/webnn/graph_validation_utils.cc b/components/ml/webnn/graph_validation_utils.cc
index 13c03a1..d1c5b44 100644
--- a/components/ml/webnn/graph_validation_utils.cc
+++ b/components/ml/webnn/graph_validation_utils.cc
@@ -5,6 +5,7 @@
#include "components/ml/webnn/graph_validation_utils.h"
#include <algorithm>
+#include <set>
#include "base/check_op.h"
#include "base/notreached.h"
@@ -502,6 +503,28 @@
return Operand(a.data_type, std::move(output_shape));
}
+base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
+ const Operand& input,
+ base::span<const uint32_t> permutation) {
+ auto input_dimensions = input.dimensions;
+ auto input_rank = input_dimensions.size();
+ if (permutation.size() != input_rank) {
+ return base::unexpected(
+ "The number of values in permutation must be the same as the rank of "
+ "the input tensor.");
+ }
+ auto validation_result = ValidateAxes(permutation, input_rank);
+ if (!validation_result.has_value()) {
+ return base::unexpected(validation_result.error());
+ }
+
+ std::vector<uint32_t> output_shape(input_rank);
+ for (size_t i = 0; i < input_rank; ++i) {
+ output_shape[i] = input_dimensions[permutation[i]];
+ }
+ return Operand(input.data_type, std::move(output_shape));
+}
+
base::expected<size_t, std::string> ValidateAndCalculateElementsNumber(
base::span<const uint32_t> dimensions) {
if (dimensions.empty()) {
@@ -535,6 +558,24 @@
return checked_byte_length.ValueOrDie();
}
+base::expected<void, std::string> ValidateAxes(base::span<const uint32_t> axes,
+ uint32_t rank) {
+ if (base::ranges::any_of(axes, [rank](uint32_t axis) {
+ return base::MakeStrictNum(axis) >= rank;
+ })) {
+ return base::unexpected(base::StringPrintf(
+ "The values in axes must be within the range from 0 to (%u).",
+ rank - 1));
+ }
+
+ if (axes.size() != std::set<uint32_t>(axes.begin(), axes.end()).size()) {
+ return base::unexpected(
+ "Two or more values are same in the axes sequence.");
+ }
+
+ return base::ok();
+}
+
absl::optional<std::vector<uint32_t>> BroadcastShapes(
base::span<const uint32_t> dims_lhs,
base::span<const uint32_t> dims_rhs,
diff --git a/components/ml/webnn/graph_validation_utils.h b/components/ml/webnn/graph_validation_utils.h
index 4851ce1..ddb0b85 100644
--- a/components/ml/webnn/graph_validation_utils.h
+++ b/components/ml/webnn/graph_validation_utils.h
@@ -184,6 +184,12 @@
const Operand& b,
const GemmAttributes& attributes);
+// Validate transpose operator defined in WebIDL here
+// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-transpose
+base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
+ const Operand& input,
+ base::span<const uint32_t> permutation);
+
base::expected<size_t, std::string> ValidateAndCalculateElementsNumber(
base::span<const uint32_t> dimensions);
@@ -191,6 +197,11 @@
size_t type_bytes,
base::span<const uint32_t> dimensions);
+// Validate that the axes are within the range of [0, rank - 1] without
+// duplication.
+base::expected<void, std::string> ValidateAxes(base::span<const uint32_t> axes,
+ uint32_t rank);
+
// Broadcast the input shapes and return the output shape.
// If bidirectional is true, its behavior follows the numpy-broadcasting-rule:
// https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules.
diff --git a/services/webnn/dml/graph_impl.cc b/services/webnn/dml/graph_impl.cc
index 304375f..9c75adc 100644
--- a/services/webnn/dml/graph_impl.cc
+++ b/services/webnn/dml/graph_impl.cc
@@ -32,6 +32,7 @@
using mojom::ComputeResult;
using mojom::Operand;
using mojom::OperandPtr;
+using mojom::Operation;
using mojom::Operator;
using mojom::OperatorPtr;
@@ -95,6 +96,17 @@
}
}
+std::string OpTagToString(Operation::Tag tag) {
+ switch (tag) {
+ case Operation::Tag::kPool2d:
+ return "pool2d";
+ case Operation::Tag::kTranspose:
+ return "transpose";
+ case Operation::Tag::kGenericOperator:
+ NOTREACHED_NORETURN();
+ }
+}
+
// Calculate the total byte length of buffers and the D3D12_RANGE for each
// buffer, all with the required alignment.
template <typename Map>
@@ -888,12 +900,7 @@
break;
}
default:
- DLOG(ERROR) << "This operator kind (" + OpKindToString(operation->kind) +
- ") is not supported.";
- create_operator_result = base::unexpected(mojom::Error::New(
- mojom::Error::Code::kNotSupportedError,
- "This operator (" + OpKindToString(operation->kind) +
- ") is not supported."));
+ NOTREACHED_NORETURN();
}
return create_operator_result;
}
@@ -1347,18 +1354,27 @@
// message.
base::expected<void, mojom::ErrorPtr> create_operator_result;
switch (operation->which()) {
- case mojom::Operation::Tag::kPool2d: {
+ case Operation::Tag::kPool2d: {
create_operator_result = CreateOperatorNodeForPool2d(
id_to_operand_map, operation->get_pool2d(), graph_builder,
id_to_node_output_map);
break;
}
- case mojom::Operation::Tag::kGenericOperator: {
+ case Operation::Tag::kGenericOperator: {
create_operator_result = CreateGenericOperator(
id_to_operand_map, operation->get_generic_operator(), graph_builder,
id_to_node_output_map);
break;
}
+ default: {
+ DLOG(ERROR) << "This operator kind (" +
+ OpTagToString(operation->which()) +
+ ") is not supported.";
+ create_operator_result = base::unexpected(mojom::Error::New(
+ mojom::Error::Code::kNotSupportedError,
+ "This operator (" + OpTagToString(operation->which()) +
+ ") is not supported."));
+ }
}
if (!create_operator_result.has_value()) {
std::move(callback).Run(mojom::CreateGraphResult::NewError(
diff --git a/services/webnn/public/mojom/webnn_graph.mojom b/services/webnn/public/mojom/webnn_graph.mojom
index cce12a6..d79ba15 100644
--- a/services/webnn/public/mojom/webnn_graph.mojom
+++ b/services/webnn/public/mojom/webnn_graph.mojom
@@ -135,6 +135,17 @@
bool b_transpose = false;
};
+// Represents the transpose operation that permutes the dimensions of the
+// input tensor following the given permutation.
+struct Transpose {
+ // The id of input operand.
+ uint64 input_operand_id;
+ // The id of output operand.
+ uint64 output_operand_id;
+ // The values used to permute the dimensions of the input tensor.
+ array<uint32> permutation;
+};
+
// Holds one of operator attributes.
union OperatorAttributes {
ClampAttributes clamp;
@@ -177,6 +188,7 @@
// Holds one of operator.
union Operation {
Pool2d pool2d;
+ Transpose transpose;
// TODO(crbug.com/1273291): The `generic_operator` will be removed.
Operator generic_operator;
diff --git a/services/webnn/webnn_graph_impl.cc b/services/webnn/webnn_graph_impl.cc
index 5251473..544a18c 100644
--- a/services/webnn/webnn_graph_impl.cc
+++ b/services/webnn/webnn_graph_impl.cc
@@ -463,6 +463,28 @@
return true;
}
+bool ValidateTranspose(const IdToOperandMap& id_to_operand_map,
+ const mojom::TransposePtr& transpose) {
+ auto* input = GetMojoOperand(id_to_operand_map, transpose->input_operand_id);
+ auto* output =
+ GetMojoOperand(id_to_operand_map, transpose->output_operand_id);
+ if (!input || !output) {
+ // The transpose operator is invalid.
+ return false;
+ }
+
+ auto validated_output = ValidateTransposeAndInferOutput(
+ ConvertToComponentOperand(input), transpose->permutation);
+ if (!validated_output.has_value()) {
+ return false;
+ }
+ if (validated_output != ConvertToComponentOperand(output)) {
+ return false;
+ }
+
+ return true;
+}
+
bool ValidateGenericOperator(const IdToOperandMap& id_to_operand_map,
const mojom::OperatorPtr& operation) {
switch (operation->kind) {
@@ -515,6 +537,8 @@
switch (operation->which()) {
case mojom::Operation::Tag::kPool2d:
return ValidatePool2d(id_to_operand_map, operation->get_pool2d());
+ case mojom::Operation::Tag::kTranspose:
+ return ValidateTranspose(id_to_operand_map, operation->get_transpose());
case mojom::Operation::Tag::kGenericOperator:
return ValidateGenericOperator(id_to_operand_map,
operation->get_generic_operator());
diff --git a/services/webnn/webnn_graph_impl_unittest.cc b/services/webnn/webnn_graph_impl_unittest.cc
index 0888b3e2b..4dc71f4 100644
--- a/services/webnn/webnn_graph_impl_unittest.cc
+++ b/services/webnn/webnn_graph_impl_unittest.cc
@@ -1144,6 +1144,90 @@
}
}
+struct TransposeTester {
+ OperandInfo input;
+ std::vector<uint32_t> permutation;
+ OperandInfo output;
+ bool expected;
+
+ void Test() {
+ // Build the graph with mojo type.
+ GraphInfoBuilder builder;
+ uint64_t input_operand_id =
+ builder.BuildInput("input", input.dimensions, input.type);
+ uint64_t output_operand_id =
+ builder.BuildOutput("output", output.dimensions, output.type);
+ builder.BuildTranspose(input_operand_id, output_operand_id,
+ std::move(permutation));
+ EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
+ }
+};
+
+TEST_F(WebNNGraphImplTest, TransposeTest) {
+ {
+ // Test transpose operator with permutation [2, 3, 1, 0].
+ TransposeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3, 4}},
+ .permutation = {2, 3, 1, 0},
+ .output = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {3, 4, 2, 1}},
+ .expected = true}
+ .Test();
+ }
+ {
+ // Test the invalid graph when the rank of permutation is larger than the
+ // input rank.
+ TransposeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3}},
+ .permutation = {0, 1, 2, 2},
+ .output = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3, 3}},
+ .expected = false}
+ .Test();
+ }
+ {
+ // Test the invalid graph when the permutation contains duplicate values.
+ TransposeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3, 4}},
+ .permutation = {0, 1, 2, 2},
+ .output = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3, 3}},
+ .expected = false}
+ .Test();
+ }
+ {
+ // Test the invalid graph when one value in permutation is greater than
+ // input_rank - 1.
+ TransposeTester{.input = {.type = mojom::Operand::DataType::kFloat16,
+ .dimensions = {1, 2, 3, 4}},
+ .permutation = {0, 1, 2, 4},
+ .output = {.type = mojom::Operand::DataType::kFloat16,
+ .dimensions = {1, 2, 3, 4}},
+ .expected = false}
+ .Test();
+ }
+ {
+ // Test the invalid graph for output shapes are not expected.
+ TransposeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3, 4}},
+ .permutation = {0, 1, 2, 3},
+ .output = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3}},
+ .expected = false}
+ .Test();
+ }
+ {
+ // Test the invalid graph for output types don't match.
+ TransposeTester{.input = {.type = mojom::Operand::DataType::kFloat32,
+ .dimensions = {1, 2, 3, 4}},
+ .permutation = {0, 1, 2, 3},
+ .output = {.type = mojom::Operand::DataType::kFloat16,
+ .dimensions = {1, 2, 3, 4}},
+ .expected = false}
+ .Test();
+ }
+}
+
TEST_F(WebNNGraphImplTest, ValidateInputsTest) {
const std::vector<uint32_t> dimensions = {3, 5};
// Build the graph with mojo type.
diff --git a/services/webnn/webnn_test_utils.cc b/services/webnn/webnn_test_utils.cc
index e407629..db497568 100644
--- a/services/webnn/webnn_test_utils.cc
+++ b/services/webnn/webnn_test_utils.cc
@@ -77,6 +77,17 @@
mojom::Operation::NewGenericOperator(std::move(operation)));
}
+void GraphInfoBuilder::BuildTranspose(uint64_t input_operand_id,
+ uint64_t output_operand_id,
+ std::vector<uint32_t> permutation) {
+ mojom::TransposePtr transpose = mojom::Transpose::New();
+ transpose->input_operand_id = input_operand_id;
+ transpose->output_operand_id = output_operand_id;
+ transpose->permutation = std::move(permutation);
+ graph_info_->operations.push_back(
+ mojom::Operation::NewTranspose(std::move(transpose)));
+}
+
mojom::GraphInfoPtr GraphInfoBuilder::CloneGraphInfo() const {
CHECK_IS_TEST();
mojom::GraphInfoPtr cloned_graph_info = mojom::GraphInfo::New();
diff --git a/services/webnn/webnn_test_utils.h b/services/webnn/webnn_test_utils.h
index 1f757aef..ff7036e 100644
--- a/services/webnn/webnn_test_utils.h
+++ b/services/webnn/webnn_test_utils.h
@@ -75,6 +75,10 @@
mojom::Operation::NewPool2d(std::move(pool2d)));
}
+ void BuildTranspose(uint64_t input_operand_id,
+ uint64_t output_operand_id,
+ std::vector<uint32_t> permutation);
+
const mojom::GraphInfoPtr& GetGraphInfo() const { return graph_info_; }
// Get a clone of internal graph info. This is used by
diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc
index 9406dc3..1084f11 100644
--- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc
+++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc
@@ -282,29 +282,6 @@
return true;
}
-bool ValidateAxes(const Vector<uint32_t>& axes,
- uint32_t input_rank,
- ExceptionState& exception_state) {
- if (base::ranges::any_of(axes, [input_rank](uint32_t axis) {
- return base::MakeStrictNum(axis) >= input_rank;
- })) {
- exception_state.ThrowDOMException(
- DOMExceptionCode::kDataError,
- String::Format("The values in axes must be within the range from 0 "
- "to (%u).",
- input_rank - 1));
- return false;
- }
-
- if (axes.size() != std::set<uint32_t>(axes.begin(), axes.end()).size()) {
- exception_state.ThrowDOMException(
- DOMExceptionCode::kDataError,
- "Two or more values are same in the axes sequence.");
- return false;
- }
- return true;
-}
-
absl::optional<Vector<uint32_t>> BroadcastShapes(
const Vector<uint32_t>& dims_lhs,
const Vector<uint32_t>& dims_rhs,
@@ -391,7 +368,11 @@
default_axes[i] = i;
}
const auto axes = options->getAxesOr(std::move(default_axes));
- if (!ValidateAxes(axes, input_rank, exception_state)) {
+ auto validation_result = webnn::ValidateAxes(axes, input_rank);
+ if (!validation_result.has_value()) {
+ exception_state.ThrowDOMException(
+ DOMExceptionCode::kDataError,
+ String::FromUTF8(validation_result.error()));
return nullptr;
}
@@ -1841,35 +1822,25 @@
// When permutation is not specified, it’s set to [N-1, ..., 0], where N is
// the rank of the input tensor.
auto input_rank = input->Dimensions().size();
- Vector<uint32_t> default_permutation(input_rank);
- for (wtf_size_t i = 0; i < input_rank - 1; i++) {
- default_permutation[i] = input_rank - 1 - i;
- }
const Vector<uint32_t> permutation =
- options->getPermutationOr(std::move(default_permutation));
- if (permutation.size() != input_rank) {
+ options->getPermutationOr(CreateDefaultPermutation(input_rank));
+ auto validated_output = webnn::ValidateTransposeAndInferOutput(
+ ConvertToComponentOperand(input), permutation);
+ if (!validated_output.has_value()) {
exception_state.ThrowDOMException(
DOMExceptionCode::kDataError,
- "The number of values in permutation must be the same as the rank "
- "of the input tensor.");
+ String::FromUTF8(validated_output.error()));
return nullptr;
}
- if (!ValidateAxes(permutation, input_rank, exception_state)) {
- return nullptr;
- }
-
- Vector<uint32_t> output_shape(input_rank);
- for (wtf_size_t i = 0; i < input_rank; ++i) {
- output_shape[i] = input->Dimensions()[permutation[i]];
- }
auto* transpose = MakeGarbageCollected<MLOperator>(
this, MLOperator::OperatorKind::kTranspose, options);
// According to WebNN spec
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-transpose, the output
// tensor of transpose has the same type as its input.
auto output = MLOperand::ValidateAndCreateOutput(
- this, input->Type(), std::move(output_shape), transpose);
+ this, ComponentOperandTypeToBlink(validated_output->data_type),
+ Vector<uint32_t>(validated_output->dimensions), transpose);
if (!output.has_value()) {
exception_state.ThrowDOMException(DOMExceptionCode::kDataError,
output.error());
diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_test_mojo.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph_test_mojo.cc
index 4e719d4e..8a4b5a6 100644
--- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_test_mojo.cc
+++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_test_mojo.cc
@@ -1465,6 +1465,95 @@
}
}
+struct TransposeTester {
+ OperandInfoBlink input;
+ absl::optional<Vector<uint32_t>> permutation;
+ OperandInfoMojo expected_operand;
+ Vector<uint32_t> expected_permutation;
+
+ void Test(MLGraphTestMojo& helper,
+ V8TestingScope& scope,
+ MLGraphBuilder* builder) {
+ // Build the graph.
+ auto* input_operand = BuildInput(builder, "input", input.dimensions,
+ input.type, scope.GetExceptionState());
+ MLTransposeOptions* options = MLTransposeOptions::Create();
+ if (permutation.has_value()) {
+ options->setPermutation(permutation.value());
+ }
+ auto* output_operand =
+ builder->transpose(input_operand, options, scope.GetExceptionState());
+ auto [graph, build_exception] =
+ helper.BuildGraph(scope, builder, {{"output", output_operand}});
+ ASSERT_NE(graph, nullptr);
+
+ auto graph_info = helper.GetGraphInfo();
+ // Verify the graph information of mojo are as expected.
+ ASSERT_EQ(graph_info->operations.size(), 1u);
+ auto& operation = graph_info->operations[0];
+ ASSERT_TRUE(operation->is_transpose());
+ auto& transpose = operation->get_transpose();
+
+ // Validate the permutation of transpose operation.
+ EXPECT_EQ(transpose->permutation, expected_permutation);
+
+ // Validate the input operand.
+ EXPECT_EQ(graph_info->input_operands.size(), 1u);
+ auto input_operand_id = graph_info->input_operands[0];
+ EXPECT_EQ(transpose->input_operand_id, input_operand_id);
+ auto input_operand_iter =
+ graph_info->id_to_operand_map.find(input_operand_id);
+ ASSERT_TRUE(input_operand_iter != graph_info->id_to_operand_map.end());
+ EXPECT_EQ(input_operand_iter->value->data_type, expected_operand.type);
+ EXPECT_EQ(input_operand_iter->value->dimensions, input.dimensions);
+
+ // Validate the output operand.
+ EXPECT_EQ(graph_info->output_operands.size(), 1u);
+ auto output_operand_id = graph_info->output_operands[0];
+ EXPECT_EQ(transpose->output_operand_id, output_operand_id);
+ auto output_operand_iter =
+ graph_info->id_to_operand_map.find(output_operand_id);
+ ASSERT_TRUE(output_operand_iter != graph_info->id_to_operand_map.end());
+ EXPECT_EQ(output_operand_iter->value->data_type, expected_operand.type);
+ EXPECT_EQ(output_operand_iter->value->dimensions,
+ expected_operand.dimensions);
+ }
+};
+
+TEST_P(MLGraphTestMojo, TransposeTest) {
+ V8TestingScope scope;
+ // Bind fake WebNN Context in the service for testing.
+ ScopedWebNNServiceBinder scoped_setup_binder(*this, scope);
+ base::test::ScopedFeatureList scoped_feature_list;
+ scoped_feature_list.InitAndEnableFeature(
+ webnn::features::kEnableMachineLearningNeuralNetworkService);
+ auto* options = MLContextOptions::Create();
+ // Create WebNN Context with GPU device preference.
+ options->setDevicePreference(V8MLDevicePreference::Enum::kGpu);
+ auto* builder = CreateMLGraphBuilder(scope.GetExecutionContext(), options);
+ {
+ // Test transpose operator with default options.
+ TransposeTester{
+ .input = {.type = V8MLOperandType::Enum::kFloat32,
+ .dimensions = {1, 2, 3, 4}},
+ .expected_operand = {.type = blink_mojom::Operand::DataType::kFloat32,
+ .dimensions = {4, 3, 2, 1}},
+ .expected_permutation = {3, 2, 1, 0}}
+ .Test(*this, scope, builder);
+ }
+ {
+ // Test transpose operator with a given permutation.
+ TransposeTester{
+ .input = {.type = V8MLOperandType::Enum::kFloat16,
+ .dimensions = {1, 2, 3, 4}},
+ .permutation = Vector<uint32_t>{3, 0, 2, 1},
+ .expected_operand = {.type = blink_mojom::Operand::DataType::kFloat16,
+ .dimensions = {4, 1, 3, 2}},
+ .expected_permutation = {3, 0, 2, 1}}
+ .Test(*this, scope, builder);
+ }
+}
+
template <typename T>
struct ConstantTester {
OperandInfo<T> constant;
diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.cc
index cb69ca6..e94d05b 100644
--- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.cc
+++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.cc
@@ -8,6 +8,7 @@
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_conv_2d_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_gemm_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_pool_2d_options.h"
+#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_transpose_options.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_activation.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
@@ -476,6 +477,29 @@
std::move(operator_mojo));
}
+OperationPtr CreateTransposeOperation(const OperandToIdMap& operand_to_id_map,
+ const MLOperator* transpose) {
+ const uint64_t input_operand_id =
+ GetOperatorInputId(transpose, operand_to_id_map);
+ const uint64_t output_operand_id =
+ GetOperatorOutputId(transpose, operand_to_id_map);
+
+ auto transpose_mojo = webnn::mojom::blink::Transpose::New();
+ transpose_mojo->input_operand_id = input_operand_id;
+ transpose_mojo->output_operand_id = output_operand_id;
+ const auto* options =
+ static_cast<const MLTransposeOptions*>(transpose->Options());
+ CHECK(options);
+
+ auto input_rank = transpose->Inputs()[0]->Dimensions().size();
+ transpose_mojo->permutation =
+ options->getPermutationOr(CreateDefaultPermutation(input_rank));
+ CHECK_EQ(transpose_mojo->permutation.size(), input_rank);
+
+ return webnn::mojom::blink::Operation::NewTranspose(
+ std::move(transpose_mojo));
+}
+
} // namespace
base::expected<OperationPtr, String> ConvertToMojoOperation(
@@ -504,13 +528,14 @@
return CreateReshapeOperator(operand_to_id_map, op);
case MLOperator::OperatorKind::kSoftmax:
return CreateSoftmaxOperator(operand_to_id_map, op);
+ case MLOperator::OperatorKind::kTranspose:
+ return CreateTransposeOperation(operand_to_id_map, op);
case MLOperator::OperatorKind::kHardSwish:
case MLOperator::OperatorKind::kReduceMean:
case MLOperator::OperatorKind::kReduceSum:
case MLOperator::OperatorKind::kResample2d:
case MLOperator::OperatorKind::kSigmoid:
case MLOperator::OperatorKind::kConcat:
- case MLOperator::OperatorKind::kTranspose:
case MLOperator::OperatorKind::kLeakyRelu:
case MLOperator::OperatorKind::kConvTranspose2d:
case MLOperator::OperatorKind::kPRelu:
diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.cc b/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.cc
index 3df926b..704c171 100644
--- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.cc
+++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.cc
@@ -203,4 +203,12 @@
NOTREACHED_NORETURN();
}
+Vector<uint32_t> CreateDefaultPermutation(const wtf_size_t rank) {
+ Vector<uint32_t> default_permutation(rank);
+ for (wtf_size_t i = 0; i < rank; ++i) {
+ default_permutation[i] = rank - 1 - i;
+ }
+ return default_permutation;
+}
+
} // namespace blink
diff --git a/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h b/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h
index bd92f3b..05c7272e 100644
--- a/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h
+++ b/third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h
@@ -77,6 +77,9 @@
webnn::AutoPad BlinkAutoPadToComponent(blink::V8MLAutoPad::Enum type);
+// Create a default permutation vector [rank - 1, ..., 0].
+Vector<uint32_t> CreateDefaultPermutation(const wtf_size_t rank);
+
// Helper to get padding sizes for convolution 2d or pooling 2d Nodes.
template <typename OptionsType>
webnn::Padding2d CalculatePadding2D(const OptionsType* options,