Skip to content

Commit e475c6e

Browse files
authored
(4.3.0) Add strict type checking using pyright and improve the static types (#85)
* Replace mypy with pyright * Add a types.py file containing custom types * Remove Union in favor of |, replace pre-commit hooks with local tools * Fix an issue with np.argmax and ColorBox not implementing the __array__ protocol * Bump minor version, add changelog * Remove redundant comment
1 parent 3ede903 commit e475c6e

13 files changed

Lines changed: 278 additions & 255 deletions

File tree

.github/workflows/ci_tests.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,9 @@ jobs:
2929
key: venv-${{ hashFiles('uv.lock') }}
3030
- name: Install the project dependencies
3131
run: uv sync --extra dev
32+
- name: Install pyright
33+
run: uv tool install pyright
34+
- name: Run type checking
35+
run: uv run pyright
3236
- name: Run the automated tests (for example)
3337
run: uv run pytest -v

.pre-commit-config.yaml

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,23 @@ repos:
66
exclude: ^mkdocs\.yml$
77
- id: end-of-file-fixer
88
- id: trailing-whitespace
9-
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
# Ruff version.
11-
rev: v0.5.0
9+
- repo: local
1210
hooks:
13-
# Run the linter.
14-
- id: ruff
15-
args:
16-
- --config=pyproject.toml
17-
- --fix
18-
# Run the formatter.
11+
- id: ruff-check
12+
name: ruff check
13+
entry: uv run ruff check --config=pyproject.toml --fix
14+
language: system
15+
types: [python]
16+
require_serial: false
1917
- id: ruff-format
20-
- repo: https://github.com/pre-commit/mirrors-mypy
21-
rev: 'v1.10.0' # Use the sha / tag you want to point at
22-
hooks:
23-
- id: mypy
24-
args:
25-
- --show-error-codes
26-
- --pretty
27-
- --warn-redundant-casts
28-
- --check-untyped-defs
29-
- --ignore-missing-imports
30-
- --disallow-any-generics
31-
- --disallow-subclassing-any
32-
- --no-implicit-optional
18+
name: ruff format
19+
entry: uv run ruff format
20+
language: system
21+
types: [python]
22+
require_serial: false
23+
- id: pyright
24+
name: pyright
25+
entry: uv run pyright
26+
language: system
27+
types: [python]
28+
require_serial: false

CHANGELOG.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
# Released
1212

13-
## 4.2.0 - 08/11/2025
13+
## 4.3.0 - 11/08/2025
14+
15+
### Added
16+
17+
• Added `types.py` file containing custom type definitions
18+
19+
### Changed
20+
21+
• Replaced mypy with pyright for type checking
22+
• Updated and improved type annotations
23+
• Replaced pre-commit hooks with local tools
24+
25+
### Fixed
26+
27+
• Fixed an issue with np.argmax and ColorBox not implementing the __array__ protocol, improving NumPy compatibility
28+
29+
## 4.2.0 - 11/08/2025
1430

1531
### Added
1632

1733
- Added support for `PIL.Image` type in `extract_colors`.
1834

19-
## 4.1.0 - 4/7/2025
35+
## 4.1.0 - 04/07/2025
2036

2137
### Added
2238

Pylette/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from Pylette.src.color import Color
2-
from Pylette.src.color_extraction import ImageType_T, extract_colors
2+
from Pylette.src.color_extraction import extract_colors
33
from Pylette.src.palette import Palette
4+
from Pylette.src.types import ImageInput
45

5-
__all__ = ["extract_colors", "Palette", "Color", "ImageType_T"]
6+
__all__ = ["extract_colors", "Palette", "Color", "ImageInput"]

Pylette/cmd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ def main(
5252
typer.echo("Please provide either a filename or an image-url, but not both.")
5353
raise typer.Exit(code=1)
5454

55-
image: pathlib.Path | str | None
56-
if filename is not None and image_url is None:
57-
image = filename
55+
if filename is not None:
56+
image = filename # Path
5857
else:
59-
image = image_url
58+
assert image_url is not None
59+
image = image_url # str (URL)
6060

6161
output_file_path = str(out_filename) if out_filename is not None else None
6262
palette = extract_colors(

Pylette/src/color_extraction.py

Lines changed: 40 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,50 @@
1-
import os
21
import urllib.parse
3-
from enum import Enum
42
from io import BytesIO
5-
from typing import Any, Literal, TypeAlias, Union
3+
from pathlib import Path
4+
from typing import Literal
65

76
import numpy as np
8-
import requests # type: ignore
9-
from numpy.typing import NDArray
7+
import requests
108
from PIL import Image
119

1210
from Pylette.src.extractors.k_means import k_means_extraction
1311
from Pylette.src.extractors.median_cut import median_cut_extraction
1412
from Pylette.src.palette import Palette
15-
16-
ImageType_T: TypeAlias = Union["os.PathLike[Any]", bytes, NDArray[float], str, Image.Image]
17-
18-
19-
class ImageType(str, Enum):
20-
PATH = "path"
21-
BYTES = "bytes"
22-
ARRAY = "array"
23-
URL = "url"
24-
PIL = "pil"
25-
NONE = "none"
26-
27-
28-
def _parse_image_type(image: ImageType_T) -> ImageType:
29-
"""
30-
Determines the type of the input image.
31-
32-
Parameters:
33-
image (ImageType_T): The input image.
34-
35-
Returns:
36-
ImageType: The type of the input image.
37-
"""
38-
match image:
39-
case Image.Image():
40-
image_type = ImageType.PIL
41-
case np.ndarray():
42-
image_type = ImageType.ARRAY
43-
case os.PathLike():
44-
image_type = ImageType.PATH
45-
case bytes():
46-
image_type = ImageType.BYTES
47-
case str():
48-
try:
49-
result = urllib.parse.urlparse(image)
50-
if all([result.scheme, result.netloc]):
51-
image_type = ImageType.URL
52-
else:
53-
image_type = ImageType.PATH
54-
except ValueError:
55-
image_type = ImageType.PATH
56-
case _:
57-
image_type = ImageType.NONE
58-
return image_type
13+
from Pylette.src.types import ImageInput, PILImage
14+
15+
16+
def _is_url(image_str: str) -> bool:
17+
"""Check if a string is a valid URL."""
18+
try:
19+
result = urllib.parse.urlparse(image_str)
20+
return all([result.scheme, result.netloc])
21+
except Exception:
22+
return False
23+
24+
25+
def _normalize_image_input(image: ImageInput) -> PILImage:
26+
"""Convert any valid image input to PIL Image."""
27+
if isinstance(image, Image.Image):
28+
return image
29+
elif isinstance(image, (str, Path)):
30+
image_str = str(image)
31+
if _is_url(image_str):
32+
return request_image(image_str)
33+
else:
34+
return Image.open(image)
35+
elif isinstance(image, bytes):
36+
return Image.open(BytesIO(image))
37+
elif hasattr(image, "__array__"): # More general check for array-like objects
38+
return Image.fromarray(image)
39+
else:
40+
raise TypeError(f"Unsupported image type: {type(image)}")
5941

6042

6143
def extract_colors(
62-
image: ImageType_T,
44+
image: ImageInput,
6345
palette_size: int = 5,
6446
resize: bool = True,
65-
mode: Literal["KM"] | Literal["MC"] = "KM",
47+
mode: Literal["KM", "MC"] = "KM",
6648
sort_mode: Literal["luminance", "frequency"] | None = None,
6749
alpha_mask_threshold: int | None = None,
6850
) -> Palette:
@@ -87,27 +69,9 @@ def extract_colors(
8769
>>> extract_colors(b"image_bytes", palette_size=5, resize=True, mode="KM", sort_mode="luminance")
8870
"""
8971

90-
image_type = _parse_image_type(image)
91-
92-
match image_type:
93-
case ImageType.PATH:
94-
img_obj = Image.open(image)
95-
case ImageType.BYTES:
96-
assert isinstance(image, bytes)
97-
img_obj = Image.open(BytesIO(image))
98-
case ImageType.URL:
99-
assert isinstance(image, str)
100-
img_obj = request_image(image)
101-
case ImageType.ARRAY:
102-
img_obj = Image.fromarray(image)
103-
case ImageType.PIL:
104-
img_obj = image
105-
case ImageType.NONE:
106-
raise ValueError(f"Unable to parse image source. Got image type {type(image)}")
107-
108-
# Convert to RGBA
72+
# Normalize input to PIL Image and convert to RGBA
73+
img_obj = _normalize_image_input(image)
10974
img = img_obj.convert("RGBA")
110-
11175
# open the image
11276
if resize:
11377
img = img.resize((256, 256))
@@ -121,12 +85,11 @@ def extract_colors(
12185
alpha_mask = arr[:, :, 3] <= alpha_mask_threshold
12286
valid_pixels = arr[~alpha_mask]
12387

124-
if mode == "KM":
125-
colors = k_means_extraction(valid_pixels, height, width, palette_size)
126-
elif mode == "MC":
127-
colors = median_cut_extraction(valid_pixels, height, width, palette_size)
128-
else:
129-
raise NotImplementedError("Extraction mode not implemented")
88+
match mode:
89+
case "KM":
90+
colors = k_means_extraction(valid_pixels, height, width, palette_size)
91+
case "MC":
92+
colors = median_cut_extraction(valid_pixels, height, width, palette_size)
13093

13194
if sort_mode == "luminance":
13295
colors.sort(key=lambda c: c.luminance, reverse=False)

Pylette/src/extractors/median_cut.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ def extract(self, arr: NDArray[NP_T], height: int, width: int, palette_size: int
2525
valid_pixel_count = arr.shape[0]
2626
boxes = [ColorBox(arr)]
2727
while len(boxes) < palette_size:
28-
largest_box_idx = np.argmax(boxes) # type: ignore
28+
largest_box_idx = np.argmax([box.size for box in boxes])
2929
boxes = boxes[:largest_box_idx] + boxes[largest_box_idx].split() + boxes[largest_box_idx + 1 :]
30-
3130
return [Color(tuple(map(int, box.average)), box.pixel_count / valid_pixel_count) for box in boxes]
3231

3332

Pylette/src/palette.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def to_csv(
115115
palette_file.write(",{}".format(color.freq))
116116
palette_file.write("\n")
117117

118-
def random_color(self, N, mode="frequency"):
118+
def random_color(self, N: int, mode: str = "frequency") -> list[Color]:
119119
"""
120120
Returns N random colors from the palette, either using the frequency of each color, or choosing uniformly.
121121
@@ -128,11 +128,17 @@ def random_color(self, N, mode="frequency"):
128128
"""
129129

130130
if mode == "frequency":
131-
pdf = self.frequencies
131+
# Convert to numpy-compatible format for weighted selection
132+
colors_array = np.array(range(len(self.colors)))
133+
indices = np.random.choice(colors_array, size=N, p=self.frequencies)
134+
return [self.colors[i] for i in indices]
132135
elif mode == "uniform":
133-
pdf = None
134-
135-
return np.random.choice(self.colors, size=N, p=pdf)
136+
# Uniform selection without weights
137+
colors_array = np.array(range(len(self.colors)))
138+
indices = np.random.choice(colors_array, size=N)
139+
return [self.colors[i] for i in indices]
140+
else:
141+
raise ValueError(f"Invalid mode: {mode}. Must be 'frequency' or 'uniform'.")
136142

137143
def __str__(self):
138144
return "".join(["({}, {}, {}, {}) \n".format(c.rgb[0], c.rgb[1], c.rgb[2], c.freq) for c in self.colors])

Pylette/src/types.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Centralized type definitions for Pylette.
3+
4+
This module contains all the type aliases and protocols used throughout the Pylette library
5+
to ensure type safety and consistency.
6+
"""
7+
8+
from pathlib import Path
9+
from typing import Any, Protocol, TypeAlias
10+
11+
import numpy as np
12+
from numpy.typing import NDArray
13+
from PIL import Image
14+
15+
16+
class ImageLike(Protocol):
17+
"""Protocol for image-like objects that can be converted to PIL Image."""
18+
19+
pass
20+
21+
22+
class ArrayLike(Protocol):
23+
"""Protocol for array-like objects."""
24+
25+
def __array__(self) -> NDArray[np.uint8]: ...
26+
27+
28+
# Specific image input types
29+
PathLikeImage: TypeAlias = str | Path
30+
URLImage: TypeAlias = str # URLs are strings but semantically different
31+
BytesImage: TypeAlias = bytes
32+
ArrayImage: TypeAlias = NDArray[np.uint8] # Properly typed array
33+
PILImage: TypeAlias = Image.Image
34+
35+
# Main union type - more restrictive and logical
36+
ImageInput: TypeAlias = PathLikeImage | URLImage | BytesImage | ArrayImage | PILImage
37+
38+
# Color array types
39+
ColorArray: TypeAlias = NDArray[np.uint8] # For RGB/RGBA color data
40+
FloatArray: TypeAlias = NDArray[np.floating[Any]] # For calculations
41+
IntArray: TypeAlias = NDArray[np.integer[Any]] # For integer arrays
42+
43+
# Color tuple types
44+
RGBTuple: TypeAlias = tuple[int, int, int]
45+
RGBATuple: TypeAlias = tuple[int, int, int, int]
46+
ColorTuple: TypeAlias = RGBTuple | RGBATuple

0 commit comments

Comments
 (0)