Skip to content

Commit fa85315

Browse files
committed
fixed issue #5
1 parent 020ca85 commit fa85315

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

fasttreeshap/explainers/_tree.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit
361361
if self.algorithm == "auto":
362362
# check if number of samples to be explained is sufficiently large
363363
num_samples = X.shape[0]
364-
num_samples_threshold = 2**(self.model.max_depth + 1) / self.model.max_depth
364+
num_samples_threshold = 2**int(self.model.max_depth + 1) / self.model.max_depth
365365
num_samples_check = (num_samples >= num_samples_threshold)
366366
# check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in function _memory_check)
367367
memory_check_1, memory_check_2 = self._memory_check(X)
@@ -481,7 +481,7 @@ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_addit
481481
# check if memory constraint is satisfied (check Section Notes in README.md for justifications of memory check conditions in this function)
482482
def _memory_check(self, X):
483483
max_leaves = (max(self.model.num_nodes) + 1) / 2
484-
max_combinations = 2**self.model.max_depth
484+
max_combinations = 2**int(self.model.max_depth)
485485
phi_dim = X.shape[0] * (X.shape[1] + 1) * self.model.num_outputs
486486
memory_usage_1 = (max_leaves * max_combinations + phi_dim) * 8 * self.n_jobs
487487
memory_usage_2 = max_leaves * max_combinations * self.model.values.shape[0] * 8

notebooks/FastTreeSHAP_Census_Income.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
"def memory_estimate_v2(shap_explainer, num_sample, num_feature, n_jobs):\n",
134134
" max_node = max(shap_explainer.model.num_nodes)\n",
135135
" max_leaves = (max_node + 1) // 2\n",
136-
" max_combinations = 2**shap_explainer.model.max_depth\n",
136+
" max_combinations = 2**int(shap_explainer.model.max_depth)\n",
137137
" phi_dim = num_sample * (num_feature + 1) * shap_explainer.model.num_outputs\n",
138138
" n_jobs = os.cpu_count() if n_jobs == -1 else n_jobs\n",
139139
" memory_1 = (max_leaves * max_combinations + phi_dim) * 8 * n_jobs\n",

notebooks/FastTreeSHAP_Crop_Mapping.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@
140140
"def memory_estimate_v2(shap_explainer, num_sample, num_feature, n_jobs):\n",
141141
" max_node = max(shap_explainer.model.num_nodes)\n",
142142
" max_leaves = (max_node + 1) // 2\n",
143-
" max_combinations = 2**shap_explainer.model.max_depth\n",
143+
" max_combinations = 2**int(shap_explainer.model.max_depth)\n",
144144
" phi_dim = num_sample * (num_feature + 1) * shap_explainer.model.num_outputs\n",
145145
" n_jobs = os.cpu_count() if n_jobs == -1 else n_jobs\n",
146146
" memory_1 = (max_leaves * max_combinations + phi_dim) * 8 * n_jobs\n",

notebooks/FastTreeSHAP_Superconductor.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
"def memory_estimate_v2(shap_explainer, num_sample, num_feature, n_jobs):\n",
117117
" max_node = max(shap_explainer.model.num_nodes)\n",
118118
" max_leaves = (max_node + 1) // 2\n",
119-
" max_combinations = 2**shap_explainer.model.max_depth\n",
119+
" max_combinations = 2**int(shap_explainer.model.max_depth)\n",
120120
" phi_dim = num_sample * (num_feature + 1) * shap_explainer.model.num_outputs\n",
121121
" n_jobs = os.cpu_count() if n_jobs == -1 else n_jobs\n",
122122
" memory_1 = (max_leaves * max_combinations + phi_dim) * 8 * n_jobs\n",

0 commit comments

Comments
 (0)