Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

services: fix some TTS websocket service interruption handling #1272

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ stt = DeepgramSTTService(..., live_options=LiveOptions(model="nova-2-general"))

### Fixed

- Fixed a `ElevenLabsTTSService`, `FishAudioTTSService`, `LMNTTTSService` and
`PlayHTTTSService` issue that was resulting in audio requested before an
interruption being played after an interruption.

- Fixed `match_endofsentence` support for ellipses.

- Fixed an issue that would cause undesired interruptions via
Expand Down
100 changes: 95 additions & 5 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pipecat.audio.utils import calculate_audio_volume, exp_smoothing
from pipecat.frames.frames import (
AudioRawFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
Expand All @@ -40,6 +41,7 @@
from pipecat.metrics.metrics import MetricsData
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language
from pipecat.utils.string import match_endofsentence
from pipecat.utils.text.base_text_filter import BaseTextFilter
Expand Down Expand Up @@ -434,6 +436,12 @@ async def _stop_frame_handler(self):


class WordTTSService(TTSService):
"""This a base class for TTS services that support word timestamps. Word
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""This a base class for TTS services that support word timestamps. Word
"""This is a base class for TTS services that support word timestamps. Word

timestamps are useful to synchronize audio with text of the spoken
words. This way only the spoken words are added to the conversation context.

"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._initial_word_timestamp = -1
Expand Down Expand Up @@ -503,11 +511,93 @@ async def _words_task_handler(self):
self._words_queue.task_done()


class AudioContextWordTTSService(WordTTSService):
"""This services allow us to send multiple TTS request to the services. Each
request could be multiple sentences long which are grouped by context. For
this to work, the TTS service needs to support handling multiple requests at
once (i.e. multiple simultaneous contexts).
class WebsocketTTSService(TTSService, WebsocketService):
"""This a base class for websocket-based TTS services."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""This a base class for websocket-based TTS services."""
"""This is a base class for websocket-based TTS services."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same typo in other docstrings.


def __init__(self, **kwargs):
TTSService.__init__(self, **kwargs)
WebsocketService.__init__(self)


class InterruptibleTTSService(WebsocketTTSService):
"""This a base class for websocket-based TTS services that don't support
word timestamps and that don't offer a way to correlate the generated audio
to the requested text.

"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

# Indicates if the bot is speaking. If the bot is not speaking we don't
# need to reconnect when the user speaks. If the bot is speaking and the
# user interrupts we need to reconnect.
self._bot_speaking = False

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
if self._bot_speaking:
await self._disconnect()
await self._connect()

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, BotStartedSpeakingFrame):
self._bot_speaking = True
elif isinstance(frame, BotStoppedSpeakingFrame):
self._bot_speaking = False


class WebsocketWordTTSService(WordTTSService, WebsocketService):
"""This a base class for websocket-based TTS services that support word
timestamps.

"""

def __init__(self, **kwargs):
WordTTSService.__init__(self, **kwargs)
WebsocketService.__init__(self)


class InterruptibleWordTTSService(WebsocketWordTTSService):
"""This a base class for websocket-based TTS services that support word
timestamps but don't offer a way to correlate the generated audio to the
requested text.

"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

# Indicates if the bot is speaking. If the bot is not speaking we don't
# need to reconnect when the user speaks. If the bot is speaking and the
# user interrupts we need to reconnect.
self._bot_speaking = False

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
if self._bot_speaking:
await self._disconnect()
await self._connect()

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, BotStartedSpeakingFrame):
self._bot_speaking = True
elif isinstance(frame, BotStoppedSpeakingFrame):
self._bot_speaking = False


class AudioContextWordTTSService(WebsocketWordTTSService):
"""This a base class for websocket-based TTS services that support word
timestamps and also allow correlating the generated audio with the requested
text.

Each request could be multiple sentences long which are grouped by
context. For this to work, the TTS service needs to support handling
multiple requests at once (i.e. multiple simultaneous contexts).

The audio received from the TTS will be played in context order. That is, if
we requested audio for a context "A" and then audio for context "B", the
Expand Down
10 changes: 2 additions & 8 deletions src/pipecat/services/cartesia.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,18 @@
from pydantic import BaseModel

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AudioContextWordTTSService, TTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language

# See .env.example for Cartesia configuration needed
Expand Down Expand Up @@ -75,7 +71,7 @@ def language_to_cartesia_language(language: Language) -> Optional[str]:
return result


class CartesiaTTSService(AudioContextWordTTSService, WebsocketService):
class CartesiaTTSService(AudioContextWordTTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
speed: Optional[Union[str, float]] = ""
Expand Down Expand Up @@ -105,15 +101,13 @@ def __init__(
# if we're interrupted. Cartesia gives us word-by-word timestamps. We
# can use those to generate text frames ourselves aligned with the
# playout timing of the audio!
AudioContextWordTTSService.__init__(
self,
super().__init__(
aggregate_sentences=True,
push_text_frames=False,
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
WebsocketService.__init__(self)

self._api_key = api_key
self._cartesia_version = cartesia_version
Expand Down
12 changes: 3 additions & 9 deletions src/pipecat/services/elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,18 @@
from pydantic import BaseModel, model_validator

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService, WordTTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.services.ai_services import InterruptibleWordTTSService, TTSService
from pipecat.transcriptions.language import Language

# See .env.example for ElevenLabs configuration needed
Expand Down Expand Up @@ -141,7 +137,7 @@ def calculate_word_times(
return word_times


class ElevenLabsTTSService(WordTTSService, WebsocketService):
class ElevenLabsTTSService(InterruptibleWordTTSService):
class InputParams(BaseModel):
language: Optional[Language] = None
optimize_streaming_latency: Optional[str] = None
Expand Down Expand Up @@ -186,16 +182,14 @@ def __init__(
# Finally, ElevenLabs doesn't provide information on when the bot stops
# speaking for a while, so we want the parent class to send TTSStopFrame
# after a short period not receiving any audio.
WordTTSService.__init__(
self,
super().__init__(
aggregate_sentences=True,
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
WebsocketService.__init__(self)

self._api_key = api_key
self._url = url
Expand Down
15 changes: 7 additions & 8 deletions src/pipecat/services/fish.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.services.ai_services import InterruptibleTTSService
from pipecat.transcriptions.language import Language

try:
Expand All @@ -40,7 +39,7 @@
FishAudioOutputFormat = Literal["opus", "mp3", "pcm", "wav"]


class FishAudioTTSService(TTSService, WebsocketService):
class FishAudioTTSService(InterruptibleTTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
latency: Optional[str] = "normal" # "normal" or "balanced"
Expand Down Expand Up @@ -149,6 +148,11 @@ def _get_websocket(self):
return self._websocket
raise Exception("Websocket not connected")

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None

async def _receive_messages(self):
async for message in self._get_websocket():
try:
Expand All @@ -168,11 +172,6 @@ async def _receive_messages(self):
except Exception as e:
logger.error(f"Error processing message: {e}")

async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
await super()._handle_interruption(frame, direction)
await self.stop_all_metrics()
self._request_id = None

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating Fish TTS: [{text}]")
try:
Expand Down
9 changes: 3 additions & 6 deletions src/pipecat/services/lmnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.services.ai_services import InterruptibleTTSService
from pipecat.transcriptions.language import Language

# See .env.example for LMNT configuration needed
Expand Down Expand Up @@ -60,7 +59,7 @@ def language_to_lmnt_language(language: Language) -> Optional[str]:
return result


class LmntTTSService(TTSService, WebsocketService):
class LmntTTSService(InterruptibleTTSService):
def __init__(
self,
*,
Expand All @@ -70,14 +69,12 @@ def __init__(
language: Language = Language.EN,
**kwargs,
):
TTSService.__init__(
self,
super().__init__(
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
WebsocketService.__init__(self)

self._api_key = api_key
self._voice_id = voice_id
Expand Down
12 changes: 3 additions & 9 deletions src/pipecat/services/playht.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,18 @@
from pydantic import BaseModel

from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
LLMFullResponseEndFrame,
StartFrame,
StartInterruptionFrame,
TTSAudioRawFrame,
TTSSpeakFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import TTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.services.ai_services import InterruptibleTTSService, TTSService
from pipecat.transcriptions.language import Language

try:
Expand Down Expand Up @@ -100,7 +96,7 @@ def language_to_playht_language(language: Language) -> Optional[str]:
return result


class PlayHTTTSService(TTSService, WebsocketService):
class PlayHTTTSService(InterruptibleTTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
speed: Optional[float] = 1.0
Expand All @@ -118,13 +114,11 @@ def __init__(
params: InputParams = InputParams(),
**kwargs,
):
TTSService.__init__(
self,
super().__init__(
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
WebsocketService.__init__(self)

self._api_key = api_key
self._user_id = user_id
Expand Down
7 changes: 2 additions & 5 deletions src/pipecat/services/rime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AudioContextWordTTSService, TTSService
from pipecat.services.websocket_service import WebsocketService
from pipecat.transcriptions.language import Language

try:
Expand Down Expand Up @@ -55,7 +54,7 @@ def language_to_rime_language(language: Language) -> str:
return LANGUAGE_MAP.get(language, "eng")


class RimeTTSService(AudioContextWordTTSService, WebsocketService):
class RimeTTSService(AudioContextWordTTSService):
"""Text-to-Speech service using Rime's websocket API.

Uses Rime's websocket JSON API to convert text to speech with word-level timing
Expand Down Expand Up @@ -92,16 +91,14 @@ def __init__(
params: Additional configuration parameters.
"""
# Initialize with parent class settings for proper frame handling
AudioContextWordTTSService.__init__(
self,
super().__init__(
aggregate_sentences=True,
push_text_frames=False,
push_stop_frames=True,
pause_frame_processing=True,
sample_rate=sample_rate,
**kwargs,
)
WebsocketService.__init__(self)

# Store service configuration
self._api_key = api_key
Expand Down
Loading