Skip to content

Commit 9285b6a

Browse files
pnpnpnlaurenyu
authored andcommitted
Adding Object2Vec support to SageMaker Python SDK (#467)
1 parent 5201c60 commit 9285b6a

File tree

10 files changed

+1580
-5
lines changed

10 files changed

+1580
-5
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.14.1
6+
======
7+
8+
* feature: Estimators: add support for Amazon Object2Vec algorithm
9+
510
1.14.0
611
======
712

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you
394394
The full list of algorithms is available at: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html
395395
396396
The SageMaker Python SDK includes estimator wrappers for the AWS K-means, Principal Components Analysis (PCA), Linear Learner, Factorization Machines,
397-
Latent Dirichlet Allocation (LDA), Neural Topic Model (NTM) Random Cut Forest and k-nearest neighbors (k-NN) algorithms.
397+
Latent Dirichlet Allocation (LDA), Neural Topic Model (NTM), Random Cut Forest, k-nearest neighbors (k-NN), and Object2Vec algorithms.
398398
399399
For more information, see `AWS SageMaker Estimators and Models`_.
400400

src/sagemaker/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.amazon.randomcutforest import (RandomCutForest, RandomCutForestModel, # noqa: F401
2424
RandomCutForestPredictor)
2525
from sagemaker.amazon.knn import KNN, KNNModel, KNNPredictor # noqa: F401
26+
from sagemaker.amazon.object2vec import Object2Vec, Object2VecModel # noqa: F401
2627

2728
from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics # noqa: F401
2829
from sagemaker.local.local_session import LocalSession # noqa: F401
@@ -35,4 +36,4 @@
3536
from sagemaker.session import s3_input # noqa: F401
3637
from sagemaker.session import get_execution_role # noqa: F401
3738

38-
__version__ = '1.14.0'
39+
__version__ = '1.14.1'

src/sagemaker/amazon/README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you
77

88
The full list of algorithms is available on the AWS website: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html
99

10-
SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA), Neural Topic Model(NTM), Random Cut Forest algorithms and k-nearest neighbors (k-NN).
10+
SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA), Neural Topic Model(NTM), Random Cut Forest algorithms, k-nearest neighbors (k-NN) and Object2Vec.
1111

1212
Definition and usage
1313
~~~~~~~~~~~~~~~~~~~~

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def registry(region_name, algorithm=None):
280280
https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon
281281
"""
282282
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm",
283-
"randomcutforest", "knn"]:
283+
"randomcutforest", "knn", "object2vec"]:
284284
account_id = {
285285
"us-east-1": "382416733822",
286286
"us-east-2": "404615174143",

src/sagemaker/amazon/object2vec.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
16+
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
17+
from sagemaker.amazon.validation import ge, le, isin
18+
from sagemaker.predictor import RealTimePredictor
19+
from sagemaker.model import Model
20+
from sagemaker.session import Session
21+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
22+
23+
24+
class Object2Vec(AmazonAlgorithmEstimatorBase):
25+
26+
repo_name = 'object2vec'
27+
repo_version = 1
28+
MINI_BATCH_SIZE = 32
29+
30+
enc_dim = hp('enc_dim', (ge(4), le(10000)),
31+
'An integer in [4, 10000]', int)
32+
mini_batch_size = hp('mini_batch_size', (ge(1), le(10000)),
33+
'An integer in [1, 10000]', int)
34+
epochs = hp('epochs', (ge(1), le(100)),
35+
'An integer in [1, 100]', int)
36+
early_stopping_patience = hp('early_stopping_patience', (ge(1), le(5)),
37+
'An integer in [1, 5]', int)
38+
early_stopping_tolerance = hp('early_stopping_tolerance', (ge(1e-06), le(0.1)),
39+
'A float in [1e-06, 0.1]', float)
40+
dropout = hp('dropout', (ge(0.0), le(1.0)),
41+
'A float in [0.0, 1.0]', float)
42+
weight_decay = hp('weight_decay', (ge(0.0), le(10000.0)),
43+
'A float in [0.0, 10000.0]', float)
44+
bucket_width = hp('bucket_width', (ge(0), le(100)),
45+
'An integer in [0, 100]', int)
46+
num_classes = hp('num_classes', (ge(2), le(30)),
47+
'An integer in [2, 30]', int)
48+
mlp_layers = hp('mlp_layers', (ge(1), le(10)),
49+
'An integer in [1, 10]', int)
50+
mlp_dim = hp('mlp_dim', (ge(2), le(10000)),
51+
'An integer in [2, 10000]', int)
52+
mlp_activation = hp('mlp_activation', isin("tanh", "relu", "linear"),
53+
'One of "tanh", "relu", "linear"', str)
54+
output_layer = hp('output_layer', isin("softmax", "mean_squared_error"),
55+
'One of "softmax", "mean_squared_error"', str)
56+
optimizer = hp('optimizer', isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"),
57+
'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', str)
58+
learning_rate = hp('learning_rate', (ge(1e-06), le(1.0)),
59+
'A float in [1e-06, 1.0]', float)
60+
enc0_network = hp('enc0_network', isin("hcnn", "bilstm", "pooled_embedding"),
61+
'One of "hcnn", "bilstm", "pooled_embedding"', str)
62+
enc1_network = hp('enc1_network', isin("hcnn", "bilstm", "pooled_embedding", "enc0"),
63+
'One of "hcnn", "bilstm", "pooled_embedding", "enc0"', str)
64+
enc0_cnn_filter_width = hp('enc0_cnn_filter_width', (ge(1), le(9)),
65+
'An integer in [1, 9]', int)
66+
enc1_cnn_filter_width = hp('enc1_cnn_filter_width', (ge(1), le(9)),
67+
'An integer in [1, 9]', int)
68+
enc0_max_seq_len = hp('enc0_max_seq_len', (ge(1), le(5000)),
69+
'An integer in [1, 5000]', int)
70+
enc1_max_seq_len = hp('enc1_max_seq_len', (ge(1), le(5000)),
71+
'An integer in [1, 5000]', int)
72+
enc0_token_embedding_dim = hp('enc0_token_embedding_dim', (ge(2), le(1000)),
73+
'An integer in [2, 1000]', int)
74+
enc1_token_embedding_dim = hp('enc1_token_embedding_dim', (ge(2), le(1000)),
75+
'An integer in [2, 1000]', int)
76+
enc0_vocab_size = hp('enc0_vocab_size', (ge(2), le(3000000)),
77+
'An integer in [2, 3000000]', int)
78+
enc1_vocab_size = hp('enc1_vocab_size', (ge(2), le(3000000)),
79+
'An integer in [2, 3000000]', int)
80+
enc0_layers = hp('enc0_layers', (ge(1), le(4)),
81+
'An integer in [1, 4]', int)
82+
enc1_layers = hp('enc1_layers', (ge(1), le(4)),
83+
'An integer in [1, 4]', int)
84+
enc0_freeze_pretrained_embedding = hp('enc0_freeze_pretrained_embedding', (),
85+
'Either True or False', bool)
86+
enc1_freeze_pretrained_embedding = hp('enc1_freeze_pretrained_embedding', (),
87+
'Either True or False', bool)
88+
89+
def __init__(self, role, train_instance_count, train_instance_type,
90+
epochs,
91+
enc0_max_seq_len,
92+
enc0_vocab_size,
93+
enc_dim=None,
94+
mini_batch_size=None,
95+
early_stopping_patience=None,
96+
early_stopping_tolerance=None,
97+
dropout=None,
98+
weight_decay=None,
99+
bucket_width=None,
100+
num_classes=None,
101+
mlp_layers=None,
102+
mlp_dim=None,
103+
mlp_activation=None,
104+
output_layer=None,
105+
optimizer=None,
106+
learning_rate=None,
107+
enc0_network=None,
108+
enc1_network=None,
109+
enc0_cnn_filter_width=None,
110+
enc1_cnn_filter_width=None,
111+
enc1_max_seq_len=None,
112+
enc0_token_embedding_dim=None,
113+
enc1_token_embedding_dim=None,
114+
enc1_vocab_size=None,
115+
enc0_layers=None,
116+
enc1_layers=None,
117+
enc0_freeze_pretrained_embedding=None,
118+
enc1_freeze_pretrained_embedding=None,
119+
**kwargs):
120+
"""Object2Vec is :class:`Estimator` used for anomaly detection.
121+
122+
This Estimator may be fit via calls to
123+
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`.
124+
There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that
125+
can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed
126+
to the `fit` call.
127+
128+
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker
129+
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an
130+
Endpoint, deploy returns a :class:`~sagemaker.amazon.RealTimePredictor` object that can be used
131+
for inference calls using the trained model hosted in the SageMaker Endpoint.
132+
133+
Object2Vec Estimators can be configured by setting hyperparameters. The available hyperparameters for
134+
Object2Vec are documented below.
135+
136+
For further information on the AWS Object2Vec algorithm,
137+
please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/object2vec.html
138+
139+
Args:
140+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and
141+
APIs that create Amazon SageMaker endpoints use this role to access
142+
training data and model artifacts. After the endpoint is created,
143+
the inference code might use the IAM role, if accessing AWS resource.
144+
train_instance_count (int): Number of Amazon EC2 instances to use for training.
145+
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
146+
147+
epochs(int): Total number of epochs for SGD training
148+
enc0_max_seq_len(int): Maximum sequence length
149+
enc0_vocab_size(int): Vocabulary size of tokens
150+
151+
enc_dim(int): Optional. Dimension of the output of the embedding layer
152+
mini_batch_size(int): Optional. mini batch size for SGD training
153+
early_stopping_patience(int): Optional. The allowed number of consecutive epochs without improvement
154+
before early stopping is applied
155+
early_stopping_tolerance(float): Optional. The value used to determine whether the algorithm has made
156+
improvement between two consecutive epochs for early stopping
157+
dropout(float): Optional. Dropout probability on network layers
158+
weight_decay(float): Optional. Weight decay parameter during optimization
159+
bucket_width(int): Optional. The allowed difference between data sequence length when bucketing is enabled
160+
num_classes(int): Optional. Number of classes for classification training (ignored for regression problems)
161+
mlp_layers(int): Optional. Number of MLP layers in the network
162+
mlp_dim(int): Optional. Dimension of the output of MLP layer
163+
mlp_activation(str): Optional. Type of activation function for the MLP layer
164+
output_layer(str): Optional. Type of output layer
165+
optimizer(str): Optional. Type of optimizer for training
166+
learning_rate(float): Optional. Learning rate for SGD training
167+
enc0_network(str): Optional. Network model of encoder "enc0"
168+
enc1_network(str): Optional. Network model of encoder "enc1"
169+
enc0_cnn_filter_width(int): Optional. CNN filter width
170+
enc1_cnn_filter_width(int): Optional. CNN filter width
171+
enc1_max_seq_len(int): Optional. Maximum sequence length
172+
enc0_token_embedding_dim(int): Optional. Output dimension of token embedding layer
173+
enc1_token_embedding_dim(int): Optional. Output dimension of token embedding layer
174+
enc1_vocab_size(int): Optional. Vocabulary size of tokens
175+
enc0_layers(int): Optional. Number of layers in encoder
176+
enc1_layers(int): Optional. Number of layers in encoder
177+
enc0_freeze_pretrained_embedding(bool): Optional. Freeze pretrained embedding weights
178+
enc1_freeze_pretrained_embedding(bool): Optional. Freeze pretrained embedding weights
179+
180+
**kwargs: base class keyword argument values.
181+
"""
182+
183+
super(Object2Vec, self).__init__(role, train_instance_count, train_instance_type, **kwargs)
184+
185+
self.enc_dim = enc_dim
186+
self.mini_batch_size = mini_batch_size
187+
self.epochs = epochs
188+
self.early_stopping_patience = early_stopping_patience
189+
self.early_stopping_tolerance = early_stopping_tolerance
190+
self.dropout = dropout
191+
self.weight_decay = weight_decay
192+
self.bucket_width = bucket_width
193+
self.num_classes = num_classes
194+
self.mlp_layers = mlp_layers
195+
self.mlp_dim = mlp_dim
196+
self.mlp_activation = mlp_activation
197+
self.output_layer = output_layer
198+
self.optimizer = optimizer
199+
self.learning_rate = learning_rate
200+
self.enc0_network = enc0_network
201+
self.enc1_network = enc1_network
202+
self.enc0_cnn_filter_width = enc0_cnn_filter_width
203+
self.enc1_cnn_filter_width = enc1_cnn_filter_width
204+
self.enc0_max_seq_len = enc0_max_seq_len
205+
self.enc1_max_seq_len = enc1_max_seq_len
206+
self.enc0_token_embedding_dim = enc0_token_embedding_dim
207+
self.enc1_token_embedding_dim = enc1_token_embedding_dim
208+
self.enc0_vocab_size = enc0_vocab_size
209+
self.enc1_vocab_size = enc1_vocab_size
210+
self.enc0_layers = enc0_layers
211+
self.enc1_layers = enc1_layers
212+
self.enc0_freeze_pretrained_embedding = enc0_freeze_pretrained_embedding
213+
self.enc1_freeze_pretrained_embedding = enc1_freeze_pretrained_embedding
214+
215+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
216+
"""Return a :class:`~sagemaker.amazon.Object2VecModel` referencing the latest
217+
s3 model data produced by this Estimator.
218+
219+
Args:
220+
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the model.
221+
Default: use subnets and security groups from this Estimator.
222+
* 'Subnets' (list[str]): List of subnet ids.
223+
* 'SecurityGroupIds' (list[str]): List of security group ids.
224+
"""
225+
return Object2VecModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session,
226+
vpc_config=self.get_vpc_config(vpc_config_override))
227+
228+
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
229+
if mini_batch_size is None:
230+
mini_batch_size = self.MINI_BATCH_SIZE
231+
232+
super(Object2Vec, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
233+
234+
235+
class Object2VecModel(Model):
236+
"""Reference Object2Vec s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an
237+
Endpoint and returns a Predictor that calculates anomaly scores for datapoints."""
238+
239+
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
240+
sagemaker_session = sagemaker_session or Session()
241+
repo = '{}:{}'.format(Object2Vec.repo_name, Object2Vec.repo_version)
242+
image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name,
243+
Object2Vec.repo_name), repo)
244+
super(Object2VecModel, self).__init__(model_data, image, role,
245+
predictor_cls=RealTimePredictor,
246+
sagemaker_session=sagemaker_session,
247+
**kwargs)

src/sagemaker/tuner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
'linear-learner': 'LinearLearner',
3333
'ntm': 'NTM',
3434
'randomcutforest': 'RandomCutForest',
35-
'knn': 'KNN'
35+
'knn': 'KNN',
36+
'object2vec': 'Object2Vec',
3637
}
3738

3839

0 commit comments

Comments
 (0)