Skip to content

[Draft] Add support for seq split in Domino #961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions training/DeepSpeed-Domino/domino/arguments.py
Original file line number Diff line number Diff line change
@@ -206,6 +206,8 @@ def parse_args():
help='Report loss and timing interval.')
parser.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')
parser.add_argument('--input-split-dim', type=str, default='batch',
help='Dimension for input split.')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimension for input split. ['batch', 'seq']


args = parser.parse_args()

@@ -355,6 +357,8 @@ class TransformerConfig():
no_sync_func: Callable = None
# grad_sync_func: Callable = None
# param_sync_func: Callable = None

input_split_dim: str = 'batch'

def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
@@ -396,5 +400,6 @@ def core_transformer_config_from_args(args):
kw_args['init_method'] = args.init_method
kw_args['output_layer_init_method'] = args.init_method
kw_args['params_dtype'] = args.params_dtype
kw_args['input_split_dim'] = args.input_split_dim

return TransformerConfig(**kw_args)
22 changes: 18 additions & 4 deletions training/DeepSpeed-Domino/domino/language_model.py
Original file line number Diff line number Diff line change
@@ -127,6 +127,7 @@ def __init__(self,
self.init_method = config.init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.encoder_hidden_state = None
self.input_split_dim = config.input_split_dim

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one possible thing we need to verify and check is if split dim is seq, can we still initialize the same rope embedding, which is a very popular position embedding function. because for rope:

        if self.use_rotary_position_embeddings:
            self.seq_length = args.seq_length

our self.seq_length may not be correct ( ours might be half of original seq_length).


if self.pre_process:
self.embedding = Embedding(self.hidden_size,
@@ -177,17 +178,30 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,

encoder_out_size = encoder_input.shape
p_batch_size = encoder_out_size[1] // 2
p_seq_size = encoder_out_size[0] // 2
dtype = encoder_input.dtype
encoder_output_t = torch.empty(encoder_out_size, dtype=dtype, device=torch.cuda.current_device())
intra_partitions = 2
encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=1)
if self.input_split_dim == 'batch':
encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=1)
elif self.input_split_dim == 'seq':
encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=0)
else:
raise NotImplementedError
encoder_outputs = self.encoder(
encoder_inputs,
enc_attn_mask,
rotary_pos_emb=rotary_pos_emb)
encoder_output_t[:, 0:p_batch_size, :] = encoder_outputs[0]
encoder_output_t[:, p_batch_size:2*p_batch_size, :] = encoder_outputs[1]

if self.input_split_dim == 'batch':
encoder_output_t[:, 0:p_batch_size, :] = encoder_outputs[0]
encoder_output_t[:, p_batch_size:2*p_batch_size, :] = encoder_outputs[1]
elif self.input_split_dim == 'seq':
encoder_output_t[0:p_seq_size, :, :] = encoder_outputs[0]
encoder_output_t[p_seq_size:2*p_seq_size, :, :] = encoder_outputs[1]
else:
raise NotImplementedError

encoder_output = encoder_output_t

return encoder_output