XLA is a compiler used internally by JAX. JAX is distributed via PyPI wheels. The JAX Continuous Integration documentation explains how to build JAX wheels using the tensorflow/ml-build:latest Docker container.
We can extend these instructions to build XLA targets within the JAX container as well. This ensures that the XLA targets’ build configuration is consistent with the JAX/XLA build configuration, which can be useful if we want to reproduce workload results using XLA tools that were originally created in JAX.
- Clone the JAX repository and navigate to the 'jax' directory
git clone https://github.com/jax-ml/jax.git
cd jax- Start JAX CI/Release Docker container by running:
./ci/utilities/run_docker_container.shThis will start a Docker container named 'jax'.
- Build the jax-cuda-plugin target inside the container using:
docker exec jax ./ci/build_artifacts.sh jax-cuda-pluginThis will create the .jax_configure.bazelrc file with the required build configuration, including CUDA/cuDNN support
- Access an interactive shell inside the container:
docker exec -ti jax /bin/bashYou should now be in the /jax directory within the container
- Build the XLA target with the following command, e.g.:
/usr/local/bin/bazel build \
--config=cuda_libraries_from_stubs \
--verbose_failures=true \
@xla//xla/tools/multihost_hlo_runner:hlo_runner_mainOptionally, you can overwrite HERMETIC envs, e.g.:
--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_90"
- Copy the resulting artifacts to
/jax/distto access them from the host OS if needed
cp bazel-bin/external/xla/xla/tools/multihost_hlo_runner/hlo_runner_main \
./dist/- Exit the interactive shell:
exit