2025-04-09 13:56:30 +01:00

152 lines
3.2 KiB
Go

package client
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
)
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
Chat *Chat `json:"messages"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message ChatAiMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
type AgentClient struct {
url string
apiKey string
responseFormat string
ToolHandler ToolsHandlers
Do func(req *http.Request) (*http.Response, error)
}
const OPENAI_API_KEY = "OPENAI_API_KEY"
func CreateAgentClient() (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
ToolHandler: ToolsHandlers{
handlers: map[string]ToolHandler{},
},
}, nil
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(req)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
fmt.Println(string(response))
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
if len(agentResponse.Choices) != 1 {
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
}
req.Chat.AddAiResponse(agentResponse.Choices[0].Message)
return agentResponse, nil
}
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
var err error
message, err := req.Chat.GetLatest()
if err != nil {
return err
}
aiMessage, ok := message.(ChatAiMessage)
if !ok {
return errors.New("Latest message isnt an AI message")
}
if aiMessage.ToolCalls == nil {
// Not an error, we just dont have any tool calls to process.
return nil
}
for _, toolCall := range *aiMessage.ToolCalls {
toolResponse := client.ToolHandler.Handle(info, toolCall)
req.Chat.AddToolResponse(toolResponse)
}
return err
}