|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | +package org.apache.commons.text.similarity; |
| 18 | + |
| 19 | +import java.util.Collection; |
| 20 | +import java.util.HashMap; |
| 21 | +import java.util.Map; |
| 22 | +import java.util.Map.Entry; |
| 23 | +import java.util.Set; |
| 24 | +import java.util.function.Function; |
| 25 | + |
| 26 | +/** |
| 27 | + * Measures the intersection of two sets created from a pair of character sequences. |
| 28 | + * |
| 29 | + * <p>It is assumed that the type {@code T} correctly conforms to the requirements for storage |
| 30 | + * within a {@link Set} or {@link HashMap}. Ideally the type is immutable and implements |
| 31 | + * {@link Object#equals(Object)} and {@link Object#hashCode()}.</p> |
| 32 | + * |
| 33 | + * @param <T> the type of the elements extracted from the character sequence |
| 34 | + * @since 1.7 |
| 35 | + * @see Set |
| 36 | + * @see HashMap |
| 37 | + */ |
| 38 | +public class IntersectionSimilarity<T> implements SimilarityScore<IntersectionResult> { |
| 39 | + /** The converter used to create the elements from the characters. */ |
| 40 | + private final Function<CharSequence, Collection<T>> converter; |
| 41 | + |
| 42 | + // The following is adapted from commons-collections for a Bag. |
| 43 | + // A Bag is a collection that can store the count of the number |
| 44 | + // of copies of each element. |
| 45 | + |
| 46 | + /** |
| 47 | + * Mutable counter class for storing the count of elements. |
| 48 | + */ |
| 49 | + private static class BagCount { |
| 50 | + /** The count. This is initialised to 1 upon construction. */ |
| 51 | + int count = 1; |
| 52 | + } |
| 53 | + |
| 54 | + /** |
| 55 | + * A minimal implementation of a Bag that can store elements and a count. |
| 56 | + * |
| 57 | + * <p>For the intended purpose the Bag does not have to be a {@link Collection}. It does not |
| 58 | + * even have to know its own size. |
| 59 | + */ |
| 60 | + private class TinyBag { |
| 61 | + /** The backing map. */ |
| 62 | + private final Map<T, BagCount> map; |
| 63 | + |
| 64 | + /** |
| 65 | + * Create a new tiny bag. |
| 66 | + * |
| 67 | + * @param initialCapacity the initial capacity |
| 68 | + */ |
| 69 | + TinyBag(int initialCapacity) { |
| 70 | + map = new HashMap<>(initialCapacity); |
| 71 | + } |
| 72 | + |
| 73 | + /** |
| 74 | + * Adds a new element to the bag, incrementing its count in the underlying map. |
| 75 | + * |
| 76 | + * @param object the object to add |
| 77 | + */ |
| 78 | + void add(T object) { |
| 79 | + final BagCount mut = map.get(object); |
| 80 | + if (mut == null) { |
| 81 | + map.put(object, new BagCount()); |
| 82 | + } else { |
| 83 | + mut.count++; |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + /** |
| 88 | + * Returns the number of occurrence of the given element in this bag by |
| 89 | + * looking up its count in the underlying map. |
| 90 | + * |
| 91 | + * @param object the object to search for |
| 92 | + * @return the number of occurrences of the object, zero if not found |
| 93 | + */ |
| 94 | + int getCount(final Object object) { |
| 95 | + final BagCount count = map.get(object); |
| 96 | + if (count != null) { |
| 97 | + return count.count; |
| 98 | + } |
| 99 | + return 0; |
| 100 | + } |
| 101 | + |
| 102 | + /** |
| 103 | + * Returns a Set view of the mappings contained in this bag. |
| 104 | + * |
| 105 | + * @return the Set view |
| 106 | + */ |
| 107 | + Set<Entry<T, BagCount>> entrySet() { |
| 108 | + return map.entrySet(); |
| 109 | + } |
| 110 | + |
| 111 | + /** |
| 112 | + * Get the number of unique elements in the bag. |
| 113 | + * |
| 114 | + * @return the unique element size |
| 115 | + */ |
| 116 | + int uniqueElementSize() { |
| 117 | + return map.size(); |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + /** |
| 122 | + * Create a new intersection similarity using the provided converter. |
| 123 | + * |
| 124 | + * <p>If the converter returns a {@link Set} then the intersection result will |
| 125 | + * not include duplicates. Any other {@link Collection} is used to produce a result |
| 126 | + * that will include duplicates in the intersect and union. |
| 127 | + * |
| 128 | + * @param converter the converter used to create the elements from the characters |
| 129 | + * @throws IllegalArgumentException if the converter is null |
| 130 | + */ |
| 131 | + public IntersectionSimilarity(Function<CharSequence, Collection<T>> converter) { |
| 132 | + if (converter == null) { |
| 133 | + throw new IllegalArgumentException("Converter must not be null"); |
| 134 | + } |
| 135 | + this.converter = converter; |
| 136 | + } |
| 137 | + |
| 138 | + /** |
| 139 | + * Calculates the intersection of two character sequences passed as input. |
| 140 | + * |
| 141 | + * @param left first character sequence |
| 142 | + * @param right second character sequence |
| 143 | + * @return the intersection result |
| 144 | + * @throws IllegalArgumentException if either input sequence is {@code null} |
| 145 | + */ |
| 146 | + @Override |
| 147 | + public IntersectionResult apply(final CharSequence left, final CharSequence right) { |
| 148 | + if (left == null || right == null) { |
| 149 | + throw new IllegalArgumentException("Input cannot be null"); |
| 150 | + } |
| 151 | + |
| 152 | + // Create the elements from the sequences |
| 153 | + final Collection<T> objectsA = converter.apply(left); |
| 154 | + final Collection<T> objectsB = converter.apply(right); |
| 155 | + final int sizeA = objectsA.size(); |
| 156 | + final int sizeB = objectsB.size(); |
| 157 | + |
| 158 | + // Short-cut if either collection is empty |
| 159 | + if (Math.min(sizeA, sizeB) == 0) { |
| 160 | + // No intersection |
| 161 | + return new IntersectionResult(sizeA, sizeB, 0); |
| 162 | + } |
| 163 | + |
| 164 | + // Intersection = count the number of shared elements |
| 165 | + int intersection; |
| 166 | + if (objectsA instanceof Set && objectsB instanceof Set) { |
| 167 | + // If a Set then the elements will only have a count of 1. |
| 168 | + // Iterate over the smaller set. |
| 169 | + intersection = (sizeA < sizeB) |
| 170 | + ? getIntersection((Set<T>) objectsA, (Set<T>) objectsB) |
| 171 | + : getIntersection((Set<T>) objectsB, (Set<T>) objectsA); |
| 172 | + } else { |
| 173 | + // Create a bag for each collection |
| 174 | + final TinyBag bagA = toBag(objectsA); |
| 175 | + final TinyBag bagB = toBag(objectsB); |
| 176 | + // Iterate over the smaller number of unique elements |
| 177 | + intersection = (bagA.uniqueElementSize() < bagB.uniqueElementSize()) |
| 178 | + ? getIntersection(bagA, bagB) |
| 179 | + : getIntersection(bagB, bagA); |
| 180 | + } |
| 181 | + |
| 182 | + return new IntersectionResult(sizeA, sizeB, intersection); |
| 183 | + } |
| 184 | + |
| 185 | + /** |
| 186 | + * Convert the collection to a bag. The bag will contain the count of each element |
| 187 | + * in the collection. |
| 188 | + * |
| 189 | + * @param objects the objects |
| 190 | + * @return the bag |
| 191 | + */ |
| 192 | + private TinyBag toBag(Collection<T> objects) { |
| 193 | + final TinyBag bag = new TinyBag(objects.size()); |
| 194 | + for (T t : objects) { |
| 195 | + bag.add(t); |
| 196 | + } |
| 197 | + return bag; |
| 198 | + } |
| 199 | + |
| 200 | + /** |
| 201 | + * Compute the intersection between two sets. This is the count of all the elements |
| 202 | + * that are within both sets. |
| 203 | + * |
| 204 | + * @param <T> the type of the elements in the set |
| 205 | + * @param setA the set A |
| 206 | + * @param setB the set B |
| 207 | + * @return the intersection |
| 208 | + */ |
| 209 | + private static <T> int getIntersection(Set<T> setA, Set<T> setB) { |
| 210 | + int intersection = 0; |
| 211 | + for (T element : setA) { |
| 212 | + if (setB.contains(element)) { |
| 213 | + intersection++; |
| 214 | + } |
| 215 | + } |
| 216 | + return intersection; |
| 217 | + } |
| 218 | + |
| 219 | + /** |
| 220 | + * Compute the intersection between two bags. This is the sum of the minimum |
| 221 | + * count of each element that is within both sets. |
| 222 | + * |
| 223 | + * @param bagA the bag A |
| 224 | + * @param bagB the bag B |
| 225 | + * @return the intersection |
| 226 | + */ |
| 227 | + private int getIntersection(TinyBag bagA, TinyBag bagB) { |
| 228 | + int intersection = 0; |
| 229 | + for (Entry<T, BagCount> entry : bagA.entrySet()) { |
| 230 | + final T element = entry.getKey(); |
| 231 | + final int count = entry.getValue().count; |
| 232 | + // The intersection of this entry in both bags is the minimum count |
| 233 | + intersection += Math.min(count, bagB.getCount(element)); |
| 234 | + } |
| 235 | + return intersection; |
| 236 | + } |
| 237 | +} |
0 commit comments