Skip to content

Commit 40c861d

Browse files
committed
rebase with new upstream
1 parent 35a7b9b commit 40c861d

3 files changed

Lines changed: 133 additions & 35 deletions

File tree

src/post_processing/utils/filtering_utils.py

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -196,33 +196,62 @@ def get_dataset(df: DataFrame) -> list[str]:
196196
def get_canonical_tz(tz):
197197
"""Return timezone of object as a pytz timezone."""
198198
if isinstance(tz, datetime.timezone):
199-
if tz == datetime.timezone.utc:
199+
if tz == datetime.UTC:
200200
return pytz.utc
201201
offset_minutes = int(tz.utcoffset(None).total_seconds() / 60)
202202
return pytz.FixedOffset(offset_minutes)
203203
if hasattr(tz, "zone") and tz.zone:
204204
return pytz.timezone(tz.zone)
205205
if hasattr(tz, "key"):
206206
return pytz.timezone(tz.key)
207-
else:
208-
msg = f"Unknown timezone: {tz}"
209-
raise TypeError(msg)
207+
msg = f"Unknown timezone: {tz}"
208+
raise TypeError(msg)
210209

211210

212211
def get_timezone(df: DataFrame):
213-
"""Return timezone(s) from DataFrame."""
212+
"""Return timezone(s) from APLOSE DataFrame.
213+
214+
Parameters
215+
----------
216+
df: DataFrame
217+
APLOSE result Dataframe
218+
219+
Returns
220+
-------
221+
tzoffset: list[tzoffset]
222+
list of timezones
223+
224+
"""
214225
timezones = {get_canonical_tz(ts.tzinfo) for ts in df["start_datetime"]}
215226

216227
if len(timezones) == 1:
217228
return next(iter(timezones))
218229
return list(timezones)
219230

220231

232+
def check_timestamp(df: DataFrame, timestamp_audio: list[Timestamp]) -> None:
233+
"""Check if provided timestamp_audio list is correctly formated.
234+
235+
Parameters
236+
----------
237+
df: DataFrame APLOSE results Dataframe.
238+
timestamp_audio: A list of timestamps. Each timestamp is the start datetime of the
239+
corresponding audio file for each detection in df.
240+
241+
"""
242+
if timestamp_audio is None:
243+
msg = "`timestamp_wav` is empty"
244+
raise ValueError(msg)
245+
if len(timestamp_audio) != len(df):
246+
msg = "`timestamp_wav` is not the same length as `df`"
247+
raise ValueError(msg)
248+
249+
221250
def reshape_timebin(
222251
df: DataFrame,
223-
timestamp_wav: list[Timestamp],
252+
*,
224253
timebin_new: Timedelta | None,
225-
timestamp: list[Timestamp] | None = None,
254+
timestamp_audio: list[Timestamp] | None = None,
226255
) -> DataFrame:
227256
"""Reshape an APLOSE result DataFrame according to a new time bin.
228257
@@ -232,11 +261,9 @@ def reshape_timebin(
232261
An APLOSE result DataFrame.
233262
timebin_new: Timedelta
234263
The size of the new time bin.
235-
timestamp: list[Timestamp]
264+
timestamp_audio: list[Timestamp]
236265
A list of Timestamp objects corresponding to the shape
237266
in which the data should be reshaped.
238-
timestamp_wav: list[Timestamp]
239-
A list of the start datetime of each wavfile. Length should be the same as df
240267
241268
Returns
242269
-------
@@ -251,16 +278,20 @@ def reshape_timebin(
251278
if not timebin_new:
252279
return df
253280

281+
check_timestamp(df, timestamp_audio)
282+
254283
annotators = get_annotators(df)
255284
labels = get_labels(df)
256285
max_freq = get_max_freq(df)
257286
dataset = get_dataset(df)
258287

259288
if isinstance(get_timezone(df), list):
260289
df["start_datetime"] = [to_datetime(elem, utc=True)
261-
for elem in df["start_datetime"]]
290+
for elem in df["start_datetime"]
291+
]
262292
df["end_datetime"] = [to_datetime(elem, utc=True)
263-
for elem in df["end_datetime"]]
293+
for elem in df["end_datetime"]
294+
]
264295

265296
results = []
266297
for ant in annotators:
@@ -270,13 +301,13 @@ def reshape_timebin(
270301
if df_1annot_1label.empty:
271302
continue
272303

273-
if timestamp is not None:
304+
if timestamp_audio is not None:
274305
# I do not remember if this is a regular case or not
275306
# might need to be deleted
276-
origin_timebin = timestamp[1] - timestamp[0]
277-
step = int(timebin_new / origin_timebin)
278-
time_vector = timestamp[0::step]
279-
else:
307+
#origin_timebin = timestamp_audio[1] - timestamp_audio[0]
308+
#step = int(timebin_new / origin_timebin)
309+
#time_vector = timestamp_audio[0::step]
310+
#else:
280311
t1 = min(df_1annot_1label["start_datetime"]).floor(timebin_new)
281312
t2 = max(df_1annot_1label["end_datetime"]).ceil(timebin_new)
282313
time_vector = date_range(start=t1, end=t2, freq=timebin_new)
@@ -292,7 +323,7 @@ def reshape_timebin(
292323
bisect.bisect_left(ts_detect_beg, ts) != len(ts_detect_beg)):
293324
idx = bisect.bisect_left(ts_detect_beg, ts)
294325
filename_vector.append(
295-
filenames[idx] if timestamp_wav[idx] <= ts else
326+
filenames[idx] if timestamp_audio[idx] <= ts else
296327
filenames[idx - 1],
297328
)
298329
elif bisect.bisect_left(ts_detect_beg, ts) == len(ts_detect_beg):
@@ -338,9 +369,39 @@ def reshape_timebin(
338369
),
339370
)
340371

341-
return concat(results).sort_values(by=["start_datetime", "end_datetime",
342-
"annotator", "annotation"]).reset_index(drop=True)
372+
return (concat(results).
373+
sort_values(by=["start_datetime", "end_datetime",
374+
"annotator", "annotation"]).reset_index(drop=True)
375+
)
376+
343377

378+
def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]:
379+
"""Get start timestamps of the wav files of each detection contained in df.
380+
381+
Parameters.
382+
----------
383+
df: DataFrame
384+
An APLOSE result DataFrame.
385+
date_parser: str
386+
date parser of the wav file
387+
388+
Returns
389+
-------
390+
List of Timestamps corresponding to the wav files' start timestamps
391+
of each detection contained in df.
392+
393+
"""
394+
tz = get_timezone(df)
395+
try:
396+
return [
397+
to_datetime(
398+
ts,
399+
format=date_parser,
400+
).tz_localize(tz) for ts in df["filename"]
401+
]
402+
except ValueError:
403+
msg = """Could not parse timestamps from `df["filename"]`."""
404+
raise ValueError(msg) from None
344405

345406
def ensure_in_list(value: str, candidates: list[str], label: str) -> None:
346407
"""Check for non-valid elements of a list."""
@@ -378,10 +439,14 @@ def load_detections(filters: DetectionFilter) -> DataFrame:
378439
df = filter_by_label(df, label=filters.annotation)
379440
df = filter_by_freq(df, filters.f_min, filters.f_max)
380441
df = filter_by_score(df, filters.score)
381-
df = reshape_timebin(df, filters.timebin_new)
442+
filename_ts = get_filename_timestamps(df, filters.filename_format)
443+
df = reshape_timebin(df,
444+
timebin_new=filters.timebin_new,
445+
timestamp_audio=filename_ts
446+
)
382447

383448
annotators = get_annotators(df)
384-
if len(annotators) > 1 and filters.user_sel in ["union", "intersection"]:
449+
if len(annotators) > 1 and filters.user_sel in {"union", "intersection"}:
385450
df = intersection_or_union(df, user_sel=filters.user_sel)
386451

387452
return df.sort_values(by=["start_datetime", "end_datetime"]).reset_index(drop=True)
@@ -397,7 +462,7 @@ def intersection_or_union(df: DataFrame, user_sel: str) -> DataFrame:
397462
if user_sel == "all":
398463
return df
399464

400-
if user_sel not in ("intersection", "union"):
465+
if user_sel not in {"intersection", "union"}:
401466
msg = "'user_sel' must be either 'intersection' or 'union'"
402467
raise ValueError(msg)
403468

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import io
22
import os
33
from pathlib import Path
4+
from sqlite3.dbapi2 import Timestamp
45

56
import numpy as np
67
import pytest
78
import soundfile as sf
89
import yaml
910
from osekit.utils.timestamp_utils import strftime_osmose_format
10-
from pandas import DataFrame, read_csv
11+
from pandas import DataFrame, read_csv, to_datetime
12+
13+
from post_processing.utils.filtering_utils import get_timezone, read_dataframe
1114

1215
SAMPLE = """dataset,filename,start_time,end_time,start_frequency,end_frequency,annotation,annotator,start_datetime,end_datetime,type,score
1316
sample_dataset,2025_01_25_06_20_00,0.0,10.0,0.0,72000.0,lbl2,ann2,2025-01-25T06:20:00.000+00:00,2025-01-25T06:20:10.000+00:00,WEAK,0.11

tests/test_filtering_utils.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_get_timezone_single(sample_df: DataFrame) -> None:
298298
def test_get_timezone_several(sample_df: DataFrame) -> None:
299299
new_row = {
300300
"dataset": "dataset",
301-
"filename": "filename",
301+
"filename": "2025_01_26_06_20_00",
302302
"start_time": 0,
303303
"end_time": 2,
304304
"start_frequency": 100,
@@ -384,7 +384,7 @@ def test_no_timebin_returns_original(sample_df: DataFrame) -> None:
384384
def test_no_timebin_several_tz(sample_df: DataFrame) -> None:
385385
new_row = {
386386
"dataset": "dataset",
387-
"filename": "filename",
387+
"filename": "2025_01_26_06_20_00",
388388
"start_time": 0,
389389
"end_time": 2,
390390
"start_frequency": 100,
@@ -400,13 +400,24 @@ def test_no_timebin_several_tz(sample_df: DataFrame) -> None:
400400
[sample_df, DataFrame([new_row])],
401401
ignore_index=False
402402
)
403-
404-
df_out = reshape_timebin(sample_df, timebin_new=None)
403+
tz = get_timezone(sample_df)
404+
timestamp_wav = to_datetime(sample_df["filename"],
405+
format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(pytz.UTC)
406+
df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=None)
405407
assert df_out.equals(sample_df)
406408

407409

408410
def test_no_timebin_original_timebin(sample_df: DataFrame) -> None:
409-
df_out = reshape_timebin(sample_df, timebin_new=Timedelta("1min"))
411+
tz = get_timezone(sample_df)
412+
timestamp_wav = to_datetime(
413+
sample_df["filename"],
414+
format="%Y_%m_%d_%H_%M_%S"
415+
).dt.tz_localize(tz)
416+
df_out = reshape_timebin(
417+
sample_df,
418+
timestamp_audio=timestamp_wav,
419+
timebin_new=Timedelta("1min"),
420+
)
410421
expected = DataFrame(
411422
{
412423
"dataset": ["sample_dataset"] * 18,
@@ -488,7 +499,16 @@ def test_no_timebin_original_timebin(sample_df: DataFrame) -> None:
488499

489500

490501
def test_simple_reshape_hourly(sample_df: DataFrame) -> None:
491-
df_out = reshape_timebin(sample_df, timebin_new=Timedelta(hours=1))
502+
tz = get_timezone(sample_df)
503+
timestamp_wav = to_datetime(
504+
sample_df["filename"],
505+
format="%Y_%m_%d_%H_%M_%S"
506+
).dt.tz_localize(tz)
507+
df_out = reshape_timebin(
508+
sample_df,
509+
timestamp_audio=timestamp_wav,
510+
timebin_new=Timedelta(hours=1),
511+
)
492512
assert not df_out.empty
493513
assert all(df_out["end_time"] == 3600.0)
494514
assert df_out["end_frequency"].max() == sample_df["end_frequency"].max()
@@ -497,7 +517,12 @@ def test_simple_reshape_hourly(sample_df: DataFrame) -> None:
497517

498518

499519
def test_reshape_daily_multiple_bins(sample_df: DataFrame) -> None:
500-
df_out = reshape_timebin(sample_df, timebin_new=Timedelta(days=1))
520+
tz = get_timezone(sample_df)
521+
timestamp_wav = to_datetime(
522+
sample_df["filename"],
523+
format="%Y_%m_%d_%H_%M_%S"
524+
).dt.tz_localize(tz)
525+
df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=Timedelta(days=1))
501526
assert not df_out.empty
502527
assert all(df_out["end_time"] == 86400.0)
503528
assert df_out["start_datetime"].min() >= sample_df["start_datetime"].min().floor("D")
@@ -508,11 +533,13 @@ def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None:
508533
t0 = sample_df["start_datetime"].min().floor("30min")
509534
t1 = sample_df["end_datetime"].max().ceil("30min")
510535
ts_vec = list(date_range(t0, t1, freq="30min"))
511-
536+
tz = get_timezone(sample_df)
537+
timestamp_wav = to_datetime(sample_df["filename"],
538+
format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz)
512539
df_out = reshape_timebin(
513540
sample_df,
514-
timebin_new=Timedelta(hours=1),
515-
timestamp=ts_vec,
541+
timestamp_audio=timestamp_wav,
542+
timebin_new=Timedelta(hours=1)
516543
)
517544

518545
assert not df_out.empty
@@ -521,8 +548,11 @@ def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None:
521548

522549

523550
def test_empty_result_when_no_matching(sample_df: DataFrame) -> None:
551+
tz = get_timezone(sample_df)
552+
timestamp_wav = to_datetime(sample_df["filename"],
553+
format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz)
524554
with pytest.raises(ValueError, match="DataFrame is empty"):
525-
reshape_timebin(DataFrame(), Timedelta(hours=1))
555+
reshape_timebin(DataFrame(), timestamp_audio=timestamp_wav, timebin_new=Timedelta(hours=1))
526556

527557

528558
# %% ensure_no_invalid

0 commit comments

Comments
 (0)