diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index 3efeaa05b..61f99065b 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -178,7 +178,7 @@ async def _gather_tasks(tasks): out_token_throughput = np.round(token_latency_stats.size / elapsed_time, 2) total_token_throughput = np.round(concurrency * test_round * (input_seqlen + output_seqlen) / elapsed_time, 2) - print(f'\n{"-" * 50}\ntotal time: {elapsed_time:.2f}s\n' + print(f'\n{" - " * 50}\ntotal time: {elapsed_time:.2f}s\n' f'concurrency: {concurrency}, test_round: {test_round}\n' f'input_tokens: {input_seqlen}, output_tokens: {output_seqlen}\n' f'first_token latency(min, max, ave): ' @@ -188,7 +188,7 @@ async def _gather_tasks(tasks): f'{token_latency_ave}s\n' f'token_latency percentiles(50%,75%,95%,99%)(s): {percentiles}\n' f'throughput(output): {out_token_throughput} token/s\n' - f'throughput(total): {total_token_throughput} token/s\n{"-" * 50}') + f'throughput(total): {total_token_throughput} token/s\n{" - " * 50}') return model_path, \ [first_token_latency_min, first_token_latency_max, first_token_latency_ave], \ diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 2aa417cc2..8394ac2f3 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. - +from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend from lmdeploy.utils import get_max_batch_size from .cli import CLI @@ -167,6 +167,8 @@ def add_parser_api_server(): ArgumentHelper.dp_rank(pt_group) ArgumentHelper.ep(pt_group) ArgumentHelper.enable_microbatch(pt_group) + ArgumentHelper.role(pt_group) + ArgumentHelper.migration_backend(pt_group) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') @@ -216,7 +218,13 @@ def add_parser_proxy(): parser.set_defaults(run=SubCliServe.proxy) parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Host ip for proxy serving') parser.add_argument('--server-port', type=int, default=8000, help='Server port of the proxy') - parser.add_argument('--strategy', + parser.add_argument('--serving-strategy', + type=str, + choices=['Hybrid', 'DistServe'], + default='Hybrid', + help='the strategy to serve, Hybrid for colocating Prefill and Decode' + 'workloads into same engine, DistServe for Prefill-Decode Disaggregation') + parser.add_argument('--routing-strategy', type=str, choices=['random', 'min_expected_latency', 'min_observed_latency'], default='min_expected_latency', @@ -226,6 +234,15 @@ def add_parser_proxy(): help='Whether to disable cache status of the ' 'proxy. If set, the proxy will forget the status ' 'of the previous time') + + # For Disaggregation + parser.add_argument('--migration-protocol', + type=str, + choices=['RDMA', 'NVLINK'], + default='RDMA', + help='transport protocol of KV migration') + parser.add_argument('--link-type', type=str, choices=['RoCE', 'IB'], default='RoCE', help='RDMA Link Type') + parser.add_argument('--disable-gdr', action='store_true', help='with GPU Direct Memory Access') ArgumentHelper.api_keys(parser) ArgumentHelper.ssl(parser) ArgumentHelper.log_level(parser) @@ -311,7 +328,9 @@ def api_server(args): quant_policy=args.quant_policy, eager_mode=args.eager_mode, max_prefill_token_num=args.max_prefill_token_num, - enable_microbatch=args.enable_microbatch) + enable_microbatch=args.enable_microbatch, + role=EngineRole[args.role], + migration_backend=MigrationBackend[args.migration_backend]) else: from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig(dtype=args.dtype, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 10e07d43f..9b6d14d19 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -527,3 +527,22 @@ def enable_microbatch(parser): return parser.add_argument('--enable-microbatch', action='store_true', help='enable microbatch for specified model') + + # For Disaggregation + @staticmethod + def role(parser): + return parser.add_argument('--role', + type=str, + default='Hybrid', + choices=['Hybrid', 'Prefill', 'Decode'], + help='Hybrid for Non-Disaggregated Engine;' + 'Prefill for Disaggregated Prefill Engine;' + 'Decode for Disaggregated Decode Engine;') + + @staticmethod + def migration_backend(parser): + return parser.add_argument('--migration-backend', + type=str, + default='DLSlime', + choices=['DLSlime'], + help='kvcache migration management backend when PD disaggregation') diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index e07309b84..b3fed7036 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -6,6 +6,9 @@ import torch from pydantic.dataclasses import dataclass as pydantic_dataclass +from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend +from lmdeploy.pytorch.disagg.request import MigrationRequest + from .tokenizer import Tokenizer from .utils import get_logger @@ -107,6 +110,11 @@ class GenerationConfig: output_logits: Literal['all', 'generation'] = None output_last_hidden_state: Literal['all', 'generation'] = None + # for disaggregation + with_cache: bool = False + preserve_cache: bool = False + migration_request: Optional[MigrationRequest] = None + def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """convert stop_words/bad_sords to ids and append the ids to stop_token_ids/bad_token_ids.""" @@ -298,6 +306,10 @@ class PytorchEngineConfig: distributed_executor_backend (str): backend of distributed backend, options: ['uni', 'mp', 'ray'] enable_microbatch (bool): enable microbatch for specified model + role (EngineRole): role of engin, options: ['Hybrid', 'Prefill', + 'Decode']. Default to `EngineRole.Hybrid`. + migration_backend: migration backend. options: ['DLSlime']. + Default to `MigrationBackend.DLSlime`. """ dtype: str = 'auto' tp: int = 1 @@ -324,6 +336,9 @@ class PytorchEngineConfig: distributed_executor_backend: str = None enable_microbatch: bool = False + role: EngineRole = EngineRole.Hybrid + migration_backend: MigrationBackend = MigrationBackend.DLSlime + def __post_init__(self): """Check input validation.""" assert self.dtype in ['auto', 'float16', 'bfloat16'] @@ -404,6 +419,8 @@ class EngineOutput: may not equal to the length of token_ids logprobs (List[Dict[int, float]]): the top logprobs for each output position. + cache_block_ids (List[int]): send cache blocks back for migration in + Disaggregated LLM Serving when Prefill Engine is Done. """ status: ResponseType token_ids: List[int] @@ -412,6 +429,8 @@ class EngineOutput: logits: torch.Tensor = None last_hidden_state: torch.Tensor = None + cache_block_ids: Optional[List[int]] = None + @dataclass class VisionConfig: diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 1a48a262a..dea2a1877 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -4,6 +4,8 @@ import torch +from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend + def _update_torch_dtype(config: 'ModelConfig', dtype: str): """Update the torch dtype from the model config. @@ -80,6 +82,10 @@ class CacheConfig: quant_policy: Literal[0, 4, 8] = 0 device_type: str = 'cuda' + # For PD Disaggregation + role: EngineRole = EngineRole.Hybrid + migration_backend: MigrationBackend = MigrationBackend.DLSlime + def __post_init__(self): """post init.""" from lmdeploy.utils import get_logger diff --git a/lmdeploy/pytorch/disagg/README.md b/lmdeploy/pytorch/disagg/README.md new file mode 100644 index 000000000..4cb435970 --- /dev/null +++ b/lmdeploy/pytorch/disagg/README.md @@ -0,0 +1,103 @@ +# LMDeploy-DistServe + +## Key Components + +1. ​**Router Service**: Coordinates between prefill/decode engines +2. ​**Migration Manager**: Facilitates high-performance memory sharing + +## Installation + +``` +# Inference Engine +pip install lmdeploy[all] >= 0.7.0 + +# Transfer Engine +pip install dlslime>=0.0.1.post7 +``` + +## Quick Start + +A PD disaggregated deployment of DeepSeekV3 is shown below: + +### 1. Launch Router Service + +```shell +lmdeploy serve proxy --server-name 0.0.0.0 --server-port 8000 --routing-strategy "min_expected_latency" --serving-strategy DistServe --log-level INFO +``` + +LMDeploy-DistServe support both NVLink and RDMA for kvcache transferring from Prefill Engine to Decode Engine. RDMA is default model. Set `--migration-protocol NVLink` for NVLink transport. + +### 2. Configure Endpoints + +First deploy your prefill and decode engines. + +```shell +# Prefill Engine +CUDA_VISIBLE_DEVICES=0 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --role Prefill --proxy-url http://0.0.0.0:8000 --backend pytorch +# Decode Engine +CUDA_VISIBLE_DEVICES=1 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23334 --role Decode --proxy-url http://0.0.0.0:8000 --backend pytorch +``` + +By now, only **Pytorch backend** supports PD Disaggregation. + +## API Usage + +```shell +# API Invoke +curl -X POST "http://localhost:8000/v1/completions" \ +-H "Content-Type: application/json" \ +-d '{"model": "internlm/internlm2_5-7b-chat", "temperature":0, "prompt": "Shanghai is a city that ", "max_tokens": 16, "stream": false}' +# Output +{ + "id":"2", + "object":"text_completion", + "created":1743662400," + model":"internlm/internlm2_5-7b-chat", + "choices":[ + { + "index":0, + "text":" is very famous for its skyscrapers. It is also a city","logprobs":null,"finish_reason":"length" + } + ], + "usage": { + "prompt_tokens":7,"total_tokens":23,"completion_tokens":16 + } +} +``` + +## Trouble Shooting + +### RDMA Connection Failed: + +Make sure ibverbs is correctly installed: + +``` +# on Ubuntu +sudo apt install libibverbs-dev +# on CentOS +sudo yum install ibverbs-devel +``` + +```bash +ibstat # Verify IB device status +ibv_devinfo # Check device capabilities +``` + +### Check GPU Direct RDMA: + +By now, lmdeploy-distserve use GPUDirect RDMA to perform KVTransfer. Make sure GPUDirect RDMA Driver is loaded to kernel. + +```bash +lsmod | grep nv_peer_mem +# GPUDirect RDMA info will be printed If GPUDirect RDMA is correctly loaded. +``` + +### Connection Pool + +Currently, if the ​​Proxy disconnects​​, the connection pool must be ​​warmed up again​​. A future enhancement could involve: + +A ​​dedicated connection pool management server​​ (e.g., using ​​Raft-based tools like ETCD​​, as mentioned in ​​Mooncake​​) to improve ​​connection discovery​​ and avoid repeated warmups. + +### Proxy + +Do not add an engine nodes to **different proxy** because it is not supported and is not considered as a right usage by now. diff --git a/lmdeploy/pytorch/disagg/__init__.py b/lmdeploy/pytorch/disagg/__init__.py new file mode 100644 index 000000000..ef101fec6 --- /dev/null +++ b/lmdeploy/pytorch/disagg/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/pytorch/disagg/backend/__init__.py b/lmdeploy/pytorch/disagg/backend/__init__.py new file mode 100644 index 000000000..1b4584033 --- /dev/null +++ b/lmdeploy/pytorch/disagg/backend/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.logger import get_logger + +logger = get_logger('lmdeploy') + +try: + logger.debug('Registering DLSlime Backend') + from .dlslime import DLSlimeBackend +except ImportError: + logger.warning('Disable DLSlime Backend') + +try: + logger.debug('Registering Mooncake Backend') + from .mooncake import MooncakeBackend +except ImportError: + logger.warning('Disable Mooncake Backend') + +try: + logger.debug('Registering InfiniStoreBackend Backend') + from .infinistore import InfiniStoreBackend +except ImportError: + logger.warning('Disable InfiniStoreBackend Backend') + +__all__ = ['DLSlimeBackend', 'MooncakeBackend', 'InfiniStoreBackend'] diff --git a/lmdeploy/pytorch/disagg/backend/backend.py b/lmdeploy/pytorch/disagg/backend/backend.py new file mode 100644 index 000000000..a3bc2da9e --- /dev/null +++ b/lmdeploy/pytorch/disagg/backend/backend.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.registry import Registry + +MIGRATION_BACKENDS = Registry('migration_backend', locations=['lmdeploy.pytorch.disagg.backend.backend']) diff --git a/lmdeploy/pytorch/disagg/backend/base.py b/lmdeploy/pytorch/disagg/backend/base.py new file mode 100644 index 000000000..200443d12 --- /dev/null +++ b/lmdeploy/pytorch/disagg/backend/base.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod + +from lmdeploy.pytorch.disagg.config import MigrationProtocol +from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest + + +class MigrationBackendImpl: + + @abstractmethod + def p2p_initialize(self, init_request: DistServeInitRequest): + raise NotImplementedError + + @abstractmethod + def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): + raise NotImplementedError + + @abstractmethod + def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): + return NotImplementedError + + @abstractmethod + def p2p_connect(self, conn_req: DistServeConnectionRequest): + raise NotImplementedError + + @abstractmethod + def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError + + @abstractmethod + def store(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError + + @abstractmethod + def load(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError diff --git a/lmdeploy/pytorch/disagg/backend/dlslime.py b/lmdeploy/pytorch/disagg/backend/dlslime.py new file mode 100644 index 000000000..7e0e8d308 --- /dev/null +++ b/lmdeploy/pytorch/disagg/backend/dlslime.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import Dict, List + +from dlslime import Assignment as DLSlimeAssignment +from dlslime import NVLinkEndpoint, RDMAEndpoint, available_nic + +from lmdeploy.logger import get_logger +from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend, MigrationProtocol +from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest + +logger = get_logger('lmdeploy') + + +class DLSlimeMigrationManagement: + + def __init__(self, init_request: DistServeInitRequest): + self.rank = init_request.rank + self.local_engine_config: DistServeEngineConfig = init_request.local_engine_config + self.remote_engine_config: DistServeEngineConfig = init_request.remote_engine_config + self.endpoint: Dict[MigrationProtocol, RDMAEndpoint] = { + MigrationProtocol.TCP: None, + MigrationProtocol.RDMA: None, + MigrationProtocol.NVLINK: None, + } + if init_request.protocol == MigrationProtocol.RDMA: + nics = available_nic() + device_name = nics[self.rank % len(nics)] + logger.info(f'use device {device_name} for kv migration') + self.endpoint[MigrationProtocol.RDMA] = RDMAEndpoint(device_name=device_name, + ib_port=1, + link_type=init_request.rdma_config.link_type.name) + elif init_request.protocol == MigrationProtocol.NVLINK: + self.endpoint[MigrationProtocol.NVLINK] = NVLinkEndpoint() + + def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): + self.endpoint[register_mr_request.protocol].register_memory_region(register_mr_request.mr_key, + register_mr_request.addr, + register_mr_request.offset, + register_mr_request.length) + + def connect(self, connect_request: DistServeConnectionRequest): + self.endpoint[connect_request.protocol].connect(json.loads(connect_request.remote_endpoint_info)) + + def p2p_migrate(self, assignment: MigrationAssignment, async_op=False): + MAX_NUM_READ_BATCH = 4096 + + def split(batch: List[DLSlimeAssignment]): + batch_split = [] + for i in range(0, len(batch), MAX_NUM_READ_BATCH): + batch_split.append(batch[i:i + MAX_NUM_READ_BATCH]) + return batch_split + + batch = [ + DLSlimeAssignment( + mr_key=assign.mr_key, + target_offset=assign.target_offset, + source_offset=assign.source_offset, + length=assign.length, + ) for assign in assignment.batch + ] + batch_splited = split(batch) + for b_split in batch_splited: + self.endpoint[assignment.protocol].read_batch(b_split) + + +@MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name) +class DLSlimeBackend(MigrationBackendImpl): + + def __init__(self): + self.links: Dict[int, DLSlimeMigrationManagement] = {} + + def p2p_initialize(self, init_request: DistServeInitRequest): + self.links[init_request.remote_engine_id] = DLSlimeMigrationManagement(init_request) + + def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): + self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request) + + def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): + return self.links[remote_engine_id].endpoint[protocol].endpoint_info + + def p2p_connect(self, conn_req: DistServeConnectionRequest): + self.links[conn_req.remote_engine_id].connect(conn_req) + + def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): + self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op) + + def store(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError + + def load(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError diff --git a/lmdeploy/pytorch/disagg/backend/infinistore.py b/lmdeploy/pytorch/disagg/backend/infinistore.py new file mode 100644 index 000000000..f75850138 --- /dev/null +++ b/lmdeploy/pytorch/disagg/backend/infinistore.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.config import MigrationBackend, MigrationProtocol +from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest + + +@MIGRATION_BACKENDS.register_module(MigrationBackend.InfiniStore.name) +class InfiniStoreBackend(MigrationBackendImpl): + + def p2p_initialize(self, init_request: DistServeInitRequest): + raise NotImplementedError + + def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): + raise NotImplementedError + + def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): + return NotImplementedError + + def p2p_connect(self, conn_req: DistServeConnectionRequest): + raise NotImplementedError + + def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError + + def store(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError + + def load(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError diff --git a/lmdeploy/pytorch/disagg/backend/mooncake.py b/lmdeploy/pytorch/disagg/backend/mooncake.py new file mode 100644 index 000000000..9a1348817 --- /dev/null +++ b/lmdeploy/pytorch/disagg/backend/mooncake.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.config import MigrationBackend, MigrationProtocol +from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest + + +@MIGRATION_BACKENDS.register_module(MigrationBackend.Mooncake.name) +class MooncakeBackend(MigrationBackendImpl): + + def p2p_initialize(self, init_request: DistServeInitRequest): + raise NotImplementedError + + def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): + raise NotImplementedError + + def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): + return NotImplementedError + + def p2p_connect(self, connect_request: DistServeConnectionRequest): + raise NotImplementedError + + def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError + + def store(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError + + def load(self, assignment: MigrationAssignment, async_op: bool = False): + raise NotImplementedError diff --git a/lmdeploy/pytorch/disagg/config.py b/lmdeploy/pytorch/disagg/config.py new file mode 100644 index 000000000..14bc7fc29 --- /dev/null +++ b/lmdeploy/pytorch/disagg/config.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import enum +from typing import Optional + +from pydantic import BaseModel + + +class ServingStrategy(enum.Enum): + """Serving Strategy. + + Attributes: + Hybrid: Prefill and Decode workload are co-located in one engine. + DistServe: Prefill and Decode workload are assigned to different + engines. After the execution of prefill phase in Prefill Engine, + KVCache is migrated from Prefill to Decode Engine. + """ + + Hybrid = enum.auto() + DistServe = enum.auto() + + +class EngineRole(enum.Enum): + """Role of Engine. + + Note: In the implementation of LMDeploy-Distserve, all engine is hybrid + engine technically, the role of engine is up to what kind of request is + sent to the engine. However, taking implementation into the consideration, + the role is still need to be identified when starting the engine server + for the following reasons: + 1. Make sure the engine can be correctly discovered by the proxy. + 2. The create of ModelInputs is different among hybrid, prefill and + decode engines in DP Engine (DSV3 DP + EP). + """ + + Hybrid = enum.auto() + Prefill = enum.auto() + Decode = enum.auto() + + +class MigrationBackend(enum.Enum): + """Migration Backend.""" + + DLSlime = enum.auto() + Mooncake = enum.auto() + InfiniStore = enum.auto() + + +class MigrationProtocol(enum.Enum): + """Migration Transport Protocol. + + Attributes: + TCP: TCP for General Purpose Transport Protocol. + RDMA: IB or RoCEv1/v2. + NVLINK: High device-to-device link. + + Warning: By now, only `GPU Directed RDMA` is supported in DistServe. + We preserve several protocol and will be implemented in the future. + """ + + TCP = enum.auto() + RDMA = enum.auto() + NVLINK = enum.auto() + + +class RDMALinkType(enum.Enum): + """RDMA Link Type.""" + + IB = enum.auto() + RoCE = enum.auto() + + +class DistServeRDMAConfig(BaseModel): + """DistServe RDMA Config. + + Args: + with_gdr: default to True. + link_type: default to `RDMALinkType.RoCE`. + + Warning: Only GDR is supported by now. + Warning: Technically, both RoCE and IB are supported. + However, IB mode is not tested because of unavailable + testing envoriment. + """ + + # RDMA with GPU Direct RDMA Access + with_gdr: bool = True + link_type: RDMALinkType = RDMALinkType.RoCE + + +class DistServeTCPConfig(BaseModel): + """TODO: Add TCP Protocol""" + + +class DistServeNVLinkConfig(BaseModel): + """TODO: Add NVLink Protocol""" + + +class DistServeEngineConfig(BaseModel): + """DistServe Engine Config. + + In Disaggregated LLM Serving, we need to get engine info of each + PD Peer for the following reason: + 1. Cache: The stride of cache block for correct offset of KV Transfer. + 2. Parallel: Prefill and decode use different parallel strategy to + achieve high SLO Attainment or high throughput. In this situation, + we need to caclculate which prefill-decode worker peers need to connect. + For example, prefill worker use pp4 and decode worker use tp2pp2, + the perfill-decode worker conn peer is (0, 0), (0, 1), (1, 0), (1, 1), + (2, 2), (2, 3), (3, 2), (3, 3). Instead, under the situation of + (tp4, tp4), perfill-decode worker conn peer is (0, 0), (1, 1), (2, 2), + (3, 3). + """ + + # parallel config + # (dp, pp, tp, ep) + tp_size: int + ep_size: int + dp_size: int + pp_size: Optional[int] + + # Rank of DP + dp_rank: int + + # cache config + block_size: int + num_cpu_blocks: int + num_gpu_blocks: int + + +class DistServeConfig(BaseModel): + """DistServe Config.""" + + serving_strategy: ServingStrategy + distserve_transport_protocol: MigrationProtocol + rdma_config: Optional[DistServeRDMAConfig] = None + nvlink_config: Optional[DistServeNVLinkConfig] = None + tcp_config: Optional[DistServeTCPConfig] = None diff --git a/lmdeploy/pytorch/disagg/conn.py b/lmdeploy/pytorch/disagg/conn.py new file mode 100644 index 000000000..263ed873e --- /dev/null +++ b/lmdeploy/pytorch/disagg/conn.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import enum +import json +import os +from typing import Dict, List, Tuple + +import aiohttp + +from lmdeploy.logger import get_logger +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig +from lmdeploy.pytorch.disagg.messages import PDConnectionMessage +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest + +logger = get_logger('lmdeploy') + +AIOHTTP_TIMEOUT = os.getenv('AIOHTTP_TIMEOUT', None) + + +class PDConnectionStatus(enum.Enum): + Disconnected = enum.auto() + Connected = enum.auto() + Connecting = enum.auto() + + +class PDConnectionState: + """PDConnectionState.""" + + def __init__(self, status: PDConnectionStatus, event: asyncio.Event): + self.status = status + self.event = event + + async def wait(self): + await self.event.wait() + + def set_status(self, status: PDConnectionStatus): + self.status = status + + +class PDConnectionPool: + """Constructing the link of Prefill and Decode engine for the migration of + KVCache. + + Note: we use Peer to Peer transportation in KVCache migration. + Note: Lazy link construction is supported, which perform connection + at the first LLM request. As a result, we don't need to construct + PD Communication group when start a engine server. + Warning: By now, only engines with same parallel configuration can be + correctly connected. + """ + + def __init__(self): + # Links of PD Connection. + self.pool: Dict[Tuple[str, str], PDConnectionState] = {} + + # conn_perform handler queue + self.waiting_conn: asyncio.Queue[Tuple[PDConnectionMessage, asyncio.Event]] = (asyncio.Queue()) + + # conn Registry Lock + self.conn_lock = asyncio.Lock() + + # Connection Retry when failure + self.max_retry_cnt = 8 + + # trigger signal when conn request arrive. + self.conn_req_event = asyncio.Event() + + # conn initialized signal + self.initialized = False + + async def perform_conn(self): + + def get_server_api(url: str, api: str): + return f'{url}/{api}' + + async def get_engine_config(server_endpoint): + async with self.conn_sem: + async with self.conn_sess.get( + get_server_api(server_endpoint, 'distserve/engine_info'), + timeout=self.aiotimeout, + ) as resp: + return DistServeEngineConfig.model_validate_json(await resp.json()) + + async def p2p_initialize(server_endpoint, init_request: DistServeInitRequest): + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_initialize'), + json=init_request.model_dump(mode='json'), + timeout=self.aiotimeout, + ) as resp: + return await resp.json() + + async def p2p_connect(server_endpoint, conn_request: List[DistServeConnectionRequest]): + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_connect'), + json=[req.model_dump(mode='json') for req in conn_request], + timeout=self.aiotimeout, + ) as resp: + return await resp.json() + + async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): + try: + link = (conn_req.p_url, conn_req.d_url) + logger.debug(f'{link} connecting...') + # Step 1. Get Remote Engine Configuration + prefill_engine_config = await get_engine_config(conn_req.p_url) + decode_engine_config = await get_engine_config(conn_req.d_url) + + # Note: Only Same Parallel Configurations are supported by now + assert prefill_engine_config.tp_size == decode_engine_config.tp_size + + # Step 2. Construct Initialize Configuration + prefill_init_req = DistServeInitRequest( + protocol=conn_req.protocol, + local_engine_id=conn_req.p_url, + local_engine_config=prefill_engine_config, + remote_engine_id=conn_req.d_url, + remote_engine_config=decode_engine_config, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config, + ) + decode_init_req = DistServeInitRequest(protocol=conn_req.protocol, + local_engine_id=conn_req.d_url, + local_engine_config=decode_engine_config, + remote_engine_id=conn_req.p_url, + remote_engine_config=prefill_engine_config, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config) + + prefill_endpoint_info = await p2p_initialize(conn_req.p_url, prefill_init_req) + decode_endpoint_info = await p2p_initialize(conn_req.d_url, decode_init_req) + + # Step 3. Connection + prefill_endpoint_conn_reqs = [ + DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.d_url, + remote_endpoint_info=json.dumps(info), + ) for info in decode_endpoint_info + ] + decode_endpoint_conn_reqs = [ + DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.p_url, + remote_endpoint_info=json.dumps(info), + ) for info in prefill_endpoint_info + ] + await p2p_connect(conn_req.p_url, prefill_endpoint_conn_reqs) + await p2p_connect(conn_req.d_url, decode_endpoint_conn_reqs) + self.pool[link].set_status(PDConnectionStatus.Connected) + logger.debug(f'{(conn_req.p_url, conn_req.d_url)} connected') + except Exception as e: + self.pool[link].set_status(PDConnectionStatus.Disconnected) + logger.error(f'pd connection error: {e}') + conn_event.set() + + async def wait_for_conn(conn_req: PDConnectionMessage, conn_event: asyncio.Event): + await self.pool[(conn_req.p_url, conn_req.d_url)].event.wait() + conn_event.set() + + logger.debug('perform_conn start') + while True: + if self.waiting_conn.empty(): + await self.conn_req_event.wait() + + self.conn_req_event.clear() + + while not self.waiting_conn.empty(): + conn_req, conn_event = self.waiting_conn.get_nowait() + link = (conn_req.p_url, conn_req.d_url) + if link not in self.pool: + self.pool[link] = PDConnectionState( + PDConnectionStatus.Disconnected, + conn_event, + ) + if self.pool[link].status == PDConnectionStatus.Connecting: + asyncio.create_task(wait_for_conn(conn_req, conn_event)) + elif self.pool[link].status == PDConnectionStatus.Disconnected: + self.pool[link].set_status(PDConnectionStatus.Connecting) + asyncio.create_task(conn_worker(conn_req, conn_event)) + + async def connect(self, conn_req: PDConnectionMessage): + if not self.initialized: + loop = asyncio.get_event_loop() + loop.create_task(self.perform_conn()) + self.conn_sem = asyncio.Semaphore(1024) + self.conn_sess = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit_per_host=256), + timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT), + ) + self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) + self.initialized = True + cnt = 0 + while cnt < self.max_retry_cnt: + if self.is_connected(conn_req.p_url, conn_req.d_url): + return + if cnt > 0: + logger.warning(f'Connection failure, retry cnt: {cnt}') + conn_event = asyncio.Event() + self.waiting_conn.put_nowait((conn_req, conn_event)) + self.conn_req_event.set() + await conn_event.wait() + cnt += 1 + async with self.conn_lock: + self.pool[conn_req.p_url, conn_req.d_url].set_status(PDConnectionStatus.Disconnected) + raise TimeoutError('PDConnection Failure') + + def is_connected(self, p_url: str, d_url: str): + link = self.pool.get((p_url, d_url), None) + if not link: + return False + return link.status == PDConnectionStatus.Connected + + def drop(self, left: str, right: str): + self.pool.pop((left, right), None) diff --git a/lmdeploy/pytorch/disagg/messages.py b/lmdeploy/pytorch/disagg/messages.py new file mode 100644 index 000000000..b49769102 --- /dev/null +++ b/lmdeploy/pytorch/disagg/messages.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +from pydantic import BaseModel + +from lmdeploy.pytorch.disagg.config import (DistServeNVLinkConfig, DistServeRDMAConfig, DistServeTCPConfig, + MigrationProtocol) + + +class MigrationExecutionBatch(BaseModel): + """Input of the Migration.""" + + protocol: MigrationProtocol + requests: List[Tuple[str, List[Tuple[int, int]]]] = [] + + +class AssignmentInstruct(BaseModel): + """Assignment Batch.""" + mr_key: str + target_offset: int + source_offset: int + length: int + + +class MigrationAssignment(BaseModel): + """Migration Assignment.""" + protocol: MigrationProtocol + remote_engine_id: str + batch: List[AssignmentInstruct] + + +class PDConnectionMessage(BaseModel): + p_url: str + d_url: str + protocol: MigrationProtocol = MigrationProtocol.RDMA + tcp_config: Optional[DistServeTCPConfig] = None + rdma_config: Optional[DistServeRDMAConfig] = None + nvlink_config: Optional[DistServeNVLinkConfig] = None + + +class DistServeRegisterMRMessage(BaseModel): + protocol: MigrationProtocol + + remote_engine_id: str + mr_key: str + addr: int + offset: int + length: int diff --git a/lmdeploy/pytorch/disagg/request.py b/lmdeploy/pytorch/disagg/request.py new file mode 100644 index 000000000..990fe814b --- /dev/null +++ b/lmdeploy/pytorch/disagg/request.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from pydantic import BaseModel + +from lmdeploy.pytorch.disagg.config import (DistServeEngineConfig, DistServeNVLinkConfig, DistServeRDMAConfig, + DistServeTCPConfig, MigrationProtocol) + + +class DistServeConnectionRequest(BaseModel): + protocol: MigrationProtocol + remote_engine_id: str + remote_endpoint_info: str + + +class DistServeInitRequest(BaseModel): + local_engine_id: str + local_engine_config: DistServeEngineConfig + + remote_engine_id: str + remote_engine_config: DistServeEngineConfig + + protocol: MigrationProtocol + + rank: Optional[int] = None + + tcp_config: Optional[DistServeTCPConfig] = None + rdma_config: Optional[DistServeRDMAConfig] = None + nvlink_config: Optional[DistServeNVLinkConfig] = None + + +class MigrationRequest(BaseModel): + protocol: MigrationProtocol + + remote_engine_id: str + remote_session_id: int + remote_token_id: int + remote_block_ids: List[int] diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index 0b4ea8ea5..6e8a5257d 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -1,10 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm -from typing import Dict, List, Literal, Tuple +from typing import Dict, List, Literal, Optional, Tuple import torch from lmdeploy.pytorch.backends import get_backend +from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.messages import (AssignmentInstruct, DistServeRegisterMRMessage, MigrationAssignment, + MigrationExecutionBatch) +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.utils import get_logger from ..config import CacheConfig, ModelConfig @@ -29,10 +34,13 @@ def __init__( self, cache_config: CacheConfig, model_config: ModelConfig, + rank: int = 0, + tp_rank: int = 0, world_size: int = 1, ) -> None: self.world_size = world_size - + self.rank = rank + self.tp_rank = tp_rank self.cache_config = cache_config self.model_config = model_config @@ -51,6 +59,8 @@ def __init__( self.local_gpu_cache = self.allocate_gpu_cache() self.local_cpu_cache = self.allocate_cpu_cache() + self.migration_backend_impl: Optional[MigrationBackendImpl] = None + # Initialize the stream for caching operations. self.cache_stream = torch.cuda.Stream() assert self.cache_stream != torch.cuda.current_stream() @@ -302,3 +312,66 @@ def get_cache_block_size(cls, total = num_layers * (mem_key_block + mem_value_block) return total + + """ Metheds for PD Disaggregation Begin. """ + + def p2p_initialize(self, migration_init_request: DistServeInitRequest): + if not self.migration_backend_impl: + self.migration_backend_impl = MIGRATION_BACKENDS.module_dict[self.cache_config.migration_backend.name]() + migration_init_request.rank = self.rank + self.migration_backend_impl.p2p_initialize(migration_init_request) + for i, t in enumerate(self.full_gpu_cache): + if t.numel() == 0: + continue + register_mr_request = DistServeRegisterMRMessage(protocol=migration_init_request.protocol, + remote_engine_id=migration_init_request.remote_engine_id, + mr_key=str(i), + addr=t.data_ptr(), + offset=t.storage_offset(), + length=t.numel() * t.itemsize) + self.migration_backend_impl.register_memory_region(register_mr_request) + return self.migration_backend_impl.endpoint_info(migration_init_request.remote_engine_id, + migration_init_request.protocol) + + def p2p_connect(self, migration_conn_request: DistServeConnectionRequest): + self.migration_backend_impl.p2p_connect(migration_conn_request[self.tp_rank]) + + def migrate(self, migration_execution_inputs: MigrationExecutionBatch): + + def get_assignment_len(): + head_dim = self.model_config.get_head_size() + num_heads = self.model_config.num_key_value_heads // self.world_size + block_size = self.cache_config.block_size + return head_dim * num_heads * block_size * self.model_config.dtype.itemsize + + assignment_len = get_assignment_len() + layer_stride = self.cache_config.num_gpu_blocks * assignment_len + + def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote_layer_stride): + return [ + AssignmentInstruct(mr_key=mr_key, + target_offset=block_id[0] * assignment_len + layer * remote_layer_stride, + source_offset=block_id[1] * assignment_len + layer * layer_stride, + length=assignment_len) for layer in range(self.model_config.num_layers) + for block_id in block_ids + ] + + assignment_batch: List[Tuple[str, int, int, int]] = [] # mr_key, target, source, offset + for migration_exe_req in migration_execution_inputs.requests: + remote_engine_id = migration_exe_req[0] + blocks_to_migration = migration_exe_req[1] + remote_layer_stride = self.migration_backend_impl.links[ + remote_engine_id].remote_engine_config.num_gpu_blocks * assignment_len + + for i, t in enumerate(self.full_gpu_cache): + assignment_batch.extend( + get_assignment_batch(str(i), blocks_to_migration, assignment_len, layer_stride, + remote_layer_stride)) + self.migration_backend_impl.p2p_migrate( + MigrationAssignment( + protocol=migration_execution_inputs.protocol, + remote_engine_id=remote_engine_id, + batch=assignment_batch, + )) + + """ Metheds for PD Disaggregation End. """ diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 95c6bd713..7ed5bd46d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -4,12 +4,14 @@ import logging import os from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple import numpy as np import torch from lmdeploy.messages import PytorchEngineConfig, ResponseType +from lmdeploy.pytorch.disagg.config import EngineRole +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer from ..adapter.adapter import AdapterManager @@ -40,6 +42,10 @@ class InferOutput: finish: bool = False logits: torch.Tensor = None + # send cache blocks back for migration in Disaggregated LLM Serving + # when Prefill Engine is Done. + cache_block_ids: List[int] = None + def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): """tensorlize block_offsets.""" @@ -59,17 +65,17 @@ def _build_scheduler_config(engine_config: PytorchEngineConfig): def _build_cache_config(engine_config: PytorchEngineConfig): """build cache config.""" - cache_config = CacheConfig( - max_batches=engine_config.max_batch_size, - block_size=engine_config.block_size, - num_cpu_blocks=engine_config.num_cpu_blocks, - num_gpu_blocks=engine_config.num_gpu_blocks, - cache_max_entry_count=engine_config.cache_max_entry_count, - max_prefill_token_num=engine_config.max_prefill_token_num, - enable_prefix_caching=engine_config.enable_prefix_caching, - quant_policy=engine_config.quant_policy, - device_type=engine_config.device_type, - ) + cache_config = CacheConfig(max_batches=engine_config.max_batch_size, + block_size=engine_config.block_size, + num_cpu_blocks=engine_config.num_cpu_blocks, + num_gpu_blocks=engine_config.num_gpu_blocks, + cache_max_entry_count=engine_config.cache_max_entry_count, + max_prefill_token_num=engine_config.max_prefill_token_num, + enable_prefix_caching=engine_config.enable_prefix_caching, + quant_policy=engine_config.quant_policy, + device_type=engine_config.device_type, + migration_backend=engine_config.migration_backend, + role=engine_config.role) return cache_config @@ -187,7 +193,8 @@ def do_prefill(self): num_waiting = scheduler.num_waiting() max_batches = self.scheduler_config.max_batches # prefill if too much waiting - if num_waiting >= 4: + permitted_waiting = 4 if (self.engine.engine_config.role != EngineRole.Prefill) else 1 + if num_waiting >= permitted_waiting: return True # prefill if no enough running if num_running < max_batches * 0.5: @@ -243,8 +250,11 @@ def __init__(self, engine: 'Engine'): self._is_prefill = True def do_prefill(self): - ret = self._is_prefill - self._is_prefill = not self._is_prefill + if self.engine.engine_config.role in [EngineRole.Hybrid, EngineRole.Decode]: + ret = self._is_prefill + self._is_prefill = not self._is_prefill + elif self.engine.engine_config.role == EngineRole.Prefill: + ret = True return ret async def send_next_inputs(self): @@ -353,6 +363,9 @@ def __init__(self, self._start_loop() self._loop_main = None + # for migration loop management + self.migration_event: asyncio.Event = None + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, @@ -473,7 +486,11 @@ def _on_end_session(self, reqs: List[Request], **kwargs): resp = req.data.get('response', True) resp_type = ResponseType.SESSION_NOT_EXIST if session_id in self.scheduler.sessions: - self.scheduler.end_session(session_id) + msg = list(self.scheduler.sessions[session_id].sequences.values())[0] + if msg.preserve_cache: + self.scheduler._set_message_status(msg, MessageStatus.TO_BE_MIGRATED) + else: + self.scheduler.end_session(session_id) resp_type = ResponseType.SUCCESS if resp: self._response(req.resp, resp_type) @@ -523,18 +540,23 @@ def __update_max_new_tokens(msg): sampling_param = req.data['sampling_param'] return_logits = sampling_param.out_logits if len(sess.sequences) == 0: + migration_request = req.data.get('migration_request') assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') - sess.add_sequence( - req.data['token_ids'], - sampling_param=sampling_param, - adapter_name=req.data['adapter_name'], - return_logits=return_logits, - multimodals=req.data.get('input_multimodals'), - input_embeddings=req.data.get('input_embeddings'), - ) + sess.add_sequence(req.data['token_ids'], + sampling_param=sampling_param, + adapter_name=req.data['adapter_name'], + return_logits=return_logits, + multimodals=req.data.get('input_multimodals'), + input_embeddings=req.data.get('input_embeddings', ), + migration_request=migration_request, + resp_cache=req.data.get('with_cache'), + preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) __update_max_new_tokens(msg) scheduler.add_sequence(msg) + if migration_request: + self.scheduler._set_message_status(msg, MessageStatus.WAITING_MIGRATION) + self.migration_event.set() else: msg = next(iter(sess.sequences.values())) msg.update_token_ids( @@ -703,6 +725,24 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped msg.update_token_ids(update_token, model_meta=model_meta) msg.status = MessageStatus.STOPPED + def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, + model_metas: List[Dict[str, Any]]): + """update scheduler.""" + if model_metas is None: + model_metas = [None] * len(running) + for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): + if msg.status != MessageStatus.MIGRATION_LOCKED: + continue + update_token = token + + # fill token + msg.update_token_ids(update_token, model_meta=model_meta) + msg.num_new_tokens += 1 + if stop: + update_token = _EMPTY_TOKEN + msg.update_token_ids(update_token, model_meta=model_meta) + msg.status = MessageStatus.STOPPED + def _make_infer_outputs(self, next_token_ids: torch.LongTensor, running: SeqList, logits: torch.Tensor, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]): """make infer output.""" @@ -722,12 +762,15 @@ def _make_infer_outputs(self, next_token_ids: torch.LongTensor, running: SeqList if not finish and len(token_ids) == 0: continue session_id = msg.session_id - out = InferOutput( - session_id=session_id, - resp=msg.resp, - finish=finish, - token_ids=token_ids, - ) + if msg.resp_cache: + cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist() + else: + cache_block_ids = None + out = InferOutput(session_id=session_id, + resp=msg.resp, + finish=finish, + token_ids=token_ids, + cache_block_ids=cache_block_ids) outputs[session_id] = out if msg.return_logits: @@ -815,7 +858,9 @@ def __make_dummy_inputs(): return None # schedule decoding if no valid prefill reqs. - if prefill and len(scheduler_output.running) == 0 and not self.should_execute_dummy_batch: + if prefill and len( + scheduler_output.running + ) == 0 and not self.should_execute_dummy_batch and self.engine_config.role != EngineRole.Prefill: prefill = False scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval) @@ -824,7 +869,7 @@ def __make_dummy_inputs(): swap_in_map = scheduler_output.swap_in_map swap_out_map = scheduler_output.swap_out_map - if self.should_execute_dummy_batch and len(running) == 0: + if (self.should_execute_dummy_batch or self.engine_config.role == EngineRole.Prefill) and len(running) == 0: return __make_dummy_inputs() assert len(running) > 0 @@ -877,12 +922,14 @@ def __log_resps(outputs: List[InferOutput]): session_ids = [out.session_id for out in outputs] logger.debug(f'Response sessions: {session_ids}') elif logger.level <= logging.INFO: - logger.info(f'Response: num_outputs={len(outputs)}.') + logger.debug(f'Response: num_outputs={len(outputs)}.') def __send_resp(out: InferOutput): """send response.""" resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS) - self._response(out.resp, resp_type, data=dict(token_ids=out.token_ids, logits=out.logits)) + self._response(out.resp, + resp_type, + data=dict(token_ids=out.token_ids, logits=out.logits, cache_block_ids=out.cache_block_ids)) def __send_resps(step_outputs: List[InferOutput]): """send response callback.""" @@ -901,6 +948,51 @@ def __send_resps(step_outputs: List[InferOutput]): await self._await_forward_event(forward_event) __send_resps(resps) + @torch.inference_mode() + async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): + """async loop migration.""" + while True: + migration_running = self.scheduler._schedule_migration() + if not migration_running: + await self.migration_event.wait() + else: + self.migration_event.clear() + for msg in migration_running: + migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] + migration_request = msg.migration_request + prefill_block_ids = migration_request.remote_block_ids + decode_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg)) + + assert len(prefill_block_ids) == len(decode_block_ids) + migration_execution_requests.append(( + migration_request.remote_engine_id, + list(zip(prefill_block_ids, decode_block_ids)), + )) + migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol, + requests=migration_execution_requests) + logger.info(f'migrating session: {msg.session_id} begin') + await self.executor.migrate(migration_inputs) + logger.info(f'migrating session: {msg.session_id} done') + + # generate output + outputs: Dict[int, InferOutput] = dict() + self.scheduler.lock_running_migration(migration_running) + for _, msg in enumerate(migration_running): + session_id = msg.session_id + msg.resp.type = ResponseType.SUCCESS + token_ids = [msg.migration_request.remote_token_id] + out = InferOutput( + session_id=session_id, + resp=msg.resp, + finish=False, + token_ids=np.array(token_ids), + ) + outputs[session_id] = out + self.update_running_migration([msg], np.array([token_ids]), [False], [None]) + resp_que.put_nowait(outputs) + self.scheduler.unlock_running_migration(migration_running) + has_runable_event.event.set() + @torch.inference_mode() async def _async_loop_main( self, @@ -917,8 +1009,10 @@ async def _async_loop_main( next_running = None while True: + logger.info('begin loop') if next_running is None: await has_runable_event.wait() + scheduler.collect_migration_done() forward_inputs, next_running = await inputs_maker.send_next_inputs() num_loops = forward_inputs['loop_count'] running = next_running @@ -927,12 +1021,16 @@ async def _async_loop_main( for idx in range(num_loops): if idx >= num_loops - 1: forward_inputs, next_running = await inputs_maker.prefetch_next_inputs() + logger.info('inputs forwarding done') out = await self.executor.get_output_async() + logger.info('get_output_async done') if len(out) > 0: step_outputs = self._make_infer_outputs(**out, running=running) resp_que.put_nowait(step_outputs) + logger.info('send response done') scheduler.unlock_running(running) has_runable_event.set() + logger.info('end loop') @staticmethod def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]): @@ -970,6 +1068,9 @@ async def async_loop(self): forward_event = asyncio.Event() forward_event.set() + # migration task + self.migration_event = asyncio.Event() + logger.info('Starting executor.') self.executor.start(forward_event) @@ -986,9 +1087,15 @@ async def async_loop(self): loop_send_resp = event_loop.create_task(self._async_loop_send_responses(resp_que, forward_event), name='MainLoopResponse') + logger.info('Starting async task MigrationLoop.') + loop_migration = event_loop.create_task( + self._async_loop_migration(resp_que, has_runable_event=has_runable_event), + name='MainLoopMigration', + ) + # binding done callback loop_main = asyncio.current_task() - loop_tasks: List[asyncio.Task] = [loop_main, loop_msg_proc, loop_send_resp] + loop_tasks: List[asyncio.Task] = [loop_main, loop_msg_proc, loop_migration, loop_send_resp] self._add_loop_tasks_done_callback(loop_tasks) self._loop_main = loop_main diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index da0e1c2ee..985d64831 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -137,6 +137,9 @@ async def async_stream_infer(self, sampling_param=sampling_param, adapter_name=adapter_name, input_multimodals=multimodal, + migration_request=gen_config.migration_request, + with_cache=gen_config.with_cache, + preserve_cache=gen_config.preserve_cache, ) logger.debug(f'session[{session_id}] add message: num_input_ids={len(input_ids)}.') resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) @@ -144,18 +147,19 @@ async def async_stream_infer(self, while True: resp = await self.req_sender.async_recv(resp) + cache_block_ids = resp.data.get('cache_block_ids', None) if resp.type == ResponseType.SUCCESS: token_ids = resp.data['token_ids'].tolist() num_ids = len(token_ids) logger.debug(f'session[{session_id}] success: num_out_ids={num_ids}.') - yield EngineOutput(resp.type, token_ids, num_ids) + yield EngineOutput(resp.type, token_ids, num_ids, cache_block_ids=cache_block_ids) elif resp.type == ResponseType.FINISH: resp_data = resp.data token_ids = resp_data['token_ids'].tolist() logits = resp_data['logits'] num_ids = len(token_ids) logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.') - yield EngineOutput(resp.type, token_ids, num_ids, logits=logits) + yield EngineOutput(resp.type, token_ids, num_ids, logits=logits, cache_block_ids=cache_block_ids) break else: logger.debug(f'session[{session_id}] failed.') diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index a87107bcb..5c003c302 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. # Inspired by vLLM: https://github.com/vllm-project/vllm import asyncio -from typing import Any, Dict +from typing import Any, Dict, List from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, ModelConfig +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.pytorch.engine.cache_engine import CacheEngine from lmdeploy.utils import get_logger @@ -90,6 +92,22 @@ async def get_output_async(self): """get output async.""" raise NotImplementedError('Not Implemented') + """ PD Disaggregation API Begin """ + + def p2p_initialize(self, remote_engine_config: DistServeInitRequest): + """init rdma link.""" + raise NotImplementedError('Not implemented') + + def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + """rdma_connect.""" + raise NotImplementedError('Not Implemented') + + async def migrate(self, batch: MigrationExecutionBatch): + """KV Cache Migration.""" + raise NotImplementedError('Not Implemented') + + """ PD Disaggregation API End """ + def _get_runtime_size(self, num_free_gpu_mem: int, cache_block_size: int, vocal_size: int): """find best prefill num.""" cache_max_entry_count = self.cache_config.cache_max_entry_count diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index 2dd24287d..cd6f7aab6 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -from typing import Any, Dict +from typing import Any, Dict, List from lmdeploy.pytorch.backends.selector import get_backend from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, ModelConfig from lmdeploy.pytorch.devices import DeviceContext +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.pytorch.distributed import DistContext from lmdeploy.pytorch.engine.model_agent import build_model_agent from lmdeploy.utils import get_logger @@ -159,3 +161,16 @@ async def get_output_async(self): def release(self): """stop engine loop.""" self.model_agent.release() + + """ PD Disaggregation API Begin """ + + def p2p_initialize(self, init_request: DistServeInitRequest): + return self.model_agent.cache_engine.p2p_initialize(init_request) + + def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + return self.model_agent.cache_engine.p2p_connect(conn_request) + + async def migrate(self, inputs: MigrationExecutionBatch): + return self.model_agent.cache_engine.migrate(inputs) + + """ PD Disaggregation API End """ diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index e9ef101f4..be013d67b 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -15,6 +15,8 @@ from lmdeploy.pytorch.backends.selector import init_backend from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, ModelConfig from lmdeploy.pytorch.devices import DeviceContext, get_device_manager +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.utils import get_logger from .base import ExecutorBase @@ -387,7 +389,7 @@ def _prefetch_task_callback(self, task: asyncio.Task): except KeyboardInterrupt: logger.debug(f'{task.get_name()} KeyboardInterrupt.') except BaseException: - logger.exception(f'{task.get_name()} task failed.') + logger.debug(f'{task.get_name()} task failed.') def start(self, forward_event: asyncio.Event): """start engine loop.""" @@ -552,3 +554,18 @@ def _init_distributed_environment_by_device(self, device_str: str): ray.get([worker.set_env.remote(envs) for worker in self.workers]) else: raise ValueError(f'Unsupported device type: {device_str}') + + """ PD Disaggregation API Begin """ + + def p2p_initialize(self, init_request: DistServeInitRequest): + return self.collective_rpc('p2p_initialize', (init_request, )) + + def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + """rdma connect.""" + return self.collective_rpc('p2p_connect', (conn_request, )) + + async def migrate(self, batch: MigrationExecutionBatch): + jobs = (worker.migrate.remote(batch) for worker in self.workers) + return await asyncio.gather(*jobs) + + """ PD Disaggregation API Begin """ diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index f5475be80..aaeb6342c 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -from typing import Any, Dict +from typing import Any, Dict, List from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, ModelConfig from lmdeploy.pytorch.devices import DeviceContext +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.pytorch.engine.model_agent import build_model_agent from lmdeploy.utils import get_logger @@ -97,3 +99,22 @@ async def get_output_async(self, dp_rank: int = 0): def get_input_processor(self): """get input processor.""" return self.model_agent.get_input_processor() + + """ PD Disaggregation API Begin """ + + def p2p_initialize(self, init_request: DistServeInitRequest): + """init rdma link. + + note: return list to be composible with multiprocess executor like ray. + """ + return [self.model_agent.cache_engine.p2p_initialize(init_request)] + + def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + """rdma_connect.""" + self.model_agent.cache_engine.p2p_connect(conn_request) + + async def migrate(self, batch: MigrationExecutionBatch): + """KV Cache Migration.""" + return self.model_agent.cache_engine.migrate(batch) + + """ PD Disaggregation API End """ diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 9cdc9256c..1ce9041fd 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.utils import get_logger from ..backends import get_backend @@ -367,10 +368,13 @@ async def __await_distworker(worker, timeout: float = 0.001): rank = dist_ctx.rank tp = dist_ctx.tp dp = dist_ctx.dp + tp_cpu_group = dist_ctx.tp_cpu_group + tp_gpu_group = dist_ctx.tp_gpu_group - logger.info(f' rank[{rank}]: ' - f'batch_size={inputs.seq_length.size(0)} ' - f'num_tokens={inputs.input_ids.size(-1)}') + logger.debug(f' rank[{rank}]: ' + f'batch_size={inputs.seq_length.size(0)} ' + f'num_tokens={inputs.input_ids.size(-1)} ' + f'is_decoding={inputs.is_decoding}') is_decoding = inputs.is_decoding eager_mode = self.backend_config.eager_mode @@ -447,13 +451,16 @@ async def __await_distworker(worker, timeout: float = 0.001): # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, # as it can trigger recompilation on different ranks when using torch.compile. with torch.inference_mode(): - next_token_ids = torch.empty_like(num_ignore_eos) + next_token_ids = torch.zeros_like(num_ignore_eos) stopped = None if need_broadcast_next: - logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') - tp_gpu_group = dist_ctx.tp_gpu_group - dist.broadcast(next_token_ids, src=rank // tp * tp, group=tp_gpu_group) + logger.info(f' rank[{rank}]: synchornize token ids [{idx}]') + if self.cache_config.role == EngineRole.Decode: + next_token_ids = next_token_ids.cpu() + dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group) + else: + dist.broadcast(next_token_ids, src=0, group=tp_gpu_group) # send output model_metas = output.get('model_metas') @@ -640,7 +647,11 @@ def build_cache_engine(self): attn_dist_cfg = dist_ctx.dist_config.attn_config tp = attn_dist_cfg.tp - self.cache_engine = CacheEngine(self.cache_config, self.model_config, world_size=tp) + self.cache_engine = CacheEngine(self.cache_config, + self.model_config, + rank=self.rank, + tp_rank=self.tp_rank, + world_size=tp) def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index f59e73f4b..8542e8d18 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,6 +8,8 @@ from torch import Tensor from lmdeploy.messages import GenerationConfig, LogitsProcessor +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch +from lmdeploy.pytorch.disagg.request import MigrationRequest from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger @@ -135,6 +137,17 @@ class MessageStatus(enum.Enum): ABORTED = enum.auto() LOCKED = enum.auto() + # PD Disaggregation + # WAITING_MIGRATION: state of Unmigrated Requests + # in both prefill and decode engines are tagged by + # RUNNING_MIGRATION: state of Migrating Requests + # in decode engine + TO_BE_MIGRATED = enum.auto() + WAITING_MIGRATION = enum.auto() + RUNNING_MIGRATION = enum.auto() + MIGRATION_LOCKED = enum.auto() + MIGRATION_DONE = enum.auto() + _SEQ_COUNT = 0 @@ -215,7 +228,10 @@ def add_sequence(self, adapter_name: str = None, return_logits: bool = False, multimodals: MultiModalInputs = None, - input_embeddings: List[InputEmbeddings] = None) -> 'SchedulerSequence': + input_embeddings: List[InputEmbeddings] = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" if isinstance(token_ids, Tensor): token_ids = token_ids.numpy() @@ -237,6 +253,9 @@ def add_sequence(self, history_embeddings=HistoryEmbeddings(input_embeddings), history_multimodals=HistoryMultiModals(multimodals), return_logits=return_logits, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache, ) self.sequences[seq.seq_id] = seq if self.seq_manager is not None: @@ -443,6 +462,12 @@ class SchedulerSequence: num_ignored_history: int = 0 model_meta: Dict[str, Any] = None + # For Disaggregation + migration_request: Optional[MigrationRequest] = None + resp_cache: bool = False + preserve_cache: bool = False + migration_inputs: Optional[MigrationExecutionBatch] = None + def __post_init__(self): """post init.""" self._num_history_ids: int = 0 diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 275a2b9ab..cae5a0969 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager from dataclasses import dataclass, field, fields -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional import torch @@ -9,6 +9,8 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.backends import get_backend from lmdeploy.pytorch.config import ModelConfig +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch +from lmdeploy.pytorch.disagg.request import MigrationRequest from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor @@ -151,6 +153,8 @@ class ModelInputs: history_cross_length: torch.LongTensor = None model_metas: List[Dict[str, Any]] = None dp_meta: 'DPMeta' = None + migration_inputs: Optional[MigrationExecutionBatch] = None + migration_requests: Optional[MigrationRequest] = None def update(self, input_ids: torch.LongTensor): """update input ids.""" diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 314fbcd49..bc62c1491 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -3,8 +3,9 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Tuple +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.utils import get_logger, logging_timer from ..config import CacheConfig, SchedulerConfig @@ -41,6 +42,9 @@ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) self.sessions: Dict[int, SchedulerSession] = OrderedDict() + # For Disaggregation + self.locked_sessions: Dict[int, SchedulerSession] = OrderedDict() + self.block_manager = build_block_manager(cache_config) self.block_trie = BlockTrie(self.cache_config, self.block_manager) @@ -72,6 +76,24 @@ def locked(self): seq_map = self.seq_manager.get_sequences(MessageStatus.LOCKED) return list(seq_map.values()) + @property + def waiting_migration(self): + """get migration sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_MIGRATION) + return list(seq_map.values()) + + @property + def running_migration(self): + """get migration sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_MIGRATION) + return list(seq_map.values()) + + @property + def migration_done(self): + """get waiting sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) + return list(seq_map.values()) + def build_eviction_helper(self, eviction_type: str): if eviction_type == 'copy': logger.warning('`copy` eviction has been deprecated, ' @@ -114,6 +136,46 @@ def add_sequence(self, seq: SchedulerSequence): # push message to waiting queue self._set_message_status(seq, MessageStatus.WAITING) + @logging_timer('ScheduleMigration', logger) + def _schedule_migration(self): + + running_migration: SeqList = [] + migrating_token_count = 0 + + def _to_running(seq: SchedulerSequence): + """to running.""" + seq.status = MessageStatus.RUNNING_MIGRATION + running_migration.append(seq) + nonlocal migrating_token_count + migrating_token_count += seq.num_token_ids + + def __evict_for_seq(seq: SchedulerSequence, waiting): + """evict until can append.""" + from itertools import chain + + hanging = reversed(self.hanging) + waiting = reversed(waiting) + evictable = list(chain(hanging, waiting)) + return self.eviction_helper.evict_for_seq(seq, evictable, 0) + + def _reorder_migrating(): + """reorder waiting.""" + return sorted(self.waiting_migration, key=lambda seq: seq.arrive_time) + + waiting = _reorder_migrating() + + while len(waiting) > 0: + seq = waiting.pop(0) + self.block_trie.match(waiting) + if not __evict_for_seq(seq, waiting): + break + + # allocate session memory + self.block_manager.allocate(seq) + _to_running(seq) + + return running_migration + @logging_timer('SchedulePrefilling', logger) def _schedule_prefill(self): """Schedule for prefilling.""" @@ -132,6 +194,20 @@ def _to_running(seq: SchedulerSequence): running.append(seq) nonlocal token_count token_count += seq.num_token_ids + if seq.migration_request: + migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] + migration_request = seq.migration_request + prefill_block_ids = migration_request.remote_block_ids + decode_block_ids = list(self.block_manager.get_block_table(msg=seq)) + + assert len(prefill_block_ids) == len(decode_block_ids) + migration_execution_requests.append(( + migration_request.remote_engine_id, + list(zip(prefill_block_ids, decode_block_ids)), + )) + migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol, + requests=migration_execution_requests) + seq.migration_inputs = migration_inputs def __evict_for_seq(seq: SchedulerSequence, waiting): """evict until can append.""" @@ -265,7 +341,7 @@ def end_session(self, session_id: int): def has_unfinished(self): """Check if there are any unfinished message.""" - return self.has_running() or self.has_waiting() + return self.has_running() or self.has_waiting() or self.has_migration_done() def has_running(self): return self.num_running() > 0 @@ -273,6 +349,15 @@ def has_running(self): def has_waiting(self): return self.num_waiting() > 0 + def has_migration_running(self): + return self.num_running() > 0 + + def has_migration_waiting(self): + return self.num_migration_waiting() > 0 + + def has_migration_done(self): + return self.num_migration_done() > 0 + def get_block_tables(self, seqs: SeqList): """get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] @@ -285,6 +370,18 @@ def num_waiting(self): """num waiting.""" return self.seq_manager.num_sequences(MessageStatus.WAITING) + def num_migration_running(self): + """num migration running.""" + return self.seq_manager.num_sequences(MessageStatus.RUNNING_MIGRATION) + + def num_migration_done(self): + """num migration done.""" + return self.seq_manager.num_sequences(MessageStatus.MIGRATION_DONE) + + def num_migration_waiting(self): + """num waiting.""" + return self.seq_manager.num_sequences(MessageStatus.WAITING_MIGRATION) + def num_locked(self): """num locked.""" return self.seq_manager.num_sequences(MessageStatus.LOCKED) @@ -299,3 +396,20 @@ def unlock_running(self, locked: SeqList): for seq in locked: if seq.status == MessageStatus.LOCKED: self._set_message_status(seq, MessageStatus.RUNNING) + + def lock_running_migration(self, running: SeqList): + """lock running sequence.""" + for seq in running: + if seq.status == MessageStatus.RUNNING_MIGRATION: + self._set_message_status(seq, MessageStatus.MIGRATION_LOCKED) + + def unlock_running_migration(self, locked: SeqList): + """unlock running migration.""" + for seq in locked: + if seq.status == MessageStatus.MIGRATION_LOCKED: + self._set_message_status(seq, MessageStatus.MIGRATION_DONE) + + def collect_migration_done(self): + migration_done = self.migration_done + for seq in migration_done: + self._set_message_status(seq, MessageStatus.RUNNING) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index ae8f9eb58..03c2e81b5 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -22,6 +22,7 @@ from lmdeploy.logger import RequestLogger from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig from lmdeploy.model import MODELS, BaseChatTemplate, ChatTemplateConfig, best_match_model +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.serve.utils import LogitsMixin from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger @@ -59,6 +60,9 @@ class GenOut: logits: Any = None last_hidden_state: Any = None + # for disaggregation + cache_block_ids: List[int] = None + def _gen_out_to_response(out: GenOut, index) -> Response: return Response(text=out.response, @@ -757,7 +761,13 @@ def is_error(status): spaces_between_special_tokens=gen_config.spaces_between_special_tokens) res = token_ids[ids_offset:] - out = GenOut(response, history_len, input_len, gen_len, finish_reason, res) + out = GenOut(response, + history_len, + input_len, + gen_len, + finish_reason, + token_ids=res, + cache_block_ids=outputs.cache_block_ids) if outputs.logprobs is not None: log_offset = ids_offset - start_ids_offset @@ -786,7 +796,13 @@ def is_error(status): logger.info(f'session {session_id} finished, reason ' f'"{finish_reason}", input_tokens ' f'{len(input_ids)}, outupt_tokens {gen_len}') - yield GenOut(response, self.id2step[session_id], len(input_ids), gen_len, finish_reason) + yield GenOut(response, + self.id2step[session_id], + len(input_ids), + gen_len, + finish_reason, + token_ids=token_ids, + cache_block_ids=outputs.cache_block_ids) else: logger.error(f'session {session_id} finished, ' 'reason "error"') @@ -880,3 +896,20 @@ def _gen(): session.generator = None return session + + """ DistServe Async Engine API Begin """ + + def free_cache(self, session_id: int): + if session_id in self.engine.scheduler.sessions: + self.engine.scheduler.end_session(session_id) + logger.debug(f'successfully free session {session_id}') + else: + logger.warning(f'Invalid Free session {session_id}.') + + def p2p_initialize(self, init_request: DistServeInitRequest): + return self.engine.executor.p2p_initialize(init_request) + + def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + return self.engine.executor.p2p_connect(conn_request) + + """ DistServe Async Engine API End """ diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 00f463917..35b9c11a5 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -2,6 +2,7 @@ # yapf: disable import asyncio import copy +import json import os import time from functools import partial @@ -18,6 +19,8 @@ from lmdeploy.archs import get_task from lmdeploy.messages import GenerationConfig, LogitsProcessor, PytorchEngineConfig, TurbomindEngineConfig from lmdeploy.model import ChatTemplateConfig +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig +from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest, MigrationRequest from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501 from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice, @@ -263,7 +266,7 @@ def _logit_bias_processor( @router.post('/v1/chat/completions', dependencies=[Depends(check_api_key)]) -async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None): +async def chat_completions_v1(raw_request: Request = None): """Completion API similar to OpenAI's API. Refer to `https://platform.openai.com/docs/api-reference/chat/create` @@ -323,6 +326,14 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ + json_request = await raw_request.json() + request = ChatCompletionRequest.model_validate(json_request) + migration_request = json_request.pop('migration_request', None) + with_cache = json_request.pop('with_cache', False) + preserve_cache = json_request.pop('preserve_cache', False) + if migration_request: + migration_request = MigrationRequest.model_validate(migration_request) + if request.session_id == -1: VariableInterface.session_id += 1 request.session_id = VariableInterface.session_id @@ -376,7 +387,10 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque min_new_tokens=request.min_new_tokens, min_p=request.min_p, random_seed=random_seed, - spaces_between_special_tokens=request.spaces_between_special_tokens) + spaces_between_special_tokens=request.spaces_between_special_tokens, + migration_request=migration_request, + with_cache=with_cache, + preserve_cache=preserve_cache) tools = None if request.tools and request.tool_choice != 'none': @@ -485,6 +499,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: finish_reason=res.finish_reason, logprobs=logprobs, usage=usage) + if res.cache_block_ids is not None: + response_json['cache_block_ids'] = res.cache_block_ids + response_json['remote_token_ids'] = res.token_ids yield f'data: {response_json}\n\n' yield 'data: [DONE]\n\n' @@ -497,6 +514,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: final_token_ids = [] final_res = None text = '' + cache_block_ids = [] + remote_token_ids = [] async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. @@ -508,6 +527,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: final_token_ids.extend(res.token_ids) if res.logprobs: final_logprobs.extend(res.logprobs) + cache_block_ids.append(res.cache_block_ids) + remote_token_ids.append(res.token_ids) tool_calls = None reasoning_content = None @@ -543,6 +564,10 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) choices.append(choice_data) + if with_cache: + cache_block_ids = cache_block_ids[0] + remote_token_ids = [remote_token_ids[0][-1]] + total_tokens = sum([final_res.history_token_len, final_res.input_token_len, final_res.generate_token_len]) usage = UsageInfo( prompt_tokens=final_res.input_token_len, @@ -555,13 +580,17 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: model=model_name, choices=choices, usage=usage, - ) + ).model_dump() + + if with_cache: + response['cache_block_ids'] = cache_block_ids + response['remote_token_ids'] = remote_token_ids return response @router.post('/v1/completions', dependencies=[Depends(check_api_key)]) -async def completions_v1(request: CompletionRequest, raw_request: Request = None): +async def completions_v1(raw_request: Request = None): """Completion API similar to OpenAI's API. Go to `https://platform.openai.com/docs/api-reference/completions/create` @@ -607,6 +636,14 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ + json_request = await raw_request.json() + request = CompletionRequest.model_validate(json_request) + migration_request = json_request.pop('migration_request', None) + with_cache = json_request.pop('with_cache', False) + preserve_cache = json_request.pop('preserve_cache', False) + if migration_request: + migration_request = MigrationRequest.model_validate(migration_request) + if request.session_id == -1: VariableInterface.session_id += 1 request.session_id = VariableInterface.session_id @@ -640,7 +677,10 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None skip_special_tokens=request.skip_special_tokens, min_p=request.min_p, random_seed=random_seed, - spaces_between_special_tokens=request.spaces_between_special_tokens) + spaces_between_special_tokens=request.spaces_between_special_tokens, + migration_request=migration_request, + with_cache=with_cache, + preserve_cache=preserve_cache) generators = [] for i in range(len(request.prompt)): result_generator = VariableInterface.async_engine.generate( @@ -670,8 +710,7 @@ def create_stream_response_json(index: int, choices=[choice_data], usage=usage, ) - response_json = response.model_dump_json() - + response_json = response.model_dump() return response_json async def completion_stream_generator() -> AsyncGenerator[str, None]: @@ -702,7 +741,10 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: finish_reason=res.finish_reason, logprobs=logprobs, usage=usage) - yield f'data: {response_json}\n\n' + if res.cache_block_ids is not None: + response_json['cache_block_ids'] = res.cache_block_ids + response_json['remote_token_ids'] = res.token_ids + yield f'data: {json.dumps(response_json)}\n\n' yield 'data: [DONE]\n\n' # Streaming response @@ -712,8 +754,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # Non-streaming response usage = UsageInfo() choices = [None] * len(generators) + cache_block_ids = [] + remote_token_ids = [] async def _inner_call(i, generator): + nonlocal cache_block_ids, remote_token_ids final_logprobs = [] final_token_ids = [] final_res = None @@ -725,6 +770,8 @@ async def _inner_call(i, generator): return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res text += res.response + cache_block_ids.append(res.cache_block_ids) + remote_token_ids.append(res.token_ids) if res.token_ids: final_token_ids.extend(res.token_ids) if res.logprobs: @@ -748,6 +795,10 @@ async def _inner_call(i, generator): ) choices[i] = choice_data + if with_cache: + cache_block_ids = cache_block_ids[0] + remote_token_ids = [remote_token_ids[0][-1]] + total_tokens = sum([final_res.history_token_len, final_res.input_token_len, final_res.generate_token_len]) usage.prompt_tokens += final_res.input_token_len usage.completion_tokens += final_res.generate_token_len @@ -761,7 +812,11 @@ async def _inner_call(i, generator): model=model_name, choices=choices, usage=usage, - ) + ).model_dump() + + if with_cache: + response['cache_block_ids'] = cache_block_ids + response['remote_token_ids'] = remote_token_ids return response @@ -801,6 +856,46 @@ def encode(prompt: str, do_preprocess: bool, add_bos: bool): return EncodeResponse(input_ids=encoded, length=length) +""" PD Disaggregation API Begin """ + + +@router.get('/distserve/engine_info') +async def engine_info(): + engine = VariableInterface.async_engine.engine + + response = DistServeEngineConfig(tp_size=engine.engine_config.tp, + dp_size=engine.engine_config.dp, + pp_size=None, + ep_size=engine.engine_config.ep, + dp_rank=engine.engine_config.dp_rank, + block_size=engine.engine_config.block_size, + num_cpu_blocks=engine.scheduler.block_manager.num_cpu_blocks, + num_gpu_blocks=engine.scheduler.block_manager.num_gpu_blocks) + + return response.model_dump_json() + + +@router.post('/distserve/p2p_initialize') +async def p2p_initialize(init_request: DistServeInitRequest): + return VariableInterface.async_engine.p2p_initialize(init_request) + + +@router.post('/distserve/p2p_connect') +async def p2p_connect(conn_request: List[DistServeConnectionRequest]): + return VariableInterface.async_engine.p2p_connect(conn_request) + + +@router.post('/distserve/free_cache') +async def free_cache(raw_request: Request) -> JSONResponse: + config = await raw_request.json() + session_id = int(config['session_id']) + VariableInterface.async_engine.free_cache(session_id) + return {'status': 'SUCCESS'} + + +""" PD Disaggregation API End """ + + @router.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)]) async def chat_interactive_v1(request: GenerateRequest, raw_request: Request = None): """Generate completion for the request. @@ -960,14 +1055,20 @@ async def startup_event(): return try: import requests + engine_config = VariableInterface.async_engine.engine.engine_config url = f'{VariableInterface.proxy_url}/nodes/add' - data = {'url': VariableInterface.api_server_url, 'status': {'models': get_model_list()}} + data = { + 'url': VariableInterface.api_server_url, + 'status': { + 'models': get_model_list(), + 'role': engine_config.role.value + } + } headers = {'accept': 'application/json', 'Content-Type': 'application/json'} response = requests.post(url, headers=headers, json=data) if response.status_code != 200: raise HTTPException(status_code=400, detail='Service registration failed') - print(response.text) except Exception as e: print(f'Service registration failed: {e}') diff --git a/lmdeploy/serve/proxy/constants.py b/lmdeploy/serve/proxy/constants.py index 17cd167c9..3484d7e50 100644 --- a/lmdeploy/serve/proxy/constants.py +++ b/lmdeploy/serve/proxy/constants.py @@ -15,7 +15,7 @@ 'through env variable AIOHTTP_TIMEOUT') -class Strategy(enum.Enum): +class RoutingStrategy(enum.Enum): """Strategy to dispatch requests to nodes.""" RANDOM = enum.auto() MIN_EXPECTED_LATENCY = enum.auto() diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 24766334f..37a145277 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -16,16 +16,20 @@ import numpy as np import requests import uvicorn -import yaml from fastapi import BackgroundTasks, Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field +from lmdeploy.pytorch.disagg.config import (DistServeRDMAConfig, EngineRole, MigrationProtocol, RDMALinkType, + ServingStrategy) +from lmdeploy.pytorch.disagg.conn import PDConnectionPool +from lmdeploy.pytorch.disagg.messages import PDConnectionMessage +from lmdeploy.pytorch.disagg.request import MigrationRequest from lmdeploy.serve.openai.api_server import check_api_key, create_error_response from lmdeploy.serve.openai.protocol import ModelCard # noqa: E501 from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission -from lmdeploy.serve.proxy.constants import AIOHTTP_TIMEOUT, LATENCY_DEQUE_LEN, ErrorCodes, Strategy, err_msg +from lmdeploy.serve.proxy.constants import AIOHTTP_TIMEOUT, LATENCY_DEQUE_LEN, ErrorCodes, RoutingStrategy, err_msg from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -33,6 +37,7 @@ class Status(BaseModel): """Status protocol consists of models' information.""" + role: EngineRole = EngineRole.Hybrid models: Optional[List[str]] = Field(default=[], examples=[[]]) unfinished: int = 0 latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]]) @@ -72,36 +77,68 @@ class NodeManager: def __init__(self, config_path: Optional[str] = None, - strategy: str = 'min_expected_latency', + serving_strategy: str = 'Hybrid', + routing_strategy: str = 'min_expected_latency', + migration_protocol: str = 'RDMA', + link_type: str = 'RoCE', + with_gdr: bool = True, cache_status: Optional[bool] = True) -> None: self.nodes = dict() - self.strategy = Strategy.from_str(strategy) + self.serving_strategy = ServingStrategy[serving_strategy] + self.routing_strategy = RoutingStrategy.from_str(routing_strategy) + self.cache_status = cache_status self.latencies = dict() - self.config_path = osp.join(osp.dirname(osp.realpath(__file__)), 'proxy_config.yml') + self.config_path = osp.join(osp.dirname(osp.realpath(__file__)), 'proxy_config.json') if config_path is not None: self.config_path = config_path if osp.exists(self.config_path) and self.cache_status: with open(self.config_path, 'r') as config_file: - self.nodes = yaml.safe_load(config_file)['nodes'] - for url, status in self.nodes.items(): - latency = deque(status.get('latency', []), maxlen=LATENCY_DEQUE_LEN) - status['latency'] = latency - status = Status(**status) - self.nodes[url] = status + if os.path.getsize(self.config_path) > 0: + logger.info(f'loading node configuration: {self.config_path}') + config = json.load(config_file) + self.nodes = { + node_url: Status.model_validate_json(node_status) + for node_url, node_status in config.items() + } self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self, ), daemon=True) self.heart_beat_thread.start() self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) + # For PD Disaggregation + self.migration_protocol = MigrationProtocol[migration_protocol] + self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type]) + self.pd_connection_pool = PDConnectionPool() + self.initialized = False + + def get_nodes(self, role: EngineRole) -> Dict: + return {node_url: node_status for (node_url, node_status) in self.nodes.items() if node_status.role == role} + + @property + def hybrid_nodes(self): + return self.get_nodes(EngineRole.Hybrid) + + @property + def prefill_nodes(self): + return self.get_nodes(EngineRole.Prefill) + + @property + def decode_nodes(self): + return self.get_nodes(EngineRole.Decode) + def update_config_file(self): """Update the config file.""" nodes = copy.deepcopy(self.nodes) - for url, status in nodes.items(): - nodes[url] = status.model_dump() - nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:] + for _, status in nodes.items(): + status.latency = deque(list(status.latency)[-LATENCY_DEQUE_LEN:]) if self.cache_status: with open(self.config_path, 'w') as config_file: # update cfg yml - yaml.dump(dict(nodes=nodes), config_file) + json.dump({ + node_url: node_status.model_dump_json() + for node_url, node_status in nodes.items() + }, + config_file, + indent=2) def add(self, node_url: str, status: Optional[Status] = None): """Add a node to the manager. @@ -117,6 +154,7 @@ def add(self, node_url: str, status: Optional[Status] = None): if status is None: status = self.nodes.get(node_url, Status()) if status.models != []: # force register directly + self.remove(node_url) self.nodes[node_url] = status self.update_config_file() return @@ -134,6 +172,12 @@ def remove(self, node_url: str): if node_url in self.nodes.keys(): self.nodes.pop(node_url) self.update_config_file() + dropped_conn = [] + for conn in self.pd_connection_pool.pool: + if node_url in conn: + dropped_conn.append(conn) + for conn in dropped_conn: + self.pd_connection_pool.drop(*conn) def remove_stale_nodes_by_expiration(self): """remove stale nodes.""" @@ -156,8 +200,8 @@ def remove_stale_nodes_by_expiration(self): def model_list(self): """Supported model list.""" model_names = [] - for node_url, node_status in self.nodes.items(): - model_names.extend(node_status.models) + for _, status in self.nodes.items(): + model_names.extend(status.models) return model_names @property @@ -165,7 +209,7 @@ def status(self): """Return the status.""" return self.nodes - def get_node_url(self, model_name: str): + def get_node_url(self, model_name: str, role: EngineRole = EngineRole.Hybrid): """Add a node to the manager. Args: @@ -177,11 +221,11 @@ def get_node_url(self, model_name: str): def get_matched_urls(): urls_with_speeds, speeds, urls_without_speeds = [], [], [] - for node_url, node_status in self.nodes.items(): - if model_name in node_status.models: - if node_status.speed is not None: + for node_url, status in self.get_nodes(role).items(): + if model_name in status.models: + if status.speed is not None: urls_with_speeds.append(node_url) - speeds.append(node_status.speed) + speeds.append(status.speed) else: urls_without_speeds.append(node_url) all_matched_urls = urls_with_speeds + urls_without_speeds @@ -193,7 +237,7 @@ def get_matched_urls(): all_the_speeds = speeds + [average_speed] * len(urls_without_speeds) return all_matched_urls, all_the_speeds - if self.strategy == Strategy.RANDOM: + if self.routing_strategy == RoutingStrategy.RANDOM: all_matched_urls, all_the_speeds = get_matched_urls() if len(all_matched_urls) == 0: return None @@ -202,7 +246,7 @@ def get_matched_urls(): index = random.choices(range(len(all_matched_urls)), weights=weights)[0] url = all_matched_urls[index] return url - elif self.strategy == Strategy.MIN_EXPECTED_LATENCY: + elif self.routing_strategy == RoutingStrategy.MIN_EXPECTED_LATENCY: all_matched_urls, all_the_speeds = get_matched_urls() if len(all_matched_urls) == 0: return None @@ -212,15 +256,15 @@ def get_matched_urls(): all_indexes = [i for i in range(len(all_the_speeds))] random.shuffle(all_indexes) for index in all_indexes: - latency = self.nodes[all_matched_urls[index]].unfinished / all_the_speeds[index] + latency = self.get_nodes(role)[all_matched_urls[index]].unfinished / all_the_speeds[index] if min_latency > latency: min_latency = latency min_index = index url = all_matched_urls[min_index] return url - elif self.strategy == Strategy.MIN_OBSERVED_LATENCY: + elif self.routing_strategy == RoutingStrategy.MIN_OBSERVED_LATENCY: all_matched_urls, latencies = [], [] - for node_url, node_status in self.nodes.items(): + for node_url, node_status in self.get_nodes(role).items(): if model_name in node_status.models: if len(node_status.latency): latencies.append(np.mean(np.array(node_status.latency))) @@ -232,7 +276,7 @@ def get_matched_urls(): index = np.argmin(np.array(latencies)) return all_matched_urls[index] else: - raise ValueError(f'Invalid strategy: {self.strategy}') + raise ValueError(f'Invalid strategy: {self.routing_strategy}') async def check_request_model(self, model_name) -> Optional[JSONResponse]: """Check if a request is valid.""" @@ -263,7 +307,12 @@ def handle_api_timeout(self, node_url): } return json.dumps(ret).encode() + b'\n' - async def stream_generate(self, request: Dict, node_url: str, endpoint: str): + async def stream_generate(self, + request: Dict, + node_url: str, + endpoint: str, + prefill_url: Optional[str] = None, + remote_session_id: int = None): """Return a generator to handle the input request. Args: @@ -277,12 +326,16 @@ async def stream_generate(self, request: Dict, node_url: str, endpoint: str): async for line in response.content: if line.strip(): yield line + b'\n\n' + if prefill_url: + async with session.post(f'{prefill_url}/distserve/free_cache', + json={'session_id': remote_session_id}) as response: + await response.json() except (Exception, GeneratorExit, aiohttp.ClientError) as e: # noqa logger.error(f'catched an exception: {e}') # exception happened, reduce unfinished num yield self.handle_api_timeout(node_url) - async def generate(self, request: Dict, node_url: str, endpoint: str): + async def generate(self, request: Dict, node_url: str, endpoint: str, is_prefill: bool = False): """Return a the response of the input request. Args: @@ -391,6 +444,20 @@ def remove_node(node_url: str): return 'Failed to delete, please check the input url.' +@app.post('/distserve/connection_warmup') +async def connection_warmup(): + await asyncio.gather(*[ + node_manager.pd_connection_pool.connect( + PDConnectionMessage( + p_url=p_url, + d_url=d_url, + protocol=node_manager.migration_protocol, + rdma_config=node_manager.rdma_config, + )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes + ]) + return JSONResponse({'SUCCESS': True}) + + @app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)]) async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None): """Completion API similar to OpenAI's API. @@ -450,21 +517,89 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque check_response = await node_manager.check_request_model(request.model) if check_response is not None: return check_response - node_url = node_manager.get_node_url(request.model) - if not node_url: - return node_manager.handle_unavailable_model(request.model) - - logger.info(f'A request is dispatched to {node_url}') - request_dict = request.model_dump() - start = node_manager.pre_call(node_url) - if request.stream is True: - response = node_manager.stream_generate(request_dict, node_url, '/v1/chat/completions') - background_task = node_manager.create_background_tasks(node_url, start) - return StreamingResponse(response, background=background_task) + + if node_manager.serving_strategy == ServingStrategy.Hybrid: + node_url = node_manager.get_node_url(request.model) + if not node_url: + return node_manager.handle_unavailable_model(request.model) + + logger.info(f'A request is dispatched to {node_url}') + request_dict = request.model_dump() + start = node_manager.pre_call(node_url) + if request.stream is True: + response = node_manager.stream_generate(request_dict, node_url, '/v1/chat/completions') + background_task = node_manager.create_background_tasks(node_url, start) + return StreamingResponse(response, background=background_task) + else: + response = await node_manager.generate(request_dict, node_url, '/v1/chat/completions') + node_manager.post_call(node_url, start) + return JSONResponse(json.loads(response)) + elif node_manager.serving_strategy == ServingStrategy.DistServe: + request_dict = request.model_dump() + + # Prefill + prefill_request_dict = copy.deepcopy(request_dict) + prefill_request_dict['max_tokens'] = 1 + prefill_request_dict['stream'] = False + prefill_request_dict['with_cache'] = True + prefill_request_dict['preserve_cache'] = True + + p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) + if not p_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'A Prefill request is dispatched to {p_url}') + + start = node_manager.pre_call(p_url) + prefill_info = json.loads(await node_manager.generate(prefill_request_dict, + p_url, + '/v1/chat/completions', + is_prefill=True)) + node_manager.post_call(p_url, start) + + # # Decode + d_url = node_manager.get_node_url(request.model, EngineRole.Decode) + if not d_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'A Decode request is dispatched to {d_url}') + + if not node_manager.pd_connection_pool.is_connected(p_url, d_url): + await node_manager.pd_connection_pool.connect( + PDConnectionMessage( + p_url=p_url, + d_url=d_url, + protocol=node_manager.migration_protocol, + rdma_config=node_manager.rdma_config, + )) + request_dict['migration_request'] = MigrationRequest( + protocol=node_manager.migration_protocol, + remote_engine_id=p_url, + remote_session_id=int(prefill_info['id']), + remote_block_ids=prefill_info['cache_block_ids'], + remote_token_id=prefill_info['remote_token_ids'][-1], + ).model_dump(mode='json') + + start = node_manager.pre_call(d_url) + if request.stream is True: + response = node_manager.stream_generate(request_dict, + d_url, + '/v1/chat/completions', + prefill_url=p_url, + remote_session_id=int(prefill_info['id'])) + background_task = node_manager.create_background_tasks(d_url, start) + return StreamingResponse(response, background=background_task) + else: + try: + response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions') + node_manager.post_call(d_url, start) + resp = JSONResponse(json.loads(response)) + finally: + async with aiohttp.ClientSession() as session: + async with session.post(f'{p_url}/distserve/free_cache', json={'session_id': + prefill_info['id']}) as response: + await response.json() + return resp else: - response = await node_manager.generate(request_dict, node_url, '/v1/chat/completions') - node_manager.post_call(node_url, start) - return JSONResponse(json.loads(response)) + raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') @app.post('/v1/completions', dependencies=[Depends(check_api_key)]) @@ -507,37 +642,108 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None check_response = await node_manager.check_request_model(request.model) if check_response is not None: return check_response - node_url = node_manager.get_node_url(request.model) - if not node_url: - return node_manager.handle_unavailable_model(request.model) - - logger.info(f'A request is dispatched to {node_url}') - request_dict = request.model_dump() - start = node_manager.pre_call(node_url) - if request.stream is True: - response = node_manager.stream_generate(request_dict, node_url, '/v1/completions') - background_task = node_manager.create_background_tasks(node_url, start) - return StreamingResponse(response, background=background_task) + if node_manager.serving_strategy == ServingStrategy.Hybrid: + node_url = node_manager.get_node_url(request.model) + if not node_url: + return node_manager.handle_unavailable_model(request.model) + + logger.info(f'A request is dispatched to {node_url}') + request_dict = request.model_dump() + start = node_manager.pre_call(node_url) + if request.stream is True: + response = node_manager.stream_generate(request_dict, node_url, '/v1/completions') + background_task = node_manager.create_background_tasks(node_url, start) + return StreamingResponse(response, background=background_task) + else: + response = await node_manager.generate(request_dict, node_url, '/v1/completions') + node_manager.post_call(node_url, start) + return JSONResponse(json.loads(response)) + elif node_manager.serving_strategy == ServingStrategy.DistServe: + request_dict = request.model_dump() + + # Prefill + prefill_request_dict = copy.deepcopy(request_dict) + prefill_request_dict['max_tokens'] = 1 + prefill_request_dict['stream'] = False + prefill_request_dict['with_cache'] = True + prefill_request_dict['preserve_cache'] = True + + p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) + if not p_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'A Prefill request is dispatched to {p_url}') + + start = node_manager.pre_call(p_url) + prefill_info = json.loads(await node_manager.generate(prefill_request_dict, + p_url, + '/v1/completions', + is_prefill=True)) + node_manager.post_call(p_url, start) + + # # Decode + d_url = node_manager.get_node_url(request.model, EngineRole.Decode) + if not d_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'A Decode request is dispatched to {d_url}') + + if not node_manager.pd_connection_pool.is_connected(p_url, d_url): + await node_manager.pd_connection_pool.connect( + PDConnectionMessage( + p_url=p_url, + d_url=d_url, + protocol=node_manager.migration_protocol, + rdma_config=node_manager.rdma_config, + )) + + request_dict['migration_request'] = MigrationRequest( + protocol=node_manager.migration_protocol, + remote_engine_id=p_url, + remote_session_id=int(prefill_info['id']), + remote_block_ids=prefill_info['cache_block_ids'], + remote_token_id=prefill_info['remote_token_ids'][-1], + ).model_dump(mode='json') + + start = node_manager.pre_call(d_url) + if request.stream is True: + response = node_manager.stream_generate(request_dict, + d_url, + '/v1/completions', + prefill_url=p_url, + remote_session_id=int(prefill_info['id'])) + background_task = node_manager.create_background_tasks(d_url, start) + return StreamingResponse(response, background=background_task) + else: + response = await node_manager.generate(request_dict, d_url, '/v1/completions') + node_manager.post_call(d_url, start) + resp = JSONResponse(json.loads(response)) + async with aiohttp.ClientSession() as session: + async with session.post(f'{p_url}/distserve/free_cache', json={'session_id': + prefill_info['id']}) as response: + await response.json() + return resp else: - response = await node_manager.generate(request_dict, node_url, '/v1/completions') - node_manager.post_call(node_url, start) - return JSONResponse(json.loads(response)) + raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') def proxy(server_name: str = '0.0.0.0', server_port: int = 8000, - strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency', + serving_strategy: Literal['Hybrid', 'DistServe'] = 'Hybrid', + routing_strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency', api_keys: Optional[Union[List[str], str]] = None, ssl: bool = False, log_level: str = 'INFO', disable_cache_status: bool = False, + link_type: Literal['RoCE', 'IB'] = 'ROCE', + migration_protocol: Literal['RDMA'] = 'RDMA', **kwargs): """To launch the proxy server. Args: server_name (str): the server name of the proxy. Default to '0.0.0.0'. server_port (str): the server port. Default to 8000. - strategy ('random' | 'min_expected_latency' | 'min_observed_latency'): + serving_strategy ('Hybrid' | 'DistServe'): the strategy to serving. Hybrid default. + DistServe for PD Disaggregation. + route_strategy ('random' | 'min_expected_latency' | 'min_observed_latency'): the strategy to dispatch requests to nodes. Default to 'min_expected_latency' api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as @@ -546,8 +752,16 @@ def proxy(server_name: str = '0.0.0.0', log_level (str): Set the log level. Default to INFO. disable_cache_status (str): Whether to cache the proxy status to proxy_config.yml. + migration_protocol: migration protocol when PD disaggregation. RDMA default. """ # noqa - node_manager.strategy = Strategy.from_str(strategy) + node_manager.serving_strategy = ServingStrategy[serving_strategy] + node_manager.routing_strategy = RoutingStrategy.from_str(routing_strategy) + node_manager.migration_protocol = MigrationProtocol[migration_protocol] + + node_manager.rdma_config = DistServeRDMAConfig( + link_type=RDMALinkType[link_type], + with_gdr=True, + ) node_manager.cache_status = not disable_cache_status if api_keys is not None: if isinstance(api_keys, str):