Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions bats_ai/core/management/commands/load_public_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _ingest_files_from_manifest(
limit: int | None,
file_key: str = "file_key",
tag_keys: list[str] | None = None,
data_dir: Path | None = None,
):
if tag_keys is None:
tag_keys = []
Expand All @@ -131,6 +132,7 @@ def _ingest_files_from_manifest(
filename = None

try:
local = False
s3_key = line[file_key]
existing_recording = Recording.objects.filter(name=s3_key).first()
if existing_recording:
Expand All @@ -140,12 +142,22 @@ def _ingest_files_from_manifest(
logger.info("Ingesting %s...", s3_key)
object_exists = _try_head_s3_object(s3_client, bucket, s3_key)
if not object_exists:
logger.warning("Could not HEAD object with key %s. Skipping...", s3_key)
continue
filename = _create_filename(s3_key)
logger.info("Downloading to temporary file %s...", filename)
s3_client.download_file(bucket, s3_key, filename)
logger.info("Creating recording for %s", s3_key)
if not data_dir:
logger.warning("Could not HEAD object with key %s. Skipping...", s3_key)
else:
logger.info(
"Could not HEAD object with key %s. Checking local directory %s",
s3_key,
data_dir,
)
local = True
if not local:
filename = _create_filename(s3_key)
logger.info("Downloading to temporary file %s...", filename)
s3_client.download_file(bucket, s3_key, filename)
logger.info("Creating recording for %s", s3_key)
else:
filename = s3_key
metadata = _get_metadata(filename, line)
with open(filename, "rb") as f:
recording = Recording.objects.create(
Expand Down Expand Up @@ -182,7 +194,7 @@ def _ingest_files_from_manifest(
)
recording_compute_spectrogram.delay(recording.pk)
finally:
if filename:
if not local and filename:
# Delete the file (this may run on a machine with limited resources)
try:
logger.info("Cleaning up by removing temporary file %s...", filename)
Expand All @@ -192,7 +204,7 @@ def _ingest_files_from_manifest(


class Command(BaseCommand):
help = "Create recordings and spectrograms from WAV files in a public s3 bucket"
help = "Ingest recordings from local filesystem and public s3 according to a manifest file."

def add_arguments(self, parser):
parser.add_argument(
Expand All @@ -206,6 +218,9 @@ def add_arguments(self, parser):
# Assume columns "Key" and "Tags"
help="Manifest CSV file with file keys and tags",
)
parser.add_argument(
"--data-dir", type=str, help="The directory where local recordings are located"
)
parser.add_argument(
"--owner",
type=str,
Expand Down Expand Up @@ -244,6 +259,16 @@ def handle(self, *args, **options):
except ClientError:
self.stdout.write(self.style.ERROR(f"Could not access bucket {bucket}"))
return

data_dir = options.get("data-dir")
if data_dir:
data_dir = Path(data_dir)
if not data_dir.exists():
self.stdout.write(
self.style.ERROR(f"Specified data directory {data_dir} does not exist")
)
return

manifest = Path(options["manifest"])
if not manifest.exists():
self.stdout.write(self.style.ERROR(f"Could not find manifest file {manifest}"))
Expand Down Expand Up @@ -277,4 +302,5 @@ def handle(self, *args, **options):
limit=limit,
file_key=file_key,
tag_keys=tag_keys,
data_dir=data_dir,
)