Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: check-useless-excludes

- repo: https://github.com/ComPWA/policy
rev: 0.8.9
rev: 0.8.10
hooks:
- id: check-dev-files
args:
Expand Down Expand Up @@ -61,7 +61,7 @@ repos:
metadata.vscode

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.10
rev: v0.15.12
hooks:
- id: ruff-check
args: [--fix]
Expand Down Expand Up @@ -114,7 +114,7 @@ repos:
- --in-place

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.37.1
rev: 0.37.2
hooks:
- id: check-jsonschema
name: Check CITATION.cff
Expand Down Expand Up @@ -154,6 +154,6 @@ repos:
types_or: [python, pyi, jupyter]

- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.11.6
rev: 0.11.13
hooks:
- id: uv-lock
9 changes: 6 additions & 3 deletions docs/amplitude-analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1501,10 +1501,13 @@
"outputs": [],
"source": [
"from tensorwaves.optimizer import Minuit2\n",
"from tensorwaves.optimizer.callbacks import CSVSummary\n",
"from tensorwaves.optimizer.callbacks import CallbackList, CSVSummary, TqdmProgressBar\n",
"\n",
"minuit2 = Minuit2(\n",
" callback=CSVSummary(\"fit_traceback.csv\"),\n",
" callback=CallbackList([\n",
" CSVSummary(\"fit_traceback.csv\"),\n",
" TqdmProgressBar(),\n",
" ]),\n",
" use_analytic_gradient=False,\n",
")\n",
"fit_result = minuit2.optimize(estimator, initial_parameters)\n",
Expand Down Expand Up @@ -1905,7 +1908,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.12"
"version": "3.13.13"
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def get_tensorflow_url() -> str:
"np.ndarray": "numpy.ndarray",
"ParameterValue": "tensorwaves.interface.ParameterValue",
"Path": "pathlib.Path",
"ProgressColumn": "rich.progress.ProgressColumn",
"sp.Basic": "sympy.core.basic.Basic",
"sp.Expr": "sympy.core.expr.Expr",
"sp.Symbol": "sympy.core.symbol.Symbol",
Expand Down Expand Up @@ -212,6 +213,7 @@ def get_tensorflow_url() -> str:
"pandas": (f"https://pandas.pydata.org/pandas-docs/version/{pin('pandas')}", None),
"python": ("https://docs.python.org/3", None),
"qrules": (f"https://qrules.readthedocs.io/{pin('qrules')}", None),
"rich": ("https://rich.readthedocs.io/en/stable", None),
"scipy": (get_scipy_url(), None),
"sympy": ("https://docs.sympy.org/latest", None),
"tensorflow": (get_tensorflow_url(), "tensorflow.inv"),
Expand Down
17 changes: 11 additions & 6 deletions docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@
" density=True,\n",
" label=\"weighted with $f$\",\n",
")\n",
"plt.legend();"
"plt.legend()\n",
"plt.show()"
]
},
{
Expand Down Expand Up @@ -571,7 +572,8 @@
"source": [
"%config InlineBackend.figure_formats = ['png']\n",
"\n",
"plt.hist(data[\"x\"], bins=200);"
"plt.hist(data[\"x\"], bins=200)\n",
"plt.show()"
]
},
{
Expand Down Expand Up @@ -637,7 +639,8 @@
" histtype=\"step\",\n",
" color=\"red\",\n",
" density=True,\n",
");"
")\n",
"plt.show()"
]
},
{
Expand Down Expand Up @@ -753,7 +756,8 @@
" histtype=\"step\",\n",
" color=\"red\",\n",
" density=True,\n",
");"
")\n",
"plt.show()"
]
},
{
Expand Down Expand Up @@ -913,7 +917,8 @@
"\n",
"y_range = (y, -sp.pi, +sp.pi)\n",
"substituted_expr_2d = expression_2d.subs(parameter_defaults)\n",
"plot3d(substituted_expr_2d, x_range, y_range);"
"plot3d(substituted_expr_2d, x_range, y_range)\n",
"plt.show()"
]
},
{
Expand Down Expand Up @@ -1258,7 +1263,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.12"
"version": "3.13.13"
}
},
"nbformat": 4,
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"attrs >=20.1.0", # https://www.attrs.org/en/stable/api.html#next-gen
"iminuit >=2.0",
"numpy",
"rich",
"sympy >=1.9", # lambdify cse
"tqdm >=4.24.0", # autonotebook
]
Expand Down Expand Up @@ -116,6 +117,7 @@ notebooks = [
"ipympl",
"matplotlib",
"pandas",
"rich[jupyter]",
"tensorwaves[jax,pwa]",
]
style = [
Expand Down Expand Up @@ -376,6 +378,13 @@ ref = "test"
env = {UV_PYTHON = "3.13"}
ref = "test"

[tool.poe.tasks.upgrade]
executor = {type = "simple"}
parallel = [
{cmd = "pre-commit autoupdate -j8"},
{cmd = "uv lock --upgrade"},
]

[tool.pytest]
addopts = [
"--color=yes",
Expand Down
4 changes: 2 additions & 2 deletions src/tensorwaves/function/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,13 @@ def extract_constant_sub_expressions(
substitutions = {
expr: sp.Symbol(f"f{i}") for i, expr in enumerate(constant_sub_expressions)
}
top_expression: sp.Expr = expression.xreplace(substitutions)
top_expression = expression.xreplace(substitutions)
sub_expressions = {
symbol: expr
for expr, symbol in substitutions.items()
if symbol in _get_free_symbols(top_expression)
}
return top_expression, sub_expressions
return top_expression, sub_expressions # ty:ignore[invalid-return-type]


def prepare_caching(
Expand Down
114 changes: 114 additions & 0 deletions src/tensorwaves/optimizer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from collections.abc import Iterable
from pathlib import Path

from rich.progress import Progress as RichProgress
from rich.progress import ProgressColumn
from tqdm import tqdm as TqdmType # noqa: N812

from tensorwaves.interface import Estimator, Optimizer, ParameterValue


Expand Down Expand Up @@ -94,6 +98,116 @@ def on_function_call_end(
callback.on_function_call_end(function_call, logs)


class RichProgressBar(Callback):
"""Display a `rich` progress bar during optimization.

Args:
*columns: The :ref:`columns <rich:columns>` to display in the progress bar. If
not provided, a default set of columns will be used.
**progress_kwargs: Keyword arguments forwarded to `rich.progress.Progress`.
total: The expected total number of function calls to be made during
optimization in order to get a time estimate.
"""

def __init__(
self,
*columns: str | ProgressColumn,
total: int | None = None,
**progress_kwargs: Any,
) -> None:
if columns:
self.__progress_columns = columns
else:
from rich.progress import ( # noqa: PLC0415
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
)

self.__progress_columns = (
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
MofNCompleteColumn(),
TimeElapsedColumn(),
)
self.__progress_kwargs = progress_kwargs
self.__progress: RichProgress | None = None
self.__task_id: Any = None
self.__total = total

def on_optimize_start(self, logs: dict[str, Any] | None = None) -> None:
from rich.progress import Progress # noqa: PLC0415

self.__progress = Progress(
*self.__progress_columns,
**self.__progress_kwargs,
)
self.__progress.start()
self.__task_id = self.__progress.add_task("Optimizing", total=self.__total)

def on_optimize_end(self, logs: dict[str, Any] | None = None) -> None:
if self.__progress is not None:
self.__progress.stop()
self.__progress = None
self.__task_id = None

def on_iteration_end(
self, iteration: int, logs: dict[str, Any] | None = None
) -> None:
pass

def on_function_call_end(
self, function_call: int, logs: dict[str, Any] | None = None
) -> None:
if self.__progress is None or self.__task_id is None:
return
description = "Optimizing"
if logs is not None:
estimator_value = logs.get("estimator", {}).get("value")
if estimator_value is not None:
description = f"estimator={estimator_value:.6g}"
self.__progress.update(self.__task_id, description=description, advance=1)


class TqdmProgressBar(Callback):
"""Display a ``tqdm`` progress bar during optimization.

Args:
**tqdm_kwargs: Keyword arguments forwarded to `tqdm <https://tqdm.github.io/docs/tqdm>`_.
"""

def __init__(self, **tqdm_kwargs: Any) -> None:
self.__tqdm_kwargs = tqdm_kwargs
self.__progress_bar: TqdmType | None = None

def on_optimize_start(self, logs: dict[str, Any] | None = None) -> None:
from tqdm.auto import tqdm # noqa: PLC0415

self.__progress_bar = tqdm(**self.__tqdm_kwargs)

def on_optimize_end(self, logs: dict[str, Any] | None = None) -> None:
if self.__progress_bar is not None:
self.__progress_bar.close()
self.__progress_bar = None

def on_iteration_end(
self, iteration: int, logs: dict[str, Any] | None = None
) -> None:
pass

def on_function_call_end(
self, function_call: int, logs: dict[str, Any] | None = None
) -> None:
if self.__progress_bar is None:
return
if logs is not None:
estimator_value = logs.get("estimator", {}).get("value")
if estimator_value is not None:
self.__progress_bar.set_postfix({"estimator": estimator_value})
self.__progress_bar.update()


class CSVSummary(Callback, Loadable):
"""Log fit parameters and the estimator value to a CSV file."""

Expand Down
4 changes: 0 additions & 4 deletions src/tensorwaves/optimizer/minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import TYPE_CHECKING, Any

import iminuit
from tqdm.auto import tqdm

from tensorwaves.interface import Estimator, FitResult, Optimizer, ParameterValue
from tensorwaves.optimizer.callbacks import Callback, _create_log
Expand Down Expand Up @@ -63,7 +62,6 @@ def optimize(
parameter_handler = ParameterFlattener(initial_parameters)
flattened_parameters = parameter_handler.flatten(initial_parameters)

progress_bar = tqdm(disable=_LOGGER.level > logging.WARNING)
n_function_calls = 0

parameters = parameter_handler.unflatten(flattened_parameters)
Expand All @@ -88,8 +86,6 @@ def wrapped_function(pars: list) -> float:
update_parameters(pars)
parameters = parameter_handler.unflatten(flattened_parameters)
estimator_value = float(estimator(parameters))
progress_bar.set_postfix({"estimator": estimator_value})
progress_bar.update()
if self.__callback is not None:
self.__callback.on_function_call_end(
n_function_calls,
Expand Down
5 changes: 0 additions & 5 deletions src/tensorwaves/optimizer/scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import time
from typing import TYPE_CHECKING, Any

from tqdm.auto import tqdm

from tensorwaves.function._backend import raise_missing_module_error
from tensorwaves.interface import Estimator, FitResult, Optimizer, ParameterValue
from tensorwaves.optimizer.parameter import ParameterFlattener
Expand Down Expand Up @@ -54,7 +52,6 @@ def optimize( # noqa: C901
parameter_handler = ParameterFlattener(initial_parameters)
flattened_parameters = parameter_handler.flatten(initial_parameters)

progress_bar = tqdm(disable=_LOGGER.level > logging.WARNING)
n_function_calls = 0
iterations = 0
estimator_value = 0.0
Expand Down Expand Up @@ -88,8 +85,6 @@ def wrapped_function(pars: list) -> float:
update_parameters(pars)
parameters = parameter_handler.unflatten(flattened_parameters)
estimator_value = estimator(parameters)
progress_bar.set_postfix({"estimator": estimator_value})
progress_bar.update()
if self.__callback is not None:
self.__callback.on_function_call_end(
n_function_calls,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_create_cached_function(backend):
symbols: tuple[sp.Symbol, ...] = sp.symbols("a b c d x y")
a, b, c, d, x, y = symbols
expression = a * x + b * (c * x + d * y**2)
parameter_defaults = {a: -2.5, b: 1.4, c: 0.8, d: 3.7}
parameter_defaults: dict[sp.Basic, int | float] = {a: -2.5, b: 1.4, c: 0.8, d: 3.7}

function = create_parametrized_function(expression, parameter_defaults, backend)
cached_function, cache_transformer = create_cached_function(
Expand Down
Loading
Loading