diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 148143ba7f..bf66f661bd 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -25,6 +25,41 @@ from .micro_service import MicroService +def read_pdf(file): + from langchain.document_loaders import PyPDFLoader + + loader = PyPDFLoader(file) + docs = loader.load_and_split() + return docs + + +def read_text_from_file(file, save_file_name): + import docx2txt + from langchain.text_splitter import CharacterTextSplitter + + # read text file + if file.headers["content-type"] == "text/plain": + file.file.seek(0) + content = file.file.read().decode("utf-8") + # Split text + text_splitter = CharacterTextSplitter() + texts = text_splitter.split_text(content) + # Create multiple documents + file_content = texts + # read pdf file + elif file.headers["content-type"] == "application/pdf": + documents = read_pdf(save_file_name) + file_content = [doc.page_content for doc in documents] + # read docx file + elif ( + file.headers["content-type"] == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + or file.headers["content-type"] == "application/octet-stream" + ): + file_content = docx2txt.process(save_file_name) + + return file_content + + class Gateway: def __init__( self, @@ -365,39 +400,6 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888): megaservice, host, port, str(MegaServiceEndpoint.DOC_SUMMARY), ChatCompletionRequest, ChatCompletionResponse ) - def read_pdf(self, file): - from langchain.document_loaders import PyPDFLoader - - loader = PyPDFLoader(file) - docs = loader.load_and_split() - return docs - - def read_text_from_file(self, file, save_file_name): - import docx2txt - from langchain.text_splitter import CharacterTextSplitter - - # read text file - if file.headers["content-type"] == "text/plain": - file.file.seek(0) - content = file.file.read().decode("utf-8") - # Split text - text_splitter = CharacterTextSplitter() - texts = text_splitter.split_text(content) - # Create multiple documents - file_content = texts - # read pdf file - elif file.headers["content-type"] == "application/pdf": - documents = self.read_pdf(save_file_name) - file_content = [doc.page_content for doc in documents] - # read docx file - elif ( - file.headers["content-type"] == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - or file.headers["content-type"] == "application/octet-stream" - ): - file_content = docx2txt.process(save_file_name) - - return file_content - async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): data = await request.form() stream_opt = data.get("stream", True) @@ -411,7 +413,7 @@ async def handle_request(self, request: Request, files: List[UploadFile] = File( async with aiofiles.open(file_path, "wb") as f: await f.write(await file.read()) - docs = self.read_text_from_file(file, file_path) + docs = read_text_from_file(file, file_path) os.remove(file_path) if isinstance(docs, list): file_summaries.extend(docs) @@ -547,11 +549,31 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888): megaservice, host, port, str(MegaServiceEndpoint.FAQ_GEN), ChatCompletionRequest, ChatCompletionResponse ) - async def handle_request(self, request: Request): - data = await request.json() + async def handle_request(self, request: Request, files: List[UploadFile] = File(default=None)): + data = await request.form() stream_opt = data.get("stream", True) chat_request = ChatCompletionRequest.parse_obj(data) - prompt = self._handle_message(chat_request.messages) + file_summaries = [] + if files: + for file in files: + file_path = f"/tmp/{file.filename}" + + import aiofiles + + async with aiofiles.open(file_path, "wb") as f: + await f.write(await file.read()) + docs = read_text_from_file(file, file_path) + os.remove(file_path) + if isinstance(docs, list): + file_summaries.extend(docs) + else: + file_summaries.append(docs) + + if file_summaries: + prompt = self._handle_message(chat_request.messages) + "\n".join(file_summaries) + else: + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10,