From ec7f4e89dcbd65264c2a48decb968d8711f97bff Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 29 Apr 2026 15:38:51 +0200 Subject: [PATCH 01/15] Add VitDet encoder / classification model --- doctr/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doctr/models/__init__.py b/doctr/models/__init__.py index b6db1c0678..8bdcccd1dd 100644 --- a/doctr/models/__init__.py +++ b/doctr/models/__init__.py @@ -1,5 +1,6 @@ from .classification import * from .detection import * from .recognition import * +from .layout import * from .zoo import * from .factory import * From 3dec2c66f295d8aaf7ce9af7275df524cd08683c Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 6 May 2026 14:22:31 +0200 Subject: [PATCH 02/15] Add LW-DETR --- doctr/models/layout/__init__.py | 2 + doctr/models/layout/lw_detr/__init__.py | 1 + doctr/models/layout/lw_detr/base.py | 242 ++++++ .../models/layout/lw_detr/layers/__init__.py | 1 + doctr/models/layout/lw_detr/layers/pytorch.py | 654 +++++++++++++++ doctr/models/layout/lw_detr/pytorch.py | 764 ++++++++++++++++++ doctr/models/layout/predictor/__init__.py | 1 + doctr/models/layout/predictor/pytorch.py | 80 ++ doctr/models/layout/zoo.py | 88 ++ tests/common/test_models_layout.py | 66 ++ tests/pytorch/test_models_layout.py | 182 +++++ 11 files changed, 2081 insertions(+) create mode 100644 doctr/models/layout/__init__.py create mode 100644 doctr/models/layout/lw_detr/__init__.py create mode 100644 doctr/models/layout/lw_detr/base.py create mode 100644 doctr/models/layout/lw_detr/layers/__init__.py create mode 100644 doctr/models/layout/lw_detr/layers/pytorch.py create mode 100644 doctr/models/layout/lw_detr/pytorch.py create mode 100644 doctr/models/layout/predictor/__init__.py create mode 100644 doctr/models/layout/predictor/pytorch.py create mode 100644 doctr/models/layout/zoo.py create mode 100644 tests/common/test_models_layout.py create mode 100644 tests/pytorch/test_models_layout.py diff --git a/doctr/models/layout/__init__.py b/doctr/models/layout/__init__.py new file mode 100644 index 0000000000..d4147b2974 --- /dev/null +++ b/doctr/models/layout/__init__.py @@ -0,0 +1,2 @@ +from .zoo import * +from .lw_detr import * diff --git a/doctr/models/layout/lw_detr/__init__.py b/doctr/models/layout/lw_detr/__init__.py new file mode 100644 index 0000000000..e3c861310c --- /dev/null +++ b/doctr/models/layout/lw_detr/__init__.py @@ -0,0 +1 @@ +from .pytorch import * diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py new file mode 100644 index 0000000000..3a7ab89ad9 --- /dev/null +++ b/doctr/models/layout/lw_detr/base.py @@ -0,0 +1,242 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import cv2 +import numpy as np + +from doctr.models.core import BaseModel + +__all__ = ["_LWDETR", "LWDETRPostProcessor"] + + +class LWDETRPostProcessor: + """Implements a post processor for LW-DETR model + + Args: + num_classes: number of classes + score_thresh: confidence threshold for filtering predictions + iou_thresh: IoU threshold for NMS + topk: number of top predictions to keep before NMS + assume_straight_pages: whether the pages are assumed to be straight (i.e., no rotation) + """ + + def __init__( + self, + num_classes: int, + score_thresh: float = 0.3, + iou_thresh: float = 0.5, + topk: int = 300, + assume_straight_pages: bool = True, + ): + self.num_classes = num_classes + self.score_thresh = score_thresh + self.iou_thresh = iou_thresh + self.topk = topk + self.assume_straight_pages = assume_straight_pages + + def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Decode the predicted boxes from OBB format to polygon format + + Args: + boxes: array of predicted boxes in OBB format (N, 6) (cx, cy, w, h, sin(theta), cos(theta)) + + Returns: + tuple of (polys, angles) where polys is an array of decoded polygons (N, 4, 2) + and angles is an array of angles in radians (N,) + """ + cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + sin, cos = boxes[:, 4], boxes[:, 5] + + angles = np.arctan2(sin, cos) + + polys = [] + for i in range(len(boxes)): + rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i]))) + + poly = cv2.boxPoints(rect) + polys.append(poly) + + return np.asarray(polys, dtype=np.float32), angles + + def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float: + """Compute the IoU between two polygons + + Args: + poly1: first polygon (4, 2) + poly2: second polygon (4, 2) + + Returns: + IoU between the two polygons + """ + inter = cv2.intersectConvexConvex( + poly1.astype(np.float32), + poly2.astype(np.float32), + )[0] + + if inter <= 0: + return 0.0 + + area1 = cv2.contourArea(poly1) + area2 = cv2.contourArea(poly2) + + return inter / (area1 + area2 - inter + 1e-6) + + def _nms(self, polys: np.ndarray, scores: np.ndarray) -> list[int]: + """Perform NMS on the predicted polygons + + Args: + polys: array of predicted polygons (N, 4, 2) + scores: array of predicted scores (N,) + + Returns: + list of indices of the polygons to keep after NMS + """ + idxs = np.argsort(scores)[::-1] + keep = [] + + while idxs.size > 0: + i = idxs[0] + keep.append(i) + + if idxs.size == 1: + break + + rest = idxs[1:] + + ious = np.array([self._iou(polys[i], polys[j]) for j in rest]) + + idxs = rest[ious < self.iou_thresh] + + return keep + + def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]: + + logits = np.asarray(logits) + boxes = np.asarray(boxes) + + results: list[tuple[list[int], np.ndarray, list[float]]] = [] + + for b in range(boxes.shape[0]): + # Convert logits to probabilities and get scores and labels + exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) + prob = exp / exp.sum(axis=-1, keepdims=True) + + scores = prob[:, 1:].max(axis=-1) + labels = prob[:, 1:].argmax(axis=-1) + 1 + + # Keep only topk predictions before NMS + if self.topk is not None and len(scores) > self.topk: + idxs = np.argsort(scores)[::-1][: self.topk] + else: + idxs = np.arange(len(scores)) + + scores_b = scores[idxs] + labels_b = labels[idxs] + bboxes = boxes[b][idxs] + + mask = scores_b > self.score_thresh + + bboxes = bboxes[mask] + scores_b = scores_b[mask] + labels_b = labels_b[mask] + + polys, _ = ( + self._decode_boxes(bboxes) + if len(bboxes) > 0 + else ( + np.zeros((0, 4, 2), dtype=np.float32), + np.zeros((0,), dtype=np.float32), + ) + ) + + keep = self._nms(polys, scores_b) if len(polys) > 0 else [] + + final_labels = [] + final_boxes = [] + final_scores = [] + + for idx in keep: + poly = polys[idx].reshape(-1).tolist() + if self.assume_straight_pages: + x_coords = poly[0::2] + y_coords = poly[1::2] + xmin, xmax = min(x_coords), max(x_coords) + ymin, ymax = min(y_coords), max(y_coords) + final_boxes.append([xmin, ymin, xmax, ymax]) + else: + final_boxes.append(poly) + + final_labels.append(int(labels_b[idx])) + final_scores.append(float(scores_b[idx])) + + final_boxes_arr = ( + np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4, 2) + if not self.assume_straight_pages + else np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4) + ) + + results.append(( + final_labels, + final_boxes_arr, + final_scores, + )) + + return results + + +class _LWDETR(BaseModel): + """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection" + `_. + """ + + def build_target( + self, + target: list[tuple[list[int], np.ndarray]], + ): + + targets = [] + + def _quad_to_obb(poly: np.ndarray): + p1, p2, p3, p4 = poly + + cx, cy = np.mean(poly, axis=0) + + w = (np.linalg.norm(p2 - p1) + np.linalg.norm(p3 - p4)) / 2 + h = (np.linalg.norm(p3 - p2) + np.linalg.norm(p4 - p1)) / 2 + + theta = np.arctan2(*(p2 - p1)[::-1]) + + return np.array( + [cx, cy, w, h, np.sin(theta), np.cos(theta)], + dtype=np.float32, + ) + + for class_ids, boxes in target: + boxes_all = [] + labels_all = [] + + if len(boxes) == 0: + targets.append({ + "boxes": np.zeros((0, 6), dtype=np.float32), + "labels": np.zeros((0,), dtype=np.int64), + }) + continue + + for cls_id, box in zip(np.asarray(class_ids), np.asarray(boxes)): + poly = box.reshape(4, 2) + obb = _quad_to_obb(poly) + + if obb[2] <= 1e-3 or obb[3] <= 1e-3: + continue + + boxes_all.append(obb) + labels_all.append(cls_id + 1) # background = 0 + + targets.append({ + "boxes": np.asarray(boxes_all, dtype=np.float32), + "labels": np.asarray(labels_all, dtype=np.int64), + }) + + return targets diff --git a/doctr/models/layout/lw_detr/layers/__init__.py b/doctr/models/layout/lw_detr/layers/__init__.py new file mode 100644 index 0000000000..e3c861310c --- /dev/null +++ b/doctr/models/layout/lw_detr/layers/__init__.py @@ -0,0 +1 @@ +from .pytorch import * diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py new file mode 100644 index 0000000000..ed5ec086f6 --- /dev/null +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -0,0 +1,654 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from doctr.models.modules import ChannelLayerNorm +from doctr.models.utils import conv_sequence_pt + +__all__ = ["MultiScaleProjector", "C2fBottleneck", "LWDETRHead", "LWDETRDecoder", "LWDETRMultiscaleDeformableAttention"] + + +class LWDETRHead(nn.Module): + """ + Simple MLP used to predict the normalized center coordinates, height and width of a bounding box w.r.t. an image. + + Args: + input_dim: number of input features + hidden_dim: number of hidden features + output_dim: number of output features + num_layers: number of layers in the MLP + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class LWDETRMLP(nn.Module): + """Simple MLP used in the decoder layers of LW-DETR. + + Args: + d_model: number of input and output features + ff_dim: number of hidden features + dropout_prob: dropout probability + """ + + def __init__( + self, + d_model: int, + ff_dim: int, + dropout_prob: float = 0.1, + ): + super().__init__() + self.act = nn.ReLU() + self.dropout_1 = nn.Dropout(dropout_prob) + self.dropout_2 = nn.Dropout(dropout_prob) + self.fc1 = nn.Linear(d_model, ff_dim) + self.fc2 = nn.Linear(ff_dim, d_model) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.dropout_2(self.fc2(self.dropout_1(self.act(self.fc1(x))))) + + +class LWDETRAttention(nn.Module): + """This module implements the self-attention mechanism used in LW-DETR. + It performs multi-head self-attention on the input hidden states. + The group detr technique is used during training to add more supervision by + using multiple weight-sharing decoders at once for faster convergence. + + Args: + sa_num_heads: number of attention heads for self-attention + d_model: number of input and output features + dropout_prob: dropout probability for attention weights + group_detr: number of weight-sharing decoders to use during training + layer_idx: index of the decoder layer (used for group detr) + """ + + def __init__( + self, + sa_num_heads: int = 8, + d_model: int = 256, + dropout_prob: float = 0.0, + group_detr: int = 13, + layer_idx: int = 0, + ): + super().__init__() + self.layer_idx = layer_idx + self.head_dim = d_model // sa_num_heads + self.scaling = self.head_dim**-0.5 + self.dropout = dropout_prob + self.group_detr = group_detr + + self.q_proj = nn.Linear(d_model, sa_num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(d_model, sa_num_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(d_model, sa_num_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(sa_num_heads * self.head_dim, d_model, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, _ = hidden_states.shape + + hidden_states_original = hidden_states + if position_embeddings is not None: + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + + if self.training: + # at training, we use group detr technique to + # add more supervision by using multiple weight-sharing decoders at once for faster convergence + # at inference, we only use one decoder + hidden_states_original = torch.cat(hidden_states_original.split(seq_len // self.group_detr, dim=1), dim=0) + hidden_states = torch.cat(hidden_states.split(seq_len // self.group_detr, dim=1), dim=0) + + attention_input_shape = hidden_states.shape[:-1] + hidden_shape = (*attention_input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(*attention_input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if self.training: + attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1) + + return attn_output, attn_weights + + +class MultiScaleDeformableAttention(nn.Module): + """This module implements MultiScaleDeformableAttention from Deformable DETR. + It performs multi-scale deformable attention on the input feature maps. + Borrowed from: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/deformable_detr/modeling_deformable_detr.py + """ + + def forward( + self, + value: torch.Tensor, + value_spatial_shapes_list: list[tuple], + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, + ) -> torch.Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes_list): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = ( + value_list[level_id] + .flatten(2) + .transpose(1, 2) + .reshape(batch_size * num_heads, hidden_dim, height, width) + ) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) + # batch_size*num_heads, hidden_dim, num_queries, num_points + sampling_value_l_ = nn.functional.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + batch_size * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class LWDETRMultiscaleDeformableAttention(nn.Module): + """ + Multiscale deformable attention as proposed in Deformable DETR. + + Args: + d_model: number of input and output features + ca_num_heads: number of attention heads for cross-attention + dec_n_points: number of sampling points for each attention head + """ + + def __init__( + self, + d_model: int = 256, + ca_num_heads: int = 16, + dec_n_points: int = 2, + ): + super().__init__() + + self.attn = MultiScaleDeformableAttention() + + self.d_model = d_model + self.n_levels = 1 + self.n_heads = ca_num_heads + self.n_points = dec_n_points + + self.sampling_offsets = nn.Linear(d_model, self.n_heads * self.n_levels * dec_n_points * 2) + self.attention_weights = nn.Linear(d_model, self.n_heads * self.n_levels * dec_n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states=None, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states = hidden_states + position_embeddings + + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + + value = self.value_proj(encoder_hidden_states) + if attention_mask is not None: + # we invert the attention_mask + value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(hidden_states).view( + batch_size, num_queries, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + batch_size, num_queries, self.n_heads, self.n_levels, self.n_points + ) + # batch_size, num_queries, n_heads, n_levels, n_points, 2 + num_coordinates = reference_points.shape[-1] + + if num_coordinates == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + elif num_coordinates == 6: + ref = reference_points[:, :, None, :, None, :] # (..., 6) + + center = ref[..., :2] # (cx, cy) + wh = ref[..., 2:4] # (w, h) + sin = ref[..., 4:5] # sinθ + cos = ref[..., 5:6] # cosθ + + # normalize offsets + offsets = sampling_offsets / self.n_points * wh * 0.5 + + dx = offsets[..., 0:1] + dy = offsets[..., 1:2] + + # rotate offsets + dx_rot = dx * cos - dy * sin + dy_rot = dx * sin + dy * cos + + rotated_offsets = torch.cat([dx_rot, dy_rot], dim=-1) + + sampling_locations = center + rotated_offsets + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + + output = self.attn( + value, + spatial_shapes_list, + sampling_locations, + attention_weights, + ) + + output = self.output_proj(output) + + return output, attention_weights + + +class LWDETRDecoderLayer(nn.Module): + """This module implements a single decoder layer of LW-DETR, + which consists of self-attention, cross-attention and an MLP. + + Args: + d_model: number of input and output features + ff_dim: number of hidden features in the MLP + dropout_prob: dropout probability for the attention and MLP layers + ca_num_heads: number of attention heads for cross-attention + dec_n_points: number of sampling points for each attention head in cross-attention + sa_num_heads: number of attention heads for self-attention + group_detr: number of weight-sharing decoders to use during training for the group detr technique + layer_idx: index of the decoder layer (used for group detr) + """ + + def __init__( + self, + d_model: int = 256, + ff_dim: int = 2048, + dropout_prob: float = 0.0, + ca_num_heads: int = 16, + dec_n_points: int = 2, + sa_num_heads: int = 8, + group_detr: int = 13, + layer_idx: int = 0, + ): + super().__init__() + self.dropout = dropout_prob + + # self-attention + self.self_attn = LWDETRAttention( + sa_num_heads=sa_num_heads, + d_model=d_model, + dropout_prob=dropout_prob, + group_detr=group_detr, + layer_idx=layer_idx, + ) + self.self_attn_layer_norm = nn.LayerNorm(d_model) + + # cross-attention + self.cross_attn = LWDETRMultiscaleDeformableAttention( + d_model=d_model, + ca_num_heads=ca_num_heads, + dec_n_points=dec_n_points, + ) + self.cross_attn_layer_norm = nn.LayerNorm(d_model) + + # mlp + self.mlp = LWDETRMLP(d_model, ff_dim, dropout_prob) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points=None, + spatial_shapes_list=None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + self_attention_output, self_attn_weights = self.self_attn( + hidden_states, position_embeddings=position_embeddings + ) + + self_attention_output = F.dropout(self_attention_output, p=self.dropout, training=self.training) + hidden_states = hidden_states + self_attention_output + hidden_states = self.self_attn_layer_norm(hidden_states) + + cross_attention_output, cross_attn_weights = self.cross_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes_list=spatial_shapes_list, + ) + cross_attention_output = F.dropout(cross_attention_output, p=self.dropout, training=self.training) + hidden_states = hidden_states + cross_attention_output + hidden_states = self.cross_attn_layer_norm(hidden_states) + + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + return hidden_states + + +# function to generate sine positional embedding for 4d coordinates +# Borrowed from: https://github.com/Atten4Vis/LW-DETR/blob/main/models/transformer.py +def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 256) -> torch.Tensor: + """ + This function computes position embeddings using sine and cosine functions from the input positional tensor, + which has a shape of (batch_size, num_queries, 4). + The last dimension of `pos_tensor` represents the following coordinates: + - 0: x-coord + - 1: y-coord + - 2: width + - 3: height + + The output shape is (batch_size, num_queries, 512), + where final dim (hidden_size*2 = 512) is the total embedding dimension + achieved by concatenating the sine and cosine values for each coordinate. + """ + scale = 2 * math.pi + dim = hidden_size // 2 + dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + return pos.to(pos_tensor.dtype) + + +class LWDETRDecoder(nn.Module): + """This module implements the decoder of LW-DETR, + which consists of multiple decoder layers and a reference point head. + + Args: + num_layers: number of decoder layers + d_model: number of input and output features for each decoder layer + sa_num_heads: number of attention heads for self-attention in each decoder layer + ca_num_heads: number of attention heads for cross-attention in each decoder layer + ff_dim: number of hidden features in the MLP of each decoder layer + dec_n_points: number of sampling points for each attention head in cross-attention of each decoder layer + group_detr: number of weight-sharing decoders to use during training for the group detr technique + dropout_prob: dropout probability for the attention and MLP layers in each decoder layer + """ + + def __init__( + self, + num_layers: int = 3, + d_model: int = 256, + sa_num_heads: int = 8, + ca_num_heads: int = 16, + ff_dim: int = 2048, + dec_n_points: int = 2, + group_detr: int = 13, + dropout_prob: float = 0.0, + ): + super().__init__() + self.dropout_prob = dropout_prob + self.d_model = d_model + self.layers = nn.ModuleList([ + LWDETRDecoderLayer( + d_model=self.d_model, + sa_num_heads=sa_num_heads, + ca_num_heads=ca_num_heads, + ff_dim=ff_dim, + dec_n_points=dec_n_points, + group_detr=group_detr, + dropout_prob=dropout_prob, + layer_idx=i, + ) + for i in range(num_layers) + ]) + self.layernorm = nn.LayerNorm(self.d_model) + + self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) + self.angle_proj = nn.Linear(2, self.d_model) + + def get_reference(self, reference_points, valid_ratios): + obj_center = reference_points[..., :4] + reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + # DETR positional encoding + query_sine_embed = gen_sine_position_embeddings(reference_points_inputs[:, :, 0, :], self.d_model) + base_query_pos = self.ref_point_head(query_sine_embed) + # angle embedding + angle = reference_points[..., 4:6] # (sin, cos) + angle_emb = self.angle_proj(angle) + # Combine + query_pos = base_query_pos + angle_emb + return reference_points_inputs, query_pos + + def forward( + self, + inputs_embeds: torch.Tensor | None = None, + reference_points: torch.Tensor | None = None, + spatial_shapes_list: torch.Tensor | None = None, + valid_ratios: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ): + intermediate = () + intermediate_reference_points = (reference_points,) + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios) + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_embeddings=query_pos, + reference_points=reference_points_inputs, + spatial_shapes_list=spatial_shapes_list, + ) + intermediate_hidden_states = self.layernorm(hidden_states) + intermediate += (intermediate_hidden_states,) + + intermediate = torch.stack(intermediate) + last_hidden_state = intermediate[-1] + intermediate_reference_points = torch.stack(intermediate_reference_points) + + return last_hidden_state, intermediate, intermediate_reference_points + + +class MultiScaleProjector(nn.Module): + """ + This module implements MultiScaleProjector in :paper:`lwdetr`. + It creates pyramid features built on top of the input feature map. + This is modified from the original MultiScaleProjector to use only the levels used in LW-DETR small and medium. + + Args: + in_channels (list[int]): list of input channels for each level of the input feature maps. + out_channels (int): number of channels in the output feature maps. + num_blocks (int): number of blocks in the C2fBottleneck. + """ + + def __init__(self, in_channels: list[int], out_channels: int, num_blocks: int = 3): + super().__init__() + + self.use_extra_pool = False + + self.stages_sampling = nn.ModuleList() + self.stages = nn.ModuleList() + + sampling_layers = nn.ModuleList() + out_dim: int = 0 + + for in_dim in in_channels: + layers, out_dim = [nn.Identity()], in_dim + sampling_layers.append(nn.Sequential(*layers)) + + self.stages_sampling.append(sampling_layers) + + fusion_in_dim = out_dim * len(in_channels) + + self.stages.append( + nn.Sequential( + C2fBottleneck(fusion_in_dim, out_channels, num_blocks), + ChannelLayerNorm(out_channels), + ) + ) + + def forward(self, x: torch.Tensor) -> list[tuple[torch.Tensor, torch.Tensor]]: + feats = [layer(xi) for layer, xi in zip(self.stages_sampling[0], x)] # type: ignore[call-overload] + fused = torch.cat(feats, dim=1) + return [self.stages[0](fused)] + + +class C2fBottleneck(nn.Module): + """Faster implementation of CSP bottleneck with 2 convolutions and 1 residual connection + + Args: + input_dim: number of input channels + out_channels: number of output channels + num_blocks: number of bottleneck blocks + """ + + def __init__(self, input_dim: int, out_channels: int, num_blocks: int): + super().__init__() + + self.c = int(out_channels * 0.5) + + self.conv_seq_1 = nn.Sequential( + *conv_sequence_pt( + in_channels=input_dim, + out_channels=2 * self.c, + kernel_size=1, + stride=1, + padding=0, + groups=1, + dilation=1, + act=True, + bias=False, + bn=True, + activation=nn.SiLU(inplace=True), + ) + ) + + self.blocks = nn.ModuleList([ + nn.Sequential( + *conv_sequence_pt( + in_channels=self.c, + out_channels=self.c, + kernel_size=3, + stride=1, + padding=1, + groups=1, + dilation=1, + act=True, + bias=False, + bn=True, + activation=nn.SiLU(inplace=True), + ), + *conv_sequence_pt( + in_channels=self.c, + out_channels=self.c, + kernel_size=3, + stride=1, + padding=1, + groups=1, + dilation=1, + act=True, + bias=False, + bn=True, + activation=nn.SiLU(inplace=True), + ), + ) + for _ in range(num_blocks) + ]) + + self.conv_seq_2 = nn.Sequential( + *conv_sequence_pt( + in_channels=(2 + num_blocks) * self.c, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + dilation=1, + act=True, + bias=False, + bn=True, + activation=nn.SiLU(inplace=True), + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = list(self.conv_seq_1(x).split((self.c, self.c), dim=1)) + + for block in self.blocks: + y.append(block(y[-1])) + + return self.conv_seq_2(torch.cat(y, dim=1)) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py new file mode 100644 index 0000000000..a17592313b --- /dev/null +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -0,0 +1,764 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import math +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from doctr.models.classification import vit_det_m, vit_det_s + +from ...utils import load_pretrained_params +from .base import _LWDETR, LWDETRPostProcessor +from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector + +__all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"] + + +default_cfgs: dict[str, dict[str, Any]] = { + "lw_detr_s": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "class_names": [ + "Caption", + "Footnote", + "Formula", + "List-item", + "Page-footer", + "Page-header", + "Picture", + "Section-header", + "Table", + "Text", + "Title", + "Document Index", + "Code", + "Checkbox-Selected", + "Checkbox-Unselected", + "Form", + "Key-Value Region", + ], + "url": None, + }, + "lw_detr_m": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "class_names": [ + "Caption", + "Footnote", + "Formula", + "List-item", + "Page-footer", + "Page-header", + "Picture", + "Section-header", + "Table", + "Text", + "Title", + "Document Index", + "Code", + "Checkbox-Selected", + "Checkbox-Unselected", + "Form", + "Key-Value Region", + ], + "url": None, + }, +} + + +class LWDETRBackbone(nn.Module): + """Backbone of LW-DETR, based on a ViT Det architecture. The backbone is used as feature extractor. + + Args: + encoder_fn: the function to build the encoder of the backbone, which is a ViT Det architecture + out_channels: number of channels in the output feature maps of the backbone. + num_blocks: number of blocks in the C2fBottleneck of the projector. + """ + + def __init__( + self, + encoder_fn: nn.Module, + out_channels: int = 256, + num_blocks: int = 3, + ) -> None: + super().__init__() + self.encoder = encoder_fn + + _is_training = self.encoder.training + + self.encoder.eval() + with torch.no_grad(): + in_shape = (3, 512, 512) + out = self.encoder(torch.zeros((1, *in_shape))) + # Get the number of channels for each feature map output by the backbone + _shapes = [feat.shape[1] for feat in out] + self.encoder.train(_is_training) + + self.projector = MultiScaleProjector( + in_channels=_shapes, + out_channels=out_channels, + num_blocks=num_blocks, + ) + + def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tuple[torch.Tensor, torch.Tensor]]: + """Forward pass of the backbone. + + Args: + x: batched images, of shape [batch_size x 3 x H x W] + mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + Returns: + A list of tuples (feat, mask) for each feature map, where: + - feat is the feature map of shape [batch_size x out_channels x H' x W'] + - mask is the corresponding attention mask of shape [batch_size x H' x W'], containing 1 on padded pixels + """ + # (H, W, B, C) + feats = self.encoder(x) + feats = self.projector(feats) + # [(B, C, H, W)] + if mask is None: + mask = torch.zeros((x.shape[0], x.shape[2], x.shape[3]), dtype=torch.bool, device=x.device) + return [ + (feat, F.interpolate(mask.unsqueeze(1).float(), size=feat.shape[-2:], mode="nearest").squeeze(1).bool()) + for feat in feats + ] + + +class LWDETR(nn.Module, _LWDETR): + """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection" + `_. + + Args: + feat_extractor: the backbone of the model, used as feature extractor + class_names: list of class names to be detected by the model + score_thresh: the score threshold for post-processing the model outputs + iou_thresh: the IoU threshold for post-processing the model outputs + d_model: the dimension of the model + num_queries: the number of object queries + group_detr: the number of groups in the group DETR architecture + dec_layers: the number of decoder layers + sa_num_heads: the number of heads in the self-attention of the decoder + ca_num_heads: the number of heads in the cross-attention of the decoder + ff_dim: the dimension of the feedforward network in the decoder + dec_n_points: the number of sampling points in the deformable attention of the decoder + dropout_prob: the dropout probability in the decoder + assume_straight_pages: if True, fit straight bounding boxes only + exportable: onnx exportable returns only logits + cfg: the configuration dict of the model + """ + + def __init__( + self, + feat_extractor: LWDETRBackbone, + class_names: list[str], + score_thresh: float = 0.3, + iou_thresh: float = 0.5, + d_model: int = 256, + num_queries: int = 300, + group_detr: int = 13, + dec_layers: int = 3, + sa_num_heads: int = 8, + ca_num_heads: int = 16, + ff_dim: int = 2048, + dec_n_points: int = 2, + dropout_prob: float = 0.0, + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: dict[str, Any] | None = None, + ) -> None: + super().__init__() + + self.class_names = ["__background__"] + class_names + self.num_classes = len(self.class_names) + self.cfg = cfg + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + + self.feat_extractor = feat_extractor + + self.group_detr = group_detr + self.num_queries = num_queries + self.d_model = d_model + + self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) + # Initialize angle to (sin=0, cos=1) + with torch.no_grad(): + self.reference_point_embed.weight[:, 4] = 0.0 # sinθ + self.reference_point_embed.weight[:, 5] = 1.0 # cosθ + + self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) + + self.decoder = LWDETRDecoder( + num_layers=dec_layers, + d_model=d_model, + sa_num_heads=sa_num_heads, + ca_num_heads=ca_num_heads, + ff_dim=ff_dim, + dec_n_points=dec_n_points, + group_detr=group_detr, + dropout_prob=dropout_prob, + ) + + self.enc_output = nn.ModuleList([nn.Linear(self.d_model, self.d_model) for _ in range(self.group_detr)]) + self.enc_output_norm = nn.ModuleList([nn.LayerNorm(self.d_model) for _ in range(self.group_detr)]) + + self.enc_out_bbox_embed = nn.ModuleList([ + LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) for _ in range(self.group_detr) + ]) + self.enc_out_class_embed = nn.ModuleList([ + nn.Linear(self.d_model, self.num_classes) for _ in range(self.group_detr) + ]) + self.class_embed = nn.Linear(self.d_model, self.num_classes) + self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) + + self.postprocessor = LWDETRPostProcessor( + num_classes=self.num_classes, + score_thresh=score_thresh, + iou_thresh=iou_thresh, + assume_straight_pages=self.assume_straight_pages, + ) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): + if hasattr(m, "weight") and m.weight is not None: + nn.init.ones_(m.weight) + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, LWDETRMultiscaleDeformableAttention): + nn.init.constant_(m.sampling_offsets.weight, 0.0) + + thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(m.n_heads, 1, 1, 2) + .repeat(1, m.n_levels, m.n_points, 1) + ) + + for i in range(m.n_points): + grid_init[:, :, i, :] *= i + 1 + + with torch.no_grad(): + m.sampling_offsets.bias.copy_(grid_init.view(-1)) + + nn.init.constant_(m.attention_weights.weight, 0.0) + nn.init.constant_(m.attention_weights.bias, 0.0) + + nn.init.xavier_uniform_(m.value_proj.weight) + nn.init.zeros_(m.value_proj.bias) + + nn.init.xavier_uniform_(m.output_proj.weight) + nn.init.zeros_(m.output_proj.bias) + + if isinstance(m, nn.Linear) and m.out_features == self.num_classes: + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + if m.bias is not None: + nn.init.constant_(m.bias, bias_value) + if isinstance(m, LWDETRHead): + last = m.layers[-1] + if isinstance(last, nn.Linear): + nn.init.zeros_(last.weight) + nn.init.zeros_(last.bias) + + def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: + """Load pretrained parameters onto the model + + Args: + path_or_url: the path or URL to the model parameters (checkpoint) + **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params` + """ + load_pretrained_params(self, path_or_url, **kwargs) + + def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine bounding boxes by applying the predicted deltas to the reference points. + The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. + The refined boxes are computed as follows: + + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + sinθ' = sinθ * cos(delta_θ) + cosθ * sin(delta_θ) + cosθ' = cosθ * cos(delta_θ) - sinθ * sin(delta_θ) + + Args: + reference_points: (N, S, 6) tensor containing the reference points + deltas: (N, S, 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (N, S, 6) tensor containing the refined bounding boxes + """ + reference_points = reference_points.to(deltas.device) + # center + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size + wh = deltas[..., 2:4].exp() * reference_points[..., 2:4] + # rotation (sin, cos) + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + sin_delta = deltas[..., 4:5] + cos_delta = deltas[..., 5:6] + # compose rotations (like adding angles) + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + # normalize + norm = torch.sqrt(sin_new**2 + cos_new**2 + 1e-6) + sin_new = sin_new / norm + cos_new = cos_new / norm + + return torch.cat((cxcy, wh, sin_new, cos_new), dim=-1) + + def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Get the valid ratio of all feature maps. + + Args: + mask: (N, H, W) binary tensor containing 1 on padded pixels + dtype: the desired data type of the output tensor + + Returns: + valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch + """ + _, height, width = mask.shape + valid_height = torch.sum(mask[:, :, 0], 1) + valid_width = torch.sum(mask[:, 0, :], 1) + valid_ratio_height = valid_height.to(dtype) / height + valid_ratio_width = valid_width.to(dtype) / width + valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) + return valid_ratio + + def gen_encoder_output_proposals( + self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the encoder output proposals from encoded enc_output. + + Args: + enc_output: Output of the encoder + padding_mask: Padding mask for `enc_output` + spatial_shapes: Spatial shapes of the feature maps + + Returns: + A tuple of feature map and bbox prediction. + - object_query: Object query features. Later used to directly predict a bounding box. + - output_proposals: Normalized proposals in [0, 1] space. + Invalid positions (padding or out-of-bounds) are filled with 0. + - invalid_mask: Boolean mask that is True for invalid positions + (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). + """ + batch_size = enc_output.shape[0] + proposals = [] + _cur = 0 + for level, (height, width) in enumerate(spatial_shapes): + mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1) + valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, + height - 1, + height, + dtype=enc_output.dtype, + device=enc_output.device, + ), + torch.linspace( + 0, + width - 1, + width, + dtype=enc_output.dtype, + device=enc_output.device, + ), + indexing="ij", + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale + width_height = torch.ones_like(grid) * 0.05 * (2.0**level) + # add default rotation (sin=0, cos=1) + sin = torch.zeros_like(grid[..., :1]) + cos = torch.ones_like(grid[..., :1]) + proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6) + proposals.append(proposal) + _cur += height * width + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1) + invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid + output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) + + # assign each pixel as an object query + object_query = enc_output + object_query = object_query.masked_fill(invalid_mask, float(0)) + return object_query, output_proposals, invalid_mask + + def forward( + self, + input: torch.Tensor, + masks: torch.Tensor, + target: list[tuple[list[int], np.ndarray]] | None = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> dict[str, Any]: + feats = self.feat_extractor(input, masks) + + sources = [] + masks = [] + for source, mask in feats: + sources.append(source) + masks.append(mask) + if mask is None: + raise ValueError("No attention mask was provided") + + if self.training: + reference_points = self.reference_point_embed.weight + query_feat = self.query_feat.weight + else: + # only use one group in inference + reference_points = self.reference_point_embed.weight[: self.num_queries] + query_feat = self.query_feat.weight[: self.num_queries] + + # Prepare encoder inputs (by flattening) + source_flatten: list[torch.Tensor] | torch.Tensor = [] + mask_flatten: list[torch.Tensor] | torch.Tensor = [] + spatial_shapes_list: list[tuple[int, int]] = [] + for source, mask in zip(sources, masks): + batch_size, num_channels, height, width = source.shape + spatial_shape = (height, width) + spatial_shapes_list.append(spatial_shape) + source = source.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + source_flatten.append(source) + mask_flatten.append(mask) + source_flatten = torch.cat(source_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) + + tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) + reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) + + object_query_embedding, output_proposals, invalid_mask = self.gen_encoder_output_proposals( + source_flatten, mask_flatten, spatial_shapes_list + ) + + group_detr = self.group_detr if self.training else 1 + topk = self.num_queries + topk_coords_logits = [] + topk_coords_logits_undetach = [] + object_query_undetach = [] + + # For each group, predict class logits and bbox deltas from the object query embeddings, + # and select the top-k proposals based on the predicted class logits. + for group_id in range(group_detr): + group_object_query = self.enc_output[group_id](object_query_embedding) + group_object_query = self.enc_output_norm[group_id](group_object_query) + + group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) + group_enc_outputs_class = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) + group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) + group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) + + group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] + group_topk_coords_logits_undetach = torch.gather( + group_enc_outputs_coord, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), + ) + group_topk_coords_logits = group_topk_coords_logits_undetach.detach() + group_object_query_undetach = torch.gather( + group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + ) + + topk_coords_logits.append(group_topk_coords_logits) + topk_coords_logits_undetach.append(group_topk_coords_logits_undetach) + object_query_undetach.append(group_object_query_undetach) + + topk_coords_logits = torch.cat(topk_coords_logits, 1) + topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1) + object_query_undetach = torch.cat(object_query_undetach, 1) + + reference_points = self.refine_bboxes(topk_coords_logits, reference_points) + + last_hidden_states, intermediate, intermediate_reference_points = self.decoder( + inputs_embeds=tgt, + reference_points=reference_points, + spatial_shapes_list=spatial_shapes_list, + valid_ratios=valid_ratios, + encoder_hidden_states=source_flatten, + ) + + logits = self.class_embed(last_hidden_states) + pred_boxes_delta = self.bbox_embed(last_hidden_states) + pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + + out: dict[str, Any] = {} + + if self.exportable: + out["logits"] = logits + out["pred_boxes"] = pred_boxes + return out + + if return_model_output or target is None or return_preds: + out["logits"] = logits + + if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable + def _postprocess(logits, boxes): + return self.postprocessor(logits, boxes) + + out["preds"] = _postprocess(logits.detach().cpu().numpy(), pred_boxes.detach().cpu().numpy()) + + if target is not None: + loss = self.compute_loss(logits, pred_boxes, target) + out["loss"] = loss + + return out + + def compute_loss( + self, logits: torch.Tensor, pred_boxes: torch.Tensor, target: list[tuple[list[int], np.ndarray]] + ) -> torch.Tensor: + """ + Compute the loss for LW-DETR. The loss consists of three components: + classification loss, box regression loss, and rotation loss. + The classification loss is a cross-entropy loss between the predicted class logits and the target classes. + The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes, + computed only on the positive samples. + The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation, + averaged over the positive samples. + The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box, + we select the top-k queries with the lowest cost + (combination of classification cost, box regression cost, and rotation cost). + + Args: + logits: (B, Q, C) tensor containing the predicted class logits for each query + pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query + target: list of length B, where each element is a tuple of (classes, boxes) + for the corresponding image in the batch. + - classes: list of length N_i containing the class indices of the N_i ground truth boxes in the image + - boxes: (N_i, 4. 2) array containing the ground truth boxes + in the format [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] + + Returns: + loss: the computed loss value + """ + device = logits.device + B, Q, C = logits.shape + + targets = self.build_target(target) + bg_idx = 0 + + total_cls, total_box, total_rot = ( + torch.tensor(0.0, device=device), + torch.tensor(0.0, device=device), + torch.tensor(0.0, device=device), + ) + + for b in range(B): + tgt_boxes = torch.as_tensor(targets[b]["boxes"], device=device) + tgt_cls = torch.as_tensor(targets[b]["labels"], device=device) + + num_gt = len(tgt_cls) + + pred_logits = logits[b] + pred_boxes_b = pred_boxes[b] + + pred_xy = pred_boxes_b[:, :4] + pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) + + if num_gt == 0: + target_classes = torch.full((Q,), bg_idx, device=device) + total_cls += F.cross_entropy(pred_logits, target_classes) + continue + + tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) + tgt_boxes_xy = tgt_boxes[:, :4] + + with torch.no_grad(): + prob = pred_logits.softmax(-1) + + cost_cls = -prob[:, tgt_cls] # (Q, G) + cost_box = torch.cdist(pred_xy, tgt_boxes_xy, p=1) + cost_rot = 1 - pred_rot @ tgt_rot.T + + cost = cost_cls + cost_box + 0.5 * cost_rot + + # IoU for dynamic k + def box_iou(a, b): + a = torch.cat([a[:, :2] - a[:, 2:] / 2, a[:, :2] + a[:, 2:] / 2], -1) + b = torch.cat([b[:, :2] - b[:, 2:] / 2, b[:, :2] + b[:, 2:] / 2], -1) + + lt = torch.max(a[:, None, :2], b[:, :2]) + rb = torch.min(a[:, None, 2:], b[:, 2:]) + wh = (rb - lt).clamp(min=0) + inter = wh[..., 0] * wh[..., 1] + + area_a = (a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]) + area_b = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) + + union = area_a[:, None] + area_b - inter + 1e-6 + return inter / union + + iou = box_iou(pred_xy, tgt_boxes_xy) + + topk_iou, _ = torch.topk(iou, k=min(10, Q), dim=0) + dynamic_k = torch.clamp(topk_iou.sum(0).int(), min=1) + + pos_mask = torch.zeros(Q, dtype=torch.bool, device=device) + gt_indices = [] + + # SimOTA assignment + for gt_idx in range(num_gt): + k = dynamic_k[gt_idx].item() + + _, query_idx = torch.topk(-cost[:, gt_idx], k=k) + pos_mask[query_idx] = True + + gt_indices.append(torch.full((k,), gt_idx, device=device, dtype=torch.long)) + + if pos_mask.any(): + pos_idx = pos_mask.nonzero(as_tuple=False).squeeze(1) + + # map each positive query to a GT (best-effort stable assignment) + gt_indices = torch.cat(gt_indices)[: len(pos_idx)] + + target_classes = torch.full((Q,), bg_idx, device=device) + target_classes[pos_idx] = tgt_cls[gt_indices] + + total_cls += F.cross_entropy(pred_logits, target_classes) + + total_box += F.smooth_l1_loss( + pred_xy[pos_idx], + tgt_boxes_xy[gt_indices], + ) + + total_rot += (1 - (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1)).mean() + else: + target_classes = torch.full((Q,), bg_idx, device=device) + total_cls += F.cross_entropy(pred_logits, target_classes) + + return torch.as_tensor((total_cls + 5.0 * total_box + total_rot) / B, device=device) + + +def _lw_detr( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + ignore_keys: list[str] | None = None, + **kwargs: Any, +) -> LWDETR: + + # Build the feature extractor + backbone = backbone_fn( # type: ignore[call-arg] + False, + include_top=False, + input_shape=default_cfgs[arch]["input_shape"], + patch_size=kwargs.get("patch_size", (16, 16)), + ) + feat_extractor = LWDETRBackbone(encoder_fn=backbone) + + # Build the model + model = LWDETR( + feat_extractor, + cfg=default_cfgs[arch], + class_names=kwargs.get("class_names", default_cfgs[arch].get("class_names", "")), + **kwargs, + ) + # Load pretrained parameters + if pretrained: + # The number of class_names is not the same as the number of classes in the pretrained model => + # remove the layer weights + _ignore_keys = ( + ignore_keys + if kwargs.get("class_names", default_cfgs[arch].get("class_names")) != default_cfgs[arch].get("class_names") + else None + ) + model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def lw_detr_s(pretrained: bool = False, **kwargs: Any) -> LWDETR: + """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection" + `_. + + >>> import torch + >>> from doctr.models import lw_detr_s + >>> model = lw_detr_s(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + text detection architecture + """ + return _lw_detr( + "lw_detr_s", + pretrained, + vit_det_s, + ignore_keys=[ + "class_embed.weight", + "class_embed.bias", + *[f"enc_out_class_embed.{i}.weight" for i in range(kwargs.get("group_detr", 13))], + *[f"enc_out_class_embed.{i}.bias" for i in range(kwargs.get("group_detr", 13))], + ], + **kwargs, + ) + + +def lw_detr_m(pretrained: bool = False, **kwargs: Any) -> LWDETR: + """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection" + `_. + + >>> import torch + >>> from doctr.models import lw_detr_m + >>> model = lw_detr_m(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + **kwargs: keyword arguments of the LinkNet architecture + + Returns: + text detection architecture + """ + return _lw_detr( + "lw_detr_m", + pretrained, + vit_det_m, + ignore_keys=[ + "class_embed.weight", + "class_embed.bias", + *[f"enc_out_class_embed.{i}.weight" for i in range(kwargs.get("group_detr", 13))], + *[f"enc_out_class_embed.{i}.bias" for i in range(kwargs.get("group_detr", 13))], + ], + **kwargs, + ) diff --git a/doctr/models/layout/predictor/__init__.py b/doctr/models/layout/predictor/__init__.py new file mode 100644 index 0000000000..e3c861310c --- /dev/null +++ b/doctr/models/layout/predictor/__init__.py @@ -0,0 +1 @@ +from .pytorch import * diff --git a/doctr/models/layout/predictor/pytorch.py b/doctr/models/layout/predictor/pytorch.py new file mode 100644 index 0000000000..3cf61f7406 --- /dev/null +++ b/doctr/models/layout/predictor/pytorch.py @@ -0,0 +1,80 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +import numpy as np +import torch +from torch import nn + +from doctr.models.detection._utils import _remove_padding +from doctr.models.preprocessor import PreProcessor +from doctr.models.utils import set_device_and_dtype + +__all__ = ["LayoutPredictor"] + + +class LayoutPredictor(nn.Module): + """Implements an object able to localize layout elements in a document + + Args: + pre_processor: transform inputs for easier batched model inference + model: core layout architecture + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + ) -> None: + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + + @torch.inference_mode() + def forward( + self, + pages: list[np.ndarray], + **kwargs: Any, + ) -> list[dict[str, np.ndarray]]: + # Extract parameters from the preprocessor + preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio + symmetric_pad = self.pre_processor.resize.symmetric_pad + assume_straight_pages = self.model.assume_straight_pages + # This flag is needed to return the padding mask from the preprocessor + self.pre_processor.resize.return_padding_mask = True + + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(pages) + _params = next(self.model.parameters()) + self.model, processed_batches = set_device_and_dtype( + self.model, processed_batches, _params.device, _params.dtype + ) + predicted_batches = [ + self.model(batch[0], batch[1], return_preds=True, return_model_output=True, **kwargs) + for batch in processed_batches + ] + # remap idx to class names + class_names = [ + [self.model.class_names[int(i)] for i in pred[0]] for batch in predicted_batches for pred in batch["preds"] + ] + boxes = [pred[1] for batch in predicted_batches for pred in batch["preds"]] + scores = [pred[2] for batch in predicted_batches for pred in batch["preds"]] + + # Remove padding from loc predictions + preds = _remove_padding( + pages, + [{"pred": box} for box in boxes], + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + assume_straight_pages=assume_straight_pages, # type: ignore[arg-type] + ) + return [ + {"class_names": class_name, "boxes": pred["pred"], "scores": score} + for class_name, pred, score in zip(class_names, preds, scores) + ] diff --git a/doctr/models/layout/zoo.py b/doctr/models/layout/zoo.py new file mode 100644 index 0000000000..52d7c54666 --- /dev/null +++ b/doctr/models/layout/zoo.py @@ -0,0 +1,88 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +from doctr.models.utils import _CompiledModule + +from .. import layout +from ..preprocessor import PreProcessor +from .predictor import LayoutPredictor + +__all__ = ["layout_predictor"] + +ARCHS: list[str] + +ARCHS = ["lw_detr_s", "lw_detr_m"] + + +def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> LayoutPredictor: + if isinstance(arch, str): + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + _model = layout.__dict__[arch]( + pretrained=pretrained, + assume_straight_pages=assume_straight_pages, + ) + else: + # Adding the type for torch compiled models to the allowed architectures + allowed_archs = [layout.LWDETR, _CompiledModule] + + if not isinstance(arch, tuple(allowed_archs)): + raise ValueError(f"unknown architecture: {type(arch)}") + _model = arch + + kwargs.pop("pretrained_backbone", None) + + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 2) + predictor = LayoutPredictor( + PreProcessor(_model.cfg["input_shape"][1:], **kwargs), + _model, + ) + return predictor + + +def layout_predictor( + arch: Any = "lw_detr_s", + pretrained: bool = False, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + batch_size: int = 2, + **kwargs: Any, +) -> LayoutPredictor: + """Layout prediction architecture. + + >>> import numpy as np + >>> from doctr.models import layout_predictor + >>> model = layout_predictor(arch='lw_detr_s', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + arch: name of the architecture or model itself to use (e.g. 'lw_detr_s') + pretrained: If True, returns a model pre-trained on our layout prediction dataset + assume_straight_pages: If True, fit straight boxes to the page + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right + batch_size: number of samples the model processes in parallel + **kwargs: optional keyword arguments passed to the architecture + + Returns: + Layout predictor + """ + return _predictor( + arch=arch, + pretrained=pretrained, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + batch_size=batch_size, + **kwargs, + ) diff --git a/tests/common/test_models_layout.py b/tests/common/test_models_layout.py new file mode 100644 index 0000000000..6549a147fc --- /dev/null +++ b/tests/common/test_models_layout.py @@ -0,0 +1,66 @@ +import numpy as np +import pytest + +from doctr.models.layout.lw_detr.base import LWDETRPostProcessor + + +def test_lwdetr_postprocessor(): + postprocessor = LWDETRPostProcessor( + num_classes=5, + score_thresh=0.2, + iou_thresh=0.5, + topk=50, + assume_straight_pages=True, + ) + + r_postprocessor = LWDETRPostProcessor( + num_classes=5, + score_thresh=0.2, + iou_thresh=0.5, + topk=50, + assume_straight_pages=False, + ) + + # Input validation + with pytest.raises(Exception): + postprocessor(np.random.rand(2, 20, 5).astype(np.float32), np.random.rand(2, 20, 5).astype(np.float32)) + + # Forward pass + batch_size, num_queries = 2, 20 + logits = np.random.randn(batch_size, num_queries, 6).astype(np.float32) + boxes = np.random.rand(batch_size, num_queries, 6).astype(np.float32) + + out = postprocessor(logits, boxes) + r_out = r_postprocessor(logits, boxes) + + # Batch composition + assert isinstance(out, list) + assert len(out) == batch_size + + assert isinstance(r_out, list) + assert len(r_out) == batch_size + + assert all(isinstance(sample, tuple) and len(sample) == 3 for sample in out) + + assert all(isinstance(sample, tuple) and len(sample) == 3 for sample in r_out) + + labels, bboxes, scores = out[0] + r_labels, r_bboxes, r_scores = r_out[0] + + assert isinstance(labels, list) + assert isinstance(scores, list) + assert isinstance(bboxes, np.ndarray) + + # straight pages → (K, 4) + assert bboxes.ndim == 2 + assert bboxes.shape[1] == 4 + + # rotated pages → (K, 4, 2) + assert isinstance(r_bboxes, np.ndarray) + assert r_bboxes.ndim == 3 + assert r_bboxes.shape[2] == 2 + + # Score / label validity + assert all(isinstance(s, float) for s in scores) + assert all(s >= 0 for s in scores) + assert len(labels) == len(scores) diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py new file mode 100644 index 0000000000..8cbb0459ef --- /dev/null +++ b/tests/pytorch/test_models_layout.py @@ -0,0 +1,182 @@ +import os +import tempfile + +import numpy as np +import onnxruntime +import pytest +import torch + +from doctr.io import DocumentFile +from doctr.models import layout +from doctr.models.layout.predictor import LayoutPredictor +from doctr.models.utils import _CompiledModule, export_model_to_onnx + + +@pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["lw_detr_s", (3, 512, 512)], + ["lw_detr_m", (3, 1024, 1024)], + ], +) +def test_layout_models(arch_name, input_shape, train_mode): + batch_size = 2 + model = layout.__dict__[arch_name](pretrained=True) + model = model.train() if train_mode else model.eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + input_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) + target = [] + for _ in range(batch_size): + num_boxes = 5 + + class_ids = torch.randint(0, 10, (num_boxes,)).tolist() + + # random boxes in normalized-ish space + boxes = [] + for _ in range(num_boxes): + cx, cy = torch.rand(2) * 0.8 + 0.1 + w, h = torch.rand(2) * 0.2 + 0.05 + + x1, y1 = cx - w / 2, cy - h / 2 + x2, y2 = cx + w / 2, cy - h / 2 + x3, y3 = cx + w / 2, cy + h / 2 + x4, y4 = cx - w / 2, cy + h / 2 + + boxes.append([ + [x1, y1], + [x2, y2], + [x3, y3], + [x4, y4], + ]) + + target.append((class_ids, torch.tensor(boxes, dtype=torch.float32))) + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + input_masks = input_masks.cuda() + out = model(input_tensor, input_masks, target, return_model_output=True, return_preds=not train_mode) + assert isinstance(out, dict) + assert len(out) == 3 if not train_mode else len(out) == 2 + # Check logits + assert "logits" in out + assert isinstance(out["logits"], torch.Tensor) + + # Check Preds + if not train_mode: + for results in out["preds"]: + assert isinstance(results, tuple) and len(results) == 3 + assert isinstance(results[0], list) and all(isinstance(idxs, int) for idxs in results[0]) + assert isinstance(results[1], np.ndarray) and results[1].shape == (len(results[0]), 4) + assert isinstance(results[2], list) and all(isinstance(scores, float) for scores in results[2]) + # Check class idxs are in the model's num_classes + assert all(0 <= idx < model.num_classes for idx in results[0]) + # Check scores are between 0 and 1 + assert all(0 <= score <= 1 for score in results[2]) + # Check that the number of boxes, labels and scores are the same + assert len(results[0]) == len(results[1]) == len(results[2]) + # Check that boxes are in the range [0, 1] + assert all((box >= 0).all() and (box <= 1).all() for box in results[1]) + # Check loss + assert isinstance(out["loss"], torch.Tensor) + assert hasattr(model, "from_pretrained") + + +@pytest.mark.parametrize( + "arch_name", + [ + "lw_detr_s", + "lw_detr_m", + ], +) +def test_layout_zoo(arch_name): + # Model + predictor = layout.zoo.layout_predictor(arch_name, pretrained=False) + predictor.model.eval() + # object check + assert isinstance(predictor, LayoutPredictor) + input_tensor = np.random.rand(2, 1024, 1024, 3).astype(np.float32) + if torch.cuda.is_available(): + predictor.model.cuda() + + with torch.no_grad(): + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == 2 + assert all(isinstance(sample, dict) for sample in out) + assert all("class_names" in sample and "boxes" in sample and "scores" in sample for sample in out) + assert all(isinstance(sample["class_names"], list) for sample in out) + assert all(isinstance(sample["boxes"], np.ndarray) for sample in out) + assert all(isinstance(sample["scores"], list) for sample in out) + assert all(sample["boxes"].shape[1] == 4 for sample in out) + assert all(len(sample["class_names"]) == len(sample["scores"]) == sample["boxes"].shape[0] for sample in out) + assert all(all(isinstance(score, float) and 0 <= score <= 1 for score in sample["scores"]) for sample in out) + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["lw_detr_s", (3, 1024, 1024)], + ["lw_detr_m", (3, 1024, 1024)], + ], +) +def test_models_onnx_export(arch_name, input_shape): + # Model + batch_size = 2 + model = layout.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + dummy_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) + pt = model(dummy_input, dummy_masks) + pt_logits = pt["logits"].detach().cpu().numpy() + pt_boxes = pt["pred_boxes"].detach().cpu().numpy() + with tempfile.TemporaryDirectory() as tmpdir: + # Export + model_path = export_model_to_onnx( + model, model_name=os.path.join(tmpdir, "model"), dummy_input=(dummy_input, dummy_masks) + ) + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession( + os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] + ) + ort_outs = ort_session.run( + ["logits", "pred_boxes"], {"input": dummy_input.numpy(), "masks": dummy_masks.numpy()} + ) + + assert isinstance(ort_outs, list) and len(ort_outs) == 2 + # Check boxes shape + assert ort_outs[0].shape == pt_logits.shape + assert ort_outs[1].shape == pt_boxes.shape + # Check that the output is close to the PyTorch output - only warn if not close + try: + assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) + assert np.allclose(pt_boxes, ort_outs[1], atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + + +@pytest.mark.parametrize( + "arch_name", + [ + "lw_detr_s", + "lw_detr_m", + ], +) +def test_torch_compiled_models(arch_name, mock_payslip): + doc = DocumentFile.from_images([mock_payslip]) + predictor = layout.zoo.layout_predictor(arch_name, pretrained=True) + assert isinstance(predictor, LayoutPredictor) + out = predictor(doc) + + # Compile the model + compiled_model = torch.compile(layout.__dict__[arch_name](pretrained=True).eval()) + assert isinstance(compiled_model, _CompiledModule) + compiled_predictor = layout.zoo.layout_predictor(compiled_model) + compiled_out = compiled_predictor(doc) + + # Compare that outputs are close + assert len(out) == len(compiled_out) == 1 + # TODO: Enable if the model has a pretrained version + # assert out[0]["class_names"] == compiled_out[0]["class_names"] + # assert np.allclose(out[0]["boxes"], compiled_out[0]["boxes"], atol=1e-4) + # assert np.allclose(out[0]["scores"], compiled_out[0]["scores"], atol=1e-4) From 5d182d4723cfe9cef51c2b42327e9701110be481 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 7 May 2026 10:02:45 +0200 Subject: [PATCH 03/15] Add hf hub layout model logic --- doctr/models/factory/hub.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 7560233b2d..fc93f574e4 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -30,6 +30,7 @@ "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS, "detection": models.detection.zoo.ARCHS, "recognition": models.recognition.zoo.ARCHS, + "layout": models.layout.zoo.ARCHS, } @@ -97,7 +98,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # if run_config is None and arch is None: raise ValueError("run_config or arch must be specified") if task not in ["classification", "detection", "recognition"]: - raise ValueError("task must be one of classification, detection, recognition") + raise ValueError("task must be one of classification, detection, recognition, layout") # default readme readme = textwrap.dedent( @@ -208,6 +209,8 @@ def from_hub(repo_id: str, **kwargs: Any): model = models.detection.__dict__[arch](pretrained=False) elif task == "recognition": model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"]) + elif task == "layout": + model = models.layout.__dict__[arch](pretrained=False, class_names=cfg["class_names"]) # update model cfg model.cfg = cfg From 87111aabe12ea16a50485032c0111f79117ee094 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 7 May 2026 10:19:13 +0200 Subject: [PATCH 04/15] Add hf hub layout model logic --- doctr/models/factory/hub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index fc93f574e4..92b26b5c02 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -97,7 +97,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # if run_config is None and arch is None: raise ValueError("run_config or arch must be specified") - if task not in ["classification", "detection", "recognition"]: + if task not in ["classification", "detection", "recognition", "layout"]: raise ValueError("task must be one of classification, detection, recognition, layout") # default readme From abedd954fa18222e7bed3a49da5f71d46594d165 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 7 May 2026 11:17:24 +0200 Subject: [PATCH 05/15] Update lw-detr onnx test shape --- tests/pytorch/test_models_layout.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index 8cbb0459ef..aa7e687114 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -116,8 +116,8 @@ def test_layout_zoo(arch_name): @pytest.mark.parametrize( "arch_name, input_shape", [ - ["lw_detr_s", (3, 1024, 1024)], - ["lw_detr_m", (3, 1024, 1024)], + ["lw_detr_s", (3, 512, 512)], + ["lw_detr_m", (3, 512, 512)], ], ) def test_models_onnx_export(arch_name, input_shape): From bb683acd07b4f31b9c6e6ea48c409db7d28bdc78 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 11 May 2026 10:55:45 +0200 Subject: [PATCH 06/15] rebase From 3b653cead05d167448bc1b58ee7bfcb80cbe2087 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 11 May 2026 14:24:02 +0200 Subject: [PATCH 07/15] Rebase From 94cd9a5b24446b3ad636d3a3bafdfd7bee461cfa Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 11 May 2026 15:27:50 +0200 Subject: [PATCH 08/15] Fix mypy --- doctr/models/layout/lw_detr/layers/pytorch.py | 27 +++++----- doctr/models/layout/lw_detr/pytorch.py | 50 ++++++++----------- doctr/models/layout/predictor/pytorch.py | 8 +-- tests/common/test_models_layout.py | 4 +- 4 files changed, 43 insertions(+), 46 deletions(-) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index ed5ec086f6..b8f7adc4a9 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -489,15 +489,16 @@ def get_reference(self, reference_points, valid_ratios): def forward( self, - inputs_embeds: torch.Tensor | None = None, - reference_points: torch.Tensor | None = None, - spatial_shapes_list: torch.Tensor | None = None, - valid_ratios: torch.Tensor | None = None, - encoder_hidden_states: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None, + reference_points: torch.Tensor, + spatial_shapes_list: torch.Tensor, + valid_ratios: torch.Tensor, + encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor | None = None, ): - intermediate = () - intermediate_reference_points = (reference_points,) + intermediate: list[torch.Tensor] = [] + + intermediate_reference_points: list[torch.Tensor] = [reference_points] if inputs_embeds is not None: hidden_states = inputs_embeds @@ -513,14 +514,16 @@ def forward( reference_points=reference_points_inputs, spatial_shapes_list=spatial_shapes_list, ) + intermediate_hidden_states = self.layernorm(hidden_states) - intermediate += (intermediate_hidden_states,) + intermediate.append(intermediate_hidden_states) + + intermediate_stack = torch.stack(intermediate) + last_hidden_state = intermediate_stack[-1] - intermediate = torch.stack(intermediate) - last_hidden_state = intermediate[-1] - intermediate_reference_points = torch.stack(intermediate_reference_points) + intermediate_reference_points_stack = torch.stack(intermediate_reference_points) - return last_hidden_state, intermediate, intermediate_reference_points + return last_hidden_state, intermediate_stack, intermediate_reference_points_stack class MultiScaleProjector(nn.Module): diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index a17592313b..512a8509a8 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -177,7 +177,7 @@ def __init__( ) -> None: super().__init__() - self.class_names = ["__background__"] + class_names + self.class_names: list[str] = ["__background__"] + class_names self.num_classes = len(self.class_names) self.cfg = cfg self.exportable = exportable @@ -426,11 +426,12 @@ def forward( ) -> dict[str, Any]: feats = self.feat_extractor(input, masks) - sources = [] - masks = [] + sources: list[torch.Tensor] = [] + feats_masks: list[torch.Tensor] = [] + for source, mask in feats: sources.append(source) - masks.append(mask) + feats_masks.append(mask) if mask is None: raise ValueError("No attention mask was provided") @@ -443,20 +444,20 @@ def forward( query_feat = self.query_feat.weight[: self.num_queries] # Prepare encoder inputs (by flattening) - source_flatten: list[torch.Tensor] | torch.Tensor = [] - mask_flatten: list[torch.Tensor] | torch.Tensor = [] + source_flatten_list: list[torch.Tensor] = [] + mask_flatten_list: list[torch.Tensor] = [] spatial_shapes_list: list[tuple[int, int]] = [] - for source, mask in zip(sources, masks): + for source, mask in zip(sources, feats_masks): batch_size, num_channels, height, width = source.shape spatial_shape = (height, width) spatial_shapes_list.append(spatial_shape) source = source.flatten(2).transpose(1, 2) mask = mask.flatten(1) - source_flatten.append(source) - mask_flatten.append(mask) - source_flatten = torch.cat(source_flatten, 1) - mask_flatten = torch.cat(mask_flatten, 1) - valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) + source_flatten_list.append(source) + mask_flatten_list.append(mask) + source_flatten = torch.cat(source_flatten_list, 1) + mask_flatten = torch.cat(mask_flatten_list, 1) + valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1) tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) @@ -467,9 +468,8 @@ def forward( group_detr = self.group_detr if self.training else 1 topk = self.num_queries - topk_coords_logits = [] - topk_coords_logits_undetach = [] - object_query_undetach = [] + + topk_coords_logits_list: list[torch.Tensor] = [] # For each group, predict class logits and bbox deltas from the object query embeddings, # and select the top-k proposals based on the predicted class logits. @@ -489,17 +489,9 @@ def forward( group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) group_topk_coords_logits = group_topk_coords_logits_undetach.detach() - group_object_query_undetach = torch.gather( - group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) - ) - - topk_coords_logits.append(group_topk_coords_logits) - topk_coords_logits_undetach.append(group_topk_coords_logits_undetach) - object_query_undetach.append(group_object_query_undetach) + topk_coords_logits_list.append(group_topk_coords_logits) - topk_coords_logits = torch.cat(topk_coords_logits, 1) - topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1) - object_query_undetach = torch.cat(object_query_undetach, 1) + topk_coords_logits = torch.cat(topk_coords_logits_list, 1) reference_points = self.refine_bboxes(topk_coords_logits, reference_points) @@ -629,22 +621,22 @@ def box_iou(a, b): dynamic_k = torch.clamp(topk_iou.sum(0).int(), min=1) pos_mask = torch.zeros(Q, dtype=torch.bool, device=device) - gt_indices = [] + gt_indices_list: list[torch.Tensor] = [] # SimOTA assignment for gt_idx in range(num_gt): - k = dynamic_k[gt_idx].item() + k = int(dynamic_k[gt_idx].item()) _, query_idx = torch.topk(-cost[:, gt_idx], k=k) pos_mask[query_idx] = True - gt_indices.append(torch.full((k,), gt_idx, device=device, dtype=torch.long)) + gt_indices_list.append(torch.full((k,), gt_idx, device=device, dtype=torch.long)) if pos_mask.any(): pos_idx = pos_mask.nonzero(as_tuple=False).squeeze(1) # map each positive query to a GT (best-effort stable assignment) - gt_indices = torch.cat(gt_indices)[: len(pos_idx)] + gt_indices = torch.cat(gt_indices_list)[: len(pos_idx)] target_classes = torch.full((Q,), bg_idx, device=device) target_classes[pos_idx] = tgt_cls[gt_indices] diff --git a/doctr/models/layout/predictor/pytorch.py b/doctr/models/layout/predictor/pytorch.py index 3cf61f7406..edd33ded6b 100644 --- a/doctr/models/layout/predictor/pytorch.py +++ b/doctr/models/layout/predictor/pytorch.py @@ -52,7 +52,7 @@ def forward( processed_batches = self.pre_processor(pages) _params = next(self.model.parameters()) - self.model, processed_batches = set_device_and_dtype( + self.model, processed_batches = set_device_and_dtype( # type: ignore[assignment] self.model, processed_batches, _params.device, _params.dtype ) predicted_batches = [ @@ -61,7 +61,9 @@ def forward( ] # remap idx to class names class_names = [ - [self.model.class_names[int(i)] for i in pred[0]] for batch in predicted_batches for pred in batch["preds"] + [self.model.class_names[int(i)] for i in pred[0]] # type: ignore[index] + for batch in predicted_batches + for pred in batch["preds"] ] boxes = [pred[1] for batch in predicted_batches for pred in batch["preds"]] scores = [pred[2] for batch in predicted_batches for pred in batch["preds"]] @@ -75,6 +77,6 @@ def forward( assume_straight_pages=assume_straight_pages, # type: ignore[arg-type] ) return [ - {"class_names": class_name, "boxes": pred["pred"], "scores": score} + {"class_names": class_name, "boxes": pred["pred"], "scores": score} # type: ignore[dict-item] for class_name, pred, score in zip(class_names, preds, scores) ] diff --git a/tests/common/test_models_layout.py b/tests/common/test_models_layout.py index 6549a147fc..61ac2fc116 100644 --- a/tests/common/test_models_layout.py +++ b/tests/common/test_models_layout.py @@ -51,11 +51,11 @@ def test_lwdetr_postprocessor(): assert isinstance(scores, list) assert isinstance(bboxes, np.ndarray) - # straight pages → (K, 4) + # straight pages: (K, 4) assert bboxes.ndim == 2 assert bboxes.shape[1] == 4 - # rotated pages → (K, 4, 2) + # rotated pages: (K, 4, 2) assert isinstance(r_bboxes, np.ndarray) assert r_bboxes.ndim == 3 assert r_bboxes.shape[2] == 2 From 125b1c1e5f0f7e50fab526be1e2be9e757c740d9 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 11 May 2026 15:34:28 +0200 Subject: [PATCH 09/15] Update tests --- docs/source/modules/models.rst | 8 ++++++++ tests/pytorch/test_models_factory.py | 2 ++ 2 files changed, 10 insertions(+) diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst index de5a34f604..55ce88a365 100644 --- a/docs/source/modules/models.rst +++ b/docs/source/modules/models.rst @@ -76,6 +76,14 @@ doctr.models.detection .. autofunction:: doctr.models.detection.detection_predictor +doctr.models.layout +------------------- + +.. autofunction:: doctr.models.layout.lw_detr_s + +.. autofunction:: doctr.models.layout.lw_detr_m + + doctr.models.recognition ------------------------ diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py index 1497ba0f12..88cbd67dc9 100644 --- a/tests/pytorch/test_models_factory.py +++ b/tests/pytorch/test_models_factory.py @@ -49,6 +49,8 @@ def test_push_to_hf_hub(): ["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"], ["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"], ["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"], + ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], + ["lw_detr_m", "layout", "Felix92/doctr-dummy-torch-lw-detr-m"], ], ) def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir): From 4d5da4596265b2dd51d1bea5018b6c36103c7052 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 11 May 2026 15:39:12 +0200 Subject: [PATCH 10/15] mypy and clean --- doctr/models/layout/lw_detr/layers/pytorch.py | 1 - doctr/models/layout/lw_detr/pytorch.py | 11 +++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index b8f7adc4a9..92f8bbaf46 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -232,7 +232,6 @@ def forward( encoder_hidden_states=None, position_embeddings: torch.Tensor | None = None, reference_points=None, - spatial_shapes=None, spatial_shapes_list=None, ) -> tuple[torch.Tensor, torch.Tensor]: # add position embeddings to the hidden states before projecting to queries and keys diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 512a8509a8..05ae1bf355 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -125,7 +125,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tup feats = self.encoder(x) feats = self.projector(feats) # [(B, C, H, W)] - if mask is None: + if mask is None: # pragma: no cover mask = torch.zeros((x.shape[0], x.shape[2], x.shape[3]), dtype=torch.bool, device=x.device) return [ (feat, F.interpolate(mask.unsqueeze(1).float(), size=feat.shape[-2:], mode="nearest").squeeze(1).bool()) @@ -236,7 +236,7 @@ def __init__( nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) - elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) @@ -432,7 +432,7 @@ def forward( for source, mask in feats: sources.append(source) feats_masks.append(mask) - if mask is None: + if mask is None: # pragma: no cover raise ValueError("No attention mask was provided") if self.training: @@ -582,11 +582,6 @@ def compute_loss( pred_xy = pred_boxes_b[:, :4] pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) - if num_gt == 0: - target_classes = torch.full((Q,), bg_idx, device=device) - total_cls += F.cross_entropy(pred_logits, target_classes) - continue - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) tgt_boxes_xy = tgt_boxes[:, :4] From b64f9a345810961742a41a24f65412fdd578af5d Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 11 May 2026 16:05:45 +0200 Subject: [PATCH 11/15] minor reference fix --- doctr/models/layout/lw_detr/layers/pytorch.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 92f8bbaf46..e2f3fc112a 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -475,12 +475,15 @@ def __init__( def get_reference(self, reference_points, valid_ratios): obj_center = reference_points[..., :4] - reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + # Extract angles + angle = reference_points[..., 4:6] # (sin, cos) + angle_expanded = angle[:, :, None] + reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1) # DETR positional encoding - query_sine_embed = gen_sine_position_embeddings(reference_points_inputs[:, :, 0, :], self.d_model) + query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) base_query_pos = self.ref_point_head(query_sine_embed) - # angle embedding - angle = reference_points[..., 4:6] # (sin, cos) + # Angle embedding angle_emb = self.angle_proj(angle) # Combine query_pos = base_query_pos + angle_emb From 75815cc5c2399a6bae9dba2fe72513ac3adbcb12 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 08:31:21 +0200 Subject: [PATCH 12/15] Fix hf upload --- doctr/models/factory/hub.py | 15 +++++++++++---- tests/pytorch/test_models_factory.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 92b26b5c02..45d57a326d 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -102,9 +102,14 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # # default readme readme = textwrap.dedent( - f""" - + f"""--- language: en + tags: + - ocr + - pytorch + - doctr + - {task} + ---

@@ -162,7 +167,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # # Create repository api = HfApi() - api.create_repo(model_name, token=get_token(), exist_ok=False) + repo_url = api.create_repo(model_name, token=get_token(), repo_type="model", exist_ok=False) + full_repo_id = repo_url.repo_id # Save model files to a temporary directory with tempfile.TemporaryDirectory() as tmp_dir: @@ -173,7 +179,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # # Upload all files to the hub api.upload_folder( folder_path=tmp_dir, - repo_id=model_name, + repo_id=full_repo_id, + repo_type="model", commit_message=commit_message, token=get_token(), ) diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py index 88cbd67dc9..db9aed1cdf 100644 --- a/tests/pytorch/test_models_factory.py +++ b/tests/pytorch/test_models_factory.py @@ -34,6 +34,7 @@ def test_push_to_hf_hub(): ["mobilenet_v3_small", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-small"], ["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-large"], ["vit_s", "classification", "Felix92/doctr-dummy-torch-vit-s"], + ["vit_det_s", "classification", "Felix92/doctr-dummy-torch-vit-det-s"], ["textnet_tiny", "classification", "Felix92/doctr-dummy-torch-textnet-tiny"], ["db_resnet34", "detection", "Felix92/doctr-dummy-torch-db-resnet34"], ["db_resnet50", "detection", "Felix92/doctr-dummy-torch-db-resnet50"], @@ -50,7 +51,6 @@ def test_push_to_hf_hub(): ["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"], ["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"], ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], - ["lw_detr_m", "layout", "Felix92/doctr-dummy-torch-lw-detr-m"], ], ) def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir): From 1d0c9a6e6e8356985a0a86d100351d6772631f40 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 11:46:15 +0200 Subject: [PATCH 13/15] Update loss and rot improvements --- doctr/models/layout/lw_detr/layers/pytorch.py | 77 +++++- doctr/models/layout/lw_detr/pytorch.py | 239 ++++++++++++------ 2 files changed, 229 insertions(+), 87 deletions(-) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index e2f3fc112a..6aeda038e4 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -469,9 +469,14 @@ def __init__( for i in range(num_layers) ]) self.layernorm = nn.LayerNorm(self.d_model) + self.bbox_embed = None self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) - self.angle_proj = nn.Linear(2, self.d_model) + self.angle_proj = nn.Sequential( + nn.Linear(4, self.d_model), + nn.ReLU(), + nn.Linear(self.d_model, self.d_model), + ) def get_reference(self, reference_points, valid_ratios): obj_center = reference_points[..., :4] @@ -484,11 +489,56 @@ def get_reference(self, reference_points, valid_ratios): query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) base_query_pos = self.ref_point_head(query_sine_embed) # Angle embedding - angle_emb = self.angle_proj(angle) + sin_t = angle[..., 0:1] + cos_t = angle[..., 1:2] + + angle_feat = torch.cat( + [ + sin_t, + cos_t, + 2 * sin_t * cos_t, + cos_t**2 - sin_t**2, + ], + dim=-1, + ) + + angle_emb = self.angle_proj(angle_feat) # Combine query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos + def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine the reference points using the predicted deltas. + + Args: + reference_points: (batch_size, num_queries, 6) + tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) + deltas: (batch_size, num_queries, 6) + tensor containing the predicted deltas for the reference points in the same format as reference_points + + Returns: + refined_reference_points: (batch_size, num_queries, 6) + tensor containing the refined reference points in the same format as reference_points + """ + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + + wh = deltas[..., 2:4].exp() * reference_points[..., 2:4] + + delta_rot = F.normalize(deltas[..., 4:6], dim=-1) + + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] + + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1) + + return torch.cat((cxcy, wh, rot), dim=-1) + def forward( self, inputs_embeds: torch.Tensor | None, @@ -507,7 +557,7 @@ def forward( reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios) - for decoder_layer in self.layers: + for lid, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -517,8 +567,25 @@ def forward( spatial_shapes_list=spatial_shapes_list, ) - intermediate_hidden_states = self.layernorm(hidden_states) - intermediate.append(intermediate_hidden_states) + hidden_states_norm = self.layernorm(hidden_states) + + # iterative refinement + if self.bbox_embed is not None: + delta = self.bbox_embed(hidden_states_norm) + + reference_points = self.refine_boxes( + reference_points.squeeze(2), + delta, + ) + + intermediate_reference_points.append(reference_points) + + reference_points_inputs, query_pos = self.get_reference( + reference_points, + valid_ratios, + ) + + intermediate.append(hidden_states_norm) intermediate_stack = torch.stack(intermediate) last_hidden_state = intermediate_stack[-1] diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 05ae1bf355..7dcb7e0491 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -5,6 +5,7 @@ import math from collections.abc import Callable +from copy import deepcopy from typing import Any import numpy as np @@ -219,6 +220,7 @@ def __init__( ]) self.class_embed = nn.Linear(self.d_model, self.num_classes) self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) + self.decoder.bbox_embed = self.bbox_embed # type: ignore[assignment] self.postprocessor = LWDETRPostProcessor( num_classes=self.num_classes, @@ -317,20 +319,20 @@ def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] # size wh = deltas[..., 2:4].exp() * reference_points[..., 2:4] - # rotation (sin, cos) + # normalize predicted delta rotation + delta_rot = F.normalize(deltas[..., 4:6], dim=-1) + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] sin_ref = reference_points[..., 4:5] cos_ref = reference_points[..., 5:6] - sin_delta = deltas[..., 4:5] - cos_delta = deltas[..., 5:6] - # compose rotations (like adding angles) + + # compose rotations sin_new = sin_ref * cos_delta + cos_ref * sin_delta cos_new = cos_ref * cos_delta - sin_ref * sin_delta - # normalize - norm = torch.sqrt(sin_new**2 + cos_new**2 + 1e-6) - sin_new = sin_new / norm - cos_new = cos_new / norm + # normalize final rotation + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1) - return torch.cat((cxcy, wh, sin_new, cos_new), dim=-1) + return torch.cat((cxcy, wh, rot), dim=-1) def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Get the valid ratio of all feature maps. @@ -343,8 +345,8 @@ def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32 valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch """ _, height, width = mask.shape - valid_height = torch.sum(mask[:, :, 0], 1) - valid_width = torch.sum(mask[:, 0, :], 1) + valid_height = torch.sum(~mask[:, :, 0], 1) + valid_width = torch.sum(~mask[:, 0, :], 1) valid_ratio_height = valid_height.to(dtype) / height valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) @@ -558,97 +560,168 @@ def compute_loss( Returns: loss: the computed loss value """ - device = logits.device - B, Q, C = logits.shape - targets = self.build_target(target) - bg_idx = 0 + def _sigmoid_focal_loss( + inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0 + ) -> torch.Tensor: + """Compute the sigmoid focal loss between `inputs` and `targets`.""" + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.sum(-1).mean() + + def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters + (mean and covariance). + """ + cxcy = boxes[..., :2] + + w = boxes[..., 2].clamp(min=1e-6) + h = boxes[..., 3].clamp(min=1e-6) + + sin = boxes[..., 4] + cos = boxes[..., 5] + + R = torch.stack( + [ + torch.stack([cos, -sin], dim=-1), + torch.stack([sin, cos], dim=-1), + ], + dim=-2, + ) - total_cls, total_box, total_rot = ( - torch.tensor(0.0, device=device), - torch.tensor(0.0, device=device), - torch.tensor(0.0, device=device), - ) + sx = (w / 2) ** 2 + sy = (h / 2) ** 2 - for b in range(B): - tgt_boxes = torch.as_tensor(targets[b]["boxes"], device=device) - tgt_cls = torch.as_tensor(targets[b]["labels"], device=device) + S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) - num_gt = len(tgt_cls) + S[..., 0, 0] = sx + S[..., 1, 1] = sy - pred_logits = logits[b] - pred_boxes_b = pred_boxes[b] + covariance = R @ S @ R.transpose(-1, -2) + return cxcy, covariance - pred_xy = pred_boxes_b[:, :4] - pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) + def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: + """Compute the ProbIoU loss between predicted boxes and target boxes.""" + mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes) + mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes) - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) - tgt_boxes_xy = tgt_boxes[:, :4] + delta = (mu1 - mu2).unsqueeze(-1) - with torch.no_grad(): - prob = pred_logits.softmax(-1) + sigma = (sigma1 + sigma2) * 0.5 - cost_cls = -prob[:, tgt_cls] # (Q, G) - cost_box = torch.cdist(pred_xy, tgt_boxes_xy, p=1) - cost_rot = 1 - pred_rot @ tgt_rot.T + sigma_inv = torch.linalg.inv(sigma) - cost = cost_cls + cost_box + 0.5 * cost_rot + mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - # IoU for dynamic k - def box_iou(a, b): - a = torch.cat([a[:, :2] - a[:, 2:] / 2, a[:, :2] + a[:, 2:] / 2], -1) - b = torch.cat([b[:, :2] - b[:, 2:] / 2, b[:, :2] + b[:, 2:] / 2], -1) + det_sigma = torch.linalg.det(sigma).clamp(min=1e-6) + det_sigma1 = torch.linalg.det(sigma1).clamp(min=1e-6) + det_sigma2 = torch.linalg.det(sigma2).clamp(min=1e-6) - lt = torch.max(a[:, None, :2], b[:, :2]) - rb = torch.min(a[:, None, 2:], b[:, 2:]) - wh = (rb - lt).clamp(min=0) - inter = wh[..., 0] * wh[..., 1] + bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) - area_a = (a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]) - area_b = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) + probiou = torch.exp(-bhattacharyya) - union = area_a[:, None] + area_b - inter + 1e-6 - return inter / union + return 1 - probiou - iou = box_iou(pred_xy, tgt_boxes_xy) + device = logits.device + B, Q, C = logits.shape + # Build targets + targets = self.build_target(target) - topk_iou, _ = torch.topk(iou, k=min(10, Q), dim=0) - dynamic_k = torch.clamp(topk_iou.sum(0).int(), min=1) + total_cls = torch.tensor(0.0, device=device) + total_box = torch.tensor(0.0, device=device) + total_rot = torch.tensor(0.0, device=device) - pos_mask = torch.zeros(Q, dtype=torch.bool, device=device) - gt_indices_list: list[torch.Tensor] = [] + for b in range(B): + pred_logits = logits[b] + pred_boxes_b = pred_boxes[b] - # SimOTA assignment - for gt_idx in range(num_gt): - k = int(dynamic_k[gt_idx].item()) + tgt_boxes = torch.as_tensor( + targets[b]["boxes"], + device=device, + dtype=pred_boxes.dtype, + ) + tgt_cls = torch.as_tensor( + targets[b]["labels"], + device=device, + dtype=torch.long, + ) - _, query_idx = torch.topk(-cost[:, gt_idx], k=k) - pos_mask[query_idx] = True + num_gt = len(tgt_cls) + if num_gt == 0: + target_onehot = torch.zeros_like(pred_logits) + total_cls += _sigmoid_focal_loss(pred_logits, target_onehot) + continue - gt_indices_list.append(torch.full((k,), gt_idx, device=device, dtype=torch.long)) + pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) + tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) - if pos_mask.any(): - pos_idx = pos_mask.nonzero(as_tuple=False).squeeze(1) + with torch.no_grad(): + cls_prob = pred_logits.sigmoid() + cost_cls = -torch.log(cls_prob[:, tgt_cls].clamp(min=1e-6)) + cost_l1 = torch.cdist( + pred_boxes_b[:, :4], + tgt_boxes[:, :4], + p=1, + ) + cost_rot = 1 - (pred_rot @ tgt_rot.T).abs() + total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot + matching_matrix = torch.zeros( + (Q, num_gt), + dtype=torch.bool, + device=device, + ) - # map each positive query to a GT (best-effort stable assignment) - gt_indices = torch.cat(gt_indices_list)[: len(pos_idx)] + iou_like = 1 - cost_l1 # proxy + dynamic_k = (iou_like.sum(0).int() + 1).clamp(min=5, max=20) - target_classes = torch.full((Q,), bg_idx, device=device) - target_classes[pos_idx] = tgt_cls[gt_indices] + for gt_idx in range(num_gt): + _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item())) + matching_matrix[candidate_idx, gt_idx] = True - total_cls += F.cross_entropy(pred_logits, target_classes) + # resolve duplicate matches + multiple_match_mask = matching_matrix.sum(1) > 1 - total_box += F.smooth_l1_loss( - pred_xy[pos_idx], - tgt_boxes_xy[gt_indices], - ) + if multiple_match_mask.any(): + duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1) + min_cost_idx = total_cost[duplicate_idx].argmin(dim=1) + # Set all matches to False for the duplicate indices, + # then set the match with the lowest cost to True + matching_matrix[duplicate_idx] = False + matching_matrix[duplicate_idx, min_cost_idx] = True + + pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) + + target_onehot = torch.zeros_like(pred_logits) + if len(pos_idx) > 0: + target_onehot[pos_idx, tgt_cls[gt_indices]] = 1 - total_rot += (1 - (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1)).mean() - else: - target_classes = torch.full((Q,), bg_idx, device=device) - total_cls += F.cross_entropy(pred_logits, target_classes) + total_cls += _sigmoid_focal_loss(pred_logits, target_onehot) - return torch.as_tensor((total_cls + 5.0 * total_box + total_rot) / B, device=device) + if len(pos_idx) == 0: + continue + + pred_sel = pred_boxes_b[pos_idx] + tgt_sel = tgt_boxes[gt_indices] + # L1 loss on (cx, cy, w, h) + l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4]) + # ProbIoU loss on the whole box (including rotation) + probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean() + total_box += 2.0 * l1_loss + 0.5 * probiou_loss + # Rotation loss + cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs() + rot_loss = (1 - cos_sim).mean() + total_rot += 0.5 * rot_loss + # Average the loss over the batch + return (total_cls + total_box + total_rot) / B def _lw_detr( @@ -658,6 +731,12 @@ def _lw_detr( ignore_keys: list[str] | None = None, **kwargs: Any, ) -> LWDETR: + # Patch the config + kwargs["class_names"] = kwargs.get("class_names", default_cfgs[arch].get("class_names", [])) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["class_names"] = kwargs["class_names"] + kwargs.pop("class_names") # Build the feature extractor backbone = backbone_fn( # type: ignore[call-arg] @@ -671,19 +750,15 @@ def _lw_detr( # Build the model model = LWDETR( feat_extractor, - cfg=default_cfgs[arch], - class_names=kwargs.get("class_names", default_cfgs[arch].get("class_names", "")), + cfg=_cfg, + class_names=_cfg["class_names"], **kwargs, ) # Load pretrained parameters if pretrained: # The number of class_names is not the same as the number of classes in the pretrained model => # remove the layer weights - _ignore_keys = ( - ignore_keys - if kwargs.get("class_names", default_cfgs[arch].get("class_names")) != default_cfgs[arch].get("class_names") - else None - ) + _ignore_keys = ignore_keys if _cfg["class_names"] != default_cfgs[arch].get("class_names") else None model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model From bac8b6d21db628edc63eb99ce067b76bb5c7d469 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 12:24:54 +0200 Subject: [PATCH 14/15] Update docstrings --- doctr/models/layout/lw_detr/base.py | 16 ++++++++-- doctr/models/layout/lw_detr/layers/pytorch.py | 31 +++++++++++++++---- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 3a7ab89ad9..e103021f5f 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -3,6 +3,8 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +from typing import Any + import cv2 import numpy as np @@ -112,7 +114,6 @@ def _nms(self, polys: np.ndarray, scores: np.ndarray) -> list[int]: return keep def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]: - logits = np.asarray(logits) boxes = np.asarray(boxes) @@ -194,8 +195,19 @@ class _LWDETR(BaseModel): def build_target( self, target: list[tuple[list[int], np.ndarray]], - ): + ) -> list[dict[str, Any]]: + """Build the target for LW-DETR training + Args: + target: list of tuples (class_ids, boxes) where class_ids is a list of class ids for the boxes + and boxes is an array of shape (num_boxes, 8) containing the coordinates of the 4 corners of the box + in the format (x1, y1, x2, y2, x3, y3, x4, y4) + + Returns: + list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6) + containing the box parameters in OBB format (cx, cy, w, h, sin(theta), cos(theta)) + and "labels" is an array of shape (num_boxes,) containing the class labels + """ targets = [] def _quad_to_obb(poly: np.ndarray): diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 6aeda038e4..b964ebfd73 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -17,7 +17,7 @@ class LWDETRHead(nn.Module): """ - Simple MLP used to predict the normalized center coordinates, height and width of a bounding box w.r.t. an image. + Simple MLP used as the reference point head in LW-DETR. Args: input_dim: number of input features @@ -285,7 +285,7 @@ def forward( sampling_locations = center + rotated_offsets else: - raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") output = self.attn( value, @@ -354,8 +354,8 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor | None = None, - reference_points=None, - spatial_shapes_list=None, + reference_points: torch.Tensor | None = None, + spatial_shapes_list: list[tuple] | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: @@ -478,7 +478,26 @@ def __init__( nn.Linear(self.d_model, self.d_model), ) - def get_reference(self, reference_points, valid_ratios): + def get_reference( + self, reference_points: torch.Tensor, valid_ratios: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """This function computes the reference point inputs and positional embeddings for the decoder layers. + + Args: + reference_points: (batch_size, num_queries, 6) + tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) + valid_ratios: (batch_size, num_levels, 2) + tensor containing the valid ratios for each level of the input feature maps + + Returns: + reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4) + tensor containing the reference point inputs for the decoder layers, + which are the normalized center coordinates, + width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps + query_pos: (batch_size, num_queries, d_model) + tensor containing the positional embeddings for the decoder layers, + which are computed from the reference points using sine and cosine functions and a linear projection + """ obj_center = reference_points[..., :4] spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] # Extract angles @@ -640,7 +659,7 @@ def forward(self, x: torch.Tensor) -> list[tuple[torch.Tensor, torch.Tensor]]: class C2fBottleneck(nn.Module): - """Faster implementation of CSP bottleneck with 2 convolutions and 1 residual connection + """Faster implementation of CSP bottleneck with 2 convolutions and 1 residual connection. Args: input_dim: number of input channels From fc8fb635d5707c277c48e32277e47a65a0ce81b8 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 12:29:50 +0200 Subject: [PATCH 15/15] Update docstring --- doctr/models/layout/lw_detr/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 7dcb7e0491..8c97fc626b 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -304,8 +304,8 @@ def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> cy' = cy + delta_cy * h w' = w * exp(delta_w) h' = h * exp(delta_h) - sinθ' = sinθ * cos(delta_θ) + cosθ * sin(delta_θ) - cosθ' = cosθ * cos(delta_θ) - sinθ * sin(delta_θ) + sinθ' = sinθ * cosΔ + cosθ * sinΔ + cosθ' = cosθ * cosΔ - sinθ * sinΔ Args: reference_points: (N, S, 6) tensor containing the reference points