Skip to content

Commit 4e41a1a

Browse files
authored
support alphatrion init (#152)
* support alphatrion init Signed-off-by: kerthcet <kerthcet@gmail.com> * fix lint Signed-off-by: kerthcet <kerthcet@gmail.com> * fix import error Signed-off-by: kerthcet <kerthcet@gmail.com> * fix import error Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 3950e4f commit 4e41a1a

7 files changed

Lines changed: 203 additions & 106 deletions

File tree

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,24 @@ ServerPWD: alphatr1on
5656

5757
You can also visit the Docker Registry UI at `http://localhost:80` to see the local registry where the built images are stored.
5858

59+
Next, init the environment with a user and team:
60+
61+
```bash
62+
alphatrion init # see -h for options to specify username, email and team name
63+
```
64+
65+
You will see the generated user ID and team ID in the console. Use these IDs to initialize the AlphaTrion environment in your code later.
5966

6067
### Run a Simple Experiment
6168

6269
Below is a simple example with two approaches demonstrating how to create an experiment and log performance metrics.
6370

6471
```python
65-
import uuid
66-
6772
import alphatrion as alpha
6873
from alphatrion import experiment, project
6974

70-
# Better to use a fixed UUID for the team and user in real scenarios.
71-
alpha.init(team_id=uuid.uuid4(), user_id=uuid.uuid4())
75+
# Use the user ID and team ID generated from the `alphatrion init` command.
76+
alpha.init(user_id=<user_id>, team_id=<team_id>)
7277

7378
async def your_task():
7479
# Run your code here then log metrics.

alphatrion/server/cmd/main.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
import httpx
1111
import uvicorn
1212
from dotenv import load_dotenv
13+
from faker import Faker
1314
from fastapi import FastAPI, Request
1415
from fastapi.responses import FileResponse, Response
1516
from fastapi.staticfiles import StaticFiles
1617
from rich.console import Console
1718
from rich.text import Text
1819

19-
from alphatrion.server.graphql.runtime import init as graphql_init
20+
from alphatrion.server import runtime
2021

2122
load_dotenv()
2223
console = Console()
@@ -63,6 +64,30 @@ def main():
6364
)
6465
dashboard.set_defaults(func=start_dashboard)
6566

67+
# init command
68+
init = subparsers.add_parser(
69+
"init", help="Initialize AlphaTrion with a user and team"
70+
)
71+
init.add_argument(
72+
"--username",
73+
type=str,
74+
default=None,
75+
help="Username for the new user (auto-generated if not provided)",
76+
)
77+
init.add_argument(
78+
"--email",
79+
type=str,
80+
default=None,
81+
help="Email for the new user (auto-generated if not provided)",
82+
)
83+
init.add_argument(
84+
"--teamname",
85+
type=str,
86+
default="Default Team",
87+
help="Team name (default: Default Team)",
88+
)
89+
init.set_defaults(func=init_command)
90+
6691
# version command
6792
version = subparsers.add_parser("version", help="Show the version of AlphaTrion")
6893
version.set_defaults(func=lambda args: print(f"AlphaTrion version {__version__}"))
@@ -74,6 +99,71 @@ def main():
7499
parser.print_help()
75100

76101

102+
def init_command(args):
103+
"""Initialize AlphaTrion with a user and team."""
104+
# Initialize the Server runtime to get access to metadb
105+
runtime.init(init_tables=True)
106+
107+
fake = Faker()
108+
109+
# Generate username if not provided
110+
username = args.username if args.username else fake.name()
111+
email = (
112+
args.email
113+
if args.email
114+
else f"{username.lower().replace(' ', '.')}@inftyai.com"
115+
)
116+
teamname = args.teamname
117+
118+
try:
119+
metadb = runtime.server_runtime().metadb
120+
121+
# Create user
122+
console.print(
123+
Text(f"👤 Creating user: {username} ({email})", style="bold cyan")
124+
)
125+
user_id = metadb.create_user(username=username, email=email)
126+
127+
# Create team
128+
console.print(Text(f"🏢 Creating team: {teamname}", style="bold cyan"))
129+
team_id = metadb.create_team(name=teamname, description=f"Team for {username}")
130+
131+
# Add user to team
132+
metadb.add_user_to_team(user_id=user_id, team_id=team_id)
133+
134+
console.print()
135+
console.print(Text("✅ Initialization successful!", style="bold green"))
136+
console.print()
137+
console.print(Text("📋 Your user ID:", style="bold yellow"))
138+
console.print(Text(f" {user_id}", style="bold cyan"))
139+
console.print(Text(" Your team ID:", style="bold yellow"))
140+
console.print(Text(f" {team_id}", style="bold cyan"))
141+
console.print()
142+
console.print(Text("💡 Use this user ID to launch the dashboard:", style="dim"))
143+
console.print(
144+
Text(f" alphatrion dashboard --userid {user_id}", style="magenta")
145+
)
146+
console.print()
147+
console.print(
148+
Text(
149+
"🚀 Use this user ID and team ID to setup the experiment environment:",
150+
style="dim",
151+
)
152+
)
153+
console.print(Text(" import alphatrion as alpha", style="white"))
154+
console.print(
155+
Text(
156+
f" alpha.init(user_id='{user_id}', team_id='{team_id}')",
157+
style="white",
158+
)
159+
)
160+
console.print()
161+
162+
except Exception as e:
163+
console.print(Text(f"❌ Error during initialization: {e}", style="bold red"))
164+
raise
165+
166+
77167
def run_server(args):
78168
BLUE = "\033[94m"
79169
RESET = "\033[0m"
@@ -99,7 +189,7 @@ def run_server(args):
99189
style="bold green",
100190
)
101191
console.print(msg)
102-
graphql_init()
192+
runtime.init()
103193
uvicorn.run("alphatrion.server.cmd.app:app", host=args.host, port=args.port)
104194

105195

alphatrion/server/graphql/resolvers.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import strawberry
55

6-
from alphatrion.server.graphql import runtime
6+
from alphatrion.server import runtime
77
from alphatrion.storage.sql_models import Status
88

99
from .types import (
@@ -27,7 +27,7 @@
2727
class GraphQLResolvers:
2828
@staticmethod
2929
def list_teams(user_id: strawberry.ID) -> list[Team]:
30-
metadb = runtime.graphql_runtime().metadb
30+
metadb = runtime.server_runtime().metadb
3131
teams = metadb.list_user_teams(user_id=user_id)
3232
return [
3333
Team(
@@ -43,7 +43,7 @@ def list_teams(user_id: strawberry.ID) -> list[Team]:
4343

4444
@staticmethod
4545
def get_team(id: strawberry.ID) -> Team | None:
46-
metadb = runtime.graphql_runtime().metadb
46+
metadb = runtime.server_runtime().metadb
4747
team = metadb.get_team(team_id=uuid.UUID(id))
4848
if team:
4949
return Team(
@@ -58,7 +58,7 @@ def get_team(id: strawberry.ID) -> Team | None:
5858

5959
@staticmethod
6060
def get_user(id: strawberry.ID) -> User | None:
61-
metadb = runtime.graphql_runtime().metadb
61+
metadb = runtime.server_runtime().metadb
6262
user = metadb.get_user(user_id=uuid.UUID(id))
6363
if user:
6464
return User(
@@ -80,7 +80,7 @@ def list_projects(
8080
order_by: str = "created_at",
8181
order_desc: bool = True,
8282
) -> list[Project]:
83-
metadb = runtime.graphql_runtime().metadb
83+
metadb = runtime.server_runtime().metadb
8484
projects = metadb.list_projects(
8585
team_id=uuid.UUID(team_id),
8686
page=page,
@@ -104,7 +104,7 @@ def list_projects(
104104

105105
@staticmethod
106106
def get_project(id: strawberry.ID) -> Project | None:
107-
metadb = runtime.graphql_runtime().metadb
107+
metadb = runtime.server_runtime().metadb
108108
proj = metadb.get_project(project_id=uuid.UUID(id))
109109
if proj:
110110
return Project(
@@ -127,7 +127,7 @@ def list_experiments(
127127
order_by: str = "created_at",
128128
order_desc: bool = True,
129129
) -> list[Experiment]:
130-
metadb = runtime.graphql_runtime().metadb
130+
metadb = runtime.server_runtime().metadb
131131
exps = metadb.list_exps_by_project_id(
132132
project_id=uuid.UUID(project_id),
133133
page=page,
@@ -156,7 +156,7 @@ def list_experiments(
156156

157157
@staticmethod
158158
def get_experiment(id: strawberry.ID) -> Experiment | None:
159-
metadb = runtime.graphql_runtime().metadb
159+
metadb = runtime.server_runtime().metadb
160160
exp = metadb.get_experiment(experiment_id=uuid.UUID(id))
161161
if exp:
162162
return Experiment(
@@ -184,7 +184,7 @@ def list_runs(
184184
order_by: str = "created_at",
185185
order_desc: bool = True,
186186
) -> list[Run]:
187-
metadb = runtime.graphql_runtime().metadb
187+
metadb = runtime.server_runtime().metadb
188188
runs = metadb.list_runs_by_exp_id(
189189
exp_id=uuid.UUID(experiment_id),
190190
page=page,
@@ -208,7 +208,7 @@ def list_runs(
208208

209209
@staticmethod
210210
def get_run(id: strawberry.ID) -> Run | None:
211-
metadb = runtime.graphql_runtime().metadb
211+
metadb = runtime.server_runtime().metadb
212212
run = metadb.get_run(run_id=uuid.UUID(id))
213213
if run:
214214
return Run(
@@ -225,7 +225,7 @@ def get_run(id: strawberry.ID) -> Run | None:
225225

226226
@staticmethod
227227
def list_exp_metrics(experiment_id: strawberry.ID) -> list[Metric]:
228-
metadb = runtime.graphql_runtime().metadb
228+
metadb = runtime.server_runtime().metadb
229229
metrics = metadb.list_metrics_by_experiment_id(experiment_id=experiment_id)
230230
return [
231231
Metric(
@@ -243,17 +243,17 @@ def list_exp_metrics(experiment_id: strawberry.ID) -> list[Metric]:
243243

244244
@staticmethod
245245
def total_projects(team_id: strawberry.ID) -> int:
246-
metadb = runtime.graphql_runtime().metadb
246+
metadb = runtime.server_runtime().metadb
247247
return metadb.count_projects(team_id=team_id)
248248

249249
@staticmethod
250250
def total_experiments(team_id: strawberry.ID) -> int:
251-
metadb = runtime.graphql_runtime().metadb
251+
metadb = runtime.server_runtime().metadb
252252
return metadb.count_experiments(team_id=team_id)
253253

254254
@staticmethod
255255
def total_runs(team_id: strawberry.ID) -> int:
256-
metadb = runtime.graphql_runtime().metadb
256+
metadb = runtime.server_runtime().metadb
257257
return metadb.count_runs(team_id=team_id)
258258

259259
@staticmethod
@@ -262,7 +262,7 @@ def list_exps_by_timeframe(
262262
start_time: datetime,
263263
end_time: datetime,
264264
) -> list[Experiment]:
265-
metadb = runtime.graphql_runtime().metadb
265+
metadb = runtime.server_runtime().metadb
266266
experiments = metadb.list_exps_by_timeframe(
267267
team_id=team_id,
268268
start_time=start_time,
@@ -292,7 +292,7 @@ def list_exps_by_timeframe(
292292
class GraphQLMutations:
293293
@staticmethod
294294
def create_user(input: CreateUserInput) -> User:
295-
metadb = runtime.graphql_runtime().metadb
295+
metadb = runtime.server_runtime().metadb
296296
user_id = metadb.create_user(
297297
uuid=uuid.UUID(input.id) if input.id else None,
298298
username=input.username,
@@ -316,7 +316,7 @@ def create_user(input: CreateUserInput) -> User:
316316

317317
@staticmethod
318318
def update_user(input: UpdateUserInput) -> User:
319-
metadb = runtime.graphql_runtime().metadb
319+
metadb = runtime.server_runtime().metadb
320320
user_id = uuid.UUID(input.id)
321321

322322
user = metadb.update_user(user_id=user_id, meta=input.meta)
@@ -336,7 +336,7 @@ def update_user(input: UpdateUserInput) -> User:
336336

337337
@staticmethod
338338
def create_team(input: CreateTeamInput) -> Team:
339-
metadb = runtime.graphql_runtime().metadb
339+
metadb = runtime.server_runtime().metadb
340340
team_id = metadb.create_team(
341341
uuid=uuid.UUID(input.id) if input.id else None,
342342
name=input.name,
@@ -358,7 +358,7 @@ def create_team(input: CreateTeamInput) -> Team:
358358

359359
@staticmethod
360360
def add_user_to_team(input: AddUserToTeamInput) -> bool:
361-
metadb = runtime.graphql_runtime().metadb
361+
metadb = runtime.server_runtime().metadb
362362
user_id = uuid.UUID(input.user_id)
363363
team_id = uuid.UUID(input.team_id)
364364

@@ -379,7 +379,7 @@ def add_user_to_team(input: AddUserToTeamInput) -> bool:
379379

380380
@staticmethod
381381
def remove_user_from_team(input: RemoveUserFromTeamInput) -> bool:
382-
metadb = runtime.graphql_runtime().metadb
382+
metadb = runtime.server_runtime().metadb
383383
user_id = uuid.UUID(input.user_id)
384384
team_id = uuid.UUID(input.team_id)
385385

alphatrion/server/graphql/runtime.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)