SkyRL is a scalable and flexible reinforcement learning framework designed to handle the complex memory and compute patterns required for large language model (LLM) alignment. It is optimized for distributed environments, enabling efficient training by decoupling data, model weights, and the training engine.
Image reference: https://docs.skyrl.ai/docs/tinker/architectureSkyRL provides a backend implementation for the Tinker API using JAX, allowing you to scale LLM post-training on TPUs or GPUs . The integration contains three layers:
-
API Layer: FastAPI HTTP server that accepts Tinker API requests, stores them in a database, and returns future IDs for async polling.
-
Engine Layer: Background subprocess that polls the database, batches pending requests, and dispatches them to the backend.
-
Backend Layer: Translates Tinker operations into training and inference calls, managing Ray workers, JAX, FSDP2/Megatron training, and vLLM inference
See SkyRL architecture documentation for more details.
Why JAX and TPUs?
SkyRL utilizes JAX as its distributed backend for training and sampling during reinforcement learning. Because JAX is hardware-agnostic and offers primitives that scale seamlessly from local development to clusters of thousands of chips, SkyRL can extend model training across both GPUs and TPUs with little to no changes.
Furthermore, SkyRL leverages JAX capabilities such as JIT compilation to achieve high-throughput execution. This significantly boosts performance, particularly on TPUs, which are natively optimized for JAX-based workloads.
What We Are Building
In this tutorial, we will perform multi-host training (FSDP and tensor parallelism) on Qwen3-8B using SkyRL on GKE and Trillium TPUs.
Prerequisites
- GKE cluster with a 4x4 TPU v6e node pool. To get started, follow Deploy TPU workloads in GKE Standard
- kubectl installed , with the GKE authentication plugin.
Step 1: Configure access to the GKE cluster
Get cluster credentials to access the GKE cluster:
export PROJECT_ID=<YOUR_PROJECT_ID>
export CLUSTER_REGION=<YOUR_REGION>
export CLUSTER_NAME=<YOUR_CLUSTER_NAME>
gcloud container clusters get-credentials $CLUSTER_NAME --region $CLUSTER_REGION --project-id $PROJECT_ID
Step 2: Install the JobSet controller
In this tutorial, we will use the JobSet API to orchestrate parallel execution of SkyRL workers.
First, install the JobSet CRDs and controller:
VERSION=v0.11.1
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/$VERSION/manifests.yaml
Verify installation by confirming that the JobSet controller is running:
kubectl -n jobset-system get pods
Step 3: Deploy JobSet
In this tutorial, we will create a JobSet resource that will orchestrate and deploy each SkyRL worker on a 4x4 TPU v6e slice. Worker 0 is configured to start the SkyRL Tinker API server and engine.
Create the JobSet manifest:
# jobset.yaml
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: skyrl-tx-job
spec:
failurePolicy:
maxRestarts: 0
replicatedJobs:
- name: skyrl-workers
replicas: 1
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
affinity:
podAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
- labelSelector:
matchExpressions:
- key: jobset.sigs.k8s.io/jobset-name
operator: In
values:
- skyrl-tx-job
topologyKey: cloud.google.com/gke-nodepool
subdomain: test
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: worker
image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
securityContext:
privileged: false
command:
- bash
- -c
- |
set -ex
git clone https://github.com/NovaSky-AI/SkyRL.git
cd SkyRL
if [ "$JOB_COMPLETION_INDEX" -eq "0" ]; then
uv run --extra tpu --extra tinker --extra jax -m skyrl.tinker.api \
--base-model "Qwen/Qwen3-8B" \
--backend-config "{\"train_micro_batch_size\": 8, \"sample_max_num_sequences\": 256, \"tensor_parallel_size\": 4, \"fully_sharded_data_parallel_size\": 4, \"num_processes\": 4, \"coordinator_address\": \"skyrl-tx-job-skyrl-workers-0-0.skyrl-tx-job:7777\"}"
else
sleep 60
uv run --extra tpu --extra tinker --extra jax -m skyrl.backends.jax \
--coordinator-address skyrl-tx-job-skyrl-workers-0-0.skyrl-tx-job:7777 \
--num-processes 4 \
--process-id "$JOB_COMPLETION_INDEX"
fi
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Apply the resource with kubectl:
kubectl apply -f jobset.yaml
Step 4: Verify the Tinker API server and engine
Retrieve logs from the first replicated job running the Tinker API server and engine:
kubectl logs -f -l batch.kubernetes.io/job-completion-index=0,job-name=skyrl-tx-job-skyrl-workers-0
Verify engine and workers have started from the output:
Installed 144 packages in 650ms
2026-03-26 17:18:44,231 - INFO - uvicorn.error: Started server process [401]
2026-03-26 17:18:44,233 - INFO - uvicorn.error: Waiting for application startup.
2026-03-26 17:18:45,877 - INFO - skyrl: Using internal engine for inference
2026-03-26 17:18:45,878 - DEBUG - skyrl: Detected API server uv run flags:
['--extra', 'tpu', '--extra', 'tinker', '--extra', 'jax']
2026-03-26 17:18:45,884 - INFO - skyrl: Started background engine with PID 404:
uv run --extra tpu --extra tinker --extra jax --extra tinker --extra jax -m
skyrl.tinker.engine --base-model Qwen/Qwen3-8B --backend jax --backend-config
{"train_micro_batch_size": 8, "sample_max_num_sequences": 256,
"tensor_parallel_size": 4, "fully_sharded_data_parallel_size": 4,
"num_processes": 4, "coordinator_address":
"skyrl-tx-job-skyrl-workers-0-0.skyrl-tx-job:7777"} --checkpoints-base
/tmp/skyrl_checkpoints --database-url
sqlite:////jax-ai-image/SkyRL/skyrl/tinker/tinker.db
--external-inference-api-key EMPTY --external-inference-lora-base
/tmp/lora_models --session-cleanup-interval-sec 60 --session-timeout-sec 300
2026-03-26 17:18:45,886 - INFO - uvicorn.error: Application startup complete.
2026-03-26 17:18:45,888 - INFO - uvicorn.error: Uvicorn running on
http://0.0.0.0:8000 (Press CTRL+C to quit)
warning: The `extra-build-dependencies` option is experimental and may change without warning. Pass `--preview-features extra-build-dependencies` to disable this warning.
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
2026-03-26 17:20:32,331 - INFO - skyrl: JAX distributed initialized:
process_id=0 (4 total), local devices: 4, total devices: 16
Fetching 11 files: 100%|ββββββββββ| 11/11 [00:06<00:00, 1.83it/s]
Step 5: Run the training loop
Create a training script to fine-tune a Qwen3-8B model. The script implements a training loop designed to teach the model English-to-Pig Latin translation. It manages the full optimization cycle by executing forward-backward passes to compute loss and iteratively updating the modelβs weights using the Adam optimizer. Each training step interacts with the SkyRL backend through the Tinker API.
train.py
import tinker
import numpy as np
from tinker import types
# Connect to the local server
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
tokenizer = training_client.get_tokenizer()
# Training examples
examples = [
{"input": "banana split", "output": "anana-bay plit-say"},
{"input": "quantum physics", "output": "uantum-qay ysics-phay"},
{"input": "coding wizard", "output": "oding-cay izard-way"},
]
def process_example(example, tokenizer):
prompt = f"English: {example['input']}\nPig Latin:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
tokens = prompt_tokens + completion_tokens
weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:])
)
processed = [process_example(ex, tokenizer) for ex in examples]
# Training loop
for _ in range(6):
fwdbwd = training_client.forward_backward(processed, "cross_entropy").result()
training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()
logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs])
weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed])
print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}")
Copy train.py and run on worker 0:
POD=$(kubectl get pod -l batch.kubernetes.io/job-completion-index=0,job-name=skyrl-tx-job-skyrl-workers-0 -o custom-columns=":metadata.name" --no-headers)
kubectl cp rl_loop.py $POD:/jax-ai-image/SkyRL/train.py
kubectl exec -ti $POD -- sh -c 'cd SkyRL && uv run train.py'
Verify results by viewing the training loss in the output:
Loss: 3.9209
Loss: 3.1629
Loss: 2.2993
Loss: 1.6949
Loss: 1.1634
Loss: 0.5750
