InvalidArgumentError when using class_weight in Model.fit with labels having extra dimension for time step #48555
Labels
comp:keras
Keras related issues
stat:awaiting response
Status - Awaiting response from author
stat:awaiting tensorflower
Status - Awaiting response from tensorflower
TF 2.16
type:bug
Bug
We used the
class_weight
parameter to assign different weights for samples of different classes inkeras.Model.fit()
.Our data (input) and labels (output) have an extra time-step dimension besides the batch dimension.
The labels are indices of classes (i.e. not one-hot encoded), used along with the sparse CE loss function.
We've written a minimal reproduction script (as presented below) to simplify the situation.
It looks like that
class_weight
is only designed for outputs of shape (batch_dim, n_classes).Possible workarounds:
class_weight
tosample_weight
tf.reshape()
in aLambda
layer, which would mess the code upI've noticed there are relevant issues but they are left there and closed as the authors did not reply in time.
System information
Describe the current behavior
Crashed on
model.fit
with the following exception (traceback appended at the end as it is too long):The training progress bar did not appear.
Describe the expected behavior
Complete the training, though the minimal reproduction script has no actual trainable parameter.
Standalone code to reproduce the issue
Other info / logs
Traceback:
traceback.txt
Model Structure (in minimal repro script):
The text was updated successfully, but these errors were encountered: