Brain-JEPA (safetensors)
Pretrained weights for Brain-JEPA (NeurIPS 2024, Spotlight) converted to safetensors format for use with brainjepa-rs.
Model description
Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with:
- Brain gradient positioning for spatial (ROI) embeddings
- Temporal patch embedding via 1D convolution along time
- JEPA architecture (Joint Embedding Predictive Architecture)
The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs.
Files
| File | Description | Shape info |
|---|---|---|
brainjepa.safetensors |
All weights (encoder + predictor + target_encoder) | 384 tensors, ~709 MB |
gradient_mapping_450.csv |
Brain gradient coordinates for positional embeddings | 450 rows x 30 columns |
Weight key structure
Keys are prefixed by component (encoder., predictor., target_encoder.):
encoder.patch_embed.proj.weight [768, 1, 1, 16]
encoder.blocks.{i}.norm1.weight [768]
encoder.blocks.{i}.attn.qkv.weight [2304, 768]
encoder.blocks.{i}.attn.proj.weight [768, 768]
encoder.blocks.{i}.mlp.fc1.weight [3072, 768]
encoder.blocks.{i}.mlp.fc2.weight [768, 3072]
encoder.norm.weight [768]
...
For inference, use target_encoder.* keys (EMA-smoothed weights from pretraining).
Usage with brainjepa-rs (Rust)
# Install
git clone https://github.com/eugenehp/brainjepa-rs
cd brainjepa-rs
# Download weights from this repo
# Place brainjepa.safetensors and gradient_mapping_450.csv in data/
# Run inference (CPU)
cargo run --release --bin infer -- \
--weights data/brainjepa.safetensors \
--gradient data/gradient_mapping_450.csv \
--input data/fmri_sample.safetensors
# Run inference (GPU, Metal/Vulkan)
cargo run --release --no-default-features --features wgpu --bin infer -- \
--weights data/brainjepa.safetensors \
--gradient data/gradient_mapping_450.csv \
--input data/fmri_sample.safetensors
Rust library
use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig};
let (encoder, _) = BrainJepaEncoder::<B>::from_weights(
"data/brainjepa.safetensors",
"data/gradient_mapping_450.csv",
&ModelConfig::default(),
&DataConfig::default(),
&device,
)?;
let result = encoder.encode_safetensors("data/fmri.safetensors")?;
// result.embeddings: [4500, 768] float32
Usage with original Python code
These weights were converted from the original PyTorch checkpoint. To use with the original code:
import torch
from safetensors.torch import load_file
tensors = load_file("brainjepa.safetensors")
# Filter for target_encoder weights and strip prefix:
state_dict = {
k.removeprefix("target_encoder."): v
for k, v in tensors.items()
if k.startswith("target_encoder.")
}
model.load_state_dict(state_dict)
Conversion
Weights were converted from the original PyTorch checkpoint using:
python scripts/convert_weights.py \
--input jepa-ep300.pth.tar \
--output brainjepa.safetensors
The conversion script strips the module. prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format.
Benchmark
Tested on Mac Mini M4 Pro (14 cores, 64 GB).
Input: [1, 1, 450, 160] (single sample, ViT-Base 86M params). Best-of-3 encode time.
| Backend | Encode | vs PyTorch CPU |
|---|---|---|
| Rust โ NdArray + Rayon (CPU) | 28,778 ms | 0.06x |
| Rust โ NdArray + Accelerate (CPU) | 21,092 ms | 0.08x |
| Python โ PyTorch (CPU) | 1,782 ms | 1.0x |
| Python โ PyTorch MPS (GPU) | 581 ms | 3.1x |
| Rust โ wgpu f32 / Metal (GPU) | 83 ms | 21.5x |
| Rust โ wgpu f16 / Metal (GPU) | 85 ms | 21.0x |
The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster than PyTorch CPU.
Architecture details
| Parameter | Value |
|---|---|
| Model | ViT-Base |
| Embedding dim | 768 |
| Encoder depth | 12 layers |
| Predictor depth | 6 layers |
| Attention heads | 12 |
| Head dim | 64 |
| MLP ratio | 4x (hidden=3072) |
| Patch size | 16 (temporal) |
| Input size | 450 ROIs x 160 time points |
| Output | 4500 patches x 768 dims |
| Normalization | LayerNorm (eps=1e-6) |
| Activation | GELU |
| Pretraining | 300 epochs on UK Biobank |
| Loss | Smooth L1 (JEPA representation matching) |
| Optimizer | AdamW (lr=1e-3, warmup=40 epochs, cosine decay) |
Source
Original paper and code:
Zijian Dong, Ruilin Li, Yilei Wu, et al. Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking. NeurIPS 2024 (Spotlight). arXiv:2409.19407
- Paper: arxiv.org/abs/2409.19407
- Original code: github.com/hzlab/Brain-JEPA
- Rust inference: github.com/eugenehp/brainjepa-rs
