Tutorial: Multi-host Training with TPUs and SkyRL on GKE

SkyRL is a scalable and flexible reinforcement learning framework designed to handle the complex memory and compute patterns required for large language model (LLM) alignment. It is optimized for distributed environments, enabling efficient training by decoupling data, model weights, and the training engine.

Image reference: https://docs.skyrl.ai/docs/tinker/architecture

SkyRL provides a backend implementation for the Tinker API using JAX, allowing you to scale LLM post-training on TPUs or GPUs . The integration contains three layers:

  1. API Layer: FastAPI HTTP server that accepts Tinker API requests, stores them in a database, and returns future IDs for async polling.

  2. Engine Layer: Background subprocess that polls the database, batches pending requests, and dispatches them to the backend.

  3. Backend Layer: Translates Tinker operations into training and inference calls, managing Ray workers, JAX, FSDP2/Megatron training, and vLLM inference

See SkyRL architecture documentation for more details.

Why JAX and TPUs?

SkyRL utilizes JAX as its distributed backend for training and sampling during reinforcement learning. Because JAX is hardware-agnostic and offers primitives that scale seamlessly from local development to clusters of thousands of chips, SkyRL can extend model training across both GPUs and TPUs with little to no changes.

Furthermore, SkyRL leverages JAX capabilities such as JIT compilation to achieve high-throughput execution. This significantly boosts performance, particularly on TPUs, which are natively optimized for JAX-based workloads.

What We Are Building

In this tutorial, we will perform multi-host training (FSDP and tensor parallelism) on Qwen3-8B using SkyRL on GKE and Trillium TPUs.

Prerequisites

Step 1: Configure access to the GKE cluster

Get cluster credentials to access the GKE cluster:

export PROJECT_ID=<YOUR_PROJECT_ID>
export CLUSTER_REGION=<YOUR_REGION>
export CLUSTER_NAME=<YOUR_CLUSTER_NAME>

gcloud container clusters get-credentials $CLUSTER_NAME --region $CLUSTER_REGION --project-id $PROJECT_ID

Step 2: Install the JobSet controller

In this tutorial, we will use the JobSet API to orchestrate parallel execution of SkyRL workers.

First, install the JobSet CRDs and controller:

VERSION=v0.11.1
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/$VERSION/manifests.yaml

Verify installation by confirming that the JobSet controller is running:

kubectl -n jobset-system get pods

Step 3: Deploy JobSet

In this tutorial, we will create a JobSet resource that will orchestrate and deploy each SkyRL worker on a 4x4 TPU v6e slice. Worker 0 is configured to start the SkyRL Tinker API server and engine.

Create the JobSet manifest:

# jobset.yaml
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: skyrl-tx-job
spec:
  failurePolicy:
    maxRestarts: 0
  replicatedJobs:
    - name: skyrl-workers
      replicas: 1
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              affinity:
                podAffinity:
                  requiredDuringSchedulingIgnoredDuringExecution:
                  - labelSelector:
                      matchExpressions:
                      - key: jobset.sigs.k8s.io/jobset-name
                        operator: In
                        values:
                        - skyrl-tx-job
                    topologyKey: cloud.google.com/gke-nodepool 
              subdomain: test
              restartPolicy: Never
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              containers:
              - name: worker
                image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
                securityContext:
                  privileged: false
                command:
                - bash
                - -c
                - |
                  set -ex

                  git clone https://github.com/NovaSky-AI/SkyRL.git
                  cd SkyRL

                  if [ "$JOB_COMPLETION_INDEX" -eq "0" ]; then
                    uv run --extra tpu --extra tinker --extra jax -m skyrl.tinker.api \
                    --base-model "Qwen/Qwen3-8B" \
                    --backend-config "{\"train_micro_batch_size\": 8, \"sample_max_num_sequences\": 256, \"tensor_parallel_size\": 4, \"fully_sharded_data_parallel_size\": 4, \"num_processes\": 4, \"coordinator_address\": \"skyrl-tx-job-skyrl-workers-0-0.skyrl-tx-job:7777\"}"
                  else
                    sleep 60
                    uv run --extra tpu --extra tinker --extra jax -m skyrl.backends.jax \
                    --coordinator-address skyrl-tx-job-skyrl-workers-0-0.skyrl-tx-job:7777 \
                    --num-processes 4 \
                    --process-id "$JOB_COMPLETION_INDEX"
                  fi
                resources:
                  requests:
                    google.com/tpu: 4
                  limits:
                    google.com/tpu: 4

Apply the resource with kubectl:

kubectl apply -f jobset.yaml

Step 4: Verify the Tinker API server and engine

Retrieve logs from the first replicated job running the Tinker API server and engine:

kubectl logs -f -l batch.kubernetes.io/job-completion-index=0,job-name=skyrl-tx-job-skyrl-workers-0

Verify engine and workers have started from the output:

Installed 144 packages in 650ms
2026-03-26 17:18:44,231 - INFO - uvicorn.error: Started server process [401]
2026-03-26 17:18:44,233 - INFO - uvicorn.error: Waiting for application startup.
2026-03-26 17:18:45,877 - INFO - skyrl: Using internal engine for inference
2026-03-26 17:18:45,878 - DEBUG - skyrl: Detected API server uv run flags:
['--extra', 'tpu', '--extra', 'tinker', '--extra', 'jax']
2026-03-26 17:18:45,884 - INFO - skyrl: Started background engine with PID 404:
uv run --extra tpu --extra tinker --extra jax --extra tinker --extra jax -m
skyrl.tinker.engine --base-model Qwen/Qwen3-8B --backend jax --backend-config
{"train_micro_batch_size": 8, "sample_max_num_sequences": 256,
"tensor_parallel_size": 4, "fully_sharded_data_parallel_size": 4,
"num_processes": 4, "coordinator_address":
"skyrl-tx-job-skyrl-workers-0-0.skyrl-tx-job:7777"} --checkpoints-base
/tmp/skyrl_checkpoints --database-url
sqlite:////jax-ai-image/SkyRL/skyrl/tinker/tinker.db
--external-inference-api-key EMPTY --external-inference-lora-base
/tmp/lora_models --session-cleanup-interval-sec 60 --session-timeout-sec 300
2026-03-26 17:18:45,886 - INFO - uvicorn.error: Application startup complete.
2026-03-26 17:18:45,888 - INFO - uvicorn.error: Uvicorn running on
http://0.0.0.0:8000 (Press CTRL+C to quit)
warning: The `extra-build-dependencies` option is experimental and may change without warning. Pass `--preview-features extra-build-dependencies` to disable this warning.
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
2026-03-26 17:20:32,331 - INFO - skyrl: JAX distributed initialized:
process_id=0 (4 total), local devices: 4, total devices: 16
Fetching 11 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 11/11 [00:06<00:00,  1.83it/s]

Step 5: Run the training loop

Create a training script to fine-tune a Qwen3-8B model. The script implements a training loop designed to teach the model English-to-Pig Latin translation. It manages the full optimization cycle by executing forward-backward passes to compute loss and iteratively updating the model’s weights using the Adam optimizer. Each training step interacts with the SkyRL backend through the Tinker API.

train.py

import tinker
import numpy as np
from tinker import types

# Connect to the local server
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
tokenizer = training_client.get_tokenizer()

# Training examples
examples = [
    {"input": "banana split", "output": "anana-bay plit-say"},
    {"input": "quantum physics", "output": "uantum-qay ysics-phay"},
    {"input": "coding wizard", "output": "oding-cay izard-way"},
]

def process_example(example, tokenizer):
    prompt = f"English: {example['input']}\nPig Latin:"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)

    tokens = prompt_tokens + completion_tokens
    weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
        loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:])
    )

processed = [process_example(ex, tokenizer) for ex in examples]

# Training loop
for _ in range(6):
    fwdbwd = training_client.forward_backward(processed, "cross_entropy").result()
    training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()

    logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs])
    weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed])
    print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}")

Copy train.py and run on worker 0:

POD=$(kubectl get pod -l batch.kubernetes.io/job-completion-index=0,job-name=skyrl-tx-job-skyrl-workers-0 -o custom-columns=":metadata.name" --no-headers)

kubectl cp rl_loop.py $POD:/jax-ai-image/SkyRL/train.py
kubectl exec -ti $POD -- sh -c 'cd SkyRL && uv run train.py'

Verify results by viewing the training loss in the output:

Loss: 3.9209
Loss: 3.1629
Loss: 2.2993
Loss: 1.6949
Loss: 1.1634
Loss: 0.5750
2 Likes