Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion dice_ml/counterfactual_explanations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import json
import os


import jsonschema
import numpy as np
import pandas as pd
from raiutils.exceptions import UserConfigValidationException
from counterplots import CreatePlot

from dice_ml.constants import _SchemaVersions
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
Expand Down Expand Up @@ -52,11 +56,12 @@ class CounterfactualExplanations:
based on the input set of CounterfactualExamples instances

"""
def __init__(self, cf_examples_list,
def __init__(self, cf_examples_list, predict_fn,
local_importance=None,
summary_importance=None,
version=None):
self._cf_examples_list = cf_examples_list
self._predict_fn = predict_fn
self._local_importance = local_importance
self._summary_importance = summary_importance
self._metadata = {'version': version if version is not None else _SchemaVersions.CURRENT_VERSION}
Expand Down Expand Up @@ -240,6 +245,7 @@ def from_json(json_str):

return CounterfactualExplanations(
cf_examples_list=cf_examples_list,
predict_fn=None,
local_importance=json_dict[_CounterfactualExpV1SchemaConstants.LOCAL_IMPORTANCE],
summary_importance=json_dict[_CounterfactualExpV1SchemaConstants.SUMMARY_IMPORTANCE],
version=version)
Expand Down Expand Up @@ -295,8 +301,50 @@ def from_json(json_str):

return CounterfactualExplanations(
cf_examples_list=cf_examples_list,
predict_fn=None,
local_importance=local_importance_list,
summary_importance=summary_importance_dict,
version=version)
else:
return json_dict

def generete_counterplots(self, predict_fn=None):

if self._predict_fn is None and predict_fn is None:
raise ValueError("predict_fn is required to generate counterplots")
elif predict_fn is not None:
self._predict_fn = predict_fn

cf_data = json.loads(self.to_json())
factual = self.cf_examples_list[0].test_instance_df.to_numpy()[0][:-1]
feature_names = list(self.cf_examples_list[0].test_instance_df.columns)[:-1]
df_structure = self.cf_examples_list[0].test_instance_df[:0].loc[:, feature_names]
data_types = df_structure.dtypes.apply(lambda x: x.name).to_dict()

factual_class_name = str(self.cf_examples_list[0].test_pred)
cf_class_name = str(self.cf_examples_list[0].new_outcome)

def adjust_types(x):
for i in range(x.shape[1]):
if 'int' in list(data_types.values())[i]:
x[:, i] = int(float(x[:, i]))
return x

def model_pred(x):
scores = self._predict_fn(df_structure.append(pd.DataFrame(x, columns=feature_names))).numpy()
return np.concatenate((1 - scores, scores), axis=1)

out_exp = []

for raw_cf in cf_data['cfs_list'][0]:
cf = adjust_types(np.array([raw_cf[:-1]]))[0]

out_exp.append(CreatePlot(
factual=np.array(factual),
cf=np.array(cf),
model_pred=model_pred,
feature_names=feature_names,
class_names={0: factual_class_name, 1: cf_class_name}))

return out_exp

2 changes: 1 addition & 1 deletion dice_ml/explainer_interfaces/dice_tensorflow1.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
posthoc_sparsity_param=posthoc_sparsity_param,
desired_class=desired_class)

return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])
return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations], predict_fn=self.predict_fn)

def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
"""Intializes TF variables required for CF generation."""
Expand Down
2 changes: 1 addition & 1 deletion dice_ml/explainer_interfaces/dice_tensorflow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp
posthoc_sparsity_param=posthoc_sparsity_param,
desired_class=desired_class)

return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations])
return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations], predict_fn=self.predict_fn)

def predict_fn(self, input_instance):
"""prediction function"""
Expand Down
3 changes: 2 additions & 1 deletion dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
cf_examples_arr.append(res)
self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)

return CounterfactualExplanations(cf_examples_list=cf_examples_arr)
return CounterfactualExplanations(cf_examples_list=cf_examples_arr, predict_fn=self.predict_fn)

@abstractmethod
def _generate_counterfactuals(self, query_instance, total_CFs,
Expand Down Expand Up @@ -475,6 +475,7 @@ def feature_importance(self, query_instances, cf_examples_list=None,

return CounterfactualExplanations(
cf_examples_list,
predict_fn=self.predict_fn,
local_importance=local_importances,
summary_importance=summary_importance)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pandas<2.0.0
scikit-learn
tqdm
raiutils>=0.4.0
counterplots
14 changes: 14 additions & 0 deletions tests/test_counterfactual_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import pytest
from raiutils.exceptions import UserConfigValidationException


import numpy as np
import dice_ml
from dice_ml.counterfactual_explanations import CounterfactualExplanations


class TestCounterfactualExplanations:

def __init__(self):
self.predict_fn = lambda x: np.array([1.0 for _ in range(np.array(x).shape[0])])

def test_sorted_summary_importance_counterfactual_explanations(self):

unsorted_summary_importance = {
Expand All @@ -34,6 +40,7 @@ def test_sorted_summary_importance_counterfactual_explanations(self):

counterfactual_explanations = CounterfactualExplanations(
cf_examples_list=[],
predict_fn=self.predict_fn,
local_importance=None,
summary_importance=unsorted_summary_importance)

Expand Down Expand Up @@ -94,6 +101,7 @@ def test_sorted_local_importance_counterfactual_explanations(self):
counterfactual_explanations = CounterfactualExplanations(
cf_examples_list=[],
local_importance=unsorted_local_importance,
predict_fn=self.predict_fn,
summary_importance=None)

for index in range(0, len(unsorted_local_importance)):
Expand All @@ -108,6 +116,10 @@ def test_sorted_local_importance_counterfactual_explanations(self):


class TestSerializationCounterfactualExplanations:

def __init__(self):
self.predict_fn = lambda x: np.array([1.0 for _ in range(np.array(x).shape[0])])

@pytest.fixture(autouse=True)
def _initiate_exp_object(self, binary_classification_exp_object):
self.exp = binary_classification_exp_object # explainer object
Expand Down Expand Up @@ -211,6 +223,7 @@ def test_empty_counterfactual_explanations_object(self, version):

counterfactual_explanations = CounterfactualExplanations(
cf_examples_list=[],
predict_fn=self.predict_fn,
local_importance=None,
summary_importance=None,
version=version)
Expand Down Expand Up @@ -311,6 +324,7 @@ def test_unsupported_versions_from_json(self, unsupported_version):
def test_unsupported_versions_to_json(self, unsupported_version):
counterfactual_explanations = CounterfactualExplanations(
cf_examples_list=[],
predict_fn=self.predict_fn,
local_importance=None,
summary_importance=None,
version=unsupported_version)
Expand Down