[go: nahoru, domu]

Skip to content

Commit

Permalink
[wasm] Add all/any kernels (tensorflow#4799)
Browse files Browse the repository at this point in the history
FEATURE
* Add all/any kernels

* Fix license

* Fix license

Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
  • Loading branch information
kon72 and pyu10055 committed Mar 18, 2021
1 parent 34e2720 commit e1420f6
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ tfjs_cc_library(
":Abs",
":Add",
":AddN",
":All",
":Any",
":ArgMax",
":AvgPool",
":BatchMatMul",
Expand Down Expand Up @@ -330,6 +332,24 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "All",
srcs = ["kernels/All.cc"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "Any",
srcs = ["kernels/Any.cc"],
deps = [
":backend",
":util",
],
)

tfjs_cc_library(
name = "ArgMax",
srcs = ["kernels/ArgMax.cc"],
Expand Down
62 changes: 62 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/All.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <cstddef>

#include "tfjs-backend-wasm/src/cc/backend.h"

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void All(const size_t x_id, const size_t reduce_size, const size_t out_id) {
auto& x_info = backend::get_tensor_info(x_id);
auto& out_info = backend::get_tensor_info_out(out_id);

const bool* x_buf = x_info.b();
const size_t x_size = x_info.size;

bool* out_buf = out_info.b_write();
const size_t out_size = out_info.size;

const bool* x_offset = x_buf;

for (size_t i = 0; i < out_size; ++i) {
const size_t offset = i * reduce_size;
bool all = x_buf[offset];

const bool* x_iter_end = x_offset + reduce_size;

for (const bool* x = x_offset; x < x_iter_end; ++x) {
bool value = *x;
all = all && value;
}

x_offset += reduce_size;

out_buf[i] = all;
}
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
62 changes: 62 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/Any.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <cstddef>

#include "tfjs-backend-wasm/src/cc/backend.h"

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void Any(const size_t x_id, const size_t reduce_size, const size_t out_id) {
auto& x_info = backend::get_tensor_info(x_id);
auto& out_info = backend::get_tensor_info_out(out_id);

const bool* x_buf = x_info.b();
const size_t x_size = x_info.size;

bool* out_buf = out_info.b_write();
const size_t out_size = out_info.size;

const bool* x_offset = x_buf;

for (size_t i = 0; i < out_size; ++i) {
const size_t offset = i * reduce_size;
bool any = x_buf[offset];

const bool* x_iter_end = x_offset + reduce_size;

for (const bool* x = x_offset; x < x_iter_end; ++x) {
bool value = *x;
any = any || value;
}

x_offset += reduce_size;

out_buf[i] = any;
}
}

} // extern "C"
} // namespace wasm
} // namespace tfjs
79 changes: 79 additions & 0 deletions tfjs-backend-wasm/src/kernels/All.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {All, AllAttrs, AllInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

import {permuteAxesAndTranspose} from './kernel_utils';

let wasmAll: (xId: number, reduceSize: number, outId: number) => void;

function setup(backend: BackendWasm): void {
wasmAll = backend.wasm.cwrap(All, null /*void*/, ['number, number, number']);
}

function all(args: {backend: BackendWasm, inputs: AllInputs, attrs: AllAttrs}):
TensorInfo {
const {backend, inputs, attrs} = args;
const {axis, keepDims} = attrs;
const {x} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;

const {transposed, axes, originalAxes, inputWasTransposed} =
permuteAxesAndTranspose(x, axis, backend);

if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
input = transposed;
inputId = transposedId;
}

const inputRank = input.shape.length;
backend_util.assertAxesAreInnerMostDims('all', axes, inputRank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(input.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);

const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmAll(inputId, reduceSize, outId);
}

if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}

if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}

return out;
}

export const allConfig: KernelConfig = {
kernelName: All,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: all as {} as KernelFunc
};
79 changes: 79 additions & 0 deletions tfjs-backend-wasm/src/kernels/Any.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Any, AnyAttrs, AnyInputs, backend_util, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

import {permuteAxesAndTranspose} from './kernel_utils';

let wasmAny: (xId: number, reduceSize: number, outId: number) => void;

function setup(backend: BackendWasm): void {
wasmAny = backend.wasm.cwrap(Any, null /*void*/, ['number, number, number']);
}

function any(args: {backend: BackendWasm, inputs: AnyInputs, attrs: AnyAttrs}):
TensorInfo {
const {backend, inputs, attrs} = args;
const {axis, keepDims} = attrs;
const {x} = inputs;
const xId = backend.dataIdMap.get(x.dataId).id;
let inputId = xId;
let input = x;

const {transposed, axes, originalAxes, inputWasTransposed} =
permuteAxesAndTranspose(x, axis, backend);

if (inputWasTransposed) {
const transposedId = backend.dataIdMap.get(transposed.dataId).id;
input = transposed;
inputId = transposedId;
}

const inputRank = input.shape.length;
backend_util.assertAxesAreInnerMostDims('any', axes, inputRank);
const [outShape, reduceShape] =
backend_util.computeOutAndReduceShapes(input.shape, axes);
const reduceSize = util.sizeFromShape(reduceShape);

const out = backend.makeOutput(outShape, x.dtype);
if (util.sizeFromShape(input.shape) !== 0) {
const outId = backend.dataIdMap.get(out.dataId).id;
wasmAny(inputId, reduceSize, outId);
}

if (inputWasTransposed) {
// dispose of the transposed tensor.
backend.disposeData(transposed.dataId);
}

if (keepDims) {
// reshape
const newShape = backend_util.expandShapeToKeepDim(out.shape, originalAxes);
out.shape = newShape;
}

return out;
}

export const anyConfig: KernelConfig = {
kernelName: Any,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: any as {} as KernelFunc
};
4 changes: 4 additions & 0 deletions tfjs-backend-wasm/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import {fusedMatMulConfig} from './kernels/_FusedMatMul';
import {absConfig} from './kernels/Abs';
import {addConfig} from './kernels/Add';
import {addNConfig} from './kernels/AddN';
import {allConfig} from './kernels/All';
import {anyConfig} from './kernels/Any';
import {argMaxConfig} from './kernels/ArgMax';
import {avgPoolConfig} from './kernels/AvgPool';
import {batchMatMulConfig} from './kernels/BatchMatMul';
Expand Down Expand Up @@ -112,6 +114,8 @@ const kernelConfigs: KernelConfig[] = [
absConfig,
addConfig,
addNConfig,
allConfig,
anyConfig,
argMaxConfig,
avgPoolConfig,
batchMatMulConfig,
Expand Down
12 changes: 12 additions & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,18 @@ const TEST_FILTERS: TestFilter[] = [
},
{include: 'step kernel'},
{include: 'ceil'},
{
startsWith: 'all',
excludes: [
'ignores NaNs' // Doesn't yet ignore NaN
]
},
{
startsWith: 'any',
excludes: [
'ignores NaNs' // Doesn't yet ignore NaN
]
},
];

const customInclude = (testName: string) => {
Expand Down

0 comments on commit e1420f6

Please sign in to comment.