Skip to content

Commit 5db1e47

Browse files
committed
Propagate M.
1 parent c9501b0 commit 5db1e47

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

wedoco_optimo/model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,22 @@ def transfer_model(self, model: str, modelica_files=list[str], force_recompile:
5252

5353
return fmu_file_path
5454

55-
def define_time_grid(self, start_time: float, end_time: float, dt: float):
55+
def define_time_grid(self, start_time: float, end_time: float, dt: float, M: int=1):
5656
"""
5757
Define the time grid for the simulation and optimization problems.
5858
Args:
5959
start_time: The starting time (float, e.g., 0.0)
6060
end_time: The ending time (float, e.g., 10.0)
6161
dt: Time interval between points (float, e.g., 0.1)
62+
M: Number of time steps per control step (int, default=1)
6263
"""
6364
self.start_time = start_time
6465
self.end_time = end_time
6566
self.dt = dt
67+
self.M = M
6668
# Use np.arange to ensure the last point is included if possible
6769
self.tgrid = np.arange(start_time, end_time + dt, dt)
6870
self.N = len(self.tgrid) - 1 # Number of control steps
69-
self.M = 1 # Number of integration steps per control step
7071
self.t_horizon = end_time - start_time
7172

7273
def get_default_x0(self):
@@ -187,13 +188,14 @@ def optimize(self):
187188
sol = self.ocp.solve()
188189

189190
# Extract the results
190-
x_ocp = np.atleast_2d([sol.sample(self.dae.var(v), grid="integrator")[1].T for v in self.dae.x()])
191+
t_ocp = sol.sample(self.dae.var(self.dae.u()[0]), grid="control")[0].T
192+
x_ocp = np.atleast_2d([sol.sample(self.dae.var(v), grid="control")[1].T for v in self.dae.x()])
191193
u_ocp = np.atleast_2d([sol.sample(self.dae.var(v), grid="control")[1].T for v in self.dae.u()])
192194
res_xyu_ocp = self.f_xu_xyu(x=x_ocp, u=u_ocp)
193195
y_ocp = res_xyu_ocp["y"].full()
194196
u_ocp = res_xyu_ocp["u"].full()
195197

196198
t0 = 0
197-
res_df = get_dae_results(self.tgrid, self.dae, x_ocp, y_ocp, u_ocp, t0)
199+
res_df = get_dae_results(t_ocp, self.dae, x_ocp, y_ocp, u_ocp, t0)
198200

199201
return res_df

0 commit comments

Comments
 (0)