Skip to content

Commit ad35c61

Browse files
authored
Add tests for StreamSampler functionality (#136)
1 parent 71b2f60 commit ad35c61

2 files changed

Lines changed: 79 additions & 0 deletions

File tree

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ using StreamSampling
1616
include("unweighted_sampling_multi_tests.jl")
1717
include("weighted_sampling_single_tests.jl")
1818
include("weighted_sampling_multi_tests.jl")
19+
include("stream_sampling_tests.jl")
1920
include("sequential_sampling_tests.jl")
2021
include("merge_tests.jl")
2122
include("empty_tests.jl")
2223
include("benchmark_tests.jl")
2324
end
25+

test/stream_sampling_tests.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
@testset "StreamSampler tests" begin
2+
rng = StableRNG(45)
3+
N = 10
4+
n = 3
5+
reps = 100000
6+
7+
for alg in (AlgD(), AlgHiddenShuffle())
8+
dict_res = Dict{Vector{Int}, Int}()
9+
for _ in 1:reps
10+
s = StreamSampler{Int}(rng, 1:N, n, N, alg)
11+
out = collect(s)
12+
dict_res[out] = get(dict_res, out, 0) + 1
13+
end
14+
15+
valid_triples = 120
16+
count_est = Int[]
17+
for i in 1:N, j in i+1:N, k in j+1:N
18+
push!(count_est, get(dict_res, [i, j, k], 0))
19+
end
20+
21+
chisq_test = ChisqTest(count_est, fill(1/valid_triples, valid_triples))
22+
@test pvalue(chisq_test) > 0.05
23+
end
24+
25+
dict_res = Dict{Vector{Int}, Int}()
26+
for _ in 1:reps
27+
s = StreamSampler{Int}(rng, 1:N, n, N, AlgORDSWR())
28+
out = collect(s)
29+
dict_res[out] = get(dict_res, out, 0) + 1
30+
end
31+
32+
count_est = Int[]
33+
ps_exact = Float64[]
34+
for i in 1:N, j in i:N, k in j:N
35+
push!(count_est, get(dict_res, [i, j, k], 0))
36+
if i == j == k
37+
push!(ps_exact, 1/(N^3))
38+
elseif i == j || j == k
39+
push!(ps_exact, 3/(N^3))
40+
else
41+
push!(ps_exact, 6/(N^3))
42+
end
43+
end
44+
45+
chisq_test = ChisqTest(count_est, ps_exact)
46+
@test pvalue(chisq_test) > 0.05
47+
48+
weights = [i <= 5 ? 1.0 : 2.0 for i in 1:N]
49+
W = sum(weights)
50+
wfunc(i) = weights[i]
51+
52+
dict_res = Dict{Vector{Int}, Int}()
53+
for _ in 1:reps
54+
s = StreamSampler{Int}(rng, 1:N, wfunc, n, W, AlgORDWSWR())
55+
out = collect(s)
56+
dict_res[out] = get(dict_res, out, 0) + 1
57+
end
58+
59+
count_est = Int[]
60+
ps_exact = Float64[]
61+
for i in 1:N, j in i:N, k in j:N
62+
push!(count_est, get(dict_res, [i, j, k], 0))
63+
wi, wj, wk = weights[i], weights[j], weights[k]
64+
if i == j == k
65+
push!(ps_exact, (wi^3) / (W^3))
66+
elseif i == j
67+
push!(ps_exact, 3 * (wi^2 * wk) / (W^3))
68+
elseif j == k
69+
push!(ps_exact, 3 * (wi * wj^2) / (W^3))
70+
else
71+
push!(ps_exact, 6 * (wi * wj * wk) / (W^3))
72+
end
73+
end
74+
75+
chisq_test = ChisqTest(count_est, ps_exact)
76+
@test pvalue(chisq_test) > 0.05
77+
end

0 commit comments

Comments
 (0)