diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index c3a5dcea..17f7da83 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -1,8 +1,12 @@ import json import os + import jsonschema +import numpy as np +import pandas as pd from raiutils.exceptions import UserConfigValidationException +from counterplots import CreatePlot from dice_ml.constants import _SchemaVersions from dice_ml.diverse_counterfactuals import (CounterfactualExamples, @@ -52,11 +56,12 @@ class CounterfactualExplanations: based on the input set of CounterfactualExamples instances """ - def __init__(self, cf_examples_list, + def __init__(self, cf_examples_list, predict_fn, local_importance=None, summary_importance=None, version=None): self._cf_examples_list = cf_examples_list + self._predict_fn = predict_fn self._local_importance = local_importance self._summary_importance = summary_importance self._metadata = {'version': version if version is not None else _SchemaVersions.CURRENT_VERSION} @@ -240,6 +245,7 @@ def from_json(json_str): return CounterfactualExplanations( cf_examples_list=cf_examples_list, + predict_fn=None, local_importance=json_dict[_CounterfactualExpV1SchemaConstants.LOCAL_IMPORTANCE], summary_importance=json_dict[_CounterfactualExpV1SchemaConstants.SUMMARY_IMPORTANCE], version=version) @@ -295,8 +301,50 @@ def from_json(json_str): return CounterfactualExplanations( cf_examples_list=cf_examples_list, + predict_fn=None, local_importance=local_importance_list, summary_importance=summary_importance_dict, version=version) else: return json_dict + + def generete_counterplots(self, predict_fn=None): + + if self._predict_fn is None and predict_fn is None: + raise ValueError("predict_fn is required to generate counterplots") + elif predict_fn is not None: + self._predict_fn = predict_fn + + cf_data = json.loads(self.to_json()) + factual = self.cf_examples_list[0].test_instance_df.to_numpy()[0][:-1] + feature_names = list(self.cf_examples_list[0].test_instance_df.columns)[:-1] + df_structure = self.cf_examples_list[0].test_instance_df[:0].loc[:, feature_names] + data_types = df_structure.dtypes.apply(lambda x: x.name).to_dict() + + factual_class_name = str(self.cf_examples_list[0].test_pred) + cf_class_name = str(self.cf_examples_list[0].new_outcome) + + def adjust_types(x): + for i in range(x.shape[1]): + if 'int' in list(data_types.values())[i]: + x[:, i] = int(float(x[:, i])) + return x + + def model_pred(x): + scores = self._predict_fn(df_structure.append(pd.DataFrame(x, columns=feature_names))).numpy() + return np.concatenate((1 - scores, scores), axis=1) + + out_exp = [] + + for raw_cf in cf_data['cfs_list'][0]: + cf = adjust_types(np.array([raw_cf[:-1]]))[0] + + out_exp.append(CreatePlot( + factual=np.array(factual), + cf=np.array(cf), + model_pred=model_pred, + feature_names=feature_names, + class_names={0: factual_class_name, 1: cf_class_name})) + + return out_exp + diff --git a/dice_ml/explainer_interfaces/dice_tensorflow1.py b/dice_ml/explainer_interfaces/dice_tensorflow1.py index df17b561..e81305fb 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow1.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow1.py @@ -167,7 +167,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp posthoc_sparsity_param=posthoc_sparsity_param, desired_class=desired_class) - return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations]) + return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations], predict_fn=self.predict_fn) def do_cf_initializations(self, total_CFs, algorithm, features_to_vary): """Intializes TF variables required for CF generation.""" diff --git a/dice_ml/explainer_interfaces/dice_tensorflow2.py b/dice_ml/explainer_interfaces/dice_tensorflow2.py index aca80ade..79b54891 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow2.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow2.py @@ -144,7 +144,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp posthoc_sparsity_param=posthoc_sparsity_param, desired_class=desired_class) - return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations]) + return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations], predict_fn=self.predict_fn) def predict_fn(self, input_instance): """prediction function""" diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index ad231e69..40239f80 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -172,7 +172,7 @@ def generate_counterfactuals(self, query_instances, total_CFs, cf_examples_arr.append(res) self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr) - return CounterfactualExplanations(cf_examples_list=cf_examples_arr) + return CounterfactualExplanations(cf_examples_list=cf_examples_arr, predict_fn=self.predict_fn) @abstractmethod def _generate_counterfactuals(self, query_instance, total_CFs, @@ -475,6 +475,7 @@ def feature_importance(self, query_instances, cf_examples_list=None, return CounterfactualExplanations( cf_examples_list, + predict_fn=self.predict_fn, local_importance=local_importances, summary_importance=summary_importance) diff --git a/requirements.txt b/requirements.txt index 7f89d1f4..8f92b182 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pandas<2.0.0 scikit-learn tqdm raiutils>=0.4.0 +counterplots diff --git a/tests/test_counterfactual_explanations.py b/tests/test_counterfactual_explanations.py index 4dcb5628..85a16092 100644 --- a/tests/test_counterfactual_explanations.py +++ b/tests/test_counterfactual_explanations.py @@ -3,11 +3,17 @@ import pytest from raiutils.exceptions import UserConfigValidationException + +import numpy as np +import dice_ml from dice_ml.counterfactual_explanations import CounterfactualExplanations class TestCounterfactualExplanations: + def __init__(self): + self.predict_fn = lambda x: np.array([1.0 for _ in range(np.array(x).shape[0])]) + def test_sorted_summary_importance_counterfactual_explanations(self): unsorted_summary_importance = { @@ -34,6 +40,7 @@ def test_sorted_summary_importance_counterfactual_explanations(self): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], + predict_fn=self.predict_fn, local_importance=None, summary_importance=unsorted_summary_importance) @@ -94,6 +101,7 @@ def test_sorted_local_importance_counterfactual_explanations(self): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], local_importance=unsorted_local_importance, + predict_fn=self.predict_fn, summary_importance=None) for index in range(0, len(unsorted_local_importance)): @@ -108,6 +116,10 @@ def test_sorted_local_importance_counterfactual_explanations(self): class TestSerializationCounterfactualExplanations: + + def __init__(self): + self.predict_fn = lambda x: np.array([1.0 for _ in range(np.array(x).shape[0])]) + @pytest.fixture(autouse=True) def _initiate_exp_object(self, binary_classification_exp_object): self.exp = binary_classification_exp_object # explainer object @@ -211,6 +223,7 @@ def test_empty_counterfactual_explanations_object(self, version): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], + predict_fn=self.predict_fn, local_importance=None, summary_importance=None, version=version) @@ -311,6 +324,7 @@ def test_unsupported_versions_from_json(self, unsupported_version): def test_unsupported_versions_to_json(self, unsupported_version): counterfactual_explanations = CounterfactualExplanations( cf_examples_list=[], + predict_fn=self.predict_fn, local_importance=None, summary_importance=None, version=unsupported_version)