Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,21 @@ def get_domain_info(self):

def get_inferred_values(self):
tic = time.clock()
query = "SELECT t1._tid_, t1.attribute, domain[inferred_assignment + 1] as rv_value " \
"FROM " \
"(SELECT _tid_, attribute, " \
"_vid_, init_value, string_to_array(regexp_replace(domain, \'[{\"\"}]\', \'\', \'gi\'), \'|||\') as domain " \
"FROM %s) as t1, %s as t2 " \
"WHERE t1._vid_ = t2._vid_"%(AuxTables.cell_domain.name, AuxTables.inf_values_idx.name)
query = """
SELECT t1._tid_,
t1.attribute,
domain[inferred_assignment + 1] AS rv_value
FROM (
SELECT _tid_,
attribute,
_vid_,
current_value,
string_to_array(regexp_replace(domain, \'[{{\"\"}}]\', \'\', \'gi\'), \'|||\') AS domain
FROM {cell_domain}) AS t1,
{inf_values_idx} AS t2
WHERE t1._vid_ = t2._vid_
""".format(cell_domain=AuxTables.cell_domain.name,
inf_values_idx=AuxTables.inf_values_idx.name)
self.generate_aux_table_sql(AuxTables.inf_values_dom, query, index_attrs=['_tid_'])
self.aux_table[AuxTables.inf_values_dom].create_db_index(self.engine, ['attribute'])
status = "DONE collecting the inferred values."
Expand Down
125 changes: 79 additions & 46 deletions domain/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from dataset import AuxTables


class DomainEngine:
def __init__(self, env, dataset, cor_strength = 0.1, sampling_prob=0.3, max_sample=5):
"""
Expand Down Expand Up @@ -74,21 +73,33 @@ def store_domains(self, domain):
pos_values schema:
_tid_: entity/tuple ID
_cid_: cell ID
_vid_: random variable ID (all cells with more than 1 domain value)
_

_vid_: random variable ID (1-1 with _cid_)
attribute: name of attribute
rv_val: cell value
val_id: domain index of rv_val
"""
if domain.empty:
raise Exception("ERROR: Generated domain is empty.")
else:
self.ds.generate_aux_table(AuxTables.cell_domain, domain, store=True, index_attrs=['_vid_'])
self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_tid_'])
self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_cid_'])
query = "SELECT _vid_, _cid_, _tid_, attribute, a.rv_val, a.val_id from %s , unnest(string_to_array(regexp_replace(domain,\'[{\"\"}]\',\'\',\'gi\'),\'|||\')) WITH ORDINALITY a(rv_val,val_id)" % AuxTables.cell_domain.name
self.ds.generate_aux_table_sql(AuxTables.pos_values, query, index_attrs=['_tid_', 'attribute'])

self.ds.generate_aux_table(AuxTables.cell_domain, domain, store=True, index_attrs=['_vid_'])
self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_tid_'])
self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_cid_'])
query = """
SELECT
_vid_,
_cid_,
_tid_,
attribute,
a.rv_val,
a.val_id
FROM
{cell_domain},
unnest(string_to_array(regexp_replace(domain,\'[{{\"\"}}]\',\'\',\'gi\'),\'|||\')) WITH ORDINALITY a(rv_val,val_id)
""".format(cell_domain=AuxTables.cell_domain.name)
self.ds.generate_aux_table_sql(AuxTables.pos_values, query, index_attrs=['_tid_', 'attribute'])

def setup_attributes(self):
self.active_attributes = self.get_active_attributes()
self.active_attributes = self.fetch_active_attributes()
total, single_stats, pair_stats = self.ds.get_statistics()
self.total = total
self.single_stats = single_stats
Expand Down Expand Up @@ -132,9 +143,9 @@ def _topk_pair_stats(self, pair_stats):
out[attr1][attr2][val1] = top_cands
return out

def get_active_attributes(self):
def fetch_active_attributes(self):
"""
get_active_attributes returns the attributes to be modeled.
fetch_active_attributes fetches/refetches the attributes to be modeled.
These attributes correspond only to attributes that contain at least
one potentially erroneous cell.
"""
Expand Down Expand Up @@ -177,10 +188,13 @@ def generate_domain(self):
_cid_: cell ID (unique for every entity-attribute)
_vid_: variable ID (1-1 correspondence with _cid_)
attribute: attribute name
attribute_idx: index of attribute
domain: ||| seperated string of domain values
domain_size: length of domain
init_value: initial value for this cell
init_value_idx: domain index of init_value
init_values: initial values for this cell
init_values_idx: domain indexes of init_values
current_value: current value (current predicted)
current_value_idx: domain index for current value
fixed: 1 if a random sample was taken since no correlated attributes/top K values
"""

Expand All @@ -195,33 +209,42 @@ def generate_domain(self):
for row in tqdm(list(records)):
tid = row['_tid_']
app = []

# Iterate over each active attribute (attributes that have at
# least one dk cell) and generate for this cell:
# 1) the domain values
# 2) the initial values (taken from raw data)
# 3) the current value (best predicted value)
for attr in self.active_attributes:
init_value, dom = self.get_domain_cell(attr, row)
init_value_idx = dom.index(init_value)
if len(dom) > 1:
cid = self.ds.get_cell_id(tid, attr)
app.append({"_tid_": tid, "attribute": attr, "_cid_": cid, "_vid_":vid, "domain": "|||".join(dom), "domain_size": len(dom),
"init_value": init_value, "init_index": init_value_idx, "fixed":0})
vid += 1
else:
add_domain = self.get_random_domain(attr,init_value)
# Check if attribute has more than one unique values
if len(add_domain) > 0:
dom.extend(self.get_random_domain(attr,init_value))
cid = self.ds.get_cell_id(tid, attr)
app.append({"_tid_": tid, "attribute": attr, "_cid_": cid, "_vid_": vid, "domain": "|||".join(dom),
"domain_size": len(dom),
"init_value": init_value, "init_index": init_value_idx, "fixed": 1})
vid += 1
init_values, current_value, dom = self.get_domain_cell(attr, row)
init_values_idx = [dom.index(val) for val in init_values]
current_value_idx = dom.index(current_value)
cid = self.ds.get_cell_id(tid, attr)
fixed = 0

# If domain could not be generated from correlated attributes,
# randomly choose values to add to our domain.
if len(dom) == 1:
fixed = 1
add_domain = self.get_random_domain(attr, init_values)
dom.extend(add_domain)

app.append({"_tid_": tid, "_cid_": cid, "_vid_":vid,
"attribute": attr, "attribute_idx": self.ds.attr_to_idx[attr],
"domain": '|||'.join(dom), "domain_size": len(dom),
"init_values": '|||'.join(init_values), "init_values_idx": '|||'.join(map(str,init_values_idx)),
"current_value": current_value, "current_value_idx": current_value_idx,
"fixed": fixed})
vid+=1
cells.extend(app)
domain_df = pd.DataFrame(data=cells)
logging.info('DONE generating domain')
return domain_df

def get_domain_cell(self, attr, row):
"""
get_domain_cell returns a list of all domain values for the given
entity (row) and attribute.
get_domain_cell returns list of init values, current (best predicted)
value, and list of domain values for the given cell.

We define domain values as values in 'attr' that co-occur with values
in attributes ('cond_attr') that are correlated with 'attr' at least in
Expand All @@ -237,10 +260,10 @@ def get_domain_cell(self, attr, row):

This would produce [B,C,E] as domain values.

:return: (initial value of entity-attribute, domain values for entity-attribute).
:return: (list of initial values, current value, list of domain values).
"""

domain = set([])
domain = set()
correlated_attributes = self.get_corr_attributes(attr)
# Iterate through all attributes correlated at least self.cor_strength ('cond_attr')
# and take the top K co-occurrence values for 'attr' with the current
Expand All @@ -265,25 +288,35 @@ def get_domain_cell(self, attr, row):

# Remove _nan_ if added due to correlated attributes
domain.discard('_nan_')

# Add initial value in domain
if pd.isnull(row[attr]):
domain.update(set(['_nan_']))
init_value = '_nan_'
else:
domain.update(set([row[attr]]))
init_value = row[attr]
return init_value, list(domain)
init_values = ['_nan_']
if not pd.isnull(row[attr]):
# Assume value in raw dataset is given as ||| separate initial values
init_values = row[attr].split('|||')
domain.update(set(init_values))

def get_random_domain(self, attr, cur_value):
# Take the first initial value as the current value
# TODO(richardwu): revisit how we should initialize 'current'
current_value = init_values[0]

return init_values, current_value, list(domain)

def get_random_domain(self, attr, init_values):
"""
get_random_domain returns a random sample of at most size
'self.max_sample' of domain values for 'attr' that is NOT 'cur_value'.
'self.max_sample' of domain values for :param attr: that is NOT any
of :param init_values:

:param attr: (str) name of attribute to generate random domain for
:param init_values: (list[str]) list of initial values
"""

if random.random() > self.sampling_prob:
return []
domain_pool = set(self.single_stats[attr].keys())
domain_pool.discard(cur_value)
# Do not include initial values in random domain
domain_pool = domain_pool.difference(init_values)
size = len(domain_pool)
if size > 0:
k = min(self.max_sample, size)
Expand Down
78 changes: 55 additions & 23 deletions evaluate/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,26 +91,46 @@ def eval_report(self):
return report, report_time, report_list

def compute_total_repairs(self):
query = "SELECT count(*) FROM " \
"(SELECT _vid_ " \
"FROM %s as t1, %s as t2 " \
"WHERE t1._tid_ = t2._tid_ " \
"AND t1.attribute = t2.attribute " \
"AND t1.init_value != t2.rv_value) AS t"\
%(AuxTables.cell_domain.name, AuxTables.inf_values_dom.name)
query = """
SELECT
count(*)
FROM
(SELECT
_vid_
FROM
{cell_domain} AS t1,
{inf_values_dom} as t2
WHERE
t1._tid_ = t2._tid_
AND t1.attribute = t2.attribute
AND t1.current_value != t2.rv_value
) AS t
""".format(cell_domain=AuxTables.cell_domain.name,
inf_values_dom=AuxTables.inf_values_dom.name)
res = self.ds.engine.execute_query(query)
self.total_repairs = float(res[0][0])

def compute_total_repairs_grdt(self):
query = "SELECT count(*) FROM " \
"(SELECT _vid_ " \
"FROM %s as t1, %s as t2, %s as t3 " \
"WHERE t1._tid_ = t2._tid_ " \
"AND t1.attribute = t2.attribute " \
"AND t1.init_value != t2.rv_value " \
"AND t1._tid_ = t3._tid_ " \
"AND t1.attribute = t3._attribute_) AS t"\
%(AuxTables.cell_domain.name, AuxTables.inf_values_dom.name, self.clean_data.name)
query = """
SELECT
count(*)
FROM
(SELECT
_vid_
FROM
{cell_domain} AS t1,
{inf_values_dom} AS t2,
{clean_data} AS t3
WHERE
t1._tid_ = t2._tid_
AND t1.attribute = t2.attribute
AND t1.current_value != t2.rv_value
AND t1._tid_ = t3._tid_
AND t1.attribute = t3._attribute_
) AS t
""".format(cell_domain=AuxTables.cell_domain.name,
inf_values_dom=AuxTables.inf_values_dom.name,
clean_data=self.clean_data.name)
res = self.ds.engine.execute_query(query)
self.total_repairs_grdt = float(res[0][0])

Expand Down Expand Up @@ -139,13 +159,25 @@ def compute_total_errors_grdt(self):
self.total_errors = total_errors

def compute_detected_errors(self):
query = "SELECT count(*) FROM " \
"(SELECT _vid_ " \
"FROM %s as t1, %s as t2, %s as t3 " \
"WHERE t1._tid_ = t2._tid_ AND t1._cid_ = t3._cid_ " \
"AND t1.attribute = t2._attribute_ " \
"AND t1.init_value != t2._value_) AS t" \
% (AuxTables.cell_domain.name, self.clean_data.name, AuxTables.dk_cells.name)
query = """
SELECT
count(*)
FROM
(SELECT
_vid_
FROM
{cell_domain} AS t1,
{clean_data} AS t2,
{dk_cells} AS t3
WHERE
t1._tid_ = t2._tid_
AND t1._cid_ = t3._cid_
AND t1.attribute = t2._attribute_
AND t1.current_value != t2._value_
) AS t
""".format(cell_domain=AuxTables.cell_domain.name,
clean_data=self.clean_data.name,
dk_cells=AuxTables.dk_cells.name)
res = self.ds.engine.execute_query(query)
self.detected_errors = float(res[0][0])

Expand Down
8 changes: 4 additions & 4 deletions examples/holoclean_repair_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import holoclean
from detect import NullDetector, ViolationDetector
from repair.featurize import InitFeaturizer
from repair.featurize import InitAttFeaturizer
from repair.featurize import InitSimFeaturizer
from repair.featurize import CurrentFeaturizer
from repair.featurize import CurrentAttrFeaturizer
from repair.featurize import CurrentSimFeaturizer
from repair.featurize import FreqFeaturizer
from repair.featurize import OccurFeaturizer
from repair.featurize import ConstraintFeat
Expand All @@ -23,7 +23,7 @@

# 4. Repair errors utilizing the defined features.
hc.setup_domain()
featurizers = [InitAttFeaturizer(learnable=False), InitSimFeaturizer(), FreqFeaturizer(), OccurFeaturizer(), LangModelFeat(), ConstraintFeat()]
featurizers = [CurrentAttrFeaturizer(learnable=False), CurrentSimFeaturizer(), FreqFeaturizer(), OccurFeaturizer(), LangModelFeat(), ConstraintFeat()]
hc.repair_errors(featurizers)

# 5. Evaluate the correctness of the results.
Expand Down
14 changes: 7 additions & 7 deletions repair/featurize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .featurize import FeaturizedDataset
from .featurizer import Featurizer
from .initfeat import InitFeaturizer
from .initsimfeat import InitSimFeaturizer
from .freqfeat import FreqFeaturizer
from .occurfeat import OccurFeaturizer
from .constraintfeat import ConstraintFeat
from .currentfeat import CurrentFeaturizer
from .currentattrfeat import CurrentAttrFeaturizer
from .currentsimfeat import CurrentSimFeaturizer
from .freqfeat import FreqFeaturizer
from .langmodel import LangModelFeat
from .initattfeat import InitAttFeaturizer
from .occurfeat import OccurFeaturizer
from .occurattrfeat import OccurAttrFeaturizer

__all__ = ['FeaturizedDataset', 'Featurizer', 'InitFeaturizer', 'InitSimFeaturizer', 'FreqFeaturizer',
'OccurFeaturizer', 'ConstraintFeat', 'LangModelFeat', 'InitAttFeaturizer', 'OccurAttrFeaturizer']
__all__ = ['FeaturizedDataset', 'Featurizer', 'CurrentFeaturizer', 'CurrentSimFeaturizer', 'FreqFeaturizer',
'OccurFeaturizer', 'ConstraintFeat', 'LangModelFeat', 'CurrentAttrFeaturizer', 'OccurAttrFeaturizer']
Loading