Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bats_ai/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def global_auth(request):
api = NinjaAPI(auth=global_auth)

api.add_router("/recording/", views.recording_router)
api.add_router(
"/recording-locations/",
views.recording_locations_router,
)
api.add_router("/species/", views.species_router)
api.add_router("/grts/", views.grts_cells_router)
api.add_router("/guano/", views.guano_metadata_router)
Expand Down
36 changes: 32 additions & 4 deletions bats_ai/core/management/commands/copy_recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from bats_ai.core.models import (
CompressedSpectrogram,
GRTSCells,
Recording,
RecordingAnnotation,
RecordingTag,
Expand All @@ -31,7 +32,7 @@

logger = logging.getLogger(__name__)

DEFAULT_TAGS = ["test", "foo", "bar"]
DEFAULT_TAGS = ["test", "data", "sample", "foo", "bar"]


def _link_spectrogram_and_annotations(source_recording, new_recording):
Expand Down Expand Up @@ -141,6 +142,14 @@ def add_arguments(self, parser):
help="Username of the owner for the new recordings\
(default: use source recording owner)",
)
parser.add_argument(
"--random-grts-cell-id",
action="store_true",
help=(
"Assign a random valid GRTS Cell ID to each copied recording. "
"When enabled, recording_location and grts_cell are cleared."
),
)

def handle(self, *args, **options):
count = options["count"]
Expand All @@ -149,6 +158,7 @@ def handle(self, *args, **options):
if not tag_texts:
tag_texts = DEFAULT_TAGS
owner_username = options.get("owner")
randomize_grts_cell_id = options.get("random_grts_cell_id", False)

if count < 1:
raise CommandError("--count must be at least 1.")
Expand All @@ -157,6 +167,16 @@ def handle(self, *args, **options):
if not recordings:
raise CommandError("No existing recordings found. Create or import some first.")

valid_grts_cell_ids: list[int] = []
if randomize_grts_cell_id:
valid_grts_cell_ids = list(
GRTSCells.objects.exclude(grts_cell_id__isnull=True)
.values_list("grts_cell_id", flat=True)
.distinct()
)
if not valid_grts_cell_ids:
raise CommandError("No valid GRTS Cell IDs were found in GRTSCells.")

owner = None
if owner_username:
try:
Expand Down Expand Up @@ -190,6 +210,14 @@ def handle(self, *args, **options):
ext = "." + source.audio_file.name.rsplit(".", 1)[-1]
save_name = new_name + ext if ext else new_name

grts_cell_id = source.grts_cell_id
grts_cell = source.grts_cell
recording_location = source.recording_location
if randomize_grts_cell_id:
grts_cell_id = random.choice(valid_grts_cell_ids) # noqa: S311
grts_cell = None
recording_location = None

new_recording = Recording(
name=new_name,
owner=owner,
Expand All @@ -198,9 +226,9 @@ def handle(self, *args, **options):
recorded_time=source.recorded_time,
equipment=source.equipment,
comments=source.comments,
recording_location=source.recording_location,
grts_cell_id=source.grts_cell_id,
grts_cell=source.grts_cell,
recording_location=recording_location,
grts_cell_id=grts_cell_id,
grts_cell=grts_cell,
public=source.public,
software=source.software,
detector=source.detector,
Expand Down
20 changes: 19 additions & 1 deletion bats_ai/core/management/commands/importRecordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
import contextlib
import logging
from pathlib import Path
import random

from django.contrib.auth.models import User
from django.contrib.gis.geos import Point
from django.core.files import File
from django.core.management.base import BaseCommand
from django.utils import timezone

from bats_ai.core.models import Recording
from bats_ai.core.models import Recording, RecordingTag
from bats_ai.core.tasks.tasks import recording_compute_spectrogram
from bats_ai.core.utils.guano_utils import extract_guano_metadata

logger = logging.getLogger(__name__)

_RANDOM_TAG_POOL = ("foo", "bar", "test", "sample", "data")


class Command(BaseCommand):
help = "Import WAV files from a directory, extract GUANO metadata, and create recordings"
Expand Down Expand Up @@ -43,12 +46,21 @@ def add_arguments(self, parser):
type=int,
help="Limit the number of WAV files to import (useful for testing)",
)
parser.add_argument(
"--assign-random-tags",
action="store_true",
help=(
"Assign each imported recording one tag chosen at random from: "
+ ", ".join(_RANDOM_TAG_POOL)
),
)

def handle(self, *args, **options):
directory_path = Path(options["directory"])
owner_username = options.get("owner")
is_public = options.get("public", False)
limit = options.get("limit")
assign_random_tags = options.get("assign_random_tags", False)

# Validate directory
if not directory_path.exists():
Expand Down Expand Up @@ -176,6 +188,12 @@ def handle(self, *args, **options):

self.stdout.write(self.style.SUCCESS(f" Created recording ID: {recording.pk}"))

if assign_random_tags:
tag_text = random.choice(_RANDOM_TAG_POOL) # noqa: S311
tag, _ = RecordingTag.objects.get_or_create(user=owner, text=tag_text)
recording.tags.add(tag)
self.stdout.write(self.style.SUCCESS(f" Assigned random tag: {tag_text}"))

# Generate spectrogram synchronously
self.stdout.write(" Generating spectrogram...")
try:
Expand Down
14 changes: 11 additions & 3 deletions bats_ai/core/management/commands/loadGRTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ class Command(BaseCommand):

def add_arguments(self, parser):
parser.add_argument(
"--batch-size", type=int, default=5000, help="Batch size for database insertion"
"--batch-size",
type=int,
default=5000,
help="Batch size for database insertion",
)

def _download_file(self, url: str, zip_path: Path) -> None:
Expand All @@ -78,7 +81,8 @@ def handle(self, *args, **options):
self._download_file(url, zip_path)
except requests.RequestException as e:
logger.warning(
"Failed to download from primary URL: %s. Attempting backup URL...", e
"Failed to download from primary URL: %s. Attempting backup URL...",
e,
)
if backup_url is None:
logger.warning("No backup URL provided, skipping this shapefile.")
Expand Down Expand Up @@ -117,7 +121,9 @@ def handle(self, *args, **options):
count_new = 0

for idx, row in tqdm(
gdf.iterrows(), total=len(gdf), desc=f"Importing {sample_frame_id}"
gdf.iterrows(),
total=len(gdf),
desc=f"Importing {sample_frame_id}",
):
# Hard fail if GRTS_ID is missing
if "GRTS_ID" not in row or row["GRTS_ID"] is None:
Expand All @@ -131,6 +137,7 @@ def handle(self, *args, **options):
continue

geom_4326 = row.geometry.wkt
centroid_4326 = row.geometry.centroid.wkt
if gdf.crs and gdf.crs.to_epsg() != 4326:
grts_geom = row.geometry.to_wkt()
else:
Expand All @@ -142,6 +149,7 @@ def handle(self, *args, **options):
sample_frame_id=sample_frame_id,
grts_geom=grts_geom,
geom_4326=geom_4326,
centroid_4326=centroid_4326,
)
records_to_create.append(cell)
count_new += 1
Expand Down
44 changes: 44 additions & 0 deletions bats_ai/core/migrations/0035_add_grtscells_centroid_4326.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from django.contrib.gis.db import models
from django.db import migrations


def backfill_centroids(apps, schema_editor) -> None:
GRTSCells = apps.get_model("core", "GRTSCells")

batch_size = 1000
to_update = []

qs = (
GRTSCells.objects.filter(centroid_4326__isnull=True)
.exclude(geom_4326__isnull=True)
.only("id", "grts_cell_id", "sample_frame_id", "geom_4326")
)

for cell in qs.iterator(chunk_size=batch_size):
# `centroid` returns a Point geometry in the same SRID.
cell.centroid_4326 = cell.geom_4326.centroid
to_update.append(cell)

if len(to_update) >= batch_size:
GRTSCells.objects.bulk_update(to_update, ["centroid_4326"], batch_size=batch_size)
to_update.clear()

if to_update:
GRTSCells.objects.bulk_update(to_update, ["centroid_4326"], batch_size=batch_size)


class Migration(migrations.Migration):
dependencies = [
("core", "0034_alter_spectrogramimage_type"),
]

operations = [
migrations.AddField(
model_name="grtscells",
name="centroid_4326",
field=models.PointField(srid=4326, null=True, blank=True),
),
migrations.RunPython(backfill_centroids, migrations.RunPython.noop),
]
2 changes: 2 additions & 0 deletions bats_ai/core/models/grts_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class GRTSCells(models.Model):
grts_geom = models.GeometryField(blank=True, null=True)
# continue defining all fields similarly
geom_4326 = models.GeometryField()
# Precomputed centroid of `geom_4326` for faster lookup of cell centers.
centroid_4326 = models.PointField(srid=4326, blank=True, null=True)

@property
def sample_frame_mapping(self):
Expand Down
2 changes: 2 additions & 0 deletions bats_ai/core/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .processing_tasks import router as processing_task_router
from .recording import router as recording_router
from .recording_annotation import router as recording_annotation_router
from .recording_location import router as recording_locations_router
from .recording_tag import router as recording_tag_router
from .species import router as species_router
from .vetting_details import router as vetting_router
Expand All @@ -18,6 +19,7 @@
"guano_metadata_router",
"processing_task_router",
"recording_annotation_router",
"recording_locations_router",
"recording_router",
"recording_tag_router",
"species_router",
Expand Down
Loading