Skip to content

Commit 32ae5ef

Browse files
committed
Re-compute single and co-occur stats after every EM iteration.
1 parent 7e18d2c commit 32ae5ef

5 files changed

Lines changed: 159 additions & 75 deletions

File tree

dataset/dataset.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,21 @@ def load_data(self, name, fpath, na_values=None, entity_col=None, src_col=None):
135135
def set_constraints(self, constraints):
136136
self.constraints = constraints
137137

138+
def aux_table_exists(self, aux_table):
139+
"""
140+
get_aux_table returns True if :param aux_table: has been generated.
141+
142+
:param aux_table: (AuxTables(Enum)) auxiliary table to check
143+
"""
144+
return aux_table in self.aux_tables
145+
138146
def get_aux_table(self, aux_table):
139147
"""
140148
get_aux_table returns the Table associated with :param aux_table:.
141149
142-
:param aux_table: (AuxTables(Enum)) auxiliary table to check
150+
:param aux_table: (AuxTables(Enum)) auxiliary table to retrieve
143151
"""
144-
if aux_table not in self.aux_tables:
152+
if not self.aux_table_exists(aux_table):
145153
raise Exception("{} auxiliary table has not been generated".format(aux_table))
146154
return self.aux_tables[aux_table]
147155

@@ -218,14 +226,14 @@ def get_cell_id(self, tuple_id, attr_name):
218226

219227
def get_statistics(self):
220228
if not self.stats_ready:
221-
self._collect_stats()
229+
self.collect_stats()
222230
stats = (self.total_tuples, self.single_attr_stats, self.pair_attr_stats)
223231
self.stats_ready = True
224232
return stats
225233

226-
def _collect_stats(self):
234+
def collect_stats(self):
227235
"""
228-
_collect_stats memoizes:
236+
collect_stats calculates and memoizes: (based on current statistics)
229237
1. self.single_attr_stats ({ attribute -> Series (value -> count) })
230238
the frequency (# of entities) of a given attribute-value
231239
2. self.pair_attr_stats ({ attr1 -> { attr2 -> DataFrame } } where
@@ -236,7 +244,7 @@ def _collect_stats(self):
236244
Also known as co-occurrence count.
237245
"""
238246

239-
self.total_tuples = self.get_raw_data().shape[0]
247+
self.total_tuples = self.get_raw_data()['_tid_'].nunique()
240248
# Single attribute-value frequency
241249
for attr in self.get_attributes():
242250
self.single_attr_stats[attr] = self._get_stats_single(attr)
@@ -251,8 +259,21 @@ def _get_stats_single(self, attr):
251259
"""
252260
Returns a Series indexed on possible values for 'attr' and contains the frequency.
253261
"""
254-
tmp_df = self.get_raw_data()[[attr]].groupby([attr]).size()
255-
return tmp_df
262+
263+
# If cell_domain has not been initialized yet, retrieve statistics
264+
# from raw data (this happens when the domain is just being setup)
265+
if not self.aux_table_exists(AuxTables.cell_domain):
266+
return self.get_raw_data()[[attr]].groupby([attr]).size()
267+
268+
# Retrieve statistics on current value from cell_domain
269+
270+
df_domain = self.get_aux_table(AuxTables.cell_domain).df
271+
df_count = df_domain.loc[df_domain['attribute'] == attr, 'current_value'].value_counts()
272+
# We do not store attributes with only NULL values in cell_domain:
273+
# we require _nan_ in our single stats however
274+
if df_count.empty:
275+
return pd.Series(self.total_tuples, index=['_nan_'])
276+
return df_count
256277

257278
def _get_stats_pair(self, cond_attr, trg_attr):
258279
"""
@@ -261,8 +282,26 @@ def _get_stats_pair(self, cond_attr, trg_attr):
261282
<trg_attr>: all values for trg_attr that appeared at least once with <val1> ('val2')
262283
<count>: frequency (# of entities) where cond_attr: val1 AND trg_attr: val2
263284
"""
264-
tmp_df = self.get_raw_data()[[cond_attr,trg_attr]].groupby([cond_attr,trg_attr]).size().reset_index(name="count")
265-
return tmp_df
285+
# If cell_domain has not been initialized yet, retrieve statistics
286+
# from raw data (this happens when the domain is just being setup)
287+
if not self.aux_table_exists(AuxTables.cell_domain):
288+
return self.get_raw_data()[[cond_attr,trg_attr]].groupby([cond_attr,trg_attr]).size().reset_index(name="count")
289+
290+
# Retrieve pairwise statistics on current value from cell_domain
291+
292+
df_domain = self.get_aux_table(AuxTables.cell_domain).df
293+
# Filter cell_domain for only the attributes we care about
294+
df_domain = df_domain[df_domain['attribute'].isin([cond_attr, trg_attr])]
295+
# Convert to wide form so we have our :param cond_attr:
296+
# and :trg_attr: as columns along with the _tid_ column
297+
df_domain = df_domain[['_tid_', 'attribute', 'current_value']].pivot(index='_tid_', columns='attribute', values='current_value')
298+
# We do not store cells for attributes consisting of only NULL values in cell_domain.
299+
# We require this for pair stats though.
300+
if cond_attr not in df_domain.columns:
301+
df_domain[cond_attr] = '_nan_'
302+
if trg_attr not in df_domain.columns:
303+
df_domain[trg_attr] = '_nan_'
304+
return df_domain.groupby([cond_attr, trg_attr]).size().reset_index(name="count")
266305

267306
def get_domain_info(self):
268307
"""

evaluate/eval.py

Lines changed: 85 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@
77
from dataset import AuxTables
88
from dataset.table import Table, Source
99

10-
errors_template = Template('SELECT count(*) '\
11-
'FROM $init_table as t1, $grdt_table as t2 '\
12-
'WHERE t1._tid_ = t2._tid_ '\
13-
'AND t2._attribute_ = \'$attr\' '\
14-
'AND t1."$attr" != t2._value_')
10+
errors_template = Template("""
11+
SELECT
12+
count(*)
13+
FROM
14+
$raw_table as t1
15+
INNER JOIN
16+
$clean_table as t2
17+
ON
18+
t1._tid_ = t2._tid_
19+
AND t2._attribute_ = '$attr'
20+
AND t1."$attr" != t2._value_
21+
""")
1522

1623
"""
1724
The 'errors' aliased subquery returns the (_tid_, _attribute_, _value_)
@@ -23,15 +30,30 @@
2330
We then count the number of cells that we repaired to the correct ground
2431
truth value.
2532
"""
26-
correct_repairs_template = Template('SELECT COUNT(*) FROM'\
27-
'(SELECT t2._tid_, t2._attribute_, t2._value_ '\
28-
'FROM $init_table as t1, $grdt_table as t2 '\
29-
'WHERE t1._tid_ = t2._tid_ '\
30-
'AND t2._attribute_ = \'$attr\' '\
31-
'AND t1."$attr" != t2._value_ ) as errors, $inf_dom as repairs '\
32-
'WHERE errors._tid_ = repairs._tid_ '\
33-
'AND errors._attribute_ = repairs.attribute '\
34-
'AND errors._value_ = repairs.rv_value')
33+
correct_repairs_template = Template("""
34+
SELECT
35+
count(*)
36+
FROM (
37+
SELECT
38+
t2._tid_,
39+
t2._attribute_,
40+
t2._value_
41+
FROM
42+
$raw_table AS t1
43+
INNER JOIN
44+
$clean_table AS t2
45+
ON
46+
t1._tid_ = t2._tid_
47+
AND t2._attribute_ = '$attr'
48+
AND t1."$attr" != t2._value_
49+
) AS errors
50+
INNER JOIN
51+
$inf_dom AS repairs
52+
ON
53+
errors._tid_ = repairs._tid_
54+
AND errors._attribute_ = repairs.attribute
55+
AND errors._value_ = repairs.rv_value
56+
""")
3557

3658

3759
class EvalEngine:
@@ -64,7 +86,7 @@ def load_data(self, name, fpath, tid_col, attr_col, val_col, na_values=None):
6486

6587
def evaluate_repairs(self):
6688
self.compute_total_repairs()
67-
self.compute_total_repairs_grdt()
89+
self.compute_total_repairs_clean()
6890
self.compute_total_errors()
6991
self.compute_detected_errors()
7092
self.compute_correct_repairs()
@@ -79,10 +101,10 @@ def eval_report(self):
79101
tic = time.clock()
80102
try:
81103
prec, rec, rep_recall, f1, rep_f1 = self.evaluate_repairs()
82-
report = "Precision = %.2f, Recall = %.2f, Repairing Recall = %.2f, F1 = %.2f, Repairing F1 = %.2f, Detected Errors = %d, Total Errors = %d, Correct Repairs = %d, Total Repairs = %d, Total Repairs (Grdth present) = %d" % (
83-
prec, rec, rep_recall, f1, rep_f1, self.detected_errors, self.total_errors, self.correct_repairs, self.total_repairs, self.total_repairs_grdt)
104+
report = "Precision = %.2f, Recall = %.2f, Repairing Recall = %.2f, F1 = %.2f, Repairing F1 = %.2f, Detected Errors = %d, Total Errors = %d, Correct Repairs = %d, Total Repairs = %d, Total Repairs (clean data) = %d" % (
105+
prec, rec, rep_recall, f1, rep_f1, self.detected_errors, self.total_errors, self.correct_repairs, self.total_repairs, self.total_repairs_clean)
84106
report_list = [prec, rec, rep_recall, f1, rep_f1, self.detected_errors, self.total_errors,
85-
self.correct_repairs, self.total_repairs, self.total_repairs_grdt]
107+
self.correct_repairs, self.total_repairs, self.total_repairs_clean]
86108
except Exception as e:
87109
logging.error("ERROR generating evaluation report %s" % e)
88110
raise
@@ -91,70 +113,78 @@ def eval_report(self):
91113
return report, report_time, report_list
92114

93115
def compute_total_repairs(self):
116+
"""
117+
compute_total_repairs memoizes into self.total_repairs
118+
the number of cells where the initial value differs from the inferred
119+
value (i.e. the number of repairs) for the entities in the TRAINING data.
120+
"""
94121
# TODO(richardwu): how do we define a "repair" if we have multiple
95122
# init values?
96123
query = """
97124
SELECT
98125
count(*)
99126
FROM
100-
(SELECT
101-
_vid_
102-
FROM
103-
{cell_domain} AS t1,
104-
{inf_values_dom} as t2
105-
WHERE
106-
t1._tid_ = t2._tid_
107-
AND t1.attribute = t2.attribute
108-
AND t1.init_values != t2.rv_value
109-
) AS t
127+
{cell_domain} AS t1
128+
INNER JOIN
129+
{inf_values_dom} as t2
130+
ON
131+
t1._tid_ = t2._tid_
132+
AND t1.attribute = t2.attribute
133+
WHERE
134+
t1.init_values != t2.rv_value
110135
""".format(cell_domain=AuxTables.cell_domain.name,
111136
inf_values_dom=AuxTables.inf_values_dom.name)
112137
res = self.ds.engine.execute_query(query)
113138
self.total_repairs = float(res[0][0])
114139

115-
def compute_total_repairs_grdt(self):
140+
def compute_total_repairs_clean(self):
141+
"""
142+
compute_total_repairs_clean memoizes into self.total_repairs_clean
143+
the number of cells where the initial value differs from the inferred
144+
value (i.e. the number of repairs) for the entities in the TEST (clean) data.
145+
"""
116146
# TODO(richardwu): how do we define a "repair" if we have multiple
117147
# init values?
118148
query = """
119149
SELECT
120150
count(*)
121151
FROM
122-
(SELECT
123-
_vid_
124-
FROM
125-
{cell_domain} AS t1,
126-
{inf_values_dom} AS t2,
127-
{clean_data} AS t3
128-
WHERE
129-
t1._tid_ = t2._tid_
130-
AND t1.attribute = t2.attribute
131-
AND t1.init_values != t2.rv_value
132-
AND t1._tid_ = t3._tid_
133-
AND t1.attribute = t3._attribute_
134-
) AS t
152+
{cell_domain} AS t1
153+
INNER JOIN
154+
{inf_values_dom} AS t2
155+
ON
156+
t1._tid_ = t2._tid_
157+
AND t1.attribute = t2.attribute
158+
INNER JOIN
159+
{clean_data} AS t3
160+
ON
161+
t1._tid_ = t3._tid_
162+
AND t1.attribute = t3._attribute_
163+
WHERE
164+
t1.init_values != t2.rv_value
135165
""".format(cell_domain=AuxTables.cell_domain.name,
136166
inf_values_dom=AuxTables.inf_values_dom.name,
137167
clean_data=self.clean_data.name)
138168
res = self.ds.engine.execute_query(query)
139-
self.total_repairs_grdt = float(res[0][0])
169+
self.total_repairs_clean = float(res[0][0])
140170

141171
def compute_total_errors(self):
142172
queries = []
143173
total_errors = 0.0
144174
for attr in self.ds.get_attributes():
145-
query = errors_template.substitute(init_table=self.ds.raw_data.name, grdt_table=self.clean_data.name,
175+
query = errors_template.substitute(raw_table=self.ds.raw_data.name, clean_table=self.clean_data.name,
146176
attr=attr)
147177
queries.append(query)
148178
results = self.ds.engine.execute_queries(queries)
149179
for res in results:
150180
total_errors += float(res[0][0])
151181
self.total_errors = total_errors
152182

153-
def compute_total_errors_grdt(self):
183+
def compute_total_errors_clean(self):
154184
queries = []
155185
total_errors = 0.0
156186
for attr in self.ds.get_attributes():
157-
query = errors_template.substitute(init_table=self.ds.raw_data.name, grdt_table=self.clean_data.name,
187+
query = errors_template.substitute(raw_table=self.ds.raw_data.name, clean_table=self.clean_data.name,
158188
attr=attr)
159189
queries.append(query)
160190
results = self.ds.engine.execute_queries(queries)
@@ -163,6 +193,11 @@ def compute_total_errors_grdt(self):
163193
self.total_errors = total_errors
164194

165195
def compute_detected_errors(self):
196+
"""
197+
compute_detected_errors
198+
"""
199+
# TODO(richardwu): how do we define a "repair" if we have multiple
200+
# init values?
166201
query = """
167202
SELECT
168203
count(*)
@@ -177,7 +212,7 @@ def compute_detected_errors(self):
177212
t1._tid_ = t2._tid_
178213
AND t1._cid_ = t3._cid_
179214
AND t1.attribute = t2._attribute_
180-
AND t1.current_value != t2._value_
215+
AND t1.init_values != t2._value_
181216
) AS t
182217
""".format(cell_domain=AuxTables.cell_domain.name,
183218
clean_data=self.clean_data.name,
@@ -189,7 +224,7 @@ def compute_correct_repairs(self):
189224
queries = []
190225
correct_repairs = 0.0
191226
for attr in self.ds.get_attributes():
192-
query = correct_repairs_template.substitute(init_table=self.ds.raw_data.name, grdt_table=self.clean_data.name,
227+
query = correct_repairs_template.substitute(raw_table=self.ds.raw_data.name, clean_table=self.clean_data.name,
193228
attr=attr, inf_dom=AuxTables.inf_values_dom.name)
194229
queries.append(query)
195230
results = self.ds.engine.execute_queries(queries)
@@ -208,9 +243,9 @@ def compute_repairing_recall(self):
208243
return self.correct_repairs / self.detected_errors
209244

210245
def compute_precision(self):
211-
if self.total_repairs_grdt == 0:
246+
if self.total_repairs_clean == 0:
212247
return 0
213-
return self.correct_repairs / self.total_repairs_grdt
248+
return self.correct_repairs / self.total_repairs_clean
214249

215250
def compute_f1(self):
216251
prec = self.compute_precision()

holoclean.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def repair_errors(self, featurizers, em_iterations=1, em_iter_func=None):
272272
logging.debug('Time to retrieve featurizer weights: %.2f secs' % time)
273273
# Update current values with inferred values
274274
self.ds.update_current_values()
275+
# Re-compute statistics with new current values
276+
self.ds.collect_stats()
275277

276278
# Call em_iter_func if provided at the end of every EM iteration
277279
if em_iter_func is not None:

0 commit comments

Comments
 (0)