From 77be705c42705cfbecaf31f732b691e261c1dd45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Fri, 21 Feb 2025 15:00:06 -0800 Subject: [PATCH] services: fix some TTS websocket service interruption handling --- CHANGELOG.md | 4 + src/pipecat/services/ai_services.py | 100 ++++++++++++++++++++-- src/pipecat/services/cartesia.py | 10 +-- src/pipecat/services/elevenlabs.py | 12 +-- src/pipecat/services/fish.py | 15 ++-- src/pipecat/services/lmnt.py | 9 +- src/pipecat/services/playht.py | 12 +-- src/pipecat/services/rime.py | 7 +- src/pipecat/services/websocket_service.py | 4 +- 9 files changed, 121 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a71ae5f02..3127a6a1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,10 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general")) ### Fixed +- Fixed a `ElevenLabsTTSService`, `FishAudioTTSService`, `LMNTTTSService` and + `PlayHTTTSService` issue that was resulting in audio requested before an + interruption being played after an interruption. + - Fixed `match_endofsentence` support for ellipses. - Fixed an issue that would cause undesired interruptions via diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 05c637126..30c5d137f 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -15,6 +15,7 @@ from pipecat.audio.utils import calculate_audio_volume, exp_smoothing from pipecat.frames.frames import ( AudioRawFrame, + BotStartedSpeakingFrame, BotStoppedSpeakingFrame, CancelFrame, EndFrame, @@ -40,6 +41,7 @@ from pipecat.metrics.metrics import MetricsData from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language from pipecat.utils.string import match_endofsentence from pipecat.utils.text.base_text_filter import BaseTextFilter @@ -434,6 +436,12 @@ async def _stop_frame_handler(self): class WordTTSService(TTSService): + """This is a base class for TTS services that support word timestamps. Word + timestamps are useful to synchronize audio with text of the spoken + words. This way only the spoken words are added to the conversation context. + + """ + def __init__(self, **kwargs): super().__init__(**kwargs) self._initial_word_timestamp = -1 @@ -503,11 +511,93 @@ async def _words_task_handler(self): self._words_queue.task_done() -class AudioContextWordTTSService(WordTTSService): - """This services allow us to send multiple TTS request to the services. Each - request could be multiple sentences long which are grouped by context. For - this to work, the TTS service needs to support handling multiple requests at - once (i.e. multiple simultaneous contexts). +class WebsocketTTSService(TTSService, WebsocketService): + """This is a base class for websocket-based TTS services.""" + + def __init__(self, **kwargs): + TTSService.__init__(self, **kwargs) + WebsocketService.__init__(self) + + +class InterruptibleTTSService(WebsocketTTSService): + """This is a base class for websocket-based TTS services that don't support + word timestamps and that don't offer a way to correlate the generated audio + to the requested text. + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Indicates if the bot is speaking. If the bot is not speaking we don't + # need to reconnect when the user speaks. If the bot is speaking and the + # user interrupts we need to reconnect. + self._bot_speaking = False + + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + if self._bot_speaking: + await self._disconnect() + await self._connect() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, BotStartedSpeakingFrame): + self._bot_speaking = True + elif isinstance(frame, BotStoppedSpeakingFrame): + self._bot_speaking = False + + +class WebsocketWordTTSService(WordTTSService, WebsocketService): + """This is a base class for websocket-based TTS services that support word + timestamps. + + """ + + def __init__(self, **kwargs): + WordTTSService.__init__(self, **kwargs) + WebsocketService.__init__(self) + + +class InterruptibleWordTTSService(WebsocketWordTTSService): + """This is a base class for websocket-based TTS services that support word + timestamps but don't offer a way to correlate the generated audio to the + requested text. + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Indicates if the bot is speaking. If the bot is not speaking we don't + # need to reconnect when the user speaks. If the bot is speaking and the + # user interrupts we need to reconnect. + self._bot_speaking = False + + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + if self._bot_speaking: + await self._disconnect() + await self._connect() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, BotStartedSpeakingFrame): + self._bot_speaking = True + elif isinstance(frame, BotStoppedSpeakingFrame): + self._bot_speaking = False + + +class AudioContextWordTTSService(WebsocketWordTTSService): + """This is a base class for websocket-based TTS services that support word + timestamps and also allow correlating the generated audio with the requested + text. + + Each request could be multiple sentences long which are grouped by + context. For this to work, the TTS service needs to support handling + multiple requests at once (i.e. multiple simultaneous contexts). The audio received from the TTS will be played in context order. That is, if we requested audio for a context "A" and then audio for context "B", the diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia.py index f9f2bdbe6..0ec55aa0c 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia.py @@ -13,22 +13,18 @@ from pydantic import BaseModel from pipecat.frames.frames import ( - BotStoppedSpeakingFrame, CancelFrame, EndFrame, ErrorFrame, Frame, - LLMFullResponseEndFrame, StartFrame, StartInterruptionFrame, TTSAudioRawFrame, - TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import AudioContextWordTTSService, TTSService -from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language # See .env.example for Cartesia configuration needed @@ -75,7 +71,7 @@ def language_to_cartesia_language(language: Language) -> Optional[str]: return result -class CartesiaTTSService(AudioContextWordTTSService, WebsocketService): +class CartesiaTTSService(AudioContextWordTTSService): class InputParams(BaseModel): language: Optional[Language] = Language.EN speed: Optional[Union[str, float]] = "" @@ -105,15 +101,13 @@ def __init__( # if we're interrupted. Cartesia gives us word-by-word timestamps. We # can use those to generate text frames ourselves aligned with the # playout timing of the audio! - AudioContextWordTTSService.__init__( - self, + super().__init__( aggregate_sentences=True, push_text_frames=False, pause_frame_processing=True, sample_rate=sample_rate, **kwargs, ) - WebsocketService.__init__(self) self._api_key = api_key self._cartesia_version = cartesia_version diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs.py index 58ec91342..47fa87b5a 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs.py @@ -14,22 +14,18 @@ from pydantic import BaseModel, model_validator from pipecat.frames.frames import ( - BotStoppedSpeakingFrame, CancelFrame, EndFrame, ErrorFrame, Frame, - LLMFullResponseEndFrame, StartFrame, StartInterruptionFrame, TTSAudioRawFrame, - TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import TTSService, WordTTSService -from pipecat.services.websocket_service import WebsocketService +from pipecat.services.ai_services import InterruptibleWordTTSService, TTSService from pipecat.transcriptions.language import Language # See .env.example for ElevenLabs configuration needed @@ -141,7 +137,7 @@ def calculate_word_times( return word_times -class ElevenLabsTTSService(WordTTSService, WebsocketService): +class ElevenLabsTTSService(InterruptibleWordTTSService): class InputParams(BaseModel): language: Optional[Language] = None optimize_streaming_latency: Optional[str] = None @@ -186,8 +182,7 @@ def __init__( # Finally, ElevenLabs doesn't provide information on when the bot stops # speaking for a while, so we want the parent class to send TTSStopFrame # after a short period not receiving any audio. - WordTTSService.__init__( - self, + super().__init__( aggregate_sentences=True, push_text_frames=False, push_stop_frames=True, @@ -195,7 +190,6 @@ def __init__( sample_rate=sample_rate, **kwargs, ) - WebsocketService.__init__(self) self._api_key = api_key self._url = url diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish.py index e2a75bdb2..1fd46b89c 100644 --- a/src/pipecat/services/fish.py +++ b/src/pipecat/services/fish.py @@ -22,8 +22,7 @@ TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import TTSService -from pipecat.services.websocket_service import WebsocketService +from pipecat.services.ai_services import InterruptibleTTSService from pipecat.transcriptions.language import Language try: @@ -40,7 +39,7 @@ FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"] -class FishAudioTTSService(TTSService, WebsocketService): +class FishAudioTTSService(InterruptibleTTSService): class InputParams(BaseModel): language: Optional[Language] = Language.EN latency: Optional[str] = "normal" # "normal" or "balanced" @@ -149,6 +148,11 @@ def _get_websocket(self): return self._websocket raise Exception("Websocket not connected") + async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + await self.stop_all_metrics() + self._request_id = None + async def _receive_messages(self): async for message in self._get_websocket(): try: @@ -168,11 +172,6 @@ async def _receive_messages(self): except Exception as e: logger.error(f"Error processing message: {e}") - async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection): - await super()._handle_interruption(frame, direction) - await self.stop_all_metrics() - self._request_id = None - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating Fish TTS: [{text}]") try: diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt.py index 0272690e8..026ca5f16 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt.py @@ -21,8 +21,7 @@ TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import TTSService -from pipecat.services.websocket_service import WebsocketService +from pipecat.services.ai_services import InterruptibleTTSService from pipecat.transcriptions.language import Language # See .env.example for LMNT configuration needed @@ -60,7 +59,7 @@ def language_to_lmnt_language(language: Language) -> Optional[str]: return result -class LmntTTSService(TTSService, WebsocketService): +class LmntTTSService(InterruptibleTTSService): def __init__( self, *, @@ -70,14 +69,12 @@ def __init__( language: Language = Language.EN, **kwargs, ): - TTSService.__init__( - self, + super().__init__( push_stop_frames=True, pause_frame_processing=True, sample_rate=sample_rate, **kwargs, ) - WebsocketService.__init__(self) self._api_key = api_key self._voice_id = voice_id diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht.py index 7f0b29c87..08f31dd67 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht.py @@ -16,22 +16,18 @@ from pydantic import BaseModel from pipecat.frames.frames import ( - BotStoppedSpeakingFrame, CancelFrame, EndFrame, ErrorFrame, Frame, - LLMFullResponseEndFrame, StartFrame, StartInterruptionFrame, TTSAudioRawFrame, - TTSSpeakFrame, TTSStartedFrame, TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import TTSService -from pipecat.services.websocket_service import WebsocketService +from pipecat.services.ai_services import InterruptibleTTSService, TTSService from pipecat.transcriptions.language import Language try: @@ -100,7 +96,7 @@ def language_to_playht_language(language: Language) -> Optional[str]: return result -class PlayHTTTSService(TTSService, WebsocketService): +class PlayHTTTSService(InterruptibleTTSService): class InputParams(BaseModel): language: Optional[Language] = Language.EN speed: Optional[float] = 1.0 @@ -118,13 +114,11 @@ def __init__( params: InputParams = InputParams(), **kwargs, ): - TTSService.__init__( - self, + super().__init__( pause_frame_processing=True, sample_rate=sample_rate, **kwargs, ) - WebsocketService.__init__(self) self._api_key = api_key self._user_id = user_id diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime.py index cc59b132d..ad41af311 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime.py @@ -26,7 +26,6 @@ ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import AudioContextWordTTSService, TTSService -from pipecat.services.websocket_service import WebsocketService from pipecat.transcriptions.language import Language try: @@ -55,7 +54,7 @@ def language_to_rime_language(language: Language) -> str: return LANGUAGE_MAP.get(language, "eng") -class RimeTTSService(AudioContextWordTTSService, WebsocketService): +class RimeTTSService(AudioContextWordTTSService): """Text-to-Speech service using Rime's websocket API. Uses Rime's websocket JSON API to convert text to speech with word-level timing @@ -92,8 +91,7 @@ def __init__( params: Additional configuration parameters. """ # Initialize with parent class settings for proper frame handling - AudioContextWordTTSService.__init__( - self, + super().__init__( aggregate_sentences=True, push_text_frames=False, push_stop_frames=True, @@ -101,7 +99,6 @@ def __init__( sample_rate=sample_rate, **kwargs, ) - WebsocketService.__init__(self) # Store service configuration self._api_key = api_key diff --git a/src/pipecat/services/websocket_service.py b/src/pipecat/services/websocket_service.py index a8dc97414..7836a7068 100644 --- a/src/pipecat/services/websocket_service.py +++ b/src/pipecat/services/websocket_service.py @@ -91,12 +91,12 @@ async def _receive_task_handler(self, report_error: Callable[[ErrorFrame], Await continue @abstractmethod - async def _connect_websocket(self): + async def _connect(self): """Implement service-specific websocket connection logic.""" pass @abstractmethod - async def _disconnect_websocket(self): + async def _disconnect(self): """Implement service-specific websocket disconnection logic.""" pass