Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions nah/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::AppContext;
use crate::ModelConfig;
use futures_util::pin_mut;
use futures_util::stream::StreamExt;
use nah_chat::{ChatClient, ChatMessage, ToolCallRequest};
use nah_chat::{ChatClient, ChatCompletionParamsBuilder, ChatMessage, ToolCallRequest};
use serde_json::{json, Value};
use std::fs::{File, OpenOptions};
use tokio::runtime::{Builder, Runtime};
Expand Down Expand Up @@ -197,14 +197,13 @@ impl ChatContext {
* Generate assistant message.
*/
pub fn generate(&mut self) -> Result<&ChatMessage, NahError> {
let mut params = HashMap::from([
("max_token".to_owned(), json!(4096)),
("tools".to_owned(), json!(self.tools.clone())),
("n".to_owned(), json!(1)),
("temperature".to_owned(), json!(0.7)),
("top_p".to_owned(), json!(0.9)),
("frequency_penalty".to_owned(), json!(0.5)),
]);
let mut params_builder = ChatCompletionParamsBuilder::new();
params_builder
.max_token(4096)
.temperature(0.7)
.top_p(0.9)
.frequency_penalty(0.5)
.insert("tools", json!(self.tools.clone()));

self
.model_config
Expand All @@ -213,15 +212,15 @@ impl ChatContext {
.and_then(|v| v.as_object())
.and_then(|extra_params| {
extra_params.iter().for_each(|(key, value)| {
params.insert(key.to_owned(), value.to_owned());
params_builder.insert(key, value.to_owned());
});
Some(())
});

let message: Result<ChatMessage, NahError> = self.tokio_runtime.block_on(async {
let stream = match self
.chat_client
.chat_completion_stream(&self.model_config.model, &self.messages, &params)
.chat_completion_stream(&self.model_config.model, &self.messages, &params_builder)
.await
{
Ok(s) => s,
Expand Down
89 changes: 88 additions & 1 deletion nah_chat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use async_stream::stream;
use bytes::Bytes;
use futures_core::stream::Stream;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use serde_json::{Number, Value, json};

/**
* Error kinds that may occur in `nah_chat`.
Expand Down Expand Up @@ -282,6 +282,93 @@ pub struct FunctionCallRequestChunkDelta {
pub arguments: Option<String>,
}

/**
* A builder for creating parameters for chat completion requests.
*/
#[derive(Debug, Clone)]
pub struct ChatCompletionParamsBuilder {
data: std::collections::HashMap<String, Value>,
}

impl ChatCompletionParamsBuilder {
/**
* Initialize a [ChatCompletionParamsBuilder] object.
*/
pub fn new() -> Self {
ChatCompletionParamsBuilder {
data: std::collections::HashMap::new(),
}
}

/**
* Consume the data builder to get a hash map of the parameters for chat completion requests.
*/
pub fn build(self) -> std::collections::HashMap<String, Value> {
self.data
}

/**
* Set max token parameter.
*/
pub fn max_token(&mut self, n: usize) -> &mut Self {
self.data.insert(
"max_token".to_owned(),
Value::Number(Number::from_u128(n as u128).unwrap()),
);
self
}

/**
* Set temperature parameter.
*/
pub fn temperature(&mut self, t: f64) -> &mut Self {
self.data.insert(
"temperature".to_owned(),
Value::Number(Number::from_f64(t).unwrap()),
);
self
}

/**
* Set top_p parameter.
*/
pub fn top_p(&mut self, p: f64) -> &mut Self {
self.data.insert(
"top_p".to_owned(),
Value::Number(Number::from_f64(p).unwrap()),
);
self
}

/**
* Set frequency_penalty parameter.
*/
pub fn frequency_penalty(&mut self, p: f64) -> &mut Self {
self.data.insert(
"frequency_penalty".to_owned(),
Value::Number(Number::from_f64(p).unwrap()),
);
self
}

/**
* Set a parameter with key of `name` and value of `value`.
*/
pub fn insert(&mut self, name: &str, value: Value) -> &mut Self {
self.data.insert(name.to_owned(), value);
self
}
}

impl<'a> std::iter::IntoIterator for &'a ChatCompletionParamsBuilder {
type Item = (&'a String, &'a Value);
type IntoIter = std::collections::hash_map::Iter<'a, String, Value>;

fn into_iter(self) -> Self::IntoIter {
(&self.data).into_iter()
}
}

/**
* The object to hold information about the model server and `reqwest` HTTP client.
*/
Expand Down