diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 28fdf1b0d88b..c5a2098442f9 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -45,6 +45,7 @@ from superset.mcp_service.common.error_schemas import ChartGenerationError from superset.mcp_service.system.schemas import ( PaginationInfo, + serialize_user_object, TagInfo, UserInfo, ) @@ -278,8 +279,9 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None: if getattr(chart, "tags", None) else [], owners=[ - UserInfo.model_validate(owner, from_attributes=True) + info for owner in getattr(chart, "owners", []) + if (info := serialize_user_object(owner)) is not None ] if getattr(chart, "owners", None) else [], diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index 7b002677c698..e09b3871ce6c 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -87,6 +87,7 @@ from superset.mcp_service.system.schemas import ( PaginationInfo, RoleInfo, + serialize_user_object, TagInfo, UserInfo, ) @@ -109,19 +110,8 @@ def create(cls, error: str, error_type: str) -> "DashboardError": return cls(error=error, error_type=error_type, timestamp=datetime.now()) -def serialize_user_object(user: Any) -> UserInfo | None: - """Serialize a user object to UserInfo""" - if not user: - return None - - return UserInfo( - id=getattr(user, "id", None), - username=getattr(user, "username", None), - first_name=getattr(user, "first_name", None), - last_name=getattr(user, "last_name", None), - email=getattr(user, "email", None), - active=getattr(user, "active", None), - ) +# serialize_user_object is imported from system.schemas and re-exported here +# for backward compatibility with dashboard tool modules. def serialize_tag_object(tag: Any) -> TagInfo | None: @@ -502,8 +492,9 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo: changed_on_humanized=dashboard.changed_on_humanized, chart_count=len(dashboard.slices) if dashboard.slices else 0, owners=[ - UserInfo.model_validate(owner, from_attributes=True) + info for owner in dashboard.owners + if (info := serialize_user_object(owner)) is not None ] if dashboard.owners else [], diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index b0dad96b5926..14dd329e8f7d 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -37,6 +37,7 @@ from superset.mcp_service.common.cache_schemas import MetadataCacheControl from superset.mcp_service.system.schemas import ( PaginationInfo, + serialize_user_object, TagInfo, UserInfo, ) @@ -338,8 +339,9 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None: if getattr(dataset, "tags", None) else [], owners=[ - UserInfo.model_validate(owner, from_attributes=True) + info for owner in getattr(dataset, "owners", []) + if (info := serialize_user_object(owner)) is not None ] if getattr(dataset, "owners", None) else [], diff --git a/superset/mcp_service/system/schemas.py b/superset/mcp_service/system/schemas.py index f1667471693f..9810cc4d3b3e 100644 --- a/superset/mcp_service/system/schemas.py +++ b/superset/mcp_service/system/schemas.py @@ -25,7 +25,7 @@ from __future__ import annotations from datetime import datetime -from typing import Dict, List +from typing import Any, Dict, List from pydantic import BaseModel, ConfigDict, Field @@ -170,6 +170,29 @@ class UserInfo(BaseModel): ) +def serialize_user_object(user: Any) -> UserInfo | None: + """Serialize a user ORM object to UserInfo, extracting role names as strings.""" + if not user: + return None + + user_roles: list[str] = [] + if (raw_roles := getattr(user, "roles", None)) is not None: + try: + user_roles = [role.name for role in raw_roles if hasattr(role, "name")] + except TypeError: + user_roles = [] + + return UserInfo( + id=getattr(user, "id", None), + username=getattr(user, "username", None), + first_name=getattr(user, "first_name", None), + last_name=getattr(user, "last_name", None), + email=getattr(user, "email", None), + active=getattr(user, "active", None), + roles=user_roles, + ) + + class TagInfo(BaseModel): id: int | None = None name: str | None = None diff --git a/tests/unit_tests/mcp_service/system/test_serialize_user_object.py b/tests/unit_tests/mcp_service/system/test_serialize_user_object.py new file mode 100644 index 000000000000..adca76fdfde9 --- /dev/null +++ b/tests/unit_tests/mcp_service/system/test_serialize_user_object.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock + +import pytest + +from superset.mcp_service.system.schemas import serialize_user_object + + +def test_returns_none_for_none_user() -> None: + assert serialize_user_object(None) is None + + +def test_returns_none_for_falsy_user() -> None: + assert serialize_user_object(0) is None + assert serialize_user_object("") is None + + +def test_extracts_basic_fields() -> None: + user = MagicMock() + user.id = 1 + user.username = "admin" + user.first_name = "Ad" + user.last_name = "Min" + user.email = "admin@example.com" + user.active = True + user.roles = [] + + result = serialize_user_object(user) + + assert result is not None + assert result.id == 1 + assert result.username == "admin" + assert result.first_name == "Ad" + assert result.last_name == "Min" + assert result.email == "admin@example.com" + assert result.active is True + assert result.roles == [] + + +def test_extracts_role_names_from_orm_objects() -> None: + """The original bug: SQLAlchemy Role objects must be converted to strings.""" + role_admin = MagicMock() + role_admin.name = "Admin" + role_alpha = MagicMock() + role_alpha.name = "Alpha" + + user = MagicMock() + user.id = 1 + user.username = "admin" + user.first_name = "Ad" + user.last_name = "Min" + user.email = "admin@example.com" + user.active = True + user.roles = [role_admin, role_alpha] + + result = serialize_user_object(user) + + assert result is not None + assert result.roles == ["Admin", "Alpha"] + + +def test_handles_user_without_roles_attribute() -> None: + user = MagicMock( + spec=["id", "username", "first_name", "last_name", "email", "active"] + ) + user.id = 1 + user.username = "noroles" + user.first_name = "No" + user.last_name = "Roles" + user.email = "no@roles.com" + user.active = True + + result = serialize_user_object(user) + + assert result is not None + assert result.roles == [] + + +def _make_user(**overrides: object) -> MagicMock: + """Helper to create a fully-populated mock user.""" + defaults = { + "id": 1, + "username": "testuser", + "first_name": "Test", + "last_name": "User", + "email": "test@example.com", + "active": True, + "roles": [], + } + defaults.update(overrides) + user = MagicMock() + for k, v in defaults.items(): + setattr(user, k, v) + return user + + +def test_handles_non_iterable_roles() -> None: + user = _make_user(roles=42) + + result = serialize_user_object(user) + + assert result is not None + assert result.roles == [] + + +def test_skips_roles_without_name_attribute() -> None: + role_good = MagicMock() + role_good.name = "Admin" + + user = _make_user(roles=[role_good]) + + result = serialize_user_object(user) + + assert result is not None + assert result.roles == ["Admin"] + + +def test_handles_none_roles() -> None: + user = _make_user(roles=None) + + result = serialize_user_object(user) + + assert result is not None + assert result.roles == [] + + +@pytest.mark.parametrize( + "missing_field", + ["id", "username", "first_name", "last_name", "email", "active"], +) +def test_missing_fields_default_to_none(missing_field: str) -> None: + """Fields not present on the user object should default to None.""" + attrs = { + "id": 1, + "username": "test", + "first_name": "T", + "last_name": "U", + "email": "t@e.com", + "active": True, + "roles": [], + } + # Remove the field under test + del attrs[missing_field] + user = MagicMock(spec=list(attrs.keys())) + for k, v in attrs.items(): + setattr(user, k, v) + + result = serialize_user_object(user) + + assert result is not None + assert getattr(result, missing_field) is None