Skip to content

Commit 332af10

Browse files
Merge pull request #14 from davidnabergoj/update
parallelization, new single_sample param, fix small bugs
2 parents a611e9c + 24fba7b commit 332af10

6 files changed

Lines changed: 135 additions & 80 deletions

File tree

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
- name: Set up Python
1919
uses: actions/setup-python@v2
2020
with:
21-
python-version: "3.8"
21+
python-version: "3.12"
2222

2323
- name: Build source and wheel distributions
2424
run: |

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
- name: Set up Python
2222
uses: actions/setup-python@v2
2323
with:
24-
python-version: "3.8"
24+
python-version: "3.12"
2525

2626
- name: Build source and wheel distributions
2727
run: |

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
python-version: ["3.8", "3.9", "3.10"]
17+
python-version: ["3.11", "3.12"]
1818

1919
steps:
2020
- uses: actions/checkout@v3

bootplot/backend/base.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,38 @@ def __init__(self,
1414
f: callable,
1515
data: Union[np.ndarray, pd.DataFrame],
1616
m: int,
17-
output_size_px: Tuple[int, int]):
17+
output_size_px: Tuple[int, int],
18+
single_sample: bool):
1819
self.output_size_px = output_size_px
1920
self.f = f
2021
self.data = data
2122
self.m = m
23+
self.single_sample = single_sample
2224

2325
@abstractmethod
2426
def create_figure(self):
2527
raise NotImplemented
2628

2729
def plot(self):
28-
indices = np.random.randint(low=0, high=len(self.data), size=len(self.data))
30+
size = 1 if self.single_sample else len(self.data)
31+
indices = np.random.randint(low=0, high=len(self.data), size=size)
2932
if isinstance(self.data, pd.DataFrame):
3033
return self.f(self.data.iloc[indices], self.data, *self.plot_args)
3134
elif isinstance(self.data, np.ndarray):
3235
return self.f(self.data[indices], self.data, *self.plot_args)
3336

37+
def sample_all_indices(self):
38+
size = 1 if self.single_sample else len(self.data)
39+
return np.random.randint(0, len(self.data), size=(self.m, size))
40+
41+
42+
def plot_from_indices(self, indices):
43+
if isinstance(self.data, pd.DataFrame):
44+
subset = self.data.iloc[indices]
45+
else:
46+
subset = self.data[indices]
47+
return self.f(subset, self.data, *self.plot_args)
48+
3449
@abstractmethod
3550
def plot_to_array(self) -> np.ndarray:
3651
raise NotImplemented
@@ -50,8 +65,8 @@ def plot_args(self):
5065

5166

5267
class Basic(Backend):
53-
def __init__(self, f: callable, data: Union[np.ndarray, pd.DataFrame], m: int, output_size_px: Tuple[int, int]):
54-
super().__init__(f, data, m, output_size_px)
68+
def __init__(self, f: callable, data: Union[np.ndarray, pd.DataFrame], m: int, output_size_px: Tuple[int, int], single_sample: bool):
69+
super().__init__(f, data, m, output_size_px, single_sample)
5570
self.cached_image = None
5671

5772
def plot(self):
@@ -79,10 +94,11 @@ def __init__(self,
7994
f: callable,
8095
data: Union[np.ndarray, pd.DataFrame],
8196
m: int,
82-
output_size_px: Tuple[int, int] = (512, 512)):
97+
output_size_px: Tuple[int, int] = (512, 512),
98+
single_sample: bool = False):
8399
self.fig = None
84100
self.ax = None
85-
super().__init__(f, data, m, output_size_px)
101+
super().__init__(f, data, m, output_size_px, single_sample)
86102

87103
def create_figure(self):
88104
self.fig, self.ax = bootplot.backend.matplotlib.create_figure(self.output_size_px)
@@ -106,8 +122,9 @@ def __init__(self,
106122
f: callable,
107123
data: Union[np.ndarray, pd.DataFrame],
108124
m: int,
109-
output_size_px: Tuple[int, int] = (512, 512)):
110-
super().__init__(f, data, m, output_size_px)
125+
output_size_px: Tuple[int, int] = (512, 512),
126+
single_sample: bool = False):
127+
super().__init__(f, data, m, output_size_px, single_sample)
111128

112129
def create_figure(self):
113130
raise NotImplemented
@@ -131,7 +148,8 @@ def __init__(self,
131148
f: callable,
132149
data: Union[np.ndarray, pd.DataFrame],
133150
m: int,
134-
output_size_px: Tuple[int, int] = (512, 512)):
151+
output_size_px: Tuple[int, int] = (512, 512),
152+
single_sample: bool = False):
135153
super().__init__(f, data, m, output_size_px)
136154

137155
def create_figure(self):

bootplot/base.py

Lines changed: 98 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,88 @@
44
import numpy as np
55
import imageio
66
import pandas as pd
7-
from matplotlib import pyplot as plt
7+
import matplotlib.pyplot as plt
88
from scipy.ndimage import gaussian_filter
99
from tqdm import tqdm
1010
from PIL import Image, ImageFilter
1111
from scipy.stats import beta
1212

1313
from bootplot.backend.base import Backend, create_backend
1414
from bootplot.sorting import sort_images
15-
from collections import Counter
1615

16+
import jax.numpy as jnp
17+
from jax import jit, vmap, device_get
18+
from jax.scipy.special import betainc
1719

18-
def plot(plot_function: callable,
19-
data: Union[np.ndarray, pd.DataFrame],
20-
indices: np.ndarray,
21-
backend: Backend,
22-
**kwargs):
23-
if isinstance(data, pd.DataFrame):
24-
plot_function(data.iloc[indices], data, *backend.plot_args, kwargs)
25-
else:
26-
plot_function(data[indices], data, *backend.plot_args, **kwargs)
27-
28-
def symmetric_transformation_new(x,
29-
k,
30-
threshold):
31-
y = beta.cdf(x, k, k)
32-
return (1-2*threshold) * y + threshold
33-
34-
def adjust_relative_frequencies_opt(relative_frequencies,
35-
k,
36-
threshold):
37-
dominant_color = max(relative_frequencies, key=relative_frequencies.get)
38-
transformed_dominant = symmetric_transformation_new(relative_frequencies[dominant_color], k, threshold)
39-
sum_other = 1-relative_frequencies[dominant_color]
40-
transformed_other = 1-transformed_dominant
41-
return {
42-
color: transformed_other * rel_freq / sum_other if color != dominant_color else
43-
transformed_dominant
44-
for color, rel_freq in relative_frequencies.items()
45-
}
46-
47-
def merge_images(images: np.ndarray,
48-
k: int,
49-
threshold: int) -> np.ndarray:
50-
num_images, rows, cols, _ = images.shape
51-
new_image = np.zeros((rows, cols, 3), dtype=np.uint8)
52-
53-
# Iterate over each pixel location
54-
for i in range(rows):
55-
for j in range(cols):
56-
# Extract the colors at the current pixel location across all images
57-
pixel_colors = [tuple(images[img, i, j]) for img in range(num_images)]
58-
# Count the occurrence of each color in this list of colors
59-
color_counts = Counter(pixel_colors)
60-
percentages_old = {color: count / sum(color_counts.values()) for color, count in color_counts.items()}
61-
if len(percentages_old) > 1:
62-
percentages = adjust_relative_frequencies_opt(percentages_old, k, threshold)
63-
new_color = np.sum([np.array(c) * p for c, p in percentages.items()], axis=0)
64-
new_color = np.clip(new_color, 0, 255).astype(np.uint8)
65-
new_image[i, j] = new_color
66-
else:
67-
new_image[i,j] = list(percentages_old.keys())[0]
68-
return new_image
20+
21+
def symmetric_transformation_new(x: float,
22+
k: float,
23+
threshold: float) -> float:
24+
y = betainc(k, k, x)
25+
return (1 - 2 * threshold) * y + threshold
26+
27+
def adjust_freqs(freqs: jnp.ndarray,
28+
k: float,
29+
threshold: float) -> jnp.ndarray:
30+
dom_idx = jnp.argmax(freqs)
31+
dom = freqs[dom_idx]
32+
33+
t_dom = symmetric_transformation_new(dom, k, threshold)
34+
sum_other = 1.0 - dom
35+
scale = (1.0 - t_dom) / sum_other
36+
37+
out = freqs * scale
38+
return out.at[dom_idx].set(t_dom)
39+
40+
41+
def process_pixel(pixel_stack: jnp.ndarray,
42+
k: float,
43+
threshold: float) -> jnp.ndarray:
44+
mn = pixel_stack.shape[0]
45+
46+
r = pixel_stack[:, 0].astype(jnp.int32)
47+
g = pixel_stack[:, 1].astype(jnp.int32)
48+
b = pixel_stack[:, 2].astype(jnp.int32)
49+
50+
idx = (r << 16) + (g << 8) + b
51+
52+
uniq, counts = jnp.unique(idx, size=mn, fill_value=0, return_counts=True)
53+
54+
n_unique = jnp.sum(counts > 0)
55+
56+
ur = ((uniq >> 16) & 255).astype(jnp.float32)
57+
ug = ((uniq >> 8) & 255).astype(jnp.float32)
58+
ub = (uniq & 255).astype(jnp.float32)
59+
60+
colors = jnp.stack([ur, ug, ub], axis=1)
61+
62+
freqs = counts.astype(jnp.float32) / mn
63+
64+
only_one = (n_unique == 1)
65+
one_color = colors[0].astype(jnp.uint8)
66+
67+
freqs_adj = adjust_freqs(freqs, k, threshold)
68+
69+
rgb = jnp.sum(colors * freqs_adj[:, None], axis=0)
70+
rgb = jnp.clip(rgb, 0, 255).astype(jnp.uint8)
71+
72+
return jnp.where(only_one, one_color, rgb)
73+
74+
75+
76+
@jit
77+
def merge_images(images: np.ndarray,
78+
k: float,
79+
threshold: float) -> jnp.ndarray:
80+
mn, rows, cols, _ = images.shape
81+
82+
pixels = images.transpose(1, 2, 0, 3)
83+
84+
#each of the rows * cols elements is a list of RGB pixels from all images at the same location:
85+
pixels = pixels.reshape(rows * cols, mn, 3)
86+
87+
fused = vmap(process_pixel, in_axes=(0, None, None))(pixels, k, threshold)
88+
return fused.reshape(rows, cols, 3)
6989

7090

7191
def merge_images_original(images: np.ndarray) -> np.ndarray:
@@ -84,6 +104,7 @@ def merge_images_original(images: np.ndarray) -> np.ndarray:
84104
return merged
85105

86106

107+
87108
def decay_images(images: np.ndarray,
88109
m: int,
89110
decay_length: int) -> np.ndarray:
@@ -107,12 +128,16 @@ def decay_images(images: np.ndarray,
107128
return decayed_images
108129

109130

131+
132+
133+
110134
def bootplot(f: callable,
111135
data: Union[np.ndarray, pd.DataFrame],
112136
m: int = 100,
113137
k: int = 2.5,
114138
threshold: int = 0.3,
115139
output_size_px: Tuple[int, int] = (512, 512),
140+
single_sample: bool = False,
116141
output_image_path: Union[str, Path] = None,
117142
transformation: bool = True,
118143
output_animation_path: Union[str, Path] = None,
@@ -147,9 +172,12 @@ def bootplot(f: callable,
147172
:param threshold: input transformation parameter. Controls the codomain of the transformation. It lies between 0 and 0.5. Default: ``0,3``.
148173
:type threshold: int
149174
150-
:param output_size_px: output size (height, width) in pixels. Default: ``(512, 512)``.
175+
:param output_size_px: output size (width, heigth) in pixels. Default: ``(512, 512)``.
151176
:type output_size_px: tuple[int, int]
152177
178+
:param single_sample: if true data_subset consists of a single sample. Default: ``False``.
179+
:type single_sample: bool
180+
153181
:param output_image_path: path where the image should be stored. The image format is inferred from the filename
154182
extension. If None, the image is not stored. Default: ``None``.
155183
:type output_image_path: str or pathlib.Path
@@ -223,28 +251,36 @@ def bootplot(f: callable,
223251
>>> image.shape
224252
(512, 512, 3)
225253
"""
254+
255+
226256
if isinstance(backend, str):
227-
backend = create_backend(backend, f, data, m, output_size_px=output_size_px)
257+
backend_class = create_backend(backend, f, data, m, output_size_px=output_size_px, single_sample=single_sample)
228258

229-
backend.create_figure()
259+
backend_class.create_figure()
230260
images = []
231261
for _ in tqdm(range(m), desc='Generating plots', disable=not verbose):
232-
backend.plot()
233-
image = backend.plot_to_array()
262+
backend_class.plot()
263+
image = backend_class.plot_to_array()
234264
images.append(image)
235-
backend.clear_figure()
236-
backend.close_figure()
265+
backend_class.clear_figure()
266+
backend_class.close_figure()
237267
images = np.stack(images)
238268

269+
239270
if transformation:
240-
merged_image = merge_images(images[..., :3], k, threshold)
271+
merged_image = np.array(merge_images(images[..., :3], k, threshold))
272+
241273
else:
242274
merged_image = merge_images_original(images[..., :3])
243275

244276
if output_image_path is not None:
245277
if verbose:
246278
print(f'> Saving bootstrapped image to {output_image_path}')
247-
Image.fromarray(merged_image).save(output_image_path)
279+
if isinstance(backend, str) and backend.lower() == "matplotlib":
280+
dpi = plt.rcParams['figure.dpi']
281+
Image.fromarray(merged_image).save(output_image_path, dpi=(dpi, dpi))
282+
else:
283+
Image.fromarray(merged_image).save(output_image_path)
248284
if output_animation_path is not None:
249285
sort_kwargs = dict() if sort_kwargs is None else sort_kwargs
250286
order = sort_images(images, sort_type, verbose=verbose, **sort_kwargs)

requirements.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
numpy>=1.3,<2
2-
imageio~=2.9.0
3-
imageio-ffmpeg==0.4.7
1+
numpy>=1.3
2+
imageio>=2.9.0
3+
imageio-ffmpeg>=0.4.7
44
matplotlib>=3
5-
tqdm~=4.64.0
5+
tqdm>=4.64.0
66
pillow>=8
77
scipy>=1.5
88
scikit-image>=0.17
99
networkx>=2.7.1
1010
scikit-learn>=0.24
11-
opencv-python~=4.5.5
12-
pandas~=1.4.3
11+
opencv-python>=4.5.5
12+
pandas>=1.4.3
13+
jax>=0.8.1

0 commit comments

Comments
 (0)