[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues in Tensorflow model training #69984

Open
myh1234567 opened this issue Jun 18, 2024 · 2 comments
Open

Issues in Tensorflow model training #69984

myh1234567 opened this issue Jun 18, 2024 · 2 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues TF 2.16 type:bug Bug

Comments

@myh1234567
Copy link

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

tf 2.16.1

Custom code

Yes

OS platform and distribution

linux

Mobile device

x86_64

Python version

3.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

With the same code and the same data, I had no issues training under Python 3.9 with TensorFlow 2.11.
However, after updating to Python 3.12 with TensorFlow 2.16, I encountered errors.

The error messages are as follows:
Traceback (most recent call last):
File "/home/train_model", line 379, in
history_db = model_db.fit(X_transformed_train_db, y_train_db_, epochs=3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/miniconda3/lib/python3.12/site-packages/keras/utils/traceback_utils.py", line 123, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/miniconda3/lib/python3.12/site-packages/keras/trainers/data_adapters/init.py", line 113, in get_data_adapter
raise ValueError(f"Unrecognized data type: x={x} (of type {type(x)})")
ValueError: Unrecognized data type: x=SparseTensor(indices=tf.Tensor(
[[ 0 11068]
[ 0 16849]
[ 0 35681]
...
[8563599 29603]
[8563599 31778]
[8563599 38279]], shape=(41839230, 2), dtype=int64), values=tf.Tensor([0.52680086 0.59266628 0.42937609 ... 0.27554394 0.25317136 0.33879793], shape=(41839230,), dtype=float64), dense_shape=tf.Tensor([8563600 40000], shape=(2,), dtype=int64)) (of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'>)

Standalone code to reproduce the issue

n/a

Relevant log output

No response

@sushreebarsa
Copy link
Contributor

@myh1234567
Could you please verify that all input data (x, y, etc.) passed to your model is of the correct type (tf.float32, tf.float64, etc.) as expected by the model's architecture and the loss function ? Also, in order to expedite the trouble-shooting process, please provide a code snippet to reproduce the issue reported here. Thank you!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label Jun 19, 2024
@myh1234567
Copy link
Author
myh1234567 commented Jun 19, 2024

@myh1234567 Could you please verify that all input data (x, y, etc.) passed to your model is of the correct type (tf.float32, tf.float64, etc.) as expected by the model's architecture and the loss function ? Also, in order to expedite the trouble-shooting process, please provide a code snippet to reproduce the issue reported here. Thank you!

the following is the code snippent,

import tensorflow as tf
import numpy as np
import scipy.sparse as sp
import keras
from keras.utils import to_categorical

physical_device = tf.config.list_physical_devices("GPU")
print("===================================== num gpu:", len(physical_device))

def convert_sparse_matrix_to_sparse_tensor(X) -> tf.SparseTensor:
""" Convert a scipy sparse matrix to a SparseTensor."""
coo = X.tocoo()
indices = np.mat([coo.row, coo.col]).transpose()
return tf.sparse.SparseTensor(indices, coo.data, coo.shape)

X_transformed_train_db = sp.rand(100, 1000, density=0.1, format='coo')
X_train_db = convert_sparse_matrix_to_sparse_tensor(X_transformed_train_db)
X_train_db = tf.sparse.reorder(X_train_db)
y_train_db = np.random.randint(0, 3, size=(100,)) # Let's assume 3 classes
y_train_db_ = to_categorical(y_train_db)

gpus = tf.config.list_logical_devices("GPU")
strategy = tf.distribute.MirroredStrategy(gpus)

with strategy.scope():
opti = keras.optimizers.Adam(learning_rate=0.0001)
input_dimension = X_transformed_train_db.shape[1]
model_db = keras.Sequential()
model_db.add(keras.layers.Dense(1000, kernel_initializer=keras.initializers.HeNormal(seed=1), activation='relu', input_dim=input_dimension))
model_db.add(keras.layers.Dropout(0.1))
model_db.add(keras.layers.Dense(500, kernel_initializer=keras.initializers.HeNormal(seed=2), activation='relu'))
model_db.add(keras.layers.Dropout(0.1))
model_db.add(keras.layers.Dense(200, kernel_initializer=keras.initializers.HeNormal(), activation='relu'))
model_db.add(keras.layers.Dropout(0.1))
model_db.add(keras.layers.Dense(y_train_db_.shape[1], kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=4), activation='softmax'))
model_db.compile(optimizer=opti, loss='categorical_crossentropy',metrics=['accuracy'])

history_db = model_db.fit(X_transformed_train_db, y_train_db_, epochs=3)

tensorflow 2.16.1
keras 3.3.3
scipy 1.13.1
numpy 1.26.4

thank you

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Jun 19, 2024
@sushreebarsa sushreebarsa added the comp:dist-strat Distribution Strategy related issues label Jun 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues TF 2.16 type:bug Bug
Projects
None yet
Development

No branches or pull requests

3 participants