[go: nahoru, domu]

Skip to content

Commit

Permalink
Add test case for 1D convolution
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644286074
  • Loading branch information
Adam-Banas authored and tensorflower-gardener committed Jun 18, 2024
1 parent 9128be0 commit cf31ac0
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions third_party/xla/xla/tests/convolution_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,61 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest {
TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes);
TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); }

template <typename T>
class Convolve1D_1x3x5_3x5x3_Valid : public ConvolutionTest {
public:
void RunTest() {
XlaBuilder builder(TestName());
std::vector<int64_t> input_dims = {1, 3, 5};
std::vector<int64_t> filter_dims = {3, 5, 3};
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
{
auto input = Parameter(&builder, 0, input_shape, "input");
auto filter = Parameter(&builder, 1, filter_shape, "filter");

// Tensorflow dimension numbers for 1D convolution.
// Layout as supported by Eigen convolution.
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.add_input_spatial_dimensions(1);
dnums.set_input_feature_dimension(2);
dnums.add_kernel_spatial_dimensions(0);
dnums.set_kernel_input_feature_dimension(1);
dnums.set_kernel_output_feature_dimension(2);
dnums.set_output_batch_dimension(0);
dnums.add_output_spatial_dimensions(1);
dnums.set_output_feature_dimension(2);

ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
}

std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r3 = input_r1.Reshape(input_dims).value();

std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r3 = filter_r1.Reshape(filter_dims).value();

auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(3480), static_cast<T>(3600), static_cast<T>(3720)});
auto expected_r3 = expected_r1.Reshape({1, 1, 3}).value();

auto input_literal = client_->TransferToServer(input_r3).value();
auto filter_literal = client_->TransferToServer(filter_r3).value();

ComputeAndCompareLiteral(&builder, expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
};

TYPED_TEST_CASE(Convolve1D_1x3x5_3x5x3_Valid, TestTypes);
TYPED_TEST(Convolve1D_1x3x5_3x5x3_Valid, Types) { this->RunTest(); }

template <typename T>
class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest {
public:
Expand Down Expand Up @@ -1755,6 +1810,19 @@ ENTRY TestComputation {
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}

XLA_TEST_F(ConvolutionHloTest, TestConv2DF16) {
std::string kHlo = R"(
HloModule TestModule
ENTRY TestComputation {
%p0 = f16[8,5,5,1] parameter(0)
%p1 = f16[3,3,1,32] parameter(1)
ROOT %conv = f16[8,5,5,32] convolution(p0, p1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f
})";

EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}

XLA_TEST_F(ConvolutionHloTest, TestFusedConv2D) {
std::string kHlo = R"(
HloModule TestModule
Expand Down

0 comments on commit cf31ac0

Please sign in to comment.