diff --git a/src/main.rs b/src/main.rs index ab4738e..cfd2cf1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,7 @@ use crate::{ error::ServerError, server::RustDocsServer, // Import the updated RustDocsServer }; -use async_openai::Client as OpenAIClient; +use async_openai::{Client as OpenAIClient, config::OpenAIConfig}; use bincode::config; use cargo::core::PackageIdSpec; use clap::Parser; // Import clap Parser @@ -182,7 +182,12 @@ async fn main() -> Result<(), ServerError> { let mut documents_for_server: Vec = loaded_documents_from_cache.unwrap_or_default(); // --- Initialize OpenAI Client (needed for question embedding even if cache hit) --- - let openai_client = OpenAIClient::new(); + let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") { + let config = OpenAIConfig::new().with_api_base(api_base); + OpenAIClient::with_config(config) + } else { + OpenAIClient::new() + }; OPENAI_CLIENT .set(openai_client.clone()) // Clone the client for the OnceCell .expect("Failed to set OpenAI client"); @@ -209,12 +214,10 @@ async fn main() -> Result<(), ServerError> { documents_for_server = loaded_documents.clone(); eprintln!("Generating embeddings..."); - let (generated_embeddings, total_tokens) = generate_embeddings( - &openai_client, - &loaded_documents, - "text-embedding-3-small", - ) - .await?; + let embedding_model: String = env::var("EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + let (generated_embeddings, total_tokens) = + generate_embeddings(&openai_client, &loaded_documents, &embedding_model).await?; let cost_per_million = 0.02; let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million; diff --git a/src/server.rs b/src/server.rs index eea72e9..9e886ca 100644 --- a/src/server.rs +++ b/src/server.rs @@ -175,8 +175,10 @@ impl RustDocsServer { .get() .ok_or_else(|| McpError::internal_error("OpenAI client not initialized", None))?; + let embedding_model: String = + env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-3-small".to_string()); let question_embedding_request = CreateEmbeddingRequestArgs::default() - .model("text-embedding-3-small") + .model(embedding_model) .input(question.to_string()) .build() .map_err(|e| { @@ -223,8 +225,10 @@ impl RustDocsServer { doc.content, question ); + let llm_model: String = env::var("LLM_MODEL") + .unwrap_or_else(|_| "gpt-4o-mini-2024-07-18".to_string()); let chat_request = CreateChatCompletionRequestArgs::default() - .model("gpt-4o-mini-2024-07-18") + .model(llm_model) .messages(vec![ ChatCompletionRequestSystemMessageArgs::default() .content(system_prompt) @@ -379,5 +383,4 @@ impl ServerHandler for RustDocsServer { resource_templates: Vec::new(), // No templates defined yet }) } - }