diff --git a/requirements.txt b/requirements.txt index 50aaa5fd..9108a15f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,4 +63,3 @@ wslink==1.12.4 yarl>=1 # via aiohttp -opengeodeweb-microservice==1.*,>=1.1.3 diff --git a/src/opengeodeweb_viewer/app.py b/src/opengeodeweb_viewer/app.py index ca61d2bf..bd088801 100644 --- a/src/opengeodeweb_viewer/app.py +++ b/src/opengeodeweb_viewer/app.py @@ -108,8 +108,7 @@ class _Server(VtkTypingMixin, ServerProtocol): @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument( - "--data_folder_path", - default=os.environ.get("DATA_FOLDER_PATH"), + "--project_folder_path", help="Path to the folder where data is stored", ) @@ -194,24 +193,40 @@ def initialize(self) -> None: def run_server(Server: type[ServerProtocol] = _Server) -> None: - PYTHON_ENV = os.environ.get("PYTHON_ENV", default="prod").strip().lower() - if PYTHON_ENV == "prod": - prod_config() - elif PYTHON_ENV == "dev": - dev_config() - parser = argparse.ArgumentParser(description="Vtk server") server.add_arguments(parser) - + parser.set_defaults(port=None, host=None) Server.add_arguments(parser) args = parser.parse_args() - if not "host" in args: - args.host = os.environ["DEFAULT_HOST"] - if not "port" in args or args.port == 8080: - args.port = os.environ.get("DEFAULT_PORT") - if "data_folder_path" in args and args.data_folder_path: - os.environ["DATA_FOLDER_PATH"] = args.data_folder_path + if args.project_folder_path is None: + raise ValueError("project_folder_path must be provided") + else: + args.project_folder_path = os.path.abspath(args.project_folder_path) + + PYTHON_ENV = os.environ.get("PYTHON_ENV", "prod").strip().lower() + + app_config: Config + if PYTHON_ENV == "prod": + app_config = ProdConfig(args.project_folder_path) + elif PYTHON_ENV == "dev": + app_config = DevConfig(args.project_folder_path) + elif PYTHON_ENV == "test": + app_config = TestConfig(args.project_folder_path) + else: + raise ValueError(f"Unknown PYTHON_ENV: {PYTHON_ENV!r}") + + if args.host is not None: + app_config.HOST = str(args.host) + else: + args.host = app_config.HOST + + if args.port is not None: + app_config.PORT = str(args.port) + else: + args.port = app_config.PORT + + app_config.sync_env() db_full_path = os.path.join(os.environ["DATA_FOLDER_PATH"], "project.db") connection.init_database(db_full_path, create_tables=False) @@ -219,7 +234,6 @@ def run_server(Server: type[ServerProtocol] = _Server) -> None: print(f"{args=}", flush=True) Server.configure(args) - server.start_webserver(options=args, protocol=Server) diff --git a/src/opengeodeweb_viewer/config.py b/src/opengeodeweb_viewer/config.py index 876fc697..a48fba18 100644 --- a/src/opengeodeweb_viewer/config.py +++ b/src/opengeodeweb_viewer/config.py @@ -1,29 +1,44 @@ import os from shutil import copyfile, copytree -from sys import platform -def default_config() -> None: - os.environ["DEFAULT_HOST"] = "localhost" - os.environ["DEFAULT_PORT"] = "1234" +class Config: + HOST = "localhost" + PORT = "1234" + DATABASE_FILENAME = "project.db" + def __init__(self, project_folder_path: str): + self.PROJECT_FOLDER_PATH = project_folder_path + self.DATA_FOLDER_PATH = os.path.join(project_folder_path, "data") + self.sync_env() -def prod_config() -> None: - default_config() - os.environ["DATA_FOLDER_PATH"] = "/data/" + def sync_env(self) -> None: + os.environ["PROJECT_FOLDER_PATH"] = self.PROJECT_FOLDER_PATH + os.environ["DATA_FOLDER_PATH"] = self.DATA_FOLDER_PATH + os.environ["HOST"] = self.HOST + os.environ["PORT"] = self.PORT + os.environ["DATABASE_FILENAME"] = self.DATABASE_FILENAME -def dev_config() -> None: - default_config() - if platform == "linux": - os.environ["DATA_FOLDER_PATH"] = "/temp/OpenGeodeWeb_Data/" - elif platform == "win32": - os.environ["DATA_FOLDER_PATH"] = os.path.join( - "C:/Users", os.getlogin(), "OpenGeodeWeb_Data" - ) - data_folder_path = os.environ.get("DATA_FOLDER_PATH") - if data_folder_path and not os.path.exists(data_folder_path): - os.mkdir(data_folder_path) +class ProdConfig(Config): + def __init__(self, project_folder_path: str) -> None: + super().__init__(project_folder_path) + + +class DevConfig(Config): + def __init__(self, project_folder_path: str) -> None: + super().__init__(project_folder_path) + os.makedirs(self.DATA_FOLDER_PATH, exist_ok=True) + + +class TestConfig(Config): + def __init__(self, project_folder_path: str) -> None: + print("Received ", project_folder_path, flush=True) + super().__init__(project_folder_path) + os.makedirs(self.DATA_FOLDER_PATH, exist_ok=True) + db_file = os.path.join(self.DATA_FOLDER_PATH, self.DATABASE_FILENAME) + if not os.path.exists(db_file): + open(db_file, "a").close() def _copy_test_assets( @@ -47,18 +62,3 @@ def _copy_test_assets( copyfile(src, os.path.join(tmp_data_root, test_id, file)) copyfile(src, os.path.join(structure_directory, file)) copyfile(src, os.path.join(uploads_directory, file)) - - -def test_config() -> None: - default_config() - if "DATA_FOLDER_PATH" not in os.environ: - data_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "data") - os.environ["DATA_FOLDER_PATH"] = os.path.abspath(data_path) - - data_path = os.environ["DATA_FOLDER_PATH"] - if not os.path.exists(data_path): - os.makedirs(data_path, exist_ok=True) - - db_file = os.path.join(data_path, "project.db") - if not os.path.exists(db_file): - open(db_file, "a").close() diff --git a/tests/conftest.py b/tests/conftest.py index 1b39a6fb..6cc8cf02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import shutil import xml.etree.ElementTree as ET from typing import Callable, Generator, Any -from opengeodeweb_viewer import config +from opengeodeweb_viewer.config import TestConfig from opengeodeweb_microservice.database.connection import get_session, init_database from opengeodeweb_microservice.database.data import Data from opengeodeweb_viewer.rpc.viewer.viewer_protocols import VtkViewerView @@ -18,9 +18,9 @@ class ServerMonitor: - def __init__(self, log: str) -> None: + def __init__(self, log: str, port: str = "1234") -> None: self.log = log - self.ws = create_connection("ws://localhost:1234/ws") + self.ws = create_connection(f"ws://localhost:{port}/ws") self.images_dir_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "data", "images") ) @@ -158,7 +158,7 @@ class FixtureHelper: def __init__(self, root_path: Path) -> None: self.root_path = Path(root_path) - def get_xprocess_args(self) -> tuple[str, type, type]: + def get_xprocess_args(self, project_folder_path: str) -> tuple[str, type, type]: class Starter(ProcessStarter): # type: ignore terminate_on_interrupt = True pattern = "wslink: Starting factory" @@ -167,6 +167,8 @@ class Starter(ProcessStarter): # type: ignore # command to start process args = [ "opengeodeweb-viewer", + "--project_folder_path", + project_folder_path, ] return "app", Starter, ServerMonitor @@ -178,7 +180,8 @@ class Starter(ProcessStarter): # type: ignore @pytest.fixture def server(xprocess: XProcess) -> Generator[ServerMonitor, None, None]: - name, Starter, Monitor = HELPER.get_xprocess_args() + project_folder_path = str(Path(__file__).parent.absolute()) + name, Starter, Monitor = HELPER.get_xprocess_args(project_folder_path) os.environ["PYTHON_ENV"] = "test" _, log = xprocess.ensure(name, Starter) monitor = Monitor(log) @@ -193,16 +196,14 @@ def server(xprocess: XProcess) -> Generator[ServerMonitor, None, None]: @pytest.fixture(scope="session", autouse=True) def configure_test_environment() -> Generator[None, None, None]: - project_root = Path(__file__).parent.absolute() - os.environ["DATA_FOLDER_PATH"] = str(project_root / "data") - - config.test_config() - db_path = Path(os.environ["DATA_FOLDER_PATH"]) / "project.db" + project_folder_path = str(Path(__file__).parent.absolute()) + app_config = TestConfig(project_folder_path) + db_path = Path(app_config.DATA_FOLDER_PATH) / "project.db" init_database(db_path=str(db_path)) os.environ["TEST_DB_PATH"] = str(db_path) yield - tmp_data_path = os.environ.get("DATA_FOLDER_PATH") + tmp_data_path = app_config.DATA_FOLDER_PATH if tmp_data_path and "ogw_test_data_" in tmp_data_path: shutil.rmtree(tmp_data_path, ignore_errors=True) print(f"Cleaned up test data folder: {tmp_data_path}", flush=True)