In Part II we discussed how to train a LLM on GPUs using Cluster Director. Alas, the latter does not support TPUs as of publishing, so I thought I’d write a brief addendum explaining how to achieve the same goal on TPUs.
Deploy a TPU Cluster
Availability is still a concern and the best approach to ensure it is to reserve the number and type of chips you need. Assuming you have done so, you can proceed to deployment using the console or the command line. The screenshot below, for instance, shows the parameters you need to specify for a 64-chip TPU v5e cluster in zone us-west4-a
The equivalent command line would be something like
gcloud compute tpus queued-resources create your-queued-resource-id \
--node-id your-node-id \
--project your-project-id \
--zone us-west4-a \
--accelerator-type v5litepod-64 \
--runtime-version v2-alpha-tpuv5-lite \
--reserved \
--valid-after-time 2025-11-18T00:00:00Z \
--valid-until-time 2025-11-24T00:00:00Z
You can also pass a startup script with the parameter
--metadata-from-file=’startup-script=startup-script.sh’
Note that the default topology for a v5e-64 cluster consists of 16 hosts, each with 4 chips.
Setting up the environment
You can collect the following instructions in a setup script that runs in the cloud console. For the sake of clarity, we’ll explain the commands one by one. First, let us set a few variables for further use.
export TPU_CLUSTER_NAME=<the cluster name you used before>
export PROJECT_ID=<your project id>
export ZONE=<the zone where you deployed>
Then we configure large memory pages to reduce memory management overhead in TPU VMs. This is especially relevant for v5e machines and newer.
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled"'
We can then proceed to install a runtime and the necessary libraries. In our case, that means Jax and Keras v3.
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U keras keras-hub datasets'
Now we test if the configuration was successful by running a simple python script.
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; jax.distributed.initialize(); print(jax.device_count()); print(jax.local_device_count())"'
Using ssh batch size of 1. Attempting to SSH into 1 nodes with a total of 16 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
SSH: Attempting to connect to worker 4...
SSH: Attempting to connect to worker 5...
SSH: Attempting to connect to worker 6...
SSH: Attempting to connect to worker 7...
SSH: Attempting to connect to worker 8...
SSH: Attempting to connect to worker 9...
SSH: Attempting to connect to worker 10...
SSH: Attempting to connect to worker 11...
SSH: Attempting to connect to worker 12...
SSH: Attempting to connect to worker 13...
SSH: Attempting to connect to worker 14...
SSH: Attempting to connect to worker 15...
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
64
4
The command initializes the jax runtime and asks each node to count the local chips and those available in the cluster. In our case, each of 16 nodes should report 4 and 64 respectively.
We still need to set up shared storage. We have several options (e.g. GCS, NFS or Lustre). Let’s consider the most economical one, i.e. mounting a storage container with GCSFuse. First, we must install gcsfuse on all nodes.
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install gcsfs proto-plus==1.24.0.dev1'
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`; echo "deb [signed-by=/usr/share/keyrings/cloud.google.asc] https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list'
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo tee /usr/share/keyrings/cloud.google.asc'
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt-get update; sudo apt-get install gcsfuse'
Then we must make sure we have a staging bucket where we copy the files we need (training scripts, configuration files, datasets etc…).
The scripts are essentially the same we have used so far in this blog series. We must however take into account the new 16x4 topology when sharding the model layers.
We mount the staging bucket with the following command.
gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='mkdir -p /home/<your user id>/<your mount point>; gcsfuse --implicit-dirs <your staging bucket> /home/<your user id>/<your mount point>'
Run the training
Now we are ready to launch the main training job.
nohup gcloud compute tpus tpu-vm ssh ${TPU_CLUSTER_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd <your mount point/scripts> && python3 training-script.py ' & disown
by prepending nohupwe make sure the output is saved to the local filesystem and by appending & disown we let the process run even if the current shell times out.
The best way to observe progress is to use a service like tensorboard or “weights and biases”. In our previous example, we stored tensorboard logs onto GCS. We can visualize them simply by creating a colab instance and adding the relevant extension. For instance, have a look at the screenshot below.
We can also monitor TPU utilization using the GCP metrics explorer, as shown in the following figure.
The run may last several days. It is crucial to monitor the health of the cluster in that time. Fortunately, we have dashboards for that.
If you want to try yourself, have a look at https://github.com/gmarchetti2020/colmo.git
We provide the code and some instructions, but not the Dolma dataset. You can find that on Huggingface.



