Skip to content

Commit ec8b18a

Browse files
authored
Merge pull request #52 from eli5-org/modernize-types
Modernize more type definitions related to text support
2 parents a683b34 + 9093a37 commit ec8b18a

File tree

4 files changed

+81
-113
lines changed

4 files changed

+81
-113
lines changed

eli5/base.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@attrs
14-
class Explanation(object):
14+
class Explanation:
1515
""" An explanation for classifier or regressor,
1616
it can either explain weights or a single prediction.
1717
"""
@@ -49,7 +49,7 @@ def _repr_html_(self):
4949

5050

5151
@attrs
52-
class FeatureImportances(object):
52+
class FeatureImportances:
5353
""" Feature importances with number of remaining non-zero features.
5454
"""
5555
def __init__(self, importances, remaining):
@@ -64,7 +64,7 @@ def from_names_values(cls, names, values, std=None, **kwargs):
6464

6565

6666
@attrs
67-
class TargetExplanation(object):
67+
class TargetExplanation:
6868
""" Explanation for a single target or class.
6969
Feature weights are stored in the :feature_weights: attribute,
7070
and features highlighted in text in the :weighted_spans: attribute.
@@ -92,7 +92,7 @@ def __init__(self,
9292

9393

9494
@attrs
95-
class FeatureWeights(object):
95+
class FeatureWeights:
9696
""" Weights for top features, :pos: for positive and :neg: for negative,
9797
sorted by descending absolute value.
9898
Number of remaining positive and negative features are stored in
@@ -111,7 +111,7 @@ def __init__(self,
111111

112112

113113
@attrs
114-
class FeatureWeight(object):
114+
class FeatureWeight:
115115
def __init__(self, feature: Feature, weight: float, std: Optional[float] = None, value=None):
116116
self.feature = feature
117117
self.weight = weight
@@ -120,7 +120,7 @@ def __init__(self, feature: Feature, weight: float, std: Optional[float] = None,
120120

121121

122122
@attrs
123-
class WeightedSpans(object):
123+
class WeightedSpans:
124124
""" Holds highlighted spans for parts of document - a DocWeightedSpans
125125
object for each vectorizer, and other features not highlighted anywhere.
126126
"""
@@ -140,7 +140,7 @@ def __init__(self,
140140

141141

142142
@attrs
143-
class DocWeightedSpans(object):
143+
class DocWeightedSpans:
144144
""" Features highlighted in text. :document: is a pre-processed document
145145
before applying the analyzer. :weighted_spans: holds a list of spans
146146
for features found in text (span indices correspond to
@@ -161,15 +161,15 @@ def __init__(self,
161161

162162

163163
@attrs
164-
class TransitionFeatureWeights(object):
164+
class TransitionFeatureWeights:
165165
""" Weights matrix for transition features. """
166166
def __init__(self, class_names: list[str], coef):
167167
self.class_names = class_names
168168
self.coef = coef
169169

170170

171171
@attrs
172-
class TreeInfo(object):
172+
class TreeInfo:
173173
""" Information about the decision tree. :criterion: is the name of
174174
the function to measure the quality of a split, :tree: holds all nodes
175175
of the tree, and :graphviz: is the tree rendered in graphviz .dot format.
@@ -182,7 +182,7 @@ def __init__(self, criterion: str, tree: 'NodeInfo', graphviz: str, is_classific
182182

183183

184184
@attrs
185-
class NodeInfo(object):
185+
class NodeInfo:
186186
""" A node in a binary tree.
187187
Pointers to left and right children are in :left: and :right: attributes.
188188
"""

eli5/formatters/html.py

+37-61
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
# -*- coding: utf-8 -*-
2-
from __future__ import absolute_import
31
from itertools import groupby
4-
from typing import List, Optional, Tuple
2+
from html import escape
3+
from typing import Optional
54

65
import numpy as np
76
from jinja2 import Environment, PackageLoader
@@ -32,16 +31,15 @@
3231
))
3332

3433

35-
def format_as_html(explanation, # type: Explanation
36-
include_styles=True, # type: bool
37-
force_weights=True, # type: bool
34+
def format_as_html(explanation: Explanation,
35+
include_styles=True,
36+
force_weights=True,
3837
show=fields.ALL,
39-
preserve_density=None, # type: Optional[bool]
40-
highlight_spaces=None, # type: Optional[bool]
41-
horizontal_layout=True, # type: bool
42-
show_feature_values=False # type: bool
43-
):
44-
# type: (...) -> str
38+
preserve_density: Optional[bool] = None,
39+
highlight_spaces: Optional[bool] = None,
40+
horizontal_layout=True,
41+
show_feature_values=False,
42+
) -> str:
4543
""" Format explanation as html.
4644
Most styles are inline, but some are included separately in <style> tag,
4745
you can omit them by passing ``include_styles=False`` and call
@@ -130,42 +128,37 @@ def format_as_html(explanation, # type: Explanation
130128
'''.replace('\n', ' ')
131129

132130

133-
def format_html_styles():
134-
# type: () -> str
131+
def format_html_styles() -> str:
135132
""" Format just the styles,
136133
use with ``format_as_html(explanation, include_styles=False)``.
137134
"""
138135
return template_env.get_template('styles.html').render()
139136

140137

141138
def render_targets_weighted_spans(
142-
targets, # type: List[TargetExplanation]
143-
preserve_density, # type: Optional[bool]
144-
):
145-
# type: (...) -> List[Optional[str]]
139+
targets: list[TargetExplanation],
140+
preserve_density: Optional[bool],
141+
) -> list[Optional[str]]:
146142
""" Return a list of rendered weighted spans for targets.
147143
Function must accept a list in order to select consistent weight
148144
ranges across all targets.
149145
"""
150146
prepared_weighted_spans = prepare_weighted_spans(
151147
targets, preserve_density)
152148

153-
def _fmt_pws(pws):
154-
# type: (PreparedWeightedSpans) -> str
149+
def _fmt_pws(pws: PreparedWeightedSpans) -> str:
155150
name = ('<b>{}:</b> '.format(pws.doc_weighted_spans.vec_name)
156151
if pws.doc_weighted_spans.vec_name else '')
157152
return '{}{}'.format(name, render_weighted_spans(pws))
158153

159-
def _fmt_pws_list(pws_lst):
160-
# type: (List[PreparedWeightedSpans]) -> str
154+
def _fmt_pws_list(pws_lst: list[PreparedWeightedSpans]) -> str:
161155
return '<br/>'.join(_fmt_pws(pws) for pws in pws_lst)
162156

163157
return [_fmt_pws_list(pws_lst) if pws_lst else None
164158
for pws_lst in prepared_weighted_spans]
165159

166160

167-
def render_weighted_spans(pws):
168-
# type: (PreparedWeightedSpans) -> str
161+
def render_weighted_spans(pws: PreparedWeightedSpans) -> str:
169162
# TODO - for longer documents, an option to remove text
170163
# without active features
171164
return ''.join(
@@ -177,11 +170,10 @@ def render_weighted_spans(pws):
177170
key=lambda x: x[1]))
178171

179172

180-
def _colorize(token, # type: str
181-
weight, # type: float
182-
weight_range, # type: float
183-
):
184-
# type: (...) -> str
173+
def _colorize(token: str,
174+
weight: float,
175+
weight_range: float,
176+
) -> str:
185177
""" Return token wrapped in a span with some styles
186178
(calculated from weight and weight_range) applied.
187179
"""
@@ -208,8 +200,7 @@ def _colorize(token, # type: str
208200
)
209201

210202

211-
def _weight_opacity(weight, weight_range):
212-
# type: (float, float) -> str
203+
def _weight_opacity(weight: float, weight_range: float) -> str:
213204
""" Return opacity value for given weight as a string.
214205
"""
215206
min_opacity = 0.8
@@ -220,11 +211,10 @@ def _weight_opacity(weight, weight_range):
220211
return '{:.2f}'.format(min_opacity + (1 - min_opacity) * rel_weight)
221212

222213

223-
_HSL_COLOR = Tuple[float, float, float]
214+
_HSL_COLOR = tuple[float, float, float]
224215

225216

226-
def weight_color_hsl(weight, weight_range, min_lightness=0.8):
227-
# type: (float, float, float) -> _HSL_COLOR
217+
def weight_color_hsl(weight: float, weight_range: float, min_lightness=0.8) -> _HSL_COLOR:
228218
""" Return HSL color components for given weight,
229219
where the max absolute weight is given by weight_range.
230220
"""
@@ -235,21 +225,18 @@ def weight_color_hsl(weight, weight_range, min_lightness=0.8):
235225
return hue, saturation, lightness
236226

237227

238-
def format_hsl(hsl_color):
239-
# type: (_HSL_COLOR) -> str
228+
def format_hsl(hsl_color: _HSL_COLOR) -> str:
240229
""" Format hsl color as css color string.
241230
"""
242231
hue, saturation, lightness = hsl_color
243232
return 'hsl({}, {:.2%}, {:.2%})'.format(hue, saturation, lightness)
244233

245234

246-
def _hue(weight):
247-
# type: (float) -> float
235+
def _hue(weight: float) -> float:
248236
return 120 if weight > 0 else 0
249237

250238

251-
def get_weight_range(weights):
252-
# type: (FeatureWeights) -> float
239+
def get_weight_range(weights: FeatureWeights) -> float:
253240
""" Max absolute feature for pos and neg weights.
254241
"""
255242
return max_or_0(abs(fw.weight)
@@ -258,11 +245,10 @@ def get_weight_range(weights):
258245

259246

260247
def remaining_weight_color_hsl(
261-
ws, # type: List[FeatureWeight]
262-
weight_range, # type: float
263-
pos_neg, # type: str
264-
):
265-
# type: (...) -> _HSL_COLOR
248+
ws: list[FeatureWeight],
249+
weight_range: float,
250+
pos_neg: str,
251+
) -> _HSL_COLOR:
266252
""" Color for "remaining" row.
267253
Handles a number of edge cases: if there are no weights in ws or weight_range
268254
is zero, assume the worst (most intensive positive or negative color).
@@ -278,8 +264,7 @@ def remaining_weight_color_hsl(
278264
return weight_color_hsl(weight, weight_range)
279265

280266

281-
def _format_unhashed_feature(feature, weight, hl_spaces):
282-
# type: (...) -> str
267+
def _format_unhashed_feature(feature, weight, hl_spaces) -> str:
283268
""" Format unhashed feature: show first (most probable) candidate,
284269
display other candidates in title attribute.
285270
"""
@@ -295,8 +280,7 @@ def _format_unhashed_feature(feature, weight, hl_spaces):
295280
return html
296281

297282

298-
def _format_feature(feature, weight, hl_spaces):
299-
# type: (...) -> str
283+
def _format_feature(feature, weight, hl_spaces) -> str:
300284
""" Format any feature.
301285
"""
302286
if isinstance(feature, FormattedFeatureName):
@@ -308,14 +292,12 @@ def _format_feature(feature, weight, hl_spaces):
308292
return _format_single_feature(feature, weight, hl_spaces=hl_spaces)
309293

310294

311-
def _format_single_feature(feature, weight, hl_spaces):
312-
# type: (str, float, bool) -> str
295+
def _format_single_feature(feature: str, weight: float, hl_spaces: bool) -> str:
313296
feature = html_escape(feature)
314297
if not hl_spaces:
315298
return feature
316299

317-
def replacer(n_spaces, side):
318-
# type: (int, str) -> str
300+
def replacer(n_spaces: int, side: str) -> str:
319301
m = '0.1em'
320302
margins = {'left': (m, 0), 'right': (0, m), 'center': (m, m)}[side]
321303
style = '; '.join([
@@ -331,18 +313,12 @@ def replacer(n_spaces, side):
331313
return replace_spaces(feature, replacer)
332314

333315

334-
def _format_decision_tree(treedict):
335-
# type: (...) -> str
316+
def _format_decision_tree(treedict) -> str:
336317
if treedict.graphviz and _graphviz.is_supported():
337318
return _graphviz.dot2svg(treedict.graphviz)
338319
else:
339320
return tree2text(treedict)
340321

341322

342-
def html_escape(text):
343-
# type: (str) -> str
344-
try:
345-
from html import escape
346-
except ImportError:
347-
from cgi import escape # type: ignore
323+
def html_escape(text) -> str:
348324
return escape(text, quote=True)

eli5/formatters/text_helpers.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from collections import Counter
2-
from typing import List, Optional
2+
from typing import Optional
33

44
import numpy as np
55

6-
from eli5.base import TargetExplanation, WeightedSpans, DocWeightedSpans
6+
from eli5.base import TargetExplanation, DocWeightedSpans
77
from eli5.base_utils import attrs
88
from eli5.utils import max_or_0
99

1010

11-
def get_char_weights(doc_weighted_spans, preserve_density=None):
12-
# type: (DocWeightedSpans, Optional[bool]) -> np.ndarray
11+
def get_char_weights(
12+
doc_weighted_spans: DocWeightedSpans, preserve_density: Optional[bool] = None,
13+
) -> np.ndarray:
1314
""" Return character weights for a text document with highlighted features.
1415
If preserve_density is True, then color for longer fragments will be
1516
less intensive than for shorter fragments, so that "sum" of intensities
@@ -35,11 +36,10 @@ def get_char_weights(doc_weighted_spans, preserve_density=None):
3536
@attrs
3637
class PreparedWeightedSpans(object):
3738
def __init__(self,
38-
doc_weighted_spans, # type: DocWeightedSpans
39-
char_weights, # type: np.ndarray
40-
weight_range, # type: float
39+
doc_weighted_spans: DocWeightedSpans,
40+
char_weights: np.ndarray,
41+
weight_range: float,
4142
):
42-
# type: (...) -> None
4343
self.doc_weighted_spans = doc_weighted_spans
4444
self.char_weights = char_weights
4545
self.weight_range = weight_range
@@ -55,25 +55,24 @@ def __eq__(self, other):
5555
return False
5656

5757

58-
def prepare_weighted_spans(targets, # type: List[TargetExplanation]
59-
preserve_density=None, # type: Optional[bool]
60-
):
61-
# type: (...) -> List[Optional[List[PreparedWeightedSpans]]]
58+
def prepare_weighted_spans(targets: list[TargetExplanation],
59+
preserve_density: Optional[bool] = None,
60+
) -> list[Optional[list[PreparedWeightedSpans]]]:
6261
""" Return weighted spans prepared for rendering.
6362
Calculate a separate weight range for each different weighted
6463
span (for each different index): each target has the same number
6564
of weighted spans.
6665
"""
67-
targets_char_weights = [
66+
targets_char_weights: list[Optional[list[np.ndarray]]] = [
6867
[get_char_weights(ws, preserve_density=preserve_density)
6968
for ws in t.weighted_spans.docs_weighted_spans]
7069
if t.weighted_spans else None
71-
for t in targets] # type: List[Optional[List[np.ndarray]]]
70+
for t in targets]
7271
max_idx = max_or_0(len(ch_w or []) for ch_w in targets_char_weights)
7372

74-
targets_char_weights_not_None = [
73+
targets_char_weights_not_None: list[list[np.ndarray]] = [
7574
cw for cw in targets_char_weights
76-
if cw is not None] # type: List[List[np.ndarray]]
75+
if cw is not None]
7776

7877
spans_weight_ranges = [
7978
max_or_0(

0 commit comments

Comments
 (0)