Skip to content

Commit 6d9c588

Browse files
jamjamjonjamjamjon
authored andcommitted
update
1 parent 18648c8 commit 6d9c588

5 files changed

Lines changed: 595 additions & 247 deletions

File tree

benches/bench_standardize.rs

Lines changed: 149 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use criterion::{criterion_group, criterion_main, Criterion};
33
use ndarray::{Array, ArrayView1, IxDyn, Zip};
44
use slsl::Tensor;
5+
use std::collections::HashSet;
56
use std::hint::black_box;
67

78
// ndarray implementation from user
@@ -53,7 +54,7 @@ pub fn standardize_ndarray(
5354
Zip::from(x)
5455
.and(mean_broadcast)
5556
.and(std_broadcast)
56-
.par_for_each(|x_val, &mean_val, &std_val| {
57+
.for_each(|x_val, &mean_val, &std_val| {
5758
*x_val = (*x_val - mean_val) / std_val;
5859
});
5960

@@ -62,7 +63,7 @@ pub fn standardize_ndarray(
6263

6364
// Test configurations for 3D tensors
6465
struct TestConfig3D {
65-
name: &'static str,
66+
name: String,
6667
shape: [usize; 3],
6768
dim: usize,
6869
mean: [f32; 3],
@@ -71,210 +72,201 @@ struct TestConfig3D {
7172

7273
// Test configurations for 4D tensors
7374
struct TestConfig4D {
74-
name: &'static str,
75+
name: String,
7576
shape: [usize; 4],
7677
dim: usize,
7778
mean: [f32; 3],
7879
std: [f32; 3],
7980
}
8081

81-
const TEST_CONFIGS_3D: &[TestConfig3D] = &[
82-
// 224x224x3 configurations
83-
TestConfig3D {
84-
name: "224x224x3_hwc_zeros",
85-
shape: [224, 224, 3],
86-
dim: 2,
87-
mean: [0.0, 0.0, 0.0],
88-
std: [1.0, 1.0, 1.0],
89-
},
90-
TestConfig3D {
91-
name: "224x224x3_hwc_half",
92-
shape: [224, 224, 3],
93-
dim: 2,
94-
mean: [0.5, 0.5, 0.5],
95-
std: [0.5, 0.5, 0.5],
96-
},
97-
TestConfig3D {
98-
name: "224x224x3_hwc_imagenet",
99-
shape: [224, 224, 3],
100-
dim: 2,
101-
mean: [0.48145466, 0.4578275, 0.40821073],
102-
std: [0.26862954, 0.261_302_6, 0.275_777_1],
103-
},
104-
// 3x224x224 configurations
105-
TestConfig3D {
106-
name: "3x224x224_chw_zeros",
107-
shape: [3, 224, 224],
108-
dim: 0,
109-
mean: [0.0, 0.0, 0.0],
110-
std: [1.0, 1.0, 1.0],
111-
},
112-
TestConfig3D {
113-
name: "3x224x224_chw_half",
114-
shape: [3, 224, 224],
115-
dim: 0,
116-
mean: [0.5, 0.5, 0.5],
117-
std: [0.5, 0.5, 0.5],
118-
},
119-
TestConfig3D {
120-
name: "3x224x224_chw_imagenet",
121-
shape: [3, 224, 224],
122-
dim: 0,
123-
mean: [0.48145466, 0.4578275, 0.40821073],
124-
std: [0.26862954, 0.261_302_6, 0.275_777_1],
125-
},
126-
// Other sizes
127-
TestConfig3D {
128-
name: "256x256x3_hwc_zeros",
129-
shape: [256, 256, 3],
130-
dim: 2,
131-
mean: [0.0, 0.0, 0.0],
132-
std: [1.0, 1.0, 1.0],
133-
},
134-
TestConfig3D {
135-
name: "3x256x256_chw_zeros",
136-
shape: [3, 256, 256],
137-
dim: 0,
138-
mean: [0.0, 0.0, 0.0],
139-
std: [1.0, 1.0, 1.0],
140-
},
141-
TestConfig3D {
142-
name: "512x512x3_hwc_zeros",
143-
shape: [512, 512, 3],
144-
dim: 2,
145-
mean: [0.0, 0.0, 0.0],
146-
std: [1.0, 1.0, 1.0],
147-
},
148-
TestConfig3D {
149-
name: "3x512x512_chw_zeros",
150-
shape: [3, 512, 512],
151-
dim: 0,
152-
mean: [0.0, 0.0, 0.0],
153-
std: [1.0, 1.0, 1.0],
154-
},
155-
TestConfig3D {
156-
name: "1024x1024x3_hwc_zeros",
157-
shape: [1024, 1024, 3],
158-
dim: 2,
159-
mean: [0.0, 0.0, 0.0],
160-
std: [1.0, 1.0, 1.0],
161-
},
162-
TestConfig3D {
163-
name: "3x1024x1024_chw_zeros",
164-
shape: [3, 1024, 1024],
165-
dim: 0,
166-
mean: [0.0, 0.0, 0.0],
167-
std: [1.0, 1.0, 1.0],
168-
},
169-
];
82+
fn gen_3d_configs() -> Vec<TestConfig3D> {
83+
let means = [
84+
("zeros", [0.0f32, 0.0, 0.0], [1.0f32, 1.0, 1.0]),
85+
("half", [0.5f32, 0.5, 0.5], [0.5f32, 0.5, 0.5]),
86+
(
87+
"imagenet",
88+
[0.48145466, 0.4578275, 0.40821073],
89+
[0.26862954, 0.261_302_6, 0.275_777_1],
90+
),
91+
];
17092

171-
const TEST_CONFIGS_4D: &[TestConfig4D] = &[
172-
// 4D NCHW format (batch, channels, height, width) - typical for PyTorch/vision models
173-
TestConfig4D {
174-
name: "1x3x224x224_nchw_imagenet",
175-
shape: [1, 3, 224, 224],
176-
dim: 1,
177-
mean: [0.48145466, 0.4578275, 0.40821073],
178-
std: [0.26862954, 0.2613026, 0.2757771],
179-
},
180-
TestConfig4D {
181-
name: "1x3x640x640_nchw_imagenet",
182-
shape: [1, 3, 640, 640],
183-
dim: 1,
184-
mean: [0.48145466, 0.4578275, 0.40821073],
185-
std: [0.26862954, 0.2613026, 0.2757771],
186-
},
187-
];
93+
let sizes = [(224usize, 224usize), (640, 640), (1024, 1024)];
94+
let mut v = Vec::new();
95+
for (h, w) in sizes {
96+
// Generate three permutations where channel size 3 is placed at each axis
97+
let shapes = [
98+
([3, h, w], 0, format!("3x{}x{}", h, w)),
99+
([h, 3, w], 1, format!("{}x3x{}", h, w)),
100+
([h, w, 3], 2, format!("{}x{}x3", h, w)),
101+
];
102+
for (shape, dim, tag) in shapes {
103+
for (case, mean, std) in means.iter() {
104+
v.push(TestConfig3D {
105+
name: format!("{}_{case}", tag),
106+
shape,
107+
dim,
108+
mean: *mean,
109+
std: *std,
110+
});
111+
}
112+
}
113+
}
114+
v
115+
}
188116

189-
fn bench_slsl_standardize_3d(c: &mut Criterion) {
190-
let mut group = c.benchmark_group("standardize_slsl_3d");
117+
fn gen_4d_configs() -> Vec<TestConfig4D> {
118+
// Start from 3D base (H,W,3) sizes and insert 10 at all possible positions,
119+
// then also permute channel (3) to different axes by using the three 3D permutations and inserting 10.
120+
let sizes = [(224usize, 224usize), (640, 640), (1024, 1024)];
121+
let mut v = Vec::new();
122+
let mut seen: HashSet<(usize, usize, usize, usize, usize)> = HashSet::new();
123+
let mean = [0.48145466, 0.4578275, 0.40821073];
124+
let std = [0.26862954, 0.2613026, 0.2757771];
191125

192-
for config in TEST_CONFIGS_3D {
193-
group.bench_function(config.name, |b| {
194-
b.iter(|| {
195-
// Create test data
196-
let data: Vec<f32> = (0..config.shape.iter().product::<usize>())
197-
.map(|i| (i as f32) * 0.01)
198-
.collect();
199-
let tensor = Tensor::from_vec(data, config.shape).unwrap();
126+
for (h, w) in sizes {
127+
// Base 3D permutations
128+
let bases = [
129+
([3, h, w], 0, format!("3x{}x{}", h, w)),
130+
([h, 3, w], 1, format!("{}x3x{}", h, w)),
131+
([h, w, 3], 2, format!("{}x{}x3", h, w)),
132+
];
200133

201-
// Perform standardization
202-
black_box(
203-
tensor
204-
.standardize(&config.mean, &config.std, config.dim)
205-
.unwrap(),
206-
)
207-
});
208-
});
209-
}
134+
for (base, ch_dim, tag3d) in bases {
135+
// Insert 10 at all four positions
136+
let shapes_4d = [
137+
(
138+
[10, base[0], base[1], base[2]],
139+
ch_dim + 1,
140+
format!("10x{}", tag3d),
141+
),
142+
(
143+
[base[0], 10, base[1], base[2]],
144+
ch_dim,
145+
format!(
146+
"{}x10x{}",
147+
if ch_dim == 0 {
148+
format!("{}x{}", base[1], base[2])
149+
} else {
150+
format!("{}", base[0])
151+
},
152+
if ch_dim == 0 {
153+
format!("{}", base[0])
154+
} else {
155+
format!("{}x{}", base[1], base[2])
156+
}
157+
),
158+
),
159+
(
160+
[base[0], base[1], 10, base[2]],
161+
ch_dim,
162+
format!("{}x{}x10x{}", base[0], base[1], base[2]),
163+
),
164+
(
165+
[base[0], base[1], base[2], 10],
166+
ch_dim,
167+
format!("{}x{}x{}x10", base[0], base[1], base[2]),
168+
),
169+
];
210170

211-
group.finish();
171+
for (shape4, dim4, _tag4) in shapes_4d {
172+
// Derive channel dim by locating the index of value 3
173+
let mut channel_dim = None;
174+
for (i, &d) in shape4.iter().enumerate() {
175+
if d == 3 {
176+
channel_dim = Some(i);
177+
break;
178+
}
179+
}
180+
let dim = channel_dim.unwrap_or(dim4);
181+
let key = (shape4[0], shape4[1], shape4[2], shape4[3], dim);
182+
if seen.insert(key) {
183+
let name = format!(
184+
"{}x{}x{}x{}:dim{}",
185+
shape4[0], shape4[1], shape4[2], shape4[3], dim
186+
);
187+
v.push(TestConfig4D {
188+
name,
189+
shape: shape4,
190+
dim,
191+
mean,
192+
std,
193+
});
194+
}
195+
}
196+
}
197+
}
198+
v
212199
}
213200

214-
fn bench_slsl_standardize_4d(c: &mut Criterion) {
215-
let mut group = c.benchmark_group("standardize_slsl_4d");
201+
fn bench_compare_standardize_3d(c: &mut Criterion) {
202+
let mut group = c.benchmark_group("standardize_3d_compare");
203+
let configs = gen_3d_configs();
216204

217-
for config in TEST_CONFIGS_4D {
218-
group.bench_function(config.name, |b| {
205+
for config in configs.iter() {
206+
// slsl
207+
group.bench_function(format!("slsl/{}", config.name), |b| {
219208
b.iter(|| {
220-
// Create test data
221209
let data: Vec<f32> = (0..config.shape.iter().product::<usize>())
222210
.map(|i| (i as f32) * 0.01)
223211
.collect();
224212
let tensor = Tensor::from_vec(data, config.shape).unwrap();
225-
226-
// Perform standardization
227213
black_box(
228214
tensor
229215
.standardize(&config.mean, &config.std, config.dim)
230216
.unwrap(),
231217
)
232218
});
233219
});
234-
}
235-
236-
group.finish();
237-
}
238-
239-
fn bench_ndarray_standardize_3d(c: &mut Criterion) {
240-
let mut group = c.benchmark_group("standardize_ndarray_3d");
241220

242-
for config in TEST_CONFIGS_3D {
243-
group.bench_function(config.name, |b| {
221+
// ndarray
222+
group.bench_function(format!("ndarray/{}", config.name), |b| {
244223
b.iter(|| {
245-
// Create test data
246224
let data: Vec<f32> = (0..config.shape.iter().product::<usize>())
247225
.map(|i| (i as f32) * 0.01)
248226
.collect();
249227
let mut array = Array::from_shape_vec(IxDyn(&config.shape), data).unwrap();
250228
let mean = ArrayView1::from(&config.mean);
251229
let std = ArrayView1::from(&config.std);
252230
standardize_ndarray(&mut array, mean, std, config.dim).unwrap();
253-
black_box(());
254-
array
231+
black_box(&array);
255232
});
256233
});
257234
}
258235

259236
group.finish();
260237
}
261238

262-
fn bench_ndarray_standardize_4d(c: &mut Criterion) {
263-
let mut group = c.benchmark_group("standardize_ndarray_4d");
239+
fn bench_compare_standardize_4d(c: &mut Criterion) {
240+
let mut group = c.benchmark_group("standardize_4d_compare");
241+
let configs = gen_4d_configs();
242+
243+
for config in configs.iter() {
244+
// slsl
245+
group.bench_function(format!("slsl/{}", config.name), |b| {
246+
b.iter(|| {
247+
let data: Vec<f32> = (0..config.shape.iter().product::<usize>())
248+
.map(|i| (i as f32) * 0.01)
249+
.collect();
250+
let tensor = Tensor::from_vec(data, config.shape).unwrap();
251+
black_box(
252+
tensor
253+
.standardize(&config.mean, &config.std, config.dim)
254+
.unwrap(),
255+
)
256+
});
257+
});
264258

265-
for config in TEST_CONFIGS_4D {
266-
group.bench_function(config.name, |b| {
259+
// ndarray
260+
group.bench_function(format!("ndarray/{}", config.name), |b| {
267261
b.iter(|| {
268-
// Create test data
269262
let data: Vec<f32> = (0..config.shape.iter().product::<usize>())
270263
.map(|i| (i as f32) * 0.01)
271264
.collect();
272265
let mut array = Array::from_shape_vec(IxDyn(&config.shape), data).unwrap();
273266
let mean = ArrayView1::from(&config.mean);
274267
let std = ArrayView1::from(&config.std);
275268
standardize_ndarray(&mut array, mean, std, config.dim).unwrap();
276-
black_box(());
277-
array
269+
black_box(&array);
278270
});
279271
});
280272
}
@@ -284,9 +276,7 @@ fn bench_ndarray_standardize_4d(c: &mut Criterion) {
284276

285277
criterion_group!(
286278
benches,
287-
bench_slsl_standardize_3d,
288-
bench_slsl_standardize_4d,
289-
bench_ndarray_standardize_3d,
290-
bench_ndarray_standardize_4d
279+
bench_compare_standardize_3d,
280+
bench_compare_standardize_4d
291281
);
292282
criterion_main!(benches);

0 commit comments

Comments
 (0)