Skip to content

Commit 36b15c5

Browse files
FBumannclaude
andauthored
perf: speed up LP file writing (2.5-3.9x on large models, no regressions on small) (#564)
* perf: use Polars streaming engine for LP file writing Extract _format_and_write() helper that uses lazy().collect(engine="streaming") with automatic fallback, replacing 7 instances of df.select(concat_str(...)).write_csv(...). * fix: log warning with traceback when Polars streaming fallback triggers * perf: speed up LP constraint writing by replacing concat+sort with join Replace the vertical concat + sort approach in Constraint.to_polars() with an inner join, so every row has all columns populated. This removes the need for the group_by validation step in constraints_to_file() and simplifies the formatting expressions by eliminating null checks on coeffs/vars columns. * fix: missing space in lp file * perf: skip group_terms when unnecessary and avoid xarray broadcast for short DataFrame - Skip group_terms_polars when _term dim size is 1 (no duplicate vars) - Build the short DataFrame (labels, rhs, sign) directly with numpy instead of going through xarray.broadcast + to_polars - Add sign column via pl.lit when uniform (common case), avoiding costly numpy string array → polars conversion Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * perf: skip group_terms in LinearExpression.to_polars when no duplicate vars Check n_unique before running the expensive group_by+sum. When all variable references are unique (common case for objectives), this saves ~31ms per 320k terms. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * perf: reduce per-constraint overhead in Constraint.to_polars() Replace np.unique with faster numpy equality check for sign uniformity. Eliminate redundant filter_nulls_polars and check_has_nulls_polars on the short DataFrame by applying the labels mask directly during construction. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix: handle empty constraint slices in sign_flat check Guard against IndexError when sign_flat is empty (no valid labels) by checking len(sign_flat) > 0 before accessing sign_flat[0]. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * docs: add LP write speed improvement to release notes Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * bench: add LP write benchmark script with plotting * bench: larger model * perf: Add maybe_group_terms_polars() helper in common.py that checks for duplicate (labels, vars) pairs before calling group_terms_polars. Use it in both Constraint.to_polars() and LinearExpression.to_polars() to avoid expensive group_by when terms already reference distinct variables * Add variance to plot * test: add coverage for streaming fallback and maybe_group_terms_polars * fix: mypy * fix: mypy * Move kwargs into method for readability * Remove fallback and pin polars >=1.31 * Remove the benchmark_lp_writer.py --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 59f92ae commit 36b15c5

9 files changed

Lines changed: 153 additions & 80 deletions

File tree

doc/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Upcoming Version
66

77
* Fix docs (pick highs solver)
88
* Add the `sphinx-copybutton` to the documentation
9+
* Speed up LP file writing by 2-2.7x on large models through Polars streaming engine, join-based constraint assembly, and reduced per-constraint overhead
910

1011
Upcoming Version
1112
----------------

linopy/common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,25 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
449449
return df
450450

451451

452+
def maybe_group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
453+
"""
454+
Group terms only if there are duplicate (labels, vars) pairs.
455+
456+
This avoids the expensive group_by operation when terms already
457+
reference distinct variables (e.g. ``x - y`` has ``_term=2`` but
458+
no duplicates). When skipping, columns are reordered to match the
459+
output of ``group_terms_polars``.
460+
"""
461+
varcols = [c for c in df.columns if c.startswith("vars")]
462+
keys = [c for c in ["labels"] + varcols if c in df.columns]
463+
key_count = df.select(pl.struct(keys).n_unique()).item()
464+
if key_count < df.height:
465+
return group_terms_polars(df)
466+
# Match column order of group_terms (group-by keys, coeffs, rest)
467+
rest = [c for c in df.columns if c not in keys and c != "coeffs"]
468+
return df.select(keys + ["coeffs"] + rest)
469+
470+
452471
def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
453472
"""
454473
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.

linopy/constraints.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@
4040
generate_indices_for_printout,
4141
get_dims_with_index_levels,
4242
get_label_position,
43-
group_terms_polars,
4443
has_optimized_model,
45-
infer_schema_polars,
4644
iterate_slices,
45+
maybe_group_terms_polars,
4746
maybe_replace_signs,
4847
print_coord,
4948
print_single_constraint,
@@ -622,21 +621,38 @@ def to_polars(self) -> pl.DataFrame:
622621
long = to_polars(ds[keys])
623622

624623
long = filter_nulls_polars(long)
625-
long = group_terms_polars(long)
624+
if ds.sizes.get("_term", 1) > 1:
625+
long = maybe_group_terms_polars(long)
626626
check_has_nulls_polars(long, name=f"{self.type} {self.name}")
627627

628-
short_ds = ds[[k for k in ds if "_term" not in ds[k].dims]]
629-
schema = infer_schema_polars(short_ds)
630-
schema["sign"] = pl.Enum(["=", "<=", ">="])
631-
short = to_polars(short_ds, schema=schema)
632-
short = filter_nulls_polars(short)
633-
check_has_nulls_polars(short, name=f"{self.type} {self.name}")
634-
635-
df = pl.concat([short, long], how="diagonal_relaxed").sort(["labels", "rhs"])
636-
# delete subsequent non-null rhs (happens is all vars per label are -1)
637-
is_non_null = df["rhs"].is_not_null()
638-
prev_non_is_null = is_non_null.shift(1).fill_null(False)
639-
df = df.filter(is_non_null & ~prev_non_is_null | ~is_non_null)
628+
# Build short DataFrame (labels, rhs, sign) without xarray broadcast.
629+
# Apply labels mask directly instead of filter_nulls_polars.
630+
labels_flat = ds["labels"].values.reshape(-1)
631+
mask = labels_flat != -1
632+
labels_masked = labels_flat[mask]
633+
rhs_flat = np.broadcast_to(ds["rhs"].values, ds["labels"].shape).reshape(-1)
634+
635+
sign_values = ds["sign"].values
636+
sign_flat = np.broadcast_to(sign_values, ds["labels"].shape).reshape(-1)
637+
all_same_sign = len(sign_flat) > 0 and (
638+
sign_flat[0] == sign_flat[-1] and (sign_flat[0] == sign_flat).all()
639+
)
640+
641+
short_data: dict = {
642+
"labels": labels_masked,
643+
"rhs": rhs_flat[mask],
644+
}
645+
if all_same_sign:
646+
short = pl.DataFrame(short_data).with_columns(
647+
pl.lit(sign_flat[0]).cast(pl.Enum(["=", "<=", ">="])).alias("sign")
648+
)
649+
else:
650+
short_data["sign"] = pl.Series(
651+
"sign", sign_flat[mask], dtype=pl.Enum(["=", "<=", ">="])
652+
)
653+
short = pl.DataFrame(short_data)
654+
655+
df = long.join(short, on="labels", how="inner")
640656
return df[["labels", "coeffs", "vars", "sign", "rhs"]]
641657

642658
# Wrapped function which would convert variable to dataarray

linopy/expressions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
has_optimized_model,
6161
is_constant,
6262
iterate_slices,
63+
maybe_group_terms_polars,
6364
print_coord,
6465
print_single_expression,
6566
to_dataframe,
@@ -1469,7 +1470,7 @@ def to_polars(self) -> pl.DataFrame:
14691470

14701471
df = to_polars(self.data)
14711472
df = filter_nulls_polars(df)
1472-
df = group_terms_polars(df)
1473+
df = maybe_group_terms_polars(df)
14731474
check_has_nulls_polars(df, name=self.type)
14741475
return df
14751476

linopy/io.py

Lines changed: 30 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,21 @@ def clean_name(name: str) -> str:
5454
coord_sanitizer = str.maketrans("[,]", "(,)", " ")
5555

5656

57+
def _format_and_write(
58+
df: pl.DataFrame, columns: list[pl.Expr], f: BufferedWriter
59+
) -> None:
60+
"""
61+
Format columns via concat_str and write to file.
62+
63+
Uses Polars streaming engine for better memory efficiency.
64+
"""
65+
df.lazy().select(pl.concat_str(columns, ignore_nulls=True)).collect(
66+
engine="streaming"
67+
).write_csv(
68+
f, separator=" ", null_value="", quote_style="never", include_header=False
69+
)
70+
71+
5772
def signed_number(expr: pl.Expr) -> tuple[pl.Expr, pl.Expr]:
5873
"""
5974
Return polars expressions for a signed number string, handling -0.0 correctly.
@@ -155,10 +170,7 @@ def objective_write_linear_terms(
155170
*signed_number(pl.col("coeffs")),
156171
*print_variable(pl.col("vars")),
157172
]
158-
df = df.select(pl.concat_str(cols, ignore_nulls=True))
159-
df.write_csv(
160-
f, separator=" ", null_value="", quote_style="never", include_header=False
161-
)
173+
_format_and_write(df, cols, f)
162174

163175

164176
def objective_write_quadratic_terms(
@@ -171,10 +183,7 @@ def objective_write_quadratic_terms(
171183
*print_variable(pl.col("vars2")),
172184
]
173185
f.write(b"+ [\n")
174-
df = df.select(pl.concat_str(cols, ignore_nulls=True))
175-
df.write_csv(
176-
f, separator=" ", null_value="", quote_style="never", include_header=False
177-
)
186+
_format_and_write(df, cols, f)
178187
f.write(b"] / 2\n")
179188

180189

@@ -254,11 +263,7 @@ def bounds_to_file(
254263
*signed_number(pl.col("upper")),
255264
]
256265

257-
kwargs: Any = dict(
258-
separator=" ", null_value="", quote_style="never", include_header=False
259-
)
260-
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
261-
formatted.write_csv(f, **kwargs)
266+
_format_and_write(df, columns, f)
262267

263268

264269
def binaries_to_file(
@@ -296,11 +301,7 @@ def binaries_to_file(
296301
*print_variable(pl.col("labels")),
297302
]
298303

299-
kwargs: Any = dict(
300-
separator=" ", null_value="", quote_style="never", include_header=False
301-
)
302-
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
303-
formatted.write_csv(f, **kwargs)
304+
_format_and_write(df, columns, f)
304305

305306

306307
def integers_to_file(
@@ -339,11 +340,7 @@ def integers_to_file(
339340
*print_variable(pl.col("labels")),
340341
]
341342

342-
kwargs: Any = dict(
343-
separator=" ", null_value="", quote_style="never", include_header=False
344-
)
345-
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
346-
formatted.write_csv(f, **kwargs)
343+
_format_and_write(df, columns, f)
347344

348345

349346
def sos_to_file(
@@ -399,11 +396,7 @@ def sos_to_file(
399396
pl.col("var_weights"),
400397
]
401398

402-
kwargs: Any = dict(
403-
separator=" ", null_value="", quote_style="never", include_header=False
404-
)
405-
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
406-
formatted.write_csv(f, **kwargs)
399+
_format_and_write(df, columns, f)
407400

408401

409402
def constraints_to_file(
@@ -440,58 +433,32 @@ def constraints_to_file(
440433
if df.height == 0:
441434
continue
442435

443-
# Ensure each constraint has both coefficient and RHS terms
444-
analysis = df.group_by("labels").agg(
445-
[
446-
pl.col("coeffs").is_not_null().sum().alias("coeff_rows"),
447-
pl.col("sign").is_not_null().sum().alias("rhs_rows"),
448-
]
449-
)
450-
451-
valid = analysis.filter(
452-
(pl.col("coeff_rows") > 0) & (pl.col("rhs_rows") > 0)
453-
)
454-
455-
if valid.height == 0:
456-
continue
457-
458-
# Keep only constraints that have both parts
459-
df = df.join(valid.select("labels"), on="labels", how="inner")
460-
461436
# Sort by labels and mark first/last occurrences
462437
df = df.sort("labels").with_columns(
463438
[
464-
pl.when(pl.col("labels").is_first_distinct())
465-
.then(pl.col("labels"))
466-
.otherwise(pl.lit(None))
467-
.alias("labels_first"),
439+
pl.col("labels").is_first_distinct().alias("is_first_in_group"),
468440
(pl.col("labels") != pl.col("labels").shift(-1))
469441
.fill_null(True)
470442
.alias("is_last_in_group"),
471443
]
472444
)
473445

474-
row_labels = print_constraint(pl.col("labels_first"))
446+
row_labels = print_constraint(pl.col("labels"))
475447
col_labels = print_variable(pl.col("vars"))
476448
columns = [
477-
pl.when(pl.col("labels_first").is_not_null()).then(row_labels[0]),
478-
pl.when(pl.col("labels_first").is_not_null()).then(row_labels[1]),
479-
pl.when(pl.col("labels_first").is_not_null())
480-
.then(pl.lit(":\n"))
481-
.alias(":"),
449+
pl.when(pl.col("is_first_in_group")).then(row_labels[0]),
450+
pl.when(pl.col("is_first_in_group")).then(row_labels[1]),
451+
pl.when(pl.col("is_first_in_group")).then(pl.lit(":\n")).alias(":"),
482452
*signed_number(pl.col("coeffs")),
483-
pl.when(pl.col("vars").is_not_null()).then(col_labels[0]),
484-
pl.when(pl.col("vars").is_not_null()).then(col_labels[1]),
453+
col_labels[0],
454+
col_labels[1],
455+
pl.when(pl.col("is_last_in_group")).then(pl.lit("\n")),
485456
pl.when(pl.col("is_last_in_group")).then(pl.col("sign")),
486457
pl.when(pl.col("is_last_in_group")).then(pl.lit(" ")),
487458
pl.when(pl.col("is_last_in_group")).then(pl.col("rhs").cast(pl.String)),
488459
]
489460

490-
kwargs: Any = dict(
491-
separator=" ", null_value="", quote_style="never", include_header=False
492-
)
493-
formatted = df.select(pl.concat_str(columns, ignore_nulls=True))
494-
formatted.write_csv(f, **kwargs)
461+
_format_and_write(df, columns, f)
495462

496463
# in the future, we could use lazy dataframes when they support appending
497464
# tp existent files

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies = [
3333
"numexpr",
3434
"xarray>=2024.2.0",
3535
"dask>=0.18.0",
36-
"polars",
36+
"polars>=1.31",
3737
"tqdm",
3838
"deprecation",
3939
"packaging",

test/test_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_dims_with_index_levels,
2424
is_constant,
2525
iterate_slices,
26+
maybe_group_terms_polars,
2627
)
2728
from linopy.testing import assert_linequal, assert_varequal
2829

@@ -737,3 +738,20 @@ def test_is_constant() -> None:
737738
]
738739
for cv in constant_values:
739740
assert is_constant(cv)
741+
742+
743+
def test_maybe_group_terms_polars_no_duplicates() -> None:
744+
"""Fast path: distinct (labels, vars) pairs skip group_by."""
745+
df = pl.DataFrame({"labels": [0, 0], "vars": [1, 2], "coeffs": [3.0, 4.0]})
746+
result = maybe_group_terms_polars(df)
747+
assert result.shape == (2, 3)
748+
assert result.columns == ["labels", "vars", "coeffs"]
749+
assert result["coeffs"].to_list() == [3.0, 4.0]
750+
751+
752+
def test_maybe_group_terms_polars_with_duplicates() -> None:
753+
"""Slow path: duplicate (labels, vars) pairs trigger group_by."""
754+
df = pl.DataFrame({"labels": [0, 0], "vars": [1, 1], "coeffs": [3.0, 4.0]})
755+
result = maybe_group_terms_polars(df)
756+
assert result.shape == (1, 3)
757+
assert result["coeffs"].to_list() == [7.0]

test/test_constraint.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,20 @@ def test_constraint_to_polars(c: linopy.constraints.Constraint) -> None:
437437
assert isinstance(c.to_polars(), pl.DataFrame)
438438

439439

440+
def test_constraint_to_polars_mixed_signs(m: Model, x: linopy.Variable) -> None:
441+
"""Test to_polars when a constraint has mixed sign values across dims."""
442+
# Create a constraint, then manually patch the sign to have mixed values
443+
m.add_constraints(x >= 0, name="mixed")
444+
con = m.constraints["mixed"]
445+
# Replace sign data with mixed signs across the first dimension
446+
n = con.data.sizes["first"]
447+
signs = np.array(["<=" if i % 2 == 0 else ">=" for i in range(n)])
448+
con.data["sign"] = xr.DataArray(signs, dims=con.data["sign"].dims)
449+
df = con.to_polars()
450+
assert isinstance(df, pl.DataFrame)
451+
assert set(df["sign"].to_list()) == {"<=", ">="}
452+
453+
440454
def test_constraint_assignment_with_anonymous_constraints(
441455
m: Model, x: linopy.Variable, y: linopy.Variable
442456
) -> None:

test/test_io.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,40 @@ def test_to_file_lp_with_negative_zero_coefficients(tmp_path: Path) -> None:
336336

337337
# Verify Gurobi can read it without errors
338338
gurobipy.read(str(fn))
339+
340+
341+
def test_to_file_lp_same_sign_constraints(tmp_path: Path) -> None:
342+
"""Test LP writing when all constraints have the same sign operator."""
343+
m = Model()
344+
N = np.arange(5)
345+
x = m.add_variables(coords=[N], name="x")
346+
# All constraints use <=
347+
m.add_constraints(x <= 10, name="upper")
348+
m.add_constraints(x <= 20, name="upper2")
349+
m.add_objective(x.sum())
350+
351+
fn = tmp_path / "same_sign.lp"
352+
m.to_file(fn)
353+
content = fn.read_text()
354+
assert "s.t." in content
355+
assert "<=" in content
356+
357+
358+
def test_to_file_lp_mixed_sign_constraints(tmp_path: Path) -> None:
359+
"""Test LP writing when constraints have different sign operators."""
360+
m = Model()
361+
N = np.arange(5)
362+
x = m.add_variables(coords=[N], name="x")
363+
# Mix of <= and >= constraints in the same container
364+
m.add_constraints(x <= 10, name="upper")
365+
m.add_constraints(x >= 1, name="lower")
366+
m.add_constraints(2 * x == 8, name="eq")
367+
m.add_objective(x.sum())
368+
369+
fn = tmp_path / "mixed_sign.lp"
370+
m.to_file(fn)
371+
content = fn.read_text()
372+
assert "s.t." in content
373+
assert "<=" in content
374+
assert ">=" in content
375+
assert "=" in content

0 commit comments

Comments
 (0)