All Projects // project_02 · deep learning · cell biology · software engineering

BioVAE-Phenotyper

Unsupervised deep learning pipeline for single-cell phenotyping — 96% recall on rare cell populations, built as a hands-on introduction to JAX and CI/CD engineering.

Status✓ Complete
DatasetMedMNIST v2 — BloodMNIST
StackPyTorch · JAX · Poetry · GitHub Actions

Context & Learning Goals

BioVAE-Phenotyper was built with two parallel objectives. The first was scientific: can an unsupervised model identify distinct cell phenotypes from microscopy data without any labels? The second was engineering: use this project as a concrete introduction to JAX and CI/CD pipelines — two tools I wanted to learn properly, not just in a tutorial.

The result is a project that genuinely works biologically (96% recall on a rare cell class) while also serving as a demonstration that research code can be held to software engineering standards. Every design choice — Poetry for dependency management, GitHub Actions for automated testing, JAX for independent numerical validation — was deliberate and hands-on.

The dataset is BloodMNIST from MedMNIST v2: standardised 28×28 microscopy images of 8 blood cell types, a realistic benchmark for unsupervised cell phenotyping.

Results

96% Recall on Class 7 (rare phenotype) — fully unsupervised
64 Latent space dimensions
50 Training epochs
CI passing — GitHub Actions

The key result: 96% recall on Class 7 — a rare blood cell phenotype — using purely unsupervised latent feature extraction. No labels were provided during training. In the t-SNE projection, Class 7 forms a distinct, isolated cluster, demonstrating the model's ability to identify specific biological states in a high-content screen.

Latent Space & Reconstructions

The t-SNE projection of the 64-dimensional latent space shows clearly separated cell phenotype clusters. The reconstruction panel confirms that the 3-layer architecture preserves critical morphological details — nucleus shape and texture — despite heavy compression.

t-SNE projection of 64-dimensional latent space

t-SNE · 64-dim latent space · Class 7 isolated

VAE reconstructions vs originals

Top: original cells · Bottom: VAE reconstructions

Architecture — β-VAE

The model is a custom β-VAE with 3 convolutional layers and Batch Normalization. The β weighting on the KL term pushes the model to learn a more disentangled latent space — separating biological factors of variation rather than just compressing pixels. The tradeoff between reconstruction quality and latent structure is tuned for phenotype separability rather than visual fidelity.

# beta-VAE loss — higher beta = more disentangled latent space
def loss(recon_x, x, mu, log_var, beta=4.0):
    recon = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kld   = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon + beta * kld

Introducing JAX — Hybrid Validation

One of the explicit goals of this project was to get hands-on with JAX. Rather than rewriting the full training loop (PyTorch is the right tool for that), I used JAX for what it does best: JIT-compiled numerical verification of key metrics after training.

Concretely, recall per class and reconstruction metrics are independently computed in JAX and cross-checked against the PyTorch outputs. This hybrid approach served two purposes: it caught a subtle numerical discrepancy in my early loss implementation, and it gave me practical experience with JAX's functional style and jax.jit compilation model — meaningfully different from PyTorch's imperative paradigm.

# JAX — JIT-compiled recall metric for independent validation
@jax.jit
def compute_recall(preds, labels, target_class):
    tp = jnp.sum((preds == target_class) & (labels == target_class))
    fn = jnp.sum((preds != target_class) & (labels == target_class))
    return tp / (tp + fn + 1e-8)

# Cross-check: PyTorch result vs JAX result must match within 1e-4
assert abs(torch_recall - float(jax_recall)) < 1e-4

Introducing CI/CD — GitHub Actions

The second explicit learning goal was CI/CD. Before this project, my testing was ad-hoc. Here I set up a full GitHub Actions pipeline that runs on every push — enforcing that the codebase is always in a working state.

Setting this up from scratch was instructive: writing the YAML workflow, understanding the Poetry environment setup in CI context, and structuring Pytest tests around a model that depends on random seeds all required real problem-solving. The discipline it enforces — broken code never reaches main — is something I now consider non-negotiable for any serious project.

git push → GitHub Actions trigger on: push, pull_request
Poetry install — locked environment fully reproducible across machines
Pytest — unit tests forward pass · loss shape · recall metric
Badge → README always green or fix it before merging

Pipeline

01
BloodMNIST
MedMNIST v2 auto-download
02
β-VAE Training
50 epochs · PyTorch
03
JAX Validation
JIT metrics · cross-check
04
t-SNE + Probing
64-dim → 2D clusters
05
96% Recall
Class 7 — unsupervised

Tech Stack

PyTorch JAX Python 3.12 GitHub Actions Poetry Pytest scikit-learn t-SNE MedMNIST v2 Seaborn
DeepLocalAdaptation 2 / 3 QST–FST Framework