[go: nahoru, domu]

Skip to content

Commit

Permalink
feat: change nodejs tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
kewent authored and iennae committed May 9, 2024
1 parent b04ac14 commit a3521a8
Showing 1 changed file with 32 additions and 46 deletions.
78 changes: 32 additions & 46 deletions ai-platform/snippets/test/predict-text-embeddings.test.js
Original file line number Diff line number Diff line change
@@ -1,65 +1,51 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* Licensed under the Apache License, Version 2.0 (the 'License');
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an 'AS IS' BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

'use strict';

// [START aiplatform_sdk_embedding]
async function main(
project,
model = 'text-embedding-004',
texts = 'banana bread?;banana muffins?',
task = 'QUESTION_ANSWERING',
outputDimensionality = 0,
apiEndpoint = 'us-central1-aiplatform.googleapis.com'
) {
const aiplatform = require('@google-cloud/aiplatform');
const {PredictionServiceClient} = aiplatform.v1;
const {helpers} = aiplatform; // helps construct protobuf.Value objects.
const clientOptions = {apiEndpoint: apiEndpoint};
const match = apiEndpoint.match(/(?<Location>\w+-\w+)/);
const location = match ? match.groups.Location : 'us-centra11';
const endpoint = `projects/${project}/locations/${location}/publishers/google/models/${model}`;
const parameters =
outputDimensionality > 0
? helpers.toValue(outputDimensionality)
: helpers.toValue(256);
const path = require('path');
const {assert} = require('chai');
const {describe, it} = require('mocha');

async function callPredict() {
const instances = texts
.split(';')
.map(e => helpers.toValue({content: e, taskType: task}));
const request = {endpoint, instances, parameters};
const client = new PredictionServiceClient(clientOptions);
const [response] = await client.predict(request);
console.log('Got predict response');
const predictions = response.predictions;
for (const prediction of predictions) {
const embeddings = prediction.structValue.fields.embeddings;
const values = embeddings.structValue.fields.values.listValue.values;
console.log('Got prediction: ' + JSON.stringify(values));
}
}
const cp = require('child_process');
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
const cwd = path.join(__dirname, '..');

callPredict();
}
// [END aiplatform_sdk_embedding]
const project = process.env.CAIP_PROJECT_ID;
const texts = [
'banana bread?',
'banana muffin?',
'banana?',
'recipe?',
'muffin recipe?',
].join(';');

process.on('unhandledRejection', err => {
console.error(err.message);
process.exitCode = 1;
});

main(...process.argv.slice(2));
describe('predict text embeddings', () => {
it('should get text embeddings using the latest model', async () => {
const stdout = execSync(
`node ./predict-text-embeddings.js ${project} text-embedding-004 '${texts}' QUESTION_ANSWERING 256`,
{cwd}
);
assert.match(stdout, /Got predict response/);
});
it('should get text embeddings using the preview model', async () => {
const stdout = execSync(
`node ./predict-text-embeddings-preview.js ${project} text-embedding-preview-0409 '${texts}' QUESTION_ANSWERING 256`,
{cwd}
);
assert.match(stdout, /Got predict response/);
});
});

0 comments on commit a3521a8

Please sign in to comment.