diff --git a/dataset/dataset.py b/dataset/dataset.py index 0e5dc6068..817eb5671 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -280,12 +280,21 @@ def get_domain_info(self): def get_inferred_values(self): tic = time.clock() - query = "SELECT t1._tid_, t1.attribute, domain[inferred_assignment + 1] as rv_value " \ - "FROM " \ - "(SELECT _tid_, attribute, " \ - "_vid_, init_value, string_to_array(regexp_replace(domain, \'[{\"\"}]\', \'\', \'gi\'), \'|||\') as domain " \ - "FROM %s) as t1, %s as t2 " \ - "WHERE t1._vid_ = t2._vid_"%(AuxTables.cell_domain.name, AuxTables.inf_values_idx.name) + query = """ + SELECT t1._tid_, + t1.attribute, + domain[inferred_assignment + 1] AS rv_value + FROM ( + SELECT _tid_, + attribute, + _vid_, + current_value, + string_to_array(regexp_replace(domain, \'[{{\"\"}}]\', \'\', \'gi\'), \'|||\') AS domain + FROM {cell_domain}) AS t1, + {inf_values_idx} AS t2 + WHERE t1._vid_ = t2._vid_ + """.format(cell_domain=AuxTables.cell_domain.name, + inf_values_idx=AuxTables.inf_values_idx.name) self.generate_aux_table_sql(AuxTables.inf_values_dom, query, index_attrs=['_tid_']) self.aux_table[AuxTables.inf_values_dom].create_db_index(self.engine, ['attribute']) status = "DONE collecting the inferred values." diff --git a/domain/domain.py b/domain/domain.py index d45b8d97d..17fbe9380 100644 --- a/domain/domain.py +++ b/domain/domain.py @@ -7,7 +7,6 @@ from dataset import AuxTables - class DomainEngine: def __init__(self, env, dataset, cor_strength = 0.1, sampling_prob=0.3, max_sample=5): """ @@ -74,21 +73,33 @@ def store_domains(self, domain): pos_values schema: _tid_: entity/tuple ID _cid_: cell ID - _vid_: random variable ID (all cells with more than 1 domain value) - _ - + _vid_: random variable ID (1-1 with _cid_) + attribute: name of attribute + rv_val: cell value + val_id: domain index of rv_val """ if domain.empty: raise Exception("ERROR: Generated domain is empty.") - else: - self.ds.generate_aux_table(AuxTables.cell_domain, domain, store=True, index_attrs=['_vid_']) - self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_tid_']) - self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_cid_']) - query = "SELECT _vid_, _cid_, _tid_, attribute, a.rv_val, a.val_id from %s , unnest(string_to_array(regexp_replace(domain,\'[{\"\"}]\',\'\',\'gi\'),\'|||\')) WITH ORDINALITY a(rv_val,val_id)" % AuxTables.cell_domain.name - self.ds.generate_aux_table_sql(AuxTables.pos_values, query, index_attrs=['_tid_', 'attribute']) + + self.ds.generate_aux_table(AuxTables.cell_domain, domain, store=True, index_attrs=['_vid_']) + self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_tid_']) + self.ds.aux_table[AuxTables.cell_domain].create_db_index(self.ds.engine, ['_cid_']) + query = """ + SELECT + _vid_, + _cid_, + _tid_, + attribute, + a.rv_val, + a.val_id + FROM + {cell_domain}, + unnest(string_to_array(regexp_replace(domain,\'[{{\"\"}}]\',\'\',\'gi\'),\'|||\')) WITH ORDINALITY a(rv_val,val_id) + """.format(cell_domain=AuxTables.cell_domain.name) + self.ds.generate_aux_table_sql(AuxTables.pos_values, query, index_attrs=['_tid_', 'attribute']) def setup_attributes(self): - self.active_attributes = self.get_active_attributes() + self.active_attributes = self.fetch_active_attributes() total, single_stats, pair_stats = self.ds.get_statistics() self.total = total self.single_stats = single_stats @@ -132,9 +143,9 @@ def _topk_pair_stats(self, pair_stats): out[attr1][attr2][val1] = top_cands return out - def get_active_attributes(self): + def fetch_active_attributes(self): """ - get_active_attributes returns the attributes to be modeled. + fetch_active_attributes fetches/refetches the attributes to be modeled. These attributes correspond only to attributes that contain at least one potentially erroneous cell. """ @@ -177,10 +188,13 @@ def generate_domain(self): _cid_: cell ID (unique for every entity-attribute) _vid_: variable ID (1-1 correspondence with _cid_) attribute: attribute name + attribute_idx: index of attribute domain: ||| seperated string of domain values domain_size: length of domain - init_value: initial value for this cell - init_value_idx: domain index of init_value + init_values: initial values for this cell + init_values_idx: domain indexes of init_values + current_value: current value (current predicted) + current_value_idx: domain index for current value fixed: 1 if a random sample was taken since no correlated attributes/top K values """ @@ -195,24 +209,33 @@ def generate_domain(self): for row in tqdm(list(records)): tid = row['_tid_'] app = [] + + # Iterate over each active attribute (attributes that have at + # least one dk cell) and generate for this cell: + # 1) the domain values + # 2) the initial values (taken from raw data) + # 3) the current value (best predicted value) for attr in self.active_attributes: - init_value, dom = self.get_domain_cell(attr, row) - init_value_idx = dom.index(init_value) - if len(dom) > 1: - cid = self.ds.get_cell_id(tid, attr) - app.append({"_tid_": tid, "attribute": attr, "_cid_": cid, "_vid_":vid, "domain": "|||".join(dom), "domain_size": len(dom), - "init_value": init_value, "init_index": init_value_idx, "fixed":0}) - vid += 1 - else: - add_domain = self.get_random_domain(attr,init_value) - # Check if attribute has more than one unique values - if len(add_domain) > 0: - dom.extend(self.get_random_domain(attr,init_value)) - cid = self.ds.get_cell_id(tid, attr) - app.append({"_tid_": tid, "attribute": attr, "_cid_": cid, "_vid_": vid, "domain": "|||".join(dom), - "domain_size": len(dom), - "init_value": init_value, "init_index": init_value_idx, "fixed": 1}) - vid += 1 + init_values, current_value, dom = self.get_domain_cell(attr, row) + init_values_idx = [dom.index(val) for val in init_values] + current_value_idx = dom.index(current_value) + cid = self.ds.get_cell_id(tid, attr) + fixed = 0 + + # If domain could not be generated from correlated attributes, + # randomly choose values to add to our domain. + if len(dom) == 1: + fixed = 1 + add_domain = self.get_random_domain(attr, init_values) + dom.extend(add_domain) + + app.append({"_tid_": tid, "_cid_": cid, "_vid_":vid, + "attribute": attr, "attribute_idx": self.ds.attr_to_idx[attr], + "domain": '|||'.join(dom), "domain_size": len(dom), + "init_values": '|||'.join(init_values), "init_values_idx": '|||'.join(map(str,init_values_idx)), + "current_value": current_value, "current_value_idx": current_value_idx, + "fixed": fixed}) + vid+=1 cells.extend(app) domain_df = pd.DataFrame(data=cells) logging.info('DONE generating domain') @@ -220,8 +243,8 @@ def generate_domain(self): def get_domain_cell(self, attr, row): """ - get_domain_cell returns a list of all domain values for the given - entity (row) and attribute. + get_domain_cell returns list of init values, current (best predicted) + value, and list of domain values for the given cell. We define domain values as values in 'attr' that co-occur with values in attributes ('cond_attr') that are correlated with 'attr' at least in @@ -237,10 +260,10 @@ def get_domain_cell(self, attr, row): This would produce [B,C,E] as domain values. - :return: (initial value of entity-attribute, domain values for entity-attribute). + :return: (list of initial values, current value, list of domain values). """ - domain = set([]) + domain = set() correlated_attributes = self.get_corr_attributes(attr) # Iterate through all attributes correlated at least self.cor_strength ('cond_attr') # and take the top K co-occurrence values for 'attr' with the current @@ -265,25 +288,35 @@ def get_domain_cell(self, attr, row): # Remove _nan_ if added due to correlated attributes domain.discard('_nan_') + # Add initial value in domain - if pd.isnull(row[attr]): - domain.update(set(['_nan_'])) - init_value = '_nan_' - else: - domain.update(set([row[attr]])) - init_value = row[attr] - return init_value, list(domain) + init_values = ['_nan_'] + if not pd.isnull(row[attr]): + # Assume value in raw dataset is given as ||| separate initial values + init_values = row[attr].split('|||') + domain.update(set(init_values)) - def get_random_domain(self, attr, cur_value): + # Take the first initial value as the current value + # TODO(richardwu): revisit how we should initialize 'current' + current_value = init_values[0] + + return init_values, current_value, list(domain) + + def get_random_domain(self, attr, init_values): """ get_random_domain returns a random sample of at most size - 'self.max_sample' of domain values for 'attr' that is NOT 'cur_value'. + 'self.max_sample' of domain values for :param attr: that is NOT any + of :param init_values: + + :param attr: (str) name of attribute to generate random domain for + :param init_values: (list[str]) list of initial values """ if random.random() > self.sampling_prob: return [] domain_pool = set(self.single_stats[attr].keys()) - domain_pool.discard(cur_value) + # Do not include initial values in random domain + domain_pool = domain_pool.difference(init_values) size = len(domain_pool) if size > 0: k = min(self.max_sample, size) diff --git a/evaluate/eval.py b/evaluate/eval.py index 5f1c58df2..7554020b4 100644 --- a/evaluate/eval.py +++ b/evaluate/eval.py @@ -91,26 +91,46 @@ def eval_report(self): return report, report_time, report_list def compute_total_repairs(self): - query = "SELECT count(*) FROM " \ - "(SELECT _vid_ " \ - "FROM %s as t1, %s as t2 " \ - "WHERE t1._tid_ = t2._tid_ " \ - "AND t1.attribute = t2.attribute " \ - "AND t1.init_value != t2.rv_value) AS t"\ - %(AuxTables.cell_domain.name, AuxTables.inf_values_dom.name) + query = """ + SELECT + count(*) + FROM + (SELECT + _vid_ + FROM + {cell_domain} AS t1, + {inf_values_dom} as t2 + WHERE + t1._tid_ = t2._tid_ + AND t1.attribute = t2.attribute + AND t1.current_value != t2.rv_value + ) AS t + """.format(cell_domain=AuxTables.cell_domain.name, + inf_values_dom=AuxTables.inf_values_dom.name) res = self.ds.engine.execute_query(query) self.total_repairs = float(res[0][0]) def compute_total_repairs_grdt(self): - query = "SELECT count(*) FROM " \ - "(SELECT _vid_ " \ - "FROM %s as t1, %s as t2, %s as t3 " \ - "WHERE t1._tid_ = t2._tid_ " \ - "AND t1.attribute = t2.attribute " \ - "AND t1.init_value != t2.rv_value " \ - "AND t1._tid_ = t3._tid_ " \ - "AND t1.attribute = t3._attribute_) AS t"\ - %(AuxTables.cell_domain.name, AuxTables.inf_values_dom.name, self.clean_data.name) + query = """ + SELECT + count(*) + FROM + (SELECT + _vid_ + FROM + {cell_domain} AS t1, + {inf_values_dom} AS t2, + {clean_data} AS t3 + WHERE + t1._tid_ = t2._tid_ + AND t1.attribute = t2.attribute + AND t1.current_value != t2.rv_value + AND t1._tid_ = t3._tid_ + AND t1.attribute = t3._attribute_ + ) AS t + """.format(cell_domain=AuxTables.cell_domain.name, + inf_values_dom=AuxTables.inf_values_dom.name, + clean_data=self.clean_data.name) res = self.ds.engine.execute_query(query) self.total_repairs_grdt = float(res[0][0]) @@ -139,13 +159,25 @@ def compute_total_errors_grdt(self): self.total_errors = total_errors def compute_detected_errors(self): - query = "SELECT count(*) FROM " \ - "(SELECT _vid_ " \ - "FROM %s as t1, %s as t2, %s as t3 " \ - "WHERE t1._tid_ = t2._tid_ AND t1._cid_ = t3._cid_ " \ - "AND t1.attribute = t2._attribute_ " \ - "AND t1.init_value != t2._value_) AS t" \ - % (AuxTables.cell_domain.name, self.clean_data.name, AuxTables.dk_cells.name) + query = """ + SELECT + count(*) + FROM + (SELECT + _vid_ + FROM + {cell_domain} AS t1, + {clean_data} AS t2, + {dk_cells} AS t3 + WHERE + t1._tid_ = t2._tid_ + AND t1._cid_ = t3._cid_ + AND t1.attribute = t2._attribute_ + AND t1.current_value != t2._value_ + ) AS t + """.format(cell_domain=AuxTables.cell_domain.name, + clean_data=self.clean_data.name, + dk_cells=AuxTables.dk_cells.name) res = self.ds.engine.execute_query(query) self.detected_errors = float(res[0][0]) diff --git a/examples/holoclean_repair_example.py b/examples/holoclean_repair_example.py index ee43ee55e..ef2490933 100644 --- a/examples/holoclean_repair_example.py +++ b/examples/holoclean_repair_example.py @@ -1,8 +1,8 @@ import holoclean from detect import NullDetector, ViolationDetector -from repair.featurize import InitFeaturizer -from repair.featurize import InitAttFeaturizer -from repair.featurize import InitSimFeaturizer +from repair.featurize import CurrentFeaturizer +from repair.featurize import CurrentAttrFeaturizer +from repair.featurize import CurrentSimFeaturizer from repair.featurize import FreqFeaturizer from repair.featurize import OccurFeaturizer from repair.featurize import ConstraintFeat @@ -23,7 +23,7 @@ # 4. Repair errors utilizing the defined features. hc.setup_domain() -featurizers = [InitAttFeaturizer(learnable=False), InitSimFeaturizer(), FreqFeaturizer(), OccurFeaturizer(), LangModelFeat(), ConstraintFeat()] +featurizers = [CurrentAttrFeaturizer(learnable=False), CurrentSimFeaturizer(), FreqFeaturizer(), OccurFeaturizer(), LangModelFeat(), ConstraintFeat()] hc.repair_errors(featurizers) # 5. Evaluate the correctness of the results. diff --git a/repair/featurize/__init__.py b/repair/featurize/__init__.py index 23a9d438e..fab366780 100644 --- a/repair/featurize/__init__.py +++ b/repair/featurize/__init__.py @@ -1,13 +1,13 @@ from .featurize import FeaturizedDataset from .featurizer import Featurizer -from .initfeat import InitFeaturizer -from .initsimfeat import InitSimFeaturizer -from .freqfeat import FreqFeaturizer -from .occurfeat import OccurFeaturizer from .constraintfeat import ConstraintFeat +from .currentfeat import CurrentFeaturizer +from .currentattrfeat import CurrentAttrFeaturizer +from .currentsimfeat import CurrentSimFeaturizer +from .freqfeat import FreqFeaturizer from .langmodel import LangModelFeat -from .initattfeat import InitAttFeaturizer +from .occurfeat import OccurFeaturizer from .occurattrfeat import OccurAttrFeaturizer -__all__ = ['FeaturizedDataset', 'Featurizer', 'InitFeaturizer', 'InitSimFeaturizer', 'FreqFeaturizer', - 'OccurFeaturizer', 'ConstraintFeat', 'LangModelFeat', 'InitAttFeaturizer', 'OccurAttrFeaturizer'] +__all__ = ['FeaturizedDataset', 'Featurizer', 'CurrentFeaturizer', 'CurrentSimFeaturizer', 'FreqFeaturizer', + 'OccurFeaturizer', 'ConstraintFeat', 'LangModelFeat', 'CurrentAttrFeaturizer', 'OccurAttrFeaturizer'] diff --git a/repair/featurize/initattfeat.py b/repair/featurize/currentattrfeat.py similarity index 58% rename from repair/featurize/initattfeat.py rename to repair/featurize/currentattrfeat.py index 10d0369b6..b7d0eaac4 100644 --- a/repair/featurize/initattfeat.py +++ b/repair/featurize/currentattrfeat.py @@ -7,23 +7,27 @@ def gen_feat_tensor(input, classes, total_attrs): vid = int(input[0]) attr_idx = input[1] - init_idx = int(input[2]) + current_idx = int(input[2]) tensor = -1.0*torch.ones(1,classes,total_attrs) - tensor[0][init_idx][attr_idx] = 1.0 + tensor[0][current_idx][attr_idx] = 1.0 return tensor -class InitAttFeaturizer(Featurizer): +class CurrentAttrFeaturizer(Featurizer): def specific_setup(self): - self.name = 'InitAttFeaturizer' + self.name = 'CurrentAttrFeaturizer' self.attr_to_idx = self.ds.attr_to_idx self.total_attrs = len(self.ds.attr_to_idx) def create_tensor(self): - query = 'SELECT _vid_, attribute, init_index FROM %s ORDER BY _vid_'%AuxTables.cell_domain.name + query = """ + SELECT + _vid_, + attribute_idx, + current_value_idx + FROM {cell_domain} + ORDER BY _vid_ + """.format(cell_domain=AuxTables.cell_domain.name) results = self.ds.engine.execute_query(query) - map_input = [] - for res in results: - map_input.append((res[0], self.attr_to_idx[res[1]], res[2])) - tensors = self.pool.map(partial(gen_feat_tensor, classes=self.classes, total_attrs=self.total_attrs), map_input) + tensors = self.pool.map(partial(gen_feat_tensor, classes=self.classes, total_attrs=self.total_attrs), results) combined = torch.cat(tensors) return combined diff --git a/repair/featurize/initfeat.py b/repair/featurize/currentfeat.py similarity index 59% rename from repair/featurize/initfeat.py rename to repair/featurize/currentfeat.py index e21be4989..72d610356 100644 --- a/repair/featurize/initfeat.py +++ b/repair/featurize/currentfeat.py @@ -7,18 +7,24 @@ def gen_feat_tensor(input, classes): vid = int(input[0]) - init_idx = int(input[1]) + current_idx = int(input[1]) tensor = -1.0*torch.ones(1,classes,1) - tensor[0][init_idx][0] = 1.0 + tensor[0][current_idx][0] = 1.0 return tensor -class InitFeaturizer(Featurizer): +class CurrentFeaturizer(Featurizer): def specific_setup(self): - self.name = 'InitFeaturizer' + self.name = 'CurrentFeaturizer' def create_tensor(self): - query = 'SELECT _vid_, init_index FROM %s ORDER BY _vid_'%AuxTables.cell_domain.name + query = """ + SELECT + _vid_, + current_value_idx + FROM {cell_domain} + ORDER BY _vid_ + """.format(cell_domain=AuxTables.cell_domain.name) results = self.ds.engine.execute_query(query) tensors = self.pool.map(partial(gen_feat_tensor, classes=self.classes), results) combined = torch.cat(tensors) diff --git a/repair/featurize/initsimfeat.py b/repair/featurize/currentsimfeat.py similarity index 58% rename from repair/featurize/initsimfeat.py rename to repair/featurize/currentsimfeat.py index a29529621..59c2a06a1 100644 --- a/repair/featurize/initsimfeat.py +++ b/repair/featurize/currentsimfeat.py @@ -9,31 +9,37 @@ def gen_feat_tensor(input, classes, total_attrs): vid = int(input[0]) attr_idx = input[1] - init_value = input[2] + current_value = input[2] + domain = input[3].split('|||') # TODO: To add more similarity metrics increase the last dimension of tensor. tensor = torch.zeros(1, classes, total_attrs) - domain = input[2].split('|||') for idx, val in enumerate(domain): - if val == init_value: + if val == current_value: sim = -1.0 else: - sim = 2*Levenshtein.ratio(val, init_value) - 1 + sim = 2*Levenshtein.ratio(val, current_value) - 1 tensor[0][idx][attr_idx] = sim return tensor -class InitSimFeaturizer(Featurizer): +class CurrentSimFeaturizer(Featurizer): def specific_setup(self): - self.name = 'InitSimFeaturizer' + self.name = 'CurrentSimFeaturizer' self.attr_to_idx = self.ds.attr_to_idx self.total_attrs = len(self.ds.attr_to_idx) def create_tensor(self): - query = 'SELECT _vid_, attribute, init_value, domain FROM %s ORDER BY _vid_'%AuxTables.cell_domain.name + query = """ + SELECT + _vid_, + attribute_idx, + current_value, + domain + FROM {cell_domain} + ORDER BY _vid_ + """.format(cell_domain=AuxTables.cell_domain.name) results = self.ds.engine.execute_query(query) - map_input = [] - for res in results: - map_input.append((res[0],self.attr_to_idx[res[1]],res[2])) - tensors = self.pool.map(partial(gen_feat_tensor, classes=self.classes, total_attrs=self.total_attrs), map_input) + # Map attribute to their attribute indexes + tensors = self.pool.map(partial(gen_feat_tensor, classes=self.classes, total_attrs=self.total_attrs), results) combined = torch.cat(tensors) return combined diff --git a/repair/featurize/featurize.py b/repair/featurize/featurize.py index 45b3f0310..441e0e73f 100644 --- a/repair/featurize/featurize.py +++ b/repair/featurize/featurize.py @@ -6,6 +6,7 @@ FeatInfo = namedtuple('FeatInfo', ['name', 'size', 'learnable', 'init_weight']) + class FeaturizedDataset: def __init__(self, dataset, env, featurizers): self.ds = dataset @@ -26,16 +27,27 @@ def __init__(self, dataset, env, featurizers): def generate_weak_labels(self): """ generate_weak_labels returns a tensor where for each VID we have the - domain index of the initial value. + domain index of the current value. :return: Torch.Tensor of size (# of variables) X 1 where tensor[i][0] - contains the domain index of the initial value for the i-th + contains the domain index of the current value for the i-th variable/VID. """ logging.debug("Generating weak labels.") - query = 'SELECT _vid_, init_index FROM %s AS t1 LEFT JOIN %s AS t2 ' \ - 'ON t1._cid_ = t2._cid_ WHERE t2._cid_ is NULL OR t1.fixed = 1;' % ( - AuxTables.cell_domain.name, AuxTables.dk_cells.name) + query = """ + SELECT + _vid_, + current_value_idx + FROM + {cell_domain} AS t1 + LEFT JOIN + {dk_cells} AS t2 + ON t1._cid_ = t2._cid_ + WHERE + t2._cid_ is NULL + OR t1.fixed = 1 + """.format(cell_domain=AuxTables.cell_domain.name, + dk_cells=AuxTables.dk_cells.name) res = self.ds.engine.execute_query(query) if len(res) == 0: raise Exception("No weak labels available. Reduce pruning threshold.") @@ -82,10 +94,10 @@ def get_training_data(self): get_training_data returns X_train, y_train, and mask_train where each row of each tensor is a variable/VID and y_train are weak labels for each variable i.e. they are - set as the initial values. + set as the current value. - This assumes that we have a larger proportion of correct initial values - and only a small amount of incorrect initial values which allow us + This assumes that we have a larger proportion of correct current values + and only a small amount of incorrect current values which allow us to train to convergence. """ train_idx = (self.weak_labels != -1).nonzero()[:,0] diff --git a/repair/featurize/occurfeat.py b/repair/featurize/occurfeat.py index 0233c00bc..d2e038acc 100644 --- a/repair/featurize/occurfeat.py +++ b/repair/featurize/occurfeat.py @@ -23,7 +23,7 @@ def setup_stats(self): """ Memoize single (frequency of attribute-value) and pairwise stats (frequency of attr1-value1-attr2-value2) - from Dataset. + for the current values from loaded dataset. self.single_stats is a dict { attribute -> { value -> count } }. self.pair_stats is a dict { attr1 -> { attr2 -> { val1 -> {val2 -> co-occur frequency } } } }. @@ -38,12 +38,12 @@ def setup_stats(self): def create_tensor(self): """ For each unique VID (cell) returns the co-occurrence probability between - each possible domain value for this VID and the initial/raw values for the + each possible domain value for this VID and the current value for the corresponding entity/tuple of this cell. :return: Torch.Tensor of shape (# of VIDs) X (max domain) X (# of attributes) where tensor[i][j][k] contains the co-occur probability between the j-th domain value - of the i-th random variable (VID) and the initial/raw value of the k-th + of the i-th random variable (VID) and the current value of the k-th attribute for the corresponding entity. """ # Iterate over tuples in domain diff --git a/tests/test_holoclean_repair.py b/tests/test_holoclean_repair.py index 9f2259647..5bdef72de 100644 --- a/tests/test_holoclean_repair.py +++ b/tests/test_holoclean_repair.py @@ -2,9 +2,9 @@ import holoclean from detect import NullDetector, ViolationDetector -from repair.featurize import InitFeaturizer -from repair.featurize import InitAttFeaturizer -from repair.featurize import InitSimFeaturizer +from repair.featurize import CurrentFeaturizer +from repair.featurize import CurrentAttrFeaturizer +from repair.featurize import CurrentSimFeaturizer from repair.featurize import FreqFeaturizer from repair.featurize import OccurFeaturizer from repair.featurize import ConstraintFeat @@ -27,7 +27,7 @@ def test_hospital(self): # 4. Repair errors utilizing the defined features. hc.setup_domain() - featurizers = [InitAttFeaturizer(), InitSimFeaturizer(), FreqFeaturizer(), OccurFeaturizer(), LangModelFeat(), ConstraintFeat()] + featurizers = [CurrentAttrFeaturizer(), CurrentSimFeaturizer(), FreqFeaturizer(), OccurFeaturizer(), LangModelFeat(), ConstraintFeat()] hc.repair_errors(featurizers) # 5. Evaluate the correctness of the results.