diff --git a/cloud_pipelines_backend/api_router.py b/cloud_pipelines_backend/api_router.py index ce05ecb..8dac120 100644 --- a/cloud_pipelines_backend/api_router.py +++ b/cloud_pipelines_backend/api_router.py @@ -481,6 +481,46 @@ def get_current_user( ) ) + # region UserPipelines routes + user_pipelines_service = api_server_sql.UserPipelineApiService() + + # The routes pass `file_path` as query parameter, not path parameter since file_path may contain slashes and ASGI/FastAPI has buggy handling of them. + # We're not using `{file_path:path}` (although that's pretty tempting) since this will essentially block all future API expansion under `/api/users/me/pipelines/`. + # For example adding `/api/users/me/pipelines/copy` method would be impossible since it would clash with pipeline name. + # See https://github.com/fastapi/fastapi/issues/791 https://github.com/Kludex/starlette/issues/826 + + router.get( + "/api/users/me/pipelines/all", tags=["user_pipelines"], **default_config + )( + inject_session_dependency( + inject_user_name( + user_pipelines_service.list_pipelines, parameter_name="user_id" + ) + ) + ) + router.get("/api/users/me/pipelines", tags=["user_pipelines"], **default_config)( + inject_session_dependency( + inject_user_name( + user_pipelines_service.get_pipeline, parameter_name="user_id" + ) + ) + ) + router.put("/api/users/me/pipelines", tags=["user_pipelines"], **default_config)( + inject_session_dependency( + inject_user_name( + user_pipelines_service.set_pipeline, parameter_name="user_id" + ) + ), + ) + router.delete("/api/users/me/pipelines", tags=["user_pipelines"], **default_config)( + inject_session_dependency( + inject_user_name( + user_pipelines_service.delete_pipeline, parameter_name="user_id" + ) + ) + ) + # endregion + # region UserSettings routes user_settings_service = api_server_sql.UserSettingsApiService() diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 086ef1f..03dea6f 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -1210,6 +1210,166 @@ def list_secrets( ) +# region: UserSettingsService +# /api/user/me/pipelines + + +@dataclasses.dataclass(kw_only=True) +class UserPipelineResponse: + file_path: str + pipeline_name: str | None = None + created_at: datetime.datetime + modified_at: datetime.datetime + root_pipeline_task: structures.TaskSpec + pipeline_run_annotations: dict[str, Any] | None = None + + @classmethod + def from_db(cls, pipeline_row: bts.UserPipeline) -> "UserPipelineResponse": + return UserPipelineResponse( + file_path=pipeline_row.file_path, + pipeline_name=(pipeline_row.extra_data or {}).get("pipeline_name"), + created_at=pipeline_row.created_at, + modified_at=pipeline_row.modified_at, + root_pipeline_task=structures.TaskSpec.from_json_dict( + pipeline_row.root_pipeline_task + ), + pipeline_run_annotations=pipeline_row.pipeline_run_annotations, + ) + + +@dataclasses.dataclass(kw_only=True) +class UserPipelineShortResponse: + file_path: str + pipeline_name: str | None = None + created_at: datetime.datetime + modified_at: datetime.datetime + + +@dataclasses.dataclass(kw_only=True) +class ListUserPipelinesResponse: + pipelines: list[UserPipelineShortResponse] + + +class UserPipelineApiService: + + def get_pipeline( + self, + *, + session: orm.Session, + user_id: str, + file_path: str, + ) -> UserPipelineResponse: + pipeline_row = session.scalar( + sql.select(bts.UserPipeline).where( + bts.UserPipeline.user_id == user_id, + bts.UserPipeline.file_path == file_path, + ) + ) + if pipeline_row is None: + raise errors.ItemNotFoundError( + f"Pipeline with file path {file_path} not found." + ) + return UserPipelineResponse.from_db(pipeline_row) + + def list_pipelines( + self, + *, + session: orm.Session, + user_id: str, + ) -> ListUserPipelinesResponse: + pipelines = [ + UserPipelineShortResponse( + file_path=file_path, + created_at=created_at, + modified_at=modified_at, + pipeline_name=(extra_data or {}).get("pipeline_name"), + ) + for file_path, created_at, modified_at, extra_data in session.execute( + sql.select( + bts.UserPipeline.file_path, + bts.UserPipeline.created_at, + bts.UserPipeline.modified_at, + bts.UserPipeline.extra_data, + ) + .where(bts.UserPipeline.user_id == user_id) + .order_by(bts.UserPipeline.modified_at.desc()) + ).tuples() + ] + return ListUserPipelinesResponse(pipelines=pipelines) + + def set_pipeline( + self, + *, + session: orm.Session, + user_id: str, + file_path: str, + root_pipeline_task: structures.TaskSpec, + pipeline_run_annotations: dict[str, str] | None = None, + ) -> None: + file_path = file_path.strip() + if not (0 < len(file_path) <= bts.UserPipeline.MAX_FILE_PATH_LENGTH): + raise ApiServiceError( + f"Pipeline file path must be between 1 and {bts.UserPipeline.MAX_FILE_PATH_LENGTH} characters." + ) + # Note: It's OK if the pipeline is not fully valid (e.g. required inputs without arguments) + pipeline_row = session.scalar( + sql.select(bts.UserPipeline).where( + bts.UserPipeline.user_id == user_id, + bts.UserPipeline.file_path == file_path, + ) + ) + current_time = _get_current_time() + if pipeline_row is None: + pipeline_row = bts.UserPipeline( + user_id=user_id, + file_path=file_path, + created_at=current_time, + modified_at=current_time, + root_pipeline_task=root_pipeline_task.to_json_dict(), + pipeline_run_annotations=pipeline_run_annotations, + ) + session.add(pipeline_row) + else: + pipeline_row.modified_at = current_time + pipeline_row.root_pipeline_task = root_pipeline_task.to_json_dict() + pipeline_row.pipeline_run_annotations = pipeline_run_annotations + # Storing pipeline name. Storing it in extra_data instead of a column to avoid issues with long pipeline names + # TODO: Hydrate pipeline from text if needed. + pipeline_name = ( + root_pipeline_task.component_ref.spec.name + if root_pipeline_task.component_ref.spec + else None + ) + if pipeline_name: + if not pipeline_row.extra_data: + pipeline_row.extra_data = {} + pipeline_row.extra_data["pipeline_name"] = pipeline_name + session.commit() + + def delete_pipeline( + self, + *, + session: orm.Session, + user_id: str, + file_path: str, + ) -> None: + pipeline_row = session.scalar( + sql.select(bts.UserPipeline).where( + bts.UserPipeline.user_id == user_id, + bts.UserPipeline.file_path == file_path, + ) + ) + if pipeline_row is None: + raise errors.ItemNotFoundError( + f"Pipeline with file path {file_path} not found." + ) + session.delete(pipeline_row) + session.commit() + + +# endregion + + # region: User Settings API Service # /api/user/me/settings diff --git a/cloud_pipelines_backend/backend_types_sql.py b/cloud_pipelines_backend/backend_types_sql.py index b984c6c..b7d77ef 100644 --- a/cloud_pipelines_backend/backend_types_sql.py +++ b/cloud_pipelines_backend/backend_types_sql.py @@ -531,6 +531,64 @@ class Secret(_TableBase): extra_data: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None) +class UserPipeline(_TableBase): + __tablename__ = "user_pipeline" + + # What should be the maximum file path length we support? + # VARCHAR length cannot be more than ~16 * 1024 (or even less) in some databases like MySQL. + # See for example: https://dev.mysql.com/doc/refman/8.4/en/column-count-limit.html (end of the document) + # Maybe, set it to 1024 for now. + # No. On MySQL, the max length of an indexed string column must be at least <= 768. + # Otherwise we get an error in MySQL: "Specified key was too long; max key length is 3072 bytes" when creating the main index. + # See https://github.com/TangleML/tangle/issues/173 + # And since the index also includes user_id (which is a string or max size 255), we have even less space: 512 bytes. + # So, let's just limit the length to 255 for now. + MAX_FILE_PATH_LENGTH = 255 + + # What should be the primary key? + # * (user_id, file_path)? + # * Surrogate ID? + # * User-provided ID? + # Value of `file_path` may be changed by the user (in the future) and changing IDs is discouraged, so this leads us to use surrogate primary key. + # Should we use generate_unique_id or normal auto-increment integer? + # Leaning towards using generate_unique_id here too. + id: orm.Mapped[str] = orm.mapped_column( + primary_key=True, init=False, insert_default=generate_unique_id + ) + user_id: orm.Mapped[str] = orm.mapped_column(index=True) + # Which SQL type to use for file paths? + # The TEXT type in MySQL is stored off-row and creates some issues, especially when it's part of an index + # See https://dev.mysql.com/doc/refman/8.4/en/blob.html + # file_path: orm.Mapped[str] = orm.mapped_column(type_=sql.Text()) + file_path: orm.Mapped[str] = orm.mapped_column( + type_=sql.String(MAX_FILE_PATH_LENGTH) + ) + created_at: orm.Mapped[datetime.datetime] = orm.mapped_column() + modified_at: orm.Mapped[datetime.datetime] = orm.mapped_column() + + # What exactly do we want to store? + # ! Pipeline is usually a ComponentSpec, but we want to save more. + # First of all, we want to save pipeline arguments. So we need TaskSpec. + # But the user might also want to save pipeline run annotations. + root_pipeline_task: orm.Mapped[dict[str, Any]] = orm.mapped_column() + pipeline_run_annotations: orm.Mapped[dict[str, Any] | None] = orm.mapped_column( + default=None + ) + extra_data: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None) + + __table_args__ = ( + sql.Index( + "ix_user_pipeline_user_id_file_path_unique", user_id, file_path, unique=True + ), + sql.Index( + "ix_user_pipeline_user_id_created_at_desc", user_id, created_at.desc() + ), + sql.Index( + "ix_user_pipeline_user_id_modified_at_desc", user_id, modified_at.desc() + ), + ) + + class UserSettings(_TableBase): __tablename__ = "user_settings"