[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autograph] Inconsistent behaviour with lambda variable in loop #56089

Open
bhack opened this issue May 12, 2022 · 23 comments · Fixed by #56119
Open

[Autograph] Inconsistent behaviour with lambda variable in loop #56089

bhack opened this issue May 12, 2022 · 23 comments · Fixed by #56119
Assignees
Labels
comp:autograph Autograph related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@bhack
Copy link
Contributor
bhack commented May 12, 2022
Click to expand!

Issue Type

Bug

Source

source

Tensorflow Version

master

Custom Code

No

OS Platform and Distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

We have an inconsistent behavior with lambda variables in a loop in pure python and graph mode:
https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result

Standalone code to reproduce the issue

import tensorflow as tf

def test_a():
  fns = []
  for i in range(3):
    fns.append(lambda: print(i))
  for f in fns:
    f()

@tf.function
def test_b():
  fns = []
  for i in range(3):
    fns.append(lambda: print(i))
  for f in fns:
    f()

def test_c():
  fns = []
  for i in range(3):
    fns.append(lambda i=i: print(i))
  for f in fns:
    f()

@tf.function 
def test_d():
  fns = []
  for i in range(3):
    fns.append(lambda i=i: print(i))
  for f in fns:
    f()

test_a() 
print("=="*10)
tf.config.run_functions_eagerly(False)
test_b()
print("=="*10)
tf.config.run_functions_eagerly(True)
test_b()
print("=="*10)
test_c() 
print("=="*10)
tf.config.run_functions_eagerly(False)
test_d()
print("=="*10)
tf.config.run_functions_eagerly(True)
test_d() 
2
2
2
====================
0
1
2
====================
2
2
2
====================
0
1
2
====================
0
1
2
====================
0
1
2

Relevant log output

test_b is wrongly working "as expected" in graph mode:

# coding=utf-8
def tf__test():
    with ag__.FunctionScope('test', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        fns = []

        def get_state():
            return ()

        def set_state(block_vars):
            pass

        def loop_body(itr):
            i = itr
            ag__.converted_call(ag__.ld(fns).append, (ag__.autograph_artifact((lambda : ag__.ld(print)(ag__.ld(i)))),), None, fscope)
        i = ag__.Undefined('i')
        ag__.for_stmt(ag__.converted_call(ag__.ld(range), (3,), None, fscope), None, loop_body, get_state, set_state, (), {'iterate_names': 'i'})

        def get_state_1():
            return ()

        def set_state_1(block_vars):
            pass

        def loop_body_1(itr_1):
            f = itr_1
            ag__.converted_call(ag__.ld(f), (), None, fscope)
        f = ag__.Undefined('f')
        ag__.for_stmt(ag__.ld(fns), None, loop_body_1, get_state_1, set_state_1, (), {'iterate_names': 'f'})
@bhack
Copy link
Contributor Author
bhack commented May 12, 2022

See more at keras-team/keras-cv#432

/cc @mdanatg

@bhack bhack changed the title [Autograph] Inconsistent behaviour between lambda variable in loop [Autograph] Inconsistent behaviour with lambda variable in loop May 12, 2022
@bhack
Copy link
Contributor Author
bhack commented May 12, 2022

As referenced in the official Python FAQ:

Note that this behaviour is not peculiar to lambdas, but applies to regular functions too.

@mdanatg
Copy link
mdanatg commented May 13, 2022

I think it's because we pass the loop variable through a function argument, which is enough to avoid the closure aliasing:

        def loop_body(itr):
            i = itr
            ag__.converted_call(ag__.ld(fns).append, (ag__.autograph_artifact((lambda : ag__.ld(i))),), None, fscope)

In the code above, the lambda would close over the local i which has a copy of the value.

This only happens for the for_stmt operator, which passes the iterate as an argument. If we rewrite the test as a while loop, so that i is closed-ver by the function body, things are once again quirky as intended:

import tensorflow as tf

def test():
  fns = []
  i = 0
  while i < 3:
    fns.append(lambda: print(i))
    i += 1
  for f in fns:
    f()

test()


tf.autograph.set_verbosity(0, True)

@tf.function(autograph=True)
def test():
  fns = []
  i = 0
  while i < 3:
    fns.append(lambda: print(i))
    i += 1
  for f in fns:
    f()

test()

This means that the fix is also to avoid passing the iterate as argument to the loop body, and instead rely on the get_state/set_state functions, as is the case of the while loop.

@tilakrayal tilakrayal added comp:autograph Autograph related issues stat:awaiting response Status - Awaiting response from author labels May 13, 2022
@bhack
Copy link
Contributor Author
bhack commented May 13, 2022

Yes is what I have suspected:

from tensorflow.python.autograph.impl import api
ag__ = api._TRANSPILER.get_extra_locals()['ag__']  # pylint:disable=protected-access
def tf__test_for():
    with ag__.FunctionScope('test', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        fns = []

        def get_state():
            return ()

        def set_state(block_vars):
            pass

        def loop_body(itr):
            i = itr
            ag__.converted_call(ag__.ld(fns).append, (ag__.autograph_artifact((lambda : ag__.ld(print)((ag__.converted_call(ag__.ld(str), (ag__.ld(i),), None, fscope))))),), None, fscope)
        i = ag__.Undefined('i')
        ag__.for_stmt(ag__.converted_call(ag__.ld(range), (3,), None, fscope), None, loop_body, get_state, set_state, (), {'iterate_names': 'i'})

        def get_state_1():
            return ()

        def set_state_1(block_vars):
            pass

        def loop_body_1(itr_1):
            f = itr_1
            ag__.converted_call(ag__.ld(f), (), None, fscope)
        f = ag__.Undefined('f')
        ag__.for_stmt(ag__.ld(fns), None, loop_body_1, get_state_1, set_state_1, (), {'iterate_names': 'f'})

def tf__test_while():
    with ag__.FunctionScope('test_while', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        fns = []
        i = 0

        def get_state():
            return (i,)

        def set_state(vars_):
            nonlocal i
            (i,) = vars_

        def loop_body():
            nonlocal i
            ag__.converted_call(ag__.ld(fns).append, (ag__.autograph_artifact((lambda : ag__.ld(print)(ag__.ld(i)))),), None, fscope)
            i = ag__.ld(i)
            i += 1

        def loop_test():
            return (ag__.ld(i) < 3)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('i',), {})

        def get_state_1():
            return ()

        def set_state_1(block_vars):
            pass

        def loop_body_1(itr):
            f = itr
            ag__.converted_call(ag__.ld(f), (), None, fscope)
        f = ag__.Undefined('f')
        ag__.for_stmt(ag__.ld(fns), None, loop_body_1, get_state_1, set_state_1, (), {'iterate_names': 'f'})
tf__test_for()
print("====")
tf__test_while()
0
1
2
====
3
3
3

And modify your while example to correctly get the output we "expect":

import tensorflow as tf

def test():
  fns = []
  i = 0
  while i < 3:
    fns.append(lambda i=i: print(i))
    i += 1
  for f in fns:
    f()

test()


tf.autograph.set_verbosity(0, True)

@tf.function(autograph=True)
def test():
  fns = []
  i = 0
  while i < 3:
    fns.append(lambda i=i: print(i))
    i += 1
  for f in fns:
    f()

test()
0
1
2
0
1
2

@tilakrayal tilakrayal added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels May 13, 2022
@bhack
Copy link
Contributor Author
bhack commented May 13, 2022

@mdanatg I've manually rewritten the transformation output. Do we need to produce something like this for the lambda: case?

from tensorflow.python.autograph.impl import api
ag__ = api._TRANSPILER.get_extra_locals()['ag__']  # pylint:disable=protected-access
def tf__test_for():
    with ag__.FunctionScope('test', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        fns = []
        i = 0
        def get_state():
            return (i,)

        def set_state(block_vars):
            nonlocal i
            (i,) = block_vars

        def loop_body(itr):
            nonlocal i
            ag__.converted_call(ag__.ld(fns).append, (ag__.autograph_artifact((lambda : ag__.ld(print)((ag__.converted_call(ag__.ld(str), (ag__.ld(i),), None, fscope))))),), None, fscope)
            i = ag__.ld(i)
            i += 1
        ag__.for_stmt(ag__.converted_call(ag__.ld(range), (3,), None, fscope), None, loop_body, get_state, set_state, (), {'iterate_names': 'i'})

        def get_state_1():
            return ()

        def set_state_1(block_vars):
            pass

        def loop_body_1(itr_1):
            f = itr_1
            ag__.converted_call(ag__.ld(f), (), None, fscope)
        f = ag__.Undefined('f')
        ag__.for_stmt(ag__.ld(fns), None, loop_body_1, get_state_1, set_state_1, (), {'iterate_names': 'f'})

@mdanatg
Copy link
mdanatg commented May 13, 2022

Yes, something like that. And then the loop_body function would be a regular thunk: def loop_body():. We may also need to initialize i with Undefined, rather than 0, and might also need to replace (), {'iterate_names': 'i'} with just ('i',), but not sure.

@bhack
Copy link
Contributor Author
bhack commented May 13, 2022

Do we want to handle iter as nolocal? As in the helper function above it is only in the undefined bucket:

def _get_block_vars(self, node, modified):
"""Determines the variables affected inside a control flow statement."""
defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
fn_scope = self.state[_Function].scope
basic_scope_vars = self._get_block_basic_vars(
modified,
live_in,
live_out)
composite_scope_vars = self._get_block_composite_vars(modified, live_in)
scope_vars = tuple(basic_scope_vars | composite_scope_vars)
# Variables that are modified inside the scope, but not defined
# before entering it. Only simple variables must be defined. The
# composite ones will be implicitly checked at runtime.
possibly_undefined = (
modified - defined_in - fn_scope.globals - fn_scope.nonlocals)
undefined = tuple(v for v in possibly_undefined if not v.is_composite())
# Variables that are modified inside the scope, and depend on values outside
# it.
input_only = basic_scope_vars & live_in - live_out
# Place the outputs first, then sort lexicographically.
scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
nouts = len(scope_vars) - len(input_only)
return scope_vars, undefined, nouts

@mdanatg
Copy link
mdanatg commented May 13, 2022

Yes, I think we do.

@bhack
Copy link
Contributor Author
bhack commented May 13, 2022

I don't know if it makes sense:

I've changed:

      if s in live_in or s in live_out or s in nonlocals or \
        (s not in live_in and s not in live_out):
import tensorflow as tf
tf.autograph.set_verbosity(0, True)
@tf.function
def test_b():
  fns = []
  for i in range(3):
    fns.append(lambda: print(i))
  for f in fns:
    f()
test_b()
2
2
2

And

import tensorflow as tf
tf.autograph.set_verbosity(0, True)
@tf.function
def test_b():
  fns = []
  for i in range(3):
    fns.append(lambda i=i: print(i))
  for f in fns:
    f()
test_b()
0
1
2

So the output is correct but the transformation, surely, it is still quite ugly:

# coding=utf-8
def tf__test_b():
    with ag__.FunctionScope('test_b', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        fns = []

        def get_state():
            return (i,)

        def set_state(vars_):
            nonlocal i
            (i,) = vars_

        def loop_body(itr):
            nonlocal i
            i = itr
            ag__.converted_call(ag__.ld(fns).append, (ag__.autograph_artifact((lambda : ag__.ld(print)(ag__.ld(i)))),), None, fscope)
        i = ag__.Undefined('i')
        ag__.for_stmt(ag__.converted_call(ag__.ld(range), (3,), None, fscope), None, loop_body, get_state, set_state, ('i',), {'iterate_names': 'i'})

        def get_state_1():
            return (f,)

        def set_state_1(vars_):
            nonlocal f
            (f,) = vars_

        def loop_body_1(itr_1):
            nonlocal f
            f = itr_1
            ag__.converted_call(ag__.ld(f), (), None, fscope)
        f = ag__.Undefined('f')
        ag__.for_stmt(ag__.ld(fns), None, loop_body_1, get_state_1, set_state_1, ('f',), {'iterate_names': 'f'})

@bhack
Copy link
Contributor Author
bhack commented May 13, 2022

P.s. as a side note, in the BUILD the loop scoping integration test is not registered. How the CI is running this currently?:

# # Scoping and modularity

If it is activated some loop_scoping tests are failing:

# # Scoping and modularity
reference_test(name = "loop_scoping_test")

@mdanatg
Copy link
mdanatg commented May 13, 2022

Ah, that's not intended. Can add it to the build file and mark the failing ones with self.skipTest, then file an issue to get them to pass?

For the transformation, I'm not sure, the rules are quite finnicky, and I'm not sure I'd change them without extensive testing. Likely still safer to manually add the iterate to the list of loop_vars.

@mdanatg
Copy link
mdanatg commented May 23, 2022

Side note - I just realized the limitations section of autograph does seem to document this case: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#variables-closed-over-by-lambda-functions

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@clime
Copy link
clime commented Feb 4, 2023

Hello, I have just discovered this in my app logs:

WARNING:tensorflow:From /home/clime/.virtualenvs/keras/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

Is there anything I should look for in my code or is this a generic warning that displays to anyone? To be more precise, does the fact that I am seeing this warning mean that I am actually using something which is now deprecated?

@bhack
Copy link
Contributor Author
bhack commented Feb 4, 2023

@clime
Copy link
clime commented Feb 4, 2023

@clime We hope to align to the python behavior after the deprecation:

https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result

My question rather is, if seeing that warning on my screen necessarily means that I am triggering that deprecated behavior (i.e. I am using lambda somewhere in @tf.function or similar).

@bhack
Copy link
Contributor Author
bhack commented Feb 4, 2023

Do you have a small code gist to reproduce this?

@clime
Copy link
clime commented Feb 4, 2023

Do you have a small code gist to reproduce this?

Sorry, I don't understand what I should be reproducing. I just have a pretty big code base and I just want to know how much relevant the warning is for me :).

@clime
Copy link
clime commented Feb 4, 2023

I guess the answer is here: 6197fa3

@bhack
Copy link
Contributor Author
bhack commented Feb 4, 2023

If it was internal someone introduced a new deprecation case internally ignoring the warning.

@sokrypton
Copy link

Are there any instructions on how to hide this warning? :P

python3.7/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

@bhack
Copy link
Contributor Author
bhack commented Feb 14, 2023

It is suppressed in nightly and in the next release

@MaximeLee
Copy link

Hi,
How can one disable this warning so that it is not printed? I am using TF 2.11.1 and I cant really change my TF version.

It is suppressed in nightly and in the next release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:autograph Autograph related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants