[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(No publicly visible changes) #70385

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tensorflow/python/eager/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
_XLA_SHARDING_FOR_RESOURCE_VARIABLES = (
os.getenv("TF_XLA_SHARDING_FOR_RESOURCE_VARIABLES") == "1"
)
_OPTIONALS = os.getenv("TF_OPTIONALS") != "0"


def run_eager_op_as_function_enabled():
Expand Down Expand Up @@ -152,6 +153,26 @@ def xla_sharding_for_resource_variables_enabled():
return _XLA_SHARDING_FOR_RESOURCE_VARIABLES


def enable_optionals():
global _OPTIONALS
_OPTIONALS = True
if context_safe() is not None:
context_safe().optionals = True


def disable_optionals():
global _OPTIONALS
_OPTIONALS = False
if context_safe() is not None:
context_safe().optionals = False


def optionals_enabled():
if context_safe() is not None:
return context_safe().optionals
return _OPTIONALS


@contextlib.contextmanager
def temporarily_disable_xla_sharding_for_resource_variables():
"""Temporarily disables XLA sharding for resource variables.
Expand Down Expand Up @@ -553,6 +574,7 @@ def __init__(
self._xla_sharding_for_resource_variables = (
xla_sharding_for_resource_variables_enabled()
)
self._optionals = optionals_enabled()
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None
Expand Down Expand Up @@ -2232,6 +2254,14 @@ def xla_sharding_for_resource_variables(self):
def xla_sharding_for_resource_variables(self, enable):
self._xla_sharding_for_resource_variables = enable

@property
def optionals(self):
return self._optionals

@optionals.setter
def optionals(self, enable):
self._optionals = enable

@property
def device_policy(self):
# Only get the policy from the context if it has already been initialized
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,7 @@ py_strict_library(
":optional_ops_gen",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:backprop_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:auto_control_deps",
"//tensorflow/python/framework:auto_control_deps_utils",
"//tensorflow/python/framework:constant_op",
Expand Down
16 changes: 13 additions & 3 deletions tensorflow/python/ops/cond_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
Expand Down Expand Up @@ -138,7 +139,10 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# NOTE(skyewm): if there are any active sessions, this modification to `op`
# may make them unrunnable!

if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
if (
not context.optionals_enabled()
or control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())
):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(skyewm,jpienaar): can XLA support optionals?
Expand Down Expand Up @@ -1032,7 +1036,10 @@ def _capture_helper(self, tensor, name):
tensor_util.constant_value(tensor), dtype=tensor.dtype)
return self._captured_constants[tensor_id]

if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
if (
not context.optionals_enabled()
or control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())
):
# XLA does not yet support optionals, so capture intermediates directly.
# TODO(skyewm,jpienaar): can XLA support optionals?
if all(tensor is not capture for capture in self.external_captures):
Expand Down Expand Up @@ -1163,7 +1170,10 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
# NOTE(bjp): if there are any active sessions, this modification to `op`
# may make them unrunnable!

if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
if (
not context.optionals_enabled()
or control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())
):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(bjp,jpienaar): can XLA support optionals?
Expand Down
Loading