[go: nahoru, domu]

Skip to content

Commit

Permalink
Throw errors when any index is out of bound in GatherV2 (tensorflow#5439
Browse files Browse the repository at this point in the history
)

* Throw errors when index is out of bound in GatherV2

* fix

* Add checks for various backends

* fix

* fix
  • Loading branch information
jinjingforever committed Sep 7, 2021
1 parent 261ca57 commit a007f43
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ tfjs-backend-wasm/wasm-out/*.wasm
yalc.lock
yarn-error.log
cloudbuild_generated.yml
wasm-dist/

# User-specific .bazelrc
.bazelrc.user
10 changes: 4 additions & 6 deletions tfjs-backend-cpu/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"@rollup/plugin-commonjs": "^11.0.2",
"@rollup/plugin-node-resolve": "^7.1.1",
"@rollup/plugin-typescript": "^3.0.0",
"@tensorflow/tfjs-core": "link:../tfjs-core",
"@tensorflow/tfjs-core": "link:../link-package-core/node_modules/@tensorflow/tfjs-core",
"@types/jasmine": "~3.0.0",
"clang-format": "~1.2.4",
"jasmine": "~3.1.0",
Expand All @@ -51,10 +51,8 @@
"build": "bazel build :tfjs-backend-cpu_pkg",
"bundle": "bazel build :tfjs-backend-cpu_pkg",
"bundle-ci": "yarn bundle",
"build-core": "cd ../tfjs-core && yarn && yarn build",
"build-core-ci": "cd ../tfjs-core && yarn && yarn build-ci",
"build-deps": "yarn build-core && yarn build",
"build-deps-ci": "yarn build-core-ci && yarn build-ci",
"build-link-package-core": "cd ../link-package-core && yarn build",
"build-deps": "yarn build-link-package-core",
"build-npm": "bazel build :tfjs-backend-cpu_pkg",
"link-local": "yalc link",
"publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push",
Expand All @@ -71,7 +69,7 @@
"seedrandom": "2.4.3"
},
"peerDependencies": {
"@tensorflow/tfjs-core": "link:../dist/bin/tfjs-core/tfjs-core_pkg"
"@tensorflow/tfjs-core": "link:../link-package-core/node_modules/@tensorflow/tfjs-core"
},
"browser": {
"util": false,
Expand Down
15 changes: 13 additions & 2 deletions tfjs-backend-cpu/src/kernels/GatherV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';

import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';
Expand All @@ -33,6 +33,18 @@ export function gatherV2(args: {

assertNotComplex([x, indices], 'gatherV2');

// Throw error when any index is out of bound.
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
const indicesVals = backend.data.get(indices.dataId).values as TypedArray;
const axisDim = x.shape[parsedAxis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
util.assert(
index <= axisDim - 1 && index >= 0,
() =>
`GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
}

let $batchDims = batchDims;

if (batchDims == null) {
Expand All @@ -41,7 +53,6 @@ export function gatherV2(args: {

const indicesSize = util.sizeFromShape(indices.shape);

const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(
x, indices, parsedAxis, $batchDims);

Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-cpu/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@
estree-walker "^1.0.1"
picomatch "^2.2.2"

"@tensorflow/tfjs-core@link:../tfjs-core":
"@tensorflow/tfjs-core@link:../link-package-core/node_modules/@tensorflow/tfjs-core":
version "0.0.0"
uid ""

Expand Down
13 changes: 12 additions & 1 deletion tfjs-backend-wasm/src/kernels/GatherV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, Tensor, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

Expand Down Expand Up @@ -47,7 +47,18 @@ function gatherV2(
const {x, indices} = inputs;
const {axis, batchDims} = attrs;

// Throw error when any index is out of bound.
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
const indicesVals = backend.readSync(indices.dataId) as TypedArray;
const axisDim = x.shape[parsedAxis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
util.assert(
index <= axisDim - 1 && index >= 0,
() =>
`GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
}

const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(
x as Tensor, indices as Tensor, parsedAxis, batchDims);

Expand Down
37 changes: 37 additions & 0 deletions tfjs-backend-wasm/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -897,9 +897,11 @@

"@tensorflow/tfjs-backend-cpu@link:../link-package/node_modules/@tensorflow/tfjs-backend-cpu":
version "0.0.0"
uid ""

"@tensorflow/tfjs-core@link:../link-package/node_modules/@tensorflow/tfjs-core":
version "0.0.0"
uid ""

"@types/component-emitter@^1.2.10":
version "1.2.10"
Expand Down Expand Up @@ -936,18 +938,38 @@
resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-2.8.17.tgz#65fa3be377126253f6c7988b365dfc78d62d536e"
integrity sha512-lXmY2lBjE38ASvP7ah38yZwXCdc7DTCKhHqx4J3WGNiVzp134U0BD9VKdL5x9q9AAfhnpJeQr4owL6ZOXhOpfA==

"@types/long@^4.0.1":
version "4.0.1"
resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.1.tgz#459c65fa1867dafe6a8f322c4c51695663cc55e9"
integrity sha512-5tXH6Bx/kNGd3MgffdmP4dy2Z+G4eaXw0SE81Tq3BNadtnMR5/ySMzX4SLEzHJzSmPNn4HIdpQsBvXMUykr58w==

"@types/node@*", "@types/node@>=10.0.0":
version "14.14.37"
resolved "https://registry.yarnpkg.com/@types/node/-/node-14.14.37.tgz#a3dd8da4eb84a996c36e331df98d82abd76b516e"
integrity sha512-XYmBiy+ohOR4Lh5jE379fV2IU+6Jn4g5qASinhitfyO71b/sCo6MKsMLF5tc7Zf2CE8hViVQyYSobJNke8OvUw==

"@types/offscreencanvas@~2019.3.0":
version "2019.3.0"
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.3.0.tgz#3336428ec7e9180cf4566dfea5da04eb586a6553"
integrity sha512-esIJx9bQg+QYF0ra8GnvfianIY8qWB0GBx54PK5Eps6m+xTj86KLavHv6qDhzKcu5UUOgNfJ2pWaIIV7TRUd9Q==

"@types/resolve@0.0.8":
version "0.0.8"
resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-0.0.8.tgz#f26074d238e02659e323ce1a13d041eee280e194"
integrity sha512-auApPaJf3NPfe18hSoJkp8EbZzer2ISk7o8mCC3M9he/a04+gbMF97NkpD2S8riMGvm4BMRI59/SZQSaLTKpsQ==
dependencies:
"@types/node" "*"

"@types/seedrandom@2.4.27":
version "2.4.27"
resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.27.tgz#9db563937dd86915f69092bc43259d2f48578e41"
integrity sha1-nbVjk33YaRX2kJK8QyWdL0hXjkE=

"@types/webgl-ext@0.0.30":
version "0.0.30"
resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d"
integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg==

accepts@~1.3.4:
version "1.3.7"
resolved "https://registry.yarnpkg.com/accepts/-/accepts-1.3.7.tgz#531bc726517a3b2b41f850021c6cc15eaab507cd"
Expand Down Expand Up @@ -2539,6 +2561,11 @@ log4js@^6.2.1, log4js@^6.3.0:
rfdc "^1.1.4"
streamroller "^2.2.4"

long@4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28"
integrity sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==

magic-string@^0.25.2, magic-string@^0.25.7:
version "0.25.7"
resolved "https://registry.yarnpkg.com/magic-string/-/magic-string-0.25.7.tgz#3f497d6fd34c669c6798dcb821f2ef31f5445051"
Expand Down Expand Up @@ -2661,6 +2688,11 @@ negotiator@0.6.2:
resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.2.tgz#feacf7ccf525a77ae9634436a64883ffeca346fb"
integrity sha512-hZXc7K2e+PgeI1eDBe/10Ard4ekbfrrqG8Ep+8Jmf4JID2bNg7NvCPOZN+kfF574pFQI7mum2AUqDidoKqcTOw==

node-fetch@~2.6.1:
version "2.6.1"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.1.tgz#045bd323631f76ed2e2b55573394416b639a0052"
integrity sha512-V4aYg89jEoVRxRb2fJdAg8FHvI7cEyYdVAh94HH0UIK8oJxUfkjlDQN9RbMx+bEjP7+ggMiFRprSti032Oipxw==

node-releases@^1.1.71:
version "1.1.72"
resolved "https://registry.yarnpkg.com/node-releases/-/node-releases-1.1.72.tgz#14802ab6b1039a79a0c7d662b610a5bbd76eacbe"
Expand Down Expand Up @@ -3107,6 +3139,11 @@ safe-buffer@~5.1.1:
resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a"
integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==

seedrandom@2.4.3:
version "2.4.3"
resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-2.4.3.tgz#2438504dad33917314bff18ac4d794f16d6aaecc"
integrity sha1-JDhQTa0zkXMUv/GKxNeU8W1qrsw=

semver@7.0.0:
version "7.0.0"
resolved "https://registry.yarnpkg.com/semver/-/semver-7.0.0.tgz#5f3ca35761e47e05b206c6daff2cf814f0316b8e"
Expand Down
11 changes: 11 additions & 0 deletions tfjs-backend-webgl/src/kernels/GatherV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,18 @@ export function gatherV2(args: {
const {x, indices} = inputs;
const {axis, batchDims} = attrs;

// Throw error when any index is out of bound.
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
const indicesVals = backend.readSync(indices.dataId) as TypedArray;
const axisDim = x.shape[parsedAxis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
util.assert(
index <= axisDim - 1 && index >= 0,
() =>
`GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
}

const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(
x, indices, parsedAxis, batchDims);

Expand Down
11 changes: 11 additions & 0 deletions tfjs-backend-webgpu/src/kernels/GatherV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,18 @@ export function gatherV2(
const {x, indices} = inputs;
const {axis, batchDims} = attrs;

// Throw error when any index is out of bound.
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
const indicesVals = backend.readSync(indices.dataId) as TypedArray;
const axisDim = x.shape[parsedAxis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
util.assert(
index <= axisDim - 1 && index >= 0,
() =>
`GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
}

const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(
x, indices, parsedAxis, batchDims);

Expand Down
8 changes: 8 additions & 0 deletions tfjs-core/src/ops/gather_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ describeWithFlags('gather', ALL_ENVS, () => {
.toThrowError(/Argument 'indices' passed to 'gather' must be a Tensor/);
});

it('throws when index is out of bound', async () => {
const t = tf.tensor2d([1, 11, 2, 22], [2, 2]);
expect(() => tf.gather(t, tf.tensor1d([100], 'int32')))
.toThrowError(/GatherV2: the index value 100 is not in \[0, 1\]/);
expect(() => tf.gather(t, tf.tensor1d([-1], 'int32')))
.toThrowError(/GatherV2: the index value -1 is not in \[0, 1\]/);
});

it('accepts a tensor-like object', async () => {
const res = tf.gather([1, 2, 3], [0, 2, 0, 1], 0);
expect(res.shape).toEqual([4]);
Expand Down
13 changes: 12 additions & 1 deletion tfjs-node/src/kernels/GatherV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, scalar, Tensor} from '@tensorflow/tfjs';
import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, scalar, Tensor, TypedArray, util} from '@tensorflow/tfjs';

import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend';

Expand All @@ -27,6 +27,17 @@ export const gatherV2Config: KernelConfig = {
const backend = args.backend as NodeJSKernelBackend;
const {axis, batchDims} = args.attrs as {} as GatherV2Attrs;

// Throw error when any index is out of bound.
const indicesVals = backend.readSync(indices.dataId) as TypedArray;
const axisDim = x.shape[axis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
util.assert(
index <= axisDim - 1 && index >= 0,
() => `GatherV2: the index value ${index} is not in [0, ${
axisDim - 1}]`);
}

// validate the inputs
backend_util.segment_util.collectGatherOpShapeInfo(
x as Tensor, indices as Tensor, axis, batchDims);
Expand Down
10 changes: 5 additions & 5 deletions tfjs-node/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@
version "0.0.0"
uid ""

"@tensorflow/tfjs-converter@link:../tfjs-converter":
"@tensorflow/tfjs-converter@link:../link-package/node_modules/@tensorflow/tfjs-converter":
version "0.0.0"
uid ""

Expand Down Expand Up @@ -348,10 +348,10 @@ abbrev@1:
resolved "https://registry.yarnpkg.com/abbrev/-/abbrev-1.1.1.tgz#f8f2c887ad10bf67f634f005b6987fed3179aac8"
integrity sha512-nne9/IiQ/hzIhY6pdDnbBtz7DjPTKrY00P/zvPSm5pOFkl6xuGrGnXn/VtTNNfNtAfZ9/1RtehkszU9qcTii0Q==

adm-zip@^0.4.11:
version "0.4.16"
resolved "https://registry.yarnpkg.com/adm-zip/-/adm-zip-0.4.16.tgz#cf4c508fdffab02c269cbc7f471a875f05570365"
integrity sha512-TFi4HBKSGfIKsK5YCkKaaFG2m4PEDyViZmEwof3MTIgzimHLto6muaHVpbrljdIvIrFZzEq/p4nafOeLcYegrg==
adm-zip@^0.5.2:
version "0.5.5"
resolved "https://registry.yarnpkg.com/adm-zip/-/adm-zip-0.5.5.tgz#b6549dbea741e4050309f1bb4d47c47397ce2c4f"
integrity sha512-IWwXKnCbirdbyXSfUDvCCrmYrOHANRZcc8NcRrvTlIApdl7PwE9oGcsYvNeJPAVY1M+70b4PxXGKIf8AEuiQ6w==

agent-base@6:
version "6.0.2"
Expand Down

0 comments on commit a007f43

Please sign in to comment.