Skip to content

Commit cd68c03

Browse files
committed
Use initial states to set initial guess.
1 parent b82a6e6 commit cd68c03

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

wedoco_optimo/model.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def define_time_grid(self, start_time: float, end_time: float, dt: float):
7070
self.t_horizon = end_time - start_time
7171

7272
def get_default_x0(self):
73-
return self.dae.start(self.dae.x())
73+
return {x_name: self.dae.start(x_name) for x_name in self.dae.x()}
7474

7575
def get_default_u(self):
7676
u_start = np.array(self.dae.start(self.dae.u()))
@@ -79,7 +79,7 @@ def get_default_u(self):
7979
else:
8080
return u_start.reshape(-1, 1)*np.ones((1, self.N+1))
8181

82-
def simulate(self, x0: np.array = None, u_sim: np.array = None,
82+
def simulate(self, x0: dict = None, u_sim: np.array = None,
8383
start_time: float = None, end_time: float = None, dt: float = None):
8484
# Optionally override time grid if parameters are provided
8585
if start_time is not None and end_time is not None and dt is not None:
@@ -90,7 +90,9 @@ def simulate(self, x0: np.array = None, u_sim: np.array = None,
9090
self.f_sim = ca.integrator("simulator", "cvodes", dae_dict, 0, self.tgrid, opts)
9191

9292
# Get external values if provided. Otherwise use those from the model
93-
x0 = x0 if x0 is not None else self.get_default_x0()
93+
x0 = x0 if x0 is not None else self.get_default_x0()
94+
# Convert x0 to a numpy array as required by casadi
95+
x0 = np.atleast_2d([x0[name] for name in self.dae.x()])
9496
u_sim = np.atleast_2d(u_sim) if u_sim is not None else self.get_default_u()
9597

9698
# Simulate the model dynamics
@@ -154,8 +156,12 @@ def define_optimization(self,
154156
x0 = x0 if x0 is not None else self.get_default_x0()
155157

156158
# Set initial state values
157-
for i, x_name in enumerate(self.dae.x()):
158-
self.ocp.subject_to(self.ocp.at_t0(self.dae.var(x_name)) == x0[i])
159+
for x_name in self.dae.x():
160+
self.ocp.subject_to(self.ocp.at_t0(self.dae.var(x_name)) == x0[x_name])
161+
162+
# Set the initial guess based on the initial state values
163+
for x_name in self.dae.x():
164+
self.ocp.set_initial(self.dae.var(x_name), x0[x_name])
159165

160166
# Set initial time
161167
t0 = 0

0 commit comments

Comments
 (0)