[go: nahoru, domu]

Skip to content

Commit

Permalink
Engine/backend refCount refactoring (tensorflow#4628)
Browse files Browse the repository at this point in the history
DEV
  • Loading branch information
pyu10055 committed Feb 5, 2021
1 parent 9ad607d commit 0767a6e
Show file tree
Hide file tree
Showing 25 changed files with 410 additions and 205 deletions.
52 changes: 34 additions & 18 deletions tfjs-backend-cpu/src/backend_cpu.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down Expand Up @@ -39,6 +39,10 @@ export class MathBackendCPU extends KernelBackend {

data: DataStorage<TensorData<DataType>>;
private firstUse = true;
private static nextDataId = 0;
private nextDataId(): number {
return MathBackendCPU.nextDataId++;
}

constructor() {
super();
Expand All @@ -63,7 +67,7 @@ export class MathBackendCPU extends KernelBackend {
'\n============================');
}
}
const dataId = {};
const dataId = {id: this.nextDataId()};

this.data.set(dataId, {values, dtype, refCount: 1});

Expand Down Expand Up @@ -93,6 +97,15 @@ export class MathBackendCPU extends KernelBackend {
return {dataId: outId, shape, dtype};
}

/** Return refCount of a `TensorData`. */
refCount(dataId: DataId): number {
if (this.data.has(dataId)) {
const tensorData = this.data.get(dataId);
return tensorData.refCount;
}
return 0;
}

/** Increase refCount of a `TensorData`. */
incRef(dataId: DataId): void {
const tensorData = this.data.get(dataId);
Expand All @@ -109,8 +122,8 @@ export class MathBackendCPU extends KernelBackend {

move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType): void {
this.data.set(dataId, {values, dtype, refCount: 1});
dtype: DataType, refCount: number): void {
this.data.set(dataId, {values, dtype, refCount});
}

numDataIds(): number {
Expand Down Expand Up @@ -155,31 +168,34 @@ export class MathBackendCPU extends KernelBackend {
return engine().makeTensorFromDataId(dataId, shape, dtype, this) as T;
}

disposeData(dataId: DataId): void {
/**
* Dispose the memory if the dataId has 0 refCount. Return true if the memory
* is released or memory is not managed in this backend, false if memory is
* not cleared.
* @param dataId
* @oaram force Optional, remove the data regardless of refCount
*/
disposeData(dataId: DataId, force = false): boolean {
if (this.data.has(dataId)) {
this.data.get(dataId).refCount--;
if (!force && this.data.get(dataId).refCount > 0) {
return false;
}

const {complexTensorInfos} = this.data.get(dataId);

if (complexTensorInfos != null) {
this.disposeData(complexTensorInfos.real.dataId);
this.disposeData(complexTensorInfos.imag.dataId);
this.disposeData(complexTensorInfos.real.dataId, true);
this.disposeData(complexTensorInfos.imag.dataId, true);
}

this.data.delete(dataId);
}
return true;
}

disposeIntermediateTensorInfo(tensorInfo: TensorInfo): void {
const dataId = tensorInfo.dataId;

if (this.data.has(dataId)) {
const tensorData = this.data.get(dataId);

tensorData.refCount--;

if (tensorData.refCount < 1) {
this.disposeData(dataId);
}
}
this.disposeData(tensorInfo.dataId);
}

async time(f: () => void): Promise<BackendTimingInfo> {
Expand Down
54 changes: 43 additions & 11 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ interface TensorData {
memoryOffset: number;
shape: number[];
dtype: DataType;
refCount: number;
/** Only used for string tensors, storing encoded bytes. */
stringBytes?: Uint8Array[];
}
Expand All @@ -48,8 +49,8 @@ export class BackendWasm extends KernelBackend {

write(values: backend_util.BackendValues, shape: number[], dtype: DataType):
DataId {
const dataId = {};
this.move(dataId, values, shape, dtype);
const dataId = {id: this.dataIdNextNumber++};
this.move(dataId, values, shape, dtype, 1);
return dataId;
}

Expand All @@ -66,20 +67,21 @@ export class BackendWasm extends KernelBackend {

move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType): void {
dtype: DataType, refCount: number): void {
const id = this.dataIdNextNumber++;
if (dtype === 'string') {
const stringBytes = values as Uint8Array[];
this.dataIdMap.set(
dataId, {id, stringBytes, shape, dtype, memoryOffset: null});
dataId,
{id, stringBytes, shape, dtype, memoryOffset: null, refCount});
return;
}

const size = util.sizeFromShape(shape);
const numBytes = size * util.bytesPerElement(dtype);
const memoryOffset = this.wasm._malloc(numBytes);

this.dataIdMap.set(dataId, {id, memoryOffset, shape, dtype});
this.dataIdMap.set(dataId, {id, memoryOffset, shape, dtype, refCount});

this.wasm.tfjs.registerTensor(id, size, memoryOffset);

Expand Down Expand Up @@ -108,11 +110,41 @@ export class BackendWasm extends KernelBackend {
return typedArrayFromBuffer(bytes.buffer, dtype);
}

disposeData(dataId: DataId) {
/**
* Dispose the memory if the dataId has 0 refCount. Return true if the memory
* is released, false otherwise.
* @param dataId
* @oaram force Optional, remove the data regardless of refCount
*/
disposeData(dataId: DataId, force = false): boolean {
if (this.dataIdMap.has(dataId)) {
const data = this.dataIdMap.get(dataId);
data.refCount--;
if (!force && data.refCount > 0) {
return false;
}

this.wasm._free(data.memoryOffset);
this.wasm.tfjs.disposeData(data.id);
this.dataIdMap.delete(dataId);
}
return true;
}

/** Return refCount of a `TensorData`. */
refCount(dataId: DataId): number {
if (this.dataIdMap.has(dataId)) {
const tensorData = this.dataIdMap.get(dataId);
return tensorData.refCount;
}
return 0;
}

incRef(dataId: DataId) {
const data = this.dataIdMap.get(dataId);
this.wasm._free(data.memoryOffset);
this.wasm.tfjs.disposeData(data.id);
this.dataIdMap.delete(dataId);
if (data != null) {
data.refCount++;
}
}

floatPrecision(): 32 {
Expand Down Expand Up @@ -146,9 +178,9 @@ export class BackendWasm extends KernelBackend {
if (memoryOffset == null) {
dataId = this.write(null /* values */, shape, dtype);
} else {
dataId = {};
const id = this.dataIdNextNumber++;
this.dataIdMap.set(dataId, {id, memoryOffset, shape, dtype});
dataId = {id};
this.dataIdMap.set(dataId, {id, memoryOffset, shape, dtype, refCount: 1});
const size = util.sizeFromShape(shape);
this.wasm.tfjs.registerTensor(id, size, memoryOffset);
}
Expand Down
3 changes: 3 additions & 0 deletions tfjs-backend-wasm/src/kernels/BatchMatMul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ function batchMatMul(args: {
a3dId, aShapeBytes, a3d.shape.length, b3dId, bShapeBytes,
b3d.shape.length, transposeA, transposeB, outId);

backend.disposeData(a3d.dataId);
backend.disposeData(b3d.dataId);

out.shape = outShape;
return out;
}
Expand Down
3 changes: 3 additions & 0 deletions tfjs-backend-wasm/src/kernels/GatherV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ function gatherV2(
xId, CppDType[x.dtype], xStridesBytes, stridesSize, indicesId,
shapeInfo.batchSize, outStridesBytes, outId);

backend.disposeData(flattenX.dataId);
backend.disposeData(flattenIndex.dataId);

// reshape
out.shape = shapeInfo.outputShape;
return out;
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/kernels/Reshape.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ export function reshape(args: {
() => `new shape: ${$shape}, old shape: ${x.shape}. New shape and old ` +
`shape must have the same number of elements.`);

// Backend needs to track refCount for the dataId for reshape op
args.backend.incRef(x.dataId);
return {dataId: x.dataId, shape: $shape, dtype: x.dtype};
}

Expand Down
6 changes: 5 additions & 1 deletion tfjs-backend-wasm/src/kernels/Reverse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ export function reverse(
wasmReverse(
xId, axesBytes, axes.length, outShapeBytes, x.shape.length, outId);

return reshape({inputs: {x: out}, attrs: {shape: x.shape}, backend});
const reshaped =
reshape({inputs: {x: out}, attrs: {shape: x.shape}, backend});

backend.disposeData(out.dataId);
return reshaped;
}

export const reverseConfig: KernelConfig = {
Expand Down
13 changes: 11 additions & 2 deletions tfjs-backend-wasm/src/kernels/StridedSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ export function stridedSlice(args: {
const nonStrided = strides.every(v => v === 1);
if (nonStrided) {
const xSliced = slice({inputs: {x}, attrs: {begin, size}, backend});
return reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend});
backend.disposeData(xReshaped.dataId);
const reshaped =
reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend});
backend.disposeData(xSliced.dataId);
return reshaped;
}

const out = backend.makeOutput(outShape, 'float32');
Expand All @@ -134,8 +138,13 @@ export function stridedSlice(args: {
stridesBytes, outputShapeBytes, outStridesBytes, outShape.length,
outId);
}
backend.disposeData(xReshaped.dataId);

return reshape({inputs: {x: out}, attrs: {shape: outShape}, backend});
const reshaped =
reshape({inputs: {x: out}, attrs: {shape: outShape}, backend});

backend.disposeData(out.dataId);
return reshaped;
}

export const stridedSliceConfig: KernelConfig = {
Expand Down
Loading

0 comments on commit 0767a6e

Please sign in to comment.