Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions embodichain/toolkits/scaffold/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

"""Scaffold new EmbodiChain task environments (in-repo or external extension)."""

from __future__ import annotations

from embodichain.toolkits.scaffold.cli import main
from embodichain.toolkits.scaffold.generator import generate_task
from embodichain.toolkits.scaffold.spec import TaskSpec

__all__ = ["TaskSpec", "generate_task", "main"]
22 changes: 22 additions & 0 deletions embodichain/toolkits/scaffold/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

from embodichain.toolkits.scaffold.cli import main

if __name__ == "__main__":
main()
191 changes: 191 additions & 0 deletions embodichain/toolkits/scaffold/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

import argparse
import sys
from pathlib import Path

from embodichain.toolkits.scaffold.generator import generate_task, print_summary
from embodichain.toolkits.scaffold.naming import default_gym_id
from embodichain.toolkits.scaffold.spec import INREPO_CATEGORIES, TaskSpec


def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="embodichain-new-task",
description="Scaffold EmbodiChain task environments (in-repo or external extension).",
)
parser.add_argument(
"--target",
choices=("inrepo", "extension"),
default="inrepo",
help="Generate inside EmbodiChain repo or a new extension project.",
)
parser.add_argument(
"--workflow",
choices=("demo", "rl", "config-only"),
required=False,
help="Task workflow: expert demo, RL, or config-only env class.",
)
parser.add_argument(
"--name",
"--task-name",
dest="task_snake",
help="Task name in snake_case (e.g. pick_place).",
)
parser.add_argument("--gym-id", help="Gym registration id (e.g. PickPlace-v1).")
parser.add_argument(
"--category",
choices=INREPO_CATEGORIES,
default="tableware",
help="In-repo task category (ignored for RL workflow).",
)
parser.add_argument(
"--robot-preset",
choices=("cobot_magic", "ur5_minimal"),
default="cobot_magic",
help="Robot/sensor/light preset for gym JSON.",
)
parser.add_argument(
"--max-episode-steps", type=int, default=300, help="Max steps per episode."
)
parser.add_argument(
"--max-episodes", type=int, default=5, help="Episodes in gym JSON metadata."
)
parser.add_argument(
"--reward-style",
choices=("json", "python"),
default="json",
help="RL rewards in gym JSON or Python get_reward().",
)
parser.add_argument(
"--project-name",
help="Extension project name (pyproject name).",
)
parser.add_argument(
"--package-name",
help="Extension Python package name (snake_case).",
)
parser.add_argument(
"--output-dir",
type=Path,
help="Extension output directory (default: ./<project-name>).",
)
parser.add_argument(
"--no-test", action="store_true", help="Skip generating test stub."
)
parser.add_argument(
"--no-black", action="store_true", help="Skip running black on generated files."
)
parser.add_argument(
"--init-git",
action="store_true",
help="Run git init in extension output (extension target only).",
)
parser.add_argument(
"--force", action="store_true", help="Overwrite existing files."
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print paths only; do not write files.",
)
parser.add_argument(
"-i",
"--interactive",
action="store_true",
help="Prompt for missing options.",
)
return parser


def _prompt_choice(label: str, options: list[str], default: str) -> str:
print(f"{label} [{'/'.join(options)}] (default: {default}): ", end="")
value = input().strip()
return value if value in options else default


def _interactive_fill(args: argparse.Namespace) -> None:
if args.target is None:
args.target = _prompt_choice("Target", ["inrepo", "extension"], "inrepo")
if args.workflow is None:
args.workflow = _prompt_choice(
"Workflow", ["demo", "rl", "config-only"], "demo"
)
if args.task_snake is None:
args.task_snake = input("Task name (snake_case): ").strip()
if args.gym_id is None and args.task_snake:
args.gym_id = default_gym_id(args.task_snake)
print(f"Gym id (default: {args.gym_id}): ", end="")
custom = input().strip()
if custom:
args.gym_id = custom
if args.target == "inrepo" and args.workflow != "rl":
if args.category == "tableware":
args.category = _prompt_choice(
"Category", list(INREPO_CATEGORIES), "tableware"
)
if args.target == "extension":
if args.package_name is None:
args.package_name = args.task_snake
if args.project_name is None:
args.project_name = args.package_name.replace("_", "-")


def main(argv: list[str] | None = None) -> int:
parser = _build_parser()
args = parser.parse_args(argv)

if args.interactive or args.workflow is None or args.task_snake is None:
_interactive_fill(args)

if args.workflow is None:
parser.error("--workflow is required (or use --interactive).")
if args.task_snake is None:
parser.error("--name is required (or use --interactive).")

try:
spec = TaskSpec(
task_snake=args.task_snake,
workflow=args.workflow,
target=args.target,
gym_id=args.gym_id,
category=args.category,
robot_preset=args.robot_preset,
max_episode_steps=args.max_episode_steps,
max_episodes=args.max_episodes,
reward_style=args.reward_style,
project_name=args.project_name,
package_name=args.package_name,
output_dir=args.output_dir,
include_test=not args.no_test,
dry_run=args.dry_run,
force=args.force,
run_black=not args.no_black,
init_git=args.init_git,
)
paths = generate_task(spec)
print_summary(spec, paths)
except (ValueError, FileExistsError) as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1
return 0


if __name__ == "__main__":
raise SystemExit(main())
153 changes: 153 additions & 0 deletions embodichain/toolkits/scaffold/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2021-2026 DexForce Technology Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import annotations

import shutil
from pathlib import Path

from embodichain.toolkits.scaffold import post_process
from embodichain.toolkits.scaffold.presets import gym_config_to_json
from embodichain.toolkits.scaffold.render import (
render_extension_file,
render_task_py,
render_test_py,
)
from embodichain.toolkits.scaffold.spec import TaskSpec

_APACHE_LICENSE = """Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/LICENSE-2.0
"""


def _write(path: Path, content: str, spec: TaskSpec) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists() and spec.force:
path.unlink()
path.write_text(content, encoding="utf-8")


def _collect_paths(spec: TaskSpec) -> list[Path]:
paths = list(spec.all_output_paths())
if (
spec.target == "extension"
and spec.output_dir is not None
and spec.is_new_extension_project()
and not spec.dry_run
):
if (
spec.output_dir.exists()
and any(spec.output_dir.iterdir())
and not spec.force
):
raise FileExistsError(
f"Output directory is not empty: {spec.output_dir} (use --force)"
)
post_process.check_collisions(spec, paths)
return paths


def generate_task(spec: TaskSpec) -> list[Path]:
"""Generate task scaffold files. Returns list of written paths."""
_collect_paths(spec)

if spec.dry_run:
return spec.all_output_paths()

written: list[Path] = []

task_py = spec.task_py_path()
_write(task_py, render_task_py(spec), spec)
written.append(task_py)

gym_json = spec.gym_json_path()
_write(gym_json, gym_config_to_json(spec), spec)
written.append(gym_json)

if test_path := spec.test_py_path():
_write(test_path, render_test_py(spec), spec)
written.append(test_path)

if spec.target == "inrepo":
post_process.patch_tasks_init(spec)
elif spec.is_new_extension_project():
_generate_extension_tree(spec, written)
else:
post_process.patch_tasks_init(spec)

if spec.run_black:
post_process.run_black(written)

if spec.target == "extension" and spec.init_git and spec.output_dir:
post_process.init_git_repo(spec.output_dir)

return written


def _generate_extension_tree(spec: TaskSpec, written: list[Path]) -> None:
assert spec.output_dir is not None
assert spec.package_name is not None
root = spec.output_dir
pkg = root / spec.package_name

license_text = _APACHE_LICENSE
repo_license = spec.repo_root / "LICENSE"
if repo_license.is_file():
license_text = repo_license.read_text(encoding="utf-8")

files: list[tuple[Path, str]] = [
(root / "pyproject.toml", render_extension_file("pyproject.toml", spec)),
(root / "README.md", render_extension_file("README.md", spec)),
(root / "LICENSE", license_text),
(root / ".gitignore", render_extension_file("gitignore", spec)),
(root / "VERSION", "0.1.0\n"),
(pkg / "VERSION", "0.1.0\n"),
(pkg / "__init__.py", render_extension_file("package_init.py", spec)),
(pkg / "tasks" / "__init__.py", render_extension_file("tasks_init.py", spec)),
(pkg / "data" / "__init__.py", ""),
(pkg / "data" / "constants.py", render_extension_file("constants.py", spec)),
(pkg / "utils" / "__init__.py", ""),
(root / "scripts" / "run_env.py", render_extension_file("run_env.py", spec)),
]

for path, content in files:
_write(path, content, spec)
written.append(path)

post_process.patch_tasks_init(spec)


def print_summary(spec: TaskSpec, paths: list[Path]) -> None:
"""Print generation summary and next-step commands."""
mode = "DRY RUN — would write" if spec.dry_run else "Wrote"
print(f"\n{mode} {len(paths)} file(s):\n")
for p in paths:
print(f" {p}")

print("\nNext steps:\n")
if spec.target == "inrepo":
gym_cfg = spec.gym_json_path().relative_to(spec.repo_root)
print(f" python embodichain/lab/scripts/run_env.py " f"--gym_config {gym_cfg}")
print(" pytest " + str(spec.test_py_path().relative_to(spec.repo_root)))
else:
assert spec.output_dir is not None
print(f" cd {spec.output_dir}")
print(" pip install -e .")
print(
f" python scripts/run_env.py "
f"--gym_config configs/{spec.task_snake}/gym.json --headless"
)
Loading
Loading