PAXlib augments PyTorch with JAX-inspired abstractions including functional state management and pytree-compatible module structures. It enables researchers to write more modular and composable deep learning code within a PyTorch workflow, without requiring a full migration to JAX.
This page was last edited on 2024-04-09.
This page was last edited on 2024-04-09.