Written by Ashish Narasimham & Zhenzhen (Jen) Tan, Members of the AI Infrastructure Technical Practice Community
Since the arrival of Large Language Models, deep learning has been front and center for the past few years. This begs the question - Is there still a place for those “other” models? How do we navigate the algorithms that preceded this latest innovation? In parallel, organizations are amassing more and more data, with an eye towards using that data to train models and gain a highly-accurate inference machine for targeted use cases. However as the data volumes grow, scaling up reaches its limits. What do you do when this data stops being able to fit on a single machine?
The answer isn’t to abandon XGBoost for a more complex deep learning model. The answer is to scale it. In this post, we’ll show you how to take your familiar XGBoost workload and scale it to handle terabytes of data using Dask for distributed computing and Google Kubernetes Engine (GKE) for flexible, on-demand infrastructure. We’ll cover some of the nuances of the techniques used here, and provide a working example for you to try.
Let’s start with understanding the technology stack and answering some important questions.
Should we use decision trees when deep learning models abound?
XGBoost, a tree-based gradient boosting library, trains an important, yet less highlighted model today in the exciting world of LLMs and the deep learning innovations that come out at such a rapid pace. Despite all of this focus on deep learning, in practice many organizations use decision trees with great accuracy at reduced training cost. Trees train faster; are easier to understand; and are more compact which results in better loading/inference times among other benefits.
The benefits of deep learning models are well-known, but when might you steer away from that direction? With deep learning, data needs are high as model architectures become more complex; the model type may be overkill for the problem statement; or the incremental compute needs may be too high to justify appropriate ROI, especially with periodic retraining.
XGBoost and Dask have a tight integration, which enables effortless transition from your data processing pipeline into model training. Read the data across the network, load into CPU or preferably GPU memory, pre-process the data, and then execute the single-line xgboost.train method with your hyperparameters.
Leverage Dask to scale out your data pipeline
Dask is an open-source parallel computing library that scales your existing Python tools like Pandas and NumPy across multiple machines. When combined with the RAPIDS ecosystem, Dask can fully utilize GPUs. It uses the CPU, host RAM, GPU, and device RAM to intelligently coordinate data load from persistent storage into host and device memory, run your data processing pipeline in a distributed context, and then hand off to your model training framework of choice. Some are more integrated than others - XGBoost, which is what we’ll leverage here, has deep integrations with Dask and can fully utilize the distributed context to run training iterations across the cluster.
You can use popular frameworks like PyTorch as well to run deep learning model training, but the distributed contexts of PyTorch and Dask do not know about each other by default. There are libraries that facilitate this handoff and you’ll need to stitch those together yourself, or use persistent storage as the handoff point.
GKE: Flexible, scalable, fast
The role of Google Kubernetes Engine here (and in many places) is as an automation layer for the Dask cluster configuration.
In a virtual machine environment, we’d have to provision another VM, install drivers and dependencies, and validate correct functioning of NCCL and other prerequisites before using the machine for testing.
With GKE, we can:
-
Define-once, deploy-many - for complex setups like NCCL configuration
-
Get pre-configured functionality, like side car containers for GCSFuse on each worker
-
Leverage abstractions such as deployments+services for static communication paths
-
Conduct multiple tests across configurations, scaling out with a simple YAML change
-
Setup/tear down with a single-line command
-
Use ephemeral deployment types such as Dynamic Workload Scheduler Flex-start (DWS Flex-start) to reduce costs for short-term workloads
GKE consumption options to meet your needs: DWS Flex and Queued Provisioning
Deployment options fall on a spectrum from ephemeral to committed, offering varying levels of cost relative to the level of commitment made.
| Provisioning Model (by increasing pricing and guarantees) | Capabilities |
|---|---|
| DWS Flex-Start | Flexible start, Guaranteed end with 7-day maximum |
| DWS Calendar | Guaranteed start, Guaranteed end with 90-day maximum |
| Spot | Immediate start, Flexible end with 1-day maximum |
We will choose DWS Flex-Start due to the short, flexible, and guaranteed nature of the work we are doing. Training workloads are well-suited to DWS Calendar as well since they are typically periodic and can fit within the required timeframe.
To optimize useful time for our large-scale XGBoost workload, we pair DWS Flex with queued provisioning in GKE. Queued provisioning ensures “atomic gang scheduling”—guaranteeing that all requested nodes for our distributed cluster are provisioned simultaneously. This ensures that we aren’t paying for partially available resources that are unusable until the entire cluster is ready for training. In GKE, we will create a separate node pool configured for DWS Flex with queued provisioning.
Queued Provisioning is accomplished using Kueue, which automates creating the provisioning request to obtain the required machines and then resumes the GKE resource request once the machines are added to the node pool. With MultiKueue you can even scale out across clusters, load balancing your workloads and choosing the best cluster for the job. This optimizes fleet utilization and helps your entire organization maximize GKE usage.
Example - Training an XGBoost model on Dask
Now that we have an understanding of the layers and technologies, let’s put it to practice and deploy the full stack!
Here’s the architecture we’ll be reviewing and then deploying:
We will approach this in two phases - data generation and model training. In a real-world scenario, data will already exist; we’ll only need to pre-process it. Here, we are both generating and pre-processing as the flow above shows.
Note: for this example, we generate new data for each run, so we don’t persist the final model. In a production workflow, you’d add a final step to save your model to disk.
Overall, we will follow this flow to run both generation and training code:
-
Start the scheduler node via GKE manifest
-
Start the worker nodes via GKE manifest
-
Exec into the scheduler node and start the job
-
Monitor job logs until job completion
-
Exit the exec’d container environment and delete the GKE deployments
Here’s the Github repository with the full code for the example: . Make sure to review the README file if you plan to replicate this testing.
Generation
We’ll generate a 1TB dataset to simulate a real-world, large-scale problem. We’ll use a 3 (2 workers, 1 scheduler) node cluster to process this data. The reader can edit this value up or down to meet their scaling needs.
We are generating a multi-cluster dataset with the make_classification method so that our boosted decision tree model can learn important features within the dataset and reduce loss. You can alternatively generate random data which will simulate train times well but will not allow the algorithm to improve accuracy.
Pro tip: in order to optimize storage/data load, we store the features and labels together in Google Cloud Storage. Why? Repartition operations within Dask can move features and labels to different nodes of the cluster, causing errors when it comes to train time. With this technique, each file contains one partition of both features and labels.
Training
To train, we first have to load the dataset into HBM.
-
Our data is located on Google Cloud Storage, so we can take advantage of the 8 workers per node to parallelize the partition reads
-
We reached 60-80 Gbps using the two-node setup here; this number can scale linearly with additional nodes up to a certain point, because the scheduler will assign the generated partitions stored in GCS to each of the workers independently to read
-
We introduce a synchronous wait step post-data load. We found it is both more efficient and removes OOM errors that are otherwise encountered when Dask tries to operate on multiple parts of the execution graph at once - loading data, splitting the data into X’s and y’s, and building the Quantile DMatrix are all memory intensive operations
-
Splitting the data into X’s and y’s is not a trivial operation! This requires twice the memory at peak since we have to have two copies of the data before deleting one. Dask is able to leverage host memory as a temporary storage location, split the data, and then copy it back to device memory before beginning training
The generated dataset has two clusters and two columns of informative data; this provides a non-trivial and interesting dataset that the model needs to intelligently learn from, reduce loss, and come to a good accuracy at the end of training.
We were able to train the XGBoost model across 2000 iterations, with the 1 terabyte of data distributed across two worker nodes in the cluster.
The memory usage is as follows:
-
The cluster HBM totals 2.89TB - 181GB per B200 chip * 8 chips * 2 nodes
-
The data takes up about 1 TB of that, probably a bit more with additional metadata that the application adds
-
XGBoost histograms take up an additional significant chunk of it, around 1.3TB by observation
Therefore a 1TB dataset+computation almost fills the 2.89TB HBM of a 2-node cluster! This may be a surprise, but working memory is no small component of a model training process - as you can see, it comes out to be more than the dataset itself.
Next steps & optimizations
There are a few areas to work on to turn this example into a production deployment:
-
Real data - the nuances and intricacies of real data cannot be replicated by synthetic data generation. Even using a public dataset as an intermediate step could help to reveal additional interesting areas of modification
-
Better utilization of the scheduler node - the scheduler node has 8 GPUs on it just like the worker nodes, but the activity of scheduling does not use any of those GPUs. Creating a slightly smaller worker pod on that node could be a good way to utilize resources better
-
Incremental optimization of HBM - we can probably fit a slightly larger dataset into HBM and still proceed with training without Out Of Memory errors, we’d need to see how large we can make the dataset while still allowing room for model working memory. We attempted training with 1.25TB of data, but received OOM errors - the working memory demands were too much!
Recap & conclusion - what did we do
We’ve now end-to-end generated data, trained an XGBoost model, and seen some validation results - that was a lot!
Let’s recap:
-
Google Cloud Storage has an extremely high-bandwidth pipe to your compute instances, enabling fast data load from disk into memory
-
The NVIDIA B200 GPU is a powerful chip that enables massively parallel workloads with even just a 2-node A4-highgpu-8g cluster
-
GKE enables flexible deployment options to balance permanence with cost
-
GKE’s orchestration allows us to spin up/tear down entire clusters quickly, deploying when needed and shutting things down when done
-
Dask then coordinates a multi-node GPU-accelerated cluster, enabling distributed computation including data processing and model training
-
XGBoost integrates deeply with Dask, providing an out-of-the-box distributed model training experience on a Dask cluster
Deep learning is the top area of innovation today, but XGBoost and decision trees continue to be a high-accuracy group of algorithms that have a fundamental place within the enterprise due to their high ROI and fast training times.
So reader - what are you waiting for? Leverage this cloud-native tech stack to get started - clone the git repo, deploy the Kubernetes manifests, and start your data generation!
All code available in the GitHub repository.

