MaxKernel: Automating Pallas Kernel Generation and Optimization via Agentic Systems

The Problem: The Performance Bottleneck

Software optimization is a key requirement in getting optimal performance from today’s ML workloads running on machine learning accelerators. Today, writing handcrafted kernels is an imperative to achieve this optimization. These kernels are necessary for all workloads such as training and inference.

ML engineers working on kernels currently face a significant gap in the ecosystem:

  • Need for Deep Expertise : Engineers need expertise across hardware, low level software, compilers and ML workloads in order to write optimized kernel code.

  • Time-Consuming Optimization: Currently, optimizing model code and writing new or existing kernels for TPU architecture is a laborious process that takes several weeks.

  • Complex, Multi-Step Workflows: Translating and optimizing these kernels isn’t a single step. It is a highly manual, iterative cycle that involves rigorous profiling, experimentation, code implementation, and validation.

The Solution: To overcome these friction points and accelerate development, there is a critical need for an intelligent, agentic system capable of automating the heavy lifting of kernel generation and optimization.

Overview

MaxKernel is an interactive, Human-in-the-Loop (HITL) agentic framework designed to solve the Pallas kernel bottleneck. It acts as an intelligent “co-pilot” for kernel engineers, successfully transforming JAX code or existing GPU kernel code (such as CUDA, Triton, or PyTorch) into highly optimized TPU kernel implementations.

MaxKernel guides users through a multi-stage workflow, providing a conversational interface that keeps you entirely in control. You can manually iterate on the generated code, reviewing and refining plans at every step.

Key Capabilities & Agent Architecture: MaxKernel operates using a sophisticated hierarchy of specialized sub-agents working together to orchestrate the kernel optimization lifecycle:

  • Plan-Driven Development (PlanKernelAgent): Before writing any code, the system creates detailed optimization plans. It suggests tiling strategies, memory optimizations, and sets performance targets, subject to user approval.

  • Kernel Implementation (ImplementKernelAgent): Once a plan is approved, this agent generates clean, idiomatic JAX/Pallas code. It leverages RAG (Retrieval-Augmented Generation) tied to official JAX/Pallas documentation to provide highly accurate, context-aware advice.

  • GPU-to-JAX Conversion (GpuToJaxAgent): This pipeline automatically converts existing CUDA, Triton, and PyTorch code into JAX by intelligently stripping out hardware-specific optimizations and generating equivalent JAX code.

  • Automated Testing & Validation (ValidatedTestGenerationAgent): The system automatically generates comprehensive pytest test suites. It runs compilation checks, verifies numerical correctness, and benchmarks performance, providing full tracebacks for easy debugging.

  • Performance Profiling (ProfileAgentOrchestrator): This agent analyzes your kernel’s execution, looking at DMA/memory transfers and the compute-vs-memory ratio. It identifies bottlenecks and provides actionable, data-driven optimization recommendations.

  • Safety & Control: MaxKernel operates with scoped file system permissions. All file access is restricted to a configurable work directory, ensuring that your system remains secure while you interact with the agent.

Experimental Results: Real-World Impact

MaxKernel has already demonstrated significant benefits during internal dogfooding for inference workloads.

  • Inference Use Case 1 : Deepseek MLA kernel on v5p TPU platform

    • Latency Improvement: 8.7% speedup (latency dropped from 3.12ms to 2.856 ms) compared to human written and optimized pallas kernel baseline.

    • Throughput Improvement: 9% increase (jumped from 116.73 TFLOPS to 127.82 TFLOPS) compared to human written and optimized pallas kernel baseline.

  • Inference Use Case 2 : The MaxKernel Agent addressed a critical kernel crash in RPAv3 Kernel caused by unpadded inputs during the prefill phase. To resolve this, the agent implemented logic to prevent pipeline deadlocks, clamped values to avoid out-of-bounds DMA errors, and added defensive sizing to handle edge cases. These changes ensure robust processing of left-padded inputs with negligible performance impact. The successful implementation has been captured in a pull request, ensuring correctness for chunked prefill sizes.

Conclusion & How to Try It Out

MaxKernel represents a massive leap forward in generating and automating Pallas kernel optimization. MaxKernel is available as an open-source software in the Google Accelerator Agents OSS GitHub repository.

Getting Started:

  1. Prerequisites: You will need a Python 3.9+ environment with JAX installed, access to Google Cloud Vertex AI (for the agent and RAG retrieval), and TPU Access for kernel execution.

  2. Installation: Simply navigate to the repo directory and run our setup script.

    • bash prepare_hitl_agent.sh

    • This will prompt you to choose your Python environment, install dependencies, configure your scoped work directory, and generate a .env file for your Google Cloud credentials.

  3. Running the Agent: You can run MaxKernel in two modes based on your workflow:

    • CLI Mode (Recommended for Devs): Run bash run_hitl_agent.sh for a low-overhead, interactive command-line interface.

    • Web UI Mode: Run bash run_hitl_agent.sh --ui to spin up a visual interface on port 1430, complete with conversation history.

We invite the community to clone the repository, and start building the next generation of highly optimized Pallas kernels!

Various teams and individuals within Google Cloud contributed to this effort. Special thanks to the core engineering team(alphabetical) - Andi Gavrilescu, Newfel Harrat, Steven Ingram, Gerson Kroiz, Sethu Sankaran, Hassan Sipra, George Vanica, Shangkun Wang.

Deepak Patil - Group Product Manager

Nina Cai - Senior Machine Learning Software Developer

2 Likes