Skip to content

Commit

Permalink
Variables for OpenAI API and Model (#1172)
Browse files Browse the repository at this point in the history
Created a variable for the OpenAI Base URL and model that will be used.
This allows for utilizing other services that adhere to the OpenAI API
spec, as well as testing out additional models that work best with the
function calling and requests.

I noticed that different model versions were used throughout the code
(e.g. `gpt-3.5-turbo-1106` vs `gpt-3.5-turbo-0613`. This patch only
included a single model for all locations. If different models were used
due to performance reasons, this can be separated out.
  • Loading branch information
drikster80 authored Feb 13, 2024
1 parent f166687 commit 086bc97
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
7 changes: 7 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
# Get key from https://dashboard.cohere.ai/api-keys
# COHERE_API_KEY=

OPENAI_API_BASE=https://api.openai.com/v1/
API_MODEL='gpt-3.5-turbo-1106'

# Get key from https://platform.openai.com/account/api-keys
# OPENAI_API_KEY=

# If you are using a different intrerence engine that adheres to the OpenAI, set it here
# OPENAI_API_BASE='http://127.0.0.1:8080/v1'
# OPENAI_VERSION='2023-05-15'

# Get key from https://makersuite.google.com/app/apikey

# HuggingFace demos: machine that uploads to HuggingFace.
Expand Down
11 changes: 8 additions & 3 deletions lilac/data/cluster_titling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from ..tasks import TaskInfo
from ..utils import chunks, log
from ..env import env

_TOP_K_CENTRAL_DOCS = 7
_TOP_K_CENTRAL_TITLES = 20
Expand Down Expand Up @@ -184,9 +185,11 @@ def request_with_retries() -> list[str]:
@functools.cache
def _openai_client() -> Any:
"""Get an OpenAI client."""
api_base = env('OPENAI_API_BASE')
try:
import openai


except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
Expand All @@ -196,7 +199,7 @@ def _openai_client() -> Any:
# OpenAI requests sometimes hang, without any errors, and the default connection timeout is 10
# mins, which is too long. Set it to 7 seconds (99%-tile for latency is 3-4 sec). Also set
# `max_retries` to 0 to disable internal retries so we handle retries ourselves.
return instructor.patch(openai.OpenAI(timeout=7, max_retries=0))
return instructor.patch(openai.OpenAI(timeout=7, max_retries=0, base_url=api_base))


class Title(BaseModel):
Expand Down Expand Up @@ -234,11 +237,12 @@ def generate_title_openai(ranked_docs: list[tuple[str, float]]) -> str:
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
api_model = env('API_MODEL')
max_tokens = _OPENAI_INITIAL_MAX_TOKENS
while max_tokens <= _OPENAI_FINAL_MAX_TOKENS:
try:
title = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Title,
temperature=0.0,
max_tokens=max_tokens,
Expand Down Expand Up @@ -295,11 +299,12 @@ def generate_category_openai(ranked_docs: list[tuple[str, float]]) -> str:
stop=stop_after_attempt(_NUM_RETRIES),
)
def request_with_retries() -> str:
api_model = env('API_MODEL')
max_tokens = _OPENAI_INITIAL_MAX_TOKENS
while max_tokens <= _OPENAI_FINAL_MAX_TOKENS:
try:
category = _openai_client().chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Category,
temperature=0.0,
max_tokens=max_tokens,
Expand Down
3 changes: 2 additions & 1 deletion lilac/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class OpenAIEmbedding(TextEmbeddingSignal):
@override
def setup(self) -> None:
api_key = env('OPENAI_API_KEY')
api_base = env('OPENAI_API_BASE')
azure_api_key = env('AZURE_OPENAI_KEY')
azure_api_version = env('AZURE_OPENAI_VERSION')
azure_api_endpoint = env('AZURE_OPENAI_ENDPOINT')
Expand All @@ -64,7 +65,7 @@ def setup(self) -> None:

else:
if api_key:
self._client = openai.OpenAI(api_key=api_key)
self._client = openai.OpenAI(api_key=api_key, base_url=api_base)
self._azure = False

elif azure_api_key:
Expand Down
10 changes: 6 additions & 4 deletions lilac/gen/generator_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

class OpenAIChatCompletionGenerator(TextGenerator):
"""An interface for OpenAI chat completion."""

model: str = 'gpt-3.5-turbo-0613'
model: str = env('API_MODEL')
response_description: str = ''

@override
Expand All @@ -25,11 +24,14 @@ def generate(self, prompt: str) -> str:
api_key = env('OPENAI_API_KEY')
api_type = env('OPENAI_API_TYPE')
api_version = env('OPENAI_API_VERSION')
api_base = env('OPENAI_API_BASE')
api_model = env('API_MODEL')
if not api_key:
raise ValueError('`OPENAI_API_KEY` environment variable not set.')

try:
import openai

except ImportError:
raise ImportError(
'Could not import the "openai" python package. '
Expand All @@ -41,15 +43,15 @@ def generate(self, prompt: str) -> str:
openai.api_version = api_version

# Enables response_model in the openai client.
client = instructor.patch(openai.OpenAI())
client = instructor.patch(openai.OpenAI(base_url=api_base))

class Completion(OpenAISchema):
"""Generated completion of a prompt."""

completion: str = Field(..., description=self.response_description)

return client.chat.completions.create(
model='gpt-3.5-turbo',
model=api_model,
response_model=Completion,
messages=[
{
Expand Down
7 changes: 5 additions & 2 deletions lilac/router_concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def generate_examples(description: str) -> list[str]:
api_type = env('OPENAI_API_TYPE')
api_version = env('OPENAI_API_VERSION')
api_engine = env('OPENAI_API_ENGINE_CHAT')
api_base = env('OPENAI_API_BASE')
api_model = env('API_MODEL')
if not api_key:
raise ValueError('`OPENAI_API_KEY` environment variable not set.')
try:
Expand All @@ -262,16 +264,17 @@ def generate_examples(description: str) -> list[str]:
openai.api_key = api_key
api_engine = api_engine


if api_type:
openai.api_type = api_type
openai.api_version = api_version

try:
# Enables response_model in the openai client.
client = instructor.patch(openai.OpenAI())
client = instructor.patch(openai.OpenAI(base_url=api_base))

completion = client.chat.completions.create(
model='gpt-3.5-turbo-1106',
model=api_model,
response_model=Examples,
temperature=0.0,
messages=[
Expand Down

0 comments on commit 086bc97

Please sign in to comment.