[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling structured inputs to call for Keras 3 #18735

Open
areiner222 opened this issue Nov 6, 2023 · 2 comments
Open

Enabling structured inputs to call for Keras 3 #18735

areiner222 opened this issue Nov 6, 2023 · 2 comments
Assignees
Labels
backend:tensorflow stat:awaiting keras-eng Awaiting response from Keras engineer type:feature The user is asking for a new feature.

Comments

@areiner222
Copy link

I've heavily relied on using structured inputs for subclassed {Model, Layer}.call - will keras 3 support this?

I seem to be unable to pass a tensorflow ExtensionType or a generic dataclass (PyTreeNode in jax) hitting this value check.

I believe it should be possible to pass this kind of structured input especially with the tf_flatten / tf_unflatten utility and the jax pytree registration functionality.

TF extension type example:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras_core

class CompositeTensor(tf.experimental.ExtensionType):
    value: tf.Tensor
    meta: int
    
    def __tf_flatten__(self):
        metadata = (self.meta,)  # static config.
        components = (self.value,)  # dynamic values.
        return metadata, components

    @classmethod
    def __tf_unflatten__(cls, metadata, components):
        return cls(*metadata, *components)
    

class ModelCheck(keras_core.Model):
    def __init__(self):
        super().__init__()
        self.layer = keras_core.layers.Dense(32)

    def call(inp, training=None):
        return self.layer(inp.value)

m = ModelCheck()

inp = CompositeTensor(value=tf.random.uniform((10, 64)), meta=3)
print([type(v) for v in tf.nest.flatten(inp)])
out = m(inp)
@sachinprasadhs sachinprasadhs added type:feature The user is asking for a new feature. backend:tensorflow keras-team-review-pending Pending review by a Keras team member. labels Nov 6, 2023
@fchollet
Copy link
Member
fchollet commented Nov 9, 2023

Thanks for the suggestion. We are looking into this. The key APIs to modify would be is_tensor, convert_to_tensor, convert_to_numpy. Maybe we can just extend those on the TF and JAX side.

@qlzh727 qlzh727 removed the keras-team-review-pending Pending review by a Keras team member. label Nov 9, 2023
@akensert
Copy link

Any progress on this matter? It would be fantastic to have the extension types work with Keras 3 :)

@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:tensorflow stat:awaiting keras-eng Awaiting response from Keras engineer type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

5 participants