Skip to content

Commit 964077a

Browse files
committed
TEXT-155: Add a generic IntersectionSimilarity measure
1 parent 0ada5fa commit 964077a

4 files changed

Lines changed: 591 additions & 0 deletions

File tree

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.commons.text.similarity;
18+
19+
import java.util.Objects;
20+
21+
/**
22+
* Container class to store the intersection results between two sets.
23+
*
24+
* <p>Stores the size of set A, set B and the intersection of A and B (<code>A &#8745; B</code>).
25+
* The result can be used to produce various similarity metrics, for example the Jaccard or F1-score.</p>
26+
*
27+
* <p>This class is immutable.</p>
28+
*
29+
* @since 1.7
30+
* @see <a href="https://en.wikipedia.org/wiki/Jaccard_index">Jaccard index</a>
31+
* @see <a href="https://en.wikipedia.org/wiki/F1_score">F1 score</a>
32+
*/
33+
public class IntersectionResult {
34+
/**
35+
* The size of set A.
36+
*/
37+
private final int sizeA;
38+
/**
39+
* The size of set B.
40+
*/
41+
private final int sizeB;
42+
/**
43+
* The size of the intersection between set A and B.
44+
*/
45+
private final int intersection;
46+
47+
/**
48+
* Create the results for an intersection between two sets.
49+
*
50+
* @param sizeA the size of set A ({@code |A|})
51+
* @param sizeB the size of set B ({@code |B|})
52+
* @param intersection the size of the intersection of A and B (<code>A &#8745; B</code>)
53+
* @throws IllegalArgumentException if the sizes are negative or the intersection is greater
54+
* than the minimum of the two set sizes
55+
*/
56+
public IntersectionResult(final int sizeA, final int sizeB, final int intersection) {
57+
if (sizeA < 0) {
58+
throw new IllegalArgumentException("Set size |A| is not positive: " + sizeA);
59+
}
60+
if (sizeB < 0) {
61+
throw new IllegalArgumentException("Set size |B| is not positive: " + sizeB);
62+
}
63+
if (intersection < 0 || intersection > Math.min(sizeA, sizeB)) {
64+
throw new IllegalArgumentException("Invalid intersection of |A| and |B|: " + intersection);
65+
}
66+
this.sizeA = sizeA;
67+
this.sizeB = sizeB;
68+
this.intersection = intersection;
69+
}
70+
71+
/**
72+
* Get the size of set A (|A|).
73+
*
74+
* @return |A|
75+
*/
76+
public int getSizeA() {
77+
return sizeA;
78+
}
79+
80+
/**
81+
* Get the size of set B (|B|).
82+
*
83+
* @return |B|
84+
*/
85+
public int getSizeB() {
86+
return sizeB;
87+
}
88+
89+
/**
90+
* Get the size of the intersection between set A and B.
91+
*
92+
* @return <code>A &#8745; B</code>
93+
*/
94+
public int getIntersection() {
95+
return intersection;
96+
}
97+
/**
98+
* Get the size of the union between set A and B.
99+
*
100+
* @return <code>A &#8745; B</code>
101+
*/
102+
public long getUnion() {
103+
return (long) sizeA + sizeB - intersection;
104+
}
105+
106+
/**
107+
* Gets the Jaccard.
108+
*
109+
* <p>This implementation defines the result as zero if there is no intersection,
110+
* even when the size of both sets is zero.</p>
111+
*
112+
* @return the Jaccard
113+
* @see <a href="https://en.wikipedia.org/wiki/Jaccard_index">Jaccard index</a>
114+
*/
115+
public double getJaccard() {
116+
return intersection == 0 ? 0.0 : (double) intersection / getUnion();
117+
}
118+
119+
/**
120+
* Gets the F1 score.
121+
*
122+
* <p>This implementation defines the result as zero if there is no intersection,
123+
* even when the size of both sets is zero.</p>
124+
*
125+
* @return the F1 score
126+
* @see <a href="https://en.wikipedia.org/wiki/F1_score">F1 score</a>
127+
*/
128+
public double getF1Score() {
129+
return intersection == 0 ? 0.0 : 2.0 * intersection / ((long) sizeA + sizeB);
130+
}
131+
132+
@Override
133+
public boolean equals(final Object o) {
134+
if (this == o) {
135+
return true;
136+
}
137+
if (o == null || getClass() != o.getClass()) {
138+
return false;
139+
}
140+
final IntersectionResult result = (IntersectionResult) o;
141+
return sizeA == result.sizeA && sizeB == result.sizeB && intersection == result.intersection;
142+
}
143+
144+
@Override
145+
public int hashCode() {
146+
return Objects.hash(sizeA, sizeB, intersection);
147+
}
148+
149+
@Override
150+
public String toString() {
151+
return "Size A: " + sizeA + ", Size B: " + sizeB + ", Intersection: " + intersection;
152+
}
153+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.commons.text.similarity;
18+
19+
import java.util.Set;
20+
import java.util.function.Function;
21+
import java.util.stream.Collectors;
22+
23+
/**
24+
* Measures the intersection of two sets created from a pair of character
25+
* sequences.
26+
*
27+
* <p>It is assumed that the type {@code T} correctly conforms to the
28+
* requirements for storage within a {@link Set}, ideally the type is
29+
* immutable and implements {@link Object#equals(Object)}.</p>
30+
*
31+
* @param <T> the type of the set extracted from the character sequence
32+
* @since 1.7
33+
* @see Set
34+
*/
35+
public class IntersectionSimilarity<T> implements SimilarityScore<IntersectionResult> {
36+
/** The converter used to create the set elements. */
37+
private final Function<CharSequence, Set<T>> converter;
38+
39+
/**
40+
* Create a new set similarity using the provided converter.
41+
*
42+
* @param converter the converter used to create the set
43+
* @throws IllegalArgumentException if the converter is null
44+
*/
45+
public IntersectionSimilarity(Function<CharSequence, Set<T>> converter) {
46+
if (converter == null) {
47+
throw new IllegalArgumentException("Converter must not be null");
48+
}
49+
this.converter = converter;
50+
}
51+
52+
/**
53+
* Calculates the intersection of two character sequences passed as input.
54+
*
55+
* @param left first character sequence
56+
* @param right second character sequence
57+
* @return the intersection result
58+
* @throws IllegalArgumentException if either input sequence is {@code null}
59+
*/
60+
@Override
61+
public IntersectionResult apply(final CharSequence left, final CharSequence right) {
62+
if (left == null || right == null) {
63+
throw new IllegalArgumentException("Input cannot be null");
64+
}
65+
final Set<T> setA = converter.apply(left);
66+
final Set<T> setB = converter.apply(right);
67+
final int sizeA = setA.size();
68+
final int sizeB = setB.size();
69+
// Short-cut if either set is empty
70+
if (Math.min(sizeA, sizeB) == 0) {
71+
// No intersection
72+
return new IntersectionResult(sizeA, sizeB, 0);
73+
}
74+
// We can use intValue() to convert the Long output from the
75+
// collector as the intersection cannot be bigger than either set.
76+
final int intersection = setA.stream().filter(setB::contains).collect(Collectors.counting()).intValue();
77+
return new IntersectionResult(sizeA, sizeB, intersection);
78+
}
79+
}

0 commit comments

Comments
 (0)