diff --git a/comps/guardrails/llama_guard/README.md b/comps/guardrails/llama_guard/README.md index 94bdcd9524..084711125a 100644 --- a/comps/guardrails/llama_guard/README.md +++ b/comps/guardrails/llama_guard/README.md @@ -36,7 +36,7 @@ pip install -r requirements.txt export HF_TOKEN=${your_hf_api_token} export LANGCHAIN_TRACING_V2=true export LANGCHAIN_API_KEY=${your_langchain_api_key} -export LANGCHAIN_PROJECT="opea/gaurdrails" +export LANGCHAIN_PROJECT="opea/guardrails" volume=$PWD/data model_id="meta-llama/Meta-Llama-Guard-2-8B" docker pull ghcr.io/huggingface/tgi-gaudi:2.0.1 diff --git a/comps/guardrails/llama_guard/guardrails_tgi.py b/comps/guardrails/llama_guard/guardrails_tgi.py index 96a89b8c8a..b415876edd 100644 --- a/comps/guardrails/llama_guard/guardrails_tgi.py +++ b/comps/guardrails/llama_guard/guardrails_tgi.py @@ -2,13 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import os +from typing import List, Union from langchain_community.utilities.requests import JsonRequestsWrapper from langchain_huggingface import ChatHuggingFace from langchain_huggingface.llms import HuggingFaceEndpoint from langsmith import traceable -from comps import ServiceType, TextDoc, opea_microservices, register_microservice +from comps import GeneratedDoc, ServiceType, TextDoc, opea_microservices, register_microservice DEFAULT_MODEL = "meta-llama/LlamaGuard-7b" @@ -59,12 +60,17 @@ def get_tgi_service_model_id(endpoint_url, default=DEFAULT_MODEL): endpoint="/v1/guardrails", host="0.0.0.0", port=9090, - input_datatype=TextDoc, + input_datatype=Union[GeneratedDoc, TextDoc], output_datatype=TextDoc, ) @traceable(run_type="llm") -def safety_guard(input: TextDoc) -> TextDoc: - response_input_guard = llm_engine_hf.invoke([{"role": "user", "content": input.text}]).content +def safety_guard(input: Union[GeneratedDoc, TextDoc]) -> TextDoc: + if isinstance(input, GeneratedDoc): + messages = [{"role": "user", "content": input.prompt}, {"role": "assistant", "content": input.text}] + else: + messages = [{"role": "user", "content": input.text}] + response_input_guard = llm_engine_hf.invoke(messages).content + if "unsafe" in response_input_guard: unsafe_dict = get_unsafe_dict(llm_engine_hf.model_id) policy_violation_level = response_input_guard.split("\n")[1].strip() @@ -75,7 +81,6 @@ def safety_guard(input: TextDoc) -> TextDoc: ) else: res = TextDoc(text=input.text) - return res diff --git a/comps/guardrails/llama_guard/requirements.txt b/comps/guardrails/llama_guard/requirements.txt index 5fd992e663..5eda601705 100644 --- a/comps/guardrails/llama_guard/requirements.txt +++ b/comps/guardrails/llama_guard/requirements.txt @@ -1,6 +1,7 @@ docarray[full] fastapi -huggingface_hub +# Fix for issue with langchain-huggingface not using InferenceClient `base_url` kwarg +huggingface-hub<=0.24.0 langchain-community langchain-huggingface langsmith diff --git a/tests/test_guardrails_llama_guard.sh b/tests/test_guardrails_llama_guard.sh index 10b6465780..b174cbd644 100644 --- a/tests/test_guardrails_llama_guard.sh +++ b/tests/test_guardrails_llama_guard.sh @@ -25,7 +25,6 @@ function start_service() { sleep 4m docker run -d --name="test-comps-guardrails-langchain-service" -p 9090:9090 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy -e SAFETY_GUARD_MODEL_ID=$SAFETY_GUARD_MODEL_ID -e SAFETY_GUARD_ENDPOINT=$SAFETY_GUARD_ENDPOINT -e HUGGINGFACEHUB_API_TOKEN=$HF_TOKEN opea/guardrails-tgi:comps sleep 10s - echo "Microservice started" }