Skip to content
Merged
8 changes: 8 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ doctr.models.detection
.. autofunction:: doctr.models.detection.detection_predictor


doctr.models.layout
-------------------

.. autofunction:: doctr.models.layout.lw_detr_s

.. autofunction:: doctr.models.layout.lw_detr_m


doctr.models.recognition
------------------------

Expand Down
1 change: 1 addition & 0 deletions doctr/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .classification import *
from .detection import *
from .recognition import *
from .layout import *
from .zoo import *
from .factory import *
22 changes: 16 additions & 6 deletions doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
"detection": models.detection.zoo.ARCHS,
"recognition": models.recognition.zoo.ARCHS,
"layout": models.layout.zoo.ARCHS,
}


Expand Down Expand Up @@ -96,14 +97,19 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #

if run_config is None and arch is None:
raise ValueError("run_config or arch must be specified")
if task not in ["classification", "detection", "recognition"]:
raise ValueError("task must be one of classification, detection, recognition")
if task not in ["classification", "detection", "recognition", "layout"]:
raise ValueError("task must be one of classification, detection, recognition, layout")

# default readme
readme = textwrap.dedent(
f"""

f"""---
language: en
tags:
- ocr
- pytorch
- doctr
- {task}
---


<p align="center">
Expand Down Expand Up @@ -161,7 +167,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #

# Create repository
api = HfApi()
api.create_repo(model_name, token=get_token(), exist_ok=False)
repo_url = api.create_repo(model_name, token=get_token(), repo_type="model", exist_ok=False)
full_repo_id = repo_url.repo_id

# Save model files to a temporary directory
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -172,7 +179,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
# Upload all files to the hub
api.upload_folder(
folder_path=tmp_dir,
repo_id=model_name,
repo_id=full_repo_id,
repo_type="model",
commit_message=commit_message,
token=get_token(),
)
Expand Down Expand Up @@ -208,6 +216,8 @@ def from_hub(repo_id: str, **kwargs: Any):
model = models.detection.__dict__[arch](pretrained=False)
elif task == "recognition":
model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"])
elif task == "layout":
model = models.layout.__dict__[arch](pretrained=False, class_names=cfg["class_names"])

# update model cfg
model.cfg = cfg
Expand Down
2 changes: 2 additions & 0 deletions doctr/models/layout/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .zoo import *

Check warning on line 1 in doctr/models/layout/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/__init__.py#L1

'.zoo.*' imported but unused (F401)
from .lw_detr import *
1 change: 1 addition & 0 deletions doctr/models/layout/lw_detr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pytorch import *

Check warning on line 1 in doctr/models/layout/lw_detr/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/__init__.py#L1

'.pytorch.*' imported but unused (F401)
254 changes: 254 additions & 0 deletions doctr/models/layout/lw_detr/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright (C) 2021-2026, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any

import cv2
import numpy as np

from doctr.models.core import BaseModel

__all__ = ["_LWDETR", "LWDETRPostProcessor"]


class LWDETRPostProcessor:
"""Implements a post processor for LW-DETR model

Check notice on line 17 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L17

Missing blank line after last section ('Args') (D413)

Check notice on line 17 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L17

Missing dashed underline after section ('Args') (D407)

Args:
num_classes: number of classes
score_thresh: confidence threshold for filtering predictions
iou_thresh: IoU threshold for NMS
topk: number of top predictions to keep before NMS
assume_straight_pages: whether the pages are assumed to be straight (i.e., no rotation)
"""

def __init__(
self,
num_classes: int,
score_thresh: float = 0.3,
iou_thresh: float = 0.5,
topk: int = 300,
assume_straight_pages: bool = True,
):
self.num_classes = num_classes
self.score_thresh = score_thresh
self.iou_thresh = iou_thresh
self.topk = topk
self.assume_straight_pages = assume_straight_pages

def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Decode the predicted boxes from OBB format to polygon format

Check notice on line 42 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L42

Missing dashed underline after section ('Returns') (D407)

Check notice on line 42 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L42

Section name should end with a newline ('Returns', not 'Returns:') (D406)

Args:
boxes: array of predicted boxes in OBB format (N, 6) (cx, cy, w, h, sin(theta), cos(theta))

Returns:
tuple of (polys, angles) where polys is an array of decoded polygons (N, 4, 2)
and angles is an array of angles in radians (N,)
"""
cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
sin, cos = boxes[:, 4], boxes[:, 5]

angles = np.arctan2(sin, cos)

polys = []
for i in range(len(boxes)):
rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i])))

poly = cv2.boxPoints(rect)
polys.append(poly)

return np.asarray(polys, dtype=np.float32), angles

def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float:
"""Compute the IoU between two polygons

Check notice on line 66 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L66

Missing dashed underline after section ('Returns') (D407)

Args:
poly1: first polygon (4, 2)
poly2: second polygon (4, 2)

Returns:
IoU between the two polygons
"""
inter = cv2.intersectConvexConvex(
poly1.astype(np.float32),
poly2.astype(np.float32),
)[0]

if inter <= 0:
return 0.0

area1 = cv2.contourArea(poly1)
area2 = cv2.contourArea(poly2)

return inter / (area1 + area2 - inter + 1e-6)

def _nms(self, polys: np.ndarray, scores: np.ndarray) -> list[int]:
"""Perform NMS on the predicted polygons

Check notice on line 89 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L89

Missing blank line after last section ('Returns') (D413)

Check notice on line 89 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L89

Missing dashed underline after section ('Returns') (D407)

Check notice on line 89 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L89

Section name should end with a newline ('Returns', not 'Returns:') (D406)

Args:
polys: array of predicted polygons (N, 4, 2)
scores: array of predicted scores (N,)

Returns:
list of indices of the polygons to keep after NMS
"""
idxs = np.argsort(scores)[::-1]
keep = []

while idxs.size > 0:
i = idxs[0]
keep.append(i)

if idxs.size == 1:
break

rest = idxs[1:]

ious = np.array([self._iou(polys[i], polys[j]) for j in rest])

idxs = rest[ious < self.iou_thresh]

return keep

def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]:
logits = np.asarray(logits)
boxes = np.asarray(boxes)

results: list[tuple[list[int], np.ndarray, list[float]]] = []

for b in range(boxes.shape[0]):
# Convert logits to probabilities and get scores and labels
exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True))
prob = exp / exp.sum(axis=-1, keepdims=True)

scores = prob[:, 1:].max(axis=-1)
labels = prob[:, 1:].argmax(axis=-1) + 1

# Keep only topk predictions before NMS
if self.topk is not None and len(scores) > self.topk:
idxs = np.argsort(scores)[::-1][: self.topk]
else:
idxs = np.arange(len(scores))

scores_b = scores[idxs]
labels_b = labels[idxs]
bboxes = boxes[b][idxs]

mask = scores_b > self.score_thresh

bboxes = bboxes[mask]
scores_b = scores_b[mask]
labels_b = labels_b[mask]

polys, _ = (
self._decode_boxes(bboxes)
if len(bboxes) > 0
else (
np.zeros((0, 4, 2), dtype=np.float32),
np.zeros((0,), dtype=np.float32),
)
)

keep = self._nms(polys, scores_b) if len(polys) > 0 else []

final_labels = []
final_boxes = []
final_scores = []

for idx in keep:
poly = polys[idx].reshape(-1).tolist()
if self.assume_straight_pages:
x_coords = poly[0::2]
y_coords = poly[1::2]
xmin, xmax = min(x_coords), max(x_coords)
ymin, ymax = min(y_coords), max(y_coords)
final_boxes.append([xmin, ymin, xmax, ymax])
else:
final_boxes.append(poly)

final_labels.append(int(labels_b[idx]))
final_scores.append(float(scores_b[idx]))

final_boxes_arr = (
np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4, 2)
if not self.assume_straight_pages
else np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4)
)

results.append((
final_labels,
final_boxes_arr,
final_scores,
))

return results


class _LWDETR(BaseModel):
"""LW-DETR as described in `"LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection"
<https://arxiv.org/pdf/2406.03459v1>`_.
"""

def build_target(
self,
target: list[tuple[list[int], np.ndarray]],
) -> list[dict[str, Any]]:
"""Build the target for LW-DETR training

Check notice on line 199 in doctr/models/layout/lw_detr/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/base.py#L199

Multi-line docstring summary should start at the second line (D213)

Args:
target: list of tuples (class_ids, boxes) where class_ids is a list of class ids for the boxes
and boxes is an array of shape (num_boxes, 8) containing the coordinates of the 4 corners of the box
in the format (x1, y1, x2, y2, x3, y3, x4, y4)

Returns:
list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6)
containing the box parameters in OBB format (cx, cy, w, h, sin(theta), cos(theta))
and "labels" is an array of shape (num_boxes,) containing the class labels
"""
targets = []

def _quad_to_obb(poly: np.ndarray):
p1, p2, p3, p4 = poly

cx, cy = np.mean(poly, axis=0)

w = (np.linalg.norm(p2 - p1) + np.linalg.norm(p3 - p4)) / 2
h = (np.linalg.norm(p3 - p2) + np.linalg.norm(p4 - p1)) / 2

theta = np.arctan2(*(p2 - p1)[::-1])

return np.array(
[cx, cy, w, h, np.sin(theta), np.cos(theta)],
dtype=np.float32,
)

for class_ids, boxes in target:
boxes_all = []
labels_all = []

if len(boxes) == 0:
targets.append({
"boxes": np.zeros((0, 6), dtype=np.float32),
"labels": np.zeros((0,), dtype=np.int64),
})
continue

for cls_id, box in zip(np.asarray(class_ids), np.asarray(boxes)):
poly = box.reshape(4, 2)
obb = _quad_to_obb(poly)

if obb[2] <= 1e-3 or obb[3] <= 1e-3:
continue

boxes_all.append(obb)
labels_all.append(cls_id + 1) # background = 0

targets.append({
"boxes": np.asarray(boxes_all, dtype=np.float32),
"labels": np.asarray(labels_all, dtype=np.int64),
})

return targets
1 change: 1 addition & 0 deletions doctr/models/layout/lw_detr/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pytorch import *

Check notice on line 1 in doctr/models/layout/lw_detr/layers/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/layers/__init__.py#L1

Missing docstring in public package (D104)
Loading
Loading