[go: nahoru, domu]

Skip to content

Commit

Permalink
(No publicly visible changes)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646571492
  • Loading branch information
matthiaskramm authored and tensorflower-gardener committed Jun 26, 2024
1 parent b572b20 commit 6ae557f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
30 changes: 30 additions & 0 deletions tensorflow/python/eager/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
_XLA_SHARDING_FOR_RESOURCE_VARIABLES = (
os.getenv("TF_XLA_SHARDING_FOR_RESOURCE_VARIABLES") == "1"
)
_OPTIONALS = os.getenv("TF_OPTIONALS") != "0"


def run_eager_op_as_function_enabled():
Expand Down Expand Up @@ -152,6 +153,26 @@ def xla_sharding_for_resource_variables_enabled():
return _XLA_SHARDING_FOR_RESOURCE_VARIABLES


def enable_optionals():
global _OPTIONALS
_OPTIONALS = True
if context_safe() is not None:
context_safe().optionals = True


def disable_optionals():
global _OPTIONALS
_OPTIONALS = False
if context_safe() is not None:
context_safe().optionals = False


def optionals_enabled():
if context_safe() is not None:
return context_safe().optionals
return _OPTIONALS


@contextlib.contextmanager
def temporarily_disable_xla_sharding_for_resource_variables():
"""Temporarily disables XLA sharding for resource variables.
Expand Down Expand Up @@ -553,6 +574,7 @@ def __init__(
self._xla_sharding_for_resource_variables = (
xla_sharding_for_resource_variables_enabled()
)
self._optionals = optionals_enabled()
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None
Expand Down Expand Up @@ -2232,6 +2254,14 @@ def xla_sharding_for_resource_variables(self):
def xla_sharding_for_resource_variables(self, enable):
self._xla_sharding_for_resource_variables = enable

@property
def optionals(self):
return self._optionals

@optionals.setter
def optionals(self, enable):
self._optionals = enable

@property
def device_policy(self):
# Only get the policy from the context if it has already been initialized
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,7 @@ py_strict_library(
":optional_ops_gen",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:backprop_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:auto_control_deps",
"//tensorflow/python/framework:auto_control_deps_utils",
"//tensorflow/python/framework:constant_op",
Expand Down
16 changes: 13 additions & 3 deletions tensorflow/python/ops/cond_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
Expand Down Expand Up @@ -138,7 +139,10 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# NOTE(skyewm): if there are any active sessions, this modification to `op`
# may make them unrunnable!

if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
if (
not context.optionals_enabled()
or control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())
):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(skyewm,jpienaar): can XLA support optionals?
Expand Down Expand Up @@ -1032,7 +1036,10 @@ def _capture_helper(self, tensor, name):
tensor_util.constant_value(tensor), dtype=tensor.dtype)
return self._captured_constants[tensor_id]

if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
if (
not context.optionals_enabled()
or control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())
):
# XLA does not yet support optionals, so capture intermediates directly.
# TODO(skyewm,jpienaar): can XLA support optionals?
if all(tensor is not capture for capture in self.external_captures):
Expand Down Expand Up @@ -1163,7 +1170,10 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
# NOTE(bjp): if there are any active sessions, this modification to `op`
# may make them unrunnable!

if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
if (
not context.optionals_enabled()
or control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph())
):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(bjp,jpienaar): can XLA support optionals?
Expand Down

0 comments on commit 6ae557f

Please sign in to comment.