In the world of machine learning, training a model is just the beginning, we also need to figure out how to serve predictions (inference). Broadly, inference workloads fall into two camps: real-time (or online) and offline batch.
Real-time inference is what powers the interactive web. Think of a chatbot answering your question, a recommendation engine personalizing your homepage, or a fraud detection system flagging a transaction—all requiring immediate, low-latency responses. These systems use always-on endpoints, ready to process single requests in milliseconds.
But what happens when the need isn’t for an instant response, but for massive throughput on a large dataset?
- An insurance company needs to classify a year’s worth of millions of property damage photos.
- A media house wants to analyze its entire petabyte-scale video archive for specific content.
- A research institute must process terabytes of satellite imagery to detect environmental changes.
For these scenarios, using a real-time endpoint is inefficient and prohibitively expensive. The right way is offline batch inference. This paradigm prioritizes high throughput and cost-efficiency over low latency, making it ideal for large, non-interactive jobs.
The challenge, however, is building a system that can spin up a massive fleet of GPUs, feed them data at line-rate without bottlenecks, run the job efficiently, and tear it all down to control costs.
In this post, we’ll build exactly that: a purpose-built, high-performance offline batch inference pipeline on Google Cloud. We’ll tackle the key challenges head-on:
- Data I/O Bottlenecks: How do you efficiently feed terabytes of data, often stored as millions of individual files in a cloud bucket, to a fleet of GPUs without them sitting idle?
- Cost and Resource Management: How do you get access to powerful, expensive accelerators like NVIDIA H200 GPUs precisely when you need them, without paying for them when you don’t?
- Scalable Performance: How do you ensure that communication between hundreds of GPUs is blazing fast, so your distributed job isn’t bottlenecked on networking?
To demonstrate, we’ll build a scalable offline batch inference pipeline using PyTorch on Google Kubernetes Engine (GKE), leveraging a few key technologies designed for this exact purpose:
- Google Cloud’s DataFlux: For high-throughput, parallel data loading directly from Google Cloud Storage (GCS).
- Dynamic Workload Scheduler (DWS) with Kueue: For flexible, on-demand access to powerful A3 Ultra (H200) virtual machines.
- High-Performance Networking: Using Google’s RDMA over Converged Ethernet (RoCE) v2 for performant collective operations.
As a practical example, we’ll use the well-known COCO image dataset and classify images with Google’s Vision Transformer (ViT) model. However, this architecture is highly flexible and can be adapted for any PyTorch code or model and can be applied to varied use cases (Like training).
The solution stack: A high-level view
Before diving into the code, let’s take a look at the high-level architecture and how these components fit together.
As the diagram illustrates, a user submits a batch job to a GKE cluster using the Kubernetes Job API. Kueue, a job queueing controller, intercepts the job and holds it until the necessary resources are available via Dynamic Workload Scheduler (DWS). Once the nodes are ready, the PyTorch workers instantiate on the GPUs, read images from Google Cloud Storage via Dataflux for GCS, process them using the ViT model from Hugging Face, and output the classification results back to GCS. The network communication between the GPU nodes is accelerated by Google’s RoCE (RDMA over converged ethernet) implementation. Once the job is done, DWS tears down the GPU resources.
The tutorial: Step-by-step implementation
This tutorial is divided into several parts, walking you through setting up the environment, understanding the infrastructure, preparing the data, and finally, running the PyTorch job.
Part 1: Environment setup
Getting your environment ready is straightforward. The provided repository includes a setup script that handles the heavy lifting.
- Clone the Repository:
git clone https://github.com/ai-on-gke/tutorials-and-examples.git
cd tutorials-and-examples/batch-inference-pytorch/standard/setup
- Configure Your Environment: Review the environment variables in the .env file and customize them as needed.
- Run the Setup Script
./setup.sh
- This script will create the necessary VPCs, set up the GKE cluster with a DWS-enabled node pool, install Kueue and NCCL, and configure all required permissions and service accounts.
Note that while our PyTorch code is geared towards an offline inference example for simplicity, everything we have setup here will work for other distributed batch jobs with GPUs like model training as well!
Part 2: Understanding the infrastructure
The setup.sh script automates the creation of several key components. Let’s explore some of the most important ones.
High-performance network foundation
To achieve peak performance for distributed GPU workloads, a specialized network topology is crucial. We create a VPC specifically engineered for high-bandwidth, low-latency GPU-to-GPU communication. By enabling RDMA (Remote Direct Memory Access) over Converged Ethernet (RoCE), we allow network cards to transfer data directly between the memory of different nodes, bypassing the main CPUs and drastically reducing latency.
GKE with Dynamic Workload Scheduler (DWS) and Kueue
For our compute infrastructure, we utilize GKE with DWS and the open-source job scheduler Kueue. This combination allows for:
- Queued Provisioning: Your job waits in a queue until the exact number of nodes it needs are available.
- Scales from Zero: The node pool is configured with zero initial nodes and autoscaling enabled, meaning you only pay for resources when a job is running.
- Job Stability: Autorepair and auto-upgrades are disabled to prevent interruptions to long-running jobs.
Configuring Kueue for DWS
Kueue is configured to use GKE’s DWS for provisioning. A ProvisioningRequestConfig custom resource tells Kueue to manage nvidia.com/gpu resources through DWS. A ClusterQueue defines a global queue, and a LocalQueue provides the user-facing entry point for job submissions.
Part 3: High-throughput data access with DataFlux
A common bottleneck in large-scale data processing is data loading. Reading millions of individual small files from cloud storage can be slow. This is where DataFlux for GCS comes in. It’s a purpose-built PyTorch Dataset abstraction designed to accelerate data loading from Google Cloud Storage, offering up to a 3.5x improvement in training times compared to alternatives.
DataFlux achieves these performance gains through:
- Fast Parallel Listing: A sophisticated work-stealing algorithm dramatically speeds up the initial listing of files.
- Dynamic Object Composition: Instead of fetching thousands of small files individually, DataFlux dynamically composes them into larger temporary objects on the fly, minimizing latency and maximizing throughput.
- Seamless PyTorch Integration: DataFlux is wrapped in a familiar PyTorch Dataset primitive, requiring minimal changes to your existing code.
Part 4: The PyTorch job
The entire workload is defined in the batch_inference.yaml file. Let’s break down the key parts.
Metadata storage via GCS Fuse
We use the GCS Fuse CSI driver to mount our GCS bucket directly into the pod as if it were a local directory. A PersistentVolume points to our GCS bucket with optimized mount options for caching and parallel downloads.
The Kubernetes job definition
The Kubernetes Job is configured with Kueue integration and multi-network interfaces, ensuring each GPU gets a dedicated RDMA lane. The job also includes fine-tuned NCCL environment variables for GCP to leverage the fast RDMA fabric for communication.
The PyTorch code logic
The Python script implements a distributed, multi-GPU batch inference pipeline using torch.distributed.DistributedDataParallel (DDP) which will process our data in parallel but create a replica of our model on each device. During the backward pass (in training), it uses the network to communicate and average gradients—a relatively small amount of data. It’s important to call out that for a small model in an inference-only job (no backward pass), inter-GPU communication is minimal and hence not really using the RDMA we set up. However, for training or for inference with models too large to fit on a single device, torch.distributed.fsdp.FullyShardedDataParallel
(FSDP) becomes essential. FSDP shards the model itself across GPUs and requires constant, high-bandwidth communication to exchange model parameters and activations, taking full advantage of the RDMA fabric we set up. We use DDP in this example for simplicity and demonstration purposes.
The script also seamlessly integrates DataFlux with PyTorch’s native data loading tools. Each distributed rank processes a unique shard of the data and writes its results to a separate CSV file in the mounted GCS bucket.
Part 5: Running the job and conclusion
Now, it’s time to put it all together and run the job. From the root of our tutorial:
tutorials-and-examples/batch-inference-pytorch
Build and push the Docker image
The repository includes a Dockerfile to package the application.
# Authenticate Docker with Google Artifact Registry
gcloud auth configure-docker ${REGION}-docker.pkg.dev
# Build the image
docker build -f Dockerfile -t ${REGION}-docker.pkg.dev/${PROJECT}/torch-images/torch-ultra-job:latest .
# Push the image
docker push ${REGION}-docker.pkg.dev/${PROJECT}/torch-images/torch-ultra-job:latest
Get the dataset
Download the COCO dataset and upload it to your GCS bucket. Note that we’re using a rather small dataset for the purpose of this example so it finishes quickly, but you can scale this to TBs of data.
wget http://images.cocodataset.org/zips/val2017.zip
unzip val2017.zip
python3 upload_images.py
Launch the job
Apply the job manifest to your cluster.
cd standard/
envsubst < batch_inference.yaml | kubectl apply -f -
The job will be suspended until DWS provisions the necessary nodes. You can monitor its status with kubectl describe job torch-inference-job. Once running, you can view the logs from all processes via kubectl logs -l app=torch-inference-job -c job -f as they classify the images and write the results to your GCS bucket.
The final output will be a series of CSV files containing the image IDs and their corresponding classifications. We can find those CSV files located in GCS at gs://$GSDATABUCKET (Which we specified in our .env file at the start).
A few lines from a sample job output file looks like this:
image_id,classification
413621,"teddy, teddy bear"
54259,"home theater, home theatre"
311496,balloon
Conclusion
By combining the power of GKE, PyTorch, DataFlux, DWS, and Kueue, we’ve created a offline batch system that is:
- Performant: Using high end GPUs and tackling both data I/O and network bottlenecks
- Scalable: Easily scaling with DWS Flex and autoscaling
- Cost-effective: Paying for powerful GPUs only for the exact duration of the job
This architecture provides a robust blueprint for any large-scale offline processing task, bridging the gap between cutting-edge hardware and the practical challenges of capacity management. The principles and components discussed here can be extended to various PyTorch applications and other machine learning frameworks, empowering you to tackle even the most demanding batch processing workloads with efficiency and scale.