There was recently an interesting blog post by nvidia about this: https://developer.nvidia.com/blog/accelerating-long-context-model-training-in-jax-and-xla/
There was recently an interesting blog post by nvidia about this: https://developer.nvidia.com/blog/accelerating-long-context-model-training-in-jax-and-xla/