Skip to content

Commit 45e72a2

Browse files
committed
unified get_count_rate method and corrected get_elepsed_time
1 parent 2b9372e commit 45e72a2

1 file changed

Lines changed: 209 additions & 84 deletions

File tree

src/sed/loader/cfel/loader.py

Lines changed: 209 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import numpy as np
2020
import scipy.interpolate as sint
2121
from natsort import natsorted
22-
from typing import Sequence
2322

2423
from sed.core.logging import set_verbosity
2524
from sed.core.logging import setup_logging
@@ -168,14 +167,14 @@ def _initialize_dirs(self) -> None:
168167
self.processed_dir = str(processed_dir)
169168
self.meta_dir = str(meta_dir)
170169

170+
@staticmethod
171171
def _file_index(path: Path) -> int:
172172
"""
173173
Extract file index from filename.
174174
Returns 0 for single-file runs.
175175
"""
176176
stem = path.stem # no extension
177177
parts = stem.rsplit("_", 1)
178-
179178
if len(parts) == 2 and parts[1].isdigit():
180179
return int(parts[1])
181180

@@ -285,13 +284,16 @@ def file_index(path: Path) -> int:
285284

286285
files: list[Path] = []
287286
for folder in folders:
288-
files.extend(
289-
natsorted(
290-
Path(folder).glob(file_pattern),
291-
key=file_index,
292-
)
293-
)
294-
287+
found = list(Path(folder).glob(file_pattern))
288+
# Use the static method directly
289+
files.extend(natsorted(found, key=self._file_index))
290+
# for folder in folders:
291+
# files.extend(
292+
# natsorted(
293+
# Path(folder).glob(file_pattern),
294+
# key=file_index,
295+
# )
296+
# )
295297
if not files:
296298
raise FileNotFoundError(
297299
f"No files found for run {run_id} in directory {folders}",
@@ -540,34 +542,108 @@ def get_count_rate_simple(
540542
count_rate = np.array(all_counts) / np.array(elapsed_times)
541543
times = np.cumsum(elapsed_times)
542544
return count_rate, times
545+
# def get_count_rate(
546+
# self,
547+
# fids: Sequence[int] | None = None,
548+
# runs: Sequence[int] | None = None,
549+
# **kwds,
550+
# ) -> tuple[np.ndarray, np.ndarray]:
551+
# """
552+
# Returns the count rate. By default, returns high-resolution
553+
# point-resolved rates using the millisecond counter.
554+
555+
# Args:
556+
# fids (Sequence[int], optional):
557+
# File IDs to include. Defaults to all files.
558+
# runs (Sequence[int], optional):
559+
# Run IDs to include. If provided, overrides `fids`.
560+
# **kwds:
561+
# Additional arguments passed to `get_count_rate_ms`.
562+
# - mode: "point" (default) or "file".
563+
564+
# Returns:
565+
# tuple[np.ndarray, np.ndarray]:
566+
# - count_rate : array of count rates in Hz
567+
# - time : array of global times in seconds since scan start
568+
# """
569+
# mode = kwds.pop("mode", "point")
570+
# # Resolve runs to fids before calling get_count_rate_ms
571+
# fids_resolved = self._resolve_fids(fids=fids, runs=runs)
572+
# return self.get_count_rate_ms(fids=fids_resolved, mode=mode, **kwds)
543573
def get_count_rate(
544574
self,
545575
fids: Sequence[int] | None = None,
546576
runs: Sequence[int] | None = None,
577+
method: str = "fast", # "fast" (metadata) or "precise" (h5)
578+
mode: str = "file", # "file" (1 pt/file) or "point" (intra-file)
547579
**kwds,
548580
) -> tuple[np.ndarray, np.ndarray]:
549581
"""
550-
Returns the count rate. By default, returns high-resolution
551-
point-resolved rates using the millisecond counter.
552-
553-
Args:
554-
fids (Sequence[int], optional):
555-
File IDs to include. Defaults to all files.
556-
runs (Sequence[int], optional):
557-
Run IDs to include. If provided, overrides `fids`.
558-
**kwds:
559-
Additional arguments passed to `get_count_rate_ms`.
560-
- mode: "point" (default) or "file".
561-
562-
Returns:
563-
tuple[np.ndarray, np.ndarray]:
564-
- count_rate : array of count rates in Hz
565-
- time : array of global times in seconds since scan start
582+
Returns the count rate for specified files or runs.
583+
584+
By default, calculates a fast, file-resolved count rate using metadata.
585+
Supports high-resolution, hardware-timed rates within files when
586+
method='precise' is used.
587+
588+
Parameters
589+
----------
590+
fids : Sequence[int], optional
591+
File indices to include. Defaults to all loaded files.
592+
runs : Sequence[int], optional
593+
Run IDs to include. If provided, overrides `fids`.
594+
method : {"fast", "precise"}, default "fast"
595+
Calculation methodology:
596+
- "fast": Uses pre-collected metadata (very quick, low RAM).
597+
- "precise": Reads hardware 'millisecCounter' from H5 files.
598+
mode : {"file", "point"}, default "file"
599+
Temporal resolution:
600+
- "file": One average rate per file.
601+
- "point": Intra-file time-resolved rates (hardware or statistical).
602+
**kwds : dict
603+
Additional arguments:
604+
- time_bin_size (float): Binning for "fast" + "point" mode (default: 1.0s).
605+
- bin_size (int): Rolling average window for "precise" + "point" mode.
606+
607+
Returns
608+
-------
609+
count_rate : np.ndarray
610+
Array of count rates in Hz.
611+
time : np.ndarray
612+
Array of global times in seconds since the start of the scan.
613+
614+
Notes
615+
-----
616+
'Precise' mode requires 'millisecCounter' to be present in the H5 files.
617+
'Fast' mode requires 'file_statistics' to be populated in the loader.
566618
"""
567-
mode = kwds.pop("mode", "point")
568-
# Resolve runs to fids before calling get_count_rate_ms
569619
fids_resolved = self._resolve_fids(fids=fids, runs=runs)
570-
return self.get_count_rate_ms(fids=fids_resolved, mode=mode, **kwds)
620+
621+
if method == "fast":
622+
if mode == "file":
623+
# Original get_count_rate_simple logic
624+
all_counts = [
625+
self.metadata["file_statistics"]["electron"][str(fid)]["num_rows"]
626+
for fid in fids_resolved
627+
]
628+
# Use metadata-based duration if available, else quick H5 peek
629+
durations = self.get_elapsed_time(fids=fids_resolved)
630+
rates = np.array(all_counts) / np.array(durations)
631+
times = np.cumsum(durations) # Simplified timeline
632+
return rates, times
633+
634+
elif mode == "point":
635+
# Original get_count_rate_time_resolved logic
636+
# Statistical 'point' mode: resolution without high-precision H5 reading
637+
return self.get_count_rate_time_resolved(
638+
fids=fids_resolved,
639+
time_bin_size=kwds.get("time_bin_size", 1.0)
640+
)
641+
642+
elif method == "precise":
643+
# Always uses get_count_rate_ms logic (which already handles file vs point)
644+
return self.get_count_rate_ms(fids=fids_resolved, mode=mode, **kwds)
645+
646+
raise ValueError(f"Invalid method/mode combination: {method}/{mode}")
571647

572648
# -------------------------------
573649
# Time-resolved count rate (binned)
@@ -635,87 +711,136 @@ def get_elapsed_time(
635711
fids: Sequence[int] | None = None,
636712
*,
637713
runs: Sequence[int] | None = None,
638-
first_files: int | None = None,
714+
precise: bool = False,
639715
aggregate: bool = False,
640716
) -> float | list[float]:
641717
"""
642-
Calculates the elapsed acquisition time using millisecCounter.
718+
Calculates the elapsed acquisition time for specified files or runs.
643719
644-
Uses millisecCounter directly from H5 files for accurate duration calculation.
720+
Determines the duration of data collection. By default, it uses fast
721+
metadata-based timestamps. If 'precise' is True, it reads the hardware
722+
millisecCounter directly from the H5 files.
645723
646724
Parameters
647725
----------
648726
fids : Sequence[int] | None
649-
File IDs to include.
727+
File indices to include.
650728
runs : Sequence[int] | None
651-
Run IDs to include.
729+
Run IDs to include. If provided, overrides fids.
730+
precise : bool, default False
731+
If True, forces reading the hardware 'millisecCounter' from HDF5 files
732+
for higher accuracy. If False, uses pre-collected metadata timestamps.
652733
first_files : int | None
653-
Limit to first N resolved files.
654-
aggregate : bool
655-
If True, return total elapsed time (s),
656-
otherwise return per-file elapsed times.
734+
Limit the result to the first N resolved files.
735+
aggregate : bool, default False
736+
If True, returns the total sum of elapsed time (s).
737+
If False, returns a list of elapsed times per file.
657738
658739
Returns
659740
-------
660741
float | list[float]
661742
Elapsed time(s) in seconds.
662-
"""
663743
664-
millis_key = self._config.get("millis_counter_key", "/DLD/millisecCounter")
665-
666-
# ----------------------------
667-
# Resolve files consistently
668-
# ----------------------------
669-
fids_resolved = self._resolve_fids(
670-
fids=fids,
671-
runs=runs,
672-
first_files=first_files,
673-
)
674-
675-
elapsed_per_file: list[float] = []
744+
Raises
745+
------
746+
KeyError
747+
If `precise=True` and the hardware counter is missing from the H5 file.
748+
"""
749+
fids_resolved = self._resolve_fids(fids=fids, runs=runs)
750+
elapsed_per_file = []
676751

677752
for fid in fids_resolved:
678-
try:
679-
with h5py.File(self.files[fid], "r") as h5:
680-
if millis_key not in h5:
681-
raise KeyError(f"millisecCounter not found in file {self.files[fid]}")
753+
dt_s = None
754+
755+
# 1. Try Metadata first (Fast & Safe)
756+
if not precise:
757+
try:
758+
# Accessing the statistics you stored in self.metadata
759+
file_stats = self.metadata["file_statistics"]["timed"][str(fid)]
760+
time_stamps = file_stats["columns"].get("timeStamp", file_stats["columns"].get("timestamp"))
682761

683-
ms = np.asarray(h5[millis_key], dtype=np.float64)
762+
# If these are stored as datetime objects, get the delta
763+
t_min = time_stamps["min"]
764+
t_max = time_stamps["max"]
684765

685-
if len(ms) == 0:
686-
raise ValueError(f"Empty millisecCounter in file {self.files[fid]}")
766+
# Handle both float timestamps and datetime objects
767+
t1 = t_min.total_seconds() if hasattr(t_min, "total_seconds") else float(t_min)
768+
t2 = t_max.total_seconds() if hasattr(t_max, "total_seconds") else float(t_max)
769+
dt_s = t2 - t1
770+
except (KeyError, TypeError):
771+
logger.debug(f"Metadata duration missing for fid {fid}, falling back to H5.")
772+
773+
# 2. Try H5 millisecCounter (Precise, but risky)
774+
if dt_s is None:
775+
millis_key = self._config.get("millis_counter_key", "/DLD/millisecCounter")
776+
try:
777+
with h5py.File(self.files[fid], "r") as h5:
778+
ms = np.asarray(h5[millis_key], dtype=np.float64)
779+
dt_s = (ms[-1] - ms[0]) / 1000.0
780+
except (KeyError, IndexError):
781+
# Ultimate fallback: if everything fails, we can't calculate a rate
782+
logger.warning(f"Could not determine duration for fid {fid}. Using 0.0")
783+
dt_s = 0.0
784+
785+
elapsed_per_file.append(dt_s)
786+
787+
return sum(elapsed_per_file) if aggregate else elapsed_per_file
788+
789+
# millis_key = self._config.get("millis_counter_key", "/DLD/millisecCounter")
790+
791+
# # ----------------------------
792+
# # Resolve files consistently
793+
# # ----------------------------
794+
# fids_resolved = self._resolve_fids(
795+
# fids=fids,
796+
# runs=runs,
797+
# first_files=first_files,
798+
# )
799+
800+
# elapsed_per_file: list[float] = []
801+
802+
# for fid in fids_resolved:
803+
# try:
804+
# with h5py.File(self.files[fid], "r") as h5:
805+
# if millis_key not in h5:
806+
# raise KeyError(f"millisecCounter not found in file {self.files[fid]}")
687807

688-
# Duration is simply last - first millisecond value
689-
dt_ms = ms[-1] - ms[0]
690-
dt_s = dt_ms / 1000.0 # Convert to seconds
808+
# ms = np.asarray(h5[millis_key], dtype=np.float64)
691809

692-
if dt_s < 0:
693-
raise ValueError(
694-
f"Negative elapsed time in file {fid}: {dt_s}s"
695-
)
810+
# if len(ms) == 0:
811+
# raise ValueError(f"Empty millisecCounter in file {self.files[fid]}")
696812

697-
elapsed_per_file.append(dt_s)
813+
# # Duration is simply last - first millisecond value
814+
# dt_ms = ms[-1] - ms[0]
815+
# dt_s = dt_ms / 1000.0 # Convert to seconds
698816

699-
logger.debug(
700-
f"[get_elapsed_time] File {fid}: ms_min={ms[0]}, ms_max={ms[-1]}, "
701-
f"duration={dt_s:.2f}s"
702-
)
817+
# if dt_s < 0:
818+
# raise ValueError(
819+
# f"Negative elapsed time in file {fid}: {dt_s}s"
820+
# )
703821

704-
except KeyError as exc:
705-
filename = (
706-
Path(self.files[fid]).name
707-
if fid < len(self.files)
708-
else f"file_{fid}"
709-
)
710-
raise KeyError(
711-
f"millisecCounter missing in file {filename} (fid={fid}). "
712-
"Ensure millisecCounter is available in the H5 file."
713-
) from exc
714-
715-
if aggregate:
716-
return sum(elapsed_per_file)
717-
718-
return elapsed_per_file
822+
# elapsed_per_file.append(dt_s)
823+
824+
# logger.debug(
825+
# f"[get_elapsed_time] File {fid}: ms_min={ms[0]}, ms_max={ms[-1]}, "
826+
# f"duration={dt_s:.2f}s"
827+
# )
828+
829+
# except KeyError as exc:
830+
# filename = (
831+
# Path(self.files[fid]).name
832+
# if fid < len(self.files)
833+
# else f"file_{fid}"
834+
# )
835+
# raise KeyError(
836+
# f"millisecCounter missing in file {filename} (fid={fid}). "
837+
# "Ensure millisecCounter is available in the H5 file."
838+
# ) from exc
839+
840+
# if aggregate:
841+
# return sum(elapsed_per_file)
842+
843+
# return elapsed_per_file
719844

720845

721846
def read_dataframe(

0 commit comments

Comments
 (0)