161 lines
3.3 KiB
Go
161 lines
3.3 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
)
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema any `json:"json_schema"`
|
|
}
|
|
|
|
type AgentRequestBody struct {
|
|
Model string `json:"model"`
|
|
Temperature float64 `json:"temperature"`
|
|
ResponseFormat ResponseFormat `json:"response_format"`
|
|
|
|
Tools *any `json:"tools,omitempty"`
|
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
|
|
|
Chat Chat `json:"messages"`
|
|
}
|
|
|
|
type ResponseChoice struct {
|
|
Index int `json:"index"`
|
|
Message ChatAiMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
type AgentResponse struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Choices []ResponseChoice `json:"choices"`
|
|
Created int `json:"created"`
|
|
}
|
|
|
|
type AgentClient struct {
|
|
url string
|
|
apiKey string
|
|
responseFormat string
|
|
|
|
ToolHandler ToolsHandlers
|
|
|
|
Do func(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
|
|
|
func CreateAgentClient() (AgentClient, error) {
|
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
|
|
|
if len(apiKey) == 0 {
|
|
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
|
}
|
|
|
|
return AgentClient{
|
|
apiKey: apiKey,
|
|
url: "https://api.mistral.ai/v1/chat/completions",
|
|
Do: func(req *http.Request) (*http.Response, error) {
|
|
client := &http.Client{}
|
|
return client.Do(req)
|
|
},
|
|
|
|
ToolHandler: ToolsHandlers{
|
|
handlers: map[string]ToolHandler{},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
|
if err != nil {
|
|
return req, err
|
|
}
|
|
|
|
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
|
req.Header.Add("Content-Type", "application/json")
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
|
|
jsonAiRequest, err := json.Marshal(chat)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
httpRequest, err := client.getRequest(jsonAiRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
resp, err := client.Do(httpRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
response, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
agentResponse := AgentResponse{}
|
|
err = json.Unmarshal(response, &agentResponse)
|
|
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
if len(agentResponse.Choices) != 1 {
|
|
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
|
}
|
|
|
|
chat.AddAiResponse(agentResponse.Choices[0].Message)
|
|
|
|
return agentResponse, nil
|
|
}
|
|
|
|
func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error {
|
|
var err error
|
|
|
|
message, err := chat.GetLatest()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
aiMessage, ok := message.(ChatAiMessage)
|
|
if !ok {
|
|
return errors.New("Latest message isnt an AI message")
|
|
}
|
|
|
|
if aiMessage.ToolCalls == nil {
|
|
// Not an error, we just dont have any tool calls to process.
|
|
return nil
|
|
}
|
|
|
|
for _, toolCall := range *aiMessage.ToolCalls {
|
|
toolResponse, err := client.ToolHandler.Handle(info, toolCall)
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
chat.AddToolResponse(toolResponse)
|
|
|
|
_, err = client.Request(chat)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
return err
|
|
}
|