diff --git a/.gitignore b/.gitignore index b379f823bb..80e81a9384 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,7 @@ tfjs-backend-wasm/wasm-out/*.wasm yalc.lock yarn-error.log cloudbuild_generated.yml +wasm-dist/ # User-specific .bazelrc .bazelrc.user diff --git a/tfjs-backend-cpu/package.json b/tfjs-backend-cpu/package.json index 6284dc7c03..e0fb6bf1c9 100644 --- a/tfjs-backend-cpu/package.json +++ b/tfjs-backend-cpu/package.json @@ -24,7 +24,7 @@ "@rollup/plugin-commonjs": "^11.0.2", "@rollup/plugin-node-resolve": "^7.1.1", "@rollup/plugin-typescript": "^3.0.0", - "@tensorflow/tfjs-core": "link:../tfjs-core", + "@tensorflow/tfjs-core": "link:../link-package-core/node_modules/@tensorflow/tfjs-core", "@types/jasmine": "~3.0.0", "clang-format": "~1.2.4", "jasmine": "~3.1.0", @@ -51,10 +51,8 @@ "build": "bazel build :tfjs-backend-cpu_pkg", "bundle": "bazel build :tfjs-backend-cpu_pkg", "bundle-ci": "yarn bundle", - "build-core": "cd ../tfjs-core && yarn && yarn build", - "build-core-ci": "cd ../tfjs-core && yarn && yarn build-ci", - "build-deps": "yarn build-core && yarn build", - "build-deps-ci": "yarn build-core-ci && yarn build-ci", + "build-link-package-core": "cd ../link-package-core && yarn build", + "build-deps": "yarn build-link-package-core", "build-npm": "bazel build :tfjs-backend-cpu_pkg", "link-local": "yalc link", "publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push", @@ -71,7 +69,7 @@ "seedrandom": "2.4.3" }, "peerDependencies": { - "@tensorflow/tfjs-core": "link:../dist/bin/tfjs-core/tfjs-core_pkg" + "@tensorflow/tfjs-core": "link:../link-package-core/node_modules/@tensorflow/tfjs-core" }, "browser": { "util": false, diff --git a/tfjs-backend-cpu/src/kernels/GatherV2.ts b/tfjs-backend-cpu/src/kernels/GatherV2.ts index e4bf765288..27a691f63b 100644 --- a/tfjs-backend-cpu/src/kernels/GatherV2.ts +++ b/tfjs-backend-cpu/src/kernels/GatherV2.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; @@ -33,6 +33,18 @@ export function gatherV2(args: { assertNotComplex([x, indices], 'gatherV2'); + // Throw error when any index is out of bound. + const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; + const indicesVals = backend.data.get(indices.dataId).values as TypedArray; + const axisDim = x.shape[parsedAxis]; + for (let i = 0; i < indicesVals.length; ++i) { + const index = indicesVals[i]; + util.assert( + index <= axisDim - 1 && index >= 0, + () => + `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`); + } + let $batchDims = batchDims; if (batchDims == null) { @@ -41,7 +53,6 @@ export function gatherV2(args: { const indicesSize = util.sizeFromShape(indices.shape); - const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo( x, indices, parsedAxis, $batchDims); diff --git a/tfjs-backend-cpu/yarn.lock b/tfjs-backend-cpu/yarn.lock index 2780884c4c..efe81a56ff 100644 --- a/tfjs-backend-cpu/yarn.lock +++ b/tfjs-backend-cpu/yarn.lock @@ -263,7 +263,7 @@ estree-walker "^1.0.1" picomatch "^2.2.2" -"@tensorflow/tfjs-core@link:../tfjs-core": +"@tensorflow/tfjs-core@link:../link-package-core/node_modules/@tensorflow/tfjs-core": version "0.0.0" uid "" diff --git a/tfjs-backend-wasm/src/kernels/GatherV2.ts b/tfjs-backend-wasm/src/kernels/GatherV2.ts index d66abbdc29..9965261e1c 100644 --- a/tfjs-backend-wasm/src/kernels/GatherV2.ts +++ b/tfjs-backend-wasm/src/kernels/GatherV2.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, Tensor, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, KernelFunc, Tensor, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; @@ -47,7 +47,18 @@ function gatherV2( const {x, indices} = inputs; const {axis, batchDims} = attrs; + // Throw error when any index is out of bound. const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; + const indicesVals = backend.readSync(indices.dataId) as TypedArray; + const axisDim = x.shape[parsedAxis]; + for (let i = 0; i < indicesVals.length; ++i) { + const index = indicesVals[i]; + util.assert( + index <= axisDim - 1 && index >= 0, + () => + `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`); + } + const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo( x as Tensor, indices as Tensor, parsedAxis, batchDims); diff --git a/tfjs-backend-wasm/yarn.lock b/tfjs-backend-wasm/yarn.lock index fc95ffe125..43e48d07f2 100644 --- a/tfjs-backend-wasm/yarn.lock +++ b/tfjs-backend-wasm/yarn.lock @@ -897,9 +897,11 @@ "@tensorflow/tfjs-backend-cpu@link:../link-package/node_modules/@tensorflow/tfjs-backend-cpu": version "0.0.0" + uid "" "@tensorflow/tfjs-core@link:../link-package/node_modules/@tensorflow/tfjs-core": version "0.0.0" + uid "" "@types/component-emitter@^1.2.10": version "1.2.10" @@ -936,11 +938,21 @@ resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-2.8.17.tgz#65fa3be377126253f6c7988b365dfc78d62d536e" integrity sha512-lXmY2lBjE38ASvP7ah38yZwXCdc7DTCKhHqx4J3WGNiVzp134U0BD9VKdL5x9q9AAfhnpJeQr4owL6ZOXhOpfA== +"@types/long@^4.0.1": + version "4.0.1" + resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.1.tgz#459c65fa1867dafe6a8f322c4c51695663cc55e9" + integrity sha512-5tXH6Bx/kNGd3MgffdmP4dy2Z+G4eaXw0SE81Tq3BNadtnMR5/ySMzX4SLEzHJzSmPNn4HIdpQsBvXMUykr58w== + "@types/node@*", "@types/node@>=10.0.0": version "14.14.37" resolved "https://registry.yarnpkg.com/@types/node/-/node-14.14.37.tgz#a3dd8da4eb84a996c36e331df98d82abd76b516e" integrity sha512-XYmBiy+ohOR4Lh5jE379fV2IU+6Jn4g5qASinhitfyO71b/sCo6MKsMLF5tc7Zf2CE8hViVQyYSobJNke8OvUw== +"@types/offscreencanvas@~2019.3.0": + version "2019.3.0" + resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.3.0.tgz#3336428ec7e9180cf4566dfea5da04eb586a6553" + integrity sha512-esIJx9bQg+QYF0ra8GnvfianIY8qWB0GBx54PK5Eps6m+xTj86KLavHv6qDhzKcu5UUOgNfJ2pWaIIV7TRUd9Q== + "@types/resolve@0.0.8": version "0.0.8" resolved "https://registry.yarnpkg.com/@types/resolve/-/resolve-0.0.8.tgz#f26074d238e02659e323ce1a13d041eee280e194" @@ -948,6 +960,16 @@ dependencies: "@types/node" "*" +"@types/seedrandom@2.4.27": + version "2.4.27" + resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.27.tgz#9db563937dd86915f69092bc43259d2f48578e41" + integrity sha1-nbVjk33YaRX2kJK8QyWdL0hXjkE= + +"@types/webgl-ext@0.0.30": + version "0.0.30" + resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d" + integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg== + accepts@~1.3.4: version "1.3.7" resolved "https://registry.yarnpkg.com/accepts/-/accepts-1.3.7.tgz#531bc726517a3b2b41f850021c6cc15eaab507cd" @@ -2539,6 +2561,11 @@ log4js@^6.2.1, log4js@^6.3.0: rfdc "^1.1.4" streamroller "^2.2.4" +long@4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28" + integrity sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA== + magic-string@^0.25.2, magic-string@^0.25.7: version "0.25.7" resolved "https://registry.yarnpkg.com/magic-string/-/magic-string-0.25.7.tgz#3f497d6fd34c669c6798dcb821f2ef31f5445051" @@ -2661,6 +2688,11 @@ negotiator@0.6.2: resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.2.tgz#feacf7ccf525a77ae9634436a64883ffeca346fb" integrity sha512-hZXc7K2e+PgeI1eDBe/10Ard4ekbfrrqG8Ep+8Jmf4JID2bNg7NvCPOZN+kfF574pFQI7mum2AUqDidoKqcTOw== +node-fetch@~2.6.1: + version "2.6.1" + resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.1.tgz#045bd323631f76ed2e2b55573394416b639a0052" + integrity sha512-V4aYg89jEoVRxRb2fJdAg8FHvI7cEyYdVAh94HH0UIK8oJxUfkjlDQN9RbMx+bEjP7+ggMiFRprSti032Oipxw== + node-releases@^1.1.71: version "1.1.72" resolved "https://registry.yarnpkg.com/node-releases/-/node-releases-1.1.72.tgz#14802ab6b1039a79a0c7d662b610a5bbd76eacbe" @@ -3107,6 +3139,11 @@ safe-buffer@~5.1.1: resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a" integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== +seedrandom@2.4.3: + version "2.4.3" + resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-2.4.3.tgz#2438504dad33917314bff18ac4d794f16d6aaecc" + integrity sha1-JDhQTa0zkXMUv/GKxNeU8W1qrsw= + semver@7.0.0: version "7.0.0" resolved "https://registry.yarnpkg.com/semver/-/semver-7.0.0.tgz#5f3ca35761e47e05b206c6daff2cf814f0316b8e" diff --git a/tfjs-backend-webgl/src/kernels/GatherV2.ts b/tfjs-backend-webgl/src/kernels/GatherV2.ts index e815d0f0d7..e07aef60a8 100644 --- a/tfjs-backend-webgl/src/kernels/GatherV2.ts +++ b/tfjs-backend-webgl/src/kernels/GatherV2.ts @@ -32,7 +32,18 @@ export function gatherV2(args: { const {x, indices} = inputs; const {axis, batchDims} = attrs; + // Throw error when any index is out of bound. const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; + const indicesVals = backend.readSync(indices.dataId) as TypedArray; + const axisDim = x.shape[parsedAxis]; + for (let i = 0; i < indicesVals.length; ++i) { + const index = indicesVals[i]; + util.assert( + index <= axisDim - 1 && index >= 0, + () => + `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`); + } + const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo( x, indices, parsedAxis, batchDims); diff --git a/tfjs-backend-webgpu/src/kernels/GatherV2.ts b/tfjs-backend-webgpu/src/kernels/GatherV2.ts index 3bcf724e68..0c308e82d2 100644 --- a/tfjs-backend-webgpu/src/kernels/GatherV2.ts +++ b/tfjs-backend-webgpu/src/kernels/GatherV2.ts @@ -31,7 +31,18 @@ export function gatherV2( const {x, indices} = inputs; const {axis, batchDims} = attrs; + // Throw error when any index is out of bound. const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; + const indicesVals = backend.readSync(indices.dataId) as TypedArray; + const axisDim = x.shape[parsedAxis]; + for (let i = 0; i < indicesVals.length; ++i) { + const index = indicesVals[i]; + util.assert( + index <= axisDim - 1 && index >= 0, + () => + `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`); + } + const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo( x, indices, parsedAxis, batchDims); diff --git a/tfjs-core/src/ops/gather_test.ts b/tfjs-core/src/ops/gather_test.ts index 8fb2f67fa6..5b2a627986 100644 --- a/tfjs-core/src/ops/gather_test.ts +++ b/tfjs-core/src/ops/gather_test.ts @@ -220,6 +220,14 @@ describeWithFlags('gather', ALL_ENVS, () => { .toThrowError(/Argument 'indices' passed to 'gather' must be a Tensor/); }); + it('throws when index is out of bound', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + expect(() => tf.gather(t, tf.tensor1d([100], 'int32'))) + .toThrowError(/GatherV2: the index value 100 is not in \[0, 1\]/); + expect(() => tf.gather(t, tf.tensor1d([-1], 'int32'))) + .toThrowError(/GatherV2: the index value -1 is not in \[0, 1\]/); + }); + it('accepts a tensor-like object', async () => { const res = tf.gather([1, 2, 3], [0, 2, 0, 1], 0); expect(res.shape).toEqual([4]); diff --git a/tfjs-node/src/kernels/GatherV2.ts b/tfjs-node/src/kernels/GatherV2.ts index 7aea08f91d..e5f7450cd8 100644 --- a/tfjs-node/src/kernels/GatherV2.ts +++ b/tfjs-node/src/kernels/GatherV2.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, scalar, Tensor} from '@tensorflow/tfjs'; +import {backend_util, GatherV2, GatherV2Attrs, GatherV2Inputs, KernelConfig, scalar, Tensor, TypedArray, util} from '@tensorflow/tfjs'; import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; @@ -27,6 +27,17 @@ export const gatherV2Config: KernelConfig = { const backend = args.backend as NodeJSKernelBackend; const {axis, batchDims} = args.attrs as {} as GatherV2Attrs; + // Throw error when any index is out of bound. + const indicesVals = backend.readSync(indices.dataId) as TypedArray; + const axisDim = x.shape[axis]; + for (let i = 0; i < indicesVals.length; ++i) { + const index = indicesVals[i]; + util.assert( + index <= axisDim - 1 && index >= 0, + () => `GatherV2: the index value ${index} is not in [0, ${ + axisDim - 1}]`); + } + // validate the inputs backend_util.segment_util.collectGatherOpShapeInfo( x as Tensor, indices as Tensor, axis, batchDims); diff --git a/tfjs-node/yarn.lock b/tfjs-node/yarn.lock index d05bbd9a1c..c3835a04d8 100644 --- a/tfjs-node/yarn.lock +++ b/tfjs-node/yarn.lock @@ -235,7 +235,7 @@ version "0.0.0" uid "" -"@tensorflow/tfjs-converter@link:../tfjs-converter": +"@tensorflow/tfjs-converter@link:../link-package/node_modules/@tensorflow/tfjs-converter": version "0.0.0" uid "" @@ -348,10 +348,10 @@ abbrev@1: resolved "https://registry.yarnpkg.com/abbrev/-/abbrev-1.1.1.tgz#f8f2c887ad10bf67f634f005b6987fed3179aac8" integrity sha512-nne9/IiQ/hzIhY6pdDnbBtz7DjPTKrY00P/zvPSm5pOFkl6xuGrGnXn/VtTNNfNtAfZ9/1RtehkszU9qcTii0Q== -adm-zip@^0.4.11: - version "0.4.16" - resolved "https://registry.yarnpkg.com/adm-zip/-/adm-zip-0.4.16.tgz#cf4c508fdffab02c269cbc7f471a875f05570365" - integrity sha512-TFi4HBKSGfIKsK5YCkKaaFG2m4PEDyViZmEwof3MTIgzimHLto6muaHVpbrljdIvIrFZzEq/p4nafOeLcYegrg== +adm-zip@^0.5.2: + version "0.5.5" + resolved "https://registry.yarnpkg.com/adm-zip/-/adm-zip-0.5.5.tgz#b6549dbea741e4050309f1bb4d47c47397ce2c4f" + integrity sha512-IWwXKnCbirdbyXSfUDvCCrmYrOHANRZcc8NcRrvTlIApdl7PwE9oGcsYvNeJPAVY1M+70b4PxXGKIf8AEuiQ6w== agent-base@6: version "6.0.2"