Skip to content

Commit d02839d

Browse files
author
Christian Meier
committed
allow for -1 in integrated gradients
1 parent 2f36651 commit d02839d

File tree

3 files changed

+77
-23
lines changed

3 files changed

+77
-23
lines changed

pyadlml/dataset/plot/plotly/discrete.py

+75-20
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from plotly.subplots import make_subplots
44
import pandas as pd
5+
from copy import copy
56

67
from pyadlml.constants import ACTIVITY, END_TIME, START_TIME, STRFTIME_PRECISE
78
from pyadlml.dataset._core.activities import is_activity_df
@@ -232,7 +233,7 @@ def func(row, t, dt):
232233
.rename(columns={'start_time':'start', 'lengths_time':'lengths'})
233234

234235

235-
def _plot_devices_into(fig, X, cat_col_map, row, col, time, dev_order):
236+
def _plot_devices_into(fig, X, cat_col_map, row, col, time, dev_order, init_device_markers=False):
236237
ON = 1
237238
OFF = 0
238239
# TODO refactor, generalize also for numerical labels
@@ -246,18 +247,25 @@ def _plot_devices_into(fig, X, cat_col_map, row, col, time, dev_order):
246247

247248
trace_lst = []
248249

249-
devs = X.columns.to_list() if dev_order is None else dev_order
250+
devs = X.columns.to_list() if dev_order is None else copy(dev_order)
250251
devs = devs.tolist() if isinstance(devs, np.ndarray) else devs
251252
devs_num = []
252253

254+
if init_device_markers:
255+
fig = _plot_selected_device_marker(fig, time, init=True)
256+
devs_num.append('sel_dev')
257+
dev_order.insert(0, 'sel_dev')
258+
259+
253260
for dev in devs:
254261

255-
# Check if 0-1-vector numeric
262+
256263
vals = set(X[dev].unique())
257-
is_01_dev = (vals == set([0, 1]) or vals == set([0]) or vals == set([1]))
258-
if is_01_dev:
259-
260-
for k in [0, 1]:
264+
is_binary_dev = (vals == set([0, 1]) or vals == set([0]) or vals == set([1]))\
265+
or (vals == set([-1, 1]) or vals == set([-1]) or vals == set([1]))
266+
if is_binary_dev:
267+
zero_value = list(vals - {1})[0]
268+
for k in [zero_value, 1]:
261269
if time is None:
262270
hover_template = '<b>' + dev + '</b><br>'\
263271
+ 'Int: [%{base}, %{x})<br>' \
@@ -325,7 +333,6 @@ def _plot_devices_into(fig, X, cat_col_map, row, col, time, dev_order):
325333
fig.add_trace(trace, row=row, col=col, secondary_y=True)
326334

327335

328-
329336
fig.update_layout({'barmode': 'overlay',
330337
'legend': {'tracegroupgap': 0}
331338
})
@@ -337,15 +344,15 @@ def _plot_devices_into(fig, X, cat_col_map, row, col, time, dev_order):
337344
for dev in devs_num:
338345
# Rescale numerical devices and position them at their index
339346
# Since data is scaled to 0.5 add 0.25 to center around dev_idx
340-
y = list(fig.select_traces(selector=dict(name=dev, type='scatter')))[0]['y']
347+
y = np.array((list(fig.select_traces(selector=dict(name=dev, type='scatter')))[0]['y']))
341348
y = y + dev_order.index(dev) + 0.25
342349
fig.update_traces(dict(y=y), selector=dict(type='scatter', name=dev))
343350

344351
if devs_num:
345352
axis_name = "yaxis3" if row == 2 else "yaxis2"
346353
yaxis = fig['layout'][axis_name].overlaying
347354
fig.update_yaxes(row=row, col=col, secondary_y=True,
348-
range=[0, len(devs) + 2],
355+
range=[0, len(devs) + 2 + int(init_device_markers)],
349356
scaleanchor=yaxis, # Linking the secondary y-axis to the primary y-axis
350357
scaleratio=1, # Ensuring equal scaling for both y-axes
351358
constrain='domain', # Coupling the secondary y-axis to the primary y-axis
@@ -359,7 +366,7 @@ def _plot_devices_into(fig, X, cat_col_map, row, col, time, dev_order):
359366

360367

361368
@remove_whitespace_around_fig
362-
def acts_and_devs(X, y_true=None, y_pred=None, y_conf=None, act_order=None, dev_order=None, times=None, heatmap:go.Heatmap=None):
369+
def acts_and_devs(X, y_true=None, y_pred=None, y_conf=None, act_order=None, dev_order=None, times=None, heatmap:go.Heatmap=None, sel_device=False):
363370
""" Plot activities and devices for already
364371
365372
"""
@@ -374,9 +381,10 @@ def acts_and_devs(X, y_true=None, y_pred=None, y_conf=None, act_order=None, dev_
374381
raise
375382
#N = y_true.shape[0]
376383
#assert y_pred.shape[0] == N and y_conf.shape[0] == N and times.shape[0] == N and X.shape[0] == N
384+
times = pd.Series(times) if isinstance(times, np.ndarray) else times
385+
377386

378387

379-
times = pd.Series(times) if isinstance(times, np.ndarray) else times
380388
if dev_order is None:
381389
device_order = X.columns.to_list()
382390
elif isinstance(dev_order, np.ndarray):
@@ -417,8 +425,10 @@ def acts_and_devs(X, y_true=None, y_pred=None, y_conf=None, act_order=None, dev_
417425
heatmap_row = 1 if y_conf is None else 2
418426
fig.add_trace(heatmap, row=heatmap_row, col=1, secondary_y=False)
419427

420-
fig = _plot_devices_into(fig, X, cat_col_map, row=rows, col=cols, time=times, dev_order=device_order)
421-
428+
fig = _plot_devices_into(fig, X, cat_col_map, row=rows, col=cols,
429+
time=times, dev_order=device_order,
430+
init_device_markers=sel_device
431+
)
422432

423433
if y_true is not None:
424434
fig = _plot_activities_into(fig, y_true, 'y_true', cat_col_map, row=rows, col=cols, time=times, activities=act_order)
@@ -428,13 +438,14 @@ def acts_and_devs(X, y_true=None, y_pred=None, y_conf=None, act_order=None, dev_
428438
fig = _plot_confidences_into(fig, 1, 1, y_conf, act_order,
429439
cat_col_map, times)
430440

441+
431442
# Update y-axis tickfont
432-
fig.update_yaxes(row=rows, col=cols,
433-
tickfont=dict(size=8, family='Arial'),
434-
categoryorder='array',
435-
categoryarray=device_order + ['y_pred', 'y_true'],
436-
secondary_y=True,
437-
)
443+
#fig.update_yaxes(row=rows, col=cols,
444+
# tickfont=dict(size=8, family='Arial'),
445+
# categoryorder='array',
446+
# categoryarray=['sel_dev'] + device_order + ['y_pred', 'y_true'],
447+
# secondary_y=True,
448+
#)
438449
fig.update_yaxes(row=rows, col=cols, secondary_y=False,
439450
tickmode='linear',
440451
dtick=1,
@@ -443,4 +454,48 @@ def acts_and_devs(X, y_true=None, y_pred=None, y_conf=None, act_order=None, dev_
443454
if times is not None:
444455
fig.update_xaxes(range=[times.iloc[0], times.iloc[-1]])
445456

457+
458+
return fig
459+
460+
461+
def _plot_selected_device_marker(fig, times, init=False):
462+
from pyadlml.constants import TIME
463+
row = 2
464+
col = 1
465+
label= 'sel_dev'
466+
467+
if init:
468+
times = times[:3]
469+
470+
df = pd.DataFrame({TIME: times, 'y':[label]*len(times)})
471+
hover_template = 'Time: %{x|' + STRFTIME_PRECISE + '}<br>' + \
472+
'<extra></extra>'
473+
marker = dict(size=5, symbol=5, line=dict(color='Red', width=1))
474+
trace = go.Scatter(
475+
name=label,
476+
meta=label,
477+
mode='markers',
478+
y=[0.25]*len(df[TIME]),
479+
x=df[TIME],
480+
marker=marker,
481+
opacity=0.0 if init else 1.0,
482+
hovertemplate=hover_template,
483+
showlegend=False)
484+
485+
# Add dummy bar plot as placeholder
486+
fig.add_trace(go.Bar(
487+
name='-1',
488+
meta=label,
489+
base=times.loc[:2],
490+
x=[10, 30],
491+
y=np.array([label]*2),
492+
orientation='h',
493+
showlegend=False,
494+
width=0.3,
495+
textposition='auto',
496+
opacity=0.0,
497+
marker_color='black',
498+
), row=row, col=col)
499+
500+
fig.add_trace(trace, row=row, col=col, secondary_y=True)
446501
return fig

pyadlml/dataset/plot/plotly/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class CatColMap():
281281
cat_idx =0
282282

283283
cat_col_map = {
284-
0:COL_OFF, 1:COL_ON,
284+
0:COL_OFF, -1:COL_OFF, 1:COL_ON,
285285
'off':COL_OFF, 'on':COL_ON,
286286
False:COL_OFF, True:COL_ON,
287287
}

pyadlml/metrics.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,6 @@ def transition_accuracy(y_true: pd.DataFrame, y_pred, y_times, eps=0.8, lag='10s
585585
lag_greater_than_min_act = ((y_true[END_TIME] - y_true[START_TIME]) > lag).all(), error_msg
586586

587587
assert 0 <= eps and eps <= 1, 'Epsilon should be in range [0, 1]'
588-
assert average in ['micro', 'macro']
589588

590589
lag = pd.Timedelta(lag)
591590
y_true, y_pred, y_times = y_true.copy(), y_pred.copy(), y_times.copy()
@@ -680,7 +679,7 @@ def transition_accuracy(y_true: pd.DataFrame, y_pred, y_times, eps=0.8, lag='10s
680679
trans_end_viol_tmp = trans_end_viol[dur_trans_end_viol != trans_end_viol]
681680
dur_trans_end_viol = dur_trans_end_viol[dur_trans_end_viol != trans_end_viol]
682681
trans_end_viol = trans_end_viol_tmp
683-
assert not trans_start_viol and not trans_end_viol
682+
assert trans_start_viol.empty and trans_end_viol.empty
684683

685684

686685
df_trans_start = df.groupby('trans_start')['correct_time'].sum().iloc[1:]

0 commit comments

Comments
 (0)