[go: nahoru, domu]

Skip to content

Commit

Permalink
Add TensorFlow test files (lutzroeder#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 10, 2022
1 parent ba624f4 commit 45a5672
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 17 deletions.
2 changes: 2 additions & 0 deletions source/base.js
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,8 @@ if (typeof window !== 'undefined' && typeof window.Long != 'undefined') {
if (typeof module !== 'undefined' && typeof module.exports === 'object') {
module.exports.Int64 = base.Int64;
module.exports.Uint64 = base.Uint64;
module.exports.Complex64 = base.Complex;
module.exports.Complex128 = base.Complex;
module.exports.BinaryReader = base.BinaryReader;
module.exports.Metadata = base.Metadata;
}
44 changes: 27 additions & 17 deletions source/tf.js
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,9 @@ tf.Tensor = class {
else {
const DataType = tf.proto.tensorflow.DataType;
switch (tensor.dtype) {
case DataType.DT_INVALID: {
break;
}
case DataType.DT_BFLOAT16: {
const values = tensor.half_val || [];
this._values = new Uint8Array(values.length << 2);
Expand Down Expand Up @@ -1309,6 +1312,24 @@ tf.Tensor = class {
this._layout = '|';
break;
}
case DataType.DT_COMPLEX64: {
this._layout = '|';
const values = tensor.scomplex_val || null;
this._values = new Array(values.length >> 1);
for (let i = 0; i < values.length; i += 2) {
this._values[i >> 1] = base.Complex64.create(values[i], values[i + 1]);
}
break;
}
case DataType.DT_COMPLEX128: {
this._layout = '|';
const values = tensor.dcomplex_val || null;
this._values = new Array(values.length >> 1);
for (let i = 0; i < values.length; i += 2) {
this._values[i >> 1] = base.Complex128.create(values[i], values[i + 1]);
}
break;
}
default: {
throw new tf.Error("Unsupported tensor data type '" + tensor.dtype + "'.");
}
Expand Down Expand Up @@ -1913,13 +1934,11 @@ tf.Utility = class {

static dataType(type) {
if (!tf.Utility._dataTypes) {
const dataTypes = new Map();
const DataType = tf.proto.tensorflow.DataType;
for (let key of Object.keys(DataType)) {
const value = DataType[key];
key = key.startsWith('DT_') ? key.substring(3) : key;
dataTypes.set(value, key.toLowerCase());
}
const dataTypes = new Map(Object.entries(DataType).map((entry) => {
const key = entry[0].startsWith('DT_') ? entry[0].substring(3) : entry[0];
return [ entry[1], key.toLowerCase() ];
}));
dataTypes.set(DataType.DT_HALF, 'float16');
dataTypes.set(DataType.DT_FLOAT, 'float32');
dataTypes.set(DataType.DT_DOUBLE, 'float64');
Expand All @@ -1931,17 +1950,8 @@ tf.Utility = class {

static dataTypeKey(type) {
if (!tf.Utility._dataTypeKeys) {
const dataTypeKeys = new Map();
const DataType = tf.proto.tensorflow.DataType;
for (let key of Object.keys(DataType)) {
const value = DataType[key];
key = key.startsWith('DT_') ? key.substring(3) : key;
dataTypeKeys.set(key.toLowerCase(), value);
}
dataTypeKeys.set('float16', DataType.DT_HALF);
dataTypeKeys.set('float32', DataType.DT_FLOAT);
dataTypeKeys.set('float64', DataType.DT_DOUBLE);
tf.Utility._dataTypeKeys = dataTypeKeys;
tf.Utility.dataType(0);
tf.Utility._dataTypeKeys = new Map(Array.from(tf.Utility._dataTypes).map((entry) => [ entry[1], entry[0] ]));
}
return tf.Utility._dataTypeKeys.get(type);
}
Expand Down
21 changes: 21 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5367,6 +5367,13 @@
"action": "skip-render",
"link": "https://github.com/google-research/bert#pre-trained-models"
},
{
"type": "tf",
"target": "bfloat16.pbtxt",
"source": "https://github.com/lutzroeder/netron/files/9540969/bfloat16.pbtxt.zip[bfloat16.pbtxt]",
"format": "TensorFlow Graph",
"link": "https://github.com/lutzroeder/netron/issues/187"
},
{
"type": "tf",
"target": "char-rnn-tensorflow.pb",
Expand Down Expand Up @@ -5403,6 +5410,20 @@
"format": "TensorFlow Graph",
"link": "https://github.com/taey16/tf"
},
{
"type": "tf",
"target": "complex64.pb",
"source": "https://github.com/lutzroeder/netron/files/9540970/complex.zip[complex64.pb]",
"format": "TensorFlow Graph",
"link": "https://github.com/lutzroeder/netron/issues/187"
},
{
"type": "tf",
"target": "complex128.pb",
"source": "https://github.com/lutzroeder/netron/files/9540970/complex.zip[complex128.pb]",
"format": "TensorFlow Graph",
"link": "https://github.com/lutzroeder/netron/issues/187"
},
{
"type": "tf",
"target": "conv-layers.pb.zip",
Expand Down

0 comments on commit 45a5672

Please sign in to comment.