diff --git a/nah/src/chat.rs b/nah/src/chat.rs index ca48537..f277d38 100644 --- a/nah/src/chat.rs +++ b/nah/src/chat.rs @@ -17,7 +17,7 @@ use crate::AppContext; use crate::ModelConfig; use futures_util::pin_mut; use futures_util::stream::StreamExt; -use nah_chat::{ChatClient, ChatMessage, ToolCallRequest}; +use nah_chat::{ChatClient, ChatCompletionParamsBuilder, ChatMessage, ToolCallRequest}; use serde_json::{json, Value}; use std::fs::{File, OpenOptions}; use tokio::runtime::{Builder, Runtime}; @@ -197,14 +197,13 @@ impl ChatContext { * Generate assistant message. */ pub fn generate(&mut self) -> Result<&ChatMessage, NahError> { - let mut params = HashMap::from([ - ("max_token".to_owned(), json!(4096)), - ("tools".to_owned(), json!(self.tools.clone())), - ("n".to_owned(), json!(1)), - ("temperature".to_owned(), json!(0.7)), - ("top_p".to_owned(), json!(0.9)), - ("frequency_penalty".to_owned(), json!(0.5)), - ]); + let mut params_builder = ChatCompletionParamsBuilder::new(); + params_builder + .max_token(4096) + .temperature(0.7) + .top_p(0.9) + .frequency_penalty(0.5) + .insert("tools", json!(self.tools.clone())); self .model_config @@ -213,7 +212,7 @@ impl ChatContext { .and_then(|v| v.as_object()) .and_then(|extra_params| { extra_params.iter().for_each(|(key, value)| { - params.insert(key.to_owned(), value.to_owned()); + params_builder.insert(key, value.to_owned()); }); Some(()) }); @@ -221,7 +220,7 @@ impl ChatContext { let message: Result = self.tokio_runtime.block_on(async { let stream = match self .chat_client - .chat_completion_stream(&self.model_config.model, &self.messages, ¶ms) + .chat_completion_stream(&self.model_config.model, &self.messages, ¶ms_builder) .await { Ok(s) => s, diff --git a/nah_chat/src/lib.rs b/nah_chat/src/lib.rs index 987b07e..f314fd6 100644 --- a/nah_chat/src/lib.rs +++ b/nah_chat/src/lib.rs @@ -61,7 +61,7 @@ use async_stream::stream; use bytes::Bytes; use futures_core::stream::Stream; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::{Number, Value, json}; /** * Error kinds that may occur in `nah_chat`. @@ -282,6 +282,93 @@ pub struct FunctionCallRequestChunkDelta { pub arguments: Option, } +/** + * A builder for creating parameters for chat completion requests. + */ +#[derive(Debug, Clone)] +pub struct ChatCompletionParamsBuilder { + data: std::collections::HashMap, +} + +impl ChatCompletionParamsBuilder { + /** + * Initialize a [ChatCompletionParamsBuilder] object. + */ + pub fn new() -> Self { + ChatCompletionParamsBuilder { + data: std::collections::HashMap::new(), + } + } + + /** + * Consume the data builder to get a hash map of the parameters for chat completion requests. + */ + pub fn build(self) -> std::collections::HashMap { + self.data + } + + /** + * Set max token parameter. + */ + pub fn max_token(&mut self, n: usize) -> &mut Self { + self.data.insert( + "max_token".to_owned(), + Value::Number(Number::from_u128(n as u128).unwrap()), + ); + self + } + + /** + * Set temperature parameter. + */ + pub fn temperature(&mut self, t: f64) -> &mut Self { + self.data.insert( + "temperature".to_owned(), + Value::Number(Number::from_f64(t).unwrap()), + ); + self + } + + /** + * Set top_p parameter. + */ + pub fn top_p(&mut self, p: f64) -> &mut Self { + self.data.insert( + "top_p".to_owned(), + Value::Number(Number::from_f64(p).unwrap()), + ); + self + } + + /** + * Set frequency_penalty parameter. + */ + pub fn frequency_penalty(&mut self, p: f64) -> &mut Self { + self.data.insert( + "frequency_penalty".to_owned(), + Value::Number(Number::from_f64(p).unwrap()), + ); + self + } + + /** + * Set a parameter with key of `name` and value of `value`. + */ + pub fn insert(&mut self, name: &str, value: Value) -> &mut Self { + self.data.insert(name.to_owned(), value); + self + } +} + +impl<'a> std::iter::IntoIterator for &'a ChatCompletionParamsBuilder { + type Item = (&'a String, &'a Value); + type IntoIter = std::collections::hash_map::Iter<'a, String, Value>; + + fn into_iter(self) -> Self::IntoIter { + (&self.data).into_iter() + } +} + /** * The object to hold information about the model server and `reqwest` HTTP client. */