Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion superset/mcp_service/chart/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from superset.mcp_service.common.error_schemas import ChartGenerationError
from superset.mcp_service.system.schemas import (
PaginationInfo,
serialize_user_object,
TagInfo,
UserInfo,
)
Expand Down Expand Up @@ -278,8 +279,9 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
if getattr(chart, "tags", None)
else [],
owners=[
UserInfo.model_validate(owner, from_attributes=True)
info
for owner in getattr(chart, "owners", [])
if (info := serialize_user_object(owner)) is not None
]
if getattr(chart, "owners", None)
else [],
Expand Down
19 changes: 5 additions & 14 deletions superset/mcp_service/dashboard/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from superset.mcp_service.system.schemas import (
PaginationInfo,
RoleInfo,
serialize_user_object,
TagInfo,
UserInfo,
)
Expand All @@ -109,19 +110,8 @@ def create(cls, error: str, error_type: str) -> "DashboardError":
return cls(error=error, error_type=error_type, timestamp=datetime.now())


def serialize_user_object(user: Any) -> UserInfo | None:
"""Serialize a user object to UserInfo"""
if not user:
return None

return UserInfo(
id=getattr(user, "id", None),
username=getattr(user, "username", None),
first_name=getattr(user, "first_name", None),
last_name=getattr(user, "last_name", None),
email=getattr(user, "email", None),
active=getattr(user, "active", None),
)
# serialize_user_object is imported from system.schemas and re-exported here
# for backward compatibility with dashboard tool modules.


def serialize_tag_object(tag: Any) -> TagInfo | None:
Expand Down Expand Up @@ -502,8 +492,9 @@ def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo:
changed_on_humanized=dashboard.changed_on_humanized,
chart_count=len(dashboard.slices) if dashboard.slices else 0,
owners=[
UserInfo.model_validate(owner, from_attributes=True)
info
for owner in dashboard.owners
if (info := serialize_user_object(owner)) is not None
]
if dashboard.owners
else [],
Expand Down
4 changes: 3 additions & 1 deletion superset/mcp_service/dataset/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from superset.mcp_service.common.cache_schemas import MetadataCacheControl
from superset.mcp_service.system.schemas import (
PaginationInfo,
serialize_user_object,
TagInfo,
UserInfo,
)
Expand Down Expand Up @@ -338,8 +339,9 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
if getattr(dataset, "tags", None)
else [],
owners=[
UserInfo.model_validate(owner, from_attributes=True)
info
for owner in getattr(dataset, "owners", [])
if (info := serialize_user_object(owner)) is not None
]
if getattr(dataset, "owners", None)
else [],
Expand Down
25 changes: 24 additions & 1 deletion superset/mcp_service/system/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Dict, List
from typing import Any, Dict, List

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -170,6 +170,29 @@ class UserInfo(BaseModel):
)


def serialize_user_object(user: Any) -> UserInfo | None:
"""Serialize a user ORM object to UserInfo, extracting role names as strings."""
if not user:
return None

user_roles: list[str] = []
if (raw_roles := getattr(user, "roles", None)) is not None:
try:
user_roles = [role.name for role in raw_roles if hasattr(role, "name")]
except TypeError:
user_roles = []

return UserInfo(
id=getattr(user, "id", None),
username=getattr(user, "username", None),
first_name=getattr(user, "first_name", None),
last_name=getattr(user, "last_name", None),
email=getattr(user, "email", None),
active=getattr(user, "active", None),
roles=user_roles,
)


class TagInfo(BaseModel):
id: int | None = None
name: str | None = None
Expand Down
167 changes: 167 additions & 0 deletions tests/unit_tests/mcp_service/system/test_serialize_user_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from unittest.mock import MagicMock

import pytest

from superset.mcp_service.system.schemas import serialize_user_object


def test_returns_none_for_none_user() -> None:
assert serialize_user_object(None) is None


def test_returns_none_for_falsy_user() -> None:
assert serialize_user_object(0) is None
assert serialize_user_object("") is None


def test_extracts_basic_fields() -> None:
user = MagicMock()
user.id = 1
user.username = "admin"
user.first_name = "Ad"
user.last_name = "Min"
user.email = "admin@example.com"
user.active = True
user.roles = []

result = serialize_user_object(user)

assert result is not None
assert result.id == 1
assert result.username == "admin"
assert result.first_name == "Ad"
assert result.last_name == "Min"
assert result.email == "admin@example.com"
assert result.active is True
assert result.roles == []


def test_extracts_role_names_from_orm_objects() -> None:
"""The original bug: SQLAlchemy Role objects must be converted to strings."""
role_admin = MagicMock()
role_admin.name = "Admin"
role_alpha = MagicMock()
role_alpha.name = "Alpha"

user = MagicMock()
user.id = 1
user.username = "admin"
user.first_name = "Ad"
user.last_name = "Min"
user.email = "admin@example.com"
user.active = True
user.roles = [role_admin, role_alpha]

result = serialize_user_object(user)

assert result is not None
assert result.roles == ["Admin", "Alpha"]


def test_handles_user_without_roles_attribute() -> None:
user = MagicMock(
spec=["id", "username", "first_name", "last_name", "email", "active"]
)
user.id = 1
user.username = "noroles"
user.first_name = "No"
user.last_name = "Roles"
user.email = "no@roles.com"
user.active = True

result = serialize_user_object(user)

assert result is not None
assert result.roles == []


def _make_user(**overrides: object) -> MagicMock:
"""Helper to create a fully-populated mock user."""
defaults = {
"id": 1,
"username": "testuser",
"first_name": "Test",
"last_name": "User",
"email": "test@example.com",
"active": True,
"roles": [],
}
defaults.update(overrides)
user = MagicMock()
for k, v in defaults.items():
setattr(user, k, v)
return user


def test_handles_non_iterable_roles() -> None:
user = _make_user(roles=42)

result = serialize_user_object(user)

assert result is not None
assert result.roles == []


def test_skips_roles_without_name_attribute() -> None:
role_good = MagicMock()
role_good.name = "Admin"

user = _make_user(roles=[role_good])

result = serialize_user_object(user)

assert result is not None
assert result.roles == ["Admin"]


def test_handles_none_roles() -> None:
user = _make_user(roles=None)

result = serialize_user_object(user)

assert result is not None
assert result.roles == []


@pytest.mark.parametrize(
"missing_field",
["id", "username", "first_name", "last_name", "email", "active"],
)
def test_missing_fields_default_to_none(missing_field: str) -> None:
"""Fields not present on the user object should default to None."""
attrs = {
"id": 1,
"username": "test",
"first_name": "T",
"last_name": "U",
"email": "t@e.com",
"active": True,
"roles": [],
}
# Remove the field under test
del attrs[missing_field]
user = MagicMock(spec=list(attrs.keys()))
for k, v in attrs.items():
setattr(user, k, v)

result = serialize_user_object(user)

assert result is not None
assert getattr(result, missing_field) is None
Loading