Skip to content

Running on TPUs

Currently we only have instructions how to run tx on a single TPU VM. Multi-node instructions will be added later.

Setting up the TPU VM

First start the TPU VM:

gcloud compute tpus tpu-vm create <TPU_NAME> --project=<PROJECT> --zone=<ZONE> --accelerator-type=v6e-8 --version=v2-alpha-tpuv6e --scopes=https://www.googleapis.com/auth/cloud-platform.read-only,https://www.googleapis.com/auth/devstorage.read_write --network=<NETWORK> --subnetwork=<SUBNETWORK> --spot

After the VM is started, you can ssh into it via

gcloud compute tpus tpu-vm ssh <TPU_NAME>

Setting up tx

Once you are logged into the VM, install uv and clone the tx repository with

curl -LsSf https://astral.sh/uv/install.sh | sh
git clone https://github.com/tx-project/tx
cd tx

Starting the training

Next, download the dataset with

uv run --with huggingface_hub hf download Qwen/Qwen3-4B --local-dir /tmp/qwen3

You can then start the training with

uv run --extra tpu --with jinja2 tx train --model Qwen/Qwen3-4B --dataset HuggingFaceH4/ultrachat_200k --loader tx.loaders.chat --split train_sft --output-dir /tmp/ultrachat --batch-size 8 --load-checkpoint-path /tmp/qwen3 --tp-size 8

Note that at the beginning the training is a little slow since the JIT compiler needs to compile kernels for the various shapes.

See the full set of options of tx in the CLI reference.

You can visualize TPU usage with

uv run --with libtpu --with git+https://github.com/google/cloud-accelerator-diagnostics/#subdirectory=tpu_info tpu-info