Unsupervised deep learning pipeline for single-cell phenotyping — 96% recall on rare cell populations, built as a hands-on introduction to JAX and CI/CD engineering.
BioVAE-Phenotyper was built with two parallel objectives. The first was scientific: can an unsupervised model identify distinct cell phenotypes from microscopy data without any labels? The second was engineering: use this project as a concrete introduction to JAX and CI/CD pipelines — two tools I wanted to learn properly, not just in a tutorial.
The result is a project that genuinely works biologically (96% recall on a rare cell class) while also serving as a demonstration that research code can be held to software engineering standards. Every design choice — Poetry for dependency management, GitHub Actions for automated testing, JAX for independent numerical validation — was deliberate and hands-on.
The dataset is BloodMNIST from MedMNIST v2: standardised 28×28 microscopy images of 8 blood cell types, a realistic benchmark for unsupervised cell phenotyping.
The key result: 96% recall on Class 7 — a rare blood cell phenotype — using purely unsupervised latent feature extraction. No labels were provided during training. In the t-SNE projection, Class 7 forms a distinct, isolated cluster, demonstrating the model's ability to identify specific biological states in a high-content screen.
The t-SNE projection of the 64-dimensional latent space shows clearly separated cell phenotype clusters. The reconstruction panel confirms that the 3-layer architecture preserves critical morphological details — nucleus shape and texture — despite heavy compression.
t-SNE · 64-dim latent space · Class 7 isolated
Top: original cells · Bottom: VAE reconstructions
The model is a custom β-VAE with 3 convolutional layers and Batch Normalization. The β weighting on the KL term pushes the model to learn a more disentangled latent space — separating biological factors of variation rather than just compressing pixels. The tradeoff between reconstruction quality and latent structure is tuned for phenotype separability rather than visual fidelity.
# beta-VAE loss — higher beta = more disentangled latent space def loss(recon_x, x, mu, log_var, beta=4.0): recon = F.binary_cross_entropy(recon_x, x, reduction='sum') kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) return recon + beta * kld
One of the explicit goals of this project was to get hands-on with JAX. Rather than rewriting the full training loop (PyTorch is the right tool for that), I used JAX for what it does best: JIT-compiled numerical verification of key metrics after training.
Concretely, recall per class and reconstruction metrics are independently computed in JAX and cross-checked against the PyTorch outputs. This hybrid approach served two purposes: it caught a subtle numerical discrepancy in my early loss implementation, and it gave me practical experience with JAX's functional style and jax.jit compilation model — meaningfully different from PyTorch's imperative paradigm.
# JAX — JIT-compiled recall metric for independent validation @jax.jit def compute_recall(preds, labels, target_class): tp = jnp.sum((preds == target_class) & (labels == target_class)) fn = jnp.sum((preds != target_class) & (labels == target_class)) return tp / (tp + fn + 1e-8) # Cross-check: PyTorch result vs JAX result must match within 1e-4 assert abs(torch_recall - float(jax_recall)) < 1e-4
The second explicit learning goal was CI/CD. Before this project, my testing was ad-hoc. Here I set up a full GitHub Actions pipeline that runs on every push — enforcing that the codebase is always in a working state.
Setting this up from scratch was instructive: writing the YAML workflow, understanding the Poetry environment setup in CI context, and structuring Pytest tests around a model that depends on random seeds all required real problem-solving. The discipline it enforces — broken code never reaches main — is something I now consider non-negotiable for any serious project.