Skip to content

Commit af2c86b

Browse files
authored
Add some more merge tests (#135)
1 parent 57d5a90 commit af2c86b

3 files changed

Lines changed: 55 additions & 6 deletions

File tree

src/UnweightedSamplingSingle.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ function Base.merge(ss::SingleAlgRSWRSKIPSampler...)
4444
ns = [nobs(s) for s in ss]
4545
n_tot = sum(ns)
4646
ps = cumsum(ns ./ n_tot)
47-
r = rand(s1.rng)
48-
value = ss[findfirst(p -> r < p, ps)].value
49-
return typeof(s1)(n_tot, sum(s.skip_k for s in ss), ss[1].rng, value)
47+
r = rand(ss[1].rng)
48+
value = ss[findfirst(p -> r < p, ps)].rvalue
49+
return typeof(ss[1])(n_tot, sum(s.skip_k for s in ss), ss[1].rng, value)
5050
end
5151

5252
function Base.merge!(s1::SingleAlgRSWRSKIPSampler_Mut, ss::SingleAlgRSWRSKIPSampler_Mut...)

src/WeightedSamplingSingle.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ function Base.merge(ss::SingleAlgWRSWRSKIPSampler...)
4949
ns = [s.total_w for s in ss]
5050
n_tot = sum(ns)
5151
ps = cumsum(ns ./ n_tot)
52-
r = rand(s1.rng)
53-
value = ss[findfirst(p -> r < p, ps)].value
54-
return typeof(s1)(sum(s.seen_k for s in ss), sum(s.total_w for s in ss), sum(s.skip_w for s in ss),
52+
r = rand(ss[1].rng)
53+
value = ss[findfirst(p -> r < p, ps)].rvalue
54+
return typeof(ss[1])(sum(s.seen_k for s in ss), sum(s.total_w for s in ss), sum(s.skip_w for s in ss),
5555
ss[1].rng, value)
5656
end
5757

test/merge_tests.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,53 @@
4343
m == AlgRSWRSKIP() ? fit!(s2, 2) : fit!(s2, 2, 1.0)
4444
@test value(merge!(s1, s2)) in (1, 2)
4545
end
46+
47+
iters = (1:10, 11:30)
48+
reps = 10000
49+
for m in (AlgRSWRSKIP(),)
50+
count_s1 = 0
51+
for _ in 1:reps
52+
s1 = ReservoirSampler{Int}(rng, m)
53+
s2 = ReservoirSampler{Int}(rng, m)
54+
for x in iters[1] fit!(s1, x) end
55+
for x in iters[2] fit!(s2, x) end
56+
s_merged = merge(s1, s2)
57+
if value(s_merged) <= 10
58+
count_s1 += 1
59+
end
60+
end
61+
chisq_test = ChisqTest([count_s1, reps - count_s1], [1/3, 2/3])
62+
@test pvalue(chisq_test) > 0.05
63+
end
64+
65+
for m in (AlgWRSWRSKIP(),)
66+
count_s1 = 0
67+
for _ in 1:reps
68+
s1 = ReservoirSampler{Int}(rng, m)
69+
s2 = ReservoirSampler{Int}(rng, m)
70+
for x in iters[1] fit!(s1, x, 1.0) end
71+
for x in iters[2] fit!(s2, x, 1.0) end
72+
s_merged = merge(s1, s2)
73+
if value(s_merged) <= 10
74+
count_s1 += 1
75+
end
76+
end
77+
chisq_test = ChisqTest([count_s1, reps - count_s1], [1/3, 2/3])
78+
@test pvalue(chisq_test) > 0.05
79+
80+
rng = StableRNG(45)
81+
count_s1 = 0
82+
for _ in 1:reps
83+
s1 = ReservoirSampler{Int}(rng, m)
84+
s2 = ReservoirSampler{Int}(rng, m)
85+
fit!(s1, 1, 10.0)
86+
fit!(s2, 2, 20.0)
87+
s_merged = merge(s1, s2)
88+
if value(s_merged) == 1
89+
count_s1 += 1
90+
end
91+
end
92+
chisq_test = ChisqTest([count_s1, reps - count_s1], [1/3, 2/3])
93+
@test pvalue(chisq_test) > 0.05
94+
end
4695
end

0 commit comments

Comments
 (0)