JAX
Transformable numerical computing at scale
JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.
Via:
- Tune Llama3 405B on AMD MI300x (our journey) - https://publish.obsidian.md/felafax/pages/Tune+Llama3+405B+on+AMD+MI300x+(our+journey)
- PyTorch is dead. Long live JAX. https://neel04.github.io/my-website/blog/pytorch_rant/
- Using JAX to accelerate our research https://deepmind.google/discover/blog/using-jax-to-accelerate-our-research/
Tags:
Programming, Language