Mistral's models seem to do something really strange if you allow for `tool_choice` to be anything but `any`. They start putting the tool call inside the `content` instead of an actual tool call. This means that I need this `stop` mechanism using a tool call instead because I cannot trust the model to do it by itself. I quite like this model though, it's cheap, it's fast and it's open source. And all the answers are pretty good!
268 lines
5.8 KiB
Go
268 lines
5.8 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema any `json:"json_schema"`
|
|
}
|
|
|
|
type AgentRequestBody struct {
|
|
Model string `json:"model"`
|
|
Temperature float64 `json:"temperature"`
|
|
ResponseFormat ResponseFormat `json:"response_format"`
|
|
|
|
Tools *any `json:"tools,omitempty"`
|
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
|
|
|
EndToolCall string `json:"-"`
|
|
|
|
Chat *Chat `json:"messages"`
|
|
}
|
|
|
|
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
|
|
return json.Marshal(&struct {
|
|
Model string `json:"model"`
|
|
Temperature float64 `json:"temperature"`
|
|
ResponseFormat ResponseFormat `json:"response_format"`
|
|
|
|
Tools *any `json:"tools,omitempty"`
|
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
|
Messages []ChatMessage `json:"messages"`
|
|
}{
|
|
Model: req.Model,
|
|
Temperature: req.Temperature,
|
|
ResponseFormat: req.ResponseFormat,
|
|
|
|
Tools: req.Tools,
|
|
ToolChoice: req.ToolChoice,
|
|
|
|
Messages: req.Chat.Messages,
|
|
})
|
|
}
|
|
|
|
type ResponseChoice struct {
|
|
Index int `json:"index"`
|
|
Message ChatAiMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
type AgentResponse struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Choices []ResponseChoice `json:"choices"`
|
|
Created int `json:"created"`
|
|
}
|
|
|
|
type AgentClient struct {
|
|
url string
|
|
apiKey string
|
|
responseFormat string
|
|
|
|
ToolHandler ToolsHandlers
|
|
|
|
Log *log.Logger
|
|
|
|
Reply string
|
|
|
|
Do func(req *http.Request) (*http.Response, error)
|
|
|
|
Options CreateAgentClientOptions
|
|
}
|
|
|
|
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
|
|
|
type CreateAgentClientOptions struct {
|
|
Log *log.Logger
|
|
SystemPrompt string
|
|
JsonTools string
|
|
EndToolCall string
|
|
Query *string
|
|
}
|
|
|
|
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
|
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
|
|
|
if len(apiKey) == 0 {
|
|
panic("No api key")
|
|
}
|
|
|
|
return AgentClient{
|
|
apiKey: apiKey,
|
|
url: "https://api.mistral.ai/v1/chat/completions",
|
|
Do: func(req *http.Request) (*http.Response, error) {
|
|
client := &http.Client{}
|
|
return client.Do(req)
|
|
},
|
|
|
|
Log: options.Log,
|
|
|
|
ToolHandler: ToolsHandlers{
|
|
handlers: map[string]ToolHandler{},
|
|
},
|
|
|
|
Options: options,
|
|
}
|
|
}
|
|
|
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
|
if err != nil {
|
|
return req, err
|
|
}
|
|
|
|
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
|
req.Header.Add("Content-Type", "application/json")
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
|
|
jsonAiRequest, err := json.Marshal(req)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
httpRequest, err := client.getRequest(jsonAiRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
resp, err := client.Do(httpRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
response, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
agentResponse := AgentResponse{}
|
|
err = json.Unmarshal(response, &agentResponse)
|
|
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
if len(agentResponse.Choices) != 1 {
|
|
client.Log.Errorf("Received more than 1 choice from AI \n %s\n", string(response))
|
|
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
|
}
|
|
|
|
msg := agentResponse.Choices[0].Message
|
|
req.Chat.AddAiResponse(msg)
|
|
|
|
return agentResponse, nil
|
|
}
|
|
|
|
func (client *AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
|
for {
|
|
response, err := client.Request(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if response.Choices[0].FinishReason == "stop" {
|
|
client.Log.Debug("Agent is finished")
|
|
return nil
|
|
}
|
|
|
|
err = client.Process(info, req)
|
|
|
|
if err != nil {
|
|
|
|
if err == FinishedCall {
|
|
client.Log.Debug("Agent is finished")
|
|
}
|
|
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
var FinishedCall = errors.New("Last tool tool was called")
|
|
|
|
func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
|
var err error
|
|
|
|
message, err := req.Chat.GetLatest()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
aiMessage, ok := message.(ChatAiMessage)
|
|
if !ok {
|
|
return errors.New("Latest message isnt an AI message")
|
|
}
|
|
|
|
if aiMessage.ToolCalls == nil {
|
|
// Not an error, we just dont have any tool calls to process.
|
|
return nil
|
|
}
|
|
|
|
for _, toolCall := range *aiMessage.ToolCalls {
|
|
if toolCall.Function.Name == req.EndToolCall {
|
|
return FinishedCall
|
|
}
|
|
|
|
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
|
|
|
if toolCall.Function.Name == "reply" {
|
|
client.Reply = toolCall.Function.Arguments
|
|
}
|
|
|
|
client.Log.Debug("Tool call response", "toolCall", toolCall.Function.Name, "response", toolResponse.Content)
|
|
|
|
req.Chat.AddToolResponse(toolResponse)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
|
var tools any
|
|
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
toolChoice := "any"
|
|
|
|
request := AgentRequestBody{
|
|
Tools: &tools,
|
|
ToolChoice: &toolChoice,
|
|
Model: "pixtral-12b-2409",
|
|
Temperature: 0.3,
|
|
EndToolCall: client.Options.EndToolCall,
|
|
ResponseFormat: ResponseFormat{
|
|
Type: "text",
|
|
},
|
|
Chat: &Chat{
|
|
Messages: make([]ChatMessage, 0),
|
|
},
|
|
}
|
|
|
|
request.Chat.AddSystem(client.Options.SystemPrompt)
|
|
request.Chat.AddImage(imageName, imageData, client.Options.Query)
|
|
|
|
toolHandlerInfo := ToolHandlerInfo{
|
|
ImageId: imageId,
|
|
ImageName: imageName,
|
|
UserId: userId,
|
|
Image: &imageData,
|
|
}
|
|
|
|
return client.ToolLoop(toolHandlerInfo, &request)
|
|
}
|