From eca0f15999c5a4517ec6bf7bbec9a4ea77e5771d Mon Sep 17 00:00:00 2001 From: Kdump Date: Tue, 24 Sep 2024 13:23:16 +0800 Subject: [PATCH 1/4] ## Add vllm_worker support for lora_modules ## usage ### start ```bash export VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m fastchat.serve.vllm_worker \ --model-path /data/models/Qwen/Qwen2-72B-Instruct \ --tokenizer /data/dpo/lora/b15s1/saves/Qwen2-72B-Instruct/v7.9/v7.3 \ --enable-lora \ --lora-modules m1=/data/modules/lora/adapter/m1 m2=/data/modules/lora/adapter/m2 m3=/data/modules/lora/adapter/m3 \ --model-names qwen2-72b-instruct,m1,m2,m3\ --controller http://localhost:21001 \ --host 0.0.0.0 \ --num-gpus 8 \ --port 31034 \ --limit-worker-concurrency 100 \ --worker-address http://localhost:31034 ``` ### post - example1 ```bash curl --location --request POST 'http://llm-gw.sunlinecloud.cn/v1/chat/completions' \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer sk-xxx' \ --data-raw '{ "model": "m1", "stream": false, "temperature": 0.7, "top_p": 0.1, "max_tokens": 4096, "messages": [ { "role": "user", "content": "Hi?" } ] }' ``` - example2 ```bash curl --location --request POST 'http://llm-gw.sunlinecloud.cn/v1/chat/completions' \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer sk-xxx' \ --data-raw '{ "model": "qwen2-72b-instruct", "stream": false, "temperature": 0.7, "top_p": 0.1, "max_tokens": 4096, "messages": [ { "role": "user", "content": "Hi?" } ] }' ``` --- fastchat/serve/vllm_worker.py | 42 +++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0af680bb5..09dced415 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -13,7 +13,9 @@ from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from vllm import AsyncLLMEngine -from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.openai.cli_args import LoRAParserAction +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -24,10 +26,8 @@ ) from fastchat.utils import get_context_length, is_partial_stop - app = FastAPI() - class VLLMWorker(BaseModelWorker): def __init__( self, @@ -40,6 +40,7 @@ def __init__( no_register: bool, llm_engine: AsyncLLMEngine, conv_template: str, + lora_requests: LoRARequest, ): super().__init__( controller_addr, @@ -55,6 +56,7 @@ def __init__( f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." ) self.tokenizer = llm_engine.engine.tokenizer + self.lora_requests = lora_requests # This is to support vllm >= 0.2.7 where TokenizerGroup was introduced # and llm_engine.engine.tokenizer was no longer a raw tokenizer if hasattr(self.tokenizer, "tokenizer"): @@ -64,9 +66,20 @@ def __init__( if not no_register: self.init_heart_beat() + def find_lora(self, model): + lora_request = next((item for item in lora_requests if item.lora_name == model), None) + + if lora_request: + logger.info(f"Successfully selected LoRA adapter: {model}") + return lora_request + else: + logger.warning(f"Corresponding LoRA not found: {model}, will perform inference without LoRA adapter.") + return None + async def generate_stream(self, params): self.call_ct += 1 + model = params.pop("model") context = params.pop("prompt") request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) @@ -116,7 +129,9 @@ async def generate_stream(self, params): frequency_penalty=frequency_penalty, best_of=best_of, ) - results_generator = engine.generate(context, sampling_params, request_id) + if self.lora_requests and len(self.lora_requests) > 0: + lora_request = self.find_lora(model) + results_generator = engine.generate(context, sampling_params, request_id, lora_request = lora_request) async for request_output in results_generator: prompt = request_output.prompt @@ -278,6 +293,14 @@ async def api_model_details(request: Request): "throughput. However, if the value is too high, it may cause out-of-" "memory (OOM) errors.", ) + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in the format name=path. " + "Multiple modules can be specified.") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -286,6 +309,16 @@ async def api_model_details(request: Request): if args.num_gpus > 1: args.tensor_parallel_size = args.num_gpus + lora_requests = None + if args.lora_modules is not None: + lora_requests = [ + LoRARequest( + lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + ) for i, lora in enumerate(args.lora_modules, start=1) + ] + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) worker = VLLMWorker( @@ -298,5 +331,6 @@ async def api_model_details(request: Request): args.no_register, engine, args.conv_template, + lora_requests ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") From 2f90685b915ed21d6bd841b0c45c04d85ee03516 Mon Sep 17 00:00:00 2001 From: Kdump Date: Tue, 24 Sep 2024 14:30:01 +0800 Subject: [PATCH 2/4] ## Add vllm_worker support for lora_modules ## usage ### start ```bash export VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m fastchat.serve.vllm_worker \ --model-path /data/models/Qwen/Qwen2-72B-Instruct \ --tokenizer /data/dpo/lora/b15s1/saves/Qwen2-72B-Instruct/v7.9/v7.3 \ --enable-lora \ --lora-modules m1=/data/modules/lora/adapter/m1 m2=/data/modules/lora/adapter/m2 m3=/data/modules/lora/adapter/m3 \ --model-names qwen2-72b-instruct,m1,m2,m3\ --controller http://localhost:21001 \ --host 0.0.0.0 \ --num-gpus 8 \ --port 31034 \ --limit-worker-concurrency 100 \ --worker-address http://localhost:31034 ``` ### post - example1 ```bash curl --location --request POST 'http://llm-gw.sunlinecloud.cn/v1/chat/completions' \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer sk-xxx' \ --data-raw '{ "model": "m1", "stream": false, "temperature": 0.7, "top_p": 0.1, "max_tokens": 4096, "messages": [ { "role": "user", "content": "Hi?" } ] }' ``` - example2 ```bash curl --location --request POST 'http://llm-gw.sunlinecloud.cn/v1/chat/completions' \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer sk-xxx' \ --data-raw '{ "model": "qwen2-72b-instruct", "stream": false, "temperature": 0.7, "top_p": 0.1, "max_tokens": 4096, "messages": [ { "role": "user", "content": "Hi?" } ] }' ``` --- fastchat/serve/vllm_worker.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 09dced415..0df37c127 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -28,6 +28,7 @@ app = FastAPI() + class VLLMWorker(BaseModelWorker): def __init__( self, @@ -67,13 +68,17 @@ def __init__( self.init_heart_beat() def find_lora(self, model): - lora_request = next((item for item in lora_requests if item.lora_name == model), None) + lora_request = next( + (item for item in lora_requests if item.lora_name == model), None + ) if lora_request: logger.info(f"Successfully selected LoRA adapter: {model}") return lora_request else: - logger.warning(f"Corresponding LoRA not found: {model}, will perform inference without LoRA adapter.") + logger.warning( + f"Corresponding LoRA not found: {model}, will perform inference without LoRA adapter." + ) return None async def generate_stream(self, params): @@ -131,7 +136,9 @@ async def generate_stream(self, params): ) if self.lora_requests and len(self.lora_requests) > 0: lora_request = self.find_lora(model) - results_generator = engine.generate(context, sampling_params, request_id, lora_request = lora_request) + results_generator = engine.generate( + context, sampling_params, request_id, lora_request=lora_request + ) async for request_output in results_generator: prompt = request_output.prompt @@ -171,9 +178,11 @@ async def generate_stream(self, params): "cumulative_logprob": [ output.cumulative_logprob for output in request_output.outputs ], - "finish_reason": request_output.outputs[0].finish_reason - if len(request_output.outputs) == 1 - else [output.finish_reason for output in request_output.outputs], + "finish_reason": ( + request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs] + ), } # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. # This aligns with the behavior of model_worker. @@ -297,10 +306,11 @@ async def api_model_details(request: Request): "--lora-modules", type=nullable_str, default=None, - nargs='+', + nargs="+", action=LoRAParserAction, help="LoRA module configurations in the format name=path. " - "Multiple modules can be specified.") + "Multiple modules can be specified.", + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -316,7 +326,8 @@ async def api_model_details(request: Request): lora_name=lora.name, lora_int_id=i, lora_path=lora.path, - ) for i, lora in enumerate(args.lora_modules, start=1) + ) + for i, lora in enumerate(args.lora_modules, start=1) ] engine_args = AsyncEngineArgs.from_cli_args(args) @@ -331,6 +342,6 @@ async def api_model_details(request: Request): args.no_register, engine, args.conv_template, - lora_requests + lora_requests, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") From d36dc7427630c3e96e41f33ad8d3895e19b34ea2 Mon Sep 17 00:00:00 2001 From: Kdump Date: Fri, 27 Sep 2024 10:44:24 +0800 Subject: [PATCH 3/4] add doc --- docs/vllm_integration.md | 66 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/docs/vllm_integration.md b/docs/vllm_integration.md index 7d3205bb8..b99515246 100644 --- a/docs/vllm_integration.md +++ b/docs/vllm_integration.md @@ -23,3 +23,69 @@ See the supported models [here](https://vllm.readthedocs.io/en/latest/models/sup ''' python3 -m fastchat.serve.vllm_worker --model-path TheBloke/vicuna-7B-v1.5-AWQ --quantization awq ''' + +## Add vllm_worker support for lora_modules + +### usage + +1. start + +```bash +export VLLM_WORKER_MULTIPROC_METHOD=spawn +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m fastchat.serve.vllm_worker \ + --model-path /data/models/Qwen/Qwen2-72B-Instruct \ + --tokenizer /data/models/Qwen/Qwen2-72B-Instruct \ + --enable-lora \ + --lora-modules m1=/data/modules/lora/adapter/m1 m2=/data/modules/lora/adapter/m2 m3=/data/modules/lora/adapter/m3 \ + --model-names qwen2-72b-instruct,m1,m2,m3\ + --controller http://localhost:21001 \ + --host 0.0.0.0 \ + --num-gpus 8 \ + --port 31034 \ + --limit-worker-concurrency 100 \ + --worker-address http://localhost:31034 +``` + +1. post + +- example1 + +```bash +curl --location --request POST 'http://fastchat_address:port/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer sk-xxx' \ +--data-raw '{ + "model": "m1", + "stream": false, + "temperature": 0.7, + "top_p": 0.1, + "max_tokens": 4096, + "messages": [ + { + "role": "user", + "content": "Hi?" + } + ] + }' +``` + +- example2 + +```bash +curl --location --request POST 'http://fastchat_address:port/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer sk-xxx' \ +--data-raw '{ + "model": "qwen2-72b-instruct", + "stream": false, + "temperature": 0.7, + "top_p": 0.1, + "max_tokens": 4096, + "messages": [ + { + "role": "user", + "content": "Hi?" + } + ] + }' +``` From 4591d5ef666e4d30d27d86309c82b2c02507dfc2 Mon Sep 17 00:00:00 2001 From: Kdump Date: Fri, 11 Oct 2024 11:35:25 +0800 Subject: [PATCH 4/4] fix lora_request variable is not declared --- fastchat/serve/vllm_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0df37c127..82b449451 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -134,6 +134,7 @@ async def generate_stream(self, params): frequency_penalty=frequency_penalty, best_of=best_of, ) + lora_request = None if self.lora_requests and len(self.lora_requests) > 0: lora_request = self.find_lora(model) results_generator = engine.generate(