diff --git a/openai-hs/src/OpenAI/Client.hs b/openai-hs/src/OpenAI/Client.hs index d1a4f6b..5fc557a 100644 --- a/openai-hs/src/OpenAI/Client.hs +++ b/openai-hs/src/OpenAI/Client.hs @@ -31,6 +31,8 @@ module OpenAI.Client ChatFunction (..), ChatFunctionCall (..), ChatFunctionCallStrategy (..), + ChatMessageContent (..), + ChatMessageContentPart (..), ChatMessage (..), ChatCompletionRequest (..), ChatChoice (..), diff --git a/openai-servant/src/OpenAI/Resources.hs b/openai-servant/src/OpenAI/Resources.hs index 76275c2..2d06d76 100644 --- a/openai-servant/src/OpenAI/Resources.hs +++ b/openai-servant/src/OpenAI/Resources.hs @@ -23,6 +23,8 @@ module OpenAI.Resources ChatFunction (..), ChatFunctionCall (..), ChatFunctionCallStrategy (..), + ChatMessageContent (..), + ChatMessageContentPart (..), ChatMessage (..), ChatCompletionRequest (..), ChatChoice (..), @@ -86,8 +88,10 @@ module OpenAI.Resources where import qualified Data.Aeson as A +import qualified Data.Aeson.Types as A import qualified Data.ByteString.Lazy as BSL import Data.Maybe (catMaybes) +import Data.String (IsString(fromString)) import qualified Data.Text as T import qualified Data.Text.Encoding as T import Data.Time @@ -251,8 +255,54 @@ instance A.ToJSON ChatFunctionCall where "arguments" A..= T.decodeUtf8 (BSL.toStrict (A.encode arguments)) ] +newtype ChatMessageContent = ChatMessageContent [ChatMessageContentPart] + deriving (Eq, Show, Semigroup, Monoid) + +instance IsString ChatMessageContent where + fromString s = ChatMessageContent [fromString s] + +instance A.FromJSON ChatMessageContent where + parseJSON (A.String s) = pure $ ChatMessageContent [CMCP_Text s] + parseJSON (A.Array a) = ChatMessageContent <$> traverse A.parseJSON (V.toList a) + parseJSON invalid = A.typeMismatch "String or Array" invalid + +instance A.ToJSON ChatMessageContent where + toJSON (ChatMessageContent xs) = go xs where + go [] = A.Array V.empty + go [CMCP_Text s] = A.String s + go xs' = A.Array . V.fromList $ map A.toJSON xs' + +data ChatMessageContentPart = CMCP_Text T.Text + | CMCP_Image { imageUrl :: T.Text, imageDetail :: Maybe T.Text } + deriving (Eq, Show) + +instance IsString ChatMessageContentPart where + fromString = CMCP_Text . T.pack + +instance A.FromJSON ChatMessageContentPart where + parseJSON (A.String s) = pure $ CMCP_Text s + parseJSON (A.Object o) = o A..: "type" >>= A.withText "Type" go where + go "text" = CMCP_Text <$> o A..: "text" + go "image_url" = o A..: "image_url" >>= img + go ty = A.unexpected $ A.String ty + img (A.String s) = pure $ CMCP_Image s Nothing + img (A.Object o') = CMCP_Image <$> o' A..: "url" <*> o' A..:? "detail" + img invalid = A.typeMismatch "String or Object" invalid + parseJSON invalid = A.typeMismatch "ChatMessageContentPart" invalid + +instance A.ToJSON ChatMessageContentPart where + toJSON (CMCP_Text s) = A.object ["type" A..= ("text" :: T.Text), "text" A..= s] + toJSON (CMCP_Image u d) = + A.object [ "type" A..= ("image_url" :: T.Text), + "image_url" A..= case d of + Nothing -> A.String u + Just detail -> A.object [ "url" A..= u, + "detail" A..= detail + ] + ] + data ChatMessage = ChatMessage - { chmContent :: Maybe T.Text, + { chmContent :: Maybe ChatMessageContent, chmRole :: T.Text, chmFunctionCall :: Maybe ChatFunctionCall, chmName :: Maybe T.Text