From 0f8c954b1cee793185cd7457e2e354f81a9d88c6 Mon Sep 17 00:00:00 2001 From: Gaurav Gupta Date: Sat, 18 Dec 2021 23:59:53 -0800 Subject: [PATCH 01/13] Add flake8-breakpoint to avoid code checkin with active breakpoints Signed-off-by: Gaurav Gupta Signed-off-by: giandos200 --- requirements-linting.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-linting.txt b/requirements-linting.txt index 628331b8..7026726f 100644 --- a/requirements-linting.txt +++ b/requirements-linting.txt @@ -1,6 +1,7 @@ flake8==3.9.2 flake8-bugbear==21.11.29 flake8-blind-except==0.1.1 +flake8-breakpoint flake8-builtins==1.5.3 flake8-logging-format==0.6.0 flake8-nb==0.3.0 From 7874455bfd92da67b053b3ff395f119f4b237832 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Thu, 13 Jan 2022 12:38:39 +0100 Subject: [PATCH 02/13] Enanchment in CF Generation Signed-off-by: giandos200 --- dice_ml/__init__.py | 2 +- dice_ml/counterfactual_explanations.py | 9 ++- .../data_interfaces/private_data_interface.py | 7 +- .../data_interfaces/public_data_interface.py | 10 ++- dice_ml/dice.py | 10 ++- dice_ml/diverse_counterfactuals.py | 7 +- dice_ml/explainer_interfaces/dice_KD.py | 13 ++-- dice_ml/explainer_interfaces/dice_genetic.py | 38 ++++++----- dice_ml/explainer_interfaces/dice_pytorch.py | 10 +-- dice_ml/explainer_interfaces/dice_random.py | 17 +++-- .../explainer_interfaces/dice_tensorflow1.py | 12 ++-- .../explainer_interfaces/dice_tensorflow2.py | 12 ++-- .../explainer_interfaces/explainer_base.py | 67 +++++++------------ .../explainer_interfaces/feasible_base_vae.py | 13 ++-- .../feasible_model_approx.py | 8 +-- dice_ml/model.py | 4 +- dice_ml/model_interfaces/base_model.py | 8 +-- .../keras_tensorflow_model.py | 5 +- dice_ml/model_interfaces/pytorch_model.py | 5 +- dice_ml/utils/helpers.py | 16 ++--- .../utils/sample_architecture/vae_model.py | 4 +- ...ing_different_CF_explanation_methods.ipynb | 4 +- 22 files changed, 132 insertions(+), 149 deletions(-) diff --git a/dice_ml/__init__.py b/dice_ml/__init__.py index 63a3dc1d..cf41c603 100644 --- a/dice_ml/__init__.py +++ b/dice_ml/__init__.py @@ -1,6 +1,6 @@ from .data import Data -from .dice import Dice from .model import Model +from .dice import Dice __all__ = ["Data", "Model", diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index ba1a95f3..2e499784 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -1,12 +1,11 @@ import json -import os - import jsonschema +import os -from dice_ml.constants import _SchemaVersions -from dice_ml.diverse_counterfactuals import (CounterfactualExamples, - _DiverseCFV2SchemaConstants) +from dice_ml.diverse_counterfactuals import CounterfactualExamples from dice_ml.utils.exception import UserConfigValidationException +from dice_ml.diverse_counterfactuals import _DiverseCFV2SchemaConstants +from dice_ml.constants import _SchemaVersions class _CommonSchemaConstants: diff --git a/dice_ml/data_interfaces/private_data_interface.py b/dice_ml/data_interfaces/private_data_interface.py index 2960c17b..f319442d 100644 --- a/dice_ml/data_interfaces/private_data_interface.py +++ b/dice_ml/data_interfaces/private_data_interface.py @@ -1,11 +1,10 @@ """Module containing meta data information about private data.""" -import collections -import logging import sys - -import numpy as np import pandas as pd +import numpy as np +import collections +import logging from dice_ml.data_interfaces.base_data_interface import _BaseData diff --git a/dice_ml/data_interfaces/public_data_interface.py b/dice_ml/data_interfaces/public_data_interface.py index ab3e5ed1..31b8e6af 100644 --- a/dice_ml/data_interfaces/public_data_interface.py +++ b/dice_ml/data_interfaces/public_data_interface.py @@ -1,15 +1,13 @@ """Module containing all required information about the interface between raw (or transformed) public data and DiCE explainers.""" +import pandas as pd +import numpy as np import logging from collections import defaultdict -import numpy as np -import pandas as pd - from dice_ml.data_interfaces.base_data_interface import _BaseData -from dice_ml.utils.exception import (SystemException, - UserConfigValidationException) +from dice_ml.utils.exception import SystemException, UserConfigValidationException class PublicData(_BaseData): @@ -260,7 +258,7 @@ def get_valid_feature_range(self, feature_range_input, normalized=True): """ feature_range = {} - for _, feature_name in enumerate(self.feature_names): + for idx, feature_name in enumerate(self.feature_names): feature_range[feature_name] = [] if feature_name in self.continuous_feature_names: max_value = self.data_df[feature_name].max() diff --git a/dice_ml/dice.py b/dice_ml/dice.py index d1c78172..961b2859 100644 --- a/dice_ml/dice.py +++ b/dice_ml/dice.py @@ -3,9 +3,9 @@ such as RandomSampling, DiCEKD or DiCEGenetic""" from dice_ml.constants import BackEndTypes, SamplingStrategy -from dice_ml.data_interfaces.private_data_interface import PrivateData -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.utils.exception import UserConfigValidationException +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +from dice_ml.data_interfaces.private_data_interface import PrivateData class Dice(ExplainerBase): @@ -67,14 +67,12 @@ def decide(model_interface, method): elif model_interface.backend == BackEndTypes.Tensorflow1: # pretrained Keras Sequential model with Tensorflow 1.x backend - from dice_ml.explainer_interfaces.dice_tensorflow1 import \ - DiceTensorFlow1 + from dice_ml.explainer_interfaces.dice_tensorflow1 import DiceTensorFlow1 return DiceTensorFlow1 elif model_interface.backend == BackEndTypes.Tensorflow2: # pretrained Keras Sequential model with Tensorflow 2.x backend - from dice_ml.explainer_interfaces.dice_tensorflow2 import \ - DiceTensorFlow2 + from dice_ml.explainer_interfaces.dice_tensorflow2 import DiceTensorFlow2 return DiceTensorFlow2 elif model_interface.backend == BackEndTypes.Pytorch: diff --git a/dice_ml/diverse_counterfactuals.py b/dice_ml/diverse_counterfactuals.py index 2dc5c044..fe8aa134 100644 --- a/dice_ml/diverse_counterfactuals.py +++ b/dice_ml/diverse_counterfactuals.py @@ -1,10 +1,8 @@ +import pandas as pd import copy import json - -import pandas as pd - -from dice_ml.constants import ModelTypes, _SchemaVersions from dice_ml.utils.serialize import DummyDataInterface +from dice_ml.constants import _SchemaVersions, ModelTypes class _DiverseCFV1SchemaConstants: @@ -117,7 +115,6 @@ def _visualize_internal(self, display_sparse_df=True, show_only_changes=False, def visualize_as_dataframe(self, display_sparse_df=True, show_only_changes=False): from IPython.display import display - # original instance print('Query instance (original outcome : %i)' % round(self.test_pred)) display(self.test_instance_df) # works only in Jupyter notebook diff --git a/dice_ml/explainer_interfaces/dice_KD.py b/dice_ml/explainer_interfaces/dice_KD.py index 532c340e..bac3bf6e 100644 --- a/dice_ml/explainer_interfaces/dice_KD.py +++ b/dice_ml/explainer_interfaces/dice_KD.py @@ -2,15 +2,14 @@ Module to generate counterfactual explanations from a KD-Tree This code is similar to 'Interpretable Counterfactual Explanations Guided by Prototypes': https://arxiv.org/pdf/1907.02584.pdf """ -import copy -import timeit - +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase import numpy as np +import timeit import pandas as pd +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceKD(ExplainerBase): @@ -260,10 +259,14 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig if total_cfs_found < total_CFs: self.elapsed = timeit.default_timer() - start_time m, s = divmod(self.elapsed, 60) - print('Only %d (required %d) ' % (total_cfs_found, self.total_CFs), + print('Only %d (required %d) ' % (total_cfs_found, total_CFs), 'Diverse Counterfactuals found for the given configuation, perhaps ', 'change the query instance or the features to vary...' '; total time taken: %02d' % m, 'min %02d' % s, 'sec') + elif total_cfs_found == 0: + print( + 'No Counterfactuals found for the given configuration, perhaps try with different parameters...', + '; total time taken: %02d' % m, 'min %02d' % s, 'sec') else: print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec') diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index ed35ee49..5928bbc2 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -2,17 +2,16 @@ Module to generate diverse counterfactual explanations based on genetic algorithm This code is similar to 'GeCo: Quality Counterfactual Explanations in Real Time': https://arxiv.org/pdf/2101.01292.pdf """ -import copy -import random -import timeit - +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase import numpy as np import pandas as pd +import random +import timeit +import copy from sklearn.preprocessing import LabelEncoder from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceGenetic(ExplainerBase): @@ -116,9 +115,8 @@ def do_random_init(self, num_inits, features_to_vary, query_instance, desired_cl def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desired_range): cfs = self.label_encode(cfs) cfs = cfs.reset_index(drop=True) - - self.cfs = np.zeros((self.population_size, self.data_interface.number_of_features)) - for kx in range(self.population_size): + row = [] + for kx in range(self.population_size*5): if kx >= len(cfs): break one_init = np.zeros(self.data_interface.number_of_features) @@ -143,16 +141,21 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir one_init[jx] = query_instance[jx] else: one_init[jx] = np.random.choice(self.feature_range[feature]) - self.cfs[kx] = one_init + t = tuple(one_init) + if t not in row: + row.append(t) + if len(row) == self.population_size: + break kx += 1 + self.cfs = np.array(row) - new_array = [tuple(row) for row in self.cfs] - uniques = np.unique(new_array, axis=0) - - if len(uniques) != self.population_size: + #if len(self.cfs) > self.population_size: + # pass + if len(self.cfs) != self.population_size: + print("Pericolo Loop infinito....!!!!") remaining_cfs = self.do_random_init( - self.population_size - len(uniques), features_to_vary, query_instance, desired_class, desired_range) - self.cfs = np.concatenate([uniques, remaining_cfs]) + self.population_size - len(self.cfs), features_to_vary, query_instance, desired_class, desired_range) + self.cfs = np.concatenate([self.cfs, remaining_cfs]) def do_cf_initializations(self, total_CFs, initialization, algorithm, features_to_vary, desired_range, desired_class, @@ -260,7 +263,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k (see diverse_counterfactuals.py). """ - self.population_size = 10 * total_CFs + self.population_size = 3 * total_CFs self.start_time = timeit.default_timer() @@ -514,6 +517,9 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class, if len(self.final_cfs) == self.total_CFs: print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec') + elif len(self.final_cfs) == 0: + print('No Counterfactuals found for the given configuration, perhaps try with different parameters...', + '; total time taken: %02d' % m, 'min %02d' % s, 'sec') else: print('Only %d (required %d) ' % (len(self.final_cfs), self.total_CFs), 'Diverse Counterfactuals found for the given configuation, perhaps ', diff --git a/dice_ml/explainer_interfaces/dice_pytorch.py b/dice_ml/explainer_interfaces/dice_pytorch.py index 09d257cc..0aaa52f3 100644 --- a/dice_ml/explainer_interfaces/dice_pytorch.py +++ b/dice_ml/explainer_interfaces/dice_pytorch.py @@ -1,16 +1,16 @@ """ Module to generate diverse counterfactual explanations based on PyTorch framework """ -import copy -import random -import timeit +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +import torch import numpy as np -import torch +import random +import timeit +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DicePyTorch(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index 2995a398..b97c8fcf 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -3,15 +3,14 @@ Module to generate diverse counterfactual explanations based on random sampling. A simple implementation. """ -import random -import timeit - +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase import numpy as np import pandas as pd +import random +import timeit from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceRandom(ExplainerBase): @@ -109,11 +108,17 @@ class of query_instance for binary classification. cfs_df = None candidate_cfs = pd.DataFrame( np.repeat(query_instance.values, sample_size, axis=0), columns=query_instance.columns) - # Loop to change one feature at a time, then two features, and so on. + # Loop to change one feature at a time ##->(NOT TRUE), then two features, and so on. for num_features_to_vary in range(1, len(self.features_to_vary)+1): + # commented lines allow more values to change as num_features_to_vary increases, instead of .at you should use .loc + # is deliberately left commented out to let you choose. + # is slower, but more complete and still faster than genetic/KDtree + # selected_features = np.random.choice(self.features_to_vary, (sample_size, num_features_to_vary), replace=True) selected_features = np.random.choice(self.features_to_vary, (sample_size, 1), replace=True) for k in range(sample_size): - candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]] + candidate_cfs.at[k, selected_features[k][0]] = random_instances._get_value(k, selected_features[k][0]) + # If you only want to change one feature, you should use _get_value + # candidate_cfs.iloc[k][selected_features[k]]=random_instances.iloc[k][selected_features[k]] scores = self.predict_fn(candidate_cfs) validity = self.decide_cf_validity(scores) if sum(validity) > 0: diff --git a/dice_ml/explainer_interfaces/dice_tensorflow1.py b/dice_ml/explainer_interfaces/dice_tensorflow1.py index 69ee8298..8ad5088b 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow1.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow1.py @@ -1,17 +1,17 @@ """ Module to generate diverse counterfactual explanations based on tensorflow 1.x """ -import collections -import copy -import random -import timeit +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +import tensorflow as tf import numpy as np -import tensorflow as tf +import random +import collections +import timeit +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceTensorFlow1(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_tensorflow2.py b/dice_ml/explainer_interfaces/dice_tensorflow2.py index 58445929..b8e1ab75 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow2.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow2.py @@ -1,16 +1,16 @@ """ Module to generate diverse counterfactual explanations based on tensorflow 2.x """ -import copy -import random -import timeit +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +import tensorflow as tf import numpy as np -import tensorflow as tf +import random +import timeit +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceTensorFlow2(ExplainerBase): @@ -177,7 +177,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary): # CF initialization if len(self.cfs) != self.total_CFs: self.cfs = [] - for _ in range(self.total_CFs): + for ix in range(self.total_CFs): one_init = [[]] for jx in range(self.minx.shape[1]): one_init[0].append(np.random.uniform(self.minx[0][jx], self.maxx[0][jx])) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 5f25d1cc..d1c73f4e 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -2,17 +2,17 @@ Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch. All methods are in dice_ml.explainer_interfaces""" +import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable - import numpy as np import pandas as pd -from sklearn.neighbors import KDTree from tqdm import tqdm -from dice_ml.constants import ModelTypes +from collections.abc import Iterable +from sklearn.neighbors import KDTree from dice_ml.counterfactual_explanations import CounterfactualExplanations from dice_ml.utils.exception import UserConfigValidationException +from dice_ml.constants import ModelTypes class ExplainerBase(ABC): @@ -50,7 +50,7 @@ def generate_counterfactuals(self, query_instances, total_CFs, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, - posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): + posthoc_sparsity_algorithm=None, verbose=False, **kwargs): """General method for generating counterfactuals. :param query_instances: Input point(s) for which counterfactuals are to be generated. @@ -81,6 +81,16 @@ def generate_counterfactuals(self, query_instances, total_CFs, if total_CFs <= 0: raise UserConfigValidationException( "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") + if total_CFs > 10: + if posthoc_sparsity_algorithm == None: + posthoc_sparsity_algorithm = 'binary' + elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear': + warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; " + "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to " + "'binary' search!".format(total_CFs)) + elif posthoc_sparsity_algorithm == None: + posthoc_sparsity_algorithm = 'linear' + cf_examples_arr = [] query_instances_list = [] if isinstance(query_instances, pd.DataFrame): @@ -88,7 +98,6 @@ def generate_counterfactuals(self, query_instances, total_CFs, query_instances_list.append(query_instances[ix:(ix+1)]) elif isinstance(query_instances, Iterable): query_instances_list = query_instances - for query_instance in tqdm(query_instances_list): self.data_interface.set_continuous_feature_indexes(query_instance) res = self._generate_counterfactuals( @@ -103,9 +112,6 @@ def generate_counterfactuals(self, query_instances, total_CFs, verbose=verbose, **kwargs) cf_examples_arr.append(res) - - self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr) - return CounterfactualExplanations(cf_examples_list=cf_examples_arr) @abstractmethod @@ -211,12 +217,10 @@ def local_feature_importance(self, query_instances, cf_examples_list=None, if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " - "greater than or equal to 10 to compute feature importance for all query points") + "greater than or equal to 10") elif total_CFs < 10: - raise UserConfigValidationException( - "The number of counterfactuals requested per " - "query instance should be greater than or equal to 10 " - "to compute feature importance for all query points") + raise UserConfigValidationException("The number of counterfactuals generated per " + "query instance should be greater than or equal to 10") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, @@ -257,25 +261,16 @@ def global_feature_importance(self, query_instances, cf_examples_list=None, input, and the global feature importance summarized over all inputs. """ if query_instances is not None and len(query_instances) < 10: - raise UserConfigValidationException( - "The number of query instances should be greater than or equal to 10 " - "to compute global feature importance over all query points") + raise UserConfigValidationException("The number of query instances should be greater than or equal to 10") if cf_examples_list is not None: - if len(cf_examples_list) < 10: - raise UserConfigValidationException( - "The number of points for which counterfactuals generated should be " - "greater than or equal to 10 " - "to compute global feature importance") - elif any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): + if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " - "greater than or equal to 10 " - "to compute global feature importance over all query points") + "greater than or equal to 10") elif total_CFs < 10: raise UserConfigValidationException( - "The number of counterfactuals requested per query instance should be greater " - "than or equal to 10 " - "to compute global feature importance over all query points") + "The number of counterfactuals generated per query instance should be greater " + "than or equal to 10") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, @@ -354,7 +349,7 @@ def feature_importance(self, query_instances, cf_examples_list=None, continue per_query_point_cfs = 0 - for _, row in df.iterrows(): + for index, row in df.iterrows(): per_query_point_cfs += 1 for col in self.data_interface.continuous_feature_names: if not np.isclose(org_instance[col].iat[0], row[col]): @@ -535,7 +530,7 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred) self.target_cf_class = np.array( [[self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)]], dtype=np.float32) - desired_class = int(self.target_cf_class[0][0]) + desired_class = self.target_cf_class[0][0] if self.target_cf_class == 0 and self.stopping_threshold > 0.5: self.stopping_threshold = 0.25 elif self.target_cf_class == 1 and self.stopping_threshold < 0.5: @@ -700,15 +695,3 @@ def round_to_precision(self): self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix]) if self.final_cfs_df_sparse is not None: self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix]) - - def _check_any_counterfactuals_computed(self, cf_examples_arr): - """Check if any counterfactuals were generated for any query point.""" - no_cf_generated = True - # Check if any counterfactuals were generated for any query point - for cf_examples in cf_examples_arr: - if cf_examples.final_cfs_df is not None and len(cf_examples.final_cfs_df) > 0: - no_cf_generated = False - break - if no_cf_generated: - raise UserConfigValidationException( - "No counterfactuals found for any of the query points! Kindly check your configuration.") diff --git a/dice_ml/explainer_interfaces/feasible_base_vae.py b/dice_ml/explainer_interfaces/feasible_base_vae.py index 28dc7d13..c59a26d4 100644 --- a/dice_ml/explainer_interfaces/feasible_base_vae.py +++ b/dice_ml/explainer_interfaces/feasible_base_vae.py @@ -1,15 +1,16 @@ # General Imports import numpy as np -# Pytorch -import torch -import torch.utils.data -from torch.nn import functional as F -from dice_ml import diverse_counterfactuals as exp # Dice Imports from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +from dice_ml import diverse_counterfactuals as exp from dice_ml.utils.helpers import get_base_gen_cf_initialization +# Pytorch +import torch +import torch.utils.data +from torch.nn import functional as F + class FeasibleBaseVAE(ExplainerBase): @@ -186,7 +187,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp curr_cf_pred = [] curr_test_pred = train_y.numpy() - for _ in range(total_CFs): + for cf_count in range(total_CFs): recon_err, kl_err, x_true, x_pred, cf_label = \ self.cf_vae.compute_elbo(train_x, 1.0-train_y, self.pred_model) while(cf_label == train_y): diff --git a/dice_ml/explainer_interfaces/feasible_model_approx.py b/dice_ml/explainer_interfaces/feasible_model_approx.py index a7fffda5..01fc609a 100644 --- a/dice_ml/explainer_interfaces/feasible_model_approx.py +++ b/dice_ml/explainer_interfaces/feasible_model_approx.py @@ -1,13 +1,13 @@ # Dice Imports +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE +from dice_ml.utils.helpers import get_base_gen_cf_initialization + # Pytorch import torch import torch.utils.data from torch.nn import functional as F -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE -from dice_ml.utils.helpers import get_base_gen_cf_initialization - class FeasibleModelApprox(FeasibleBaseVAE, ExplainerBase): diff --git a/dice_ml/model.py b/dice_ml/model.py index 737701bf..0518cdf8 100644 --- a/dice_ml/model.py +++ b/dice_ml/model.py @@ -4,7 +4,6 @@ frameworks such as Tensorflow or PyTorch. """ import warnings - from dice_ml.constants import BackEndTypes, ModelTypes from dice_ml.utils.exception import UserConfigValidationException @@ -70,8 +69,7 @@ def decide(backend): import tensorflow # noqa: F401 except ImportError: raise UserConfigValidationException("Unable to import tensorflow. Please install tensorflow") - from dice_ml.model_interfaces.keras_tensorflow_model import \ - KerasTensorFlowModel + from dice_ml.model_interfaces.keras_tensorflow_model import KerasTensorFlowModel return KerasTensorFlowModel elif backend == BackEndTypes.Pytorch: diff --git a/dice_ml/model_interfaces/base_model.py b/dice_ml/model_interfaces/base_model.py index 3a25b5cf..09b49ddd 100644 --- a/dice_ml/model_interfaces/base_model.py +++ b/dice_ml/model_interfaces/base_model.py @@ -3,12 +3,10 @@ All model interface methods are in dice_ml.model_interfaces""" import pickle - import numpy as np - +from dice_ml.utils.helpers import DataTransfomer from dice_ml.constants import ModelTypes from dice_ml.utils.exception import SystemException -from dice_ml.utils.helpers import DataTransfomer class BaseModel: @@ -64,7 +62,7 @@ def get_num_output_nodes(self, inp_size): temp_input = np.transpose(np.array([np.random.uniform(0, 1) for i in range(inp_size)]).reshape(-1, 1)) return self.get_output(temp_input).shape[1] - def get_num_output_nodes2(self, input_instance): + def get_num_output_nodes2(self, input): if self.model_type == ModelTypes.Regressor: raise SystemException('Number of output nodes not supported for regression') - return self.get_output(input_instance).shape[1] + return self.get_output(input).shape[1] diff --git a/dice_ml/model_interfaces/keras_tensorflow_model.py b/dice_ml/model_interfaces/keras_tensorflow_model.py index df150850..72619f12 100644 --- a/dice_ml/model_interfaces/keras_tensorflow_model.py +++ b/dice_ml/model_interfaces/keras_tensorflow_model.py @@ -1,10 +1,9 @@ """Module containing an interface to trained Keras Tensorflow model.""" +from dice_ml.model_interfaces.base_model import BaseModel import tensorflow as tf from tensorflow import keras -from dice_ml.model_interfaces.base_model import BaseModel - class KerasTensorFlowModel(BaseModel): @@ -40,7 +39,7 @@ def get_output(self, input_tensor, training=False, transform_data=False): else: return self.model(input_tensor) - def get_gradient(self, input_instance): + def get_gradient(self, input): # Future Support raise NotImplementedError("Future Support") diff --git a/dice_ml/model_interfaces/pytorch_model.py b/dice_ml/model_interfaces/pytorch_model.py index 9cf577cc..e0063f5a 100644 --- a/dice_ml/model_interfaces/pytorch_model.py +++ b/dice_ml/model_interfaces/pytorch_model.py @@ -1,8 +1,7 @@ """Module containing an interface to trained PyTorch model.""" -import torch - from dice_ml.model_interfaces.base_model import BaseModel +import torch class PyTorchModel(BaseModel): @@ -38,7 +37,7 @@ def get_output(self, input_tensor, transform_data=False): def set_eval_mode(self): self.model.eval() - def get_gradient(self, input_instance): + def get_gradient(self, input): # Future Support raise NotImplementedError("Future Support") diff --git a/dice_ml/utils/helpers.py b/dice_ml/utils/helpers.py index e3ea688f..866662c3 100644 --- a/dice_ml/utils/helpers.py +++ b/dice_ml/utils/helpers.py @@ -1,17 +1,17 @@ """ This module containts helper functions to load data and get meta deta. """ -import os -import shutil - import numpy as np import pandas as pd -from sklearn.model_selection import train_test_split -# for data transformations -from sklearn.preprocessing import FunctionTransformer +import shutil +import os import dice_ml +# for data transformations +from sklearn.preprocessing import FunctionTransformer +from sklearn.model_selection import train_test_split + def load_adult_income_dataset(only_train=True): """Loads adult income dataset from https://archive.ics.uci.edu/ml/datasets/Adult and prepares @@ -168,11 +168,11 @@ def get_base_gen_cf_initialization(data_interface, encoded_size, cont_minx, cont wm1, wm2, wm3, learning_rate): # Dice Imports - TODO: keep this method for VAE as a spearate module or move it to feasible_base_vae.py. # Check dependencies. + from dice_ml.utils.sample_architecture.vae_model import CF_VAE + # Pytorch from torch import optim - from dice_ml.utils.sample_architecture.vae_model import CF_VAE - # Dataset for training Variational Encoder Decoder model for CF Generation df = data_interface.normalize_data(data_interface.one_hot_encoded_data) encoded_data = df[data_interface.ohe_encoded_feature_names + [data_interface.outcome_name]] diff --git a/dice_ml/utils/sample_architecture/vae_model.py b/dice_ml/utils/sample_architecture/vae_model.py index 3b4be568..6d461494 100644 --- a/dice_ml/utils/sample_architecture/vae_model.py +++ b/dice_ml/utils/sample_architecture/vae_model.py @@ -109,7 +109,7 @@ def forward(self, x, c): res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples - for _ in range(mc_samples): + for i in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(torch.cat((z, c), 1)) res['z'].append(z) @@ -239,7 +239,7 @@ def forward(self, x): res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples - for _ in range(mc_samples): + for i in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(z) res['z'].append(z) diff --git a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb index 9786d701..fec9ce5f 100644 --- a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb +++ b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb @@ -274,7 +274,7 @@ " for q in query_instances:\n", " if q in d.categorical_feature_names:\n", " query_instances.loc[:, q] = \\\n", - " [random.choice(dataset[q].values.unique()) for _ in query_instances.index]\n", + " [random.choice(dataset[q].unique()) for _ in query_instances.index]\n", " else:\n", " query_instances.loc[:, q] = \\\n", " [np.random.uniform(dataset[q].min(), dataset[q].max()) for _ in query_instances.index]\n", @@ -329,4 +329,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file From 7e9cf74c08a530f443956df350125f8629fdc83c Mon Sep 17 00:00:00 2001 From: giandos200 Date: Thu, 13 Jan 2022 18:28:30 +0100 Subject: [PATCH 03/13] import update Signed-off-by: giandos200 --- dice_ml/__init__.py | 2 +- dice_ml/counterfactual_explanations.py | 9 ++-- .../data_interfaces/private_data_interface.py | 7 +-- .../data_interfaces/public_data_interface.py | 11 ++-- dice_ml/dice.py | 8 +-- dice_ml/diverse_counterfactuals.py | 7 ++- dice_ml/explainer_interfaces/dice_KD.py | 7 +-- dice_ml/explainer_interfaces/dice_genetic.py | 9 ++-- dice_ml/explainer_interfaces/dice_pytorch.py | 10 ++-- dice_ml/explainer_interfaces/dice_random.py | 7 +-- .../explainer_interfaces/dice_tensorflow1.py | 12 ++--- .../explainer_interfaces/dice_tensorflow2.py | 12 ++--- .../explainer_interfaces/explainer_base.py | 54 ++++++++++++++----- .../explainer_interfaces/feasible_base_vae.py | 12 ++--- .../feasible_model_approx.py | 8 +-- dice_ml/model.py | 4 +- dice_ml/model_interfaces/base_model.py | 8 +-- .../keras_tensorflow_model.py | 5 +- dice_ml/model_interfaces/pytorch_model.py | 5 +- dice_ml/utils/helpers.py | 17 +++--- .../utils/sample_architecture/vae_model.py | 4 +- 21 files changed, 132 insertions(+), 86 deletions(-) diff --git a/dice_ml/__init__.py b/dice_ml/__init__.py index cf41c603..63a3dc1d 100644 --- a/dice_ml/__init__.py +++ b/dice_ml/__init__.py @@ -1,6 +1,6 @@ from .data import Data -from .model import Model from .dice import Dice +from .model import Model __all__ = ["Data", "Model", diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index 2e499784..ba1a95f3 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -1,11 +1,12 @@ import json -import jsonschema import os -from dice_ml.diverse_counterfactuals import CounterfactualExamples -from dice_ml.utils.exception import UserConfigValidationException -from dice_ml.diverse_counterfactuals import _DiverseCFV2SchemaConstants +import jsonschema + from dice_ml.constants import _SchemaVersions +from dice_ml.diverse_counterfactuals import (CounterfactualExamples, + _DiverseCFV2SchemaConstants) +from dice_ml.utils.exception import UserConfigValidationException class _CommonSchemaConstants: diff --git a/dice_ml/data_interfaces/private_data_interface.py b/dice_ml/data_interfaces/private_data_interface.py index f319442d..2960c17b 100644 --- a/dice_ml/data_interfaces/private_data_interface.py +++ b/dice_ml/data_interfaces/private_data_interface.py @@ -1,10 +1,11 @@ """Module containing meta data information about private data.""" -import sys -import pandas as pd -import numpy as np import collections import logging +import sys + +import numpy as np +import pandas as pd from dice_ml.data_interfaces.base_data_interface import _BaseData diff --git a/dice_ml/data_interfaces/public_data_interface.py b/dice_ml/data_interfaces/public_data_interface.py index 31b8e6af..b5c04796 100644 --- a/dice_ml/data_interfaces/public_data_interface.py +++ b/dice_ml/data_interfaces/public_data_interface.py @@ -1,14 +1,15 @@ """Module containing all required information about the interface between raw (or transformed) public data and DiCE explainers.""" -import pandas as pd -import numpy as np import logging from collections import defaultdict -from dice_ml.data_interfaces.base_data_interface import _BaseData -from dice_ml.utils.exception import SystemException, UserConfigValidationException +import numpy as np +import pandas as pd +from dice_ml.data_interfaces.base_data_interface import _BaseData +from dice_ml.utils.exception import (SystemException, + UserConfigValidationException) class PublicData(_BaseData): """A data interface for public data. This class is an interface to DiCE explainers @@ -258,7 +259,7 @@ def get_valid_feature_range(self, feature_range_input, normalized=True): """ feature_range = {} - for idx, feature_name in enumerate(self.feature_names): + for _, feature_name in enumerate(self.feature_names): feature_range[feature_name] = [] if feature_name in self.continuous_feature_names: max_value = self.data_df[feature_name].max() diff --git a/dice_ml/dice.py b/dice_ml/dice.py index 961b2859..8a55240e 100644 --- a/dice_ml/dice.py +++ b/dice_ml/dice.py @@ -3,9 +3,9 @@ such as RandomSampling, DiCEKD or DiCEGenetic""" from dice_ml.constants import BackEndTypes, SamplingStrategy -from dice_ml.utils.exception import UserConfigValidationException from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.data_interfaces.private_data_interface import PrivateData +from dice_ml.utils.exception import UserConfigValidationException class Dice(ExplainerBase): @@ -67,12 +67,14 @@ def decide(model_interface, method): elif model_interface.backend == BackEndTypes.Tensorflow1: # pretrained Keras Sequential model with Tensorflow 1.x backend - from dice_ml.explainer_interfaces.dice_tensorflow1 import DiceTensorFlow1 + from dice_ml.explainer_interfaces.dice_tensorflow1 import \ + DiceTensorFlow1 return DiceTensorFlow1 elif model_interface.backend == BackEndTypes.Tensorflow2: # pretrained Keras Sequential model with Tensorflow 2.x backend - from dice_ml.explainer_interfaces.dice_tensorflow2 import DiceTensorFlow2 + from dice_ml.explainer_interfaces.dice_tensorflow2 import \ + DiceTensorFlow2 return DiceTensorFlow2 elif model_interface.backend == BackEndTypes.Pytorch: diff --git a/dice_ml/diverse_counterfactuals.py b/dice_ml/diverse_counterfactuals.py index fe8aa134..2dc5c044 100644 --- a/dice_ml/diverse_counterfactuals.py +++ b/dice_ml/diverse_counterfactuals.py @@ -1,8 +1,10 @@ -import pandas as pd import copy import json + +import pandas as pd + +from dice_ml.constants import ModelTypes, _SchemaVersions from dice_ml.utils.serialize import DummyDataInterface -from dice_ml.constants import _SchemaVersions, ModelTypes class _DiverseCFV1SchemaConstants: @@ -115,6 +117,7 @@ def _visualize_internal(self, display_sparse_df=True, show_only_changes=False, def visualize_as_dataframe(self, display_sparse_df=True, show_only_changes=False): from IPython.display import display + # original instance print('Query instance (original outcome : %i)' % round(self.test_pred)) display(self.test_instance_df) # works only in Jupyter notebook diff --git a/dice_ml/explainer_interfaces/dice_KD.py b/dice_ml/explainer_interfaces/dice_KD.py index bac3bf6e..61b2220d 100644 --- a/dice_ml/explainer_interfaces/dice_KD.py +++ b/dice_ml/explainer_interfaces/dice_KD.py @@ -2,14 +2,15 @@ Module to generate counterfactual explanations from a KD-Tree This code is similar to 'Interpretable Counterfactual Explanations Guided by Prototypes': https://arxiv.org/pdf/1907.02584.pdf """ -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -import numpy as np +import copy import timeit + +import numpy as np import pandas as pd -import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceKD(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index 5928bbc2..3b8514cc 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -2,16 +2,17 @@ Module to generate diverse counterfactual explanations based on genetic algorithm This code is similar to 'GeCo: Quality Counterfactual Explanations in Real Time': https://arxiv.org/pdf/2101.01292.pdf """ -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -import numpy as np -import pandas as pd +import copy import random import timeit -import copy + +import numpy as np +import pandas as pd from sklearn.preprocessing import LabelEncoder from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceGenetic(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_pytorch.py b/dice_ml/explainer_interfaces/dice_pytorch.py index 0aaa52f3..91d2f0c3 100644 --- a/dice_ml/explainer_interfaces/dice_pytorch.py +++ b/dice_ml/explainer_interfaces/dice_pytorch.py @@ -1,16 +1,16 @@ """ Module to generate diverse counterfactual explanations based on PyTorch framework """ -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -import torch - -import numpy as np +import copy import random import timeit -import copy +import numpy as np + +import torch from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DicePyTorch(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index b97c8fcf..49866318 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -3,14 +3,15 @@ Module to generate diverse counterfactual explanations based on random sampling. A simple implementation. """ -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -import numpy as np -import pandas as pd import random import timeit +import numpy as np +import pandas as pd + from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceRandom(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_tensorflow1.py b/dice_ml/explainer_interfaces/dice_tensorflow1.py index 8ad5088b..69ee8298 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow1.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow1.py @@ -1,17 +1,17 @@ """ Module to generate diverse counterfactual explanations based on tensorflow 1.x """ -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -import tensorflow as tf - -import numpy as np -import random import collections -import timeit import copy +import random +import timeit + +import numpy as np +import tensorflow as tf from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceTensorFlow1(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_tensorflow2.py b/dice_ml/explainer_interfaces/dice_tensorflow2.py index b8e1ab75..58445929 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow2.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow2.py @@ -1,16 +1,16 @@ """ Module to generate diverse counterfactual explanations based on tensorflow 2.x """ -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -import tensorflow as tf - -import numpy as np +import copy import random import timeit -import copy + +import numpy as np +import tensorflow as tf from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceTensorFlow2(ExplainerBase): @@ -177,7 +177,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary): # CF initialization if len(self.cfs) != self.total_CFs: self.cfs = [] - for ix in range(self.total_CFs): + for _ in range(self.total_CFs): one_init = [[]] for jx in range(self.minx.shape[1]): one_init[0].append(np.random.uniform(self.minx[0][jx], self.maxx[0][jx])) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index d1c73f4e..f70d8be5 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -2,17 +2,17 @@ Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch. All methods are in dice_ml.explainer_interfaces""" -import warnings from abc import ABC, abstractmethod +from collections.abc import Iterable + import numpy as np import pandas as pd +from sklearn.neighbors import KDTree from tqdm import tqdm -from collections.abc import Iterable -from sklearn.neighbors import KDTree +from dice_ml.constants import ModelTypes from dice_ml.counterfactual_explanations import CounterfactualExplanations from dice_ml.utils.exception import UserConfigValidationException -from dice_ml.constants import ModelTypes class ExplainerBase(ABC): @@ -85,6 +85,7 @@ def generate_counterfactuals(self, query_instances, total_CFs, if posthoc_sparsity_algorithm == None: posthoc_sparsity_algorithm = 'binary' elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear': + import warnings warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; " "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to " "'binary' search!".format(total_CFs)) @@ -98,6 +99,7 @@ def generate_counterfactuals(self, query_instances, total_CFs, query_instances_list.append(query_instances[ix:(ix+1)]) elif isinstance(query_instances, Iterable): query_instances_list = query_instances + for query_instance in tqdm(query_instances_list): self.data_interface.set_continuous_feature_indexes(query_instance) res = self._generate_counterfactuals( @@ -112,6 +114,9 @@ def generate_counterfactuals(self, query_instances, total_CFs, verbose=verbose, **kwargs) cf_examples_arr.append(res) + + self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr) + return CounterfactualExplanations(cf_examples_list=cf_examples_arr) @abstractmethod @@ -217,10 +222,12 @@ def local_feature_importance(self, query_instances, cf_examples_list=None, if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " - "greater than or equal to 10") + "greater than or equal to 10 to compute feature importance for all query points") elif total_CFs < 10: - raise UserConfigValidationException("The number of counterfactuals generated per " - "query instance should be greater than or equal to 10") + raise UserConfigValidationException( + "The number of counterfactuals requested per " + "query instance should be greater than or equal to 10 " + "to compute feature importance for all query points") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, @@ -261,16 +268,25 @@ def global_feature_importance(self, query_instances, cf_examples_list=None, input, and the global feature importance summarized over all inputs. """ if query_instances is not None and len(query_instances) < 10: - raise UserConfigValidationException("The number of query instances should be greater than or equal to 10") + raise UserConfigValidationException( + "The number of query instances should be greater than or equal to 10 " + "to compute global feature importance over all query points") if cf_examples_list is not None: - if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): + if len(cf_examples_list) < 10: + raise UserConfigValidationException( + "The number of points for which counterfactuals generated should be " + "greater than or equal to 10 " + "to compute global feature importance") + elif any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " - "greater than or equal to 10") + "greater than or equal to 10" + "to compute global feature importance over all query points") elif total_CFs < 10: raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be greater " - "than or equal to 10") + "than or equal to 10" + "to compute global feature importance over all query points") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, @@ -349,7 +365,7 @@ def feature_importance(self, query_instances, cf_examples_list=None, continue per_query_point_cfs = 0 - for index, row in df.iterrows(): + for _, row in df.iterrows(): per_query_point_cfs += 1 for col in self.data_interface.continuous_feature_names: if not np.isclose(org_instance[col].iat[0], row[col]): @@ -530,7 +546,7 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred) self.target_cf_class = np.array( [[self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)]], dtype=np.float32) - desired_class = self.target_cf_class[0][0] + desired_class = int(self.target_cf_class[0][0]) if self.target_cf_class == 0 and self.stopping_threshold > 0.5: self.stopping_threshold = 0.25 elif self.target_cf_class == 1 and self.stopping_threshold < 0.5: @@ -695,3 +711,15 @@ def round_to_precision(self): self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix]) if self.final_cfs_df_sparse is not None: self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix]) + + def _check_any_counterfactuals_computed(self, cf_examples_arr): + """Check if any counterfactuals were generated for any query point.""" + no_cf_generated = True + # Check if any counterfactuals were generated for any query point + for cf_examples in cf_examples_arr: + if cf_examples.final_cfs_df is not None and len(cf_examples.final_cfs_df) > 0: + no_cf_generated = False + break + if no_cf_generated: + raise UserConfigValidationException( + "No counterfactuals found for any of the query points! Kindly check your configuration.") \ No newline at end of file diff --git a/dice_ml/explainer_interfaces/feasible_base_vae.py b/dice_ml/explainer_interfaces/feasible_base_vae.py index c59a26d4..503dbb29 100644 --- a/dice_ml/explainer_interfaces/feasible_base_vae.py +++ b/dice_ml/explainer_interfaces/feasible_base_vae.py @@ -1,15 +1,15 @@ # General Imports import numpy as np +# Pytorch +import torch +import torch.utils.data +from torch.nn import functional as F +from dice_ml import diverse_counterfactuals as exp # Dice Imports from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -from dice_ml import diverse_counterfactuals as exp from dice_ml.utils.helpers import get_base_gen_cf_initialization -# Pytorch -import torch -import torch.utils.data -from torch.nn import functional as F class FeasibleBaseVAE(ExplainerBase): @@ -187,7 +187,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp curr_cf_pred = [] curr_test_pred = train_y.numpy() - for cf_count in range(total_CFs): + for _ in range(total_CFs): recon_err, kl_err, x_true, x_pred, cf_label = \ self.cf_vae.compute_elbo(train_x, 1.0-train_y, self.pred_model) while(cf_label == train_y): diff --git a/dice_ml/explainer_interfaces/feasible_model_approx.py b/dice_ml/explainer_interfaces/feasible_model_approx.py index 01fc609a..a7fffda5 100644 --- a/dice_ml/explainer_interfaces/feasible_model_approx.py +++ b/dice_ml/explainer_interfaces/feasible_model_approx.py @@ -1,13 +1,13 @@ # Dice Imports -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE -from dice_ml.utils.helpers import get_base_gen_cf_initialization - # Pytorch import torch import torch.utils.data from torch.nn import functional as F +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE +from dice_ml.utils.helpers import get_base_gen_cf_initialization + class FeasibleModelApprox(FeasibleBaseVAE, ExplainerBase): diff --git a/dice_ml/model.py b/dice_ml/model.py index 0518cdf8..737701bf 100644 --- a/dice_ml/model.py +++ b/dice_ml/model.py @@ -4,6 +4,7 @@ frameworks such as Tensorflow or PyTorch. """ import warnings + from dice_ml.constants import BackEndTypes, ModelTypes from dice_ml.utils.exception import UserConfigValidationException @@ -69,7 +70,8 @@ def decide(backend): import tensorflow # noqa: F401 except ImportError: raise UserConfigValidationException("Unable to import tensorflow. Please install tensorflow") - from dice_ml.model_interfaces.keras_tensorflow_model import KerasTensorFlowModel + from dice_ml.model_interfaces.keras_tensorflow_model import \ + KerasTensorFlowModel return KerasTensorFlowModel elif backend == BackEndTypes.Pytorch: diff --git a/dice_ml/model_interfaces/base_model.py b/dice_ml/model_interfaces/base_model.py index 09b49ddd..3a25b5cf 100644 --- a/dice_ml/model_interfaces/base_model.py +++ b/dice_ml/model_interfaces/base_model.py @@ -3,10 +3,12 @@ All model interface methods are in dice_ml.model_interfaces""" import pickle + import numpy as np -from dice_ml.utils.helpers import DataTransfomer + from dice_ml.constants import ModelTypes from dice_ml.utils.exception import SystemException +from dice_ml.utils.helpers import DataTransfomer class BaseModel: @@ -62,7 +64,7 @@ def get_num_output_nodes(self, inp_size): temp_input = np.transpose(np.array([np.random.uniform(0, 1) for i in range(inp_size)]).reshape(-1, 1)) return self.get_output(temp_input).shape[1] - def get_num_output_nodes2(self, input): + def get_num_output_nodes2(self, input_instance): if self.model_type == ModelTypes.Regressor: raise SystemException('Number of output nodes not supported for regression') - return self.get_output(input).shape[1] + return self.get_output(input_instance).shape[1] diff --git a/dice_ml/model_interfaces/keras_tensorflow_model.py b/dice_ml/model_interfaces/keras_tensorflow_model.py index 72619f12..df150850 100644 --- a/dice_ml/model_interfaces/keras_tensorflow_model.py +++ b/dice_ml/model_interfaces/keras_tensorflow_model.py @@ -1,9 +1,10 @@ """Module containing an interface to trained Keras Tensorflow model.""" -from dice_ml.model_interfaces.base_model import BaseModel import tensorflow as tf from tensorflow import keras +from dice_ml.model_interfaces.base_model import BaseModel + class KerasTensorFlowModel(BaseModel): @@ -39,7 +40,7 @@ def get_output(self, input_tensor, training=False, transform_data=False): else: return self.model(input_tensor) - def get_gradient(self, input): + def get_gradient(self, input_instance): # Future Support raise NotImplementedError("Future Support") diff --git a/dice_ml/model_interfaces/pytorch_model.py b/dice_ml/model_interfaces/pytorch_model.py index e0063f5a..9cf577cc 100644 --- a/dice_ml/model_interfaces/pytorch_model.py +++ b/dice_ml/model_interfaces/pytorch_model.py @@ -1,8 +1,9 @@ """Module containing an interface to trained PyTorch model.""" -from dice_ml.model_interfaces.base_model import BaseModel import torch +from dice_ml.model_interfaces.base_model import BaseModel + class PyTorchModel(BaseModel): @@ -37,7 +38,7 @@ def get_output(self, input_tensor, transform_data=False): def set_eval_mode(self): self.model.eval() - def get_gradient(self, input): + def get_gradient(self, input_instance): # Future Support raise NotImplementedError("Future Support") diff --git a/dice_ml/utils/helpers.py b/dice_ml/utils/helpers.py index 866662c3..408248bf 100644 --- a/dice_ml/utils/helpers.py +++ b/dice_ml/utils/helpers.py @@ -1,16 +1,17 @@ """ This module containts helper functions to load data and get meta deta. """ -import numpy as np -import pandas as pd -import shutil import os +import shutil -import dice_ml - +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split # for data transformations from sklearn.preprocessing import FunctionTransformer -from sklearn.model_selection import train_test_split +import dice_ml + + def load_adult_income_dataset(only_train=True): @@ -168,11 +169,11 @@ def get_base_gen_cf_initialization(data_interface, encoded_size, cont_minx, cont wm1, wm2, wm3, learning_rate): # Dice Imports - TODO: keep this method for VAE as a spearate module or move it to feasible_base_vae.py. # Check dependencies. - from dice_ml.utils.sample_architecture.vae_model import CF_VAE - # Pytorch from torch import optim + from dice_ml.utils.sample_architecture.vae_model import CF_VAE + # Dataset for training Variational Encoder Decoder model for CF Generation df = data_interface.normalize_data(data_interface.one_hot_encoded_data) encoded_data = df[data_interface.ohe_encoded_feature_names + [data_interface.outcome_name]] diff --git a/dice_ml/utils/sample_architecture/vae_model.py b/dice_ml/utils/sample_architecture/vae_model.py index 6d461494..3b4be568 100644 --- a/dice_ml/utils/sample_architecture/vae_model.py +++ b/dice_ml/utils/sample_architecture/vae_model.py @@ -109,7 +109,7 @@ def forward(self, x, c): res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples - for i in range(mc_samples): + for _ in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(torch.cat((z, c), 1)) res['z'].append(z) @@ -239,7 +239,7 @@ def forward(self, x): res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples - for i in range(mc_samples): + for _ in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(z) res['z'].append(z) From 25d858c8dbf55e51f18b0614a1b3cc8d6cd649aa Mon Sep 17 00:00:00 2001 From: giandos200 Date: Fri, 14 Jan 2022 10:38:16 +0100 Subject: [PATCH 04/13] import update Signed-off-by: giandos200 --- dice_ml/data_interfaces/public_data_interface.py | 1 + dice_ml/dice.py | 2 +- dice_ml/explainer_interfaces/dice_pytorch.py | 2 +- dice_ml/explainer_interfaces/feasible_base_vae.py | 1 - dice_ml/utils/helpers.py | 1 + 5 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dice_ml/data_interfaces/public_data_interface.py b/dice_ml/data_interfaces/public_data_interface.py index b5c04796..ab3e5ed1 100644 --- a/dice_ml/data_interfaces/public_data_interface.py +++ b/dice_ml/data_interfaces/public_data_interface.py @@ -11,6 +11,7 @@ from dice_ml.utils.exception import (SystemException, UserConfigValidationException) + class PublicData(_BaseData): """A data interface for public data. This class is an interface to DiCE explainers and contains methods to transform user-fed raw data into the format a DiCE explainer diff --git a/dice_ml/dice.py b/dice_ml/dice.py index 8a55240e..d1c78172 100644 --- a/dice_ml/dice.py +++ b/dice_ml/dice.py @@ -3,8 +3,8 @@ such as RandomSampling, DiCEKD or DiCEGenetic""" from dice_ml.constants import BackEndTypes, SamplingStrategy -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.data_interfaces.private_data_interface import PrivateData +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.utils.exception import UserConfigValidationException diff --git a/dice_ml/explainer_interfaces/dice_pytorch.py b/dice_ml/explainer_interfaces/dice_pytorch.py index 91d2f0c3..09d257cc 100644 --- a/dice_ml/explainer_interfaces/dice_pytorch.py +++ b/dice_ml/explainer_interfaces/dice_pytorch.py @@ -4,8 +4,8 @@ import copy import random import timeit -import numpy as np +import numpy as np import torch from dice_ml import diverse_counterfactuals as exp diff --git a/dice_ml/explainer_interfaces/feasible_base_vae.py b/dice_ml/explainer_interfaces/feasible_base_vae.py index 503dbb29..28dc7d13 100644 --- a/dice_ml/explainer_interfaces/feasible_base_vae.py +++ b/dice_ml/explainer_interfaces/feasible_base_vae.py @@ -11,7 +11,6 @@ from dice_ml.utils.helpers import get_base_gen_cf_initialization - class FeasibleBaseVAE(ExplainerBase): def __init__(self, data_interface, model_interface, **kwargs): diff --git a/dice_ml/utils/helpers.py b/dice_ml/utils/helpers.py index 408248bf..e9443452 100644 --- a/dice_ml/utils/helpers.py +++ b/dice_ml/utils/helpers.py @@ -9,6 +9,7 @@ from sklearn.model_selection import train_test_split # for data transformations from sklearn.preprocessing import FunctionTransformer + import dice_ml From 45023274e7f8f5b5137a121de86235a89be09016 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Fri, 14 Jan 2022 10:40:06 +0100 Subject: [PATCH 05/13] import update Signed-off-by: giandos200 --- dice_ml/utils/helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dice_ml/utils/helpers.py b/dice_ml/utils/helpers.py index e9443452..e3ea688f 100644 --- a/dice_ml/utils/helpers.py +++ b/dice_ml/utils/helpers.py @@ -13,8 +13,6 @@ import dice_ml - - def load_adult_income_dataset(only_train=True): """Loads adult income dataset from https://archive.ics.uci.edu/ml/datasets/Adult and prepares the data for data analysis based on https://rpubs.com/H_Zhu/235617 From b07e01744519b8125581c5d7ee081a4b052818b9 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Fri, 14 Jan 2022 12:01:47 +0100 Subject: [PATCH 06/13] notebook updated Signed-off-by: giandos200 --- .../Benchmarking_different_CF_explanation_methods.ipynb | 4 ++-- tests/test_notebooks.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb index fec9ce5f..9786d701 100644 --- a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb +++ b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb @@ -274,7 +274,7 @@ " for q in query_instances:\n", " if q in d.categorical_feature_names:\n", " query_instances.loc[:, q] = \\\n", - " [random.choice(dataset[q].unique()) for _ in query_instances.index]\n", + " [random.choice(dataset[q].values.unique()) for _ in query_instances.index]\n", " else:\n", " query_instances.loc[:, q] = \\\n", " [np.random.uniform(dataset[q].min(), dataset[q].max()) for _ in query_instances.index]\n", @@ -329,4 +329,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index c2d0b3bf..a392d0a8 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -9,7 +9,7 @@ import nbformat import pytest -NOTEBOOKS_PATH = "docs/source/notebooks/" +NOTEBOOKS_PATH = "../docs/source/notebooks/" notebooks_list = [f.name for f in os.scandir(NOTEBOOKS_PATH) if f.name.endswith(".ipynb")] # notebooks that should not be run advanced_notebooks = [ From dc2c6ae548c50b635171c7f84e86bd88d8579e52 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Fri, 14 Jan 2022 13:27:00 +0100 Subject: [PATCH 07/13] all test passed and updated Signed-off-by: giandos200 --- dice_ml/explainer_interfaces/dice_genetic.py | 4 ++-- dice_ml/explainer_interfaces/explainer_base.py | 6 +++--- tests/test_dice_interface/test_explainer_base.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index 3b8514cc..30acf09b 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -470,8 +470,8 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class, if rest_members > 0: new_generation_2 = np.zeros((rest_members, self.data_interface.number_of_features)) for new_gen_idx in range(rest_members): - parent1 = random.choice(population[:int(len(population) / 2)]) - parent2 = random.choice(population[:int(len(population) / 2)]) + parent1 = random.choice(population[:max(int(len(population) / 2),1)]) + parent2 = random.choice(population[:max(int(len(population) / 2),1)]) child = self.mate(parent1, parent2, features_to_vary, query_instance) new_generation_2[new_gen_idx] = child diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index f70d8be5..55b2ea19 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -280,12 +280,12 @@ def global_feature_importance(self, query_instances, cf_examples_list=None, elif any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " - "greater than or equal to 10" + "greater than or equal to 10 " "to compute global feature importance over all query points") elif total_CFs < 10: raise UserConfigValidationException( - "The number of counterfactuals generated per query instance should be greater " - "than or equal to 10" + "The number of counterfactuals requested per query instance should be greater " + "than or equal to 10 " "to compute global feature importance over all query points") importances = self.feature_importance( query_instances, diff --git a/tests/test_dice_interface/test_explainer_base.py b/tests/test_dice_interface/test_explainer_base.py index 90d02cea..cd4009f4 100644 --- a/tests/test_dice_interface/test_explainer_base.py +++ b/tests/test_dice_interface/test_explainer_base.py @@ -181,8 +181,8 @@ def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_qu with pytest.raises( UserConfigValidationException, match="The number of counterfactuals generated per query instance should be " - "greater than or equal to 10 " - "to compute global feature importance over all query points"): + "greater than or equal to 10 " + "to compute global feature importance over all query points"): exp.global_feature_importance( query_instances=None, cf_examples_list=cf_explanations.cf_examples_list) From 7d84a63f4b37e5994b769def61cdf239b3ce83b6 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Fri, 14 Jan 2022 13:55:25 +0100 Subject: [PATCH 08/13] adding signoff Signed-off-by: Giandomenico Cornacchia Signed-off-by: giandos200 --- .../Benchmarking_different_CF_explanation_methods.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb index 9786d701..fec9ce5f 100644 --- a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb +++ b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb @@ -274,7 +274,7 @@ " for q in query_instances:\n", " if q in d.categorical_feature_names:\n", " query_instances.loc[:, q] = \\\n", - " [random.choice(dataset[q].values.unique()) for _ in query_instances.index]\n", + " [random.choice(dataset[q].unique()) for _ in query_instances.index]\n", " else:\n", " query_instances.loc[:, q] = \\\n", " [np.random.uniform(dataset[q].min(), dataset[q].max()) for _ in query_instances.index]\n", @@ -329,4 +329,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file From b046fe01159289121ccfdfe6cd274414a1c9c01a Mon Sep 17 00:00:00 2001 From: Giandomenico Cornacchia <60853532+giandos200@users.noreply.github.com> Date: Fri, 14 Jan 2022 14:14:57 +0100 Subject: [PATCH 09/13] Update Benchmarking_different_CF_explanation_methods.ipynb Signed-off-by: giandos200 --- .../Benchmarking_different_CF_explanation_methods.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb index fec9ce5f..1205796f 100644 --- a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb +++ b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb @@ -329,4 +329,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 25f126355d801b630db04eb67dde1ebc84600dec Mon Sep 17 00:00:00 2001 From: giandos200 Date: Fri, 14 Jan 2022 14:28:02 +0100 Subject: [PATCH 10/13] benchUpdated Signed-off-by: giandos200 --- .../Benchmarking_different_CF_explanation_methods.ipynb | 4 ++-- tests/test_dice_interface/test_explainer_base.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb index 1205796f..ac96571f 100644 --- a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb +++ b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb @@ -274,7 +274,7 @@ " for q in query_instances:\n", " if q in d.categorical_feature_names:\n", " query_instances.loc[:, q] = \\\n", - " [random.choice(dataset[q].unique()) for _ in query_instances.index]\n", + " [random.choice(dataset[q].values.unique()) for _ in query_instances.index]\n", " else:\n", " query_instances.loc[:, q] = \\\n", " [np.random.uniform(dataset[q].min(), dataset[q].max()) for _ in query_instances.index]\n", @@ -329,4 +329,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/tests/test_dice_interface/test_explainer_base.py b/tests/test_dice_interface/test_explainer_base.py index cd4009f4..90d02cea 100644 --- a/tests/test_dice_interface/test_explainer_base.py +++ b/tests/test_dice_interface/test_explainer_base.py @@ -181,8 +181,8 @@ def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_qu with pytest.raises( UserConfigValidationException, match="The number of counterfactuals generated per query instance should be " - "greater than or equal to 10 " - "to compute global feature importance over all query points"): + "greater than or equal to 10 " + "to compute global feature importance over all query points"): exp.global_feature_importance( query_instances=None, cf_examples_list=cf_explanations.cf_examples_list) From dd098bcbe73fd4ba722519b83d2b6c9c3e123645 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Fri, 14 Jan 2022 17:19:55 +0100 Subject: [PATCH 11/13] review Signed-off-by: giandos200 --- dice_ml/explainer_interfaces/dice_KD.py | 9 +-- dice_ml/explainer_interfaces/dice_genetic.py | 8 +-- .../explainer_interfaces/explainer_base.py | 61 +++++++++++-------- tests/test_notebooks.py | 2 +- 4 files changed, 44 insertions(+), 36 deletions(-) diff --git a/dice_ml/explainer_interfaces/dice_KD.py b/dice_ml/explainer_interfaces/dice_KD.py index 61b2220d..618e064d 100644 --- a/dice_ml/explainer_interfaces/dice_KD.py +++ b/dice_ml/explainer_interfaces/dice_KD.py @@ -240,7 +240,8 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig # post-hoc operation on continuous features to enhance sparsity - only for public data if posthoc_sparsity_param is not None and posthoc_sparsity_param > 0 and 'data_df' in self.data_interface.__dict__: self.final_cfs_df_sparse = copy.deepcopy(self.final_cfs) - self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse, query_instance, + self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse, + query_instance, posthoc_sparsity_param, posthoc_sparsity_algorithm) else: @@ -265,9 +266,9 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig 'change the query instance or the features to vary...' '; total time taken: %02d' % m, 'min %02d' % s, 'sec') elif total_cfs_found == 0: - print( - 'No Counterfactuals found for the given configuration, perhaps try with different parameters...', - '; total time taken: %02d' % m, 'min %02d' % s, 'sec') + print( + 'No Counterfactuals found for the given configuration, perhaps try with different parameters...', + '; total time taken: %02d' % m, 'min %02d' % s, 'sec') else: print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec') diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index 30acf09b..3f47cc56 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -150,8 +150,6 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir kx += 1 self.cfs = np.array(row) - #if len(self.cfs) > self.population_size: - # pass if len(self.cfs) != self.population_size: print("Pericolo Loop infinito....!!!!") remaining_cfs = self.do_random_init( @@ -264,7 +262,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k (see diverse_counterfactuals.py). """ - self.population_size = 3 * total_CFs + self.population_size = 10 * total_CFs self.start_time = timeit.default_timer() @@ -470,8 +468,8 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class, if rest_members > 0: new_generation_2 = np.zeros((rest_members, self.data_interface.number_of_features)) for new_gen_idx in range(rest_members): - parent1 = random.choice(population[:max(int(len(population) / 2),1)]) - parent2 = random.choice(population[:max(int(len(population) / 2),1)]) + parent1 = random.choice(population[:max(int(len(population) / 2), 1)]) + parent2 = random.choice(population[:max(int(len(population) / 2), 1)]) child = self.mate(parent1, parent2, features_to_vary, query_instance) new_generation_2[new_gen_idx] = child diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 55b2ea19..aeac3d7c 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -82,21 +82,22 @@ def generate_counterfactuals(self, query_instances, total_CFs, raise UserConfigValidationException( "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") if total_CFs > 10: - if posthoc_sparsity_algorithm == None: + if posthoc_sparsity_algorithm is None: posthoc_sparsity_algorithm = 'binary' - elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear': + elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear': import warnings - warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; " - "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to " - "'binary' search!".format(total_CFs)) - elif posthoc_sparsity_algorithm == None: + warnings.warn( + "The number of counterfactuals (total_CFs={}) generated per query instance could take much time; " + "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to " + "'binary' search!".format(total_CFs)) + elif posthoc_sparsity_algorithm is None: posthoc_sparsity_algorithm = 'linear' cf_examples_arr = [] query_instances_list = [] if isinstance(query_instances, pd.DataFrame): for ix in range(query_instances.shape[0]): - query_instances_list.append(query_instances[ix:(ix+1)]) + query_instances_list.append(query_instances[ix:(ix + 1)]) elif isinstance(query_instances, Iterable): query_instances_list = query_instances @@ -190,11 +191,14 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query if feature not in features_to_vary and permitted_range is not None: if feature in permitted_range and feature in self.data_interface.continuous_feature_names: - if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][1]: - raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.") + if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][ + 1]: + raise ValueError("Feature:", feature, + "is outside the permitted range and isn't allowed to vary.") elif feature in permitted_range and feature in self.data_interface.categorical_feature_names: if query_instance[feature].values[0] not in self.feature_range[feature]: - raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.") + raise ValueError("Feature:", feature, + "is outside the permitted range and isn't allowed to vary.") def local_feature_importance(self, query_instances, cf_examples_list=None, total_CFs=10, @@ -440,12 +444,13 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post cfs_preds_sparse = [] for cf_ix in list(final_cfs_sparse.index): - current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) + current_pred = self.predict_fn_for_sparsity( + final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) for feature in features_sorted: # current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names]) # feat_ix = self.data_interface.continuous_feature_names.index(feature) diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) - if(abs(diff) <= quantiles[feature]): + if (abs(diff) <= quantiles[feature]): if posthoc_sparsity_algorithm == "linear": final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix, feature, final_cfs_sparse, current_pred) @@ -466,13 +471,14 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f query_instance greedily until the prediction class changes.""" old_diff = diff - change = (10**-decimal_prec[feature]) # the minimal possible change for a feature + change = (10 ** -decimal_prec[feature]) # the minimal possible change for a feature current_pred = current_pred_orig if self.model.model_type == ModelTypes.Classifier: - while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)): + while ((abs(diff) > 10e-4) and (np.sign(diff * old_diff) > 0) and self.is_cf_valid(current_pred)): old_val = int(final_cfs_sparse.at[cf_ix, feature]) - final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change - current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) + final_cfs_sparse.at[cf_ix, feature] += np.sign(diff) * change + current_pred = self.predict_fn_for_sparsity( + final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) old_diff = diff if not self.is_cf_valid(current_pred): @@ -505,11 +511,12 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f right = query_instance[feature].iat[0] while left <= right: - current_val = left + ((right - left)/2) + current_val = left + ((right - left) / 2) current_val = round(current_val, decimal_prec[feature]) final_cfs_sparse.at[cf_ix, feature] = current_val - current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) + current_pred = self.predict_fn_for_sparsity( + final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) if current_val == right or current_val == left: break @@ -524,19 +531,20 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f right = int(final_cfs_sparse.at[cf_ix, feature]) while right >= left: - current_val = right - ((right - left)/2) + current_val = right - ((right - left) / 2) current_val = round(current_val, decimal_prec[feature]) final_cfs_sparse.at[cf_ix, feature] = current_val - current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) + current_pred = self.predict_fn_for_sparsity( + final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) if current_val == right or current_val == left: break if self.is_cf_valid(current_pred): - right = current_val - (10**-decimal_prec[feature]) + right = current_val - (10 ** -decimal_prec[feature]) else: - left = current_val + (10**-decimal_prec[feature]) + left = current_val + (10 ** -decimal_prec[feature]) return final_cfs_sparse @@ -578,7 +586,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_ raise UserConfigValidationException("Desired class not present in training data!") else: raise UserConfigValidationException("The target class for {0} could not be identified".format( - desired_class_input)) + desired_class_input)) def infer_target_cfs_range(self, desired_range_input): target_range = None @@ -597,7 +605,7 @@ def decide_cf_validity(self, model_outputs): pred = model_outputs[i] if self.model.model_type == ModelTypes.Classifier: if self.num_output_nodes == 2: # binary - pred_1 = pred[self.num_output_nodes-1] + pred_1 = pred[self.num_output_nodes - 1] validity[i] = 1 if \ ((self.target_cf_class == 0 and pred_1 <= self.stopping_threshold) or (self.target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else 0 @@ -634,7 +642,7 @@ def is_cf_valid(self, model_score): (target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False return validity if self.num_output_nodes == 2: # binary - pred_1 = model_score[self.num_output_nodes-1] + pred_1 = model_score[self.num_output_nodes - 1] validity = True if \ ((target_cf_class == 0 and pred_1 <= self.stopping_threshold) or (target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False @@ -710,7 +718,8 @@ def round_to_precision(self): for ix, feature in enumerate(self.data_interface.continuous_feature_names): self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix]) if self.final_cfs_df_sparse is not None: - self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix]) + self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round( + precisions[ix]) def _check_any_counterfactuals_computed(self, cf_examples_arr): """Check if any counterfactuals were generated for any query point.""" diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index a392d0a8..c2d0b3bf 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -9,7 +9,7 @@ import nbformat import pytest -NOTEBOOKS_PATH = "../docs/source/notebooks/" +NOTEBOOKS_PATH = "docs/source/notebooks/" notebooks_list = [f.name for f in os.scandir(NOTEBOOKS_PATH) if f.name.endswith(".ipynb")] # notebooks that should not be run advanced_notebooks = [ From ef3de5687809e25c28292c97c3269bff9a96530f Mon Sep 17 00:00:00 2001 From: giandos200 Date: Sat, 15 Jan 2022 14:56:24 +0100 Subject: [PATCH 12/13] flake8 E125/W292 bestpractice reviewed Signed-off-by: giandos200 --- dice_ml/explainer_interfaces/explainer_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index aeac3d7c..e276171f 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -191,8 +191,8 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query if feature not in features_to_vary and permitted_range is not None: if feature in permitted_range and feature in self.data_interface.continuous_feature_names: - if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][ - 1]: + if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][\ + 1]: raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.") elif feature in permitted_range and feature in self.data_interface.categorical_feature_names: @@ -731,4 +731,4 @@ def _check_any_counterfactuals_computed(self, cf_examples_arr): break if no_cf_generated: raise UserConfigValidationException( - "No counterfactuals found for any of the query points! Kindly check your configuration.") \ No newline at end of file + "No counterfactuals found for any of the query points! Kindly check your configuration.") From b12c369a8327e876f4e14cdc2d01e2f53e471b46 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Mon, 24 Jan 2022 14:26:44 +0100 Subject: [PATCH 13/13] update Signed-off-by: giandos200 --- dice_ml/explainer_interfaces/dice_genetic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index 3f47cc56..f3d19860 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -151,7 +151,6 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir self.cfs = np.array(row) if len(self.cfs) != self.population_size: - print("Pericolo Loop infinito....!!!!") remaining_cfs = self.do_random_init( self.population_size - len(self.cfs), features_to_vary, query_instance, desired_class, desired_range) self.cfs = np.concatenate([self.cfs, remaining_cfs])