#JAX
diffrax
Diffrax is a JAX-based library offering numerical solvers for differential equations, including ordinary, stochastic, and controlled. It supports GPU and autodifferentiation, features vmappable operations and PyTree states, and is efficient for neural differential equations. The unified framework is suitable for academic and practical computational tasks.
tf2jax
TF2JAX provides a method for converting TensorFlow functions to JAX-compatible versions, allowing use of JAX features like JIT and autograd. This library aids in integrating and optimizing TensorFlow models within JAX codebases and supports various serialization formats and custom gradients. As the API is experimental, it may be unstable and requires thorough testing. Community contributions are encouraged to enhance operation support.
gymnax
Gymnax integrates JAX acceleration into gym APIs, enhancing speed and efficiency in reinforcement learning environments. Supporting various settings from classic control to bsuite, it employs JAX primitives like 'jit', 'vmap', and 'pmap' for high-throughput experiments. The project offers control over environments, beneficial for meta reinforcement learning and evolutionary optimization, including implementing the Anakin sub-architecture. Speed tests on NVIDIA A100 GPUs illustrate its capabilities, making it suitable for scalable RL experiments. Tutorial resources are available for users to start exploring its features.
EasyLM
EasyLM provides an efficient framework for pre-training, fine-tuning, evaluating, and deploying large language models with JAX/Flax. It supports TPU/GPU scaling across multiple hosts using JAX's pjit utility and integrates with Huggingface's tools for straightforward customization. Models like LLaMA and its successors are available. Participate in discussions on JAX-based LLM training on Discord for further insights.
jax-triton
The jax-triton repository facilitates effective JAX and Triton integration for optimized GPU computations. It utilizes 'jax_triton.triton_call' to implement Triton functions within 'jax.jit'-compiled routines. Users can begin with examples like Triton's vector addition tutorial and progress to advanced tasks such as fused attention. Installation is straightforward, supporting both stable and nightly Triton releases, with prerequisite CUDA-compatible JAX. Developers can participate by cloning the repository and conducting editable installs, supported by tests using 'pytest'.
Mava
The project facilitates advanced research in multi-agent reinforcement learning using scalable JAX-based algorithms, enabling efficient parallel execution on various devices. It supports prominent MARL frameworks like CTDE and DTDE and includes adaptable environment wrappers for diverse tasks such as robotics and foraging. Mava ensures precise evaluation with comprehensive JSON logging, aiding in detailed analysis and performance optimization across different hardware environments. The open-source model supports community contributions and seamless integration into complex research workflows.
awesome-jax
This comprehensive guide features a wide range of JAX libraries, projects, and resources suited for researchers in high-performance machine learning using GPUs and TPUs. It includes neural network libraries such as Flax, Haiku, and Objax, along with specialized tools like Jraph for graph neural networks and NumPyro for probabilistic programming. The guide also highlights pioneering libraries like FedJAX for federated learning and Levanter for scalable model deployment. Discover models, projects, tutorials, videos, and community resources that make this an essential resource for maximizing the use of JAX.
DMFF
Discover DMFF, a cutting-edge Jax-based Python package for differentiable molecular force field modeling. This tool streamlines parameterization and handles complex potentials, including polarizable models, enhancing simulations for systems like water, proteins, and organic compounds. By incorporating AI, DMFF automates parameter fitting and merges machine learning with traditional force fields, enhancing predictive accuracy and adaptability for new molecules. Benefit from efficient GPU-accelerated computations for detailed molecular analysis.
pyprobml
The pyprobml project offers Python 3 code to recreate illustrations from the books 'Probabilistic Machine Learning: An Introduction' and 'Advanced Topics'. It leverages libraries such as numpy, scipy, and matplotlib, alongside JAX, Tensorflow, and Torch frameworks, making it a significant resource for machine learning research. Utilize environments like Colab for easy notebook execution or configure it locally with detailed instructions. Stay informed with available dashboards and deepen knowledge of probabilistic models with this extensive and well-documented repository, which is currently in maintenance mode.
dm_pix
PIX utilizes JAX to provide advanced image processing functions, promoting efficient optimization and parallelization. It integrates features such as jax.jit and jax.vmap, offering essential tools for machine learning tasks. Easily installed with pip, PIX ensures reliable performance in parallel tasks and includes a thorough testing suite. Contributions are welcomed to enhance its capabilities.
evojax
EvoJAX is a toolkit designed for hardware-accelerated neuroevolution, utilizing multi-TPU/GPU setups for parallel processing. Based on the JAX library, it facilitates quick experiments across various learning tasks by using NumPy for just-in-time compilation, aiming to improve computational efficiency. The toolkit supports customization via its trainer and simulation manager components and encourages community contributions in advanced algorithms and task development.
flashbax
Flashbax offers optimized replay buffers for JAX, supporting both academic and industrial reinforcement learning applications. Its framework includes diverse buffer types, such as Flat, Trajectory, and Prioritised Buffers, emphasizing efficient memory and prioritisation. Ideal for algorithms using recurrent networks, Flashbax integrates effortlessly into various projects, enhancing speed and performance in RL environments. Discover detailed examples and benchmarks to effectively utilize Flashbax in your reinforcement learning projects.
optimistix
Optimistix, a JAX library, offers solutions for nonlinear problems like root finding and least squares, with features such as modular optimizers and fast runtimes. It integrates with Optax and supports GPU/TPU acceleration, highlighting JAX's strengths like autodifferentiation. Simple installation with Python and JAX is required. Access examples and further reading in the documentation. Academic users can contribute by citing it in their work.
grok-1
This repository offers JAX example code to run the Grok-1 model, characterized by its Mixture of 8 Experts architecture and 314 billion parameters. The process requires significant GPU memory and involves 64 layers and 48 attention heads. The SentencePiece tokenizer encompasses 131,072 tokens with features like rotary embeddings, activation sharding, and 8-bit quantization for sequences up to 8,192 tokens. Weights can be downloaded through a magnet link or the HuggingFace Hub under the Apache 2.0 license.
ttt-lm-pytorch
Discover sequence modeling layers offering linear complexity and expressive hidden states to enhance RNN efficiency for extensive contexts. This PyTorch version highlights Test-Time Training (TTT) layers, with hidden states that adapt during testing. It features TTT-Linear and TTT-MLP layers geared for inference, suitable for seamless use with Huggingface Transformers.
mctx
Mctx provides a JAX-native implementation of Monte Carlo tree search algorithms, supporting models like AlphaZero and MuZero. It utilizes JIT-compilation for enhanced performance. The library is ideal for reinforcement learning research and offers configurable search methods. Installation is available through PyPI or GitHub, with practical examples included for ease of use.
paxml
Utilize PaxML for efficient machine learning experiment configuration and execution on Jax-powered Cloud TPU frameworks. This tool enables scalable machine learning task management on both TPUs and GPUs. Available for installation from PyPI or GitHub, PaxML integrates support for complex models such as GPT-3, supported by detailed documentation and Jupyter Notebook tutorials for an enhanced educational experience. Take advantage of NVIDIA enhancements for superior GPU performance, promoting efficient operation across various computational scenarios.
learned_optimization
Explore a robust research codebase for designing and evaluating learned optimizers with JAX. The project features tools for meta-training dynamic systems and includes comprehensive tutorials via Colab notebooks. Understand outer-training algorithms such as ES, PES, and truncated backprop through time, and engage with practical examples tailored for deep learning professionals. Access documentation to learn about creating custom tasks, developing gradient estimators, and applying meta-training techniques with Gradient Learner, aimed at advancing research in learned optimization.
jumanji
Explore 22 scalable reinforcement learning environments crafted with JAX to boost research efficiency. These environments, ranging from basic games to intricate NP-complete challenges, support research applications in both academia and industry. Seamlessly integrates with popular frameworks such as OpenAI Gym and DeepMind Env, providing practical examples for easy implementation. Suitable for novice and experienced users alike.
jaxdf
Jaxdf provides tools for developing numerical models of physical systems and solving partial differential equations using JAX. Its seamless integration with JAX facilitates customized, differentiable models for diverse research applications, including neural network layers and physics-based loss functions. The framework enhances the use of differentiable programming in complex computations, promoting efficiency and adaptability in scientific research.
equinox
Equinox is a versatile JAX library designed to simplify model construction using PyTorch-inspired syntax. Offering advanced capabilities like PyTree manipulation and runtime error management, it integrates effortlessly within the JAX ecosystem, ensuring compatibility with various operations and libraries. Equinox serves as a solid choice for developers transitioning from Flax or Haiku, thanks to its additional features and enhanced model optimization through JIT and grad boundaries. It requires Python 3.9+ and JAX 0.4.13+ for installation, making it a practical, non-framework tool for researchers and developers.
synjax
SynJax is a neural network library designed for JAX, emphasizing structured probability distributions such as Linear Chain CRF and Semi-Markov CRF. It supports essential operations like log-probabilities and entropy calculations, utilizing JAX transformations for optimized performance. SynJax provides easy installation aligned with JAX's guidelines and includes practical examples in its notebooks for effective learning.
jax
JAX is a Python library for efficient numerical computing and large-scale machine learning on accelerators like GPUs and TPUs. It provides automatic differentiation for Python and NumPy functions and compiles programs for optimal execution. With transformations like 'grad' for differentiation and 'jit' for just-in-time compilation, JAX simplifies the development of sophisticated algorithms. Contributions are welcomed through feedback and bug reporting.
blackjax
BlackJAX is a versatile library designed for JAX, providing efficient sampling tools suited for both CPU and GPU environments. It enables developers and researchers to explore modular algorithms, facilitating advancements in probabilistic programming through customizable techniques. BlackJAX bridges the need for independent sampling solutions, offering reusable code that enhances Bayesian inference research.
EasyDeL
EasyDeL is an open-source framework designed to enhance machine learning model training with a focus on JAX/Flax optimization for TPU and GPU. It supports a variety of model architectures and includes specialized trainers and efficient serving engines. Features include customizable sharding strategies, advanced quantization, and a flexible framework suitable for cutting-edge ML research. Regular updates provide access to the latest technologies, enabling scalable model training and experimentation with easy-to-use APIs.
brax
Brax, developed with JAX, is a high-performance physics engine that supports fast simulations in robotics, human perception, and reinforcement learning. It is optimized for acceleration hardware, making it scalable for simulations across multiple devices. Brax includes training algorithms such as PPO and SAC, utilizing its differentiable simulator for analytical policy gradients. With four physics pipelines, including MuJoCo XLA and Spring, Brax adapts to diverse simulation needs. It offers Colab notebook support and integrates with frameworks like PyTorch. Installation is flexible with pip, Conda, or source, fitting different computational environments.
evosax
Evosax enables efficient neuroevolution through JAX and XLA, incorporating both classic (CMA-ES) and modern (OpenAI-ES) strategies. Utilize auto-vectorization and parallelization for optimal performance on accelerators, simplifying cycle management. Evosax provides robust implementations for scalable evolutionary computations, ideal for high-throughput needs without async complexity.
GradCache
Gradient Cache overcomes GPU/TPU memory limits to efficiently scale contrastive learning. Compatible with PyTorch and JAX, it supports dense passage retrieval on single GPUs, lowering hardware costs with high FLOP systems. Suitable for deep learning, it supports mixed precision and distributed training, offering functional and decorator tools for streamlined cache implementation.
axlearn
AXLearn, built on JAX and XLA, provides scalable and object-oriented tools for crafting extensive deep learning models. It integrates with Flax and Hugging Face transformers, and supports NLP, computer vision, and speech recognition tasks. Designed for public cloud deployment, it handles immense model sizes efficiently, thanks to its global computation paradigm. Comprehensive documentation guides users in configuring models using modular components.
s2fft
This Python package provides efficient spherical Fourier and Wigner transforms using JAX and PyTorch, enabling differentiable operations on GPUs and TPUs. Features include spin spherical harmonics and optimized calculations for various angular resolutions. The package enhances capabilities with PyTorch and JAX, particularly for HEALPix sampling, making it ideal for scientific and machine learning applications.
scenic
Scenic offers a robust framework for creating attention-based computer vision models, supporting tasks like classification and segmentation across multiple modalities. Utilizing JAX and Flax, it simplifies large-scale training through efficient pipelines and established baselines, ideal for research. Explore projects with state-of-the-art models like ViViT. Scenic provides adaptable solutions for both newcomers and experts, facilitating easy integration into existing workflows.
dm-haiku
Haiku is a compact neural network library for JAX that offers an object-oriented programming approach integrated with JAX function transformations. Created by the developers of Sonnet for TensorFlow, Haiku focuses on efficient parameter and state management without adding extra frameworks. While currently in maintenance mode focused on bug fixes and compatibility, Haiku still offers key features like hk.Module and hk.transform, facilitating the transition from TensorFlow to JAX. It caters to large-scale project requirements and supports the incorporation of stochastic models and non-trainable states, extending to distributed model training through jax.pmap. Well-documented resources and examples assist users in leveraging Haiku effectively.
Feedback Email: [email protected]