Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,24 @@ ServerPWD: alphatr1on

You can also visit the Docker Registry UI at `http://localhost:80` to see the local registry where the built images are stored.

Next, init the environment with a user and team:

```bash
alphatrion init # see -h for options to specify username, email and team name
```

You will see the generated user ID and team ID in the console. Use these IDs to initialize the AlphaTrion environment in your code later.

### Run a Simple Experiment

Below is a simple example with two approaches demonstrating how to create an experiment and log performance metrics.

```python
import uuid

import alphatrion as alpha
from alphatrion import experiment, project

# Better to use a fixed UUID for the team and user in real scenarios.
alpha.init(team_id=uuid.uuid4(), user_id=uuid.uuid4())
# Use the user ID and team ID generated from the `alphatrion init` command.
alpha.init(user_id=<user_id>, team_id=<team_id>)

async def your_task():
# Run your code here then log metrics.
Expand Down
94 changes: 92 additions & 2 deletions alphatrion/server/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import httpx
import uvicorn
from dotenv import load_dotenv
from faker import Faker
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from rich.console import Console
from rich.text import Text

from alphatrion.server.graphql.runtime import init as graphql_init
from alphatrion.server import runtime

load_dotenv()
console = Console()
Expand Down Expand Up @@ -63,6 +64,30 @@ def main():
)
dashboard.set_defaults(func=start_dashboard)

# init command
init = subparsers.add_parser(
"init", help="Initialize AlphaTrion with a user and team"
)
init.add_argument(
"--username",
type=str,
default=None,
help="Username for the new user (auto-generated if not provided)",
)
init.add_argument(
"--email",
type=str,
default=None,
help="Email for the new user (auto-generated if not provided)",
)
init.add_argument(
"--teamname",
type=str,
default="Default Team",
help="Team name (default: Default Team)",
)
init.set_defaults(func=init_command)

# version command
version = subparsers.add_parser("version", help="Show the version of AlphaTrion")
version.set_defaults(func=lambda args: print(f"AlphaTrion version {__version__}"))
Expand All @@ -74,6 +99,71 @@ def main():
parser.print_help()


def init_command(args):
"""Initialize AlphaTrion with a user and team."""
# Initialize the Server runtime to get access to metadb
runtime.init(init_tables=True)

fake = Faker()

# Generate username if not provided
username = args.username if args.username else fake.name()
email = (
args.email
if args.email
else f"{username.lower().replace(' ', '.')}@inftyai.com"
)
teamname = args.teamname

try:
metadb = runtime.server_runtime().metadb

# Create user
console.print(
Text(f"👤 Creating user: {username} ({email})", style="bold cyan")
)
user_id = metadb.create_user(username=username, email=email)

# Create team
console.print(Text(f"🏢 Creating team: {teamname}", style="bold cyan"))
team_id = metadb.create_team(name=teamname, description=f"Team for {username}")

# Add user to team
metadb.add_user_to_team(user_id=user_id, team_id=team_id)

console.print()
console.print(Text("✅ Initialization successful!", style="bold green"))
console.print()
console.print(Text("📋 Your user ID:", style="bold yellow"))
console.print(Text(f" {user_id}", style="bold cyan"))
console.print(Text(" Your team ID:", style="bold yellow"))
console.print(Text(f" {team_id}", style="bold cyan"))
console.print()
console.print(Text("💡 Use this user ID to launch the dashboard:", style="dim"))
console.print(
Text(f" alphatrion dashboard --userid {user_id}", style="magenta")
)
console.print()
console.print(
Text(
"🚀 Use this user ID and team ID to setup the experiment environment:",
style="dim",
)
)
console.print(Text(" import alphatrion as alpha", style="white"))
console.print(
Text(
f" alpha.init(user_id='{user_id}', team_id='{team_id}')",
style="white",
)
)
console.print()

except Exception as e:
console.print(Text(f"❌ Error during initialization: {e}", style="bold red"))
raise


def run_server(args):
BLUE = "\033[94m"
RESET = "\033[0m"
Expand All @@ -99,7 +189,7 @@ def run_server(args):
style="bold green",
)
console.print(msg)
graphql_init()
runtime.init()
uvicorn.run("alphatrion.server.cmd.app:app", host=args.host, port=args.port)


Expand Down
40 changes: 20 additions & 20 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import strawberry

from alphatrion.server.graphql import runtime
from alphatrion.server import runtime
from alphatrion.storage.sql_models import Status

from .types import (
Expand All @@ -27,7 +27,7 @@
class GraphQLResolvers:
@staticmethod
def list_teams(user_id: strawberry.ID) -> list[Team]:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
teams = metadb.list_user_teams(user_id=user_id)
return [
Team(
Expand All @@ -43,7 +43,7 @@ def list_teams(user_id: strawberry.ID) -> list[Team]:

@staticmethod
def get_team(id: strawberry.ID) -> Team | None:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
team = metadb.get_team(team_id=uuid.UUID(id))
if team:
return Team(
Expand All @@ -58,7 +58,7 @@ def get_team(id: strawberry.ID) -> Team | None:

@staticmethod
def get_user(id: strawberry.ID) -> User | None:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
user = metadb.get_user(user_id=uuid.UUID(id))
if user:
return User(
Expand All @@ -80,7 +80,7 @@ def list_projects(
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Project]:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
projects = metadb.list_projects(
team_id=uuid.UUID(team_id),
page=page,
Expand All @@ -104,7 +104,7 @@ def list_projects(

@staticmethod
def get_project(id: strawberry.ID) -> Project | None:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
proj = metadb.get_project(project_id=uuid.UUID(id))
if proj:
return Project(
Expand All @@ -127,7 +127,7 @@ def list_experiments(
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Experiment]:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
exps = metadb.list_exps_by_project_id(
project_id=uuid.UUID(project_id),
page=page,
Expand Down Expand Up @@ -156,7 +156,7 @@ def list_experiments(

@staticmethod
def get_experiment(id: strawberry.ID) -> Experiment | None:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
exp = metadb.get_experiment(experiment_id=uuid.UUID(id))
if exp:
return Experiment(
Expand Down Expand Up @@ -184,7 +184,7 @@ def list_runs(
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Run]:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
runs = metadb.list_runs_by_exp_id(
exp_id=uuid.UUID(experiment_id),
page=page,
Expand All @@ -208,7 +208,7 @@ def list_runs(

@staticmethod
def get_run(id: strawberry.ID) -> Run | None:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
run = metadb.get_run(run_id=uuid.UUID(id))
if run:
return Run(
Expand All @@ -225,7 +225,7 @@ def get_run(id: strawberry.ID) -> Run | None:

@staticmethod
def list_exp_metrics(experiment_id: strawberry.ID) -> list[Metric]:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
metrics = metadb.list_metrics_by_experiment_id(experiment_id=experiment_id)
return [
Metric(
Expand All @@ -243,17 +243,17 @@ def list_exp_metrics(experiment_id: strawberry.ID) -> list[Metric]:

@staticmethod
def total_projects(team_id: strawberry.ID) -> int:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
return metadb.count_projects(team_id=team_id)

@staticmethod
def total_experiments(team_id: strawberry.ID) -> int:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
return metadb.count_experiments(team_id=team_id)

@staticmethod
def total_runs(team_id: strawberry.ID) -> int:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
return metadb.count_runs(team_id=team_id)

@staticmethod
Expand All @@ -262,7 +262,7 @@ def list_exps_by_timeframe(
start_time: datetime,
end_time: datetime,
) -> list[Experiment]:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
experiments = metadb.list_exps_by_timeframe(
team_id=team_id,
start_time=start_time,
Expand Down Expand Up @@ -292,7 +292,7 @@ def list_exps_by_timeframe(
class GraphQLMutations:
@staticmethod
def create_user(input: CreateUserInput) -> User:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
user_id = metadb.create_user(
uuid=uuid.UUID(input.id) if input.id else None,
username=input.username,
Expand All @@ -316,7 +316,7 @@ def create_user(input: CreateUserInput) -> User:

@staticmethod
def update_user(input: UpdateUserInput) -> User:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
user_id = uuid.UUID(input.id)

user = metadb.update_user(user_id=user_id, meta=input.meta)
Expand All @@ -336,7 +336,7 @@ def update_user(input: UpdateUserInput) -> User:

@staticmethod
def create_team(input: CreateTeamInput) -> Team:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
team_id = metadb.create_team(
uuid=uuid.UUID(input.id) if input.id else None,
name=input.name,
Expand All @@ -358,7 +358,7 @@ def create_team(input: CreateTeamInput) -> Team:

@staticmethod
def add_user_to_team(input: AddUserToTeamInput) -> bool:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
user_id = uuid.UUID(input.user_id)
team_id = uuid.UUID(input.team_id)

Expand All @@ -379,7 +379,7 @@ def add_user_to_team(input: AddUserToTeamInput) -> bool:

@staticmethod
def remove_user_from_team(input: RemoveUserFromTeamInput) -> bool:
metadb = runtime.graphql_runtime().metadb
metadb = runtime.server_runtime().metadb
user_id = uuid.UUID(input.user_id)
team_id = uuid.UUID(input.team_id)

Expand Down
36 changes: 0 additions & 36 deletions alphatrion/server/graphql/runtime.py

This file was deleted.

Loading
Loading