Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve embeddings #496

Merged
merged 5 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions .github/workflows/pre-release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
name: Release

on:
push:
branches:
- main
workflow_dispatch:
inputs:
release_type:
description: "Type of release (prerelease, prepatch, patch, minor, preminor, major)"
required: true
default: "prerelease"

jobs:
release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- uses: pnpm/action-setup@v3
with:
version: 8

- name: Configure Git
run: |
git config user.name "${{ github.actor }}"
git config user.email "${{ github.actor }}@users.noreply.github.com"

- name: "Setup npm for npmjs"
run: |
npm config set registry https://registry.npmjs.org/
echo "//registry.npmjs.org/:_authToken=${{ secrets.NPM_TOKEN }}" > ~/.npmrc

- name: Install Protobuf Compiler
run: sudo apt-get install -y protobuf-compiler

- name: Install dependencies
run: pnpm install

- name: Build packages
run: pnpm run build

- name: Tag and Publish Packages
id: tag_publish
run: |
RELEASE_TYPE=${{ github.event_name == 'push' && 'prerelease' || github.event.inputs.release_type }}
npx lerna version $RELEASE_TYPE --conventional-commits --yes --no-private --force-publish
npx lerna publish from-git --yes --dist-tag next

- name: Get Version Tag
id: get_tag
run: echo "TAG=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT

- name: Generate Release Body
id: release_body
run: |
if [ -f CHANGELOG.md ]; then
echo "body=$(cat CHANGELOG.md)" >> $GITHUB_OUTPUT
else
echo "body=No changelog provided for this release." >> $GITHUB_OUTPUT
fi

- name: Create GitHub Release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
PNPM_HOME: /home/runner/setup-pnpm/node_modules/.bin
with:
tag_name: ${{ steps.get_tag.outputs.TAG }}
release_name: Release
body_path: CHANGELOG.md
draft: false
prerelease: ${{ github.event_name == 'push' }}
59 changes: 49 additions & 10 deletions packages/adapter-postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
type IDatabaseCacheAdapter,
Participant,
DatabaseAdapter,
elizaLogger,
} from "@ai16z/eliza";
import fs from "fs";
import { fileURLToPath } from "url";
Expand All @@ -28,15 +29,50 @@ export class PostgresDatabaseAdapter
constructor(connectionConfig: any) {
super();

this.pool = new pg.Pool({
...connectionConfig,
const defaultConfig = {
max: 20,
idleTimeoutMillis: 30000,
connectionTimeoutMillis: 2000,
};

this.pool = new pg.Pool({
...defaultConfig,
...connectionConfig, // Allow overriding defaults
});

this.pool.on("error", (err) => {
console.error("Unexpected error on idle client", err);
this.pool.on("error", async (err) => {
elizaLogger.error("Unexpected error on idle client", err);

// Attempt to reconnect with exponential backoff
let retryCount = 0;
const maxRetries = 5;
const baseDelay = 1000; // Start with 1 second delay

while (retryCount < maxRetries) {
try {
const delay = baseDelay * Math.pow(2, retryCount);
elizaLogger.log(`Attempting to reconnect in ${delay}ms...`);
await new Promise((resolve) => setTimeout(resolve, delay));

// Create new pool with same config
this.pool = new pg.Pool(this.pool.options);
await this.testConnection();

elizaLogger.log("Successfully reconnected to database");
return;
} catch (error) {
retryCount++;
elizaLogger.error(
`Reconnection attempt ${retryCount} failed:`,
error
);
}
}

elizaLogger.error(
`Failed to reconnect after ${maxRetries} attempts`
);
throw new Error("Database connection lost and unable to reconnect");
});
}

Expand All @@ -51,7 +87,7 @@ export class PostgresDatabaseAdapter
);
await client.query(schema);
} catch (error) {
console.error(error);
elizaLogger.error(error);
throw error;
}
}
Expand All @@ -61,10 +97,13 @@ export class PostgresDatabaseAdapter
try {
client = await this.pool.connect();
const result = await client.query("SELECT NOW()");
console.log("Database connection test successful:", result.rows[0]);
elizaLogger.log(
"Database connection test successful:",
result.rows[0]
);
return true;
} catch (error) {
console.error("Database connection test failed:", error);
elizaLogger.error("Database connection test failed:", error);
throw new Error(`Failed to connect to database: ${error.message}`);
} finally {
if (client) client.release();
Expand Down Expand Up @@ -187,7 +226,7 @@ export class PostgresDatabaseAdapter
if (rows.length === 0) return null;

const account = rows[0];
console.log("account", account);
elizaLogger.log("account", account);
return {
...account,
details:
Expand Down Expand Up @@ -217,7 +256,7 @@ export class PostgresDatabaseAdapter
);
return true;
} catch (error) {
console.log("Error creating account", error);
elizaLogger.log("Error creating account", error);
return false;
} finally {
client.release();
Expand Down Expand Up @@ -370,7 +409,7 @@ export class PostgresDatabaseAdapter
values.push(params.count);
}

console.log("sql", sql, values);
elizaLogger.log("sql", sql, values);

const { rows } = await client.query(sql, values);
return rows.map((row) => ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,7 @@ const summarizeAction = {
const model = models[runtime.character.settings.model];
const chunkSize = model.settings.maxContextLength - 1000;

const chunks = await splitChunks(
formattedMemories,
chunkSize,
"gpt-4o-mini",
0
);
const chunks = await splitChunks(formattedMemories, chunkSize, 0);

const datestr = new Date().toUTCString().replace(/:/g, "-");

Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"gaxios": "6.7.1",
"glob": "11.0.0",
"js-sha1": "0.7.0",
"langchain": "^0.3.6",
"ollama-ai-provider": "^0.16.1",
"openai": "4.69.0",
"tiktoken": "1.0.17",
Expand Down
44 changes: 30 additions & 14 deletions packages/core/src/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async function getRemoteEmbedding(
: {}),
},
body: JSON.stringify({
input,
input: trimTokens(input, 8191, "gpt-4o-mini"),
model: options.model,
length: options.length || 384,
}),
Expand Down Expand Up @@ -70,25 +70,39 @@ async function getRemoteEmbedding(
* @param input The input to be embedded.
* @returns The embedding of the input.
*/
/**
* Generate embeddings for input text using configured model provider
* @param runtime The agent runtime containing model configuration
* @param input The text to generate embeddings for
* @returns Array of embedding numbers
*/
export async function embed(runtime: IAgentRuntime, input: string) {
// Get model provider configuration
const modelProvider = models[runtime.character.modelProvider];
//need to have env override for this to select what to use for embedding if provider doesnt provide or using openai

// Determine which embedding model to use:
// 1. OpenAI if USE_OPENAI_EMBEDDING is true
// 2. Provider's own embedding model if available
// 3. Fallback to OpenAI embedding model
const embeddingModel = settings.USE_OPENAI_EMBEDDING
? "text-embedding-3-small" // Use OpenAI if specified
: modelProvider.model?.[ModelClass.EMBEDDING] || // Use provider's embedding model if available
models[ModelProviderName.OPENAI].model[ModelClass.EMBEDDING]; // Fallback to OpenAI
? "text-embedding-3-small"
: modelProvider.model?.[ModelClass.EMBEDDING] ||
models[ModelProviderName.OPENAI].model[ModelClass.EMBEDDING];

if (!embeddingModel) {
throw new Error("No embedding model configured");
}

// // Try local embedding first
// Check if we're in Node.js environment
// Check if running in Node.js environment
const isNode =
typeof process !== "undefined" &&
process.versions != null &&
process.versions.node != null;

// Use local embedding if:
// - Running in Node.js
// - Not using OpenAI provider
// - Not forcing OpenAI embeddings
if (
isNode &&
runtime.character.modelProvider !== ModelProviderName.OPENAI &&
Expand All @@ -97,28 +111,30 @@ export async function embed(runtime: IAgentRuntime, input: string) {
return await getLocalEmbedding(input);
}

// Check cache
// Try to get cached embedding first
const cachedEmbedding = await retrieveCachedEmbedding(runtime, input);
if (cachedEmbedding) {
return cachedEmbedding;
}

// Get remote embedding
// Generate new embedding remotely
return await getRemoteEmbedding(input, {
model: embeddingModel,
// Use OpenAI endpoint if specified, otherwise use provider endpoint
endpoint: settings.USE_OPENAI_EMBEDDING
? "https://api.openai.com/v1" // Always use OpenAI endpoint when USE_OPENAI_EMBEDDING is true
? "https://api.openai.com/v1"
: runtime.character.modelEndpointOverride || modelProvider.endpoint,
// Use OpenAI API key if specified, otherwise use runtime token
apiKey: settings.USE_OPENAI_EMBEDDING
? settings.OPENAI_API_KEY // Use OpenAI key from settings when USE_OPENAI_EMBEDDING is true
: runtime.token, // Use runtime token for other providers
? settings.OPENAI_API_KEY
: runtime.token,
// Special handling for Ollama provider
isOllama:
runtime.character.modelProvider === ModelProviderName.OLLAMA &&
!settings.USE_OPENAI_EMBEDDING,
});
}

// TODO: Add back in when it can work in browser and locally
async function getLocalEmbedding(input: string): Promise<number[]> {
// Check if we're in Node.js environment
const isNode =
Expand Down Expand Up @@ -153,7 +169,7 @@ async function getLocalEmbedding(input: string): Promise<number[]> {
cacheDir: cacheDir,
});

const trimmedInput = trimTokens(input, 8000, "gpt-4o-mini");
const trimmedInput = trimTokens(input, 8191, "gpt-4o-mini");
const embedding = await embeddingModel.queryEmbed(trimmedInput);
return embedding;
} else {
Expand Down
Loading