Skip to content

Commit 13ad978

Browse files
fix: keep sagemaker_session from being overridden to None (#5021)
* fix: keep sagemaker_session from being overridden to None, add unit/integ tests * remove commented code * fix styling issues --------- Co-authored-by: Zhaoqi <jzhaoqwa@amazon.com>
1 parent c0b740c commit 13ad978

File tree

5 files changed

+60
-2
lines changed

5 files changed

+60
-2
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def __init__(
150150
if s3_client_config
151151
else boto3.client("s3", region_name=self._region)
152152
)
153-
self._sagemaker_session = sagemaker_session
153+
# Fallback in case a caller overrides sagemaker_session to None
154+
self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
154155

155156
def set_region(self, region: str) -> None:
156157
"""Set region for cache. Clears cache after new region is set."""

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def construct_hub_arn_from_name(
7878
account_id: Optional[str] = None,
7979
) -> str:
8080
"""Constructs a Hub arn from the Hub name using default Session values."""
81+
if session is None:
82+
# session is overridden to none by some callers
83+
session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
8184

8285
account_id = account_id or session.account_id()
8386
region = region or session.boto_region_name
@@ -211,6 +214,9 @@ def get_hub_model_version(
211214
ClientError: If the specified model is not found in the hub.
212215
KeyError: If the specified model version is not found.
213216
"""
217+
if sagemaker_session is None:
218+
# sagemaker_session is overridden to none by some callers
219+
sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION
214220

215221
try:
216222
hub_content_summaries = sagemaker_session.list_hub_content_versions(

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ def test_jumpstart_hub_model(setup, add_model_references):
8282
assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name)
8383

8484

85+
def test_jumpstart_hub_model_with_default_session(setup, add_model_references):
86+
model_version = "*"
87+
hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
88+
89+
model_id = "catboost-classification-model"
90+
91+
sagemaker_session = get_sm_session()
92+
93+
model = JumpStartModel(model_id=model_id, model_version=model_version, hub_name=hub_name)
94+
95+
predictor = model.deploy(
96+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
97+
)
98+
99+
assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name)
100+
101+
85102
def test_jumpstart_hub_gated_model(setup, add_model_references):
86103

87104
model_id = "meta-textgeneration-llama-3-2-1b"

tests/unit/sagemaker/jumpstart/hub/test_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
from unittest.mock import patch, Mock
1616
from sagemaker.jumpstart.types import HubArnExtractedInfo
17-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
17+
from sagemaker.jumpstart.constants import (
18+
JUMPSTART_DEFAULT_REGION_NAME,
19+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
20+
)
1821
from sagemaker.jumpstart.hub import parser_utils, utils
1922

2023

@@ -80,6 +83,17 @@ def test_construct_hub_arn_from_name():
8083
)
8184

8285

86+
def test_construct_hub_arn_from_name_with_session_none():
87+
hub_name = "my-cool-hub"
88+
account_id = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.account_id()
89+
boto_region_name = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.boto_region_name
90+
91+
assert (
92+
utils.construct_hub_arn_from_name(hub_name=hub_name, session=None)
93+
== f"arn:aws:sagemaker:{boto_region_name}:{account_id}:hub/{hub_name}"
94+
)
95+
96+
8397
def test_construct_hub_model_arn_from_inputs():
8498
model_name, version = "pytorch-ic-imagenet-v2", "1.0.2"
8599
hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub"

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.jumpstart.cache import (
3030
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
3131
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
32+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3233
JumpStartModelsCache,
3334
)
3435
from sagemaker.jumpstart.constants import (
@@ -57,6 +58,25 @@
5758
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
5859

5960

61+
@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
62+
@patch(
63+
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
64+
)
65+
@patch("boto3.client")
66+
def test_jumpstart_cache_init(mock_boto3_client):
67+
cache = JumpStartModelsCache()
68+
assert cache._region == "dummy-region"
69+
assert cache.s3_bucket_name == "dummy-bucket"
70+
assert cache._manifest_file_s3_key == JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
71+
assert cache._proprietary_manifest_s3_key == JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY
72+
assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION
73+
mock_boto3_client.assert_called_once_with("s3", region_name="dummy-region")
74+
75+
# Some callers override the session to None, should still be set to default
76+
cache = JumpStartModelsCache(sagemaker_session=None)
77+
assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION
78+
79+
6080
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
6181
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
6282
def test_jumpstart_cache_get_header():

0 commit comments

Comments
 (0)