@@ -70,7 +70,7 @@ def define_time_grid(self, start_time: float, end_time: float, dt: float):
7070 self .t_horizon = end_time - start_time
7171
7272 def get_default_x0 (self ):
73- return self .dae .start (self .dae .x ())
73+ return { x_name : self .dae .start (x_name ) for x_name in self .dae .x ()}
7474
7575 def get_default_u (self ):
7676 u_start = np .array (self .dae .start (self .dae .u ()))
@@ -79,7 +79,7 @@ def get_default_u(self):
7979 else :
8080 return u_start .reshape (- 1 , 1 )* np .ones ((1 , self .N + 1 ))
8181
82- def simulate (self , x0 : np . array = None , u_sim : np .array = None ,
82+ def simulate (self , x0 : dict = None , u_sim : np .array = None ,
8383 start_time : float = None , end_time : float = None , dt : float = None ):
8484 # Optionally override time grid if parameters are provided
8585 if start_time is not None and end_time is not None and dt is not None :
@@ -90,7 +90,9 @@ def simulate(self, x0: np.array = None, u_sim: np.array = None,
9090 self .f_sim = ca .integrator ("simulator" , "cvodes" , dae_dict , 0 , self .tgrid , opts )
9191
9292 # Get external values if provided. Otherwise use those from the model
93- x0 = x0 if x0 is not None else self .get_default_x0 ()
93+ x0 = x0 if x0 is not None else self .get_default_x0 ()
94+ # Convert x0 to a numpy array as required by casadi
95+ x0 = np .atleast_2d ([x0 [name ] for name in self .dae .x ()])
9496 u_sim = np .atleast_2d (u_sim ) if u_sim is not None else self .get_default_u ()
9597
9698 # Simulate the model dynamics
@@ -154,8 +156,12 @@ def define_optimization(self,
154156 x0 = x0 if x0 is not None else self .get_default_x0 ()
155157
156158 # Set initial state values
157- for i , x_name in enumerate (self .dae .x ()):
158- self .ocp .subject_to (self .ocp .at_t0 (self .dae .var (x_name )) == x0 [i ])
159+ for x_name in self .dae .x ():
160+ self .ocp .subject_to (self .ocp .at_t0 (self .dae .var (x_name )) == x0 [x_name ])
161+
162+ # Set the initial guess based on the initial state values
163+ for x_name in self .dae .x ():
164+ self .ocp .set_initial (self .dae .var (x_name ), x0 [x_name ])
159165
160166 # Set initial time
161167 t0 = 0
0 commit comments