From 85eaf7abfefd4d16cb4769da7b443b1509ba1431 Mon Sep 17 00:00:00 2001 From: naglepuff Date: Thu, 19 Mar 2026 12:02:52 -0400 Subject: [PATCH] Modify ingest script to check a local directory --- .../commands/load_public_dataset.py | 42 +++++++++++++++---- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/bats_ai/core/management/commands/load_public_dataset.py b/bats_ai/core/management/commands/load_public_dataset.py index d40b1cdc..c6c6d688 100644 --- a/bats_ai/core/management/commands/load_public_dataset.py +++ b/bats_ai/core/management/commands/load_public_dataset.py @@ -115,6 +115,7 @@ def _ingest_files_from_manifest( limit: int | None, file_key: str = "file_key", tag_keys: list[str] | None = None, + data_dir: Path | None = None, ): if tag_keys is None: tag_keys = [] @@ -131,6 +132,7 @@ def _ingest_files_from_manifest( filename = None try: + local = False s3_key = line[file_key] existing_recording = Recording.objects.filter(name=s3_key).first() if existing_recording: @@ -140,12 +142,22 @@ def _ingest_files_from_manifest( logger.info("Ingesting %s...", s3_key) object_exists = _try_head_s3_object(s3_client, bucket, s3_key) if not object_exists: - logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) - continue - filename = _create_filename(s3_key) - logger.info("Downloading to temporary file %s...", filename) - s3_client.download_file(bucket, s3_key, filename) - logger.info("Creating recording for %s", s3_key) + if not data_dir: + logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) + else: + logger.info( + "Could not HEAD object with key %s. Checking local directory %s", + s3_key, + data_dir, + ) + local = True + if not local: + filename = _create_filename(s3_key) + logger.info("Downloading to temporary file %s...", filename) + s3_client.download_file(bucket, s3_key, filename) + logger.info("Creating recording for %s", s3_key) + else: + filename = s3_key metadata = _get_metadata(filename, line) with open(filename, "rb") as f: recording = Recording.objects.create( @@ -182,7 +194,7 @@ def _ingest_files_from_manifest( ) recording_compute_spectrogram.delay(recording.pk) finally: - if filename: + if not local and filename: # Delete the file (this may run on a machine with limited resources) try: logger.info("Cleaning up by removing temporary file %s...", filename) @@ -192,7 +204,7 @@ def _ingest_files_from_manifest( class Command(BaseCommand): - help = "Create recordings and spectrograms from WAV files in a public s3 bucket" + help = "Ingest recordings from local filesystem and public s3 according to a manifest file." def add_arguments(self, parser): parser.add_argument( @@ -206,6 +218,9 @@ def add_arguments(self, parser): # Assume columns "Key" and "Tags" help="Manifest CSV file with file keys and tags", ) + parser.add_argument( + "--data-dir", type=str, help="The directory where local recordings are located" + ) parser.add_argument( "--owner", type=str, @@ -244,6 +259,16 @@ def handle(self, *args, **options): except ClientError: self.stdout.write(self.style.ERROR(f"Could not access bucket {bucket}")) return + + data_dir = options.get("data-dir") + if data_dir: + data_dir = Path(data_dir) + if not data_dir.exists(): + self.stdout.write( + self.style.ERROR(f"Specified data directory {data_dir} does not exist") + ) + return + manifest = Path(options["manifest"]) if not manifest.exists(): self.stdout.write(self.style.ERROR(f"Could not find manifest file {manifest}")) @@ -277,4 +302,5 @@ def handle(self, *args, **options): limit=limit, file_key=file_key, tag_keys=tag_keys, + data_dir=data_dir, )