[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[e2e] Refactor flag usage #7624

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 0 additions & 120 deletions e2e/benchmarks/benchmark_util.js
Original file line number Diff line number Diff line change
Expand Up @@ -471,126 +471,6 @@ function aggregateKernelTime(kernels) {
.sort((a, b) => b.timeMs - a.timeMs);
}

/**
* This map descripes tunable flags and theior corresponding types.
*
* The flags (keys) in the map satisfy the following two conditions:
* - Is tunable. For example, `IS_BROWSER` and `IS_CHROME` is not tunable,
* because they are fixed when running the scripts.
* - Does not depend on other flags when registering in `ENV.registerFlag()`.
* This rule aims to make the list streamlined, and, since there are
* dependencies between flags, only modifying an independent flag without
* modifying its dependents may cause inconsistency.
* (`WEBGL_RENDER_FLOAT32_CAPABLE` is an exception, because only exposing
* `WEBGL_FORCE_F16_TEXTURES` may confuse users.)
*/
const TUNABLE_FLAG_VALUE_RANGE_MAP = {
WEBGL_VERSION: [1, 2],
WASM_HAS_SIMD_SUPPORT: [true, false],
WASM_HAS_MULTITHREAD_SUPPORT: [true, false],
WEBGL_CPU_FORWARD: [true, false],
WEBGL_PACK: [true, false],
WEBGL_FORCE_F16_TEXTURES: [true, false],
WEBGL_RENDER_FLOAT32_CAPABLE: [true, false],
WEBGL_FLUSH_THRESHOLD: [-1, 0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2],
WEBGL_PACK_DEPTHWISECONV: [true, false],
CHECK_COMPUTATION_FOR_ERRORS: [true, false],
KEEP_INTERMEDIATE_TENSORS: [true, false],
WEBGL_USE_SHAPES_UNIFORMS: [true, false],
WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE: [1, 5, 10, 15, 20, 25, 30, 35, 40]
};

/**
* Set environment flags for testing.
*
* This will first set tunable flags (the keys of `TUNABLE_FLAG_TYPE_MAP`). Then
* set URL parameter flags. If there are overlap, URL parameter flags will
* override tunable flags.
*
* ```js
* const flagConfig = {
* WEBGL_PACK: false,
* };
* await setEnvFlags(flagConfig);
*
* console.log(tf.env().getBool('WEBGL_PACK')); // false
* console.log(tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')); // false
* ```
*
* @param flagConfig An object to store flag-value pairs.
*/
async function setEnvFlags(flagConfig) {
if (flagConfig == null) {
return true;
} else if (typeof flagConfig !== 'object') {
throw new Error(
`An object is expected, while a(n) ${typeof flagConfig} is found.`);
}

// Check the validation of flags and values.
for (const flag in flagConfig) {
// TODO: check whether flag can be set as flagConfig[flag].
if (!(flag in TUNABLE_FLAG_VALUE_RANGE_MAP)) {
throw new Error(`${flag} is not a tunable or valid environment flag.`);
}
if (TUNABLE_FLAG_VALUE_RANGE_MAP[flag].indexOf(flagConfig[flag]) === -1) {
throw new Error(
`${flag} value is expected to be in the range [${
TUNABLE_FLAG_VALUE_RANGE_MAP[flag]}], while ${flagConfig[flag]}` +
' is found.');
}
}

tf.env().setFlags(flagConfig);
setEnvFlagsFromUrlParameters();

// `WASM_HAS_SIMD_SUPPORT` and `WEBGL_VERSION` are also evaluated when
// initializing backends, not only inferring.
// TODO: The following backend rebuild logics can be implemented in `setHook`
// when registering these flags.
if ('WASM_HAS_SIMD_SUPPORT' in flagConfig) {
return await resetBackend('wasm');
}

if ('WEBGL_VERSION' in flagConfig) {
return await resetBackend('webgl');
}
}

/**
* Set flags from URL. URL should be in the format:
* ?tfjsflags=FLAG1:1,FLAG2:true.
*/
function setEnvFlagsFromUrlParameters() {
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
const urlParams = tf.env().getQueryParams(location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(keyValue => {
const [key, value] = keyValue.split(':');
try {
tf.env().set(key, parseValue(value));
} catch (err) {
console.error(err);
}
});
}
}

/**
* Converted a URL parameter to a typed value, such a boolean, number, string.
*/
function parseValue(value) {
const lowerCaseValue = value.toLowerCase();
if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
return lowerCaseValue === 'true';
} else if (`${+ lowerCaseValue}` === lowerCaseValue) {
return +lowerCaseValue;
} else {
return value;
}
}

/**
* Reset the target backend.
*
Expand Down
141 changes: 0 additions & 141 deletions e2e/benchmarks/benchmark_util_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -218,147 +218,6 @@ describe('benchmark_util', () => {
});
});

describe('setEnvFlags', () => {
describe('changes nothing when setting empty config or rejecting', () => {
let originalFlags = {};

beforeEach(() => {
originalFlags = {...tf.env().flags};
});
afterAll(() => tf.env().reset());

it('empty config', async () => {
await setEnvFlags();
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting untunable flags', async () => {
const flagConfig = {
IS_BROWSER: false,
};
expectAsync(setEnvFlags(flagConfig))
.toBeRejectedWithError(
Error, /is not a tunable or valid environment flag./);
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting a number flag by a boolean value', async () => {
const flagConfig = {
WEBGL_VERSION: false,
};
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting boolean flag by a number', async () => {
const flagConfig = {
WEBGL_PACK: 1,
};
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting flag value out of the range', async () => {
const outOfRangeValue =
Math.max(...TUNABLE_FLAG_VALUE_RANGE_MAP.WEBGL_VERSION) + 1;
const flagConfig = {
WEBGL_VERSION: outOfRangeValue,
};
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
expect(tf.env().flags).toEqual(originalFlags);
});
});

describe('reset simple flags', () => {
beforeEach(() => tf.env().reset());
afterEach(() => tf.env().reset());

it('reset number type flags', async () => {
const flagConfig = {
WEBGL_VERSION: 1,
};
await setEnvFlags(flagConfig);
expect(tf.env().getNumber('WEBGL_VERSION')).toBe(1);
});

it('reset boolean flags', async () => {
const flagConfig = {
WASM_HAS_SIMD_SUPPORT: false,
WEBGL_CPU_FORWARD: false,
WEBGL_PACK: false,
WEBGL_FORCE_F16_TEXTURES: false,
WEBGL_RENDER_FLOAT32_CAPABLE: false,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WASM_HAS_SIMD_SUPPORT')).toBe(false);
expect(tf.env().getBool('WEBGL_CPU_FORWARD')).toBe(false);
expect(tf.env().getBool('WEBGL_PACK')).toBe(false);
expect(tf.env().getBool('WEBGL_FORCE_F16_TEXTURES')).toBe(false);
expect(tf.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(false);
});
});

describe('reset flags related to environment initialization', () => {
beforeEach(() => tf.engine().reset());
afterAll(() => {
tf.engine().reset();
tf.setBackend('cpu');
});

it(`set 'WEBGL_VERSION' to 2`, async () => {
if (!tf.webgl_util.isWebGLVersionEnabled(2)) {
pending(
'Please use a browser supporting WebGL 2.0 to run this test.');
}
const flagConfig = {
WEBGL_VERSION: 2,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WEBGL_BUFFER_SUPPORTED')).toBe(true);
});

it(`set 'WEBGL_VERSION' to 1`, async () => {
if (!tf.webgl_util.isWebGLVersionEnabled(1)) {
pending(
'Please use a browser supporting WebGL 1.0 to run this test.');
}
const flagConfig = {
WEBGL_VERSION: 1,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WEBGL_BUFFER_SUPPORTED')).toBe(false);
});

it(`reset flags when the related backend is active`, async () => {
if (!tf.webgl_util.isWebGLVersionEnabled(1)) {
pending(
'Please use a browser supporting WebGL 1.0 to run this test.');
}
await tf.setBackend('webgl');
const flagConfig = {
WEBGL_VERSION: 1,
};
await setEnvFlags(flagConfig);
expect(tf.getBackend()).toBe('webgl');
});

it(`reset 'WASM_HAS_SIMD_SUPPORT' as true`,
async () => {
// TODO: add test for SIMD after SIMD implementation.
// const simdSupported = await
// env().getAsync('WASM_HAS_SIMD_SUPPORT');
});

it(`reset 'WASM_HAS_SIMD_SUPPORT' as false`, async () => {
const flagConfig = {
WASM_HAS_SIMD_SUPPORT: false,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WASM_HAS_SIMD_SUPPORT')).toBe(false);
});
});
});

describe('resetBackend', () => {
beforeEach(() => tf.setBackend('cpu'));
afterAll(() => tf.engine().reset());
Expand Down
Loading