diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 9eda239038..92d47aa583 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -362,6 +362,7 @@ async def handle_request(self, request: Request): presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, streaming=stream_opt, + language=chat_request.language if chat_request.language else "auto", ) result_dict, runtime_graph = await self.megaservice.schedule( initial_inputs={"query": prompt}, llm_parameters=parameters diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index d2fb0adb19..cf8b2ca1df 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -175,6 +175,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: Optional[Union[Literal["none"], ChatCompletionNamedToolChoiceParam]] = "none" parallel_tool_calls: Optional[bool] = True user: Optional[str] = None + language: str = "auto" # can be "en", "zh" # Ordered by official OpenAI API documentation # default values are same with diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 3b6ce50215..ee06017a5f 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -175,6 +175,7 @@ class LLMParamsDoc(BaseDoc): presence_penalty: float = 0.0 repetition_penalty: float = 1.03 streaming: bool = True + language: str = "auto" # can be "en", "zh" chat_template: Optional[str] = Field( default=None, @@ -212,6 +213,7 @@ class LLMParams(BaseDoc): presence_penalty: float = 0.0 repetition_penalty: float = 1.03 streaming: bool = True + language: str = "auto" # can be "en", "zh" chat_template: Optional[str] = Field( default=None, diff --git a/comps/llms/summarization/tgi/langchain/README.md b/comps/llms/summarization/tgi/langchain/README.md index d4a4fdacf1..85a9259bc5 100644 --- a/comps/llms/summarization/tgi/langchain/README.md +++ b/comps/llms/summarization/tgi/langchain/README.md @@ -92,12 +92,18 @@ curl http://${your_ip}:9000/v1/health_check\ # Enable streaming to receive a streaming response. By default, this is set to True. curl http://${your_ip}:9000/v1/chat/docsum \ -X POST \ - -d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5."}' \ + -d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en"}' \ -H 'Content-Type: application/json' # Disable streaming to receive a non-streaming response. curl http://${your_ip}:9000/v1/chat/docsum \ -X POST \ - -d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "streaming":false}' \ + -d '{"query":"Text Embeddings Inference (TEI) is a toolkit for deploying and serving open source text embeddings and sequence classification models. TEI enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE and E5.", "max_tokens":32, "language":"en", "streaming":false}' \ + -H 'Content-Type: application/json' + +# Use Chinese mode. By default, language is set to "en" +curl http://${your_ip}:9000/v1/chat/docsum \ + -X POST \ + -d '{"query":"2024年9月26日,北京——今日,英特尔正式发布英特尔® 至强® 6性能核处理器(代号Granite Rapids),为AI、数据分析、科学计算等计算密集型业务提供卓越性能。", "max_tokens":32, "language":"zh", "streaming":false}' \ -H 'Content-Type: application/json' ``` diff --git a/comps/llms/summarization/tgi/langchain/llm.py b/comps/llms/summarization/tgi/langchain/llm.py index e9f85cb829..702b194326 100644 --- a/comps/llms/summarization/tgi/langchain/llm.py +++ b/comps/llms/summarization/tgi/langchain/llm.py @@ -4,26 +4,35 @@ import os from fastapi.responses import StreamingResponse -from langchain.chains.summarize import load_summarize_chain -from langchain.docstore.document import Document -from langchain.text_splitter import CharacterTextSplitter -from langchain_huggingface import HuggingFaceEndpoint +from huggingface_hub import AsyncInferenceClient +from langchain.prompts import PromptTemplate from comps import CustomLogger, GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice logger = CustomLogger("llm_docsum") logflag = os.getenv("LOGFLAG", False) +llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") +llm = AsyncInferenceClient( + model=llm_endpoint, + timeout=600, +) + +templ_en = """Write a concise summary of the following: + + +"{text}" + + +CONCISE SUMMARY:""" -def post_process_text(text: str): - if text == " ": - return "data: @#$\n\n" - if text == "\n": - return "data:
\n\n" - if text.isspace(): - return None - new_text = text.replace(" ", "@#$") - return f"data: {new_text}\n\n" +templ_zh = """请简要概括以下内容: + + +"{text}" + + +概况:""" @register_microservice( @@ -37,46 +46,51 @@ async def llm_generate(input: LLMParamsDoc): if logflag: logger.info(input) - llm = HuggingFaceEndpoint( - endpoint_url=llm_endpoint, + if input.language in ["en", "auto"]: + templ = templ_en + elif input.language in ["zh"]: + templ = templ_zh + else: + raise NotImplementedError('Please specify the input language in "en", "zh", "auto"') + + prompt_template = PromptTemplate.from_template(templ) + prompt = prompt_template.format(text=input.query) + + if logflag: + logger.info("After prompting:") + logger.info(prompt) + + text_generation = await llm.text_generation( + prompt=prompt, + stream=input.streaming, max_new_tokens=input.max_tokens, + repetition_penalty=input.repetition_penalty, + temperature=input.temperature, top_k=input.top_k, top_p=input.top_p, typical_p=input.typical_p, - temperature=input.temperature, - repetition_penalty=input.repetition_penalty, - streaming=input.streaming, ) - llm_chain = load_summarize_chain(llm=llm, chain_type="map_reduce") - texts = text_splitter.split_text(input.query) - - # Create multiple documents - docs = [Document(page_content=t) for t in texts] if input.streaming: async def stream_generator(): - from langserve.serialization import WellKnownLCSerializer - - _serializer = WellKnownLCSerializer() - async for chunk in llm_chain.astream_log(docs): - data = _serializer.dumps({"ops": chunk.ops}).decode("utf-8") + chat_response = "" + async for text in text_generation: + chat_response += text + chunk_repr = repr(text.encode("utf-8")) if logflag: - logger.info(f"[docsum - text_summarize] data: {data}") - yield f"data: {data}\n\n" + logger.info(f"[ docsum - text_summarize ] chunk:{chunk_repr}") + yield f"data: {chunk_repr}\n\n" + if logflag: + logger.info(f"[ docsum - text_summarize ] stream response: {chat_response}") yield "data: [DONE]\n\n" return StreamingResponse(stream_generator(), media_type="text/event-stream") else: - response = await llm_chain.ainvoke(docs) - response = response["output_text"] if logflag: - logger.info(response) - return GeneratedDoc(text=response, prompt=input.query) + logger.info(text_generation) + return GeneratedDoc(text=text_generation, prompt=input.query) if __name__ == "__main__": - llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080") - # Split text - text_splitter = CharacterTextSplitter() opea_microservices["opea_service@llm_docsum"].start()