2626
2727import functools
2828import itertools
29+ import math
2930import os
3031import random
3132import re
33+ import time
3234
3335import gin
3436import gin .tf
3840from mesh_tensorflow .transformer import learning_rate_schedules
3941from mesh_tensorflow .transformer import transformer
4042import numpy as np
43+ import pandas as pd
4144import pkg_resources
4245import six
4346import tensorflow .compat .v1 as tf
@@ -1654,6 +1657,52 @@ def get_sequence_length(tokens, pad_id=0):
16541657 return scores
16551658
16561659
1660+ @gin .configurable
1661+ def save_scores_to_tfrecords (
1662+ results , vocabulary , scores_filename , shard_idx = 0 , save_ids_only = False ):
1663+ """Processes results from scoring examples and saves them to tfrecords files.
1664+
1665+ Args:
1666+ results: list of dictionaries containing the results for each scored
1667+ example.
1668+ vocabulary: a function that that returns a tf.data.Dataset with examples
1669+ containing the string field 'targets' and optionally the field 'inputs'
1670+ scores_filename: a string (path of file to write scores to).
1671+ shard_idx: an integer indicating the current index of the file for sharding.
1672+ save_ids_only: if true, save the ID that is prepended to the inputs.
1673+ """
1674+ results = _maybe_add_pretokenized_features (results , vocabulary )
1675+ scores = [r .get ("scores" , 0.0 ) for r in results ]
1676+ targets = [r .get ("targets_pretokenized" , r ["targets" ]) for r in results ]
1677+ inputs = [r .get ("targets_neg_pretokenized" , "" ) for r in results ]
1678+
1679+ if save_ids_only :
1680+ inputs = [r .split (" " , 1 )[0 ] for r in inputs ]
1681+
1682+ table_path = "{}_{}.tfrecord" .format (scores_filename , shard_idx )
1683+ tf .logging .info ("Saving results to {}" .format (table_path ))
1684+
1685+ with tf .io .TFRecordWriter (table_path ) as file_writer :
1686+ for input_ , target , score in zip (inputs , targets , scores ):
1687+ record_bytes = tf .train .Example (
1688+ features = tf .train .Features (
1689+ feature = {
1690+ "input" :
1691+ tf .train .Feature (
1692+ bytes_list = tf .train .BytesList (
1693+ value = [bytes (input_ , "utf8" )])),
1694+ "target" :
1695+ tf .train .Feature (
1696+ bytes_list = tf .train .BytesList (
1697+ value = [bytes (target , "utf8" )])),
1698+ "score" :
1699+ tf .train .Feature (
1700+ float_list = tf .train .FloatList (value = [score ])),
1701+ })).SerializeToString ()
1702+ file_writer .write (record_bytes )
1703+
1704+
1705+ @gin .configurable
16571706def score_with_estimator (estimator , input_fn , eval_checkpoint_step , model_dir ,
16581707 vocabulary , score_postprocess_fn = save_scores ,
16591708 num_examples = None ):
@@ -1691,6 +1740,70 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
16911740 return score_postprocess_fn (results , vocabulary )
16921741
16931742
1743+ @gin .configurable
1744+ def score_with_estimator_lazy (
1745+ estimator , input_fn , eval_checkpoint_step , model_dir ,
1746+ vocabulary , score_postprocess_fn = save_scores_to_tfrecords ,
1747+ num_examples = None , num_examples_per_shard = 10000 ):
1748+ """Score each example returned by input_fn lazily.
1749+
1750+ Args:
1751+ estimator: a TPUEstimator
1752+ input_fn: a function that that returns a tf.data.Dataset with examples
1753+ containing the string field 'targets' and optionally the field 'inputs'
1754+ eval_checkpoint_step: int, list of ints, or None, see `eval_model`
1755+ docstring.
1756+ model_dir: string, estimator model_dir
1757+ vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
1758+ targets_vocabulary) tuple
1759+ score_postprocess_fn: a function that takes in model outputs and
1760+ post-processes, saves, and returns them.
1761+ num_examples: int, the total # of examples being scored, None if unknown
1762+ num_examples_per_shard: int, the number of examples per file shard.
1763+
1764+ Returns:
1765+ a list of floats
1766+ """
1767+ if num_examples is not None :
1768+ num_shards = math .ceil (num_examples / num_examples_per_shard )
1769+ else :
1770+ num_shards = None
1771+ tf .logging .info (
1772+ "Scoring {} examples with {} shards at {} examples per shard" .format (
1773+ num_examples , num_shards , num_examples_per_shard ))
1774+
1775+ checkpoint_path , = get_checkpoint_iterator (
1776+ eval_checkpoint_step , model_dir )
1777+ result_iter = estimator .predict (input_fn , checkpoint_path = checkpoint_path )
1778+
1779+ start = time .time ()
1780+ results = []
1781+ shard_idx = 0
1782+
1783+ for i , result in enumerate (result_iter ):
1784+ results .append (result )
1785+ num_results = len (results )
1786+ exceeded_num_examples = num_examples is not None and i >= num_examples
1787+
1788+ if num_results >= num_examples_per_shard or exceeded_num_examples :
1789+ score_postprocess_fn (results , vocabulary , shard_idx = shard_idx )
1790+
1791+ elapsed = time .time () - start
1792+ tf .logging .info (
1793+ "Scored {} results in {} s, {} examples/s for shard {}" .format (
1794+ num_results , elapsed , num_results / elapsed , shard_idx ))
1795+
1796+ results = []
1797+ shard_idx += 1
1798+ start = time .time ()
1799+
1800+ if exceeded_num_examples :
1801+ break
1802+
1803+ if results :
1804+ score_postprocess_fn (results , vocabulary , shard_idx = shard_idx )
1805+
1806+
16941807def _maybe_add_pretokenized_features (examples , vocabulary ):
16951808 """Ensures decoded versions of "inputs" and "targets" exist in each example.
16961809
@@ -1712,9 +1825,17 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
17121825 for example in examples :
17131826 for feature_name in ["inputs" , "targets" ]:
17141827 pretokenized_feature_name = feature_name + "_pretokenized"
1828+ neg_pretokenized_feature_name = feature_name + "_neg_pretokenized"
17151829 if feature_name in example and pretokenized_feature_name not in example :
1716- s = vocabulary [feature_name ].decode (example [feature_name ].tolist ())
1830+ ids = example [feature_name ].tolist ()
1831+
1832+ neg_ids = [abs (i ) for i in ids if i < 0 ]
1833+ ids = [i for i in ids if i > 0 ]
1834+
1835+ s = vocabulary [feature_name ].decode (ids )
17171836 example [pretokenized_feature_name ] = s
1837+ neg_s = vocabulary [feature_name ].decode (neg_ids )
1838+ example [neg_pretokenized_feature_name ] = neg_s
17181839
17191840 if not added_pretokenized [feature_name ]:
17201841 added_pretokenized [feature_name ] = True
@@ -1730,7 +1851,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17301851 sequence_length , model_dir , eval_checkpoint_step ,
17311852 inputs = gin .REQUIRED , targets = gin .REQUIRED ,
17321853 score_postprocess_fn = gin .REQUIRED , eos_id = 1 ,
1733- score_eos = True ):
1854+ score_eos = True ,
1855+ score_with_estimator_fn = score_with_estimator ):
17341856 """Compute log likelihoods per example and write to a text file.
17351857
17361858 inputs & targets must either be the same length (in lines) or have inputs
@@ -1761,6 +1883,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17611883 score_eos: a boolean - whether to score the final eos token of each line
17621884 If this is set to false, the scores can be interpreted as prefix
17631885 log-likelihoods
1886+ score_with_estimator_fn: a function to run scoring with the estimator.
17641887 Returns:
17651888 a list of floats
17661889 """
@@ -1806,7 +1929,7 @@ def input_fn(params):
18061929 dataset = dataset .batch (batch_size , drop_remainder = True )
18071930 return dataset .prefetch (tf .data .experimental .AUTOTUNE )
18081931
1809- return score_with_estimator (
1932+ return score_with_estimator_fn (
18101933 estimator , input_fn , eval_checkpoint_step , model_dir ,
18111934 vocabulary , score_postprocess_fn , len (targets ))
18121935
@@ -1815,7 +1938,8 @@ def input_fn(params):
18151938def score_from_dataset (estimator , vocabulary , batch_size , sequence_length ,
18161939 model_dir , eval_checkpoint_step , dataset_split ,
18171940 score_dataset_fn = None ,
1818- score_postprocess_fn = gin .REQUIRED ):
1941+ score_postprocess_fn = gin .REQUIRED ,
1942+ score_with_estimator_fn = score_with_estimator ):
18191943 """Compute log likelihoods per example and write to a text file.
18201944
18211945 The function returns a list of floats representing the log-likelihood of the
@@ -1837,6 +1961,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18371961 See `eval_dataset_fn` argument to `eval_model` for details.
18381962 score_postprocess_fn: Function that takes in model outputs and
18391963 post-processes then returns then.
1964+ score_with_estimator_fn: a function to run scoring with the estimator.
18401965
18411966 Returns:
18421967 scores: a list of floats, the log likelihood scores
@@ -1850,9 +1975,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18501975 input_fn = _get_combined_dataset_input_fn (
18511976 scoring_datasets , batch_size , sequence_length )
18521977
1853- return score_with_estimator (
1978+ return score_with_estimator_fn (
18541979 estimator , input_fn , eval_checkpoint_step , model_dir ,
1855- vocabulary , score_postprocess_fn , None )
1980+ vocabulary , score_postprocess_fn )
18561981
18571982
18581983def get_estimator (model_type , vocabulary , mesh_shape ,
@@ -2093,7 +2218,8 @@ def eval_model(estimator,
20932218 eval_checkpoint_step ,
20942219 eval_with_score = False ,
20952220 output_eval_examples = True ,
2096- eval_dir_suffix = None ):
2221+ eval_dir_suffix = None ,
2222+ score_with_estimator_fn = score_with_estimator ):
20972223 """Eval a Mesh-TF model.
20982224
20992225 Args:
@@ -2137,6 +2263,7 @@ def eval_model(estimator,
21372263 of the eval examples in plaintext to eval_summary_dir.
21382264 eval_dir_suffix: string, if not None then will appended to the
21392265 eval_summary_dir.
2266+ score_with_estimator_fn: a function to run scoring with the estimator.
21402267 """
21412268 if eval_dataset_fn is None :
21422269 raise ValueError ("Must provide eval_dataset_fn through gin for eval." )
@@ -2248,7 +2375,7 @@ def eval_model(estimator,
22482375 tf .logging .info ("Checkpoint path %s" % checkpoint_path )
22492376 global_step = int (get_step_from_checkpoint_path (checkpoint_path ))
22502377 if eval_with_score :
2251- outputs , _ = score_with_estimator (
2378+ outputs , _ = score_with_estimator_fn (
22522379 estimator , input_fn , global_step , model_dir , vocabulary ,
22532380 num_examples = sum (len (cex ) for cex in cached_examples .values ()))
22542381 else :
0 commit comments