33from sqlalchemy .pool import StaticPool
44from typing import Optional , List
55from app .config import settings
6- from app .models import Ship , SessionShip
6+ from app .models import Ship , SessionShip , ShipStatus
77from datetime import datetime , timezone
88
99
@@ -63,10 +63,12 @@ async def update_ship(self, ship: Ship) -> Ship:
6363 ship .updated_at = datetime .now (timezone .utc )
6464 session = self .get_session ()
6565 try :
66- session .add (ship )
66+ # Use merge() instead of add() to handle detached objects
67+ # merge() copies the state of the given instance into a persistent instance
68+ merged_ship = await session .merge (ship )
6769 await session .commit ()
68- await session .refresh (ship )
69- return ship
70+ await session .refresh (merged_ship )
71+ return merged_ship
7072 finally :
7173 await session .close ()
7274
@@ -87,10 +89,23 @@ async def delete_ship(self, ship_id: str) -> bool:
8789 await session .close ()
8890
8991 async def list_active_ships (self ) -> List [Ship ]:
90- """List all active ships"""
92+ """List all active ships (running and creating)"""
93+ session = self .get_session ()
94+ try :
95+ # Include both RUNNING and CREATING status ships
96+ statement = select (Ship ).where (
97+ (Ship .status == ShipStatus .RUNNING ) | (Ship .status == ShipStatus .CREATING )
98+ )
99+ result = await session .execute (statement )
100+ return list (result .scalars ().all ())
101+ finally :
102+ await session .close ()
103+
104+ async def list_all_ships (self ) -> List [Ship ]:
105+ """List all ships (including stopped)"""
91106 session = self .get_session ()
92107 try :
93- statement = select (Ship ).where (Ship .status == 1 )
108+ statement = select (Ship ).order_by (Ship .created_at . desc () )
94109 result = await session .execute (statement )
95110 return list (result .scalars ().all ())
96111 finally :
@@ -163,20 +178,21 @@ async def update_session_ship(self, session_ship: SessionShip) -> SessionShip:
163178 """Update session-ship relationship"""
164179 session = self .get_session ()
165180 try :
166- session .add (session_ship )
181+ # Use merge() instead of add() to handle detached objects
182+ merged_session_ship = await session .merge (session_ship )
167183 await session .commit ()
168- await session .refresh (session_ship )
169- return session_ship
184+ await session .refresh (merged_session_ship )
185+ return merged_session_ship
170186 finally :
171187 await session .close ()
172188
173189 async def find_available_ship (self , session_id : str ) -> Optional [Ship ]:
174190 """Find an available ship that can accept a new session"""
175191 session = self .get_session ()
176192 try :
177- # Find ships that have available session slots
193+ # Find ships that have available session slots (only RUNNING ships)
178194 statement = select (Ship ).where (
179- Ship .status == 1 , Ship .current_session_num < Ship .max_session_num
195+ Ship .status == ShipStatus . RUNNING , Ship .current_session_num < Ship .max_session_num
180196 )
181197 result = await session .execute (statement )
182198 ships = list (result .scalars ().all ())
@@ -193,38 +209,50 @@ async def find_available_ship(self, session_id: str) -> Optional[Ship]:
193209 await session .close ()
194210
195211 async def find_active_ship_for_session (self , session_id : str ) -> Optional [Ship ]:
196- """Find an active running ship that this session has access to"""
212+ """Find an active running ship that this session has access to.
213+
214+ If the session has access to multiple running ships, returns the most recently updated one.
215+ """
197216 session = self .get_session ()
198217 try :
199- # Find active ships that this session has access to
218+ # Find RUNNING ships that this session has access to
219+ # Order by updated_at desc to get the most recently used one
200220 statement = (
201221 select (Ship )
202222 .join (SessionShip , Ship .id == SessionShip .ship_id )
203223 .where (
204224 SessionShip .session_id == session_id ,
205- Ship .status == 1 ,
225+ Ship .status == ShipStatus . RUNNING ,
206226 )
227+ .order_by (Ship .updated_at .desc ())
207228 )
208229 result = await session .execute (statement )
209- return result .scalar_one_or_none ()
230+ # Use scalars().first() instead of scalar_one_or_none() to handle multiple results
231+ return result .scalars ().first ()
210232 finally :
211233 await session .close ()
212234
213235 async def find_stopped_ship_for_session (self , session_id : str ) -> Optional [Ship ]:
214- """Find a stopped ship that belongs to this session"""
236+ """Find a stopped ship that belongs to this session.
237+
238+ If the session has access to multiple stopped ships, returns the most recently updated one.
239+ """
215240 session = self .get_session ()
216241 try :
217- # Find stopped ships that this session has access to
242+ # Find STOPPED ships that this session has access to
243+ # Order by updated_at desc to get the most recently stopped one
218244 statement = (
219245 select (Ship )
220246 .join (SessionShip , Ship .id == SessionShip .ship_id )
221247 .where (
222248 SessionShip .session_id == session_id ,
223- Ship .status == 0 ,
249+ Ship .status == ShipStatus . STOPPED ,
224250 )
251+ .order_by (Ship .updated_at .desc ())
225252 )
226253 result = await session .execute (statement )
227- return result .scalar_one_or_none ()
254+ # Use scalars().first() instead of scalar_one_or_none() to handle multiple results
255+ return result .scalars ().first ()
228256 finally :
229257 await session .close ()
230258
@@ -266,5 +294,88 @@ async def decrement_ship_session_count(self, ship_id: str) -> Optional[Ship]:
266294 finally :
267295 await session .close ()
268296
297+ async def delete_sessions_for_ship (self , ship_id : str ) -> List [str ]:
298+ """Delete all session-ship relationships for a ship and return deleted session IDs"""
299+ session = self .get_session ()
300+ try :
301+ # First, get all session IDs for this ship
302+ statement = select (SessionShip ).where (SessionShip .ship_id == ship_id )
303+ result = await session .execute (statement )
304+ session_ships = list (result .scalars ().all ())
305+
306+ deleted_session_ids = [ss .session_id for ss in session_ships ]
307+
308+ # Delete all session-ship relationships
309+ for ss in session_ships :
310+ await session .delete (ss )
311+
312+ await session .commit ()
313+ return deleted_session_ids
314+ finally :
315+ await session .close ()
316+
317+ async def extend_session_ttl (
318+ self , session_id : str , ttl : int
319+ ) -> Optional [SessionShip ]:
320+ """Extend the TTL for a session by updating expires_at"""
321+ from datetime import timedelta
322+
323+ session = self .get_session ()
324+ try :
325+ statement = select (SessionShip ).where (SessionShip .session_id == session_id )
326+ result = await session .execute (statement )
327+ session_ship = result .scalar_one_or_none ()
328+
329+ if session_ship :
330+ now = datetime .now (timezone .utc )
331+ session_ship .expires_at = now + timedelta (seconds = ttl )
332+ session_ship .last_activity = now
333+ session .add (session_ship )
334+ await session .commit ()
335+ await session .refresh (session_ship )
336+
337+ return session_ship
338+ finally :
339+ await session .close ()
340+
341+ async def expire_sessions_for_ship (self , ship_id : str ) -> int :
342+ """Mark all sessions for a ship as expired by setting expires_at to current time.
343+
344+ This is called when a ship is stopped to ensure session status
345+ reflects the actual container state.
346+
347+ Args:
348+ ship_id: The ship ID
349+
350+ Returns:
351+ Number of sessions updated
352+ """
353+ session = self .get_session ()
354+ try :
355+ statement = select (SessionShip ).where (SessionShip .ship_id == ship_id )
356+ result = await session .execute (statement )
357+ session_ships = list (result .scalars ().all ())
358+
359+ now = datetime .now (timezone .utc )
360+ updated_count = 0
361+
362+ for ss in session_ships :
363+ # Only update if session is still active (expires_at > now)
364+ expires_at = ss .expires_at
365+ if expires_at is not None :
366+ if expires_at .tzinfo is None :
367+ expires_at = expires_at .replace (tzinfo = timezone .utc )
368+ if expires_at > now :
369+ ss .expires_at = now
370+ session .add (ss )
371+ updated_count += 1
372+
373+ if updated_count > 0 :
374+ await session .commit ()
375+
376+ return updated_count
377+ finally :
378+ await session .close ()
379+
269380
270381db_service = DatabaseService ()
0 commit comments