@@ -159,20 +159,20 @@ def __getitem__(self, index):
159159 """
160160 img_id = self .img_ids [index ]
161161 img_dict : LVISImgEntry = self .lvis_api .load_imgs (ids = [img_id ])[0 ]
162- annotation_dicts = self .targets [index ]
162+ annotation_dicts : LVISImgTargets = self .targets [index ]
163163
164164 # Transform from LVIS dictionary to torchvision-style target
165- num_objs = len ( annotation_dicts )
165+ num_objs = annotation_dicts [ "bbox" ]. shape [ 0 ]
166166
167167 boxes = []
168168 labels = []
169169 for i in range (num_objs ):
170- xmin = annotation_dicts [i ][ "bbox" ][0 ]
171- ymin = annotation_dicts [i ][ "bbox" ][1 ]
172- xmax = xmin + annotation_dicts [i ][ "bbox" ][2 ]
173- ymax = ymin + annotation_dicts [i ][ "bbox" ][3 ]
170+ xmin = annotation_dicts ["bbox" ][ i ][0 ]
171+ ymin = annotation_dicts ["bbox" ][ i ][1 ]
172+ xmax = xmin + annotation_dicts ["bbox" ][ i ][2 ]
173+ ymax = ymin + annotation_dicts ["bbox" ][ i ][3 ]
174174 boxes .append ([xmin , ymin , xmax , ymax ])
175- labels .append (annotation_dicts [i ][ "category_id" ])
175+ labels .append (annotation_dicts ["category_id" ][ i ])
176176
177177 if len (boxes ) > 0 :
178178 boxes = torch .as_tensor (boxes , dtype = torch .float32 )
@@ -183,7 +183,7 @@ def __getitem__(self, index):
183183 image_id = torch .tensor ([img_id ])
184184 areas = []
185185 for i in range (num_objs ):
186- areas .append (annotation_dicts [i ][ "area" ])
186+ areas .append (annotation_dicts ["area" ][ i ])
187187 areas = torch .as_tensor (areas , dtype = torch .float32 )
188188 iscrowd = torch .zeros ((num_objs ,), dtype = torch .int64 )
189189
@@ -233,7 +233,17 @@ class LVISAnnotationEntry(TypedDict):
233233 category_id : int
234234
235235
236- class LVISDetectionTargets (Sequence [List [LVISAnnotationEntry ]]):
236+ class LVISImgTargets (TypedDict ):
237+ id : torch .Tensor
238+ area : torch .Tensor
239+ segmentation : List [List [List [float ]]]
240+ image_id : torch .Tensor
241+ bbox : torch .Tensor
242+ category_id : torch .Tensor
243+ labels : torch .Tensor
244+
245+
246+ class LVISDetectionTargets (Sequence [List [LVISImgTargets ]]):
237247 def __init__ (
238248 self ,
239249 lvis_api : LVIS ,
@@ -254,7 +264,28 @@ def __getitem__(self, index):
254264 annotation_dicts : List [LVISAnnotationEntry ] = self .lvis_api .load_anns (
255265 annotation_ids
256266 )
257- return annotation_dicts
267+
268+ n_annotations = len (annotation_dicts )
269+
270+ category_tensor = torch .empty ((n_annotations ,), dtype = torch .long )
271+ target_dict : LVISImgTargets = {
272+ 'bbox' : torch .empty ((n_annotations , 4 ), dtype = torch .float32 ),
273+ 'category_id' : category_tensor ,
274+ 'id' : torch .empty ((n_annotations ,), dtype = torch .long ),
275+ 'area' : torch .empty ((n_annotations ,), dtype = torch .float32 ),
276+ 'image_id' : torch .full ((1 ,), img_id , dtype = torch .long ),
277+ 'segmentation' : [],
278+ 'labels' : category_tensor # Alias of category_id
279+ }
280+
281+ for ann_idx , annotation in enumerate (annotation_dicts ):
282+ target_dict ['bbox' ][ann_idx ] = torch .as_tensor (annotation ['bbox' ])
283+ target_dict ['category_id' ][ann_idx ] = annotation ['category_id' ]
284+ target_dict ['id' ][ann_idx ] = annotation ['id' ]
285+ target_dict ['area' ][ann_idx ] = annotation ['area' ]
286+ target_dict ['segmentation' ].append (annotation ['segmentation' ])
287+
288+ return target_dict
258289
259290
260291def _test_to_tensor (a , b ):
@@ -316,5 +347,6 @@ def _plot_detection_sample(img: Image.Image, target):
316347 "LvisDataset" ,
317348 "LVISImgEntry" ,
318349 "LVISAnnotationEntry" ,
350+ "LVISImgTargets" ,
319351 "LVISDetectionTargets" ,
320352]
0 commit comments