Optimizing frontier model training on TPU v7x (Ironwood)

Authors: Hina Jajoo, Parmita Mehta, Carlos Bustamante, Dipak Gaikwad, Harsh Shah

As generative AI models scale to massive sizes, the demand for compute power is outpacing traditional hardware capabilities. While next-generation accelerators like the seventh-generation Ironwood TPU offer massive theoretical compute, raw power alone doesn’t guarantee faster training times. The real challenge for AI teams is bridging the gap between theoretical hardware peaks and actual, realized performance — without paying a massive “engineering tax” to migrate workloads and optimize for a completely new architecture.

To help you get the most out of Ironwood TPUs from day one, we proactively optimized the most critical training workloads for its unique architecture. By diving deep into the hardware’s unique architecture and utilizing new features in MaxText and Tokamax, we’ve unlocked high-throughput training benchmarks — meaning faster, more efficient model training for your business.

In this post, we share the exact optimization techniques our ML performance engineers use, so you can maximize Ironwood’s performance right away. For a deeper dive into the hardware, check out the TPU cloud documentation.

Components of training performance optimization

Let’s take a deeper look at the various components of Ironwood’s architecture that you need to understand to tune model training performance.

Taking advantage of the architecture

1. Utilizing the memory hierarchy

Managing data movement between Ironwood’s multi-tiered memory system is a crucial element of managing performance. Ironwood features high-bandwidth memory (HBM), vector memory (VMEM), and host memory with the following characteristics:

  • HBM: Each chip features 192 GB of HBM, a 6x increase over Trillium. Peak bandwidth is 7.38 TB/s. While vast, HBM can still be a bottleneck for memory-bound vector operations or inefficient data access.
  • Vector Memory (VMEM): VMEM is a smaller, on-chip SRAM with significantly higher bandwidth to the MXU than HBM. It acts as a high-speed scratchpad for custom kernels.
  • Host memory and PCIe: Each set of four TPU chips connects to a CPU host via PCIe. The host’s main memory can be used for offloading activations or optimizer states to free up HBM.

Interconnect fabric and arithmetic intensity

Arithmetic intensity (AI) is the ratio of peak FLOPs to communication bandwidth.
For Ironwood, the One Dimensional AI is very high, approximately 11,500. This means it can perform a large amount of computation for each byte of data moved. When performing performance tuning, focus on minimizing or hiding data movement so that the MXUs aren’t left idle waiting for data.

2. Utilizing SparseCore

SparseCore is a unique component of TPUs, a processing unit engineered for high-performance acceleration of workloads that involve irregular, sparse memory access and computation. One of the ways you can utilize SparseCore for large-scale model training on Ironwood is to offload collective computation to it. This allows collective communication operations (like All-Gather or Reduce-Scatter) to execute asynchronously with the main computations happening on the TensorCores. Using specific XLA flags enables this offloading for the most common collectives.

3. Design for architectural alignment

Achieving peak performance on specialized hardware such as Ironwood requires designing hardware-friendly model architectures. Performance tuning starts with model definition, as architectural choices can set permanent performance limits. Here are some details of the Matrix Multiplication Unit (MXU) that should be kept in mind while designing models to be performant on TPUs.

  • Architectural specification: The Ironwood MXU is a 256x256 systolic array, which is efficient when the contracting dimension is larger than or equal to a multiple of 256.
  • MXU utilization: Models whose head dimensions are a multiple of 256 will be able to utilize the MXU fully and see high Model Flops Utilization (MFU) on the attention blocks. For models with head_dim 128 or 64, for the QK product in flash attention, we see 50% or more underutilization of MXU, and recommend using other techniques to compensate for it.

Balancing compute and memory utilization

The next challenge in performance optimization is managing the trade-off between compute and memory efficiency. This involves selecting appropriate sharding strategies and techniques like activation rematerialization to optimize resource use.

1. Finding optimal sharding strategy

Choosing the right sharding strategy is essential. A guiding principle is to select the simplest strategy that meets memory constraints, as this typically minimizes communication overhead. Before selecting a strategy, perform a roofline analysis to determine whether a given computation is limited by compute, memory bandwidth, or interconnect bandwidth.
Here are some common sharding strategies:

  • Fully Sharded Data Parallelism (FSDP): This is the preferred strategy for large-scale model training that exceeds the memory capacity of a single chip. FSDP shards the model’s weights, gradients, and optimizer states. Increasing the per-device batch size improves efficiency by introducing more compute, which can hide the latency of the All-Gather operations it introduces.
  • Tensor Parallelism (TP): TP shards individual tensors. Ironwood’s high AI (11.5k) requires an MLP dimension greater than 46k (for TP degree 4) to be viable over ICI. Most open source models like Llama3 70B (MLP dimension 28,672) and Qwen 2.5 7B (MLP dimension 18,944) fall short, and using TP here would result in the system becoming communication-bound.
  • Expert Parallelism (EP): This can be a helpful sharding strategy for training Mixture of Experts (MoE) models. EP shards the “expert” layers across a set of devices (a device contains only a subset of experts), and an All-to-All communication collective is used to route tokens to their designated expert device.
  • Context Parallelism (CP): CP is essential for long sequence lengths. It shards the sequence dimension of activation tensors, allowing for a fractional per-device batch size. Because CP introduces more communication than FSDP, the rule of thumb is to use the minimum degree of CP necessary.

2. Activation rematerialization

Rematerialization reduces HBM footprint by discarding activations and recomputing them during the backward pass. While it saves significant amounts of memory, it incurs ~25-30% additional FLOPs.

MaxText provides granular control over these trade-offs via the remat_policy flag. Beyond presets like full (maximizes memory savings) and minimal (maximizes training speed), users can implement custom policies. This allows you to specify behavior for individual layers:

  • device: Store the activation in HBM.
  • remat: Recompute the activation during the backward pass.
  • offload: Move the activation to CPU host memory via PCIe to free up HBM without the compute cost of recomputation.

Leveraging kernels optimized for Ironwood

While architecture provides the foundation, achieving maximum performance requires optimizing the computational routines themselves. This section details how leveraging specialized kernels and fine-tuning the memory pipeline enable us to hit peak utilization on Ironwood.

1. Leveraging Tokamax kernels

To address hardware-specific bottlenecks, we recommend utilizing Tokamax, a high-performance JAX kernels library, with many highly optimized TPU kernels. The following are the specific high-performance kernels and tuning methodologies we utilized to maximize Ironwood’s MXU utilization and mitigate memory bottlenecks.

  • Splash Attention: Splash Attention is used as the primary attention implementation to eliminate the HBM bottleneck of standard attention and use the most efficient attention implementation on TPUs.
  • Megablox Grouped Matrix Multiplication (GMM): For MoE workloads, Megablox efficiently handles grouped matrix multiplications by computing over the ragged activations representation. It efficiently maps over the ragged dimension, computing matrix multiplications between ragged groups of rows, and the corresponding expert matrix, avoiding the need to pad batches to a fixed size.
  • Empirical tuning with tune-jax: Tokamax library has utilities that use tune-jax to perform empirical searches for optimal block sizes. Default kernel tile sizes are often suboptimal; tuning allows choosing hardware friendly VMEM tile sizes (as well as other hyperparameters) to maximize hardware utilization.

2. Memory pipeline tuning

Kernel performance, like flash attention, depends on the selected tile sizes in the kernel, whose size is limited by the total available VMEM (on-chip SRAM). Ironwood chips have 64 MB of VMEM, which can be split between the current scope (scoped VMEM) and future weight prefetch. Increasing the VMEM reserved for the current scope allows increasing the tile sizes used by the kernel, potentially removing memory stalls and increasing kernel performance (for example, block_q, block_k). You can control the scoped VMEM size by setting xla_tpu_scoped_vmem_limit_kib (in LIBTPU_INIT_ARGS). Further, experimenting with this setting allows you to explore kernel performance as well as end-to-end performance limits. Optimizing scoped VMEM size improves custom Pallas kernel performance.

Case studies: Detailed optimization profiles

We ran pre-training benchmarks for both custom models and common OSS models on Ironwood. We conducted these benchmarks using a 4x4x4 configuration (64 chips) to evaluate performance across the 3D Torus topology. Let’s take a look at the results.

Case study: Dense LLM (< 20B parameters) – short context (8k)

In this regime, the workload is primarily compute-bound. The objective is to keep the MXUs fully saturated and minimize TensorCore idle time.

  • Taking advantage of the architecture - Sparsecore offload: TensorCores’s MXU units often sit idle while waiting for communication collectives (like All-Gather or Reduce-Scatter) to complete. By offloading these to the SparseCore, we freed TensorCores to focus on MXU operations and achieved near-perfect overlap between communication and computation. Result: 22% decrease in step time.
  • Balancing compute and memory utilization - Sharding with FSDP: The model was large enough that the weights, gradients and optimizer states would not fit into memory, making sharding necessary. FSDP gave us the best performance as it is designed to overlap communication with computation more efficiently.
  • Leveraging kernels optimized for Ironwood - Tokamax Splash Attention and kernel tuning: We replaced standard attention with Splash Attention. We chose this because the naive attention implementation is heavily IO-bound due to the large attention matrix; Splash Attention keeps these computations in SRAM. Because default block sizes often lead to either memory stalls or poor compute units overlap, we used tune-jax to find the exact “sweet spot” for Ironwood’s SRAM. Result: 12% decrease in step time.

Case study: Dense LLM (< 20B parameters) – long context (128k)

At a context length of 128k, the workload can exhaust total memory capacity. Activation memory grows with sequence length, making out-of-memory (OOM) errors the primary hurdle. Here’s how we approached this benchmarking job:

  • Taking advantage of the architecture - Sparsecore Offload: By offloading All-Gather and Reduce-Scatter operations to the SparseCore, we ensured that the communication required for TP and CP did not stall the MXUs. This overlap is critical as communication volume increases alongside activation size. Result: 5% reduction in step time.
  • Balancing compute and memory utilization - Hybrid Parallelism (FSDP16 + TP2 + CP2): During initial testing, we found that the largest per-device batch size whose activations could fit in memory was 0.25. To handle a full batch, we needed to utilize Tensor Parallelism (TP) and/or Context Parallelism (CP) with a combined degree of 4 to split the batch. We experimented with various configurations (such as CP4) but found that a hybrid approach of CP2 and TP2 delivered the best performance. We chose TP2 specifically to align the workload with Ironwood’s dual-chiplet architecture. This allows the frequent communications to occur over the internal die-to-die (D2D) interface — which is 6x faster than the standard ICI. By balancing the communication load between the D2D interface and the ICI fabric, we achieved a 4% performance improvement compared to using CP4 alone.
  • Leveraging kernels optimized for Ironwood: Max logits estimate - The Tokamax Splash Attention kernel can further be optimized by setting a value for max_logit_const. If set, it replaces the reduction calculation of the max logit during the softmax operation of attention (softmax(Q * KT)), reducing some computations, and synchronization overhead. In Maxtext, it is implemented by the config use_max_logits_estimate which can be set to None (disabled) or a floating point value. We used this parameter to get a 4% reduction in step time.

Case study: MoE 110B – short context (8k)

Training a 110B MoE model introduces unique structural inefficiencies because tokens are routed to specific “experts,” creating “ragged” batches where experts receive an imbalanced number of tokens. Consider the following benchmarking strategy:

  • Taking advantage of the architecture - Sparsecore offload: We leveraged SparseCore offloading to handle the heavy communication requirements of the MoE architecture. This was essential to ensure that the communication collectives required for the expert layers did not stall the MXUs. Result: 15% decrease in step time.
  • Balancing compute and memory utilization - Sharding using FSDP: We started to experiment with a hybrid approach of EP and FSDP. However we saw that the All-to-All collective used in Expert Parallelism caused a large bottleneck which we were unable to optimize. We got the best performance using FSDP for this model.
  • Leveraging kernels optimized for Ironwood: Tokamax GMM kernel - Standard dense matrix multiplication kernels often struggle with uneven shapes, either wasting FLOPs on padding or dropping tokens. We employed Megablox because it performs only the necessary work for each expert using parallel dense GEMMs, without wasteful padding. Using tune-jax further optimized the tiling strategy for Ironwood’s SRAM and model dimensions. Result: 10% decrease in step time.

Get started

7th generation Ironwood TPUs are available for your frontier model training workloads. To learn more and get started:

  • Explore Maxtext: Access our open source reference implementation for highly performant models in JAX.
  • Experiment with Tokamax kernels: Use our high-performance JAX and Pallas kernels library to address hardware-specific bottlenecks and optimize attention and MoE workloads.
  • Deploy optimized training recipes: Use these ready-to-use optimized recipes to understand techniques used to run common OSS models on Ironwood efficiently.
3 Likes