1- import os
21import urllib .parse
3- from enum import Enum
42from io import BytesIO
5- from typing import Any , Literal , TypeAlias , Union
3+ from pathlib import Path
4+ from typing import Literal
65
76import numpy as np
8- import requests # type: ignore
9- from numpy .typing import NDArray
7+ import requests
108from PIL import Image
119
1210from Pylette .src .extractors .k_means import k_means_extraction
1311from Pylette .src .extractors .median_cut import median_cut_extraction
1412from Pylette .src .palette import Palette
15-
16- ImageType_T : TypeAlias = Union ["os.PathLike[Any]" , bytes , NDArray [float ], str , Image .Image ]
17-
18-
19- class ImageType (str , Enum ):
20- PATH = "path"
21- BYTES = "bytes"
22- ARRAY = "array"
23- URL = "url"
24- PIL = "pil"
25- NONE = "none"
26-
27-
28- def _parse_image_type (image : ImageType_T ) -> ImageType :
29- """
30- Determines the type of the input image.
31-
32- Parameters:
33- image (ImageType_T): The input image.
34-
35- Returns:
36- ImageType: The type of the input image.
37- """
38- match image :
39- case Image .Image ():
40- image_type = ImageType .PIL
41- case np .ndarray ():
42- image_type = ImageType .ARRAY
43- case os .PathLike ():
44- image_type = ImageType .PATH
45- case bytes ():
46- image_type = ImageType .BYTES
47- case str ():
48- try :
49- result = urllib .parse .urlparse (image )
50- if all ([result .scheme , result .netloc ]):
51- image_type = ImageType .URL
52- else :
53- image_type = ImageType .PATH
54- except ValueError :
55- image_type = ImageType .PATH
56- case _:
57- image_type = ImageType .NONE
58- return image_type
13+ from Pylette .src .types import ImageInput , PILImage
14+
15+
16+ def _is_url (image_str : str ) -> bool :
17+ """Check if a string is a valid URL."""
18+ try :
19+ result = urllib .parse .urlparse (image_str )
20+ return all ([result .scheme , result .netloc ])
21+ except Exception :
22+ return False
23+
24+
25+ def _normalize_image_input (image : ImageInput ) -> PILImage :
26+ """Convert any valid image input to PIL Image."""
27+ if isinstance (image , Image .Image ):
28+ return image
29+ elif isinstance (image , (str , Path )):
30+ image_str = str (image )
31+ if _is_url (image_str ):
32+ return request_image (image_str )
33+ else :
34+ return Image .open (image )
35+ elif isinstance (image , bytes ):
36+ return Image .open (BytesIO (image ))
37+ elif hasattr (image , "__array__" ): # More general check for array-like objects
38+ return Image .fromarray (image )
39+ else :
40+ raise TypeError (f"Unsupported image type: { type (image )} " )
5941
6042
6143def extract_colors (
62- image : ImageType_T ,
44+ image : ImageInput ,
6345 palette_size : int = 5 ,
6446 resize : bool = True ,
65- mode : Literal ["KM" ] | Literal [ "MC" ] = "KM" ,
47+ mode : Literal ["KM" , "MC" ] = "KM" ,
6648 sort_mode : Literal ["luminance" , "frequency" ] | None = None ,
6749 alpha_mask_threshold : int | None = None ,
6850) -> Palette :
@@ -87,27 +69,9 @@ def extract_colors(
8769 >>> extract_colors(b"image_bytes", palette_size=5, resize=True, mode="KM", sort_mode="luminance")
8870 """
8971
90- image_type = _parse_image_type (image )
91-
92- match image_type :
93- case ImageType .PATH :
94- img_obj = Image .open (image )
95- case ImageType .BYTES :
96- assert isinstance (image , bytes )
97- img_obj = Image .open (BytesIO (image ))
98- case ImageType .URL :
99- assert isinstance (image , str )
100- img_obj = request_image (image )
101- case ImageType .ARRAY :
102- img_obj = Image .fromarray (image )
103- case ImageType .PIL :
104- img_obj = image
105- case ImageType .NONE :
106- raise ValueError (f"Unable to parse image source. Got image type { type (image )} " )
107-
108- # Convert to RGBA
72+ # Normalize input to PIL Image and convert to RGBA
73+ img_obj = _normalize_image_input (image )
10974 img = img_obj .convert ("RGBA" )
110-
11175 # open the image
11276 if resize :
11377 img = img .resize ((256 , 256 ))
@@ -121,12 +85,11 @@ def extract_colors(
12185 alpha_mask = arr [:, :, 3 ] <= alpha_mask_threshold
12286 valid_pixels = arr [~ alpha_mask ]
12387
124- if mode == "KM" :
125- colors = k_means_extraction (valid_pixels , height , width , palette_size )
126- elif mode == "MC" :
127- colors = median_cut_extraction (valid_pixels , height , width , palette_size )
128- else :
129- raise NotImplementedError ("Extraction mode not implemented" )
88+ match mode :
89+ case "KM" :
90+ colors = k_means_extraction (valid_pixels , height , width , palette_size )
91+ case "MC" :
92+ colors = median_cut_extraction (valid_pixels , height , width , palette_size )
13093
13194 if sort_mode == "luminance" :
13295 colors .sort (key = lambda c : c .luminance , reverse = False )
0 commit comments