[go: nahoru, domu]

Skip to content

Commit

Permalink
Improve effect support on internal backends.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 484566725
  • Loading branch information
hyeontaek authored and jax authors committed Oct 28, 2022
1 parent 2cc4fd2 commit 54967e9
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,8 @@ def from_hlo(xla_computation,
if hasattr(pci.backend, "compile_replicated"):
return _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks)
compile_options, host_callbacks, bool(unordered_effects),
ordered_effects)

with dispatch.log_elapsed_time(
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec"):
Expand Down Expand Up @@ -3316,8 +3317,8 @@ def from_hlo(name: str,
return _compile_replicated_mesh_executable_from_hlo(
name, computation, global_in_avals, global_out_avals, in_shardings,
out_shardings, in_is_global, auto_spmd_lowering, compile_options,
host_callbacks, kept_var_idx, backend, device_assignment, committed,
pmap_nreps)
host_callbacks, bool(unordered_effects), ordered_effects,
kept_var_idx, backend, device_assignment, committed, pmap_nreps)
else:
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec"):
Expand Down Expand Up @@ -3460,17 +3461,18 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
return out_handler(in_handler(outs))


def _compile_replicated_pmap_executable_from_hlo(xla_computation, pci,
input_indices, in_shardings,
handle_outs, compile_options,
host_callbacks):
def _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=xla_computation,
compile_options=compile_options, host_callbacks=host_callbacks,
in_avals=pci.avals, in_indices=input_indices,
in_shardings=in_shardings, kept_var_idx=set(range(len(pci.avals))),
mode=InputsHandlerMode.pmap, out_handler=handle_outs)
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), mode=InputsHandlerMode.pmap,
out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, execute_fun, None, pci.avals)

Expand All @@ -3479,8 +3481,8 @@ def _compile_replicated_pmap_executable_from_hlo(xla_computation, pci,
def _compile_replicated_mesh_executable_from_hlo(
name, computation, global_in_avals, global_out_avals, in_shardings,
out_shardings, in_is_global, auto_spmd_lowering, compile_options,
host_callbacks, kept_var_idx, backend, device_assignment, committed,
pmap_nreps):
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
backend, device_assignment, committed, pmap_nreps):
assert not auto_spmd_lowering
in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
Expand All @@ -3493,10 +3495,12 @@ def _compile_replicated_mesh_executable_from_hlo(
unsafe_call = backend.compile_replicated(
is_trivial=False, name=name, computation=computation,
compile_options=compile_options, host_callbacks=host_callbacks,
in_avals=input_avals, in_indices=input_indices,
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
mode=InputsHandlerMode.pjit_or_xmap, out_avals=global_out_avals,
out_shardings=out_shardings, committed=committed)
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=input_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, mode=InputsHandlerMode.pjit_or_xmap,
out_avals=global_out_avals, out_shardings=out_shardings,
committed=committed)
xla_executable = None
return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_shardings, out_shardings, auto_spmd_lowering,
Expand Down

0 comments on commit 54967e9

Please sign in to comment.