refactor(ai-client): moving tool handling and client into seperate folders
This commit is contained in:
265
backend/agents/client/client.go
Normal file
265
backend/agents/client/client.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
)
|
||||
|
||||
type ImageInfo struct {
|
||||
Tags []string `json:"tags"`
|
||||
Text []string `json:"text"`
|
||||
Links []string `json:"links"`
|
||||
|
||||
Locations []model.Locations `json:"locations"`
|
||||
Events []model.Events `json:"events"`
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JsonSchema any `json:"json_schema"`
|
||||
}
|
||||
|
||||
type AgentRequestBody struct {
|
||||
Model string `json:"model"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ResponseFormat ResponseFormat `json:"response_format"`
|
||||
|
||||
Tools *any `json:"tools,omitempty"`
|
||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||
|
||||
// ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
|
||||
AgentMessages
|
||||
}
|
||||
|
||||
type AgentMessages struct {
|
||||
Messages []AgentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type AgentMessage interface {
|
||||
MessageToJson() ([]byte, error)
|
||||
}
|
||||
|
||||
type AgentTextMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
||||
// TODO: Validate the `Role`.
|
||||
return json.Marshal(textContent)
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type AgentAssistantToolCall struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(toolCall)
|
||||
}
|
||||
|
||||
type AgentArrayMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []AgentContent `json:"content"`
|
||||
}
|
||||
|
||||
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(arrayContent)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddText(message AgentTextMessage) {
|
||||
content.Messages = append(content.Messages, message)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
|
||||
content.Messages = append(content.Messages, toolCall)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
||||
extension := filepath.Ext(imageName)
|
||||
if len(extension) == 0 {
|
||||
// TODO: could also validate for image types we support.
|
||||
return errors.New("Image does not have extension")
|
||||
}
|
||||
|
||||
extension = extension[1:]
|
||||
|
||||
encodedString := base64.StdEncoding.EncodeToString(image)
|
||||
|
||||
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
|
||||
arrayMessage.Content[0] = AgentImage{
|
||||
ImageType: IMAGE_TYPE,
|
||||
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
||||
}
|
||||
|
||||
content.Messages = append(content.Messages, arrayMessage)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddSystem(prompt string) error {
|
||||
if len(content.Messages) != 0 {
|
||||
return errors.New("You can only add a system prompt at the beginning")
|
||||
}
|
||||
|
||||
content.Messages = append(content.Messages, AgentTextMessage{
|
||||
Role: ROLE_SYSTEM,
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AgentContent interface {
|
||||
ToJson() ([]byte, error)
|
||||
}
|
||||
|
||||
type ImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
type AgentImage struct {
|
||||
ImageType string `json:"type"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
}
|
||||
|
||||
func (imageMessage AgentImage) ToJson() ([]byte, error) {
|
||||
imageMessage.ImageType = IMAGE_TYPE
|
||||
return json.Marshal(imageMessage)
|
||||
}
|
||||
|
||||
type AiClient interface {
|
||||
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
|
||||
}
|
||||
|
||||
type ResponseChoiceMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
type ResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ResponseChoiceMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AgentResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Choices []ResponseChoice `json:"choices"`
|
||||
Created int `json:"created"`
|
||||
}
|
||||
|
||||
type AgentClient struct {
|
||||
url string
|
||||
apiKey string
|
||||
systemPrompt string
|
||||
responseFormat string
|
||||
|
||||
ToolHandler ToolsHandlers
|
||||
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||
const ROLE_USER = "user"
|
||||
const ROLE_SYSTEM = "system"
|
||||
const IMAGE_TYPE = "image_url"
|
||||
|
||||
func CreateAgentClient(prompt string) (AgentClient, error) {
|
||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||
|
||||
if len(apiKey) == 0 {
|
||||
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
||||
}
|
||||
|
||||
return AgentClient{
|
||||
apiKey: apiKey,
|
||||
url: "https://api.mistral.ai/v1/chat/completions",
|
||||
systemPrompt: prompt,
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
client := &http.Client{}
|
||||
return client.Do(req)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
|
||||
jsonAiRequest, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return AgentResponse{}, err
|
||||
}
|
||||
|
||||
httpRequest, err := client.getRequest(jsonAiRequest)
|
||||
if err != nil {
|
||||
return AgentResponse{}, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpRequest)
|
||||
if err != nil {
|
||||
return AgentResponse{}, err
|
||||
}
|
||||
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return AgentResponse{}, err
|
||||
}
|
||||
|
||||
agentResponse := AgentResponse{}
|
||||
err = json.Unmarshal(response, &agentResponse)
|
||||
|
||||
if err != nil {
|
||||
return AgentResponse{}, err
|
||||
}
|
||||
|
||||
log.Println(string(response))
|
||||
|
||||
toolCalls := agentResponse.Choices[0].Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
// Should for sure be more flexible.
|
||||
request.AddToolCall(AgentAssistantToolCall{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: toolCalls,
|
||||
})
|
||||
}
|
||||
|
||||
return agentResponse, nil
|
||||
}
|
||||
79
backend/agents/client/tools.go
Normal file
79
backend/agents/client/tools.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ToolHandlerInfo struct {
|
||||
UserId uuid.UUID
|
||||
ImageId uuid.UUID
|
||||
}
|
||||
|
||||
type ToolHandler struct {
|
||||
Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error)
|
||||
}
|
||||
|
||||
type ToolsHandlers struct {
|
||||
handlers *map[string]ToolHandler
|
||||
}
|
||||
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
|
||||
agentMessage := request.Messages[len(request.Messages)-1]
|
||||
|
||||
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
||||
if !ok {
|
||||
return "", errors.New("Latest message was not a tool call.")
|
||||
}
|
||||
|
||||
fnName := toolCall.ToolCalls[0].Function.Name
|
||||
arguments := toolCall.ToolCalls[0].Function.Arguments
|
||||
|
||||
fnHandler, exists := (*handler.handlers)[fnName]
|
||||
if !exists {
|
||||
return "", errors.New("Could not find tool with this name.")
|
||||
}
|
||||
|
||||
log.Printf("Calling: %s\n", fnName)
|
||||
res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.AddText(AgentTextMessage{
|
||||
Role: "tool",
|
||||
Name: "createLocation",
|
||||
Content: res,
|
||||
ToolCallId: toolCall.ToolCalls[0].Id,
|
||||
})
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) {
|
||||
(*handler.handlers)["createLocation"] = ToolHandler{
|
||||
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
|
||||
argsStruct := getArgs()
|
||||
|
||||
err := json.Unmarshal([]byte(args), &argsStruct)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
res, err := fn(info, argsStruct, call)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
marshalledRes, err := json.Marshal(res)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(marshalledRes), nil
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user