44import numpy as np
55import imageio
66import pandas as pd
7- from matplotlib import pyplot as plt
7+ import matplotlib . pyplot as plt
88from scipy .ndimage import gaussian_filter
99from tqdm import tqdm
1010from PIL import Image , ImageFilter
1111from scipy .stats import beta
1212
1313from bootplot .backend .base import Backend , create_backend
1414from bootplot .sorting import sort_images
15- from collections import Counter
1615
16+ import jax .numpy as jnp
17+ from jax import jit , vmap , device_get
18+ from jax .scipy .special import betainc
1719
18- def plot (plot_function : callable ,
19- data : Union [np .ndarray , pd .DataFrame ],
20- indices : np .ndarray ,
21- backend : Backend ,
22- ** kwargs ):
23- if isinstance (data , pd .DataFrame ):
24- plot_function (data .iloc [indices ], data , * backend .plot_args , kwargs )
25- else :
26- plot_function (data [indices ], data , * backend .plot_args , ** kwargs )
27-
28- def symmetric_transformation_new (x ,
29- k ,
30- threshold ):
31- y = beta .cdf (x , k , k )
32- return (1 - 2 * threshold ) * y + threshold
33-
34- def adjust_relative_frequencies_opt (relative_frequencies ,
35- k ,
36- threshold ):
37- dominant_color = max (relative_frequencies , key = relative_frequencies .get )
38- transformed_dominant = symmetric_transformation_new (relative_frequencies [dominant_color ], k , threshold )
39- sum_other = 1 - relative_frequencies [dominant_color ]
40- transformed_other = 1 - transformed_dominant
41- return {
42- color : transformed_other * rel_freq / sum_other if color != dominant_color else
43- transformed_dominant
44- for color , rel_freq in relative_frequencies .items ()
45- }
46-
47- def merge_images (images : np .ndarray ,
48- k : int ,
49- threshold : int ) -> np .ndarray :
50- num_images , rows , cols , _ = images .shape
51- new_image = np .zeros ((rows , cols , 3 ), dtype = np .uint8 )
52-
53- # Iterate over each pixel location
54- for i in range (rows ):
55- for j in range (cols ):
56- # Extract the colors at the current pixel location across all images
57- pixel_colors = [tuple (images [img , i , j ]) for img in range (num_images )]
58- # Count the occurrence of each color in this list of colors
59- color_counts = Counter (pixel_colors )
60- percentages_old = {color : count / sum (color_counts .values ()) for color , count in color_counts .items ()}
61- if len (percentages_old ) > 1 :
62- percentages = adjust_relative_frequencies_opt (percentages_old , k , threshold )
63- new_color = np .sum ([np .array (c ) * p for c , p in percentages .items ()], axis = 0 )
64- new_color = np .clip (new_color , 0 , 255 ).astype (np .uint8 )
65- new_image [i , j ] = new_color
66- else :
67- new_image [i ,j ] = list (percentages_old .keys ())[0 ]
68- return new_image
20+
21+ def symmetric_transformation_new (x : float ,
22+ k : float ,
23+ threshold : float ) -> float :
24+ y = betainc (k , k , x )
25+ return (1 - 2 * threshold ) * y + threshold
26+
27+ def adjust_freqs (freqs : jnp .ndarray ,
28+ k : float ,
29+ threshold : float ) -> jnp .ndarray :
30+ dom_idx = jnp .argmax (freqs )
31+ dom = freqs [dom_idx ]
32+
33+ t_dom = symmetric_transformation_new (dom , k , threshold )
34+ sum_other = 1.0 - dom
35+ scale = (1.0 - t_dom ) / sum_other
36+
37+ out = freqs * scale
38+ return out .at [dom_idx ].set (t_dom )
39+
40+
41+ def process_pixel (pixel_stack : jnp .ndarray ,
42+ k : float ,
43+ threshold : float ) -> jnp .ndarray :
44+ mn = pixel_stack .shape [0 ]
45+
46+ r = pixel_stack [:, 0 ].astype (jnp .int32 )
47+ g = pixel_stack [:, 1 ].astype (jnp .int32 )
48+ b = pixel_stack [:, 2 ].astype (jnp .int32 )
49+
50+ idx = (r << 16 ) + (g << 8 ) + b
51+
52+ uniq , counts = jnp .unique (idx , size = mn , fill_value = 0 , return_counts = True )
53+
54+ n_unique = jnp .sum (counts > 0 )
55+
56+ ur = ((uniq >> 16 ) & 255 ).astype (jnp .float32 )
57+ ug = ((uniq >> 8 ) & 255 ).astype (jnp .float32 )
58+ ub = (uniq & 255 ).astype (jnp .float32 )
59+
60+ colors = jnp .stack ([ur , ug , ub ], axis = 1 )
61+
62+ freqs = counts .astype (jnp .float32 ) / mn
63+
64+ only_one = (n_unique == 1 )
65+ one_color = colors [0 ].astype (jnp .uint8 )
66+
67+ freqs_adj = adjust_freqs (freqs , k , threshold )
68+
69+ rgb = jnp .sum (colors * freqs_adj [:, None ], axis = 0 )
70+ rgb = jnp .clip (rgb , 0 , 255 ).astype (jnp .uint8 )
71+
72+ return jnp .where (only_one , one_color , rgb )
73+
74+
75+
76+ @jit
77+ def merge_images (images : np .ndarray ,
78+ k : float ,
79+ threshold : float ) -> jnp .ndarray :
80+ mn , rows , cols , _ = images .shape
81+
82+ pixels = images .transpose (1 , 2 , 0 , 3 )
83+
84+ #each of the rows * cols elements is a list of RGB pixels from all images at the same location:
85+ pixels = pixels .reshape (rows * cols , mn , 3 )
86+
87+ fused = vmap (process_pixel , in_axes = (0 , None , None ))(pixels , k , threshold )
88+ return fused .reshape (rows , cols , 3 )
6989
7090
7191def merge_images_original (images : np .ndarray ) -> np .ndarray :
@@ -84,6 +104,7 @@ def merge_images_original(images: np.ndarray) -> np.ndarray:
84104 return merged
85105
86106
107+
87108def decay_images (images : np .ndarray ,
88109 m : int ,
89110 decay_length : int ) -> np .ndarray :
@@ -107,12 +128,16 @@ def decay_images(images: np.ndarray,
107128 return decayed_images
108129
109130
131+
132+
133+
110134def bootplot (f : callable ,
111135 data : Union [np .ndarray , pd .DataFrame ],
112136 m : int = 100 ,
113137 k : int = 2.5 ,
114138 threshold : int = 0.3 ,
115139 output_size_px : Tuple [int , int ] = (512 , 512 ),
140+ single_sample : bool = False ,
116141 output_image_path : Union [str , Path ] = None ,
117142 transformation : bool = True ,
118143 output_animation_path : Union [str , Path ] = None ,
@@ -147,9 +172,12 @@ def bootplot(f: callable,
147172 :param threshold: input transformation parameter. Controls the codomain of the transformation. It lies between 0 and 0.5. Default: ``0,3``.
148173 :type threshold: int
149174
150- :param output_size_px: output size (height, width ) in pixels. Default: ``(512, 512)``.
175+ :param output_size_px: output size (width, heigth ) in pixels. Default: ``(512, 512)``.
151176 :type output_size_px: tuple[int, int]
152177
178+ :param single_sample: if true data_subset consists of a single sample. Default: ``False``.
179+ :type single_sample: bool
180+
153181 :param output_image_path: path where the image should be stored. The image format is inferred from the filename
154182 extension. If None, the image is not stored. Default: ``None``.
155183 :type output_image_path: str or pathlib.Path
@@ -223,28 +251,36 @@ def bootplot(f: callable,
223251 >>> image.shape
224252 (512, 512, 3)
225253 """
254+
255+
226256 if isinstance (backend , str ):
227- backend = create_backend (backend , f , data , m , output_size_px = output_size_px )
257+ backend_class = create_backend (backend , f , data , m , output_size_px = output_size_px , single_sample = single_sample )
228258
229- backend .create_figure ()
259+ backend_class .create_figure ()
230260 images = []
231261 for _ in tqdm (range (m ), desc = 'Generating plots' , disable = not verbose ):
232- backend .plot ()
233- image = backend .plot_to_array ()
262+ backend_class .plot ()
263+ image = backend_class .plot_to_array ()
234264 images .append (image )
235- backend .clear_figure ()
236- backend .close_figure ()
265+ backend_class .clear_figure ()
266+ backend_class .close_figure ()
237267 images = np .stack (images )
238268
269+
239270 if transformation :
240- merged_image = merge_images (images [..., :3 ], k , threshold )
271+ merged_image = np .array (merge_images (images [..., :3 ], k , threshold ))
272+
241273 else :
242274 merged_image = merge_images_original (images [..., :3 ])
243275
244276 if output_image_path is not None :
245277 if verbose :
246278 print (f'> Saving bootstrapped image to { output_image_path } ' )
247- Image .fromarray (merged_image ).save (output_image_path )
279+ if isinstance (backend , str ) and backend .lower () == "matplotlib" :
280+ dpi = plt .rcParams ['figure.dpi' ]
281+ Image .fromarray (merged_image ).save (output_image_path , dpi = (dpi , dpi ))
282+ else :
283+ Image .fromarray (merged_image ).save (output_image_path )
248284 if output_animation_path is not None :
249285 sort_kwargs = dict () if sort_kwargs is None else sort_kwargs
250286 order = sort_images (images , sort_type , verbose = verbose , ** sort_kwargs )
0 commit comments