Open
Description
Describe the bug
inputting "mean" does not work on baselines.
To reproduce
- include "mean" as the baseline for related and target timeseries:
{
"target_time_series": "mean",
"related_time_series": "mean",
}
Expected behavior
I would expect this to use a baseline of mean values.
Screenshots or logs
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 830, in main
process()
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 822, in process
serializer.dump_stream(out_iter, outfile)
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 274, in dump_stream
vs = list(itertools.islice(iterator, batch))
File "/usr/local/lib/python3.9/site-packages/analyzer/analyzers/timeseries/timeseries_asymmetric_shap_analyzer.py", line 251, in explain
instance_explanation=explainer.explain(self._to_explainer_input(row), baseline_config=baseline_config),
File "/usr/local/lib/python3.9/site-packages/explainers/shap/asymmetric_shap/asymmetric_shap.py", line 94, in explain
return self._explain_time_series(input_dataset, baseline_config or TimeSeriesBaselineConfig())
File "/usr/local/lib/python3.9/site-packages/explainers/shap/asymmetric_shap/asymmetric_shap.py", line 114, in _explain_time_series
baseline = TimeSeriesBaselineValues.create_from(
File "/usr/local/lib/python3.9/site-packages/explainers/shap/asymmetric_shap/asymmetric_shap_dataclasses.py", line 195, in create_from
target_ts_baseline = 0.0 if config.target_ts == "zero" else input_dataset.target_ts.mean()
AttributeError: 'list' object has no attribute 'mean'
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 2.229.0
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): N/A (TSX)
- Framework version: N/A (TSX)
- Python version: 3.11, but TSX appears to use py 3.9
- CPU or GPU: CPU
- Custom Docker image (Y/N): N
Additional context
It seems that it is trying to call .mean()
on a python list which is not an attribute / module of a list. However it is an attribute of a numpy ndarray. Could be what is happening here.