Skip to content

Latest commit

 

History

History
35 lines (28 loc) · 1.25 KB

README.md

File metadata and controls

35 lines (28 loc) · 1.25 KB

Monkfish: Distributed latent video model training on TPUs (and other stuff maybe)

This is the training code for a 2 stage autoregressive video model.

TODO:

  • Chunked scatter/gather/init functions
  • Parallel model save/load
  • Dtype conversions at scatter/gather/init functions
  • Distributed data loading
  • Distributed model training
  • Multi-platform file backend via PyFilesystem2
  • GPU Support
  • SLURM Support
  • Kubernetes Support
  • Text conditional diffusion Transformer
  • (5/6)-D parallelism
    • FSDP
    • Ring attention
    • Pipeline parallelism
    • Async swarm
  • Llama 3 support
  • Sophisticated logging (Logfire/SQL database)

References For Developers

Parameter scaling:

Jax sharding:

Data loader Design: