|
1 |
| -from __future__ import print_function |
2 |
| - |
3 |
| -from aix360.algorithms.lwbe import LocalWBExplainer |
4 |
| - |
5 |
| -from .CEM_aen import AEADEN |
6 |
| - |
7 |
| -import random |
8 |
| -import numpy as np |
9 |
| - |
10 |
| - |
11 |
| -class CEMExplainer(LocalWBExplainer): |
12 |
| - """ |
13 |
| - CEMExplainer can be used to compute contrastive explanations for image and tabular data. |
14 |
| - This is achieved by finding what is minimally sufficient (PP - Pertinent Positive) and |
15 |
| - what should be necessarily absent (PN - Pertinent Negative) to maintain the original classification. |
16 |
| - We use elastic norm regularization to ensure minimality for both parts of the explanation |
17 |
| - i.e. PPs and PNs. An autoencoder can optionally be used to make the explanations more realistic. [#]_ |
18 |
| -
|
19 |
| - References: |
20 |
| - .. [#] `Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen Tu, |
21 |
| - Paishun Ting, Karthikeyan Shanmugam, Payel Das, "Explanations based on |
22 |
| - the Missing: Towards Contrastive Explanations with Pertinent Negatives," |
23 |
| - Advances in Neural Information Processing Systems (NeurIPS), 2018. |
24 |
| - <https://arxiv.org/abs/1802.07623>`_ |
25 |
| - """ |
26 |
| - def __init__(self, model): |
27 |
| - |
28 |
| - """ |
29 |
| - Constructor method, initializes the explainer |
30 |
| -
|
31 |
| - Args: |
32 |
| - model: KerasClassifier model whose predictions needs to be explained |
33 |
| - """ |
34 |
| - super(CEMExplainer, self).__init__() |
35 |
| - self._wbmodel = model |
36 |
| - |
37 |
| - |
38 |
| - def set_params(self, *argv, **kwargs): |
39 |
| - """ |
40 |
| - Set parameters for the explainer. |
41 |
| - """ |
42 |
| - pass |
43 |
| - |
44 |
| - |
45 |
| - def explain_instance(self, input_X, |
46 |
| - arg_mode, AE_model, arg_kappa, arg_b, |
47 |
| - arg_max_iter, arg_init_const, arg_beta, arg_gamma): |
48 |
| - |
49 |
| - """ |
50 |
| - Explains an input instance input_X and returns contrastive explanations. |
51 |
| - Note that this assumes that the classifier was trained with inputs normalized in [-0.5,0.5] range. |
52 |
| -
|
53 |
| - Args: |
54 |
| - input_X (numpy.ndarray): input instance to be explained |
55 |
| - arg_mode (str): 'PP' or 'PN' |
56 |
| - AE_model: Auto-encoder model |
57 |
| - arg_kappa (double): Confidence gap between desired class and other classes |
58 |
| - arg_b (double): Number of different weightings of loss function to try |
59 |
| - arg_max_iter (int): For each weighting of loss function number of iterations to search |
60 |
| - arg_init_const (double): Initial weighting of loss function |
61 |
| - arg_beta (double): Weighting of L1 loss |
62 |
| - arg_gamma (double): Weighting of auto-encoder |
63 |
| -
|
64 |
| - Returns: |
65 |
| - tuple: |
66 |
| - * **adv_X** (`numpy ndarray`) -- Perturbed input instance for PP/PN |
67 |
| - * **delta_X** (`numpy ndarray`) -- Difference between input and Perturbed instance |
68 |
| - * **INFO** (`str`) -- Other information about PP/PN |
69 |
| - """ |
70 |
| - |
71 |
| - random.seed(121) |
72 |
| - np.random.seed(1211) |
73 |
| - |
74 |
| - (_, orig_class, orig_prob_str) = self._wbmodel.predict_long(input_X) |
75 |
| - target_label = orig_class |
76 |
| - |
77 |
| - target = np.array([np.eye(self._wbmodel._nb_classes)[target_label]]) |
78 |
| - |
79 |
| - # Hard coding batch_size=1 |
80 |
| - batch_size = 1 |
81 |
| - |
82 |
| - # Example: for MNIST (1, 28, 28, 1), for tabular (1, no of columns) |
83 |
| - shape = input_X.shape |
84 |
| - |
85 |
| - attack = AEADEN(self._wbmodel, shape, |
86 |
| - mode=arg_mode, AE=AE_model, batch_size=batch_size, |
87 |
| - kappa=arg_kappa, init_learning_rate=1e-2, |
88 |
| - binary_search_steps=arg_b, max_iterations=arg_max_iter, |
89 |
| - initial_const=arg_init_const, beta=arg_beta, gamma=arg_gamma) |
90 |
| - |
91 |
| - adv_X = attack.attack(input_X, target) |
92 |
| - |
93 |
| - adv_prob, adv_class, adv_prob_str = self._wbmodel.predict_long(adv_X) |
94 |
| - |
95 |
| - delta_X = input_X - adv_X |
96 |
| - |
97 |
| - _, delta_class, delta_prob_str = self._wbmodel.predict_long(delta_X) |
98 |
| - |
99 |
| - INFO = "[INFO]kappa:{}, Orig class:{}, Perturbed class:{}, Delta class: {}, Orig prob:{}, Perturbed prob:{}, Delta prob:{}".format( |
100 |
| - arg_kappa, orig_class, adv_class, delta_class, orig_prob_str, adv_prob_str, delta_prob_str) |
101 |
| - |
102 |
| - return (adv_X, delta_X, INFO) |
| 1 | +from __future__ import print_function |
| 2 | + |
| 3 | +from aix360.algorithms.lwbe import LocalWBExplainer |
| 4 | + |
| 5 | +from .CEM_aen import AEADEN |
| 6 | + |
| 7 | +import random |
| 8 | +import numpy as np |
| 9 | + |
| 10 | + |
| 11 | +class CEMExplainer(LocalWBExplainer): |
| 12 | + """ |
| 13 | + CEMExplainer can be used to compute contrastive explanations for image and tabular data. |
| 14 | + This is achieved by finding what is minimally sufficient (PP - Pertinent Positive) and |
| 15 | + what should be necessarily absent (PN - Pertinent Negative) to maintain the original classification. |
| 16 | + We use elastic norm regularization to ensure minimality for both parts of the explanation |
| 17 | + i.e. PPs and PNs. An autoencoder can optionally be used to make the explanations more realistic. [#]_ |
| 18 | +
|
| 19 | + References: |
| 20 | + .. [#] `Amit Dhurandhar, Pin-Yu Chen, Ronny Luss, Chun-Chen Tu, |
| 21 | + Paishun Ting, Karthikeyan Shanmugam, Payel Das, "Explanations based on |
| 22 | + the Missing: Towards Contrastive Explanations with Pertinent Negatives," |
| 23 | + Advances in Neural Information Processing Systems (NeurIPS), 2018. |
| 24 | + <https://arxiv.org/abs/1802.07623>`_ |
| 25 | + """ |
| 26 | + def __init__(self, model): |
| 27 | + |
| 28 | + """ |
| 29 | + Constructor method, initializes the explainer |
| 30 | +
|
| 31 | + Args: |
| 32 | + model: KerasClassifier model whose predictions needs to be explained |
| 33 | + """ |
| 34 | + super(CEMExplainer, self).__init__() |
| 35 | + self._wbmodel = model |
| 36 | + |
| 37 | + |
| 38 | + def set_params(self, *argv, **kwargs): |
| 39 | + """ |
| 40 | + Set parameters for the explainer. |
| 41 | + """ |
| 42 | + pass |
| 43 | + |
| 44 | + |
| 45 | + def explain_instance(self, input_X, |
| 46 | + arg_mode, AE_model, arg_kappa, arg_b, |
| 47 | + arg_max_iter, arg_init_const, arg_beta, arg_gamma, arg_alpha=0, arg_threshold=1, arg_offset=0): |
| 48 | + |
| 49 | + """ |
| 50 | + Explains an input instance input_X and returns contrastive explanations. |
| 51 | + Note that this assumes that the classifier was trained with inputs normalized in [-arg_offset, arg_offset] range, where arg_offset is 0 or .5. |
| 52 | +
|
| 53 | + Args: |
| 54 | + input_X (numpy.ndarray): input instance to be explained |
| 55 | + arg_mode (str): 'PP' or 'PN' |
| 56 | + AE_model: Auto-encoder model |
| 57 | + arg_kappa (double): Confidence gap between desired class and other classes |
| 58 | + arg_b (double): Number of different weightings of loss function to try |
| 59 | + arg_max_iter (int): For each weighting of loss function number of iterations to search |
| 60 | + arg_init_const (double): Initial weighting of loss function |
| 61 | + arg_beta (double): Weighting of L1 loss |
| 62 | + arg_gamma (double): Weighting of auto-encoder |
| 63 | + arg_alpha (double): Weighting of L2 loss |
| 64 | + arg_threshold (double): automatically turn off all features less than arg_threshold since nothing to turn off |
| 65 | + arg_offset (double): input_X is in [0,1]. we subtract offset when passed to classifier |
| 66 | +
|
| 67 | + Returns: |
| 68 | + tuple: |
| 69 | + * **adv_X** (`numpy ndarray`) -- Perturbed input instance for PP/PN |
| 70 | + * **delta_X** (`numpy ndarray`) -- Difference between input and Perturbed instance |
| 71 | + * **INFO** (`str`) -- Other information about PP/PN |
| 72 | + """ |
| 73 | + |
| 74 | + random.seed(121) |
| 75 | + np.random.seed(1211) |
| 76 | + |
| 77 | + (_, orig_class, orig_prob_str) = self._wbmodel.predict_long(input_X) |
| 78 | + target_label = orig_class |
| 79 | + |
| 80 | + target = np.array([np.eye(self._wbmodel._nb_classes)[target_label]]) |
| 81 | + |
| 82 | + # Hard coding batch_size=1 |
| 83 | + batch_size = 1 |
| 84 | + |
| 85 | + # Example: for MNIST (1, 28, 28, 1), for tabular (1, no of columns) |
| 86 | + shape = input_X.shape |
| 87 | + |
| 88 | + attack = AEADEN(self._wbmodel, shape, |
| 89 | + mode=arg_mode, AE=AE_model, batch_size=batch_size, |
| 90 | + kappa=arg_kappa, init_learning_rate=1e-2, |
| 91 | + binary_search_steps=arg_b, max_iterations=arg_max_iter, |
| 92 | + initial_const=arg_init_const, beta=arg_beta, gamma=arg_gamma, |
| 93 | + alpha=arg_alpha, threshold=arg_threshold, offset=arg_offset) |
| 94 | + |
| 95 | + |
| 96 | + self._wbmodel.predict(input_X) # helps compile |
| 97 | + adv_X = attack.attack(input_X + arg_offset, target) |
| 98 | + |
| 99 | + adv_prob, adv_class, adv_prob_str = self._wbmodel.predict_long(adv_X) |
| 100 | + |
| 101 | + delta_X = (input_X + arg_offset) - adv_X - arg_offset # add 0.5 to input for attack but subtract 0.5 to get back to [-0.5, 0.5] |
| 102 | + |
| 103 | + adv_X = adv_X - arg_offset # subtrack arg_offset to get it back to [-arg_offset, arg_offset] |
| 104 | + |
| 105 | + _, delta_class, delta_prob_str = self._wbmodel.predict_long(delta_X) |
| 106 | + |
| 107 | + INFO = "[INFO]kappa:{}, Orig class:{}, Perturbed class:{}, Delta class: {}, Orig prob:{}, Perturbed prob:{}, Delta prob:{}".format( |
| 108 | + arg_kappa, orig_class, adv_class, delta_class, orig_prob_str, adv_prob_str, delta_prob_str) |
| 109 | + |
| 110 | + return (adv_X, delta_X, INFO) |
0 commit comments