Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update when needed #861

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 41 additions & 99 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provide
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::schemas::shinkai_preferences::ShinkaiInternalComms;
use shinkai_message_primitives::schemas::shinkai_tool_offering::{
AssetPayment, ToolPrice, UsageType, UsageTypeInquiry
AssetPayment, ToolPrice, UsageType, UsageTypeInquiry,
};
use shinkai_message_primitives::schemas::shinkai_tools::CodeLanguage;
use shinkai_message_primitives::schemas::wallet_mixed::{Asset, NetworkIdentifier};
Expand Down Expand Up @@ -224,108 +224,42 @@ impl ToolRouter {
let signing_secret_key = signing_secret_key.clone();
async move {
// Try to see if a tool with the same routerKey is already installed.
match db.get_tool_by_key(router_key) {
let do_install = match db.get_tool_by_key(router_key) {
Ok(existing_tool) => {
// Compare version numbers:
// The local version is existing_tool.version(),
// the remote version is new_version (string from the JSON).
// We parse them into IndexableVersion and compare.
let local_ver = match existing_tool.version_indexable() {
Ok(iv) => iv,
Err(e) => {
eprintln!("Failed to parse local version: {}", e);
return Ok::<(), ToolError>(());
}
};
let remote_ver = match IndexableVersion::from_string(new_version) {
Ok(iv) => iv,
Err(e) => {
eprintln!("Failed to parse remote version: {}", e);
return Ok::<(), ToolError>(());
}
};

if remote_ver > local_ver {
eprintln!(
"A newer version is available for tool {} (local: {}, remote: {}). Upgrading...",
router_key,
local_ver.to_string(),
remote_ver.to_string()
);

// Node::v2_api_import_tool_internal fetches the new tool code,
// builds a new ShinkaiTool, and returns it.
match Node::v2_api_import_tool_internal(
db.clone(),
node_env.clone(),
tool_url.to_string(),
node_name,
signing_secret_key,
)
.await
{
Ok(val) => {
// We stored the tool under val["tool"] in the JSON response
let new_tool: ShinkaiTool =
match serde_json::from_value::<ShinkaiTool>(val["tool"].clone()) {
Ok(tool) => tool,
Err(err) => {
eprintln!("Couldn't parse 'tool' field as ShinkaiTool: {}", err);
return Ok(());
}
};

match db.upgrade_tool(new_tool).await {
Ok(_) => {
println!("Upgraded tool {} to version {}", router_key, new_version);
}
Err(e) => {
eprintln!("Failed to upgrade tool {}: {:?}", router_key, e);
}
}
}
Err(e) => {
eprintln!("Failed to download tool {} for upgrade: {:?}", router_key, e);
}
}
} else {
// Versions are the same or the local one is newer:
println!("Tool already up-to-date: {} (version: {})", router_key, local_ver);
}
let local_ver = existing_tool.version_indexable()?;
let remote_ver = IndexableVersion::from_string(new_version)?;
Ok(remote_ver > local_ver)
}
Err(SqliteManagerError::ToolNotFound(_)) => {
// If the tool isn't found locally, import it anew
match Node::v2_api_import_tool_internal(
db.clone(),
node_env.clone(),
tool_url.to_string(),
node_name,
signing_secret_key,
)
.await
{
Ok(val) => {
// We stored the tool under val["tool"] in the JSON response
match serde_json::from_value::<ShinkaiTool>(val["tool"].clone()) {
Ok(_tool) => {
println!(
"Successfully imported tool {} (version: {})",
tool_name, new_version
);
}
Err(err) => {
eprintln!("Couldn't parse 'tool' field as ShinkaiTool: {}", err);
}
}
}
Err(e) => {
eprintln!("Failed to import tool {}: {:?}", tool_name, e);
}
}
Err(SqliteManagerError::ToolNotFound(_)) => Ok(true), // Update needed
Err(e) => Err(ToolError::DatabaseError(e.to_string())),
}?;

if !do_install {
// Skip installation
return Ok::<(), ToolError>(());
}

let val: Value = Node::v2_api_import_tool_internal(
db.clone(),
node_env.clone(),
tool_url.to_string(),
node_name,
signing_secret_key,
)
.await
.map_err(|e| ToolError::ExecutionError(e.message))?;

// We stored the tool under val["tool"] in the JSON response
match serde_json::from_value::<ShinkaiTool>(val["tool"].clone()) {
Ok(_tool) => {
println!("Successfully imported tool {} (version: {})", tool_name, new_version);
}
Err(e) => {
// Some DB error or other
eprintln!("Failed to get tool {}: {:?}", router_key, e);
Err(err) => {
eprintln!("Couldn't parse 'tool' field as ShinkaiTool: {}", err);
}
}
Ok::<(), ToolError>(())
Expand Down Expand Up @@ -407,9 +341,17 @@ impl ToolRouter {
None,
tool.tool_router_key,
);
if let Err(e) = self.sqlite_manager.add_tool(ShinkaiTool::Rust(rust_tool, true)).await {
eprintln!("Error adding rust tool: {}", e);
}

let _ = match self.sqlite_manager.get_tool_by_key(&rust_tool.tool_router_key) {
// TODO We have no good mechanism to check if the tool is up to date.
Err(SqliteManagerError::ToolNotFound(_)) => self
.sqlite_manager
.add_tool(ShinkaiTool::Rust(rust_tool, true))
.await
.map_err(|e| ToolError::DatabaseError(e.to_string())),
Err(e) => Err(ToolError::DatabaseError(e.to_string())),
Ok(_db_tool) => continue,
}?;
}
Ok(())
}
Expand Down
153 changes: 91 additions & 62 deletions shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1883,77 +1883,106 @@ impl Node {
tool.enable();
}

// check if any version of the tool exists in the database
let db_tool = match db.get_tool_by_key(&tool.tool_router_key().to_string_without_version()) {
Ok(tool) => Some(tool),
Err(_) => None,
};

// if the tool exists in the database, check if the version is the same or newer
if let Some(db_tool) = db_tool.clone() {
let version_db = db_tool.version_number()?;
let version_zip = tool.version_number()?;
if version_db >= version_zip {
// No need to update
return Ok(json!({
"status": "success",
"message": "Tool already up-to-date",
"tool_key": tool.tool_router_key().to_string_without_version(),
"tool": tool.clone()
}));
}
}

// Save the tool to the database
match db.add_tool(tool).await {
Ok(tool) => {
let archive_clone = zip_contents.archive.clone();
let files = archive_clone.file_names();
for file in files {
if file == "__tool.json" {
continue;
}
let mut buffer = Vec::new();
{
let file = zip_contents.archive.by_name(file);
let mut tool_file = match file {
Ok(file) => file,
Err(_) => {
return Err(APIError {
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Invalid Tool Archive".to_string(),
message: "Archive does not contain tool.json".to_string(),
});
}
};

// Read the tool file contents into a buffer
if let Err(err) = tool_file.read_to_end(&mut buffer) {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Read Error".to_string(),
message: format!("Failed to read tool.json contents: {}", err),
});
}
} // `tool_file` goes out of scope here
let tool = match db_tool {
None => db.add_tool(tool).await.map_err(|e| APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Database Error".to_string(),
message: format!("Failed to save tool to database: {}", e),
})?,
Some(_) => db.upgrade_tool(tool).await.map_err(|e| APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Database Error".to_string(),
message: format!("Failed to upgrade tool: {}", e),
})?,
};

let mut file_path = PathBuf::from(&node_env.node_storage_path.clone().unwrap_or_default())
.join(".tools_storage")
.join("tools")
.join(tool.tool_router_key().convert_to_path());
if !file_path.exists() {
let s = std::fs::create_dir_all(&file_path);
if s.is_err() {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Failed to create directory".to_string(),
message: format!("Failed to create directory: {}", s.err().unwrap()),
});
}
}
file_path.push(file);
let s = std::fs::write(&file_path, &buffer);
if s.is_err() {
let archive_clone = zip_contents.archive.clone();
let files = archive_clone.file_names();
for file in files {
if file.contains("__MACOSX/") {
continue;
}
if file == "__tool.json" {
continue;
}
let mut buffer = Vec::new();
{
let file = zip_contents.archive.by_name(file);
let mut tool_file = match file {
Ok(file) => file,
Err(_) => {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Failed to write file".to_string(),
message: format!("Failed to write file: {}", s.err().unwrap()),
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Invalid Tool Archive".to_string(),
message: "Archive does not contain tool.json".to_string(),
});
}
}
};

Ok(json!({
"status": "success",
"message": "Tool imported successfully",
"tool_key": tool.tool_router_key().to_string_without_version(),
"tool": tool
}))
// Read the tool file contents into a buffer
if let Err(err) = tool_file.read_to_end(&mut buffer) {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Read Error".to_string(),
message: format!("Failed to read tool.json contents: {}", err),
});
}
} // `tool_file` goes out of scope here

let mut file_path = PathBuf::from(&node_env.node_storage_path.clone().unwrap_or_default())
.join(".tools_storage")
.join("tools")
.join(tool.tool_router_key().convert_to_path());
if !file_path.exists() {
let s = std::fs::create_dir_all(&file_path);
if s.is_err() {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Failed to create directory".to_string(),
message: format!("Failed to create directory: {}", s.err().unwrap()),
});
}
}
file_path.push(file);
let s = std::fs::write(&file_path, &buffer);
if s.is_err() {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Failed to write file".to_string(),
message: format!("Failed to write file: {}", s.err().unwrap()),
});
}
Err(err) => Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Database Error".to_string(),
message: format!("Failed to save tool to database: {}", err),
}),
}

Ok(json!({
"status": "success",
"message": "Tool imported successfully",
"tool_key": tool.tool_router_key().to_string_without_version(),
"tool": tool
}))
}

/// Resolves a Shinkai file protocol URL into actual file bytes.
Expand Down