Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ seed:
seed-cleanup:
python hack/seed.py cleanup

.PHONY: dashboard
dashboard:
cd dashboard && npm run dev
.PHONY: build-dashboard
build-dashboard:
cd dashboard && npm install && npm run build
31 changes: 21 additions & 10 deletions alphatrion/server/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def main():
help="Backend server URL to proxy requests to (default: http://localhost:8000)",
)
dashboard.add_argument(
"--no-browser", action="store_true", help="Don't automatically open browser"
"--userid",
type=str,
required=True,
help="User ID to scope the dashboard (required)",
)
dashboard.set_defaults(func=start_dashboard)

Expand Down Expand Up @@ -160,18 +163,27 @@ def start_dashboard(args):
console.print(
Text(f"🔗 Proxying backend requests to: {args.backend_url}", style="dim")
)
console.print(Text(f"👤 Dashboard scoped to user: {args.userid}", style="yellow"))
console.print()
console.print(
Text("💡 Note: Make sure the backend server is running:", style="bold yellow")
)
console.print(Text(" alphatrion server --port 8000", style="cyan"))
console.print(Text(" alphatrion server", style="cyan"))
console.print()

app = FastAPI()

# Store user ID in app state
app.state.user_id = args.userid

# Create HTTP client for proxying requests to backend
http_client = httpx.AsyncClient(base_url=args.backend_url, timeout=30.0)

# Endpoint to get current user ID (for frontend)
@app.get("/api/config")
async def get_config():
return {"userId": app.state.user_id}

# Proxy /graphql requests to backend (MUST be before catch-all route)
@app.api_route("/graphql", methods=["GET", "POST"])
async def proxy_graphql(request: Request):
Expand Down Expand Up @@ -251,16 +263,15 @@ async def shutdown_event():

url = f"http://127.0.0.1:{args.port}"

# Open browser after a short delay (unless --no-browser is set)
if not args.no_browser:
console.print(Text(f"🌐 Dashboard URL: {url}", style="bold cyan"))

def open_browser():
time.sleep(1) # Wait for server to start
webbrowser.open(url)
# Open browser after a short delay to ensure server is ready
def open_browser():
time.sleep(1) # Wait for server to start
webbrowser.open(url)

threading.Thread(target=open_browser, daemon=True).start()
else:
console.print(Text(f"🌐 Open your browser at: {url}", style="bold cyan"))
browser_thread = threading.Thread(target=open_browser, daemon=True)
browser_thread.start()

try:
uvicorn.run(app, host="127.0.0.1", port=args.port, log_level="info")
Expand Down
119 changes: 116 additions & 3 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,28 @@
from alphatrion.storage.sql_models import Status

from .types import (
AddUserToTeamInput,
CreateTeamInput,
CreateUserInput,
Experiment,
GraphQLExperimentType,
GraphQLExperimentTypeEnum,
GraphQLStatusEnum,
Metric,
Project,
RemoveUserFromTeamInput,
Run,
Team,
UpdateUserInput,
User,
)


class GraphQLResolvers:
@staticmethod
def list_teams(page: int = 0, page_size: int = 10) -> list[Team]:
def list_teams(user_id: strawberry.ID) -> list[Team]:
metadb = runtime.graphql_runtime().metadb
teams = metadb.list_teams(page=page, page_size=page_size)
teams = metadb.list_user_teams(user_id=user_id)
return [
Team(
id=t.uuid,
Expand Down Expand Up @@ -60,7 +65,7 @@ def get_user(id: strawberry.ID) -> User | None:
id=user.uuid,
username=user.username,
email=user.email,
team_id=user.team_id,
avatar_url=user.avatar_url,
meta=user.meta,
created_at=user.created_at,
updated_at=user.updated_at,
Expand Down Expand Up @@ -282,3 +287,111 @@ def list_exps_by_timeframe(
)
for e in experiments
]


class GraphQLMutations:
@staticmethod
def create_user(input: CreateUserInput) -> User:
metadb = runtime.graphql_runtime().metadb
user_id = metadb.create_user(
username=input.username,
email=input.email,
avatar_url=input.avatar_url,
meta=input.meta,
)
user = metadb.get_user(user_id=user_id)
if user:
return User(
id=user.uuid,
username=user.username,
email=user.email,
avatar_url=user.avatar_url,
meta=user.meta,
created_at=user.created_at,
updated_at=user.updated_at,
)
msg = f"Failed to create user with username {input.username}"
raise RuntimeError(msg)

@staticmethod
def update_user(input: UpdateUserInput) -> User:
metadb = runtime.graphql_runtime().metadb
user_id = uuid.UUID(input.id)

user = metadb.update_user(user_id=user_id, meta=input.meta)
if not user:
msg = f"User with id {input.id} not found"
raise ValueError(msg)

return User(
id=user.uuid,
username=user.username,
email=user.email,
avatar_url=user.avatar_url,
meta=user.meta,
created_at=user.created_at,
updated_at=user.updated_at,
)

@staticmethod
def create_team(input: CreateTeamInput) -> Team:
metadb = runtime.graphql_runtime().metadb
team_id = metadb.create_team(
name=input.name,
description=input.description,
meta=input.meta,
)
team = metadb.get_team(team_id=team_id)
if team:
return Team(
id=team.uuid,
name=team.name,
description=team.description,
meta=team.meta,
created_at=team.created_at,
updated_at=team.updated_at,
)
msg = f"Failed to create team with name {input.name}"
raise RuntimeError(msg)

@staticmethod
def add_user_to_team(input: AddUserToTeamInput) -> bool:
metadb = runtime.graphql_runtime().metadb
user_id = uuid.UUID(input.user_id)
team_id = uuid.UUID(input.team_id)

# Verify team exists
team = metadb.get_team(team_id=team_id)
if not team:
msg = f"Team with id {input.team_id} not found"
raise ValueError(msg)

# Verify user exists
user = metadb.get_user(user_id=user_id)
if not user:
msg = f"User with id {input.user_id} not found"
raise ValueError(msg)

# Add user to team (creates TeamMember entry)
return metadb.add_user_to_team(user_id=user_id, team_id=team_id)

@staticmethod
def remove_user_from_team(input: RemoveUserFromTeamInput) -> bool:
metadb = runtime.graphql_runtime().metadb
user_id = uuid.UUID(input.user_id)
team_id = uuid.UUID(input.team_id)

# Verify team exists
team = metadb.get_team(team_id=team_id)
if not team:
msg = f"Team with id {input.team_id} not found"
raise ValueError(msg)

# Verify user exists
user = metadb.get_user(user_id=user_id)
if not user:
msg = f"User with id {input.user_id} not found"
raise ValueError(msg)

# Remove user from team (deletes TeamMember entry)
return metadb.remove_user_from_team(user_id=user_id, team_id=team_id)
40 changes: 37 additions & 3 deletions alphatrion/server/graphql/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import strawberry

from alphatrion.server.graphql.resolvers import GraphQLResolvers
from alphatrion.server.graphql.types import Experiment, Project, Run, Team, User
from alphatrion.server.graphql.resolvers import GraphQLMutations, GraphQLResolvers
from alphatrion.server.graphql.types import (
AddUserToTeamInput,
CreateTeamInput,
CreateUserInput,
Experiment,
Project,
RemoveUserFromTeamInput,
Run,
Team,
UpdateUserInput,
User,
)


@strawberry.type
Expand Down Expand Up @@ -71,4 +82,27 @@ def runs(
run: Run | None = strawberry.field(resolver=GraphQLResolvers.get_run)


schema = strawberry.Schema(Query)
@strawberry.type
class Mutation:
@strawberry.mutation
def create_user(self, input: CreateUserInput) -> User:
return GraphQLMutations.create_user(input=input)

@strawberry.mutation
def update_user(self, input: UpdateUserInput) -> User:
return GraphQLMutations.update_user(input=input)

@strawberry.mutation
def create_team(self, input: CreateTeamInput) -> Team:
return GraphQLMutations.create_team(input=input)

@strawberry.mutation
def add_user_to_team(self, input: AddUserToTeamInput) -> bool:
return GraphQLMutations.add_user_to_team(input=input)

@strawberry.mutation
def remove_user_from_team(self, input: RemoveUserFromTeamInput) -> bool:
return GraphQLMutations.remove_user_from_team(input=input)


schema = strawberry.Schema(query=Query, mutation=Mutation)
42 changes: 41 additions & 1 deletion alphatrion/server/graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ class User:
id: strawberry.ID
username: str
email: str
team_id: strawberry.ID
avatar_url: str | None
meta: JSON | None
created_at: datetime
updated_at: datetime

@strawberry.field
def teams(self) -> list["Team"] | None:
from .resolvers import GraphQLResolvers

return GraphQLResolvers.list_teams(user_id=self.id)


@strawberry.type
class Project:
Expand Down Expand Up @@ -134,3 +140,37 @@ class Metric:
experiment_id: strawberry.ID
run_id: strawberry.ID
created_at: datetime


# Input types for mutations
@strawberry.input
class CreateUserInput:
username: str
email: str
avatar_url: str | None = None
meta: JSON | None = None


@strawberry.input
class CreateTeamInput:
name: str
description: str | None = None
meta: JSON | None = None


@strawberry.input
class UpdateUserInput:
id: strawberry.ID
meta: JSON | None = None


@strawberry.input
class AddUserToTeamInput:
user_id: strawberry.ID
team_id: strawberry.ID


@strawberry.input
class RemoveUserFromTeamInput:
user_id: strawberry.ID
team_id: strawberry.ID
3 changes: 2 additions & 1 deletion alphatrion/storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def create_user(
self,
username: str,
email: str,
team_id: uuid.UUID,
avatar_url: str | None = None,
team_id: uuid.UUID | None = None,
meta: dict | None = None,
) -> uuid.UUID:
raise NotImplementedError("Subclasses must implement this method.")
Expand Down
21 changes: 19 additions & 2 deletions alphatrion/storage/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class User(Base):
__tablename__ = "users"

uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
username = Column(String, nullable=False, unique=True)
username = Column(String, nullable=False)
email = Column(String, nullable=False, unique=True)
team_id = Column(UUID(as_uuid=True), nullable=False)
avatar_url = Column(String, nullable=True)
meta = Column(
MutableDict.as_mutable(JSON),
nullable=True,
Expand All @@ -74,6 +74,23 @@ class User(Base):
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")


class TeamMember(Base):
__tablename__ = "team_members"

uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
team_id = Column(UUID(as_uuid=True), nullable=False)
user_id = Column(UUID(as_uuid=True), nullable=False)

created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)

__table_args__ = (UniqueConstraint("team_id", "user_id", name="unique_team_user"),)


# Define the Project model for SQLAlchemy
class Project(Base):
__tablename__ = "projects"
Expand Down
Loading
Loading