forked from facebookresearch/vjepa2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvjepa2_demo.py
More file actions
171 lines (136 loc) · 6.82 KB
/
vjepa2_demo.py
File metadata and controls
171 lines (136 loc) · 6.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import subprocess
import numpy as np
import torch
import torch.nn.functional as F
from decord import VideoReader
from transformers import AutoModel, AutoVideoProcessor
import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.attentive_pooler import AttentiveClassifier
from src.models.vision_transformer import vit_giant_xformers_rope
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def load_pretrained_vjepa_pt_weights(model, pretrained_weights):
# Load weights of the VJEPA2 encoder
# The PyTorch state_dict is already preprocessed to have the right key names
pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["encoder"]
pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
msg = model.load_state_dict(pretrained_dict, strict=False)
print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
def load_pretrained_vjepa_classifier_weights(model, pretrained_weights):
# Load weights of the VJEPA2 classifier
# The PyTorch state_dict is already preprocessed to have the right key names
pretrained_dict = torch.load(pretrained_weights, weights_only=True, map_location="cpu")["classifiers"][0]
pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
msg = model.load_state_dict(pretrained_dict, strict=False)
print("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
def build_pt_video_transform(img_size):
short_side_size = int(256.0 / 224 * img_size)
# Eval transform has no random cropping nor flip
eval_transform = video_transforms.Compose(
[
video_transforms.Resize(short_side_size, interpolation="bilinear"),
video_transforms.CenterCrop(size=(img_size, img_size)),
volume_transforms.ClipToTensor(),
video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
]
)
return eval_transform
def get_video():
vr = VideoReader("sample_video.mp4")
# choosing some frames here, you can define more complex sampling strategy
frame_idx = np.arange(0, 128, 2)
video = vr.get_batch(frame_idx).asnumpy()
return video
def forward_vjepa_video(model_hf, model_pt, hf_transform, pt_transform):
# Run a sample inference with VJEPA
with torch.inference_mode():
# Read and pre-process the image
video = get_video() # T x H x W x C
video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W
x_pt = pt_transform(video).cuda().unsqueeze(0)
x_hf = hf_transform(video, return_tensors="pt")["pixel_values_videos"].to("cuda")
# Extract the patch-wise features from the last layer
out_patch_features_pt = model_pt(x_pt)
out_patch_features_hf = model_hf.get_vision_features(x_hf)
return out_patch_features_hf, out_patch_features_pt
def get_vjepa_video_classification_results(classifier, out_patch_features_pt):
SOMETHING_SOMETHING_V2_CLASSES = json.load(open("ssv2_classes.json", "r"))
with torch.inference_mode():
out_classifier = classifier(out_patch_features_pt)
print(f"Classifier output shape: {out_classifier.shape}")
print("Top 5 predicted class names:")
top5_indices = out_classifier.topk(5).indices[0]
top5_probs = F.softmax(out_classifier.topk(5).values[0]) * 100.0 # convert to percentage
for idx, prob in zip(top5_indices, top5_probs):
str_idx = str(idx.item())
print(f"{SOMETHING_SOMETHING_V2_CLASSES[str_idx]} ({prob}%)")
return
def run_sample_inference():
# HuggingFace model repo name
hf_model_name = (
"facebook/vjepa2-vitg-fpc64-384" # Replace with your favored model, e.g. facebook/vjepa2-vitg-fpc64-384
)
# Path to local PyTorch weights
pt_model_path = "YOUR_MODEL_PATH"
sample_video_path = "sample_video.mp4"
# Download the video if not yet downloaded to local path
if not os.path.exists(sample_video_path):
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/bowling/-WH-lxmGJVY_000005_000015.mp4"
command = ["wget", video_url, "-O", sample_video_path]
subprocess.run(command)
print("Downloading video")
# Initialize the HuggingFace model, load pretrained weights
model_hf = AutoModel.from_pretrained(hf_model_name)
model_hf.cuda().eval()
# Build HuggingFace preprocessing transform
hf_transform = AutoVideoProcessor.from_pretrained(hf_model_name)
img_size = hf_transform.crop_size["height"] # E.g. 384, 256, etc.
# Initialize the PyTorch model, load pretrained weights
model_pt = vit_giant_xformers_rope(img_size=(img_size, img_size), num_frames=64)
model_pt.cuda().eval()
load_pretrained_vjepa_pt_weights(model_pt, pt_model_path)
# Build PyTorch preprocessing transform
pt_video_transform = build_pt_video_transform(img_size=img_size)
# Inference on video
out_patch_features_hf, out_patch_features_pt = forward_vjepa_video(
model_hf, model_pt, hf_transform, pt_video_transform
)
print(
f"""
Inference results on video:
HuggingFace output shape: {out_patch_features_hf.shape}
PyTorch output shape: {out_patch_features_pt.shape}
Absolute difference sum: {torch.abs(out_patch_features_pt - out_patch_features_hf).sum():.6f}
Close: {torch.allclose(out_patch_features_pt, out_patch_features_hf, atol=1e-3, rtol=1e-3)}
"""
)
# Initialize the classifier
classifier_model_path = "YOUR_ATTENTIVE_PROBE_PATH"
classifier = (
AttentiveClassifier(embed_dim=model_pt.embed_dim, num_heads=16, depth=4, num_classes=174).cuda().eval()
)
load_pretrained_vjepa_classifier_weights(classifier, classifier_model_path)
# Download SSV2 classes if not already present
ssv2_classes_path = "ssv2_classes.json"
if not os.path.exists(ssv2_classes_path):
command = [
"wget",
"https://huggingface.co/datasets/huggingface/label-files/resolve/d79675f2d50a7b1ecf98923d42c30526a51818e2/"
"something-something-v2-id2label.json",
"-O",
"ssv2_classes.json",
]
subprocess.run(command)
print("Downloading SSV2 classes")
get_vjepa_video_classification_results(classifier, out_patch_features_pt)
if __name__ == "__main__":
# Run with: `python -m notebooks.vjepa2_demo`
run_sample_inference()