diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst
index de5a34f604..55ce88a365 100644
--- a/docs/source/modules/models.rst
+++ b/docs/source/modules/models.rst
@@ -76,6 +76,14 @@ doctr.models.detection
.. autofunction:: doctr.models.detection.detection_predictor
+doctr.models.layout
+-------------------
+
+.. autofunction:: doctr.models.layout.lw_detr_s
+
+.. autofunction:: doctr.models.layout.lw_detr_m
+
+
doctr.models.recognition
------------------------
diff --git a/doctr/models/__init__.py b/doctr/models/__init__.py
index b6db1c0678..8bdcccd1dd 100644
--- a/doctr/models/__init__.py
+++ b/doctr/models/__init__.py
@@ -1,5 +1,6 @@
from .classification import *
from .detection import *
from .recognition import *
+from .layout import *
from .zoo import *
from .factory import *
diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py
index 7560233b2d..45d57a326d 100644
--- a/doctr/models/factory/hub.py
+++ b/doctr/models/factory/hub.py
@@ -30,6 +30,7 @@
"classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
"detection": models.detection.zoo.ARCHS,
"recognition": models.recognition.zoo.ARCHS,
+ "layout": models.layout.zoo.ARCHS,
}
@@ -96,14 +97,19 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
if run_config is None and arch is None:
raise ValueError("run_config or arch must be specified")
- if task not in ["classification", "detection", "recognition"]:
- raise ValueError("task must be one of classification, detection, recognition")
+ if task not in ["classification", "detection", "recognition", "layout"]:
+ raise ValueError("task must be one of classification, detection, recognition, layout")
# default readme
readme = textwrap.dedent(
- f"""
-
+ f"""---
language: en
+ tags:
+ - ocr
+ - pytorch
+ - doctr
+ - {task}
+ ---
@@ -161,7 +167,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
# Create repository
api = HfApi()
- api.create_repo(model_name, token=get_token(), exist_ok=False)
+ repo_url = api.create_repo(model_name, token=get_token(), repo_type="model", exist_ok=False)
+ full_repo_id = repo_url.repo_id
# Save model files to a temporary directory
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -172,7 +179,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
# Upload all files to the hub
api.upload_folder(
folder_path=tmp_dir,
- repo_id=model_name,
+ repo_id=full_repo_id,
+ repo_type="model",
commit_message=commit_message,
token=get_token(),
)
@@ -208,6 +216,8 @@ def from_hub(repo_id: str, **kwargs: Any):
model = models.detection.__dict__[arch](pretrained=False)
elif task == "recognition":
model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"])
+ elif task == "layout":
+ model = models.layout.__dict__[arch](pretrained=False, class_names=cfg["class_names"])
# update model cfg
model.cfg = cfg
diff --git a/doctr/models/layout/__init__.py b/doctr/models/layout/__init__.py
new file mode 100644
index 0000000000..d4147b2974
--- /dev/null
+++ b/doctr/models/layout/__init__.py
@@ -0,0 +1,2 @@
+from .zoo import *
+from .lw_detr import *
diff --git a/doctr/models/layout/lw_detr/__init__.py b/doctr/models/layout/lw_detr/__init__.py
new file mode 100644
index 0000000000..e3c861310c
--- /dev/null
+++ b/doctr/models/layout/lw_detr/__init__.py
@@ -0,0 +1 @@
+from .pytorch import *
diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py
new file mode 100644
index 0000000000..e103021f5f
--- /dev/null
+++ b/doctr/models/layout/lw_detr/base.py
@@ -0,0 +1,254 @@
+# Copyright (C) 2021-2026, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any
+
+import cv2
+import numpy as np
+
+from doctr.models.core import BaseModel
+
+__all__ = ["_LWDETR", "LWDETRPostProcessor"]
+
+
+class LWDETRPostProcessor:
+ """Implements a post processor for LW-DETR model
+
+ Args:
+ num_classes: number of classes
+ score_thresh: confidence threshold for filtering predictions
+ iou_thresh: IoU threshold for NMS
+ topk: number of top predictions to keep before NMS
+ assume_straight_pages: whether the pages are assumed to be straight (i.e., no rotation)
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ score_thresh: float = 0.3,
+ iou_thresh: float = 0.5,
+ topk: int = 300,
+ assume_straight_pages: bool = True,
+ ):
+ self.num_classes = num_classes
+ self.score_thresh = score_thresh
+ self.iou_thresh = iou_thresh
+ self.topk = topk
+ self.assume_straight_pages = assume_straight_pages
+
+ def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """Decode the predicted boxes from OBB format to polygon format
+
+ Args:
+ boxes: array of predicted boxes in OBB format (N, 6) (cx, cy, w, h, sin(theta), cos(theta))
+
+ Returns:
+ tuple of (polys, angles) where polys is an array of decoded polygons (N, 4, 2)
+ and angles is an array of angles in radians (N,)
+ """
+ cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
+ sin, cos = boxes[:, 4], boxes[:, 5]
+
+ angles = np.arctan2(sin, cos)
+
+ polys = []
+ for i in range(len(boxes)):
+ rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i])))
+
+ poly = cv2.boxPoints(rect)
+ polys.append(poly)
+
+ return np.asarray(polys, dtype=np.float32), angles
+
+ def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float:
+ """Compute the IoU between two polygons
+
+ Args:
+ poly1: first polygon (4, 2)
+ poly2: second polygon (4, 2)
+
+ Returns:
+ IoU between the two polygons
+ """
+ inter = cv2.intersectConvexConvex(
+ poly1.astype(np.float32),
+ poly2.astype(np.float32),
+ )[0]
+
+ if inter <= 0:
+ return 0.0
+
+ area1 = cv2.contourArea(poly1)
+ area2 = cv2.contourArea(poly2)
+
+ return inter / (area1 + area2 - inter + 1e-6)
+
+ def _nms(self, polys: np.ndarray, scores: np.ndarray) -> list[int]:
+ """Perform NMS on the predicted polygons
+
+ Args:
+ polys: array of predicted polygons (N, 4, 2)
+ scores: array of predicted scores (N,)
+
+ Returns:
+ list of indices of the polygons to keep after NMS
+ """
+ idxs = np.argsort(scores)[::-1]
+ keep = []
+
+ while idxs.size > 0:
+ i = idxs[0]
+ keep.append(i)
+
+ if idxs.size == 1:
+ break
+
+ rest = idxs[1:]
+
+ ious = np.array([self._iou(polys[i], polys[j]) for j in rest])
+
+ idxs = rest[ious < self.iou_thresh]
+
+ return keep
+
+ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]:
+ logits = np.asarray(logits)
+ boxes = np.asarray(boxes)
+
+ results: list[tuple[list[int], np.ndarray, list[float]]] = []
+
+ for b in range(boxes.shape[0]):
+ # Convert logits to probabilities and get scores and labels
+ exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True))
+ prob = exp / exp.sum(axis=-1, keepdims=True)
+
+ scores = prob[:, 1:].max(axis=-1)
+ labels = prob[:, 1:].argmax(axis=-1) + 1
+
+ # Keep only topk predictions before NMS
+ if self.topk is not None and len(scores) > self.topk:
+ idxs = np.argsort(scores)[::-1][: self.topk]
+ else:
+ idxs = np.arange(len(scores))
+
+ scores_b = scores[idxs]
+ labels_b = labels[idxs]
+ bboxes = boxes[b][idxs]
+
+ mask = scores_b > self.score_thresh
+
+ bboxes = bboxes[mask]
+ scores_b = scores_b[mask]
+ labels_b = labels_b[mask]
+
+ polys, _ = (
+ self._decode_boxes(bboxes)
+ if len(bboxes) > 0
+ else (
+ np.zeros((0, 4, 2), dtype=np.float32),
+ np.zeros((0,), dtype=np.float32),
+ )
+ )
+
+ keep = self._nms(polys, scores_b) if len(polys) > 0 else []
+
+ final_labels = []
+ final_boxes = []
+ final_scores = []
+
+ for idx in keep:
+ poly = polys[idx].reshape(-1).tolist()
+ if self.assume_straight_pages:
+ x_coords = poly[0::2]
+ y_coords = poly[1::2]
+ xmin, xmax = min(x_coords), max(x_coords)
+ ymin, ymax = min(y_coords), max(y_coords)
+ final_boxes.append([xmin, ymin, xmax, ymax])
+ else:
+ final_boxes.append(poly)
+
+ final_labels.append(int(labels_b[idx]))
+ final_scores.append(float(scores_b[idx]))
+
+ final_boxes_arr = (
+ np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4, 2)
+ if not self.assume_straight_pages
+ else np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4)
+ )
+
+ results.append((
+ final_labels,
+ final_boxes_arr,
+ final_scores,
+ ))
+
+ return results
+
+
+class _LWDETR(BaseModel):
+ """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection"
+ `_.
+ """
+
+ def build_target(
+ self,
+ target: list[tuple[list[int], np.ndarray]],
+ ) -> list[dict[str, Any]]:
+ """Build the target for LW-DETR training
+
+ Args:
+ target: list of tuples (class_ids, boxes) where class_ids is a list of class ids for the boxes
+ and boxes is an array of shape (num_boxes, 8) containing the coordinates of the 4 corners of the box
+ in the format (x1, y1, x2, y2, x3, y3, x4, y4)
+
+ Returns:
+ list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6)
+ containing the box parameters in OBB format (cx, cy, w, h, sin(theta), cos(theta))
+ and "labels" is an array of shape (num_boxes,) containing the class labels
+ """
+ targets = []
+
+ def _quad_to_obb(poly: np.ndarray):
+ p1, p2, p3, p4 = poly
+
+ cx, cy = np.mean(poly, axis=0)
+
+ w = (np.linalg.norm(p2 - p1) + np.linalg.norm(p3 - p4)) / 2
+ h = (np.linalg.norm(p3 - p2) + np.linalg.norm(p4 - p1)) / 2
+
+ theta = np.arctan2(*(p2 - p1)[::-1])
+
+ return np.array(
+ [cx, cy, w, h, np.sin(theta), np.cos(theta)],
+ dtype=np.float32,
+ )
+
+ for class_ids, boxes in target:
+ boxes_all = []
+ labels_all = []
+
+ if len(boxes) == 0:
+ targets.append({
+ "boxes": np.zeros((0, 6), dtype=np.float32),
+ "labels": np.zeros((0,), dtype=np.int64),
+ })
+ continue
+
+ for cls_id, box in zip(np.asarray(class_ids), np.asarray(boxes)):
+ poly = box.reshape(4, 2)
+ obb = _quad_to_obb(poly)
+
+ if obb[2] <= 1e-3 or obb[3] <= 1e-3:
+ continue
+
+ boxes_all.append(obb)
+ labels_all.append(cls_id + 1) # background = 0
+
+ targets.append({
+ "boxes": np.asarray(boxes_all, dtype=np.float32),
+ "labels": np.asarray(labels_all, dtype=np.int64),
+ })
+
+ return targets
diff --git a/doctr/models/layout/lw_detr/layers/__init__.py b/doctr/models/layout/lw_detr/layers/__init__.py
new file mode 100644
index 0000000000..e3c861310c
--- /dev/null
+++ b/doctr/models/layout/lw_detr/layers/__init__.py
@@ -0,0 +1 @@
+from .pytorch import *
diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py
new file mode 100644
index 0000000000..b964ebfd73
--- /dev/null
+++ b/doctr/models/layout/lw_detr/layers/pytorch.py
@@ -0,0 +1,745 @@
+# Copyright (C) 2021-2026, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from doctr.models.modules import ChannelLayerNorm
+from doctr.models.utils import conv_sequence_pt
+
+__all__ = ["MultiScaleProjector", "C2fBottleneck", "LWDETRHead", "LWDETRDecoder", "LWDETRMultiscaleDeformableAttention"]
+
+
+class LWDETRHead(nn.Module):
+ """
+ Simple MLP used as the reference point head in LW-DETR.
+
+ Args:
+ input_dim: number of input features
+ hidden_dim: number of hidden features
+ output_dim: number of output features
+ num_layers: number of layers in the MLP
+ """
+
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for i, layer in enumerate(self.layers):
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class LWDETRMLP(nn.Module):
+ """Simple MLP used in the decoder layers of LW-DETR.
+
+ Args:
+ d_model: number of input and output features
+ ff_dim: number of hidden features
+ dropout_prob: dropout probability
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ ff_dim: int,
+ dropout_prob: float = 0.1,
+ ):
+ super().__init__()
+ self.act = nn.ReLU()
+ self.dropout_1 = nn.Dropout(dropout_prob)
+ self.dropout_2 = nn.Dropout(dropout_prob)
+ self.fc1 = nn.Linear(d_model, ff_dim)
+ self.fc2 = nn.Linear(ff_dim, d_model)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x + self.dropout_2(self.fc2(self.dropout_1(self.act(self.fc1(x)))))
+
+
+class LWDETRAttention(nn.Module):
+ """This module implements the self-attention mechanism used in LW-DETR.
+ It performs multi-head self-attention on the input hidden states.
+ The group detr technique is used during training to add more supervision by
+ using multiple weight-sharing decoders at once for faster convergence.
+
+ Args:
+ sa_num_heads: number of attention heads for self-attention
+ d_model: number of input and output features
+ dropout_prob: dropout probability for attention weights
+ group_detr: number of weight-sharing decoders to use during training
+ layer_idx: index of the decoder layer (used for group detr)
+ """
+
+ def __init__(
+ self,
+ sa_num_heads: int = 8,
+ d_model: int = 256,
+ dropout_prob: float = 0.0,
+ group_detr: int = 13,
+ layer_idx: int = 0,
+ ):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.head_dim = d_model // sa_num_heads
+ self.scaling = self.head_dim**-0.5
+ self.dropout = dropout_prob
+ self.group_detr = group_detr
+
+ self.q_proj = nn.Linear(d_model, sa_num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(d_model, sa_num_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(d_model, sa_num_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(sa_num_heads * self.head_dim, d_model, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, seq_len, _ = hidden_states.shape
+
+ hidden_states_original = hidden_states
+ if position_embeddings is not None:
+ hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
+
+ if self.training:
+ # at training, we use group detr technique to
+ # add more supervision by using multiple weight-sharing decoders at once for faster convergence
+ # at inference, we only use one decoder
+ hidden_states_original = torch.cat(hidden_states_original.split(seq_len // self.group_detr, dim=1), dim=0)
+ hidden_states = torch.cat(hidden_states.split(seq_len // self.group_detr, dim=1), dim=0)
+
+ attention_input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*attention_input_shape, -1, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(*attention_input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if self.training:
+ attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1)
+
+ return attn_output, attn_weights
+
+
+class MultiScaleDeformableAttention(nn.Module):
+ """This module implements MultiScaleDeformableAttention from Deformable DETR.
+ It performs multi-scale deformable attention on the input feature maps.
+ Borrowed from:
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/deformable_detr/modeling_deformable_detr.py
+ """
+
+ def forward(
+ self,
+ value: torch.Tensor,
+ value_spatial_shapes_list: list[tuple],
+ sampling_locations: torch.Tensor,
+ attention_weights: torch.Tensor,
+ ) -> torch.Tensor:
+ batch_size, _, num_heads, hidden_dim = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
+ # batch_size, height*width, num_heads, hidden_dim
+ # -> batch_size, height*width, num_heads*hidden_dim
+ # -> batch_size, num_heads*hidden_dim, height*width
+ # -> batch_size*num_heads, hidden_dim, height, width
+ value_l_ = (
+ value_list[level_id]
+ .flatten(2)
+ .transpose(1, 2)
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
+ )
+ # batch_size, num_queries, num_heads, num_points, 2
+ # -> batch_size, num_heads, num_queries, num_points, 2
+ # -> batch_size*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
+ sampling_value_l_ = nn.functional.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
+ )
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(batch_size, num_heads * hidden_dim, num_queries)
+ )
+ return output.transpose(1, 2).contiguous()
+
+
+class LWDETRMultiscaleDeformableAttention(nn.Module):
+ """
+ Multiscale deformable attention as proposed in Deformable DETR.
+
+ Args:
+ d_model: number of input and output features
+ ca_num_heads: number of attention heads for cross-attention
+ dec_n_points: number of sampling points for each attention head
+ """
+
+ def __init__(
+ self,
+ d_model: int = 256,
+ ca_num_heads: int = 16,
+ dec_n_points: int = 2,
+ ):
+ super().__init__()
+
+ self.attn = MultiScaleDeformableAttention()
+
+ self.d_model = d_model
+ self.n_levels = 1
+ self.n_heads = ca_num_heads
+ self.n_points = dec_n_points
+
+ self.sampling_offsets = nn.Linear(d_model, self.n_heads * self.n_levels * dec_n_points * 2)
+ self.attention_weights = nn.Linear(d_model, self.n_heads * self.n_levels * dec_n_points)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ encoder_hidden_states=None,
+ position_embeddings: torch.Tensor | None = None,
+ reference_points=None,
+ spatial_shapes_list=None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # add position embeddings to the hidden states before projecting to queries and keys
+ if position_embeddings is not None:
+ hidden_states = hidden_states + position_embeddings
+
+ batch_size, num_queries, _ = hidden_states.shape
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
+
+ value = self.value_proj(encoder_hidden_states)
+ if attention_mask is not None:
+ # we invert the attention_mask
+ value = value.masked_fill(~attention_mask[..., None], float(0))
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
+ )
+ attention_weights = self.attention_weights(hidden_states).view(
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
+ )
+ attention_weights = F.softmax(attention_weights, -1).view(
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
+ )
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
+ num_coordinates = reference_points.shape[-1]
+
+ if num_coordinates == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ )
+ elif num_coordinates == 6:
+ ref = reference_points[:, :, None, :, None, :] # (..., 6)
+
+ center = ref[..., :2] # (cx, cy)
+ wh = ref[..., 2:4] # (w, h)
+ sin = ref[..., 4:5] # sinθ
+ cos = ref[..., 5:6] # cosθ
+
+ # normalize offsets
+ offsets = sampling_offsets / self.n_points * wh * 0.5
+
+ dx = offsets[..., 0:1]
+ dy = offsets[..., 1:2]
+
+ # rotate offsets
+ dx_rot = dx * cos - dy * sin
+ dy_rot = dx * sin + dy * cos
+
+ rotated_offsets = torch.cat([dx_rot, dy_rot], dim=-1)
+
+ sampling_locations = center + rotated_offsets
+ else:
+ raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}")
+
+ output = self.attn(
+ value,
+ spatial_shapes_list,
+ sampling_locations,
+ attention_weights,
+ )
+
+ output = self.output_proj(output)
+
+ return output, attention_weights
+
+
+class LWDETRDecoderLayer(nn.Module):
+ """This module implements a single decoder layer of LW-DETR,
+ which consists of self-attention, cross-attention and an MLP.
+
+ Args:
+ d_model: number of input and output features
+ ff_dim: number of hidden features in the MLP
+ dropout_prob: dropout probability for the attention and MLP layers
+ ca_num_heads: number of attention heads for cross-attention
+ dec_n_points: number of sampling points for each attention head in cross-attention
+ sa_num_heads: number of attention heads for self-attention
+ group_detr: number of weight-sharing decoders to use during training for the group detr technique
+ layer_idx: index of the decoder layer (used for group detr)
+ """
+
+ def __init__(
+ self,
+ d_model: int = 256,
+ ff_dim: int = 2048,
+ dropout_prob: float = 0.0,
+ ca_num_heads: int = 16,
+ dec_n_points: int = 2,
+ sa_num_heads: int = 8,
+ group_detr: int = 13,
+ layer_idx: int = 0,
+ ):
+ super().__init__()
+ self.dropout = dropout_prob
+
+ # self-attention
+ self.self_attn = LWDETRAttention(
+ sa_num_heads=sa_num_heads,
+ d_model=d_model,
+ dropout_prob=dropout_prob,
+ group_detr=group_detr,
+ layer_idx=layer_idx,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(d_model)
+
+ # cross-attention
+ self.cross_attn = LWDETRMultiscaleDeformableAttention(
+ d_model=d_model,
+ ca_num_heads=ca_num_heads,
+ dec_n_points=dec_n_points,
+ )
+ self.cross_attn_layer_norm = nn.LayerNorm(d_model)
+
+ # mlp
+ self.mlp = LWDETRMLP(d_model, ff_dim, dropout_prob)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor | None = None,
+ reference_points: torch.Tensor | None = None,
+ spatial_shapes_list: list[tuple] | None = None,
+ encoder_hidden_states: torch.Tensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ self_attention_output, self_attn_weights = self.self_attn(
+ hidden_states, position_embeddings=position_embeddings
+ )
+
+ self_attention_output = F.dropout(self_attention_output, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + self_attention_output
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ cross_attention_output, cross_attn_weights = self.cross_attn(
+ hidden_states=hidden_states,
+ attention_mask=encoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ position_embeddings=position_embeddings,
+ reference_points=reference_points,
+ spatial_shapes_list=spatial_shapes_list,
+ )
+ cross_attention_output = F.dropout(cross_attention_output, p=self.dropout, training=self.training)
+ hidden_states = hidden_states + cross_attention_output
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
+
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+
+ return hidden_states
+
+
+# function to generate sine positional embedding for 4d coordinates
+# Borrowed from: https://github.com/Atten4Vis/LW-DETR/blob/main/models/transformer.py
+def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 256) -> torch.Tensor:
+ """
+ This function computes position embeddings using sine and cosine functions from the input positional tensor,
+ which has a shape of (batch_size, num_queries, 4).
+ The last dimension of `pos_tensor` represents the following coordinates:
+ - 0: x-coord
+ - 1: y-coord
+ - 2: width
+ - 3: height
+
+ The output shape is (batch_size, num_queries, 512),
+ where final dim (hidden_size*2 = 512) is the total embedding dimension
+ achieved by concatenating the sine and cosine values for each coordinate.
+ """
+ scale = 2 * math.pi
+ dim = hidden_size // 2
+ dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
+ dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
+ x_embed = pos_tensor[:, :, 0] * scale
+ y_embed = pos_tensor[:, :, 1] * scale
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
+ if pos_tensor.size(-1) == 4:
+ w_embed = pos_tensor[:, :, 2] * scale
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ h_embed = pos_tensor[:, :, 3] * scale
+ pos_h = h_embed[:, :, None] / dim_t
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+ else:
+ raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
+ return pos.to(pos_tensor.dtype)
+
+
+class LWDETRDecoder(nn.Module):
+ """This module implements the decoder of LW-DETR,
+ which consists of multiple decoder layers and a reference point head.
+
+ Args:
+ num_layers: number of decoder layers
+ d_model: number of input and output features for each decoder layer
+ sa_num_heads: number of attention heads for self-attention in each decoder layer
+ ca_num_heads: number of attention heads for cross-attention in each decoder layer
+ ff_dim: number of hidden features in the MLP of each decoder layer
+ dec_n_points: number of sampling points for each attention head in cross-attention of each decoder layer
+ group_detr: number of weight-sharing decoders to use during training for the group detr technique
+ dropout_prob: dropout probability for the attention and MLP layers in each decoder layer
+ """
+
+ def __init__(
+ self,
+ num_layers: int = 3,
+ d_model: int = 256,
+ sa_num_heads: int = 8,
+ ca_num_heads: int = 16,
+ ff_dim: int = 2048,
+ dec_n_points: int = 2,
+ group_detr: int = 13,
+ dropout_prob: float = 0.0,
+ ):
+ super().__init__()
+ self.dropout_prob = dropout_prob
+ self.d_model = d_model
+ self.layers = nn.ModuleList([
+ LWDETRDecoderLayer(
+ d_model=self.d_model,
+ sa_num_heads=sa_num_heads,
+ ca_num_heads=ca_num_heads,
+ ff_dim=ff_dim,
+ dec_n_points=dec_n_points,
+ group_detr=group_detr,
+ dropout_prob=dropout_prob,
+ layer_idx=i,
+ )
+ for i in range(num_layers)
+ ])
+ self.layernorm = nn.LayerNorm(self.d_model)
+ self.bbox_embed = None
+
+ self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2)
+ self.angle_proj = nn.Sequential(
+ nn.Linear(4, self.d_model),
+ nn.ReLU(),
+ nn.Linear(self.d_model, self.d_model),
+ )
+
+ def get_reference(
+ self, reference_points: torch.Tensor, valid_ratios: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """This function computes the reference point inputs and positional embeddings for the decoder layers.
+
+ Args:
+ reference_points: (batch_size, num_queries, 6)
+ tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ)
+ valid_ratios: (batch_size, num_levels, 2)
+ tensor containing the valid ratios for each level of the input feature maps
+
+ Returns:
+ reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4)
+ tensor containing the reference point inputs for the decoder layers,
+ which are the normalized center coordinates,
+ width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps
+ query_pos: (batch_size, num_queries, d_model)
+ tensor containing the positional embeddings for the decoder layers,
+ which are computed from the reference points using sine and cosine functions and a linear projection
+ """
+ obj_center = reference_points[..., :4]
+ spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+ # Extract angles
+ angle = reference_points[..., 4:6] # (sin, cos)
+ angle_expanded = angle[:, :, None]
+ reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1)
+ # DETR positional encoding
+ query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model)
+ base_query_pos = self.ref_point_head(query_sine_embed)
+ # Angle embedding
+ sin_t = angle[..., 0:1]
+ cos_t = angle[..., 1:2]
+
+ angle_feat = torch.cat(
+ [
+ sin_t,
+ cos_t,
+ 2 * sin_t * cos_t,
+ cos_t**2 - sin_t**2,
+ ],
+ dim=-1,
+ )
+
+ angle_emb = self.angle_proj(angle_feat)
+ # Combine
+ query_pos = base_query_pos + angle_emb
+ return reference_points_inputs, query_pos
+
+ def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
+ """Refine the reference points using the predicted deltas.
+
+ Args:
+ reference_points: (batch_size, num_queries, 6)
+ tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ)
+ deltas: (batch_size, num_queries, 6)
+ tensor containing the predicted deltas for the reference points in the same format as reference_points
+
+ Returns:
+ refined_reference_points: (batch_size, num_queries, 6)
+ tensor containing the refined reference points in the same format as reference_points
+ """
+ cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
+
+ wh = deltas[..., 2:4].exp() * reference_points[..., 2:4]
+
+ delta_rot = F.normalize(deltas[..., 4:6], dim=-1)
+
+ sin_delta = delta_rot[..., 0:1]
+ cos_delta = delta_rot[..., 1:2]
+
+ sin_ref = reference_points[..., 4:5]
+ cos_ref = reference_points[..., 5:6]
+
+ sin_new = sin_ref * cos_delta + cos_ref * sin_delta
+ cos_new = cos_ref * cos_delta - sin_ref * sin_delta
+
+ rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1)
+
+ return torch.cat((cxcy, wh, rot), dim=-1)
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor | None,
+ reference_points: torch.Tensor,
+ spatial_shapes_list: torch.Tensor,
+ valid_ratios: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor | None = None,
+ ):
+ intermediate: list[torch.Tensor] = []
+
+ intermediate_reference_points: list[torch.Tensor] = [reference_points]
+
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+
+ reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios)
+
+ for lid, decoder_layer in enumerate(self.layers):
+ hidden_states = decoder_layer(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ position_embeddings=query_pos,
+ reference_points=reference_points_inputs,
+ spatial_shapes_list=spatial_shapes_list,
+ )
+
+ hidden_states_norm = self.layernorm(hidden_states)
+
+ # iterative refinement
+ if self.bbox_embed is not None:
+ delta = self.bbox_embed(hidden_states_norm)
+
+ reference_points = self.refine_boxes(
+ reference_points.squeeze(2),
+ delta,
+ )
+
+ intermediate_reference_points.append(reference_points)
+
+ reference_points_inputs, query_pos = self.get_reference(
+ reference_points,
+ valid_ratios,
+ )
+
+ intermediate.append(hidden_states_norm)
+
+ intermediate_stack = torch.stack(intermediate)
+ last_hidden_state = intermediate_stack[-1]
+
+ intermediate_reference_points_stack = torch.stack(intermediate_reference_points)
+
+ return last_hidden_state, intermediate_stack, intermediate_reference_points_stack
+
+
+class MultiScaleProjector(nn.Module):
+ """
+ This module implements MultiScaleProjector in :paper:`lwdetr`.
+ It creates pyramid features built on top of the input feature map.
+ This is modified from the original MultiScaleProjector to use only the levels used in LW-DETR small and medium.
+
+ Args:
+ in_channels (list[int]): list of input channels for each level of the input feature maps.
+ out_channels (int): number of channels in the output feature maps.
+ num_blocks (int): number of blocks in the C2fBottleneck.
+ """
+
+ def __init__(self, in_channels: list[int], out_channels: int, num_blocks: int = 3):
+ super().__init__()
+
+ self.use_extra_pool = False
+
+ self.stages_sampling = nn.ModuleList()
+ self.stages = nn.ModuleList()
+
+ sampling_layers = nn.ModuleList()
+ out_dim: int = 0
+
+ for in_dim in in_channels:
+ layers, out_dim = [nn.Identity()], in_dim
+ sampling_layers.append(nn.Sequential(*layers))
+
+ self.stages_sampling.append(sampling_layers)
+
+ fusion_in_dim = out_dim * len(in_channels)
+
+ self.stages.append(
+ nn.Sequential(
+ C2fBottleneck(fusion_in_dim, out_channels, num_blocks),
+ ChannelLayerNorm(out_channels),
+ )
+ )
+
+ def forward(self, x: torch.Tensor) -> list[tuple[torch.Tensor, torch.Tensor]]:
+ feats = [layer(xi) for layer, xi in zip(self.stages_sampling[0], x)] # type: ignore[call-overload]
+ fused = torch.cat(feats, dim=1)
+ return [self.stages[0](fused)]
+
+
+class C2fBottleneck(nn.Module):
+ """Faster implementation of CSP bottleneck with 2 convolutions and 1 residual connection.
+
+ Args:
+ input_dim: number of input channels
+ out_channels: number of output channels
+ num_blocks: number of bottleneck blocks
+ """
+
+ def __init__(self, input_dim: int, out_channels: int, num_blocks: int):
+ super().__init__()
+
+ self.c = int(out_channels * 0.5)
+
+ self.conv_seq_1 = nn.Sequential(
+ *conv_sequence_pt(
+ in_channels=input_dim,
+ out_channels=2 * self.c,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ dilation=1,
+ act=True,
+ bias=False,
+ bn=True,
+ activation=nn.SiLU(inplace=True),
+ )
+ )
+
+ self.blocks = nn.ModuleList([
+ nn.Sequential(
+ *conv_sequence_pt(
+ in_channels=self.c,
+ out_channels=self.c,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ dilation=1,
+ act=True,
+ bias=False,
+ bn=True,
+ activation=nn.SiLU(inplace=True),
+ ),
+ *conv_sequence_pt(
+ in_channels=self.c,
+ out_channels=self.c,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ dilation=1,
+ act=True,
+ bias=False,
+ bn=True,
+ activation=nn.SiLU(inplace=True),
+ ),
+ )
+ for _ in range(num_blocks)
+ ])
+
+ self.conv_seq_2 = nn.Sequential(
+ *conv_sequence_pt(
+ in_channels=(2 + num_blocks) * self.c,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ dilation=1,
+ act=True,
+ bias=False,
+ bn=True,
+ activation=nn.SiLU(inplace=True),
+ )
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = list(self.conv_seq_1(x).split((self.c, self.c), dim=1))
+
+ for block in self.blocks:
+ y.append(block(y[-1]))
+
+ return self.conv_seq_2(torch.cat(y, dim=1))
diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py
new file mode 100644
index 0000000000..8c97fc626b
--- /dev/null
+++ b/doctr/models/layout/lw_detr/pytorch.py
@@ -0,0 +1,826 @@
+# Copyright (C) 2021-2026, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+import math
+from collections.abc import Callable
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from doctr.models.classification import vit_det_m, vit_det_s
+
+from ...utils import load_pretrained_params
+from .base import _LWDETR, LWDETRPostProcessor
+from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector
+
+__all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"]
+
+
+default_cfgs: dict[str, dict[str, Any]] = {
+ "lw_detr_s": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "class_names": [
+ "Caption",
+ "Footnote",
+ "Formula",
+ "List-item",
+ "Page-footer",
+ "Page-header",
+ "Picture",
+ "Section-header",
+ "Table",
+ "Text",
+ "Title",
+ "Document Index",
+ "Code",
+ "Checkbox-Selected",
+ "Checkbox-Unselected",
+ "Form",
+ "Key-Value Region",
+ ],
+ "url": None,
+ },
+ "lw_detr_m": {
+ "input_shape": (3, 1024, 1024),
+ "mean": (0.798, 0.785, 0.772),
+ "std": (0.264, 0.2749, 0.287),
+ "class_names": [
+ "Caption",
+ "Footnote",
+ "Formula",
+ "List-item",
+ "Page-footer",
+ "Page-header",
+ "Picture",
+ "Section-header",
+ "Table",
+ "Text",
+ "Title",
+ "Document Index",
+ "Code",
+ "Checkbox-Selected",
+ "Checkbox-Unselected",
+ "Form",
+ "Key-Value Region",
+ ],
+ "url": None,
+ },
+}
+
+
+class LWDETRBackbone(nn.Module):
+ """Backbone of LW-DETR, based on a ViT Det architecture. The backbone is used as feature extractor.
+
+ Args:
+ encoder_fn: the function to build the encoder of the backbone, which is a ViT Det architecture
+ out_channels: number of channels in the output feature maps of the backbone.
+ num_blocks: number of blocks in the C2fBottleneck of the projector.
+ """
+
+ def __init__(
+ self,
+ encoder_fn: nn.Module,
+ out_channels: int = 256,
+ num_blocks: int = 3,
+ ) -> None:
+ super().__init__()
+ self.encoder = encoder_fn
+
+ _is_training = self.encoder.training
+
+ self.encoder.eval()
+ with torch.no_grad():
+ in_shape = (3, 512, 512)
+ out = self.encoder(torch.zeros((1, *in_shape)))
+ # Get the number of channels for each feature map output by the backbone
+ _shapes = [feat.shape[1] for feat in out]
+ self.encoder.train(_is_training)
+
+ self.projector = MultiScaleProjector(
+ in_channels=_shapes,
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ )
+
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> list[tuple[torch.Tensor, torch.Tensor]]:
+ """Forward pass of the backbone.
+
+ Args:
+ x: batched images, of shape [batch_size x 3 x H x W]
+ mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+
+ Returns:
+ A list of tuples (feat, mask) for each feature map, where:
+ - feat is the feature map of shape [batch_size x out_channels x H' x W']
+ - mask is the corresponding attention mask of shape [batch_size x H' x W'], containing 1 on padded pixels
+ """
+ # (H, W, B, C)
+ feats = self.encoder(x)
+ feats = self.projector(feats)
+ # [(B, C, H, W)]
+ if mask is None: # pragma: no cover
+ mask = torch.zeros((x.shape[0], x.shape[2], x.shape[3]), dtype=torch.bool, device=x.device)
+ return [
+ (feat, F.interpolate(mask.unsqueeze(1).float(), size=feat.shape[-2:], mode="nearest").squeeze(1).bool())
+ for feat in feats
+ ]
+
+
+class LWDETR(nn.Module, _LWDETR):
+ """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection"
+ `_.
+
+ Args:
+ feat_extractor: the backbone of the model, used as feature extractor
+ class_names: list of class names to be detected by the model
+ score_thresh: the score threshold for post-processing the model outputs
+ iou_thresh: the IoU threshold for post-processing the model outputs
+ d_model: the dimension of the model
+ num_queries: the number of object queries
+ group_detr: the number of groups in the group DETR architecture
+ dec_layers: the number of decoder layers
+ sa_num_heads: the number of heads in the self-attention of the decoder
+ ca_num_heads: the number of heads in the cross-attention of the decoder
+ ff_dim: the dimension of the feedforward network in the decoder
+ dec_n_points: the number of sampling points in the deformable attention of the decoder
+ dropout_prob: the dropout probability in the decoder
+ assume_straight_pages: if True, fit straight bounding boxes only
+ exportable: onnx exportable returns only logits
+ cfg: the configuration dict of the model
+ """
+
+ def __init__(
+ self,
+ feat_extractor: LWDETRBackbone,
+ class_names: list[str],
+ score_thresh: float = 0.3,
+ iou_thresh: float = 0.5,
+ d_model: int = 256,
+ num_queries: int = 300,
+ group_detr: int = 13,
+ dec_layers: int = 3,
+ sa_num_heads: int = 8,
+ ca_num_heads: int = 16,
+ ff_dim: int = 2048,
+ dec_n_points: int = 2,
+ dropout_prob: float = 0.0,
+ assume_straight_pages: bool = True,
+ exportable: bool = False,
+ cfg: dict[str, Any] | None = None,
+ ) -> None:
+ super().__init__()
+
+ self.class_names: list[str] = ["__background__"] + class_names
+ self.num_classes = len(self.class_names)
+ self.cfg = cfg
+ self.exportable = exportable
+ self.assume_straight_pages = assume_straight_pages
+
+ self.feat_extractor = feat_extractor
+
+ self.group_detr = group_detr
+ self.num_queries = num_queries
+ self.d_model = d_model
+
+ self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6)
+ # Initialize angle to (sin=0, cos=1)
+ with torch.no_grad():
+ self.reference_point_embed.weight[:, 4] = 0.0 # sinθ
+ self.reference_point_embed.weight[:, 5] = 1.0 # cosθ
+
+ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model)
+
+ self.decoder = LWDETRDecoder(
+ num_layers=dec_layers,
+ d_model=d_model,
+ sa_num_heads=sa_num_heads,
+ ca_num_heads=ca_num_heads,
+ ff_dim=ff_dim,
+ dec_n_points=dec_n_points,
+ group_detr=group_detr,
+ dropout_prob=dropout_prob,
+ )
+
+ self.enc_output = nn.ModuleList([nn.Linear(self.d_model, self.d_model) for _ in range(self.group_detr)])
+ self.enc_output_norm = nn.ModuleList([nn.LayerNorm(self.d_model) for _ in range(self.group_detr)])
+
+ self.enc_out_bbox_embed = nn.ModuleList([
+ LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) for _ in range(self.group_detr)
+ ])
+ self.enc_out_class_embed = nn.ModuleList([
+ nn.Linear(self.d_model, self.num_classes) for _ in range(self.group_detr)
+ ])
+ self.class_embed = nn.Linear(self.d_model, self.num_classes)
+ self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3)
+ self.decoder.bbox_embed = self.bbox_embed # type: ignore[assignment]
+
+ self.postprocessor = LWDETRPostProcessor(
+ num_classes=self.num_classes,
+ score_thresh=score_thresh,
+ iou_thresh=iou_thresh,
+ assume_straight_pages=self.assume_straight_pages,
+ )
+
+ for n, m in self.named_modules():
+ # Don't override the initialization of the backbone
+ if n.startswith("feat_extractor."):
+ continue
+
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
+ if hasattr(m, "weight") and m.weight is not None:
+ nn.init.ones_(m.weight)
+ if hasattr(m, "bias") and m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Embedding):
+ nn.init.normal_(m.weight, std=0.02)
+ elif isinstance(m, LWDETRMultiscaleDeformableAttention):
+ nn.init.constant_(m.sampling_offsets.weight, 0.0)
+
+ thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(m.n_heads, 1, 1, 2)
+ .repeat(1, m.n_levels, m.n_points, 1)
+ )
+
+ for i in range(m.n_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ with torch.no_grad():
+ m.sampling_offsets.bias.copy_(grid_init.view(-1))
+
+ nn.init.constant_(m.attention_weights.weight, 0.0)
+ nn.init.constant_(m.attention_weights.bias, 0.0)
+
+ nn.init.xavier_uniform_(m.value_proj.weight)
+ nn.init.zeros_(m.value_proj.bias)
+
+ nn.init.xavier_uniform_(m.output_proj.weight)
+ nn.init.zeros_(m.output_proj.bias)
+
+ if isinstance(m, nn.Linear) and m.out_features == self.num_classes:
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, bias_value)
+ if isinstance(m, LWDETRHead):
+ last = m.layers[-1]
+ if isinstance(last, nn.Linear):
+ nn.init.zeros_(last.weight)
+ nn.init.zeros_(last.bias)
+
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
+ """Load pretrained parameters onto the model
+
+ Args:
+ path_or_url: the path or URL to the model parameters (checkpoint)
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
+ """
+ load_pretrained_params(self, path_or_url, **kwargs)
+
+ def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
+ """Refine bounding boxes by applying the predicted deltas to the reference points.
+ The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format.
+ The refined boxes are computed as follows:
+
+ cx' = cx + delta_cx * w
+ cy' = cy + delta_cy * h
+ w' = w * exp(delta_w)
+ h' = h * exp(delta_h)
+ sinθ' = sinθ * cosΔ + cosθ * sinΔ
+ cosθ' = cosθ * cosΔ - sinθ * sinΔ
+
+ Args:
+ reference_points: (N, S, 6) tensor containing the reference points
+ deltas: (N, S, 6) tensor containing the predicted deltas
+
+ Returns:
+ refined_boxes: (N, S, 6) tensor containing the refined bounding boxes
+ """
+ reference_points = reference_points.to(deltas.device)
+ # center
+ cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
+ # size
+ wh = deltas[..., 2:4].exp() * reference_points[..., 2:4]
+ # normalize predicted delta rotation
+ delta_rot = F.normalize(deltas[..., 4:6], dim=-1)
+ sin_delta = delta_rot[..., 0:1]
+ cos_delta = delta_rot[..., 1:2]
+ sin_ref = reference_points[..., 4:5]
+ cos_ref = reference_points[..., 5:6]
+
+ # compose rotations
+ sin_new = sin_ref * cos_delta + cos_ref * sin_delta
+ cos_new = cos_ref * cos_delta - sin_ref * sin_delta
+ # normalize final rotation
+ rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1)
+
+ return torch.cat((cxcy, wh, rot), dim=-1)
+
+ def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Get the valid ratio of all feature maps.
+
+ Args:
+ mask: (N, H, W) binary tensor containing 1 on padded pixels
+ dtype: the desired data type of the output tensor
+
+ Returns:
+ valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch
+ """
+ _, height, width = mask.shape
+ valid_height = torch.sum(~mask[:, :, 0], 1)
+ valid_width = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_height = valid_height.to(dtype) / height
+ valid_ratio_width = valid_width.to(dtype) / width
+ valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
+ return valid_ratio
+
+ def gen_encoder_output_proposals(
+ self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Generate the encoder output proposals from encoded enc_output.
+
+ Args:
+ enc_output: Output of the encoder
+ padding_mask: Padding mask for `enc_output`
+ spatial_shapes: Spatial shapes of the feature maps
+
+ Returns:
+ A tuple of feature map and bbox prediction.
+ - object_query: Object query features. Later used to directly predict a bounding box.
+ - output_proposals: Normalized proposals in [0, 1] space.
+ Invalid positions (padding or out-of-bounds) are filled with 0.
+ - invalid_mask: Boolean mask that is True for invalid positions
+ (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)).
+ """
+ batch_size = enc_output.shape[0]
+ proposals = []
+ _cur = 0
+ for level, (height, width) in enumerate(spatial_shapes):
+ mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
+ valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(
+ 0,
+ height - 1,
+ height,
+ dtype=enc_output.dtype,
+ device=enc_output.device,
+ ),
+ torch.linspace(
+ 0,
+ width - 1,
+ width,
+ dtype=enc_output.dtype,
+ device=enc_output.device,
+ ),
+ indexing="ij",
+ )
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
+ width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
+ # add default rotation (sin=0, cos=1)
+ sin = torch.zeros_like(grid[..., :1])
+ cos = torch.ones_like(grid[..., :1])
+ proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6)
+ proposals.append(proposal)
+ _cur += height * width
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
+ invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1)
+ invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid
+ output_proposals = output_proposals.masked_fill(invalid_mask, float(0))
+
+ # assign each pixel as an object query
+ object_query = enc_output
+ object_query = object_query.masked_fill(invalid_mask, float(0))
+ return object_query, output_proposals, invalid_mask
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ masks: torch.Tensor,
+ target: list[tuple[list[int], np.ndarray]] | None = None,
+ return_model_output: bool = False,
+ return_preds: bool = False,
+ **kwargs: Any,
+ ) -> dict[str, Any]:
+ feats = self.feat_extractor(input, masks)
+
+ sources: list[torch.Tensor] = []
+ feats_masks: list[torch.Tensor] = []
+
+ for source, mask in feats:
+ sources.append(source)
+ feats_masks.append(mask)
+ if mask is None: # pragma: no cover
+ raise ValueError("No attention mask was provided")
+
+ if self.training:
+ reference_points = self.reference_point_embed.weight
+ query_feat = self.query_feat.weight
+ else:
+ # only use one group in inference
+ reference_points = self.reference_point_embed.weight[: self.num_queries]
+ query_feat = self.query_feat.weight[: self.num_queries]
+
+ # Prepare encoder inputs (by flattening)
+ source_flatten_list: list[torch.Tensor] = []
+ mask_flatten_list: list[torch.Tensor] = []
+ spatial_shapes_list: list[tuple[int, int]] = []
+ for source, mask in zip(sources, feats_masks):
+ batch_size, num_channels, height, width = source.shape
+ spatial_shape = (height, width)
+ spatial_shapes_list.append(spatial_shape)
+ source = source.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ source_flatten_list.append(source)
+ mask_flatten_list.append(mask)
+ source_flatten = torch.cat(source_flatten_list, 1)
+ mask_flatten = torch.cat(mask_flatten_list, 1)
+ valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1)
+
+ tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
+ reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
+
+ object_query_embedding, output_proposals, invalid_mask = self.gen_encoder_output_proposals(
+ source_flatten, mask_flatten, spatial_shapes_list
+ )
+
+ group_detr = self.group_detr if self.training else 1
+ topk = self.num_queries
+
+ topk_coords_logits_list: list[torch.Tensor] = []
+
+ # For each group, predict class logits and bbox deltas from the object query embeddings,
+ # and select the top-k proposals based on the predicted class logits.
+ for group_id in range(group_detr):
+ group_object_query = self.enc_output[group_id](object_query_embedding)
+ group_object_query = self.enc_output_norm[group_id](group_object_query)
+
+ group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)
+ group_enc_outputs_class = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf"))
+ group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
+ group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox)
+
+ group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1]
+ group_topk_coords_logits_undetach = torch.gather(
+ group_enc_outputs_coord,
+ 1,
+ group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6),
+ )
+ group_topk_coords_logits = group_topk_coords_logits_undetach.detach()
+ topk_coords_logits_list.append(group_topk_coords_logits)
+
+ topk_coords_logits = torch.cat(topk_coords_logits_list, 1)
+
+ reference_points = self.refine_bboxes(topk_coords_logits, reference_points)
+
+ last_hidden_states, intermediate, intermediate_reference_points = self.decoder(
+ inputs_embeds=tgt,
+ reference_points=reference_points,
+ spatial_shapes_list=spatial_shapes_list,
+ valid_ratios=valid_ratios,
+ encoder_hidden_states=source_flatten,
+ )
+
+ logits = self.class_embed(last_hidden_states)
+ pred_boxes_delta = self.bbox_embed(last_hidden_states)
+ pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
+
+ out: dict[str, Any] = {}
+
+ if self.exportable:
+ out["logits"] = logits
+ out["pred_boxes"] = pred_boxes
+ return out
+
+ if return_model_output or target is None or return_preds:
+ out["logits"] = logits
+
+ if target is None or return_preds:
+ # Disable for torch.compile compatibility
+ @torch.compiler.disable
+ def _postprocess(logits, boxes):
+ return self.postprocessor(logits, boxes)
+
+ out["preds"] = _postprocess(logits.detach().cpu().numpy(), pred_boxes.detach().cpu().numpy())
+
+ if target is not None:
+ loss = self.compute_loss(logits, pred_boxes, target)
+ out["loss"] = loss
+
+ return out
+
+ def compute_loss(
+ self, logits: torch.Tensor, pred_boxes: torch.Tensor, target: list[tuple[list[int], np.ndarray]]
+ ) -> torch.Tensor:
+ """
+ Compute the loss for LW-DETR. The loss consists of three components:
+ classification loss, box regression loss, and rotation loss.
+ The classification loss is a cross-entropy loss between the predicted class logits and the target classes.
+ The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes,
+ computed only on the positive samples.
+ The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation,
+ averaged over the positive samples.
+ The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box,
+ we select the top-k queries with the lowest cost
+ (combination of classification cost, box regression cost, and rotation cost).
+
+ Args:
+ logits: (B, Q, C) tensor containing the predicted class logits for each query
+ pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query
+ target: list of length B, where each element is a tuple of (classes, boxes)
+ for the corresponding image in the batch.
+ - classes: list of length N_i containing the class indices of the N_i ground truth boxes in the image
+ - boxes: (N_i, 4. 2) array containing the ground truth boxes
+ in the format [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
+
+ Returns:
+ loss: the computed loss value
+ """
+
+ def _sigmoid_focal_loss(
+ inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0
+ ) -> torch.Tensor:
+ """Compute the sigmoid focal loss between `inputs` and `targets`."""
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ return loss.sum(-1).mean()
+
+ def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters
+ (mean and covariance).
+ """
+ cxcy = boxes[..., :2]
+
+ w = boxes[..., 2].clamp(min=1e-6)
+ h = boxes[..., 3].clamp(min=1e-6)
+
+ sin = boxes[..., 4]
+ cos = boxes[..., 5]
+
+ R = torch.stack(
+ [
+ torch.stack([cos, -sin], dim=-1),
+ torch.stack([sin, cos], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ sx = (w / 2) ** 2
+ sy = (h / 2) ** 2
+
+ S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device)
+
+ S[..., 0, 0] = sx
+ S[..., 1, 1] = sy
+
+ covariance = R @ S @ R.transpose(-1, -2)
+ return cxcy, covariance
+
+ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor:
+ """Compute the ProbIoU loss between predicted boxes and target boxes."""
+ mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes)
+ mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes)
+
+ delta = (mu1 - mu2).unsqueeze(-1)
+
+ sigma = (sigma1 + sigma2) * 0.5
+
+ sigma_inv = torch.linalg.inv(sigma)
+
+ mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1)
+
+ det_sigma = torch.linalg.det(sigma).clamp(min=1e-6)
+ det_sigma1 = torch.linalg.det(sigma1).clamp(min=1e-6)
+ det_sigma2 = torch.linalg.det(sigma2).clamp(min=1e-6)
+
+ bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2))
+
+ probiou = torch.exp(-bhattacharyya)
+
+ return 1 - probiou
+
+ device = logits.device
+ B, Q, C = logits.shape
+ # Build targets
+ targets = self.build_target(target)
+
+ total_cls = torch.tensor(0.0, device=device)
+ total_box = torch.tensor(0.0, device=device)
+ total_rot = torch.tensor(0.0, device=device)
+
+ for b in range(B):
+ pred_logits = logits[b]
+ pred_boxes_b = pred_boxes[b]
+
+ tgt_boxes = torch.as_tensor(
+ targets[b]["boxes"],
+ device=device,
+ dtype=pred_boxes.dtype,
+ )
+ tgt_cls = torch.as_tensor(
+ targets[b]["labels"],
+ device=device,
+ dtype=torch.long,
+ )
+
+ num_gt = len(tgt_cls)
+ if num_gt == 0:
+ target_onehot = torch.zeros_like(pred_logits)
+ total_cls += _sigmoid_focal_loss(pred_logits, target_onehot)
+ continue
+
+ pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1)
+ tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1)
+
+ with torch.no_grad():
+ cls_prob = pred_logits.sigmoid()
+ cost_cls = -torch.log(cls_prob[:, tgt_cls].clamp(min=1e-6))
+ cost_l1 = torch.cdist(
+ pred_boxes_b[:, :4],
+ tgt_boxes[:, :4],
+ p=1,
+ )
+ cost_rot = 1 - (pred_rot @ tgt_rot.T).abs()
+ total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot
+ matching_matrix = torch.zeros(
+ (Q, num_gt),
+ dtype=torch.bool,
+ device=device,
+ )
+
+ iou_like = 1 - cost_l1 # proxy
+ dynamic_k = (iou_like.sum(0).int() + 1).clamp(min=5, max=20)
+
+ for gt_idx in range(num_gt):
+ _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item()))
+ matching_matrix[candidate_idx, gt_idx] = True
+
+ # resolve duplicate matches
+ multiple_match_mask = matching_matrix.sum(1) > 1
+
+ if multiple_match_mask.any():
+ duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1)
+ min_cost_idx = total_cost[duplicate_idx].argmin(dim=1)
+ # Set all matches to False for the duplicate indices,
+ # then set the match with the lowest cost to True
+ matching_matrix[duplicate_idx] = False
+ matching_matrix[duplicate_idx, min_cost_idx] = True
+
+ pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True)
+
+ target_onehot = torch.zeros_like(pred_logits)
+ if len(pos_idx) > 0:
+ target_onehot[pos_idx, tgt_cls[gt_indices]] = 1
+
+ total_cls += _sigmoid_focal_loss(pred_logits, target_onehot)
+
+ if len(pos_idx) == 0:
+ continue
+
+ pred_sel = pred_boxes_b[pos_idx]
+ tgt_sel = tgt_boxes[gt_indices]
+ # L1 loss on (cx, cy, w, h)
+ l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4])
+ # ProbIoU loss on the whole box (including rotation)
+ probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean()
+ total_box += 2.0 * l1_loss + 0.5 * probiou_loss
+ # Rotation loss
+ cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs()
+ rot_loss = (1 - cos_sim).mean()
+ total_rot += 0.5 * rot_loss
+ # Average the loss over the batch
+ return (total_cls + total_box + total_rot) / B
+
+
+def _lw_detr(
+ arch: str,
+ pretrained: bool,
+ backbone_fn: Callable[[bool], nn.Module],
+ ignore_keys: list[str] | None = None,
+ **kwargs: Any,
+) -> LWDETR:
+ # Patch the config
+ kwargs["class_names"] = kwargs.get("class_names", default_cfgs[arch].get("class_names", []))
+
+ _cfg = deepcopy(default_cfgs[arch])
+ _cfg["class_names"] = kwargs["class_names"]
+ kwargs.pop("class_names")
+
+ # Build the feature extractor
+ backbone = backbone_fn( # type: ignore[call-arg]
+ False,
+ include_top=False,
+ input_shape=default_cfgs[arch]["input_shape"],
+ patch_size=kwargs.get("patch_size", (16, 16)),
+ )
+ feat_extractor = LWDETRBackbone(encoder_fn=backbone)
+
+ # Build the model
+ model = LWDETR(
+ feat_extractor,
+ cfg=_cfg,
+ class_names=_cfg["class_names"],
+ **kwargs,
+ )
+ # Load pretrained parameters
+ if pretrained:
+ # The number of class_names is not the same as the number of classes in the pretrained model =>
+ # remove the layer weights
+ _ignore_keys = ignore_keys if _cfg["class_names"] != default_cfgs[arch].get("class_names") else None
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+ return model
+
+
+def lw_detr_s(pretrained: bool = False, **kwargs: Any) -> LWDETR:
+ """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection"
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import lw_detr_s
+ >>> model = lw_detr_s(pretrained=True).eval()
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ text detection architecture
+ """
+ return _lw_detr(
+ "lw_detr_s",
+ pretrained,
+ vit_det_s,
+ ignore_keys=[
+ "class_embed.weight",
+ "class_embed.bias",
+ *[f"enc_out_class_embed.{i}.weight" for i in range(kwargs.get("group_detr", 13))],
+ *[f"enc_out_class_embed.{i}.bias" for i in range(kwargs.get("group_detr", 13))],
+ ],
+ **kwargs,
+ )
+
+
+def lw_detr_m(pretrained: bool = False, **kwargs: Any) -> LWDETR:
+ """LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection"
+ `_.
+
+ >>> import torch
+ >>> from doctr.models import lw_detr_m
+ >>> model = lw_detr_m(pretrained=True).eval()
+ >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+ >>> out = model(input_tensor)
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+ **kwargs: keyword arguments of the LinkNet architecture
+
+ Returns:
+ text detection architecture
+ """
+ return _lw_detr(
+ "lw_detr_m",
+ pretrained,
+ vit_det_m,
+ ignore_keys=[
+ "class_embed.weight",
+ "class_embed.bias",
+ *[f"enc_out_class_embed.{i}.weight" for i in range(kwargs.get("group_detr", 13))],
+ *[f"enc_out_class_embed.{i}.bias" for i in range(kwargs.get("group_detr", 13))],
+ ],
+ **kwargs,
+ )
diff --git a/doctr/models/layout/predictor/__init__.py b/doctr/models/layout/predictor/__init__.py
new file mode 100644
index 0000000000..e3c861310c
--- /dev/null
+++ b/doctr/models/layout/predictor/__init__.py
@@ -0,0 +1 @@
+from .pytorch import *
diff --git a/doctr/models/layout/predictor/pytorch.py b/doctr/models/layout/predictor/pytorch.py
new file mode 100644
index 0000000000..edd33ded6b
--- /dev/null
+++ b/doctr/models/layout/predictor/pytorch.py
@@ -0,0 +1,82 @@
+# Copyright (C) 2021-2026, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any
+
+import numpy as np
+import torch
+from torch import nn
+
+from doctr.models.detection._utils import _remove_padding
+from doctr.models.preprocessor import PreProcessor
+from doctr.models.utils import set_device_and_dtype
+
+__all__ = ["LayoutPredictor"]
+
+
+class LayoutPredictor(nn.Module):
+ """Implements an object able to localize layout elements in a document
+
+ Args:
+ pre_processor: transform inputs for easier batched model inference
+ model: core layout architecture
+ """
+
+ def __init__(
+ self,
+ pre_processor: PreProcessor,
+ model: nn.Module,
+ ) -> None:
+ super().__init__()
+ self.pre_processor = pre_processor
+ self.model = model.eval()
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ pages: list[np.ndarray],
+ **kwargs: Any,
+ ) -> list[dict[str, np.ndarray]]:
+ # Extract parameters from the preprocessor
+ preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
+ symmetric_pad = self.pre_processor.resize.symmetric_pad
+ assume_straight_pages = self.model.assume_straight_pages
+ # This flag is needed to return the padding mask from the preprocessor
+ self.pre_processor.resize.return_padding_mask = True
+
+ # Dimension check
+ if any(page.ndim != 3 for page in pages):
+ raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
+
+ processed_batches = self.pre_processor(pages)
+ _params = next(self.model.parameters())
+ self.model, processed_batches = set_device_and_dtype( # type: ignore[assignment]
+ self.model, processed_batches, _params.device, _params.dtype
+ )
+ predicted_batches = [
+ self.model(batch[0], batch[1], return_preds=True, return_model_output=True, **kwargs)
+ for batch in processed_batches
+ ]
+ # remap idx to class names
+ class_names = [
+ [self.model.class_names[int(i)] for i in pred[0]] # type: ignore[index]
+ for batch in predicted_batches
+ for pred in batch["preds"]
+ ]
+ boxes = [pred[1] for batch in predicted_batches for pred in batch["preds"]]
+ scores = [pred[2] for batch in predicted_batches for pred in batch["preds"]]
+
+ # Remove padding from loc predictions
+ preds = _remove_padding(
+ pages,
+ [{"pred": box} for box in boxes],
+ preserve_aspect_ratio=preserve_aspect_ratio,
+ symmetric_pad=symmetric_pad,
+ assume_straight_pages=assume_straight_pages, # type: ignore[arg-type]
+ )
+ return [
+ {"class_names": class_name, "boxes": pred["pred"], "scores": score} # type: ignore[dict-item]
+ for class_name, pred, score in zip(class_names, preds, scores)
+ ]
diff --git a/doctr/models/layout/zoo.py b/doctr/models/layout/zoo.py
new file mode 100644
index 0000000000..52d7c54666
--- /dev/null
+++ b/doctr/models/layout/zoo.py
@@ -0,0 +1,88 @@
+# Copyright (C) 2021-2026, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to for full license details.
+
+from typing import Any
+
+from doctr.models.utils import _CompiledModule
+
+from .. import layout
+from ..preprocessor import PreProcessor
+from .predictor import LayoutPredictor
+
+__all__ = ["layout_predictor"]
+
+ARCHS: list[str]
+
+ARCHS = ["lw_detr_s", "lw_detr_m"]
+
+
+def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> LayoutPredictor:
+ if isinstance(arch, str):
+ if arch not in ARCHS:
+ raise ValueError(f"unknown architecture '{arch}'")
+
+ _model = layout.__dict__[arch](
+ pretrained=pretrained,
+ assume_straight_pages=assume_straight_pages,
+ )
+ else:
+ # Adding the type for torch compiled models to the allowed architectures
+ allowed_archs = [layout.LWDETR, _CompiledModule]
+
+ if not isinstance(arch, tuple(allowed_archs)):
+ raise ValueError(f"unknown architecture: {type(arch)}")
+ _model = arch
+
+ kwargs.pop("pretrained_backbone", None)
+
+ kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
+ kwargs["std"] = kwargs.get("std", _model.cfg["std"])
+ kwargs["batch_size"] = kwargs.get("batch_size", 2)
+ predictor = LayoutPredictor(
+ PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
+ _model,
+ )
+ return predictor
+
+
+def layout_predictor(
+ arch: Any = "lw_detr_s",
+ pretrained: bool = False,
+ assume_straight_pages: bool = True,
+ preserve_aspect_ratio: bool = True,
+ symmetric_pad: bool = True,
+ batch_size: int = 2,
+ **kwargs: Any,
+) -> LayoutPredictor:
+ """Layout prediction architecture.
+
+ >>> import numpy as np
+ >>> from doctr.models import layout_predictor
+ >>> model = layout_predictor(arch='lw_detr_s', pretrained=True)
+ >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
+ >>> out = model([input_page])
+
+ Args:
+ arch: name of the architecture or model itself to use (e.g. 'lw_detr_s')
+ pretrained: If True, returns a model pre-trained on our layout prediction dataset
+ assume_straight_pages: If True, fit straight boxes to the page
+ preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
+ running the detection model on it
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
+ batch_size: number of samples the model processes in parallel
+ **kwargs: optional keyword arguments passed to the architecture
+
+ Returns:
+ Layout predictor
+ """
+ return _predictor(
+ arch=arch,
+ pretrained=pretrained,
+ assume_straight_pages=assume_straight_pages,
+ preserve_aspect_ratio=preserve_aspect_ratio,
+ symmetric_pad=symmetric_pad,
+ batch_size=batch_size,
+ **kwargs,
+ )
diff --git a/tests/common/test_models_layout.py b/tests/common/test_models_layout.py
new file mode 100644
index 0000000000..61ac2fc116
--- /dev/null
+++ b/tests/common/test_models_layout.py
@@ -0,0 +1,66 @@
+import numpy as np
+import pytest
+
+from doctr.models.layout.lw_detr.base import LWDETRPostProcessor
+
+
+def test_lwdetr_postprocessor():
+ postprocessor = LWDETRPostProcessor(
+ num_classes=5,
+ score_thresh=0.2,
+ iou_thresh=0.5,
+ topk=50,
+ assume_straight_pages=True,
+ )
+
+ r_postprocessor = LWDETRPostProcessor(
+ num_classes=5,
+ score_thresh=0.2,
+ iou_thresh=0.5,
+ topk=50,
+ assume_straight_pages=False,
+ )
+
+ # Input validation
+ with pytest.raises(Exception):
+ postprocessor(np.random.rand(2, 20, 5).astype(np.float32), np.random.rand(2, 20, 5).astype(np.float32))
+
+ # Forward pass
+ batch_size, num_queries = 2, 20
+ logits = np.random.randn(batch_size, num_queries, 6).astype(np.float32)
+ boxes = np.random.rand(batch_size, num_queries, 6).astype(np.float32)
+
+ out = postprocessor(logits, boxes)
+ r_out = r_postprocessor(logits, boxes)
+
+ # Batch composition
+ assert isinstance(out, list)
+ assert len(out) == batch_size
+
+ assert isinstance(r_out, list)
+ assert len(r_out) == batch_size
+
+ assert all(isinstance(sample, tuple) and len(sample) == 3 for sample in out)
+
+ assert all(isinstance(sample, tuple) and len(sample) == 3 for sample in r_out)
+
+ labels, bboxes, scores = out[0]
+ r_labels, r_bboxes, r_scores = r_out[0]
+
+ assert isinstance(labels, list)
+ assert isinstance(scores, list)
+ assert isinstance(bboxes, np.ndarray)
+
+ # straight pages: (K, 4)
+ assert bboxes.ndim == 2
+ assert bboxes.shape[1] == 4
+
+ # rotated pages: (K, 4, 2)
+ assert isinstance(r_bboxes, np.ndarray)
+ assert r_bboxes.ndim == 3
+ assert r_bboxes.shape[2] == 2
+
+ # Score / label validity
+ assert all(isinstance(s, float) for s in scores)
+ assert all(s >= 0 for s in scores)
+ assert len(labels) == len(scores)
diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py
index 1497ba0f12..db9aed1cdf 100644
--- a/tests/pytorch/test_models_factory.py
+++ b/tests/pytorch/test_models_factory.py
@@ -34,6 +34,7 @@ def test_push_to_hf_hub():
["mobilenet_v3_small", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-small"],
["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-torch-mobilenet-v3-large"],
["vit_s", "classification", "Felix92/doctr-dummy-torch-vit-s"],
+ ["vit_det_s", "classification", "Felix92/doctr-dummy-torch-vit-det-s"],
["textnet_tiny", "classification", "Felix92/doctr-dummy-torch-textnet-tiny"],
["db_resnet34", "detection", "Felix92/doctr-dummy-torch-db-resnet34"],
["db_resnet50", "detection", "Felix92/doctr-dummy-torch-db-resnet50"],
@@ -49,6 +50,7 @@ def test_push_to_hf_hub():
["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"],
["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"],
["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"],
+ ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"],
],
)
def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir):
diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py
new file mode 100644
index 0000000000..aa7e687114
--- /dev/null
+++ b/tests/pytorch/test_models_layout.py
@@ -0,0 +1,182 @@
+import os
+import tempfile
+
+import numpy as np
+import onnxruntime
+import pytest
+import torch
+
+from doctr.io import DocumentFile
+from doctr.models import layout
+from doctr.models.layout.predictor import LayoutPredictor
+from doctr.models.utils import _CompiledModule, export_model_to_onnx
+
+
+@pytest.mark.parametrize("train_mode", [True, False])
+@pytest.mark.parametrize(
+ "arch_name, input_shape",
+ [
+ ["lw_detr_s", (3, 512, 512)],
+ ["lw_detr_m", (3, 1024, 1024)],
+ ],
+)
+def test_layout_models(arch_name, input_shape, train_mode):
+ batch_size = 2
+ model = layout.__dict__[arch_name](pretrained=True)
+ model = model.train() if train_mode else model.eval()
+ assert isinstance(model, torch.nn.Module)
+ input_tensor = torch.rand((batch_size, *input_shape))
+ input_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool)
+ target = []
+ for _ in range(batch_size):
+ num_boxes = 5
+
+ class_ids = torch.randint(0, 10, (num_boxes,)).tolist()
+
+ # random boxes in normalized-ish space
+ boxes = []
+ for _ in range(num_boxes):
+ cx, cy = torch.rand(2) * 0.8 + 0.1
+ w, h = torch.rand(2) * 0.2 + 0.05
+
+ x1, y1 = cx - w / 2, cy - h / 2
+ x2, y2 = cx + w / 2, cy - h / 2
+ x3, y3 = cx + w / 2, cy + h / 2
+ x4, y4 = cx - w / 2, cy + h / 2
+
+ boxes.append([
+ [x1, y1],
+ [x2, y2],
+ [x3, y3],
+ [x4, y4],
+ ])
+
+ target.append((class_ids, torch.tensor(boxes, dtype=torch.float32)))
+ if torch.cuda.is_available():
+ model.cuda()
+ input_tensor = input_tensor.cuda()
+ input_masks = input_masks.cuda()
+ out = model(input_tensor, input_masks, target, return_model_output=True, return_preds=not train_mode)
+ assert isinstance(out, dict)
+ assert len(out) == 3 if not train_mode else len(out) == 2
+ # Check logits
+ assert "logits" in out
+ assert isinstance(out["logits"], torch.Tensor)
+
+ # Check Preds
+ if not train_mode:
+ for results in out["preds"]:
+ assert isinstance(results, tuple) and len(results) == 3
+ assert isinstance(results[0], list) and all(isinstance(idxs, int) for idxs in results[0])
+ assert isinstance(results[1], np.ndarray) and results[1].shape == (len(results[0]), 4)
+ assert isinstance(results[2], list) and all(isinstance(scores, float) for scores in results[2])
+ # Check class idxs are in the model's num_classes
+ assert all(0 <= idx < model.num_classes for idx in results[0])
+ # Check scores are between 0 and 1
+ assert all(0 <= score <= 1 for score in results[2])
+ # Check that the number of boxes, labels and scores are the same
+ assert len(results[0]) == len(results[1]) == len(results[2])
+ # Check that boxes are in the range [0, 1]
+ assert all((box >= 0).all() and (box <= 1).all() for box in results[1])
+ # Check loss
+ assert isinstance(out["loss"], torch.Tensor)
+ assert hasattr(model, "from_pretrained")
+
+
+@pytest.mark.parametrize(
+ "arch_name",
+ [
+ "lw_detr_s",
+ "lw_detr_m",
+ ],
+)
+def test_layout_zoo(arch_name):
+ # Model
+ predictor = layout.zoo.layout_predictor(arch_name, pretrained=False)
+ predictor.model.eval()
+ # object check
+ assert isinstance(predictor, LayoutPredictor)
+ input_tensor = np.random.rand(2, 1024, 1024, 3).astype(np.float32)
+ if torch.cuda.is_available():
+ predictor.model.cuda()
+
+ with torch.no_grad():
+ out = predictor(input_tensor)
+ assert isinstance(out, list) and len(out) == 2
+ assert all(isinstance(sample, dict) for sample in out)
+ assert all("class_names" in sample and "boxes" in sample and "scores" in sample for sample in out)
+ assert all(isinstance(sample["class_names"], list) for sample in out)
+ assert all(isinstance(sample["boxes"], np.ndarray) for sample in out)
+ assert all(isinstance(sample["scores"], list) for sample in out)
+ assert all(sample["boxes"].shape[1] == 4 for sample in out)
+ assert all(len(sample["class_names"]) == len(sample["scores"]) == sample["boxes"].shape[0] for sample in out)
+ assert all(all(isinstance(score, float) and 0 <= score <= 1 for score in sample["scores"]) for sample in out)
+
+
+@pytest.mark.parametrize(
+ "arch_name, input_shape",
+ [
+ ["lw_detr_s", (3, 512, 512)],
+ ["lw_detr_m", (3, 512, 512)],
+ ],
+)
+def test_models_onnx_export(arch_name, input_shape):
+ # Model
+ batch_size = 2
+ model = layout.__dict__[arch_name](pretrained=True, exportable=True).eval()
+ dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
+ dummy_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool)
+ pt = model(dummy_input, dummy_masks)
+ pt_logits = pt["logits"].detach().cpu().numpy()
+ pt_boxes = pt["pred_boxes"].detach().cpu().numpy()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Export
+ model_path = export_model_to_onnx(
+ model, model_name=os.path.join(tmpdir, "model"), dummy_input=(dummy_input, dummy_masks)
+ )
+ assert os.path.exists(model_path)
+ # Inference
+ ort_session = onnxruntime.InferenceSession(
+ os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"]
+ )
+ ort_outs = ort_session.run(
+ ["logits", "pred_boxes"], {"input": dummy_input.numpy(), "masks": dummy_masks.numpy()}
+ )
+
+ assert isinstance(ort_outs, list) and len(ort_outs) == 2
+ # Check boxes shape
+ assert ort_outs[0].shape == pt_logits.shape
+ assert ort_outs[1].shape == pt_boxes.shape
+ # Check that the output is close to the PyTorch output - only warn if not close
+ try:
+ assert np.allclose(pt_logits, ort_outs[0], atol=1e-4)
+ assert np.allclose(pt_boxes, ort_outs[1], atol=1e-4)
+ except AssertionError:
+ pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}")
+
+
+@pytest.mark.parametrize(
+ "arch_name",
+ [
+ "lw_detr_s",
+ "lw_detr_m",
+ ],
+)
+def test_torch_compiled_models(arch_name, mock_payslip):
+ doc = DocumentFile.from_images([mock_payslip])
+ predictor = layout.zoo.layout_predictor(arch_name, pretrained=True)
+ assert isinstance(predictor, LayoutPredictor)
+ out = predictor(doc)
+
+ # Compile the model
+ compiled_model = torch.compile(layout.__dict__[arch_name](pretrained=True).eval())
+ assert isinstance(compiled_model, _CompiledModule)
+ compiled_predictor = layout.zoo.layout_predictor(compiled_model)
+ compiled_out = compiled_predictor(doc)
+
+ # Compare that outputs are close
+ assert len(out) == len(compiled_out) == 1
+ # TODO: Enable if the model has a pretrained version
+ # assert out[0]["class_names"] == compiled_out[0]["class_names"]
+ # assert np.allclose(out[0]["boxes"], compiled_out[0]["boxes"], atol=1e-4)
+ # assert np.allclose(out[0]["scores"], compiled_out[0]["scores"], atol=1e-4)