This repository implements a modular, variational-inference-based framework for reconstructing continuous, branching developmental trajectories from single-cell RNA-seq data, inspired by the original Continuous-State Hidden Markov Model (CSHMM) proposed by Lin and Bar-Joseph (2019).
Traditional pseudotime inference methods either reduce data to low-dimensional embeddings and order cells deterministically, or use probabilistic state models with discrete assignments. The CSHMM framework bridges this gap by modeling cell states continuously along branching paths, capturing both noise and expression dynamics.
Our implementation expands on CSHMM by introducing a modular PyTorch-based variational inference framework with flexible training configurations including:
- Minibatching
- Curriculum learning (e.g., emission parameter freezing)
- Lagging variation training (alternate inference/generative updates)
- Pluggable trajectory and posterior models
Original Method:
Lin, C. & Bar-Joseph, Z. (2019). Continuous-state HMMs for modeling time-series single-cell RNA-Seq data. Bioinformatics, 35(22), 4707–4715.
DOI: 10.1093/bioinformatics/btz296
- Install dependencies:
pip install torch scanpy anndata numpy
-
Prepare your .h5ad dataset and initialize a trajectory graph (e.g., via Leiden + PAGA).
-
Run the trajectory_test.ipynb notebook to test preprocessing, trajectory construction, and model training.
Training is customizable via:
-
minibatch: Whether to train on mini-batches or full dataset
-
batch_size: Mini-batch size
-
freeze_emission_epochs: Number of epochs to freeze emission parameters
-
lagging: Whether to alternate E/M steps
-
inference_steps, generative_steps: Number of updates per epoch for each phase
-
See training/loop.py for implementation details.
This work builds on the CSHMM model introduced in: Lin, C. & Bar-Joseph, Z. Continuous-state HMMs for modeling time-series single-cell RNA-Seq data Bioinformatics (2019) DOI: 10.1093/bioinformatics/btz296