How I trained a Stance-Aware Cross-Encoder that classifies Indonesian news headlines against claims — starting on a free Colab TPU and scaling out to Cloud TPU v5p with a single@kinetic.run() decorator
Introduction
Misinformation is one of the defining problems of the social-media era, and Indonesia has been hit particularly hard. Hoaks (the Indonesian shorthand for fake news) spread through WhatsApp groups and Twitter threads faster than any fact-checker can keep up. Most published research on automated fake-news detection focuses on English-language data, which leaves practitioners working with Bahasa Indonesia in a frustrating spot: the techniques exist, but the tooling and pre-trained models are scarce.
This article walks through building a real, working multimodal hoax detector for Indonesian news from scratch. The model takes two inputs — a claim (the original assertion, often from social media) and a headline (a news article headline that mentions the same topic) — and predicts whether the article supports the claim (for), refutes it (against), or merely observes it neutrally (observing).
The architecture is a Stance-Aware Cross-Encoder: a BiLSTM-style encoder for each input, multi-head self-attention, and a cross-attention layer that lets the claim and headline literally read each other before classification. Built end-to-end with JAX and Flax, trained on TPU.
The deployment story has two halves:
- Free Colab TPU for prototyping — Google gives every Colab user free access to a v5e-1 TPU, which is enough to train this model end-to-end in under an hour at zero cost.
- Cloud TPU v5p via Keras Kinetic for serious training — when you outgrow Colab’s runtime limits, Keras Kinetic lets you ship the same training function to a Cloud TPU pod with a single Python decorator. No Docker, no Kubernetes YAML, no SSH.
By the end, you’ll have:
- A reusable Indonesian tokenizer and dataset loader
- A 4-layer Transformer encoder with stance-aware cross-attention, written in Flax
- A JIT-compiled training loop with optax and orbaxcheckpointing
- A working predict.py that runs new claim-headline pairs through the trained model
- The exact same code, deployed to Cloud TPU [email protected]()
Let’s get into it.
Why This Problem Is Hard (and Interesting)
Naive fake-news detectors look at one piece of text and try to classify it as “real” or “fake.” That’s both technically weak and ethically uncomfortable — a single text rarely carries enough signal, and the “true/false” framing assumes the model has access to ground truth it can’t possibly have.
The stance detection framing is much more honest. Given a claim and a related news article, the model doesn’t decide whether the claim is true; it decides whether this particular article supports, refutes, or merely observes the claim. That’s a question a model can actually answer, and it’s exactly the input a downstream fact-checker needs to make a final call.
Mathematically, the task is a 3-way classification over theinteraction of two pieces of text. That word — interaction — is what makes the architecture interesting. You can’t just encode each side independently and concatenate. You need a layer that lets the claim attend to the headline and vice versa, so the model can pick up on subtle cues like negation (“Government denies…”), hedging (“alleged…”), or framing (“according to critics…”).
Why JAX, Flax, Keras Kinetic, and TPU?
- JAX gives me NumPy-style code with automatic differentiation, JIT compilation via XLA, and transparent acceleration on CPU/GPU/TPU.
- Flax sits on top of JAX and lets me write neural networks as nn.Module classes. The model is dense in attention layers, and Flax keeps the parameter management clean.
- Optax for optimization (AdamW with linear warmup + cosine decay) and Orbax for checkpointing — both are part of the JAX ecosystem and JIT-friendly.
- Keras Kinetic is the deployment glue. One decorator turns a local Python function into a remote TPU job, with container caching, log streaming, and automatic GKE provisioning.
- TPU because the workload is dominated by attention matmuls — exactly what TPU systolic arrays are built for. Free on Colab (v5e-1), and Cloud TPU v5p when you need to scale up.
TPU vs GPU for This Workload
A multimodal Transformer with cross-attention is one of the cleanest TPU workloads you can write. Here’s why, and where GPUs still hold their own.
Hardware design
- GPU (NVIDIA A100/H100): general-purpose parallel processor, thousands of CUDA cores, great for arbitrary parallel computation.
- TPU (v5e or v5p): domain-specific accelerator built around a large systolic array (MXU) optimized for dense matrix multiplications.
What dominates the compute in this model
- Multi-head self-attention: softmax(Q Kᵀ / √d) V — three big matmuls per head per layer.
- Cross-attention between claim and headline: same shape, just with different inputs feeding Q vs K/V.
- Feedforward blocks: two Dense layers with GELU between them.
All of those are dense matmuls with predictable shapes. The TPU systolic array is purpose-built to chew through exactly this pattern at peak FLOPs. The XLA compiler fuses the entire train_step into a few kernels, and after the first compile, every step runs at full throughput.
Where GPUs still win for stance detection / NLP
- You’re doing token-level decoding with KV-cache and irregular generation lengths (we’re not — we’re doing classification).
- You need a HuggingFace transformers model that’s only available as a PyTorch checkpoint (we’re training from scratch, so this doesn’t apply).
- You want to iterate in a notebook with constant Python control flow that doesn’t JIT cleanly (Colab gives you both a TPU anda notebook, so you don’t have to choose).
Where TPUs win for stance detection / NLP
- Fixed-length sequences (we pad to 64 tokens) → predictable shapes → great XLA compilation.
- The whole train_step JITs into a single fused execution graph.
- pmap / shard_map make multi-chip training a one-liner if you want to scale up.
- Free on Colab, and Cloud TPU v5e is roughly $0.40/chip-hour on Spot.
Rule of thumb
- Quick prototyping in a notebook with Bahasa Indonesia data → free Colab TPU. (This article.)
- Iterative R&D using HuggingFace PyTorch checkpoints → GPU.
- Production training with batchable, JAX-native workloads → Cloud TPU via Kinetic.
- You need to fine-tune a 7B+ Indonesian LLM → that’s a different article (and a different category — vLLM or Tunix).
Now let’s actually build it.
Project Architecture
The pipeline is straightforward:
datasetika.csv (Claim, Judul, Stance)
│
▼
┌──────────────────┐
│ IndonesianTokenizer │ whitespace + punctuation, vocab from corpus
└──────────────────┘
│
▼
┌──────────────────┐
│ FakeNewsDataset │ stratified train/val/test split, JAX-ready arrays
└──────────────────┘
│
▼
┌──────────────────────────────────────┐
│ FakeNewsDetector (Flax nn.Module) │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Token+Pos │ │ Token+Pos │ │
│ │ Embedding │ │ Embedding │ │
│ │ (Claim) │ │ (Headline) │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Transformer │ │ Transformer │ │
│ │ × N layers │ │ × N layers │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ └────────┬────────┘ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Stance Cross-Encoder│ │
│ │ (cross-attention + │ │
│ │ diff & product) │ │
│ └──────────┬──────────┘ │
│ ▼ │
│ Dense → Softmax (3 classes) │
└──────────────────────────────────────┘
│
▼
for / against / observing
The whole thing sits inside one jax.jit-compiled train_step. Now let’s walk through each piece.
Step 1 — Hardware Setup
There are two paths. Pick whichever fits your stage of the project.
Path A: Free Colab TPU (recommended for first run)
- Open colab.research.google.com and create a new notebook.
- Click Runtime → Change runtime type.
- Under Hardware accelerator, select v5e-1 TPU.
- Click Save.
- Verify in a cell:
import jax
print(jax.devices())
# Expected: [TpuDevice(id=0, ...)]
That’s it. You now have a free TPU v5e chip for the duration of your Colab session.
Path B: Cloud TPU via Keras Kinetic (when you outgrow Colab)
Colab is fantastic for prototyping but has runtime limits and gets bumped under load. When you’re ready to run multi-hour training jobs, switch to a real Cloud TPU. The traditional path means provisioning a TPU VM, SSHing in, installing dependencies, and uploading scripts — Kinetic skips all of that.
On your local laptop:
pip install keras-kinetic
gcloud auth application-default login
gcloud config set project YOUR_PROJECT_ID
kinetic up --accelerator v5p-8 --yes
The last command provisions a GKE Autopilot cluster with a TPU v5p-8 node pool. Takes a few minutes the first time, after which you don’t touch infrastructure again until tear-down.
I’ll show the actual @kinetic.run() deployment in Step 6. For now, let’s build the model.
Step 2 — Indonesian Tokenizer and Dataset
Bahasa Indonesia is morphologically less complex than, say, Turkish or Finnish, so a whitespace + punctuation tokenizer with a learned vocabulary works surprisingly well as a baseline. (For production, swap in IndoBERT — I’ll show how at the end of this section.)
The tokenizer reserves four special tokens, builds a frequency-ranked vocabulary from the training corpus, and emits(token_ids, attention_mask) pairs at a fixed length. Standard stuff, but with one subtlety: we tokenize both the Claim and Judul (headline) columns into a shared vocabulary so the embedding layer can pick up cross-input correlations.
"""
Data Preprocessing for Indonesian Fake News Detection
Tokenizes Claim + Judul columns, encodes Stance labels.
Compatible with JAX/Flax training pipeline.
"""
import re
import numpy as np
import pandas as pd
from collections import Counter
from typing import List, Tuple, Dict
from sklearn.model_selection import train_test_split
LABEL_MAP = {"for": 0, "against": 1, "observing": 2}
ID_TO_LABEL = {v: k for k, v in LABEL_MAP.items()}
class IndonesianTokenizer:
"""
Lightweight whitespace + punctuation tokenizer for Indonesian text.
For production, swap with:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
"""
SPECIAL_TOKENS = {"<PAD>": 0, "<UNK>": 1, "<CLS>": 2, "<SEP>": 3}
def __init__(self, vocab_size: int = 30_000, min_freq: int = 2):
self.vocab_size = vocab_size
self.min_freq = min_freq
self.word2id: Dict[str, int] = dict(self.SPECIAL_TOKENS)
self.id2word: Dict[int, str] = {v: k for k, v in self.word2id.items()}
@staticmethod
def _clean(text: str) -> str:
text = text.lower()
text = re.sub(r"<[^>]+>", " ", text) # strip HTML
text = re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE) # keep alphanum
text = re.sub(r"\s+", " ", text).strip()
return text
@staticmethod
def tokenize(text: str) -> List[str]:
return IndonesianTokenizer._clean(text).split()
def build_vocab(self, texts: List[str]) -> None:
counter: Counter = Counter()
for t in texts:
counter.update(self.tokenize