22use criterion:: { criterion_group, criterion_main, Criterion } ;
33use ndarray:: { Array , ArrayView1 , IxDyn , Zip } ;
44use slsl:: Tensor ;
5+ use std:: collections:: HashSet ;
56use std:: hint:: black_box;
67
78// ndarray implementation from user
@@ -53,7 +54,7 @@ pub fn standardize_ndarray(
5354 Zip :: from ( x)
5455 . and ( mean_broadcast)
5556 . and ( std_broadcast)
56- . par_for_each ( |x_val, & mean_val, & std_val| {
57+ . for_each ( |x_val, & mean_val, & std_val| {
5758 * x_val = ( * x_val - mean_val) / std_val;
5859 } ) ;
5960
@@ -62,7 +63,7 @@ pub fn standardize_ndarray(
6263
6364// Test configurations for 3D tensors
6465struct TestConfig3D {
65- name : & ' static str ,
66+ name : String ,
6667 shape : [ usize ; 3 ] ,
6768 dim : usize ,
6869 mean : [ f32 ; 3 ] ,
@@ -71,210 +72,201 @@ struct TestConfig3D {
7172
7273// Test configurations for 4D tensors
7374struct TestConfig4D {
74- name : & ' static str ,
75+ name : String ,
7576 shape : [ usize ; 4 ] ,
7677 dim : usize ,
7778 mean : [ f32 ; 3 ] ,
7879 std : [ f32 ; 3 ] ,
7980}
8081
81- const TEST_CONFIGS_3D : & [ TestConfig3D ] = & [
82- // 224x224x3 configurations
83- TestConfig3D {
84- name : "224x224x3_hwc_zeros" ,
85- shape : [ 224 , 224 , 3 ] ,
86- dim : 2 ,
87- mean : [ 0.0 , 0.0 , 0.0 ] ,
88- std : [ 1.0 , 1.0 , 1.0 ] ,
89- } ,
90- TestConfig3D {
91- name : "224x224x3_hwc_half" ,
92- shape : [ 224 , 224 , 3 ] ,
93- dim : 2 ,
94- mean : [ 0.5 , 0.5 , 0.5 ] ,
95- std : [ 0.5 , 0.5 , 0.5 ] ,
96- } ,
97- TestConfig3D {
98- name : "224x224x3_hwc_imagenet" ,
99- shape : [ 224 , 224 , 3 ] ,
100- dim : 2 ,
101- mean : [ 0.48145466 , 0.4578275 , 0.40821073 ] ,
102- std : [ 0.26862954 , 0.261_302_6 , 0.275_777_1 ] ,
103- } ,
104- // 3x224x224 configurations
105- TestConfig3D {
106- name : "3x224x224_chw_zeros" ,
107- shape : [ 3 , 224 , 224 ] ,
108- dim : 0 ,
109- mean : [ 0.0 , 0.0 , 0.0 ] ,
110- std : [ 1.0 , 1.0 , 1.0 ] ,
111- } ,
112- TestConfig3D {
113- name : "3x224x224_chw_half" ,
114- shape : [ 3 , 224 , 224 ] ,
115- dim : 0 ,
116- mean : [ 0.5 , 0.5 , 0.5 ] ,
117- std : [ 0.5 , 0.5 , 0.5 ] ,
118- } ,
119- TestConfig3D {
120- name : "3x224x224_chw_imagenet" ,
121- shape : [ 3 , 224 , 224 ] ,
122- dim : 0 ,
123- mean : [ 0.48145466 , 0.4578275 , 0.40821073 ] ,
124- std : [ 0.26862954 , 0.261_302_6 , 0.275_777_1 ] ,
125- } ,
126- // Other sizes
127- TestConfig3D {
128- name : "256x256x3_hwc_zeros" ,
129- shape : [ 256 , 256 , 3 ] ,
130- dim : 2 ,
131- mean : [ 0.0 , 0.0 , 0.0 ] ,
132- std : [ 1.0 , 1.0 , 1.0 ] ,
133- } ,
134- TestConfig3D {
135- name : "3x256x256_chw_zeros" ,
136- shape : [ 3 , 256 , 256 ] ,
137- dim : 0 ,
138- mean : [ 0.0 , 0.0 , 0.0 ] ,
139- std : [ 1.0 , 1.0 , 1.0 ] ,
140- } ,
141- TestConfig3D {
142- name : "512x512x3_hwc_zeros" ,
143- shape : [ 512 , 512 , 3 ] ,
144- dim : 2 ,
145- mean : [ 0.0 , 0.0 , 0.0 ] ,
146- std : [ 1.0 , 1.0 , 1.0 ] ,
147- } ,
148- TestConfig3D {
149- name : "3x512x512_chw_zeros" ,
150- shape : [ 3 , 512 , 512 ] ,
151- dim : 0 ,
152- mean : [ 0.0 , 0.0 , 0.0 ] ,
153- std : [ 1.0 , 1.0 , 1.0 ] ,
154- } ,
155- TestConfig3D {
156- name : "1024x1024x3_hwc_zeros" ,
157- shape : [ 1024 , 1024 , 3 ] ,
158- dim : 2 ,
159- mean : [ 0.0 , 0.0 , 0.0 ] ,
160- std : [ 1.0 , 1.0 , 1.0 ] ,
161- } ,
162- TestConfig3D {
163- name : "3x1024x1024_chw_zeros" ,
164- shape : [ 3 , 1024 , 1024 ] ,
165- dim : 0 ,
166- mean : [ 0.0 , 0.0 , 0.0 ] ,
167- std : [ 1.0 , 1.0 , 1.0 ] ,
168- } ,
169- ] ;
82+ fn gen_3d_configs ( ) -> Vec < TestConfig3D > {
83+ let means = [
84+ ( "zeros" , [ 0.0f32 , 0.0 , 0.0 ] , [ 1.0f32 , 1.0 , 1.0 ] ) ,
85+ ( "half" , [ 0.5f32 , 0.5 , 0.5 ] , [ 0.5f32 , 0.5 , 0.5 ] ) ,
86+ (
87+ "imagenet" ,
88+ [ 0.48145466 , 0.4578275 , 0.40821073 ] ,
89+ [ 0.26862954 , 0.261_302_6 , 0.275_777_1 ] ,
90+ ) ,
91+ ] ;
17092
171- const TEST_CONFIGS_4D : & [ TestConfig4D ] = & [
172- // 4D NCHW format (batch, channels, height, width) - typical for PyTorch/vision models
173- TestConfig4D {
174- name : "1x3x224x224_nchw_imagenet" ,
175- shape : [ 1 , 3 , 224 , 224 ] ,
176- dim : 1 ,
177- mean : [ 0.48145466 , 0.4578275 , 0.40821073 ] ,
178- std : [ 0.26862954 , 0.2613026 , 0.2757771 ] ,
179- } ,
180- TestConfig4D {
181- name : "1x3x640x640_nchw_imagenet" ,
182- shape : [ 1 , 3 , 640 , 640 ] ,
183- dim : 1 ,
184- mean : [ 0.48145466 , 0.4578275 , 0.40821073 ] ,
185- std : [ 0.26862954 , 0.2613026 , 0.2757771 ] ,
186- } ,
187- ] ;
93+ let sizes = [ ( 224usize , 224usize ) , ( 640 , 640 ) , ( 1024 , 1024 ) ] ;
94+ let mut v = Vec :: new ( ) ;
95+ for ( h, w) in sizes {
96+ // Generate three permutations where channel size 3 is placed at each axis
97+ let shapes = [
98+ ( [ 3 , h, w] , 0 , format ! ( "3x{}x{}" , h, w) ) ,
99+ ( [ h, 3 , w] , 1 , format ! ( "{}x3x{}" , h, w) ) ,
100+ ( [ h, w, 3 ] , 2 , format ! ( "{}x{}x3" , h, w) ) ,
101+ ] ;
102+ for ( shape, dim, tag) in shapes {
103+ for ( case, mean, std) in means. iter ( ) {
104+ v. push ( TestConfig3D {
105+ name : format ! ( "{}_{case}" , tag) ,
106+ shape,
107+ dim,
108+ mean : * mean,
109+ std : * std,
110+ } ) ;
111+ }
112+ }
113+ }
114+ v
115+ }
188116
189- fn bench_slsl_standardize_3d ( c : & mut Criterion ) {
190- let mut group = c. benchmark_group ( "standardize_slsl_3d" ) ;
117+ fn gen_4d_configs ( ) -> Vec < TestConfig4D > {
118+ // Start from 3D base (H,W,3) sizes and insert 10 at all possible positions,
119+ // then also permute channel (3) to different axes by using the three 3D permutations and inserting 10.
120+ let sizes = [ ( 224usize , 224usize ) , ( 640 , 640 ) , ( 1024 , 1024 ) ] ;
121+ let mut v = Vec :: new ( ) ;
122+ let mut seen: HashSet < ( usize , usize , usize , usize , usize ) > = HashSet :: new ( ) ;
123+ let mean = [ 0.48145466 , 0.4578275 , 0.40821073 ] ;
124+ let std = [ 0.26862954 , 0.2613026 , 0.2757771 ] ;
191125
192- for config in TEST_CONFIGS_3D {
193- group. bench_function ( config. name , |b| {
194- b. iter ( || {
195- // Create test data
196- let data: Vec < f32 > = ( 0 ..config. shape . iter ( ) . product :: < usize > ( ) )
197- . map ( |i| ( i as f32 ) * 0.01 )
198- . collect ( ) ;
199- let tensor = Tensor :: from_vec ( data, config. shape ) . unwrap ( ) ;
126+ for ( h, w) in sizes {
127+ // Base 3D permutations
128+ let bases = [
129+ ( [ 3 , h, w] , 0 , format ! ( "3x{}x{}" , h, w) ) ,
130+ ( [ h, 3 , w] , 1 , format ! ( "{}x3x{}" , h, w) ) ,
131+ ( [ h, w, 3 ] , 2 , format ! ( "{}x{}x3" , h, w) ) ,
132+ ] ;
200133
201- // Perform standardization
202- black_box (
203- tensor
204- . standardize ( & config. mean , & config. std , config. dim )
205- . unwrap ( ) ,
206- )
207- } ) ;
208- } ) ;
209- }
134+ for ( base, ch_dim, tag3d) in bases {
135+ // Insert 10 at all four positions
136+ let shapes_4d = [
137+ (
138+ [ 10 , base[ 0 ] , base[ 1 ] , base[ 2 ] ] ,
139+ ch_dim + 1 ,
140+ format ! ( "10x{}" , tag3d) ,
141+ ) ,
142+ (
143+ [ base[ 0 ] , 10 , base[ 1 ] , base[ 2 ] ] ,
144+ ch_dim,
145+ format ! (
146+ "{}x10x{}" ,
147+ if ch_dim == 0 {
148+ format!( "{}x{}" , base[ 1 ] , base[ 2 ] )
149+ } else {
150+ format!( "{}" , base[ 0 ] )
151+ } ,
152+ if ch_dim == 0 {
153+ format!( "{}" , base[ 0 ] )
154+ } else {
155+ format!( "{}x{}" , base[ 1 ] , base[ 2 ] )
156+ }
157+ ) ,
158+ ) ,
159+ (
160+ [ base[ 0 ] , base[ 1 ] , 10 , base[ 2 ] ] ,
161+ ch_dim,
162+ format ! ( "{}x{}x10x{}" , base[ 0 ] , base[ 1 ] , base[ 2 ] ) ,
163+ ) ,
164+ (
165+ [ base[ 0 ] , base[ 1 ] , base[ 2 ] , 10 ] ,
166+ ch_dim,
167+ format ! ( "{}x{}x{}x10" , base[ 0 ] , base[ 1 ] , base[ 2 ] ) ,
168+ ) ,
169+ ] ;
210170
211- group. finish ( ) ;
171+ for ( shape4, dim4, _tag4) in shapes_4d {
172+ // Derive channel dim by locating the index of value 3
173+ let mut channel_dim = None ;
174+ for ( i, & d) in shape4. iter ( ) . enumerate ( ) {
175+ if d == 3 {
176+ channel_dim = Some ( i) ;
177+ break ;
178+ }
179+ }
180+ let dim = channel_dim. unwrap_or ( dim4) ;
181+ let key = ( shape4[ 0 ] , shape4[ 1 ] , shape4[ 2 ] , shape4[ 3 ] , dim) ;
182+ if seen. insert ( key) {
183+ let name = format ! (
184+ "{}x{}x{}x{}:dim{}" ,
185+ shape4[ 0 ] , shape4[ 1 ] , shape4[ 2 ] , shape4[ 3 ] , dim
186+ ) ;
187+ v. push ( TestConfig4D {
188+ name,
189+ shape : shape4,
190+ dim,
191+ mean,
192+ std,
193+ } ) ;
194+ }
195+ }
196+ }
197+ }
198+ v
212199}
213200
214- fn bench_slsl_standardize_4d ( c : & mut Criterion ) {
215- let mut group = c. benchmark_group ( "standardize_slsl_4d" ) ;
201+ fn bench_compare_standardize_3d ( c : & mut Criterion ) {
202+ let mut group = c. benchmark_group ( "standardize_3d_compare" ) ;
203+ let configs = gen_3d_configs ( ) ;
216204
217- for config in TEST_CONFIGS_4D {
218- group. bench_function ( config. name , |b| {
205+ for config in configs. iter ( ) {
206+ // slsl
207+ group. bench_function ( format ! ( "slsl/{}" , config. name) , |b| {
219208 b. iter ( || {
220- // Create test data
221209 let data: Vec < f32 > = ( 0 ..config. shape . iter ( ) . product :: < usize > ( ) )
222210 . map ( |i| ( i as f32 ) * 0.01 )
223211 . collect ( ) ;
224212 let tensor = Tensor :: from_vec ( data, config. shape ) . unwrap ( ) ;
225-
226- // Perform standardization
227213 black_box (
228214 tensor
229215 . standardize ( & config. mean , & config. std , config. dim )
230216 . unwrap ( ) ,
231217 )
232218 } ) ;
233219 } ) ;
234- }
235-
236- group. finish ( ) ;
237- }
238-
239- fn bench_ndarray_standardize_3d ( c : & mut Criterion ) {
240- let mut group = c. benchmark_group ( "standardize_ndarray_3d" ) ;
241220
242- for config in TEST_CONFIGS_3D {
243- group. bench_function ( config. name , |b| {
221+ // ndarray
222+ group. bench_function ( format ! ( "ndarray/{}" , config. name) , |b| {
244223 b. iter ( || {
245- // Create test data
246224 let data: Vec < f32 > = ( 0 ..config. shape . iter ( ) . product :: < usize > ( ) )
247225 . map ( |i| ( i as f32 ) * 0.01 )
248226 . collect ( ) ;
249227 let mut array = Array :: from_shape_vec ( IxDyn ( & config. shape ) , data) . unwrap ( ) ;
250228 let mean = ArrayView1 :: from ( & config. mean ) ;
251229 let std = ArrayView1 :: from ( & config. std ) ;
252230 standardize_ndarray ( & mut array, mean, std, config. dim ) . unwrap ( ) ;
253- black_box ( ( ) ) ;
254- array
231+ black_box ( & array) ;
255232 } ) ;
256233 } ) ;
257234 }
258235
259236 group. finish ( ) ;
260237}
261238
262- fn bench_ndarray_standardize_4d ( c : & mut Criterion ) {
263- let mut group = c. benchmark_group ( "standardize_ndarray_4d" ) ;
239+ fn bench_compare_standardize_4d ( c : & mut Criterion ) {
240+ let mut group = c. benchmark_group ( "standardize_4d_compare" ) ;
241+ let configs = gen_4d_configs ( ) ;
242+
243+ for config in configs. iter ( ) {
244+ // slsl
245+ group. bench_function ( format ! ( "slsl/{}" , config. name) , |b| {
246+ b. iter ( || {
247+ let data: Vec < f32 > = ( 0 ..config. shape . iter ( ) . product :: < usize > ( ) )
248+ . map ( |i| ( i as f32 ) * 0.01 )
249+ . collect ( ) ;
250+ let tensor = Tensor :: from_vec ( data, config. shape ) . unwrap ( ) ;
251+ black_box (
252+ tensor
253+ . standardize ( & config. mean , & config. std , config. dim )
254+ . unwrap ( ) ,
255+ )
256+ } ) ;
257+ } ) ;
264258
265- for config in TEST_CONFIGS_4D {
266- group. bench_function ( config. name , |b| {
259+ // ndarray
260+ group. bench_function ( format ! ( "ndarray/{}" , config. name) , |b| {
267261 b. iter ( || {
268- // Create test data
269262 let data: Vec < f32 > = ( 0 ..config. shape . iter ( ) . product :: < usize > ( ) )
270263 . map ( |i| ( i as f32 ) * 0.01 )
271264 . collect ( ) ;
272265 let mut array = Array :: from_shape_vec ( IxDyn ( & config. shape ) , data) . unwrap ( ) ;
273266 let mean = ArrayView1 :: from ( & config. mean ) ;
274267 let std = ArrayView1 :: from ( & config. std ) ;
275268 standardize_ndarray ( & mut array, mean, std, config. dim ) . unwrap ( ) ;
276- black_box ( ( ) ) ;
277- array
269+ black_box ( & array) ;
278270 } ) ;
279271 } ) ;
280272 }
@@ -284,9 +276,7 @@ fn bench_ndarray_standardize_4d(c: &mut Criterion) {
284276
285277criterion_group ! (
286278 benches,
287- bench_slsl_standardize_3d,
288- bench_slsl_standardize_4d,
289- bench_ndarray_standardize_3d,
290- bench_ndarray_standardize_4d
279+ bench_compare_standardize_3d,
280+ bench_compare_standardize_4d
291281) ;
292282criterion_main ! ( benches) ;
0 commit comments