diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 67a761c8..1fabd502 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -42,7 +42,8 @@ and UPPER_SNAKE_CASE for constants. - Use `from __future__ import annotations` in every module. - Type-annotate all public function signatures. -- Docstrings on all public classes and methods (numpy style). +- Docstrings on all public classes and methods (numpy style). These must + include sections Parameters, Returns and Raises, where applicable. - Prefer flat over nested, explicit over clever. - Write straightforward code; do not add defensive checks for unlikely edge cases. @@ -147,6 +148,8 @@ `docs/architecture/architecture.md`. - After changes, run linting and formatting fixes with `pixi run fix`. Do not check what was auto-fixed, just accept the fixes and move on. + Then, run linting and formatting checks with `pixi run check` and + address any remaining issues until the code is clean. - After changes, run unit tests with `pixi run unit-tests`. - After changes, run integration tests with `pixi run integration-tests`. diff --git a/docs/architecture/architecture.md b/docs/architecture/architecture.md index f4c6fc52..8a7a9369 100644 --- a/docs/architecture/architecture.md +++ b/docs/architecture/architecture.md @@ -188,7 +188,7 @@ GuardedBase └── GenericDescriptorBase # name, value (validated via AttributeSpec), description ├── GenericStringDescriptor # _value_type = DataTypes.STRING └── GenericNumericDescriptor # _value_type = DataTypes.NUMERIC, + units - └── GenericParameter # + free, uncertainty, fit_min, fit_max, constrained, uid + └── GenericParameter # + free, uncertainty, fit_min, fit_max, constrained ``` CIF-bound concrete classes add a `CifHandler` for serialisation: @@ -714,12 +714,13 @@ Projects are saved as a directory of CIF files: ```shell project_dir/ ├── project.cif # ProjectInfo -├── analysis.cif # Analysis settings ├── summary.cif # Summary report ├── structures/ │ └── lbco.cif # One file per structure -└── experiments/ - └── hrpt.cif # One file per experiment +├── experiments/ +│ └── hrpt.cif # One file per experiment +└── analysis/ + └── analysis.cif # Analysis settings ``` ### 7.3 Verbosity @@ -919,6 +920,10 @@ project.experiments['xray_pdf'].peak_profile_type = 'gaussian-damped-sinc' - `DatablockItem` = one CIF `data_` block, `DatablockCollection` = set of blocks. - `CategoryItem` = one CIF category, `CategoryCollection` = CIF loop. +- **Free-flag encoding**: A parameter's free/fixed status is encoded in + CIF via uncertainty brackets. `3.89` = fixed, `3.89(2)` = free with + esd, `3.89()` = free without esd. There is no separate list of free + parameters; the brackets are the single source of truth. ### 9.2 Immutability of Experiment Type diff --git a/docs/architecture/issues_closed.md b/docs/architecture/issues_closed.md index a67edcbe..f95e3b89 100644 --- a/docs/architecture/issues_closed.md +++ b/docs/architecture/issues_closed.md @@ -4,6 +4,56 @@ Issues that have been fully resolved. Kept for historical reference. --- +## Implement `Project.load()` + +**Resolution:** implemented `Project.load(dir_path)` as a classmethod +that reads `project.cif`, `structures/*.cif`, `experiments/*.cif`, and +`analysis/analysis.cif` (with fallback to `analysis.cif` at root for +backward compatibility). Reconstructs the full project state including +alias parameter references via `_resolve_alias_references()`. +Integration tests verify save → load → parameter comparison and save → +load → fit → χ² comparison. Also used by `fit_sequential` workers to +reconstruct projects from CIF strings. + +--- + +## Eliminate Dummy `Experiments` Wrapper in Single-Fit Mode + +**Resolution:** refactored `Fitter.fit()` and `_residual_function()` to +accept `experiments: list[ExperimentBase]` instead of requiring an +`Experiments` collection. `Analysis.fit()` passes +`experiments_list = [experiment]` in single-fit mode and +`list(experiments.values())` in joint-fit mode. Removed the +`object.__setattr__` hack that forced `_parent` on the dummy wrapper. + +--- + +## Replace UID Map with Direct References and Auto-Apply Constraints + +**Resolution:** eliminated `UidMapHandler` and random UID generation +from parameters entirely. Aliases now store a direct object reference to +the parameter (`Alias._param_ref`) instead of a random UID string. +`ConstraintsHandler.apply()` uses the direct reference — no map lookup. +For CIF serialisation, `Alias._param_unique_name` stores the parameter's +deterministic `unique_name`. `_minimizer_uid` now returns +`unique_name.replace('.', '__')` instead of a random string. + +Also added `enable()`/`disable()` on `Constraints` with auto-enable on +`create()`, replacing the manual `apply_constraints()` call. +`Analysis._update_categories()` now always syncs handler state from the +current aliases and constraints when `constraints.enabled` is `True`, +eliminating stale-state bugs (former issue #4). `_set_value_constrained` +bypasses validation like `_set_value_from_minimizer` since constraints +run inside the minimiser loop. `Analysis.fit()` calls +`_update_categories()` before collecting free parameters so that +constrained parameters are correctly excluded. + +API change: `aliases.create(label=..., param_uid=...uid)` → +`aliases.create(label=..., param=...)`. `apply_constraints()` removed; +`constraints.create()` auto-enables. + +--- + ## Dirty-Flag Guard Was Disabled **Resolution:** added `_set_value_from_minimizer()` on diff --git a/docs/architecture/issues_open.md b/docs/architecture/issues_open.md index 05ed175f..b1ed9aa8 100644 --- a/docs/architecture/issues_open.md +++ b/docs/architecture/issues_open.md @@ -10,25 +10,6 @@ needed. --- -## 1. 🔴 Implement `Project.load()` - -**Type:** Completeness - -`save()` serialises all components to CIF files but `load()` is a stub -that raises `NotImplementedError`. Users cannot round-trip a project. - -**Why first:** this is the highest-severity gap. Without it the save -functionality is only half useful — CIF files are written but cannot be -read back. Tutorials that demonstrate save/load are blocked. - -**Fix:** implement `load()` that reads CIF files from the project -directory and reconstructs structures, experiments, and analysis -settings. - -**Depends on:** nothing (standalone). - ---- - ## 2. 🟡 Restore Minimiser Variant Support **Type:** Feature loss + Design limitation @@ -83,31 +64,6 @@ exactly match `project.experiments.names`. --- -## 4. 🔴 Refresh Constraint State Before Automatic Updates and Fitting - -**Type:** Correctness - -`ConstraintsHandler` is only synchronised from `analysis.aliases` and -`analysis.constraints` when the user explicitly calls -`project.analysis.apply_constraints()`. The normal fit / serialisation -path calls `constraints_handler.apply()` directly, so newly added or -edited aliases and constraints can be ignored until that manual sync -step happens. - -**Why high:** this produces silently incorrect results. A user can -define constraints, run a fit, and believe they were applied when the -active singleton still contains stale state from a previous run or no -state at all. - -**Fix:** before any automatic constraint application, always refresh the -singleton from the current `Aliases` and `Constraints` collections. The -sync should happen inside `Analysis._update_categories()` or inside the -constraints category itself, not only in a user-facing helper method. - -**Depends on:** nothing. - ---- - ## 5. 🟡 Make `Analysis` a `DatablockItem` **Type:** Consistency @@ -150,24 +106,6 @@ effectively fixed after experiment creation. --- -## 7. 🟡 Eliminate Dummy `Experiments` Wrapper in Single-Fit Mode - -**Type:** Fragility - -Single-fit mode creates a throw-away `Experiments` collection per -experiment, manually forces `_parent` via `object.__setattr__`, and -passes it to `Fitter`. This bypasses `GuardedBase` parent tracking and -is fragile. - -**Fix:** make `Fitter.fit()` accept a list of experiment objects (or a -single experiment) instead of requiring an `Experiments` collection. Or -add a `fit_single(experiment)` method. - -**Depends on:** nothing, but simpler after issue 5 (Analysis refactor) -clarifies the fitting orchestration. - ---- - ## 8. 🟡 Add Explicit `create()` Signatures on Collections **Type:** API safety @@ -339,21 +277,18 @@ re-derivable default. ## Summary -| # | Issue | Severity | Type | -| --- | ------------------------------------------ | -------- | ----------------------- | -| 1 | Implement `Project.load()` | 🔴 High | Completeness | -| 2 | Restore minimiser variants | 🟡 Med | Feature loss | -| 3 | Rebuild joint-fit weights | 🟡 Med | Fragility | -| 4 | Refresh constraint state before auto-apply | 🔴 High | Correctness | -| 5 | `Analysis` as `DatablockItem` | 🟡 Med | Consistency | -| 6 | Restrict `data_type` switching | 🔴 High | Correctness/Data safety | -| 7 | Eliminate dummy `Experiments` | 🟡 Med | Fragility | -| 8 | Explicit `create()` signatures | 🟡 Med | API safety | -| 9 | Future enum extensions | 🟢 Low | Design | -| 10 | Unify update orchestration | 🟢 Low | Maintainability | -| 11 | Document `_update` contract | 🟢 Low | Maintainability | -| 12 | CIF round-trip integration test | 🟢 Low | Quality | -| 13 | Suppress redundant dirty-flag sets | 🟢 Low | Performance | -| 14 | Finer-grained change tracking | 🟢 Low | Performance | -| 15 | Validate joint-fit weights | 🟡 Med | Correctness | -| 16 | Persist per-experiment `calculator_type` | 🟡 Med | Completeness | +| # | Issue | Severity | Type | +| --- | ---------------------------------------- | -------- | ----------------------- | +| 2 | Restore minimiser variants | 🟡 Med | Feature loss | +| 3 | Rebuild joint-fit weights | 🟡 Med | Fragility | +| 5 | `Analysis` as `DatablockItem` | 🟡 Med | Consistency | +| 6 | Restrict `data_type` switching | 🔴 High | Correctness/Data safety | +| 8 | Explicit `create()` signatures | 🟡 Med | API safety | +| 9 | Future enum extensions | 🟢 Low | Design | +| 10 | Unify update orchestration | 🟢 Low | Maintainability | +| 11 | Document `_update` contract | 🟢 Low | Maintainability | +| 12 | CIF round-trip integration test | 🟢 Low | Quality | +| 13 | Suppress redundant dirty-flag sets | 🟢 Low | Performance | +| 14 | Finer-grained change tracking | 🟢 Low | Performance | +| 15 | Validate joint-fit weights | 🟡 Med | Correctness | +| 16 | Persist per-experiment `calculator_type` | 🟡 Med | Completeness | diff --git a/docs/architecture/sequential_fitting_design.md b/docs/architecture/sequential_fitting_design.md index fd10a104..0513c89f 100644 --- a/docs/architecture/sequential_fitting_design.md +++ b/docs/architecture/sequential_fitting_design.md @@ -1,7 +1,7 @@ # Sequential Fitting — Architecture Design -**Status:** Draft — for discussion before implementation **Date:** -2026-04-02 +**Status:** Implementation in progress (PRs 1–13 complete; PR 14 +optional) **Date:** 2026-04-02 (updated 2026-04-03) --- @@ -845,10 +845,11 @@ here. `analysis/` directory. All analysis artifacts (settings + results) live under one directory. See § 5.4 and § 9.6. -11. **Singletons (`UidMapHandler`, `ConstraintsHandler`)** → replace - with instance-owned state on `Project` and `Analysis`. Fixes - notebook rerun issues, simplifies worker isolation, resolves issue - #4. See § 9.5. +11. **Singletons (`UidMapHandler`, `ConstraintsHandler`)** → + `UidMapHandler` eliminated (aliases use direct references + + `unique_name`). `ConstraintsHandler` stays singleton but is now + always synced before use. Fixes notebook rerun issues, resolves + issue #4. See § 9.5. --- @@ -857,51 +858,41 @@ here. These changes are needed before implementing `fit_sequential()` itself. Each is a separate, atomic change. -### 9.1 Switch alias `param_uid` to `param_unique_name` +### 9.1 Switch alias `param_uid` to `param_unique_name` ✅ -The `Alias` category currently stores `param_uid` (random UID). Change -to `param_unique_name` (deterministic `unique_name`). Update: +**Done.** Went further than planned: eliminated `UidMapHandler` and +random UIDs entirely. Aliases now store a direct object reference to the +parameter (`Alias._param_ref`, runtime) plus `Alias._param_unique_name` +(`StringDescriptor`, CIF serialisation with tag +`_alias.param_unique_name`). `ConstraintsHandler.apply()` uses the +direct reference — no map lookup needed. `_minimizer_uid` returns +`unique_name.replace('.', '__')` instead of a random string. All +tutorials, tests, and call sites updated. -- `Alias._param_uid` → `Alias._param_unique_name` -- `CifHandler(names=['_alias.param_uid'])` → - `CifHandler(names=['_alias.param_unique_name'])` -- `ConstraintsHandler` to resolve via `unique_name` lookup instead of - UID lookup. -- `UidMapHandler` — may no longer be needed for constraint resolution - (but still used for other purposes). -- Tutorial `ed-17.py` and any tests that create aliases. +### 9.2 Fix `category_collection_to_cif` truncation ✅ -### 9.2 Fix `category_collection_to_cif` truncation +**Done.** `category_collection_to_cif` default changed to +`max_display=None` (emit all rows). Truncation is opt-in via explicit +`max_display` parameter, used only by display methods. -`category_collection_to_cif` has `max_display=20` which truncates loop -output. For CIF used in save/load/round-trip, all rows must be emitted. +### 9.3 Verify CIF round-trip for experiments ✅ -Options: +**Done.** Five integration tests in `test_cif_round_trip.py`: -- (a) Remove `max_display` from `category_collection_to_cif` entirely, - add truncation only in display methods. -- (b) Add a `full=True` parameter and use it when serialising for - persistence. +1. Parameter values survive `as_cif` → `from_cif_str`. +2. Free flags survive the round-trip. +3. Category collections (background, excluded regions, linked phases) + preserve item count. +4. Data points survive (count, first/last values). +5. Structure round-trip with symmetry constraints. -### 9.3 Verify CIF round-trip for experiments +### 9.4 Add `destination` parameter to `extract_data_paths_from_zip` ✅ -Write an integration test: +**Done.** Optional `destination` parameter added. When provided, +extracts to the given directory. When `None`, uses a temporary directory +(original behaviour). -1. Create a fully configured experiment (instrument, peak, background, - excluded regions, linked phases, data). -2. Serialise to CIF (`experiment.as_cif`). -3. Reconstruct from CIF (`ExperimentFactory.from_cif_str(cif_str)`). -4. Compare all parameter values. - -Fix any parameters that don't survive the round-trip. - -### 9.4 Add `destination` parameter to `extract_data_paths_from_zip` - -Currently extracts to a temp dir. Add optional `destination` parameter -to extract to a user-specified directory, enabling a clean two-step -workflow (extract → fit_sequential). - -### 9.5 Replace singletons with instance-owned state +### 9.5 Replace singletons with instance-owned state (partially done) #### Problem @@ -927,66 +918,41 @@ workflow (extract → fit_sequential). the same session (e.g. to compare fits), their constraints and UID maps collide in the shared singleton. -#### Proposed fix - -Move the state owned by singletons into `Analysis` (for constraints) and -`Project` (for the UID map): - -| Current singleton | New owner | Lifetime | -| -------------------- | ------------------------------ | ----------------------- | -| `ConstraintsHandler` | `Analysis._constraints_engine` | Per-`Analysis` instance | -| `UidMapHandler` | `Project._uid_map` | Per-`Project` instance | - -The objects are the same classes, just no longer singletons — they are -instantiated in `__init__` and passed explicitly to the components that -need them (e.g. `Parameter.__init__` receives a `uid_map` reference from -its owning project, `ConstraintsHandler` is accessed via -`self.project.analysis._constraints_engine`). +#### Current status -#### Impact on sequential fitting +**`UidMapHandler`: eliminated entirely.** Random UIDs and the global +UID-to-Parameter map have been removed. Aliases store direct object +references at runtime and deterministic `unique_name` strings for CIF +serialisation. This fully resolves problems 1–3 for the UID map. -- **Simplifies workers:** each worker's `Project()` naturally creates - its own `_uid_map` and `_constraints_engine`. No singleton isolation - concern at all. -- **Simplifies crash recovery and notebook reruns:** creating a new - `Project` starts with a blank slate, no stale state leaks. -- **No impact on the `fit_sequential` API** — the change is purely - internal. +**`ConstraintsHandler`: stale-state bug fixed, still a singleton.** +`Analysis._update_categories()` now always syncs the handler from the +current `aliases` and `constraints` before calling `apply()`. This +resolves problem 1 (notebook reruns) and problem 2 (worker isolation is +natural with `spawn`). Problem 3 (multiple projects) remains theoretical +— if multi-project support becomes a real need, moving +`ConstraintsHandler` to instance scope is a standalone follow-up. -#### Scope and sequencing +#### Remaining work (optional) -This is a self-contained refactor that can be done independently of -sequential fitting. It improves correctness for existing workflows -(notebook reruns, issue #4) and simplifies the sequential fitting -implementation. It is listed as a prerequisite because it eliminates a -class of bugs that would otherwise need workaround code in the worker. - -However, if the refactor proves too large for the initial sequential -fitting work, the `spawn`-based multiprocessing provides natural -isolation and the singletons can be addressed in a follow-up. The -sequential fitting design does **not** depend on this change — it works -either way. +Move `ConstraintsHandler` from singleton to per-`Analysis` instance. +This only matters for the multiple-projects edge case. The sequential +fitting design does **not** depend on this change. #### Relationship to issue #4 -Open issue #4 ("Refresh constraint state before auto-apply") is a -symptom of the singleton problem. If constraints are instance-owned, -there is no stale state to refresh — the constraint engine always -reflects the current `Analysis` instance's aliases and constraints. -Fixing the singleton issue resolves issue #4 as a side effect. - -### 9.6 Move `analysis.cif` into `analysis/` directory +Issue #4 ("Refresh constraint state before auto-apply") is **fully +resolved.** `_update_categories()` syncs handler state on every call. +Constraints auto-enable on `create()` and are applied before fitting +starts. The manual `apply_constraints()` method has been removed. Fixing +the singleton issue resolves issue #4 as a side effect. -Currently `analysis.cif` lives at the project root alongside -`project.cif` and `summary.cif`. Adding an `analysis/` directory for -`results.csv` next to a file named `analysis.cif` at the same level -creates a naming conflict and a confusing layout. +### 9.6 Move `analysis.cif` into `analysis/` directory ✅ -**Fix:** update `Project.save()` to write `analysis.cif` to -`project_dir/analysis/analysis.cif`. Update `Project.load()` (when -implemented) to read from the new path, with a fallback to the old path -for backward compatibility with existing saved projects. Update docs -(`architecture.md`, `project.md`), tests, and the save output messages. +**Done.** `Project.save()` writes to `analysis/analysis.cif`. +`Project.load()` checks `analysis/analysis.cif` first, falls back to +`analysis.cif` at root for backward compatibility. Unit tests verify +both layouts. --- @@ -999,7 +965,7 @@ resolved first because they clean up the fitting internals that ### Foundation PRs (resolve existing issues) -#### PR 1 — Eliminate dummy Experiments wrapper in single-fit mode (issue #7) +#### PR 1 — Eliminate dummy Experiments wrapper in single-fit mode (issue #7) ✅ > **Title:** `Accept single Experiment in Fitter.fit()` > @@ -1011,38 +977,36 @@ resolved first because they clean up the fitting internals that > Update all callers (single-fit, joint-fit). Update unit and > integration tests. -**Why first:** the current dummy-wrapper pattern is the exact -antipattern that `fit_sequential` workers would otherwise inherit. -Fixing it now gives the worker a clean -`Fitter.fit(structures, [experiment])` call without any collection -ceremony. +**Done.** `Fitter.fit()` and `_residual_function()` now accept +`experiments: list[ExperimentBase]`. `Analysis.fit()` passes +`experiments_list = [experiment]` in single-fit mode and +`list(experiments.values())` in joint-fit mode. No more dummy +`Experiments` wrapper or `object.__setattr__` hack. -#### PR 2 — Replace singletons with instance-owned state (issue #4 + § 9.5) +#### PR 2 — Replace UID map with direct references and auto-apply constraints (issue #4 + § 9.5) ✅ > **Title:** -> `Move ConstraintsHandler and UidMapHandler to instance scope` +> `Replace UID map with direct references and auto-apply constraints` > -> **Description:** Replace the `SingletonBase` pattern for -> `ConstraintsHandler` and `UidMapHandler` with per-project instances. -> `Project.__init__` creates `_uid_map`; `Analysis.__init__` creates -> `_constraints_engine`. Thread the references through to `Parameter` -> and constraint resolution. Remove `SingletonBase` class if no longer -> used. Update all call sites that use `.get()`. This also fixes issue -> #4 (stale constraint state) as a side effect — the constraint engine -> is always in sync with its owning `Analysis`. - -**Why second:** removes the global mutable state that makes notebook -reruns unreliable and multi-project sessions impossible. Sequential -fitting workers benefit from natural isolation (each `Project()` has its -own engine), but the main benefit is correctness for existing workflows. - -This is a sub-step breakdown if the PR proves too large: - -- **PR 2a:** `Move UidMapHandler to Project instance scope` -- **PR 2b:** `Move ConstraintsHandler to Analysis instance scope` -- **PR 2c:** `Remove SingletonBase if unused` - -#### PR 3 — Implement Project.load() (issue #1) +> **Description:** Eliminated `UidMapHandler` and random UID generation +> entirely. Aliases store direct parameter object references at runtime +> and deterministic `unique_name` strings for CIF. Added +> `enable()`/`disable()` on `Constraints` with auto-enable on +> `create()`, replacing the manual `apply_constraints()` call. +> `Analysis._update_categories()` always syncs handler state when +> constraints are enabled. Also fixes issue #4 (stale constraint state) +> and completes PR 4 (alias `param_unique_name`). + +**Why second:** removes the global UID map that made constraint +resolution opaque and fragile. The stale-state bug (issue #4) is fully +fixed. `ConstraintsHandler` remains a singleton but is now always in +sync — moving it to instance scope is an optional follow-up for the +multi-project edge case. + +This PR also absorbed PR 4 (§ 9.1) since switching from random UIDs to +`unique_name` was a natural part of the same change. + +#### PR 3 — Implement Project.load() (issue #1) ✅ > **Title:** `Implement Project.load() from CIF directory` > @@ -1053,25 +1017,20 @@ This is a sub-step breakdown if the PR proves too large: > as a fallback. Add integration test: save → load → compare all > parameter values. -**Why third:** the CIF round-trip reliability that `load()` proves is -the same reliability that `fit_sequential` workers depend on (they -reconstruct a project from CIF strings). Implementing `load()` forces us -to fix any serialisation gaps before they become worker bugs. Phase 3 -(dataset replay) also directly uses `load()`. +**Done.** `Project.load()` reads CIF files from the project directory, +reconstructs structures, experiments, and analysis. Resolves alias +`param_unique_name` strings back to live `Parameter` references. +Integration tests verify save → load → parameter comparison and save → +load → fit → χ² comparison. ### Sequential-fitting prerequisite PRs -#### PR 4 — Switch alias param_uid to param_unique_name (§ 9.1) +#### PR 4 — Switch alias param_uid to param_unique_name (§ 9.1) ✅ -> **Title:** `Use unique_name instead of random UID in aliases` -> -> **Description:** Rename `Alias._param_uid` to -> `Alias._param_unique_name`. Update `CifHandler` names. Change -> `ConstraintsHandler` to resolve parameters via `unique_name` lookup -> instead of UID. Update `ed-17.py` tutorial and all tests that create -> aliases. +> Absorbed into PR 2. Aliases now use `param_unique_name` with direct +> object references. All tutorials and tests updated. -#### PR 5 — Fix CIF collection truncation (§ 9.2) +#### PR 5 — Fix CIF collection truncation (§ 9.2) ✅ > **Title:** `Remove max_display truncation from CIF serialisation` > @@ -1080,7 +1039,11 @@ to fix any serialisation gaps before they become worker bugs. Phase 3 > (`show_as_cif()`). Ensures experiments with many background/data > points survive CIF round-trips. -#### PR 6 — Verify CIF round-trip for experiments (§ 9.3) +**Done.** `category_collection_to_cif` default changed to +`max_display=None` (emit all rows). Truncation is now opt-in, only used +by display methods. + +#### PR 6 — Verify CIF round-trip for experiments (§ 9.3) ✅ > **Title:** `Add CIF round-trip integration test for experiments` > @@ -1090,7 +1053,11 @@ to fix any serialisation gaps before they become worker bugs. Phase 3 > asserts all parameter values match. Fix any parameters that don't > survive the round-trip. -#### PR 7 — Move analysis.cif into analysis/ directory (§ 9.6) +**Done.** Five integration tests in `test_cif_round_trip.py`: parameter +values, free flags, categories (background/excluded regions/linked +phases), data points, and structure round-trip. + +#### PR 7 — Move analysis.cif into analysis/ directory (§ 9.6) ✅ > **Title:** `Move analysis.cif into analysis/ directory` > @@ -1099,7 +1066,12 @@ to fix any serialisation gaps before they become worker bugs. Phase 3 > from the new path (with fallback to old path). Update docs > (`architecture.md`, `project.md`), tests, and console output messages. -#### PR 8 — Add destination to extract_data_paths_from_zip (§ 9.4) +**Done.** `Project.save()` writes to `analysis/analysis.cif`. +`Project.load()` checks `analysis/analysis.cif` first, falls back to +`analysis.cif` at root for backward compatibility. Unit tests verify +both layouts. + +#### PR 8 — Add destination to extract_data_paths_from_zip (§ 9.4) ✅ > **Title:** `Add destination parameter to extract_data_paths_from_zip` > @@ -1108,9 +1080,13 @@ to fix any serialisation gaps before they become worker bugs. Phase 3 > directory instead of a temp dir. Enables clean two-step workflow: > extract ZIP → pass directory to `fit_sequential()`. +**Done.** `extract_data_paths_from_zip` accepts `destination` parameter. +When provided, extracts to the given directory. When `None`, uses a +temporary directory (original behaviour). + ### Sequential-fitting core PRs -#### PR 9 — Streaming sequential fit (max_workers=1) +#### PR 9 — Streaming sequential fit (max_workers=1) ✅ > **Title:** `Add fit_sequential() for streaming single-worker fitting` > @@ -1122,16 +1098,16 @@ to fix any serialisation gaps before they become worker bugs. Phase 3 > `extract_diffrn` callback support for metadata columns. Unit tests for > CSV writing, crash recovery, parameter propagation. -This is a sub-step breakdown if the PR proves too large: - -- **PR 9a:** `Add SequentialFitTemplate and _fit_worker function` — - dataclass, worker function, no CSV, no recovery. -- **PR 9b:** `Add CSV output and crash recovery to fit_sequential` — CSV - writing, reading, resumption logic. -- **PR 9c:** `Add parameter propagation and extract_diffrn callback` — - chunk-to-chunk seeding, diffrn metadata columns. +**Done.** Full implementation in `analysis/sequential.py`: +`SequentialFitTemplate` dataclass, `_fit_worker()` module-level +function, CSV helpers (`_build_csv_header`, `_write_csv_header`, +`_append_to_csv`, `_read_csv_for_recovery`), `_build_template()`, +chunk-based processing with parameter propagation, `extract_diffrn` +callback support, progress reporting. Five integration tests in +`test_sequential.py`: CSV production, crash recovery, parameter +propagation, diffrn callback, precondition validation. -#### PR 10 — Update plot_param_series to read from CSV +#### PR 10 — Update plot_param_series to read from CSV ✅ > **Title:** `Unify plot_param_series to always read from CSV` > @@ -1141,7 +1117,13 @@ This is a sub-step breakdown if the PR proves too large: > and existing `fit()` single-mode (Phase 4). Remove the old > `_parameter_snapshots` dict. -#### PR 11 — Parallel fitting (max_workers > 1) +**Implemented:** `Plotter.plot_param_series()` reads CSV via pandas. +`Plotter.plot_param_series_from_snapshots()` preserves backward +compatibility for `fit()` single-mode (no CSV yet). +`Project.plot_param_series()` tries CSV first, falls back to snapshots. +Axis labels derived from live descriptor objects. + +#### PR 11 — Parallel fitting (max_workers > 1) ✅ > **Title:** `Add multiprocessing support to fit_sequential` > @@ -1151,9 +1133,16 @@ This is a sub-step breakdown if the PR proves too large: > `max_workers='auto'` support (`os.cpu_count()`). Integration test: > parallel sequential fit (10 files, 2 workers). +**Implemented:** `ProcessPoolExecutor` with `mp.get_context('spawn')` +and `max_tasks_per_child=100` dispatches chunks in parallel when +`max_workers > 1`. Single-worker mode (`max_workers=1`) still calls +`_fit_worker` directly (no subprocess overhead). `max_workers='auto'` +resolves to `os.cpu_count()`. Integration test +`test_fit_sequential_parallel` verifies 2-worker parallel fitting. + ### Post-sequential PRs -#### PR 12 — Dataset replay from CSV +#### PR 12 — Dataset replay from CSV ✅ > **Title:** `Add apply_params_from_csv() for dataset replay` > @@ -1162,7 +1151,12 @@ This is a sub-step breakdown if the PR proves too large: > reloads data from the file path in that row. Enables > `plot_meas_vs_calc()` for any previously fitted dataset. -#### PR 13 — CSV output for existing single-fit mode +**Implemented:** `Project.apply_params_from_csv(row_index)` reads a CSV +row, overrides parameter values and uncertainties, and reloads measured +data when `file_path` points to a real file (sequential-fit case). Three +integration tests: parameter override, missing CSV, out-of-range index. + +#### PR 13 — CSV output for existing single-fit mode ✅ > **Title:** `Write results.csv from existing single-fit mode` > @@ -1171,6 +1165,14 @@ This is a sub-step breakdown if the PR proves too large: > `fit_sequential`). This gives `ed-17.py`-style workflows persistent > CSV output and unified `plot_param_series()`. +**Implemented:** `Analysis.fit()` single-mode now writes +`analysis/results.csv` incrementally (one row per experiment) when the +project has been saved. Reuses `_META_COLUMNS`, `_write_csv_header`, and +`_append_to_csv` from `sequential.py`. Diffrn metadata and free +parameter values/uncertainties are written per row. The in-memory +`_parameter_snapshots` is kept for unsaved-project fallback. +`plot_param_series()` now uses CSV for saved projects automatically. + #### PR 14 (optional) — Parallel single-fit for pre-loaded experiments > **Title:** @@ -1185,25 +1187,24 @@ This is a sub-step breakdown if the PR proves too large: ### Dependency graph ``` -PR 1 (issue #7: eliminate dummy Experiments) - └─► PR 2 (issue #4: singletons → instance-owned) - └─► PR 3 (issue #1: Project.load) - └─► PR 4 (alias unique_name) - └─► PR 5 (CIF truncation) - └─► PR 6 (CIF round-trip test) - ├─► PR 7 (analysis.cif → analysis/) - │ └─► PR 9 (streaming sequential fit) - │ ├─► PR 10 (plot from CSV) - │ │ └─► PR 13 (CSV for existing fit) - │ └─► PR 11 (parallel fitting) - │ └─► PR 14 (optional: parallel fit()) - └─► PR 8 (zip destination) - └─► PR 12 (dataset replay) +PR 1 (issue #7: eliminate dummy Experiments) ✅ + └─► PR 2 (issue #4: UID map + constraints) ✅ + └─► PR 3 (issue #1: Project.load) ✅ + └─► PR 5 (CIF truncation) ✅ + └─► PR 6 (CIF round-trip test) ✅ + ├─► PR 7 (analysis.cif → analysis/) ✅ + │ └─► PR 9 (streaming sequential fit) ✅ + │ ├─► PR 10 (plot from CSV) ✅ + │ │ └─► PR 13 (CSV for existing fit) ✅ + │ └─► PR 11 (parallel fitting) ✅ + │ └─► PR 14 (optional: parallel fit()) + └─► PR 8 (zip destination) ✅ + └─► PR 12 (dataset replay) ✅ ``` -Note: PRs 4–8 are largely independent of each other and can be -parallelised or reordered as long as PRs 1–3 are done first and PRs 4–6 -are done before PR 9. +Note: PR 4 was absorbed into PR 2. PRs 5–8 are largely independent of +each other and can be parallelised or reordered as long as PRs 1–3 are +done first and PRs 5–6 are done before PR 9. --- @@ -1216,43 +1217,52 @@ are all stdlib. ### Risks -| Risk | Mitigation | -| ------------------------------------------------ | -------------------------------------------------------- | -| CIF round-trip loses information | PR 3 (load) + PR 6 (round-trip test) verify before PR 9 | -| CIF collection truncation at 20 rows | PR 5 fixes before PR 9 | -| Worker memory leak (large N, long-running pool) | Use `max_tasks_per_child=100` on the pool | -| Pickling failures for SequentialFitTemplate | Keep it a plain dataclass with only str/dict/list fields | -| crysfml Fortran global state in forked processes | Enforced `spawn` context avoids fork issues | - -### Resolved open issues (now prerequisites) - -- **Issue #7 (dummy Experiments wrapper):** resolved in PR 1. The worker - uses the clean `Fitter.fit(structures, [experiment])` API. -- **Issue #4 (constraint refresh) + § 9.5 (singletons):** resolved in - PR 2. Instance-owned constraint engine eliminates stale state. -- **Issue #1 (Project.load):** resolved in PR 3. CIF round-trip - reliability is proven before workers depend on it. Dataset replay - (PR 12) uses `load()` directly. +| Risk | Mitigation | +| ------------------------------------------------ | ----------------------------------------------------------- | +| CIF round-trip loses information | ✅ PR 3 (load) + PR 6 (round-trip test) verified | +| CIF collection truncation at 20 rows | ✅ PR 5 fixed (default `max_display=None`) | +| Worker memory leak (large N, long-running pool) | ✅ `max_tasks_per_child=100` on the pool (PR 11) | +| Pickling failures for SequentialFitTemplate | ✅ Keep it a plain dataclass with only str/dict/list fields | +| crysfml Fortran global state in forked processes | ✅ Enforced `spawn` context avoids fork issues (PR 11) | + +### Resolved open issues (now prerequisites) — all done ✅ + +- **Issue #7 (dummy Experiments wrapper):** resolved in PR 1. + `Fitter.fit()` and `_residual_function()` accept + `list[ExperimentBase]`. The worker uses the clean + `Fitter.fit(structures, [experiment])` API. +- **Issue #4 (constraint refresh) + § 9.1 (alias unique_name) + § 9.5 + (singletons):** resolved in PR 2. `UidMapHandler` eliminated; aliases + use direct object references and deterministic `unique_name` for CIF; + `_update_categories()` always syncs handler state; constraints + auto-enable on `create()`. `ConstraintsHandler` remains a singleton + but is always in sync — multi-project isolation is an optional + follow-up. +- **Issue #1 (Project.load):** resolved in PR 3. `Project.load()` reads + CIF files, reconstructs full project state, resolves alias + `param_unique_name` strings back to `Parameter` objects via + `_resolve_alias_references()`. Dataset replay (PR 12) uses `load()` + directly. --- ## 12. Summary -| Aspect | Decision | -| ------------------- | --------------------------------------------------------------------- | -| Parallelism backend | `concurrent.futures.ProcessPoolExecutor` with `spawn` | -| Worker isolation | Each worker creates a fresh `Project` — no shared state | -| Data source | `data_dir` argument; ZIP → extract first | -| Data flow | Template CIF + data path → worker → result dict → CSV | -| Parameter IDs | `unique_name` (deterministic), not `uid` (random) | -| Parameter seeding | Last successful result in chunk → next chunk | -| CSV location | `project_dir/analysis/results.csv` (deterministic) | -| CSV contents | Fit metrics + diffrn metadata + all free param values/uncert | -| Metadata extraction | User-provided `extract_diffrn` callback, not hidden in lib | -| Crash recovery | Read existing CSV, skip fitted files, resume | -| Plotting | Unified `plot_param_series()` always reads from CSV | -| Configuration | `max_workers` + `data_dir` on `fit_sequential()` | -| Project layout | `analysis.cif` moves into `analysis/` directory | -| Singletons | Replace with instance-owned state (recommended prerequisite) | -| New dependencies | None (stdlib only) | -| First step | PRs 1–3 (foundation issues), then PRs 4–8 (prerequisites), then PR 9+ | +| Aspect | Decision | Status | +| ------------------- | ---------------------------------------------------------------------------------- | ------ | +| Parallelism backend | `concurrent.futures.ProcessPoolExecutor` with `spawn` | ✅ | +| Worker isolation | Each worker creates a fresh `Project` — no shared state | ✅ | +| Data source | `data_dir` argument; ZIP → extract first | ✅ | +| Data flow | Template CIF + data path → worker → result dict → CSV | ✅ | +| Parameter IDs | `unique_name` (deterministic), not `uid` (random) | ✅ | +| Parameter seeding | Last successful result in chunk → next chunk | ✅ | +| CSV location | `project_dir/analysis/results.csv` (deterministic) | ✅ | +| CSV contents | Fit metrics + diffrn metadata + all free param values/uncert | ✅ | +| Metadata extraction | User-provided `extract_diffrn` callback, not hidden in lib | ✅ | +| Crash recovery | Read existing CSV, skip fitted files, resume | ✅ | +| Plotting | Unified `plot_param_series()` always reads from CSV | ✅ | +| Configuration | `max_workers` + `data_dir` on `fit_sequential()` | ✅ | +| Project layout | `analysis.cif` moves into `analysis/` directory | ✅ | +| Singletons | `UidMapHandler` eliminated; `ConstraintsHandler` stays singleton but always synced | ✅ | +| New dependencies | None (stdlib only) | ✅ | +| First step | PRs 1–13 done; PR 14 optional | ✅ | diff --git a/docs/docs/tutorials/ed-1.ipynb b/docs/docs/tutorials/ed-1.ipynb index d2e178a2..3b8085bb 100644 --- a/docs/docs/tutorials/ed-1.ipynb +++ b/docs/docs/tutorials/ed-1.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8dbf8e63", + "id": "e1c6f514", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-10.ipynb b/docs/docs/tutorials/ed-10.ipynb index 0c865ce0..5707c038 100644 --- a/docs/docs/tutorials/ed-10.ipynb +++ b/docs/docs/tutorials/ed-10.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f42176f2", + "id": "2e0ed9d7", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-11.ipynb b/docs/docs/tutorials/ed-11.ipynb index 92714987..30fc3e26 100644 --- a/docs/docs/tutorials/ed-11.ipynb +++ b/docs/docs/tutorials/ed-11.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b38dbf4f", + "id": "e0a12c6e", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-12.ipynb b/docs/docs/tutorials/ed-12.ipynb index 51bdc129..6aa16c6b 100644 --- a/docs/docs/tutorials/ed-12.ipynb +++ b/docs/docs/tutorials/ed-12.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "effff825", + "id": "edee23bc", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-13.ipynb b/docs/docs/tutorials/ed-13.ipynb index bb4dd4c2..66854a4c 100644 --- a/docs/docs/tutorials/ed-13.ipynb +++ b/docs/docs/tutorials/ed-13.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "263b6625", + "id": "1a143d79", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-14.ipynb b/docs/docs/tutorials/ed-14.ipynb index 5c1f6717..df9a857c 100644 --- a/docs/docs/tutorials/ed-14.ipynb +++ b/docs/docs/tutorials/ed-14.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "502ed03e", + "id": "80ba77ad", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-15.ipynb b/docs/docs/tutorials/ed-15.ipynb index 6e2b2547..60ec08ce 100644 --- a/docs/docs/tutorials/ed-15.ipynb +++ b/docs/docs/tutorials/ed-15.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65ccac80", + "id": "7a35fc22", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-16.ipynb b/docs/docs/tutorials/ed-16.ipynb index 4cb7d1be..5d0aa659 100644 --- a/docs/docs/tutorials/ed-16.ipynb +++ b/docs/docs/tutorials/ed-16.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d57a9295", + "id": "0956c08b", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-17.ipynb b/docs/docs/tutorials/ed-17.ipynb index bbb422ca..e9369b01 100644 --- a/docs/docs/tutorials/ed-17.ipynb +++ b/docs/docs/tutorials/ed-17.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9fd2be7c", + "id": "79fce5be", "metadata": { "tags": [ "hide-in-docs" @@ -28,9 +28,9 @@ "\n", "This example demonstrates a Rietveld refinement of the Co2SiO4 crystal\n", "structure using constant-wavelength neutron powder diffraction data\n", - "from D20 at ILL. A sequential refinement of the same structure against\n", - "a temperature scan is performed to show how to manage multiple\n", - "experiments in a project." + "from D20 at ILL. A sequential refinement is performed against a\n", + "temperature scan using `fit_sequential`, which processes each data\n", + "file independently without loading all datasets into memory at once." ] }, { @@ -76,8 +76,8 @@ "id": "5", "metadata": {}, "source": [ - "Set output verbosity level to \"short\" to show only one-line status\n", - "messages during the analysis process." + "The project must be saved before running sequential fitting, so that\n", + "results can be written to `analysis/results.csv`." ] }, { @@ -87,7 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "project.verbosity = 'short'" + "project.save_as('data/cosio_project', temporary=False)" ] }, { @@ -229,10 +229,12 @@ "id": "15", "metadata": {}, "source": [ - "## Step 3: Define Experiments\n", + "## Step 3: Define Template Experiment\n", "\n", - "This section shows how to add experiments, configure their parameters,\n", - "and link the structures defined above.\n", + "For sequential fitting, we create a single template experiment from\n", + "the first data file. This template defines the instrument, peak\n", + "profile, background, and linked phases that will be reused for every\n", + "data file in the scan.\n", "\n", "#### Download Measured Data" ] @@ -244,7 +246,7 @@ "metadata": {}, "outputs": [], "source": [ - "file_path = ed.download_data(id=27, destination='data')" + "zip_path = ed.download_data(id=24, destination='data')" ] }, { @@ -252,7 +254,7 @@ "id": "17", "metadata": {}, "source": [ - "#### Create Experiments and Set Temperature" + "#### Extract Data Files" ] }, { @@ -262,18 +264,8 @@ "metadata": {}, "outputs": [], "source": [ - "data_paths = ed.extract_data_paths_from_zip(file_path)\n", - "for i, data_path in enumerate(data_paths, start=1):\n", - " name = f'd20_{i}'\n", - " project.experiments.add_from_data_path(\n", - " name=name,\n", - " data_path=data_path,\n", - " )\n", - " expt = project.experiments[name]\n", - " expt.diffrn.ambient_temperature = ed.extract_metadata(\n", - " file_path=data_path,\n", - " pattern=r'^TEMP\\s+([0-9.]+)',\n", - " )" + "data_dir = 'data/d20_scan'\n", + "data_paths = ed.extract_data_paths_from_zip(zip_path, destination=data_dir)" ] }, { @@ -281,7 +273,7 @@ "id": "19", "metadata": {}, "source": [ - "#### Set Instrument" + "#### Create Template Experiment from the First File" ] }, { @@ -291,9 +283,11 @@ "metadata": {}, "outputs": [], "source": [ - "for expt in project.experiments:\n", - " expt.instrument.setup_wavelength = 1.87\n", - " expt.instrument.calib_twotheta_offset = 0.29" + "project.experiments.add_from_data_path(\n", + " name='d20',\n", + " data_path=data_paths[0],\n", + ")\n", + "expt = project.experiments['d20']" ] }, { @@ -301,7 +295,7 @@ "id": "21", "metadata": {}, "source": [ - "#### Set Peak Profile" + "#### Set Instrument" ] }, { @@ -311,11 +305,8 @@ "metadata": {}, "outputs": [], "source": [ - "for expt in project.experiments:\n", - " expt.peak.broad_gauss_u = 0.24\n", - " expt.peak.broad_gauss_v = -0.53\n", - " expt.peak.broad_gauss_w = 0.38\n", - " expt.peak.broad_lorentz_y = 0.02" + "expt.instrument.setup_wavelength = 1.87\n", + "expt.instrument.calib_twotheta_offset = 0.29" ] }, { @@ -323,7 +314,7 @@ "id": "23", "metadata": {}, "source": [ - "#### Set Excluded Regions" + "#### Set Peak Profile" ] }, { @@ -333,9 +324,10 @@ "metadata": {}, "outputs": [], "source": [ - "for expt in project.experiments:\n", - " expt.excluded_regions.create(id='1', start=0, end=8)\n", - " expt.excluded_regions.create(id='2', start=150, end=180)" + "expt.peak.broad_gauss_u = 0.24\n", + "expt.peak.broad_gauss_v = -0.53\n", + "expt.peak.broad_gauss_w = 0.38\n", + "expt.peak.broad_lorentz_y = 0.02" ] }, { @@ -343,7 +335,7 @@ "id": "25", "metadata": {}, "source": [ - "#### Set Background" + "#### Set Excluded Regions" ] }, { @@ -353,21 +345,8 @@ "metadata": {}, "outputs": [], "source": [ - "for expt in project.experiments:\n", - " expt.background.create(id='1', x=8, y=609)\n", - " expt.background.create(id='2', x=9, y=581)\n", - " expt.background.create(id='3', x=10, y=563)\n", - " expt.background.create(id='4', x=11, y=540)\n", - " expt.background.create(id='5', x=12, y=520)\n", - " expt.background.create(id='6', x=15, y=507)\n", - " expt.background.create(id='7', x=25, y=463)\n", - " expt.background.create(id='8', x=30, y=434)\n", - " expt.background.create(id='9', x=50, y=451)\n", - " expt.background.create(id='10', x=70, y=431)\n", - " expt.background.create(id='11', x=90, y=414)\n", - " expt.background.create(id='12', x=110, y=361)\n", - " expt.background.create(id='13', x=130, y=292)\n", - " expt.background.create(id='14', x=150, y=241)" + "expt.excluded_regions.create(id='1', start=0, end=8)\n", + "expt.excluded_regions.create(id='2', start=150, end=180)" ] }, { @@ -375,7 +354,7 @@ "id": "27", "metadata": {}, "source": [ - "#### Set Linked Phases" + "#### Set Background" ] }, { @@ -385,24 +364,54 @@ "metadata": {}, "outputs": [], "source": [ - "for expt in project.experiments:\n", - " expt.linked_phases.create(id='cosio', scale=1.2)" + "expt.background.create(id='1', x=8, y=609)\n", + "expt.background.create(id='2', x=9, y=581)\n", + "expt.background.create(id='3', x=10, y=563)\n", + "expt.background.create(id='4', x=11, y=540)\n", + "expt.background.create(id='5', x=12, y=520)\n", + "expt.background.create(id='6', x=15, y=507)\n", + "expt.background.create(id='7', x=25, y=463)\n", + "expt.background.create(id='8', x=30, y=434)\n", + "expt.background.create(id='9', x=50, y=451)\n", + "expt.background.create(id='10', x=70, y=431)\n", + "expt.background.create(id='11', x=90, y=414)\n", + "expt.background.create(id='12', x=110, y=361)\n", + "expt.background.create(id='13', x=130, y=292)\n", + "expt.background.create(id='14', x=150, y=241)" ] }, { "cell_type": "markdown", "id": "29", "metadata": {}, + "source": [ + "#### Set Linked Phases" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "expt.linked_phases.create(id='cosio', scale=1.2)" + ] + }, + { + "cell_type": "markdown", + "id": "31", + "metadata": {}, "source": [ "## Step 4: Perform Analysis\n", "\n", "This section shows how to set free parameters, define constraints,\n", - "and run the refinement." + "and run the sequential refinement." ] }, { "cell_type": "markdown", - "id": "30", + "id": "32", "metadata": {}, "source": [ "#### Set Free Parameters" @@ -411,7 +420,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -442,27 +451,26 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "34", "metadata": {}, "outputs": [], "source": [ - "for expt in project.experiments:\n", - " expt.linked_phases['cosio'].scale.free = True\n", + "expt.linked_phases['cosio'].scale.free = True\n", "\n", - " expt.instrument.calib_twotheta_offset.free = True\n", + "expt.instrument.calib_twotheta_offset.free = True\n", "\n", - " expt.peak.broad_gauss_u.free = True\n", - " expt.peak.broad_gauss_v.free = True\n", - " expt.peak.broad_gauss_w.free = True\n", - " expt.peak.broad_lorentz_y.free = True\n", + "expt.peak.broad_gauss_u.free = True\n", + "expt.peak.broad_gauss_v.free = True\n", + "expt.peak.broad_gauss_w.free = True\n", + "expt.peak.broad_lorentz_y.free = True\n", "\n", - " for point in expt.background:\n", - " point.y.free = True" + "for point in expt.background:\n", + " point.y.free = True" ] }, { "cell_type": "markdown", - "id": "33", + "id": "35", "metadata": {}, "source": [ "#### Set Constraints\n", @@ -473,23 +481,23 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "36", "metadata": {}, "outputs": [], "source": [ "project.analysis.aliases.create(\n", " label='biso_Co1',\n", - " param_uid=structure.atom_sites['Co1'].b_iso.uid,\n", + " param=structure.atom_sites['Co1'].b_iso,\n", ")\n", "project.analysis.aliases.create(\n", " label='biso_Co2',\n", - " param_uid=structure.atom_sites['Co2'].b_iso.uid,\n", + " param=structure.atom_sites['Co2'].b_iso,\n", ")" ] }, { "cell_type": "markdown", - "id": "35", + "id": "37", "metadata": {}, "source": [ "Set constraints." @@ -498,91 +506,129 @@ { "cell_type": "code", "execution_count": null, - "id": "36", + "id": "38", "metadata": {}, "outputs": [], "source": [ - "project.analysis.constraints.create(\n", - " expression='biso_Co2 = biso_Co1',\n", - ")" + "project.analysis.constraints.create(expression='biso_Co2 = biso_Co1')" ] }, { "cell_type": "markdown", - "id": "37", + "id": "39", "metadata": {}, "source": [ - "Apply constraints." + "#### Run Single Fitting\n", + "\n", + "This is the fitting of the first dataset to optimize the initial\n", + "parameters for the sequential fitting. This step is optional but can\n", + "help with convergence and speed of the sequential fitting, especially\n", + "if the initial parameters are far from optimal." ] }, { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "40", "metadata": {}, "outputs": [], "source": [ - "project.analysis.apply_constraints()" + "project.analysis.fit()" ] }, { "cell_type": "markdown", - "id": "39", + "id": "41", "metadata": {}, "source": [ - "#### Set Fit Mode" + "#### Run Sequential Fitting\n", + "\n", + "Set output verbosity level to \"short\" to show only one-line status\n", + "messages during the analysis process." ] }, { "cell_type": "code", "execution_count": null, - "id": "40", + "id": "42", "metadata": {}, "outputs": [], "source": [ - "project.analysis.fit_mode.mode = 'single'" + "project.verbosity = 'short'" ] }, { "cell_type": "markdown", - "id": "41", + "id": "43", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "\n", + "Define a callback that extracts the temperature from each data file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [], + "source": [ + "def extract_diffrn(file_path):\n", + " temperature = ed.extract_metadata(\n", + " file_path=file_path,\n", + " pattern=r'^TEMP\\s+([0-9.]+)',\n", + " )\n", + " return {'ambient_temperature': temperature}" + ] + }, + { + "cell_type": "markdown", + "id": "45", "metadata": {}, "source": [ - "#### Run Fitting" + "Run the sequential fit over all data files in the scan directory." ] }, { "cell_type": "code", "execution_count": null, - "id": "42", + "id": "46", "metadata": {}, "outputs": [], "source": [ - "project.analysis.fit()" + "project.analysis.fit_sequential(\n", + " data_dir=data_dir,\n", + " extract_diffrn=extract_diffrn,\n", + " max_workers='auto',\n", + ")" ] }, { "cell_type": "markdown", - "id": "43", + "id": "47", "metadata": {}, "source": [ - "#### Plot Measured vs Calculated" + "#### Replay a Dataset\n", + "\n", + "Apply fitted parameters from the last CSV row and plot the result." ] }, { "cell_type": "code", "execution_count": null, - "id": "44", + "id": "48", "metadata": {}, "outputs": [], "source": [ - "last_expt_name = project.experiments.names[-1]\n", - "project.plot_meas_vs_calc(expt_name=last_expt_name, show_residual=True)" + "project.apply_params_from_csv(row_index=-1)\n", + "project.plot_meas_vs_calc(expt_name='d20', show_residual=True)" ] }, { "cell_type": "markdown", - "id": "45", + "id": "49", "metadata": {}, "source": [ "#### Plot Parameter Evolution\n", @@ -593,16 +639,16 @@ { "cell_type": "code", "execution_count": null, - "id": "46", + "id": "50", "metadata": {}, "outputs": [], "source": [ - "temperature = project.experiments[0].diffrn.ambient_temperature" + "temperature = expt.diffrn.ambient_temperature" ] }, { "cell_type": "markdown", - "id": "47", + "id": "51", "metadata": {}, "source": [ "Plot unit cell parameters vs. temperature." @@ -611,7 +657,7 @@ { "cell_type": "code", "execution_count": null, - "id": "48", + "id": "52", "metadata": {}, "outputs": [], "source": [ @@ -622,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "49", + "id": "53", "metadata": {}, "source": [ "Plot isotropic displacement parameters vs. temperature." @@ -631,7 +677,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50", + "id": "54", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +690,7 @@ }, { "cell_type": "markdown", - "id": "51", + "id": "55", "metadata": {}, "source": [ "Plot selected fractional coordinates vs. temperature." @@ -653,7 +699,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52", + "id": "56", "metadata": {}, "outputs": [], "source": [ diff --git a/docs/docs/tutorials/ed-17.py b/docs/docs/tutorials/ed-17.py index af06031f..eb8bcd2a 100644 --- a/docs/docs/tutorials/ed-17.py +++ b/docs/docs/tutorials/ed-17.py @@ -3,9 +3,9 @@ # # This example demonstrates a Rietveld refinement of the Co2SiO4 crystal # structure using constant-wavelength neutron powder diffraction data -# from D20 at ILL. A sequential refinement of the same structure against -# a temperature scan is performed to show how to manage multiple -# experiments in a project. +# from D20 at ILL. A sequential refinement is performed against a +# temperature scan using `fit_sequential`, which processes each data +# file independently without loading all datasets into memory at once. # %% [markdown] # ## Import Library @@ -22,11 +22,11 @@ project = ed.Project() # %% [markdown] -# Set output verbosity level to "short" to show only one-line status -# messages during the analysis process. +# The project must be saved before running sequential fitting, so that +# results can be written to `analysis/results.csv`. # %% -project.verbosity = 'short' +project.save_as('data/cosio_project', temporary=False) # %% [markdown] # ## Step 2: Define Crystal Structure @@ -115,91 +115,88 @@ ) # %% [markdown] -# ## Step 3: Define Experiments +# ## Step 3: Define Template Experiment # -# This section shows how to add experiments, configure their parameters, -# and link the structures defined above. +# For sequential fitting, we create a single template experiment from +# the first data file. This template defines the instrument, peak +# profile, background, and linked phases that will be reused for every +# data file in the scan. # # #### Download Measured Data # %% -file_path = ed.download_data(id=27, destination='data') +zip_path = ed.download_data(id=27, destination='data') # %% [markdown] -# #### Create Experiments and Set Temperature +# #### Extract Data Files # %% -data_paths = ed.extract_data_paths_from_zip(file_path) -for i, data_path in enumerate(data_paths, start=1): - name = f'd20_{i}' - project.experiments.add_from_data_path( - name=name, - data_path=data_path, - ) - expt = project.experiments[name] - expt.diffrn.ambient_temperature = ed.extract_metadata( - file_path=data_path, - pattern=r'^TEMP\s+([0-9.]+)', - ) +data_dir = 'data/d20_scan' +data_paths = ed.extract_data_paths_from_zip(zip_path, destination=data_dir) + +# %% [markdown] +# #### Create Template Experiment from the First File + +# %% +project.experiments.add_from_data_path( + name='d20', + data_path=data_paths[0], +) +expt = project.experiments['d20'] # %% [markdown] # #### Set Instrument # %% -for expt in project.experiments: - expt.instrument.setup_wavelength = 1.87 - expt.instrument.calib_twotheta_offset = 0.29 +expt.instrument.setup_wavelength = 1.87 +expt.instrument.calib_twotheta_offset = 0.29 # %% [markdown] # #### Set Peak Profile # %% -for expt in project.experiments: - expt.peak.broad_gauss_u = 0.24 - expt.peak.broad_gauss_v = -0.53 - expt.peak.broad_gauss_w = 0.38 - expt.peak.broad_lorentz_y = 0.02 +expt.peak.broad_gauss_u = 0.24 +expt.peak.broad_gauss_v = -0.53 +expt.peak.broad_gauss_w = 0.38 +expt.peak.broad_lorentz_y = 0.02 # %% [markdown] # #### Set Excluded Regions # %% -for expt in project.experiments: - expt.excluded_regions.create(id='1', start=0, end=8) - expt.excluded_regions.create(id='2', start=150, end=180) +expt.excluded_regions.create(id='1', start=0, end=8) +expt.excluded_regions.create(id='2', start=150, end=180) # %% [markdown] # #### Set Background # %% -for expt in project.experiments: - expt.background.create(id='1', x=8, y=609) - expt.background.create(id='2', x=9, y=581) - expt.background.create(id='3', x=10, y=563) - expt.background.create(id='4', x=11, y=540) - expt.background.create(id='5', x=12, y=520) - expt.background.create(id='6', x=15, y=507) - expt.background.create(id='7', x=25, y=463) - expt.background.create(id='8', x=30, y=434) - expt.background.create(id='9', x=50, y=451) - expt.background.create(id='10', x=70, y=431) - expt.background.create(id='11', x=90, y=414) - expt.background.create(id='12', x=110, y=361) - expt.background.create(id='13', x=130, y=292) - expt.background.create(id='14', x=150, y=241) +expt.background.create(id='1', x=8, y=609) +expt.background.create(id='2', x=9, y=581) +expt.background.create(id='3', x=10, y=563) +expt.background.create(id='4', x=11, y=540) +expt.background.create(id='5', x=12, y=520) +expt.background.create(id='6', x=15, y=507) +expt.background.create(id='7', x=25, y=463) +expt.background.create(id='8', x=30, y=434) +expt.background.create(id='9', x=50, y=451) +expt.background.create(id='10', x=70, y=431) +expt.background.create(id='11', x=90, y=414) +expt.background.create(id='12', x=110, y=361) +expt.background.create(id='13', x=130, y=292) +expt.background.create(id='14', x=150, y=241) # %% [markdown] # #### Set Linked Phases # %% -for expt in project.experiments: - expt.linked_phases.create(id='cosio', scale=1.2) +expt.linked_phases.create(id='cosio', scale=1.2) # %% [markdown] # ## Step 4: Perform Analysis # # This section shows how to set free parameters, define constraints, -# and run the refinement. +# and run the sequential refinement. # %% [markdown] # #### Set Free Parameters @@ -229,18 +226,17 @@ structure.atom_sites['O3'].b_iso.free = True # %% -for expt in project.experiments: - expt.linked_phases['cosio'].scale.free = True +expt.linked_phases['cosio'].scale.free = True - expt.instrument.calib_twotheta_offset.free = True +expt.instrument.calib_twotheta_offset.free = True - expt.peak.broad_gauss_u.free = True - expt.peak.broad_gauss_v.free = True - expt.peak.broad_gauss_w.free = True - expt.peak.broad_lorentz_y.free = True +expt.peak.broad_gauss_u.free = True +expt.peak.broad_gauss_v.free = True +expt.peak.broad_gauss_w.free = True +expt.peak.broad_lorentz_y.free = True - for point in expt.background: - point.y.free = True +for point in expt.background: + point.y.free = True # %% [markdown] # #### Set Constraints @@ -250,45 +246,71 @@ # %% project.analysis.aliases.create( label='biso_Co1', - param_uid=structure.atom_sites['Co1'].b_iso.uid, + param=structure.atom_sites['Co1'].b_iso, ) project.analysis.aliases.create( label='biso_Co2', - param_uid=structure.atom_sites['Co2'].b_iso.uid, + param=structure.atom_sites['Co2'].b_iso, ) # %% [markdown] # Set constraints. # %% -project.analysis.constraints.create( - expression='biso_Co2 = biso_Co1', -) +project.analysis.constraints.create(expression='biso_Co2 = biso_Co1') + +# %% [markdown] +# #### Run Single Fitting +# +# This is the fitting of the first dataset to optimize the initial +# parameters for the sequential fitting. This step is optional but can +# help with convergence and speed of the sequential fitting, especially +# if the initial parameters are far from optimal. + +# %% +project.analysis.fit() # %% [markdown] -# Apply constraints. +# #### Run Sequential Fitting +# +# Set output verbosity level to "short" to show only one-line status +# messages during the analysis process. # %% -project.analysis.apply_constraints() +project.verbosity = 'short' # %% [markdown] -# #### Set Fit Mode +# +# Define a callback that extracts the temperature from each data file. + # %% -project.analysis.fit_mode.mode = 'single' +def extract_diffrn(file_path): + temperature = ed.extract_metadata( + file_path=file_path, + pattern=r'^TEMP\s+([0-9.]+)', + ) + return {'ambient_temperature': temperature} + # %% [markdown] -# #### Run Fitting +# Run the sequential fit over all data files in the scan directory. # %% -project.analysis.fit() +project.analysis.fit_sequential( + data_dir=data_dir, + extract_diffrn=extract_diffrn, + max_workers='auto', +) # %% [markdown] -# #### Plot Measured vs Calculated +# #### Replay a Dataset +# +# Apply fitted parameters from the last CSV row and plot the result. # %% -last_expt_name = project.experiments.names[-1] -project.plot_meas_vs_calc(expt_name=last_expt_name, show_residual=True) +project.apply_params_from_csv(row_index=-1) +project.plot_meas_vs_calc(expt_name='d20', show_residual=True) # %% [markdown] # #### Plot Parameter Evolution @@ -296,7 +318,7 @@ # Define the quantity to use as the x-axis in the following plots. # %% -temperature = project.experiments[0].diffrn.ambient_temperature +temperature = expt.diffrn.ambient_temperature # %% [markdown] # Plot unit cell parameters vs. temperature. diff --git a/docs/docs/tutorials/ed-18.ipynb b/docs/docs/tutorials/ed-18.ipynb new file mode 100644 index 00000000..60ce7707 --- /dev/null +++ b/docs/docs/tutorials/ed-18.ipynb @@ -0,0 +1,193 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "71e56392", + "metadata": { + "tags": [ + "hide-in-docs" + ] + }, + "outputs": [], + "source": [ + "# Check whether easydiffraction is installed; install it if needed.\n", + "# Required for remote environments such as Google Colab.\n", + "import importlib.util\n", + "\n", + "if importlib.util.find_spec('easydiffraction') is None:\n", + " %pip install easydiffraction" + ] + }, + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Load Project and Fit: LBCO, HRPT\n", + "\n", + "This is the most minimal example of using EasyDiffraction. It shows\n", + "how to load a previously saved project from a directory and run\n", + "refinement — all in just a few lines of code.\n", + "\n", + "For details on how to define structures and experiments, see the other\n", + "tutorials." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Import Modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from easydiffraction import Project\n", + "from easydiffraction import download_data\n", + "from easydiffraction import extract_project_from_zip" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Download Project Archive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "zip_path = download_data(id=28, destination='data')" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## Extract Project" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "project_dir = extract_project_from_zip(zip_path, destination='data')" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Load Project" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "project = Project.load(project_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## Perform Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "project.analysis.fit()" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## Show Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "project.analysis.show_fit_results()" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "## Plot Meas vs Calc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "project.plot_meas_vs_calc(expt_name='hrpt', show_residual=True)" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "## Save Project" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "project.save()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/tutorials/ed-18.py b/docs/docs/tutorials/ed-18.py new file mode 100644 index 00000000..f4485dbc --- /dev/null +++ b/docs/docs/tutorials/ed-18.py @@ -0,0 +1,59 @@ +# %% [markdown] +# # Load Project and Fit: LBCO, HRPT +# +# This is the most minimal example of using EasyDiffraction. It shows +# how to load a previously saved project from a directory and run +# refinement — all in just a few lines of code. +# +# For details on how to define structures and experiments, see the other +# tutorials. + +# %% [markdown] +# ## Import Modules + +# %% +from easydiffraction import Project +from easydiffraction import download_data +from easydiffraction import extract_project_from_zip + +# %% [markdown] +# ## Download Project Archive + +# %% +zip_path = download_data(id=28, destination='data') + +# %% [markdown] +# ## Extract Project + +# %% +project_dir = extract_project_from_zip(zip_path, destination='data') + +# %% [markdown] +# ## Load Project + +# %% +project = Project.load(project_dir) + +# %% [markdown] +# ## Perform Analysis + +# %% +project.analysis.fit() + +# %% [markdown] +# ## Show Results + +# %% +project.analysis.show_fit_results() + +# %% [markdown] +# ## Plot Meas vs Calc + +# %% +project.plot_meas_vs_calc(expt_name='hrpt', show_residual=True) + +# %% [markdown] +# ## Save Project + +# %% +project.save() diff --git a/docs/docs/tutorials/ed-2.ipynb b/docs/docs/tutorials/ed-2.ipynb index 06f40dc9..00d769ad 100644 --- a/docs/docs/tutorials/ed-2.ipynb +++ b/docs/docs/tutorials/ed-2.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7dda0a11", + "id": "b94f1ffd", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-3.ipynb b/docs/docs/tutorials/ed-3.ipynb index f64a041a..78c3ed5c 100644 --- a/docs/docs/tutorials/ed-3.ipynb +++ b/docs/docs/tutorials/ed-3.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be19f628", + "id": "725cf769", "metadata": { "tags": [ "hide-in-docs" @@ -1425,11 +1425,11 @@ "source": [ "project.analysis.aliases.create(\n", " label='biso_La',\n", - " param_uid=project.structures['lbco'].atom_sites['La'].b_iso.uid,\n", + " param=project.structures['lbco'].atom_sites['La'].b_iso,\n", ")\n", "project.analysis.aliases.create(\n", " label='biso_Ba',\n", - " param_uid=project.structures['lbco'].atom_sites['Ba'].b_iso.uid,\n", + " param=project.structures['lbco'].atom_sites['Ba'].b_iso,\n", ")" ] }, @@ -1474,7 +1474,7 @@ "id": "144", "metadata": {}, "source": [ - "Show free parameters before applying constraints." + "Show free parameters." ] }, { @@ -1491,42 +1491,6 @@ "cell_type": "markdown", "id": "146", "metadata": {}, - "source": [ - "Apply constraints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "147", - "metadata": {}, - "outputs": [], - "source": [ - "project.analysis.apply_constraints()" - ] - }, - { - "cell_type": "markdown", - "id": "148", - "metadata": {}, - "source": [ - "Show free parameters after applying constraints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "149", - "metadata": {}, - "outputs": [], - "source": [ - "project.analysis.show_free_params()" - ] - }, - { - "cell_type": "markdown", - "id": "150", - "metadata": {}, "source": [ "#### Run Fitting" ] @@ -1534,7 +1498,7 @@ { "cell_type": "code", "execution_count": null, - "id": "151", + "id": "147", "metadata": {}, "outputs": [], "source": [ @@ -1544,7 +1508,7 @@ }, { "cell_type": "markdown", - "id": "152", + "id": "148", "metadata": {}, "source": [ "#### Plot Measured vs Calculated" @@ -1553,7 +1517,7 @@ { "cell_type": "code", "execution_count": null, - "id": "153", + "id": "149", "metadata": {}, "outputs": [], "source": [ @@ -1563,7 +1527,7 @@ { "cell_type": "code", "execution_count": null, - "id": "154", + "id": "150", "metadata": {}, "outputs": [], "source": [ @@ -1572,7 +1536,7 @@ }, { "cell_type": "markdown", - "id": "155", + "id": "151", "metadata": {}, "source": [ "#### Save Project State" @@ -1581,7 +1545,7 @@ { "cell_type": "code", "execution_count": null, - "id": "156", + "id": "152", "metadata": {}, "outputs": [], "source": [ @@ -1590,7 +1554,7 @@ }, { "cell_type": "markdown", - "id": "157", + "id": "153", "metadata": {}, "source": [ "### Perform Fit 5/5\n", @@ -1603,23 +1567,23 @@ { "cell_type": "code", "execution_count": null, - "id": "158", + "id": "154", "metadata": {}, "outputs": [], "source": [ "project.analysis.aliases.create(\n", " label='occ_La',\n", - " param_uid=project.structures['lbco'].atom_sites['La'].occupancy.uid,\n", + " param=project.structures['lbco'].atom_sites['La'].occupancy,\n", ")\n", "project.analysis.aliases.create(\n", " label='occ_Ba',\n", - " param_uid=project.structures['lbco'].atom_sites['Ba'].occupancy.uid,\n", + " param=project.structures['lbco'].atom_sites['Ba'].occupancy,\n", ")" ] }, { "cell_type": "markdown", - "id": "159", + "id": "155", "metadata": {}, "source": [ "Set more constraints." @@ -1628,7 +1592,7 @@ { "cell_type": "code", "execution_count": null, - "id": "160", + "id": "156", "metadata": {}, "outputs": [], "source": [ @@ -1639,7 +1603,7 @@ }, { "cell_type": "markdown", - "id": "161", + "id": "157", "metadata": {}, "source": [ "Show defined constraints." @@ -1648,8 +1612,10 @@ { "cell_type": "code", "execution_count": null, - "id": "162", - "metadata": {}, + "id": "158", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "project.analysis.show_constraints()" @@ -1657,25 +1623,7 @@ }, { "cell_type": "markdown", - "id": "163", - "metadata": {}, - "source": [ - "Apply constraints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "164", - "metadata": {}, - "outputs": [], - "source": [ - "project.analysis.apply_constraints()" - ] - }, - { - "cell_type": "markdown", - "id": "165", + "id": "159", "metadata": {}, "source": [ "Set structure parameters to be refined." @@ -1684,7 +1632,7 @@ { "cell_type": "code", "execution_count": null, - "id": "166", + "id": "160", "metadata": {}, "outputs": [], "source": [ @@ -1693,7 +1641,7 @@ }, { "cell_type": "markdown", - "id": "167", + "id": "161", "metadata": {}, "source": [ "Show free parameters after selection." @@ -1702,7 +1650,7 @@ { "cell_type": "code", "execution_count": null, - "id": "168", + "id": "162", "metadata": {}, "outputs": [], "source": [ @@ -1711,7 +1659,7 @@ }, { "cell_type": "markdown", - "id": "169", + "id": "163", "metadata": {}, "source": [ "#### Run Fitting" @@ -1720,7 +1668,7 @@ { "cell_type": "code", "execution_count": null, - "id": "170", + "id": "164", "metadata": {}, "outputs": [], "source": [ @@ -1730,7 +1678,7 @@ }, { "cell_type": "markdown", - "id": "171", + "id": "165", "metadata": {}, "source": [ "#### Plot Measured vs Calculated" @@ -1739,7 +1687,7 @@ { "cell_type": "code", "execution_count": null, - "id": "172", + "id": "166", "metadata": {}, "outputs": [], "source": [ @@ -1749,7 +1697,7 @@ { "cell_type": "code", "execution_count": null, - "id": "173", + "id": "167", "metadata": {}, "outputs": [], "source": [ @@ -1758,7 +1706,7 @@ }, { "cell_type": "markdown", - "id": "174", + "id": "168", "metadata": {}, "source": [ "#### Save Project State" @@ -1767,7 +1715,7 @@ { "cell_type": "code", "execution_count": null, - "id": "175", + "id": "169", "metadata": {}, "outputs": [], "source": [ @@ -1776,7 +1724,7 @@ }, { "cell_type": "markdown", - "id": "176", + "id": "170", "metadata": {}, "source": [ "## Step 5: Summary\n", @@ -1786,7 +1734,7 @@ }, { "cell_type": "markdown", - "id": "177", + "id": "171", "metadata": {}, "source": [ "#### Show Project Summary" @@ -1795,7 +1743,7 @@ { "cell_type": "code", "execution_count": null, - "id": "178", + "id": "172", "metadata": {}, "outputs": [], "source": [ @@ -1805,7 +1753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "179", + "id": "173", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/docs/tutorials/ed-3.py b/docs/docs/tutorials/ed-3.py index 23b60d88..1a79d789 100644 --- a/docs/docs/tutorials/ed-3.py +++ b/docs/docs/tutorials/ed-3.py @@ -567,11 +567,11 @@ # %% project.analysis.aliases.create( label='biso_La', - param_uid=project.structures['lbco'].atom_sites['La'].b_iso.uid, + param=project.structures['lbco'].atom_sites['La'].b_iso, ) project.analysis.aliases.create( label='biso_Ba', - param_uid=project.structures['lbco'].atom_sites['Ba'].b_iso.uid, + param=project.structures['lbco'].atom_sites['Ba'].b_iso, ) # %% [markdown] @@ -587,19 +587,7 @@ project.analysis.show_constraints() # %% [markdown] -# Show free parameters before applying constraints. - -# %% -project.analysis.show_free_params() - -# %% [markdown] -# Apply constraints. - -# %% -project.analysis.apply_constraints() - -# %% [markdown] -# Show free parameters after applying constraints. +# Show free parameters. # %% project.analysis.show_free_params() @@ -636,11 +624,11 @@ # %% project.analysis.aliases.create( label='occ_La', - param_uid=project.structures['lbco'].atom_sites['La'].occupancy.uid, + param=project.structures['lbco'].atom_sites['La'].occupancy, ) project.analysis.aliases.create( label='occ_Ba', - param_uid=project.structures['lbco'].atom_sites['Ba'].occupancy.uid, + param=project.structures['lbco'].atom_sites['Ba'].occupancy, ) # %% [markdown] @@ -657,11 +645,6 @@ # %% project.analysis.show_constraints() -# %% [markdown] -# Apply constraints. - -# %% -project.analysis.apply_constraints() # %% [markdown] # Set structure parameters to be refined. diff --git a/docs/docs/tutorials/ed-4.ipynb b/docs/docs/tutorials/ed-4.ipynb index 9d9381b6..fe06bd1e 100644 --- a/docs/docs/tutorials/ed-4.ipynb +++ b/docs/docs/tutorials/ed-4.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebfd4b4e", + "id": "16833253", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-5.ipynb b/docs/docs/tutorials/ed-5.ipynb index f3b3ba67..f94b7564 100644 --- a/docs/docs/tutorials/ed-5.ipynb +++ b/docs/docs/tutorials/ed-5.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d07ce7b8", + "id": "a3281949", "metadata": { "tags": [ "hide-in-docs" @@ -522,11 +522,11 @@ "source": [ "project.analysis.aliases.create(\n", " label='biso_Co1',\n", - " param_uid=project.structures['cosio'].atom_sites['Co1'].b_iso.uid,\n", + " param=project.structures['cosio'].atom_sites['Co1'].b_iso,\n", ")\n", "project.analysis.aliases.create(\n", " label='biso_Co2',\n", - " param_uid=project.structures['cosio'].atom_sites['Co2'].b_iso.uid,\n", + " param=project.structures['cosio'].atom_sites['Co2'].b_iso,\n", ")" ] }, @@ -542,7 +542,9 @@ "cell_type": "code", "execution_count": null, "id": "42", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "project.analysis.constraints.create(\n", @@ -554,24 +556,6 @@ "cell_type": "markdown", "id": "43", "metadata": {}, - "source": [ - "Apply constraints." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44", - "metadata": {}, - "outputs": [], - "source": [ - "project.analysis.apply_constraints()" - ] - }, - { - "cell_type": "markdown", - "id": "45", - "metadata": {}, "source": [ "#### Run Fitting" ] @@ -579,7 +563,7 @@ { "cell_type": "code", "execution_count": null, - "id": "46", + "id": "44", "metadata": {}, "outputs": [], "source": [ @@ -589,7 +573,7 @@ }, { "cell_type": "markdown", - "id": "47", + "id": "45", "metadata": {}, "source": [ "#### Plot Measured vs Calculated" @@ -598,7 +582,7 @@ { "cell_type": "code", "execution_count": null, - "id": "48", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -608,7 +592,7 @@ { "cell_type": "code", "execution_count": null, - "id": "49", + "id": "47", "metadata": {}, "outputs": [], "source": [ @@ -617,7 +601,7 @@ }, { "cell_type": "markdown", - "id": "50", + "id": "48", "metadata": {}, "source": [ "## Summary\n", @@ -627,7 +611,7 @@ }, { "cell_type": "markdown", - "id": "51", + "id": "49", "metadata": {}, "source": [ "#### Show Project Summary" @@ -636,7 +620,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52", + "id": "50", "metadata": {}, "outputs": [], "source": [ diff --git a/docs/docs/tutorials/ed-5.py b/docs/docs/tutorials/ed-5.py index 74e1d887..4e41a905 100644 --- a/docs/docs/tutorials/ed-5.py +++ b/docs/docs/tutorials/ed-5.py @@ -255,11 +255,11 @@ # %% project.analysis.aliases.create( label='biso_Co1', - param_uid=project.structures['cosio'].atom_sites['Co1'].b_iso.uid, + param=project.structures['cosio'].atom_sites['Co1'].b_iso, ) project.analysis.aliases.create( label='biso_Co2', - param_uid=project.structures['cosio'].atom_sites['Co2'].b_iso.uid, + param=project.structures['cosio'].atom_sites['Co2'].b_iso, ) # %% [markdown] @@ -270,11 +270,6 @@ expression='biso_Co2 = biso_Co1', ) -# %% [markdown] -# Apply constraints. - -# %% -project.analysis.apply_constraints() # %% [markdown] # #### Run Fitting diff --git a/docs/docs/tutorials/ed-6.ipynb b/docs/docs/tutorials/ed-6.ipynb index 70a334b8..92130f6b 100644 --- a/docs/docs/tutorials/ed-6.ipynb +++ b/docs/docs/tutorials/ed-6.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08932f7f", + "id": "48d300a4", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-7.ipynb b/docs/docs/tutorials/ed-7.ipynb index ce490e56..12ad852a 100644 --- a/docs/docs/tutorials/ed-7.ipynb +++ b/docs/docs/tutorials/ed-7.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "229e169b", + "id": "8cc5d312", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-8.ipynb b/docs/docs/tutorials/ed-8.ipynb index 13ce802e..80aec4e4 100644 --- a/docs/docs/tutorials/ed-8.ipynb +++ b/docs/docs/tutorials/ed-8.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8550c647", + "id": "0bc22f40", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/ed-9.ipynb b/docs/docs/tutorials/ed-9.ipynb index 1d3c883d..cc1f0ba4 100644 --- a/docs/docs/tutorials/ed-9.ipynb +++ b/docs/docs/tutorials/ed-9.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "520a3ca6", + "id": "9c2a5d62", "metadata": { "tags": [ "hide-in-docs" diff --git a/docs/docs/tutorials/index.json b/docs/docs/tutorials/index.json index 3f2f223c..9a438874 100644 --- a/docs/docs/tutorials/index.json +++ b/docs/docs/tutorials/index.json @@ -117,5 +117,12 @@ "title": "Structure Refinement: Co2SiO4, D20 (Temperature scan)", "description": "Sequential Rietveld refinement of Co2SiO4 using constant wavelength neutron powder diffraction data from D20 at ILL across a temperature scan", "level": "advanced" + }, + "18": { + "url": "https://easyscience.github.io/diffraction-lib/{version}/tutorials/ed-18/ed-18.ipynb", + "original_name": "", + "title": "Quick Start: LBCO Load Project", + "description": "Most minimal example: load a saved project from a directory and run Rietveld refinement of La0.5Ba0.5CoO3", + "level": "quick" } } diff --git a/docs/docs/tutorials/index.md b/docs/docs/tutorials/index.md index ef22e949..06a9c69e 100644 --- a/docs/docs/tutorials/index.md +++ b/docs/docs/tutorials/index.md @@ -17,6 +17,10 @@ The tutorials are organized into the following categories. ## Getting Started +- [LBCO `quick` `load`](ed-18.ipynb) – The most minimal example showing + how to load a previously saved project from a directory and run + refinement. Useful when a project has already been set up and saved in + a prior session. - [LBCO `quick` CIF](ed-1.ipynb) – A minimal example intended as a quick reference for users already familiar with the EasyDiffraction API or who want to see how Rietveld refinement of the La0.5Ba0.5CoO3 crystal diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c272bdec..5f2415c3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -191,6 +191,7 @@ nav: - Tutorials: - Tutorials: tutorials/index.md - Getting Started: + - LBCO quick load: tutorials/ed-18.ipynb - LBCO quick CIF: tutorials/ed-1.ipynb - LBCO quick code: tutorials/ed-2.ipynb - LBCO complete: tutorials/ed-3.ipynb diff --git a/src/easydiffraction/__init__.py b/src/easydiffraction/__init__.py index 10308402..11ea117c 100644 --- a/src/easydiffraction/__init__.py +++ b/src/easydiffraction/__init__.py @@ -6,6 +6,7 @@ from easydiffraction.io.ascii import extract_data_paths_from_dir from easydiffraction.io.ascii import extract_data_paths_from_zip from easydiffraction.io.ascii import extract_metadata +from easydiffraction.io.ascii import extract_project_from_zip from easydiffraction.project.project import Project from easydiffraction.utils.logging import Logger from easydiffraction.utils.logging import console diff --git a/src/easydiffraction/analysis/analysis.py b/src/easydiffraction/analysis/analysis.py index c21db3e6..80c5c1a3 100644 --- a/src/easydiffraction/analysis/analysis.py +++ b/src/easydiffraction/analysis/analysis.py @@ -563,16 +563,7 @@ def show_constraints(self) -> None: columns_alignment=['left'], columns_data=rows, ) - - def apply_constraints(self) -> None: - """Apply currently defined constraints to the project.""" - if not self.constraints._items: - log.warning('No constraints defined.') - return - - self.constraints_handler.set_aliases(self.aliases) - self.constraints_handler.set_constraints(self.constraints) - self.constraints_handler.apply() + console.print(f'Constraints enabled: {self.constraints.enabled}') def fit(self, verbosity: str | None = None) -> None: """ @@ -616,6 +607,11 @@ def fit(self, verbosity: str | None = None) -> None: log.warning('No experiments found in the project. Cannot run fit.') return + # Apply constraints before fitting so that constrained + # parameters are marked and excluded from the free parameter + # list built by the fitter. + self._update_categories() + # Run the fitting process mode = FitModeEnum(self._fit_mode.mode.value) if mode is FitModeEnum.JOINT: @@ -653,12 +649,14 @@ def fit(self, verbosity: str | None = None) -> None: short_alignments = ['left', 'right', 'right', 'center'] short_rows: list[list[str]] = [] short_display_handle: object | None = None + if verb is not VerbosityEnum.SILENT: + console.paragraph('Standard fitting') if verb is VerbosityEnum.SHORT: first = expt_names[0] last = expt_names[-1] minimizer_name = self.fitter.selection - console.paragraph( - f"Using {num_expts} experiments 🔬 from '{first}' to " + console.print( + f"📋 Using {num_expts} experiments 🔬 from '{first}' to " f"'{last}' for '{mode.value}' fitting" ) console.print(f"🚀 Starting fit process with '{minimizer_name}'...") @@ -667,8 +665,8 @@ def fit(self, verbosity: str | None = None) -> None: for _idx, expt_name in enumerate(expt_names, start=1): if verb is VerbosityEnum.FULL: - console.paragraph( - f"Using experiment 🔬 '{expt_name}' for '{mode.value}' fitting" + console.print( + f"📋 Using experiment 🔬 '{expt_name}' for '{mode.value}' fitting" ) experiment = experiments[expt_name] @@ -720,12 +718,64 @@ def fit(self, verbosity: str | None = None) -> None: raise NotImplementedError(msg) # After fitting, save the project - # TODO: Consider saving individual data during sequential - # (single) fitting, instead of waiting until the end and save - # only the last one if self.project.info.path is not None: self.project.save() + def fit_sequential( + self, + data_dir: str, + max_workers: int | str = 1, + chunk_size: int | None = None, + file_pattern: str = '*', + extract_diffrn: object = None, + verbosity: str | None = None, + ) -> None: + """ + Run sequential fitting over all data files in a directory. + + Fits each dataset independently using the current structure and + experiment as a template. Results are written incrementally to + ``analysis/results.csv`` in the project directory. + + The project must contain exactly one structure and one + experiment (the template), and must have been saved + (``save_as()``) before calling this method. + + Parameters + ---------- + data_dir : str + Path to directory containing data files. + max_workers : int | str, default=1 + Number of parallel worker processes. ``1`` = sequential. + ``'auto'`` = physical CPU count. Uses + ``ProcessPoolExecutor`` with ``spawn`` context when > 1. + chunk_size : int | None, default=None + Files per chunk. Default ``None`` uses *max_workers*. + file_pattern : str, default='*' + Glob pattern to filter files in *data_dir*. + extract_diffrn : object, default=None + User callback ``f(file_path) → {diffrn_field: value}``. + Called per file after fitting. ``None`` = no diffrn + metadata. + verbosity : str | None, default=None + ``'full'``, ``'short'``, or ``'silent'``. Default: project + verbosity. + """ + from easydiffraction.analysis.sequential import fit_sequential as _fit_seq # noqa: PLC0415 + + # Apply constraints before building the template + self._update_categories() + + _fit_seq( + analysis=self, + data_dir=data_dir, + max_workers=max_workers, + chunk_size=chunk_size, + file_pattern=file_pattern, + extract_diffrn=extract_diffrn, + verbosity=verbosity, + ) + def show_fit_results(self) -> None: """ Display a summary of the fit results. @@ -762,16 +812,14 @@ def _update_categories(self, called_by_minimizer: bool = False) -> None: called_by_minimizer : bool, default=False Whether this is called during fitting. """ + del called_by_minimizer + # Apply constraints to sync dependent parameters - if self.constraints._items: + if self.constraints.enabled and self.constraints._items: + self.constraints_handler.set_aliases(self.aliases) + self.constraints_handler.set_constraints(self.constraints) self.constraints_handler.apply() - # Update category-specific logic - # TODO: Need self.categories as in the case of datablock.py - for category in [self.aliases, self.constraints]: - if hasattr(category, '_update'): - category._update(called_by_minimizer=called_by_minimizer) - def as_cif(self) -> str: """ Serialize the analysis section to a CIF string. diff --git a/src/easydiffraction/analysis/categories/aliases/default.py b/src/easydiffraction/analysis/categories/aliases/default.py index 7b1e0df0..8aac2cdc 100644 --- a/src/easydiffraction/analysis/categories/aliases/default.py +++ b/src/easydiffraction/analysis/categories/aliases/default.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause """ -Alias category for mapping friendly names to parameter UIDs. +Alias category for mapping friendly names to parameters. Defines a small record type used by analysis configuration to refer to -parameters via readable labels instead of raw unique identifiers. +parameters via readable labels instead of opaque identifiers. At runtime +each alias holds a direct object reference to the parameter; for CIF +serialization the parameter's ``unique_name`` is stored. """ from __future__ import annotations @@ -23,8 +25,9 @@ class Alias(CategoryItem): """ Single alias entry. - Maps a human-readable ``label`` to a concrete ``param_uid`` used by - the engine. + Maps a human-readable ``label`` to a parameter object. The + ``param_unique_name`` descriptor stores the parameter's + ``unique_name`` for CIF serialization. """ def __init__(self) -> None: @@ -32,23 +35,27 @@ def __init__(self) -> None: self._label = StringDescriptor( name='label', - description='...', # TODO + description='Human-readable alias for a parameter.', value_spec=AttributeSpec( default='_', # TODO, Maybe None? validator=RegexValidator(pattern=r'^[A-Za-z_][A-Za-z0-9_]*$'), ), cif_handler=CifHandler(names=['_alias.label']), ) - self._param_uid = StringDescriptor( - name='param_uid', - description='...', # TODO + self._param_unique_name = StringDescriptor( + name='param_unique_name', + description='Unique name of the referenced parameter.', value_spec=AttributeSpec( default='_', - validator=RegexValidator(pattern=r'^[A-Za-z_][A-Za-z0-9_]*$'), + validator=RegexValidator(pattern=r'^[A-Za-z_][A-Za-z0-9_.]*$'), ), - cif_handler=CifHandler(names=['_alias.param_uid']), + cif_handler=CifHandler(names=['_alias.param_unique_name']), ) + # Direct reference to the Parameter object (runtime only). + # Stored via object.__setattr__ to avoid parent-chain mutation. + object.__setattr__(self, '_param_ref', None) # noqa: PLC2801 + self._identity.category_code = 'alias' self._identity.category_entry_name = lambda: str(self.label.value) @@ -59,7 +66,7 @@ def __init__(self) -> None: @property def label(self) -> StringDescriptor: """ - ... + Human-readable alias label (e.g. ``'biso_La'``). Reading this property returns the underlying ``StringDescriptor`` object. Assigning to it updates the @@ -72,19 +79,38 @@ def label(self, value: str) -> None: self._label.value = value @property - def param_uid(self) -> StringDescriptor: + def param(self) -> object | None: + """ + The referenced parameter object, or None before resolution. """ - ... + return self._param_ref + + @property + def param_unique_name(self) -> StringDescriptor: + """ + Unique name of the referenced parameter (for CIF). Reading this property returns the underlying - ``StringDescriptor`` object. Assigning to it updates the - parameter value. + ``StringDescriptor`` object. + """ + return self._param_unique_name + + def _set_param(self, param: object) -> None: """ - return self._param_uid + Store a direct reference to the parameter. - @param_uid.setter - def param_uid(self, value: str) -> None: - self._param_uid.value = value + Also updates ``param_unique_name`` from the parameter's + ``unique_name`` for CIF round-tripping. + """ + object.__setattr__(self, '_param_ref', param) # noqa: PLC2801 + self._param_unique_name.value = param.unique_name + + @property + def parameters(self) -> list: + """ + Descriptors owned by this alias (excludes the param reference). + """ + return [self._label, self._param_unique_name] @AliasesFactory.register @@ -99,3 +125,19 @@ class Aliases(CategoryCollection): def __init__(self) -> None: """Create an empty collection of aliases.""" super().__init__(item_type=Alias) + + def create(self, *, label: str, param: object) -> None: + """ + Create a new alias mapping a label to a parameter. + + Parameters + ---------- + label : str + Human-readable alias name (e.g. ``'biso_La'``). + param : object + The parameter object to reference. + """ + item = Alias() + item.label = label + item._set_param(param) + self.add(item) diff --git a/src/easydiffraction/analysis/categories/constraints/default.py b/src/easydiffraction/analysis/categories/constraints/default.py index 3bb1b77e..63a4264d 100644 --- a/src/easydiffraction/analysis/categories/constraints/default.py +++ b/src/easydiffraction/analysis/categories/constraints/default.py @@ -14,7 +14,6 @@ from easydiffraction.core.category import CategoryCollection from easydiffraction.core.category import CategoryItem from easydiffraction.core.metadata import TypeInfo -from easydiffraction.core.singleton import ConstraintsHandler from easydiffraction.core.validation import AttributeSpec from easydiffraction.core.validation import RegexValidator from easydiffraction.core.variable import StringDescriptor @@ -102,11 +101,27 @@ class Constraints(CategoryCollection): def __init__(self) -> None: """Create an empty constraints collection.""" super().__init__(item_type=Constraint) + self._enabled: bool = False + + @property + def enabled(self) -> bool: + """Whether constraints are currently active.""" + return self._enabled + + def enable(self) -> None: + """Activate constraints so they are applied during fitting.""" + self._enabled = True + + def disable(self) -> None: + """Deactivate constraints without deleting them.""" + self._enabled = False def create(self, *, expression: str) -> None: """ Create a constraint from an expression string. + Automatically enables constraints on the first call. + Parameters ---------- expression : str @@ -116,9 +131,4 @@ def create(self, *, expression: str) -> None: item = Constraint() item.expression = expression self.add(item) - - def _update(self, called_by_minimizer: bool = False) -> None: - del called_by_minimizer - - constraints = ConstraintsHandler.get() - constraints.apply() + self._enabled = True diff --git a/src/easydiffraction/analysis/sequential.py b/src/easydiffraction/analysis/sequential.py new file mode 100644 index 00000000..9c3b45f6 --- /dev/null +++ b/src/easydiffraction/analysis/sequential.py @@ -0,0 +1,740 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +""" +Sequential fitting infrastructure: template, worker, CSV, recovery. +""" + +from __future__ import annotations + +import contextlib +import csv +import multiprocessing as mp +import sys +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass +from dataclasses import replace +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any + +from easydiffraction.io.ascii import extract_data_paths_from_dir +from easydiffraction.utils.enums import VerbosityEnum +from easydiffraction.utils.logging import console +from easydiffraction.utils.logging import log + +if TYPE_CHECKING: + from collections.abc import Callable + +# ------------------------------------------------------------------ +# Template dataclass (picklable for ProcessPoolExecutor) +# ------------------------------------------------------------------ + + +@dataclass(frozen=True) +class SequentialFitTemplate: + """ + Snapshot of everything a worker needs to recreate and fit a project. + + All fields are plain Python types (str, dict, list) so that the + template can be pickled for ``ProcessPoolExecutor``. + """ + + structure_cif: str + experiment_cif: str + initial_params: dict[str, float] + free_param_unique_names: list[str] + alias_defs: list[dict[str, str]] + constraint_defs: list[str] + constraints_enabled: bool + minimizer_tag: str + calculator_tag: str + diffrn_field_names: list[str] + + +# ------------------------------------------------------------------ +# Worker function (module-level for pickling) +# ------------------------------------------------------------------ + + +def _fit_worker( + template: SequentialFitTemplate, + data_path: str, +) -> dict[str, Any]: + """ + Fit a single dataset in isolation. + + Creates a fresh Project, loads the template configuration via CIF, + replaces data from *data_path*, applies initial parameters, fits, + and returns a plain dict of results. + + Parameters + ---------- + template : SequentialFitTemplate + Snapshot of the project configuration. + data_path : str + Path to the data file to fit. + + Returns + ------- + dict[str, Any] + Result dict with keys: ``file_path``, ``fit_success``, + ``chi_squared``, ``reduced_chi_squared``, ``n_iterations``, and + per-parameter ``{unique_name}`` / ``{unique_name}.uncertainty``. + """ + # Lazy import to avoid circular dependencies and keep the module + # importable without heavy imports at top level. + from easydiffraction.project.project import Project # noqa: PLC0415 + + result: dict[str, Any] = {'file_path': data_path} + + try: + # 1. Create a fresh, isolated project + Project._loading = True + try: + project = Project(name='_worker') + finally: + Project._loading = False + + # 2. Load structure from template CIF + project.structures.add_from_cif_str(template.structure_cif) + + # 3. Load experiment from template CIF + # (full config + template data) + project.experiments.add_from_cif_str(template.experiment_cif) + expt = list(project.experiments.values())[0] + + # 4. Replace data from the new data path + expt._load_ascii_data_to_experiment(data_path) + + # 5. Override parameter values from propagated starting values + _apply_param_overrides(project, template.initial_params) + + # 6. Set free flags + _set_free_params(project, template.free_param_unique_names) + + # 7. Apply constraints + if template.constraints_enabled and template.alias_defs: + _apply_constraints( + project, + template.alias_defs, + template.constraint_defs, + ) + + # 8. Set calculator and minimizer + # (internal, no console output) + from easydiffraction.analysis.calculators.factory import CalculatorFactory # noqa: PLC0415 + from easydiffraction.analysis.fitting import Fitter # noqa: PLC0415 + + expt._calculator = CalculatorFactory.create(template.calculator_tag) + expt._calculator_type = template.calculator_tag + project.analysis.fitter = Fitter(template.minimizer_tag) + + # 9. Fit + project.analysis.fit(verbosity='silent') + + # 10. Collect results + result.update(_collect_results(project, template)) + + except Exception as exc: # noqa: BLE001 + result['fit_success'] = False + result['chi_squared'] = None + result['reduced_chi_squared'] = None + result['n_iterations'] = 0 + result['error'] = str(exc) + + return result + + +# ------------------------------------------------------------------ +# Helper functions +# ------------------------------------------------------------------ + + +def _apply_param_overrides( + project: object, + overrides: dict[str, float], +) -> None: + """ + Set parameter values from a ``{unique_name: value}`` dict. + + Parameters + ---------- + project : object + The worker's project instance. + overrides : dict[str, float] + Map of parameter unique names to values. + """ + all_params = project.structures.parameters + project.experiments.parameters + by_name = {p.unique_name: p for p in all_params if hasattr(p, 'unique_name')} + for name, value in overrides.items(): + if name in by_name: + by_name[name].value = value + + +def _set_free_params( + project: object, + free_names: list[str], +) -> None: + """ + Mark parameters as free based on their unique names. + + Parameters + ---------- + project : object + The worker's project instance. + free_names : list[str] + Unique names of parameters to mark as free. + """ + from easydiffraction.core.variable import Parameter # noqa: PLC0415 + + all_params = project.structures.parameters + project.experiments.parameters + free_set = set(free_names) + for p in all_params: + if isinstance(p, Parameter) and hasattr(p, 'unique_name'): + p.free = p.unique_name in free_set + + +def _apply_constraints( + project: object, + alias_defs: list[dict[str, str]], + constraint_defs: list[str], +) -> None: + """ + Recreate aliases and constraints in the worker project. + + Parameters + ---------- + project : object + The worker's project instance. + alias_defs : list[dict[str, str]] + Each dict has ``label`` and ``param_unique_name``. + constraint_defs : list[str] + Constraint expression strings. + """ + all_params = project.structures.parameters + project.experiments.parameters + by_name = {p.unique_name: p for p in all_params if hasattr(p, 'unique_name')} + + for alias_def in alias_defs: + param = by_name.get(alias_def['param_unique_name']) + if param is not None: + project.analysis.aliases.create( + label=alias_def['label'], + param=param, + ) + + for expr in constraint_defs: + project.analysis.constraints.create(expression=expr) + + +def _collect_results( + project: object, + template: SequentialFitTemplate, +) -> dict[str, Any]: + """ + Collect fit results into a plain dict. + + Parameters + ---------- + project : object + The worker's project instance after fitting. + template : SequentialFitTemplate + The template (for knowing which params to collect). + + Returns + ------- + dict[str, Any] + Fit metrics and parameter values/uncertainties. + """ + from easydiffraction.core.variable import Parameter # noqa: PLC0415 + + result: dict[str, Any] = {} + fit_results = project.analysis.fit_results + + if fit_results is not None: + result['fit_success'] = fit_results.success + result['chi_squared'] = fit_results.chi_square + result['reduced_chi_squared'] = fit_results.reduced_chi_square + result['n_iterations'] = project.analysis.fitter.minimizer.tracker.best_iteration or 0 + else: + result['fit_success'] = False + result['chi_squared'] = None + result['reduced_chi_squared'] = None + result['n_iterations'] = 0 + + # Collect all free parameter values and uncertainties + all_params = project.structures.parameters + project.experiments.parameters + free_set = set(template.free_param_unique_names) + result['params'] = {} + for p in all_params: + if isinstance(p, Parameter) and p.unique_name in free_set: + result[p.unique_name] = p.value + result[f'{p.unique_name}.uncertainty'] = p.uncertainty + result['params'][p.unique_name] = p.value + + return result + + +# ------------------------------------------------------------------ +# CSV helpers +# ------------------------------------------------------------------ + +_META_COLUMNS = [ + 'file_path', + 'chi_squared', + 'reduced_chi_squared', + 'fit_success', + 'n_iterations', +] + + +def _build_csv_header( + template: SequentialFitTemplate, +) -> list[str]: + """ + Build the CSV column header list. + + Parameters + ---------- + template : SequentialFitTemplate + The template for diffrn fields and free param names. + + Returns + ------- + list[str] + Ordered list of column names. + """ + header = list(_META_COLUMNS) + header.extend(f'diffrn.{field}' for field in template.diffrn_field_names) + for name in template.free_param_unique_names: + header.append(name) + header.append(f'{name}.uncertainty') + return header + + +def _write_csv_header( + csv_path: Path, + header: list[str], +) -> None: + """ + Create the CSV file and write the header row. + + Parameters + ---------- + csv_path : Path + Path to the CSV file. + header : list[str] + Column names. + """ + with csv_path.open('w', newline='', encoding='utf-8') as f: + writer = csv.DictWriter(f, fieldnames=header) + writer.writeheader() + + +def _append_to_csv( + csv_path: Path, + header: list[str], + results: list[dict[str, Any]], +) -> None: + """ + Append result rows to the CSV file. + + Parameters + ---------- + csv_path : Path + Path to the CSV file. + header : list[str] + Column names (for DictWriter fieldnames). + results : list[dict[str, Any]] + Result dicts from workers. + """ + with csv_path.open('a', newline='', encoding='utf-8') as f: + writer = csv.DictWriter(f, fieldnames=header, extrasaction='ignore') + for result in results: + writer.writerow(result) + + +def _read_csv_for_recovery( + csv_path: Path, +) -> tuple[set[str], dict[str, float] | None]: + """ + Read an existing CSV for crash recovery. + + Parameters + ---------- + csv_path : Path + Path to the CSV file. + + Returns + ------- + tuple[set[str], dict[str, float] | None] + A set of already-fitted file paths and the parameter values from + the last successful row (or ``None`` if no rows). + """ + fitted: set[str] = set() + last_params: dict[str, float] | None = None + + if not csv_path.is_file(): + return fitted, last_params + + with csv_path.open(newline='', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row in reader: + file_path = row.get('file_path', '') + if file_path: + fitted.add(file_path) + if row.get('fit_success', '').lower() == 'true': + # Extract parameter values from this row + params: dict[str, float] = {} + for key, val in row.items(): + if key in _META_COLUMNS: + continue + if key.startswith('diffrn.'): + continue + if key.endswith('.uncertainty'): + continue + if val: + with contextlib.suppress(ValueError, TypeError): + params[key] = float(val) + if params: + last_params = params + + return fitted, last_params + + +# ------------------------------------------------------------------ +# Template builder +# ------------------------------------------------------------------ + + +def _build_template(project: object) -> SequentialFitTemplate: + """ + Build a SequentialFitTemplate from the current project state. + + Parameters + ---------- + project : object + The main project instance (must have exactly 1 structure and 1 + experiment). + + Returns + ------- + SequentialFitTemplate + A frozen, picklable snapshot. + """ + from easydiffraction.core.variable import Parameter # noqa: PLC0415 + + structure = list(project.structures.values())[0] + experiment = list(project.experiments.values())[0] + + # Collect free parameter unique_names and initial values + all_params = project.structures.parameters + project.experiments.parameters + free_names: list[str] = [] + initial_params: dict[str, float] = {} + for p in all_params: + if isinstance(p, Parameter) and not p.constrained and p.free: + free_names.append(p.unique_name) + initial_params[p.unique_name] = p.value + + # Collect alias definitions + alias_defs: list[dict[str, str]] = [ + { + 'label': alias.label.value, + 'param_unique_name': alias.param_unique_name.value, + } + for alias in project.analysis.aliases + ] + + # Collect constraint expressions + constraint_defs: list[str] = [ + constraint.expression.value for constraint in project.analysis.constraints + ] + + # Collect diffrn field names from the experiment + diffrn_field_names: list[str] = [] + if hasattr(experiment, 'diffrn'): + diffrn_field_names.extend( + p.name + for p in experiment.diffrn.parameters + if hasattr(p, 'name') and p.name not in ('type',) + ) + + return SequentialFitTemplate( + structure_cif=structure.as_cif, + experiment_cif=experiment.as_cif, + initial_params=initial_params, + free_param_unique_names=free_names, + alias_defs=alias_defs, + constraint_defs=constraint_defs, + constraints_enabled=project.analysis.constraints.enabled, + minimizer_tag=project.analysis.current_minimizer or 'lmfit', + calculator_tag=experiment.calculator_type, + diffrn_field_names=diffrn_field_names, + ) + + +# ------------------------------------------------------------------ +# Progress reporting +# ------------------------------------------------------------------ + + +def _report_chunk_progress( + chunk_idx: int, + total_chunks: int, + results: list[dict[str, Any]], + verbosity: VerbosityEnum, +) -> None: + """ + Report progress after a chunk completes. + + Parameters + ---------- + chunk_idx : int + 1-based index of the current chunk. + total_chunks : int + Total number of chunks. + results : list[dict[str, Any]] + Results from the chunk. + verbosity : VerbosityEnum + Output verbosity. + """ + if verbosity is VerbosityEnum.SILENT: + return + + num_files = len(results) + successful = [r for r in results if r.get('fit_success')] + if successful: + avg_chi2 = sum(r['reduced_chi_squared'] for r in successful) / len(successful) + chi2_str = f'{avg_chi2:.2f}' + else: + chi2_str = '—' + + if verbosity is VerbosityEnum.SHORT: + status = '✅' if successful else '❌' + print(f'{status} Chunk {chunk_idx}/{total_chunks}: {num_files} files, avg χ² = {chi2_str}') + elif verbosity is VerbosityEnum.FULL: + print( + f'Chunk {chunk_idx}/{total_chunks}: ' + f'{num_files} files, {len(successful)} succeeded, ' + f'avg reduced χ² = {chi2_str}' + ) + for r in results: + status = '✅' if r.get('fit_success') else '❌' + rchi2 = r.get('reduced_chi_squared') + rchi2_str = f'{rchi2:.2f}' if rchi2 is not None else '—' + print(f' {status} {Path(r["file_path"]).name}: χ² = {rchi2_str}') + + +# ------------------------------------------------------------------ +# Main orchestration +# ------------------------------------------------------------------ + + +def fit_sequential( + analysis: object, + data_dir: str, + max_workers: int | str = 1, + chunk_size: int | None = None, + file_pattern: str = '*', + extract_diffrn: Callable | None = None, + verbosity: str | None = None, +) -> None: + """ + Run sequential fitting over all data files in a directory. + + Parameters + ---------- + analysis : object + The ``Analysis`` instance (owns project reference). + data_dir : str + Path to directory containing data files. + max_workers : int | str, default=1 + Number of parallel worker processes. ``1`` = sequential (no + subprocess overhead). ``'auto'`` = physical CPU count. Uses + ``ProcessPoolExecutor`` with ``spawn`` context when > 1. + chunk_size : int | None, default=None + Files per chunk. Default ``None`` uses ``max_workers``. + file_pattern : str, default='*' + Glob pattern to filter files in *data_dir*. + extract_diffrn : Callable | None, default=None + User callback: ``f(file_path) → {diffrn_field: value}``. + verbosity : str | None, default=None + ``'full'``, ``'short'``, ``'silent'``. Default: project + verbosity. + + Raises + ------ + ValueError + If preconditions are not met (e.g. multiple structures, missing + project path, no free parameters). + """ + # Guard against re-entry in spawned child processes. With the + # ``spawn`` multiprocessing context the child re-imports __main__, + # which re-executes the user script and would call fit_sequential + # again, causing infinite process spawning. + if mp.parent_process() is not None: + return + + project = analysis.project + verb = VerbosityEnum(verbosity if verbosity is not None else project.verbosity) + + # ── Preconditions ──────────────────────────────────────────── + if len(project.structures) != 1: + msg = f'Sequential fitting requires exactly 1 structure, found {len(project.structures)}.' + raise ValueError(msg) + + if len(project.experiments) != 1: + msg = ( + f'Sequential fitting requires exactly 1 experiment (the template), ' + f'found {len(project.experiments)}.' + ) + raise ValueError(msg) + + if project.info.path is None: + msg = 'Project must be saved before sequential fitting. Call save_as() first.' + raise ValueError(msg) + + # Discover data files + data_paths = extract_data_paths_from_dir(data_dir, file_pattern=file_pattern) + + from easydiffraction.core.variable import Parameter # noqa: PLC0415 + + free_params = [ + p for p in project.parameters if isinstance(p, Parameter) and not p.constrained and p.free + ] + if not free_params: + msg = 'No free parameters found. Mark at least one parameter as free.' + raise ValueError(msg) + + # ── Build template ─────────────────────────────────────────── + template = _build_template(project) + + # ── CSV setup and crash recovery ───────────────────────────── + csv_path = project.info.path / 'analysis' / 'results.csv' + csv_path.parent.mkdir(parents=True, exist_ok=True) + header = _build_csv_header(template) + + already_fitted, recovered_params = _read_csv_for_recovery(csv_path) + + if already_fitted: + num_skipped = len(already_fitted) + log.info(f'Resuming: {num_skipped} files already fitted, skipping.') + if verb is not VerbosityEnum.SILENT: + print(f'📂 Resuming from CSV: {num_skipped} files already fitted.') + # Seed from recovered params if available + if recovered_params is not None: + template = replace(template, initial_params=recovered_params) + else: + _write_csv_header(csv_path, header) + + # Filter out already-fitted files + remaining = [p for p in data_paths if p not in already_fitted] + if not remaining: + if verb is not VerbosityEnum.SILENT: + print('✅ All files already fitted. Nothing to do.') + return + + # ── Resolve workers and chunk size ─────────────────────────── + if isinstance(max_workers, str) and max_workers == 'auto': + import os # noqa: PLC0415 + + max_workers = os.cpu_count() or 1 + + if not isinstance(max_workers, int) or max_workers < 1: + msg = f"max_workers must be a positive integer or 'auto', got {max_workers!r}" + raise ValueError(msg) + + if chunk_size is None: + chunk_size = max_workers + + # ── Chunk and fit ──────────────────────────────────────────── + chunks = [remaining[i : i + chunk_size] for i in range(0, len(remaining), chunk_size)] + total_chunks = len(chunks) + + if verb is not VerbosityEnum.SILENT: + minimizer_name = analysis.fitter.selection + console.paragraph('Sequential fitting') + console.print(f"🚀 Starting fit process with '{minimizer_name}'...") + console.print( + f'📋 {len(remaining)} files in {total_chunks} chunks (max_workers={max_workers})' + ) + console.print('📈 Goodness-of-fit (reduced χ²):') + + # Create a process pool for parallel dispatch, or a no-op context + # for single-worker mode (avoids process-spawn overhead). + # + # When max_workers > 1 we use ``spawn`` context, which normally + # re-imports ``__main__`` in every child process. If the user runs + # a script without an ``if __name__ == '__main__':`` guard the + # whole script would re-execute in every worker, causing infinite + # process spawning. To prevent this we temporarily hide + # ``__main__.__file__`` and ``__main__.__spec__`` so that the spawn + # bootstrap has no path to re-import the script. ``_fit_worker`` + # lives in this module (not ``__main__``), so it is still resolved + # via normal pickle/import machinery. + _main_mod = sys.modules.get('__main__') + _main_file_bak = getattr(_main_mod, '__file__', None) + _main_spec_bak = getattr(_main_mod, '__spec__', None) + + if max_workers > 1: + # Hide __main__ origin from spawn + if _main_mod is not None and _main_file_bak is not None: + _main_mod.__file__ = None # type: ignore[assignment] + if _main_mod is not None and _main_spec_bak is not None: + _main_mod.__spec__ = None + + spawn_ctx = mp.get_context('spawn') + pool_cm = ProcessPoolExecutor( + max_workers=max_workers, + mp_context=spawn_ctx, + max_tasks_per_child=100, + ) + else: + pool_cm = contextlib.nullcontext() + + try: + with pool_cm as executor: + for chunk_idx, chunk in enumerate(chunks, start=1): + # Dispatch: parallel or sequential + if executor is not None: + templates = [template] * len(chunk) + results = list(executor.map(_fit_worker, templates, chunk)) + else: + results = [_fit_worker(template, path) for path in chunk] + + # Extract diffrn metadata in the main process + if extract_diffrn is not None: + for result in results: + try: + diffrn_values = extract_diffrn(result['file_path']) + for key, val in diffrn_values.items(): + result[f'diffrn.{key}'] = val + except Exception as exc: # noqa: BLE001 + log.warning(f'extract_diffrn failed for {result["file_path"]}: {exc}') + + # Write to CSV + _append_to_csv(csv_path, header, results) + + # Report progress + _report_chunk_progress(chunk_idx, total_chunks, results, verb) + + # Propagate: use last successful file's + # params as starting values + last_ok = None + for r in reversed(results): + if r.get('fit_success') and r.get('params'): + last_ok = r + break + + if last_ok is not None: + template = replace(template, initial_params=last_ok['params']) + finally: + # Restore __main__ attributes + if _main_mod is not None and _main_file_bak is not None: + _main_mod.__file__ = _main_file_bak + if _main_mod is not None and _main_spec_bak is not None: + _main_mod.__spec__ = _main_spec_bak + + if verb is not VerbosityEnum.SILENT: + total_fitted = len(already_fitted) + len(remaining) + print(f'✅ Sequential fitting complete: {total_fitted} files processed.') + print(f'📄 Results saved to: {csv_path}') diff --git a/src/easydiffraction/core/datablock.py b/src/easydiffraction/core/datablock.py index 36c5b589..5d497e4c 100644 --- a/src/easydiffraction/core/datablock.py +++ b/src/easydiffraction/core/datablock.py @@ -91,6 +91,27 @@ def as_cif(self) -> str: self._update_categories() return datablock_item_to_cif(self) + def _cif_for_display(self, max_loop_display: int = 20) -> str: + """ + Return CIF text with loop categories truncated for display. + + Parameters + ---------- + max_loop_display : int, default=20 + Maximum number of rows to show per loop category. + + Returns + ------- + str + CIF representation of this object, with loop categories + truncated to at most *max_loop_display* rows for display + purposes. + """ + from easydiffraction.io.cif.serialize import datablock_item_to_cif # noqa: PLC0415 + + self._update_categories() + return datablock_item_to_cif(self, max_loop_display=max_loop_display) + def help(self) -> None: """Print a summary of public attributes and categories.""" super().help() diff --git a/src/easydiffraction/core/singleton.py b/src/easydiffraction/core/singleton.py index 9d8a1d89..a4ac6b28 100644 --- a/src/easydiffraction/core/singleton.py +++ b/src/easydiffraction/core/singleton.py @@ -3,12 +3,9 @@ from typing import Any from typing import Self -from typing import TypeVar from asteval import Interpreter -T = TypeVar('T', bound='SingletonBase') - # ====================================================================== @@ -33,54 +30,6 @@ def get(cls) -> Self: # ====================================================================== -class UidMapHandler(SingletonBase): - """Global handler to manage UID-to-Parameter object mapping.""" - - def __init__(self) -> None: - # Internal map: uid (str) → Parameter instance - self._uid_map: dict[str, Any] = {} - - def get_uid_map(self) -> dict[str, Any]: - """Return the current UID-to-Parameter map.""" - return self._uid_map - - def add_to_uid_map(self, parameter: object) -> None: - """ - Add a single Parameter or Descriptor object to the UID map. - - Only Descriptor or Parameter instances are allowed (not - Components or others). - """ - from easydiffraction.core.variable import GenericDescriptorBase # noqa: PLC0415 - - if not isinstance(parameter, GenericDescriptorBase): - msg = ( - f'Cannot add object of type {type(parameter).__name__} to UID map. ' - 'Only Descriptor or Parameter instances are allowed.' - ) - raise TypeError(msg) - self._uid_map[parameter.uid] = parameter - - def replace_uid(self, old_uid: str, new_uid: str) -> None: - """ - Replace an existing UID key in the UID map with a new UID. - - Moves the associated parameter from old_uid to new_uid. Raises a - KeyError if the old_uid doesn't exist. - """ - if old_uid not in self._uid_map: - # Only raise if old_uid is not None and not empty - print('DEBUG: replace_uid failed', old_uid, 'current map:', list(self._uid_map.keys())) - msg = f"UID '{old_uid}' not found in the UID map." - raise KeyError(msg) - self._uid_map[new_uid] = self._uid_map.pop(old_uid) - - # TODO: Implement removing from the UID map - - -# ====================================================================== - - # TODO: Implement changing atrr '.constrained' back to False # when removing constraints class ConstraintsHandler(SingletonBase): @@ -94,7 +43,7 @@ class ConstraintsHandler(SingletonBase): def __init__(self) -> None: # Maps alias names - # (like 'biso_La') → ConstraintAlias(param=Parameter) + # (like 'biso_La') → Alias(param=Parameter) self._alias_to_param: dict[str, Any] = {} # Stores raw user-defined constraints indexed by lhs_alias @@ -106,7 +55,7 @@ def __init__(self) -> None: def set_aliases(self, aliases: object) -> None: """ - Set the alias map (name → parameter wrapper). + Set the alias map (name → alias wrapper). Called when user registers parameter aliases like: alias='biso_La', param=model.atom_sites['La'].b_iso @@ -137,25 +86,21 @@ def _parse_constraints(self) -> None: def apply(self) -> None: """ - Evaluate constraints and applies them to dependent parameters. + Evaluate constraints and apply them to dependent parameters. - For each constraint: - Evaluate RHS using current values of - aliases - Locate the dependent parameter by alias → uid → param + For each constraint: + - Evaluate RHS using current values of aliased parameters + - Locate the dependent parameter via direct alias reference - Update its value and mark it as constrained """ if not self._parsed_constraints: return # Nothing to apply - # Retrieve global UID → Parameter object map - uid_map = UidMapHandler.get().get_uid_map() - # Prepare a flat dict of {alias: value} for use in expressions param_values = {} for alias, alias_obj in self._alias_to_param.items(): - uid = alias_obj.param_uid.value - param = uid_map[uid] - value = param.value - param_values[alias] = value + param = alias_obj.param + param_values[alias] = param.value # Create an asteval interpreter for safe expression evaluation ae = Interpreter() @@ -167,8 +112,7 @@ def apply(self) -> None: rhs_value = ae(rhs_expr) # Get the actual parameter object we want to update - dependent_uid = self._alias_to_param[lhs_alias].param_uid.value - param = uid_map[dependent_uid] + param = self._alias_to_param[lhs_alias].param # Update its value and mark it as constrained param._set_value_constrained(rhs_value) diff --git a/src/easydiffraction/core/variable.py b/src/easydiffraction/core/variable.py index 2acce18d..6d987fd1 100644 --- a/src/easydiffraction/core/variable.py +++ b/src/easydiffraction/core/variable.py @@ -3,15 +3,12 @@ from __future__ import annotations -import secrets -import string from typing import TYPE_CHECKING import numpy as np from easydiffraction.core.diagnostic import Diagnostics from easydiffraction.core.guard import GuardedBase -from easydiffraction.core.singleton import UidMapHandler from easydiffraction.core.validation import AttributeSpec from easydiffraction.core.validation import DataTypes from easydiffraction.core.validation import RangeValidator @@ -287,9 +284,6 @@ def __init__( self._constrained_spec = self._BOOL_SPEC_TEMPLATE self._constrained = self._constrained_spec.default - self._uid: str = self._generate_uid() - UidMapHandler.get().add_to_uid_map(self) - def __str__(self) -> str: """Return string representation with uncertainty and free.""" s = GenericDescriptorBase.__str__(self) @@ -301,21 +295,10 @@ def __str__(self) -> str: s += f' (free={self.free})' return f'<{s}>' - @staticmethod - def _generate_uid(length: int = 16) -> str: - letters = string.ascii_lowercase - return ''.join(secrets.choice(letters) for _ in range(length)) - - @property - def uid(self) -> str: - """Stable random identifier for this descriptor.""" - return self._uid - @property def _minimizer_uid(self) -> str: - """Variant of uid that is safe for minimizer engines.""" - # return self.unique_name.replace('.', '__') - return self.uid + """Variant of unique_name that is safe for minimizer engines.""" + return self.unique_name.replace('.', '__') @property def constrained(self) -> bool: @@ -326,12 +309,17 @@ def _set_value_constrained(self, v: object) -> None: """ Set the value from a constraint expression. - Validates against the spec, marks the parent datablock dirty, - and flags the parameter as constrained. Used exclusively by - ``ConstraintsHandler.apply()``. + Bypasses validation and marks the parent datablock dirty, like + ``_set_value_from_minimizer``, because constraints are applied + inside the minimizer loop where trial values may exceed + physical-range validators. Flags the parameter as constrained. + Used exclusively by ``ConstraintsHandler.apply()``. """ - self.value = v + self._value = v self._constrained = True + parent_datablock = self._datablock_item() + if parent_datablock is not None: + parent_datablock._need_categories_update = True @property def free(self) -> bool: diff --git a/src/easydiffraction/datablocks/experiment/item/base.py b/src/easydiffraction/datablocks/experiment/item/base.py index 1ac72d0f..d3a0f4a0 100644 --- a/src/easydiffraction/datablocks/experiment/item/base.py +++ b/src/easydiffraction/datablocks/experiment/item/base.py @@ -129,10 +129,9 @@ def as_cif(self) -> str: def show_as_cif(self) -> None: """Pretty-print the experiment as CIF text.""" - experiment_cif = super().as_cif paragraph_title: str = f"Experiment 🔬 '{self.name}' as cif" console.paragraph(paragraph_title) - render_cif(experiment_cif) + render_cif(self._cif_for_display()) @abstractmethod def _load_ascii_data_to_experiment(self, data_path: str) -> None: diff --git a/src/easydiffraction/datablocks/structure/item/base.py b/src/easydiffraction/datablocks/structure/item/base.py index 80d8f76a..8181f1db 100644 --- a/src/easydiffraction/datablocks/structure/item/base.py +++ b/src/easydiffraction/datablocks/structure/item/base.py @@ -252,4 +252,4 @@ def show(self) -> None: def show_as_cif(self) -> None: """Render the CIF text for this structure in the terminal.""" console.paragraph(f"Structure 🧩 '{self.name}' as cif") - render_cif(self.as_cif) + render_cif(self._cif_for_display()) diff --git a/src/easydiffraction/display/plotting.py b/src/easydiffraction/display/plotting.py index 5b010ea4..92a3a031 100644 --- a/src/easydiffraction/display/plotting.py +++ b/src/easydiffraction/display/plotting.py @@ -570,6 +570,79 @@ def plot_meas_vs_calc( ) def plot_param_series( + self, + csv_path: str, + unique_name: str, + param_descriptor: object, + versus_descriptor: object | None = None, + ) -> None: + """ + Plot a parameter's value across sequential fit results. + + Reads data from the CSV file at *csv_path*. The y-axis values + come from the column named *unique_name*, uncertainties from + ``{unique_name}.uncertainty``. When *versus_descriptor* is + provided, the x-axis uses the corresponding ``diffrn.{name}`` + column; otherwise the row index is used. + + Axis labels are derived from the live descriptor objects + (*param_descriptor* and *versus_descriptor*), which carry + ``.description`` and ``.units`` attributes. + + Parameters + ---------- + csv_path : str + Path to the ``results.csv`` file. + unique_name : str + Unique name of the parameter to plot (CSV column key). + param_descriptor : object + The live parameter descriptor (for axis label / units). + versus_descriptor : object | None, default=None + A diffrn descriptor whose ``.name`` maps to a + ``diffrn.{name}`` CSV column. ``None`` → use row index. + """ + df = pd.read_csv(csv_path) + + if unique_name not in df.columns: + log.warning( + f"Parameter '{unique_name}' not found in CSV columns. " + f'Available: {list(df.columns)}' + ) + return + + y = df[unique_name].astype(float).tolist() + uncert_col = f'{unique_name}.uncertainty' + sy = df[uncert_col].astype(float).tolist() if uncert_col in df.columns else [0.0] * len(y) + + # X-axis: diffrn column or row index + versus_name = versus_descriptor.name if versus_descriptor is not None else None + diffrn_col = f'diffrn.{versus_name}' if versus_name else None + + if diffrn_col and diffrn_col in df.columns: + x = pd.to_numeric(df[diffrn_col], errors='coerce').tolist() + x_label = getattr(versus_descriptor, 'description', None) or versus_name + if hasattr(versus_descriptor, 'units') and versus_descriptor.units: + x_label = f'{x_label} ({versus_descriptor.units})' + else: + x = list(range(1, len(y) + 1)) + x_label = 'Experiment No.' + + # Y-axis label from descriptor + param_units = getattr(param_descriptor, 'units', '') + y_label = f'Parameter value ({param_units})' if param_units else 'Parameter value' + + title = f"Parameter '{unique_name}' across fit results" + + self._backend.plot_scatter( + x=x, + y=y, + sy=sy, + axes_labels=[x_label, y_label], + title=title, + height=self.height, + ) + + def plot_param_series_from_snapshots( self, unique_name: str, versus_name: str | None, @@ -577,21 +650,22 @@ def plot_param_series( parameter_snapshots: dict[str, dict[str, dict]], ) -> None: """ - Plot a parameter's value across sequential fit results. + Plot a parameter's value from in-memory snapshots. + + This is a backward-compatibility method used when no CSV file is + available (e.g. after ``fit()`` in single mode, before PR 13 + adds CSV output to the existing fit loop). Parameters ---------- unique_name : str Unique name of the parameter to plot. versus_name : str | None - Name of the diffrn descriptor to use as the x-axis (e.g. - ``'ambient_temperature'``). When ``None``, the experiment - sequence index is used instead. + Name of the diffrn descriptor for the x-axis. experiments : object Experiments collection for accessing diffrn conditions. parameter_snapshots : dict[str, dict[str, dict]] - Per-experiment parameter value snapshots keyed by experiment - name, then by parameter unique name. + Per-experiment parameter value snapshots. """ x = [] y = [] diff --git a/src/easydiffraction/io/__init__.py b/src/easydiffraction/io/__init__.py index 6ce45a95..4d0c1560 100644 --- a/src/easydiffraction/io/__init__.py +++ b/src/easydiffraction/io/__init__.py @@ -4,4 +4,5 @@ from easydiffraction.io.ascii import extract_data_paths_from_dir from easydiffraction.io.ascii import extract_data_paths_from_zip from easydiffraction.io.ascii import extract_metadata +from easydiffraction.io.ascii import extract_project_from_zip from easydiffraction.io.ascii import load_numeric_block diff --git a/src/easydiffraction/io/ascii.py b/src/easydiffraction/io/ascii.py index ae231a2c..1bba03b0 100644 --- a/src/easydiffraction/io/ascii.py +++ b/src/easydiffraction/io/ascii.py @@ -13,21 +13,86 @@ import numpy as np -def extract_data_paths_from_zip(zip_path: str | Path) -> list[str]: +def extract_project_from_zip( + zip_path: str | Path, + destination: str | Path | None = None, +) -> str: + """ + Extract a project directory from a ZIP archive. + + The archive must contain exactly one directory with a + ``project.cif`` file. Files are extracted into *destination* when + provided, or into a temporary directory that persists for the + lifetime of the process. + + Parameters + ---------- + zip_path : str | Path + Path to the ZIP archive containing the project. + destination : str | Path | None, default=None + Directory to extract into. When ``None``, a temporary directory + is created. + + Returns + ------- + str + Absolute path to the extracted project directory (the directory + that contains ``project.cif``). + + Raises + ------ + FileNotFoundError + If *zip_path* does not exist. + ValueError + If the archive does not contain a ``project.cif`` file. + """ + zip_path = Path(zip_path) + if not zip_path.exists(): + msg = f'ZIP file not found: {zip_path}' + raise FileNotFoundError(msg) + + if destination is not None: + extract_dir = Path(destination) + extract_dir.mkdir(parents=True, exist_ok=True) + else: + extract_dir = Path(tempfile.mkdtemp(prefix='ed_zip_')) + + with zipfile.ZipFile(zip_path, 'r') as zf: + # Determine the project directory from the archive contents + # *before* extraction, so we are not confused by unrelated + # project.cif files already present in the destination. + project_cif_entries = [name for name in zf.namelist() if name.endswith('project.cif')] + if not project_cif_entries: + msg = f'No project.cif found in ZIP archive: {zip_path}' + raise ValueError(msg) + + zf.extractall(extract_dir) + + project_cif_path = extract_dir / project_cif_entries[0] + return str(project_cif_path.parent.resolve()) + + +def extract_data_paths_from_zip( + zip_path: str | Path, + destination: str | Path | None = None, +) -> list[str]: """ Extract all files from a ZIP archive and return their paths. - Files are extracted into a temporary directory that persists for the - lifetime of the process. The returned paths are sorted - lexicographically by file name so that numbered data files (e.g. - ``scan_001.dat``, ``scan_002.dat``) appear in natural order. Hidden - files and directories (names starting with ``'.'`` or ``'__'``) are - excluded. + Files are extracted into *destination* when provided, or into a + temporary directory that persists for the lifetime of the process. + The returned paths are sorted lexicographically by file name so that + numbered data files (e.g. ``scan_001.dat``, ``scan_002.dat``) appear + in natural order. Hidden files and directories (names starting with + ``'.'`` or ``'__'``) are excluded. Parameters ---------- zip_path : str | Path Path to the ZIP archive. + destination : str | Path | None, default=None + Directory to extract files into. When ``None``, a temporary + directory is created. Returns ------- @@ -46,8 +111,12 @@ def extract_data_paths_from_zip(zip_path: str | Path) -> list[str]: msg = f'ZIP file not found: {zip_path}' raise FileNotFoundError(msg) - # TODO: Unify mkdir with other uses in the code - extract_dir = Path(tempfile.mkdtemp(prefix='ed_zip_')) + if destination is not None: + extract_dir = Path(destination) + extract_dir.mkdir(parents=True, exist_ok=True) + else: + # TODO: Unify mkdir with other uses in the code + extract_dir = Path(tempfile.mkdtemp(prefix='ed_zip_')) with zipfile.ZipFile(zip_path, 'r') as zf: zf.extractall(extract_dir) diff --git a/src/easydiffraction/io/cif/serialize.py b/src/easydiffraction/io/cif/serialize.py index 42a215eb..fa035981 100644 --- a/src/easydiffraction/io/cif/serialize.py +++ b/src/easydiffraction/io/cif/serialize.py @@ -35,9 +35,15 @@ def format_value(value: object) -> str: # Converting + # None → CIF unknown marker + if value is None: + value = '?' # Convert ints to floats - if isinstance(value, int): + elif isinstance(value, int): value = float(value) + # Empty strings → CIF unknown marker + elif isinstance(value, str) and not value.strip(): + value = '?' # Strings with whitespace are quoted elif isinstance(value, str) and (' ' in value or '\t' in value): value = f'"{value}"' @@ -59,15 +65,67 @@ def format_value(value: object) -> str: ################## +def format_param_value(param: object) -> str: + """ + Format a parameter value for CIF output, encoding the free flag. + + CIF convention for numeric parameters: + + - Fixed or constrained parameter: plain value, e.g. ``3.89090000`` + - Free parameter without uncertainty: value with empty brackets, + e.g. ``3.89090000()`` + - Free parameter with uncertainty: value with esd in brackets, + e.g. ``3.89090000(200000)`` + + Constrained (dependent) parameters are always written without + brackets, even if their ``free`` flag is ``True``, because they are + not independently varied by the minimizer. + + Non-numeric parameters and descriptors without a ``free`` attribute + are formatted with :func:`format_value`. + + Parameters + ---------- + param : object + A descriptor or parameter exposing ``.value`` and optionally + ``.free``, ``.constrained``, and ``.uncertainty``. + + Returns + ------- + str + Formatted CIF value string. + """ + is_free = getattr(param, 'free', False) + is_constrained = getattr(param, 'constrained', False) + value = param.value # type: ignore[attr-defined] + + if not is_free or is_constrained or not isinstance(value, (int, float)): + return format_value(value) + + precision = 8 + uncertainty = getattr(param, 'uncertainty', None) + formatted_value = f'{float(value):.{precision}f}' + + if uncertainty is not None and uncertainty > 0: + from uncertainties import ufloat as _ufloat # noqa: PLC0415 + + u = _ufloat(float(value), float(uncertainty)) + return f'{u:.{precision}fS}' + + return f'{formatted_value}()' + + def param_to_cif(param: object) -> str: """ Render a single descriptor/parameter to a CIF line. Expects ``param`` to expose ``_cif_handler.names`` and ``value``. + Free parameters are written with uncertainty brackets (see + :func:`format_param_value`). """ tags: Sequence[str] = param._cif_handler.names # type: ignore[attr-defined] main_key: str = tags[0] - return f'{main_key} {format_value(param.value)}' + return f'{main_key} {format_param_value(param)}' def category_item_to_cif(item: object) -> str: @@ -83,12 +141,26 @@ def category_item_to_cif(item: object) -> str: def category_collection_to_cif( collection: object, - max_display: int | None = 20, + max_display: int | None = None, ) -> str: """ Render a CategoryCollection-like object to CIF text. Uses first item to build loop header, then emits rows for each item. + + Parameters + ---------- + collection : object + A ``CategoryCollection``-like object. + max_display : int | None, default=None + When set to a positive integer, truncate the output to at most + this many rows (half from the start, half from the end) with an + ``...`` separator. ``None`` emits all rows. + + Returns + ------- + str + CIF text representing the collection as a loop. """ if not len(collection): return '' @@ -104,31 +176,47 @@ def category_collection_to_cif( # Rows # Limit number of displayed rows if requested - if len(collection) > max_display: + if max_display is not None and len(collection) > max_display: half_display = max_display // 2 for i in range(half_display): item = list(collection.values())[i] - row_vals = [format_value(p.value) for p in item.parameters] + row_vals = [format_param_value(p) for p in item.parameters] lines.append(' '.join(row_vals)) lines.append('...') for i in range(-half_display, 0): item = list(collection.values())[i] - row_vals = [format_value(p.value) for p in item.parameters] + row_vals = [format_param_value(p) for p in item.parameters] lines.append(' '.join(row_vals)) # No limit else: for item in collection.values(): - row_vals = [format_value(p.value) for p in item.parameters] + row_vals = [format_param_value(p) for p in item.parameters] lines.append(' '.join(row_vals)) return '\n'.join(lines) -def datablock_item_to_cif(datablock: object) -> str: +def datablock_item_to_cif( + datablock: object, + max_loop_display: int | None = None, +) -> str: """ Render a DatablockItem-like object to CIF text. Emits a data_ header and then concatenates category CIF sections. + + Parameters + ---------- + datablock : object + A ``DatablockItem``-like object. + max_loop_display : int | None, default=None + When set, truncate loop categories to this many rows. ``None`` + emits all rows (used for serialisation). + + Returns + ------- + str + CIF text representing the datablock as a loop. """ # Local imports to avoid import-time cycles from easydiffraction.core.category import CategoryCollection # noqa: PLC0415 @@ -141,7 +229,11 @@ def datablock_item_to_cif(datablock: object) -> str: parts.extend(v.as_cif for v in vars(datablock).values() if isinstance(v, CategoryItem)) # Then collections - parts.extend(v.as_cif for v in vars(datablock).values() if isinstance(v, CategoryCollection)) + parts.extend( + category_collection_to_cif(v, max_display=max_loop_display) + for v in vars(datablock).values() + if isinstance(v, CategoryCollection) + ) return '\n\n'.join(parts) @@ -161,10 +253,12 @@ def project_info_to_cif(info: object) -> str: if len(info.description) > 60: description = f'\n;\n{info.description}\n;' - else: + elif info.description: description = f'{info.description}' if ' ' in description: description = f"'{description}'" + else: + description = '?' created = f"'{info._created.strftime('%d %b %Y %H:%M:%S')}'" last_modified = f"'{info._last_modified.strftime('%d %b %Y %H:%M:%S')}'" @@ -221,6 +315,135 @@ def summary_to_cif(_summary: object) -> str: return 'To be added...' +def _wrap_in_data_block(cif_text: str, block_name: str = '_') -> str: + """ + Wrap bare CIF key-value pairs in a ``data_`` block header. + + Parameters + ---------- + cif_text : str + CIF text without a ``data_`` header. + block_name : str, default='_' + Name for the CIF data block. + + Returns + ------- + str + CIF text with a ``data_`` header prepended. + """ + return f'data_{block_name}\n\n{cif_text}' + + +def project_info_from_cif(info: object, cif_text: str) -> None: + """ + Populate a ProjectInfo instance from CIF text. + + Reads ``_project.id``, ``_project.title``, and + ``_project.description`` from the given CIF string and sets them on + the *info* object. + + Parameters + ---------- + info : object + The ``ProjectInfo`` instance to populate. + cif_text : str + CIF text content of ``project.cif``. + """ + import gemmi # noqa: PLC0415 + + doc = gemmi.cif.read_string(_wrap_in_data_block(cif_text, 'project')) + block = doc.sole_block() + + _read_cif_string = _make_cif_string_reader(block) + + name = _read_cif_string('_project.id') + if name is not None: + info.name = name + + title = _read_cif_string('_project.title') + if title is not None: + info.title = title + + description = _read_cif_string('_project.description') + if description is not None: + info.description = description + + +def analysis_from_cif(analysis: object, cif_text: str) -> None: + """ + Populate an Analysis instance from CIF text. + + Reads the fitting engine, fit mode, aliases, constraints, and + joint-fit experiment weights from the given CIF string. + + Parameters + ---------- + analysis : object + The ``Analysis`` instance to populate. + cif_text : str + CIF text content of ``analysis.cif``. + """ + import gemmi # noqa: PLC0415 + + doc = gemmi.cif.read_string(_wrap_in_data_block(cif_text, 'analysis')) + block = doc.sole_block() + + _read_cif_string = _make_cif_string_reader(block) + + # Restore minimizer selection + engine = _read_cif_string('_analysis.fitting_engine') + if engine is not None: + from easydiffraction.analysis.fitting import Fitter # noqa: PLC0415 + + analysis.fitter = Fitter(engine) + + # Restore fit mode + analysis.fit_mode.from_cif(block) + + # Restore aliases (loop) + analysis.aliases.from_cif(block) + + # Restore constraints (loop) + analysis.constraints.from_cif(block) + if analysis.constraints._items: + analysis.constraints.enable() + + # Restore joint-fit experiment weights (loop) + analysis._joint_fit_experiments.from_cif(block) + + +def _make_cif_string_reader(block: gemmi.cif.Block) -> object: + """ + Return a helper that reads a single CIF tag as a stripped string. + + Parameters + ---------- + block : gemmi.cif.Block + Parsed CIF data block. + + Returns + ------- + object + A function ``(tag) -> str | None`` that returns the unquoted + value for *tag*, or ``None`` if not found. + """ + + def _read(tag: str) -> str | None: + vals = list(block.find_values(tag)) + if not vals: + return None + raw = vals[0] + # CIF unknown / inapplicable markers + if raw in ('?', '.'): + return None + # Strip surrounding quotes + if len(raw) >= 2 and raw[0] == raw[-1] and raw[0] in {"'", '"'}: + raw = raw[1:-1] + return raw + + return _read + + # TODO: Check the following methods: ###################### @@ -262,13 +485,19 @@ def param_from_cif( # If found, pick the one at the given index raw = found_values[idx] + # CIF unknown / inapplicable markers → keep default + if raw in ('?', '.'): + return + # If numeric, parse with uncertainty if present if self._value_type == DataTypes.NUMERIC: + has_brackets = '(' in raw u = str_to_ufloat(raw) self.value = u.n - if not np.isnan(u.s) and hasattr(self, 'uncertainty'): - self.uncertainty = u.s # type: ignore[attr-defined] - self.free = True # Mark as free if uncertainty is present + if has_brackets and hasattr(self, 'free'): + self.free = True # type: ignore[attr-defined] + if not np.isnan(u.s) and hasattr(self, 'uncertainty'): + self.uncertainty = u.s # type: ignore[attr-defined] # If string, strip quotes if present elif self._value_type == DataTypes.STRING: @@ -363,13 +592,19 @@ def _get_loop(block: object, category_item: object) -> object | None: # param_from_cif raw = array[row_idx][col_idx] + # CIF unknown / inapplicable markers → keep default + if raw in ('?', '.'): + break + # If numeric, parse with uncertainty if present if param._value_type == DataTypes.NUMERIC: + has_brackets = '(' in raw u = str_to_ufloat(raw) param.value = u.n - if not np.isnan(u.s) and hasattr(param, 'uncertainty'): - param.uncertainty = u.s # type: ignore[attr-defined] - param.free = True # Mark as free if uncertainty is present + if has_brackets and hasattr(param, 'free'): + param.free = True # type: ignore[attr-defined] + if not np.isnan(u.s) and hasattr(param, 'uncertainty'): + param.uncertainty = u.s # type: ignore[attr-defined] # If string, strip quotes if present # TODO: Make a helper function for this diff --git a/src/easydiffraction/project/project.py b/src/easydiffraction/project/project.py index 42bb71a0..5bc96e79 100644 --- a/src/easydiffraction/project/project.py +++ b/src/easydiffraction/project/project.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: BSD-3-Clause """Project facade to orchestrate models, experiments, and analysis.""" +from __future__ import annotations + import pathlib import tempfile @@ -32,6 +34,9 @@ class Project(GuardedBase): # ------------------------------------------------------------------ # Initialization # ------------------------------------------------------------------ + # Class-level sentinel: True while load() is constructing a project. + _loading: bool = False + def __init__( self, name: str = 'untitled_project', @@ -48,7 +53,7 @@ def __init__( self._analysis = Analysis(self) self._summary = Summary(self) self._saved = False - self._varname = varname() + self._varname = 'project' if type(self)._loading else varname() self._verbosity: VerbosityEnum = VerbosityEnum.FULL # ------------------------------------------------------------------ @@ -172,15 +177,118 @@ def verbosity(self, value: str) -> None: # Project File I/O # ------------------------------------------ - def load(self, dir_path: str) -> None: + @classmethod + def load(cls, dir_path: str) -> Project: + """ + Load a project from a saved directory. + + Reads ``project.cif``, ``structures/*.cif``, + ``experiments/*.cif``, and ``analysis.cif`` from *dir_path* and + reconstructs the full project state. + + Parameters + ---------- + dir_path : str + Path to the project directory previously created by + :meth:`save_as`. + + Returns + ------- + Project + A fully reconstructed project instance. + + Raises + ------ + FileNotFoundError + If *dir_path* does not exist. + """ + from easydiffraction.io.cif.serialize import analysis_from_cif # noqa: PLC0415 + from easydiffraction.io.cif.serialize import project_info_from_cif # noqa: PLC0415 + + project_path = pathlib.Path(dir_path) + if not project_path.is_dir(): + msg = f"Project directory not found: '{dir_path}'" + raise FileNotFoundError(msg) + + # Create a minimal project. + # Use _loading sentinel to skip varname() inside __init__. + cls._loading = True + try: + project = cls() + finally: + cls._loading = False + project._saved = True + + # 1. Load project info + project_cif_path = project_path / 'project.cif' + if project_cif_path.is_file(): + cif_text = project_cif_path.read_text() + project_info_from_cif(project._info, cif_text) + + project._info.path = project_path + + # 2. Load structures + structures_dir = project_path / 'structures' + if structures_dir.is_dir(): + for cif_file in sorted(structures_dir.glob('*.cif')): + project._structures.add_from_cif_path(str(cif_file)) + + # 3. Load experiments + experiments_dir = project_path / 'experiments' + if experiments_dir.is_dir(): + for cif_file in sorted(experiments_dir.glob('*.cif')): + project._experiments.add_from_cif_path(str(cif_file)) + + # 4. Load analysis + # Check analysis/analysis.cif first (future layout), then + # fall back to analysis.cif at root (current layout). + analysis_cif_path = project_path / 'analysis' / 'analysis.cif' + if not analysis_cif_path.is_file(): + analysis_cif_path = project_path / 'analysis.cif' + if analysis_cif_path.is_file(): + cif_text = analysis_cif_path.read_text() + analysis_from_cif(project._analysis, cif_text) + + # 5. Resolve alias param references + project._resolve_alias_references() + + # 6. Apply symmetry constraints and update categories + for structure in project._structures: + structure._update_categories() + + log.info(f"Project '{project.name}' loaded from '{dir_path}'.") + return project + + def _resolve_alias_references(self) -> None: """ - Load a project from a given directory. + Resolve alias ``param_unique_name`` strings to live objects. - Loads project info, structures, experiments, etc. + After loading structures and experiments from CIF, aliases only + contain the ``param_unique_name`` string. This method builds a + ``{unique_name: param}`` map from all project parameters and + wires each alias's ``_param_ref``. """ - # TODO: load project components from files inside dir_path - msg = 'Project.load() is not implemented yet.' - raise NotImplementedError(msg) + aliases = self._analysis.aliases + if not aliases._items: + return + + # Build unique_name → parameter map + all_params = self._structures.parameters + self._experiments.parameters + param_map: dict[str, object] = {} + for p in all_params: + uname = getattr(p, 'unique_name', None) + if uname is not None: + param_map[uname] = p + + for alias in aliases: + uname = alias.param_unique_name.value + if uname in param_map: + alias._set_param(param_map[uname]) + else: + log.warning( + f"Alias '{alias.label.value}' references unknown " + f"parameter '{uname}'. Reference not resolved." + ) def save(self) -> None: """Save the project into the existing project directory.""" @@ -191,6 +299,11 @@ def save(self) -> None: console.paragraph(f"Saving project 📦 '{self.name}' to") console.print(self.info.path.resolve()) + # Apply constraints so dependent parameters are flagged + # before serialization (constrained params are written + # without brackets). + self._analysis._update_categories() + # Ensure project directory exists self._info.path.mkdir(parents=True, exist_ok=True) @@ -222,9 +335,12 @@ def save(self) -> None: console.print(f'│ └── 📄 {file_name}') # Save analysis - with (self._info.path / 'analysis.cif').open('w') as f: + analysis_dir = self._info.path / 'analysis' + analysis_dir.mkdir(parents=True, exist_ok=True) + with (analysis_dir / 'analysis.cif').open('w') as f: f.write(self.analysis.as_cif()) - console.print('├── 📄 analysis.cif') + console.print('├── 📁 analysis/') + console.print('│ └── 📄 analysis.cif') # Save summary with (self._info.path / 'summary.cif').open('w') as f: @@ -246,6 +362,100 @@ def save_as( self._info.path = dir_path self.save() + def apply_params_from_csv(self, row_index: int) -> None: + """ + Load a single CSV row and apply its parameters to the project. + + Reads the row at *row_index* from ``analysis/results.csv``, + overrides parameter values in the live project, and (for + sequential-fit results where ``file_path`` points to a real + file) reloads the measured data into the template experiment. + + After calling this method, ``plot_meas_vs_calc()`` will show the + fit for that specific dataset. + + Parameters + ---------- + row_index : int + Row index in the CSV file. Supports Python-style negative + indexing (e.g. ``-1`` for the last row). + + Raises + ------ + FileNotFoundError + If ``analysis/results.csv`` does not exist. + IndexError + If *row_index* is out of range. + """ + import pandas as pd # noqa: PLC0415 + + from easydiffraction.analysis.sequential import _META_COLUMNS # noqa: PLC0415 + from easydiffraction.core.variable import Parameter # noqa: PLC0415 + + if self.info.path is None: + msg = 'Project has no saved path. Save the project first.' + raise FileNotFoundError(msg) + + csv_path = pathlib.Path(self.info.path) / 'analysis' / 'results.csv' + if not csv_path.is_file(): + msg = f"Results CSV not found: '{csv_path}'" + raise FileNotFoundError(msg) + + df = pd.read_csv(csv_path) + n_rows = len(df) + + # Support Python-style negative indexing + if row_index < 0: + row_index += n_rows + + if row_index < 0 or row_index >= n_rows: + msg = f'Row index {row_index} out of range (CSV has {n_rows} rows).' + raise IndexError(msg) + + row = df.iloc[row_index] + + # 1. Reload data if file_path points to a real file + file_path = row.get('file_path', '') + if file_path and pathlib.Path(file_path).is_file(): + experiment = list(self.experiments.values())[0] + experiment._load_ascii_data_to_experiment(file_path) + + # 2. Override parameter values + all_params = self.structures.parameters + self.experiments.parameters + param_map = { + p.unique_name: p + for p in all_params + if isinstance(p, Parameter) and hasattr(p, 'unique_name') + } + + skip_cols = set(_META_COLUMNS) + for col_name in df.columns: + if col_name in skip_cols: + continue + if col_name.startswith('diffrn.'): + continue + if col_name.endswith('.uncertainty'): + continue + if col_name in param_map and pd.notna(row[col_name]): + param_map[col_name].value = float(row[col_name]) + + # 3. Apply uncertainties + for col_name in df.columns: + if not col_name.endswith('.uncertainty'): + continue + base_name = col_name.removesuffix('.uncertainty') + if base_name in param_map and pd.notna(row[col_name]): + param_map[base_name].uncertainty = float(row[col_name]) + + # 4. Force recalculation: data was replaced directly (bypassing + # value setters), so the dirty flag may not be set. + for structure in self.structures: + structure._need_categories_update = True + for experiment in self.experiments.values(): + experiment._need_categories_update = True + + log.info(f'Applied parameters from CSV row {row_index} (file: {file_path}).') + # ------------------------------------------ # Plotting # ------------------------------------------ @@ -364,6 +574,11 @@ def plot_param_series(self, param: object, versus: object | None = None) -> None """ Plot a parameter's value across sequential fit results. + When a ``results.csv`` file exists in the project's + ``analysis/`` directory, data is read from CSV. Otherwise, + falls back to in-memory parameter snapshots (produced by + ``fit()`` in single mode). + Parameters ---------- param : object @@ -376,10 +591,27 @@ def plot_param_series(self, param: object, versus: object | None = None) -> None experiment sequence number is used instead. """ unique_name = param.unique_name - versus_name = versus.name if versus is not None else None - self.plotter.plot_param_series( - unique_name, - versus_name, - self.experiments, - self.analysis._parameter_snapshots, - ) + + # Try CSV first (produced by fit_sequential or future fit) + csv_path = None + if self.info.path is not None: + candidate = pathlib.Path(self.info.path) / 'analysis' / 'results.csv' + if candidate.is_file(): + csv_path = str(candidate) + + if csv_path is not None: + self.plotter.plot_param_series( + csv_path=csv_path, + unique_name=unique_name, + param_descriptor=param, + versus_descriptor=versus, + ) + else: + # Fallback: in-memory snapshots from fit() single mode + versus_name = versus.name if versus is not None else None + self.plotter.plot_param_series_from_snapshots( + unique_name, + versus_name, + self.experiments, + self.analysis._parameter_snapshots, + ) diff --git a/src/easydiffraction/utils/utils.py b/src/easydiffraction/utils/utils.py index 4a029917..0108422d 100644 --- a/src/easydiffraction/utils/utils.py +++ b/src/easydiffraction/utils/utils.py @@ -73,7 +73,7 @@ def _fetch_data_index() -> dict: _validate_url(index_url) # macOS: sha256sum index.json - index_hash = 'sha256:f421aab32ec532782dc62f4440a97320e5cec23b9e64f5ae3f8a3e818d013430' + index_hash = 'sha256:dfde966a084579c2103b0d35ed3e8688ddc6941335e251d3e1735a792ca06144' destination_dirname = 'easydiffraction' destination_fname = 'data-index.json' cache_dir = pooch.os_cache(destination_dirname) @@ -705,19 +705,23 @@ def str_to_ufloat(s: str | None, default: float | None = None) -> UFloat: Parse a CIF-style numeric string into a ufloat. Examples of supported input: - "3.566" → ufloat(3.566, nan) - - "3.566(2)" → ufloat(3.566, 0.002) - None → ufloat(default, nan) + "3.566(2)" → ufloat(3.566, 0.002) - "3.566()" → ufloat(3.566, 0.0) - + None → ufloat(default, nan) Behavior: - If the input string contains a value with parentheses (e.g. "3.566(2)"), the number in parentheses is interpreted as an - estimated standard deviation (esd) in the last digit(s). - If the - input string has no parentheses, an uncertainty of NaN is assigned - to indicate "no esd provided". - If parsing fails, the function - falls back to the given ``default`` value with uncertainty NaN. + estimated standard deviation (esd) in the last digit(s). - Empty + parentheses (e.g. "3.566()") are treated as zero uncertainty. - If + the input string has no parentheses, an uncertainty of NaN is + assigned to indicate "no esd provided". - If parsing fails, the + function falls back to the given ``default`` value with uncertainty + NaN. Parameters ---------- s : str | None - Numeric string in CIF format (e.g. "3.566", "3.566(2)") or None. + Numeric string in CIF format (e.g. "3.566", "3.566(2)", + "3.566()") or None. default : float | None, default=None Default value to use if ``s`` is None or parsing fails. @@ -733,6 +737,9 @@ def str_to_ufloat(s: str | None, default: float | None = None) -> UFloat: if '(' not in s and ')' not in s: s = f'{s}(nan)' + elif s.endswith('()'): + # Empty brackets → zero uncertainty (free parameter, no esd yet) + s = s[:-2] + '(0)' try: return ufloat_fromstr(s) except Exception: diff --git a/tests/integration/fitting/test_cif_round_trip.py b/tests/integration/fitting/test_cif_round_trip.py new file mode 100644 index 00000000..b089027b --- /dev/null +++ b/tests/integration/fitting/test_cif_round_trip.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +"""Integration tests for experiment CIF round-trip (as_cif → from_cif_str).""" + +from __future__ import annotations + +import tempfile + +from numpy.testing import assert_almost_equal + +from easydiffraction import ExperimentFactory +from easydiffraction import StructureFactory +from easydiffraction import download_data +from easydiffraction.core.variable import Parameter + +TEMP_DIR = tempfile.gettempdir() + + +def _build_fully_configured_experiment() -> ExperimentFactory: + """ + Create a fully configured powder CWL neutron experiment. + + Includes instrument, peak profile, background, excluded regions, + linked phases, and measured data. + + Returns + ------- + ExperimentBase + A complete experiment ready for CIF round-trip testing. + """ + data_path = download_data(id=3, destination=TEMP_DIR) + expt = ExperimentFactory.from_data_path( + name='hrpt', + data_path=data_path, + ) + # Instrument + expt.instrument.setup_wavelength = 1.494 + expt.instrument.calib_twotheta_offset = 0.6225 + + # Peak profile + expt.peak.broad_gauss_u = 0.0834 + expt.peak.broad_gauss_v = -0.1168 + expt.peak.broad_gauss_w = 0.123 + expt.peak.broad_lorentz_x = 0.0 + expt.peak.broad_lorentz_y = 0.0797 + + # Background + expt.background.create(id='1', x=10, y=170) + expt.background.create(id='2', x=80, y=160) + expt.background.create(id='3', x=165, y=170) + + # Excluded regions + expt.excluded_regions.create(id='1', start=0, end=5) + expt.excluded_regions.create(id='2', start=165, end=180) + + # Linked phases + expt.linked_phases.create(id='lbco', scale=9.0) + + # Free parameters + expt.instrument.calib_twotheta_offset.free = True + expt.linked_phases['lbco'].scale.free = True + expt.background['1'].y.free = True + expt.background['2'].y.free = True + expt.background['3'].y.free = True + + return expt + + +def _collect_param_values(expt: object) -> dict[str, object]: + """ + Collect all parameter values from an experiment. + + Returns a dict keyed by unique_name with the parameter value. + Skips raw data parameters (pd_data.*) since those are large arrays. + """ + result = {} + for p in expt.parameters: + uname = getattr(p, 'unique_name', None) + if uname is None: + continue + # Skip raw data arrays + if 'pd_data.' in uname: + continue + result[uname] = p.value + return result + + +def _collect_free_flags(expt: object) -> dict[str, bool]: + """Return {unique_name: free} for fittable parameters.""" + return { + p.unique_name: p.free + for p in expt.parameters + if isinstance(p, Parameter) and not p.unique_name.startswith('pd_data.') + } + + +# ------------------------------------------------------------------ +# Test 1: Experiment CIF round-trip preserves all parameter values +# ------------------------------------------------------------------ + + +def test_experiment_cif_round_trip_preserves_parameters() -> None: + """ + Every parameter value must survive an as_cif → from_cif_str cycle. + + Creates a fully configured experiment, serialises it to CIF, + reconstructs from CIF, and compares all parameter values. + """ + original = _build_fully_configured_experiment() + + # Serialise + cif_str = original.as_cif + + # Reconstruct + loaded = ExperimentFactory.from_cif_str(cif_str) + + # Compare parameter values + orig_params = _collect_param_values(original) + loaded_params = _collect_param_values(loaded) + + for name, orig_val in orig_params.items(): + assert name in loaded_params, f'Parameter {name} missing after round-trip' + loaded_val = loaded_params[name] + if isinstance(orig_val, float): + assert_almost_equal( + loaded_val, + orig_val, + decimal=4, + err_msg=f'Value mismatch for {name}', + ) + else: + assert loaded_val == orig_val, ( + f'Value mismatch for {name}: expected {orig_val!r}, got {loaded_val!r}' + ) + + +# ------------------------------------------------------------------ +# Test 2: Free flags survive the round-trip +# ------------------------------------------------------------------ + + +def test_experiment_cif_round_trip_preserves_free_flags() -> None: + """ + Free flags must survive an as_cif → from_cif_str cycle. + + Parameters marked as free on the original experiment must also be + free on the reconstructed experiment. + """ + original = _build_fully_configured_experiment() + + cif_str = original.as_cif + loaded = ExperimentFactory.from_cif_str(cif_str) + + orig_free = _collect_free_flags(original) + loaded_free = _collect_free_flags(loaded) + + for name, orig_flag in orig_free.items(): + if name in loaded_free: + assert loaded_free[name] == orig_flag, ( + f'Free flag mismatch for {name}: expected {orig_flag}, got {loaded_free[name]}' + ) + + +# ------------------------------------------------------------------ +# Test 3: Categories survive the round-trip +# ------------------------------------------------------------------ + + +def test_experiment_cif_round_trip_preserves_categories() -> None: + """ + Category collections (background, excluded regions, linked phases) + must preserve their item count after a round-trip. + """ + original = _build_fully_configured_experiment() + + cif_str = original.as_cif + loaded = ExperimentFactory.from_cif_str(cif_str) + + # Background points + assert len(loaded.background) == len(original.background), ( + f'Background count mismatch: ' + f'expected {len(original.background)}, got {len(loaded.background)}' + ) + + # Excluded regions + assert len(loaded.excluded_regions) == len(original.excluded_regions), ( + f'Excluded regions count mismatch: ' + f'expected {len(original.excluded_regions)}, ' + f'got {len(loaded.excluded_regions)}' + ) + + # Linked phases + assert len(loaded.linked_phases) == len(original.linked_phases), ( + f'Linked phases count mismatch: ' + f'expected {len(original.linked_phases)}, ' + f'got {len(loaded.linked_phases)}' + ) + + +# ------------------------------------------------------------------ +# Test 4: Data points survive the round-trip +# ------------------------------------------------------------------ + + +def test_experiment_cif_round_trip_preserves_data() -> None: + """ + Measured data points must survive an as_cif → from_cif_str cycle. + + The number of data points and the first/last values must match. + """ + original = _build_fully_configured_experiment() + + cif_str = original.as_cif + loaded = ExperimentFactory.from_cif_str(cif_str) + + # Number of data points + assert len(loaded.data) == len(original.data), ( + f'Data point count mismatch: expected {len(original.data)}, got {len(loaded.data)}' + ) + + # First and last data point two_theta and intensity_meas + orig_first = list(original.data.values())[0] + loaded_first = list(loaded.data.values())[0] + orig_last = list(original.data.values())[-1] + loaded_last = list(loaded.data.values())[-1] + + assert_almost_equal( + loaded_first.two_theta.value, + orig_first.two_theta.value, + decimal=4, + err_msg='First data point two_theta mismatch', + ) + assert_almost_equal( + loaded_last.two_theta.value, + orig_last.two_theta.value, + decimal=4, + err_msg='Last data point two_theta mismatch', + ) + assert_almost_equal( + loaded_first.intensity_meas.value, + orig_first.intensity_meas.value, + decimal=2, + err_msg='First data point intensity_meas mismatch', + ) + + +# ------------------------------------------------------------------ +# Test 5: Structure CIF round-trip preserves all parameter values +# ------------------------------------------------------------------ + + +def test_structure_cif_round_trip_preserves_parameters() -> None: + """ + Every structure parameter must survive an as_cif → from_cif_str + cycle, including atom sites with symmetry constraints. + """ + original = StructureFactory.from_scratch(name='lbco') + original.space_group.name_h_m = 'P m -3 m' + original.cell.length_a = 3.8909 + original.atom_sites.create( + label='La', + type_symbol='La', + fract_x=0, + fract_y=0, + fract_z=0, + wyckoff_letter='a', + occupancy=0.5, + b_iso=0.5, + ) + original.atom_sites.create( + label='Co', + type_symbol='Co', + fract_x=0.5, + fract_y=0.5, + fract_z=0.5, + wyckoff_letter='b', + b_iso=0.5, + ) + original.atom_sites.create( + label='O', + type_symbol='O', + fract_x=0, + fract_y=0.5, + fract_z=0.5, + wyckoff_letter='c', + b_iso=0.5, + ) + # Apply symmetry constraints before serialisation + original._update_categories() + + cif_str = original.as_cif + loaded = StructureFactory.from_cif_str(cif_str) + # Apply symmetry on loaded to match original state + loaded._update_categories() + + # Compare cell parameters + assert_almost_equal( + loaded.cell.length_a.value, + original.cell.length_a.value, + decimal=6, + ) + + # Compare space group + assert loaded.space_group.name_h_m.value == original.space_group.name_h_m.value + + # Compare atom sites count and values + assert len(loaded.atom_sites) == len(original.atom_sites) + for label in ['La', 'Co', 'O']: + orig_site = original.atom_sites[label] + loaded_site = loaded.atom_sites[label] + assert_almost_equal( + loaded_site.fract_x.value, + orig_site.fract_x.value, + decimal=6, + err_msg=f'fract_x mismatch for {label}', + ) + assert_almost_equal( + loaded_site.b_iso.value, + orig_site.b_iso.value, + decimal=4, + err_msg=f'b_iso mismatch for {label}', + ) diff --git a/tests/integration/fitting/test_powder-diffraction_constant-wavelength.py b/tests/integration/fitting/test_powder-diffraction_constant-wavelength.py index 7a98c15d..cb5e640a 100644 --- a/tests/integration/fitting/test_powder-diffraction_constant-wavelength.py +++ b/tests/integration/fitting/test_powder-diffraction_constant-wavelength.py @@ -277,28 +277,25 @@ def test_single_fit_neutron_pd_cwl_lbco_with_constraints() -> None: # Set aliases for parameters project.analysis.aliases.create( label='biso_La', - param_uid=atom_sites['La'].b_iso.uid, + param=atom_sites['La'].b_iso, ) project.analysis.aliases.create( label='biso_Ba', - param_uid=atom_sites['Ba'].b_iso.uid, + param=atom_sites['Ba'].b_iso, ) project.analysis.aliases.create( label='occ_La', - param_uid=atom_sites['La'].occupancy.uid, + param=atom_sites['La'].occupancy, ) project.analysis.aliases.create( label='occ_Ba', - param_uid=atom_sites['Ba'].occupancy.uid, + param=atom_sites['Ba'].occupancy, ) # Set constraints project.analysis.constraints.create(expression='biso_Ba = biso_La') project.analysis.constraints.create(expression='occ_Ba = 1 - occ_La') - # Apply constraints - project.analysis.apply_constraints() - # Perform fit project.analysis.fit() @@ -482,6 +479,56 @@ def test_fit_neutron_pd_cwl_hs() -> None: ) +def test_single_fit_neutron_pd_cwl_lbco_with_constraints_from_project() -> None: + import easydiffraction as ed + + # Create a project from CIF files + project = ed.Project() + project.structures.add_from_cif_path(ed.download_data(id=1, destination='data')) + project.experiments.add_from_cif_path(ed.download_data(id=2, destination='data')) + + # Set constraints + project.analysis.aliases.create( + label='biso_La', + param=project.structures['lbco'].atom_sites['La'].b_iso, + ) + project.analysis.aliases.create( + label='biso_Ba', + param=project.structures['lbco'].atom_sites['Ba'].b_iso, + ) + + project.analysis.aliases.create( + label='occ_La', + param=project.structures['lbco'].atom_sites['La'].occupancy, + ) + project.analysis.aliases.create( + label='occ_Ba', + param=project.structures['lbco'].atom_sites['Ba'].occupancy, + ) + + project.analysis.constraints.create(expression='biso_Ba = biso_La') + project.analysis.constraints.create(expression='occ_Ba = 1 - occ_La') + + # More fit patams + project.structures['lbco'].atom_sites['La'].occupancy.free = True + + # Save to a directory + project.save_as('lbco_project') + + # Load Project from Directory + project = ed.Project.load('lbco_project') + + # Perform Analysis + project.analysis.fit() + + # Compare fit quality + assert_almost_equal( + project.analysis.fit_results.reduced_chi_square, + desired=1.28, + decimal=1, + ) + + if __name__ == '__main__': test_fit_neutron_pd_cwl_hs() test_single_fit_neutron_pd_cwl_lbco() diff --git a/tests/integration/fitting/test_project_load.py b/tests/integration/fitting/test_project_load.py new file mode 100644 index 00000000..789482f3 --- /dev/null +++ b/tests/integration/fitting/test_project_load.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +"""Integration tests for Project save → load round-trip.""" + +from __future__ import annotations + +import tempfile + +from numpy.testing import assert_almost_equal + +from easydiffraction import ExperimentFactory +from easydiffraction import Project +from easydiffraction import StructureFactory +from easydiffraction import download_data + +TEMP_DIR = tempfile.gettempdir() + + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _create_lbco_project() -> Project: + """ + Build a complete LBCO project ready for fitting. + + Returns a project with one structure, one experiment (with data), + instrument settings, peak profile, background, linked phases, free + parameters, aliases, and constraints. + """ + # Structure + model = StructureFactory.from_scratch(name='lbco') + model.space_group.name_h_m = 'P m -3 m' + model.cell.length_a = 3.8909 + model.atom_sites.create( + label='La', + type_symbol='La', + fract_x=0, + fract_y=0, + fract_z=0, + wyckoff_letter='a', + occupancy=0.5, + b_iso=0.5, + ) + model.atom_sites.create( + label='Ba', + type_symbol='Ba', + fract_x=0, + fract_y=0, + fract_z=0, + wyckoff_letter='a', + occupancy=0.5, + b_iso=0.5, + ) + model.atom_sites.create( + label='Co', + type_symbol='Co', + fract_x=0.5, + fract_y=0.5, + fract_z=0.5, + wyckoff_letter='b', + b_iso=0.5, + ) + model.atom_sites.create( + label='O', + type_symbol='O', + fract_x=0, + fract_y=0.5, + fract_z=0.5, + wyckoff_letter='c', + b_iso=0.5, + ) + + # Experiment + data_path = download_data(id=3, destination=TEMP_DIR) + expt = ExperimentFactory.from_data_path( + name='hrpt', + data_path=data_path, + ) + expt.instrument.setup_wavelength = 1.494 + expt.instrument.calib_twotheta_offset = 0.6225 + expt.peak.broad_gauss_u = 0.0834 + expt.peak.broad_gauss_v = -0.1168 + expt.peak.broad_gauss_w = 0.123 + expt.peak.broad_lorentz_x = 0 + expt.peak.broad_lorentz_y = 0.0797 + expt.background.create(id='1', x=10, y=170) + expt.background.create(id='2', x=165, y=170) + expt.linked_phases.create(id='lbco', scale=9.0) + + # Project assembly + project = Project(name='lbco_project') + project.structures.add(model) + project.experiments.add(expt) + + # Free parameters + model.cell.length_a.free = True + expt.linked_phases['lbco'].scale.free = True + expt.instrument.calib_twotheta_offset.free = True + expt.background['1'].y.free = True + expt.background['2'].y.free = True + + # Aliases and constraints + project.analysis.aliases.create( + label='biso_La', + param=model.atom_sites['La'].b_iso, + ) + project.analysis.aliases.create( + label='biso_Ba', + param=model.atom_sites['Ba'].b_iso, + ) + project.analysis.constraints.create(expression='biso_Ba = biso_La') + + return project + + +def _collect_param_snapshot(project: Project) -> dict[str, float]: + """Return ``{unique_name: value}`` for model parameters (excluding raw data).""" + return { + p.unique_name: p.value + for p in project.parameters + if not p.unique_name.startswith('pd_data.') + } + + +def _collect_free_flags(project: Project) -> dict[str, bool]: + """Return ``{unique_name: free}`` for fittable parameters.""" + from easydiffraction.core.variable import Parameter # noqa: PLC0415 + + return {p.unique_name: p.free for p in project.parameters if isinstance(p, Parameter)} + + +# ------------------------------------------------------------------ +# Test 1: save → load → compare all parameters +# ------------------------------------------------------------------ + + +def test_save_load_round_trip_preserves_parameters(tmp_path) -> None: + """ + Every parameter value must survive a save → load cycle. + + Also verifies project info, free flags, aliases, and constraints. + """ + original = _create_lbco_project() + # Apply symmetry constraints so snapshot matches the loaded state + # (load() calls _update_categories which applies symmetry). + for structure in original.structures: + structure._update_categories() + original_params = _collect_param_snapshot(original) + original_free = _collect_free_flags(original) + + # Save + proj_dir = str(tmp_path / 'lbco_project') + original.save_as(proj_dir) + + # Load + loaded = Project.load(proj_dir) + + # Compare project info + assert loaded.name == original.name + assert loaded.info.title == original.info.title + + # Compare structures + assert loaded.structures.names == original.structures.names + orig_s = original.structures['lbco'] + load_s = loaded.structures['lbco'] + assert load_s.space_group.name_h_m.value == orig_s.space_group.name_h_m.value + assert_almost_equal(load_s.cell.length_a.value, orig_s.cell.length_a.value, decimal=6) + assert len(load_s.atom_sites) == len(orig_s.atom_sites) + + # Compare experiments + assert loaded.experiments.names == original.experiments.names + + # Compare all parameter values + loaded_params = _collect_param_snapshot(loaded) + for name, orig_val in original_params.items(): + assert name in loaded_params, f'Parameter {name} missing after load' + if isinstance(orig_val, float): + assert_almost_equal( + loaded_params[name], + orig_val, + decimal=6, + err_msg=f'Mismatch for {name}', + ) + else: + assert loaded_params[name] == orig_val, f'Mismatch for {name}' + + # Compare free flags + loaded_free = _collect_free_flags(loaded) + for name, orig_flag in original_free.items(): + if name in loaded_free: + assert loaded_free[name] == orig_flag, ( + f'Free flag mismatch for {name}: expected {orig_flag}, got {loaded_free[name]}' + ) + + # Compare aliases + assert len(loaded.analysis.aliases) == len(original.analysis.aliases) + for orig_alias in original.analysis.aliases: + label = orig_alias.label.value + loaded_alias = loaded.analysis.aliases[label] + assert loaded_alias.param_unique_name.value == orig_alias.param_unique_name.value + assert loaded_alias.param is not None, f"Alias '{label}' param reference not resolved" + + # Compare constraints + assert len(loaded.analysis.constraints) == len(original.analysis.constraints) + for i, orig_c in enumerate(original.analysis.constraints): + assert loaded.analysis.constraints[i].expression.value == orig_c.expression.value + assert loaded.analysis.constraints.enabled is True + + # Compare analysis settings + assert loaded.analysis.current_minimizer == original.analysis.current_minimizer + assert loaded.analysis.fit_mode.mode.value == original.analysis.fit_mode.mode.value + + +# ------------------------------------------------------------------ +# Test 2: create → fit → save → load → fit → compare χ² +# ------------------------------------------------------------------ + + +def test_save_load_round_trip_preserves_fit_quality(tmp_path) -> None: + """ + A loaded project must produce the same χ² as the original. + + Fits the original project, saves it, loads it back, fits again, + and compares reduced χ² values. + """ + # Create and fit the original project + original = _create_lbco_project() + original.analysis.fit(verbosity='silent') + original_chi2 = original.analysis.fit_results.reduced_chi_square + + # Save the fitted project + proj_dir = str(tmp_path / 'lbco_fitted') + original.save_as(proj_dir) + + # Load + loaded = Project.load(proj_dir) + + # Fit the loaded project + loaded.analysis.fit(verbosity='silent') + loaded_chi2 = loaded.analysis.fit_results.reduced_chi_square + + # The χ² values should be very close (same starting point, + # same data, same model) + assert_almost_equal(loaded_chi2, original_chi2, decimal=1) diff --git a/tests/integration/fitting/test_sequential.py b/tests/integration/fitting/test_sequential.py new file mode 100644 index 00000000..12c14bea --- /dev/null +++ b/tests/integration/fitting/test_sequential.py @@ -0,0 +1,372 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +"""Integration tests for Analysis.fit_sequential().""" + +from __future__ import annotations + +import csv +import shutil +import tempfile +from pathlib import Path + +import pytest +from numpy.testing import assert_almost_equal + +from easydiffraction import ExperimentFactory +from easydiffraction import Project +from easydiffraction import StructureFactory +from easydiffraction import download_data +from easydiffraction.utils.enums import VerbosityEnum + +TEMP_DIR = tempfile.gettempdir() + + +def _create_sequential_project(tmp_path: Path) -> tuple[Project, str]: + """ + Build a project for sequential fitting and save it. + + Returns the project and the path to a data directory with a few + copies of the same data file (to simulate a scan). + """ + # Structure + model = StructureFactory.from_scratch(name='lbco') + model.space_group.name_h_m = 'P m -3 m' + model.cell.length_a = 3.8909 + model.atom_sites.create( + label='La', + type_symbol='La', + fract_x=0, + fract_y=0, + fract_z=0, + wyckoff_letter='a', + occupancy=0.5, + b_iso=0.5, + ) + model.atom_sites.create( + label='Ba', + type_symbol='Ba', + fract_x=0, + fract_y=0, + fract_z=0, + wyckoff_letter='a', + occupancy=0.5, + b_iso=0.5, + ) + model.atom_sites.create( + label='Co', + type_symbol='Co', + fract_x=0.5, + fract_y=0.5, + fract_z=0.5, + wyckoff_letter='b', + b_iso=0.5, + ) + model.atom_sites.create( + label='O', + type_symbol='O', + fract_x=0, + fract_y=0.5, + fract_z=0.5, + wyckoff_letter='c', + b_iso=0.5, + ) + + # Experiment (template) + data_path = download_data(id=3, destination=TEMP_DIR) + expt = ExperimentFactory.from_data_path( + name='template', + data_path=data_path, + verbosity=VerbosityEnum.SILENT, + ) + expt.instrument.setup_wavelength = 1.494 + expt.instrument.calib_twotheta_offset = 0.6225 + expt.peak.broad_gauss_u = 0.0834 + expt.peak.broad_gauss_v = -0.1168 + expt.peak.broad_gauss_w = 0.123 + expt.peak.broad_lorentz_x = 0 + expt.peak.broad_lorentz_y = 0.0797 + expt.background.create(id='1', x=10, y=170) + expt.background.create(id='2', x=165, y=170) + expt.linked_phases.create(id='lbco', scale=9.0) + + # Project assembly + project = Project(name='seq_test') + project.structures.add(model) + project.experiments.add(expt) + + # Free parameters + model.cell.length_a.free = True + expt.linked_phases['lbco'].scale.free = True + expt.instrument.calib_twotheta_offset.free = True + expt.background['1'].y.free = True + expt.background['2'].y.free = True + + # Initial fit on the template + project.analysis.fit(verbosity='silent') + + # Save project + proj_dir = str(tmp_path / 'seq_project') + project.save_as(proj_dir) + + # Create a data directory with copies of the same data file + data_dir = tmp_path / 'scan_data' + data_dir.mkdir() + for i in range(3): + shutil.copy(data_path, data_dir / f'scan_{i + 1:03d}.xye') + + return project, str(data_dir) + + +# ------------------------------------------------------------------ +# Test 1: Basic sequential fit produces CSV +# ------------------------------------------------------------------ + + +def test_fit_sequential_produces_csv(tmp_path) -> None: + """fit_sequential creates a results.csv with one row per file.""" + project, data_dir = _create_sequential_project(tmp_path) + + project.analysis.fit_sequential( + data_dir=data_dir, + verbosity='silent', + ) + + csv_path = project.info.path / 'analysis' / 'results.csv' + assert csv_path.is_file(), 'results.csv was not created' + + with csv_path.open() as f: + reader = csv.DictReader(f) + rows = list(reader) + + assert len(rows) == 3, f'Expected 3 rows, got {len(rows)}' + + # Each row should have fit_success + for row in rows: + assert row['fit_success'] == 'True', f'Fit failed for {row["file_path"]}' + + # Each row should have parameter values + assert 'lbco.cell.length_a' in rows[0] + assert rows[0]['lbco.cell.length_a'] != '' + + +# ------------------------------------------------------------------ +# Test 2: Crash recovery skips already-fitted files +# ------------------------------------------------------------------ + + +def test_fit_sequential_crash_recovery(tmp_path) -> None: + """Running fit_sequential twice does not re-fit already-fitted files.""" + project, data_dir = _create_sequential_project(tmp_path) + + # First run: fit all 3 files + project.analysis.fit_sequential( + data_dir=data_dir, + verbosity='silent', + ) + + csv_path = project.info.path / 'analysis' / 'results.csv' + with csv_path.open() as f: + rows_first = list(csv.DictReader(f)) + assert len(rows_first) == 3 + + # Second run: should skip all 3 files + project.analysis.fit_sequential( + data_dir=data_dir, + verbosity='silent', + ) + + with csv_path.open() as f: + rows_second = list(csv.DictReader(f)) + # Still 3 rows — no duplicates + assert len(rows_second) == 3 + + +# ------------------------------------------------------------------ +# Test 3: Parameter propagation +# ------------------------------------------------------------------ + + +def test_fit_sequential_parameter_propagation(tmp_path) -> None: + """Parameters from one fit propagate to the next.""" + project, data_dir = _create_sequential_project(tmp_path) + + project.analysis.fit_sequential( + data_dir=data_dir, + verbosity='silent', + ) + + csv_path = project.info.path / 'analysis' / 'results.csv' + with csv_path.open() as f: + rows = list(csv.DictReader(f)) + + # All rows should have similar parameter values (same data) + vals = [float(r['lbco.cell.length_a']) for r in rows] + for v in vals: + assert_almost_equal(v, vals[0], decimal=3) + + +# ------------------------------------------------------------------ +# Test 4: extract_diffrn callback +# ------------------------------------------------------------------ + + +def test_fit_sequential_with_diffrn_callback(tmp_path) -> None: + """extract_diffrn callback populates diffrn columns in CSV.""" + project, data_dir = _create_sequential_project(tmp_path) + + temperatures = {'scan_001.xye': 300.0, 'scan_002.xye': 350.0, 'scan_003.xye': 400.0} + + def extract_diffrn(file_path: str) -> dict[str, float]: + name = Path(file_path).name + return {'ambient_temperature': temperatures.get(name, 0.0)} + + project.analysis.fit_sequential( + data_dir=data_dir, + extract_diffrn=extract_diffrn, + verbosity='silent', + ) + + csv_path = project.info.path / 'analysis' / 'results.csv' + with csv_path.open() as f: + rows = list(csv.DictReader(f)) + + # Check that temperature column is present and populated + for row in rows: + name = Path(row['file_path']).name + if 'diffrn.ambient_temperature' in row: + expected = temperatures.get(name, 0.0) + assert_almost_equal(float(row['diffrn.ambient_temperature']), expected) + + +# ------------------------------------------------------------------ +# Test 5: Precondition checks +# ------------------------------------------------------------------ + + +def test_fit_sequential_requires_saved_project(tmp_path) -> None: + """fit_sequential raises if project hasn't been saved.""" + data_path = download_data(id=3, destination=TEMP_DIR) + model = StructureFactory.from_scratch(name='s') + expt = ExperimentFactory.from_data_path( + name='e', + data_path=data_path, + verbosity=VerbosityEnum.SILENT, + ) + expt.linked_phases.create(id='s', scale=1.0) + expt.linked_phases['s'].scale.free = True + project = Project(name='unsaved') + project.structures.add(model) + project.experiments.add(expt) + + with pytest.raises(ValueError, match='must be saved'): + project.analysis.fit_sequential(data_dir=str(tmp_path)) + + +def test_fit_sequential_requires_one_structure(tmp_path) -> None: + """fit_sequential raises if no structures exist.""" + project = Project(name='no_struct') + project.save_as(str(tmp_path / 'proj')) + + with pytest.raises(ValueError, match='exactly 1 structure'): + project.analysis.fit_sequential(data_dir=str(tmp_path)) + + +def test_fit_sequential_requires_one_experiment(tmp_path) -> None: + """fit_sequential raises if no experiments exist.""" + model = StructureFactory.from_scratch(name='s') + project = Project(name='no_expt') + project.structures.add(model) + project.save_as(str(tmp_path / 'proj')) + + with pytest.raises(ValueError, match='exactly 1 experiment'): + project.analysis.fit_sequential(data_dir=str(tmp_path)) + + +# ------------------------------------------------------------------ +# Test 6: Parallel sequential fit (max_workers=2) +# ------------------------------------------------------------------ + + +def test_fit_sequential_parallel(tmp_path) -> None: + """fit_sequential with max_workers=2 produces correct CSV.""" + project, data_dir = _create_sequential_project(tmp_path) + + project.analysis.fit_sequential( + data_dir=data_dir, + max_workers=2, + verbosity='silent', + ) + + csv_path = project.info.path / 'analysis' / 'results.csv' + assert csv_path.is_file(), 'results.csv was not created' + + with csv_path.open() as f: + reader = csv.DictReader(f) + rows = list(reader) + + assert len(rows) == 3, f'Expected 3 rows, got {len(rows)}' + + for row in rows: + assert row['fit_success'] == 'True', f'Fit failed for {row["file_path"]}' + + # Parameter values should be present and reasonable + assert 'lbco.cell.length_a' in rows[0] + vals = [float(r['lbco.cell.length_a']) for r in rows] + for v in vals: + assert_almost_equal(v, vals[0], decimal=3) + + +# ------------------------------------------------------------------ +# Test 7: Dataset replay from CSV (apply_params_from_csv) +# ------------------------------------------------------------------ + + +def test_apply_params_from_csv_loads_data_and_params(tmp_path) -> None: + """apply_params_from_csv overrides params and reloads data.""" + project, data_dir = _create_sequential_project(tmp_path) + + project.analysis.fit_sequential( + data_dir=data_dir, + verbosity='silent', + ) + + csv_path = project.info.path / 'analysis' / 'results.csv' + with csv_path.open() as f: + rows = list(csv.DictReader(f)) + + # Read the expected cell_length_a from CSV row 1 + expected_a = float(rows[1]['lbco.cell.length_a']) + + # Apply params from row 1 + project.apply_params_from_csv(row_index=1) + + # Verify the parameter value was overridden + model = list(project.structures.values())[0] + assert_almost_equal(model.cell.length_a.value, expected_a, decimal=5) + + # Verify that the experiment has measured data loaded + # (from the file_path in that CSV row) + expt = list(project.experiments.values())[0] + assert expt.data.intensity_meas is not None + + +def test_apply_params_from_csv_raises_on_missing_csv(tmp_path) -> None: + """apply_params_from_csv raises if no CSV exists.""" + project = Project(name='no_csv') + project.save_as(str(tmp_path / 'proj')) + + with pytest.raises(FileNotFoundError, match='Results CSV not found'): + project.apply_params_from_csv(row_index=0) + + +def test_apply_params_from_csv_raises_on_bad_index(tmp_path) -> None: + """apply_params_from_csv raises on out-of-range index.""" + project, data_dir = _create_sequential_project(tmp_path) + + project.analysis.fit_sequential( + data_dir=data_dir, + verbosity='silent', + ) + + with pytest.raises(IndexError, match='out of range'): + project.apply_params_from_csv(row_index=99) diff --git a/tests/unit/easydiffraction/analysis/categories/test_aliases.py b/tests/unit/easydiffraction/analysis/categories/test_aliases.py index 2545218a..5efd265c 100644 --- a/tests/unit/easydiffraction/analysis/categories/test_aliases.py +++ b/tests/unit/easydiffraction/analysis/categories/test_aliases.py @@ -3,15 +3,25 @@ from easydiffraction.analysis.categories.aliases import Alias from easydiffraction.analysis.categories.aliases import Aliases +from easydiffraction.core.validation import AttributeSpec +from easydiffraction.core.variable import Parameter +from easydiffraction.io.cif.handler import CifHandler def test_alias_creation_and_collection(): + p1 = Parameter( + name='b_iso', + value_spec=AttributeSpec(default=0.5), + cif_handler=CifHandler(names=['_atom_site.b_iso']), + ) a = Alias() a.label = 'x' - a.param_uid = 'p1' + a._set_param(p1) assert a.label.value == 'x' + assert a.param is p1 coll = Aliases() - coll.create(label='x', param_uid='p1') + coll.create(label='x', param=p1) # Collections index by entry name; check via names or direct indexing assert 'x' in coll.names - assert coll['x'].param_uid.value == 'p1' + assert coll['x'].param is p1 + assert coll['x'].param_unique_name.value == p1.unique_name diff --git a/tests/unit/easydiffraction/core/test_parameters.py b/tests/unit/easydiffraction/core/test_parameters.py index 87ded2b6..d9f96b0e 100644 --- a/tests/unit/easydiffraction/core/test_parameters.py +++ b/tests/unit/easydiffraction/core/test_parameters.py @@ -66,8 +66,8 @@ def test_parameter_string_repr_and_as_cif_and_flags(): assert 'A' in s assert '(free=True)' in s - # CIF line is ` ` - assert p.as_cif == '_param.a 2.50000000' + # CIF line: free param with uncertainty uses esd brackets + assert p.as_cif == '_param.a 2.50000000(10000000)' # CifHandler uid is owner's unique_name (parameter name here) assert p._cif_handler.uid == p.unique_name == 'a' diff --git a/tests/unit/easydiffraction/core/test_singletons.py b/tests/unit/easydiffraction/core/test_singletons.py index ba69f07a..a68f76d8 100644 --- a/tests/unit/easydiffraction/core/test_singletons.py +++ b/tests/unit/easydiffraction/core/test_singletons.py @@ -1,12 +1,10 @@ # SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause -import pytest +from easydiffraction.core.singleton import ConstraintsHandler -def test_uid_map_handler_rejects_non_descriptor(): - from easydiffraction.core.singleton import UidMapHandler - - h = UidMapHandler.get() - with pytest.raises(TypeError): - h.add_to_uid_map(object()) +def test_constraints_handler_is_singleton(): + h1 = ConstraintsHandler.get() + h2 = ConstraintsHandler.get() + assert h1 is h2 diff --git a/tests/unit/easydiffraction/io/test_ascii.py b/tests/unit/easydiffraction/io/test_ascii.py new file mode 100644 index 00000000..ab180701 --- /dev/null +++ b/tests/unit/easydiffraction/io/test_ascii.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +"""Tests for extract_project_from_zip, extract_data_paths_from_zip and extract_data_paths_from_dir.""" + +from __future__ import annotations + +import zipfile + +import pytest + +from easydiffraction.io.ascii import extract_data_paths_from_dir +from easydiffraction.io.ascii import extract_data_paths_from_zip +from easydiffraction.io.ascii import extract_project_from_zip + + +class TestExtractProjectFromZip: + """Tests for extract_project_from_zip.""" + + def test_extracts_project_dir(self, tmp_path): + """Returns path to the directory containing project.cif.""" + zip_path = tmp_path / 'proj.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('my_project/project.cif', 'data_project\n') + zf.writestr('my_project/structures/struct.cif', 'data_struct\n') + + result = extract_project_from_zip(zip_path, destination=tmp_path / 'out') + + assert result.endswith('my_project') + assert (tmp_path / 'out' / 'my_project' / 'project.cif').is_file() + + def test_extracts_to_temp_dir_by_default(self, tmp_path): + """Without destination, files go to a temp directory.""" + zip_path = tmp_path / 'proj.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('myproj/project.cif', 'data_project\n') + + result = extract_project_from_zip(zip_path) + + assert 'myproj' in result + assert 'project.cif' not in result # returns parent dir, not file + + def test_raises_file_not_found(self, tmp_path): + """Raises FileNotFoundError for missing ZIP path.""" + with pytest.raises(FileNotFoundError): + extract_project_from_zip(tmp_path / 'missing.zip') + + def test_raises_value_error_no_project_cif(self, tmp_path): + """Raises ValueError when ZIP has no project.cif.""" + zip_path = tmp_path / 'bad.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('data.dat', '1 2 3\n') + + with pytest.raises(ValueError, match='No project.cif found'): + extract_project_from_zip(zip_path) + + def test_destination_creates_directory(self, tmp_path): + """Destination directory is created if it does not exist.""" + zip_path = tmp_path / 'proj.zip' + dest = tmp_path / 'nested' / 'output' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('proj/project.cif', 'data\n') + + result = extract_project_from_zip(zip_path, destination=dest) + + assert dest.is_dir() + assert 'proj' in result + + def test_ignores_other_project_cif_in_destination(self, tmp_path): + """Only finds project.cif from the zip, not pre-existing ones.""" + dest = tmp_path / 'data' + # Pre-create another project directory in the destination + other_project = dest / 'aaa_other' / 'project.cif' + other_project.parent.mkdir(parents=True) + other_project.write_text('other\n') + + zip_path = tmp_path / 'proj.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('target_project/project.cif', 'correct\n') + + result = extract_project_from_zip(zip_path, destination=dest) + + assert 'target_project' in result + assert 'aaa_other' not in result + + +class TestExtractDataPathsFromZip: + """Tests for extract_data_paths_from_zip.""" + + def test_extracts_to_temp_dir_by_default(self, tmp_path): + """Without destination, files go to a temp directory.""" + zip_path = tmp_path / 'test.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('scan_001.dat', '1 2 3\n') + zf.writestr('scan_002.dat', '4 5 6\n') + + paths = extract_data_paths_from_zip(zip_path) + + assert len(paths) == 2 + assert 'scan_001.dat' in paths[0] + assert 'scan_002.dat' in paths[1] + + def test_extracts_to_destination(self, tmp_path): + """With destination, files go to the specified directory.""" + zip_path = tmp_path / 'test.zip' + dest = tmp_path / 'output' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('scan_001.dat', '1 2 3\n') + zf.writestr('scan_002.dat', '4 5 6\n') + + paths = extract_data_paths_from_zip(zip_path, destination=dest) + + assert len(paths) == 2 + assert all(str(dest) in p for p in paths) + assert (dest / 'scan_001.dat').is_file() + assert (dest / 'scan_002.dat').is_file() + + def test_destination_creates_directory(self, tmp_path): + """Destination directory is created if it does not exist.""" + zip_path = tmp_path / 'test.zip' + dest = tmp_path / 'nested' / 'output' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('data.dat', '1 2 3\n') + + paths = extract_data_paths_from_zip(zip_path, destination=dest) + + assert len(paths) == 1 + assert dest.is_dir() + + def test_raises_file_not_found(self, tmp_path): + """Raises FileNotFoundError for missing ZIP path.""" + with pytest.raises(FileNotFoundError): + extract_data_paths_from_zip(tmp_path / 'missing.zip') + + def test_raises_value_error_for_empty_zip(self, tmp_path): + """Raises ValueError when ZIP has no usable files.""" + zip_path = tmp_path / 'empty.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('.hidden', 'hidden\n') + + with pytest.raises(ValueError, match='No data files found'): + extract_data_paths_from_zip(zip_path) + + def test_excludes_hidden_files(self, tmp_path): + """Hidden files are excluded from returned paths.""" + zip_path = tmp_path / 'test.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('data.dat', '1 2 3\n') + zf.writestr('.hidden', 'hidden\n') + zf.writestr('__meta', 'meta\n') + + paths = extract_data_paths_from_zip(zip_path) + + assert len(paths) == 1 + assert 'data.dat' in paths[0] + + def test_returns_sorted_paths(self, tmp_path): + """Returned paths are sorted lexicographically.""" + zip_path = tmp_path / 'test.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('c.dat', '3\n') + zf.writestr('a.dat', '1\n') + zf.writestr('b.dat', '2\n') + + paths = extract_data_paths_from_zip(zip_path) + + assert 'a.dat' in paths[0] + assert 'b.dat' in paths[1] + assert 'c.dat' in paths[2] + + +class TestExtractDataPathsFromDir: + """Tests for extract_data_paths_from_dir.""" + + def test_lists_files_in_directory(self, tmp_path): + """Returns sorted paths for files in a directory.""" + (tmp_path / 'scan_002.dat').write_text('2\n') + (tmp_path / 'scan_001.dat').write_text('1\n') + + paths = extract_data_paths_from_dir(tmp_path) + + assert len(paths) == 2 + assert 'scan_001.dat' in paths[0] + assert 'scan_002.dat' in paths[1] + + def test_raises_for_missing_directory(self, tmp_path): + """Raises FileNotFoundError for non-existent directory.""" + with pytest.raises(FileNotFoundError): + extract_data_paths_from_dir(tmp_path / 'missing') + + def test_raises_for_empty_directory(self, tmp_path): + """Raises ValueError when directory has no matching files.""" + empty = tmp_path / 'empty' + empty.mkdir() + + with pytest.raises(ValueError, match='No files matching'): + extract_data_paths_from_dir(empty) diff --git a/tests/unit/easydiffraction/project/test_project_load.py b/tests/unit/easydiffraction/project/test_project_load.py new file mode 100644 index 00000000..c676cb89 --- /dev/null +++ b/tests/unit/easydiffraction/project/test_project_load.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +"""Unit tests for Project.load().""" + +from __future__ import annotations + +import pytest + +from easydiffraction.project.project import Project + + +class TestLoadMinimal: + """Load a project that has no structures or experiments.""" + + def test_raises_on_missing_directory(self, tmp_path): + missing = tmp_path / 'nonexistent' + with pytest.raises(FileNotFoundError, match='not found'): + Project.load(str(missing)) + + def test_round_trips_empty_project(self, tmp_path): + original = Project(name='empty', title='Empty', description='nothing') + original.save_as(str(tmp_path / 'proj')) + + loaded = Project.load(str(tmp_path / 'proj')) + + assert loaded.name == 'empty' + assert loaded.info.title == 'Empty' + assert loaded.info.description == 'nothing' + assert loaded.info.path is not None + assert len(loaded.structures) == 0 + assert len(loaded.experiments) == 0 + + +class TestLoadStructures: + """Load structures from a saved project.""" + + def test_round_trips_structure(self, tmp_path): + original = Project(name='s1') + original.structures.create(name='cosio') + s = original.structures['cosio'] + s.space_group.name_h_m = 'P m -3 m' + s.cell.length_a = 3.88 + s.atom_sites.create( + label='Co', + type_symbol='Co', + fract_x=0.0, + fract_y=0.0, + fract_z=0.0, + b_iso=0.5, + ) + original.save_as(str(tmp_path / 'proj')) + + loaded = Project.load(str(tmp_path / 'proj')) + + assert len(loaded.structures) == 1 + ls = loaded.structures['cosio'] + assert ls.space_group.name_h_m.value == 'P m -3 m' + assert abs(ls.cell.length_a.value - 3.88) < 1e-6 + assert len(ls.atom_sites) == 1 + assert ls.atom_sites['Co'].type_symbol.value == 'Co' + assert abs(ls.atom_sites['Co'].b_iso.value - 0.5) < 1e-6 + + +class TestLoadAnalysis: + """Load analysis settings from a saved project.""" + + def test_round_trips_minimizer(self, tmp_path): + original = Project(name='a1') + original.save_as(str(tmp_path / 'proj')) + + loaded = Project.load(str(tmp_path / 'proj')) + + assert loaded.analysis.current_minimizer == 'lmfit' + + def test_round_trips_fit_mode(self, tmp_path): + original = Project(name='a2') + original.analysis.fit_mode.mode = 'joint' + original.save_as(str(tmp_path / 'proj')) + + loaded = Project.load(str(tmp_path / 'proj')) + + assert loaded.analysis.fit_mode.mode.value == 'joint' + + def test_round_trips_constraints(self, tmp_path): + original = Project(name='c1') + original.structures.create(name='s') + s = original.structures['s'] + s.cell.length_a = 5.0 + s.cell.length_b = 5.0 + + original.analysis.aliases.create( + label='a_param', + param=s.cell.length_a, + ) + original.analysis.aliases.create( + label='b_param', + param=s.cell.length_b, + ) + original.analysis.constraints.create(expression='b_param = a_param') + original.save_as(str(tmp_path / 'proj')) + + loaded = Project.load(str(tmp_path / 'proj')) + + assert len(loaded.analysis.aliases) == 2 + assert loaded.analysis.aliases['a_param'].label.value == 'a_param' + assert loaded.analysis.aliases['b_param'].label.value == 'b_param' + # Verify alias param references are resolved + assert loaded.analysis.aliases['a_param'].param is not None + assert loaded.analysis.aliases['b_param'].param is not None + + assert len(loaded.analysis.constraints) == 1 + assert loaded.analysis.constraints[0].expression.value == 'b_param = a_param' + assert loaded.analysis.constraints.enabled is True + + +class TestLoadAnalysisCifFallback: + """Load falls back from analysis/analysis.cif to analysis.cif at root.""" + + def test_loads_analysis_from_subdir(self, tmp_path): + """Current save layout: analysis/analysis.cif.""" + original = Project(name='fb1') + original.save_as(str(tmp_path / 'proj')) + + # Verify analysis.cif is in analysis/ subdirectory (current save layout) + assert (tmp_path / 'proj' / 'analysis' / 'analysis.cif').is_file() + + loaded = Project.load(str(tmp_path / 'proj')) + assert loaded.analysis.current_minimizer == 'lmfit' + + def test_loads_analysis_from_root_fallback(self, tmp_path): + """Old layout fallback: analysis.cif at project root.""" + original = Project(name='fb2') + original.save_as(str(tmp_path / 'proj')) + + # Move analysis.cif from analysis/ subdirectory to project root + proj_dir = tmp_path / 'proj' + analysis_dir = proj_dir / 'analysis' + (analysis_dir / 'analysis.cif').rename(proj_dir / 'analysis.cif') + analysis_dir.rmdir() + + loaded = Project.load(str(proj_dir)) + assert loaded.analysis.current_minimizer == 'lmfit' diff --git a/tests/unit/easydiffraction/project/test_project_load_and_summary_wrap.py b/tests/unit/easydiffraction/project/test_project_load_and_summary_wrap.py index cdeafd35..69f84127 100644 --- a/tests/unit/easydiffraction/project/test_project_load_and_summary_wrap.py +++ b/tests/unit/easydiffraction/project/test_project_load_and_summary_wrap.py @@ -2,15 +2,27 @@ # SPDX-License-Identifier: BSD-3-Clause -def test_project_load_prints_and_sets_path(tmp_path, capsys): +def test_project_load_raises_on_missing_directory(tmp_path): import pytest from easydiffraction.project.project import Project - p = Project() - dir_path = tmp_path / 'pdir' - with pytest.raises(NotImplementedError, match='not implemented yet'): - p.load(str(dir_path)) + missing_dir = tmp_path / 'nonexistent' + with pytest.raises(FileNotFoundError, match='not found'): + Project.load(str(missing_dir)) + + +def test_project_load_reads_project_info(tmp_path): + from easydiffraction.project.project import Project + + p = Project(name='myproj', title='My Title', description='A description') + p.save_as(str(tmp_path / 'proj')) + + loaded = Project.load(str(tmp_path / 'proj')) + assert loaded.name == 'myproj' + assert loaded.info.title == 'My Title' + assert loaded.info.description == 'A description' + assert loaded.info.path is not None def test_summary_show_project_info_wraps_description(capsys): diff --git a/tests/unit/easydiffraction/project/test_project_save.py b/tests/unit/easydiffraction/project/test_project_save.py index ac8b9895..bf632e11 100644 --- a/tests/unit/easydiffraction/project/test_project_save.py +++ b/tests/unit/easydiffraction/project/test_project_save.py @@ -13,7 +13,7 @@ def test_project_save_uses_cwd_when_no_explicit_path(monkeypatch, tmp_path, caps # It should announce saving and create the three core files assert 'Saving project' in out assert (tmp_path / 'project.cif').exists() - assert (tmp_path / 'analysis.cif').exists() + assert (tmp_path / 'analysis' / 'analysis.cif').exists() assert (tmp_path / 'summary.cif').exists() @@ -34,7 +34,7 @@ def test_project_save_as_writes_core_files(tmp_path, monkeypatch): # Assert expected files/dirs exist assert (target / 'project.cif').is_file() - assert (target / 'analysis.cif').is_file() + assert (target / 'analysis' / 'analysis.cif').is_file() assert (target / 'summary.cif').is_file() assert (target / 'structures').is_dir() assert (target / 'experiments').is_dir()