Inside the optimization of FP8 training on Ironwood

Authors: @Amanda_Liang @Parmita_Mehta

Amongst many other innovations, Google’s Ironwood chips are our first TPUs to support 8-bit floating point (FP8) precision, helping to accelerate AI training and inference by reducing memory usage and doubling throughput compared to 16-bit formats (FP16/BF16). Importantly, this lets you improve the throughput while maintaining model quality in a way that is statistically similar to using higher-precision BF16 baselines.

Central to this capability is Ironwood’s native integration of FP8 formats directly within its Matrix Multiply Units (MXUs). Unlike rigid integer quantization, this architecture allows the silicon to process specialized numerical representations tailored for specific deep learning tasks. This flexibility enables the system to prioritize precision for weights and activations while allocating wider dynamic range for gradients, effectively unlocking aggressive quantization techniques such as coarse scaling and deterministic rounding that are typically infeasible with integer-only math.

We implemented many capabilities and optimizations into Ironwood to enable FP8 for your workloads. Some include:

  • Production-ready recipes, such as those used for DeepSeek v3, demonstrate how to achieve highly efficient training on the new Ironwood architecture, so you can optimize performance.

  • Specialized FP8 formats like E4M3FN for forward passes, and E5M2 for backward passes, which preserve dynamic range and ensure numerical stability without compromising accuracy.

  • Advanced tuning capabilities, such as activation host offloading and SparseCore communication offloading, keep TensorCores fed and hide system collectives, mitigating system bottlenecks.

You can begin using these optimizations immediately with MaxText and the JAX ecosystem. Our FP8 training recipe is implemented through Qwix and can be enabled by specifying specific flags in the MaxText configuration — details below. You can consult the Qwix user guide for custom guidance on quantizing your specific models.

Below is a deep dive into the journey we took and technical decisions we made to build this stack. After reading this blog, you will have techniques to effectively train your models with FP8 on Ironwood.

The journey from INT4/8 To FP8

To implement the first FP8 recipe on TPUs, we developed new scaling and quantization strategies specifically for Ironwood, rather than relying on established GPU or integer-based TPU workflows. These strategies are available for you to use as well.

We also optimized several state-of-the-art models for this architecture, which are available as production-ready recipes. Our goal was to improve the throughput using Ironwood’s FP8 capabilities while maintaining quality neutral to BF16 baselines. Using DeepSeek v3 as a case study, below we demonstrate a production-ready methodology for FP8 training on Ironwood.

A case study on DeepSeek v3 training

We explored and developed a customized FP8 recipe for DeepSeek v3. You can repurpose the learnings from this case study when developing your own FP8 models for Ironwood.

1. Defining the quantization scope

For Sparse Mixture-of-Experts (MoE) architectures, profiling consistently reveals that the highest quantization impact lies in the MLP and attention projections, where matrix multiplication dominates the runtime. To improve the efficiency in these bandwidth-heavy models, it is also critical to quantize Megablox kernel weight all-gathers, reducing significant communication and quantization overhead.

We also explored quantizing splash attention kernels, but quickly discovered that this approach led to significant quality degradation, even when applying the most conservative FP8 recipes. Furthermore, the performance goals could not be met because the kernels were heavily constrained by the Vector Processing Unit (VPU), which handles complex element-wise operations, rather than the MXU where FP8 provides its acceleration. As a result, converting the matrix operations to FP8 yielded no meaningful latency reduction, as the VPU bottleneck remained the dominant limiting factor.

2. DeepSeek FP8 training recipe

The final recipe for DeepSeek V3 training on Ironwood achieves the best of both quality and performance:

  • Rounding method: Round to Nearest Even (RNE). We chose this deterministic approach over stochastic rounding to ensure reproducibility and eliminate training noise.

  • Precision formats:

    • Activations & weights: E4M3FN (to maximize precision in the forward pass)

    • Gradients: E5M2 (to capture the high dynamic range of the backward pass)

  • Scaling granularity: per-axis. While per-tensor scaling was explored for performance, per-axis scaling was selected for the final launch to guarantee the highest model quality. The original DeepSeek papers selected per block, utilizing a block size of (1,128) for pretraining, but it was shown not to be needed in post-training.

  • Scaling mode: Hybrid.

    • Static scaling for weights and activations (pre-computed via profiling)

    • Dynamic scaling for gradients

  • Quantization scope:

    • FP8 weight All-Gather

    • All Megablox kernels - weights, activations and gradients

    • All Attention Projections - weights, activations and gradients

We implemented the recipe through Qwix, and it can be enabled in MaxText by specifying the following flags:

quantization=fp8_full \

weight_quantization_calibration_method=“fixed,-224,224” \

act_quantization_calibration_method=“fixed,-224,224” \

use_qwix_quantization=true

You can find a detailed example here. For more information about customizing the recipe, please refer to the Qwix user guide.

3. Retune the FP8 model

The transition to FP8 introduces a classic optimization challenge: As compute becomes significantly faster, it exposes system bottlenecks that were previously hidden by slower math operations. With matrix multiplications accelerating, the relative cost of communication and data movement increases, requiring us to extensively re-tune XLA flags and adjust instruction scheduling to better hide collectives. We also leveraged activation host offloading to manage memory pressure and utilized Fully Sharded Data Parallelism (FSDP), which proved sufficient for the model scale when paired with these scheduler adjustments.

To keep the accelerated TensorCores fed, we offloaded heavy communication tasks (e.g. such as MoE token dispatch and collective operations) to the SparseCores. This included an XLA flag to enable Reduce-Scatter that decomposes operations to effectively utilize Inter-Chip Interconnect (ICI) bandwidth in the absence of Megacore support.

Finally, we optimized the compute path by tuning megablox tiling strategies to match the increased throughput of the FP8 compute engine.

4. Results: Performance meets quality

Performance

As of the time of writing this blog, the FP8 DeepSeek v3 on Ironwood achieved 3307 tokens/s/chip for FSDP sharding with 256 chips, about ~1.3x speedup against bf16 baseline (2590 tokens/s/chip).

Quality

The training loss curves below demonstrate that our FP8 recipe (blue) closely tracks the BF16 baseline (orange), proving that aggressive quantization can be applied without compromising the quality of the model.

Ready to train your own FP8 models?

Ironwood is the first TPU to fully embrace native FP8 support — a capability that has driven efficiency in the GPU ecosystem for years. By moving away from legacy integer formats, Google’s custom silicon is now aligned with the modern standard for high-performance AI training.

However, hardware support alone is not enough. The recipes and optimizations detailed here are the bridge that makes that hardware usable. This work effectively allows you customers to use Ironwood TPUs with FP8 low-precision, ensuring you can successfully deploy your most demanding workloads on this new architecture with the same confidence and quality you expect from higher-precision formats. To learn more and get started, check out this guide.

5 Likes

Hi there!

Would you like to share more details about blockwise(1x128) fp8 quant on TPU (maybe with qwix and maxtext). I tried to add support for blockwise(1x128) fp8 quant on maxtext via qwix, but encountered performance issues. In my case, fp8 dot_general (qwix _fast_dot_general) is much slower than bf16 on TPU v7. I feel hard to further optimize it.

bf16:

blockwise fp8:

Could you share your experiences with it? Do you think that per-block FP8 quantization is inherently unsuitable for TPUs, or might there be specific optimizations or configurations I’m missing?

Any insights or suggestions would be greatly appreciated!

Hi Zh,

You said you’re using qwix _fast_dot_general for the blockwise quantization. This means the fusion will be determined by XLA cost model. Performative subchannel (blockwise) support through XLA is still under development.

This is not a TPU limitation. You can still achieve highly performative FP8 kernel through Pallas. We saw nearly 2x speedup on Ironwood for FP8 matmuls vs bf16.

I would suggest using pallas kernels and write your blockwise quantization directly in the kernel. See an example here.

1 Like

Hi Amanda,

I have a few questions I’d like to discuss with you.
1、Have you tried parallelism strategies in MaxText such as EP and TP? We found that the all-to-all performance of EP is very poor. In the best practices you provided, only FSDP is used. In scenarios with large gradient accumulation, the communication overhead can become very high.
2、Have you used the num_vocab_tiling parameter before? When computing the loss, num_vocab_tiling can significantly reduce peak HBM usage. However, because the backward pass is also performed in a chunk-wise manner, it can introduce numerical precision issues when summing and averaging the final gradients (it seems that the XLA compiler changes the order of gradient accumulation).

Hi Kexi,

  1. Yes there are engineers at Google working on EP/TP for some MoE models. I agree that communication overhead will be high if we don’t do it carefully with EP/TP.

  2. You’re right that using num_vocab_tiling could have some numerical complications. We did not use it in the DeepSeek v3 training. For some training, this numerical variance is harmless and gets absorbed as normal training noise.

Have you tried other techniques to reduce HBM usage? Things like micro-batching, dropping some checkpointing? Also, the next gen TPU8i and future TPUs will support sparsity which will help with the memory bounded matmuls. FP4 will be supported on TUP 8t too.

Hi Kexi,

For the vocab tiling, we have some tests protecting gradient numerical accuracy, see here. How large the numerical precision deviation you observe in your experiment?

For EP, you could consider trying use_ring_of_experts=true to avoid expensive all2all.

For large GA, you could try Zero1 + explicit sharding, using flags

  • ici_data_parallelism=X
  • shard_optimizer_over_data=true
  • shard_mode=explicit

This feature under development and currently only support Llama model though. Zero1 only brings higher HBM usage burden.

1 Like

Thank you

seq len = 4096, fsdp=8, num_vocab_tiling=8, per_device_batch_size=2 gradient_accumulation_steps=1.
After 5 steps, mtp loss diff is over 0.1% compared with Megatron on GPU. In 10 steps, lm_loss diff is less than 0.1%

Hi Amanda_Liang,

Thanks for your suggestions!

We’re running MoE GMM (Grouped Matrix Multiply) on TPU v7x and evaluating whether FP8 blockwise quantization (DeepSeek-style, block_size=128) can outperform BF16megablox.

Problem: Through Pallas kernel profiling, we found that even without quantization/dequantization overhead, the pure matmul itself is already slower for FP8 block128 compared toBF16.

Setup:

Shape: M=1280, K=2048, N=1024 (single expert tile, wi_fwd)

Grid: 256 groups (simulating full GMM)

BF16: single dot(1280, 2048) × (2048, 1024)

FP8 block128: K-dim split into 16 iterations of dot(1280, 128) × (128, 1024), accumulated sequentially

Measured results (pure MXU dot, no quantization/scale ops):

XLA estimated_cycles confirmation:

dot(1280, 2048) × (2048, 1024) FP8: 12,951 cycles

dot(1280, 128) × (128, 1024) FP8: 10,017 cycles

Effective compute per small dot is only ~809 cycles; 92% is pipeline fill/drain overhead

16 small dots total 160,272 cycles vs 20,128 cycles for BF16 large dot

XProf LLO utilization trace confirms: MXU row is nearly empty during block128 kernel execution, while Vector ALU shows more activity.

Our understanding: With block_size=128, the K-dimension is forced into 128-element chunks. Each small dot suffers extreme MXU pipeline inefficiency — K=128 is too short tosaturate the systolic array. This appears to be a fundamental mismatch between blockwise quantization and the systolic array architecture. Even with a perfectly optimizedkernel, FP8 block128 cannot outperform BF16.

Questions:

Is this understanding correct? Are there any compiler or hardware optimization paths we may have missed?

Are there plans for TPU v7x (or future generations) to support hardware-level microscaling, similar to NVIDIA Blackwell’s MXFP, where the MXU handles per-block scalesinternally?

For FP8 GMM on TPU, is there a recommended quantization granularity or approach?

Hi Zh,

Is this understanding correct? Are there any compiler or hardware optimization paths we may have missed?

Yes your analysis is correct. On Ironwood (TPU v7), the MXU size is 256 (double previous generations). Therefore, you cannot fully utilize the MXU if the block size is NOT multiples of 256. On top of that, we found that subchannel size smaller than 384 on Ironwood will cause VPU boundness. Therefore, the recommended smallest “out-of-box” subchannel size on Ironwood is 512. If you try this size, your MXU will be fully saturated. Please let me know.

Are there plans for TPU v7x (or future generations) to support hardware-level microscaling, similar to NVIDIA Blackwell’s MXFP, where the MXU handles per-block scalesinternally?

Yes, starting TPU 8t (coming out next year), native FP4 & Block Scaling/Quantization will be supported. It will support 16 element block for MXFP8 and 32 element block for MXFP4.

For FP8 GMM on TPU, is there a recommended quantization granularity or approach?

Yes, in this article, we recommend per-channel granularity on DSv3 training which we already observed quality neutrality without more fine-grained granularity.