diff --git a/linopy/matrices.py b/linopy/matrices.py index e1489e76..416cd184 100644 --- a/linopy/matrices.py +++ b/linopy/matrices.py @@ -51,7 +51,22 @@ def __init__(self, model: Model) -> None: def clean_cached_properties(self) -> None: """Clear the cache for all cached properties of an object""" - for cached_prop in ["flat_vars", "flat_cons", "sol", "dual"]: + for cached_prop in [ + "flat_vars", + "flat_cons", + "sol", + "dual", + "vlabels", + "clabels", + "A", + "c", + "b", + "sense", + "lb", + "ub", + "vtypes", + "Q", + ]: # check existence of cached_prop without creating it if cached_prop in self.__dict__: delattr(self, cached_prop) @@ -66,13 +81,13 @@ def flat_cons(self) -> pd.DataFrame: m = self._parent return m.constraints.flat - @property + @cached_property def vlabels(self) -> ndarray: """Vector of labels of all non-missing variables.""" df: pd.DataFrame = self.flat_vars return create_vector(df.key, df.labels, -1) - @property + @cached_property def vtypes(self) -> ndarray: """Vector of types of all non-missing variables.""" m = self._parent @@ -93,7 +108,7 @@ def vtypes(self) -> ndarray: ds = df.set_index("key").labels.map(ds) return create_vector(ds.index, ds.to_numpy(), fill_value="") - @property + @cached_property def lb(self) -> ndarray: """Vector of lower bounds of all non-missing variables.""" df: pd.DataFrame = self.flat_vars @@ -123,13 +138,13 @@ def dual(self) -> ndarray: ) return create_vector(df.key, df.dual, fill_value=np.nan) - @property + @cached_property def ub(self) -> ndarray: """Vector of upper bounds of all non-missing variables.""" df: pd.DataFrame = self.flat_vars return create_vector(df.key, df.upper) - @property + @cached_property def clabels(self) -> ndarray: """Vector of labels of all non-missing constraints.""" df: pd.DataFrame = self.flat_cons @@ -137,7 +152,7 @@ def clabels(self) -> ndarray: return np.array([], dtype=int) return create_vector(df.key, df.labels, fill_value=-1) - @property + @cached_property def A(self) -> csc_matrix | None: """Constraint matrix of all non-missing constraints and variables.""" m = self._parent @@ -146,19 +161,19 @@ def A(self) -> csc_matrix | None: A: csc_matrix = m.constraints.to_matrix(filter_missings=False) return A[self.clabels][:, self.vlabels] - @property + @cached_property def sense(self) -> ndarray: """Vector of senses of all non-missing constraints.""" df: pd.DataFrame = self.flat_cons return create_vector(df.key, df.sign.astype(np.dtype(" ndarray: """Vector of right-hand-sides of all non-missing constraints.""" df: pd.DataFrame = self.flat_cons return create_vector(df.key, df.rhs) - @property + @cached_property def c(self) -> ndarray: """Vector of objective coefficients of all non-missing variables.""" m = self._parent @@ -171,7 +186,7 @@ def c(self) -> ndarray: shape: int = self.flat_vars.key.max() + 1 return create_vector(vars, ds.coeffs, fill_value=0.0, shape=shape) - @property + @cached_property def Q(self) -> csc_matrix | None: """Matrix objective coefficients of quadratic terms of all non-missing variables.""" m = self._parent diff --git a/test/test_matrices.py b/test/test_matrices.py index 98a73564..e9abc564 100644 --- a/test/test_matrices.py +++ b/test/test_matrices.py @@ -77,3 +77,36 @@ def test_matrices_float_c() -> None: c = m.matrices.c assert np.all(c == np.array([1.5, 1.5])) + + +def test_matrices_properties_are_cached() -> None: + """Verify that MatrixAccessor properties are cached after first access.""" + m = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + m.add_constraints(1 * x + 10 * y, EQUAL, 0) + m.add_objective((10 * x + 5 * y).sum()) + + M = m.matrices + + # Access each property twice — second access should return the same object + for prop in ("vlabels", "clabels", "lb", "ub", "b", "sense", "c"): + first = getattr(M, prop) + second = getattr(M, prop) + assert first is second, f"{prop} is not cached (returns new object each time)" + + # A and Q return complex objects — verify they are also cached + first_A = M.A + second_A = M.A + assert first_A is second_A, "A is not cached" + + # Verify clean_cached_properties clears the cache + M.clean_cached_properties() + fresh = M.vlabels + assert fresh is not M.__dict__.get("_stale_ref", None) + # After cleaning, accessing again should still work + assert np.array_equal(fresh, M.vlabels)