[go: nahoru, domu]

Skip to content

Commit

Permalink
[e2e] Refactor flags
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Apr 25, 2023
1 parent 27a77c9 commit 3c84ac9
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 309 deletions.
120 changes: 0 additions & 120 deletions e2e/benchmarks/benchmark_util.js
Original file line number Diff line number Diff line change
Expand Up @@ -471,126 +471,6 @@ function aggregateKernelTime(kernels) {
.sort((a, b) => b.timeMs - a.timeMs);
}

/**
* This map descripes tunable flags and theior corresponding types.
*
* The flags (keys) in the map satisfy the following two conditions:
* - Is tunable. For example, `IS_BROWSER` and `IS_CHROME` is not tunable,
* because they are fixed when running the scripts.
* - Does not depend on other flags when registering in `ENV.registerFlag()`.
* This rule aims to make the list streamlined, and, since there are
* dependencies between flags, only modifying an independent flag without
* modifying its dependents may cause inconsistency.
* (`WEBGL_RENDER_FLOAT32_CAPABLE` is an exception, because only exposing
* `WEBGL_FORCE_F16_TEXTURES` may confuse users.)
*/
const TUNABLE_FLAG_VALUE_RANGE_MAP = {
WEBGL_VERSION: [1, 2],
WASM_HAS_SIMD_SUPPORT: [true, false],
WASM_HAS_MULTITHREAD_SUPPORT: [true, false],
WEBGL_CPU_FORWARD: [true, false],
WEBGL_PACK: [true, false],
WEBGL_FORCE_F16_TEXTURES: [true, false],
WEBGL_RENDER_FLOAT32_CAPABLE: [true, false],
WEBGL_FLUSH_THRESHOLD: [-1, 0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2],
WEBGL_PACK_DEPTHWISECONV: [true, false],
CHECK_COMPUTATION_FOR_ERRORS: [true, false],
KEEP_INTERMEDIATE_TENSORS: [true, false],
WEBGL_USE_SHAPES_UNIFORMS: [true, false],
WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE: [1, 5, 10, 15, 20, 25, 30, 35, 40]
};

/**
* Set environment flags for testing.
*
* This will first set tunable flags (the keys of `TUNABLE_FLAG_TYPE_MAP`). Then
* set URL parameter flags. If there are overlap, URL parameter flags will
* override tunable flags.
*
* ```js
* const flagConfig = {
* WEBGL_PACK: false,
* };
* await setEnvFlags(flagConfig);
*
* console.log(tf.env().getBool('WEBGL_PACK')); // false
* console.log(tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')); // false
* ```
*
* @param flagConfig An object to store flag-value pairs.
*/
async function setEnvFlags(flagConfig) {
if (flagConfig == null) {
return true;
} else if (typeof flagConfig !== 'object') {
throw new Error(
`An object is expected, while a(n) ${typeof flagConfig} is found.`);
}

// Check the validation of flags and values.
for (const flag in flagConfig) {
// TODO: check whether flag can be set as flagConfig[flag].
if (!(flag in TUNABLE_FLAG_VALUE_RANGE_MAP)) {
throw new Error(`${flag} is not a tunable or valid environment flag.`);
}
if (TUNABLE_FLAG_VALUE_RANGE_MAP[flag].indexOf(flagConfig[flag]) === -1) {
throw new Error(
`${flag} value is expected to be in the range [${
TUNABLE_FLAG_VALUE_RANGE_MAP[flag]}], while ${flagConfig[flag]}` +
' is found.');
}
}

tf.env().setFlags(flagConfig);
setEnvFlagsFromUrlParameters();

// `WASM_HAS_SIMD_SUPPORT` and `WEBGL_VERSION` are also evaluated when
// initializing backends, not only inferring.
// TODO: The following backend rebuild logics can be implemented in `setHook`
// when registering these flags.
if ('WASM_HAS_SIMD_SUPPORT' in flagConfig) {
return await resetBackend('wasm');
}

if ('WEBGL_VERSION' in flagConfig) {
return await resetBackend('webgl');
}
}

/**
* Set flags from URL. URL should be in the format:
* ?tfjsflags=FLAG1:1,FLAG2:true.
*/
function setEnvFlagsFromUrlParameters() {
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
const urlParams = tf.env().getQueryParams(location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(keyValue => {
const [key, value] = keyValue.split(':');
try {
tf.env().set(key, parseValue(value));
} catch (err) {
console.error(err);
}
});
}
}

/**
* Converted a URL parameter to a typed value, such a boolean, number, string.
*/
function parseValue(value) {
const lowerCaseValue = value.toLowerCase();
if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
return lowerCaseValue === 'true';
} else if (`${+ lowerCaseValue}` === lowerCaseValue) {
return +lowerCaseValue;
} else {
return value;
}
}

/**
* Reset the target backend.
*
Expand Down
141 changes: 0 additions & 141 deletions e2e/benchmarks/benchmark_util_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -218,147 +218,6 @@ describe('benchmark_util', () => {
});
});

describe('setEnvFlags', () => {
describe('changes nothing when setting empty config or rejecting', () => {
let originalFlags = {};

beforeEach(() => {
originalFlags = {...tf.env().flags};
});
afterAll(() => tf.env().reset());

it('empty config', async () => {
await setEnvFlags();
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting untunable flags', async () => {
const flagConfig = {
IS_BROWSER: false,
};
expectAsync(setEnvFlags(flagConfig))
.toBeRejectedWithError(
Error, /is not a tunable or valid environment flag./);
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting a number flag by a boolean value', async () => {
const flagConfig = {
WEBGL_VERSION: false,
};
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting boolean flag by a number', async () => {
const flagConfig = {
WEBGL_PACK: 1,
};
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
expect(tf.env().flags).toEqual(originalFlags);
});

it('rejects when setting flag value out of the range', async () => {
const outOfRangeValue =
Math.max(...TUNABLE_FLAG_VALUE_RANGE_MAP.WEBGL_VERSION) + 1;
const flagConfig = {
WEBGL_VERSION: outOfRangeValue,
};
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
expect(tf.env().flags).toEqual(originalFlags);
});
});

describe('reset simple flags', () => {
beforeEach(() => tf.env().reset());
afterEach(() => tf.env().reset());

it('reset number type flags', async () => {
const flagConfig = {
WEBGL_VERSION: 1,
};
await setEnvFlags(flagConfig);
expect(tf.env().getNumber('WEBGL_VERSION')).toBe(1);
});

it('reset boolean flags', async () => {
const flagConfig = {
WASM_HAS_SIMD_SUPPORT: false,
WEBGL_CPU_FORWARD: false,
WEBGL_PACK: false,
WEBGL_FORCE_F16_TEXTURES: false,
WEBGL_RENDER_FLOAT32_CAPABLE: false,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WASM_HAS_SIMD_SUPPORT')).toBe(false);
expect(tf.env().getBool('WEBGL_CPU_FORWARD')).toBe(false);
expect(tf.env().getBool('WEBGL_PACK')).toBe(false);
expect(tf.env().getBool('WEBGL_FORCE_F16_TEXTURES')).toBe(false);
expect(tf.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')).toBe(false);
});
});

describe('reset flags related to environment initialization', () => {
beforeEach(() => tf.engine().reset());
afterAll(() => {
tf.engine().reset();
tf.setBackend('cpu');
});

it(`set 'WEBGL_VERSION' to 2`, async () => {
if (!tf.webgl_util.isWebGLVersionEnabled(2)) {
pending(
'Please use a browser supporting WebGL 2.0 to run this test.');
}
const flagConfig = {
WEBGL_VERSION: 2,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WEBGL_BUFFER_SUPPORTED')).toBe(true);
});

it(`set 'WEBGL_VERSION' to 1`, async () => {
if (!tf.webgl_util.isWebGLVersionEnabled(1)) {
pending(
'Please use a browser supporting WebGL 1.0 to run this test.');
}
const flagConfig = {
WEBGL_VERSION: 1,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WEBGL_BUFFER_SUPPORTED')).toBe(false);
});

it(`reset flags when the related backend is active`, async () => {
if (!tf.webgl_util.isWebGLVersionEnabled(1)) {
pending(
'Please use a browser supporting WebGL 1.0 to run this test.');
}
await tf.setBackend('webgl');
const flagConfig = {
WEBGL_VERSION: 1,
};
await setEnvFlags(flagConfig);
expect(tf.getBackend()).toBe('webgl');
});

it(`reset 'WASM_HAS_SIMD_SUPPORT' as true`,
async () => {
// TODO: add test for SIMD after SIMD implementation.
// const simdSupported = await
// env().getAsync('WASM_HAS_SIMD_SUPPORT');
});

it(`reset 'WASM_HAS_SIMD_SUPPORT' as false`, async () => {
const flagConfig = {
WASM_HAS_SIMD_SUPPORT: false,
};
await setEnvFlags(flagConfig);
expect(tf.env().getBool('WASM_HAS_SIMD_SUPPORT')).toBe(false);
});
});
});

describe('resetBackend', () => {
beforeEach(() => tf.setBackend('cpu'));
afterAll(() => tf.engine().reset());
Expand Down
Loading

0 comments on commit 3c84ac9

Please sign in to comment.