-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup.py
More file actions
86 lines (77 loc) · 2.51 KB
/
setup.py
File metadata and controls
86 lines (77 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import pathlib
from pybind11.setup_helpers import Pybind11Extension, build_ext
from setuptools import find_packages, setup
def get_version() -> str:
init = open(os.path.join("flowrl", "__init__.py"), "r").read().split()
return init[init.index("__version__") + 2][1:-1]
ROOT_DIR = pathlib.Path(__file__).parent
README = (ROOT_DIR / "README.md").read_text()
VERSION = get_version()
def get_base_requirements():
return [
"jax[cuda12]==0.5.3",
"flax==0.10.5",
"orbax-checkpoint==0.11.23",
"gymnasium",
'shimmy==1.3.0',
'Cython<3',
'six==1.17.0',
"tqdm",
"hydra-core",
"distrax",
"tensorboardX==2.6.2.2",
"scikit-learn==1.6.1",
"wandb",
"matplotlib",
"imageio[ffmpeg]"
]
def get_install_requires():
return get_base_requirements()
def get_extras_require():
return {
'offline': [
"dm_control<=1.0.20",
"mujoco<=3.1.6",
],
'online': [
"dm_control==1.0.27",
"mujoco==3.2.7",
],
"humanoidbench": [
"dm_control==1.0.20",
"mujoco==3.1.6",
# commented out because it unfortunately depends on torch (see https://github.com/carlosferrazza/humanoid-bench/issues/65)
# "humanoid-bench @ git+https://github.com/carlosferrazza/humanoid-bench.git",
],
"isaaclab": [
"isaaclab[isaacsim,all]==2.3.2.post1", # install with --extra-index-url https://pypi.nvidia.com
],
}
def get_ext_modules():
return [
Pybind11Extension(
"data_structure",
["flowrl/data_structure/data_structure.cc"],
define_macros=[("VERSION_INFO", "\"{}\"".format(VERSION))],
)
]
setup(
name = "flowrl",
version = VERSION,
description = "A library desgined for flow-based RL algorithms",
long_description = README,
long_description_content_type = "text/markdown",
url = "https://github.com/typoverflow/flow-rl",
author = "typoverflow",
author_email = "typoverflow@gmail.com",
license = "MIT",
packages = find_packages(),
include_package_data = True,
tests_require=["pytest", "mock"],
python_requires=">=3.11",
ext_modules = get_ext_modules(),
cmdclass={"build_ext": build_ext},
install_requires = get_install_requires(),
extras_require = get_extras_require(),
)