-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
133 lines (110 loc) · 4.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from typing import List, Optional, Union
from pydantic import BaseModel
from typing import List, Optional
import uvicorn
import os
from dotenv import load_dotenv
from openai import OpenAI
from openinference.instrumentation.openai import OpenAIInstrumentor
from phoenix.otel import register
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
tracer_provider = register(
project_name="local-llm-trace", # Default is 'default'
endpoint="http://localhost:6006/v1/traces",
)
OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)
app = FastAPI(title="Local LLM Inspector",
description="OpenAI-compatible API proxy with LLM trace visualization")
# Initialize OpenAI client
client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL"))
class ContentText(BaseModel):
type: str
text: str
class ChatMessage(BaseModel):
role: str
content: Union[str, List[ContentText]]
# content: str
# content: List[ContentText]
class StreamOptions(BaseModel):
include_usage: bool
class ChatInput(BaseModel):
model: str
messages: List[ChatMessage] = []
temperature: Optional[float] = 0.5
stream: Optional[bool] = False
max_tokens: Optional[int] = 4096
stream_options: Optional[StreamOptions] = None
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
async def chat_completion(request: ChatInput):
# Validate request
if not request.model:
logger.error("Model field is required")
raise HTTPException(
status_code=422,
detail="model field is required"
)
if not request.messages:
logger.error("No messages provided in request")
for msg in request.messages:
if not isinstance(msg, role="user"):
logger.error("Invalid message format in request")
raise HTTPException(
status_code=422,
detail="Invalid request: missing required fields"
)
try:
if request.stream:
# Handle streaming response
async def stream_response():
try:
response = client.chat.completions.create(
model=request.model,
messages=[msg.model_dump() for msg in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens,
stream=True
)
logger.info("Streaming response from OpenAI API")
for chunk in response:
yield f"data: {chunk.json()}\n\n"
except Exception as e:
logger.error(f"Error during streaming: {str(e)}")
raise
return StreamingResponse(stream_response(), media_type="text/event-stream")
else:
# Handle non-streaming response
response = client.chat.completions.create(
model=request.model,
messages=[msg.model_dump() for msg in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens
)
logger.info("Received successful response from OpenAI API")
return response
except Exception as e:
logger.error(
f"OpenAI API error: {str(e)}",
exc_info=True
)
raise HTTPException(
status_code=500,
detail=f"OpenAI API error: {str(e)}"
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)