[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.print() with XLA compilation #52944

Open
MIHIRKHAMBETE opened this issue Nov 4, 2021 · 12 comments
Open

tf.print() with XLA compilation #52944

MIHIRKHAMBETE opened this issue Nov 4, 2021 · 12 comments
Assignees
Labels
comp:apis Highlevel API related issues comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.11 Issues related to TF 2.11 type:feature Feature requests

Comments

@MIHIRKHAMBETE
Copy link

Please make sure that this is a feature request. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template

Hello,

Does anyone know whether tf.print (or an equivalent workaround) can be made to work with XLA compilation? I get an error saying that XLA does not recognize the "printV2" operation if I have a tf.print statement inside a function decorated with @tf.function(jit_compile=True).

If this functionality does not exist, would like to request that it be added as a new feature!

Thanks!

System information

  • TensorFlow version (you are using): 2.6.0
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state.
tf.print() does not seem to work when used within functions that are compiled with XLA ( decorated with @tf.function(jit_compile=True)

Will this change the current api? How?
No

Who will benefit with this feature?
Users of XLA-compiled tensorflow code.

Any Other info.

@MIHIRKHAMBETE MIHIRKHAMBETE added the type:feature Feature requests label Nov 4, 2021
@sachinprasadhs
Copy link
Contributor

Could you please provide the code for which you are facing error. Thanks!

@sachinprasadhs sachinprasadhs added 2.6.0 comp:apis Highlevel API related issues stat:awaiting response Status - Awaiting response from author labels Nov 4, 2021
@MIHIRKHAMBETE
Copy link
Author
MIHIRKHAMBETE commented Nov 4, 2021

`
import tensorflow as tf

@tf.function(jit_compile=True)
def print_ten():
for i in range(10):
tf.print('printing')
return

print_ten()
`

@bhack
Copy link
Contributor
bhack commented Nov 4, 2021

Why not

import tensorflow as tf

@tf.function(jit_compile=True)
def print_ten():
  for i in range(10):
    print('printing ' +str(i))
  return

print_ten()

@sachinprasadhs
Copy link
Contributor

Normal print functions works without any limitations, please find the gist here. Thanks!

@MIHIRKHAMBETE
Copy link
Author

The problem is that with normal print, the print statements are only executed on the first call to the function if the function is called multiple times with the same arguments (since tf will use an already made graph, which does not include the calls to print) - (please see https://www.tensorflow.org/guide/function under section "what is tracing?" for details)

tf.print() on the other hand will always print in graph mode (without XLA), even on non-first calls. However, tf.print() is not currently supported when XLA compilation is used.

Hope this clarifies the nature of the problem -please do let me know if you need more info/code examples of the issue.

@MIHIRKHAMBETE
Copy link
Author

For example,

`import tensorflow as tf

@tf.function(jit_compile=True)
def print_ten():
for i in range(10):
print('printing ' +str(i))
return

print_ten()
print_ten()
`
Will only print 10 times (corresponding to the first call). If the regular print is replaced with tf.print(), the code will NOT run since XLA does not support tf.print at this time.

@bhack
Copy link
Contributor
bhack commented Nov 4, 2021

Yes if you want to print on every function call it is not supported. You can print the returned value outside of the compiled function.
I don't know if it is still updated but here you can find the list of the supported ops:
#14798 (comment)

@MIHIRKHAMBETE
Copy link
Author

Yup - that's why I made a feature request for adding tf.print() to the supported ops for XLA compilation.

@sachinprasadhs sachinprasadhs added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels Nov 5, 2021
@wilecoyote2015
Copy link

Is there a technical reason why tf.print is not supported using XLA?
I am really missing this, especially for custom training loops.

@rivershah
Copy link

@sachinprasadhs any updates on this please? tf.print does not work with xla

Detected unsupported operations when trying to compile graph __inference_run_step_2702[] on XLA_CPU_JIT: PrintV2 (No registered 'PrintV2' OpKernel for XLA_CPU_JIT devices compatible with node

@mohantym
Copy link
Contributor

@MIHIRKHAMBETE !
I am able to replicate this issue with 2.11 version.
Thank you!

@mohantym mohantym added TF 2.11 Issues related to TF 2.11 and removed 2.6.0 labels Jan 10, 2023
@chenmoneygithub
Copy link
Contributor

@sachinprasadhs Please help forward this to XLA or ML API team. Debugging graph mode issues often relies on tf.print, since XLA graph cannot do tf.print, it's extremely hard to debug when I see bugs specific to XLA graph. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues comp:xla XLA stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.11 Issues related to TF 2.11 type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

8 participants