diff --git a/docs/vllm_integration.md b/docs/vllm_integration.md index 7d3205bb8..b99515246 100644 --- a/docs/vllm_integration.md +++ b/docs/vllm_integration.md @@ -23,3 +23,69 @@ See the supported models [here](https://vllm.readthedocs.io/en/latest/models/sup ''' python3 -m fastchat.serve.vllm_worker --model-path TheBloke/vicuna-7B-v1.5-AWQ --quantization awq ''' + +## Add vllm_worker support for lora_modules + +### usage + +1. start + +```bash +export VLLM_WORKER_MULTIPROC_METHOD=spawn +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m fastchat.serve.vllm_worker \ + --model-path /data/models/Qwen/Qwen2-72B-Instruct \ + --tokenizer /data/models/Qwen/Qwen2-72B-Instruct \ + --enable-lora \ + --lora-modules m1=/data/modules/lora/adapter/m1 m2=/data/modules/lora/adapter/m2 m3=/data/modules/lora/adapter/m3 \ + --model-names qwen2-72b-instruct,m1,m2,m3\ + --controller http://localhost:21001 \ + --host 0.0.0.0 \ + --num-gpus 8 \ + --port 31034 \ + --limit-worker-concurrency 100 \ + --worker-address http://localhost:31034 +``` + +1. post + +- example1 + +```bash +curl --location --request POST 'http://fastchat_address:port/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer sk-xxx' \ +--data-raw '{ + "model": "m1", + "stream": false, + "temperature": 0.7, + "top_p": 0.1, + "max_tokens": 4096, + "messages": [ + { + "role": "user", + "content": "Hi?" + } + ] + }' +``` + +- example2 + +```bash +curl --location --request POST 'http://fastchat_address:port/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer sk-xxx' \ +--data-raw '{ + "model": "qwen2-72b-instruct", + "stream": false, + "temperature": 0.7, + "top_p": 0.1, + "max_tokens": 4096, + "messages": [ + { + "role": "user", + "content": "Hi?" + } + ] + }' +``` diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0af680bb5..82b449451 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -13,7 +13,9 @@ from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from vllm import AsyncLLMEngine -from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.openai.cli_args import LoRAParserAction +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -24,7 +26,6 @@ ) from fastchat.utils import get_context_length, is_partial_stop - app = FastAPI() @@ -40,6 +41,7 @@ def __init__( no_register: bool, llm_engine: AsyncLLMEngine, conv_template: str, + lora_requests: LoRARequest, ): super().__init__( controller_addr, @@ -55,6 +57,7 @@ def __init__( f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." ) self.tokenizer = llm_engine.engine.tokenizer + self.lora_requests = lora_requests # This is to support vllm >= 0.2.7 where TokenizerGroup was introduced # and llm_engine.engine.tokenizer was no longer a raw tokenizer if hasattr(self.tokenizer, "tokenizer"): @@ -64,9 +67,24 @@ def __init__( if not no_register: self.init_heart_beat() + def find_lora(self, model): + lora_request = next( + (item for item in lora_requests if item.lora_name == model), None + ) + + if lora_request: + logger.info(f"Successfully selected LoRA adapter: {model}") + return lora_request + else: + logger.warning( + f"Corresponding LoRA not found: {model}, will perform inference without LoRA adapter." + ) + return None + async def generate_stream(self, params): self.call_ct += 1 + model = params.pop("model") context = params.pop("prompt") request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) @@ -116,7 +134,12 @@ async def generate_stream(self, params): frequency_penalty=frequency_penalty, best_of=best_of, ) - results_generator = engine.generate(context, sampling_params, request_id) + lora_request = None + if self.lora_requests and len(self.lora_requests) > 0: + lora_request = self.find_lora(model) + results_generator = engine.generate( + context, sampling_params, request_id, lora_request=lora_request + ) async for request_output in results_generator: prompt = request_output.prompt @@ -156,9 +179,11 @@ async def generate_stream(self, params): "cumulative_logprob": [ output.cumulative_logprob for output in request_output.outputs ], - "finish_reason": request_output.outputs[0].finish_reason - if len(request_output.outputs) == 1 - else [output.finish_reason for output in request_output.outputs], + "finish_reason": ( + request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs] + ), } # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. # This aligns with the behavior of model_worker. @@ -278,6 +303,15 @@ async def api_model_details(request: Request): "throughput. However, if the value is too high, it may cause out-of-" "memory (OOM) errors.", ) + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs="+", + action=LoRAParserAction, + help="LoRA module configurations in the format name=path. " + "Multiple modules can be specified.", + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -286,6 +320,17 @@ async def api_model_details(request: Request): if args.num_gpus > 1: args.tensor_parallel_size = args.num_gpus + lora_requests = None + if args.lora_modules is not None: + lora_requests = [ + LoRARequest( + lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + ) + for i, lora in enumerate(args.lora_modules, start=1) + ] + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) worker = VLLMWorker( @@ -298,5 +343,6 @@ async def api_model_details(request: Request): args.no_register, engine, args.conv_template, + lora_requests, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info")