[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor shape error using TF 2.8.0 with XLA enabled #54973

Open
leondgarse opened this issue Mar 4, 2022 · 4 comments
Open

Tensor shape error using TF 2.8.0 with XLA enabled #54973

leondgarse opened this issue Mar 4, 2022 · 4 comments
Assignees
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.8 type:bug Bug

Comments

@leondgarse
Copy link
Contributor
leondgarse commented Mar 4, 2022

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): TF 2.8.0 / TF 2.6.3 / tf-nightly 2.9.0-dev20220303
  • Python version: 3.8.10
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: cuda-driver-dev-11-2, 11.2.152-1, libcudnn8, 8.3.2.44-1+cuda11.5
  • GPU model and memory: RTX 2080 Ti, 11016MiB / RTX 8000, 46080MiB

Describe the current behavior
When I upgrade my TF 2.6.3 -> 2.8.0, my daily using training script throws me out with error Must have updates.shape = indices.shape[:batch_dim] + buffer_shape[num_index_dims:], got updates.shape: [32], indices.shape: [320,2], buffer_shape: [32,10], num_index_dims: 2, and batch_dim: 1, with setting TF_XLA_FLAGS="--tf_xla_auto_jit=2" flag, which used to work well in TF 2.6.3.

Describe the expected behavior
Expect still working well like old TF 2.6.3 time.

Contributing

  • Do you want to contribute a PR? (yes/no): no
  • Briefly describe your candidate solution(if contributing): have no idea...

Standalone code to reproduce the issue
This is my standalone code for reproducing, that simplified most details:

#!/usr/bin/env python3

import tensorflow as tf
from tensorflow import keras


class NormDense(keras.layers.Layer):
    def __init__(self, units=1000, append_norm=False, **kwargs):
        super().__init__(**kwargs)
        self.units, self.append_norm = units, append_norm

    def build(self, input_shape):
        self.w = self.add_weight(name="norm_dense_w", shape=(input_shape[-1], self.units), trainable=True)
        super().build(input_shape)

    def call(self, inputs, **kwargs):
        # tf.print("tf.reduce_mean(self.w):", tf.reduce_mean(self.w))
        norm_w = tf.nn.l2_normalize(self.w, axis=0)
        norm_inputs = tf.nn.l2_normalize(inputs, axis=1)
        output = tf.matmul(norm_inputs, norm_w)
        if self.append_norm:
            output = tf.concat([output, tf.norm(inputs, axis=1, keepdims=True) * -1], axis=-1)
        return output

    def get_config(self):
        config = super().get_config()
        config.update({"units": self.units, "append_norm": self.append_norm})
        return config


class NormDenseLoss(tf.keras.losses.Loss):
    def __init__(self, from_logits=True, **kwargs):
        super().__init__(**kwargs)
        self.from_logits = from_logits

    def call(self, y_true, y_pred):
        if y_pred.shape[-1] == y_true.shape[-1]:
            norm_logits = y_pred
            margin = 0.3
            regularizer_loss = 0.0
        else:
            norm_logits, feature_norm = y_pred[:, :-1], y_pred[:, -1] * -1
            margin = 0.04 * (feature_norm - 10) + 10.0  # This triggers the error
            regularizer_loss = feature_norm / 1e4 + 1.0 / feature_norm

        pick_cond = tf.where(y_true > 0)
        y_pred_vals = tf.gather_nd(norm_logits, pick_cond)
        theta_valid = y_pred_vals - margin

        # tf.print(">>>>", norm_logits.shape, pick_cond, tf.reduce_sum(tf.cast(y_true > 0, "float32")), theta_valid.shape)
        logits = tf.tensor_scatter_nd_update(norm_logits, pick_cond, theta_valid)
        # theta_one_hot = tf.expand_dims(theta_valid, 1) * tf.cast(y_true, dtype=tf.float32)
        # logits = tf.where(tf.cast(y_true, dtype=tf.bool), theta_one_hot, norm_logits)
        # tf.print(">>>>", norm_logits.shape, logits.shape, y_true.shape)
        cls_loss = tf.keras.losses.categorical_crossentropy(y_true, logits, from_logits=self.from_logits)

        # tf.print(">>>>", cls_loss.shape, regularizer_loss.shape)
        return cls_loss + regularizer_loss * 35.0

    def get_config(self):
        config = super().get_config()
        config.update({"from_logits": self.from_logits})
        return config


if __name__ == "__main__":
    import sys
    import argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--append_norm", action="store_true", help="append norm")
    args = parser.parse_known_args(sys.argv[1:])[0]

    xx = tf.random.uniform([160, 32, 32, 3])
    yy = tf.one_hot(tf.cast(tf.random.uniform([160], 0, 10), "int32"), 10)
    mm = keras.models.Sequential([keras.layers.Input([32, 32, 3]), keras.layers.Flatten(), keras.layers.Dense(32), NormDense(10, append_norm=args.append_norm)])
    mm.compile(loss=NormDenseLoss(), optimizer="adam")
    mm.fit(xx, yy)

Run test using TF 2.6.3:

# Pass
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python ./tf_280_xla_test.py
# Pass
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python ./tf_280_xla_test.py --append_norm

Run test using TF 2.8.0 / tf-nightly 2.9.0-dev20220303:

# Pass
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python ./tf_280_xla_test.py
# Error
CUDA_VISIBLE_DEVICES='0' TF_XLA_FLAGS="--tf_xla_auto_jit=2" python ./tf_280_xla_test.py --append_norm
# Must have updates.shape = indices.shape[:batch_dim] + buffer_shape[num_index_dims:], got updates.shape: [32], indices.shape: [320,2], buffer_shape: [32,10], num_index_dims: 2, and batch_dim: 1

# Pass
CUDA_VISIBLE_DEVICES='0' python ./tf_280_xla_test.py --append_norm

Maybe some part of this script is not needed for this reproduce, not sure. I think something went wrong with margin = 0.04 * (feature_norm - 10) + 10.0, but cannot tell what exactly happens here... Please take a check.

@mohantym
Copy link
Contributor
mohantym commented Mar 4, 2022

Hi @gadagashwini ! Could you look at this issue ?It is not replicating in Colab 2.8 version though .

@mohantym mohantym assigned gadagashwini and unassigned mohantym Mar 4, 2022
@leondgarse
Copy link
Contributor Author
leondgarse commented Mar 4, 2022

@mohantym Uh, right, we can run scripts in colab... Try this tf_280_xla_test.ipynb. Just setting CUDA_VISIBLE_DEVICES='1' leaves it no GPU to use in yours, my bad. Updated commands.

@sachinprasadhs sachinprasadhs added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 7, 2022
@leondgarse
Copy link
Contributor Author
leondgarse commented Apr 14, 2022

Just verified still exists in TF 2.9.0-rc0. Testing results updated in tf_280_xla_test.ipynb. For other versions, TF 2.7.1 works, and TF 2.8.0-rc0 throws error.

@leondgarse
Copy link
Contributor Author
leondgarse commented May 17, 2022
  • This is another test using tf.gather_nd only with a similar error. Also works in TF 2.7.1 but broke after TF 2.8.0:
    import tensorflow as tf
    from tensorflow import keras
    
    def gathered_loss(y_true, y_pred):
        pick_cond = tf.where(y_true > 0)
        y_pred_vals = tf.gather_nd(y_pred, pick_cond)
        theta_valid = y_pred_vals - 0.3  # Some calculation here
    
        theta_one_hot = tf.expand_dims(theta_valid, 1) * tf.cast(y_true, dtype=tf.float32)
        logits = tf.where(tf.cast(y_true, dtype=tf.bool), theta_one_hot, y_pred)
        return tf.keras.losses.categorical_crossentropy(y_true, logits, from_logits=True)
    
    if __name__ == "__main__":
        import sys
        
        xx = tf.random.uniform([160, 32, 32, 3])
        yy = tf.one_hot(tf.cast(tf.random.uniform([160], 0, 10), "int32"), 10)
        mm = keras.models.Sequential([keras.layers.Input([32, 32, 3]), keras.layers.Flatten(), keras.layers.Dense(10)])
        mm.compile(loss=gathered_loss, optimizer="adam")
        mm.fit(xx, yy)
  • Results updated in above tf_280_xla_test.ipynb, also updated TF 2.9.0 test results. I'm wondering if it's related with [XLA] Different JIT compile behavior from TF2.7 #55610.
  • Also please notice that, if we use theta_valid = y_pred_vals - 0.3 instead of theta_valid = y_pred_vals - margin in the first script, TF 2.8.0 / TF 2.9.0 with XLA enabled also works...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.8 type:bug Bug
Projects
None yet
Development

No branches or pull requests

6 participants