From 6ddd631d1cf45952bac4b0ff23d95cac557e0e55 Mon Sep 17 00:00:00 2001 From: rmazzine Date: Fri, 16 Jun 2023 00:30:59 +0200 Subject: [PATCH 1/2] Implement CounterPlots to DiCE CF generation This commit adds the CounterPlots package to the DiCE CF generation package. Basically, the `CounterfactualExplanations` object was changed to now include a new method `generate_counterplots` which takes factual, counterfactual, and DataFrame info to generate the counterfactual analysis. Also, this object now asks for a prediction function, therefore, the `explainer_base`, `dice_tensorflow1`, and `dice_tensorflow2` now include this function to create the object. --- dice_ml/counterfactual_explanations.py | 42 ++++++++++++++++++- .../explainer_interfaces/dice_tensorflow1.py | 2 +- .../explainer_interfaces/dice_tensorflow2.py | 2 +- .../explainer_interfaces/explainer_base.py | 2 +- requirements.txt | 1 + 5 files changed, 45 insertions(+), 4 deletions(-) diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index c3a5dcea..ebae8832 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -1,8 +1,12 @@ import json import os + import jsonschema +import numpy as np +import pandas as pd from raiutils.exceptions import UserConfigValidationException +from counterplots import CreatePlot from dice_ml.constants import _SchemaVersions from dice_ml.diverse_counterfactuals import (CounterfactualExamples, @@ -52,11 +56,12 @@ class CounterfactualExplanations: based on the input set of CounterfactualExamples instances """ - def __init__(self, cf_examples_list, + def __init__(self, cf_examples_list, predict_fn, local_importance=None, summary_importance=None, version=None): self._cf_examples_list = cf_examples_list + self._predict_fn = predict_fn self._local_importance = local_importance self._summary_importance = summary_importance self._metadata = {'version': version if version is not None else _SchemaVersions.CURRENT_VERSION} @@ -300,3 +305,38 @@ def from_json(json_str): version=version) else: return json_dict + + def generete_counterplots(self): + cf_data = json.loads(self.to_json()) + factual = self.cf_examples_list[0].test_instance_df.to_numpy()[0][:-1] + feature_names = list(self.cf_examples_list[0].test_instance_df.columns)[:-1] + df_structure = self.cf_examples_list[0].test_instance_df[:0].loc[:, feature_names] + data_types = df_structure.dtypes.apply(lambda x: x.name).to_dict() + + factual_class_name = str(self.cf_examples_list[0].test_pred) + cf_class_name = str(self.cf_examples_list[0].new_outcome) + + def adjust_types(x): + for i in range(x.shape[1]): + if 'int' in list(data_types.values())[i]: + x[:, i] = int(float(x[:, i])) + return x + + def model_pred(x): + scores = self._predict_fn(df_structure.append(pd.DataFrame(x, columns=feature_names))).numpy() + return np.concatenate((1 - scores, scores), axis=1) + + out_exp = [] + + for raw_cf in cf_data['cfs_list'][0]: + cf = adjust_types(np.array([raw_cf[:-1]]))[0] + + out_exp.append(CreatePlot( + factual=np.array(factual), + cf=np.array(cf), + model_pred=model_pred, + feature_names=feature_names, + class_names={0: factual_class_name, 1: cf_class_name})) + + return out_exp + diff --git a/dice_ml/explainer_interfaces/dice_tensorflow1.py b/dice_ml/explainer_interfaces/dice_tensorflow1.py index 1e25f3ee..f61c7e89 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow1.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow1.py @@ -167,7 +167,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp posthoc_sparsity_param=posthoc_sparsity_param, desired_class=desired_class) - return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations]) + return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations], predict_fn=self.predict_fn) def do_cf_initializations(self, total_CFs, algorithm, features_to_vary): """Intializes TF variables required for CF generation.""" diff --git a/dice_ml/explainer_interfaces/dice_tensorflow2.py b/dice_ml/explainer_interfaces/dice_tensorflow2.py index 8004a341..03f9c04c 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow2.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow2.py @@ -144,7 +144,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp posthoc_sparsity_param=posthoc_sparsity_param, desired_class=desired_class) - return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations]) + return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations], predict_fn=self.predict_fn) def predict_fn(self, input_instance): """prediction function""" diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 9926bb73..6562d1c4 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -172,7 +172,7 @@ def generate_counterfactuals(self, query_instances, total_CFs, cf_examples_arr.append(res) self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr) - return CounterfactualExplanations(cf_examples_list=cf_examples_arr) + return CounterfactualExplanations(cf_examples_list=cf_examples_arr, predict_fn=self.predict_fn) @abstractmethod def _generate_counterfactuals(self, query_instance, total_CFs, diff --git a/requirements.txt b/requirements.txt index 7f89d1f4..8f92b182 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pandas<2.0.0 scikit-learn tqdm raiutils>=0.4.0 +counterplots From bbf511d0e3ed002deb2db718ed6418a8fdc790da Mon Sep 17 00:00:00 2001 From: rmazzine Date: Fri, 16 Jun 2023 12:49:38 +0200 Subject: [PATCH 2/2] Fix CounterfactualExplanations calls --- dice_ml/counterfactual_explanations.py | 10 +++++++++- dice_ml/explainer_interfaces/explainer_base.py | 1 + tests/test_counterfactual_explanations.py | 13 +++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index ebae8832..17f7da83 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -245,6 +245,7 @@ def from_json(json_str): return CounterfactualExplanations( cf_examples_list=cf_examples_list, + predict_fn=None, local_importance=json_dict[_CounterfactualExpV1SchemaConstants.LOCAL_IMPORTANCE], summary_importance=json_dict[_CounterfactualExpV1SchemaConstants.SUMMARY_IMPORTANCE], version=version) @@ -300,13 +301,20 @@ def from_json(json_str): return CounterfactualExplanations( cf_examples_list=cf_examples_list, + predict_fn=None, local_importance=local_importance_list, summary_importance=summary_importance_dict, version=version) else: return json_dict - def generete_counterplots(self): + def generete_counterplots(self, predict_fn=None): + + if self._predict_fn is None and predict_fn is None: + raise ValueError("predict_fn is required to generate counterplots") + elif predict_fn is not None: + self._predict_fn = predict_fn + cf_data = json.loads(self.to_json()) factual = self.cf_examples_list[0].test_instance_df.to_numpy()[0][:-1] feature_names = list(self.cf_examples_list[0].test_instance_df.columns)[:-1] diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 6562d1c4..929872bb 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -475,6 +475,7 @@ def feature_importance(self, query_instances, cf_examples_list=None, return CounterfactualExplanations( cf_examples_list, + predict_fn=self.predict_fn, local_importance=local_importances, summary_importance=summary_importance) diff --git a/tests/test_counterfactual_explanations.py b/tests/test_counterfactual_explanations.py index 65a0b013..efddf3df 100644 --- a/tests/test_counterfactual_explanations.py +++ b/tests/test_counterfactual_explanations.py @@ -3,6 +3,8 @@ import pytest from raiutils.exceptions import UserConfigValidationException +import numpy as np + import dice_ml from dice_ml.counterfactual_explanations import CounterfactualExplanations from dice_ml.utils import helpers @@ -10,6 +12,9 @@ class TestCounterfactualExplanations: + def __init__(self): + self.predict_fn = lambda x: np.array([1.0 for _ in range(np.array(x).shape[0])]) + def test_sorted_summary_importance_counterfactual_explanations(self): unsorted_summary_importance = { @@ -36,6 +41,7 @@ def test_sorted_summary_importance_counterfactual_explanations(self): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], + predict_fn=self.predict_fn, local_importance=None, summary_importance=unsorted_summary_importance) @@ -96,6 +102,7 @@ def test_sorted_local_importance_counterfactual_explanations(self): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], local_importance=unsorted_local_importance, + predict_fn=self.predict_fn, summary_importance=None) for index in range(0, len(unsorted_local_importance)): @@ -121,6 +128,10 @@ def random_binary_classification_exp_object(): class TestSerializationCounterfactualExplanations: + + def __init__(self): + self.predict_fn = lambda x: np.array([1.0 for _ in range(np.array(x).shape[0])]) + @pytest.fixture(autouse=True) def _initiate_exp_object(self, random_binary_classification_exp_object): self.exp = random_binary_classification_exp_object # explainer object @@ -224,6 +235,7 @@ def test_empty_counterfactual_explanations_object(self, version): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], + predict_fn=self.predict_fn, local_importance=None, summary_importance=None, version=version) @@ -324,6 +336,7 @@ def test_unsupported_versions_from_json(self, unsupported_version): def test_unsupported_versions_to_json(self, unsupported_version): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], + predict_fn=self.predict_fn, local_importance=None, summary_importance=None, version=unsupported_version)