Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sst and tts capabilities to agent #2007

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/silver-ends-fetch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@mastra/core': minor
---

add stt and tts capabilities on agent
111 changes: 110 additions & 1 deletion packages/core/src/agent/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { createOpenAI } from '@ai-sdk/openai';
import { config } from 'dotenv';
import { describe, it, expect, vi } from 'vitest';
import { PassThrough } from 'stream';

Check warning on line 3 in packages/core/src/agent/agent.test.ts

View workflow job for this annotation

GitHub Actions / Lint

`stream` import should occur before import of `@ai-sdk/openai`
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { z } from 'zod';

import { TestIntegration } from '../integration/openapi-toolset.mock';
import { Mastra } from '../mastra';
import { createTool } from '../tools';
import { CompositeVoice, MastraVoice } from '../voice';

import { Agent } from './index';

Expand Down Expand Up @@ -275,4 +277,111 @@
expect(sanitizedMessages).not.toContainEqual(toolCallThree);
expect(sanitizedMessages).toHaveLength(2);
});

describe('voice capabilities', () => {
class MockVoice extends MastraVoice {
async speak(_input: string | NodeJS.ReadableStream): Promise<NodeJS.ReadableStream> {
const stream = new PassThrough();
stream.end('mock audio');
return stream;
}

async listen(): Promise<string> {
return 'mock transcription';
}

async getSpeakers() {
return [{ voiceId: 'mock-voice' }];
}
}

let voiceAgent: Agent;
beforeEach(() => {
voiceAgent = new Agent({
name: 'Voice Agent',
instructions: 'You are an agent with voice capabilities',
model: openai('gpt-4o'),
voice: new CompositeVoice({
speakProvider: new MockVoice({
speaker: 'mock-voice',
}),
listenProvider: new MockVoice({
speaker: 'mock-voice',
}),
}),
});
});

describe('getSpeakers', () => {
it('should list available voices', async () => {
const speakers = await voiceAgent.getSpeakers();
expect(speakers).toEqual([{ voiceId: 'mock-voice' }]);
});
});

describe('speak', () => {
it('should generate audio stream from text', async () => {
const audioStream = await voiceAgent.speak('Hello World', {
speaker: 'mock-voice',
});

const chunks: Buffer[] = [];
for await (const chunk of audioStream) {
chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk));
}
const audioBuffer = Buffer.concat(chunks);

expect(audioBuffer.toString()).toBe('mock audio');
});

it('should work with different parameters', async () => {
const audioStream = await voiceAgent.speak('Test with parameters', {
speaker: 'mock-voice',
speed: 0.5,
});

const chunks: Buffer[] = [];
for await (const chunk of audioStream) {
chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk));
}
const audioBuffer = Buffer.concat(chunks);

expect(audioBuffer.toString()).toBe('mock audio');
});
});

describe('listen', () => {
it('should transcribe audio', async () => {
const audioStream = new PassThrough();
audioStream.end('test audio data');

const text = await voiceAgent.listen(audioStream);
expect(text).toBe('mock transcription');
});

it('should accept options', async () => {
const audioStream = new PassThrough();
audioStream.end('test audio data');

const text = await voiceAgent.listen(audioStream, {
language: 'en',
});
expect(text).toBe('mock transcription');
});
});

describe('error handling', () => {
it('should throw error when no voice provider is configured', async () => {
const agentWithoutVoice = new Agent({
name: 'No Voice Agent',
instructions: 'You are an agent without voice capabilities',
model: openai('gpt-4o'),
});

await expect(agentWithoutVoice.getSpeakers()).rejects.toThrow('No voice provider configured');
await expect(agentWithoutVoice.speak('Test')).rejects.toThrow('No voice provider configured');
await expect(agentWithoutVoice.listen(new PassThrough())).rejects.toThrow('No voice provider configured');
});
});
});
});
77 changes: 77 additions & 0 deletions packages/core/src/agent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
UserContent,
LanguageModelV1,
} from 'ai';
import { randomUUID } from 'crypto';

Check warning on line 12 in packages/core/src/agent/index.ts

View workflow job for this annotation

GitHub Actions / Lint

`crypto` import should occur before type import of `ai`
import type { JSONSchema7 } from 'json-schema';
import { z } from 'zod';
import type { ZodSchema } from 'zod';
Expand All @@ -26,6 +26,7 @@
import type { MemoryConfig, StorageThreadType } from '../memory/types';
import { InstrumentClass } from '../telemetry';
import type { CoreTool, ToolAction } from '../tools/types';
import type { CompositeVoice } from '../voice';

import type { AgentConfig, AgentGenerateOptions, AgentStreamOptions, ToolsetsInput } from './types';

Expand All @@ -47,6 +48,7 @@
/** @deprecated This property is deprecated. Use evals instead. */
metrics: TMetrics;
evals: TMetrics;
voice?: CompositeVoice;

constructor(config: AgentConfig<TTools, TMetrics>) {
super({ component: RegisteredLogger.AGENT });
Expand Down Expand Up @@ -86,6 +88,10 @@
if (config.memory) {
this.#memory = config.memory;
}

if (config.voice) {
this.voice = config.voice;
}
}

public hasOwnMemory(): boolean {
Expand Down Expand Up @@ -961,4 +967,75 @@
toolChoice,
}) as unknown as StreamReturn<Z>;
}

/**
* Convert text to speech using the configured voice provider
* @param input Text or text stream to convert to speech
* @param options Speech options including speaker and provider-specific options
* @returns Audio stream
*/
async speak(
input: string | NodeJS.ReadableStream,
options?: {
speaker?: string;
[key: string]: any;
},
): Promise<NodeJS.ReadableStream> {
if (!this.voice) {
throw new Error('No voice provider configured');
}
try {
return this.voice.speak(input, options);
} catch (e) {
this.logger.error('Error during agent speak', {
error: e,
});
throw e;
}
}

/**
* Convert speech to text using the configured voice provider
* @param audioStream Audio stream to transcribe
* @param options Provider-specific transcription options
* @returns Text or text stream
*/
async listen(
audioStream: NodeJS.ReadableStream,
options?: {
[key: string]: any;
},
): Promise<string | NodeJS.ReadableStream> {
if (!this.voice) {
throw new Error('No voice provider configured');
}
try {
return this.voice.listen(audioStream, options);
} catch (e) {
this.logger.error('Error during agent listen', {
error: e,
});
throw e;
}
}

/**
* Get a list of available speakers from the configured voice provider
* @throws {Error} If no voice provider is configured
* @returns {Promise<Array<{voiceId: string}>>} List of available speakers
*/
async getSpeakers() {
if (!this.voice) {
throw new Error('No voice provider configured');
}

try {
return await this.voice.getSpeakers();
} catch (e) {
this.logger.error('Error during agent getSpeakers', {
error: e,
});
throw e;
}
}
}
2 changes: 2 additions & 0 deletions packages/core/src/agent/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import type { CoreMessage, OutputType } from '../llm';
import type { MastraMemory } from '../memory/memory';
import type { MemoryConfig } from '../memory/types';
import type { ToolAction } from '../tools';
import type { CompositeVoice } from '../voice';

export type { Message as AiMessageType } from 'ai';

Expand All @@ -28,6 +29,7 @@ export interface AgentConfig<
metrics?: TMetrics;
evals?: TMetrics;
memory?: MastraMemory;
voice?: CompositeVoice;
}

export interface AgentGenerateOptions<Z extends ZodSchema | JSONSchema7 | undefined = undefined> {
Expand Down
33 changes: 33 additions & 0 deletions packages/core/src/voice/composite-voice.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { MastraVoice } from '.';

export class CompositeVoice extends MastraVoice {
protected speakProvider?: MastraVoice;
protected listenProvider?: MastraVoice;

constructor({ speakProvider, listenProvider }: { speakProvider?: MastraVoice; listenProvider?: MastraVoice }) {
super();
this.speakProvider = speakProvider;
this.listenProvider = listenProvider;
}

async speak(input: string | NodeJS.ReadableStream, options?: any) {
if (!this.speakProvider) {
throw new Error('No speak provider configured');
}
return this.speakProvider.speak(input, options);
}

async listen(audioStream: NodeJS.ReadableStream, options?: any) {
if (!this.listenProvider) {
throw new Error('No listen provider configured');
}
return this.listenProvider.listen(audioStream, options);
}

async getSpeakers() {
if (!this.speakProvider) {
throw new Error('No speak provider configured');
}
return this.speakProvider.getSpeakers();
}
}
83 changes: 2 additions & 81 deletions packages/core/src/voice/index.ts
Original file line number Diff line number Diff line change
@@ -1,81 +1,2 @@
import { MastraBase } from '../base';
import { InstrumentClass } from '../telemetry';

interface BuiltInModelConfig {
name: string;
apiKey?: string;
}

export interface VoiceConfig {
listeningModel?: BuiltInModelConfig;
speechModel?: BuiltInModelConfig;
speaker?: string;
}

@InstrumentClass({
prefix: 'voice',
excludeMethods: ['__setTools', '__setLogger', '__setTelemetry', '#log'],
})
export abstract class MastraVoice extends MastraBase {
protected listeningModel?: BuiltInModelConfig;
protected speechModel?: BuiltInModelConfig;
protected speaker?: string;

constructor({ listeningModel, speechModel, speaker }: VoiceConfig) {
super({
component: 'VOICE',
});
this.listeningModel = listeningModel;
this.speechModel = speechModel;
this.speaker = speaker;
}

traced<T extends Function>(method: T, methodName: string): T {
return (
this.telemetry?.traceMethod(method, {
spanName: `voice.${methodName}`,
attributes: {
'voice.type': this.speechModel?.name || this.listeningModel?.name || 'unknown',
},
}) ?? method
);
}

/**
* Convert text to speech
* @param input Text or text stream to convert to speech
* @param options Speech options including speaker and provider-specific options
* @returns Audio stream
*/
abstract speak(
input: string | NodeJS.ReadableStream,
options?: {
speaker?: string;
[key: string]: any;
},
): Promise<NodeJS.ReadableStream>;

/**
* Convert speech to text
* @param audioStream Audio stream to transcribe
* @param options Provider-specific transcription options
* @returns Text or text stream
*/
abstract listen(
audioStream: NodeJS.ReadableStream,
options?: {
[key: string]: any;
},
): Promise<string | NodeJS.ReadableStream>;

/**
* Get available speakers/voices
* @returns Array of available voice IDs and their metadata
*/
abstract getSpeakers(): Promise<
Array<{
voiceId: string;
[key: string]: any;
}>
>;
}
export * from './voice';
export * from './composite-voice';
Loading
Loading