From 19330ea23f8a2d5e591563bd48e6809902be7672 Mon Sep 17 00:00:00 2001 From: rbrugaro Date: Tue, 29 Oct 2024 22:45:44 -0700 Subject: [PATCH] GraphRAG with llama-index (#793) * graphRAG dataprep llama-index validated w openai endpoints Signed-off-by: rbrugaro * llama-index graphRAG retrieval validated with openai models Signed-off-by: rbrugaro * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * validated code usng TGI/TEI instead of openai Signed-off-by: Rita Brugarolas * compose.yaml for dataprep validated with neo4j, TGI/TEI, openai Signed-off-by: Rita Brugarolas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * graphRAG retriever validated and full compose.yaml Signed-off-by: Rita Brugarolas * minor fix Signed-off-by: Rita Brugarolas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add graphRAGGateway WIP Signed-off-by: Rita Brugarolas * graphragGateway working in E2E Example Signed-off-by: Rita Brugarolas * fix schedule in orchestrator to support ChatCompletionRequest input Signed-off-by: Rita Brugarolas * change default to TGI instead of openAI and add test code for neo4jretriever Signed-off-by: Rita Brugarolas * test code for dataprep-neo4j microservice Signed-off-by: Rita Brugarolas * improved READMES Signed-off-by: Rita Brugarolas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update docker build path for tests Signed-off-by: Rita Brugarolas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix Signed-off-by: Rita Brugarolas * fix typo in container name Signed-off-by: Rita Brugarolas * resolve image name conflict for hub publishing Signed-off-by: Rita Brugarolas * add tgi validation to miicroservice tests Signed-off-by: Rita Brugarolas * rename test script to trigger cicd correctly w hpu Signed-off-by: Rita Brugarolas * rename test containers for cicd Signed-off-by: Rita Brugarolas * update HF_TOKEN in TGI/TEI test scripts Signed-off-by: Rita Brugarolas * swaped larger model so the graph isnt empty in ci test Signed-off-by: Rita Brugarolas * set 4 hpu for 70B model in ci test Signed-off-by: Rita Brugarolas * add extra time for large model loading cicd Signed-off-by: Rita Brugarolas * fix tgi gaudi shard args Signed-off-by: Rita Brugarolas * switch to chat cause chat template is needed Signed-off-by: Rita Brugarolas * enable logs in test Signed-off-by: Rita Brugarolas * use locally downloaded model in CI machine Signed-off-by: Rita Brugarolas * use local model path and reduce wait time Signed-off-by: Rita Brugarolas * clear ports before ci run Signed-off-by: Rita Brugarolas * fix cache model access Signed-off-by: Rita Brugarolas * fix cache model access Signed-off-by: Rita Brugarolas * incrased wait time for tgi shards ready Signed-off-by: Rita Brugarolas * wait until tgi connected Signed-off-by: Rita Brugarolas * switch back to small model for testing Signed-off-by: Rita Brugarolas * minor readability fixes Signed-off-by: Rita Brugarolas * README fixes Signed-off-by: Rita Brugarolas --------- Signed-off-by: rbrugaro Signed-off-by: Rita Brugarolas Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../docker/compose/dataprep-compose-cd.yaml | 4 + .../docker/compose/retrievers-compose-cd.yaml | 4 + comps/__init__.py | 1 + comps/cores/mega/constants.py | 1 + comps/cores/mega/gateway.py | 75 +++ comps/dataprep/neo4j/__init__.py | 2 + comps/dataprep/neo4j/llama_index/Dockerfile | 39 ++ comps/dataprep/neo4j/llama_index/README.md | 94 +++ comps/dataprep/neo4j/llama_index/__init__.py | 2 + comps/dataprep/neo4j/llama_index/compose.yaml | 97 +++ comps/dataprep/neo4j/llama_index/config.py | 18 + .../neo4j/llama_index/extract_graph_neo4j.py | 593 ++++++++++++++++++ .../neo4j/llama_index/requirements.txt | 38 ++ comps/dataprep/neo4j/llama_index/set_env.sh | 18 + comps/retrievers/neo4j/llama_index/Dockerfile | 35 ++ comps/retrievers/neo4j/llama_index/README.md | 65 ++ .../retrievers/neo4j/llama_index/__init__.py | 2 + .../retrievers/neo4j/llama_index/compose.yaml | 124 ++++ comps/retrievers/neo4j/llama_index/config.py | 18 + .../neo4j/llama_index/requirements.txt | 36 ++ .../retriever_community_answers_neo4j.py | 231 +++++++ comps/retrievers/neo4j/llama_index/set_env.sh | 19 + ...dataprep_neo4j_llama_index_on_intel_hpu.sh | 162 +++++ ...trievers_neo4j_llama_index_on_intel_hpu.sh | 192 ++++++ 24 files changed, 1870 insertions(+) create mode 100644 comps/dataprep/neo4j/__init__.py create mode 100644 comps/dataprep/neo4j/llama_index/Dockerfile create mode 100644 comps/dataprep/neo4j/llama_index/README.md create mode 100644 comps/dataprep/neo4j/llama_index/__init__.py create mode 100644 comps/dataprep/neo4j/llama_index/compose.yaml create mode 100644 comps/dataprep/neo4j/llama_index/config.py create mode 100644 comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py create mode 100644 comps/dataprep/neo4j/llama_index/requirements.txt create mode 100644 comps/dataprep/neo4j/llama_index/set_env.sh create mode 100644 comps/retrievers/neo4j/llama_index/Dockerfile create mode 100644 comps/retrievers/neo4j/llama_index/README.md create mode 100644 comps/retrievers/neo4j/llama_index/__init__.py create mode 100644 comps/retrievers/neo4j/llama_index/compose.yaml create mode 100644 comps/retrievers/neo4j/llama_index/config.py create mode 100644 comps/retrievers/neo4j/llama_index/requirements.txt create mode 100644 comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py create mode 100644 comps/retrievers/neo4j/llama_index/set_env.sh create mode 100755 tests/dataprep/test_dataprep_neo4j_llama_index_on_intel_hpu.sh create mode 100755 tests/retrievers/test_retrievers_neo4j_llama_index_on_intel_hpu.sh diff --git a/.github/workflows/docker/compose/dataprep-compose-cd.yaml b/.github/workflows/docker/compose/dataprep-compose-cd.yaml index b7589a12c5..61f9a92da1 100644 --- a/.github/workflows/docker/compose/dataprep-compose-cd.yaml +++ b/.github/workflows/docker/compose/dataprep-compose-cd.yaml @@ -27,3 +27,7 @@ services: build: dockerfile: comps/dataprep/neo4j/langchain/Dockerfile image: ${REGISTRY:-opea}/dataprep-neo4j:${TAG:-latest} + dataprep-neo4j-llamaindex: + build: + dockerfile: comps/dataprep/neo4j/llama_index/Dockerfile + image: ${REGISTRY:-opea}/dataprep-neo4j-llamaindex:${TAG:-latest} diff --git a/.github/workflows/docker/compose/retrievers-compose-cd.yaml b/.github/workflows/docker/compose/retrievers-compose-cd.yaml index 8f4de8bbf4..5b1718cf66 100644 --- a/.github/workflows/docker/compose/retrievers-compose-cd.yaml +++ b/.github/workflows/docker/compose/retrievers-compose-cd.yaml @@ -27,3 +27,7 @@ services: build: dockerfile: comps/retrievers/neo4j/langchain/Dockerfile image: ${REGISTRY:-opea}/retriever-neo4j:${TAG:-latest} + retriever-neo4j-llamaindex: + build: + dockerfile: comps/retrievers/neo4j/llama_index/Dockerfile + image: ${REGISTRY:-opea}/retriever-neo4j-llamaindex:${TAG:-latest} diff --git a/comps/__init__.py b/comps/__init__.py index 34cfe8d18c..153acad497 100644 --- a/comps/__init__.py +++ b/comps/__init__.py @@ -59,6 +59,7 @@ VideoQnAGateway, VisualQnAGateway, MultimodalQnAGateway, + GraphragGateway, AvatarChatbotGateway, ) diff --git a/comps/cores/mega/constants.py b/comps/cores/mega/constants.py index a0523daba3..e657ba6f45 100644 --- a/comps/cores/mega/constants.py +++ b/comps/cores/mega/constants.py @@ -51,6 +51,7 @@ class MegaServiceEndpoint(Enum): TRANSLATION = "/v1/translation" RETRIEVALTOOL = "/v1/retrievaltool" FAQ_GEN = "/v1/faqgen" + GRAPH_RAG = "/v1/graphrag" # Follow OPENAI EMBEDDINGS = "/v1/embeddings" TTS = "/v1/audio/speech" diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 021e85d111..3c57928434 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -156,9 +156,12 @@ def __init__(self, megaservice, host="0.0.0.0", port=8888): async def handle_request(self, request: Request): data = await request.json() + print("data in handle request", data) stream_opt = data.get("stream", True) chat_request = ChatCompletionRequest.parse_obj(data) + print("chat request in handle request", chat_request) prompt = self._handle_message(chat_request.messages) + print("prompt in gateway", prompt) parameters = LLMParams( max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, top_k=chat_request.top_k if chat_request.top_k else 10, @@ -959,3 +962,75 @@ async def handle_request(self, request: Request): last_node = runtime_graph.all_leaves()[-1] response = result_dict[last_node]["video_path"] return response + + +class GraphragGateway(Gateway): + def __init__(self, megaservice, host="0.0.0.0", port=8888): + super().__init__( + megaservice, host, port, str(MegaServiceEndpoint.GRAPH_RAG), ChatCompletionRequest, ChatCompletionResponse + ) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", True) + chat_request = ChatCompletionRequest.parse_obj(data) + + def parser_input(data, TypeClass, key): + chat_request = None + try: + chat_request = TypeClass.parse_obj(data) + query = getattr(chat_request, key) + except: + query = None + return query, chat_request + + query = None + for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): + query, chat_request = parser_input(data, TypeClass, key) + if query is not None: + break + if query is None: + raise ValueError(f"Unknown request type: {data}") + if chat_request is None: + raise ValueError(f"Unknown request type: {data}") + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, + streaming=stream_opt, + chat_template=chat_request.chat_template if chat_request.chat_template else None, + ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + initial_inputs = chat_request + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs=initial_inputs, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + ) + for node, response in result_dict.items(): + if isinstance(response, StreamingResponse): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) diff --git a/comps/dataprep/neo4j/__init__.py b/comps/dataprep/neo4j/__init__.py new file mode 100644 index 0000000000..916f3a44b2 --- /dev/null +++ b/comps/dataprep/neo4j/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/comps/dataprep/neo4j/llama_index/Dockerfile b/comps/dataprep/neo4j/llama_index/Dockerfile new file mode 100644 index 0000000000..77f912ed12 --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/Dockerfile @@ -0,0 +1,39 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.11-slim + +ENV LANG=C.UTF-8 + +ARG ARCH="cpu" + +RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \ + build-essential \ + default-jre \ + libgl1-mesa-glx \ + libjemalloc-dev \ + vim + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +USER user + +COPY comps /home/user/comps + +RUN pip install --no-cache-dir --upgrade pip setuptools && \ + if [ ${ARCH} = "cpu" ]; then pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu; fi && \ + pip install --no-cache-dir -r /home/user/comps/dataprep/neo4j/llama_index/requirements.txt + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +USER root + +RUN mkdir -p /home/user/comps/dataprep/neo4j/llama_index/uploaded_files && chown -R user /home/user/comps/dataprep/neo4j/llama_index/uploaded_files + +USER user + +WORKDIR /home/user/comps/dataprep/neo4j/llama_index + +ENTRYPOINT ["python", "extract_graph_neo4j.py"] diff --git a/comps/dataprep/neo4j/llama_index/README.md b/comps/dataprep/neo4j/llama_index/README.md new file mode 100644 index 0000000000..d6de9dabd1 --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/README.md @@ -0,0 +1,94 @@ +# Dataprep Microservice with Neo4J + +This dataprep microservice ingests the input files and uses LLM (TGI or OpenAI model when OPENAI_API_KEY is set) to extract entities, relationships and descriptions of those to build a graph-based text index. + +### Setup Environment Variables + +```bash +# Manually set private environment settings +export host_ip=${your_hostname IP} # local IP +export no_proxy=$no_proxy,${host_ip} # important to add {host_ip} for containers communication +export http_proxy=${your_http_proxy} +export https_proxy=${your_http_proxy} +export NEO4J_URI=${your_neo4j_url} +export NEO4J_USERNAME=${your_neo4j_username} +export NEO4J_PASSWORD=${your_neo4j_password} # should match what was used in NEO4J_AUTH when running the neo4j-apoc +export PYTHONPATH=${path_to_comps} +export OPENAI_KEY=${your_openai_api_key} # optional, when not provided will use smaller models TGI/TEI +export HUGGINGFACEHUB_API_TOKEN=${your_hf_token} +# set additional environment settings +source ./set_env.sh +``` + +## 🚀Start Microservice with Docker + +### 1. Build Docker Image + +```bash +cd ../../../../ +docker build -t opea/dataprep-neo4j-llamaindex:latest --build-arg no_proxy=$no_proxy --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/dataprep/neo4j/llama_index/Dockerfile . +``` + +### 2. Setup Environment Variables + +```bash +# Set private environment settings +export host_ip=${your_hostname IP} # local IP +export no_proxy=$no_proxy,${host_ip} # important to add {host_ip} for containers communication +export http_proxy=${your_http_proxy} +export https_proxy=${your_http_proxy} +export NEO4J_URI=${your_neo4j_url} +export NEO4J_USERNAME=${your_neo4j_username} +export NEO4J_PASSWORD=${your_neo4j_password} +export PYTHONPATH=${path_to_comps} +export OPENAI_KEY=${your_openai_api_key} # optional, when not provided will use smaller models TGI/TEI +export HUGGINGFACEHUB_API_TOKEN=${your_hf_token} +# set additional environment settings +source ./set_env.sh +``` + +### 3. Run Docker with Docker Compose + +Docker compose will start 4 microservices: dataprep-neo4j-llamaindex, neo4j-apoc, tgi-gaudi-service and tei-embedding-service. The reason TGI and TEI are needed is because dataprep relies on LLM to extract entities and relationships from text to build the graph and Neo4j Property Graph Index. Neo4j database supports embeddings natively so we do not need a separate vector store. Checkout the blog [Introducing the Property Graph Index: A Powerful New Way to Build Knowledge Graphs with LLMs](https://www.llamaindex.ai/blog/introducing-the-property-graph-index-a-powerful-new-way-to-build-knowledge-graphs-with-llms) for a better understanding of Property Graph Store and Index. + +```bash +cd comps/dataprep/neo4j/llama_index +docker compose -f compose.yaml up -d +``` + +## Invoke Microservice + +Once document preparation microservice for Neo4J is started, user can use below command to invoke the microservice to convert the document to embedding and save to the database. + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./file1.txt" \ + http://${host_ip}:6004/v1/dataprep +``` + +You can specify chunk_size and chunk_size by the following commands. + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./file1.txt" \ + -F "chunk_size=1500" \ + -F "chunk_overlap=100" \ + http://${host_ip}:6004/v1/dataprep +``` + +We support table extraction from pdf documents. You can specify process_table and table_strategy by the following commands. "table_strategy" refers to the strategies to understand tables for table retrieval. As the setting progresses from "fast" to "hq" to "llm," the focus shifts towards deeper table understanding at the expense of processing speed. The default strategy is "fast". + +Note: If you specify "table_strategy=llm" TGI service will be used. + +For ensure the quality and comprehensiveness of the extracted entities, we recommend to use `gpt-4o` as the default model for parsing the document. To enable the openai service, please `export OPENAI_KEY=xxxx` before using this services. + +```bash +curl -X POST \ + -H "Content-Type: multipart/form-data" \ + -F "files=@./your_file.pdf" \ + -F "process_table=true" \ + -F "table_strategy=hq" \ + http://localhost:6004/v1/dataprep +``` diff --git a/comps/dataprep/neo4j/llama_index/__init__.py b/comps/dataprep/neo4j/llama_index/__init__.py new file mode 100644 index 0000000000..916f3a44b2 --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/comps/dataprep/neo4j/llama_index/compose.yaml b/comps/dataprep/neo4j/llama_index/compose.yaml new file mode 100644 index 0000000000..ac160f6997 --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/compose.yaml @@ -0,0 +1,97 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +version: "3.8" +services: + neo4j-apoc: + image: neo4j:latest + container_name: neo4j-apoc + volumes: + - /$HOME/neo4j/logs:/logs + - /$HOME/neo4j/config:/config + - /$HOME/neo4j/data:/data + - /$HOME/neo4j/plugins:/plugins + ipc: host + environment: + - NEO4J_AUTH=${NEO4J_USERNAME}/${NEO4J_PASSWORD} + - NEO4J_PLUGINS=["apoc"] + - NEO4J_apoc_export_file_enabled=true + - NEO4J_apoc_import_file_enabled=true + - NEO4J_apoc_import_file_use__neo4j__config=true + - NEO4J_dbms_security_procedures_unrestricted=apoc.\* + ports: + - "7474:7474" + - "7687:7687" + restart: always + tei-embedding-service: + image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 + container_name: tei-embedding-server + ports: + - "6006:80" + volumes: + - "./data:/data" + shm_size: 1g + environment: + no_proxy: ${no_proxy} + NO_PROXY: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HUGGING_FACE_HUB_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} + ipc: host + command: --model-id ${EMBEDDING_MODEL_ID} --auto-truncate + tgi-gaudi-service: + image: ghcr.io/huggingface/tgi-gaudi:2.0.5 + container_name: tgi-gaudi-server + ports: + - "6005:80" + volumes: + - "./data:/data" + environment: + no_proxy: ${no_proxy} + NO_PROXY: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HUGGING_FACE_HUB_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} + HF_HUB_DISABLE_PROGRESS_BARS: 1 + HF_HUB_ENABLE_HF_TRANSFER: 0 + HABANA_VISIBLE_DEVICES: all + OMPI_MCA_btl_vader_single_copy_mechanism: none + ENABLE_HPU_GRAPH: true + LIMIT_HPU_GRAPH: true + USE_FLASH_ATTENTION: true + FLASH_ATTENTION_RECOMPUTE: true + runtime: habana + cap_add: + - SYS_NICE + ipc: host + command: --model-id ${LLM_MODEL_ID} --max-input-length 2048 --max-total-tokens 4096 + dataprep-neo4j-llamaindex: + image: opea/dataprep-neo4j-llamaindex:latest + container_name: dataprep-neo4j-server + depends_on: + - neo4j-apoc + - tgi-gaudi-service + - tei-embedding-service + ports: + - "6004:6004" + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + host_ip: ${host_ip} + NEO4J_URL: ${NEO4J_URL} + NEO4J_USERNAME: ${NEO4J_USERNAME} + NEO4J_PASSWORD: ${NEO4J_PASSWORD} + TGI_LLM_ENDPOINT: ${TGI_LLM_ENDPOINT} + TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} + OPENAI_API_KEY: ${OPENAI_API_KEY} + OPENAI_EMBEDDING_MODEL: ${OPENAI_EMBEDDING_MODEL} + OPENAI_LLM_MODEL: ${OPENAI_LLM_MODEL} + EMBEDDING_MODEL_ID: ${EMBEDDING_MODEL_ID} + LLM_MODEL_ID: ${LLM_MODEL_ID} + LOGFLAG: ${LOGFLAG} + restart: unless-stopped +networks: + default: + driver: bridge diff --git a/comps/dataprep/neo4j/llama_index/config.py b/comps/dataprep/neo4j/llama_index/config.py new file mode 100644 index 0000000000..3037b8f9fb --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/config.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +host_ip = os.getenv("host_ip") +# Neo4J configuration +NEO4J_URL = os.getenv("NEO4J_URL", f"bolt://{host_ip}:7687") +NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j") +NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4jtest") + +# LLM/Embedding endpoints +TGI_LLM_ENDPOINT = os.getenv("TGI_LLM_ENDPOINT", f"http://{host_ip}:6005") +TEI_EMBEDDING_ENDPOINT = os.getenv("TEI_EMBEDDING_ENDPOINT ", f"http://{host_ip}:6006") + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") +OPENAI_LLM_MODEL = os.getenv("OPENAI_LLM_MODEL", "gpt-4o") diff --git a/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py b/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py new file mode 100644 index 0000000000..db5c5151f7 --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/extract_graph_neo4j.py @@ -0,0 +1,593 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# GraphRAGExtractor dependencies +import asyncio +import json +import os + +# GraphRAGStore dependencies +import re +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Union + +import nest_asyncio +import networkx as nx +import openai +import requests +from config import ( + NEO4J_PASSWORD, + NEO4J_URL, + NEO4J_USERNAME, + OPENAI_API_KEY, + OPENAI_EMBEDDING_MODEL, + OPENAI_LLM_MODEL, + TEI_EMBEDDING_ENDPOINT, + TGI_LLM_ENDPOINT, + host_ip, +) +from fastapi import File, Form, HTTPException, UploadFile +from graspologic.partition import hierarchical_leiden +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_text_splitters import HTMLHeaderTextSplitter +from llama_index.core import Document, PropertyGraphIndex, Settings +from llama_index.core.llms import ChatMessage +from llama_index.core.node_parser import LangchainNodeParser +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference +from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore +from llama_index.llms.openai import OpenAI +from llama_index.llms.text_generation_inference import TextGenerationInference +from neo4j import GraphDatabase +from openai import Client + +from comps import CustomLogger, DocPath, opea_microservices, register_microservice +from comps.dataprep.utils import ( + document_loader, + encode_filename, + get_separators, + get_tables_result, + parse_html, + save_content_to_local_disk, +) + +nest_asyncio.apply() + +from llama_index.core.async_utils import run_jobs +from llama_index.core.bridge.pydantic import BaseModel, Field +from llama_index.core.graph_stores.types import KG_NODES_KEY, KG_RELATIONS_KEY, EntityNode, Relation +from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn +from llama_index.core.llms.llm import LLM +from llama_index.core.prompts import PromptTemplate +from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT +from llama_index.core.schema import BaseNode, TransformComponent + + +class GraphRAGStore(Neo4jPropertyGraphStore): + # https://github.com/run-llama/llama_index/blob/main/docs/docs/examples/cookbooks/GraphRAG_v2.ipynb + community_summary = {} + entity_info = None + max_cluster_size = 5 + + def __init__(self, username: str, password: str, url: str, llm: LLM): + super().__init__(username=username, password=password, url=url) + self.llm = llm + + def generate_community_summary(self, text): + """Generate summary for a given text using an LLM.""" + messages = [ + ChatMessage( + role="system", + content=( + "You are provided with a set of relationships from a knowledge graph, each represented as " + "entity1->entity2->relation->relationship_description. Your task is to create a summary of these " + "relationships. The summary should include the names of the entities involved and a concise synthesis " + "of the relationship descriptions. The goal is to capture the most critical and relevant details that " + "highlight the nature and significance of each relationship. Ensure that the summary is coherent and " + "integrates the information in a way that emphasizes the key aspects of the relationships." + ), + ), + ChatMessage(role="user", content=text), + ] + if OPENAI_API_KEY: + response = OpenAI().chat(messages) + else: + response = self.llm.chat(messages) + + clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip() + return clean_response + + def build_communities(self): + """Builds communities from the graph and summarizes them.""" + nx_graph = self._create_nx_graph() + community_hierarchical_clusters = hierarchical_leiden(nx_graph, max_cluster_size=self.max_cluster_size) + self.entity_info, community_info = self._collect_community_info(nx_graph, community_hierarchical_clusters) + self._summarize_communities(community_info) + + def _create_nx_graph(self): + """Converts internal graph representation to NetworkX graph.""" + nx_graph = nx.Graph() + triplets = self.get_triplets() + for entity1, relation, entity2 in triplets: + nx_graph.add_node(entity1.name) + nx_graph.add_node(entity2.name) + nx_graph.add_edge( + relation.source_id, + relation.target_id, + relationship=relation.label, + description=relation.properties["relationship_description"], + ) + return nx_graph + + def _collect_community_info(self, nx_graph, clusters): + """Collect information for each node based on their community, + allowing entities to belong to multiple clusters.""" + entity_info = defaultdict(set) + community_info = defaultdict(list) + + for item in clusters: + node = item.node + cluster_id = item.cluster + + # Update entity_info + entity_info[node].add(cluster_id) + + for neighbor in nx_graph.neighbors(node): + edge_data = nx_graph.get_edge_data(node, neighbor) + if edge_data: + detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}" + community_info[cluster_id].append(detail) + + # Convert sets to lists for easier serialization if needed + entity_info = {k: list(v) for k, v in entity_info.items()} + + return dict(entity_info), dict(community_info) + + def _summarize_communities(self, community_info): + """Generate and store summaries for each community.""" + for community_id, details in community_info.items(): + details_text = "\n".join(details) + "." # Ensure it ends with a period + self.community_summary[community_id] = self.generate_community_summary(details_text) + + # To store summaries in neo4j + # summary = self.generate_community_summary(details_text) + # self.community_summary[ + # community_id + # ] = self.store_community_summary_in_neo4j(community_id, summary) + + def store_community_summary_in_neo4j(self, community_id, summary): + """Store the community summary in Neo4j.""" + with driver.session() as session: + session.run( + """ + MERGE (c:Community {id: $community_id}) + SET c.summary = $summary + """, + community_id=community_id, + summary=summary, + ) + + def get_community_summaries(self): + """Returns the community summaries, building them if not already done.""" + if not self.community_summary: + self.build_communities() + return self.community_summary + + def query_community_summaries(self): + """Query and print community summaries from Neo4j.""" + with driver.session() as session: + result = session.run( + """ + MATCH (c:Community) + RETURN c.id AS community_id, c.summary AS summary + """ + ) + for record in result: + print(f"Community ID: {record['community_id']}") + print(f"Community Summary: {record['summary']}") + + def query_schema(self): + """Query and print the schema information from Neo4j.""" + with driver.session() as session: + result = session.run("CALL apoc.meta.schema()") + schema = result.single()["value"] + + for label, properties in schema.items(): + if "properties" in properties: + print(f"Node Label: {label}") + for prop, prop_info in properties["properties"].items(): + print(f" Property Key: {prop}, Type: {prop_info['type']}") + if "relationships" in properties: + for rel_type, rel_info in properties["relationships"].items(): + print(f"Relationship Type: {rel_type}") + for prop, prop_info in rel_info["properties"].items(): + print(f" Property Key: {prop}, Type: {prop_info['type']}") + + +class GraphRAGExtractor(TransformComponent): + """Extract triples from a graph. + + Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text. + + Args: + llm (LLM): + The language model to use. + extract_prompt (Union[str, PromptTemplate]): + The prompt to use for extracting triples. + parse_fn (callable): + A function to parse the output of the language model. + num_workers (int): + The number of workers to use for parallel processing. + max_paths_per_chunk (int): + The maximum number of paths to extract per chunk. + """ + + llm: LLM + extract_prompt: PromptTemplate + parse_fn: Callable + num_workers: int + max_paths_per_chunk: int + + def __init__( + self, + llm: Optional[LLM] = None, + extract_prompt: Optional[Union[str, PromptTemplate]] = None, + parse_fn: Callable = default_parse_triplets_fn, + max_paths_per_chunk: int = 10, + num_workers: int = 4, + ) -> None: + """Init params.""" + from llama_index.core import Settings + + if isinstance(extract_prompt, str): + extract_prompt = PromptTemplate(extract_prompt) + + super().__init__( + llm=llm, + extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT, + parse_fn=parse_fn, + num_workers=num_workers, + max_paths_per_chunk=max_paths_per_chunk, + ) + + @classmethod + def class_name(cls) -> str: + return "GraphExtractor" + + def __call__(self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any) -> List[BaseNode]: + """Extract triples from nodes.""" + return asyncio.run(self.acall(nodes, show_progress=show_progress, **kwargs)) + + async def _aextract(self, node: BaseNode) -> BaseNode: + """Extract triples from a node.""" + assert hasattr(node, "text") + + text = node.get_content(metadata_mode="llm") + try: + llm_response = await self.llm.apredict( + self.extract_prompt, + text=text, + max_knowledge_triplets=self.max_paths_per_chunk, + ) + entities, entities_relationship = self.parse_fn(llm_response) + except ValueError: + entities = [] + entities_relationship = [] + + existing_nodes = node.metadata.pop(KG_NODES_KEY, []) + existing_relations = node.metadata.pop(KG_RELATIONS_KEY, []) + entity_metadata = node.metadata.copy() + for entity, entity_type, description in entities: + entity_metadata["entity_description"] = description + entity_node = EntityNode(name=entity, label=entity_type, properties=entity_metadata) + existing_nodes.append(entity_node) + + relation_metadata = node.metadata.copy() + for triple in entities_relationship: + subj, obj, rel, description = triple + relation_metadata["relationship_description"] = description + rel_node = Relation( + label=rel, + source_id=subj, + target_id=obj, + properties=relation_metadata, + ) + + existing_relations.append(rel_node) + + node.metadata[KG_NODES_KEY] = existing_nodes + node.metadata[KG_RELATIONS_KEY] = existing_relations + return node + + async def acall(self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any) -> List[BaseNode]: + """Extract triples from nodes async.""" + jobs = [] + for node in nodes: + jobs.append(self._aextract(node)) + + return await run_jobs( + jobs, + workers=self.num_workers, + show_progress=show_progress, + desc="Extracting paths from text", + ) + + +KG_TRIPLET_EXTRACT_TMPL = """ +-Goal- +Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities. +Given the text, extract up to {max_knowledge_triplets} entity-relation triplets. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: Type of the entity +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity strictly as ("entity"$$$$$$$$$$$$). Pay attention to the dollar signs ($$$$) separating the fields and the parentheses surrounding the entire entity. +one example: ("entity"$$$$Apple$$$$Company$$$$Apple Inc. is an American multinational technology company that specializes in consumer electronics, computer software, and online services.) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relation: relationship between source_entity and target_entity +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other + +Format each relationship strictly as ("relationship"$$$$$$$$$$$$$$$$). Pay attention to the dollar signs ($$$$) separating the fields and the parentheses surrounding the entire entity. +one example: ("relationship"$$$$Apple$$$$Steve Jobs$$$$Founded$$$$Steve Jobs co-founded Apple Inc. in 1976.) + +3. When finished, output. + +-Real Data- +###################### +text: {text} +###################### +output:""" + +entity_pattern = r'\("entity"\$\$\$\$(.+?)\$\$\$\$(.+?)\$\$\$\$(.+?)\)' +relationship_pattern = r'\("relationship"\$\$\$\$(.+?)\$\$\$\$(.+?)\$\$\$\$(.+?)\$\$\$\$(.+?)\)' + + +def inspect_db(): + try: + with driver.session() as session: + # Check for property keys + result = session.run("CALL db.propertyKeys()") + property_keys = [record["propertyKey"] for record in result] + print("Property Keys:", property_keys) + + # Check for node labels + result = session.run("CALL db.labels()") + labels = [record["label"] for record in result] + print("Node Labels:", labels) + + # Check for relationship types + result = session.run("CALL db.relationshipTypes()") + relationship_types = [record["relationshipType"] for record in result] + print("Relationship Types:", relationship_types) + except Exception as e: + print(f"Error: {e}") + finally: + driver.close() + + +def parse_fn(response_str: str) -> Any: + entities = re.findall(entity_pattern, response_str) + relationships = re.findall(relationship_pattern, response_str) + if logflag: + logger.info(f"len of parsed entities: {len(entities)} and relationships: {len(relationships)}") + return entities, relationships + + +def get_model_name_from_tgi_endpoint(url): + try: + response = requests.get(f"{url}/info") + response.raise_for_status() # Ensure we notice bad responses + try: + model_info = response.json() + model_name = model_info.get("model_id") + if model_name: + return model_name + else: + logger.error(f"model_id not found in the response from {url}") + return None + except ValueError: + logger.error(f"Invalid JSON response from {url}") + return None + except requests.RequestException as e: + logger.error(f"Request to {url} failed: {e}") + return None + + +logger = CustomLogger("prepare_doc_neo4j") +logflag = os.getenv("LOGFLAG", False) + +upload_folder = "./uploaded_files/" +driver = GraphDatabase.driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) +client = OpenAI() + + +def ingest_data_to_neo4j(doc_path: DocPath): + """Ingest document to Neo4J.""" + path = doc_path.path + if logflag: + logger.info(f"Parsing document {path}.") + + if path.endswith(".html"): + headers_to_split_on = [ + ("h1", "Header 1"), + ("h2", "Header 2"), + ("h3", "Header 3"), + ] + text_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on) + else: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=doc_path.chunk_size, + chunk_overlap=doc_path.chunk_overlap, + add_start_index=True, + separators=get_separators(), + ) + + content = document_loader(path) # single doc string + document = Document(text=content) + + structured_types = [".xlsx", ".csv", ".json", "jsonl"] + _, ext = os.path.splitext(path) + + # create llama-index nodes (chunks) + if ext in structured_types: + nodes = [document] + else: + parser = LangchainNodeParser(text_splitter) # wrap text splitting from langchain w node parser + nodes = parser.get_nodes_from_documents([document]) + + if doc_path.process_table and path.endswith(".pdf"): + table_chunks = get_tables_result(path, doc_path.table_strategy) # list of text + if table_chunks: + table_docs = [Document(text=chunk) for chunk in table_chunks] + nodes = nodes + table_docs + if logflag: + logger.info(f"extract tables nodes: len of table_docs {len(table_docs)}") + + if logflag: + logger.info(f"Done preprocessing. Created {len(nodes)} chunks of the original file.") + + if OPENAI_API_KEY: + logger.info("OpenAI API Key is set. Verifying its validity...") + openai.api_key = OPENAI_API_KEY + try: + llm = OpenAI(temperature=0, model=OPENAI_LLM_MODEL) + embed_model = OpenAIEmbedding(model=OPENAI_EMBEDDING_MODEL, embed_batch_size=100) + logger.info("OpenAI API Key is valid.") + except openai.AuthenticationError: + logger.info("OpenAI API Key is invalid.") + except Exception as e: + logger.info(f"An error occurred while verifying the API Key: {e}") + else: + logger.info("NO OpenAI API Key. TGI/TEI endpoints will be used.") + llm_name = get_model_name_from_tgi_endpoint(TGI_LLM_ENDPOINT) + llm = TextGenerationInference( + model_url=TGI_LLM_ENDPOINT, + model_name=llm_name, + temperature=0.7, + max_tokens=1512, # 512otherwise too shor + ) + emb_name = get_model_name_from_tgi_endpoint(TEI_EMBEDDING_ENDPOINT) + embed_model = TextEmbeddingsInference( + base_url=TEI_EMBEDDING_ENDPOINT, + model_name=emb_name, + timeout=60, # timeout in seconds + embed_batch_size=10, # batch size for embedding + ) + Settings.embed_model = embed_model + Settings.llm = llm + kg_extractor = GraphRAGExtractor( + llm=llm, + extract_prompt=KG_TRIPLET_EXTRACT_TMPL, + max_paths_per_chunk=2, + parse_fn=parse_fn, + ) + graph_store = GraphRAGStore(username=NEO4J_USERNAME, password=NEO4J_PASSWORD, url=NEO4J_URL, llm=llm) + + # nodes are the chunked docs to insert + index = PropertyGraphIndex( + nodes=nodes, + llm=llm, + kg_extractors=[kg_extractor], + property_graph_store=graph_store, + embed_model=embed_model or Settings.embed_model, + show_progress=True, + ) + inspect_db() + if logflag: + logger.info(f"Total number of triplets {len(index.property_graph_store.get_triplets())}") + + # index.property_graph_store.build_communities() + # print("done building communities") + + if logflag: + logger.info("The graph is built.") + + return True + + +@register_microservice( + name="opea_service@extract_graph_neo4j", + endpoint="/v1/dataprep", + host="0.0.0.0", + port=6004, + input_datatype=DocPath, + output_datatype=None, +) +async def ingest_documents( + files: Optional[Union[UploadFile, List[UploadFile]]] = File(None), + link_list: Optional[str] = Form(None), + chunk_size: int = Form(1500), + chunk_overlap: int = Form(100), + process_table: bool = Form(False), + table_strategy: str = Form("fast"), +): + if logflag: + logger.info(f"files:{files}") + logger.info(f"link_list:{link_list}") + + if files: + if not isinstance(files, list): + files = [files] + uploaded_files = [] + for file in files: + encode_file = encode_filename(file.filename) + save_path = upload_folder + encode_file + await save_content_to_local_disk(save_path, file) + ingest_data_to_neo4j( + DocPath( + path=save_path, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + process_table=process_table, + table_strategy=table_strategy, + ) + ) + uploaded_files.append(save_path) + if logflag: + logger.info(f"Successfully saved file {save_path}") + result = {"status": 200, "message": "Data preparation succeeded"} + if logflag: + logger.info(result) + return result + + if link_list: + link_list = json.loads(link_list) # Parse JSON string to list + if not isinstance(link_list, list): + raise HTTPException(status_code=400, detail="link_list should be a list.") + for link in link_list: + encoded_link = encode_filename(link) + save_path = upload_folder + encoded_link + ".txt" + content = parse_html([link])[0][0] + try: + await save_content_to_local_disk(save_path, content) + ingest_data_to_neo4j( + DocPath( + path=save_path, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + process_table=process_table, + table_strategy=table_strategy, + ) + ) + except json.JSONDecodeError: + raise HTTPException(status_code=500, detail="Fail to ingest data into qdrant.") + + if logflag: + logger.info(f"Successfully saved link {link}") + + result = {"status": 200, "message": "Data preparation succeeded"} + if logflag: + logger.info(result) + return result + + raise HTTPException(status_code=400, detail="Must provide either a file or a string list.") + + +if __name__ == "__main__": + opea_microservices["opea_service@extract_graph_neo4j"].start() diff --git a/comps/dataprep/neo4j/llama_index/requirements.txt b/comps/dataprep/neo4j/llama_index/requirements.txt new file mode 100644 index 0000000000..fc5f7b8d61 --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/requirements.txt @@ -0,0 +1,38 @@ +beautifulsoup4 +cairosvg +docarray[full] +docx2txt +easyocr +fastapi +future +graspologic +huggingface_hub +ipython +langchain +langchain-text-splitters +langchain_community +llama-index +llama-index-core +llama-index-embeddings-text-embeddings-inference +llama-index-llms-openai +llama-index-llms-text-generation-inference +llama_index_graph_stores_neo4j==0.3.3 +markdown +neo4j +numpy +openai +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk +pandas +Pillow +prometheus-fastapi-instrumentator +pymupdf +pytesseract +python-docx +python-pptx +scipy +sentence_transformers +shortuuid +unstructured[all-docs]==0.15.7 +uvicorn diff --git a/comps/dataprep/neo4j/llama_index/set_env.sh b/comps/dataprep/neo4j/llama_index/set_env.sh new file mode 100644 index 0000000000..dd5d2a15d5 --- /dev/null +++ b/comps/dataprep/neo4j/llama_index/set_env.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Remember to set your private variables mentioned in README +# host_ip, OPENAI_KEY, HUGGINGFACEHUB_API_TOKEN, proxies... + +export EMBEDDING_MODEL_ID="BAAI/bge-base-en-v1.5" +export OPENAI_EMBEDDING_MODEL="text-embedding-3-small" +export LLM_MODEL_ID="Intel/neural-chat-7b-v3-3" +export OPENAI_LLM_MODEL="gpt-4o" +export TEI_EMBEDDING_ENDPOINT="http://${host_ip}:6006" +export TGI_LLM_ENDPOINT="http://${host_ip}:6005" +export NEO4J_URL="bolt://${host_ip}:7687" +export NEO4J_USERNAME=neo4j +export DATAPREP_SERVICE_ENDPOINT="http://${host_ip}:6004/v1/dataprep" +export LOGFLAG=True diff --git a/comps/retrievers/neo4j/llama_index/Dockerfile b/comps/retrievers/neo4j/llama_index/Dockerfile new file mode 100644 index 0000000000..1b601805d0 --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/Dockerfile @@ -0,0 +1,35 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.11-slim + +ENV LANG=C.UTF-8 + +ENV HUGGINGFACEHUB_API_TOKEN=dummy + +ARG ARCH="cpu" + +RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \ + build-essential \ + libgl1-mesa-glx \ + libjemalloc-dev \ + libcairo2 \ + libglib2.0-0 + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +USER user + +COPY comps /home/user/comps + +RUN pip install --no-cache-dir --upgrade pip && \ + if [ ${ARCH} = "cpu" ]; then pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu; fi && \ + pip install --no-cache-dir -r /home/user/comps/retrievers/neo4j/llama_index/requirements.txt + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +WORKDIR /home/user/comps/retrievers/neo4j/llama_index + +ENTRYPOINT ["python", "retriever_community_answers_neo4j.py"] diff --git a/comps/retrievers/neo4j/llama_index/README.md b/comps/retrievers/neo4j/llama_index/README.md new file mode 100644 index 0000000000..05368a3d07 --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/README.md @@ -0,0 +1,65 @@ +# Retriever Microservice with Neo4J + +This retrieval miicroservice is intended for use in GraphRAG pipeline and assumes a GraphRAGStore exists. +Retrieval follows these steps: + +- Performs hierarchical_leiden clustering to identify communities in the knowledge graph +- Performs similarty to find the relevant entities to the input query +- Generates a community symmary for each community +- Generates a partial answer to the query for each community summary. This will later be used as context to generate a final query response. Please refer to [GenAIExamples/GraphRAG](https://github.com/opea-project/GenAIExamples). + +## 🚀Start Microservice with Docker + +### 1. Build Docker Image + +```bash +cd ../../ +docker build -t opea/retriever-community-answers-neo4j:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/neo4j/llama_index/Dockerfile . +``` + +### 2. Setup Environment Variables + +```bash +# Set private environment settings +export host_ip=${your_hostname IP} # local IP +export no_proxy=$no_proxy,${host_ip} # important to add {host_ip} for containers communication +export http_proxy=${your_http_proxy} +export https_proxy=${your_http_proxy} +export NEO4J_URI=${your_neo4j_url} +export NEO4J_USERNAME=${your_neo4j_username} +export NEO4J_PASSWORD=${your_neo4j_password} +export PYTHONPATH=${path_to_comps} +export OPENAI_KEY=${your_openai_api_key} # optional, when not provided will use smaller models TGI/TEI +export HUGGINGFACEHUB_API_TOKEN=${your_hf_token} +# set additional environment settings +source ./set_env.sh +``` + +### 3. Run Docker with Docker Compose + +Docker compose will start 5 microservices: retriever-neo4j-llamaindex, dataprep-neo4j-llamaindex, neo4j-apoc, tgi-gaudi-service and tei-embedding-service. The reason TGI and TEI are needed is because retriever relies on LLM to extract community summaries from the community triplets that are identified as relevant to the input query. Neo4j database supports embeddings natively so we do not need a separate vector store. Checkout the blog [Introducing the Property Graph Index: A Powerful New Way to Build Knowledge Graphs with LLMs](https://www.llamaindex.ai/blog/introducing-the-property-graph-index-a-powerful-new-way-to-build-knowledge-graphs-with-llms) for a better understanding of Property Graph Store and Index. + +```bash +cd comps/retrievers/neo4j/llama_index +docker compose -f compose.yaml up -d +``` + +## Invoke Microservice + +### 3.1 Check Service Status + +```bash +curl http://${host_ip}:6009/v1/health_check \ + -X GET \ + -H 'Content-Type: application/json' +``` + +### 3.2 Consume Retriever Service + +If OPEN_AI_KEY is provided it will use OPENAI endpoints for LLM and Embeddings otherwise will use TGI and TEI endpoints. If a model name not provided in the request it will use the default specified by the set_env.sh script. + +```bash +curl -X POST http://${host_ip}:6009/v1/retrieval \ + -H "Content-Type: application/json" \ + -d '{"model": "gpt-3.5-turbo","messages": [{"role": "user","content": "Who is John Brady and has he had any confrontations?"}]}' +``` diff --git a/comps/retrievers/neo4j/llama_index/__init__.py b/comps/retrievers/neo4j/llama_index/__init__.py new file mode 100644 index 0000000000..916f3a44b2 --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/comps/retrievers/neo4j/llama_index/compose.yaml b/comps/retrievers/neo4j/llama_index/compose.yaml new file mode 100644 index 0000000000..cc3dfa7f1c --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/compose.yaml @@ -0,0 +1,124 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +version: "3.8" +services: + neo4j-apoc: + image: neo4j:latest + container_name: neo4j-apoc + volumes: + - /$HOME/neo4j/logs:/logs + - /$HOME/neo4j/config:/config + - /$HOME/neo4j/data:/data + - /$HOME/neo4j/plugins:/plugins + ipc: host + environment: + - NEO4J_AUTH=${NEO4J_USERNAME}/${NEO4J_PASSWORD} + - NEO4J_PLUGINS=["apoc"] + - NEO4J_apoc_export_file_enabled=true + - NEO4J_apoc_import_file_enabled=true + - NEO4J_apoc_import_file_use__neo4j__config=true + - NEO4J_dbms_security_procedures_unrestricted=apoc.\* + ports: + - "7474:7474" + - "7687:7687" + restart: always + tei-embedding-service: + image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 + container_name: tei-embedding-server + ports: + - "6006:80" + volumes: + - "./data:/data" + shm_size: 1g + environment: + no_proxy: ${no_proxy} + NO_PROXY: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HUGGING_FACE_HUB_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} + ipc: host + command: --model-id ${EMBEDDING_MODEL_ID} --auto-truncate + tgi-gaudi-service: + image: ghcr.io/huggingface/tgi-gaudi:2.0.5 + container_name: tgi-gaudi-server + ports: + - "6005:80" + volumes: + - "./data:/data" + environment: + no_proxy: ${no_proxy} + NO_PROXY: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HUGGING_FACE_HUB_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} + HF_HUB_DISABLE_PROGRESS_BARS: 1 + HF_HUB_ENABLE_HF_TRANSFER: 0 + HABANA_VISIBLE_DEVICES: all + OMPI_MCA_btl_vader_single_copy_mechanism: none + ENABLE_HPU_GRAPH: true + LIMIT_HPU_GRAPH: true + USE_FLASH_ATTENTION: true + FLASH_ATTENTION_RECOMPUTE: true + runtime: habana + cap_add: + - SYS_NICE + ipc: host + command: --model-id ${LLM_MODEL_ID} --max-input-length 2048 --max-total-tokens 4096 + dataprep-neo4j-llamaindex: + image: opea/dataprep-neo4j-llamaindex:latest + container_name: dataprep-neo4j-server + depends_on: + - neo4j-apoc + - tgi-gaudi-service + - tei-embedding-service + ports: + - "6004:6004" + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + host_ip: ${host_ip} + NEO4J_URL: ${NEO4J_URL} + NEO4J_USERNAME: ${NEO4J_USERNAME} + NEO4J_PASSWORD: ${NEO4J_PASSWORD} + TGI_LLM_ENDPOINT: ${TGI_LLM_ENDPOINT} + TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} + OPENAI_API_KEY: ${OPENAI_API_KEY} + OPENAI_EMBEDDING_MODEL: ${OPENAI_EMBEDDING_MODEL} + OPENAI_LLM_MODEL: ${OPENAI_LLM_MODEL} + EMBEDDING_MODEL_ID: ${EMBEDDING_MODEL_ID} + LLM_MODEL_ID: ${LLM_MODEL_ID} + LOGFLAG: ${LOGFLAG} + restart: unless-stopped + retriever-neo4j-llamaindex: + image: opea/retriever-neo4j-llamaindex:latest + container_name: retriever-neo4j-server + depends_on: + - neo4j-apoc + - tgi-gaudi-service + - tei-embedding-service + ports: + - "6009:6009" + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + host_ip: ${host_ip} + NEO4J_URL: ${NEO4J_URL} + NEO4J_USERNAME: ${NEO4J_USERNAME} + NEO4J_PASSWORD: ${NEO4J_PASSWORD} + TGI_LLM_ENDPOINT: ${TGI_LLM_ENDPOINT} + TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} + OPENAI_API_KEY: ${OPENAI_API_KEY} + OPENAI_EMBEDDING_MODEL: ${OPENAI_EMBEDDING_MODEL} + OPENAI_LLM_MODEL: ${OPENAI_LLM_MODEL} + EMBEDDING_MODEL_ID: ${EMBEDDING_MODEL_ID} + LLM_MODEL_ID: ${LLM_MODEL_ID} + LOGFLAG: ${LOGFLAG} + restart: unless-stopped +networks: + default: + driver: bridge diff --git a/comps/retrievers/neo4j/llama_index/config.py b/comps/retrievers/neo4j/llama_index/config.py new file mode 100644 index 0000000000..3037b8f9fb --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/config.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +host_ip = os.getenv("host_ip") +# Neo4J configuration +NEO4J_URL = os.getenv("NEO4J_URL", f"bolt://{host_ip}:7687") +NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j") +NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4jtest") + +# LLM/Embedding endpoints +TGI_LLM_ENDPOINT = os.getenv("TGI_LLM_ENDPOINT", f"http://{host_ip}:6005") +TEI_EMBEDDING_ENDPOINT = os.getenv("TEI_EMBEDDING_ENDPOINT ", f"http://{host_ip}:6006") + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +OPENAI_EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") +OPENAI_LLM_MODEL = os.getenv("OPENAI_LLM_MODEL", "gpt-4o") diff --git a/comps/retrievers/neo4j/llama_index/requirements.txt b/comps/retrievers/neo4j/llama_index/requirements.txt new file mode 100644 index 0000000000..c91f71ba73 --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/requirements.txt @@ -0,0 +1,36 @@ +bs4 +cairosvg +docarray[full] +docx2txt +fastapi +frontend +future +graspologic +huggingface_hub +langchain +langchain-community +llama-index-core +llama-index-embeddings-openai +llama-index-embeddings-text-embeddings-inference +llama-index-llms-openai +llama-index-llms-text-generation-inference +llama_index_graph_stores_neo4j==0.3.3 +neo4j +numpy +opencv-python +opentelemetry-api +opentelemetry-exporter-otlp +opentelemetry-sdk +pandas +Pillow +prometheus-fastapi-instrumentator +pydantic +pymupdf +pytesseract +python-docx +python-multipart +python-pptx +sentence_transformers +shortuuid +tiktoken +uvicorn diff --git a/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py b/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py new file mode 100644 index 0000000000..837cd87e9a --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/retriever_community_answers_neo4j.py @@ -0,0 +1,231 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import os +import re +import time +from typing import List, Union + +import openai +from config import ( + NEO4J_PASSWORD, + NEO4J_URL, + NEO4J_USERNAME, + OPENAI_API_KEY, + OPENAI_EMBEDDING_MODEL, + OPENAI_LLM_MODEL, + TEI_EMBEDDING_ENDPOINT, + TGI_LLM_ENDPOINT, +) +from llama_index.core import PropertyGraphIndex, Settings +from llama_index.core.indices.property_graph.sub_retrievers.vector import VectorContextRetriever +from llama_index.core.llms import LLM, ChatMessage +from llama_index.core.query_engine import CustomQueryEngine +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings.text_embeddings_inference import TextEmbeddingsInference +from llama_index.llms.openai import OpenAI +from llama_index.llms.text_generation_inference import TextGenerationInference +from pydantic import BaseModel, PrivateAttr + +from comps import ( + CustomLogger, + EmbedDoc, + SearchedDoc, + ServiceType, + TextDoc, + opea_microservices, + register_microservice, + register_statistics, + statistics_dict, +) +from comps.cores.proto.api_protocol import ( + ChatCompletionRequest, + RetrievalRequest, + RetrievalResponse, + RetrievalResponseData, +) +from comps.dataprep.neo4j.llama_index.extract_graph_neo4j import GraphRAGStore, get_model_name_from_tgi_endpoint + +logger = CustomLogger("retriever_neo4j") +logflag = os.getenv("LOGFLAG", False) + + +class GraphRAGQueryEngine(CustomQueryEngine): + # https://github.com/run-llama/llama_index/blob/main/docs/docs/examples/cookbooks/GraphRAG_v2.ipynb + # private attr because inherits from BaseModel + _graph_store: GraphRAGStore = PrivateAttr() + _index: PropertyGraphIndex = PrivateAttr() + _llm: LLM = PrivateAttr() + _similarity_top_k: int = PrivateAttr() + + def __init__(self, graph_store: GraphRAGStore, llm: LLM, index: PropertyGraphIndex, similarity_top_k: int = 20): + super().__init__() + self._graph_store = graph_store + self._index = index + self._llm = llm + self._similarity_top_k = similarity_top_k + + def custom_query(self, query_str: str) -> RetrievalResponseData: + """Process all community summaries to generate answers to a specific query.""" + + entities = self.get_entities(query_str, self._similarity_top_k) + + community_ids = self.retrieve_entity_communities(self._graph_store.entity_info, entities) + community_summaries = self._graph_store.get_community_summaries() + if logflag: + logger.info(f"Community ids: {community_ids}") + community_answers = [ + self.generate_answer_from_summary(community_summary, query_str) + for id, community_summary in community_summaries.items() + if id in community_ids + ] + # Convert answers to RetrievalResponseData objects + response_data = [RetrievalResponseData(text=answer, metadata={}) for answer in community_answers] + return response_data + + def get_entities(self, query_str, similarity_top_k): + if logflag: + logger.info(f"Retrieving entities for query: {query_str} with top_k: {similarity_top_k}") + # TODO: make retrever configurable [VectorContextRetriever]or [LLMSynonymRetriever] + vecContext_retriever = VectorContextRetriever( + graph_store=self._graph_store, + embed_model=self._index._embed_model, + similarity_top_k=self._similarity_top_k, + # similarity_score=0.6 + ) + nodes_retrieved = self._index.as_retriever( + sub_retrievers=[vecContext_retriever], similarity_top_k=self._similarity_top_k + ).retrieve(query_str) + # if subretriever not specified it will use LLMSynonymRetriever with Settings.llm model + # nodes_retrieved = self._index.as_retriever(similarity_top_k=self._similarity_top_k).retrieve(query_str) + entities = set() + pattern = r"(\w+(?:\s+\w+)*)\s*->\s*(\w+(?:\s+\w+)*)\s*->\s*(\w+(?:\s+\w+)*)" + + for node in nodes_retrieved: + matches = re.findall(pattern, node.text, re.DOTALL) + + for match in matches: + subject = match[0] + obj = match[2] + entities.add(subject) + entities.add(obj) + if logflag: + logger.info(f"entities from query {entities}") + return list(entities) + + def retrieve_entity_communities(self, entity_info, entities): + """Retrieve cluster information for given entities, allowing for multiple clusters per entity. + + Args: + entity_info (dict): Dictionary mapping entities to their cluster IDs (list). + entities (list): List of entity names to retrieve information for. + + Returns: + List of community or cluster IDs to which an entity belongs. + """ + community_ids = [] + + for entity in entities: + if entity in entity_info: + community_ids.extend(entity_info[entity]) + + return list(set(community_ids)) + + def generate_answer_from_summary(self, community_summary, query): + """Generate an answer from a community summary based on a given query using LLM.""" + prompt = ( + f"Given the community summary: {community_summary}, " + f"how would you answer the following query? Query: {query}" + ) + messages = [ + ChatMessage(role="system", content=prompt), + ChatMessage( + role="user", + content="I need an answer based on the above information.", + ), + ] + response = self._llm.chat(messages) + cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip() + return cleaned_response + + +@register_microservice( + name="opea_service@retriever_community_answers_neo4j", + service_type=ServiceType.RETRIEVER, + endpoint="/v1/retrieval", + host="0.0.0.0", + port=6009, +) +@register_statistics(names=["opea_service@retriever_community_answers_neo4j"]) +async def retrieve(input: Union[ChatCompletionRequest]) -> Union[ChatCompletionRequest]: + if logflag: + logger.info(input) + start = time.time() + query = input.messages[0]["content"] + logger.info(f"Query received in retriever: {query}") + + if OPENAI_API_KEY: + logger.info("OpenAI API Key is set. Verifying its validity...") + openai.api_key = OPENAI_API_KEY + try: + llm = OpenAI(temperature=0, model=OPENAI_LLM_MODEL) + embed_model = OpenAIEmbedding(model=OPENAI_EMBEDDING_MODEL, embed_batch_size=100) + logger.info("OpenAI API Key is valid.") + except openai.AuthenticationError: + logger.info("OpenAI API Key is invalid.") + except Exception as e: + logger.info(f"An error occurred while verifying the API Key: {e}") + else: + logger.info("No OpenAI API KEY provided. Will use TGI and TEI endpoints") + llm_name = get_model_name_from_tgi_endpoint(TGI_LLM_ENDPOINT) + llm = TextGenerationInference( + model_url=TGI_LLM_ENDPOINT, + model_name=llm_name, + temperature=0.7, + max_tokens=1512, # 512otherwise too shor + ) + emb_name = get_model_name_from_tgi_endpoint(TEI_EMBEDDING_ENDPOINT) + embed_model = TextEmbeddingsInference( + base_url=TEI_EMBEDDING_ENDPOINT, + model_name=emb_name, + timeout=60, # timeout in seconds + embed_batch_size=10, # batch size for embedding + ) + Settings.embed_model = embed_model + Settings.llm = llm + # pre-existiing graph store (created with data_prep/llama-index/extract_graph_neo4j.py) + graph_store = GraphRAGStore(username=NEO4J_USERNAME, password=NEO4J_PASSWORD, url=NEO4J_URL, llm=llm) + + index = PropertyGraphIndex.from_existing( + property_graph_store=graph_store, + embed_model=embed_model or Settings.embed_model, + embed_kg_nodes=True, + ) + index.property_graph_store.build_communities() + query_engine = GraphRAGQueryEngine( + graph_store=index.property_graph_store, + llm=llm, + index=index, + similarity_top_k=3, + ) + + # these are the answers from the community summaries + answers_by_community = query_engine.query(query) + input.retrieved_docs = answers_by_community + input.documents = [doc.text for doc in answers_by_community] + result = ChatCompletionRequest( + messages="Retrieval of answers from community summaries successful", + retrieved_docs=input.retrieved_docs, + documents=input.documents, + ) + + statistics_dict["opea_service@retriever_community_answers_neo4j"].append_latency(time.time() - start, None) + + if logflag: + logger.info(result) + return result + + +if __name__ == "__main__": + opea_microservices["opea_service@retriever_community_answers_neo4j"].start() diff --git a/comps/retrievers/neo4j/llama_index/set_env.sh b/comps/retrievers/neo4j/llama_index/set_env.sh new file mode 100644 index 0000000000..dcaaad0fbe --- /dev/null +++ b/comps/retrievers/neo4j/llama_index/set_env.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +# Remember to set your private variables mentioned in README +# host_ip, OPENAI_KEY, HUGGINGFACEHUB_API_TOKEN, proxies... + +export EMBEDDING_MODEL_ID="BAAI/bge-base-en-v1.5" +export OPENAI_EMBEDDING_MODEL="text-embedding-3-small" +export LLM_MODEL_ID="meta-llama/Meta-Llama-3-8B-Instruct" +export OPENAI_LLM_MODEL="gpt-4o" +export TEI_EMBEDDING_ENDPOINT="http://${host_ip}:6006" +export TGI_LLM_ENDPOINT="http://${host_ip}:6005" +export NEO4J_URL="bolt://${host_ip}:7687" +export NEO4J_USERNAME=neo4j +export DATAPREP_SERVICE_ENDPOINT="http://${host_ip}:6004/v1/dataprep" +export LOGFLAG=True diff --git a/tests/dataprep/test_dataprep_neo4j_llama_index_on_intel_hpu.sh b/tests/dataprep/test_dataprep_neo4j_llama_index_on_intel_hpu.sh new file mode 100755 index 0000000000..81b716993d --- /dev/null +++ b/tests/dataprep/test_dataprep_neo4j_llama_index_on_intel_hpu.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +set -x + +WORKPATH=$(dirname "$PWD") +LOG_PATH="$WORKPATH/tests" +ip_address=$(hostname -I | awk '{print $1}') + +function build_docker_images() { + cd $WORKPATH + echo $(pwd) + docker build --no-cache -t opea/dataprep-neo4j-llamaindex:comps --build-arg no_proxy=$no_proxy --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/dataprep/neo4j/llama_index/Dockerfile . + if [ $? -ne 0 ]; then + echo "opea/dataprep-neo4j-llamaindex built fail" + exit 1 + else + echo "opea/dataprep-neo4j-llamaindex built successful" + fi + docker pull ghcr.io/huggingface/tgi-gaudi:2.0.5 + docker pull ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 +} + +function start_service() { + # neo4j-apoc + docker run -d -p 7474:7474 -p 7687:7687 --name test-comps-neo4j-apoc --env NEO4J_AUTH=neo4j/neo4jtest -e NEO4J_apoc_export_file_enabled=true -e NEO4J_apoc_import_file_enabled=true -e NEO4J_apoc_import_file_use__neo4j__config=true -e NEO4J_PLUGINS=\[\"apoc\"\] neo4j:latest + #sleep 30s + + # tei endpoint + emb_model="BAAI/bge-base-en-v1.5" + docker run -d --name="test-comps-dataprep-neo4j-tei-endpoint" -p 6006:80 -v ./data:/data -e no_proxy=$no_proxy -e http_proxy=$http_proxy \ + -e https_proxy=$https_proxy -e HF_TOKEN=$HF_TOKEN --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 --model-id $emb_model + sleep 30s + export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:6006" + + # tgi gaudi endpoint + model="meta-llama/Meta-Llama-3-8B-Instruct" + docker run -d --name="test-comps-dataprep-neo4j-tgi-endpoint" -p 6005:80 -v ./data:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all \ + -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HF_TOKEN=$HF_TOKEN -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true \ + -e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice -e no_proxy=$no_proxy -e http_proxy=$http_proxy -e https_proxy=$https_proxy \ + --ipc=host --pull always ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 --max-total-tokens 3000 + sleep 30s + export TGI_LLM_ENDPOINT="http://${ip_address}:6005" + + # dataprep neo4j + # Not testing openai code path since not able to provide key for cicd + docker run -d --name="test-comps-dataprep-neo4j-server" -p 6004:6004 -v ./data:/data --ipc=host -e TGI_LLM_ENDPOINT=$TGI_LLM_ENDPOINT \ + -e TEI_EMBEDDING_ENDPOINT=$TEI_EMBEDDING_ENDPOINT -e EMBEDDING_MODEL_ID=$emb_model -e LLM_MODEL_ID=$model -e host_ip=$ip_address -e no_proxy=$no_proxy \ + -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e NEO4J_URI="bolt://${ip_address}:7687" -e NEO4J_USERNAME="neo4j" \ + -e NEO4J_PASSWORD="neo4jtest" -e LOGFLAG=True opea/dataprep-neo4j-llamaindex:comps + sleep 30s + export DATAPREP_SERVICE_ENDPOINT="http://${ip_address}:6004" + +} + +function validate_service() { + local URL="$1" + local EXPECTED_RESULT="$2" + local SERVICE_NAME="$3" + local DOCKER_NAME="$4" + local INPUT_DATA="$5" + + if [[ $SERVICE_NAME == *"extract_graph_neo4j"* ]]; then + cd $LOG_PATH + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -F 'files=@./dataprep_file.txt' -H 'Content-Type: multipart/form-data' "$URL") + elif [[ $SERVICE_NAME == *"neo4j-apoc"* ]]; then + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" "$URL") + else + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -d "$INPUT_DATA" -H 'Content-Type: application/json' "$URL") + fi + HTTP_STATUS=$(echo $HTTP_RESPONSE | tr -d '\n' | sed -e 's/.*HTTPSTATUS://') + RESPONSE_BODY=$(echo $HTTP_RESPONSE | sed -e 's/HTTPSTATUS\:.*//g') + + docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log + + # check response status + if [ "$HTTP_STATUS" -ne "200" ]; then + echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS" + exit 1 + else + echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..." + fi + # check response body + if [[ "$SERVICE_NAME" == "neo4j-apoc" ]]; then + echo "[ $SERVICE_NAME ] Skipping content check for neo4j-apoc." + else + if [[ "$RESPONSE_BODY" != *"$EXPECTED_RESULT"* ]]; then + echo "[ $SERVICE_NAME ] Content does not match the expected result: $RESPONSE_BODY" + exit 1 + else + echo "[ $SERVICE_NAME ] Content is as expected." + fi + fi + + sleep 5s +} + +function validate_microservice() { + # validate neo4j-apoc + validate_service \ + "${ip_address}:7474" \ + "200 OK" \ + "neo4j-apoc" \ + "test-comps-neo4j-apoc" \ + "" + sleep 1m # retrieval can't curl as expected, try to wait for more time + # tgi for llm service + validate_service \ + "${ip_address}:6005/generate" \ + "generated_text" \ + "tgi-gaudi-service" \ + "test-comps-dataprep-neo4j-tgi-endpoint" \ + '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17, "do_sample": true}}' + + # test /v1/dataprep graph extraction + echo "Like many companies in the O&G sector, the stock of Chevron (NYSE:CVX) has declined about 10% over the past 90-days despite the fact that Q2 consensus earnings estimates have risen sharply (~25%) during that same time frame. Over the years, Chevron has kept a very strong balance sheet. FirstEnergy (NYSE:FE – Get Rating) posted its earnings results on Tuesday. The utilities provider reported $0.53 earnings per share for the quarter, topping the consensus estimate of $0.52 by $0.01, RTT News reports. FirstEnergy had a net margin of 10.85% and a return on equity of 17.17%. The Dáil was almost suspended on Thursday afternoon after Sinn Féin TD John Brady walked across the chamber and placed an on-call pager in front of the Minister for Housing Darragh O’Brien during a debate on retained firefighters. Mr O’Brien said Mr Brady had taken part in an act of theatre that was obviously choreographed.Around 2,000 retained firefighters around the country staged a second day of industrial action on Tuesday and are due to start all out-strike action from next Tuesday. The mostly part-time workers, who keep the services going outside of Ireland’s larger urban centres, are taking industrial action in a dispute over pay and working conditions. Speaking in the Dáil, Sinn Féin deputy leader Pearse Doherty said firefighters had marched on Leinster House today and were very angry at the fact the Government will not intervene. Reintroduction of tax relief on mortgages needs to be considered, O’Brien says. Martin withdraws comment after saying People Before Profit would ‘put the jackboot on people’ Taoiseach ‘propagated fears’ farmers forced to rewet land due to nature restoration law – Cairns An intervention is required now. I’m asking you to make an improved offer in relation to pay for retained firefighters, Mr Doherty told the housing minister.I’m also asking you, and challenging you, to go outside after this Order of Business and meet with the firefighters because they are just fed up to the hilt in relation to what you said.Some of them have handed in their pagers to members of the Opposition and have challenged you to wear the pager for the next number of weeks, put up with an €8,600 retainer and not leave your community for the two and a half kilometres and see how you can stand over those type of pay and conditions. At this point, Mr Brady got up from his seat, walked across the chamber and placed the pager on the desk in front of Mr O’Brien. Ceann Comhairle Seán Ó Fearghaíl said the Sinn Féin TD was completely out of order and told him not to carry out a charade in this House, adding it was absolutely outrageous behaviour and not to be encouraged.Mr O’Brien said Mr Brady had engaged in an act of theatre here today which was obviously choreographed and was then interrupted with shouts from the Opposition benches. Mr Ó Fearghaíl said he would suspend the House if this racket continues.Mr O’Brien later said he said he was confident the dispute could be resolved and he had immense regard for firefighters. The minister said he would encourage the unions to re-engage with the State’s industrial relations process while also accusing Sinn Féin of using the issue for their own political gain." > $LOG_PATH/dataprep_file.txt + validate_service \ + "http://${ip_address}:6004/v1/dataprep" \ + "Data preparation succeeded" \ + "extract_graph_neo4j" \ + "test-comps-dataprep-neo4j-server" + +} +function kill_process_on_port() { + local port=$1 + local pid=$(lsof -t -i:$port) + if [[ ! -z "$pid" ]]; then + echo "Killing process $pid on port $port" + kill -9 $pid + else + echo "No process found on port $port" + fi +} + +function stop_docker() { + cid_retrievers=$(docker ps -aq --filter "name=test-comps-dataprep-neo4j*") + if [[ ! -z "$cid_retrievers" ]]; then + docker stop $cid_retrievers && docker rm $cid_retrievers && sleep 1s + fi + cid_db=$(docker ps -aq --filter "name=test-comps-neo4j-apoc") + if [[ ! -z "$cid_retrievers" ]]; then + docker stop $cid_retrievers && docker rm $cid_retrievers && sleep 1s + fi +} + +function main() { + kill_process_on_port 6006 + + stop_docker + + build_docker_images + start_service + + validate_microservice + + stop_docker + echo y | docker system prune + +} + +main diff --git a/tests/retrievers/test_retrievers_neo4j_llama_index_on_intel_hpu.sh b/tests/retrievers/test_retrievers_neo4j_llama_index_on_intel_hpu.sh new file mode 100755 index 0000000000..e4606bde8b --- /dev/null +++ b/tests/retrievers/test_retrievers_neo4j_llama_index_on_intel_hpu.sh @@ -0,0 +1,192 @@ +#!/usr/bin/env bash +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +set -x + +WORKPATH=$(dirname "$PWD") +LOG_PATH="$WORKPATH/tests" +ip_address=$(hostname -I | awk '{print $1}') + +function build_docker_images() { + cd $WORKPATH + echo "current dir: $PWD" + docker build --no-cache -t opea/retriever-neo4j-llamaindex:comps --build-arg no_proxy=$no_proxy --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/retrievers/neo4j/llama_index/Dockerfile . + if [ $? -ne 0 ]; then + echo "opea/retriever-neo4j-llamaindex built fail" + exit 1 + else + echo "opea/retriever-neo4j-llamaindex built successful" + fi + + docker build --no-cache -t opea/dataprep-neo4j-llamaindex:comps --build-arg no_proxy=$no_proxy --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/dataprep/neo4j/llama_index/Dockerfile . + if [ $? -ne 0 ]; then + echo "opea/dataprep-neo4j-llamaindex built fail" + exit 1 + else + echo "opea/dataprep-neo4j-llamaindex built successful" + fi + docker pull ghcr.io/huggingface/tgi-gaudi:2.0.5 + docker pull ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 +} + +function start_service() { + # neo4j-apoc + docker run -d -p 7474:7474 -p 7687:7687 --name test-comps-retrievers-neo4j-llama-index-neo4j-apoc --env NEO4J_AUTH=neo4j/neo4jtest -e NEO4J_apoc_export_file_enabled=true -e NEO4J_apoc_import_file_enabled=true -e NEO4J_apoc_import_file_use__neo4j__config=true -e NEO4J_PLUGINS=\[\"apoc\"\] neo4j:latest + + # tei endpoint + emb_model="BAAI/bge-base-en-v1.5" + docker run -d --name="test-comps-retrievers-neo4j-llama-index-tei" -p 6006:80 -v ./data:/data -e no_proxy=$no_proxy -e http_proxy=$http_proxy \ + -e https_proxy=$https_proxy -e HF_TOKEN=$HF_TOKEN --pull always ghcr.io/huggingface/text-embeddings-inference:cpu-1.5 --model-id $emb_model + sleep 30s + export TEI_EMBEDDING_ENDPOINT="http://${ip_address}:6006" + + # tgi gaudi endpoint + # Meta-Llama-3-8B-Instruct IS NOT GOOD ENOUGH FOR EXTRACTING HIGH QUALITY GRAPH BUT OK FOR CI TESTING + model="meta-llama/Meta-Llama-3-8B-Instruct" + docker run -d --name="test-comps-retrievers-neo4j-llama-index-tgi" -p 6005:80 -v ./data:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all \ + -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e HF_TOKEN=$HF_TOKEN -e ENABLE_HPU_GRAPH=true -e LIMIT_HPU_GRAPH=true \ + -e USE_FLASH_ATTENTION=true -e FLASH_ATTENTION_RECOMPUTE=true --cap-add=sys_nice -e no_proxy=$no_proxy -e http_proxy=$http_proxy -e https_proxy=$https_proxy \ + --ipc=host --pull always ghcr.io/huggingface/tgi-gaudi:2.0.5 --model-id $model --max-input-tokens 1024 --max-total-tokens 3000 + # extra time to load large model + echo "Waiting for tgi gaudi ready" + n=0 + until [[ "$n" -ge 300 ]] || [[ $ready == true ]]; do + docker logs test-comps-retrievers-neo4j-llama-index-tgi &> ${LOG_PATH}/tgi-gaudi-service.log + n=$((n+1)) + if grep -q Connected ${LOG_PATH}/tgi-gaudi-service.log; then + break + fi + sleep 5s + done + sleep 5s + echo "Service started successfully" + export TGI_LLM_ENDPOINT="http://${ip_address}:6005" + + # dataprep neo4j + # Not testing openai code path since not able to provide key for cicd + docker run -d --name="test-comps-retrievers-neo4j-llama-index-dataprep" -p 6004:6004 -v ./data:/data --ipc=host -e TGI_LLM_ENDPOINT=$TGI_LLM_ENDPOINT \ + -e TEI_EMBEDDING_ENDPOINT=$TEI_EMBEDDING_ENDPOINT -e EMBEDDING_MODEL_ID=$emb_model -e LLM_MODEL_ID=$model -e host_ip=$ip_address -e no_proxy=$no_proxy \ + -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e NEO4J_URI="bolt://${ip_address}:7687" -e NEO4J_USERNAME="neo4j" \ + -e NEO4J_PASSWORD="neo4jtest" -e LOGFLAG=True opea/dataprep-neo4j-llamaindex:comps + sleep 30s + export DATAPREP_SERVICE_ENDPOINT="http://${ip_address}:6004" + + # Neo4J retriever + # Not testing openai code path since not able to provide key for cicd + export NEO4J_URI="bolt://${ip_address}:7687" + export NEO4J_USERNAME="neo4j" + export NEO4J_PASSWORD="neo4jtest" + export no_proxy="localhost,127.0.0.1,"${ip_address} + docker run -d --name="test-comps-retrievers-neo4j-llama-index-server" -p 6009:6009 --ipc=host -e TGI_LLM_ENDPOINT=$TGI_LLM_ENDPOINT -e TEI_EMBEDDING_ENDPOINT=$TEI_EMBEDDING_ENDPOINT \ + -e EMBEDDING_MODEL_ID=$emb_model -e LLM_MODEL_ID=$model -e host_ip=$ip_address -e http_proxy=$http_proxy -e no_proxy=$no_proxy -e https_proxy=$https_proxy \ + -e NEO4J_URI="bolt://${ip_address}:7687" -e NEO4J_USERNAME="neo4j" -e NEO4J_PASSWORD="neo4jtest" -e LOGFLAG=True opea/retriever-neo4j-llamaindex:comps + + sleep 1m + +} + +function validate_service() { + local URL="$1" + local EXPECTED_RESULT="$2" + local SERVICE_NAME="$3" + local DOCKER_NAME="$4" + local INPUT_DATA="$5" + + if [[ $SERVICE_NAME == *"extract_graph_neo4j"* ]]; then + cd $LOG_PATH + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -F 'files=@./dataprep_file.txt' -H 'Content-Type: multipart/form-data' "$URL") + elif [[ $SERVICE_NAME == *"neo4j-apoc"* ]]; then + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" "$URL") + else + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -d "$INPUT_DATA" -H 'Content-Type: application/json' "$URL") + fi + HTTP_STATUS=$(echo $HTTP_RESPONSE | tr -d '\n' | sed -e 's/.*HTTPSTATUS://') + RESPONSE_BODY=$(echo $HTTP_RESPONSE | sed -e 's/HTTPSTATUS\:.*//g') + + docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log + + # check response status + if [ "$HTTP_STATUS" -ne "200" ]; then + echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS" + exit 1 + else + echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..." + fi + # check response body + if [[ "$SERVICE_NAME" == "neo4j-apoc" ]]; then + echo "[ $SERVICE_NAME ] Skipping content check for neo4j-apoc." + else + if [[ "$RESPONSE_BODY" != *"$EXPECTED_RESULT"* ]]; then + echo "[ $SERVICE_NAME ] Content does not match the expected result: $RESPONSE_BODY" + exit 1 + else + echo "[ $SERVICE_NAME ] Content is as expected." + fi + fi + + sleep 1s +} + +function validate_microservice() { + # validate neo4j-apoc + validate_service \ + "${ip_address}:7474" \ + "200 OK" \ + "neo4j-apoc" \ + "test-comps-retrievers-neo4j-llama-index-neo4j-apoc" \ + "" + sleep 1m # retrieval can't curl as expected, try to wait for more time + + # tgi for llm service + validate_service \ + "${ip_address}:6005/generate" \ + "generated_text" \ + "tgi-gaudi-service" \ + "test-comps-retrievers-neo4j-llama-index-tgi" \ + '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17, "do_sample": true}}' + + # test /v1/dataprep graph extraction + echo "The stock of company Chevron has declined about 10% over the past 90-days despite the fact that Q2 consensus earnings estimates have risen sharply (~25%) during that same time frame. Over the years, Chevron has kept a very strong balance sheet. FirstEnergy company posted its earnings results on Tuesday. The utilities provider reported $0.53 earnings per share for the quarter, topping the consensus estimate of $0.52 by $0.01, RTT News reports. FirstEnergy had a net margin of 10.85% and a return on equity of 17.17%. The Dáil was almost suspended on Thursday afternoon after Sinn Féin TD John Brady walked across the chamber and placed an on-call pager in front of the Minister for Housing Darragh O’Brien during a debate on retained firefighters. Darragh O’Brien said John Brady had taken part in an act of theatre that was obviously choreographed. Around 2,000 retained firefighters around the country staged a second day of industrial action on Tuesday and are due to start all out-strike action from next Tuesday. The mostly part-time workers, who keep the services going outside of Ireland’s larger urban centres, are taking industrial action in a dispute over pay and working conditions. Speaking in the Dáil, Sinn Féin deputy leader Pearse Doherty said firefighters had marched on Leinster House today and were very angry at the fact the Government will not intervene. Reintroduction of tax relief on mortgages needs to be considered, Darragh O’Brien says. Martin withdraws comment after saying People Before Profit would ‘put the jackboot on people’ Taoiseach ‘propagated fears’ farmers forced to rewet land due to nature restoration law – Cairns An intervention is required now. I’m asking you to make an improved offer in relation to pay for retained firefighters, Mr Doherty told the housing minister. I’m also asking you, and challenging you, to go outside after this Order of Business and meet with the firefighters because they are just fed up to the hilt in relation to what you said. Some of them have handed in their pagers to members of the Opposition and have challenged you to wear the pager for the next number of weeks, put up with an €8,600 retainer and not leave your community for the two and a half kilometres and see how you can stand over those type of pay and conditions. At this point, John Brady got up from his seat, walked across the chamber and placed the pager on the desk in front of Darragh O’Brien. Ceann Comhairle Seán Ó Fearghaíl said the Sinn Féin TD was completely out of order and told him not to carry out a charade in this House, adding it was absolutely outrageous behaviour and not to be encouraged. Darragh O’Brien said John Brady had engaged in an act of theatre here today which was obviously choreographed and was then interrupted with shouts from the Opposition benches. Mr Ó Fearghaíl said he would suspend the House if this racket continues. Darragh O’Brien later said he was confident the dispute could be resolved and he had immense regard for firefighters. The minister said he would encourage the unions to re-engage with the State’s industrial relations process while also accusing Sinn Féin of using the issue for their own political gain." > $LOG_PATH/dataprep_file.txt + validate_service \ + "http://${ip_address}:6004/v1/dataprep" \ + "Data preparation succeeded" \ + "extract_graph_neo4j" \ + "test-comps-retrievers-neo4j-llama-index-dataprep" + + # retrieval microservice + validate_service \ + "${ip_address}:6009/v1/retrieval" \ + "Retrieval of answers from community summaries successful" \ + "retriever_community_answers_neo4j" \ + "test-comps-retrievers-neo4j-llama-index-server" \ + "{\"messages\": [{\"role\": \"user\",\"content\": \"Who is John Brady and has he had any confrontations?\"}]}" + +} + +function stop_docker() { + cid_retrievers=$(docker ps -aq --filter "name=test-comps-retrievers-neo4j*") + if [[ ! -z "$cid_retrievers" ]]; then + docker stop $cid_retrievers && docker rm $cid_retrievers && sleep 1s + fi + cid_db=$(docker ps -aq --filter "name=test-comps-retrievers-neo4j-llama-index-neo4j-apoc") + if [[ ! -z "$cid_retrievers" ]]; then + docker stop $cid_retrievers && docker rm $cid_retrievers && sleep 1s + fi +} + +function main() { + + stop_docker + + build_docker_images + start_service + + validate_microservice + + stop_docker + echo y | docker system prune + +} + +main