Skip to content

Commit a197742

Browse files
authored
Merge pull request #100 from rluss/master
CEM updates
2 parents d847420 + d01d100 commit a197742

File tree

4 files changed

+1196
-4344
lines changed

4 files changed

+1196
-4344
lines changed

aix360/algorithms/contrastive/CEM.py

Lines changed: 110 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,110 @@
1-
from __future__ import print_function
2-
3-
from aix360.algorithms.lwbe import LocalWBExplainer
4-
5-
from .CEM_aen import AEADEN
6-
7-
import random
8-
import numpy as np
9-
10-
11-
class CEMExplainer(LocalWBExplainer):
12-
"""
13-
CEMExplainer can be used to compute contrastive explanations for image and tabular data.
14-
This is achieved by finding what is minimally sufficient (PP - Pertinent Positive) and
15-
what should be necessarily absent (PN - Pertinent Negative) to maintain the original classification.
16-
We use elastic norm regularization to ensure minimality for both parts of the explanation
17-
i.e. PPs and PNs. An autoencoder can optionally be used to make the explanations more realistic. [#]_
18-
19-
References:
20-
.. [#] `Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen Tu,
21-
Paishun Ting, Karthikeyan Shanmugam, Payel Das, "Explanations based on
22-
the Missing: Towards Contrastive Explanations with Pertinent Negatives,"
23-
Advances in Neural Information Processing Systems (NeurIPS), 2018.
24-
<https://arxiv.org/abs/1802.07623>`_
25-
"""
26-
def __init__(self, model):
27-
28-
"""
29-
Constructor method, initializes the explainer
30-
31-
Args:
32-
model: KerasClassifier model whose predictions needs to be explained
33-
"""
34-
super(CEMExplainer, self).__init__()
35-
self._wbmodel = model
36-
37-
38-
def set_params(self, *argv, **kwargs):
39-
"""
40-
Set parameters for the explainer.
41-
"""
42-
pass
43-
44-
45-
def explain_instance(self, input_X,
46-
arg_mode, AE_model, arg_kappa, arg_b,
47-
arg_max_iter, arg_init_const, arg_beta, arg_gamma):
48-
49-
"""
50-
Explains an input instance input_X and returns contrastive explanations.
51-
Note that this assumes that the classifier was trained with inputs normalized in [-0.5,0.5] range.
52-
53-
Args:
54-
input_X (numpy.ndarray): input instance to be explained
55-
arg_mode (str): 'PP' or 'PN'
56-
AE_model: Auto-encoder model
57-
arg_kappa (double): Confidence gap between desired class and other classes
58-
arg_b (double): Number of different weightings of loss function to try
59-
arg_max_iter (int): For each weighting of loss function number of iterations to search
60-
arg_init_const (double): Initial weighting of loss function
61-
arg_beta (double): Weighting of L1 loss
62-
arg_gamma (double): Weighting of auto-encoder
63-
64-
Returns:
65-
tuple:
66-
* **adv_X** (`numpy ndarray`) -- Perturbed input instance for PP/PN
67-
* **delta_X** (`numpy ndarray`) -- Difference between input and Perturbed instance
68-
* **INFO** (`str`) -- Other information about PP/PN
69-
"""
70-
71-
random.seed(121)
72-
np.random.seed(1211)
73-
74-
(_, orig_class, orig_prob_str) = self._wbmodel.predict_long(input_X)
75-
target_label = orig_class
76-
77-
target = np.array([np.eye(self._wbmodel._nb_classes)[target_label]])
78-
79-
# Hard coding batch_size=1
80-
batch_size = 1
81-
82-
# Example: for MNIST (1, 28, 28, 1), for tabular (1, no of columns)
83-
shape = input_X.shape
84-
85-
attack = AEADEN(self._wbmodel, shape,
86-
mode=arg_mode, AE=AE_model, batch_size=batch_size,
87-
kappa=arg_kappa, init_learning_rate=1e-2,
88-
binary_search_steps=arg_b, max_iterations=arg_max_iter,
89-
initial_const=arg_init_const, beta=arg_beta, gamma=arg_gamma)
90-
91-
adv_X = attack.attack(input_X, target)
92-
93-
adv_prob, adv_class, adv_prob_str = self._wbmodel.predict_long(adv_X)
94-
95-
delta_X = input_X - adv_X
96-
97-
_, delta_class, delta_prob_str = self._wbmodel.predict_long(delta_X)
98-
99-
INFO = "[INFO]kappa:{}, Orig class:{}, Perturbed class:{}, Delta class: {}, Orig prob:{}, Perturbed prob:{}, Delta prob:{}".format(
100-
arg_kappa, orig_class, adv_class, delta_class, orig_prob_str, adv_prob_str, delta_prob_str)
101-
102-
return (adv_X, delta_X, INFO)
1+
from __future__ import print_function
2+
3+
from aix360.algorithms.lwbe import LocalWBExplainer
4+
5+
from .CEM_aen import AEADEN
6+
7+
import random
8+
import numpy as np
9+
10+
11+
class CEMExplainer(LocalWBExplainer):
12+
"""
13+
CEMExplainer can be used to compute contrastive explanations for image and tabular data.
14+
This is achieved by finding what is minimally sufficient (PP - Pertinent Positive) and
15+
what should be necessarily absent (PN - Pertinent Negative) to maintain the original classification.
16+
We use elastic norm regularization to ensure minimality for both parts of the explanation
17+
i.e. PPs and PNs. An autoencoder can optionally be used to make the explanations more realistic. [#]_
18+
19+
References:
20+
.. [#] `Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen Tu,
21+
Paishun Ting, Karthikeyan Shanmugam, Payel Das, "Explanations based on
22+
the Missing: Towards Contrastive Explanations with Pertinent Negatives,"
23+
Advances in Neural Information Processing Systems (NeurIPS), 2018.
24+
<https://arxiv.org/abs/1802.07623>`_
25+
"""
26+
def __init__(self, model):
27+
28+
"""
29+
Constructor method, initializes the explainer
30+
31+
Args:
32+
model: KerasClassifier model whose predictions needs to be explained
33+
"""
34+
super(CEMExplainer, self).__init__()
35+
self._wbmodel = model
36+
37+
38+
def set_params(self, *argv, **kwargs):
39+
"""
40+
Set parameters for the explainer.
41+
"""
42+
pass
43+
44+
45+
def explain_instance(self, input_X,
46+
arg_mode, AE_model, arg_kappa, arg_b,
47+
arg_max_iter, arg_init_const, arg_beta, arg_gamma, arg_alpha=0, arg_threshold=1, arg_offset=0):
48+
49+
"""
50+
Explains an input instance input_X and returns contrastive explanations.
51+
Note that this assumes that the classifier was trained with inputs normalized in [-arg_offset, arg_offset] range, where arg_offset is 0 or .5.
52+
53+
Args:
54+
input_X (numpy.ndarray): input instance to be explained
55+
arg_mode (str): 'PP' or 'PN'
56+
AE_model: Auto-encoder model
57+
arg_kappa (double): Confidence gap between desired class and other classes
58+
arg_b (double): Number of different weightings of loss function to try
59+
arg_max_iter (int): For each weighting of loss function number of iterations to search
60+
arg_init_const (double): Initial weighting of loss function
61+
arg_beta (double): Weighting of L1 loss
62+
arg_gamma (double): Weighting of auto-encoder
63+
arg_alpha (double): Weighting of L2 loss
64+
arg_threshold (double): automatically turn off all features less than arg_threshold since nothing to turn off
65+
arg_offset (double): input_X is in [0,1]. we subtract offset when passed to classifier
66+
67+
Returns:
68+
tuple:
69+
* **adv_X** (`numpy ndarray`) -- Perturbed input instance for PP/PN
70+
* **delta_X** (`numpy ndarray`) -- Difference between input and Perturbed instance
71+
* **INFO** (`str`) -- Other information about PP/PN
72+
"""
73+
74+
random.seed(121)
75+
np.random.seed(1211)
76+
77+
(_, orig_class, orig_prob_str) = self._wbmodel.predict_long(input_X)
78+
target_label = orig_class
79+
80+
target = np.array([np.eye(self._wbmodel._nb_classes)[target_label]])
81+
82+
# Hard coding batch_size=1
83+
batch_size = 1
84+
85+
# Example: for MNIST (1, 28, 28, 1), for tabular (1, no of columns)
86+
shape = input_X.shape
87+
88+
attack = AEADEN(self._wbmodel, shape,
89+
mode=arg_mode, AE=AE_model, batch_size=batch_size,
90+
kappa=arg_kappa, init_learning_rate=1e-2,
91+
binary_search_steps=arg_b, max_iterations=arg_max_iter,
92+
initial_const=arg_init_const, beta=arg_beta, gamma=arg_gamma,
93+
alpha=arg_alpha, threshold=arg_threshold, offset=arg_offset)
94+
95+
96+
self._wbmodel.predict(input_X) # helps compile
97+
adv_X = attack.attack(input_X + arg_offset, target)
98+
99+
adv_prob, adv_class, adv_prob_str = self._wbmodel.predict_long(adv_X)
100+
101+
delta_X = (input_X + arg_offset) - adv_X - arg_offset # add 0.5 to input for attack but subtract 0.5 to get back to [-0.5, 0.5]
102+
103+
adv_X = adv_X - arg_offset # subtrack arg_offset to get it back to [-arg_offset, arg_offset]
104+
105+
_, delta_class, delta_prob_str = self._wbmodel.predict_long(delta_X)
106+
107+
INFO = "[INFO]kappa:{}, Orig class:{}, Perturbed class:{}, Delta class: {}, Orig prob:{}, Perturbed prob:{}, Delta prob:{}".format(
108+
arg_kappa, orig_class, adv_class, delta_class, orig_prob_str, adv_prob_str, delta_prob_str)
109+
110+
return (adv_X, delta_X, INFO)

0 commit comments

Comments
 (0)