We are excited to announce that starting with the next Ray release, Google Cloud TPUs are officially a first-class accelerator in Ray. TPUs are now integrated into the Ray release pipelines and supported across most core Ray libraries, including Ray Train and Ray Serve. These features are currently available in the nightly release of Ray.
Previously, TPUs were classified as “experimental” - meaning that while they were usable, they relied heavily on community support, and developers often had to build their own container images just to get started. With the release of Ray 2.55, TPUs will become the first hardware accelerator promoted to “fully supported” status since the initial support for NVIDIA GPUs.
This upgrade dramatically improves the developer experience and production deployments with Ray, providing pre-built container images and official support from the Ray team. This article dives into a deeper look at what this upgraded support brings to the Ray ecosystem.
Out-of-the-box container images
To make it easier than ever to get started with TPUs on Ray, we now build official Ray -tpu images with jax[tpu] dependencies pre-installed. Aside from CUDA/GPU tagged images, these are the only accelerator-specific Ray images built directly within Ray’s CI.
Previously, developers had to write their own Dockerfiles to extend the base Ray CPU images, manually configure the correct libtpu versions, and maintain those images in their own container registries. Getting a workload up and running with these new pre-built images is now straightforward. For example, you can directly specify the TPU image in your KubeRay cluster deployments:
apiVersion: ray.io/v1
kind: RayCluster
metadata:
name: ray-tpu-cluster
spec:
...
workerGroupSpecs:
- replicas: 2 # Number of slices
numOfHosts: 4 # Number of TPU VMs per slice
groupName: tpu-group
template:
spec:
containers:
- name: ray-worker
# Use the official pre-built Ray TPU image
image: rayproject/ray:nightly-py312-tpu
resources:
limits:
google.com/tpu: "4"
requests:
google.com/tpu: "4"
env:
- name: JAX_PLATFORMS
value: tpu
- name: ENABLE_PJRT_COMPATIBILITY
value: true
- name: LIBTPU_INIT_ARGS
...
The above spec creates a TPU worker group running on 2 physical slices (GKE nodepools) with 4 TPU workers per slice - or 8 Ray workers total. The above spec can be utilized to run multi-slice training with MaxText, a high performance training library written in Python/JAX, and the JaxTrainer in Ray Train. The Ray training script looks like this:
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer
...
def train_loop_per_worker(config):
import maxtext
from maxtext.trainers.pre_train.train import main as maxtext_main
argv = config["argv"]
maxtext_main(argv)
trainer = JaxTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"argv": argv},
scaling_config=ScalingConfig(
use_tpu=True,
# To use multi-slice, just specify the topology and total workers.
# Ray Train will automatically determine the number of slices.
num_workers=8,
topology="4x4",
accelerator_type="TPU-V6E",
resources_per_worker={"TPU": 4},
placement_strategy="SPREAD",
),
run_config=RunConfig(
name="maxtext_multi_slice",
worker_runtime_env={
"uv": {
# maxtext requires some additional dependencies
"packages": ["maxtext[tpu]==0.2.1"],
"uv_pip_install_options": ["--resolution=lowest"]
},
},
),
)
result = trainer.fit()
Ray Train handles scheduling the workload to two, complete TPU slices that are able to utilize the high-speed ICI mesh for their worker processes. The actual training logic is encapsulated inside MaxText’s train function. Since we’re using a Ray image that already includes JAX dependencies, we just have to include maxtext in the runtime_env and our workloads are ready to deploy.
ray job submit \
--address http://localhost:8265 \
--working-dir . \
-- python maxtext_ray_trainer.py \
maxtext/MaxText/configs/base.yml \
base_output_directory=/data/ \
dataset_type=synthetic \
per_device_batch_size=4 \
max_target_length=4096 \
model_name=llama3-8b \
steps=100 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=4 \
run_name=rayjob-multi-slice
...
(TrainController pid=1911) Requesting resources: {'TPU': 4, 'accelerator_type:TPU-V6E': 0.001} * 8
(TrainController pid=1911) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=1911) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8
...
(RayTrainWorker pid=38170, ip=10.60.4.10) [process=0][thread=save_finalize][step=99] CheckpointManager Save Finalize is syncing with other hosts... [repeated 7x across cluster]
(RayTrainWorker pid=38170, ip=10.60.4.10) [process=0][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=38170, ip=10.60.4.10) [process=0][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=38170, ip=10.60.4.10) [process=0][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 7x across cluster]
(RayTrainWorker pid=38170, ip=10.60.4.10) completed step: 99, seconds: 1.003, TFLOP/s/device: 114.969, Tokens/s/device: 16337.797, total_weights: 524288, loss: 1.924 [repeated 7x across cluster]
------------------------------------------
Job 'raysubmit_hJ9ncFfsruLYB85j' succeeded
------------------------------------------
The above result shows a successful, untuned pre-training run of LLaMA 3 8B utilizing MaxText, multi-slice v6e TPUs, and Ray Train.
Upgraded support across the Ray ecosystem
The Ray project will continue to evolve native libraries to improve TPU capabilities. With Ray 2.55 and KubeRay v1.6, you can expect the following integrations:
-
Ray Train: We have added support for TPUs to Ray Train through the JaxTrainer API, which now supports both multi-slice and fault tolerant elastic training.
-
Ray Core: A TPU utility library has been added to simplify scheduling on a multi-host slice, handling tasks like setting coordinator environment variables.
-
Ray Serve & vLLM: TPUs are now supported with the Ray backend executor through the tpu_inference package. We have validated support for both tensor and pipeline parallelism with vLLM on a RayCluster with multi-host TPUs. These changes are available in the vllm-tpu nightly image, which includes the latest version of Ray.
-
SPMD support in KubeRay: KubeRay natively supports the atomic creation and deletion of multi-host “replicas” (slices) via the numOfHosts field - a critical pattern for running SPMD workloads. While this previously required the Ray autoscaler, KubeRay v1.6 promotes the
RayMultiHostIndexingfeature to Beta, enabling host- and slice-level indexing through Kubernetes Pod labels and the atomic scaling of Pods in the same group. On GKE, the KubeRay TPU webhook automatically utilizes these labels to bootstrap the requiredlibtpuandMEGASCALEenvironment variables, seamlessly enabling multi-host and multi-slice workloads for all TPU generations v4+. -
Autoscaling & observability: TPUs are fully supported in the Ray V2 Autoscaler with KubeRay, and TPU metrics are natively available directly within the Ray dashboard.
For more details on accelerator scheduling, check out the Ray documentation. For TPUs, see Use TPUs with KubeRay.
What’s next: In-progress and future work
We are actively expanding TPU capabilities across the rest of the Ray ecosystem. Our current in-progress work includes:
-
Ray Data: Adding JAX support to shard for both data and model parallelism to multi-host TPUs through a new
iter_jax_batchesand JAX utils library. Proof of concept. -
Ray LLM: Introducing TPU scheduling support to automatically configure placement groups based on a new topology API. Proof of concept.
In the future, we also plan to explore utilizing SkyRL on multi-host TPUs for reinforcement learning and post-training, adding dynamic slicing support for TPU super/sub slices in Ray, and enhancing support for batch inference on TPUs through Ray LLM.
References