refactor(ai-client): moving tool handling and client into seperate folders

This commit is contained in:
2025-04-04 22:03:46 +01:00
parent 8a165c2042
commit 71d4581110
9 changed files with 401 additions and 906 deletions

View File

@ -1,481 +0,0 @@
package agents
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
)
type ImageInfo struct {
Tags []string `json:"tags"`
Text []string `json:"text"`
Links []string `json:"links"`
Locations []model.Locations `json:"locations"`
Events []model.Events `json:"events"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
// ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
AgentMessages
}
type AgentMessages struct {
Messages []AgentMessage `json:"messages"`
}
type AgentMessage interface {
MessageToJson() ([]byte, error)
}
type AgentTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
type AgentAssistantToolCall struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
return json.Marshal(toolCall)
}
type AgentArrayMessage struct {
Role string `json:"role"`
Content []AgentContent `json:"content"`
}
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *AgentMessages) AddText(message AgentTextMessage) {
content.Messages = append(content.Messages, message)
}
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
content.Messages = append(content.Messages, toolCall)
}
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
arrayMessage.Content[0] = AgentImage{
ImageType: IMAGE_TYPE,
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
}
content.Messages = append(content.Messages, arrayMessage)
return nil
}
func (content *AgentMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, AgentTextMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
return nil
}
type AgentContent interface {
ToJson() ([]byte, error)
}
type ImageUrl struct {
Url string `json:"url"`
}
type AgentImage struct {
ImageType string `json:"type"`
ImageUrl string `json:"image_url"`
}
func (imageMessage AgentImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage)
}
type AiClient interface {
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
}
type AgentClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
Do func(req *http.Request) (*http.Response, error)
}
// func (client AgentClient) Do(req *http.Request) () {
// httpClient := http.Client{}
// return httpClient.Do(req)
// }
const OPENAI_API_KEY = "OPENAI_API_KEY"
const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
// TODO: extract to text file probably
const PROMPT = `
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
Be sure to extract every link (URL) that you find.
Use generic tags.
`
const RESPONSE_FORMAT = `
{
"name": "image_info",
"strict": true,
"schema": {
"type": "object",
"title": "image",
"required": ["tags", "text", "links"],
"additionalProperties": false,
"properties": {
"tags": {
"type": "array",
"title": "tags",
"description": "A list of tags you think the image is relevant to.",
"items": {
"type": "string"
}
},
"text": {
"type": "array",
"title": "text",
"description": "A list of sentences the image contains.",
"items": {
"type": "string"
}
},
"links": {
"type": "array",
"title": "links",
"description": "A list of all the links you can find in the image.",
"items": {
"type": "string"
}
},
"locations": {
"title": "locations",
"type": "array",
"description": "A list of locations you can find on the image, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"title": "name",
"type": "string"
},
"coordinates": {
"title": "coordinates",
"type": "string"
},
"address": {
"title": "address",
"type": "string"
},
"description": {
"title": "description",
"type": "string"
}
}
}
},
"events": {
"title": "events",
"type": "array",
"description": "A list of events you find on the image, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"type": "string",
"title": "name"
},
"locations": {
"title": "locations",
"type": "array",
"description": "A list of locations on this event, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"title": "name",
"type": "string"
},
"coordinates": {
"title": "coordinates",
"type": "string"
},
"address": {
"title": "address",
"type": "string"
},
"description": {
"title": "description",
"type": "string"
}
}
}
}
}
}
}
}
}
}
`
func CreateAgentClient(prompt string) (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: prompt,
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
}, nil
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) {
request := AgentRequestBody{
Model: model,
Temperature: temperature,
ResponseFormat: ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
},
}
// TODO: Add build pattern here that deals with errors in some internal state?
// I want a monad!!!
err := request.AddSystem(prompt)
if err != nil {
return request, err
}
log.Println(request)
err = request.AddImage(imageName, imageData)
if err != nil {
return request, err
}
request.Tools = nil
return request, nil
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type ResponseChoiceMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message ResponseChoiceMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
// TODO: add usage parsing
func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
response := AgentResponse{}
err := json.Unmarshal(jsonResponse, &response)
if err != nil {
return ImageInfo{}, err
}
if len(response.Choices) != 1 {
log.Println(string(jsonResponse))
return ImageInfo{}, errors.New("Expected exactly one choice.")
}
imageInfo := ImageInfo{}
err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo)
if err != nil {
return ImageInfo{}, errors.New("Could not parse content into image type.")
}
return imageInfo, nil
}
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
log.Println(string(response))
toolCalls := agentResponse.Choices[0].Message.ToolCalls
if len(toolCalls) > 0 {
// Should for sure be more flexible.
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: toolCalls,
})
}
return agentResponse, nil
}
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
if err != nil {
return ImageInfo{}, err
}
var jsonSchema any
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
if err != nil {
return ImageInfo{}, err
}
aiRequest.ResponseFormat = ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
}
jsonAiRequest, err := json.Marshal(aiRequest)
if err != nil {
return ImageInfo{}, err
}
request, err := client.getRequest(jsonAiRequest)
if err != nil {
return ImageInfo{}, err
}
resp, err := client.Do(request)
if err != nil {
return ImageInfo{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return ImageInfo{}, err
}
log.Println(string(response))
return parseAgentResponse(response)
}

View File

@ -1,213 +0,0 @@
package agents
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
)
func TestMessageBuilder(t *testing.T) {
content := AgentMessages{}
err := content.AddSystem("Some prompt")
if err != nil {
t.Log(err)
t.FailNow()
}
if len(content.Messages) != 1 {
t.Logf("Expected length 1, got %d.\n", len(content.Messages))
t.FailNow()
}
}
func TestMessageBuilderImage(t *testing.T) {
content := AgentMessages{}
prompt := "some prompt"
imageTitle := "image.png"
data := []byte("some data")
content.AddSystem(prompt)
content.AddImage(imageTitle, data)
if len(content.Messages) != 2 {
t.Logf("Expected length 2, got %d.\n", len(content.Messages))
t.FailNow()
}
promptMessage, ok := content.Messages[0].(AgentTextMessage)
if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[0])
t.FailNow()
}
if promptMessage.Role != ROLE_SYSTEM {
t.Log("Prompt message role is incorrect.")
t.FailNow()
}
if promptMessage.Content != prompt {
t.Log("Prompt message content is incorrect.")
t.FailNow()
}
arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage)
if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[1])
t.FailNow()
}
if arrayContentMessage.Role != ROLE_USER {
t.Log("Array content message role is incorrect.")
t.FailNow()
}
if len(arrayContentMessage.Content) != 1 {
t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content))
t.FailNow()
}
imageContent, ok := arrayContentMessage.Content[0].(AgentImage)
if !ok {
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
t.FailNow()
}
base64data := base64.StdEncoding.EncodeToString(data)
url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data)
if imageContent.ImageUrl != url {
t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl)
t.FailNow()
}
}
func TestFullImageRequest(t *testing.T) {
request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data"))
if err != nil {
t.Log(request)
t.FailNow()
}
jsonData, err := json.Marshal(request)
if err != nil {
t.Log(err)
t.FailNow()
}
expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}`
if string(jsonData) != expectedJson {
t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData))
t.FailNow()
}
}
func TestResponse(t *testing.T) {
testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}`
buffer := bytes.NewReader([]byte(testResponse))
body := io.NopCloser(buffer)
client := AgentClient{
url: "http://localhost:1234",
apiKey: "some-key",
Do: func(_req *http.Request) (*http.Response, error) {
return &http.Response{Body: body}, nil
},
}
info, err := client.GetImageInfo("image.png", []byte("some data"))
if err != nil {
t.Log(err)
t.FailNow()
}
if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 {
t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links))
t.FailNow()
}
if info.Tags[0] != "tag1" {
t.Log("0th tag is wrong.")
t.FailNow()
}
if info.Tags[1] != "tag2" {
t.Log("1th tag is wrong.")
t.FailNow()
}
if info.Text[0] != "text1" {
t.Log("0th text is wrong.")
t.FailNow()
}
}
func TestResponseParsing(t *testing.T) {
response := `{
"id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ",
"object": "chat.completion",
"created": 1740422508,
"model": "gpt-4o-mini-2024-07-18",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}",
"refusal": null
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 775,
"completion_tokens": 33,
"total_tokens": 808,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
},
"service_tier": "default",
"system_fingerprint": "fp_7fcd609668"
}`
imageParsed, err := parseAgentResponse([]byte(response))
if err != nil {
t.Log(err)
t.FailNow()
}
if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" {
t.Log(imageParsed)
t.Log("Should have one link called 'link'.")
t.FailNow()
}
if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" {
t.Log(imageParsed)
t.Log("Should have one tag called 'tag'.")
t.FailNow()
}
if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" {
t.Log(imageParsed)
t.Log("Should have one text called 'text'.")
t.FailNow()
}
}

View File

@ -0,0 +1,265 @@
package client
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
)
type ImageInfo struct {
Tags []string `json:"tags"`
Text []string `json:"text"`
Links []string `json:"links"`
Locations []model.Locations `json:"locations"`
Events []model.Events `json:"events"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
// ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
AgentMessages
}
type AgentMessages struct {
Messages []AgentMessage `json:"messages"`
}
type AgentMessage interface {
MessageToJson() ([]byte, error)
}
type AgentTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type AgentAssistantToolCall struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
return json.Marshal(toolCall)
}
type AgentArrayMessage struct {
Role string `json:"role"`
Content []AgentContent `json:"content"`
}
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *AgentMessages) AddText(message AgentTextMessage) {
content.Messages = append(content.Messages, message)
}
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
content.Messages = append(content.Messages, toolCall)
}
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
arrayMessage.Content[0] = AgentImage{
ImageType: IMAGE_TYPE,
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
}
content.Messages = append(content.Messages, arrayMessage)
return nil
}
func (content *AgentMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, AgentTextMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
return nil
}
type AgentContent interface {
ToJson() ([]byte, error)
}
type ImageUrl struct {
Url string `json:"url"`
}
type AgentImage struct {
ImageType string `json:"type"`
ImageUrl string `json:"image_url"`
}
func (imageMessage AgentImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage)
}
type AiClient interface {
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
}
type ResponseChoiceMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message ResponseChoiceMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
type AgentClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
ToolHandler ToolsHandlers
Do func(req *http.Request) (*http.Response, error)
}
const OPENAI_API_KEY = "OPENAI_API_KEY"
const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
func CreateAgentClient(prompt string) (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: prompt,
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
}, nil
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
log.Println(string(response))
toolCalls := agentResponse.Choices[0].Message.ToolCalls
if len(toolCalls) > 0 {
// Should for sure be more flexible.
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: toolCalls,
})
}
return agentResponse, nil
}

View File

@ -0,0 +1,79 @@
package client
import (
"encoding/json"
"errors"
"log"
"github.com/google/uuid"
)
type ToolHandlerInfo struct {
UserId uuid.UUID
ImageId uuid.UUID
}
type ToolHandler struct {
Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error)
}
type ToolsHandlers struct {
handlers *map[string]ToolHandler
}
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
agentMessage := request.Messages[len(request.Messages)-1]
toolCall, ok := agentMessage.(AgentAssistantToolCall)
if !ok {
return "", errors.New("Latest message was not a tool call.")
}
fnName := toolCall.ToolCalls[0].Function.Name
arguments := toolCall.ToolCalls[0].Function.Arguments
fnHandler, exists := (*handler.handlers)[fnName]
if !exists {
return "", errors.New("Could not find tool with this name.")
}
log.Printf("Calling: %s\n", fnName)
res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0])
if err != nil {
return "", err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: res,
ToolCallId: toolCall.ToolCalls[0].Id,
})
return res, nil
}
func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) {
(*handler.handlers)["createLocation"] = ToolHandler{
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
argsStruct := getArgs()
err := json.Unmarshal([]byte(args), &argsStruct)
if err != nil {
return "", err
}
res, err := fn(info, argsStruct, call)
if err != nil {
return "", err
}
marshalledRes, err := json.Marshal(res)
if err != nil {
return "", err
}
return string(marshalledRes), nil
},
}
}

View File

@ -5,8 +5,8 @@ import (
"encoding/json"
"errors"
"log"
"reflect"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"time"
@ -116,13 +116,13 @@ const TOOLS = `
`
type EventLocationAgent struct {
client AgentClient
client client.AgentClient
eventModel models.EventModel
locationModel models.LocationModel
contactModel models.ContactModel
toolHandler ToolsHandlers
toolHandler client.ToolsHandlers
}
type ListLocationArguments struct{}
@ -158,12 +158,12 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
toolChoice := "any"
request := AgentRequestBody{
request := client.AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
ResponseFormat: client.ResponseFormat{
Type: "text",
},
}
@ -180,11 +180,12 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
return err
}
toolHandlerInfo := ToolHandlerInfo{
imageId: imageId,
userId: userId,
toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
// TODO: this should go into a loop with toolHandler Handle function and not be here bruh.
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
for err == nil {
@ -208,107 +209,41 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
return err
}
// TODO: extract this into a more general tool handler package.
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
agentMessage := request.Messages[len(request.Messages)-1]
toolCall, ok := agentMessage.(AgentAssistantToolCall)
if !ok {
return "", errors.New("Latest message was not a tool call.")
}
fnName := toolCall.ToolCalls[0].Function.Name
arguments := toolCall.ToolCalls[0].Function.Arguments
if fnName == "finish" {
return "", errors.New("This is the end! Maybe we just return a boolean.")
}
fn, exists := handler.Handlers[fnName]
if !exists {
return "", errors.New("Could not find tool with this name.")
}
// holy jesus what the fuck.
parseMethod := reflect.ValueOf(fn).Field(1)
if !parseMethod.IsValid() {
return "", errors.New("Parse method not found")
}
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
if !parsedArgs[1].IsNil() {
return "", parsedArgs[1].Interface().(error)
}
log.Printf("Calling: %s\n", fnName)
fnMethod := reflect.ValueOf(fn).Field(2)
if !fnMethod.IsValid() {
return "", errors.New("Fn method not found")
}
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
if !response[1].IsNil() {
return "", response[1].Interface().(error)
}
stringResponse, err := json.Marshal(response[0].Interface())
if err != nil {
return "", err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(stringResponse),
ToolCallId: toolCall.ToolCalls[0].Id,
})
return string(stringResponse), nil
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
agentClient, err := client.CreateAgentClient(eventLocationPrompt)
if err != nil {
return EventLocationAgent{}, err
}
agent := EventLocationAgent{
client: client,
client: agentClient,
locationModel: locationModel,
eventModel: eventModel,
contactModel: contactModel,
}
toolHandler := ToolsHandlers{
Handlers: make(map[string]ToolHandlerInterface),
}
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
FunctionName: "listLocations",
Parse: func(stringArgs string) (ListLocationArguments, error) {
args := ListLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
agentClient.ToolHandler.AddTool("listLocations",
func() any {
return ListLocationArguments{}
},
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
return agent.locationModel.List(context.Background(), info.userId)
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
return agent.locationModel.List(context.Background(), info.UserId)
},
}
)
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
FunctionName: "createLocation",
Parse: func(stringArgs string) (CreateLocationArguments, error) {
args := CreateLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
agentClient.ToolHandler.AddTool("createLocation",
func() any {
return CreateLocationArguments{}
},
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
args, ok := _args.(CreateLocationArguments)
if !ok {
return _args, errors.New("Type error, arguments are not of the correct struct type")
}
ctx := context.Background()
location, err := agent.locationModel.Save(ctx, info.userId, model.Locations{
location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{
Name: args.Name,
Address: args.Address,
})
@ -317,21 +252,22 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return location, err
}
_, err = agent.locationModel.SaveToImage(ctx, info.imageId, location.ID)
_, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID)
return location, err
},
}
)
toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{
FunctionName: "createEvent",
Parse: func(stringArgs string) (CreateEventArguments, error) {
args := CreateEventArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
agentClient.ToolHandler.AddTool("createEvent",
func() any {
return CreateEventArguments{}
},
Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) {
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
args, ok := _args.(CreateEventArguments)
if !ok {
return _args, errors.New("Type error, arguments are not of the correct struct type")
}
ctx := context.Background()
layout := "2006-01-02T15:04:05Z"
@ -346,7 +282,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return model.Events{}, err
}
event, err := agent.eventModel.Save(ctx, info.userId, model.Events{
event, err := agent.eventModel.Save(ctx, info.UserId, model.Events{
Name: args.Name,
StartDateTime: &startTime,
EndDateTime: &endTime,
@ -356,7 +292,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return event, err
}
organizer, err := agent.contactModel.Save(ctx, info.userId, model.Contacts{
organizer, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{
Name: args.Name,
})
@ -364,12 +300,12 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return event, err
}
_, err = agent.eventModel.SaveToImage(ctx, info.imageId, event.ID)
_, err = agent.eventModel.SaveToImage(ctx, info.ImageId, event.ID)
if err != nil {
return event, err
}
_, err = agent.contactModel.SaveToImage(ctx, info.imageId, organizer.ID)
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, organizer.ID)
if err != nil {
return event, err
}
@ -386,9 +322,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID)
},
}
agent.toolHandler = toolHandler
)
return agent, nil
}

View File

@ -3,6 +3,7 @@ package agents
import (
"context"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"github.com/google/uuid"
@ -19,16 +20,16 @@ Do not return anything except markdown.
`
type NoteAgent struct {
client AgentClient
client client.AgentClient
noteModel models.NoteModel
}
func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
request := AgentRequestBody{
request := client.AgentRequestBody{
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
ResponseFormat: client.ResponseFormat{
Type: "text",
},
}
@ -65,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
}
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
client, err := CreateAgentClient(noteAgentPrompt)
client, err := client.CreateAgentClient(noteAgentPrompt)
if err != nil {
return NoteAgent{}, err
}

View File

@ -3,6 +3,7 @@ package agents
import (
"encoding/json"
"fmt"
"screenmark/screenmark/agents/client"
"github.com/google/uuid"
)
@ -74,7 +75,7 @@ const MY_TOOLS = `
]`
type OrchestratorAgent struct {
client AgentClient
client client.AgentClient
}
func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
@ -86,10 +87,10 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
return err
}
request := AgentRequestBody{
request := client.AgentRequestBody{
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
ResponseFormat: client.ResponseFormat{
Type: "text",
},
ToolChoice: &toolChoice,
@ -113,7 +114,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
}
func NewOrchestratorAgent() (OrchestratorAgent, error) {
agent, err := CreateAgentClient(orchestratorPrompt)
agent, err := client.CreateAgentClient(orchestratorPrompt)
if err != nil {
return OrchestratorAgent{}, err
}

View File

@ -1,26 +0,0 @@
package agents
import "github.com/google/uuid"
type ToolHandlerInfo struct {
userId uuid.UUID
imageId uuid.UUID
}
type ToolHandler[TArgs any, TResp any] struct {
FunctionName string
Parse func(args string) (TArgs, error)
Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error)
}
type ToolHandlerInterface interface {
GetFunctionName() string
}
func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string {
return handler.FunctionName
}
type ToolsHandlers struct {
Handlers map[string]ToolHandlerInterface
}

View File

@ -13,6 +13,7 @@ import (
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"time"
@ -24,41 +25,13 @@ import (
)
type TestAiClient struct {
ImageInfo agents.ImageInfo
ImageInfo client.ImageInfo
}
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) {
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageInfo, error) {
return client.ImageInfo, nil
}
func GetAiClient() (agents.AiClient, error) {
mode := os.Getenv("MODE")
if mode == "TESTING" {
address := "10 Downing Street"
description := "Cheese and Crackers"
return TestAiClient{
ImageInfo: agents.ImageInfo{
Tags: []string{"tag"},
Links: []string{"links"},
Text: []string{"text"},
Locations: []model.Locations{{
ID: uuid.Nil,
Name: "London",
Address: &address,
}},
Events: []model.Events{{
ID: uuid.Nil,
Name: "Party",
Description: &description,
}},
},
}, nil
}
return agents.CreateAgentClient(agents.PROMPT)
}
func main() {
err := godotenv.Load()
if err != nil {
@ -74,9 +47,6 @@ func main() {
}
imageModel := models.NewImageModel(db)
linkModel := models.NewLinkModel(db)
tagModel := models.NewTagModel(db)
textModel := models.NewTextModel(db)
locationModel := models.NewLocationModel(db)
eventModel := models.NewEventModel(db)
userModel := models.NewUserModel(db)
@ -105,11 +75,6 @@ func main() {
ctx := context.Background()
go func() {
openAiClient, err := GetAiClient()
if err != nil {
panic(err)
}
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel)
if err != nil {
panic(err)
@ -132,7 +97,7 @@ func main() {
return
}
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
_, err = imageModel.FinishProcessing(ctx, image.ID)
if err != nil {
log.Println("Failed to FinishProcessing")
log.Println(err)
@ -152,36 +117,6 @@ func main() {
log.Println("Calling noteAgent!")
err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
log.Println(err)
return
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
if err != nil {
log.Println("Failed to GetImageInfo")
log.Println(err)
return
}
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)
if err != nil {
log.Println("Failed to save tags")
log.Println(err)
return
}
err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links)
if err != nil {
log.Println("Failed to save links")
log.Println(err)
return
}
err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text)
if err != nil {
log.Println("Failed to save text")
log.Println(err)
return
}
}()
}
}