Orchestrator + Tooling rework #4

Merged
JohnCosta27 merged 17 commits from feat/orchestrator into main 2025-04-09 17:00:53 +01:00
14 changed files with 921 additions and 949 deletions

View File

@ -1,479 +0,0 @@
package agents
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
)
type ImageInfo struct {
Tags []string `json:"tags"`
Text []string `json:"text"`
Links []string `json:"links"`
Locations []model.Locations `json:"locations"`
Events []model.Events `json:"events"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
AgentMessages
}
type AgentMessages struct {
Messages []AgentMessage `json:"messages"`
}
type AgentMessage interface {
MessageToJson() ([]byte, error)
}
type AgentTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
type AgentAssistantToolCall struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
return json.Marshal(toolCall)
}
type AgentArrayMessage struct {
Role string `json:"role"`
Content []AgentContent `json:"content"`
}
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *AgentMessages) AddText(message AgentTextMessage) {
content.Messages = append(content.Messages, message)
}
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
content.Messages = append(content.Messages, toolCall)
}
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
arrayMessage.Content[0] = AgentImage{
ImageType: IMAGE_TYPE,
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
}
content.Messages = append(content.Messages, arrayMessage)
return nil
}
func (content *AgentMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, AgentTextMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
return nil
}
type AgentContent interface {
ToJson() ([]byte, error)
}
type ImageUrl struct {
Url string `json:"url"`
}
type AgentImage struct {
ImageType string `json:"type"`
ImageUrl string `json:"image_url"`
}
func (imageMessage AgentImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage)
}
type AiClient interface {
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
}
type AgentClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
Do func(req *http.Request) (*http.Response, error)
}
// func (client AgentClient) Do(req *http.Request) () {
// httpClient := http.Client{}
// return httpClient.Do(req)
// }
const OPENAI_API_KEY = "OPENAI_API_KEY"
const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
// TODO: extract to text file probably
const PROMPT = `
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
Be sure to extract every link (URL) that you find.
Use generic tags.
`
const RESPONSE_FORMAT = `
{
"name": "image_info",
"strict": true,
"schema": {
"type": "object",
"title": "image",
"required": ["tags", "text", "links"],
"additionalProperties": false,
"properties": {
"tags": {
"type": "array",
"title": "tags",
"description": "A list of tags you think the image is relevant to.",
"items": {
"type": "string"
}
},
"text": {
"type": "array",
"title": "text",
"description": "A list of sentences the image contains.",
"items": {
"type": "string"
}
},
"links": {
"type": "array",
"title": "links",
"description": "A list of all the links you can find in the image.",
"items": {
"type": "string"
}
},
"locations": {
"title": "locations",
"type": "array",
"description": "A list of locations you can find on the image, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"title": "name",
"type": "string"
},
"coordinates": {
"title": "coordinates",
"type": "string"
},
"address": {
"title": "address",
"type": "string"
},
"description": {
"title": "description",
"type": "string"
}
}
}
},
"events": {
"title": "events",
"type": "array",
"description": "A list of events you find on the image, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"type": "string",
"title": "name"
},
"locations": {
"title": "locations",
"type": "array",
"description": "A list of locations on this event, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"title": "name",
"type": "string"
},
"coordinates": {
"title": "coordinates",
"type": "string"
},
"address": {
"title": "address",
"type": "string"
},
"description": {
"title": "description",
"type": "string"
}
}
}
}
}
}
}
}
}
}
`
func CreateAgentClient(prompt string) (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: prompt,
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
}, nil
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) {
request := AgentRequestBody{
Model: model,
Temperature: temperature,
ResponseFormat: ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
},
}
// TODO: Add build pattern here that deals with errors in some internal state?
// I want a monad!!!
err := request.AddSystem(prompt)
if err != nil {
return request, err
}
log.Println(request)
err = request.AddImage(imageName, imageData)
if err != nil {
return request, err
}
request.Tools = nil
return request, nil
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type ResponseChoiceMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message ResponseChoiceMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
// TODO: add usage parsing
func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
response := AgentResponse{}
err := json.Unmarshal(jsonResponse, &response)
if err != nil {
return ImageInfo{}, err
}
if len(response.Choices) != 1 {
log.Println(string(jsonResponse))
return ImageInfo{}, errors.New("Expected exactly one choice.")
}
imageInfo := ImageInfo{}
err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo)
if err != nil {
return ImageInfo{}, errors.New("Could not parse content into image type.")
}
return imageInfo, nil
}
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
log.Println(string(response))
toolCalls := agentResponse.Choices[0].Message.ToolCalls
if len(toolCalls) > 0 {
// Should for sure be more flexible.
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: toolCalls,
})
}
return agentResponse, nil
}
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
if err != nil {
return ImageInfo{}, err
}
var jsonSchema any
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
if err != nil {
return ImageInfo{}, err
}
aiRequest.ResponseFormat = ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
}
jsonAiRequest, err := json.Marshal(aiRequest)
if err != nil {
return ImageInfo{}, err
}
request, err := client.getRequest(jsonAiRequest)
if err != nil {
return ImageInfo{}, err
}
resp, err := client.Do(request)
if err != nil {
return ImageInfo{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return ImageInfo{}, err
}
log.Println(string(response))
return parseAgentResponse(response)
}

View File

@ -1,213 +0,0 @@
package agents
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
)
func TestMessageBuilder(t *testing.T) {
content := AgentMessages{}
err := content.AddSystem("Some prompt")
if err != nil {
t.Log(err)
t.FailNow()
}
if len(content.Messages) != 1 {
t.Logf("Expected length 1, got %d.\n", len(content.Messages))
t.FailNow()
}
}
func TestMessageBuilderImage(t *testing.T) {
content := AgentMessages{}
prompt := "some prompt"
imageTitle := "image.png"
data := []byte("some data")
content.AddSystem(prompt)
content.AddImage(imageTitle, data)
if len(content.Messages) != 2 {
t.Logf("Expected length 2, got %d.\n", len(content.Messages))
t.FailNow()
}
promptMessage, ok := content.Messages[0].(AgentTextMessage)
if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[0])
t.FailNow()
}
if promptMessage.Role != ROLE_SYSTEM {
t.Log("Prompt message role is incorrect.")
t.FailNow()
}
if promptMessage.Content != prompt {
t.Log("Prompt message content is incorrect.")
t.FailNow()
}
arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage)
if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[1])
t.FailNow()
}
if arrayContentMessage.Role != ROLE_USER {
t.Log("Array content message role is incorrect.")
t.FailNow()
}
if len(arrayContentMessage.Content) != 1 {
t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content))
t.FailNow()
}
imageContent, ok := arrayContentMessage.Content[0].(AgentImage)
if !ok {
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
t.FailNow()
}
base64data := base64.StdEncoding.EncodeToString(data)
url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data)
if imageContent.ImageUrl != url {
t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl)
t.FailNow()
}
}
func TestFullImageRequest(t *testing.T) {
request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data"))
if err != nil {
t.Log(request)
t.FailNow()
}
jsonData, err := json.Marshal(request)
if err != nil {
t.Log(err)
t.FailNow()
}
expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,c29tZSBkYXRh"}}]}]}`
if string(jsonData) != expectedJson {
t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData))
t.FailNow()
}
}
func TestResponse(t *testing.T) {
testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}`
buffer := bytes.NewReader([]byte(testResponse))
body := io.NopCloser(buffer)
client := AgentClient{
url: "http://localhost:1234",
apiKey: "some-key",
Do: func(_req *http.Request) (*http.Response, error) {
return &http.Response{Body: body}, nil
},
}
info, err := client.GetImageInfo("image.png", []byte("some data"))
if err != nil {
t.Log(err)
t.FailNow()
}
if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 {
t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links))
t.FailNow()
}
if info.Tags[0] != "tag1" {
t.Log("0th tag is wrong.")
t.FailNow()
}
if info.Tags[1] != "tag2" {
t.Log("1th tag is wrong.")
t.FailNow()
}
if info.Text[0] != "text1" {
t.Log("0th text is wrong.")
t.FailNow()
}
}
func TestResponseParsing(t *testing.T) {
response := `{
"id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ",
"object": "chat.completion",
"created": 1740422508,
"model": "gpt-4o-mini-2024-07-18",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}",
"refusal": null
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 775,
"completion_tokens": 33,
"total_tokens": 808,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
},
"service_tier": "default",
"system_fingerprint": "fp_7fcd609668"
}`
imageParsed, err := parseAgentResponse([]byte(response))
if err != nil {
t.Log(err)
t.FailNow()
}
if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" {
t.Log(imageParsed)
t.Log("Should have one link called 'link'.")
t.FailNow()
}
if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" {
t.Log(imageParsed)
t.Log("Should have one tag called 'tag'.")
t.FailNow()
}
if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" {
t.Log(imageParsed)
t.Log("Should have one text called 'text'.")
t.FailNow()
}
}

View File

@ -0,0 +1,208 @@
package client
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"path/filepath"
)
type Chat struct {
Messages []ChatMessage `json:"messages"`
}
type ChatMessage interface {
IsResponse() bool
}
// TODO: the role could be inferred from the type.
// This would solve some bugs.
/*
Is there a world where this actually becomes the product?
Where we build such a resilient system of AI calls that we
can build some app builder, or even just an API system,
with a fancy UI?
Manage all the complexity for the user?
*/
// =============================================
// Messages from us to the AI.
// =============================================
type UserRole = string
const (
User UserRole = "user"
System UserRole = "system"
)
type ToolRole = string
const (
Tool ToolRole = "tool"
)
type ChatUserMessage struct {
Role UserRole `json:"role"`
MessageContent `json:"MessageContent"`
}
func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
switch t := m.MessageContent.(type) {
case SingleMessage:
return json.Marshal(&struct {
Role UserRole `json:"role"`
Content string `json:"content"`
}{
Role: User,
Content: t.Content,
})
case ArrayMessage:
return json.Marshal(&struct {
Role UserRole `json:"role"`
Content []ImageMessageContent `json:"content"`
}{
Role: User,
Content: t.Content,
})
}
return []byte{}, errors.New("Unreachable")
}
func (r ChatUserMessage) IsResponse() bool {
return false
}
type ChatUserToolResponse struct {
Role ToolRole `json:"role"`
// The name of the function we are responding to.
Name string `json:"name"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id"`
}
func (r ChatUserToolResponse) IsResponse() bool {
return false
}
type ChatAiMessage struct {
Role string `json:"role"`
ToolCalls *[]ToolCall `json:"tool_calls,omitempty"`
Content string `json:"content"`
}
func (m ChatAiMessage) IsResponse() bool {
return true
}
// =============================================
// Unique interface for message content.
// =============================================
type MessageContent interface {
IsSingleMessage() bool
}
type SingleMessage struct {
Content string `json:"content"`
}
func (m SingleMessage) IsSingleMessage() bool {
return true
}
type ArrayMessage struct {
Content []ImageMessageContent `json:"content"`
}
func (m ArrayMessage) IsSingleMessage() bool {
return false
}
type ImageMessageContent struct {
ImageType string `json:"type"`
ImageUrl string `json:"image_url"`
}
type ImageContentUrl struct {
Url string `json:"url"`
}
// =============================================
// Adjacent interfaces.
// =============================================
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// =============================================
// Chat methods
// =============================================
func (chat *Chat) AddSystem(prompt string) {
chat.Messages = append(chat.Messages, ChatUserMessage{
Role: System,
MessageContent: SingleMessage{
Content: prompt,
},
})
}
func (chat *Chat) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
messageContent := ArrayMessage{
Content: make([]ImageMessageContent, 1),
}
messageContent.Content[0] = ImageMessageContent{
ImageType: "image_url",
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
}
arrayMessage := ChatUserMessage{Role: User, MessageContent: messageContent}
chat.Messages = append(chat.Messages, arrayMessage)
return nil
}
func (chat *Chat) AddAiResponse(res ChatAiMessage) {
chat.Messages = append(chat.Messages, res)
}
func (chat *Chat) AddToolResponse(res ChatUserToolResponse) {
chat.Messages = append(chat.Messages, res)
}
func (chat Chat) GetLatest() (ChatMessage, error) {
if len(chat.Messages) == 0 {
return nil, errors.New("Not enough messages")
}
return chat.Messages[len(chat.Messages)-1], nil
}

View File

@ -0,0 +1,24 @@
package client
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestFlatMarshallSingleMessage(t *testing.T) {
require := require.New(t)
message := ChatUserMessage{
Role: User,
MessageContent: SingleMessage{
Content: "Hello",
},
}
json, err := json.Marshal(message)
require.NoError(err)
require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}")
}

View File

@ -0,0 +1,194 @@
package client
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
)
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
EndToolCall string `json:"-"`
Chat *Chat `json:"messages"`
}
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
Messages []ChatMessage `json:"messages"`
}{
Model: req.Model,
Temperature: req.Temperature,
ResponseFormat: req.ResponseFormat,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Messages: req.Chat.Messages,
})
}
type ResponseChoice struct {
Index int `json:"index"`
Message ChatAiMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
type AgentClient struct {
url string
apiKey string
responseFormat string
ToolHandler ToolsHandlers
Do func(req *http.Request) (*http.Response, error)
}
const OPENAI_API_KEY = "OPENAI_API_KEY"
func CreateAgentClient() (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
ToolHandler: ToolsHandlers{
handlers: map[string]ToolHandler{},
},
}, nil
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(req)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
fmt.Println(string(response))
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
if len(agentResponse.Choices) != 1 {
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
}
req.Chat.AddAiResponse(agentResponse.Choices[0].Message)
return agentResponse, nil
}
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
for {
err := client.Process(info, req)
if err != nil {
return err
}
_, err = client.Request(req)
if err != nil {
return err
}
}
}
var FinishedCall = errors.New("Last tool tool was called")
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
var err error
message, err := req.Chat.GetLatest()
if err != nil {
return err
}
aiMessage, ok := message.(ChatAiMessage)
if !ok {
return errors.New("Latest message isnt an AI message")
}
if aiMessage.ToolCalls == nil {
// Not an error, we just dont have any tool calls to process.
return nil
}
for _, toolCall := range *aiMessage.ToolCalls {
if toolCall.Function.Name == req.EndToolCall {
return FinishedCall
}
toolResponse := client.ToolHandler.Handle(info, toolCall)
req.Chat.AddToolResponse(toolResponse)
}
return err
}

View File

@ -0,0 +1,70 @@
package client
import (
"encoding/json"
"errors"
"github.com/google/uuid"
)
type ToolHandlerInfo struct {
UserId uuid.UUID
ImageId uuid.UUID
}
type ToolHandler struct {
Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error)
}
type ToolsHandlers struct {
handlers map[string]ToolHandler
}
var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.")
const NonExistantTool = "This tool does not exist"
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse {
fnName := toolCallMessage.Function.Name
arguments := toolCallMessage.Function.Arguments
responseMessage := ChatUserToolResponse{
Role: "tool",
Name: fnName,
ToolCallId: toolCallMessage.Id,
}
fnHandler, exists := handler.handlers[fnName]
if !exists {
responseMessage.Content = NonExistantTool
return responseMessage
}
res, err := fnHandler.Fn(info, arguments, toolCallMessage)
if err != nil {
responseMessage.Content = err.Error()
} else {
responseMessage.Content = res
}
return responseMessage
}
func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
handler.handlers[name] = ToolHandler{
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
res, err := fn(info, args, call)
if err != nil {
return "", err
}
marshalledRes, err := json.Marshal(res)
if err != nil {
return "", err
}
return string(marshalledRes), nil
},
}
}

View File

@ -0,0 +1,188 @@
package client
import (
"errors"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/suite"
)
type ToolTestSuite struct {
suite.Suite
handler ToolsHandlers
client AgentClient
}
func (suite *ToolTestSuite) SetupTest() {
suite.handler = ToolsHandlers{
handlers: map[string]ToolHandler{},
}
suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
return args, nil
})
suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
return false, errors.New("I will always error")
})
suite.client.ToolHandler = suite.handler
}
func (suite *ToolTestSuite) TestSingleToolCall() {
require := suite.Require()
response := suite.handler.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
ToolCall{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "a",
Arguments: "return",
},
})
require.EqualValues(response, ChatUserToolResponse{
Role: "tool",
Content: "\"return\"",
ToolCallId: "1",
Name: "a",
})
}
func (suite *ToolTestSuite) TestMultipleToolCalls() {
assert := suite.Assert()
require := suite.Require()
chat := Chat{
Messages: []ChatMessage{ChatAiMessage{
Role: "assistant",
Content: "",
ToolCalls: &[]ToolCall{
{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "a",
Arguments: "first-call",
},
},
{
Index: 1,
Id: "2",
Function: FunctionCall{
Name: "a",
Arguments: "second-call",
},
},
},
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
})
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(chat.Messages[1:], []ChatMessage{
ChatUserToolResponse{
Role: "tool",
Content: "\"first-call\"",
ToolCallId: "1",
Name: "a",
},
ChatUserToolResponse{
Role: "tool",
Content: "\"second-call\"",
ToolCallId: "2",
Name: "a",
},
})
}
func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
assert := suite.Assert()
require := suite.Require()
chat := Chat{
Messages: []ChatMessage{ChatAiMessage{
Role: "assistant",
Content: "",
ToolCalls: &[]ToolCall{
{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "error",
Arguments: "",
},
},
{
Index: 1,
Id: "2",
Function: FunctionCall{
Name: "non-existant",
Arguments: "",
},
},
{
Index: 2,
Id: "3",
Function: FunctionCall{
Name: "a",
Arguments: "no-error",
},
},
},
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
})
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(chat.Messages[1:], []ChatMessage{
ChatUserToolResponse{
Role: "tool",
Content: "I will always error",
ToolCallId: "1",
Name: "error",
},
ChatUserToolResponse{
Role: "tool",
Content: "This tool does not exist",
ToolCallId: "2",
Name: "non-existant",
},
ChatUserToolResponse{
Role: "tool",
Content: "\"no-error\"",
ToolCallId: "3",
Name: "a",
},
})
}
func TestToolSuite(t *testing.T) {
suite.Run(t, &ToolTestSuite{
client: AgentClient{},
})
}

View File

@ -3,10 +3,8 @@ package agents
import (
"context"
"encoding/json"
"errors"
"log"
"reflect"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"time"
@ -101,28 +99,25 @@ const TOOLS = `
}
}
},
{
"type": "function",
"function": {
"name": "finish",
"description": "Nothing else to do, call this function.",
"parameters": {
"type": "object",
"properties": {}
}
}
}
{
"type": "function",
"function": {
"name": "finish",
"description": "Nothing else to do. call this function.",
"parameters": {}
}
}
]
`
type EventLocationAgent struct {
client AgentClient
client client.AgentClient
eventModel models.EventModel
locationModel models.LocationModel
contactModel models.ContactModel
toolHandler ToolsHandlers
toolHandler client.ToolsHandlers
}
type ListLocationArguments struct{}
@ -158,157 +153,66 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
toolChoice := "any"
request := AgentRequestBody{
request := client.AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
EndToolCall: "finish",
ResponseFormat: client.ResponseFormat{
Type: "text",
},
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
err = request.AddSystem(eventLocationPrompt)
if err != nil {
return err
}
request.AddImage(imageName, imageData)
request.Chat.AddSystem(eventLocationPrompt)
request.Chat.AddImage(imageName, imageData)
_, err = agent.client.Request(&request)
if err != nil {
return err
}
toolHandlerInfo := ToolHandlerInfo{
imageId: imageId,
userId: userId,
toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
for err == nil {
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
response, requestError := agent.client.Request(&request)
if requestError != nil {
return requestError
}
log.Println(response)
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
err = innerErr
log.Println(a)
log.Println("--------------------------")
}
return err
}
// TODO: extract this into a more general tool handler package.
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
agentMessage := request.Messages[len(request.Messages)-1]
toolCall, ok := agentMessage.(AgentAssistantToolCall)
if !ok {
return "", errors.New("Latest message was not a tool call.")
}
fnName := toolCall.ToolCalls[0].Function.Name
arguments := toolCall.ToolCalls[0].Function.Arguments
if fnName == "finish" {
return "", errors.New("This is the end! Maybe we just return a boolean.")
}
fn, exists := handler.Handlers[fnName]
if !exists {
return "", errors.New("Could not find tool with this name.")
}
// holy jesus what the fuck.
parseMethod := reflect.ValueOf(fn).Field(1)
if !parseMethod.IsValid() {
return "", errors.New("Parse method not found")
}
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
if !parsedArgs[1].IsNil() {
return "", parsedArgs[1].Interface().(error)
}
log.Printf("Calling: %s\n", fnName)
fnMethod := reflect.ValueOf(fn).Field(2)
if !fnMethod.IsValid() {
return "", errors.New("Fn method not found")
}
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
if !response[1].IsNil() {
return "", response[1].Interface().(error)
}
stringResponse, err := json.Marshal(response[0].Interface())
if err != nil {
return "", err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(stringResponse),
ToolCallId: toolCall.ToolCalls[0].Id,
})
return string(stringResponse), nil
return agent.client.ToolLoop(toolHandlerInfo, &request)
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
agentClient, err := client.CreateAgentClient()
if err != nil {
return EventLocationAgent{}, err
}
agent := EventLocationAgent{
client: client,
client: agentClient,
locationModel: locationModel,
eventModel: eventModel,
contactModel: contactModel,
}
toolHandler := ToolsHandlers{
Handlers: make(map[string]ToolHandlerInterface),
}
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
FunctionName: "listLocations",
Parse: func(stringArgs string) (ListLocationArguments, error) {
args := ListLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
agentClient.ToolHandler.AddTool("listLocations",
func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return agent.locationModel.List(context.Background(), info.UserId)
},
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
return agent.locationModel.List(context.Background(), info.userId)
},
}
)
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
FunctionName: "createLocation",
Parse: func(stringArgs string) (CreateLocationArguments, error) {
agentClient.ToolHandler.AddTool("createLocation",
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
args := CreateLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
err := json.Unmarshal([]byte(_args), &args)
if err != nil {
return model.Locations{}, err
}
return args, err
},
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
ctx := context.Background()
location, err := agent.locationModel.Save(ctx, info.userId, model.Locations{
location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{
Name: args.Name,
Address: args.Address,
})
@ -317,21 +221,20 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return location, err
}
_, err = agent.locationModel.SaveToImage(ctx, info.imageId, location.ID)
_, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID)
return location, err
},
}
)
toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{
FunctionName: "createEvent",
Parse: func(stringArgs string) (CreateEventArguments, error) {
agentClient.ToolHandler.AddTool("createEvent",
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
args := CreateEventArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
err := json.Unmarshal([]byte(_args), &args)
if err != nil {
return model.Locations{}, err
}
return args, err
},
Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) {
ctx := context.Background()
layout := "2006-01-02T15:04:05Z"
@ -346,7 +249,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return model.Events{}, err
}
event, err := agent.eventModel.Save(ctx, info.userId, model.Events{
event, err := agent.eventModel.Save(ctx, info.UserId, model.Events{
Name: args.Name,
StartDateTime: &startTime,
EndDateTime: &endTime,
@ -356,7 +259,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return event, err
}
organizer, err := agent.contactModel.Save(ctx, info.userId, model.Contacts{
organizer, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{
Name: args.Name,
})
@ -364,12 +267,12 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return event, err
}
_, err = agent.eventModel.SaveToImage(ctx, info.imageId, event.ID)
_, err = agent.eventModel.SaveToImage(ctx, info.ImageId, event.ID)
if err != nil {
return event, err
}
_, err = agent.contactModel.SaveToImage(ctx, info.imageId, organizer.ID)
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, organizer.ID)
if err != nil {
return event, err
}
@ -386,9 +289,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID)
},
}
agent.toolHandler = toolHandler
)
return agent, nil
}

View File

@ -3,6 +3,7 @@ package agents
import (
"context"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"github.com/google/uuid"
@ -19,26 +20,26 @@ Do not return anything except markdown.
`
type NoteAgent struct {
client AgentClient
client client.AgentClient
noteModel models.NoteModel
}
func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
request := AgentRequestBody{
request := client.AgentRequestBody{
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
ResponseFormat: client.ResponseFormat{
Type: "text",
},
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
err := request.AddSystem(noteAgentPrompt)
if err != nil {
return err
}
request.Chat.AddSystem(noteAgentPrompt)
request.Chat.AddImage(imageName, imageData)
request.AddImage(imageName, imageData)
resp, err := agent.client.Request(&request)
if err != nil {
return err
@ -65,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
}
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
client, err := CreateAgentClient(noteAgentPrompt)
client, err := client.CreateAgentClient()
if err != nil {
return NoteAgent{}, err
}

View File

@ -0,0 +1,167 @@
package agents
import (
"encoding/json"
"errors"
"fmt"
"screenmark/screenmark/agents/client"
"github.com/google/uuid"
)
const orchestratorPrompt = `
You are an Orchestrator for various AI agents.
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
You might decide no agent needs to be called.
The agents are available as tool calls.
Agents available:
eventLocationAgent
Use it when you think the image contains an event or a location of any sort. This can be an event page, a map, an address or a date.
noteAgent
Use it when there is text on the screen. Any text, always use this. Use me!
defaultAgent
When none of the above apply.
Always call agents in parallel if you need to call more than 1.
`
const MY_TOOLS = `
[
{
"type": "function",
"function": {
"name": "eventLocationAgent",
"description": "Uses the event location agent",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "noteAgent",
"description": "Uses the note agent",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "defaultAgent",
"description": "Used when you dont think its a good idea to call other agents",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}
]`
type OrchestratorAgent struct {
client client.AgentClient
}
type Status struct {
Ok bool `json:"ok"`
}
// TODO: the primary function of the agent could be extracted outwards.
// This is basically the same function as we have in the `event_location_agent.go`
func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
toolChoice := "any"
var tools any
err := json.Unmarshal([]byte(MY_TOOLS), &tools)
if err != nil {
return err
}
request := client.AgentRequestBody{
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: client.ResponseFormat{
Type: "text",
},
ToolChoice: &toolChoice,
Tools: &tools,
EndToolCall: "defaultAgent",
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
request.Chat.AddSystem(orchestratorPrompt)
request.Chat.AddImage(imageName, imageData)
res, err := agent.client.Request(&request)
if err != nil {
return err
}
fmt.Println(res)
toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
return agent.client.ToolLoop(toolHandlerInfo, &request)
}
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
agent, err := client.CreateAgentClient()
if err != nil {
return OrchestratorAgent{}, err
}
agent.ToolHandler.AddTool("eventLocationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
// We need a way to keep track of this async?
// Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish.
go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData)
return Status{
Ok: true,
}, nil
})
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
return Status{
Ok: true,
}, nil
})
agent.ToolHandler.AddTool("defaultAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
// To nothing
return Status{
Ok: true,
}, errors.New("Finished! Kinda bad return type but...")
})
return OrchestratorAgent{
client: agent,
}, nil
}

View File

@ -1,26 +0,0 @@
package agents
import "github.com/google/uuid"
type ToolHandlerInfo struct {
userId uuid.UUID
imageId uuid.UUID
}
type ToolHandler[TArgs any, TResp any] struct {
FunctionName string
Parse func(args string) (TArgs, error)
Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error)
}
type ToolHandlerInterface interface {
GetFunctionName() string
}
func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string {
return handler.FunctionName
}
type ToolsHandlers struct {
Handlers map[string]ToolHandlerInterface
}

View File

@ -10,6 +10,6 @@ require (
github.com/joho/godotenv v1.5.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/stretchr/testify v1.10.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -14,6 +14,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -13,6 +13,7 @@ import (
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"time"
@ -24,41 +25,13 @@ import (
)
type TestAiClient struct {
ImageInfo agents.ImageInfo
ImageInfo client.ImageMessageContent
}
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) {
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageMessageContent, error) {
return client.ImageInfo, nil
}
func GetAiClient() (agents.AiClient, error) {
mode := os.Getenv("MODE")
if mode == "TESTING" {
address := "10 Downing Street"
description := "Cheese and Crackers"
return TestAiClient{
ImageInfo: agents.ImageInfo{
Tags: []string{"tag"},
Links: []string{"links"},
Text: []string{"text"},
Locations: []model.Locations{{
ID: uuid.Nil,
Name: "London",
Address: &address,
}},
Events: []model.Events{{
ID: uuid.Nil,
Name: "Party",
Description: &description,
}},
},
}, nil
}
return agents.CreateAgentClient(agents.PROMPT)
}
func main() {
err := godotenv.Load()
if err != nil {
@ -74,9 +47,6 @@ func main() {
}
imageModel := models.NewImageModel(db)
linkModel := models.NewLinkModel(db)
tagModel := models.NewTagModel(db)
textModel := models.NewTextModel(db)
locationModel := models.NewLocationModel(db)
eventModel := models.NewEventModel(db)
userModel := models.NewUserModel(db)
@ -105,11 +75,6 @@ func main() {
ctx := context.Background()
go func() {
openAiClient, err := GetAiClient()
if err != nil {
panic(err)
}
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel)
if err != nil {
panic(err)
@ -127,51 +92,21 @@ func main() {
return
}
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
_, err = imageModel.FinishProcessing(ctx, image.ID)
if err != nil {
log.Println("Failed to FinishProcessing")
log.Println(err)
return
}
// TODO: this can very much be parallel
log.Println("Calling locationAgent!")
err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
log.Println(err)
log.Println("Calling noteAgent!")
err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
log.Println(err)
return
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
orchestrator, err := agents.NewOrchestratorAgent(locationAgent, noteAgent, image.Image.ImageName, image.Image.Image)
if err != nil {
log.Println("Failed to GetImageInfo")
log.Println(err)
return
panic(err)
}
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)
err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil {
log.Println("Failed to save tags")
log.Println(err)
return
}
err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links)
if err != nil {
log.Println("Failed to save links")
log.Println(err)
return
}
err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text)
if err != nil {
log.Println("Failed to save text")
log.Println(err)
return
fmt.Println(err)
}
}()
}