Description
Describe the bug
DeserializerWrapper
overwrites the content_type
provided to deserialize()
to be the value of Accept
. This DeserializerWrapper
is used for both Input and Output de-serialization when using SchemaBuilder
. The comments mention:
We need to overwrite the accept type because model servers like XGBOOST always returns "text/html"
but this effectively prevents endpoints deployed through the use of ModelBuilder
and SchemaBuilder
to support additional Content-Types
, despite the various implementations of BaseDeserializer
supporting multiple Content-Types.
For example, an inference endpoint deployed via ModelBuilder
and SchemaBuilder
which takes a np.array
as input cannot be invoked with a Content-Type
such as application/json
or text/csv
due to this overwriting. This also makes the developer experience confusing as the stack trace shows the execution path for application/x-npy
Content-Type, despite a different Content-Type being explicitly provided.
To reproduce
import sagemaker
schema_builder = sagemaker.serve.SchemaBuilder(
sample_input=np.array([[6.4, 2.8, 5.6, 2.2]]),
sample_output=np.array(
[[0.09394703, 0.4797692, 0.42628378]], dtype=np.float32
),
)
# similar usage to that defined in unit tests for base NumpyDeserializer: https://github.com/aws/sagemaker-python-sdk/blob/master/tests/unit/sagemaker/deserializers/test_deserializers.py#L113
schema_builder.input_deserializer.deserialize(io.BytesIO(b"[[6.4, 2.8, 5.6, 2.2]]"), content_type="application/json")
# Below is not needed to reproduce, but is provided as an example for how I am using SchemaBuilder
model_builder = sagemaker.serve.ModelBuilder(
mode=sagemaker.serve.mode.function_pointers.Mode.SAGEMAKER_ENDPOINT,
model_metadata={
"MLFLOW_MODEL_PATH": f"models:/{model_name}/{model_version}",
"MLFLOW_TRACKING_ARN": TRACKING_SERVER_ARN,
},
model_server=sagemaker.serve.ModelServer.TENSORFLOW_SERVING,
role_arn=sagemaker.get_execution_role(),
schema_builder=schema_builder,
)
model = model_builder.build()
model.deploy(
initial_instance_count=1,
instance_type="ml.m5.xlarge",
)
Expected behavior
I would like the content_type
to not be overwritten for input de-serialization, so that I can use SchemaBuilder
for inference endpoints while providing a Content-Type other than application/x-npy
, such as application/json
via CURL.
Screenshots or logs
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /Users/dillodon/.venv/lib/python3.12/site-packages/sagemaker/base_deserializers.py:231 in │
│ deserialize │
│ │
│ 228 │ │ │ │ return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self │
│ 229 │ │ │ if content_type == "application/x-npy": │
│ 230 │ │ │ │ try: │
│ ❱ 231 │ │ │ │ │ return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pi │
│ 232 │ │ │ │ except ValueError as ve: │
│ 233 │ │ │ │ │ raise ValueError( │
│ 234 │ │ │ │ │ │ "Please set the param allow_pickle=True \ │
│ │
│ /Users/dillodon/.venv/lib/python3.12/site-packages/numpy/lib/npyio.py:462 in load │
│ │
│ 459 │ │ else: │
│ 460 │ │ │ # Try a pickle │
│ 461 │ │ │ if not allow_pickle: │
│ ❱ 462 │ │ │ │ raise ValueError("Cannot load file containing pickled data " │
│ 463 │ │ │ │ │ │ │ │ "when allow_pickle=False") │
│ 464 │ │ │ try: │
│ 465 │ │ │ │ return pickle.load(fid, **pickle_kwargs) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Cannot load file containing pickled data when allow_pickle=False
During handling of the above exception, another exception occurred:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:1 │
│ │
│ /Users/dillodon/.venv/lib/python3.12/site-packages/sagemaker/serve/builder/schema_builder.py:77 │
│ in deserialize │
│ │
│ 74 │ │
│ 75 │ def deserialize(self, stream, content_type: str = None): │
│ 76 │ │ """Deserialize stream into object""" │
│ ❱ 77 │ │ return self._deserializer.deserialize( │
│ 78 │ │ │ stream, │
│ 79 │ │ │ # We need to overwrite the accept type because model │
│ 80 │ │ │ # servers like XGBOOST always returns "text/html" │
│ │
│ /Users/dillodon/.venv/lib/python3.12/site-packages/sagemaker/base_deserializers.py:233 in │
│ deserialize │
│ │
│ 230 │ │ │ │ try: │
│ 231 │ │ │ │ │ return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pi │
│ 232 │ │ │ │ except ValueError as ve: │
│ ❱ 233 │ │ │ │ │ raise ValueError( │
│ 234 │ │ │ │ │ │ "Please set the param allow_pickle=True \ │
│ 235 │ │ │ │ │ │ to deserialize pickle objects in NumpyDeserializer" │
│ 236 │ │ │ │ │ ).with_traceback(ve.__traceback__) │
│ │
│ /Users/dillodon/.venv/lib/python3.12/site-packages/sagemaker/base_deserializers.py:231 in │
│ deserialize │
│ │
│ 228 │ │ │ │ return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self │
│ 229 │ │ │ if content_type == "application/x-npy": │
│ 230 │ │ │ │ try: │
│ ❱ 231 │ │ │ │ │ return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pi │
│ 232 │ │ │ │ except ValueError as ve: │
│ 233 │ │ │ │ │ raise ValueError( │
│ 234 │ │ │ │ │ │ "Please set the param allow_pickle=True \ │
│ │
│ /Users/dillodon/.venv/lib/python3.12/site-packages/numpy/lib/npyio.py:462 in load │
│ │
│ 459 │ │ else: │
│ 460 │ │ │ # Try a pickle │
│ 461 │ │ │ if not allow_pickle: │
│ ❱ 462 │ │ │ │ raise ValueError("Cannot load file containing pickled data " │
│ 463 │ │ │ │ │ │ │ │ "when allow_pickle=False") │
│ 464 │ │ │ try: │
│ 465 │ │ │ │ return pickle.load(fid, **pickle_kwargs) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Please set the param allow_pickle=True to deserialize pickle objects in NumpyDeserializer
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 2.237.3
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): TensorFlow (but applies to all)
- Framework version: 2.16 (applies to all)
- Python version: 3.12.7
- CPU or GPU: CPU
- Custom Docker image (Y/N): N
Additional context
model.deploy()
returns an instance of Predictor
, which does work properly. However, this limits inference to Python clients who have created a Predictor
instance. Clients in other languages are unable to invoke the inference server, for example.
Workarounds may be possible by implementing CustomPayloadTranslator
and providing it via input_translator
, but I have not yet tested this.