|
| 1 | +package com.spbsu.flamestream.example.benchmark; |
| 2 | + |
| 3 | +import akka.actor.ActorSystem; |
| 4 | +import com.spbsu.flamestream.example.bl.text_classifier.TextClassifierGraph; |
| 5 | +import com.spbsu.flamestream.example.bl.text_classifier.model.Prediction; |
| 6 | +import com.spbsu.flamestream.example.bl.text_classifier.model.TextDocument; |
| 7 | +import com.spbsu.flamestream.example.bl.text_classifier.ops.classifier.CountVectorizer; |
| 8 | +import com.spbsu.flamestream.example.bl.text_classifier.ops.classifier.TextUtils; |
| 9 | +import com.spbsu.flamestream.example.bl.text_classifier.ops.classifier.Vectorizer; |
| 10 | +import com.spbsu.flamestream.example.bl.text_classifier.ops.classifier.ftrl.FTRLProximal; |
| 11 | +import com.spbsu.flamestream.runtime.FlameRuntime; |
| 12 | +import com.spbsu.flamestream.runtime.LocalClusterRuntime; |
| 13 | +import com.spbsu.flamestream.runtime.LocalRuntime; |
| 14 | +import com.spbsu.flamestream.runtime.RemoteRuntime; |
| 15 | +import com.spbsu.flamestream.runtime.config.ClusterConfig; |
| 16 | +import com.spbsu.flamestream.runtime.config.SystemConfig; |
| 17 | +import com.spbsu.flamestream.runtime.config.ZookeeperWorkersNode; |
| 18 | +import com.spbsu.flamestream.runtime.edge.akka.AkkaFront; |
| 19 | +import com.spbsu.flamestream.runtime.edge.akka.AkkaFrontType; |
| 20 | +import com.spbsu.flamestream.runtime.edge.akka.AkkaRearType; |
| 21 | +import com.spbsu.flamestream.runtime.serialization.KryoSerializer; |
| 22 | +import com.spbsu.flamestream.runtime.utils.AwaitCountConsumer; |
| 23 | +import com.typesafe.config.Config; |
| 24 | +import com.typesafe.config.ConfigFactory; |
| 25 | +import org.apache.commons.csv.CSVFormat; |
| 26 | +import org.apache.commons.csv.CSVParser; |
| 27 | +import org.apache.commons.csv.CSVRecord; |
| 28 | +import org.apache.commons.io.FileUtils; |
| 29 | +import org.apache.curator.framework.CuratorFramework; |
| 30 | +import org.apache.curator.framework.CuratorFrameworkFactory; |
| 31 | +import org.apache.curator.retry.ExponentialBackoffRetry; |
| 32 | +import org.jooq.lambda.Seq; |
| 33 | +import org.slf4j.Logger; |
| 34 | +import org.slf4j.LoggerFactory; |
| 35 | + |
| 36 | +import java.io.BufferedReader; |
| 37 | +import java.io.File; |
| 38 | +import java.io.FileInputStream; |
| 39 | +import java.io.IOException; |
| 40 | +import java.io.InputStreamReader; |
| 41 | +import java.io.PrintWriter; |
| 42 | +import java.io.UncheckedIOException; |
| 43 | +import java.net.InetSocketAddress; |
| 44 | +import java.nio.charset.StandardCharsets; |
| 45 | +import java.nio.file.Files; |
| 46 | +import java.nio.file.Paths; |
| 47 | +import java.util.Arrays; |
| 48 | +import java.util.Collections; |
| 49 | +import java.util.HashMap; |
| 50 | +import java.util.Iterator; |
| 51 | +import java.util.LinkedHashMap; |
| 52 | +import java.util.List; |
| 53 | +import java.util.Map; |
| 54 | +import java.util.Random; |
| 55 | +import java.util.Spliterator; |
| 56 | +import java.util.Spliterators; |
| 57 | +import java.util.StringJoiner; |
| 58 | +import java.util.concurrent.ThreadLocalRandom; |
| 59 | +import java.util.concurrent.TimeUnit; |
| 60 | +import java.util.concurrent.atomic.AtomicInteger; |
| 61 | +import java.util.concurrent.locks.LockSupport; |
| 62 | +import java.util.regex.Matcher; |
| 63 | +import java.util.regex.Pattern; |
| 64 | +import java.util.stream.Collectors; |
| 65 | +import java.util.stream.Stream; |
| 66 | +import java.util.stream.StreamSupport; |
| 67 | + |
| 68 | +import static com.google.common.math.Quantiles.percentiles; |
| 69 | +import static java.util.stream.Collectors.toList; |
| 70 | + |
| 71 | +/** |
| 72 | + * User: Artem |
| 73 | + * Date: 28.12.2017 |
| 74 | + */ |
| 75 | +public class LentaBenchStand { |
| 76 | + private static final Logger LOG = LoggerFactory.getLogger(LentaBenchStand.class); |
| 77 | + |
| 78 | + private static Stream<TextDocument> documents(String path) throws IOException { |
| 79 | + final BufferedReader reader = new BufferedReader(new InputStreamReader( |
| 80 | + new FileInputStream(path), |
| 81 | + StandardCharsets.UTF_8 |
| 82 | + )); |
| 83 | + final CSVParser csvFileParser = new CSVParser(reader, CSVFormat.DEFAULT); |
| 84 | + { // skip headers |
| 85 | + Iterator<CSVRecord> iter = csvFileParser.iterator(); |
| 86 | + iter.next(); |
| 87 | + } |
| 88 | + |
| 89 | + final Spliterator<CSVRecord> csvSpliterator = Spliterators.spliteratorUnknownSize( |
| 90 | + csvFileParser.iterator(), |
| 91 | + Spliterator.IMMUTABLE |
| 92 | + ); |
| 93 | + AtomicInteger counter = new AtomicInteger(0); |
| 94 | + return StreamSupport.stream(csvSpliterator, false).map(r -> { |
| 95 | + Pattern p = Pattern.compile("\\w+", Pattern.UNICODE_CHARACTER_CLASS); |
| 96 | + String recordText = r.get(2); // text order |
| 97 | + Matcher m = p.matcher(recordText); |
| 98 | + StringJoiner text = new StringJoiner(" "); |
| 99 | + while (m.find()) { |
| 100 | + text.add(m.group()); |
| 101 | + } |
| 102 | + return new TextDocument( |
| 103 | + r.get(0), // url order |
| 104 | + text.toString().toLowerCase(), |
| 105 | + String.valueOf(ThreadLocalRandom.current().nextInt(0, 200)), |
| 106 | + counter.incrementAndGet(), |
| 107 | + r.get(4) // label |
| 108 | + ); |
| 109 | + }); |
| 110 | + |
| 111 | + } |
| 112 | + |
| 113 | + public static class StandConfig { |
| 114 | + private final String wikiDumpPath; |
| 115 | + private final String validatorClass; |
| 116 | + private final int sleepBetweenDocsMs; |
| 117 | + private final String cntVectorizerPath; |
| 118 | + private final String topicsPath; |
| 119 | + |
| 120 | + StandConfig(Config config) { |
| 121 | + wikiDumpPath = config.getString("lenta-dump-path"); |
| 122 | + cntVectorizerPath = config.getString("cnt-vectorizer-path"); |
| 123 | + topicsPath = config.getString("topics-path"); |
| 124 | + validatorClass = config.getString("validator"); |
| 125 | + sleepBetweenDocsMs = config.getInt("sleep-between-docs-ms"); |
| 126 | + } |
| 127 | + |
| 128 | + int sleepBetweenDocsMs() { |
| 129 | + return this.sleepBetweenDocsMs; |
| 130 | + } |
| 131 | + |
| 132 | + String wikiDumpPath() { |
| 133 | + return wikiDumpPath; |
| 134 | + } |
| 135 | + |
| 136 | + String validatorClass() { |
| 137 | + return validatorClass; |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + public static void main(String[] args) throws Exception { |
| 142 | + final Config benchConfig; |
| 143 | + final Config deployerConfig; |
| 144 | + //noinspection Duplicates |
| 145 | + if (args.length == 2) { |
| 146 | + benchConfig = ConfigFactory.parseReader(Files.newBufferedReader(Paths.get(args[0]))).getConfig("benchmark"); |
| 147 | + deployerConfig = ConfigFactory.parseReader(Files.newBufferedReader(Paths.get(args[1]))).getConfig("deployer"); |
| 148 | + } else { |
| 149 | + benchConfig = ConfigFactory.load("bench.conf").getConfig("benchmark"); |
| 150 | + deployerConfig = ConfigFactory.load("deployer.conf").getConfig("deployer"); |
| 151 | + } |
| 152 | + final StandConfig standConfig = new StandConfig(benchConfig); |
| 153 | + final Vectorizer vectorizer = new CountVectorizer(standConfig.cntVectorizerPath); |
| 154 | + //noinspection unchecked |
| 155 | + final BenchValidator<Prediction> validator = ((Class<? extends BenchValidator>) Class.forName(standConfig.validatorClass())) |
| 156 | + .newInstance(); |
| 157 | + final AwaitCountConsumer awaitConsumer = new AwaitCountConsumer(validator.expectedOutputSize()); |
| 158 | + final Map<Integer, LatencyMeasurer> latenciesMap = Collections.synchronizedMap(new LinkedHashMap<>()); |
| 159 | + |
| 160 | + final Map<String, String> props = new HashMap<>(); |
| 161 | + final String[] localAddressHostAndPort = System.getenv("LOCAL_ADDRESS").split(":"); |
| 162 | + final InetSocketAddress socketAddress = |
| 163 | + new InetSocketAddress(localAddressHostAndPort[0], Integer.parseInt(localAddressHostAndPort[1])); |
| 164 | + props.put("akka.remote.artery.canonical.hostname", socketAddress.getHostString()); |
| 165 | + props.put("akka.remote.artery.canonical.port", String.valueOf(socketAddress.getPort())); |
| 166 | + props.put("akka.remote.artery.bind.hostname", "0.0.0.0"); |
| 167 | + props.put("akka.remote.artery.bind.port", String.valueOf(socketAddress.getPort())); |
| 168 | + try { |
| 169 | + final File shm = new File(("/dev/shm")); |
| 170 | + if (shm.exists() && shm.isDirectory()) { |
| 171 | + final String aeronDir = "/dev/shm/aeron-bench"; |
| 172 | + FileUtils.deleteDirectory(new File(aeronDir)); |
| 173 | + props.put("akka.remote.artery.advanced.aeron-dir", aeronDir); |
| 174 | + } |
| 175 | + } catch (IOException e) { |
| 176 | + throw new UncheckedIOException(e); |
| 177 | + } |
| 178 | + final ActorSystem system = ActorSystem.create( |
| 179 | + "lentaTfIdf", |
| 180 | + ConfigFactory.parseMap(props).withFallback(ConfigFactory.load("remote")) |
| 181 | + ); |
| 182 | + |
| 183 | + final FlameRuntime runtime; |
| 184 | + if (deployerConfig.hasPath("local")) { |
| 185 | + runtime = new LocalRuntime.Builder().parallelism(deployerConfig.getConfig("local").getInt("parallelism")).build(); |
| 186 | + } else if (deployerConfig.hasPath("local-cluster")) { |
| 187 | + runtime = new LocalClusterRuntime( |
| 188 | + deployerConfig.getConfig("local-cluster").getInt("parallelism"), |
| 189 | + new SystemConfig.Builder().millisBetweenCommits(1000).maxElementsInGraph(100).build() |
| 190 | + ); |
| 191 | + } else { |
| 192 | + // temporary solution to keep bench stand in working state |
| 193 | + final String zkString = deployerConfig.getConfig("remote").getString("zk"); |
| 194 | + final CuratorFramework curator = CuratorFrameworkFactory.newClient( |
| 195 | + zkString, |
| 196 | + new ExponentialBackoffRetry(1000, 3) |
| 197 | + ); |
| 198 | + curator.start(); |
| 199 | + try { |
| 200 | + final ZookeeperWorkersNode zookeeperWorkersNode = new ZookeeperWorkersNode(curator, "/workers"); |
| 201 | + runtime = new RemoteRuntime( |
| 202 | + curator, |
| 203 | + new KryoSerializer(), |
| 204 | + ClusterConfig.fromWorkers(zookeeperWorkersNode.workers()) |
| 205 | + ); |
| 206 | + } catch (Exception e) { |
| 207 | + throw new RuntimeException(e); |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + final int skip = 500; |
| 212 | + try (final FlameRuntime.Flame flame = runtime.run(new TextClassifierGraph( |
| 213 | + vectorizer, |
| 214 | + FTRLProximal.builder() |
| 215 | + .alpha(132) |
| 216 | + .beta(0.1) |
| 217 | + .lambda1(0.5) |
| 218 | + .lambda2(0.095) |
| 219 | + .build(TextUtils.readTopics(standConfig.topicsPath)) |
| 220 | + ).get())) { |
| 221 | + flame.attachRear("tfidfRear", new AkkaRearType<>(system, Prediction.class)) |
| 222 | + .forEach(r -> r.addListener(prediction -> { |
| 223 | + final int docId = prediction.tfIdf().number(); |
| 224 | + latenciesMap.get(docId).finish(); |
| 225 | + validator.accept(prediction); |
| 226 | + awaitConsumer.accept(prediction); |
| 227 | + LOG.info("Progress: {}/{}", awaitConsumer.got(), awaitConsumer.expected()); |
| 228 | + })); |
| 229 | + final List<AkkaFront.FrontHandle<TextDocument>> handles = flame |
| 230 | + .attachFront("tfidfFront", new AkkaFrontType<TextDocument>(system, false)) |
| 231 | + .collect(toList()); |
| 232 | + |
| 233 | + final AkkaFront.FrontHandle<TextDocument> front = handles.get(0); |
| 234 | + for (int i = 1; i < handles.size(); i++) { |
| 235 | + handles.get(i).unregister(); |
| 236 | + } |
| 237 | + |
| 238 | + final int sleepBetweenDocsMs = standConfig.sleepBetweenDocsMs(); |
| 239 | + LOG.info("Sleep between docs: {}", sleepBetweenDocsMs); |
| 240 | + |
| 241 | + final long[] start = {-1}; |
| 242 | + final int[] skipped = {0}; |
| 243 | + documents(standConfig.wikiDumpPath()).forEach(textDocument -> { |
| 244 | + front.accept(textDocument); |
| 245 | + latenciesMap.put(textDocument.number(), new LatencyMeasurer()); |
| 246 | + skipped[0] += 1; |
| 247 | + if (skipped[0] > skip && start[0] == -1) { |
| 248 | + start[0] = System.currentTimeMillis(); |
| 249 | + } |
| 250 | + LockSupport.parkNanos(TimeUnit.MILLISECONDS.toNanos(sleepBetweenDocsMs)); |
| 251 | + }); |
| 252 | + awaitConsumer.await(60, TimeUnit.MINUTES); |
| 253 | + final double timeInSec = (System.currentTimeMillis() - start[0]) / 1000.0; |
| 254 | + LOG.info("Throughput (items/sec): {}", (validator.expectedOutputSize() - skip) / timeInSec); |
| 255 | + } |
| 256 | + |
| 257 | + final long[] latencies = latenciesMap.values().stream().mapToLong(l -> l.statistics().getMax()).toArray(); |
| 258 | + try (final PrintWriter pw = new PrintWriter(Files.newBufferedWriter(Paths.get("/tmp/lat.data")))) { |
| 259 | + final String joined = Arrays.stream(latencies).mapToObj(Long::toString).collect(Collectors.joining(", ")); |
| 260 | + pw.println(joined); |
| 261 | + LOG.info("Result: {}", Arrays.toString(latencies)); |
| 262 | + final List<Integer> collect = Seq.seq(documents(standConfig.wikiDumpPath())) |
| 263 | + .limit(validator.inputLimit()) |
| 264 | + .shuffle(new Random(1)) |
| 265 | + .map(p -> p.content().length()) |
| 266 | + .collect(Collectors.toList()); |
| 267 | + LOG.info("Page sizes: {}", collect); |
| 268 | + final long[] skipped = latenciesMap.values() |
| 269 | + .stream() |
| 270 | + .skip(skip) |
| 271 | + .mapToLong(l -> l.statistics().getMax()) |
| 272 | + .toArray(); |
| 273 | + LOG.info("Median: {}", (long) percentiles().index(50).compute(skipped)); |
| 274 | + LOG.info("75%: {}", ((long) percentiles().index(75).compute(skipped))); |
| 275 | + LOG.info("90%: {}", (long) percentiles().index(90).compute(skipped)); |
| 276 | + LOG.info("99%: {}", (long) percentiles().index(99).compute(skipped)); |
| 277 | + } catch (IOException e) { |
| 278 | + throw new UncheckedIOException(e); |
| 279 | + } |
| 280 | + |
| 281 | + System.exit(0); |
| 282 | + } |
| 283 | +} |
0 commit comments