JAX

Transformable numerical computing at scale

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

https://jax.readthedocs.io/

Via:

  1. Tune Llama3 405B on AMD MI300x (our journey) - https://publish.obsidian.md/felafax/pages/Tune+Llama3+405B+on+AMD+MI300x+(our+journey)
  2. PyTorch is dead. Long live JAX. https://neel04.github.io/my-website/blog/pytorch_rant/
  3. Using JAX to accelerate our research https://deepmind.google/discover/blog/using-jax-to-accelerate-our-research/


AJAX