Minimal Speedup When Reproducing the JAX All-Gather Overlap Example (Looking for Guidance)

Hi everyone,

I’m investigating how to achieve better compute/communication overlap on TPU. As part of this, I tried to reproduce “Example 1: All-gather on one side” from the JAX shard_map notebook:
https://docs.jax.dev/en/latest/notebooks/shard_map.html#example-1-all-gather-on-one-side

In the notebook, the example shows a large improvement (roughly 353 us → 226 us). However, on my setup I only see a very small improvement:

  • Platform: v5litepod-4 (v2-alpha-tpuv5-lite)

  • Observed latency: ~330 us(basic) → ~319 us(overlapped_bidi)

My understanding is that the notebook result may involve “additional axes of parallelism”, and that might be a key reason my reproduction doesn’t show a similar gain. I’d really appreciate any hints or guidance.

I’m including the script I used for testing below.

Thanks in advance for any pointers or guidance.

import jax
import jax.numpy as jnp
from jax import device_put
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental.shard_map import shard_map
import time
import os
from functools import partial
import jax.lax as lax

# ============== Configuration ==============
M, K, N = 8192, 8192, 1024
NUM_WARMUP = 5
NUM_ITERS = 10
TRACE_DIR = "/home/moha/jax-trace-allgather"

# Select kernel mode: "basic", "overlapped", or "overlapped_bidi"
KERNEL_MODE = "overlapped_bidi"

# ============== Initialization ==============
devices = jax.devices()
num_devices = len(devices)
print(f"🚀 Available TPU devices: {num_devices}")
print(f"   Device type: {devices[0].device_kind}")

# Create a 1D mesh
mesh = Mesh(devices, ('i',))

# ============== Create test data ==============
lhs = jnp.ones((M, K), dtype=jnp.bfloat16)
rhs = jnp.ones((K, N), dtype=jnp.bfloat16)

# Sharding specs
# In the all-gather matmul setup:
# - lhs is sharded along the M dimension (rows)
# - rhs is sharded along the K dimension (rows)
# The kernel all-gathers rhs to reconstruct the full rhs, then runs a local matmul.
lhs_spec = P('i', None)  # Shard along M
rhs_spec = P('i', None)  # Shard along K
out_spec = P('i', None)  # Output sharded along M (matches lhs)

# Create NamedSharding for device_put
lhs_sharding = NamedSharding(mesh, lhs_spec)
rhs_sharding = NamedSharding(mesh, rhs_spec)

# Place arrays with sharding
lhs = device_put(lhs, lhs_sharding)
rhs = device_put(rhs, rhs_sharding)

print(f"   lhs shape: {lhs.shape}, sharding: {lhs_spec}")
print(f"   rhs shape: {rhs.shape}, sharding: {rhs_spec}")

# ============== Define kernel functions ==============

# --- Option 1: basic all-gather ---
def matmul_allgather_kernel(lhs_block, rhs_block):
    """All-gather rhs to reconstruct the full rhs, then do a local matmul."""
    rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
    return lhs_block @ rhs

# --- Option 2: overlapped (overlap communication and compute) ---
def matmul_allgather_overlapped_kernel(lhs_block, rhs_block):
    """Ring ppermute rhs blocks to pipeline transfers and overlap comm/compute."""
    size = num_devices
    idx = jax.lax.axis_index('i')
    shift = partial(jax.lax.ppermute, axis_name='i',
                    perm=[(i, (i + 1) % size) for i in range(size)])

    B = lhs_block.shape[1] // size
    lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)

    # First iteration: use local rhs_block with the corresponding lhs slice.
    out_block = lhs_blocks(idx) @ rhs_block

    # Subsequent iterations: rotate rhs_block around the ring and accumulate.
    for i in range(1, size):
        rhs_block = shift(rhs_block)
        out_block += lhs_blocks((idx - i) % size) @ rhs_block

    return out_block

# --- Option 3: overlapped bidirectional (bidirectional ring overlap) ---
def matmul_allgather_overlapped_bidi_kernel(lhs_block, rhs_block):
    """Bidirectional ring ppermute of rhs blocks.

    Split rhs_block into two halves and send them in opposite directions to improve
    link utilization while overlapping communication with computation.
    """
    size = num_devices
    idx = jax.lax.axis_index('i')
    shift_up = partial(jax.lax.ppermute, axis_name='i',
                       perm=[(i, (i + 1) % size) for i in range(size)])
    shift_dn = partial(jax.lax.ppermute, axis_name='i',
                       perm=[(i, (i - 1) % size) for i in range(size)])

    # On each device, split lhs_block columns into size*2 sub-blocks of width B.
    B = lhs_block.shape[1] // size // 2  # half-size blocks
    lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2 * i + hi) * B, B, 1)

    # Split rhs_block along the row dimension into lower/upper halves.
    rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)

    # First iteration: multiply local rhs halves with the corresponding lhs sub-blocks.
    out_block = lhs_blocks(idx, 0) @ rhs_block_lo
    out_block += lhs_blocks(idx, 1) @ rhs_block_hi

    # Subsequent iterations: send the lower half up and the upper half down.
    for i in range(1, size):
        rhs_block_lo = shift_up(rhs_block_lo)
        rhs_block_hi = shift_dn(rhs_block_hi)
        out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
        out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi

    return out_block


# Select the kernel based on KERNEL_MODE
if KERNEL_MODE == "overlapped_bidi":
    kernel_fn = matmul_allgather_overlapped_bidi_kernel
    kernel_name = "matmul_allgather_overlapped_bidi"
elif KERNEL_MODE == "overlapped":
    kernel_fn = matmul_allgather_overlapped_kernel
    kernel_name = "matmul_allgather_overlapped"
else:
    kernel_fn = matmul_allgather_kernel
    kernel_name = "matmul_allgather_basic"

print(f"\n⚙️  Kernel mode: {kernel_name}")

# Wrap with shard_map + jit
matmul_fn = jax.jit(shard_map(
    kernel_fn,
    mesh=mesh,
    in_specs=(lhs_spec, rhs_spec),
    out_specs=out_spec,
    check_rep=False
))

# ============== Warmup ==============
print(f"\n🔥 Warming up ({NUM_WARMUP} iterations)...")
for _ in range(NUM_WARMUP):
    result = matmul_fn(lhs, rhs)
    result.block_until_ready()
    # jax.debug.visualize_array_sharding(result)
print("  Warmup complete.")

# ============== Benchmark ==============
print(f"\n📊 Running benchmark ({NUM_ITERS} iterations)...")
start = time.time()
for _ in range(NUM_ITERS):
    result = matmul_fn(lhs, rhs)
    result.block_until_ready()
elapsed = time.time() - start

flops = 2 * M * K * N * NUM_ITERS
tflops = flops / elapsed / 1e12

print(f"   Kernel: {kernel_name}")
print(f"   Matrix size: {M}x{K} @ {K}x{N} (bfloat16)")
print(f"   Total time: {elapsed:.3f}s for {NUM_ITERS} iterations")
print(f"   Per iteration: {elapsed/NUM_ITERS*1000:.2f}ms")
print(f"   Throughput: {tflops:.2f} TFLOPS")
print(f"   Output shape: {result.shape}")

# ============== Profiling ==============
print(f"\n🔬 Running profiler (10 iterations)...")
os.makedirs(TRACE_DIR, exist_ok=True)

with jax.profiler.trace(TRACE_DIR, create_perfetto_trace=True):
    for i in range(10):
        result = matmul_fn(lhs, rhs)
        result.block_until_ready()

print(f" ✅ Trace saved to: {TRACE_DIR}")

1 Like