Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chat/completions api suport vioce input/output #296

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions async-openai/src/types/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,48 @@ pub struct ChatCompletionRequestMessageContentPartImage {
pub image_url: ImageUrl,
}

#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum AudioFormat {
#[default]
Mp3,
Wav,
Flac,
Opus,
Pcm16,
}

#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(name = "Audio")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct AudioData {
/// Either base64 data of audio.
pub data: String,
/// Specifies the audio's format.
pub format: AudioFormat,
}


#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestMessageContentPartAudioArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct ChatCompletionRequestMessageContentPartAudio {
pub input_audio: AudioData,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ChatCompletionRequestUserMessageContentPart {
Text(ChatCompletionRequestMessageContentPartText),
ImageUrl(ChatCompletionRequestMessageContentPartImage),
InputAudio(ChatCompletionRequestMessageContentPartAudio),
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
Expand Down Expand Up @@ -337,6 +373,21 @@ pub struct ChatCompletionMessageToolCall {
pub function: FunctionCall,
}

#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(build_fn(error = "OpenAIError"))]
pub struct AudioResponse {
/// Unique identifier for this audio response.
pub id: String,
/// Base64 encoded audio bytes generated by the model, in the format
/// specified in the request.
pub data: String,
pub expires_at: u64,
/// Transcript of the audio generated by the model.
pub transcript: String,
}

/// A chat completion message generated by the model.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionResponseMessage {
Expand All @@ -347,6 +398,8 @@ pub struct ChatCompletionResponseMessage {
/// The tool calls generated by the model, such as function calls.
pub tool_calls: Option<Vec<ChatCompletionMessageToolCall>>,

pub audio: Option<AudioResponse>,

/// The role of the author of this message.
pub role: Role,

Expand Down Expand Up @@ -492,6 +545,38 @@ pub enum ServiceTierResponse {
Default,
}

#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ChatCompletionModalities {
Text,
Audio,
}

#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum VoiceType {
#[default]
Alloy,
Ash,
Ballad,
Coral,
Echo,
Sage,
Shimmer,
Verse,
}

#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize, PartialEq)]
#[builder(name = "AudioArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct AudioParams {
pub voice: VoiceType,
pub format: AudioFormat,
}

#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize, PartialEq)]
#[builder(name = "CreateChatCompletionRequestArgs")]
#[builder(pattern = "mutable")]
Expand Down Expand Up @@ -555,6 +640,19 @@ pub struct CreateChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,

/// What modalities are supported by gpt-4o-audio-preview
/// gpt-4o-audio-preview requires either audio output or audio input to be used at this time. Acceptable combinations of input and output are:

/// text in → text + audio out
/// audio in → text + audio out
/// audio in → text out
/// text + audio in → text + audio out
/// text + audio in → text out
pub modalities: Vec<ChatCompletionModalitie>,

pub audio: Option<AudioParams>,


/// This feature is in Beta.
/// If specified, our system will make a best effort to sample deterministically, such that repeated requests
/// with the same `seed` and parameters should return the same result.
Expand Down
11 changes: 11 additions & 0 deletions examples/chat-audio/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "chat-audio"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
async-openai = {path = "../../async-openai"}
base64 = "0.22.1"
serde_json = "1.0.117"
tokio = { version = "1.38.0", features = ["full"] }
91 changes: 91 additions & 0 deletions examples/chat-audio/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::{error::Error, io::Write};

use std::fs::File;
use std::io::Read;
use base64;

use async_openai::{
config::OpenAIConfig,
types::{
Audio,
AudioArgs,
AudioFormat,
AudioParams,
ChatCompletionRequestAssistantMessageArgs,
ChatCompletionRequestSystemMessageContentPart,
ChatCompletionRequestMessageContentPartAudioArgs,
ChatCompletionRequestMessageContentPartText,
ChatCompletionRequestMessageContentPartTextArgs,
ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestSystemMessageContent,
ChatCompletionRequestUserMessageArgs,
ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart,
CreateChatCompletionRequestArgs,
Modalitie,
VoiceType
},
Client,
};

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let client = Client::new();

let file_path = "./audio/test.mp3";

let mut file = File::open(file_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;

let audio_data = base64::encode(&buffer);

let request = CreateChatCompletionRequestArgs::default()
.model("gpt-4o-audio-preview")

.modalities(vec![Modalitie::Text, Modalitie::Audio])
.audio(AudioArgs::default().voice(VoiceType::Alloy).format(AudioFormat::Mp3).build()?)
.messages([
ChatCompletionRequestSystemMessageArgs::default()
.content(ChatCompletionRequestSystemMessageContent::Array(vec![
ChatCompletionRequestSystemMessageContentPart::Text(
ChatCompletionRequestMessageContentPartTextArgs::default().text(
"You are a helpful assistant.".to_string(),
).build()?
),
]))
.build()?
.into(),
ChatCompletionRequestUserMessageArgs::default()
.content(ChatCompletionRequestUserMessageContent::Array(vec![
ChatCompletionRequestUserMessageContentPart::Text(
ChatCompletionRequestMessageContentPartTextArgs::default()
.text("You are a helpful assistant.".to_string()).build()?
),
ChatCompletionRequestUserMessageContentPart::InputAudio(
ChatCompletionRequestMessageContentPartAudioArgs::default()
.input_audio(
Audio::default()
.data(audio_data.clone())
.format(AudioFormat::Mp3)
.build()?
)
.build()?
)
]))
.build()?
.into(),
])
.build()?;


let response = client.chat().create(request).await?;

let audio_data = response.choices[0].message.audio.as_ref().unwrap().data.clone();

let audio_data = base64::decode(audio_data)?;
let file_path = "./audio/test_response.mp3";
let mut file = File::create(file_path)?;
file.write_all(&audio_data)?;
Ok(())
}