Skip to content

nestordemeure/xmap

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Xmap

An alternative xmap implementation for Jax.

Jax famously has a vmap vectorizing function, which lets you batch a function along a given axis. Making it trivial to take a function and apply it to arrays efficiently.
However, lines like vmap(f, (0, 1), 0) can lack readability, to the point that they are often commented along the lines of ([b,a], [a,b]) -> [b].

The experimental.maps namespace contains a xmap function that solves this problem elegantly using named axis (the prior example would become something like xmap(f, in_axes=(['b', ...], [..., 'b']), out_axes=['b'])) and lets you vectorize over multiple axis simultaneously (something that would require consecutive calls to vmap and quickly becomes messy).
But, xmap encapsulate a lot of other things (like resource repartition and jitting the code), does not play well with static arguments, and overall is in the experimental namespace for a reason: it is the part of your codebase that is the most likely to break when you update your Jax version.

This library provide you with a xmap implementation that:

  • Focuses on giving you a good interface to vectorize a function along one or more axes (it implements none of the other functionalities supported by the official implementation),
  • jits down to several calls to vmap (making it resilient to update on Jax's side),
  • provides basic jit-time checks to try and catch common error (such as forgetting an argument or passing something of the wrong shape / type) with nice error messages.

Installation

Copy the xmap.py file into your project.

Usage

Here is a usage example taken from production code:

f_batched = xmap(
    f,
    in_axes={
        'step_length': int,
        'det_data': [..., ...],
        'use_flag': bool,
        'flag_data': ['n_intervals', ...],
        'flag_mask': int,
        'amplitude_offset': int,
        'amplitude_view_offset': ['n_intervals'],
        'block_indices': ['n_intervals', 'blocks_per_interval'],
        'interval_starts': ['n_intervals'],
        'interval_ends': ['n_intervals'],
    },
    out_axes=(
        ['n_intervals', 'blocks_per_interval'],
        ['n_intervals', 'blocks_per_interval'],
    ),
)

Usage is similar to the official implementation. Note that:

  • Inputs / outputs are described as a single value (of a given type) or an array (described by a list of named or unnamed dimensions),
  • inputs axes are described as a dictionary (which lets us name each input for improved readability),
  • Inputs / outputs can be single values but also tuples and other pytrees.

All axes named in the inputs have to be present in the outputs. The result function (f_batched) is equivalent to a function that would run one loop per named axis and pass the resulting sliced data to f.

About

Alternative xmap implementation for JAX

Topics

Resources

License

Stars

Watchers

Forks

Contributors

Languages