Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ dependencies = [
"torchvision>=0.15.0",
"uvicorn==0.38.0",
"zarr>=2.18",
"connectomics",
"pytorch-connectomics",
]

[tool.uv.sources]
connectomics = { path = "pytorch_connectomics", editable = true }
"pytorch-connectomics" = { path = "pytorch_connectomics", editable = true }

[dependency-groups]
dev = [
Expand Down
2 changes: 1 addition & 1 deletion pytorch_connectomics
Submodule pytorch_connectomics updated from 20ccfd to 0a0dce
6 changes: 3 additions & 3 deletions scripts/bootstrap.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ if (Test-Path (Join-Path $pytcDir '.git')) {
Push-Location $pytcDir
git fetch origin *> $null
$currentCommit = (git rev-parse HEAD) -replace '\s', ''
$targetCommit = '20ccfde'
$targetCommit = '0a0dceb'
if ($currentCommit -ne $targetCommit) {
git checkout $targetCommit
}
Pop-Location
} else {
git clone 'https://github.com/zudi-lin/pytorch_connectomics.git' $pytcDir
git clone 'https://github.com/PytorchConnectomics/pytorch_connectomics.git' $pytcDir
Push-Location $pytcDir
git checkout '20ccfde'
git checkout '0a0dceb'
Pop-Location
}

Expand Down
4 changes: 2 additions & 2 deletions scripts/setup_pytorch_connectomics.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ if ! command -v git >/dev/null 2>&1; then
exit 1
fi

PYTORCH_CONNECTOMICS_COMMIT="20ccfde"
REPO_URL="https://github.com/zudi-lin/pytorch_connectomics.git"
PYTORCH_CONNECTOMICS_COMMIT="0a0dceb"
REPO_URL="https://github.com/PytorchConnectomics/pytorch_connectomics.git"
PYTORCH_CONNECTOMICS_DIR="pytorch_connectomics"

if [ -d "${PYTORCH_CONNECTOMICS_DIR}/.git" ]; then
Expand Down
106 changes: 78 additions & 28 deletions server_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,33 +90,88 @@ def health():


BASE_DIR = pathlib.Path(__file__).resolve().parent.parent
PYTC_CONFIG_ROOT = BASE_DIR / "pytorch_connectomics" / "configs"
PYTC_BUILD_FILE = (
BASE_DIR / "pytorch_connectomics" / "connectomics" / "model" / "build.py"
PYTC_ROOT = BASE_DIR / "pytorch_connectomics"
PYTC_CONFIG_ROOTS = (
PYTC_ROOT / "tutorials",
PYTC_ROOT / "configs",
)
PYTC_CONFIG_SUFFIXES = (".yaml", ".yml")


def _iter_existing_config_roots():
for root in PYTC_CONFIG_ROOTS:
if root.exists() and root.is_dir():
yield root


def _list_pytc_configs() -> List[str]:
if not PYTC_CONFIG_ROOT.exists():
return []
return sorted(
[
str(path.relative_to(PYTC_CONFIG_ROOT)).replace("\\", "/")
for path in PYTC_CONFIG_ROOT.rglob("*.yaml")
]
configs = []
for root in _iter_existing_config_roots():
for suffix in PYTC_CONFIG_SUFFIXES:
for path in root.rglob(f"*{suffix}"):
configs.append(str(path.relative_to(PYTC_ROOT)).replace("\\", "/"))
return sorted(set(configs))


def _is_relative_to(path: pathlib.Path, root: pathlib.Path) -> bool:
try:
path.relative_to(root)
return True
except ValueError:
return False


def _is_valid_config_path(path: pathlib.Path) -> bool:
if not path.is_file():
return False
if path.suffix.lower() not in PYTC_CONFIG_SUFFIXES:
return False
if not _is_relative_to(path, PYTC_ROOT.resolve()):
return False
return any(
_is_relative_to(path, root.resolve()) for root in _iter_existing_config_roots()
)


def _resolve_requested_config(path: str) -> Optional[pathlib.Path]:
if not path:
return None

normalized = path.replace("\\", "/").strip()
if not normalized or normalized.startswith("/"):
return None
if ".." in pathlib.PurePosixPath(normalized).parts:
return None

candidates = [(PYTC_ROOT / normalized).resolve()]
for root in _iter_existing_config_roots():
candidates.append((root / normalized).resolve())

for candidate in candidates:
if _is_valid_config_path(candidate):
return candidate
return None


def _read_model_architectures() -> List[str]:
if not PYTC_BUILD_FILE.exists():
return []
text = PYTC_BUILD_FILE.read_text(encoding="utf-8", errors="ignore")
match = re.search(r"MODEL_MAP\s*=\s*{(.*?)}", text, re.S)
if not match:
return []
block = match.group(1)
keys = re.findall(r"'([^']+)'\s*:", block)
return sorted(set(keys))
# Prefer runtime registry from the installed connectomics package.
try:
from connectomics.models.arch import list_architectures

architectures = list_architectures()
if architectures:
return sorted(set(architectures))
except Exception:
pass

# Fallback: parse decorator registrations from source files.
pattern = re.compile(r"""@register_architecture\(\s*['"]([^'"]+)['"]\s*\)""")
architectures = []
arch_root = PYTC_ROOT / "connectomics" / "models" / "arch"
for py_file in arch_root.rglob("*.py"):
text = py_file.read_text(encoding="utf-8", errors="ignore")
architectures.extend(pattern.findall(text))
return sorted(set(architectures))


@app.get("/pytc/configs")
Expand All @@ -131,17 +186,12 @@ def list_pytc_configs():
def get_pytc_config(path: str):
if not path:
raise HTTPException(status_code=400, detail="Config path is required.")
if ".." in path or path.startswith("/"):
raise HTTPException(status_code=400, detail="Invalid config path.")
requested = (PYTC_CONFIG_ROOT / path).resolve()
if not str(requested).startswith(str(PYTC_CONFIG_ROOT.resolve())):
raise HTTPException(status_code=400, detail="Invalid config path.")
if not requested.is_file():
requested = _resolve_requested_config(path)
if requested is None:
raise HTTPException(status_code=404, detail="Config not found.")
if requested.suffix.lower() != ".yaml":
raise HTTPException(status_code=400, detail="Config must be a YAML file.")
content = requested.read_text(encoding="utf-8", errors="ignore")
return {"path": path, "content": content}
canonical_path = str(requested.relative_to(PYTC_ROOT)).replace("\\", "/")
return {"path": canonical_path, "content": content}


@app.get("/pytc/architectures")
Expand Down
3 changes: 1 addition & 2 deletions server_pytc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ async def get_tensorboard_url():
async def start_model_inference(req: Request):
req = await req.json()
print("start model inference")
# log_dir = req["log_dir"]
start_inference(req)
return start_inference(req)


@app.post("/stop_model_inference")
Expand Down
Loading
Loading