Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add image text model provider separation and fal.ai integration #650

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,7 @@ COINBASE_GENERATED_WALLET_HEX_SEED=
# TEE Configuration
DSTACK_SIMULATOR_ENDPOINT=
WALLET_SECRET_SALT=secret_salt

# fal.ai Configuration
FAL_API_KEY=
FAL_AI_LORA_PATH=
11 changes: 11 additions & 0 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import {
tradePlugin,
} from "@ai16z/plugin-coinbase";
import { confluxPlugin } from "@ai16z/plugin-conflux";
import { imageGenerationPlugin } from "@ai16z/plugin-image-generation";
import { evmPlugin } from "@ai16z/plugin-evm";
import { createNodePlugin } from "@ai16z/plugin-node";
import { solanaPlugin } from "@ai16z/plugin-solana";
Expand Down Expand Up @@ -225,6 +226,11 @@ export function getTokenForProvider(
character.settings?.secrets?.GROQ_API_KEY ||
settings.GROQ_API_KEY
);
case ModelProviderName.FAL:
return (
character.settings?.secrets?.FAL_API_KEY ||
settings.FAL_API_KEY
);
}
}

Expand Down Expand Up @@ -330,6 +336,11 @@ export function createAgent(
getSecret(character, "COINBASE_COMMERCE_KEY")
? coinbaseCommercePlugin
: null,
getSecret(character, "FAL_API_KEY") ||
getSecret(character, "OPENAI_API_KEY") ||
getSecret(character, "HEURIST_API_KEY")
? imageGenerationPlugin
: null,
...(getSecret(character, "COINBASE_API_KEY") &&
getSecret(character, "COINBASE_PRIVATE_KEY")
? [coinbaseMassPaymentsPlugin, tradePlugin]
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"@ai-sdk/groq": "0.0.3",
"@ai-sdk/openai": "1.0.4",
"@anthropic-ai/sdk": "0.30.1",
"@fal-ai/client": "^1.2.0",
"@types/uuid": "10.0.0",
"ai": "3.4.33",
"anthropic-vertex-ai": "1.0.2",
Expand Down
69 changes: 62 additions & 7 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
ModelProviderName,
ServiceType,
} from "./types.ts";
import { fal, } from "@fal-ai/client";

/**
* Send a message to the model for a text generateText - receive a string back and parse how you'd like
Expand Down Expand Up @@ -769,15 +770,22 @@ export const generateImage = async (
count = 1;
}

const model = getModel(runtime.character.modelProvider, ModelClass.IMAGE);
const modelSettings = models[runtime.character.modelProvider].imageSettings;
const apiKey =
runtime.token ??
runtime.getSetting("HEURIST_API_KEY") ??
const model = getModel(runtime.imageModelProvider, ModelClass.IMAGE);
const modelSettings = models[runtime.imageModelProvider].imageSettings;

elizaLogger.info("Generating image with options:", {
imageModelProvider: model,
});

const apiKey = runtime.imageModelProvider === runtime.modelProvider
? runtime.token
: runtime.getSetting("HEURIST_API_KEY") ??
runtime.getSetting("TOGETHER_API_KEY") ??
runtime.getSetting("FAL_API_KEY") ??
runtime.getSetting("OPENAI_API_KEY");

try {
if (runtime.character.modelProvider === ModelProviderName.HEURIST) {
if (runtime.imageModelProvider === ModelProviderName.HEURIST) {
const response = await fetch(
"http://sequencer.heurist.xyz/submit_job",
{
Expand Down Expand Up @@ -815,7 +823,7 @@ export const generateImage = async (
const imageURL = await response.json();
return { success: true, data: [imageURL] };
} else if (
runtime.character.modelProvider === ModelProviderName.LLAMACLOUD
runtime.imageModelProvider === ModelProviderName.LLAMACLOUD
) {
const together = new Together({ apiKey: apiKey as string });
const response = await together.images.create({
Expand Down Expand Up @@ -844,6 +852,53 @@ export const generateImage = async (
})
);
return { success: true, data: base64s };
} else if (runtime.imageModelProvider === ModelProviderName.FAL) {
fal.config({
credentials: apiKey as string
});

// Prepare the input parameters according to their schema
const input = {
prompt: prompt,
image_size: "square" as const,
num_inference_steps: modelSettings?.steps ?? 50,
guidance_scale: 3.5,
num_images: count,
enable_safety_checker: true,
output_format: "png" as const,
seed: data.seed ?? 6252023,
...(runtime.getSetting("FAL_AI_LORA_PATH") ? {
loras: [
{
path: runtime.getSetting("FAL_AI_LORA_PATH"),
scale: 1
}
]
} : {})
};

// Subscribe to the model
const result = await fal.subscribe(model, {
input,
logs: true,
onQueueUpdate: (update) => {
if (update.status === "IN_PROGRESS") {
console.log(update.logs.map((log) => log.message));
}
}
});

// Convert the returned image URLs to base64 to match existing functionality
const base64Promises = result.data.images.map(async (image) => {
const response = await fetch(image.url);
const blob = await response.blob();
const buffer = await blob.arrayBuffer();
const base64 = Buffer.from(buffer).toString('base64');
return `data:${image.content_type};base64,${base64}`;
});

const base64s = await Promise.all(base64Promises);
return { success: true, data: base64s };
} else {
let targetSize = `${width}x${height}`;
if (
Expand Down
20 changes: 20 additions & 0 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,26 @@ export const models: Models = {
[ModelClass.IMAGE]: "PepeXL",
},
},
[ModelProviderName.FAL]: {
settings: {
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
repetition_penalty: 0.4,
temperature: 0.7,
},
imageSettings: {
steps: 28,
},
endpoint: "https://api.fal.ai/v1",
model: {
[ModelClass.SMALL]: "", // FAL doesn't provide text models
[ModelClass.MEDIUM]: "",
[ModelClass.LARGE]: "",
[ModelClass.EMBEDDING]: "",
[ModelClass.IMAGE]: "fal-ai/flux-lora",
},
},
};

export function getModel(provider: ModelProviderName, type: ModelClass) {
Expand Down
10 changes: 10 additions & 0 deletions packages/core/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ export class AgentRuntime implements IAgentRuntime {
*/
modelProvider: ModelProviderName;

/**
* The model to use for generateImage.
*/
imageModelProvider: ModelProviderName;

/**
* Fetch function to use
* Some environments may not have access to the global fetch function and need a custom fetch override.
Expand Down Expand Up @@ -303,7 +308,12 @@ export class AgentRuntime implements IAgentRuntime {
opts.modelProvider ??
this.modelProvider;

this.imageModelProvider =
this.character.imageModelProvider ??
this.modelProvider;

elizaLogger.info("Selected model provider:", this.modelProvider);
elizaLogger.info("Selected image model provider:", this.imageModelProvider);

// Validate model provider
if (!Object.values(ModelProviderName).includes(this.modelProvider)) {
Expand Down
6 changes: 6 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ export type Models = {
[ModelProviderName.OPENROUTER]: Model;
[ModelProviderName.OLLAMA]: Model;
[ModelProviderName.HEURIST]: Model;
[ModelProviderName.FAL]: Model;
};

/**
Expand All @@ -218,6 +219,7 @@ export enum ModelProviderName {
OPENROUTER = "openrouter",
OLLAMA = "ollama",
HEURIST = "heurist",
FAL = "falai"
}

/**
Expand Down Expand Up @@ -610,6 +612,9 @@ export type Character = {
/** Model provider to use */
modelProvider: ModelProviderName;

/** Image model provider to use, if different from modelProvider */
imageModelProvider?: ModelProviderName;

/** Optional model endpoint override */
modelEndpointOverride?: string;

Expand Down Expand Up @@ -959,6 +964,7 @@ export interface IAgentRuntime {
databaseAdapter: IDatabaseAdapter;
token: string | null;
modelProvider: ModelProviderName;
imageModelProvider: ModelProviderName;
character: Character;
providers: Provider[];
actions: Action[];
Expand Down
14 changes: 12 additions & 2 deletions packages/plugin-image-generation/src/enviroment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ export const imageGenEnvSchema = z
ANTHROPIC_API_KEY: z.string().optional(),
TOGETHER_API_KEY: z.string().optional(),
HEURIST_API_KEY: z.string().optional(),
FAL_API_KEY: z.string().optional(),
OPENAI_API_KEY: z.string().optional(),
})
.refine(
(data) => {
return !!(
data.ANTHROPIC_API_KEY ||
data.TOGETHER_API_KEY ||
data.HEURIST_API_KEY
data.HEURIST_API_KEY ||
data.FAL_API_KEY ||
data.OPENAI_API_KEY
);
},
{
message:
"At least one of ANTHROPIC_API_KEY, TOGETHER_API_KEY, or HEURIST_API_KEY is required",
"At least one of ANTHROPIC_API_KEY, TOGETHER_API_KEY, HEURIST_API_KEY, FAL_API_KEY or OPENAI_API_KEY is required",
}
);

Expand All @@ -37,6 +41,12 @@ export async function validateImageGenConfig(
HEURIST_API_KEY:
runtime.getSetting("HEURIST_API_KEY") ||
process.env.HEURIST_API_KEY,
FAL_API_KEY:
runtime.getSetting("FAL_API_KEY") ||
process.env.FAL_API_KEY,
OPENAI_API_KEY:
runtime.getSetting("OPENAI_API_KEY") ||
process.env.OPENAI_API_KEY,
};

return imageGenEnvSchema.parse(config);
Expand Down
6 changes: 3 additions & 3 deletions packages/plugin-image-generation/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ const imageGeneration: Action = {
const anthropicApiKeyOk = !!runtime.getSetting("ANTHROPIC_API_KEY");
const togetherApiKeyOk = !!runtime.getSetting("TOGETHER_API_KEY");
const heuristApiKeyOk = !!runtime.getSetting("HEURIST_API_KEY");
const falApiKeyOk = !!runtime.getSetting("FAL_API_KEY");
const openAiApiKeyOk = !!runtime.getSetting("OPENAI_API_KEY");

// TODO: Add openai DALL-E generation as well

return anthropicApiKeyOk || togetherApiKeyOk || heuristApiKeyOk;
return anthropicApiKeyOk || togetherApiKeyOk || heuristApiKeyOk || falApiKeyOk || openAiApiKeyOk;
},
handler: async (
runtime: IAgentRuntime,
Expand Down