Skip to content

Commit e0fde7a

Browse files
authored
Merge pull request #505 from FlorianPfaff/jaxnew3
Added jax backend
2 parents 21146a0 + eed9867 commit e0fde7a

11 files changed

Lines changed: 745 additions & 12 deletions

File tree

.github/workflows/mega-linter.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ jobs:
102102
run: |
103103
sed 's/==.*//' requirements-dev.txt > requirements-dev_no_version.txt
104104
105-
- name: Remove torch, tirton, and nvidia entries (unsupported by alpine)
105+
- name: Remove torch, triton, jax, and nvidia entries (unsupported by alpine)
106106
if: steps.cache-wheels.outputs.cache-hit != 'true'
107107
run: |
108108
sed -i '/^torch/d' requirements-dev_no_version.txt
109109
sed -i '/^nvidia/d' requirements-dev_no_version.txt
110110
sed -i '/^triton/d' requirements-dev_no_version.txt
111+
sed -i '/^jax/d' requirements-dev_no_version.txt
111112
112113
- name: Run CMake to find LAPACK
113114
if: steps.cache-wheels.outputs.cache-hit != 'true'

.github/workflows/tests.yml

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,43 @@ jobs:
7979
env:
8080
PYTHONPATH: ${{ github.workspace }}
8181

82-
- name: Upload pytorch test result artifact
82+
test-jax:
83+
runs-on: ubuntu-latest
84+
permissions:
85+
checks: write
86+
pull-requests: write
87+
88+
steps:
89+
- name: Check out repository
90+
uses: actions/checkout@v4
91+
92+
- name: Set up Python
93+
uses: actions/setup-python@v4
94+
with:
95+
python-version: "3.11"
96+
97+
- name: Install dependencies
98+
run: |
99+
python -m pip install --upgrade pip
100+
python -m pip install poetry
101+
poetry env use python
102+
poetry install --extras "healpy_support" --extras "jax_support"
103+
104+
- name: Run tests with jax backend
105+
run: |
106+
export PYRECEST_BACKEND=jax
107+
poetry run python -m pytest --rootdir . -v --strict-config --junitxml=junit_test_results_jax.xml ./pyrecest
108+
env:
109+
PYTHONPATH: ${{ github.workspace }}
110+
111+
- name: Upload jax test result artifact
83112
uses: actions/upload-artifact@v3
84113
with:
85-
name: pytorch-test-results
86-
path: junit_test_results_pytorch.xml
114+
name: jax-test-results
115+
path: junit_test_results_jax.xml
87116

88117
publish-results:
89-
needs: [test-numpy, test-pytorch]
118+
needs: [test-numpy, test-pytorch, test-jax]
90119
runs-on: ubuntu-latest
91120
if: always()
92121
permissions:
@@ -104,3 +133,4 @@ jobs:
104133
files: |
105134
test-results/numpy-test-results/junit_test_results_numpy.xml
106135
test-results/pytorch-test-results/junit_test_results_pytorch.xml
136+
test-results/pytorch-test-results/junit_test_results_jax.xml

poetry.lock

Lines changed: 162 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "pyrecest"
33
description = "Framework for recursive Bayesian estimation in Python."
44
readme = "README.md"
55
authors = ["Florian Pfaff <pfaff@kit.edu>"]
6-
version = "0.6.1"
6+
version = "0.7.0"
77

88
[tool.poetry.dependencies]
99
python = ">=3.10,<3.13"
@@ -20,10 +20,14 @@ shapely = "*"
2020
[tool.poetry.extras]
2121
healpy_support = ["healpy"]
2222
pytorch_support = ["torch"]
23+
jax_support = ["jax", "jaxlib", "autograd"]
2324

2425
[tool.poetry.group.dev.dependencies]
2526
healpy = "*"
2627
torch = "*"
28+
jax = "*"
29+
jaxlib = "*"
30+
autograd = "*"
2731
autopep8 = "^2.0.2"
2832
pytest = "*"
2933
parameterized = "*"

0 commit comments

Comments
 (0)