|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 | """ |
3 | | -Script to download a sample TFLite model for object detection. |
| 3 | +Script to download a default TFLite model for object detection. |
| 4 | +
|
| 5 | +This is mainly intended for local development; Docker builds can also run it |
| 6 | +to bake the model into the image. |
4 | 7 | """ |
5 | 8 |
|
6 | | -import os |
| 9 | +from __future__ import annotations |
| 10 | + |
7 | 11 | import sys |
8 | | -import urllib.request |
9 | | -import zipfile |
10 | | -import shutil |
11 | 12 | from pathlib import Path |
12 | 13 |
|
13 | | -# Add parent directory to path to import config |
| 14 | +# Add parent directory to path to import config + utils |
14 | 15 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
15 | | -from config import settings |
16 | | - |
17 | | - |
18 | | -def download_file(url, destination): |
19 | | - """Download a file from a URL to a destination.""" |
20 | | - print(f"Downloading {url} to {destination}...") |
21 | | - |
22 | | - # Create directory if it doesn't exist |
23 | | - os.makedirs(os.path.dirname(destination), exist_ok=True) |
24 | | - |
25 | | - # Download the file |
26 | | - urllib.request.urlretrieve(url, destination) |
27 | | - |
28 | | - print(f"Downloaded {destination}") |
29 | 16 |
|
30 | | - |
31 | | -def extract_zip(zip_path, extract_dir): |
32 | | - """Extract a zip file to a directory.""" |
33 | | - print(f"Extracting {zip_path} to {extract_dir}...") |
34 | | - |
35 | | - # Create directory if it doesn't exist |
36 | | - os.makedirs(extract_dir, exist_ok=True) |
37 | | - |
38 | | - # Extract the zip file |
39 | | - with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
40 | | - zip_ref.extractall(extract_dir) |
41 | | - |
42 | | - print(f"Extracted {zip_path}") |
43 | | - |
44 | | - |
45 | | -def download_ssd_mobilenet(): |
46 | | - """Download SSD MobileNet v1 model.""" |
47 | | - model_url = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip" |
48 | | - zip_path = "/tmp/ssd_mobilenet.zip" |
49 | | - extract_dir = "/tmp/ssd_mobilenet" |
50 | | - model_dir = os.path.dirname(settings.TFLITE_MODEL_PATH) |
51 | | - |
52 | | - # Download the model |
53 | | - download_file(model_url, zip_path) |
54 | | - |
55 | | - # Extract the zip file |
56 | | - extract_zip(zip_path, extract_dir) |
57 | | - |
58 | | - # Create model directory if it doesn't exist |
59 | | - os.makedirs(model_dir, exist_ok=True) |
60 | | - |
61 | | - # Copy the model file |
62 | | - shutil.copy( |
63 | | - os.path.join(extract_dir, "detect.tflite"), |
64 | | - settings.TFLITE_MODEL_PATH |
65 | | - ) |
66 | | - |
67 | | - # Create a labelmap file |
68 | | - with open(settings.TFLITE_LABELS_PATH, 'w') as f: |
69 | | - # COCO dataset labels |
70 | | - labels = [ |
71 | | - "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", |
72 | | - "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", |
73 | | - "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", |
74 | | - "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", |
75 | | - "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", |
76 | | - "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", |
77 | | - "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", |
78 | | - "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", |
79 | | - "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", |
80 | | - "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush" |
81 | | - ] |
82 | | - for label in labels: |
83 | | - f.write(f"{label}\n") |
84 | | - |
85 | | - # Clean up |
86 | | - os.remove(zip_path) |
87 | | - shutil.rmtree(extract_dir) |
88 | | - |
89 | | - print(f"Model saved to {settings.TFLITE_MODEL_PATH}") |
90 | | - print(f"Labels saved to {settings.TFLITE_LABELS_PATH}") |
| 17 | +from config import settings |
| 18 | +from utils.model_download import ensure_tflite_ssd_mobilenet_v1, ModelDownloadError |
91 | 19 |
|
92 | 20 |
|
93 | 21 | def main(): |
94 | 22 | """Main function.""" |
95 | | - print("Downloading TFLite model for object detection...") |
96 | | - |
97 | | - # Download SSD MobileNet model |
98 | | - download_ssd_mobilenet() |
99 | | - |
100 | | - print("Done!") |
| 23 | + print("Ensuring default TFLite model for object detection...") |
| 24 | + |
| 25 | + try: |
| 26 | + result = ensure_tflite_ssd_mobilenet_v1( |
| 27 | + model_path=settings.TFLITE_MODEL_PATH, |
| 28 | + labels_path=settings.TFLITE_LABELS_PATH, |
| 29 | + force=False, |
| 30 | + ) |
| 31 | + except ModelDownloadError as e: |
| 32 | + print(f"ERROR: {e}") |
| 33 | + raise SystemExit(1) from e |
| 34 | + |
| 35 | + print(result.message) |
| 36 | + print("Done.") |
101 | 37 |
|
102 | 38 |
|
103 | 39 | if __name__ == "__main__": |
|
0 commit comments