From 30e6686d8b516731593c988dba5a80850df0a091 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 13 May 2026 11:51:08 +0200 Subject: [PATCH 1/9] [Feat] Add layout scripts and minor fixes --- .github/workflows/references.yml | 88 +++ .../using_doctr/custom_models_training.rst | 15 + docs/source/using_doctr/using_models.rst | 60 ++ doctr/datasets/datasets/pytorch.py | 12 +- doctr/datasets/layout.py | 6 +- doctr/models/layout/lw_detr/base.py | 5 +- doctr/models/layout/lw_detr/pytorch.py | 12 +- references/layout/README.md | 104 +++ references/layout/evaluate.py | 200 +++++ references/layout/latency.py | 59 ++ references/layout/train.py | 726 ++++++++++++++++++ references/layout/utils.py | 101 +++ tests/conftest.py | 2 +- tests/pytorch/test_datasets_pt.py | 25 +- 14 files changed, 1389 insertions(+), 26 deletions(-) create mode 100644 references/layout/README.md create mode 100644 references/layout/evaluate.py create mode 100644 references/layout/latency.py create mode 100644 references/layout/train.py create mode 100644 references/layout/utils.py diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index e4c9503136..9585eaac5a 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -251,3 +251,91 @@ jobs: pip install -e .[viz,html] --upgrade - name: Benchmark latency run: python references/detection/latency.py db_mobilenet_v3_large --it 5 --size 512 + + + train-layout-analysis: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + pip install -r references/requirements.txt + - name: Download and extract toy set + run: | + wget https://github.com/mindee/doctr/releases/download/v0.3.1/toy_detection_set-bbbb4243.zip + sudo apt-get update && sudo apt-get install unzip -y + unzip toy_detection_set-bbbb4243.zip -d det_set + - name: Train for a short epoch + run: python references/layout/train.py lw_detr_s --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 + + evaluate-layout-analysis: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + pip install -r references/requirements.txt + - name: Evaluate layout analysis + run: python references/layout/evaluate.py lw_detr_s + + latency-layout-analysis: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + - name: Benchmark latency + run: python references/layout/latency.py lw_detr_s --it 5 --size 512 diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index c67f6c2d70..9b28df0fbb 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -6,6 +6,7 @@ For details on the training process and the necessary data and data format, refe - `detection `_ - `recognition `_ +- `layout `_ If you’re looking for a lightweight yet efficient tool to annotate small amounts of data, especially tailored for docTR, check out the `docTR Labeling Tool `_. @@ -52,6 +53,20 @@ Load a custom recognition model trained on another vocabulary as the default one predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True) + +Load a custom layout analysis model trained on another set of classes as the default one: + +.. code:: python3 + + import torch + from doctr.models import layout_predictor, lw_detr_s + from doctr.datasets import VOCABS + + layout_model = lw_detr_s(pretrained=False, class_names=["class_name_1", "class_name_2", ...]) + layout_model.from_pretrained('') + + predictor = layout_predictor(layout_arch=layout_model, pretrained=True) + Load a custom trained KIE detection model: .. code:: python3 diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index b37434092e..18b3b5dab0 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -174,6 +174,66 @@ Recognition predictors out = model([dummy_img]) +Layout Analysis +--------------- + +The task consists of localizing and classifying visual elements in a given image. +This is a more general task than text detection, as it can be used to detect and classify any type of visual element in a document, such as tables, figures, headers, footers, etc. +Our latest layout models works with rotated and skewed documents! + +Available architectures +^^^^^^^^^^^^^^^^^^^^^^^ + +The following architectures are currently supported: + +* :py:meth:`lw_detr_s ` +* :py:meth:`lw_detr_m ` + +For a comprehensive comparison, we have compiled a detailed benchmark: + ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ +| | | | | | | | ++==================================================+=================+===============+==================+=============+==============+====================+ +| **Architecture** | **Input shape** | **# params** | **mAP@[.5:.95]** | **AP@[.5]** | **AP@[.75]** | **sec/it (B: 1)** | ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ +| lw_detr_s | (1024, 1024, 3) | 15.1 M | | | | 0.5 | ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ +| lw_detr_m | (1024, 1024, 3) | 29.5 M | | | | 0.7 | ++--------------------------------------------------+-----------------+---------------+------------------+-------------+--------------+--------------------+ + + +Explanations about the metrics being used are available in :ref:`metrics`. + +Seconds per iteration (with a batch size of 1) is computed after a warmup phase of 100 tensors, by measuring the average number of processed tensors per second over 1000 samples. Those results were obtained on a `11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz`. + + +Layout predictors +^^^^^^^^^^^^^^^^^ + +:py:meth:`layout_predictor ` wraps your layout model to make it easily useable with your favorite deep learning framework seamlessly. + +.. code:: python3 + + import numpy as np + from doctr.models import layout_predictor + model = layout_predictor('lw_detr_s') + dummy_img = (255 * np.random.rand(800, 600, 3)).astype(np.uint8) + out = model([dummy_img]) + +You can pass specific boolean arguments to the predictor: +* `pretrained`: if you want to use a model that has been pretrained on a specific dataset, setting `pretrained=True` this will load the corresponding weights. If `pretrained=False`, which is the default, would otherwise lead to a random initialization and would lead to no/useless results. +* `assume_straight_pages`: if you work with straight documents only, it will fit straight bounding boxes to the text areas. +* `preserve_aspect_ratio`: if you want to preserve the aspect ratio of your documents while resizing before sending them to the model. +* `symmetric_pad`: if you choose to preserve the aspect ratio, it will pad the image symmetrically and not from the bottom-right. + +For instance, this snippet will instantiates a layout predictor able to detect text on rotated documents while preserving the aspect ratio: + +.. code:: python3 + + from doctr.models import layout_predictor + predictor = layout_predictor('lw_detr_s', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) + + End-to-End OCR -------------- diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py index a5df0dd7f5..6418906dd7 100644 --- a/doctr/datasets/datasets/pytorch.py +++ b/doctr/datasets/datasets/pytorch.py @@ -48,9 +48,17 @@ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]: return img, deepcopy(target) @staticmethod - def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]: + def collate_fn( + samples: list[tuple[torch.Tensor, Any]] | list[tuple[tuple[torch.Tensor, torch.Tensor], Any]], + ) -> tuple[torch.Tensor, list[Any]]: images, targets = zip(*samples) - images = torch.stack(images, dim=0) # type: ignore[assignment] + if isinstance(images[0], tuple): + images, padding_masks = zip(*images) + images = torch.stack(images, dim=0) # type: ignore[assignment] + padding_masks = torch.stack(padding_masks, dim=0) # type: ignore[assignment] + images = (images, padding_masks) + else: + images = torch.stack(images, dim=0) # type: ignore[assignment] return images, list(targets) # type: ignore[return-value] diff --git a/doctr/datasets/layout.py b/doctr/datasets/layout.py index 0e7f5df9de..2d2b7d18cf 100644 --- a/doctr/datasets/layout.py +++ b/doctr/datasets/layout.py @@ -61,17 +61,17 @@ def __init__( raise FileNotFoundError(f"unable to locate {img_path}") polygons = label.get("polygons") - class_names = label.get("class_names") + class_names = label.get("classes") if polygons is None: raise KeyError(f"missing 'polygons' for image: {img_name}") if class_names is None: - raise KeyError(f"missing 'class_names' for image: {img_name}") + raise KeyError(f"missing 'classes' for image: {img_name}") if len(polygons) != len(class_names): raise ValueError( f"number of polygons ({len(polygons)}) does not match " - f"number of class_names ({len(class_names)}) for image: {img_name}" + f"number of classes ({len(class_names)}) for image: {img_name}" ) geoms, polygon_classes = self.format_polygons( diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index e103021f5f..fe237a05dc 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -9,6 +9,7 @@ import numpy as np from doctr.models.core import BaseModel +from doctr.utils import order_points __all__ = ["_LWDETR", "LWDETRPostProcessor"] @@ -57,7 +58,7 @@ def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]: for i in range(len(boxes)): rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i]))) - poly = cv2.boxPoints(rect) + poly = order_points(cv2.boxPoints(rect)) polys.append(poly) return np.asarray(polys, dtype=np.float32), angles @@ -237,7 +238,7 @@ def _quad_to_obb(poly: np.ndarray): continue for cls_id, box in zip(np.asarray(class_ids), np.asarray(boxes)): - poly = box.reshape(4, 2) + poly = order_points(box.reshape(4, 2)) obb = _quad_to_obb(poly) if obb[2] <= 1e-3 or obb[3] <= 1e-3: diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 8c97fc626b..fcaf52c444 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -39,12 +39,8 @@ "Table", "Text", "Title", - "Document Index", - "Code", "Checkbox-Selected", "Checkbox-Unselected", - "Form", - "Key-Value Region", ], "url": None, }, @@ -64,12 +60,8 @@ "Table", "Text", "Title", - "Document Index", - "Code", "Checkbox-Selected", "Checkbox-Unselected", - "Form", - "Key-Value Region", ], "url": None, }, @@ -735,7 +727,7 @@ def _lw_detr( kwargs["class_names"] = kwargs.get("class_names", default_cfgs[arch].get("class_names", [])) _cfg = deepcopy(default_cfgs[arch]) - _cfg["class_names"] = kwargs["class_names"] + _cfg["class_names"] = sorted(kwargs["class_names"]) kwargs.pop("class_names") # Build the feature extractor @@ -758,7 +750,7 @@ def _lw_detr( if pretrained: # The number of class_names is not the same as the number of classes in the pretrained model => # remove the layer weights - _ignore_keys = ignore_keys if _cfg["class_names"] != default_cfgs[arch].get("class_names") else None + _ignore_keys = ignore_keys if _cfg["class_names"] != sorted(default_cfgs[arch].get("class_names", [])) else None model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model diff --git a/references/layout/README.md b/references/layout/README.md new file mode 100644 index 0000000000..dfd080f0bb --- /dev/null +++ b/references/layout/README.md @@ -0,0 +1,104 @@ +# Layout detection + +The sample training script was made to train layout detection model with docTR. + +## Setup + +First, you need to install `doctr` (with pip, for instance) + +```shell +pip install -e . --upgrade +pip install -r references/requirements.txt +``` + +## Usage + +You can start your training in PyTorch: + +```shell +python references/layout/train.py lw_detr_s --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 +``` + +### Multi-GPU support + +We now use the built-in [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) launcher to spawn your DDP workers. `torchrun` will set all the necessary environment variables (`LOCAL_RANK`, `RANK`, etc.) for you. Arguments are the same than the ones from single GPU, except: + +- `--backend`: you can specify another `backend` for `DistributedDataParallel` if the default one is not available on +your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). + +#### Key `torchrun` parameters + +- `--nproc_per_node=` + Spawn `` processes on the local machine (typically equal to the number of GPUs you want to use). +- `--nnodes=` + (Optional) Total number of nodes in your job. Default is 1. +- `--rdzv_backend`, `--rdzv_endpoint`, `--rdzv_id` + (Optional) Rendezvous settings for multi-node jobs. See the [torchrun docs](https://pytorch.org/docs/stable/elastic/run.html) for details. + +#### GPU selection + +By default all visible GPUs will be used. To limit which GPUs participate, set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `torchrun`. For example, to use only CUDA devices 0 and 2: + +```shell +CUDA_VISIBLE_DEVICES=0,2 \ +torchrun --nproc_per_node=2 references/layout/train.py \ + lw_detr_s \ + --train_path path/to/train \ + --val_path path/to/val \ + --epochs 5 \ + --backend nccl + ``` + +## Data format + +You need to provide both `train_path` and `val_path` arguments to start training. +Each path must lead to folder with 1 subfolder and 1 file: + +```shell +├── images +│ ├── sample_img_01.png +│ ├── sample_img_02.png +│ ├── sample_img_03.png +│ └── ... +└── labels.json +``` + +Each JSON file must be a dictionary, where the keys are the image file names and the value is a dictionary with 4 entries: `img_dimensions` (spatial shape of the image), `img_hash` (SHA256 of the image file), `polygons` (the set of 2D points forming the localization polygon), `classes` (list of class names for each polygon). +The order of the points does not matter inside a polygon. Points are (x, y) absolutes coordinates. + +labels.json + +```shell +{ + "sample_img_01.png" = { + 'img_dimensions': (900, 600), + 'img_hash': "theimagedumpmyhash", + 'polygons': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...], + 'classes': ["class_name_1", "class_name_2", ...] + }, + "sample_img_02.png" = { + 'img_dimensions': (900, 600), + 'img_hash': "thisisahash", + 'polygons': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...], + 'classes': ["class_name_1", "class_name_2", ...] + } + ... +} +``` + +## Slack Logging with tqdm + +To enable Slack logging using `tqdm`, you need to set the following environment variables: + +- `TQDM_SLACK_TOKEN`: the Slack Bot Token +- `TQDM_SLACK_CHANNEL`: you can retrieve it using `Right Click on Channel > Copy > Copy link`. You should get something like `https://xxxxxx.slack.com/archives/yyyyyyyy`. Keep only the `yyyyyyyy` part. + +You can follow this page on [how to create a Slack App](https://api.slack.com/quickstart). + +## Advanced options + +Feel free to inspect the multiple script option to customize your training to your own needs! + +```python +python references/layout/train.py --help +``` diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py new file mode 100644 index 0000000000..7888b9c816 --- /dev/null +++ b/references/layout/evaluate.py @@ -0,0 +1,200 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import multiprocessing as mp +import os +import time + +import torch +from torch.utils.data import DataLoader, SequentialSampler +from torchvision.transforms import Normalize + +if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): + from tqdm.contrib.slack import tqdm +else: + from tqdm.auto import tqdm + +from doctr import transforms as T +from doctr.datasets import LayoutDataset +from doctr.models import layout +from doctr.utils.metrics import ObjectDetectionMetric + + +@torch.inference_mode() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in tqdm(val_loader): + imgs, padding_masks = images + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + imgs = batch_transforms(imgs) + if amp: + with torch.cuda.amp.autocast(): + out = model(imgs, padding_masks, targets, return_preds=True) + else: + out = model(imgs, padding_masks, targets, return_preds=True) + # Compute metric + loc_preds = out["preds"] + for target, pred in zip(targets, loc_preds): + assert pred["boxes"].shape[0] == pred["scores"].shape[0] + assert pred["boxes"].shape[0] == pred["labels"].shape[0] + val_metric.update( + gt_boxes=target["boxes"], + pred_boxes=pred["boxes"], + gt_labels=target["labels"], + pred_labels=pred["labels"], + pred_scores=pred["scores"], + ) + + val_loss += out["loss"].item() + batch_cnt += 1 + + val_loss /= batch_cnt + metrics = val_metric.summary() + return ( + val_loss, + metrics["mAP@[.5:.95]"], + metrics["AP@[.5]"], + metrics["AP@[.75]"], + ) + + +def main(args): + slack_token = os.getenv("TQDM_SLACK_TOKEN") + slack_channel = os.getenv("TQDM_SLACK_CHANNEL") + pbar = tqdm(disable=False if slack_token and slack_channel else True) + if slack_token and slack_channel: + # Monkey patch tqdm write method to send messages directly to Slack + pbar.write = lambda msg: pbar.sio.client.chat_postMessage( + channel=slack_channel, + text=msg, + ) + pbar.write(str(args)) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + # Temporary model to recover configuration + tmp_model = layout.__dict__[args.arch]( + pretrained=False, + assume_straight_pages=not args.rotation, + ) + + if isinstance(args.size, int): + input_shape = (args.size, args.size) + else: + input_shape = tmp_model.cfg["input_shape"][-2:] + mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] + + st = time.time() + ds = LayoutDataset( + img_folder=os.path.join(args.dataset_path, "images"), + label_path=os.path.join(args.dataset_path, "labels.json"), + use_polygons=args.rotation, + sample_transforms=T.Resize( + input_shape, + preserve_aspect_ratio=args.keep_ratio, + symmetric_pad=args.symmetric_pad, + return_padding_mask=True, + ), + ) + class_names = ds.class_names + + test_loader = DataLoader( + ds, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(ds), + pin_memory=torch.cuda.is_available(), + collate_fn=ds.collate_fn, + ) + + pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)") + + # Load docTR model + model = layout.__dict__[args.arch]( + pretrained=not isinstance(args.resume, str), + assume_straight_pages=not args.rotation, + class_names=class_names, + ).eval() + + batch_transforms = Normalize(mean=mean, std=std) + + # Resume weights + if isinstance(args.resume, str): + pbar.write(f"Resuming {args.resume}") + model.from_pretrained(args.resume) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + pbar.write("No accessible GPU, target device set to CPU.") + + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + # Metrics + metric = ObjectDetectionMetric( + num_classes=len(class_names), + use_polygons=args.rotation, + ) + + pbar.write("Running evaluation") + val_loss, map5095, ap50, ap75 = evaluate( + model, + test_loader, + batch_transforms, + metric, + amp=args.amp, + ) + pbar.write( + f"Validation loss: {val_loss:.6f} | mAP@[.5:.95]: {map5095:.2%} | AP@[.5]: {ap50:.2%} | AP@[.75]: {ap75:.2%}" + ) + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="docTR evaluation script for text detection (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("arch", type=str, help="text-detection model to evaluate") + parser.add_argument("dataset_path", type=str, help="path to the dataset to evaluate on") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation") + parser.add_argument("--device", default=None, type=int, help="device") + parser.add_argument("--size", type=int, default=None, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") + parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") + parser.add_argument("--rotation", dest="rotation", action="store_true", help="inference with rotated bbox") + parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/layout/latency.py b/references/layout/latency.py new file mode 100644 index 0000000000..14c73720c5 --- /dev/null +++ b/references/layout/latency.py @@ -0,0 +1,59 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +"""Layout detection latency benchmark""" + +import argparse +import time + +import numpy as np +import torch + +from doctr.models import layout + + +@torch.inference_mode() +def main(args): + device = torch.device("cuda:0" if args.gpu else "cpu") + + # Pretrained imagenet model + model = layout.__dict__[args.arch](pretrained=args.pretrained).eval().to(device=device) + + # Input + img_tensor = torch.rand((1, 3, args.size, args.size)).to(device=device) + padding_masks = torch.zeros((1, args.size, args.size), dtype=torch.bool).to(device=device) + + # Warmup + for _ in range(10): + _ = model(input=img_tensor, masks=padding_masks) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(input=img_tensor, masks=padding_masks) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="docTR latency benchmark for layout detection (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=1024, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help="Should the benchmark be performed on GPU", action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument( + "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" + ) + args = parser.parse_args() + + main(args) diff --git a/references/layout/train.py b/references/layout/train.py new file mode 100644 index 0000000000..0ef5b4fd4c --- /dev/null +++ b/references/layout/train.py @@ -0,0 +1,726 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import datetime +import hashlib +import logging +import multiprocessing +import os +import time +from pathlib import Path + +import numpy as np +import torch + +# The following import is required for DDP +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort + +if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): + from tqdm.contrib.slack import tqdm +else: + from tqdm.auto import tqdm + +from doctr import transforms as T +from doctr.datasets import LayoutDataset +from doctr.models import layout, login_to_hub, push_to_hf_hub +from doctr.utils.metrics import ObjectDetectionMetric +from utils import EarlyStopper, plot_recorder, plot_samples + + +def record_lr( + model: torch.nn.Module, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + model = model.train() + # Update param groups & LR + optimizer.defaults["lr"] = start_lr + for pgroup in optimizer.param_groups: + pgroup["lr"] = start_lr + + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + scheduler = MultiplicativeLR(optimizer, lambda step: gamma) + + lr_recorder = [start_lr * gamma**idx for idx in range(num_it)] + loss_recorder = [] + + if amp: + scaler = torch.cuda.amp.GradScaler() + + for batch_idx, (images, targets) in enumerate(train_loader): + imgs, padding_masks = images + + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + + imgs = batch_transforms(imgs) + + # Forward, Backward & update + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(imgs, padding_masks, targets)["loss"] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(imgs, padding_masks, targets)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + # Update LR + scheduler.step() + + # Record + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.item()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[: len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): + if amp: + scaler = torch.cuda.amp.GradScaler() + + model.train() + # Iterate over the batches of the dataset + epoch_train_loss, batch_cnt = 0, 0 + pbar = tqdm(train_loader, dynamic_ncols=True, disable=(rank != 0)) + for images, targets in pbar: + imgs, padding_masks = images + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + imgs = batch_transforms(imgs) + + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(imgs, padding_masks, targets)["loss"] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(imgs, padding_masks, targets)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + + scheduler.step() + last_lr = scheduler.get_last_lr()[0] + + pbar.set_description(f"Training loss: {train_loss.item():.6f} | LR: {last_lr:.6f}") + if log: + log(train_loss=train_loss.item(), lr=last_lr) + + epoch_train_loss += train_loss.item() + batch_cnt += 1 + + epoch_train_loss /= batch_cnt + return epoch_train_loss, last_lr + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=None): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + pbar = tqdm(val_loader, dynamic_ncols=True) + for images, targets in pbar: + imgs, padding_masks = images + if torch.cuda.is_available(): + imgs = imgs.cuda() + padding_masks = padding_masks.cuda() + imgs = batch_transforms(imgs) + if amp: + with torch.cuda.amp.autocast(): + out = model(imgs, padding_masks, targets, return_preds=True) + else: + out = model(imgs, padding_masks, targets, return_preds=True) + # Compute metric + loc_preds = out["preds"] + for target, pred in zip(targets, loc_preds): + assert pred["boxes"].shape[0] == pred["scores"].shape[0] + assert pred["boxes"].shape[0] == pred["labels"].shape[0] + + val_metric.update( + gt_boxes=target["boxes"], + pred_boxes=pred["boxes"], + gt_labels=target["labels"], + pred_labels=pred["labels"], + pred_scores=pred["scores"], + ) + + pbar.set_description(f"Validation loss: {out['loss'].item():.6f}") + if log: + log(val_loss=out["loss"].item()) + + val_loss += out["loss"].item() + batch_cnt += 1 + + val_loss /= batch_cnt + metrics = val_metric.summary() + return ( + val_loss, + metrics["mAP@[.5:.95]"], + metrics["AP@[.5]"], + metrics["AP@[.75]"], + ) + + +def main(args): + # Detect distributed setup + # variable is set by torchrun + world_size = int(os.environ.get("WORLD_SIZE", 1)) + distributed = world_size > 1 + + # GPU setup + if distributed: + rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group(backend=args.backend) + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + else: + # single process + rank = 0 + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + device = torch.device("cuda", args.device) + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + logging.warning("No accessible GPU, target device set to CPU.") + device = torch.device("cpu") + + slack_token = os.getenv("TQDM_SLACK_TOKEN") + slack_channel = os.getenv("TQDM_SLACK_CHANNEL") + + pbar = tqdm(disable=False if (slack_token and slack_channel) and (rank == 0) else True) + if slack_token and slack_channel: + # Monkey patch tqdm write method to send messages directly to Slack + pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) + pbar.write(str(args)) + + if rank == 0 and args.push_to_hub: + login_to_hub() + + if not isinstance(args.workers, int): + args.workers = min(16, multiprocessing.cpu_count()) + + torch.backends.cudnn.benchmark = True + + # Temporary model to recover configuration + tmp_model = layout.__dict__[args.arch]( + pretrained=False, + assume_straight_pages=not args.rotation, + ) + + mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] + + # placeholder for class names + cls_container = [None] + if rank == 0: + # validation dataset related code + st = time.time() + val_set = LayoutDataset( + img_folder=os.path.join(args.val_path, "images"), + label_path=os.path.join(args.val_path, "labels.json"), + sample_transforms=T.SampleCompose( + ( + # Important to return padding masks for layout models + [ + T.Resize( + (args.input_size, args.input_size), + preserve_aspect_ratio=True, + symmetric_pad=True, + return_padding_mask=True, + ) + ] + if not args.rotation or args.eval_straight + else [] + ) + + ( + [ + T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + T.Resize( + (args.input_size, args.input_size), + preserve_aspect_ratio=True, + symmetric_pad=True, + return_padding_mask=True, + ), + ] + if args.rotation and not args.eval_straight + else [] + ) + ), + use_polygons=args.rotation and not args.eval_straight, + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, + ) + pbar.write( + f"Validation set loaded in {time.time() - st:.4f}s ({len(val_set)} samples in {len(val_loader)} batches)" + ) + with open(os.path.join(args.val_path, "labels.json"), "rb") as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + cls_container[0] = val_set.class_names + if distributed: + # broadcast class names to all ranks + dist.broadcast_object_list(cls_container, src=0) + # unpack class names on all ranks + class_names = cls_container[0] + + batch_transforms = Normalize(mean=mean, std=std) + + # Load docTR model + model = layout.__dict__[args.arch]( + pretrained=args.pretrained, + assume_straight_pages=not args.rotation, + class_names=class_names, + ) + + # Resume weights + if isinstance(args.resume, str): + pbar.write(f"Resuming {args.resume}") + model.from_pretrained(args.resume) + + if rank == 0: + # Metrics + val_metric = ObjectDetectionMetric( + num_classes=len(class_names), + use_polygons=args.rotation and not args.eval_straight, + ) + + if rank == 0 and args.test_only: + pbar.write("Running evaluation") + val_loss, map5095, ap50, ap75 = evaluate( + model, + val_loader, + batch_transforms, + val_metric, + amp=args.amp, + ) + pbar.write( + f"Validation loss: {val_loss:.6f} | " + f"mAP@[.5:.95]: {map5095:.2%} | " + f"AP@[.5]: {ap50:.2%} | " + f"AP@[.75]: {ap75:.2%}" + ) + return + + st = time.time() + # Augmentations + # Image augmentations + img_transforms = T.OneOf([ + Compose([ + T.RandomApply(T.ColorInversion(), 0.3), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2), + ]), + Compose([ + T.RandomApply(T.RandomShadow(), 0.3), + T.RandomApply(T.GaussianNoise(), 0.1), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), + RandomGrayscale(p=0.15), + ]), + RandomPhotometricDistort(p=0.3), + lambda x: x, # Identity no transformation + ]) + # Image + target augmentations + sample_transforms = T.SampleCompose( + ( + [ + T.RandomHorizontalFlip(0.15), + T.OneOf([ + T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + ]), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + if not args.rotation + else [ + T.RandomHorizontalFlip(0.15), + T.OneOf([ + T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + ]), + # Rotation augmentation + T.Resize(args.input_size, preserve_aspect_ratio=True), + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + # Important to return padding masks for layout models + T.Resize( + (args.input_size, args.input_size), + preserve_aspect_ratio=True, + symmetric_pad=True, + return_padding_mask=True, + ), + ] + ) + ) + + # Load both train and val data generators + train_set = LayoutDataset( + img_folder=os.path.join(args.train_path, "images"), + label_path=os.path.join(args.train_path, "labels.json"), + img_transforms=img_transforms, + sample_transforms=sample_transforms, + use_polygons=args.rotation, + ) + + if distributed: + sampler = DistributedSampler(train_set, rank=rank, shuffle=False, drop_last=True) + else: + sampler = RandomSampler(train_set) + + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + drop_last=True, + num_workers=args.workers, + sampler=sampler, + pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, + ) + + # Sanity class names check between train and val sets + if set(class_names) != set(train_set.class_names): + raise ValueError( + "Class names mismatch between train and val sets. " + f"Train classes: {train_set.class_names}, Val classes: {class_names}" + ) + + if rank == 0: + pbar.write( + f"Train set loaded in {time.time() - st:.4f}s ({len(train_set)} samples in {len(train_loader)} batches)" + ) + + with open(os.path.join(args.train_path, "labels.json"), "rb") as f: + train_hash = hashlib.sha256(f.read()).hexdigest() + + if rank == 0 and args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, target) + return + + # Backbone freezing + if args.freeze_backbone: + for p in model.feat_extractor.parameters(): + p.requires_grad = False + + if torch.cuda.is_available(): + torch.cuda.set_device(device) + model = model.to(device) + + if distributed: + # construct DDP model + model = DDP(model, device_ids=[rank]) + + # Optimizer + if args.optim == "adam": + optimizer = torch.optim.Adam( + [p for p in model.parameters() if p.requires_grad], + args.lr, + betas=(0.95, 0.999), + eps=1e-6, + weight_decay=args.weight_decay, + ) + + elif args.optim == "adamw": + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], + args.lr, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=args.weight_decay or 1e-4, + ) + + # LR Finder + if rank == 0 and args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + + # Scheduler + if args.sched == "cosine": + scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + elif args.sched == "onecycle": + scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + elif args.sched == "poly": + scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) + + # Training monitoring + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + if rank == 0: + config = { + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": args.optim, + "framework": "pytorch", + "scheduler": args.sched, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + "rotation": args.rotation, + "amp": args.amp, + } + + global global_step + global_step = 0 # Shared global step counter + + # W&B + if args.wb: + import wandb + + run = wandb.init(name=exp_name, project="layout-detection", config=config) + + def wandb_log_at_step(train_loss=None, val_loss=None, lr=None): + wandb.log({ + **({"train_loss_step": train_loss} if train_loss is not None else {}), + **({"val_loss_step": val_loss} if val_loss is not None else {}), + **({"step_lr": lr} if lr is not None else {}), + }) + + # ClearML + if args.clearml: + from clearml import Logger, Task + + task = Task.init(project_name="docTR/layout-detection", task_name=exp_name, reuse_last_task_id=False) + task.upload_artifact("config", config) + + def clearml_log_at_step(train_loss=None, val_loss=None, lr=None): + logger = Logger.current_logger() + + if train_loss is not None: + logger.report_scalar( + title="Training Step Loss", + series="train_loss_step", + iteration=global_step, + value=train_loss, + ) + if val_loss is not None: + logger.report_scalar( + title="Validation Step Loss", + series="val_loss_step", + iteration=global_step, + value=val_loss, + ) + if lr is not None: + logger.report_scalar( + title="Step Learning Rate", + series="step_lr", + iteration=global_step, + value=lr, + ) + + # Unified logger + def log_at_step(train_loss=None, val_loss=None, lr=None): + global global_step + if args.wb: + wandb_log_at_step(train_loss, val_loss, lr) + if args.clearml: + clearml_log_at_step(train_loss, val_loss, lr) + global_step += 1 # Increment the shared global step counter + + # Create loss queue + min_loss = np.inf + if args.early_stop: + early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) + + # Training loop + for epoch in range(args.epochs): + train_loss, actual_lr = fit_one_epoch( + model, + train_loader, + batch_transforms, + optimizer, + scheduler, + amp=args.amp, + log=log_at_step, + rank=rank, + ) + + if rank == 0: + pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6f} | LR: {actual_lr:.6f}") + + # Validation loop at the end of each epoch + val_loss, map5095, ap50, ap75 = evaluate( + model, + val_loader, + batch_transforms, + val_metric, + amp=args.amp, + log=log_at_step, + ) + params = model.module if hasattr(model, "module") else model + if val_loss < min_loss: + pbar.write(f"Validation loss decreased {min_loss:.6f} --> {val_loss:.6f}: saving state...") + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") + min_loss = val_loss + if args.save_interval_epoch: + pbar.write(f"Saving state at epoch: {epoch + 1}") + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") + log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + if any(val is None for val in (map5095, ap50, ap75)): + log_msg += "(Undefined metric value, caused by empty GTs or predictions)" + else: + log_msg += f"| mAP@[.5:.95]: {map5095:.2%} | AP@[.5]: {ap50:.2%} | AP@[.75]: {ap75:.2%}" + pbar.write(log_msg) + # W&B + if args.wb: + wandb.log({ + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": actual_lr, + "mAP@[.5:.95]": map5095, + "AP@[.5]": ap50, + "AP@[.75]": ap75, + }) + + # ClearML + if args.clearml: + from clearml import Logger + + logger = Logger.current_logger() + logger.report_scalar(title="Training Loss", series="train_loss", value=train_loss, iteration=epoch) + logger.report_scalar(title="Validation Loss", series="val_loss", value=val_loss, iteration=epoch) + logger.report_scalar(title="Learning Rate", series="lr", value=actual_lr, iteration=epoch) + logger.report_scalar(title="mAP@[.5:.95]", series="mAP@[.5:.95]", value=map5095, iteration=epoch) + logger.report_scalar(title="AP@[.5]", series="AP@[.5]", value=ap50, iteration=epoch) + logger.report_scalar(title="AP@[.75]", series="AP@[.75]", value=ap75, iteration=epoch) + + if args.early_stop and early_stopper.early_stop(val_loss): + pbar.write("Training halted early due to reaching patience limit.") + break + + if rank == 0: + if args.wb: + run.finish() + + if args.push_to_hub: + push_to_hf_hub(model, exp_name, task="layout", run_config=args) + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="DocTR training script for layout detection (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # DDP related args + parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for torch.distributed") + parser.add_argument( + "--device", + default=None, + type=int, + help="Specify gpu device for single-gpu training. In destributed setting, this parameter is ignored", + ) + parser.add_argument("arch", type=str, help="text-detection model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") + parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") + parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") + parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") + parser.add_argument( + "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" + ) + parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") + parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") + parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") + parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") + parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") + parser.add_argument( + "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" + ) + parser.add_argument( + "--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples" + ) + parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases") + parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML") + parser.add_argument("--push-to-hub", dest="push_to_hub", action="store_true", help="Push to Huggingface Hub") + parser.add_argument( + "--pretrained", + dest="pretrained", + action="store_true", + help="Load pretrained parameters before starting the training", + ) + parser.add_argument("--rotation", dest="rotation", action="store_true", help="train with rotated documents") + parser.add_argument( + "--eval-straight", + action="store_true", + help="metrics evaluation with straight boxes instead of polygons to save time + memory", + ) + parser.add_argument("--optim", type=str, default="adam", choices=["adam", "adamw"], help="optimizer to use") + parser.add_argument( + "--sched", type=str, default="poly", choices=["cosine", "onecycle", "poly"], help="scheduler to use" + ) + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") + parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") + parser.add_argument("--early-stop-epochs", type=int, default=5, help="Patience for early stopping") + parser.add_argument("--early-stop-delta", type=float, default=0.01, help="Minimum Delta for early stopping") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/layout/utils.py b/references/layout/utils.py new file mode 100644 index 0000000000..218d5548ea --- /dev/null +++ b/references/layout/utils.py @@ -0,0 +1,101 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + + +def plot_samples(images, targets: list[dict[str, np.ndarray]]) -> None: + # Unnormalize image + nb_samples = min(len(images), 4) + _, axes = plt.subplots(2, nb_samples, figsize=(20, 5)) + for idx in range(nb_samples): + img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) + if img.shape[0] == 3 and img.shape[2] != 3: + img = img.transpose(1, 2, 0) + + target = np.zeros(img.shape[:2], np.uint8) + tgts = targets[idx].copy() + for boxes in tgts.values(): + boxes[:, [0, 2]] = boxes[:, [0, 2]] * img.shape[1] + boxes[:, [1, 3]] = boxes[:, [1, 3]] * img.shape[0] + boxes[:, :4] = boxes[:, :4].round().astype(int) + + for box in boxes: + if boxes.ndim == 3: + cv2.fillPoly(target, [np.intp(box)], 1) + else: + target[int(box[1]) : int(box[3]) + 1, int(box[0]) : int(box[2]) + 1] = 1 + if nb_samples > 1: + axes[0][idx].imshow(img) + axes[1][idx].imshow(target.astype(bool)) + else: + axes[0].imshow(img) + axes[1].imshow(target.astype(bool)) + + # Disable axis + for ax in axes.ravel(): + ax.axis("off") + plt.show() + + +def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: + """Display the results of the LR grid search. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + + Args: + lr_recorder: list of LR values + loss_recorder: list of loss values + beta (float, optional): smoothing factor + **kwargs: keyword arguments from `matplotlib.pyplot.show` + """ + if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: + raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") + + # Exp moving average of loss + smoothed_losses = [] + avg_loss = 0.0 + for idx, loss in enumerate(loss_recorder): + avg_loss = beta * avg_loss + (1 - beta) * loss + smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) + + # Properly rescale Y-axis + data_slice = slice( + min(len(loss_recorder) // 10, 10), + # -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder) + len(loss_recorder), + ) + vals = np.array(smoothed_losses[data_slice]) + min_idx = vals.argmin() + max_val = vals.max() if min_idx is None else vals[: min_idx + 1].max() # type: ignore[misc] + delta = max_val - vals[min_idx] + + plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) + plt.xscale("log") + plt.xlabel("Learning Rate") + plt.ylabel("Training loss") + plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) + plt.grid(True, linestyle="--", axis="x") + plt.show(**kwargs) + + +class EarlyStopper: + def __init__(self, patience: int = 5, min_delta: float = 0.01): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = float("inf") + + def early_stop(self, validation_loss: float) -> bool: + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.counter = 0 + elif validation_loss > (self.min_validation_loss + self.min_delta): + self.counter += 1 + if self.counter >= self.patience: + return True + return False diff --git a/tests/conftest.py b/tests/conftest.py index 7132c41e63..f12b604491 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -147,7 +147,7 @@ def mock_layout_label(tmpdir_factory): [[3, 2], [3, 3], [4, 1], [4, 3]], [[30, 20], [30, 30], [40, 10], [40, 30]], ], - "class_names": [ + "classes": [ "Table", "Header", "Footer", diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index a8008b5bc6..7c1e97d808 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -188,15 +188,20 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): ds = datasets.LayoutDataset( img_folder=mock_image_folder, label_path=mock_layout_label, - img_transforms=Resize(input_size), + img_transforms=Resize(input_size, return_padding_mask=True), use_polygons=use_polygons, ) assert len(ds) == 5 - img, target_dict = ds[0] + inputs, target_dict = ds[0] + assert isinstance(inputs, tuple) and len(inputs) == 2 + img, padding_mask = inputs assert isinstance(img, torch.Tensor) assert img.dtype == torch.float32 assert img.shape[-2:] == input_size + assert isinstance(padding_mask, torch.Tensor) + assert padding_mask.dtype == torch.bool + assert padding_mask.shape == input_size assert isinstance(target_dict, dict) expected_classes = {"Table", "Header", "Footer", "Text"} assert set(target_dict.keys()) == expected_classes @@ -213,8 +218,12 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): assert ds.class_names == sorted(expected_classes) loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) images, targets = next(iter(loader)) - assert isinstance(images, torch.Tensor) - assert images.shape == (2, 3, *input_size) + assert isinstance(images, tuple) and len(images) == 2 + img, padding_mask = images + assert isinstance(img, torch.Tensor) + assert img.shape == (2, 3, *input_size) + assert isinstance(padding_mask, torch.Tensor) + assert padding_mask.shape == (2, *input_size) assert isinstance(targets, list) assert all(isinstance(target, dict) for target in targets) for target in targets: @@ -243,14 +252,14 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): test_cases = [ ( - {"class_names": ["Text"]}, + {"classes": ["Text"]}, KeyError, "missing 'polygons'", ), ( {"polygons": [[[0, 0], [1, 0], [1, 1], [0, 1]]]}, KeyError, - "missing 'class_names'", + "missing 'classes'", ), ( { @@ -258,7 +267,7 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): [[0, 0], [1, 0], [1, 1], [0, 1]], [[0, 0], [1, 0], [1, 1], [0, 1]], ], - "class_names": ["Text"], + "classes": ["Text"], }, ValueError, "number of polygons", @@ -266,7 +275,7 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): ( { "polygons": [[[0, 0], [1, 0], [1, 1]]], # only 3 points - "class_names": ["Text"], + "classes": ["Text"], }, ValueError, "polygons are expected to have shape", From 320b4ae6e1b01b0ecf75bb7b00d2b6212ab8f3a3 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 13 May 2026 11:55:27 +0200 Subject: [PATCH 2/9] typing and mypy --- doctr/datasets/datasets/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py index 6418906dd7..439555327b 100644 --- a/doctr/datasets/datasets/pytorch.py +++ b/doctr/datasets/datasets/pytorch.py @@ -50,7 +50,7 @@ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]: @staticmethod def collate_fn( samples: list[tuple[torch.Tensor, Any]] | list[tuple[tuple[torch.Tensor, torch.Tensor], Any]], - ) -> tuple[torch.Tensor, list[Any]]: + ) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], list[Any]]: images, targets = zip(*samples) if isinstance(images[0], tuple): images, padding_masks = zip(*images) @@ -60,7 +60,7 @@ def collate_fn( else: images = torch.stack(images, dim=0) # type: ignore[assignment] - return images, list(targets) # type: ignore[return-value] + return images, list(targets) class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101 From e93dfb7bc639a9535496fb6f9948252df2babb39 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 21 May 2026 13:46:17 +0200 Subject: [PATCH 3/9] Add train scripts and update augmentations --- .github/workflows/references.yml | 13 +- doctr/datasets/coco_text.py | 6 +- doctr/datasets/cord.py | 3 +- doctr/datasets/datasets/base.py | 35 +-- doctr/datasets/datasets/pytorch.py | 24 +- doctr/datasets/detection.py | 3 +- doctr/datasets/doc_artefacts.py | 3 +- doctr/datasets/funsd.py | 3 +- doctr/datasets/generator/base.py | 5 +- doctr/datasets/generator/pytorch.py | 9 +- doctr/datasets/ic03.py | 3 +- doctr/datasets/ic13.py | 6 +- doctr/datasets/iiit5k.py | 3 +- doctr/datasets/iiithws.py | 6 +- doctr/datasets/imgur5k.py | 6 +- doctr/datasets/layout.py | 3 +- doctr/datasets/mjsynth.py | 6 +- doctr/datasets/ocr.py | 3 +- doctr/datasets/orientation.py | 3 +- doctr/datasets/recognition.py | 3 +- doctr/datasets/sroie.py | 3 +- doctr/datasets/svhn.py | 3 +- doctr/datasets/svt.py | 3 +- doctr/datasets/synthtext.py | 3 +- doctr/datasets/wildreceipt.py | 6 +- doctr/models/layout/lw_detr/base.py | 83 ++++-- doctr/models/layout/lw_detr/pytorch.py | 14 +- doctr/models/preprocessor/pytorch.py | 7 +- doctr/transforms/modules/base.py | 250 ++++++++++-------- doctr/transforms/modules/pytorch.py | 181 ++++++++----- doctr/utils/common_types.py | 22 +- references/classification/train_character.py | 13 +- .../classification/train_orientation.py | 22 +- references/detection/train.py | 4 +- references/layout/evaluate.py | 18 +- references/layout/train.py | 39 +-- references/layout/utils.py | 69 ++++- references/recognition/train.py | 12 +- tests/common/test_datasets.py | 54 ++-- tests/common/test_transforms.py | 115 +++++--- tests/pytorch/test_datasets_pt.py | 35 ++- tests/pytorch/test_models_layout.py | 49 ++-- tests/pytorch/test_transforms_pt.py | 177 +++++++------ 43 files changed, 835 insertions(+), 493 deletions(-) diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index 9585eaac5a..0336cfd032 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -281,11 +281,11 @@ jobs: pip install -r references/requirements.txt - name: Download and extract toy set run: | - wget https://github.com/mindee/doctr/releases/download/v0.3.1/toy_detection_set-bbbb4243.zip + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c79b4e69.zip sudo apt-get update && sudo apt-get install unzip -y - unzip toy_detection_set-bbbb4243.zip -d det_set + unzip toy_layout_set-c79b4e69.zip -d layout_set - name: Train for a short epoch - run: python references/layout/train.py lw_detr_s --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 + run: python references/layout/train.py lw_detr_s --train_path ./layout_set --val_path ./layout_set -b 2 --epochs 1 evaluate-layout-analysis: runs-on: ${{ matrix.os }} @@ -311,8 +311,13 @@ jobs: python -m pip install --upgrade pip pip install -e .[viz,html] --upgrade pip install -r references/requirements.txt + - name: Download and extract toy set + run: | + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c79b4e69.zip + sudo apt-get update && sudo apt-get install unzip -y + unzip toy_layout_set-c79b4e69.zip -d layout_set - name: Evaluate layout analysis - run: python references/layout/evaluate.py lw_detr_s + run: python references/layout/evaluate.py lw_detr_s ./layout_set latency-layout-analysis: runs-on: ${{ matrix.os }} diff --git a/doctr/datasets/coco_text.py b/doctr/datasets/coco_text.py index d2f011c141..d1df3f0c5c 100644 --- a/doctr/datasets/coco_text.py +++ b/doctr/datasets/coco_text.py @@ -27,10 +27,12 @@ class COCOTEXT(AbstractDataset): >>> from doctr.datasets import COCOTEXT >>> train_set = COCOTEXT(train=True, img_folder="/path/to/coco_text/train2014/", >>> label_path="/path/to/coco_text/cocotext.v2.json") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target >>> test_set = COCOTEXT(train=False, img_folder="/path/to/coco_text/train2014/", >>> label_path = "/path/to/coco_text/cocotext.v2.json") - >>> img, target = test_set[0] + >>> sample = test_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py index c71b6a22e5..d58376bd3b 100644 --- a/doctr/datasets/cord.py +++ b/doctr/datasets/cord.py @@ -26,7 +26,8 @@ class CORD(VisionDataset): >>> from doctr.datasets import CORD >>> train_set = CORD(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py index 5cdcaceaab..72c28a91df 100644 --- a/doctr/datasets/datasets/base.py +++ b/doctr/datasets/datasets/base.py @@ -9,12 +9,8 @@ from pathlib import Path from typing import Any -import numpy as np - from doctr.io.image import get_img_shape -from doctr.utils.data import download_from_url - -from ...models.utils import _copy_tensor +from doctr.utils import Sample, download_from_url __all__ = ["_AbstractDataset", "_VisionDataset"] @@ -26,8 +22,8 @@ class _AbstractDataset: def __init__( self, root: str | Path, - img_transforms: Callable[[Any], Any] | None = None, - sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None, + img_transforms: Callable[[Sample], Sample] | None = None, + sample_transforms: Callable[[Sample], Sample] | None = None, pre_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None, ) -> None: if not Path(root).is_dir(): @@ -45,32 +41,23 @@ def __len__(self) -> int: def _read_sample(self, index: int) -> tuple[Any, Any]: raise NotImplementedError - def __getitem__(self, index: int) -> tuple[Any, Any]: + def __getitem__(self, index: int) -> Sample: # Read image img, target = self._read_sample(index) + mask = None # FIX: always defined # Pre-transforms (format conversion at run-time etc.) if self._pre_transforms is not None: img, target = self._pre_transforms(img, target) + sample = Sample(image=img, mask=mask, target=target) + if self.img_transforms is not None: - # typing issue cf. https://github.com/python/mypy/issues/5485 - img = self.img_transforms(img) + sample = self.img_transforms(sample) if self.sample_transforms is not None: - # Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks. - if ( - isinstance(target, dict) - and all(isinstance(item, np.ndarray) for item in target.values()) - and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target - ): - img_transformed = _copy_tensor(img) - for class_name, bboxes in target.items(): - img_transformed, target[class_name] = self.sample_transforms(img, bboxes) - img = img_transformed - else: - img, target = self.sample_transforms(img, target) - - return img, target + sample = self.sample_transforms(sample) + + return sample def extra_repr(self) -> str: return "" diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py index 439555327b..cd4830e71f 100644 --- a/doctr/datasets/datasets/pytorch.py +++ b/doctr/datasets/datasets/pytorch.py @@ -11,6 +11,7 @@ import torch from doctr.io import read_img_as_tensor, tensor_from_numpy +from doctr.utils import Sample from .base import _AbstractDataset, _VisionDataset @@ -49,18 +50,19 @@ def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]: @staticmethod def collate_fn( - samples: list[tuple[torch.Tensor, Any]] | list[tuple[tuple[torch.Tensor, torch.Tensor], Any]], - ) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], list[Any]]: - images, targets = zip(*samples) - if isinstance(images[0], tuple): - images, padding_masks = zip(*images) - images = torch.stack(images, dim=0) # type: ignore[assignment] - padding_masks = torch.stack(padding_masks, dim=0) # type: ignore[assignment] - images = (images, padding_masks) - else: - images = torch.stack(images, dim=0) # type: ignore[assignment] + samples: list[Sample], + ) -> tuple[torch.Tensor, list[Any]] | tuple[tuple[torch.Tensor, torch.Tensor], list[Any]]: + _images = [s.image for s in samples] + targets = [s.target for s in samples] + + _masks = [s.mask for s in samples if s.mask is not None] + + images = torch.stack(_images, dim=0) + if _masks: + masks = torch.stack(_masks, dim=0) + return (images, masks), targets - return images, list(targets) + return images, targets class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101 diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index a953bddedf..4f9b4e85e3 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -23,7 +23,8 @@ class DetectionDataset(AbstractDataset): >>> from doctr.datasets import DetectionDataset >>> train_set = DetectionDataset(img_folder="/path/to/images", >>> label_path="/path/to/labels.json") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/datasets/doc_artefacts.py b/doctr/datasets/doc_artefacts.py index 6a0d3011d4..898acbb932 100644 --- a/doctr/datasets/doc_artefacts.py +++ b/doctr/datasets/doc_artefacts.py @@ -23,7 +23,8 @@ class DocArtefacts(VisionDataset): >>> from doctr.datasets import DocArtefacts >>> train_set = DocArtefacts(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py index a322fcecae..1a80dd6357 100644 --- a/doctr/datasets/funsd.py +++ b/doctr/datasets/funsd.py @@ -26,7 +26,8 @@ class FUNSD(VisionDataset): >>> from doctr.datasets import FUNSD >>> train_set = FUNSD(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/generator/base.py b/doctr/datasets/generator/base.py index c094d7406c..b367a481fb 100644 --- a/doctr/datasets/generator/base.py +++ b/doctr/datasets/generator/base.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw from doctr.io.image import tensor_from_pil +from doctr.utils import Sample from doctr.utils.fonts import get_font from ..datasets import AbstractDataset @@ -62,7 +63,7 @@ def __init__( cache_samples: bool = False, font_family: str | list[str] | None = None, img_transforms: Callable[[Any], Any] | None = None, - sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None, + sample_transforms: Callable[[Sample], Sample] | None = None, ) -> None: self.vocab = vocab self._num_samples = num_samples @@ -111,7 +112,7 @@ def __init__( cache_samples: bool = False, font_family: str | list[str] | None = None, img_transforms: Callable[[Any], Any] | None = None, - sample_transforms: Callable[[Any, Any], tuple[Any, Any]] | None = None, + sample_transforms: Callable[[Sample], Sample] | None = None, ) -> None: self.vocab = vocab self.wordlen_range = (min_chars, max_chars) diff --git a/doctr/datasets/generator/pytorch.py b/doctr/datasets/generator/pytorch.py index 81132aff0a..37cf459073 100644 --- a/doctr/datasets/generator/pytorch.py +++ b/doctr/datasets/generator/pytorch.py @@ -3,8 +3,6 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from torch.utils.data._utils.collate import default_collate - from .base import _CharacterGenerator, _WordGenerator __all__ = ["CharacterGenerator", "WordGenerator"] @@ -15,7 +13,8 @@ class CharacterGenerator(_CharacterGenerator): >>> from doctr.datasets import CharacterGenerator >>> ds = CharacterGenerator(vocab='abdef', num_samples=100) - >>> img, target = ds[0] + >>> sample = ds[0] + >>> img, target = sample.image, sample.target Args: vocab: vocabulary to take the character from @@ -28,7 +27,6 @@ class CharacterGenerator(_CharacterGenerator): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - setattr(self, "collate_fn", default_collate) class WordGenerator(_WordGenerator): @@ -36,7 +34,8 @@ class WordGenerator(_WordGenerator): >>> from doctr.datasets import WordGenerator >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100) - >>> img, target = ds[0] + >>> sample = ds[0] + >>> img, target = sample.image, sample.target Args: vocab: vocabulary to take the character from diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py index 2aaaea5502..e61c2049c6 100644 --- a/doctr/datasets/ic03.py +++ b/doctr/datasets/ic03.py @@ -25,7 +25,8 @@ class IC03(VisionDataset): >>> from doctr.datasets import IC03 >>> train_set = IC03(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/ic13.py b/doctr/datasets/ic13.py index 95cd65dcd0..5795983df3 100644 --- a/doctr/datasets/ic13.py +++ b/doctr/datasets/ic13.py @@ -27,10 +27,12 @@ class IC13(AbstractDataset): >>> from doctr.datasets import IC13 >>> train_set = IC13(img_folder="/path/to/Challenge2_Training_Task12_Images", >>> label_folder="/path/to/Challenge2_Training_Task1_GT") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target >>> test_set = IC13(img_folder="/path/to/Challenge2_Test_Task12_Images", >>> label_folder="/path/to/Challenge2_Test_Task1_GT") - >>> img, target = test_set[0] + >>> sample = test_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py index c7e9736824..5bf0929ac5 100644 --- a/doctr/datasets/iiit5k.py +++ b/doctr/datasets/iiit5k.py @@ -28,7 +28,8 @@ class IIIT5K(VisionDataset): >>> # NOTE: this dataset is for character-level localization >>> from doctr.datasets import IIIT5K >>> train_set = IIIT5K(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/iiithws.py b/doctr/datasets/iiithws.py index 99235dbd5d..41bc953656 100644 --- a/doctr/datasets/iiithws.py +++ b/doctr/datasets/iiithws.py @@ -25,11 +25,13 @@ class IIITHWS(AbstractDataset): >>> train_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized", >>> label_path="/path/to/IIIT-HWS-90K.txt", >>> train=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target >>> test_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized", >>> label_path="/path/to/IIIT-HWS-90K.txt") >>> train=False) - >>> img, target = test_set[0] + >>> sample = test_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/datasets/imgur5k.py b/doctr/datasets/imgur5k.py index 9d09a71cb6..b44bcd34ba 100644 --- a/doctr/datasets/imgur5k.py +++ b/doctr/datasets/imgur5k.py @@ -34,10 +34,12 @@ class IMGUR5K(AbstractDataset): >>> from doctr.datasets import IMGUR5K >>> train_set = IMGUR5K(train=True, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images", >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target >>> test_set = IMGUR5K(train=False, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images", >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json") - >>> img, target = test_set[0] + >>> sample = test_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/datasets/layout.py b/doctr/datasets/layout.py index 2d2b7d18cf..5ca1b54b99 100644 --- a/doctr/datasets/layout.py +++ b/doctr/datasets/layout.py @@ -23,7 +23,8 @@ class LayoutDataset(AbstractDataset): >>> img_folder="/path/to/images", >>> label_path="/path/to/labels.json", >>> ) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder containing the dataset images diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py index 115f636eb6..1b6065b9b8 100644 --- a/doctr/datasets/mjsynth.py +++ b/doctr/datasets/mjsynth.py @@ -23,11 +23,13 @@ class MJSynth(AbstractDataset): >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", >>> train=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") >>> train=False) - >>> img, target = test_set[0] + >>> sample = test_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/datasets/ocr.py b/doctr/datasets/ocr.py index 1864b3767b..cebdc5aa91 100644 --- a/doctr/datasets/ocr.py +++ b/doctr/datasets/ocr.py @@ -21,7 +21,8 @@ class OCRDataset(AbstractDataset): >>> from doctr.datasets import OCRDataset >>> train_set = OCRDataset(img_folder="/path/to/images", >>> label_file="/path/to/labels.json") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: local path to image folder (all jpg at the root) diff --git a/doctr/datasets/orientation.py b/doctr/datasets/orientation.py index d2f2b56a91..45779691f4 100644 --- a/doctr/datasets/orientation.py +++ b/doctr/datasets/orientation.py @@ -18,7 +18,8 @@ class OrientationDataset(AbstractDataset): >>> from doctr.datasets import OrientationDataset >>> train_set = OrientationDataset(img_folder="/path/to/images") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/datasets/recognition.py b/doctr/datasets/recognition.py index affce35e39..714be882ba 100644 --- a/doctr/datasets/recognition.py +++ b/doctr/datasets/recognition.py @@ -19,7 +19,8 @@ class RecognitionDataset(AbstractDataset): >>> from doctr.datasets import RecognitionDataset >>> train_set = RecognitionDataset(img_folder="/path/to/images", >>> labels_path="/path/to/labels.json") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: path to the images folder diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py index 5ca18e4acb..50017a5e25 100644 --- a/doctr/datasets/sroie.py +++ b/doctr/datasets/sroie.py @@ -26,7 +26,8 @@ class SROIE(VisionDataset): >>> from doctr.datasets import SROIE >>> train_set = SROIE(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/svhn.py b/doctr/datasets/svhn.py index 812bccf1e7..128eb9ad7a 100644 --- a/doctr/datasets/svhn.py +++ b/doctr/datasets/svhn.py @@ -25,7 +25,8 @@ class SVHN(VisionDataset): >>> from doctr.datasets import SVHN >>> train_set = SVHN(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py index 43df4b6590..1a7f5d2a6a 100644 --- a/doctr/datasets/svt.py +++ b/doctr/datasets/svt.py @@ -25,7 +25,8 @@ class SVT(VisionDataset): >>> from doctr.datasets import SVT >>> train_set = SVT(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/synthtext.py b/doctr/datasets/synthtext.py index 613283d8b3..cd560edf11 100644 --- a/doctr/datasets/synthtext.py +++ b/doctr/datasets/synthtext.py @@ -28,7 +28,8 @@ class SynthText(VisionDataset): >>> from doctr.datasets import SynthText >>> train_set = SynthText(train=True, download=True) - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target Args: train: whether the subset should be the training one diff --git a/doctr/datasets/wildreceipt.py b/doctr/datasets/wildreceipt.py index d650e43790..c5b86aa855 100644 --- a/doctr/datasets/wildreceipt.py +++ b/doctr/datasets/wildreceipt.py @@ -30,10 +30,12 @@ class WILDRECEIPT(AbstractDataset): >>> from doctr.datasets import WILDRECEIPT >>> train_set = WILDRECEIPT(train=True, img_folder="/path/to/wildreceipt/", >>> label_path="/path/to/wildreceipt/train.txt") - >>> img, target = train_set[0] + >>> sample = train_set[0] + >>> img, target = sample.image, sample.target >>> test_set = WILDRECEIPT(train=False, img_folder="/path/to/wildreceipt/", >>> label_path="/path/to/wildreceipt/test.txt") - >>> img, target = test_set[0] + >>> sample = test_set[0] + >>> img, target = sample.image, sample.target Args: img_folder: folder with all the images of the dataset diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index fe237a05dc..b0e9a48e09 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -195,14 +195,16 @@ class _LWDETR(BaseModel): def build_target( self, - target: list[tuple[list[int], np.ndarray]], + target: list[dict[str, np.ndarray]], + class_names: list[str], ) -> list[dict[str, Any]]: """Build the target for LW-DETR training Args: - target: list of tuples (class_ids, boxes) where class_ids is a list of class ids for the boxes - and boxes is an array of shape (num_boxes, 8) containing the coordinates of the 4 corners of the box - in the format (x1, y1, x2, y2, x3, y3, x4, y4) + target: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding + to class names and values corresponding to lists of boxes in either polygon format (4, 2) + or bounding box format (4,) (xmin, ymin, xmax, ymax) + class_names: list of class names Returns: list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6) @@ -211,41 +213,88 @@ def build_target( """ targets = [] + class_to_id = {name: i for i, name in enumerate(class_names)} + def _quad_to_obb(poly: np.ndarray): - p1, p2, p3, p4 = poly + poly = np.asarray(poly, dtype=np.float32) cx, cy = np.mean(poly, axis=0) - w = (np.linalg.norm(p2 - p1) + np.linalg.norm(p3 - p4)) / 2 - h = (np.linalg.norm(p3 - p2) + np.linalg.norm(p4 - p1)) / 2 + edges = np.stack([ + poly[1] - poly[0], + poly[2] - poly[1], + poly[3] - poly[2], + poly[0] - poly[3], + ]) + + lengths = np.linalg.norm(edges, axis=1) + i = np.argmax(lengths) + dx, dy = edges[i] + + theta = np.arctan2(dy, dx) - theta = np.arctan2(*(p2 - p1)[::-1]) + w = np.mean([lengths[0], lengths[2]]) + h = np.mean([lengths[1], lengths[3]]) return np.array( [cx, cy, w, h, np.sin(theta), np.cos(theta)], dtype=np.float32, ) - for class_ids, boxes in target: + def to_quad(box: np.ndarray): + box = np.asarray(box, dtype=np.float32) + + if box.shape == (4,): + x1, y1, x2, y2 = box + return np.array( + [ + [x1, y1], + [x2, y1], + [x2, y2], + [x1, y2], + ], + dtype=np.float32, + ) + + if box.shape == (8,): + return box.reshape(4, 2) + + if box.shape == (4, 2): + return box.astype(np.float32) + + raise ValueError(f"Unsupported box shape: {box.shape}") + + for sample in target: boxes_all = [] labels_all = [] - if len(boxes) == 0: + if not sample: targets.append({ "boxes": np.zeros((0, 6), dtype=np.float32), "labels": np.zeros((0,), dtype=np.int64), }) continue - for cls_id, box in zip(np.asarray(class_ids), np.asarray(boxes)): - poly = order_points(box.reshape(4, 2)) - obb = _quad_to_obb(poly) + for class_name, boxes in sample.items(): + if class_name not in class_to_id: + raise ValueError(f"Unknown class name: {class_name}") + + cls_id = class_to_id[class_name] + + boxes = np.asarray(boxes) + + if boxes.ndim == 1: + boxes = boxes[None, :] + + for box in boxes: + poly = to_quad(box) + obb = _quad_to_obb(poly) - if obb[2] <= 1e-3 or obb[3] <= 1e-3: - continue + if obb[2] <= 1e-3 or obb[3] <= 1e-3: + continue - boxes_all.append(obb) - labels_all.append(cls_id + 1) # background = 0 + boxes_all.append(obb) + labels_all.append(cls_id) targets.append({ "boxes": np.asarray(boxes_all, dtype=np.float32), diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index fcaf52c444..5c7f68c0f2 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -413,7 +413,7 @@ def forward( self, input: torch.Tensor, masks: torch.Tensor, - target: list[tuple[list[int], np.ndarray]] | None = None, + target: list[dict[str, np.ndarray]] | None = None, return_model_output: bool = False, return_preds: bool = False, **kwargs: Any, @@ -526,7 +526,7 @@ def _postprocess(logits, boxes): return out def compute_loss( - self, logits: torch.Tensor, pred_boxes: torch.Tensor, target: list[tuple[list[int], np.ndarray]] + self, logits: torch.Tensor, pred_boxes: torch.Tensor, target: list[dict[str, np.ndarray]] ) -> torch.Tensor: """ Compute the loss for LW-DETR. The loss consists of three components: @@ -543,11 +543,9 @@ def compute_loss( Args: logits: (B, Q, C) tensor containing the predicted class logits for each query pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query - target: list of length B, where each element is a tuple of (classes, boxes) - for the corresponding image in the batch. - - classes: list of length N_i containing the class indices of the N_i ground truth boxes in the image - - boxes: (N_i, 4. 2) array containing the ground truth boxes - in the format [[x1, y1], [x2, y2], [x3, y3], [x4, y4]] + target: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding + to class names and values corresponding to lists of boxes in either polygon format (4, 2) + or bounding box format (4,) (xmin, ymin, xmax, ymax) Returns: loss: the computed loss value @@ -626,7 +624,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te device = logits.device B, Q, C = logits.shape # Build targets - targets = self.build_target(target) + targets = self.build_target(target, self.class_names) total_cls = torch.tensor(0.0, device=device) total_box = torch.tensor(0.0, device=device) diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py index 4f7c4949bd..ef0d8cf888 100644 --- a/doctr/models/preprocessor/pytorch.py +++ b/doctr/models/preprocessor/pytorch.py @@ -13,6 +13,7 @@ from torchvision.transforms import transforms as T from doctr.transforms import Resize +from doctr.utils import Sample from doctr.utils.multithreading import multithread_exec __all__ = ["PreProcessor"] @@ -100,9 +101,10 @@ def sample_transforms(self, x: np.ndarray) -> torch.Tensor | tuple[torch.Tensor, tensor = torch.from_numpy(x.copy()).permute(2, 0, 1) # Resizing if self.resize.return_padding_mask: - tensor, mask = self.resize(tensor) + sample = self.resize(Sample(image=tensor)) + tensor, mask = sample.image, sample.mask else: - tensor = self.resize(tensor) + tensor = self.resize(Sample(image=tensor)).image # Data type if tensor.dtype == torch.uint8: tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1) @@ -152,6 +154,7 @@ def __call__( samples = list(multithread_exec(self.sample_transforms, x)) # Batching if self.resize.return_padding_mask: + print(samples) img_batches, mask_batches = self.batch_inputs(samples) else: img_batches = self.batch_inputs(samples) diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index e28b104909..a53acf038f 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -10,23 +10,32 @@ import numpy as np +from doctr.utils.common_types import Sample from doctr.utils.repr import NestedObject from .. import functional as F -__all__ = ["SampleCompose", "ImageTransform", "ColorInversion", "OneOf", "RandomApply", "RandomRotate", "RandomCrop"] +__all__ = [ + "SampleCompose", + "ImageTransform", + "ColorInversion", + "OneOf", + "RandomApply", + "RandomRotate", + "RandomCrop", + "ImageTorchvisionTransform", +] class SampleCompose(NestedObject): """Implements a wrapper that will apply transformations sequentially on both image and target - .. code:: python - - >>> import numpy as np - >>> import torch - >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate - >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) - >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4))) + >>> import numpy as np + >>> import torch + >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate + >>> from doctr.utils import Sample + >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) + >>> out, out_boxes = transfos(Sample(image=torch.rand(8, 64, 64, 3), target=np.zeros((2, 4)))) Args: transforms: list of transformation modules @@ -34,25 +43,23 @@ class SampleCompose(NestedObject): _children_names: list[str] = ["sample_transforms"] - def __init__(self, transforms: list[Callable[[Any, Any], tuple[Any, Any]]]) -> None: + def __init__(self, transforms: list[Callable[[Sample], Sample]]) -> None: self.sample_transforms = transforms - def __call__(self, x: Any, target: Any) -> tuple[Any, Any]: + def __call__(self, sample: Sample) -> Sample: for t in self.sample_transforms: - x, target = t(x, target) - - return x, target + sample = t(sample) + return sample class ImageTransform(NestedObject): """Implements a transform wrapper to turn an image-only transformation into an image+target transform - .. code:: python - - >>> import torch - >>> from doctr.transforms import ImageTransform, ColorInversion - >>> transfo = ImageTransform(ColorInversion((32, 32))) - >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None) + >>> import torch + >>> from doctr.transforms import ImageTransform, ColorInversion + >>> from doctr.utils import Sample + >>> transfo = ImageTransform(ColorInversion((32, 32))) + >>> out = transfo(Sample(image=torch.rand(8, 64, 64, 3))) Args: transform: the image transformation module to wrap @@ -63,21 +70,44 @@ class ImageTransform(NestedObject): def __init__(self, transform: Callable[[Any], Any]) -> None: self.img_transform = transform - def __call__(self, img: Any, target: Any) -> tuple[Any, Any]: - img = self.img_transform(img) - return img, target + def __call__(self, sample: Sample) -> Sample: + img = self.img_transform(sample) + return sample.replace(image=img) + + +class ImageTorchvisionTransform(NestedObject): + """Implements a transform wrapper to turn a torchvision image-only transformation into an image+target transform + + >>> import torch + >>> from torchvision import transforms + >>> from doctr.transforms import ImageTorchvisionTransform + >>> from doctr.utils import Sample + >>> transfo = ImageTorchvisionTransform(transforms.ColorJitter(brightness=0.5)) + >>> out, _ = transfo(Sample(image=torch.rand(8, 64, 64, 3))) + + Args: + transform: the torchvision image transformation module to wrap + """ + + _children_names: list[str] = ["img_transform"] + + def __init__(self, transform: Callable[[Any], Any]) -> None: + self.img_transform = transform + + def __call__(self, sample: Sample) -> Sample: + img = self.img_transform(sample.image) + return sample.replace(image=img) class ColorInversion(NestedObject): """Applies the following tranformation to a tensor (image or batch of images): convert to grayscale, colorize (shift 0-values randomly), and then invert colors - .. code:: python - - >>> import torch - >>> from doctr.transforms import ColorInversion - >>> transfo = ColorInversion(min_val=0.6) - >>> out = transfo(torch.rand(8, 64, 64, 3)) + >>> import torch + >>> from doctr.transforms import ColorInversion + >>> from doctr.utils import Sample + >>> transfo = ColorInversion(min_val=0.6) + >>> out = transfo(Sample(image=torch.rand(8, 64, 64, 3))) Args: min_val: range [min_val, 1] to colorize RGB pixels @@ -89,19 +119,19 @@ def __init__(self, min_val: float = 0.5) -> None: def extra_repr(self) -> str: return f"min_val={self.min_val}" - def __call__(self, img: Any) -> Any: - return F.invert_colors(img, self.min_val) + def __call__(self, sample: Sample) -> Sample: + out = F.invert_colors(sample.image, self.min_val) + return sample.replace(image=out) class OneOf(NestedObject): """Randomly apply one of the input transformations - .. code:: python - - >>> import torch - >>> from doctr.transforms import OneOf - >>> transfo = OneOf([JpegQuality(), Gamma()]) - >>> out = transfo(torch.rand(1, 64, 64, 3)) + >>> import torch + >>> from doctr.transforms import OneOf, JpegQuality, Gamma + >>> from doctr.utils import Sample + >>> transfo = OneOf([JpegQuality(), Gamma()]) + >>> out = transfo(Sample(image=torch.rand(1, 64, 64, 3))) Args: transforms: list of transformations, one only will be picked @@ -109,48 +139,39 @@ class OneOf(NestedObject): _children_names: list[str] = ["transforms"] - def __init__(self, transforms: list[Callable[[Any], Any]]) -> None: + def __init__(self, transforms: list[Callable[[Sample], Sample]]) -> None: self.transforms = transforms - def __call__( - self, img: Any, target: np.ndarray | dict[str, np.ndarray] | None = None - ) -> Any | tuple[Any, np.ndarray | dict[str, np.ndarray]]: - # Pick transformation + def __call__(self, sample: Sample) -> Sample: transfo = self.transforms[int(random.random() * len(self.transforms))] - # Apply - return transfo(img) if target is None else transfo(img, target) # type: ignore[call-arg] + return transfo(sample) class RandomApply(NestedObject): """Apply with a probability p the input transformation - .. code:: python - - >>> import torch - >>> from doctr.transforms import RandomApply - >>> transfo = RandomApply(Gamma(), p=.5) - >>> out = transfo(torch.rand(1, 64, 64, 3)) + >>> import torch + >>> from doctr.transforms import RandomApply, Gamma + >>> from doctr.utils import Sample + >>> transfo = RandomApply(Gamma(), p=.5) + >>> out = transfo(Sample(image=torch.rand(1, 64, 64, 3), target=np.array([[0.1, 0.1, 0.9, 0.9]]), mask=None)) Args: transform: transformation to apply p: probability to apply """ - def __init__(self, transform: Callable[[Any], Any], p: float = 0.5) -> None: + def __init__(self, transform: Callable[[Sample], Sample], p: float = 0.5) -> None: self.transform = transform self.p = p def extra_repr(self) -> str: return f"transform={self.transform}, p={self.p}" - def __call__( - self, - img: Any, - target: np.ndarray | dict[str, np.ndarray] | None = None, - ) -> Any | tuple[Any, np.ndarray | dict[str, np.ndarray]]: + def __call__(self, sample: Sample) -> Sample: if random.random() < self.p: - return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg] - return img if target is None else (img, target) + return self.transform(sample) + return sample class RandomRotate(NestedObject): @@ -159,6 +180,11 @@ class RandomRotate(NestedObject): .. image:: https://doctr-static.mindee.com/models?id=v0.4.0/rotation_illustration.png&src=0 :align: center + >>> import torch + >>> from doctr.transforms import RandomRotate + >>> transfo = RandomRotate(max_angle=30, expand=True) + >>> out = transfo(Sample(image=torch.rand(1, 64, 64, 3), target=np.array([[0.1, 0.1, 0.9, 0.9]]), mask=None)) + Args: max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in [-max_angle, max_angle] expand: whether the image should be padded before the rotation @@ -171,8 +197,7 @@ def __init__(self, max_angle: float = 5.0, expand: bool = False) -> None: def extra_repr(self) -> str: return f"max_angle={self.max_angle}, expand={self.expand}" - def _rotate_array(self, img: Any, target: np.ndarray, angle: float) -> tuple[Any, np.ndarray]: - """Rotate the image and the target, and keep only boxes with at least partial visibility after rotation""" + def _rotate_array(self, img: Any, target: np.ndarray, angle: float): is_polygon = target.shape[1:] == (4, 2) r_img, r_polys = F.rotate_sample(img, target, angle, self.expand) @@ -182,19 +207,29 @@ def _rotate_array(self, img: Any, target: np.ndarray, angle: float) -> tuple[Any # convert back if input was boxes if not is_polygon: - # (N, 4, 2) -> (N, 4) x1y1 = r_polys.min(axis=1) x2y2 = r_polys.max(axis=1) - r_boxes = np.concatenate([x1y1, x2y2], axis=1) - return r_img, r_boxes + return r_img, np.concatenate([x1y1, x2y2], axis=1) return r_img, r_polys - def __call__( - self, img: Any, target: np.ndarray | dict[str, np.ndarray] - ) -> tuple[Any, np.ndarray | dict[str, np.ndarray]]: + def __call__(self, sample: Sample) -> Sample: angle = random.uniform(-self.max_angle, self.max_angle) + img = sample.image + target = sample.target + mask = sample.mask + + r_mask = None + if mask is not None: + r_mask, _ = F.rotate_sample( + mask.unsqueeze(0), np.array([[0, 0, 1, 1]], dtype=np.float32), angle, self.expand + ) + + if target is None: + r_img, _ = F.rotate_sample(img, np.array([[0, 0, 1, 1]], dtype=np.float32), angle, self.expand) + return sample.replace(image=r_img, mask=r_mask) + if isinstance(target, dict): rotated_targets = {} rotated_img = None @@ -208,14 +243,23 @@ def __call__( if rotated_img is None: rotated_img = r_img rotated_targets[cls_name] = r_arr - return rotated_img if rotated_img is not None else img, rotated_targets - return self._rotate_array(img, target, angle) + final_img = rotated_img if rotated_img is not None else img + return sample.replace(image=final_img, mask=r_mask, target=rotated_targets) + + r_img, r_target = self._rotate_array(img, target, angle) + return sample.replace(image=r_img, mask=r_mask, target=r_target) class RandomCrop(NestedObject): """Randomly crop a tensor image and its boxes + >>> import torch + >>> from doctr.transforms import RandomCrop + >>> from doctr.utils import Sample + >>> transfo = RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33)) + >>> out = transfo(Sample(image=torch.rand(1, 64, 64, 3), target=np.array([[0.1, 0.1, 0.9, 0.9]]), mask=None)) + Args: scale: tuple of floats, relative (min_area, max_area) of the crop ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w @@ -228,25 +272,13 @@ def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float, def extra_repr(self) -> str: return f"scale={self.scale}, ratio={self.ratio}" - def _crop_array( - self, - img: Any, - target: np.ndarray, - crop_box: tuple[float, float, float, float], - ) -> tuple[Any, np.ndarray]: + def _crop_array(self, img: Any, target: np.ndarray, crop_box): is_polygon = target.shape[1:] == (4, 2) - # For polygons, we need to reproject the coordinates into the cropped frame, - # and keep only those with at least partial visibility + if is_polygon: cropped_img, _ = F.crop_detection( img, - np.concatenate( - ( - np.min(target, axis=1), - np.max(target, axis=1), - ), - axis=1, - ), + np.concatenate((np.min(target, axis=1), np.max(target, axis=1)), axis=1), crop_box, ) @@ -278,37 +310,43 @@ def _crop_array( return cropped_img, np.clip(crop_boxes, 0, 1) - def __call__( - self, - img: Any, - target: np.ndarray | dict[str, np.ndarray], - ) -> tuple[Any, np.ndarray | dict[str, np.ndarray]]: + def __call__(self, sample: Sample) -> Sample: scale = random.uniform(self.scale[0], self.scale[1]) ratio = random.uniform(self.ratio[0], self.ratio[1]) - height, width = img.shape[-2:] + img = sample.image + target = sample.target + mask = sample.mask + + h, w = img.shape[-2:] - # Calculate crop size - crop_area = scale * width * height - aspect_ratio = ratio * (width / height) - crop_width = int(round(math.sqrt(crop_area * aspect_ratio))) - crop_height = int(round(math.sqrt(crop_area / aspect_ratio))) + crop_area = scale * w * h + aspect_ratio = ratio * (w / h) - # Ensure crop size does not exceed image dimensions - crop_width = min(crop_width, width) - crop_height = min(crop_height, height) + crop_w = int(round(math.sqrt(crop_area * aspect_ratio))) + crop_h = int(round(math.sqrt(crop_area / aspect_ratio))) - # Randomly select crop position - x = random.randint(0, width - crop_width) - y = random.randint(0, height - crop_height) + crop_w = min(crop_w, w) + crop_h = min(crop_h, h) + + x = random.randint(0, w - crop_w) + y = random.randint(0, h - crop_h) crop_box = ( - x / width, - y / height, - (x + crop_width) / width, - (y + crop_height) / height, + x / w, + y / h, + (x + crop_w) / w, + (y + crop_h) / h, ) + r_mask = None + if mask is not None: + r_mask, _ = self._crop_array(mask, np.zeros((0, 4)), crop_box) + + if target is None: + r_img, _ = self._crop_array(img, np.zeros((0, 4)), crop_box) + return sample.replace(image=r_img, mask=r_mask) + if isinstance(target, dict): cropped_targets = {} cropped_img = None @@ -325,6 +363,8 @@ def __call__( cropped_targets[cls_name] = c_arr - return cropped_img if cropped_img is not None else img, cropped_targets + final_img = cropped_img if cropped_img is not None else img + return sample.replace(image=final_img, mask=r_mask, target=cropped_targets) - return self._crop_array(img, target, crop_box) + c_img, c_target = self._crop_array(img, target, crop_box) + return sample.replace(image=c_img, mask=r_mask, target=c_target) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index 22f3e3ddac..c7245f3a82 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -8,12 +8,13 @@ import numpy as np import torch -from PIL.Image import Image from scipy.ndimage import gaussian_filter from torch.nn.functional import pad from torchvision.transforms import functional as F from torchvision.transforms import transforms as T +from doctr.utils import Sample + from ..functional import random_shadow __all__ = [ @@ -32,8 +33,9 @@ class Resize(T.Resize): >>> import torch >>> from doctr.transforms import Resize + >>> from doctr.utils import Sample >>> transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=True) - >>> out = transfo(torch.rand((3, 64, 64))) + >>> out = transfo(Sample(image=torch.rand((3, 64, 64)))) Args: size: output size in pixels, either a tuple (height, width) or a single integer for square images @@ -92,14 +94,18 @@ def _resize_target( def forward( self, - img: torch.Tensor, - target: np.ndarray | None = None, - ) -> ( - torch.Tensor - | tuple[torch.Tensor, np.ndarray] - | tuple[torch.Tensor, np.ndarray, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor] - ): + sample: Sample, + ) -> Sample: + img = sample.image + target = sample.target + mask = sample.mask + + # Resize mask alongside image if provided + # Masks should use nearest interpolation to preserve label integrity + resize_mask = mask is not None + if resize_mask and mask is not None and mask.ndim == 2: + mask = mask.unsqueeze(0) + target_ratio = self.size[0] / self.size[1] actual_ratio = img.shape[-2] / img.shape[-1] @@ -108,18 +114,25 @@ def forward( # We can use with the regular resize img = super().forward(img) + if resize_mask: + mask = F.resize( + mask, + self.size, + interpolation=F.InterpolationMode.NEAREST, + antialias=False, + ).squeeze(0) + if self.return_padding_mask: padding_mask = torch.zeros(self.size, dtype=torch.bool, device=img.device) if target is not None: if self.return_padding_mask: - return img, target, padding_mask - return img, target - + return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) if self.return_padding_mask: - return img, padding_mask + return sample.replace(image=img, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, mask=mask if resize_mask else sample.mask) - return img else: # Resize if actual_ratio > target_ratio: @@ -129,20 +142,33 @@ def forward( # Scale image img = F.resize(img, tmp_size, self.interpolation, antialias=True) + + if resize_mask: + mask = F.resize( + mask, + tmp_size, + interpolation=F.InterpolationMode.NEAREST, + antialias=False, + ).squeeze(0) + raw_shape = img.shape[-2:] + if isinstance(self.size, (tuple, list)): # Pad (inverted in pytorch) _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) + if self.symmetric_pad: half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) # Pad image img = pad(img, _pad) + if resize_mask and mask is not None: + mask = pad(mask, _pad) + if self.return_padding_mask: h, w = self.size padding_mask = torch.zeros((h, w), dtype=torch.bool, device=img.device) - left, right, top, bottom = _pad padding_mask[top : h - bottom, left : w - right] = True @@ -156,7 +182,10 @@ def forward( else: offset = (0, 0) - if isinstance(target, dict): + if isinstance(target, str) or (isinstance(target, np.ndarray) and target.shape == (1,)): + # Special case for orientation targets and other non-box targets, which should not be resized + pass + elif isinstance(target, dict): target = { cls_name: self._resize_target( arr, @@ -175,15 +204,14 @@ def forward( symmetric_pad=self.symmetric_pad, offset=offset, ) + if target is not None: if self.return_padding_mask: - return img, target, padding_mask - return img, target - + return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) if self.return_padding_mask: - return img, padding_mask - - return img + return sample.replace(image=img, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, mask=mask if resize_mask else sample.mask) def __repr__(self) -> str: interpolate_str = self.interpolation.value @@ -198,8 +226,9 @@ class GaussianNoise(torch.nn.Module): >>> import torch >>> from doctr.transforms import GaussianNoise + >>> from doctr.utils import Sample >>> transfo = GaussianNoise(0., 1.) - >>> out = transfo(torch.rand((3, 224, 224))) + >>> out = transfo(Sample(image=torch.rand((3, 224, 224)))) Args: mean : mean of the gaussian distribution @@ -211,13 +240,14 @@ def __init__(self, mean: float = 0.0, std: float = 1.0) -> None: self.std = std self.mean = mean - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Reshape the distribution + def forward(self, sample: Sample) -> Sample: + x = sample.image noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std if x.dtype == torch.uint8: - return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) + image = (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) else: - return (x + noise.to(dtype=x.dtype)).clamp(0, 1) + image = (x + noise.to(dtype=x.dtype)).clamp(0, 1) + return sample.replace(image=image) def extra_repr(self) -> str: return f"mean={self.mean}, std={self.std}" @@ -228,7 +258,9 @@ class GaussianBlur(torch.nn.Module): >>> import torch >>> from doctr.transforms import GaussianBlur + >>> from doctr.utils import Sample >>> transfo = GaussianBlur(sigma=(0.0, 1.0)) + >>> out = transfo(Sample(image=torch.rand((3, 224, 224)))) Args: sigma : standard deviation range for the gaussian kernel @@ -238,38 +270,52 @@ def __init__(self, sigma: tuple[float, float]) -> None: super().__init__() self.sigma_range = sigma - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, sample: Sample) -> Sample: # Sample a random sigma value within the specified range sigma = torch.empty(1).uniform_(*self.sigma_range).item() # Apply Gaussian blur along spatial dimensions only blurred = torch.tensor( gaussian_filter( - x.numpy(), + sample.image.numpy(), sigma=sigma, mode="reflect", truncate=4.0, ), - dtype=x.dtype, - device=x.device, + dtype=sample.image.dtype, + device=sample.image.device, ) - return blurred + return sample.replace(image=blurred) class ChannelShuffle(torch.nn.Module): - """Randomly shuffle channel order of a given image""" + """Randomly shuffle channel order of a given image + + >>> import torch + >>> from doctr.transforms import ChannelShuffle + >>> from doctr.utils import Sample + >>> transfo = ChannelShuffle() + >>> out = transfo(Sample(image=torch.rand((3, 224, 224)))) + """ def __init__(self): super().__init__() - def forward(self, img: torch.Tensor) -> torch.Tensor: + def forward(self, sample: Sample) -> Sample: # Get a random order - chan_order = torch.rand(img.shape[0]).argsort() - return img[chan_order] + chan_order = torch.rand(sample.image.shape[0]).argsort() + return sample.replace(image=sample.image[chan_order]) class RandomHorizontalFlip(T.RandomHorizontalFlip): - """Randomly flip the input image horizontally""" + """Randomly flip the input image horizontally + + >>> import torch + >>> from doctr.transforms import RandomHorizontalFlip + >>> from doctr.utils import Sample + >>> transfo = RandomHorizontalFlip(p=1.0) + >>> out = transfo(Sample(image=torch.rand((3, 224, 224)), target=np.array([[0.1, 0.2, 0.3, 0.4]]))) + """ def _flip_array(self, target): _target = target.copy() @@ -278,24 +324,23 @@ def _flip_array(self, target): _target[:, ::2] = 1 - target[:, [2, 0]] else: _target[..., 0] = 1 - target[..., 0] - return _target - def forward( - self, - img: torch.Tensor | Image, - target: np.ndarray | dict[str, np.ndarray], - ) -> tuple[torch.Tensor | Image, np.ndarray | dict[str, np.ndarray]]: - + def forward(self, sample: Sample) -> Sample: if torch.rand(1) < self.p: - _img = F.hflip(img) + img = F.hflip(sample.image) + mask = F.hflip(sample.mask) if sample.mask is not None else None - if isinstance(target, dict): - return _img, {cls_name: self._flip_array(arr) for cls_name, arr in target.items()} + target = sample.target + if target is not None: + if isinstance(target, dict): + target = {k: self._flip_array(v) for k, v in target.items()} + else: + target = self._flip_array(target) - return _img, self._flip_array(target) + return sample.replace(image=img, mask=mask, target=target) - return img, target + return sample class RandomShadow(torch.nn.Module): @@ -303,8 +348,9 @@ class RandomShadow(torch.nn.Module): >>> import torch >>> from doctr.transforms import RandomShadow + >>> from doctr.utils import Sample >>> transfo = RandomShadow((0., 1.)) - >>> out = transfo(torch.rand((3, 64, 64))) + >>> out = transfo(Sample(image=torch.rand((3, 64, 64)))) Args: opacity_range : minimum and maximum opacity of the shade @@ -314,15 +360,15 @@ def __init__(self, opacity_range: tuple[float, float] | None = None) -> None: super().__init__() self.opacity_range = opacity_range if isinstance(opacity_range, tuple) else (0.2, 0.8) - def __call__(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, sample: Sample) -> Sample: # Reshape the distribution try: - if x.dtype == torch.uint8: - return ( + if sample.image.dtype == torch.uint8: + shadowed_image = ( ( 255 * random_shadow( - x.to(dtype=torch.float32) / 255, + sample.image.to(dtype=torch.float32) / 255, self.opacity_range, ) ) @@ -330,10 +376,12 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: .clip(0, 255) .to(dtype=torch.uint8) ) + return sample.replace(image=shadowed_image) else: - return random_shadow(x, self.opacity_range).clip(0, 1) + shadowed_image = random_shadow(sample.image, self.opacity_range).clip(0, 1) + return sample.replace(image=shadowed_image) except ValueError: - return x + return sample def extra_repr(self) -> str: return f"opacity_range={self.opacity_range}" @@ -344,8 +392,9 @@ class RandomResize(torch.nn.Module): >>> import torch >>> from doctr.transforms import RandomResize + >>> from doctr.utils import Sample >>> transfo = RandomResize((0.3, 0.9), preserve_aspect_ratio=True, symmetric_pad=True, p=0.5) - >>> out = transfo(torch.rand((3, 64, 64))) + >>> out = transfo(Sample(image=torch.rand((3, 64, 64)))) Args: scale_range: range of the resizing factor for width and height (independently) @@ -372,15 +421,14 @@ def __init__( def forward( self, - img: torch.Tensor, - target: np.ndarray | dict[str, np.ndarray], - ) -> tuple[torch.Tensor, np.ndarray | dict[str, np.ndarray]]: + sample: Sample, + ) -> Sample: if torch.rand(1) < self.p: scale_h = np.random.uniform(*self.scale_range) scale_w = np.random.uniform(*self.scale_range) - new_size = (int(img.shape[-2] * scale_h), int(img.shape[-1] * scale_w)) + new_size = (int(sample.image.shape[-2] * scale_h), int(sample.image.shape[-1] * scale_w)) - _img, _target = self._resize( + res = self._resize( new_size, preserve_aspect_ratio=self.preserve_aspect_ratio if isinstance(self.preserve_aspect_ratio, bool) @@ -388,10 +436,9 @@ def forward( symmetric_pad=self.symmetric_pad if isinstance(self.symmetric_pad, bool) else bool(torch.rand(1) <= self.symmetric_pad), - )(img, target) - - return _img, _target - return img, target + )(sample) + return res + return sample def extra_repr(self) -> str: return f"scale_range={self.scale_range}, preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}, p={self.p}" # noqa: E501 diff --git a/doctr/utils/common_types.py b/doctr/utils/common_types.py index 0cafd0dc80..df5a42de48 100644 --- a/doctr/utils/common_types.py +++ b/doctr/utils/common_types.py @@ -3,9 +3,13 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +from dataclasses import dataclass from pathlib import Path +from typing import Any -__all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"] +import numpy as np + +__all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox", "Sample"] Point2D = tuple[float, float] @@ -15,3 +19,19 @@ AbstractPath = str | Path AbstractFile = AbstractPath | bytes Bbox = tuple[float, float, float, float] + + +@dataclass +class Sample: + """Canonical data container for all transforms.""" + + image: Any + mask: Any | None = None + target: np.ndarray | dict[str, np.ndarray] | None = None + + def replace(self, **kwargs) -> "Sample": + return Sample( + image=kwargs.get("image", self.image), + mask=kwargs.get("mask", self.mask), + target=kwargs.get("target", self.target), + ) diff --git a/references/classification/train_character.py b/references/classification/train_character.py index 0b64dfdc65..c266bb2f5c 100644 --- a/references/classification/train_character.py +++ b/references/classification/train_character.py @@ -69,6 +69,7 @@ def record_lr( scaler = torch.cuda.amp.GradScaler() for batch_idx, (images, targets) in enumerate(train_loader): + targets = torch.tensor(targets) if torch.cuda.is_available(): images = images.cuda() targets = targets.cuda() @@ -116,6 +117,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a epoch_train_loss, batch_cnt = 0.0, 0.0 pbar = tqdm(train_loader, dynamic_ncols=True) for images, targets in pbar: + targets = torch.tensor(targets) if torch.cuda.is_available(): images = images.cuda() targets = targets.cuda() @@ -158,6 +160,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): val_loss, correct, samples, batch_cnt = 0, 0, 0, 0 pbar = tqdm(val_loader, dynamic_ncols=True) for images, targets in pbar: + targets = torch.tensor(targets) images = batch_transforms(images) if torch.cuda.is_available(): @@ -228,6 +231,7 @@ def main(args): num_workers=args.workers, sampler=SequentialSampler(val_set), pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, ) pbar.write(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)") @@ -273,13 +277,13 @@ def main(args): T.Resize((args.input_size, args.input_size)), # Augmentations T.RandomApply(T.ColorInversion(), 0.9), - RandomGrayscale(p=0.1), - RandomPhotometricDistort(p=0.1), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.1)), + T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.1)), T.RandomApply(T.RandomShadow(), p=0.4), T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1), T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomPerspective(distortion_scale=0.2, p=0.3), - RandomRotation(15, interpolation=InterpolationMode.BILINEAR), + T.ImageTorchvisionTransform(RandomPerspective(distortion_scale=0.2, p=0.3)), + T.ImageTorchvisionTransform(RandomRotation(15, interpolation=InterpolationMode.BILINEAR)), ]), font_family=fonts, ) @@ -291,6 +295,7 @@ def main(args): num_workers=args.workers, sampler=RandomSampler(train_set), pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, ) pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)") diff --git a/references/classification/train_orientation.py b/references/classification/train_orientation.py index c7dcecc225..86dc4c5931 100644 --- a/references/classification/train_orientation.py +++ b/references/classification/train_orientation.py @@ -33,19 +33,20 @@ from doctr.datasets import OrientationDataset from doctr.models import classification, login_to_hub, push_to_hf_hub from doctr.models.utils import export_model_to_onnx +from doctr.utils import Sample from utils import EarlyStopper, plot_recorder, plot_samples CLASSES = [0, -90, 180, 90] -def rnd_rotate(img: torch.Tensor, target): +def rnd_rotate(sample: Sample) -> Sample: angle = int(np.random.choice(CLASSES)) idx = CLASSES.index(angle) # augment the angle randomly with a probability of 0.5 if np.random.rand() < 0.5: angle += float(np.random.choice(np.arange(-25, 25, 5))) - rotated_img = F.rotate(img, angle=-angle, fill=0, expand=angle not in CLASSES)[:3] - return rotated_img, idx + rotated_img = F.rotate(sample.image, angle=-angle, fill=0, expand=angle not in CLASSES)[:3] + return Sample(image=rotated_img, target=idx) def record_lr( @@ -80,6 +81,7 @@ def record_lr( scaler = torch.cuda.amp.GradScaler() for batch_idx, (images, targets) in enumerate(train_loader): + targets = torch.tensor(targets) if torch.cuda.is_available(): images = images.cuda() targets = targets.cuda() @@ -127,6 +129,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a epoch_train_loss, batch_cnt = 0.0, 0.0 pbar = tqdm(train_loader, dynamic_ncols=True) for images, targets in pbar: + targets = torch.tensor(targets) if torch.cuda.is_available(): images = images.cuda() targets = targets.cuda() @@ -169,6 +172,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): val_loss, correct, samples, batch_cnt = 0.0, 0.0, 0.0, 0.0 pbar = tqdm(val_loader, dynamic_ncols=True) for images, targets in pbar: + targets = torch.tensor(targets) images = batch_transforms(images) if torch.cuda.is_available(): @@ -225,7 +229,7 @@ def main(args): T.Resize(input_size, preserve_aspect_ratio=True, symmetric_pad=True), ]), sample_transforms=T.SampleCompose([ - lambda x, y: rnd_rotate(x, y), + rnd_rotate, T.Resize(input_size), ]), ) @@ -236,6 +240,7 @@ def main(args): num_workers=args.workers, sampler=SequentialSampler(val_set), pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, ) pbar.write(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in {len(val_loader)} batches)") @@ -280,12 +285,12 @@ def main(args): T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), T.RandomApply(T.RandomShadow(), 0.2), T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomPhotometricDistort(p=0.1), - RandomGrayscale(p=0.1), - RandomPerspective(distortion_scale=0.1, p=0.3), + T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.1)), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.1)), + T.ImageTorchvisionTransform(RandomPerspective(distortion_scale=0.1, p=0.3)), ]), sample_transforms=T.SampleCompose([ - lambda x, y: rnd_rotate(x, y), + rnd_rotate, T.Resize(input_size), ]), ) @@ -297,6 +302,7 @@ def main(args): num_workers=args.workers, sampler=RandomSampler(train_set), pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, ) pbar.write(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in {len(train_loader)} batches)") diff --git a/references/detection/train.py b/references/detection/train.py index 8c2c429d31..43e4ec3a56 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -322,9 +322,9 @@ def main(args): T.RandomApply(T.RandomShadow(), 0.3), T.RandomApply(T.GaussianNoise(), 0.1), T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomGrayscale(p=0.15), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.15)), ]), - RandomPhotometricDistort(p=0.3), + T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.3)), lambda x: x, # Identity no transformation ]) # Image + target augmentations diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py index 7888b9c816..36fb78cf80 100644 --- a/references/layout/evaluate.py +++ b/references/layout/evaluate.py @@ -7,6 +7,7 @@ import os import time +import numpy as np import torch from torch.utils.data import DataLoader, SequentialSampler from torchvision.transforms import Normalize @@ -20,6 +21,7 @@ from doctr.datasets import LayoutDataset from doctr.models import layout from doctr.utils.metrics import ObjectDetectionMetric +from utils import convert_target @torch.inference_mode() @@ -44,14 +46,16 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): # Compute metric loc_preds = out["preds"] for target, pred in zip(targets, loc_preds): - assert pred["boxes"].shape[0] == pred["scores"].shape[0] - assert pred["boxes"].shape[0] == pred["labels"].shape[0] + target_boxes, target_labels = convert_target(target, model.class_names) + pred_labels = np.asarray(pred[0], dtype=np.int64) + pred_boxes = np.asarray(pred[1], dtype=np.float32) + pred_scores = np.asarray(pred[2], dtype=np.float32) val_metric.update( - gt_boxes=target["boxes"], - pred_boxes=pred["boxes"], - gt_labels=target["labels"], - pred_labels=pred["labels"], - pred_scores=pred["scores"], + gt_boxes=target_boxes, + pred_boxes=pred_boxes, + gt_labels=target_labels, + pred_labels=pred_labels, + pred_scores=pred_scores, ) val_loss += out["loss"].item() diff --git a/references/layout/train.py b/references/layout/train.py index 0ef5b4fd4c..e4deb8977f 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -31,7 +31,7 @@ from doctr.datasets import LayoutDataset from doctr.models import layout, login_to_hub, push_to_hf_hub from doctr.utils.metrics import ObjectDetectionMetric -from utils import EarlyStopper, plot_recorder, plot_samples +from utils import EarlyStopper, convert_target, plot_recorder, plot_samples def record_lr( @@ -177,15 +177,16 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non # Compute metric loc_preds = out["preds"] for target, pred in zip(targets, loc_preds): - assert pred["boxes"].shape[0] == pred["scores"].shape[0] - assert pred["boxes"].shape[0] == pred["labels"].shape[0] - + target_boxes, target_labels = convert_target(target, model.class_names) + pred_labels = np.asarray(pred[0], dtype=np.int64) + pred_boxes = np.asarray(pred[1], dtype=np.float32) + pred_scores = np.asarray(pred[2], dtype=np.float32) val_metric.update( - gt_boxes=target["boxes"], - pred_boxes=pred["boxes"], - gt_labels=target["labels"], - pred_labels=pred["labels"], - pred_scores=pred["scores"], + gt_boxes=target_boxes, + pred_boxes=pred_boxes, + gt_labels=target_labels, + pred_labels=pred_labels, + pred_scores=pred_scores, ) pbar.set_description(f"Validation loss: {out['loss'].item():.6f}") @@ -282,7 +283,9 @@ def main(args): ) + ( [ - T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + T.Resize( + args.input_size, preserve_aspect_ratio=True, return_padding_mask=True + ), # This does not pad T.RandomApply(T.RandomRotate(90, expand=True), 0.5), T.Resize( (args.input_size, args.input_size), @@ -369,9 +372,9 @@ def main(args): T.RandomApply(T.RandomShadow(), 0.3), T.RandomApply(T.GaussianNoise(), 0.1), T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomGrayscale(p=0.15), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.15)), ]), - RandomPhotometricDistort(p=0.3), + T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.3)), lambda x: x, # Identity no transformation ]) # Image + target augmentations @@ -383,7 +386,12 @@ def main(args): T.RandomApply(T.RandomCrop(ratio=(0.6, 1.33)), 0.25), T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + T.Resize( + (args.input_size, args.input_size), + preserve_aspect_ratio=True, + symmetric_pad=True, + return_padding_mask=True, + ), ] if not args.rotation else [ @@ -393,7 +401,7 @@ def main(args): T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), ]), # Rotation augmentation - T.Resize(args.input_size, preserve_aspect_ratio=True), + T.Resize(args.input_size, preserve_aspect_ratio=True, return_padding_mask=True), T.RandomApply(T.RandomRotate(90, expand=True), 0.5), # Important to return padding masks for layout models T.Resize( @@ -447,7 +455,8 @@ def main(args): if rank == 0 and args.show_samples: x, target = next(iter(train_loader)) - plot_samples(x, target) + img, masks = x + plot_samples(img, target, masks) return # Backbone freezing diff --git a/references/layout/utils.py b/references/layout/utils.py index 218d5548ea..6629487ab7 100644 --- a/references/layout/utils.py +++ b/references/layout/utils.py @@ -4,42 +4,85 @@ # See LICENSE or go to for full license details. +from typing import Any + import cv2 import matplotlib.pyplot as plt import numpy as np -def plot_samples(images, targets: list[dict[str, np.ndarray]]) -> None: - # Unnormalize image - nb_samples = min(len(images), 4) - _, axes = plt.subplots(2, nb_samples, figsize=(20, 5)) +def convert_target(target: dict[str, list], class_names: list[str]) -> tuple[np.ndarray, np.ndarray]: + """Convert the target from the dataset format to the format expected by the metric + + Args: + target: dictionary containing the target boxes and labels for a single sample + class_names: list of class names + + Returns: + tuple of (boxes, labels) where boxes is an array of shape (N, 4) or (N, 4, 2) depending on the use of polygons, + and labels is an array of shape (N,) containing the class indices. + """ + boxes = [] + labels = [] + + class_to_idx = {name: idx for idx, name in enumerate(class_names)} + + for class_name, class_boxes in target.items(): + if len(class_boxes) == 0: + continue + + boxes.extend(class_boxes) + labels.extend([class_to_idx[class_name]] * len(class_boxes)) + + return np.asarray(boxes, dtype=np.float32), np.asarray(labels, dtype=np.int64) + + +def plot_samples( + images: list[Any], + targets: list[dict[str, np.ndarray]], + padding_masks: list[Any] | None = None, + max_samples: int = 4, +) -> None: + nb_samples = min(len(images), max_samples) + _, axes = plt.subplots(3, nb_samples, figsize=(20, 8)) + + if nb_samples == 1: + axes = np.expand_dims(axes, axis=1) + for idx in range(nb_samples): - img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) + img = (255 * images[idx].detach().cpu().numpy()).round().clip(0, 255).astype(np.uint8) if img.shape[0] == 3 and img.shape[2] != 3: img = img.transpose(1, 2, 0) + axes[0][idx].imshow(img) + axes[0][idx].set_title("Image") + target = np.zeros(img.shape[:2], np.uint8) tgts = targets[idx].copy() for boxes in tgts.values(): boxes[:, [0, 2]] = boxes[:, [0, 2]] * img.shape[1] boxes[:, [1, 3]] = boxes[:, [1, 3]] * img.shape[0] boxes[:, :4] = boxes[:, :4].round().astype(int) - for box in boxes: if boxes.ndim == 3: cv2.fillPoly(target, [np.intp(box)], 1) else: target[int(box[1]) : int(box[3]) + 1, int(box[0]) : int(box[2]) + 1] = 1 - if nb_samples > 1: - axes[0][idx].imshow(img) - axes[1][idx].imshow(target.astype(bool)) - else: - axes[0].imshow(img) - axes[1].imshow(target.astype(bool)) - # Disable axis + axes[1][idx].imshow(target.astype(bool), cmap="gray") + axes[1][idx].set_title("GT Boxes") + + if padding_masks is not None and padding_masks[idx] is not None: + pm = padding_masks[idx].detach().cpu().numpy() + pm = pm.squeeze().astype(bool) + axes[2][idx].imshow(pm, cmap="gray") + axes[2][idx].set_title("Padding Mask") + else: + axes[2][idx].text(0.5, 0.5, "No mask", ha="center", va="center") + axes[2][idx].set_title("Padding Mask") for ax in axes.ravel(): ax.axis("off") + plt.tight_layout() plt.show() diff --git a/references/recognition/train.py b/references/recognition/train.py index a4940ad5cf..dc6b7b1b24 100644 --- a/references/recognition/train.py +++ b/references/recognition/train.py @@ -359,12 +359,12 @@ def main(args): T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Augmentations T.RandomApply(T.ColorInversion(), 0.1), - RandomGrayscale(p=0.1), - RandomPhotometricDistort(p=0.1), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.1)), + T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.1)), T.RandomApply(T.RandomShadow(), p=0.4), T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1), T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomPerspective(distortion_scale=0.2, p=0.3), + T.ImageTorchvisionTransform(RandomPerspective(distortion_scale=0.2, p=0.3)), ]), ) if len(parts) > 1: @@ -409,12 +409,12 @@ def main(args): T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), # Ensure we have a 90% split of white-background images T.RandomApply(T.ColorInversion(), 0.9), - RandomGrayscale(p=0.1), - RandomPhotometricDistort(p=0.1), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.1)), + T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.1)), T.RandomApply(T.RandomShadow(), p=0.4), T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1), T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), - RandomPerspective(distortion_scale=0.2, p=0.3), + T.ImageTorchvisionTransform(RandomPerspective(distortion_scale=0.2, p=0.3)), ]), ) if distributed: diff --git a/tests/common/test_datasets.py b/tests/common/test_datasets.py index b86fbfc8f5..669e0f140e 100644 --- a/tests/common/test_datasets.py +++ b/tests/common/test_datasets.py @@ -26,41 +26,61 @@ def test_abstractdataset(mock_image_path): # Check target format with pytest.raises(AssertionError): ds.data = [(path.name, 0)] - img, target = ds[0] + _ = ds[0] with pytest.raises(AssertionError): ds.data = [(path.name, dict(boxes=np.array([[0, 0, 1, 1]])))] - img, target = ds[0] + _ = ds[0] with pytest.raises(AssertionError): - ds.data = [(ds.data[0][0], {"label": "A"})] - img, target = ds[0] + ds.data = [(path.name, {"label": "A"})] + _ = ds[0] # Patch some data ds.data = [(path.name, np.array([0]))] # Fetch the img - img, target = ds[0] - assert isinstance(target, np.ndarray) and target == np.array([0]) + sample = ds[0] + img, target = sample.image, sample.target + assert isinstance(target, np.ndarray) + assert np.array_equal(target, np.array([0])) # Check img_transforms - ds.img_transforms = lambda x: 1 - x - img2, target2 = ds[0] + def img_transform(sample): + sample.image = 1 - sample.image + return sample + + ds.img_transforms = img_transform + + sample2 = ds[0] + img2, target2 = sample2.image, sample2.target + assert np.all(img2.numpy() == 1 - img.numpy()) - assert target == target2 + assert np.array_equal(target, target2) # Check sample_transforms ds.img_transforms = None - ds.sample_transforms = lambda x, y: (x, y + 1) - img3, target3 = ds[0] - assert np.all(img3.numpy() == img.numpy()) and (target3 == (target + 1)) + + def sample_transform(sample): + sample.target = sample.target + 1 + return sample + + ds.sample_transforms = sample_transform + + sample3 = ds[0] + img3, target3 = sample3.image, sample3.target + + assert np.all(img3.numpy() == img.numpy()) + assert np.array_equal(target3, target + 1) # Check inplace modifications ds.data = [(ds.data[0][0], "A")] - def inplace_transfo(x, target): - target += "B" - return x, target + def inplace_transfo(sample): + sample.target += "B" + return sample ds.sample_transforms = inplace_transfo - _, t = ds[0] - _, t = ds[0] + + t = ds[0].target + t = ds[0].target + assert t == "AB" diff --git a/tests/common/test_transforms.py b/tests/common/test_transforms.py index 0ad9fdb138..5a2c853f98 100644 --- a/tests/common/test_transforms.py +++ b/tests/common/test_transforms.py @@ -3,17 +3,29 @@ from doctr.transforms import modules as T from doctr.transforms.functional.base import expand_line +from doctr.utils import Sample def test_imagetransform(): - transfo = T.ImageTransform(lambda x: 1 - x) - assert transfo(0, 1) == (1, 1) + transfo = T.ImageTransform(lambda sample: 1 - sample.image) + assert transfo(Sample(image=0, target=1)) == Sample(image=1, target=1) def test_samplecompose(): - transfos = [lambda x, y: (1 - x, y), lambda x, y: (x, 2 * y)] + transfos = [ + lambda sample: Sample( + image=1 - sample.image, + target=sample.target, + mask=sample.mask, + ), + lambda sample: Sample( + image=sample.image, + target=2 * sample.target, + mask=sample.mask, + ), + ] transfo = T.SampleCompose(transfos) - assert transfo(0, 1) == (1, 2) + assert transfo(Sample(image=0, target=1)) == Sample(image=1, target=2) def test_oneof(): @@ -23,11 +35,24 @@ def test_oneof(): assert out == 0 or out == 11 # test with ndarray target - transfos = [lambda x, y: (1 - x, y), lambda x, y: (x + 10, y)] + transfos = [ + lambda sample: Sample( + image=1 - sample.image, + target=sample.target, + mask=sample.mask, + ), + lambda sample: Sample( + image=sample.image + 10, + target=sample.target, + mask=sample.mask, + ), + ] + transfo = T.OneOf(transfos) - out = transfo(1, np.array([2])) - assert out == (0, 2) or out == (11, 2) - assert isinstance(out[1], np.ndarray) + out = transfo(Sample(image=1, target=np.array([2]))) + assert out.image == 0 or out.image == 11 + assert isinstance(out.target, np.ndarray) + np.testing.assert_array_equal(out.target, np.array([2])) # test with dict target dict_target = { @@ -35,22 +60,28 @@ def test_oneof(): "labels": np.array([1], dtype=np.int64), } transfos = [ - lambda x, y: (1 - x, y), - lambda x, y: (x + 10, y), + lambda sample: Sample( + image=1 - sample.image, + target=sample.target, + mask=sample.mask, + ), + lambda sample: Sample( + image=sample.image + 10, + target=sample.target, + mask=sample.mask, + ), ] transfo = T.OneOf(transfos) - out = transfo(1, dict_target) - assert out[0] == 0 or out[0] == 11 - assert isinstance(out[1], dict) - assert set(out[1].keys()) == {"boxes", "labels"} - assert isinstance(out[1]["boxes"], np.ndarray) - assert isinstance(out[1]["labels"], np.ndarray) + out = transfo(Sample(image=1, target=dict_target)) + assert out.image == 0 or out.image == 11 + assert isinstance(out.target, dict) + assert set(out.target.keys()) == {"boxes", "labels"} np.testing.assert_array_equal( - out[1]["boxes"], + out.target["boxes"], np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), ) np.testing.assert_array_equal( - out[1]["labels"], + out.target["labels"], np.array([1], dtype=np.int64), ) @@ -59,12 +90,21 @@ def test_randomapply(): transfo = T.RandomApply(lambda x: 1 - x) out = transfo(1) assert out == 0 or out == 1 - # test with ndarray target - transfo = T.RandomApply(lambda x, y: (1 - x, 2 * y)) - out = transfo(1, np.array([2])) - assert out == (0, 4) or out == (1, 2) - assert isinstance(out[1], np.ndarray) + transfo = T.RandomApply( + lambda sample: Sample( + image=1 - sample.image, + target=2 * sample.target, + mask=sample.mask, + ) + ) + out = transfo(Sample(image=1, target=np.array([2]))) + assert out.image == 0 or out.image == 1 + assert isinstance(out.target, np.ndarray) + if out.image == 0: + np.testing.assert_array_equal(out.target, np.array([4])) + else: + np.testing.assert_array_equal(out.target, np.array([2])) # test with dict target dict_target = { @@ -72,33 +112,30 @@ def test_randomapply(): "labels": np.array([1], dtype=np.int64), } transfo = T.RandomApply( - lambda x, y: ( - 1 - x, - { - "boxes": 2 * y["boxes"], - "labels": y["labels"], + lambda sample: Sample( + image=1 - sample.image, + target={ + "boxes": 2 * sample.target["boxes"], + "labels": sample.target["labels"], }, + mask=sample.mask, ) ) - - out = transfo(1, dict_target) - assert out[0] == 0 or out[0] == 1 - assert isinstance(out[1], dict) - assert set(out[1].keys()) == {"boxes", "labels"} - assert isinstance(out[1]["boxes"], np.ndarray) - assert isinstance(out[1]["labels"], np.ndarray) - if out[0] == 0: + out = transfo(Sample(image=1, target=dict_target)) + assert out.image == 0 or out.image == 1 + assert isinstance(out.target, dict) + if out.image == 0: np.testing.assert_array_equal( - out[1]["boxes"], + out.target["boxes"], 2 * np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), ) else: np.testing.assert_array_equal( - out[1]["boxes"], + out.target["boxes"], np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), ) np.testing.assert_array_equal( - out[1]["labels"], + out.target["labels"], np.array([1], dtype=np.int64), ) assert repr(transfo).endswith(", p=0.5)") diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index 7c1e97d808..00e8b2dbce 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -14,7 +14,8 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): # Fetch one sample - img, target = ds[0] + sample = ds[0] + img, target = sample.image, sample.target assert isinstance(img, torch.Tensor) assert img.shape == (3, *input_size) @@ -50,7 +51,8 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly def _validate_dataset_recognition_part(ds, input_size, batch_size=2): # Fetch one sample - img, label = ds[0] + sample = ds[0] + img, label = sample.image, sample.target assert isinstance(img, torch.Tensor) assert img.shape == (3, *input_size) @@ -75,7 +77,8 @@ def _validate_dataset_recognition_part(ds, input_size, batch_size=2): def _validate_dataset_detection_part(ds, input_size, batch_size=2, is_polygons=False): # Fetch one sample - img, target = ds[0] + sample = ds[0] + img, target = sample.image, sample.target assert isinstance(img, torch.Tensor) assert img.shape == (3, *input_size) @@ -118,7 +121,8 @@ def test_rotation_dataset(mock_image_folder): ds = datasets.OrientationDataset(img_folder=mock_image_folder, img_transforms=Resize(input_size)) assert len(ds) == 5 - img, target = ds[0] + sample = ds[0] + img, target = sample.image, sample.target assert isinstance(img, torch.Tensor) assert img.dtype == torch.float32 assert img.shape[-2:] == input_size @@ -143,7 +147,8 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): ) assert len(ds) == 5 - img, target_dict = ds[0] + sample = ds[0] + img, target_dict = sample.image, sample.target target = target_dict[CLASS_NAME] assert isinstance(img, torch.Tensor) assert img.dtype == torch.float32 @@ -167,7 +172,7 @@ def test_detection_dataset(mock_image_folder, mock_detection_label): img_transforms=Resize(input_size), use_polygons=True, ) - _, r_target = rotated_ds[0] + r_target = rotated_ds[0].target assert r_target[CLASS_NAME].shape[1:] == (4, 2) # File existence check @@ -193,9 +198,8 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): ) assert len(ds) == 5 - inputs, target_dict = ds[0] - assert isinstance(inputs, tuple) and len(inputs) == 2 - img, padding_mask = inputs + sample = ds[0] + img, padding_mask, target_dict = sample.image, sample.mask, sample.target assert isinstance(img, torch.Tensor) assert img.dtype == torch.float32 assert img.shape[-2:] == input_size @@ -308,7 +312,8 @@ def test_recognition_dataset(mock_image_folder, mock_recognition_label): img_transforms=Resize(input_size, preserve_aspect_ratio=True), ) assert len(ds) == 5 - image, label = ds[0] + sample = ds[0] + image, label = sample.image, sample.target assert isinstance(image, torch.Tensor) assert image.shape[-2:] == input_size assert image.dtype == torch.float32 @@ -363,7 +368,8 @@ def test_charactergenerator(): ) assert len(ds) == 10 - image, label = ds[0] + sample = ds[0] + image, label = sample.image, sample.target assert isinstance(image, torch.Tensor) assert image.shape[-2:] == input_size assert image.dtype == torch.float32 @@ -372,8 +378,8 @@ def test_charactergenerator(): loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) images, targets = next(iter(loader)) assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) - assert isinstance(targets, torch.Tensor) and targets.shape == (2,) - assert targets.dtype == torch.int64 + assert isinstance(targets, list) and len(targets) == 2 + assert all(isinstance(t, int) for t in targets) def test_wordgenerator(): @@ -391,7 +397,8 @@ def test_wordgenerator(): ) assert len(ds) == 10 - image, target = ds[0] + sample = ds[0] + image, target = sample.image, sample.target assert isinstance(image, torch.Tensor) assert image.shape[-2:] == input_size assert image.dtype == torch.float32 diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index aa7e687114..d0c2865411 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -13,6 +13,7 @@ @pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize("use_polygons", [True, False]) @pytest.mark.parametrize( "arch_name, input_shape", [ @@ -20,38 +21,42 @@ ["lw_detr_m", (3, 1024, 1024)], ], ) -def test_layout_models(arch_name, input_shape, train_mode): +def test_layout_models(arch_name, input_shape, train_mode, use_polygons): batch_size = 2 model = layout.__dict__[arch_name](pretrained=True) model = model.train() if train_mode else model.eval() assert isinstance(model, torch.nn.Module) input_tensor = torch.rand((batch_size, *input_shape)) input_masks = torch.ones((batch_size, input_shape[1], input_shape[2]), dtype=torch.bool) + + class_names = model.class_names + target = [] for _ in range(batch_size): + sample_target = {} num_boxes = 5 - - class_ids = torch.randint(0, 10, (num_boxes,)).tolist() - - # random boxes in normalized-ish space - boxes = [] for _ in range(num_boxes): - cx, cy = torch.rand(2) * 0.8 + 0.1 - w, h = torch.rand(2) * 0.2 + 0.05 - - x1, y1 = cx - w / 2, cy - h / 2 - x2, y2 = cx + w / 2, cy - h / 2 - x3, y3 = cx + w / 2, cy + h / 2 - x4, y4 = cx - w / 2, cy + h / 2 - - boxes.append([ - [x1, y1], - [x2, y2], - [x3, y3], - [x4, y4], - ]) - - target.append((class_ids, torch.tensor(boxes, dtype=torch.float32))) + cls_name = np.random.choice(class_names) + x1, y1 = torch.rand(2) * 0.8 + if use_polygons: + w, h = 0.1, 0.1 + + box = np.array( + [ + [x1, y1], + [x1 + w, y1], + [x1 + w, y1 + h], + [x1, y1 + h], + ], + dtype=np.float32, + ) # (4,2) + else: + x2, y2 = x1 + 0.1, y1 + 0.1 + box = np.array([x1, y1, x2, y2], dtype=np.float32) # (4,) + sample_target.setdefault(cls_name, []) + sample_target[cls_name].append(box) + target.append(sample_target) + if torch.cuda.is_available(): model.cuda() input_tensor = input_tensor.cuda() diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index 657e1561fc..2a41a712b1 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -4,13 +4,14 @@ import numpy as np import pytest import torch +from torchvision.transforms.v2 import RandomGrayscale from doctr.transforms import ( ChannelShuffle, ColorInversion, GaussianBlur, GaussianNoise, - ImageTransform, + ImageTorchvisionTransform, OneOf, RandomApply, RandomCrop, @@ -22,13 +23,14 @@ SampleCompose, ) from doctr.transforms.functional import crop_detection, rotate_sample +from doctr.utils import Sample def test_resize(): output_size = (32, 32) transfo = Resize(output_size) - input_t = torch.ones((3, 64, 64), dtype=torch.float32) - out = transfo(input_t) + input_t = Sample(image=torch.ones((3, 64, 64), dtype=torch.float32)) + out = transfo(input_t).image assert torch.all(out == 1) assert out.shape[-2:] == output_size @@ -36,7 +38,8 @@ def test_resize(): # Test return_padding_mask without aspect ratio transfo = Resize(output_size, return_padding_mask=True) - out, mask = transfo(input_t) + data = transfo(input_t) + out, mask = data.image, data.mask assert out.shape[-2:] == output_size assert mask.shape == output_size assert mask.dtype == torch.bool @@ -44,18 +47,19 @@ def test_resize(): # Test with preserve_aspect_ratio output_size = (32, 32) - input_t = torch.ones((3, 32, 64), dtype=torch.float32) + input_t = Sample(image=torch.ones((3, 32, 64), dtype=torch.float32)) # Asymmetric padding transfo = Resize(output_size, preserve_aspect_ratio=True) - out = transfo(input_t) + out = transfo(input_t).image assert out.shape[-2:] == output_size assert not torch.all(out == 1) assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 1) # Asymmetric padding mask transfo = Resize(output_size, preserve_aspect_ratio=True, return_padding_mask=True) - out, mask = transfo(input_t) + data = transfo(input_t) + out, mask = data.image, data.mask assert mask.shape == output_size assert mask.dtype == torch.bool assert mask.any() @@ -64,13 +68,14 @@ def test_resize(): # Symmetric padding transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True) - out = transfo(input_t) + out = transfo(input_t).image assert out.shape[-2:] == output_size assert torch.all(out[:, 0] == 0) and torch.all(out[:, -1] == 0) # Symmetric padding mask transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True, return_padding_mask=True) - out, mask = transfo(input_t) + data = transfo(input_t) + out, mask = data.image, data.mask assert mask.shape == output_size assert mask.dtype == torch.bool assert mask.any() @@ -81,20 +86,20 @@ def test_resize(): assert repr(transfo) == expected # Test with inverse resize - input_t = torch.ones((3, 64, 32), dtype=torch.float32) + input_t = Sample(image=torch.ones((3, 64, 32), dtype=torch.float32)) transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True) - out = transfo(input_t) + out = transfo(input_t).image assert out.shape[-2:] == (32, 32) # Test resize with same ratio transfo = Resize((32, 128), preserve_aspect_ratio=True) - out = transfo(torch.ones((3, 16, 64), dtype=torch.float32)) + out = transfo(Sample(image=torch.ones((3, 16, 64), dtype=torch.float32))).image assert out.shape[-2:] == (32, 128) # Test with fp16 input transfo = Resize((32, 128), preserve_aspect_ratio=True) - input_t = torch.ones((3, 64, 64), dtype=torch.float16) - out = transfo(input_t) + input_t = Sample(image=torch.ones((3, 64, 64), dtype=torch.float16)) + out = transfo(input_t).image assert out.dtype == torch.float16 padding = [True, False] @@ -102,8 +107,9 @@ def test_resize(): # Test with target boxes target_boxes = np.array([[0.1, 0.1, 0.3, 0.4], [0.2, 0.2, 0.8, 0.8]]) transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad, return_padding_mask=True) - input_t = torch.ones((3, 32, 64), dtype=torch.float32) - out, new_target, mask = transfo(input_t, target_boxes) + input_t = Sample(image=torch.ones((3, 32, 64), dtype=torch.float32), target=target_boxes) + data = transfo(input_t) + out, mask, new_target = data.image, data.mask, data.target assert out.shape[-2:] == (64, 64) assert new_target.shape == target_boxes.shape @@ -117,8 +123,9 @@ def test_resize(): [[0.2, 0.2], [0.8, 0.2], [0.8, 0.8], [0.2, 0.8]], ]) transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad, return_padding_mask=True) - input_t = torch.ones((3, 32, 64), dtype=torch.float32) - out, new_target, mask = transfo(input_t, target_boxes) + input_t = Sample(image=torch.ones((3, 32, 64), dtype=torch.float32), target=target_boxes) + data = transfo(input_t) + out, mask, new_target = data.image, data.mask, data.target assert out.shape[-2:] == (64, 64) assert new_target.shape == target_boxes.shape @@ -132,7 +139,7 @@ def test_resize(): transfo = Resize((64, 64), preserve_aspect_ratio=True) with pytest.raises(AssertionError): - transfo(input_t, target) + transfo(Sample(image=input_t, target=target)) # Test dict targets target_dict = { @@ -144,7 +151,7 @@ def test_resize(): preserve_aspect_ratio=True, symmetric_pad=True, ) - _, new_target = transfo(input_t, target_dict) + new_target = transfo(Sample(image=input_t, target=target_dict)).target assert isinstance(new_target, dict) assert set(new_target.keys()) == {"boxes", "polygons"} assert new_target["boxes"].shape == (1, 4) @@ -152,18 +159,21 @@ def test_resize(): # Test return type combinations transfo = Resize((32, 32)) - out = transfo(input_t) + out = transfo(Sample(image=input_t)).image assert isinstance(out, torch.Tensor) transfo = Resize((32, 32), return_padding_mask=True) - out = transfo(input_t) - assert isinstance(out, tuple) - assert len(out) == 2 + out = transfo(Sample(image=input_t)) + assert isinstance(out, Sample) + assert hasattr(out, "image") and hasattr(out, "mask") + assert out.image.shape[-2:] == (32, 32) + assert out.mask.shape[-2:] == (32, 32) transfo = Resize((32, 32), preserve_aspect_ratio=True) - out = transfo(input_t, target_boxes) - assert isinstance(out, tuple) - assert len(out) == 2 + out = transfo(Sample(image=input_t, target=target_boxes)) + assert isinstance(out, Sample) + assert hasattr(out, "image") and hasattr(out, "target") + assert out.image.shape[-2:] == (32, 32) transfo = Resize( (32, 32), @@ -171,9 +181,11 @@ def test_resize(): return_padding_mask=True, ) - out = transfo(input_t, target_boxes) - assert isinstance(out, tuple) - assert len(out) == 3 + out = transfo(Sample(image=input_t, target=target_boxes)) + assert isinstance(out, Sample) + assert hasattr(out, "image") and hasattr(out, "mask") and hasattr(out, "target") + assert out.image.shape[-2:] == (32, 32) + assert out.mask.shape[-2:] == (32, 32) @pytest.mark.parametrize( @@ -186,21 +198,23 @@ def test_resize(): ) def test_invert_colorize(rgb_min): transfo = ColorInversion(min_val=rgb_min) - input_t = torch.ones((8, 3, 32, 32), dtype=torch.float32) - out = transfo(input_t) + input_t = Sample(image=torch.ones((8, 3, 32, 32), dtype=torch.float32)) + out = transfo(input_t).image assert torch.all(out <= 1 - rgb_min + 1e-4) assert torch.all(out >= 0) - input_t = torch.full((8, 3, 32, 32), 255, dtype=torch.uint8) - out = transfo(input_t) + input_t = Sample(image=torch.full((8, 3, 32, 32), 255, dtype=torch.uint8)) + out = transfo(input_t).image assert torch.all(out <= int(math.ceil(255 * (1 - rgb_min + 1e-4)))) assert torch.all(out >= 0) # FP16 - input_t = torch.ones((8, 3, 32, 32), dtype=torch.float16) - out = transfo(input_t) + input_t = Sample(image=torch.ones((8, 3, 32, 32), dtype=torch.float16)) + out = transfo(input_t).image assert out.dtype == torch.float16 + assert repr(transfo) == f"ColorInversion(min_val={rgb_min})" + def test_rotate_sample(): img = torch.ones((3, 200, 100), dtype=torch.float32) @@ -254,19 +268,24 @@ def test_random_rotate(): rotator = RandomRotate(max_angle=10.0, expand=False) input_t = torch.ones((3, 50, 50), dtype=torch.float32) boxes = np.array([[15, 20, 35, 30]]) - r_img, _r_boxes = rotator(input_t, boxes) + data = rotator(Sample(image=input_t, target=boxes)) + r_img = data.image assert r_img.shape == input_t.shape rotator = RandomRotate(max_angle=10.0, expand=True) - r_img, _r_boxes = rotator(input_t, boxes) + data = rotator(Sample(image=input_t, target=boxes)) + r_img = data.image assert r_img.shape != input_t.shape + assert repr(rotator) == "RandomRotate(max_angle=10.0, expand=True)" + # Test dict targets dict_target = { "boxes": np.array([[15, 20, 35, 30]]), "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), } - r_img, r_targets = rotator(input_t, dict_target) + data = rotator(Sample(image=input_t, target=dict_target)) + r_img, r_targets = data.image, data.target assert isinstance(r_targets, dict) assert set(r_targets.keys()) == {"boxes", "polygons"} assert isinstance(r_targets["boxes"], np.ndarray) @@ -287,7 +306,8 @@ def test_random_rotate(): "boxes": np.zeros((0, 4), dtype=np.float32), "polygons": np.zeros((0, 4, 2), dtype=np.float32), } - r_img, r_targets = rotator(input_t, empty_targets) + data = rotator(Sample(image=input_t, target=empty_targets)) + r_img, r_targets = data.image, data.target assert isinstance(r_targets, dict) assert r_targets["boxes"].shape == (0, 4) assert r_targets["polygons"].shape == (0, 4, 2) @@ -295,7 +315,8 @@ def test_random_rotate(): # FP16 (only on GPU) if torch.cuda.is_available(): input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda() - r_img, _ = rotator(input_t, boxes) + data = rotator(Sample(image=input_t, target=boxes)) + r_img = data.image assert r_img.dtype == torch.float16 @@ -339,9 +360,11 @@ def test_crop_detection(): ) def test_random_crop(target): cropper = RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33)) + assert repr(cropper) == "RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33))" input_t = torch.ones((3, 50, 50), dtype=torch.float32) - img, target = cropper(input_t, target) + sample = cropper(Sample(image=input_t, target=target)) + img, target = sample.image, sample.target # Check the scale assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] # Check aspect ratio @@ -358,7 +381,8 @@ def test_random_crop(target): "boxes": np.array([[15, 20, 35, 30]]), "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), } - img, cropped_targets = cropper(input_t, dict_target) + sample = cropper(Sample(image=input_t, target=dict_target)) + img, cropped_targets = sample.image, sample.target assert isinstance(cropped_targets, dict) assert set(cropped_targets.keys()) == {"boxes", "polygons"} assert isinstance(cropped_targets["boxes"], np.ndarray) @@ -387,20 +411,20 @@ def test_random_crop(target): ) def test_channel_shuffle(input_dtype, input_size): transfo = ChannelShuffle() - input_t = torch.rand(input_size, dtype=torch.float32) + input_t = Sample(image=torch.rand(input_size, dtype=torch.float32)) if input_dtype == torch.uint8: - input_t = (255 * input_t).round() - input_t = input_t.to(dtype=input_dtype) - out = transfo(input_t) + input_t.image = (255 * input_t.image).round() + input_t.image = input_t.image.to(dtype=input_dtype) + out = transfo(input_t).image assert isinstance(out, torch.Tensor) assert out.shape == input_size assert out.dtype == input_dtype # Ensure that nothing has changed apart from channel order if input_dtype == torch.uint8: - assert torch.all(input_t.sum(0) == out.sum(0)) + assert torch.all(input_t.image.sum(0) == out.sum(0)) else: # Float approximation - assert (input_t.sum(0) - out.sum(0)).abs().mean() < 1e-7 + assert (input_t.image.sum(0) - out.sum(0)).abs().mean() < 1e-7 @pytest.mark.parametrize( @@ -412,15 +436,15 @@ def test_channel_shuffle(input_dtype, input_size): ) def test_gaussian_noise(input_dtype, input_shape): transform = GaussianNoise(0.0, 1.0) - input_t = torch.rand(input_shape, dtype=torch.float32) + input_t = Sample(image=torch.rand(input_shape, dtype=torch.float32)) if input_dtype == torch.uint8: - input_t = (255 * input_t).round() - input_t = input_t.to(dtype=input_dtype) - transformed = transform(input_t) + input_t.image = (255 * input_t.image).round() + input_t.image = input_t.image.to(dtype=input_dtype) + transformed = transform(input_t).image assert isinstance(transformed, torch.Tensor) assert transformed.shape == input_shape assert transformed.dtype == input_dtype - assert torch.any(transformed != input_t) + assert torch.any(transformed != input_t.image) assert torch.all(transformed >= 0) if input_dtype == torch.uint8: assert torch.all(transformed <= 255) @@ -439,23 +463,22 @@ def test_gaussian_blur(input_dtype, input_shape): sigma_range = (0.5, 1.5) transform = GaussianBlur(sigma=sigma_range) - input_t = torch.rand(input_shape, dtype=torch.float32) + input_t = Sample(image=torch.rand(input_shape, dtype=torch.float32)) if input_dtype == torch.uint8: - input_t = (255 * input_t).round().to(dtype=torch.uint8) - - blurred = transform(input_t) + input_t.image = (255 * input_t.image).round().to(dtype=torch.uint8) + blurred = transform(input_t).image assert isinstance(blurred, torch.Tensor) assert blurred.shape == input_shape assert blurred.dtype == input_dtype if input_dtype == torch.uint8: - assert torch.any(blurred != input_t) + assert torch.any(blurred != input_t.image) assert torch.all(blurred <= 255) assert torch.all(blurred >= 0) else: - assert torch.any(blurred != input_t) + assert torch.any(blurred != input_t.image) assert torch.all(blurred <= 1.0) assert torch.all(blurred >= 0.0) @@ -475,7 +498,8 @@ def test_randomhorizontalflip(p, target): input_t = torch.ones((3, 32, 32), dtype=torch.float32) input_t[..., :16] = 0 - transformed, _target = transform(input_t, target) + data = transform(Sample(image=input_t, target=target)) + transformed, _target = data.image, data.target assert isinstance(transformed, torch.Tensor) assert transformed.shape == input_t.shape assert transformed.dtype == input_t.dtype @@ -506,7 +530,8 @@ def test_randomhorizontalflip(p, target): ), } - transformed, _target = transform(input_t, dict_target) + data = transform(Sample(image=input_t, target=dict_target)) + _target = data.target assert isinstance(_target, dict) assert set(_target.keys()) == {"boxes", "polygons"} assert _target["boxes"].dtype == np.float32 @@ -542,16 +567,16 @@ def test_randomhorizontalflip(p, target): ) def test_random_shadow(input_dtype, input_shape): transform = RandomShadow((0.2, 0.8)) - input_t = torch.ones(input_shape, dtype=torch.float32) + input_t = Sample(image=torch.ones(input_shape, dtype=torch.float32)) if input_dtype == torch.uint8: - input_t = (255 * input_t).round() - input_t = input_t.to(dtype=input_dtype) - transformed = transform(input_t) + input_t.image = (255 * input_t.image).round() + input_t.image = input_t.image.to(dtype=input_dtype) + transformed = transform(input_t).image assert isinstance(transformed, torch.Tensor) assert transformed.shape == input_shape assert transformed.dtype == input_dtype # The shadow will darken the picture - assert input_t.float().mean() >= transformed.float().mean() + assert input_t.image.float().mean() >= transformed.float().mean() assert torch.all(transformed >= 0) if input_dtype == torch.uint8: assert torch.all(transformed <= 255) @@ -581,7 +606,8 @@ def test_random_resize(p, preserve_aspect_ratio, symmetric_pad, target): img = torch.rand((3, 64, 64)) # Apply the transformation - out_img, out_target = transfo(img, target) + data = transfo(Sample(image=img, target=target)) + out_img, out_target = data.image, data.target assert isinstance(out_img, torch.Tensor) assert isinstance(out_target, np.ndarray) # Resize is already well tested @@ -605,15 +631,16 @@ def _make_pipeline(): symmetric_pad=True, p=1.0, ), - ImageTransform(ColorInversion(min_val=0.7)), - ImageTransform(GaussianNoise(mean=0.0, std=0.1)), - ImageTransform(ChannelShuffle()), - ImageTransform(RandomShadow((0.2, 0.8))), + ColorInversion(min_val=0.7), + GaussianNoise(mean=0.0, std=0.1), + ChannelShuffle(), + RandomShadow((0.2, 0.8)), RandomApply(RandomHorizontalFlip(p=1.0), p=1.0), OneOf([ RandomRotate(max_angle=5.0, expand=False), RandomCrop(scale=(0.9, 1.0), ratio=(0.95, 1.05)), ]), + ImageTorchvisionTransform(RandomGrayscale(p=0.15)), ]) @@ -634,7 +661,8 @@ def test_samplecompose_end_to_end_boxes(): } transforms = _make_pipeline() - out_img, out_targets = transforms(input_t, targets) + data = transforms(Sample(image=input_t, target=targets)) + out_img, out_targets = data.image, data.target # image checks assert isinstance(out_img, torch.Tensor) @@ -688,7 +716,8 @@ def test_samplecompose_end_to_end_polygons(): } transforms = _make_pipeline() - out_img, out_targets = transforms(input_t, targets) + data = transforms(Sample(image=input_t, target=targets)) + out_img, out_targets = data.image, data.target # image checks assert isinstance(out_img, torch.Tensor) From 21e0f590ecedeea3118486432864f66621ca095e Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 21 May 2026 14:40:47 +0200 Subject: [PATCH 4/9] Fix scripts --- .github/workflows/references.yml | 8 ++++---- scripts/evaluate.py | 17 ++++++++++------- scripts/evaluate_kie.py | 20 +++++++++++++------- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index 0336cfd032..d010ff3310 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -281,9 +281,9 @@ jobs: pip install -r references/requirements.txt - name: Download and extract toy set run: | - wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c79b4e69.zip + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c930013e.zip sudo apt-get update && sudo apt-get install unzip -y - unzip toy_layout_set-c79b4e69.zip -d layout_set + unzip toy_layout_set-c930013e.zip -d layout_set - name: Train for a short epoch run: python references/layout/train.py lw_detr_s --train_path ./layout_set --val_path ./layout_set -b 2 --epochs 1 @@ -313,9 +313,9 @@ jobs: pip install -r references/requirements.txt - name: Download and extract toy set run: | - wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c79b4e69.zip + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c930013e.zip sudo apt-get update && sudo apt-get install unzip -y - unzip toy_layout_set-c79b4e69.zip -d layout_set + unzip toy_layout_set-c930013e.zip -d layout_set - name: Evaluate layout analysis run: python references/layout/evaluate.py lw_detr_s ./layout_set diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 8846ac5c3e..06ca045627 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -11,6 +11,7 @@ from doctr import datasets from doctr import transforms as T from doctr.models import ocr_predictor +from doctr.utils import Sample from doctr.utils.geometry import extract_crops, extract_rcrops from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch @@ -27,12 +28,13 @@ def main(args): # We define a transformation function which does transform the annotation # to the required format for the Resize transformation - def _transform(img, target): - boxes = target["boxes"] - transformed_img, transformed_boxes = T.Resize( - input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad - )(img, boxes) - return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]} + def _transform(sample: Sample) -> Sample: + boxes, labels = sample.target["boxes"], sample.target["labels"] + sample = T.Resize(input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad)( + Sample(image=sample.image, target=boxes) + ) + sample.target = {"labels": labels, "boxes": sample.target} + return sample predictor = ocr_predictor( args.detection, @@ -78,7 +80,8 @@ def _transform(img, target): extraction_fn = extract_crops if args.eval_straight else extract_rcrops for dataset in sets: - for page, target in tqdm(dataset): + for data in tqdm(dataset): + page, target = data.image, data.target if isinstance(page, torch.Tensor): page = np.transpose(page.numpy(), (1, 2, 0)) # GT diff --git a/scripts/evaluate_kie.py b/scripts/evaluate_kie.py index ba14c49915..b29d799e4d 100644 --- a/scripts/evaluate_kie.py +++ b/scripts/evaluate_kie.py @@ -11,6 +11,7 @@ from doctr import transforms as T from doctr.io.elements import KIEDocument from doctr.models import kie_predictor +from doctr.utils import Sample from doctr.utils.geometry import extract_crops, extract_rcrops from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch @@ -27,12 +28,13 @@ def main(args): # We define a transformation function which does transform the annotation # to the required format for the Resize transformation - def _transform(img, target): - boxes = target["boxes"] - transformed_img, transformed_boxes = T.Resize( - input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad - )(img, boxes) - return transformed_img, {"boxes": transformed_boxes, "labels": target["labels"]} + def _transform(sample: Sample) -> Sample: + boxes, labels = sample.target["boxes"], sample.target["labels"] + sample = T.Resize(input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad)( + Sample(image=sample.image, target=boxes) + ) + sample.target = {"labels": labels, "boxes": sample.target} + return sample predictor = kie_predictor( args.detection, @@ -44,6 +46,9 @@ def _transform(img, target): assume_straight_pages=not args.rotation, ) + if torch.cuda.is_available(): + predictor = predictor.cuda() + if args.img_folder and args.label_file: testset = datasets.OCRDataset( img_folder=args.img_folder, @@ -75,7 +80,8 @@ def _transform(img, target): extraction_fn = extract_crops if args.eval_straight else extract_rcrops for dataset in sets: - for page, target in tqdm(dataset): + for data in tqdm(dataset): + page, target = data.image, data.target if isinstance(page, torch.Tensor): page = np.transpose(page.numpy(), (1, 2, 0)) # GT From 45c05c7a9b6f1cdf8e427dced0cea21428cc62c3 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 21 May 2026 15:04:34 +0200 Subject: [PATCH 5/9] Update dataset reference workflow path --- .github/workflows/references.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index d010ff3310..594aa6ed7c 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -281,9 +281,9 @@ jobs: pip install -r references/requirements.txt - name: Download and extract toy set run: | - wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c930013e.zip + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-d4a8d4c9.zip sudo apt-get update && sudo apt-get install unzip -y - unzip toy_layout_set-c930013e.zip -d layout_set + unzip toy_layout_set-d4a8d4c9.zip -d layout_set - name: Train for a short epoch run: python references/layout/train.py lw_detr_s --train_path ./layout_set --val_path ./layout_set -b 2 --epochs 1 @@ -313,9 +313,9 @@ jobs: pip install -r references/requirements.txt - name: Download and extract toy set run: | - wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-c930013e.zip + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_layout_set-d4a8d4c9.zip sudo apt-get update && sudo apt-get install unzip -y - unzip toy_layout_set-c930013e.zip -d layout_set + unzip toy_layout_set-d4a8d4c9.zip -d layout_set - name: Evaluate layout analysis run: python references/layout/evaluate.py lw_detr_s ./layout_set From d942737150bee646902626805a8842486ca4fd99 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 21 May 2026 15:21:58 +0200 Subject: [PATCH 6/9] style --- doctr/datasets/datasets/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py index 72c28a91df..a999a50166 100644 --- a/doctr/datasets/datasets/base.py +++ b/doctr/datasets/datasets/base.py @@ -44,7 +44,8 @@ def _read_sample(self, index: int) -> tuple[Any, Any]: def __getitem__(self, index: int) -> Sample: # Read image img, target = self._read_sample(index) - mask = None # FIX: always defined + mask = None + # Pre-transforms (format conversion at run-time etc.) if self._pre_transforms is not None: img, target = self._pre_transforms(img, target) From c04a53359ab31a8a57560886b09c7ab79e324fcf Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 21 May 2026 16:05:28 +0200 Subject: [PATCH 7/9] model updates --- doctr/models/layout/lw_detr/base.py | 6 ++- doctr/models/layout/lw_detr/layers/pytorch.py | 24 ++++-------- doctr/models/layout/lw_detr/pytorch.py | 38 +++++++++++++------ references/layout/train.py | 4 +- 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index b0e9a48e09..6f58f81f35 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -233,8 +233,10 @@ def _quad_to_obb(poly: np.ndarray): theta = np.arctan2(dy, dx) - w = np.mean([lengths[0], lengths[2]]) - h = np.mean([lengths[1], lengths[3]]) + # w should always be the length of the edge aligned with theta + w = np.mean([lengths[i], lengths[(i + 2) % 4]]) + # h is the perpendicular edge + h = np.mean([lengths[(i + 1) % 4], lengths[(i + 3) % 4]]) return np.array( [cx, cy, w, h, np.sin(theta), np.cos(theta)], diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index b964ebfd73..4b59c83690 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -527,34 +527,24 @@ def get_reference( return reference_points_inputs, query_pos def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine the reference points using the predicted deltas. - - Args: - reference_points: (batch_size, num_queries, 6) - tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) - deltas: (batch_size, num_queries, 6) - tensor containing the predicted deltas for the reference points in the same format as reference_points - - Returns: - refined_reference_points: (batch_size, num_queries, 6) - tensor containing the refined reference points in the same format as reference_points - """ + reference_points = reference_points.to(deltas.device) cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - wh = deltas[..., 2:4].exp() * reference_points[..., 2:4] - - delta_rot = F.normalize(deltas[..., 4:6], dim=-1) + # Clamp deltas to prevent exp() from shooting to Infinity during early training + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + # Add eps=1e-6 to avoid division-by-zero NaN creation + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) sin_delta = delta_rot[..., 0:1] cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] cos_ref = reference_points[..., 5:6] sin_new = sin_ref * cos_delta + cos_ref * sin_delta cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1) + # Add eps=1e-6 here too + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) return torch.cat((cxcy, wh, rot), dim=-1) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 5c7f68c0f2..89fc428bc8 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -156,8 +156,8 @@ def __init__( score_thresh: float = 0.3, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 300, - group_detr: int = 13, + num_queries: int = 50, + group_detr: int = 1, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, @@ -310,7 +310,8 @@ def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> # center cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] # size - wh = deltas[..., 2:4].exp() * reference_points[..., 2:4] + # Clamp deltas to prevent exp() from shooting to Infinity during early training + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] # normalize predicted delta rotation delta_rot = F.normalize(deltas[..., 4:6], dim=-1) sin_delta = delta_rot[..., 0:1] @@ -399,7 +400,9 @@ def gen_encoder_output_proposals( proposals.append(proposal) _cur += height * width output_proposals = torch.cat(proposals, 1) - output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + + spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) + output_proposals_valid = spatial_valid invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1) invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) @@ -482,7 +485,7 @@ def forward( 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) - group_topk_coords_logits = group_topk_coords_logits_undetach.detach() + group_topk_coords_logits = group_topk_coords_logits_undetach # .detach() topk_coords_logits_list.append(group_topk_coords_logits) topk_coords_logits = torch.cat(topk_coords_logits_list, 1) @@ -656,7 +659,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te with torch.no_grad(): cls_prob = pred_logits.sigmoid() - cost_cls = -torch.log(cls_prob[:, tgt_cls].clamp(min=1e-6)) + cost_cls = -cls_prob[:, tgt_cls] cost_l1 = torch.cdist( pred_boxes_b[:, :4], tgt_boxes[:, :4], @@ -670,8 +673,14 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te device=device, ) - iou_like = 1 - cost_l1 # proxy - dynamic_k = (iou_like.sum(0).int() + 1).clamp(min=5, max=20) + center_dist = torch.cdist( + pred_boxes_b[:, :2], + tgt_boxes[:, :2], + p=2, + ) + + iou_like = torch.exp(-center_dist) + dynamic_k = iou_like.sum(0).int().clamp(min=1, max=10) for gt_idx in range(num_gt): _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item())) @@ -690,11 +699,16 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) - target_onehot = torch.zeros_like(pred_logits) - if len(pos_idx) > 0: - target_onehot[pos_idx, tgt_cls[gt_indices]] = 1 + target_classes = torch.zeros( + (Q,), + dtype=torch.long, + device=device, + ) + + # background = 0 + target_classes[pos_idx] = tgt_cls[gt_indices] - total_cls += _sigmoid_focal_loss(pred_logits, target_onehot) + total_cls += F.cross_entropy(pred_logits, target_classes) if len(pos_idx) == 0: continue diff --git a/references/layout/train.py b/references/layout/train.py index e4deb8977f..f3f0d2c117 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -716,9 +716,9 @@ def parse_args(): action="store_true", help="metrics evaluation with straight boxes instead of polygons to save time + memory", ) - parser.add_argument("--optim", type=str, default="adam", choices=["adam", "adamw"], help="optimizer to use") + parser.add_argument("--optim", type=str, default="adamw", choices=["adam", "adamw"], help="optimizer to use") parser.add_argument( - "--sched", type=str, default="poly", choices=["cosine", "onecycle", "poly"], help="scheduler to use" + "--sched", type=str, default="cosine", choices=["cosine", "onecycle", "poly"], help="scheduler to use" ) parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") From 61dc036c02d57e8fce76870db95dffb37d5b65e9 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 22 May 2026 12:38:25 +0200 Subject: [PATCH 8/9] model updates --- doctr/models/layout/lw_detr/base.py | 96 +++++++------ doctr/models/layout/lw_detr/layers/pytorch.py | 12 +- doctr/models/layout/lw_detr/pytorch.py | 135 +++++++++++------- doctr/transforms/modules/base.py | 38 +++-- tests/pytorch/test_models_factory.py | 2 +- 5 files changed, 166 insertions(+), 117 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 6f58f81f35..2aa59e7d18 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -86,32 +86,54 @@ def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float: return inter / (area1 + area2 - inter + 1e-6) - def _nms(self, polys: np.ndarray, scores: np.ndarray) -> list[int]: - """Perform NMS on the predicted polygons + def _nms(self, polys: np.ndarray, scores: np.ndarray, labels: np.ndarray) -> list[int]: + """ + Class-wise greedy NMS for rotated polygons. Args: - polys: array of predicted polygons (N, 4, 2) - scores: array of predicted scores (N,) + polys: (N, 4, 2) + scores: (N,) + labels: (N,) Returns: - list of indices of the polygons to keep after NMS + indices kept after NMS (global indices) """ - idxs = np.argsort(scores)[::-1] - keep = [] + if len(polys) == 0: + return [] + + keep: list[int] = [] + + # Process each class independently + for cls in np.unique(labels): + cls_idxs = np.where(labels == cls)[0] + if len(cls_idxs) == 0: + continue - while idxs.size > 0: - i = idxs[0] - keep.append(i) + cls_scores = scores[cls_idxs] + cls_polys = polys[cls_idxs] - if idxs.size == 1: - break + # sort by confidence + order = np.argsort(cls_scores)[::-1] + cls_idxs = cls_idxs[order] + cls_polys = cls_polys[order] + cls_scores = cls_scores[order] - rest = idxs[1:] + suppressed = np.zeros(len(cls_idxs), dtype=bool) - ious = np.array([self._iou(polys[i], polys[j]) for j in rest]) + for i in range(len(cls_idxs)): + if suppressed[i]: + continue - idxs = rest[ious < self.iou_thresh] + keep.append(cls_idxs[i]) + # compare current box with the rest + for j in range(i + 1, len(cls_idxs)): + if suppressed[j]: + continue + + iou = self._iou(cls_polys[i], cls_polys[j]) + if iou >= self.iou_thresh: + suppressed[j] = True return keep def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]: @@ -125,8 +147,9 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) prob = exp / exp.sum(axis=-1, keepdims=True) - scores = prob[:, 1:].max(axis=-1) - labels = prob[:, 1:].argmax(axis=-1) + 1 + prob_fg = prob[:, :-1] # exclude background + scores = prob_fg.max(axis=-1) + labels = prob_fg.argmax(axis=-1) # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: @@ -153,7 +176,7 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int ) ) - keep = self._nms(polys, scores_b) if len(polys) > 0 else [] + keep = self._nms(polys, scores_b, labels_b) if len(polys) > 0 else [] final_labels = [] final_boxes = [] @@ -212,12 +235,12 @@ def build_target( and "labels" is an array of shape (num_boxes,) containing the class labels """ targets = [] - class_to_id = {name: i for i, name in enumerate(class_names)} def _quad_to_obb(poly: np.ndarray): poly = np.asarray(poly, dtype=np.float32) + # Center point is simply the average of the relative vertices cx, cy = np.mean(poly, axis=0) edges = np.stack([ @@ -233,56 +256,40 @@ def _quad_to_obb(poly: np.ndarray): theta = np.arctan2(dy, dx) - # w should always be the length of the edge aligned with theta + # Width and height remain cleanly in relative coordinate space [0, 1] w = np.mean([lengths[i], lengths[(i + 2) % 4]]) - # h is the perpendicular edge h = np.mean([lengths[(i + 1) % 4], lengths[(i + 3) % 4]]) + # Enforce strict unit-length normal vectors for rotation + sin_t = np.sin(theta) + cos_t = np.cos(theta) + norm = np.sqrt(sin_t**2 + cos_t**2) + 1e-8 + return np.array( - [cx, cy, w, h, np.sin(theta), np.cos(theta)], + [cx, cy, w, h, sin_t / norm, cos_t / norm], dtype=np.float32, ) def to_quad(box: np.ndarray): box = np.asarray(box, dtype=np.float32) - if box.shape == (4,): x1, y1, x2, y2 = box - return np.array( - [ - [x1, y1], - [x2, y1], - [x2, y2], - [x1, y2], - ], - dtype=np.float32, - ) - + return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32) if box.shape == (8,): return box.reshape(4, 2) - if box.shape == (4, 2): return box.astype(np.float32) - raise ValueError(f"Unsupported box shape: {box.shape}") for sample in target: boxes_all = [] labels_all = [] - if not sample: - targets.append({ - "boxes": np.zeros((0, 6), dtype=np.float32), - "labels": np.zeros((0,), dtype=np.int64), - }) - continue - for class_name, boxes in sample.items(): if class_name not in class_to_id: raise ValueError(f"Unknown class name: {class_name}") cls_id = class_to_id[class_name] - boxes = np.asarray(boxes) if boxes.ndim == 1: @@ -292,7 +299,8 @@ def to_quad(box: np.ndarray): poly = to_quad(box) obb = _quad_to_obb(poly) - if obb[2] <= 1e-3 or obb[3] <= 1e-3: + # filter out degenerate boxes + if obb[2] <= 1e-5 or obb[3] <= 1e-5: continue boxes_all.append(obb) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 4b59c83690..08f9ea9fb4 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -105,6 +105,12 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, _ = hidden_states.shape + if self.training: + # Crash prevention: ensure seq_len is perfectly divisible + assert seq_len % self.group_detr == 0, ( + f"Seq len {seq_len} must be divisible by group_detr {self.group_detr}" + ) # noqa: E501 + hidden_states_original = hidden_states if position_embeddings is not None: hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings @@ -439,6 +445,7 @@ class LWDETRDecoder(nn.Module): dec_n_points: number of sampling points for each attention head in cross-attention of each decoder layer group_detr: number of weight-sharing decoders to use during training for the group detr technique dropout_prob: dropout probability for the attention and MLP layers in each decoder layer + bbox_embed: module to predict bounding box deltas for iterative refinement of reference points """ def __init__( @@ -451,6 +458,7 @@ def __init__( dec_n_points: int = 2, group_detr: int = 13, dropout_prob: float = 0.0, + bbox_embed: nn.Module | None = None, ): super().__init__() self.dropout_prob = dropout_prob @@ -469,7 +477,7 @@ def __init__( for i in range(num_layers) ]) self.layernorm = nn.LayerNorm(self.d_model) - self.bbox_embed = None + self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) self.angle_proj = nn.Sequential( @@ -531,7 +539,7 @@ def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] # Clamp deltas to prevent exp() from shooting to Infinity during early training - wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] # Add eps=1e-6 to avoid division-by-zero NaN creation delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 89fc428bc8..848d44280e 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -156,7 +156,7 @@ def __init__( score_thresh: float = 0.3, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 50, + num_queries: int = 130, group_detr: int = 1, dec_layers: int = 3, sa_num_heads: int = 8, @@ -170,8 +170,8 @@ def __init__( ) -> None: super().__init__() - self.class_names: list[str] = ["__background__"] + class_names - self.num_classes = len(self.class_names) + self.class_names: list[str] = class_names + self.num_classes = len(self.class_names) + 1 # +1 for background class self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -190,6 +190,9 @@ def __init__( self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) + self.class_embed = nn.Linear(self.d_model, self.num_classes) + self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) + self.decoder = LWDETRDecoder( num_layers=dec_layers, d_model=d_model, @@ -199,6 +202,7 @@ def __init__( dec_n_points=dec_n_points, group_detr=group_detr, dropout_prob=dropout_prob, + bbox_embed=self.bbox_embed, ) self.enc_output = nn.ModuleList([nn.Linear(self.d_model, self.d_model) for _ in range(self.group_detr)]) @@ -210,9 +214,6 @@ def __init__( self.enc_out_class_embed = nn.ModuleList([ nn.Linear(self.d_model, self.num_classes) for _ in range(self.group_detr) ]) - self.class_embed = nn.Linear(self.d_model, self.num_classes) - self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) - self.decoder.bbox_embed = self.bbox_embed # type: ignore[assignment] self.postprocessor = LWDETRPostProcessor( num_classes=self.num_classes, @@ -277,6 +278,8 @@ def __init__( if isinstance(last, nn.Linear): nn.init.zeros_(last.weight) nn.init.zeros_(last.bias) + if last.bias.shape[0] == 6: + nn.init.constant_(last.bias[5], 1.0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -310,10 +313,9 @@ def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> # center cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] # size - # Clamp deltas to prevent exp() from shooting to Infinity during early training - wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] - # normalize predicted delta rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1) + wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] + # rotation + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) sin_delta = delta_rot[..., 0:1] cos_delta = delta_rot[..., 1:2] sin_ref = reference_points[..., 4:5] @@ -322,8 +324,7 @@ def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> # compose rotations sin_new = sin_ref * cos_delta + cos_ref * sin_delta cos_new = cos_ref * cos_delta - sin_ref * sin_delta - # normalize final rotation - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1) + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) return torch.cat((cxcy, wh, rot), dim=-1) @@ -468,28 +469,35 @@ def forward( topk_coords_logits_list: list[torch.Tensor] = [] - # For each group, predict class logits and bbox deltas from the object query embeddings, - # and select the top-k proposals based on the predicted class logits. + # encoder predictions for auxiliary losses + all_group_enc_logits: list[torch.Tensor] = [] + all_group_enc_coords: list[torch.Tensor] = [] + for group_id in range(group_detr): group_object_query = self.enc_output[group_id](object_query_embedding) group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - group_enc_outputs_class = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) + all_group_enc_logits.append(group_enc_outputs_class) + + group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) + group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) - group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] + all_group_enc_coords.append(group_enc_outputs_coord) + + group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1] + group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) - group_topk_coords_logits = group_topk_coords_logits_undetach # .detach() + group_topk_coords_logits = group_topk_coords_logits_undetach topk_coords_logits_list.append(group_topk_coords_logits) topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = self.refine_bboxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( @@ -523,13 +531,44 @@ def _postprocess(logits, boxes): out["preds"] = _postprocess(logits.detach().cpu().numpy(), pred_boxes.detach().cpu().numpy()) if target is not None: - loss = self.compute_loss(logits, pred_boxes, target) + # Build target + processed_targets = self.build_target(target, self.class_names) + + # Main loss from final decoder layer (group DETR) + split_logits = logits.chunk(group_detr, dim=1) + split_boxes = pred_boxes.chunk(group_detr, dim=1) + + main_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_logits, split_boxes): + main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) + loss = main_loss / group_detr + + # Auxiliary losses from intermediate decoder layers + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]) + aux_boxes_delta = self.bbox_embed(intermediate[i]) + aux_boxes = self.refine_bboxes(intermediate_reference_points[i], aux_boxes_delta) + + split_aux_logits = aux_logits.chunk(group_detr, dim=1) + split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) + + aux_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): + aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) + loss += 0.5 * (aux_loss / group_detr) + + # Auxiliary losses for encoder proposals + enc_loss: float | torch.Tensor = 0.0 + for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): + enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) + loss += 0.1 * (enc_loss / group_detr) + out["loss"] = loss return out def compute_loss( - self, logits: torch.Tensor, pred_boxes: torch.Tensor, target: list[dict[str, np.ndarray]] + self, logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]] ) -> torch.Tensor: """ Compute the loss for LW-DETR. The loss consists of three components: @@ -546,7 +585,7 @@ def compute_loss( Args: logits: (B, Q, C) tensor containing the predicted class logits for each query pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query - target: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding + targets: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding to class names and values corresponding to lists of boxes in either polygon format (4, 2) or bounding box format (4,) (xmin, ymin, xmax, ymax) @@ -554,21 +593,6 @@ def compute_loss( loss: the computed loss value """ - def _sigmoid_focal_loss( - inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0 - ) -> torch.Tensor: - """Compute the sigmoid focal loss between `inputs` and `targets`.""" - prob = inputs.sigmoid() - ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") - p_t = prob * targets + (1 - prob) * (1 - targets) - loss = ce_loss * ((1 - p_t) ** gamma) - - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - loss = alpha_t * loss - - return loss.sum(-1).mean() - def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters @@ -607,27 +631,29 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes) delta = (mu1 - mu2).unsqueeze(-1) - sigma = (sigma1 + sigma2) * 0.5 - sigma_inv = torch.linalg.inv(sigma) + eps = 1e-6 + eye = torch.eye(2, device=sigma.device) * eps + sigma_safe = sigma + eye + sigma1_safe = sigma1 + eye + sigma2_safe = sigma2 + eye + + sigma_inv = torch.linalg.inv(sigma_safe) mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - det_sigma = torch.linalg.det(sigma).clamp(min=1e-6) - det_sigma1 = torch.linalg.det(sigma1).clamp(min=1e-6) - det_sigma2 = torch.linalg.det(sigma2).clamp(min=1e-6) + det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) + det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) + det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) probiou = torch.exp(-bhattacharyya) - return 1 - probiou device = logits.device B, Q, C = logits.shape - # Build targets - targets = self.build_target(target, self.class_names) total_cls = torch.tensor(0.0, device=device) total_box = torch.tensor(0.0, device=device) @@ -649,24 +675,27 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te ) num_gt = len(tgt_cls) - if num_gt == 0: - target_onehot = torch.zeros_like(pred_logits) - total_cls += _sigmoid_focal_loss(pred_logits, target_onehot) - continue pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) with torch.no_grad(): cls_prob = pred_logits.sigmoid() - cost_cls = -cls_prob[:, tgt_cls] + alpha = 0.25 + gamma = 2.0 + + neg_cost = (1 - alpha) * (cls_prob**gamma) * (-(1 - cls_prob + 1e-8).log()) + + pos_cost = alpha * ((1 - cls_prob) ** gamma) * (-(cls_prob + 1e-8).log()) + + cost_cls = pos_cost[:, tgt_cls] - neg_cost[:, tgt_cls] cost_l1 = torch.cdist( pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1, ) cost_rot = 1 - (pred_rot @ tgt_rot.T).abs() - total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot + total_cost = 5.0 * cost_cls + 2.0 * cost_l1 + 1.0 * cost_rot matching_matrix = torch.zeros( (Q, num_gt), dtype=torch.bool, @@ -699,11 +728,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) - target_classes = torch.zeros( - (Q,), - dtype=torch.long, - device=device, - ) + target_classes = torch.zeros((Q,), dtype=torch.long, device=device) # background = 0 target_classes[pos_idx] = tgt_cls[gt_indices] diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index a53acf038f..105456a193 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -284,31 +284,39 @@ def _crop_array(self, img: Any, target: np.ndarray, crop_box): cropped_polys = target.copy() - crop_w = crop_box[2] - crop_box[0] - crop_h = crop_box[3] - crop_box[1] + # pixel-space crop box for coordinate transform + x0, y0, x1, y1 = ( + int(crop_box[0] * img.shape[-1]), + int(crop_box[1] * img.shape[-2]), + int(crop_box[2] * img.shape[-1]), + int(crop_box[3] * img.shape[-2]), + ) + + crop_w = x1 - x0 + crop_h = y1 - y0 - # Reproject coordinates into cropped frame - cropped_polys[..., 0] = (cropped_polys[..., 0] - crop_box[0]) / crop_w - cropped_polys[..., 1] = (cropped_polys[..., 1] - crop_box[1]) / crop_h + # shift polygons into cropped pixel frame + cropped_polys[..., 0] -= x0 + cropped_polys[..., 1] -= y0 - # Keep polygons with at least partial visibility + # visibility check in pixel space poly_min = np.min(cropped_polys, axis=1) poly_max = np.max(cropped_polys, axis=1) - is_kept = (poly_max[:, 0] > 0) & (poly_min[:, 0] < 1) & (poly_max[:, 1] > 0) & (poly_min[:, 1] < 1) + + is_kept = ( + (poly_max[:, 0] > 0) & (poly_min[:, 0] < crop_w) & (poly_max[:, 1] > 0) & (poly_min[:, 1] < crop_h) + ) + cropped_polys = cropped_polys[is_kept] if cropped_polys.shape[0] == 0: return img, target - return cropped_img, np.clip(cropped_polys, 0, 1) - - # For detection boxes, we can directly crop and clip them - cropped_img, crop_boxes = F.crop_detection(img, target, crop_box) - - if crop_boxes.shape[0] == 0: - return img, target + # final clipping in pixel space + cropped_polys[..., 0] = np.clip(cropped_polys[..., 0], 0, crop_w) + cropped_polys[..., 1] = np.clip(cropped_polys[..., 1], 0, crop_h) - return cropped_img, np.clip(crop_boxes, 0, 1) + return cropped_img, cropped_polys def __call__(self, sample: Sample) -> Sample: scale = random.uniform(self.scale[0], self.scale[1]) diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py index db9aed1cdf..8cf0a9fa45 100644 --- a/tests/pytorch/test_models_factory.py +++ b/tests/pytorch/test_models_factory.py @@ -50,7 +50,7 @@ def test_push_to_hf_hub(): ["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"], ["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"], ["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"], - ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], + # ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], ], ) def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir): From cd8cae1f9d1a6f2613368c7975b2b02b97eb7742 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 22 May 2026 12:49:25 +0200 Subject: [PATCH 9/9] revert --- doctr/transforms/modules/base.py | 38 +++++++++++++------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index 105456a193..a53acf038f 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -284,39 +284,31 @@ def _crop_array(self, img: Any, target: np.ndarray, crop_box): cropped_polys = target.copy() - # pixel-space crop box for coordinate transform - x0, y0, x1, y1 = ( - int(crop_box[0] * img.shape[-1]), - int(crop_box[1] * img.shape[-2]), - int(crop_box[2] * img.shape[-1]), - int(crop_box[3] * img.shape[-2]), - ) - - crop_w = x1 - x0 - crop_h = y1 - y0 + crop_w = crop_box[2] - crop_box[0] + crop_h = crop_box[3] - crop_box[1] - # shift polygons into cropped pixel frame - cropped_polys[..., 0] -= x0 - cropped_polys[..., 1] -= y0 + # Reproject coordinates into cropped frame + cropped_polys[..., 0] = (cropped_polys[..., 0] - crop_box[0]) / crop_w + cropped_polys[..., 1] = (cropped_polys[..., 1] - crop_box[1]) / crop_h - # visibility check in pixel space + # Keep polygons with at least partial visibility poly_min = np.min(cropped_polys, axis=1) poly_max = np.max(cropped_polys, axis=1) - - is_kept = ( - (poly_max[:, 0] > 0) & (poly_min[:, 0] < crop_w) & (poly_max[:, 1] > 0) & (poly_min[:, 1] < crop_h) - ) - + is_kept = (poly_max[:, 0] > 0) & (poly_min[:, 0] < 1) & (poly_max[:, 1] > 0) & (poly_min[:, 1] < 1) cropped_polys = cropped_polys[is_kept] if cropped_polys.shape[0] == 0: return img, target - # final clipping in pixel space - cropped_polys[..., 0] = np.clip(cropped_polys[..., 0], 0, crop_w) - cropped_polys[..., 1] = np.clip(cropped_polys[..., 1], 0, crop_h) + return cropped_img, np.clip(cropped_polys, 0, 1) + + # For detection boxes, we can directly crop and clip them + cropped_img, crop_boxes = F.crop_detection(img, target, crop_box) + + if crop_boxes.shape[0] == 0: + return img, target - return cropped_img, cropped_polys + return cropped_img, np.clip(crop_boxes, 0, 1) def __call__(self, sample: Sample) -> Sample: scale = random.uniform(self.scale[0], self.scale[1])