Skip to content

Commit

Permalink
FireworksAI: support via custom OpenAI on https://api.fireworks.ai/in…
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Feb 18, 2025
1 parent cdf4c96 commit 9f372eb
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/common/util/dMessageUtils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,12 @@ export function prettyShortChatModelName(model: string | undefined): string {
if (model.includes('grok-beta')) return 'Grok Beta';
if (model.includes('grok-vision-beta')) return 'Grok Vision Beta';
}
// [FireworksAI]
if (model.includes('accounts/')) {
const index = model.indexOf('accounts/');
const subStr = model.slice(index + 9);
return subStr.replaceAll('/models/', ' · ').replaceAll(/[_-]/g, ' ');
}
return model;
}

Expand Down
28 changes: 28 additions & 0 deletions src/modules/llms/server/openai/fireworksai.wiretypes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { z } from 'zod';


// [Fireworks AI] Models List API - Response

export const wireFireworksAIListOutputSchema = z.array(z.object({

id: z.string(),
object: z.literal('model'),
owned_by: z.union([
z.literal('fireworks'),
z.literal('yi-01-ai'),
z.string(),
]),
created: z.number(),
kind: z.union([
z.literal('HF_BASE_MODEL'),
z.literal('HF_PEFT_ADDON'),
z.literal('FLUMINA_BASE_MODEL'),
z.string(),
]).optional(),
// these seem to be there all the time, but just in case make them optional
supports_chat: z.boolean().optional(),
supports_image_input: z.boolean().optional(),
supports_tools: z.boolean().optional(),
// Not all models have this, so make it optional
context_length: z.number().optional(),
}));
89 changes: 89 additions & 0 deletions src/modules/llms/server/openai/models/fireworksai.models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { DModelInterfaceV1, LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '~/common/stores/llms/llms.types';

import { serverCapitalizeFirstLetter } from '~/server/wire';

import type { ModelDescriptionSchema } from '../../llm.server.types';

import { fromManualMapping, ManualMappings } from './models.data';
import { wireFireworksAIListOutputSchema } from '../fireworksai.wiretypes';


export function fireworksAIHeuristic(hostname: string) {
return hostname.includes('fireworks.ai/');
}


const _fireworksKnownModels: ManualMappings = [
// NOTE: we don't need manual patching as we have enough info for now
] as const;

const _fireworksDenyListContains: string[] = [
// nothing to deny for now
] as const;


function _prettyModelId(id: string, isVision: boolean): string {
// example: "accounts/fireworks/models/llama-v3p1-405b-instruct" => "Fireworks · Llama V3p1 405b Instruct"
let prettyName = id
.replace(/^accounts\//, '') // remove the leading "accounts/" if present
.replace(/\/models\//, ' · ') // turn the next "/models/" into " · "
.replaceAll(/[_-]/g, ' ') // replace underscores or dashes with spaces
.split(' ')
.filter(piece => piece !== 'instruct')
.map(serverCapitalizeFirstLetter)
.join(' ')
.replaceAll('/', ' · ') // replace any additional slash with " · "
.trim();
// add "Vision" to the name if it's a vision model
if (isVision && !id.includes('-vision'))
prettyName += ' Vision';
prettyName = prettyName.replace(' Vision', ' (Vision)');
return prettyName;
}


export function fireworksAIModelsToModelDescriptions(wireModels: unknown): ModelDescriptionSchema[] {
return wireFireworksAIListOutputSchema
.parse(wireModels)

.filter((model) => {
// filter-out non-llms
if (model.supports_chat === false)
return false;

return !_fireworksDenyListContains.some(contains => model.id.includes(contains));
})

.map((model): ModelDescriptionSchema => {

// heuristics
const label = _prettyModelId(model.id, !!model.supports_image_input);
const description = `${model.owned_by} \`${model.kind || 'unknown'}\` type.`;
const contextWindow = model.context_length || null;
const interfaces: DModelInterfaceV1[] = [LLM_IF_OAI_Chat];
if (model.supports_image_input)
interfaces.push(LLM_IF_OAI_Vision);
if (model.supports_tools)
interfaces.push(LLM_IF_OAI_Fn);

return fromManualMapping(_fireworksKnownModels, model.id, model.created, undefined, {
idPrefix: model.id,
label,
description,
contextWindow,
interfaces,
// parameterSpecs: ...
// maxCompletionTokens: ...
// trainingDataCutoff: ...
// benchmark: ...
// chatPrice,
hidden: false,
});
})

.sort((a: ModelDescriptionSchema, b: ModelDescriptionSchema): number => {
if (a.created !== b.created)
return (b.created || 0) - (a.created || 0);
return a.id.localeCompare(b.id);
});
}
8 changes: 7 additions & 1 deletion src/modules/llms/server/openai/openai.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { TRPCError } from '@trpc/server';
import { createTRPCRouter, publicProcedure } from '~/server/trpc/trpc.server';
import { env } from '~/server/env.mjs';
import { fetchJsonOrTRPCThrow } from '~/server/trpc/trpc.router.fetchers';
import { serverCapitalizeFirstLetter } from '~/server/wire';

import { T2iCreateImageOutput, t2iCreateImagesOutputSchema } from '~/modules/t2i/t2i.server';

Expand All @@ -15,6 +16,7 @@ import { OpenAIWire_API_Images_Generations, OpenAIWire_API_Models_List, OpenAIWi
import { ListModelsResponse_schema, ModelDescriptionSchema } from '../llm.server.types';
import { azureModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openAISortModels } from './models/openai.models';
import { deepseekModelFilter, deepseekModelSort, deepseekModelToModelDescription } from './models/deepseek.models';
import { fireworksAIHeuristic, fireworksAIModelsToModelDescriptions } from './models/fireworksai.models';
import { groqModelFilter, groqModelSortFn, groqModelToModelDescription } from './models/groq.models';
import { lmStudioModelToModelDescription, localAIModelToModelDescription, localAIModelSortFn } from './models/models.data';
import { mistralModelsSort, mistralModelToModelDescription } from './models/mistral.models';
Expand All @@ -24,7 +26,6 @@ import { perplexityAIModelDescriptions, perplexityAIModelSort } from './models/p
import { togetherAIModelsToModelDescriptions } from './models/together.models';
import { wilreLocalAIModelsApplyOutputSchema, wireLocalAIModelsAvailableOutputSchema, wireLocalAIModelsListOutputSchema } from './localai.wiretypes';
import { xaiModelDescriptions, xaiModelSort } from './models/xai.models';
import { serverCapitalizeFirstLetter } from '~/server/wire';


const openAIDialects = z.enum([
Expand Down Expand Up @@ -190,6 +191,11 @@ export const llmOpenAIRouter = createTRPCRouter({

// [OpenAI]: chat-only models, custom sort, manual mapping
case 'openai':

// [FireworksAI] special case for model enumeration
if (fireworksAIHeuristic(access.oaiHost))
return { models: fireworksAIModelsToModelDescriptions(openAIModels) };

models = openAIModels

// limit to only 'gpt' and 'non instruct' models
Expand Down

0 comments on commit 9f372eb

Please sign in to comment.