From 89cf760293a4fb09a3c7e1a0b22b6736402ef638 Mon Sep 17 00:00:00 2001
From: Letong Han <106566639+letonghan@users.noreply.github.com>
Date: Fri, 27 Oct 2023 10:07:44 +0800
Subject: [PATCH] [NeuralChat] Add askdoc retrieval api & example (#514)
---
.../chatbot-inference-llama-2-7b-chat-hf.yml | 1 +
.../chatbot-inference-mpt-7b-chat.yml | 1 +
.../neural_chat/examples/askdoc/README.md | 79 ++++++++++
.../neural_chat/examples/askdoc/askdoc.py | 29 ++++
.../neural_chat/examples/askdoc/askdoc.yaml | 41 +++++
.../examples/askdoc/docs/test_doc.txt | 13 ++
.../neural_chat/examples/askdoc/run.sh | 38 +++++
.../neural_chat/models/base_model.py | 31 ++--
.../plugins/prompt/prompt_template.py | 14 ++
.../retrieval/detector/intent_detection.py | 2 +-
.../retrieval/indexing/context_utils.py | 37 ++++-
.../plugins/retrieval/indexing/indexing.py | 55 +++++++
.../plugins/retrieval/retrieval_agent.py | 41 ++++-
.../plugins/retrieval/retrieval_base.py | 4 +-
.../plugins/retrieval/retrieval_bm25.py | 4 +-
.../plugins/retrieval/retrieval_chroma.py | 4 +-
.../neural_chat/prompts/prompt.py | 19 ++-
.../neural_chat/requirements.txt | 3 +-
.../neural_chat/requirements_cpu.txt | 1 +
.../neural_chat/server/restful/request.py | 24 +++
.../server/restful/retrieval_api.py | 140 +++++++++++++++++-
.../server/restful/textchat_api.py | 2 +-
.../server/restful/voicechat_api.py | 2 +-
.../neural_chat/tests/api/test_inference.py | 3 +-
.../neural_chat/tests/api/test_rag.py | 2 +-
.../neural_chat/tests/requirements.txt | 3 +
.../neural_chat/tests/server/askdoc.yaml | 38 +++++
.../tests/server/askdoc/test_doc.txt | 13 ++
.../tests/server/test_askdoc_server.py | 63 ++++++++
.../neural_chat/utils/database/__init__.py | 16 ++
.../neural_chat/utils/database/config.py | 52 +++++++
.../utils/database/init_db_ai_photos.sql | 62 ++++++++
.../utils/database/init_db_askdoc.sql | 20 +++
.../neural_chat/utils/database/mysqldb.py | 76 ++++++++++
34 files changed, 887 insertions(+), 46 deletions(-)
create mode 100644 intel_extension_for_transformers/neural_chat/examples/askdoc/README.md
create mode 100644 intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.py
create mode 100644 intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.yaml
create mode 100644 intel_extension_for_transformers/neural_chat/examples/askdoc/docs/test_doc.txt
create mode 100644 intel_extension_for_transformers/neural_chat/examples/askdoc/run.sh
create mode 100644 intel_extension_for_transformers/neural_chat/tests/server/askdoc.yaml
create mode 100644 intel_extension_for_transformers/neural_chat/tests/server/askdoc/test_doc.txt
create mode 100644 intel_extension_for_transformers/neural_chat/tests/server/test_askdoc_server.py
create mode 100644 intel_extension_for_transformers/neural_chat/utils/database/__init__.py
create mode 100644 intel_extension_for_transformers/neural_chat/utils/database/config.py
create mode 100644 intel_extension_for_transformers/neural_chat/utils/database/init_db_ai_photos.sql
create mode 100644 intel_extension_for_transformers/neural_chat/utils/database/init_db_askdoc.sql
create mode 100644 intel_extension_for_transformers/neural_chat/utils/database/mysqldb.py
diff --git a/.github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml b/.github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml
index e577bffc50d..c134bda5b2e 100644
--- a/.github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml
+++ b/.github/workflows/chatbot-inference-llama-2-7b-chat-hf.yml
@@ -39,6 +39,7 @@ jobs:
pip uninstall intel-extension-for-transformers -y; \
pip install -r requirements.txt; \
python setup.py install; \
+ pip install -r intel_extension_for_transformers/neural_chat/requirements.txt; \
python workflows/chatbot/inference/generate.py --base_model_path \"meta-llama/Llama-2-7b-chat-hf\" --hf_access_token \"${{ env.HF_ACCESS_TOKEN }}\" --instructions \"Transform the following sentence into one that shows contrast. The tree is rotten.\" "
- name: Stop Container
diff --git a/.github/workflows/chatbot-inference-mpt-7b-chat.yml b/.github/workflows/chatbot-inference-mpt-7b-chat.yml
index 27237137960..e38c83f84e5 100644
--- a/.github/workflows/chatbot-inference-mpt-7b-chat.yml
+++ b/.github/workflows/chatbot-inference-mpt-7b-chat.yml
@@ -39,6 +39,7 @@ jobs:
pip uninstall intel-extension-for-transformers -y; \
pip install -r requirements.txt; \
python setup.py install; \
+ pip install -r intel_extension_for_transformers/neural_chat/requirements.txt; \
python workflows/chatbot/inference/generate.py --base_model_path \"mosaicml/mpt-7b-chat\" --instructions \"Transform the following sentence into one that shows contrast. The tree is rotten.\" "
- name: Stop Container
diff --git a/intel_extension_for_transformers/neural_chat/examples/askdoc/README.md b/intel_extension_for_transformers/neural_chat/examples/askdoc/README.md
new file mode 100644
index 00000000000..4bcdc0f7a7c
--- /dev/null
+++ b/intel_extension_for_transformers/neural_chat/examples/askdoc/README.md
@@ -0,0 +1,79 @@
+This README is intended to guide you through setting up the server for the AskDoc demo using the NeuralChat framework. You can deploy it on various platforms, including Intel XEON Scalable Processors, Habana's Gaudi processors (HPU), Intel Data Center GPU and Client GPU, Nvidia Data Center GPU and Client GPU.
+
+# Introduction
+The popularity of applications like ChatGPT has attracted many users seeking to address everyday problems. However, some users have encountered a challenge known as "model hallucination," where LLMs generate incorrect or nonexistent information, raising concerns about content accuracy. This example introduce our solution to build a retrieval-based chatbot backend server. Though few lines of code, our api could help the user build a local refernece database to enhance the accuracy of the generation results.
+
+Before deploying this example, please follow the instructions in the [README](../../README.md) to install the necessary dependencies.
+
+# Setup Environment
+
+## Setup Conda
+
+First, you need to install and configure the Conda environment:
+
+```shell
+# Download and install Miniconda
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+bash `Miniconda*.sh`
+source ~/.bashrc
+```
+
+## Install numactl
+
+Next, install the numactl library:
+
+```shell
+sudo apt install numactl
+```
+
+## Install Python dependencies
+
+Install the following Python dependencies using Conda:
+
+```shell
+conda install astunparse ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses -y
+conda install jemalloc gperftools -c conda-forge -y
+conda install git-lfs -y
+```
+
+Install other dependencies using pip:
+
+```bash
+pip install -r ../../../requirements.txt
+```
+
+
+## Download Models
+```shell
+git-lfs install
+git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
+```
+
+
+# Configure YAML
+
+You can customize the configuration file 'askdoc.yaml' to match your environment setup. Here's a table to help you understand the configurable options:
+
+| Item | Value |
+| --------------------------------- | ---------------------------------------|
+| host | 127.0.0.1 |
+| port | 8000 |
+| model_name_or_path | "./Llama-2-7b-chat-hf" |
+| device | "auto" |
+| retrieval.enable | true |
+| retrieval.args.input_path | "./docs" |
+| retrieval.args.persist_dir | "./example_persist" |
+| retrieval.args.response_template | "We cannot find suitable content to answer your query, please contact to find help." |
+| retrieval.args.append | True |
+| tasks_list | ['textchat', 'retrieval'] |
+
+
+# Run the AskDoc server
+The Neural Chat API offers an easy way to create and utilize chatbot models while integrating local documents. Our API simplifies the process of automatically handling and storing local documents in a document store. In this example, we use `./docs/test_doc.txt` for example. You can construct your own retrieval doc of Intel® oneAPI DPC++/C++ Compiler following [this link](https://www.intel.com/content/www/us/en/docs/dpcpp-cpp-compiler/developer-guide-reference/2023-2/overview.html).
+
+
+To start the PhotoAI server, run the following command:
+
+```shell
+nohup bash run.sh &
+```
diff --git a/intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.py b/intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.py
new file mode 100644
index 00000000000..55f2a10bfa2
--- /dev/null
+++ b/intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from intel_extension_for_transformers.neural_chat import NeuralChatServerExecutor
+
+def main():
+ server_executor = NeuralChatServerExecutor()
+ server_executor(
+ config_file="./askgm.yaml",
+ log_file="./askgm.log")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.yaml b/intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.yaml
new file mode 100644
index 00000000000..26ffe3e7e72
--- /dev/null
+++ b/intel_extension_for_transformers/neural_chat/examples/askdoc/askdoc.yaml
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This is the parameter configuration file for NeuralChat Serving.
+
+#################################################################################
+# SERVER SETTING #
+#################################################################################
+host: 127.0.0.1
+port: 8000
+
+model_name_or_path: "./Llama-2-7b-chat-hf"
+device: "auto"
+
+retrieval:
+ enable: true
+ args:
+ input_path: "./docs"
+ persist_dir: "./example_persist"
+ response_template: "We cannot find suitable content to answer your query, please contact AskGM to find help. Mail: ask.gm.zizhu@intel.com."
+ append: True
+
+safety_cheker:
+ enable: true
+
+tasks_list: ['textchat', 'retrieval']
+
diff --git a/intel_extension_for_transformers/neural_chat/examples/askdoc/docs/test_doc.txt b/intel_extension_for_transformers/neural_chat/examples/askdoc/docs/test_doc.txt
new file mode 100644
index 00000000000..7c725f1c9c3
--- /dev/null
+++ b/intel_extension_for_transformers/neural_chat/examples/askdoc/docs/test_doc.txt
@@ -0,0 +1,13 @@
+This guide provides information about the Intel® oneAPI DPC++/C++ Compiler and runtime environment. This document is valid for version 2024.0 of the compilers.
+
+The Intel® oneAPI DPC++/C++ Compiler is available as part of the Intel® oneAPI Base Toolkit, Intel® oneAPI HPC Toolkit, Intel® oneAPI IoT Toolkit, or as a standalone compiler.
+
+Refer to the Intel® oneAPI DPC++/C++ Compiler product page and the Release Notes for more information about features, specifications, and downloads.
+
+
+The compiler supports these key features:
+Intel® oneAPI Level Zero: The Intel® oneAPI Level Zero (Level Zero) Application Programming Interface (API) provides direct-to-metal interfaces to offload accelerator devices.
+OpenMP* Support: Compiler support for OpenMP 5.0 Version TR4 features and some OpenMP Version 5.1 features.
+Pragmas: Information about directives to provide the compiler with instructions for specific tasks, including splitting large loops into smaller ones, enabling or disabling optimization for code, or offloading computation to the target.
+Offload Support: Information about SYCL*, OpenMP, and parallel processing options you can use to affect optimization, code generation, and more.
+Latest Standards: Use the latest standards including C++ 20, SYCL, and OpenMP 5.0 and 5.1 for GPU offload.
\ No newline at end of file
diff --git a/intel_extension_for_transformers/neural_chat/examples/askdoc/run.sh b/intel_extension_for_transformers/neural_chat/examples/askdoc/run.sh
new file mode 100644
index 00000000000..d2905ba2875
--- /dev/null
+++ b/intel_extension_for_transformers/neural_chat/examples/askdoc/run.sh
@@ -0,0 +1,38 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Kill the exist and re-run
+ps -ef |grep 'askgm' |awk '{print $2}' |xargs kill -9
+
+# KMP
+export KMP_BLOCKTIME=1
+export KMP_SETTINGS=1
+export KMP_AFFINITY=granularity=fine,compact,1,0
+
+# OMP
+export OMP_NUM_THREADS=56
+export LD_PRELOAD=${CONDA_PREFIX}/lib/libiomp5.so
+
+# tc malloc
+export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
+
+# database
+export MYSQL_PASSWORD="root"
+export MYSQL_HOST="127.0.0.1"
+export MYSQL_DB="fastrag"
+
+numactl -l -C 0-55 askdoc -m askgm 2>&1 | tee run.log
diff --git a/intel_extension_for_transformers/neural_chat/models/base_model.py b/intel_extension_for_transformers/neural_chat/models/base_model.py
index d4cd6f650db..92e59eb6d40 100644
--- a/intel_extension_for_transformers/neural_chat/models/base_model.py
+++ b/intel_extension_for_transformers/neural_chat/models/base_model.py
@@ -148,6 +148,7 @@ def predict_stream(self, query, config=None):
query_include_prompt = True
# plugin pre actions
+ link = []
for plugin_name in get_registered_plugins():
if is_plugin_enabled(plugin_name):
plugin_instance = get_plugin_instance(plugin_name)
@@ -156,11 +157,13 @@ def predict_stream(self, query, config=None):
if plugin_name == "asr" and not is_audio_file(query):
continue
if plugin_name == "retrieval":
- response = plugin_instance.pre_llm_inference_actions(self.model_name, query)
+ response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
+ if response == "Response with template.":
+ return plugin_instance.response_template, link
else:
response = plugin_instance.pre_llm_inference_actions(query)
if plugin_name == "safety_checker" and response:
- return "Your query contains sensitive words, please try another query."
+ return "Your query contains sensitive words, please try another query.", link
else:
if response != None and response != False:
query = response
@@ -183,16 +186,7 @@ def is_generator(obj):
continue
response = plugin_instance.post_llm_inference_actions(response)
- # clear plugins config
- for key in plugins:
- plugins[key] = {
- "enable": False,
- "class": None,
- "args": {},
- "instance": None
- }
-
- return response
+ return response, link
def predict(self, query, config=None):
"""
@@ -230,7 +224,9 @@ def predict(self, query, config=None):
if plugin_name == "asr" and not is_audio_file(query):
continue
if plugin_name == "retrieval":
- response = plugin_instance.pre_llm_inference_actions(self.model_name, query)
+ response, link = plugin_instance.pre_llm_inference_actions(self.model_name, query)
+ if response == "Response with template.":
+ return plugin_instance.response_template
else:
response = plugin_instance.pre_llm_inference_actions(query)
if plugin_name == "safety_checker" and response:
@@ -253,15 +249,6 @@ def predict(self, query, config=None):
if hasattr(plugin_instance, 'post_llm_inference_actions'):
response = plugin_instance.post_llm_inference_actions(response)
- # clear plugins config
- for key in plugins:
- plugins[key] = {
- "enable": False,
- "class": None,
- "args": {},
- "instance": None
- }
-
return response
def chat_stream(self, query, config=None):
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py
index e657869be71..63f1d2baea9 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/prompt/prompt_template.py
@@ -37,6 +37,20 @@ def generate_qa_prompt(query, context=None, history=None):
conv.append_message(conv.roles[1], None)
return conv.get_prompt()
+def generate_qa_enterprise(query, context=None, history=None):
+ if context and history:
+ conv = PromptTemplate("rag_with_threshold")
+ conv.append_message(conv.roles[0], query)
+ conv.append_message(conv.roles[1], context)
+ conv.append_message(conv.roles[2], history)
+ conv.append_message(conv.roles[3], None)
+ else:
+ conv = PromptTemplate("rag_with_threshold")
+ conv.append_message(conv.roles[0], query)
+ conv.append_message(conv.roles[1], context)
+ conv.append_message(conv.roles[3], None)
+ return conv.get_prompt()
+
def generate_prompt(query, history=None):
if history:
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/detector/intent_detection.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/detector/intent_detection.py
index 078a557f329..cfff5382755 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/detector/intent_detection.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/detector/intent_detection.py
@@ -32,7 +32,7 @@ def intent_detection(self, model_name, query):
params["prompt"] = prompt
params["temperature"] = 0.001
params["top_k"] = 1
- params["max_new_tokens"] = 5
+ params["max_new_tokens"] = 10
intent = predict(**params)
return intent
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/context_utils.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/context_utils.py
index baf06af284b..92c8068c795 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/context_utils.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/context_utils.py
@@ -121,7 +121,6 @@ def load_xlsx(input):
df = pd.read_excel(input)
all_data = []
documents = []
-
for index, row in df.iterrows():
sub = "User Query: " + row['Questions'] + "Answer: " + row["Answers"]
all_data.append(sub)
@@ -134,6 +133,38 @@ def load_xlsx(input):
return documents
+def load_faq_xlsx(input):
+ """Load and process faq xlsx file."""
+ df = pd.read_excel(input)
+ all_data = []
+
+ for index, row in df.iterrows():
+ sub = "Question: " + row['question'] + " Answer: " + row["answer"]
+ sub = sub.replace('#', " ")
+ sub = sub.replace(r'\t', " ")
+ sub = sub.replace('\n', ' ')
+ sub = sub.replace('\n\n', ' ')
+ sub = re.sub(r'\s+', ' ', sub)
+ all_data.append([sub, row['link']])
+ return all_data
+
+
+def load_general_xlsx(input):
+ """Load and process doc xlsx file."""
+ df = pd.read_excel(input)
+ all_data = []
+
+ for index, row in df.iterrows():
+ sub = row['context']
+ sub = sub.replace('#', " ")
+ sub = sub.replace(r'\t', " ")
+ sub = sub.replace('\n', ' ')
+ sub = sub.replace('\n\n', ' ')
+ sub = re.sub(r'\s+', ' ', sub)
+ all_data.append([sub, row['link']])
+ return all_data
+
+
def load_unstructured_data(input):
"""Load unstructured context."""
if input.endswith("pdf"):
@@ -158,6 +189,10 @@ def laod_structured_data(input, process, max_length):
"""Load structured context."""
if input.endswith("jsonl"):
content = load_json(input, process, max_length)
+ elif "faq" in input and input.endswith("xlsx"):
+ content = load_faq_xlsx(input)
+ elif "enterprise_docs" in input and input.endswith("xlsx"):
+ content = load_general_xlsx(input)
else:
content = load_xlsx(input)
return content
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/indexing.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/indexing.py
index dd32aa0636f..c27bb401cdc 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/indexing.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/indexing/indexing.py
@@ -150,3 +150,58 @@ def KB_construct(self, input):
else:
print("There might be some errors, please wait and try again!")
+
+ def KB_append(self, input): ### inmemory documentstore please use KB construct
+ if self.retrieval_type == "dense":
+ if os.path.exists(input):
+ if os.path.isfile(input):
+ data_collection = self.parse_document(input)
+ elif os.path.isdir(input):
+ data_collection = self.batch_parse_document(input)
+ else:
+ print("Please check your upload file and try again!")
+
+ documents = []
+ for data, meta in data_collection:
+ if len(data) < 5:
+ continue
+ metadata = {"source": meta}
+ new_doc = Document(page_content=data, metadata=metadata)
+ documents.append(new_doc)
+ assert documents != [], "The given file/files cannot be loaded."
+ embedding = HuggingFaceInstructEmbeddings(model_name=self.embedding_model)
+ vectordb = Chroma.from_documents(documents=documents, embedding=embedding,
+ persist_directory=self.persist_dir)
+ vectordb.persist()
+ print("The local knowledge base has been successfully built!")
+ return Chroma(persist_directory=self.persist_dir, embedding_function=embedding)
+ else:
+ print("There might be some errors, please wait and try again!")
+ else:
+ if os.path.exists(input):
+ if os.path.isfile(input):
+ data_collection = self.parse_document(input)
+ elif os.path.isdir(input):
+ data_collection = self.batch_parse_document(input)
+ else:
+ print("Please check your upload file and try again!")
+
+ if self.document_store == "Elasticsearch":
+ document_store = ElasticsearchDocumentStore(host="localhost", index=self.index_name,
+ port=9200, search_fields=["content", "title"])
+ documents = []
+ for data, meta in data_collection:
+ metadata = {"source": meta}
+ if len(data) < 5:
+ continue
+ new_doc = SDocument(content=data, meta=metadata)
+ documents.append(new_doc)
+ assert documents != [], "The given file/files cannot be loaded."
+ document_store.write_documents(documents)
+ print("The local knowledge base has been successfully built!")
+ return ElasticsearchDocumentStore(host="localhost", index=self.index_name,
+ port=9200, search_fields=["content", "title"])
+ else:
+ print("Unsupported document store type, please change to Elasticsearch!")
+ else:
+ print("There might be some errors, please wait and try again!")
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py
index 3b8c16898c0..cb4995fab02 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py
@@ -20,20 +20,24 @@
from .detector.intent_detection import IntentDetector
from .indexing.indexing import DocumentIndexing
from intel_extension_for_transformers.neural_chat.pipeline.plugins.prompt.prompt_template \
- import generate_qa_prompt, generate_prompt
+ import generate_qa_prompt, generate_prompt, generate_qa_enterprise
class Agent_QA():
def __init__(self, persist_dir="./output", process=True, input_path=None,
embedding_model="hkunlp/instructor-large", max_length=2048, retrieval_type="dense",
document_store=None, top_k=1, search_type="mmr", search_kwargs={"k": 1, "fetch_k": 5},
- append=True, index_name="elastic_index_1",
- asset_path="/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets"):
+ append=True, index_name="elastic_index_1", append_path=None,
+ response_template = "Please reformat your query to regenerate the answer.",
+ asset_path="/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets",):
self.model = None
self.tokenizer = None
self.retrieval_type = retrieval_type
self.retriever = None
+ self.append_path = append_path
self.intent_detector = IntentDetector()
script_dir = os.path.dirname(os.path.abspath(__file__))
+ self.response_template = response_template
+ self.search_type = search_type
if os.path.exists(input_path):
self.input_path = input_path
@@ -69,23 +73,44 @@ def __init__(self, persist_dir="./output", process=True, input_path=None,
embedding_model=embedding_model, max_length=max_length,
index_name = index_name)
self.db = self.doc_parser.KB_construct(self.input_path)
+ else:
+ self.doc_parser = DocumentIndexing(retrieval_type=self.retrieval_type,
+ document_store=document_store,
+ persist_dir=persist_dir, process=process,
+ embedding_model=embedding_model, max_length=max_length,
+ index_name = index_name)
+ self.db = self.doc_parser.KB_construct(self.input_path)
self.retriever = Retriever(retrieval_type=self.retrieval_type, document_store=self.db, top_k=top_k,
search_type=search_type, search_kwargs=search_kwargs)
+ def append_localdb(self,
+ append_path,
+ top_k=1,
+ search_type="similarity_score_threshold",
+ search_kwargs={"score_threshold": 0.9, "k": 1}):
+ self.db = self.doc_parser.KB_append(append_path)
+ self.retriever = Retriever(retrieval_type=self.retrieval_type, document_store=self.db, top_k=top_k,
+ search_type=search_type, search_kwargs=search_kwargs)
+
def pre_llm_inference_actions(self, model_name, query):
intent = self.intent_detector.intent_detection(model_name, query)
-
+ links = []
+ docs = []
if 'qa' not in intent.lower():
print("Chat with AI Agent.")
prompt = generate_prompt(query)
else:
print("Chat with QA agent.")
if self.retriever:
- context = self.retriever.get_context(query)
- prompt = generate_qa_prompt(query, context)
+ context, links = self.retriever.get_context(query)
+ if len(context) == 0:
+ return "Response with template.", links
+ if self.search_type == "similarity_score_threshold":
+ prompt = generate_qa_enterprise(query, context)
+ else:
+ prompt = generate_qa_prompt(query, context)
else:
prompt = generate_prompt(query)
- return prompt
-
+ return prompt, links
\ No newline at end of file
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_base.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_base.py
index 76d4b77ce39..a8dc83fb2ec 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_base.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_base.py
@@ -36,5 +36,5 @@ def __init__(self, retrieval_type="dense", document_store=None,
self.retriever = SparseBM25Retriever(document_store=document_store, top_k=top_k)
def get_context(self, query):
- context = self.retriever.query_the_database(query)
- return context
+ context, links = self.retriever.query_the_database(query)
+ return context, links
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_bm25.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_bm25.py
index 67e5336f2e3..620feddc49d 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_bm25.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_bm25.py
@@ -27,6 +27,8 @@ def __init__(self, document_store = None, top_k = 1):
def query_the_database(self, query):
documents = self.retriever.retrieve(query)
context = ""
+ links = []
for doc in documents:
context = context + doc.content + " "
- return context.strip()
+ links.append(doc.meta)
+ return context.strip(), links
diff --git a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_chroma.py b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_chroma.py
index 948d11ed362..41d0886864b 100644
--- a/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_chroma.py
+++ b/intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_chroma.py
@@ -29,6 +29,8 @@ def __init__(self, database=None, search_type="mmr", search_kwargs={"k": 1, "fet
def query_the_database(self, query):
documents = self.retriever.get_relevant_documents(query)
context = ""
+ links = []
for doc in documents:
context = context + doc.page_content + " "
- return context.strip()
+ links.append(doc.metadata)
+ return context.strip(), links
diff --git a/intel_extension_for_transformers/neural_chat/prompts/prompt.py b/intel_extension_for_transformers/neural_chat/prompts/prompt.py
index 9fd859e525a..20140295ffa 100644
--- a/intel_extension_for_transformers/neural_chat/prompts/prompt.py
+++ b/intel_extension_for_transformers/neural_chat/prompts/prompt.py
@@ -115,14 +115,29 @@
)
)
+
+# Rag with threshold
+register_conv_template(
+ Conversation(
+ name="rag_with_threshold",
+ system_message="You are served as an AI agent to help the user complete a task." + \
+ " You are required to comprehend the usr query and then use the given context to" + \
+ " generate a suitable response.\n\n",
+ roles=("### User Query: ", "### Context: ", "### Chat History: ", "### Response: "),
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
+ sep="\n",
+ )
+)
+
+
# Intent template
register_conv_template(
Conversation(
name="intent",
- system_message="Please identify the intent of the provided context." + \
+ system_message="Please identify the intent of the user query." + \
" You may only respond with \"chitchat\" or \"QA\" without explanations" + \
" or engaging in conversation.\n",
- roles=("Context:", "Intent:"),
+ roles=("### User Query: ", "### Response: "),
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="\n",
)
diff --git a/intel_extension_for_transformers/neural_chat/requirements.txt b/intel_extension_for_transformers/neural_chat/requirements.txt
index 86ebb418392..bd4583747cd 100644
--- a/intel_extension_for_transformers/neural_chat/requirements.txt
+++ b/intel_extension_for_transformers/neural_chat/requirements.txt
@@ -40,4 +40,5 @@ lm_eval
accelerate
cchardet
spacy
-neural-compressor
\ No newline at end of file
+neural-compressor
+pymysql
\ No newline at end of file
diff --git a/intel_extension_for_transformers/neural_chat/requirements_cpu.txt b/intel_extension_for_transformers/neural_chat/requirements_cpu.txt
index dcb68f13a3b..ed4c9e825d8 100644
--- a/intel_extension_for_transformers/neural_chat/requirements_cpu.txt
+++ b/intel_extension_for_transformers/neural_chat/requirements_cpu.txt
@@ -41,3 +41,4 @@ torch==2.1.0
torchaudio==2.1.0
spacy
neural-compressor
+pymysql
\ No newline at end of file
diff --git a/intel_extension_for_transformers/neural_chat/server/restful/request.py b/intel_extension_for_transformers/neural_chat/server/restful/request.py
index 2eadbdde490..f249223c6d0 100644
--- a/intel_extension_for_transformers/neural_chat/server/restful/request.py
+++ b/intel_extension_for_transformers/neural_chat/server/restful/request.py
@@ -47,3 +47,27 @@ class FinetuneRequest(RequestBaseModel):
overwrite_output_dir: bool = True
dataset_concatenation: bool = False
peft: str = 'lora'
+
+
+class AskDocRequest(RequestBaseModel):
+ query: str
+ domain: str
+ blob: Optional[str]
+ filename: Optional[str]
+ knowledge_base_id: Optional[str] = 'default'
+ embedding: Optional[str] = 'dense'
+ params: Optional[dict] = None
+ debug: Optional[bool] = False
+ stream: bool = True
+
+
+class FeedbackRequest(RequestBaseModel):
+ """
+ Request class for feedback api
+ 'feedback_id': set to be auto_increment, no need to pass as argument
+ 'feedback': 0 for 'like', 1 for 'dislike'
+ """
+ # feedback_id: Optional[int] = None
+ question: str
+ answer: str
+ feedback: Optional[int] = 0
\ No newline at end of file
diff --git a/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py b/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py
index de989a6301a..4f863eb787a 100644
--- a/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py
+++ b/intel_extension_for_transformers/neural_chat/server/restful/retrieval_api.py
@@ -15,12 +15,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import traceback
+import io
+import re
+import csv
+import datetime
from typing import Optional, Dict
-from fastapi import APIRouter
+from fastapi import APIRouter, UploadFile, File
+from ...config import GenerationConfig
from ...cli.log import logger
-from ...server.restful.request import RetrievalRequest
+from ...server.restful.request import RetrievalRequest, AskDocRequest, FeedbackRequest
from ...server.restful.response import RetrievalResponse
+from fastapi.responses import StreamingResponse
+from ...utils.database.mysqldb import MysqlDb
+from ...plugins import plugins
def check_retrieval_params(request: RetrievalRequest) -> Optional[str]:
@@ -60,3 +67,130 @@ async def retrieval_endpoint(request: RetrievalRequest) -> RetrievalResponse:
if ret is not None:
raise RuntimeError("Invalid parametery.")
return await router.handle_retrieval_request(request)
+
+
+@router.post("/v1/askdoc/upload")
+async def retrieval_upload(file: UploadFile = File(...)):
+ global plugins
+ filename = file.filename
+ path_prefix = "/home/sdp/askdoc_upload/enterprise_docs/"
+ print(f"[askdoc - upload] filename: {filename}")
+ if '/' in filename:
+ filename = filename.split('/')[-1]
+ with open(f"{path_prefix+filename}", 'wb') as fout:
+ content = await file.read()
+ fout.write(content),
+ print("[askdoc - upload] file saved to local path.")
+
+ try:
+ print("[askdoc - upload] starting to append local db...")
+ instance = plugins['retrieval']["instance"]
+ instance.append_localdb(append_path=path_prefix)
+ print(f"[askdoc - upload] kb appended successfully")
+ except Exception as e:
+ logger.info(f"[askdoc - upload] create knowledge base failes! {e}")
+ return "Error occurred while uploading files."
+ fake_kb_id = "fake_knowledge_base_id"
+ return {"knowledge_base_id": fake_kb_id}
+
+
+@router.post("/v1/askdoc/chat")
+async def retrieval_chat(request: AskDocRequest):
+ chatbot = router.get_chatbot()
+
+ logger.info(f"[askdoc - chat] Predicting chat completion using kb '{request.knowledge_base_id}'")
+ logger.info(f"[askdoc - chat] Predicting chat completion using prompt '{request.query}'")
+ config = GenerationConfig()
+ # Set attributes of the config object from the request
+ for attr, value in request.__dict__.items():
+ if attr == "stream":
+ continue
+ setattr(config, attr, value)
+ generator, link = chatbot.predict_stream(query=request.query, config=config)
+ logger.info(f"[askdoc - chat] chatbot predicted: {generator}")
+ if isinstance(generator, str):
+ def stream_generator():
+ yield f"data: {generator}\n\n"
+ yield f"data: [DONE]\n\n"
+ else:
+ def stream_generator():
+ for output in generator:
+ ret = {
+ "text": output,
+ "error_code": 0,
+ }
+ logger.info(f"[askdoc - chat] {ret}")
+ res = re.match("(http|https|ftp)://[^\s]+", output)
+ if res != None:
+ formatted_link = f''
+ yield f"data: {formatted_link}\n\n"
+ else:
+ formatted_str = ret['text'].replace('\n', '
')
+ yield f"data: {formatted_str}\n\n"
+ if link != []:
+ yield f"data: