Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the logic for creating LLM adapters. #97

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 10 additions & 31 deletions src/pipecat_flows/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ async def generate_summary(
def create_adapter(llm) -> LLMAdapter:
"""Create appropriate adapter based on LLM service type.

Uses lazy imports to avoid requiring all provider dependencies at runtime.
Only the dependency for the chosen provider needs to be installed.
Use "stringy" checks to check the LLM service type instead of trying to import each provider
dependency, as those will produce scary errors on the console even when they're not needed.

Args:
llm: LLM service instance
Expand All @@ -552,35 +552,14 @@ def create_adapter(llm) -> LLMAdapter:
Raises:
ValueError: If LLM type is not supported or required dependency not installed
"""
# Try OpenAI
try:
from pipecat.services.openai import OpenAILLMService

if isinstance(llm, OpenAILLMService):
logger.debug("Creating OpenAI adapter")
return OpenAIAdapter()
except ImportError as e:
logger.debug(f"OpenAI import failed: {e}")

# Try Anthropic
try:
from pipecat.services.anthropic import AnthropicLLMService

if isinstance(llm, AnthropicLLMService):
logger.debug("Creating Anthropic adapter")
return AnthropicAdapter()
except ImportError as e:
logger.debug(f"Anthropic import failed: {e}")

# Try Google
try:
from pipecat.services.google import GoogleLLMService

if isinstance(llm, GoogleLLMService):
logger.debug("Creating Google adapter")
return GeminiAdapter()
except ImportError as e:
logger.debug(f"Google import failed: {e}")
if type(llm).__name__ == "OpenAILLMService":
return OpenAIAdapter()

if type(llm).__name__ == "AnthropicLLMService":
return AnthropicAdapter()

if type(llm).__name__ == "GoogleLLMService":
return GeminiAdapter()

# If we get here, either the LLM type is not supported or the required dependency is not installed
llm_type = type(llm).__name__
Expand Down
6 changes: 3 additions & 3 deletions tests/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ def test_gemini_adapter(self):
def test_adapter_factory(self):
"""Test adapter creation based on LLM service type."""
# Test with valid LLM services
openai_llm = MagicMock(spec=OpenAILLMService)
openai_llm = OpenAILLMService(api_key="")
self.assertIsInstance(create_adapter(openai_llm), OpenAIAdapter)

anthropic_llm = MagicMock(spec=AnthropicLLMService)
anthropic_llm = AnthropicLLMService(api_key="")
self.assertIsInstance(create_adapter(anthropic_llm), AnthropicAdapter)

gemini_llm = MagicMock(spec=GoogleLLMService)
gemini_llm = GoogleLLMService(api_key="")
self.assertIsInstance(create_adapter(gemini_llm), GeminiAdapter)

def test_adapter_factory_error_cases(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_context_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def asyncSetUp(self):
self.mock_task = AsyncMock()

# Set up mock LLM with client
self.mock_llm = MagicMock(spec=OpenAILLMService)
self.mock_llm = OpenAILLMService(api_key="")
self.mock_llm._client = MagicMock()
self.mock_llm._client.chat = MagicMock()
self.mock_llm._client.chat.completions = MagicMock()
Expand Down Expand Up @@ -195,7 +195,7 @@ async def test_provider_specific_summary_formatting(self):
# Test OpenAI format
flow_manager = FlowManager(
task=self.mock_task,
llm=MagicMock(spec=OpenAILLMService),
llm=OpenAILLMService(api_key=""),
context_aggregator=self.mock_context_aggregator,
)
openai_message = flow_manager.adapter.format_summary_message(summary)
Expand All @@ -204,7 +204,7 @@ async def test_provider_specific_summary_formatting(self):
# Test Anthropic format
flow_manager = FlowManager(
task=self.mock_task,
llm=MagicMock(spec=AnthropicLLMService),
llm=AnthropicLLMService(api_key=""),
context_aggregator=self.mock_context_aggregator,
)
anthropic_message = flow_manager.adapter.format_summary_message(summary)
Expand All @@ -213,7 +213,7 @@ async def test_provider_specific_summary_formatting(self):
# Test Gemini format
flow_manager = FlowManager(
task=self.mock_task,
llm=MagicMock(spec=GoogleLLMService),
llm=GoogleLLMService(api_key=""),
context_aggregator=self.mock_context_aggregator,
)
gemini_message = flow_manager.adapter.format_summary_message(summary)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class TestFlowManager(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""Set up test fixtures before each test."""
self.mock_task = AsyncMock()
self.mock_llm = MagicMock(spec=OpenAILLMService)
self.mock_llm = OpenAILLMService(api_key="")
self.mock_llm.register_function = MagicMock()
self.mock_tts = AsyncMock()

# Create mock context aggregator
Expand Down