diff --git a/training/DeepSpeed-Domino/README.md b/training/DeepSpeed-Domino/README.md
index 92f6d1ecc..3c1f4040b 100644
--- a/training/DeepSpeed-Domino/README.md
+++ b/training/DeepSpeed-Domino/README.md
@@ -6,7 +6,7 @@ pip install -r requirements.txt
 ```
 
 ## Prepare the Dataset
-Follow the instructions from [Megatron-DeepSpeed](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset.
+Follow the instructions from [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing#download-and-pre-process-training-dataset) to prepare the training dataset.
 
 ## Execute Domino Training
 
@@ -38,16 +38,16 @@ The output should look like this:
 
 ```
 training ...
-iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152
-iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988
-iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736
-iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979
-iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377
-iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254
-iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691
-iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165
-iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684
-iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998
+iteration: 1 | loss: 11.318 | iteration time (ms): 2174.0469932556152 
+iteration: 2 | loss: 11.307 | iteration time (ms): 1414.4024848937988 
+iteration: 3 | loss: 11.323 | iteration time (ms): 1385.9455585479736 
+iteration: 4 | loss: 11.310 | iteration time (ms): 1475.5175113677979 
+iteration: 5 | loss: 11.306 | iteration time (ms): 1395.7207202911377 
+iteration: 6 | loss: 11.315 | iteration time (ms): 1392.2104835510254 
+iteration: 7 | loss: 11.314 | iteration time (ms): 1402.6703834533691 
+iteration: 8 | loss: 11.309 | iteration time (ms): 1450.613260269165 
+iteration: 9 | loss: 11.305 | iteration time (ms): 1473.1688499450684 
+iteration: 10 | loss: 11.320 | iteration time (ms): 1398.4534740447998 
 [2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73015 exits successfully.
 [2024-11-04 15:32:30,918] [INFO] [launch.py:351:main] Process 73017 exits successfully.
 [2024-11-04 15:32:30,919] [INFO] [launch.py:351:main] Process 73014 exits successfully.
diff --git a/training/DeepSpeed-Domino/domino/arguments.py b/training/DeepSpeed-Domino/domino/arguments.py
index 8bc59223a..7c1938e0b 100644
--- a/training/DeepSpeed-Domino/domino/arguments.py
+++ b/training/DeepSpeed-Domino/domino/arguments.py
@@ -204,9 +204,26 @@ def parse_args():
                        'validation set.')
     parser.add_argument('--log-interval', type=int, default=100,
                        help='Report loss and timing interval.')
+    parser.add_argument('--save', type=str, default=None,
+                       help='Output directory to save checkpoints to.')
+    parser.add_argument('--no-save-optim', action='store_true', default=None,
+                       help='Do not save current optimizer.')
+    parser.add_argument('--no-save-rng', action='store_true', default=None,
+                       help='Do not save current rng state.')
+    parser.add_argument('--save-interval', type=int, default=None,
+                       help='Number of iterations between checkpoint saves.')
+    parser.add_argument('--load', type=str, default=None,
+                       help='Directory containing a model checkpoint.')
+    parser.add_argument('--no-load-optim', action='store_true', default=None,
+                       help='Do not load optimizer when loading checkpoint.')
+    parser.add_argument('--no-load-rng', action='store_true', default=None,
+                       help='Do not load rng state when loading checkpoint.')
+    parser.add_argument('--exit-on-missing-checkpoint', action='store_true',
+                       help="If '--load' is set, but checkpoint is not found "
+                       "(e.g., path typo), then exit instead of random "
+                       "initialization.")
     parser.add_argument('--save-interval', type=int, default=None,
                        help='Number of iterations between checkpoint saves.')
-    
     args = parser.parse_args()
 
     args.rank = int(os.getenv('RANK', '0'))
diff --git a/training/DeepSpeed-Domino/domino/initialize.py b/training/DeepSpeed-Domino/domino/initialize.py
index 36e0fa1bc..51f2213d2 100644
--- a/training/DeepSpeed-Domino/domino/initialize.py
+++ b/training/DeepSpeed-Domino/domino/initialize.py
@@ -14,6 +14,7 @@
 from domino.modules.fused_bias_gelu import bias_gelu
 
 from megatron import fused_kernels
+import deepspeed
 
 
 def initialize_domino():
@@ -37,6 +38,9 @@ def initialize_domino():
         world_size=args.world_size,
         rank=args.rank
     )
+
+    deepspeed.init_distributed()
+
     mpu.initialize_model_parallel(args.tensor_model_parallel_size)
     seed_ = args.seed
     data_parallel_random_init = False
diff --git a/training/DeepSpeed-Domino/domino/language_model.py b/training/DeepSpeed-Domino/domino/language_model.py
index 2cfb2f9fd..5cbee692f 100644
--- a/training/DeepSpeed-Domino/domino/language_model.py
+++ b/training/DeepSpeed-Domino/domino/language_model.py
@@ -85,6 +85,71 @@ def forward(self, input_ids, position_ids):
         return combined_embeds
 
 
+    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
+        """For easy load."""
+
+        state_dict_ = {}
+        state_dict_[self._word_embeddings_key] \
+            = self.word_embeddings.state_dict(prefix=prefix,
+                                              keep_vars=keep_vars)
+        if self.add_position_embedding:
+            state_dict_[self._position_embeddings_key] \
+                = self.position_embeddings.state_dict(prefix=prefix,
+                                                  keep_vars=keep_vars)
+        if self.num_tokentypes > 0:
+            state_dict_[self._tokentype_embeddings_key] \
+                = self.tokentype_embeddings.state_dict(prefix=prefix,
+                                                       keep_vars=keep_vars)
+
+        return state_dict_
+
+    def load_state_dict(self, state_dict, strict=True):
+        """Customized load."""
+
+        # Word embedding.
+        if self._word_embeddings_key in state_dict:
+            state_dict_ = state_dict[self._word_embeddings_key]
+        else:
+            # for backward compatibility.
+            state_dict_ = {}
+            for key in state_dict.keys():
+                if 'word_embeddings' in key:
+                    state_dict_[key.split('word_embeddings.')[1]] \
+                        = state_dict[key]
+        self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+        # Position embedding.
+        if self.add_position_embedding:
+            if self._position_embeddings_key in state_dict:
+                state_dict_ = state_dict[self._position_embeddings_key]
+            else:
+                # for backward compatibility.
+                state_dict_ = {}
+                for key in state_dict.keys():
+                    if 'position_embeddings' in key:
+                        state_dict_[key.split('position_embeddings.')[1]] \
+                            = state_dict[key]
+            self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+        # Tokentype embedding.
+        if self.num_tokentypes > 0:
+            state_dict_ = {}
+            if self._tokentype_embeddings_key in state_dict:
+                state_dict_ = state_dict[self._tokentype_embeddings_key]
+            else:
+                # for backward compatibility.
+                for key in state_dict.keys():
+                    if 'tokentype_embeddings' in key:
+                        state_dict_[key.split('tokentype_embeddings.')[1]] \
+                            = state_dict[key]
+            if len(state_dict_.keys()) > 0:
+                self.tokentype_embeddings.load_state_dict(state_dict_,
+                                                          strict=strict)
+            else:
+                print('***WARNING*** expected tokentype embeddings in the '
+                      'checkpoint but could not find it', flush=True)
+
+
 class RotaryEmbedding(nn.Module):
     def __init__(self, dim, seq_len_interpolation_factor=None):
         super().__init__()
@@ -190,4 +255,91 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
         encoder_output = encoder_output_t
 
         return encoder_output
-        
\ No newline at end of file
+
+    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
+        """For easy load."""
+
+        state_dict_ = {}
+        if self.pre_process:
+            state_dict_[self._embedding_key] \
+                = self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
+                                                                keep_vars=keep_vars)
+        if self.add_encoder:
+            state_dict_[self._encoder_key] \
+                = self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
+                                                              keep_vars=keep_vars)
+        if self.post_process:
+            if self.add_pooler:
+                state_dict_[self._pooler_key] \
+                    = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
+                                                                 keep_vars=keep_vars)
+            if self.untie_embeddings_and_output_weights:
+                state_dict_[self._output_layer_key] \
+                    = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars)
+
+        if self.add_decoder:
+            state_dict_[self._decoder_key] \
+                = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
+                                                              keep_vars=keep_vars)
+
+        return state_dict_
+
+    def load_state_dict(self, state_dict, strict=True):
+        """Customized load."""
+
+        # Embedding.
+        if self.pre_process:
+            if self._embedding_key in state_dict:
+                state_dict_ = state_dict[self._embedding_key]
+            else:
+                # for backward compatibility.
+                state_dict_ = {}
+                for key in state_dict.keys():
+                    if '_embeddings' in key:
+                        state_dict_[key] = state_dict[key]
+            self.embedding.load_state_dict(state_dict_, strict=strict)
+
+        # Encoder.
+        if self.add_encoder:
+            if self._encoder_key in state_dict:
+                state_dict_ = state_dict[self._encoder_key]
+            # For backward compatibility.
+            elif 'transformer' in state_dict:
+                state_dict_ = state_dict['transformer']
+            else:
+                # For backward compatibility.
+                state_dict_ = {}
+                for key in state_dict.keys():
+                    if 'transformer.' in key:
+                        state_dict_[key.split('transformer.')[1]] = state_dict[key]
+
+            # For backward compatibility.
+            state_dict_self_attention = {}
+            for key in state_dict_.keys():
+                if '.attention.' in key:
+                    state_dict_self_attention[key.replace(".attention.",
+                        ".self_attention.")] = state_dict_[key]
+                else:
+                    state_dict_self_attention[key] = state_dict_[key]
+            state_dict_ = state_dict_self_attention
+
+            self.encoder.load_state_dict(state_dict_, strict=strict)
+
+        # Pooler.
+        if self.post_process:
+            if self.add_pooler:
+                assert 'pooler' in state_dict, \
+                    'could not find data for pooler in the checkpoint'
+                self.pooler.load_state_dict(state_dict[self._pooler_key],
+                                            strict=strict)
+            if self.untie_embeddings_and_output_weights:
+                assert 'output_layer' in state_dict, \
+                    'could not find data for output_layer in the checkpoint'
+                self.output_layer.load_state_dict(state_dict[self._output_layer_key],
+                                                  strict=strict)
+        # Decoder.
+        if self.add_decoder:
+            assert 'decoder' in state_dict, \
+                'could not find data for pooler in the checkpoint'
+            self.decoder.load_state_dict(state_dict[self._decoder_key],
+                                         strict=strict)
diff --git a/training/DeepSpeed-Domino/domino/modules/module.py b/training/DeepSpeed-Domino/domino/modules/module.py
index b89bbc21f..0f42ca764 100644
--- a/training/DeepSpeed-Domino/domino/modules/module.py
+++ b/training/DeepSpeed-Domino/domino/modules/module.py
@@ -25,6 +25,14 @@ def __init__(self, config=None, share_embeddings_and_output_weights=True):
         self.config = config
         self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
 
+
+    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
+        """Use this function to override the state dict for
+           saving checkpoints.
+        """
+
+        return self.state_dict(prefix=prefix, keep_vars=keep_vars)
+
     def initialize_word_embeddings(self):
         self.share_embeddings_and_output_weights = True
         return
@@ -74,7 +82,8 @@ def float_conversion(val):
     return conversion_helper(val, float_conversion)
 
 
-class Float16Module(torch.nn.Module):
+# class Float16Module(torch.nn.Module):
+class Float16Module(DominoModule):
 
     def __init__(self, module, args):
         super(Float16Module, self).__init__()
@@ -91,3 +100,10 @@ def forward(self, *inputs, **kwargs):
             outputs = float16_to_fp32(outputs)
         return outputs
 
+
+    def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
+        """ Retrieve state_dict from the module being wrapped."""
+        return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
+
+    def load_state_dict(self, state_dict, strict=True):
+        self.module.load_state_dict(state_dict, strict=strict)
diff --git a/training/DeepSpeed-Domino/domino/training.py b/training/DeepSpeed-Domino/domino/training.py
index 59e253fcf..142a30cf7 100644
--- a/training/DeepSpeed-Domino/domino/training.py
+++ b/training/DeepSpeed-Domino/domino/training.py
@@ -17,6 +17,9 @@
 from domino.tensor_parallel.data import broadcast_data
 
 
+from megatron.checkpointing import load_checkpoint
+from megatron.checkpointing import save_checkpoint
+
 def is_rank_0():
     # if torch.cuda.current_device() == 0:
     if torch.distributed.get_rank() == 0:
@@ -109,7 +112,10 @@ def setup_model_and_optimizer(base_model,
     optimizer = get_megatron_optimizer(models, no_wd_decay_cond, scale_lr_cond)
     opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
 
-    args.iteration = 0
+    if args.load is not None:
+        args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
+    else:
+        args.iteration = 0
 
     return model, optimizer, opt_param_scheduler
 
@@ -297,6 +303,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
                         config)
 
         iteration += 1
+
+        if args.save and args.save_interval and \
+           iteration % args.save_interval == 0:
+            save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
+
         args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
             args.micro_batch_size * get_num_microbatches()
         
diff --git a/training/DeepSpeed-Domino/megatron/checkpointing.py b/training/DeepSpeed-Domino/megatron/checkpointing.py
index e88b58513..5c174624f 100644
--- a/training/DeepSpeed-Domino/megatron/checkpointing.py
+++ b/training/DeepSpeed-Domino/megatron/checkpointing.py
@@ -9,12 +9,11 @@
 
 import torch
 
-from megatron import update_num_microbatches
-from megatron.core import mpu, tensor_parallel
-from .global_vars import get_args
-from .utils import (unwrap_model,
-                    print_rank_0)
-
+# from megatron import update_num_microbatches               
+import domino.parallel_state as mpu
+from domino.tensor_parallel.random import get_cuda_rng_tracker
+from domino.arguments import get_args
+from domino.utils import unwrap_model, print_rank_0
 
 _CHECKPOINT_VERSION = None
 
@@ -194,7 +193,7 @@ def get_rng_state():
         'np_rng_state': np.random.get_state(),
         'torch_rng_state': torch.get_rng_state(),
         'cuda_rng_state': torch.cuda.get_rng_state(),
-        'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
+        'rng_tracker_states': get_cuda_rng_tracker().get_states()}
 
     rng_state_list = None
     if torch.distributed.is_initialized() and \
@@ -218,6 +217,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
 
     # Only rank zero of the data parallel writes to the disk.
     model = unwrap_model(model)
+    model_module = model.module
+    model = [model_module]
 
     print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
         iteration, args.save))
@@ -241,7 +242,10 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
 
         # Arguments, iteration, and model.
         state_dict = {}
-        state_dict['args'] = args
+        t_args = args
+        t_args.init_method = None
+        t_args.output_layer_init_method = None
+        state_dict['args'] = t_args
         state_dict['checkpoint_version'] = 3.0
         state_dict['iteration'] = iteration
         if len(model) == 1:
@@ -503,6 +507,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
     load_dir = getattr(args, load_arg)
 
     model = unwrap_model(model)
+    model_module = model.module
+    model = [model_module]
 
     state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False)
 
@@ -522,6 +528,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
     set_checkpoint_version(state_dict.get('checkpoint_version', 0))
 
     # Set iteration.
+    args.finetune = False
     if args.finetune or release:
         iteration = 0
     else:
@@ -544,7 +551,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
         check_checkpoint_args(checkpoint_args)
         args.consumed_train_samples = getattr(checkpoint_args,
                                               'consumed_train_samples', 0)
-        update_num_microbatches(consumed_samples=args.consumed_train_samples)
+        # update_num_microbatches(consumed_samples=args.consumed_train_samples)
         args.consumed_valid_samples = getattr(checkpoint_args,
                                               'consumed_valid_samples', 0)
     else:
@@ -614,7 +621,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
                 # Check for empty states array
                 if not rng_state['rng_tracker_states']:
                     raise KeyError
-                tensor_parallel.get_cuda_rng_tracker().set_states(
+                get_cuda_rng_tracker().set_states(
                     rng_state['rng_tracker_states'])
             else:  # backward compatability
                 random.setstate(state_dict['random_rng_state'])
@@ -624,7 +631,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
                 # Check for empty states array
                 if not state_dict['rng_tracker_states']:
                     raise KeyError
-                tensor_parallel.get_cuda_rng_tracker().set_states(
+                get_cuda_rng_tracker().set_states(
                     state_dict['rng_tracker_states'])
         except KeyError:
             print_rank_0('Unable to load rng state from checkpoint {}. '
diff --git a/training/DeepSpeed-Domino/megatron/initialize.py b/training/DeepSpeed-Domino/megatron/initialize.py
index 367ba85cb..b8aa6d3e8 100644
--- a/training/DeepSpeed-Domino/megatron/initialize.py
+++ b/training/DeepSpeed-Domino/megatron/initialize.py
@@ -16,7 +16,7 @@
 from megatron import get_tensorboard_writer
 from megatron.core import mpu, tensor_parallel
 from megatron.arguments import parse_args, validate_args
-from megatron.checkpointing import load_args_from_checkpoint
+# from megatron.checkpointing import load_args_from_checkpoint
 from megatron.global_vars import set_global_variables
 from megatron.model.transformer import bias_dropout_add_fused_train
 from megatron.model.fused_bias_gelu import bias_gelu