[go: nahoru, domu]

Skip to content

Commit

Permalink
Move JAX builds to build.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642425106
  • Loading branch information
ddunl authored and tensorflower-gardener committed Jun 11, 2024
1 parent 7d37407 commit 8234a8d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 90 deletions.
87 changes: 1 addition & 86 deletions third_party/xla/.kokoro/jax/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,89 +19,4 @@
# -o history: record shell history
set -euox pipefail -o history

# Builds + tests jaxlib against CL/PR version of XLA + JAX main.

source "${KOKORO_GFILE_DIR}/utils.sh"

function is_linux_gpu_job() {
[[ "$KOKORO_JOB_NAME" =~ tensorflow/xla/jax/.*gpu.* ]]
}

clone_main_jax() {
git clone https://github.com/google/jax.git
}

prelude() {
export JAX_ENABLE_X64=0

if is_linux_gpu_job ; then
export JAX_CUDA_VERSION=12
export JAX_CUDNN_VERSION=9.1
nvidia-smi
setup_env_vars_py39
else
setup_env_vars_py312
fi

cd "${KOKORO_ARTIFACTS_DIR}"

use_local_or_install_python
install_packages "$NUMPY_VERSION" "$SCIPY_VERSION"
clone_main_jax
# Install bazel
update_bazel_linux

cd jax

}

build_and_test_on_rbe_cpu() {
# Run the tests.
bazel \
test \
--verbose_failures=true \
--override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \
--config=avx_posix \
--config=mkl_open_source_only \
--config="rbe_cpu_linux_py3.12" \
--config=tensorflow_testing_rbe_linux \
--test_env=JAX_NUM_GENERATED_CASES=25 \
--test_output=errors \
-- //tests:cpu_tests //tests:backend_independent_tests
}

build_and_test_on_rbe_gpu() {
# Runs non-multiaccelerator tests with one GPU apiece.
# It appears --run_under needs an absolute path.

bazel \
test \
--verbose_failures=true \
--override_repository=xla="${KOKORO_ARTIFACTS_DIR}"/github/xla \
--config=avx_posix \
--config=mkl_open_source_only \
--config="rbe_linux_cuda12.3_nvcc_py3.9" \
--config=tensorflow_testing_rbe_linux \
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
--test_output=errors \
--test_env=JAX_SKIP_SLOW_TESTS=1 \
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
--test_env=JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow" \
--test_tag_filters=-multiaccelerator \
-- //tests:gpu_tests //tests:backend_independent_tests
}

# Generate a templated results file to make output accessible to everyone
"$KOKORO_ARTIFACTS_DIR"/github/xla/.kokoro/generate_index_html.sh "$KOKORO_ARTIFACTS_DIR"/index.html

prelude

if is_linux_gpu_job ; then
build_and_test_on_rbe_gpu
else
build_and_test_on_rbe_cpu
fi

echo "bazel-testlogs (test results) location:"
find "$KOKORO_ARTIFACTS_DIR" \
-type l,d -name bazel-testlogs || echo "bazel-testlogs not found"
"$KOKORO_ARTIFACTS_DIR"/github/xla/build_tools/build.py
72 changes: 68 additions & 4 deletions third_party/xla/build_tools/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class BuildType(enum.Enum):
GPU = enum.auto()
GPU_CONTINUOUS = enum.auto()

JAX_CPU = enum.auto()
JAX_GPU = enum.auto()


@dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310)
class DockerImage:
Expand Down Expand Up @@ -157,6 +160,7 @@ class Build:
configs: Tuple[str, ...] = ()
tag_filters: Tuple[str, ...] = ()
action_env: Dict[str, Any] = dataclasses.field(default_factory=dict)
test_env: Dict[str, Any] = dataclasses.field(default_factory=dict)
options: Dict[str, Any] = dataclasses.field(default_factory=dict)

def bazel_test_command(self) -> List[str]:
Expand All @@ -165,9 +169,10 @@ def bazel_test_command(self) -> List[str]:
build_tag_filters = f"--build_tag_filters={','.join(self.tag_filters)}"
test_tag_filters = f"--test_tag_filters={','.join(self.tag_filters)}"
action_env = [f"--action_env={k}={v}" for k, v in self.action_env.items()]
all_options = (
[build_tag_filters, test_tag_filters] + configs + action_env + options
)
test_env = [f"--test_env={k}={v}" for k, v in self.test_env.items()]

tag_filters = [build_tag_filters, test_tag_filters]
all_options = tag_filters + configs + action_env + test_env + options
return ["bazel", "test", *all_options, "--", *self.target_patterns]


Expand Down Expand Up @@ -255,13 +260,63 @@ def nvidia_gpu_build_with_compute_capability(
type_=BuildType.GPU_CONTINUOUS, compute_capability=80
)

_JAX_CPU_BUILD = Build(
type_=BuildType.JAX_CPU,
repo="google/jax",
docker_image=_DEFAULT_IMAGE,
configs=(
"avx_posix",
"mkl_open_source_only",
"rbe_cpu_linux_py3.12",
"tensorflow_testing_rbe_linux",
),
target_patterns=("//tests:cpu_tests", "//tests:backend_independent_tests"),
test_env=dict(
JAX_NUM_GENERATED_CASES=25,
JAX_SKIP_SLOW_TESTS=1,
),
options=dict(
verbose_failures=True,
test_output="errors",
override_repository="xla=/github/xla",
profile="profile.json.gz",
),
)

_JAX_GPU_BUILD = Build(
type_=BuildType.JAX_GPU,
repo="google/jax",
docker_image=_DEFAULT_IMAGE,
configs=(
"avx_posix",
"mkl_open_source_only",
"rbe_linux_cuda12.3_nvcc_py3.9",
"tensorflow_testing_rbe_linux",
),
target_patterns=("//tests:gpu_tests", "//tests:backend_independent_tests"),
tag_filters=("-multiaccelerator",),
test_env=dict(
JAX_SKIP_SLOW_TESTS=1,
TF_CPP_MIN_LOG_LEVEL=0,
JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow",
),
options=dict(
verbose_failures=True,
test_output="errors",
override_repository="xla=/github/xla",
profile="profile.json.gz",
),
)

_KOKORO_JOB_NAME_TO_BUILD_MAP = {
"tensorflow/xla/linux/arm64/build_cpu": _CPU_ARM64_BUILD,
"tensorflow/xla/linux/cpu/build_cpu": _CPU_X86_BUILD,
"tensorflow/xla/linux/gpu/build_gpu": _GPU_BUILD,
"tensorflow/xla/linux/github_continuous/arm64/build_cpu": _CPU_ARM64_BUILD,
"tensorflow/xla/linux/github_continuous/build_gpu": _GPU_CONTINUOUS_BUILD,
"tensorflow/xla/linux/github_continuous/build_cpu": _CPU_X86_BUILD,
"tensorflow/xla/jax/cpu/build_cpu": _JAX_CPU_BUILD,
"tensorflow/xla/jax/gpu/build_gpu": _JAX_GPU_BUILD,
}


Expand All @@ -273,6 +328,16 @@ def main():

sh(["./github/xla/.kokoro/generate_index_html.sh", "index.html"])

_, repo_name = build.repo.split("/")
if build.repo != "openxla/xla":
sh([
"git",
"clone",
"--depth=1",
f"https://github.com/{build.repo}",
f"./github/{repo_name}",
])

# TODO(b/338885148): Remove this block after TF was updated to cuDNN 9
if build.type_ in (BuildType.GPU, BuildType.GPU_CONTINUOUS):
sh(
Expand All @@ -284,7 +349,6 @@ def main():
],
)

_, repo_name = build.repo.split("/")
with build.docker_image.pull_and_run(
workdir=f"/github/{repo_name}", **_DEFAULT_DOCKER_OPTIONS
) as docker_exec:
Expand Down

0 comments on commit 8234a8d

Please sign in to comment.