Training large language models (LLMs) has traditionally been the domain of a select few with access to massive, dedicated supercomputers. The process is fraught with challenges: acquiring and managing hundreds of GPUs, overcoming network bottlenecks, and orchestrating complex software environments. This tutorial challenges that paradigm.
We’ll demonstrate how to build a robust, scalable, and cost-effective system for distributed LLM training on Google Cloud in a simple and reproducible way. This approach allows you to provision a massive fleet of cutting-edge GPUs on-demand, run a high-performance training job, and tear it all down, paying only for what you use. We’ll tackle the primary challenges of large-scale training head-on:
- On-Demand Supercomputing: How do you gain access to a large cluster of powerful accelerators like NVIDIA B200 GPUs precisely when you need them, without the cost of maintaining an always-on environment?
- Networking Bottlenecks: How do you ensure that the communication between hundreds of GPUs is fast enough to keep them saturated with work, preventing the network from becoming the main bottleneck in your training job?
- Complex Job Orchestration: How do you simplify the deployment and management of a multi-node distributed training job, including all its dependencies, environment variables, and networking requirements?
To demonstrate, we will fine-tune a well-known model (Llama 3.1 70B Instruct) on a text-to-SQL dataset. We won’t do much with the fine-tuned model (stay tuned for future posts) but we will use this as a basis to run distributed training on some infrastructure. The system we build is purpose-built for performance, we will use PyTorch as our ML framework on Google Kubernetes Engine (GKE) with several key technologies:
- Dynamic Workload Scheduler (DWS) with Kueue: For queuing jobs and provisioning large-scale GPU clusters that scale from zero.
- High-Performance Networking: Using Google’s RDMA over Converged Ethernet (RoCE) to provide a low-latency, high-bandwidth fabric for inter-GPU communication.
- PyTorch specific features: Like Fully Sharded Data Parallel 2 (FSDP2), flash attention 2 and torch dynamo to improve the performance of training a large LLM at scale.
- Helm: To package our Kubernetes configurations, making our complex training application easy to manage, configure, and deploy.
The Solution Stack: A High-Level View
Before we dive in, let’s look at the architecture we build in this tutorial.
The flow starts at the Kubernetes Job API, we can assume that in this case the API is invoked by a user running Helm via the CLI. Kueue, a job queueing controller, intercepts the job and holds it until Dynamic Workload Scheduler (DWS), using Flex-start mode, can provision the requested GPU VMs. Once the nodes are ready, they mount a Google Cloud Storage (GCS) bucket using GCSFuse via the Google Virtual Network Interface Card (gvNIC) on the VMs for reading the dataset and writing checkpoints. Then the PyTorch workers on the GPU VMs start. The PyTorch workers use PyTorch’s FSDP implementation, which is vastly improved when using it with a high-speed RDMA network for collective communications between GPUs. We can see the RDMA VPC and the corresponding subnets for each GPU Network Interface Card (NIC) on the diagram. PyTorch via FSDP then shards the model parameters, the gradients and optimizer states across our fleet of GPUs. Training can then commence. Once the job is done, DWS will tear down the GPU VMs. We’ll drill more into the setup as we go along and we can refer back to here as needed.
Pre-Requisites
If you intend to run this tutorial yourself, ensure you have the following:
- A Google Cloud project with billing enabled.
- A Hugging Face account with an access token to download models like Llama 3.1 and a HF token you can pass to the setup environment variables.
- Ensure your project has sufficient quota for GKE, Compute Engine, and specifically for “A4 series GPUs with 8 NVIDIA_B200” in your target region. It will use spot / preemptible quota which is self service.
- A local environment with the Google Cloud CLI,
kubectl, andhelmand envsubst installed
We’re not going into an exhaustive list of IAM requirements so you may have to work through some IAM challenges if you are not an ‘Owner’ or an ‘Editor’ in your GCP project and are using fine grained IAM controls (Which is the right way to do things). Also note that running this tutorial yourself will incur billing costs.
Part 1: Environment Setup
Getting your cloud environment ready is automated by a single setup script that we provide.
- Clone the Repository:
git clone https://github.com/esaaren/torch-distributed-training-gke.git && cd torch-distributed-training-gke
export REPO_ROOT=`git rev-parse --show-toplevel`
- Configure Your Environment: Review the environment variables in the
.envfile and customize them to your liking. We leave some placeholders so you can edit these quickly. - Run the Setup Script:
source $REPO_ROOT/.env
cd $REPO_ROOT/setup && ./setup.sh
This script handles all the heavy lifting: creating dedicated VPCs for high-performance networking, setting up the GKE cluster with a DWS-enabled node pool, installing Kueue, and configuring all necessary permissions. Feel free to review setup.sh to see how we managed to do all of this (or if you want to use it as inspiration for other things). Everything in here is done using simple cli commands and should be straight forward to follow or adapt.
Part 2: Understanding the Infrastructure
The setup script provisions several critical components designed for large-scale AI workloads. Many of the components in here are adapted from great public doc resources found here and here, the former of which describe how we set up the clusters and RDMA and the latter of which describe how we set up DWS and Kueue.
High-Performance Network Foundation
Distributed training, especially with large models that require sharding strategies like FSDP are very network-intensive because the model’s weights and gradients are constantly being communicated between nodes. To prevent the network from becoming a bottleneck, we deploy Google’s implementation of RDMA over Converged Ethernet (RoCE) under the section Network Setup. This technology (RoCE) allows the network cards (NICs) on different GPU nodes to exchange data directly from their memory, bypassing the host CPUs and reducing the network hops that are required to transfer payloads to and from GPUs. This significantly reduces latency and maximizes bandwidth, which is essential for keeping the GPUs fed with data during the frequent synchronization steps of FSDP and other operations during training.
Like our public doc called out, we will use the default network for the CPU NIC, build a second VPC for the other CPU NIC and build a third VPC for the GPU NICs with 8 subnets in it (1 for each of the NICs). The CPU NICs are used for non-GPU required networking (like download files etc). We only need to create these networking objects once, and it will work for as many VMs (GPUs and CPUs) we want in parallel, think of these VPCs and subnets as dedicated swim lanes for the chips to move information across. To further illustrate why we do all of this, see an architecture diagram of an A3U (H200) and A4 (B200) VM below. We can see the 8 CX-7 NICs (One for each GPU) and the two gvNICs (One for each CPU) on the architecture, which helps us understand how our cloud networking objects map to the underlying physical hardware.
To clarify as well, the networking and all components of this tutorial only work for A3U and A4 machines on GCP, given that the networking implementation for other GPU types RDMA differs slightly, but the principals here remain the same.
GKE with Dynamic Workload Scheduler (DWS) and Kueue
GKE and DWS form the backbone of our cost-effective, on-demand democratized compute cluster. It gives us access to the following:
- Compact placement: B200s on DWS flex are compactly placed (The GPUs are co-located in the data center) by default, which is essential for the frequent gradient and weight synchronizations required by frameworks like FSDP, translating into higher training throughput (MFU).
- Queued Provisioning: Your job waits in a Kueue-managed queue until the exact number and type of GPU nodes it requires are provisioned by DWS. Everyone submitting jobs to DWS sit in a shared queue.
- Scales from Zero: The node pool is configured with zero initial nodes. DWS automatically scales it up when a job starts and scales it back down to zero when the job finishes.
- Job Stability: For long-running training jobs, stability is key. The node pool is configured with auto-repair and auto-upgrades disabled to prevent interruptions.
We can see how this is done under the Cluster Setup and Nodepool Setup sections in our setup script via simple gcloud commands.
The cluster create command has a few of the mandatory arguments in it for using RDMA:
gcloud container clusters create ${CLUSTER_NAME} \
...
--enable-dataplane-v2 \
--enable-ip-alias \
--enable-multi-networking \
...
We can also see some of the definitions that allow us to use DWS flex in our node pools create command:
gcloud container node-pools create gpu-nodepool-dws \
...
--reservation-affinity=none \
--location-policy=ANY \
--enable-queued-provisioning \
--flex-start \
--no-enable-autoupgrade \
--no-enable-autorepair \
--enable-autoscaling \
...
Also, pay special attention to network_mapping.yaml (via GKENetworkParamSet) and how it relates to our node pool creation command which shows you how we tie our networking interfaces we built earlier under the networking section together to make RDMA function. You will see the networking interfaces passed like so: –additional-node-network=network=${RDMA_NETWORK_PREFIX}-net,subnetwork=${RDMA_NETWORK_PREFIX}-sub-0 and then referenced in our network mapping like so:
apiVersion: networking.gke.io/v1
kind: GKENetworkParamSet
metadata:
name: rdma-0
spec:
vpc: ${RDMA_NETWORK_PREFIX}-net
vpcSubnet: ${RDMA_NETWORK_PREFIX}-sub-0
deviceMode: RDMA
This will allow us to use these network mappings later (rdma-0, rmda-1 etc) in our Kubernetes Job manifests to ensure we are using RDMA correctly.
Kueue is then set up in the following section Kueue Setup where we simply apply the OSS Kueue manifests and then apply our own Kueue queue configuration via kueue.yaml. The only difference between this manifest and a standard simple Kueue manifest is that we are using ProvisioningRequestConfig to hook into GCP’s DWS technology via provisioningClassName: queued-provisioning.gke.io to procure our GPUs.
Data and Checkpoints with GCS Fuse
Large datasets and model checkpoints are stored in Google Cloud Storage. We use the GCS FUSE CSI driver to mount our GCS bucket directly into the training pods as if it were a local filesystem. This provides a simple, POSIX-compliant interface for our PyTorch script to read training data and write sharded model checkpoints without needing to integrate complex object storage SDKs. Down the road we can also very easily integrate storage features like Anywhere Cache or Rapid storage to improve our storage performance without requiring any major code re-writes. An overview of some of these other storage features released by GCP can be reviewed here.
In the job we see later we follow some of our published GCSFuse performance best practices in the configuration.
Image and software dependencies
A critical, often overlooked, part of a successful distributed training setup is the software dependency chain. Getting this right is essential, as even minor mismatches between the host machine’s drivers and the container’s libraries can lead to cryptic errors. Let’s walk through the layers of our setup to see how they align.
- GKE Host Node (The Foundation): When we create our GKE cluster with B200s, we set
gpu-driver-version=DEFAULT. This provides our nodes with a specific NVIDIA driver that supports CUDA 12.8. This is the base layer everything else must be compatible with. - NCCL & gIB DaemonSet: We then deploy a
DaemonSetvia nccl_installer.yaml to install the NCCL and gIB libraries directly onto the GKE nodes. This allows our pods to communicate over the high-speed RDMA fabric. This installer from our this uses a specific GCP-managed image (us-docker.pkg.dev/gce-ai-infra/gpudirect-gib/nccl-plugin-gib:v1.0.6) which has its own requirements:
| Machine series | Bundled NCCL version | Supported NCCL versions | Minimum GPU driver version | Minimum CUDA runtime version |
|---|---|---|---|---|
| A4 | 2.26.6-1 | 2.25, 2.26 | 570.124.06 | 12.8 |
- Docker Container (The Application Layer): Finally, we build our application container. All choices inside the
Dockerfilemust align with the layers below it:- Base Image: We start with
nvcr.io/nvidia/cuda-dl-base:25.03-cuda12.8-devel-ubuntu24.0, which is pre-packaged with a CUDA 12.8 runtime and a NCCL version that aligns to our supported NCCL versions from our NCCL Daemonset. - PyTorch: We install the specific PyTorch build compiled for CUDA 12.8 by using the correct index URL:
... --index-url https://download.pytorch.org/whl/cu128. - CUDA Compiler: We install the
nvcccompiler (cuda-toolkit-12-8) inside the container. This is crucial fortorch.compile, which JIT-compiles parts of our model into optimized CUDA kernels at runtime.
- Base Image: We start with
Ensuring this chain is in sync, from the GKE drivers to the NCCL daemonset to the container’s CUDA runtime and PyTorch build is what makes the end-to-end setup work. This provides a clear blueprint for managing these nuances if you need to adapt this solution to different library or hardware versions.
Part 3: The PyTorch Training Job
The entire workload is defined and managed by a Helm chart, which deploys our PyTorch script. First, let’s review the script itself.
The PyTorch code (fsdp.py)
Our training script is built for performance and scalability. While not an exhaustive review, we will review some key features.
Performance-Tuned Model: To maximize throughput, we apply several state-of-the-art optimizations:
FSDP Sharding: We use PyTorch’s implementation of FullyShardedDataParallel (FSDP2) to shard the model’s parameters, gradients, and optimizer states across all GPUs. This is the core technique that enables us to fine-tune models that are too large for a single GPU’s memory. Here is how we apply the sharding policy to the model’s decoder layers:
# --- Step 5: Shard Model with FSDP ---
if global_rank == 0: print("🚀 Sharding model with FSDPv2...")
...
for module in model.modules():
if isinstance(module, decoder_layer):
fully_shard(module, mp_policy=mp_policy, reshard_after_forward=True)
fully_shard(model, mp_policy=mp_policy, reshard_after_forward=True)
FlashAttention 2: We enable the highly-optimized FlashAttention 2 kernel directly when loading the model from Hugging Face. This is a simple but powerful one-line change that significantly speeds up the attention mechanism by reducing memory I/O and is made easy for us by Hugging Face and the transformers library.
# --- Step 4: Prepare Model ---
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2" # <-- Enabled here
)
Torch Compile: After the model is sharded and all of our other primitives (optimizers, schedulers etc) are defined, we use torch.compile to JIT-compile the model into optimized kernels, further boosting performance. We include a brief warmup phase later in the script to absorb the initial compilation overhead.
# -- Step 7: Model compilation ---
if global_rank == 0: print("🚀 Compiling model...")
model = torch.compile(model, dynamic=True)
Some other nice things we are doing to make our training script more complete include
Efficient Data Handling: The script uses a “process-once, read-many” strategy. On the first run, a single worker downloads the raw dataset, formats it into a conversational prompt template, and saves the result to a shared GCS Fuse mount. All workers then load this pre-processed data for every epoch, with a DistributedSampler ensuring each GPU gets a unique data slice during training.
Checkpointing: The script uses torch.distributed.checkpoint to save the sharded model state to our GCSFuse mount. This includes logic to automatically find and resume from the most recent checkpoint, making our jobs resilient if we want to resume from failed or past runs. We check for an existing checkpoint before training begins:
# --- Step 8: Checkpoint auto resuming ---
starting_epoch = 0
resume_from_checkpoint, last_epoch = find_most_recent_checkpoint(OUTPUT_DIR)
if resume_from_checkpoint:
dc.load(
state_dict=state_to_load,
checkpoint_id=checkpoint_shard_path
)
starting_epoch = loaded_epoch + 1
Then, during the training loop, we save a new checkpoint periodically:
if (epoch + 1) % CHECKPOINT_EPOCHS == 0:
save_checkpoint(
model=model,
optimizer=optimizer,
scheduler=scheduler,
# ... other args
)
We don’t enable checkpointing by default in the tutorial, but you can review some of the job configurations later on to change how frequently a checkpoint is written. We will use these trained checkpoints in a follow up post for something interesting!
Live Performance Logging: A custom PerfLogger class calculates and prints key metrics like TFLOPs/s/GPU and Tokens/s/GPU for every step, providing immediate insight into the job’s performance.
Parameter-Efficient Fine-Tuning (PEFT / LoRA) adapters The script fully supports optionally using Parameter-Efficient Fine-Tuning (PEFT) via the popular LoRA method. This technique freezes the large pre-trained model and trains only a small set of lightweight “adapter” layers. This provides two major benefits: it drastically reduces the GPU memory required for the optimizer states, and it produces final checkpoints that are smaller (megabytes instead of gigabytes) and easy to share. We can enable this via configuration provided in the repository.
if PEFT:
peft_config = LoraConfig(
r=PEFT_R,
lora_alpha=PEFT_ALPHA,
lora_dropout=PEFT_DROPOUT,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
Gradient Accumulation To achieve the training stability of a large batch size without the high memory cost, the script uses gradient accumulation. It processes several smaller batches sequentially, accumulates their gradients, and then performs a single optimizer update. This is useful for finding the right balance between memory capacity and the effective batch size needed for stable model convergence. We’re not as focused on model convergence here, but it does also result in slightly better GPU utilization since we are not performing an optimizer update every single forward pass.
for step, batch in enumerate(train_loader):
# 1. Calculate loss and scale it down for accumulation
outputs = model(**batch)
loss = outputs.loss
loss = loss / GRADIENT_ACCUMULATION_STEPS
# 2. Accumulate gradients
loss.backward()
# 3. Perform optimizer step ONLY after enough steps
if total_dataloader_steps % GRADIENT_ACCUMULATION_STEPS == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
The Kubernetes Job Definition (Helm Chart)
Now that we understand the training script, let’s review how we actually submit the job. We will use Helm to manage the complexity of our Kubernetes deployment. This allows us to define our infrastructure as code and easily configure it. We can find the helm configuration under the training folder, with the core Job manifest being in templates/torch_job.yaml.
- Dynamic Configuration with
values.yaml: All job-specific settings—like the model ID, dataset name, hyperparameters, and NCCL settings—are defined in a centralvalues.yamlfile. This makes it easy to experiment with different configurations without changing the core Kubernetes manifests. These can be changed via cli as well, making running parallel or subsequent runs easy and reproducible. Feel free to explore with different variables (like enabling or disabling LoRA, or changing things like learning rate) and seeing how it impacts your training job and performance! - Orchestration with
torchrun: The job’s entry point istorchrun, the standard PyTorch utility for launching distributed jobs. The Helm template dynamically configures its arguments, such as--nnodes, using values from ourvalues.yamlfile. - Multi-Network Pods: The Job template includes a
networking.gke.io/interfacesannotation. This is a critical GKE feature that attaches multiple network interfaces to our pod, ensuring our GPUs have a dedicated, high-speed lane for RDMA communication. If we remember from earlier, this is how we are using and mapping rdma-0, rdma-1, etc all the way back to the physical underlying NICs - Kueue and DWS: We pass our Kueue and DWS integrations to via below to ensure that our job lands on the queue with the DWS hook to get our GPUs and that our job is suspended until the GPUs are available
metadata:
labels:
kueue.x-k8s.io/queue-name: dws-local-queue
spec:
suspend: true
template:
metadata:
annotations:
provreq.kueue.x-k8s.io/maxRunDurationSeconds:
One additional thing to call out is the NCCL configuration. Earlier, when we ran setup.sh we deployed nccl_installer.yaml on the cluster, which runs a daemonset to install NCCL and NCCL-gib on the underlying host VMs. We expose our NCCL configurations in our job container via:
spec:
volumes:
- name: library-dir-host
hostPath:
path: /home/kubernetes/bin/nvidia
- name: gib
hostPath:
path: /home/kubernetes/bin/gib
Which is critical for NCCL and RDMA to function. We also provide some optimized NCCL configurations for RDMA on GCP in the job manifest:
nccl_params:
# --- NCCL Performance settings ---
net: "gIB"
cross_nic: "0"
net_gdr_level: "PIX"
p2p_net_chunksize: "131072"
p2p_pci_chunksize: "131072"
p2p_nvl_chunksize: "524288"
nvls_chunksize: "524288"
ib_gid_index: "3"
ib_adaptive_routing: "1"
ib_qps_per_connection: "4"
ib_tc: "52"
ib_fifo_tc: "84"
tuner_config_path: "/usr/local/gib/configs/tuner_config_a4.txtpb"
Note that if you decide to run this with A3U machines (H200) the value of tuner_config_path changes to "/usr/local/gib/configs/tuner_config_a3u.txtpb"
Part 4: Running the Job and Conclusion
Now, let’s bring it all together and launch our distributed training job.
- Build and Push the Docker Image The repository includes a
Dockerfileto package the training script and its dependencies. We will use cloudbuild and the provided config files to build and push an image. Note that this can take some time since the base image is quite large.
cd $REPO_ROOT/torch
gcloud artifacts repositories create $REPOSITORY --repository-format=docker --location=${REGION} --project=${PROJECT}
gcloud builds submit . \
--project="${PROJECT}" \
--region="${REGION}" \
--config=cloudbuild.yaml \
--substitutions="_ARTIFACT_REGISTRY=${ARTIFACT_REGISTRY},_IMAGE_NAME=${IMAGE_NAME}" \
--timeout="2h" \
--machine-type="e2-highcpu-32"
- Launch the Job We use a Helm command to deploy our training job. This command installs (or upgrades) a release named
torch-training-jobusing the chart in the current directory. We use the--setflag to override the default image name and GCS bucket at runtime (since these are unique to each user) but you can override any of the provided defaults as well. We are also specifying a parallelism of 2 (to run on 2 VMs with 8 GPUs each). Note: The more GPUs you ask for, the longer you will likely wait in a queue.
cd $REPO_ROOT/training
helm upgrade --install torch-training-job . --set infra.nodepool_name="gpu-nodepool-dws" --set training_params.model_id="meta-llama/Llama-3.1-70B" --set training_params.per_device_train_batch_size=8 --set training.parallelism=2 --set image.name=${REGION}-docker.pkg.dev/${PROJECT}/${REPOSITORY}/${IMAGE_NAME}:latest --set fuse.bucket=${GSBUCKET}
The job will now be in Kueue’s queue, waiting for DWS to provision the nodes. You can monitor its status with kubectl describe job torch-job. When the job is ready we should see an output like below showing how our job was created, suspended, pending resource availability and then running after waiting for 4 minutes when the resources were available and the job was resumed. Because the state of DWS resources are transient, the queue/wait time will vary per run.
Events:
Type Reason Age From Message
---- ------ ---- ---- -------
Normal Suspended 10m job-controller Job suspended
Normal CreatedWorkload 10m batch/job-kueue-controller Created Workload: default/job-torch-job-d8689
Normal UpdatedAdmissionCheck 10m batch/job-kueue-controller dws-prov: Waiting for resources. Currently there are not enough resources available to fulfill the request.
Normal Started 6m57s batch/job-kueue-controller Admitted by clusterQueue dws-cluster-queue
Normal SuccessfulCreate 6m57s job-controller Created pod: torch-job-0-72jdl
Normal SuccessfulCreate 6m57s job-controller Created pod: torch-job-2-h548r
Normal SuccessfulCreate 6m57s job-controller Created pod: torch-job-1-zl6z4
Normal SuccessfulCreate 6m57s job-controller Created pod: torch-job-3-znc5l
Normal Resumed 6m57s job-controller Job resumed
Once the nodes are ready and the job is running, you can stream the logs from all pods to see the performance metrics in real-time.
kubectl logs -l app=torch-job -c job -f
You should see output similar to this after all of the initialization steps, showing the performance of the training job step-by-step:
[PerfLogger] Using 70.55B parameters for TFLOPs calculation.
--- Starting Epoch 1/10 ---
Step: 1 | Time: 27.02s | TFLOPs/s/GPU: 513.3 | Tokens/s/GPU: 1213 | Loss: 1.3697
Step: 2 | Time: 15.12s | TFLOPs/s/GPU: 917.7 | Tokens/s/GPU: 2168 | Loss: 1.3077
Step: 3 | Time: 15.28s | TFLOPs/s/GPU: 907.6 | Tokens/s/GPU: 2144 | Loss: 1.1527
Step: 4 | Time: 14.53s | TFLOPs/s/GPU: 954.9 | Tokens/s/GPU: 2256 | Loss: 0.9699
Step: 5 | Time: 14.54s | TFLOPs/s/GPU: 954.0 | Tokens/s/GPU: 2254 | Loss: 0.6620
Step: 6 | Time: 14.55s | TFLOPs/s/GPU: 953.7 | Tokens/s/GPU: 2253 | Loss: 0.6007
Step: 7 | Time: 14.52s | TFLOPs/s/GPU: 955.1 | Tokens/s/GPU: 2256 | Loss: 0.5132
Step: 8 | Time: 14.56s | TFLOPs/s/GPU: 952.4 | Tokens/s/GPU: 2250 | Loss: 0.5110
Step: 9 | Time: 14.54s | TFLOPs/s/GPU: 954.0 | Tokens/s/GPU: 2254 | Loss: 0.4839
Step: 10 | Time: 14.53s | TFLOPs/s/GPU: 954.4 | Tokens/s/GPU: 2254 | Loss: 0.4395
Looking at the output, we can confirm the success of our training job. First, the loss is consistently decreasing, indicating that the model is learning successfully. Second, and more importantly for this tutorial, the performance metrics are strong and stable. After an initial compilation step, the throughput stabilizes at ~955 TFLOP/s per GPU after a couple hundred steps.
This translates to a Model Flops Utilization (MFU) of approximately 42.5% on a B200 GPU in bfloat16. This is an excellent baseline result, as it proves that our entire distributed infrastructure from the on-demand provisioning with GKE and DWS to the high-speed RDMA networking and code level logic is working correctly and the primary goal of this tutorial was to build this foundation. Pushing the boundaries of SOTA performance beyond this point involves further code-level optimizations, which are great next steps for an advanced user.
If we want our training tutorial to output a checkpoint (the defaults we specify will result in no checkpoint being written), update the following variables in values.yaml or via your helm install command:
checkpoint_epochs: "X"
Now a checkpoint will be written for every X epoch. We can find our checkpoint written out to the GCS bucket ${GSBUCKET} during / after training if we do this.
We can clean everything up by running:
source $REPO_ROOT/.env
cd $REPO_ROOT/setup
./cleanup.sh
Take care to ensure that the cleanup script runs to completion with no errors, and validate that all of the resources for the tutorial are indeed deleted.
Conclusion
By combining GKE, DWS, Kueue and PyTorch, we’ve created a system for LLM training that is:
- Performant: Fully utilizing the power of modern compactly placed B200s and high-speed networking with FSDP and RDMA.
- Scalable and Cost-Effective: Leveraging DWS and Kueue to build an on-demand supercomputer that you only pay for when it’s running
- Manageable: Using Helm to abstract away the complexity of deploying and configuring distributed training workloads.
- Accessible: We did all of this without needing a reservation or a GPU commitment, using a shared queue with our fellow ML users
It’s important to understand the cost of this tutorial (or using DWS in general) so we know what we are getting into, the price for running DWS for B200s (This tutorial supports H200 as well) can be found here among other GPU types. Each training step in this tutorial (as we ran it) cost us about $0.5198/step, so if we wanted to fine tune a model for 100 steps it would cost us about $50 USD.
One extra note, while this tutorial focuses on DWS Flex-start mode for burst use, DWS also offers a calendar mode if you want to reserve capacity from 1 to 90 days. This gives you the flexibility to run on a pre-defined schedule if your job will take at least 24 hours and can also get you access to a much larger amount of GPUs without requiring a longer commitment.
To recap this tutorial, we have provided a simple and accessible blueprint for any organization looking to train large-scale models on Google Cloud, democratizing access to the cutting-edge infrastructure required for modern AI development.

