-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvlm.py
More file actions
142 lines (113 loc) · 7.06 KB
/
vlm.py
File metadata and controls
142 lines (113 loc) · 7.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""Untitled15.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1NIELBD81VWWvVXn8ZuzS6zlFUIteRdff
"""
import os # For interacting with the operating system, like file paths
from PIL import Image # For opening and manipulating images
from transformers import BlipProcessor, BlipForConditionalGeneration # Hugging Face libraries for the BLIP model
import matplotlib.pyplot as plt # For displaying images
import matplotlib.image as mpimg # For reading images as NumPy arrays
import random # For selecting random images
import zipfile # Essential for handling .zip files
# --- 0. Colab Setup (Run these cells sequentially in Colab) ---
# Cell 1: Install/Upgrade necessary libraries
print("Installing/Upgrading required libraries...")
# Using '!' to execute shell commands in Colab
!pip install transformers --upgrade # Upgrade the transformers library
!pip install torch torchvision torchaudio --upgrade # Upgrade PyTorch and its related libraries
!pip install matplotlib # Install matplotlib for plotting and image display
print("Libraries installed.")
# Cell 2: Set Runtime to GPU (VERY IMPORTANT!)
# This is a guiding message. You must perform this action manually via the Colab menu:
# Go to 'Runtime' -> 'Change runtime type' -> Select 'GPU' for 'Hardware accelerator' -> Click 'Save'.
# Colab will prompt you to restart the runtime. Confirm and restart.
# After restarting, run all cells again from the beginning (e.g., by pressing Ctrl+M A).
print("\nEnsure your Colab runtime is set to GPU:")
print("Go to 'Runtime' -> 'Change runtime type' -> Select 'GPU' for 'Hardware accelerator' -> Click 'Save'.")
print("Then restart the runtime (Runtime -> Restart runtime) and run all cells again.")
# Cell 3: Unzip your dataset
# Based on your provided image, the exact name of your zip file is 'Flickr8k_2k.zip'.
dataset_zip_name = 'Flickr8k_2k.zip' # <--- VERIFY THIS NAME!
# The base path where the dataset will be unzipped in Colab.
extracted_data_base_path = '/content/'
# The final path to the folder containing your actual image files.
# Based on your image, inside the zip, there's a folder named 'Flicker8k_2kDataset' (note the spelling).
images_dir_in_colab = os.path.join(extracted_data_base_path, 'Flicker8k_2kDataset') # <--- IMPORTANT: Corrected spelling to 'Flicker'
# Check if the zip file exists and then unzip it
zip_file_full_path = os.path.join(extracted_data_base_path, dataset_zip_name)
if os.path.exists(zip_file_full_path): # If the zip file exists at the specified path:
print(f"\nUnzipping {dataset_zip_name} to {extracted_data_base_path}...")
# 'r' is for read mode of the zip file.
# .extractall() extracts all contents of the zip file to the specified path.
with zipfile.ZipFile(zip_file_full_path, 'r') as zip_ref:
zip_ref.extractall(extracted_data_base_path)
print("Unzipping complete.")
else: # If the zip file is not found:
print(f"Error: ZIP file '{dataset_zip_name}' not found at '{extracted_data_base_path}'.")
print("Please ensure you have correctly uploaded your .zip file to Colab session storage and its name is exact.")
exit() # Stop the program.
# Verify that the image directory exists after unzipping
if not os.path.exists(images_dir_in_colab): # If the expected image folder is not found:
print(f"Error: Expected image directory '{images_dir_in_colab}' not found after unzipping.")
print("Please check the internal structure of your .zip file again.")
print("It should contain a folder named 'Flicker8k_2kDataset' with images inside.")
# For debugging, you can uncomment the line below to list the contents of /content/:
# !ls -F /content/
exit()
else: # If the image folder is found:
print(f"Images will be loaded from: {images_dir_in_colab}")
# --- 1. Set Dataset Paths for Colab (pointing to the extracted location) ---
# This variable now points to the actual directory containing your images.
images_dir = images_dir_in_colab
# --- 2. Load the VLM Model (BLIP) ---
print("\nLoading BLIP model and processor...")
# BlipProcessor: Used for image preprocessing and text tokenization.
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# BlipForConditionalGeneration: This is the actual Vision-Language Model (VLM) used for caption generation.
# .to("cuda") moves the model to the GPU memory. This significantly speeds up inference, as Colab provides GPU access.
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
print("Model loaded successfully.")
# --- 3. Image Caption Generation Function ---
def generate_caption(image_path):
"""
Generates a text caption for a given image using the loaded BLIP model.
"""
try:
# Open the image and convert it to RGB format (ensures compatibility with the model).
raw_image = Image.open(image_path).convert("RGB")
# Process the image and prepare it for the model.
# .to("cuda") ensures the input tensors are also on the GPU.
inputs = processor(raw_image, return_tensors="pt").to("cuda")
# Generate the caption using the model.
# The generate method handles the entire text generation process.
outputs = model.generate(**inputs)
# Decode the model's output (which are token IDs) into human-readable text.
# skip_special_tokens=True removes special tokens like [CLS], [SEP].
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
except Exception as e: # Catch any errors during image processing
return f"Error processing image {image_path}: {e}"
print("\nGenerating captions for sample images from the dataset...")
# --- 4. Select and Process Sample Images ---
# Get a list of all image files (jpg, jpeg, png) in the specified directory.
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
if len(image_files) == 0: # If no image files are found:
print("No image files found in the dataset directory. Please check the folder structure.")
print(f"Ensure image files are directly inside '{images_dir}'.")
exit() # Stop the program.
# Select a few random images to display (e.g., 5 images).
num_samples_to_show = min(5, len(image_files)) # Ensure we don't try to pick more than available images.
selected_sample_images = random.sample(image_files, num_samples_to_show)
for img_file in selected_sample_images: # For each selected image:
image_path = os.path.join(images_dir, img_file) # Construct the full path to the image.
caption = generate_caption(image_path) # Generate a caption using the function.
print(f"\nImage: {img_file}")
print(f"Generated Caption: {caption}")
# Display the image and its generated caption using matplotlib
plt.imshow(mpimg.imread(image_path)) # Read and display the image.
plt.title(f"Generated: {caption}", fontsize=10) # Set the plot title to the generated caption.
plt.axis('off') # Hide axes for a cleaner image display.
plt.show() # Display the plot window.
print("\n--- VLM test completed successfully! ---")