feat/inter-agent-communication #10
@ -65,8 +65,8 @@ func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
|
|||||||
})
|
})
|
||||||
case ArrayMessage:
|
case ArrayMessage:
|
||||||
return json.Marshal(&struct {
|
return json.Marshal(&struct {
|
||||||
Role UserRole `json:"role"`
|
Role UserRole `json:"role"`
|
||||||
Content []ImageMessageContent `json:"content"`
|
Content []MessageContentMessage `json:"content"`
|
||||||
}{
|
}{
|
||||||
Role: User,
|
Role: User,
|
||||||
Content: t.Content,
|
Content: t.Content,
|
||||||
@ -121,18 +121,35 @@ func (m SingleMessage) IsSingleMessage() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ArrayMessage struct {
|
type ArrayMessage struct {
|
||||||
Content []ImageMessageContent `json:"content"`
|
Content []MessageContentMessage `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m ArrayMessage) IsSingleMessage() bool {
|
func (m ArrayMessage) IsSingleMessage() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MessageContentMessage interface {
|
||||||
|
IsImageMessage() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextMessageContent struct {
|
||||||
|
TextType string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m TextMessageContent) IsImageMessage() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type ImageMessageContent struct {
|
type ImageMessageContent struct {
|
||||||
ImageType string `json:"type"`
|
ImageType string `json:"type"`
|
||||||
ImageUrl string `json:"image_url"`
|
ImageUrl string `json:"image_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m ImageMessageContent) IsImageMessage() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type ImageContentUrl struct {
|
type ImageContentUrl struct {
|
||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
}
|
}
|
||||||
@ -165,7 +182,7 @@ func (chat *Chat) AddSystem(prompt string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (chat *Chat) AddImage(imageName string, image []byte) error {
|
func (chat *Chat) AddImage(imageName string, image []byte, query *string) error {
|
||||||
extension := filepath.Ext(imageName)
|
extension := filepath.Ext(imageName)
|
||||||
if len(extension) == 0 {
|
if len(extension) == 0 {
|
||||||
// TODO: could also validate for image types we support.
|
// TODO: could also validate for image types we support.
|
||||||
@ -173,14 +190,28 @@ func (chat *Chat) AddImage(imageName string, image []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extension = extension[1:]
|
extension = extension[1:]
|
||||||
|
|
||||||
encodedString := base64.StdEncoding.EncodeToString(image)
|
encodedString := base64.StdEncoding.EncodeToString(image)
|
||||||
|
|
||||||
messageContent := ArrayMessage{
|
contentLength := 1
|
||||||
Content: make([]ImageMessageContent, 1),
|
if query != nil {
|
||||||
|
contentLength = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
messageContent.Content[0] = ImageMessageContent{
|
messageContent := ArrayMessage{
|
||||||
|
Content: make([]MessageContentMessage, contentLength),
|
||||||
|
}
|
||||||
|
|
||||||
|
index := 0
|
||||||
|
|
||||||
|
if query != nil {
|
||||||
|
messageContent.Content[index] = TextMessageContent{
|
||||||
|
TextType: "text",
|
||||||
|
Text: *query,
|
||||||
|
}
|
||||||
|
index += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
messageContent.Content[index] = ImageMessageContent{
|
||||||
ImageType: "image_url",
|
ImageType: "image_url",
|
||||||
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
||||||
}
|
}
|
||||||
|
@ -73,16 +73,28 @@ type AgentClient struct {
|
|||||||
|
|
||||||
Log *log.Logger
|
Log *log.Logger
|
||||||
|
|
||||||
|
Reply string
|
||||||
|
|
||||||
Do func(req *http.Request) (*http.Response, error)
|
Do func(req *http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
Options CreateAgentClientOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||||
|
|
||||||
func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
type CreateAgentClientOptions struct {
|
||||||
|
Log *log.Logger
|
||||||
|
SystemPrompt string
|
||||||
|
JsonTools string
|
||||||
|
EndToolCall string
|
||||||
|
Query *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
|
||||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||||
|
|
||||||
if len(apiKey) == 0 {
|
if len(apiKey) == 0 {
|
||||||
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
panic("No api key")
|
||||||
}
|
}
|
||||||
|
|
||||||
return AgentClient{
|
return AgentClient{
|
||||||
@ -93,12 +105,14 @@ func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
|||||||
return client.Do(req)
|
return client.Do(req)
|
||||||
},
|
},
|
||||||
|
|
||||||
Log: log,
|
Log: options.Log,
|
||||||
|
|
||||||
ToolHandler: ToolsHandlers{
|
ToolHandler: ToolsHandlers{
|
||||||
handlers: map[string]ToolHandler{},
|
handlers: map[string]ToolHandler{},
|
||||||
},
|
},
|
||||||
}, nil
|
|
||||||
|
Options: options,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||||
@ -146,8 +160,6 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
|
|||||||
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
||||||
}
|
}
|
||||||
|
|
||||||
client.Log.SetLevel(log.DebugLevel)
|
|
||||||
|
|
||||||
msg := agentResponse.Choices[0].Message
|
msg := agentResponse.Choices[0].Message
|
||||||
|
|
||||||
if len(msg.Content) > 0 {
|
if len(msg.Content) > 0 {
|
||||||
@ -170,7 +182,7 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
|
|||||||
return agentResponse, nil
|
return agentResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
func (client *AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||||
for {
|
for {
|
||||||
err := client.Process(info, req)
|
err := client.Process(info, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -186,7 +198,7 @@ func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody)
|
|||||||
|
|
||||||
var FinishedCall = errors.New("Last tool tool was called")
|
var FinishedCall = errors.New("Last tool tool was called")
|
||||||
|
|
||||||
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
message, err := req.Chat.GetLatest()
|
message, err := req.Chat.GetLatest()
|
||||||
@ -211,7 +223,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
|||||||
|
|
||||||
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
||||||
|
|
||||||
client.Log.SetLevel(log.DebugLevel)
|
if toolCall.Function.Name == "reply" {
|
||||||
|
client.Reply = toolCall.Function.Arguments
|
||||||
|
}
|
||||||
|
|
||||||
client.Log.Debugf("Response: %s", toolResponse.Content)
|
client.Log.Debugf("Response: %s", toolResponse.Content)
|
||||||
|
|
||||||
req.Chat.AddToolResponse(toolResponse)
|
req.Chat.AddToolResponse(toolResponse)
|
||||||
@ -220,9 +235,9 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
var tools any
|
var tools any
|
||||||
err := json.Unmarshal([]byte(jsonTools), &tools)
|
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
|
||||||
|
|
||||||
toolChoice := "any"
|
toolChoice := "any"
|
||||||
|
|
||||||
@ -231,7 +246,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
|||||||
ToolChoice: &toolChoice,
|
ToolChoice: &toolChoice,
|
||||||
Model: "pixtral-12b-2409",
|
Model: "pixtral-12b-2409",
|
||||||
Temperature: 0.3,
|
Temperature: 0.3,
|
||||||
EndToolCall: endToolCall,
|
EndToolCall: client.Options.EndToolCall,
|
||||||
ResponseFormat: ResponseFormat{
|
ResponseFormat: ResponseFormat{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
},
|
},
|
||||||
@ -240,8 +255,8 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Chat.AddSystem(systemPrompt)
|
request.Chat.AddSystem(client.Options.SystemPrompt)
|
||||||
request.Chat.AddImage(imageName, imageData)
|
request.Chat.AddImage(imageName, imageData, client.Options.Query)
|
||||||
|
|
||||||
_, err = client.Request(&request)
|
_, err = client.Request(&request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -249,8 +264,10 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
|||||||
}
|
}
|
||||||
|
|
||||||
toolHandlerInfo := ToolHandlerInfo{
|
toolHandlerInfo := ToolHandlerInfo{
|
||||||
ImageId: imageId,
|
ImageId: imageId,
|
||||||
UserId: userId,
|
ImageName: imageName,
|
||||||
|
UserId: userId,
|
||||||
|
Image: &imageData,
|
||||||
}
|
}
|
||||||
|
|
||||||
return client.ToolLoop(toolHandlerInfo, &request)
|
return client.ToolLoop(toolHandlerInfo, &request)
|
||||||
|
@ -8,8 +8,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ToolHandlerInfo struct {
|
type ToolHandlerInfo struct {
|
||||||
UserId uuid.UUID
|
UserId uuid.UUID
|
||||||
ImageId uuid.UUID
|
ImageId uuid.UUID
|
||||||
|
ImageName string
|
||||||
|
|
||||||
|
// Pointer because we don't want to copy this around too much.
|
||||||
|
Image *[]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolHandler struct {
|
type ToolHandler struct {
|
||||||
|
@ -3,11 +3,9 @@ package agents
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -81,12 +79,6 @@ const contactTools = `
|
|||||||
]
|
]
|
||||||
`
|
`
|
||||||
|
|
||||||
type ContactAgent struct {
|
|
||||||
client client.AgentClient
|
|
||||||
|
|
||||||
contactModel models.ContactModel
|
|
||||||
}
|
|
||||||
|
|
||||||
type listContactsArguments struct{}
|
type listContactsArguments struct{}
|
||||||
type createContactsArguments struct {
|
type createContactsArguments struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@ -98,23 +90,16 @@ type linkContactArguments struct {
|
|||||||
ContactID string `json:"contactId"`
|
ContactID string `json:"contactId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
func NewContactAgent(log *log.Logger, contactModel models.ContactModel) client.AgentClient {
|
||||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
ReportTimestamp: true,
|
SystemPrompt: contactPrompt,
|
||||||
TimeFormat: time.Kitchen,
|
JsonTools: contactTools,
|
||||||
Prefix: "Contacts 👥",
|
Log: log,
|
||||||
}))
|
EndToolCall: "finish",
|
||||||
if err != nil {
|
})
|
||||||
return ContactAgent{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
agent := ContactAgent{
|
|
||||||
client: agentClient,
|
|
||||||
contactModel: contactModel,
|
|
||||||
}
|
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
return agent.contactModel.List(context.Background(), info.UserId)
|
return contactModel.List(context.Background(), info.UserId)
|
||||||
})
|
})
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("createContact", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("createContact", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||||
@ -126,7 +111,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
contact, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{
|
contact, err := contactModel.Save(ctx, info.UserId, model.Contacts{
|
||||||
Name: args.Name,
|
Name: args.Name,
|
||||||
PhoneNumber: args.PhoneNumber,
|
PhoneNumber: args.PhoneNumber,
|
||||||
Email: args.Email,
|
Email: args.Email,
|
||||||
@ -136,7 +121,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
|||||||
return model.Contacts{}, err
|
return model.Contacts{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contact.ID)
|
_, err = contactModel.SaveToImage(ctx, info.ImageId, contact.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.Contacts{}, err
|
return model.Contacts{}, err
|
||||||
}
|
}
|
||||||
@ -158,7 +143,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
_, err = contactModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -166,5 +151,5 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
|||||||
return "Saved", nil
|
return "Saved", nil
|
||||||
})
|
})
|
||||||
|
|
||||||
return agent, nil
|
return agentClient
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package agents
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
@ -27,6 +26,9 @@ Lists the users already existing events.
|
|||||||
createEvent
|
createEvent
|
||||||
Use this to create a new events.
|
Use this to create a new events.
|
||||||
|
|
||||||
|
getEventLocationId
|
||||||
|
Use this if the image contains a location or place. This tool will return the locationId.
|
||||||
|
|
||||||
finish
|
finish
|
||||||
Call when there is nothing else to do.
|
Call when there is nothing else to do.
|
||||||
`
|
`
|
||||||
@ -63,11 +65,28 @@ const eventTools = `
|
|||||||
"endDateTime": {
|
"endDateTime": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The end time as an ISO string"
|
"description": "The end time as an ISO string"
|
||||||
}
|
},
|
||||||
|
"locationId": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The UUID of this location. You should use getEventLocationId to get this information, but only if you believe the event contains a location"
|
||||||
|
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"required": ["name"]
|
"required": ["name"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "getEventLocationId",
|
||||||
|
"description": "Get the ID of the location on the image, only use if the event contains a location or place.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -83,12 +102,6 @@ const eventTools = `
|
|||||||
}
|
}
|
||||||
]`
|
]`
|
||||||
|
|
||||||
type EventAgent struct {
|
|
||||||
client client.AgentClient
|
|
||||||
|
|
||||||
eventsModel models.EventModel
|
|
||||||
}
|
|
||||||
|
|
||||||
type listEventArguments struct{}
|
type listEventArguments struct{}
|
||||||
type createEventArguments struct {
|
type createEventArguments struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@ -100,24 +113,20 @@ type linkEventArguments struct {
|
|||||||
EventID string `json:"eventId"`
|
EventID string `json:"eventId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationModel models.LocationModel) client.AgentClient {
|
||||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
ReportTimestamp: true,
|
SystemPrompt: eventPrompt,
|
||||||
TimeFormat: time.Kitchen,
|
JsonTools: eventTools,
|
||||||
Prefix: "Events 📍",
|
Log: log,
|
||||||
}))
|
EndToolCall: "finish",
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
locationAgent := NewLocationAgent(log.WithPrefix("Events 📅 > Locations 📍"), locationModel)
|
||||||
return EventAgent{}, err
|
locationQuery := "Can you get me the ID of the location present in this image?"
|
||||||
}
|
locationAgent.Options.Query = &locationQuery
|
||||||
|
|
||||||
agent := EventAgent{
|
|
||||||
client: agentClient,
|
|
||||||
eventsModel: eventsModel,
|
|
||||||
}
|
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
return agent.eventsModel.List(context.Background(), info.UserId)
|
return eventsModel.List(context.Background(), info.UserId)
|
||||||
})
|
})
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("createEvent", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("createEvent", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||||
@ -141,7 +150,7 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
|||||||
return model.Events{}, err
|
return model.Events{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
events, err := agent.eventsModel.Save(ctx, info.UserId, model.Events{
|
events, err := eventsModel.Save(ctx, info.UserId, model.Events{
|
||||||
Name: args.Name,
|
Name: args.Name,
|
||||||
StartDateTime: &startTime,
|
StartDateTime: &startTime,
|
||||||
EndDateTime: &endTime,
|
EndDateTime: &endTime,
|
||||||
@ -151,7 +160,7 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
|||||||
return model.Events{}, err
|
return model.Events{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.eventsModel.SaveToImage(ctx, info.ImageId, events.ID)
|
_, err = eventsModel.SaveToImage(ctx, info.ImageId, events.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.Events{}, err
|
return model.Events{}, err
|
||||||
}
|
}
|
||||||
@ -173,9 +182,17 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
agent.eventsModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
eventsModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||||
return "Saved", nil
|
return "Saved", nil
|
||||||
})
|
})
|
||||||
|
|
||||||
return agent, nil
|
agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
|
// TODO: reenable this when I'm creating the agent locally instead of getting it from above.
|
||||||
|
locationAgent.RunAgent(info.UserId, info.ImageId, info.ImageName, *info.Image)
|
||||||
|
|
||||||
|
log.Debugf("Reply from location %s\n", locationAgent.Reply)
|
||||||
|
return locationAgent.Reply, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return agentClient
|
||||||
}
|
}
|
||||||
|
@ -3,11 +3,9 @@ package agents
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -27,6 +25,9 @@ Lists the users already existing locations.
|
|||||||
createLocation
|
createLocation
|
||||||
Use this to create a new location, when you don't see a matching one from listLocations call.
|
Use this to create a new location, when you don't see a matching one from listLocations call.
|
||||||
|
|
||||||
|
reply
|
||||||
|
Use this only if the user has asked a question about a location.
|
||||||
|
|
||||||
finish
|
finish
|
||||||
Call when there is nothing else to do.
|
Call when there is nothing else to do.
|
||||||
`
|
`
|
||||||
@ -63,6 +64,22 @@ const locationTools = `
|
|||||||
"required": ["name"]
|
"required": ["name"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "reply",
|
||||||
|
"description": "Reply to a user query, only if the user has asked something",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"locationId": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["locationId"]
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -78,12 +95,6 @@ const locationTools = `
|
|||||||
}
|
}
|
||||||
]`
|
]`
|
||||||
|
|
||||||
type LocationAgent struct {
|
|
||||||
client client.AgentClient
|
|
||||||
|
|
||||||
locationModel models.LocationModel
|
|
||||||
}
|
|
||||||
|
|
||||||
type listLocationArguments struct{}
|
type listLocationArguments struct{}
|
||||||
type createLocationArguments struct {
|
type createLocationArguments struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@ -93,24 +104,16 @@ type linkLocationArguments struct {
|
|||||||
LocationID string `json:"locationId"`
|
LocationID string `json:"locationId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error) {
|
func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) client.AgentClient {
|
||||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
ReportTimestamp: true,
|
SystemPrompt: locationPrompt,
|
||||||
TimeFormat: time.Kitchen,
|
JsonTools: locationTools,
|
||||||
Prefix: "Locations 📍",
|
Log: log,
|
||||||
}))
|
EndToolCall: "finish",
|
||||||
|
})
|
||||||
if err != nil {
|
|
||||||
return LocationAgent{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
agent := LocationAgent{
|
|
||||||
client: agentClient,
|
|
||||||
locationModel: locationModel,
|
|
||||||
}
|
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
return agent.locationModel.List(context.Background(), info.UserId)
|
return locationModel.List(context.Background(), info.UserId)
|
||||||
})
|
})
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||||
@ -122,7 +125,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{
|
location, err := locationModel.Save(ctx, info.UserId, model.Locations{
|
||||||
Name: args.Name,
|
Name: args.Name,
|
||||||
Address: args.Address,
|
Address: args.Address,
|
||||||
})
|
})
|
||||||
@ -131,7 +134,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
|
|||||||
return model.Locations{}, err
|
return model.Locations{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID)
|
_, err = locationModel.SaveToImage(ctx, info.ImageId, location.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.Locations{}, err
|
return model.Locations{}, err
|
||||||
}
|
}
|
||||||
@ -153,9 +156,13 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
agent.locationModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
locationModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||||
return "Saved", nil
|
return "Saved", nil
|
||||||
})
|
})
|
||||||
|
|
||||||
return agent, nil
|
agentClient.ToolHandler.AddTool("reply", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
|
return "ok", nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return agentClient
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,9 @@ package agents
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -43,7 +41,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Chat.AddSystem(noteAgentPrompt)
|
request.Chat.AddSystem(noteAgentPrompt)
|
||||||
request.Chat.AddImage(imageName, imageData)
|
request.Chat.AddImage(imageName, imageData, nil)
|
||||||
|
|
||||||
resp, err := agent.client.Request(&request)
|
resp, err := agent.client.Request(&request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -70,20 +68,16 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
|
func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) NoteAgent {
|
||||||
client, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
client := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
ReportTimestamp: true,
|
SystemPrompt: noteAgentPrompt,
|
||||||
TimeFormat: time.Kitchen,
|
Log: log,
|
||||||
Prefix: "Notes 📝",
|
})
|
||||||
}))
|
|
||||||
if err != nil {
|
|
||||||
return NoteAgent{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
agent := NoteAgent{
|
agent := NoteAgent{
|
||||||
client: client,
|
client: client,
|
||||||
noteModel: noteModel,
|
noteModel: noteModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
return agent, nil
|
return agent
|
||||||
}
|
}
|
||||||
|
@ -2,14 +2,12 @@ package agents
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
|
||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const OrchestratorPrompt = `
|
const orchestratorPrompt = `
|
||||||
You are an Orchestrator for various AI agents.
|
You are an Orchestrator for various AI agents.
|
||||||
|
|
||||||
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
||||||
@ -40,7 +38,7 @@ Always call agents in parallel if you need to call more than 1.
|
|||||||
Do not call the agent if you do not think it is relevant for the image.
|
Do not call the agent if you do not think it is relevant for the image.
|
||||||
`
|
`
|
||||||
|
|
||||||
const OrchestratorTools = `
|
const orchestratorTools = `
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -114,16 +112,13 @@ type Status struct {
|
|||||||
Ok bool `json:"ok"`
|
Ok bool `json:"ok"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locationAgent LocationAgent, eventAgent EventAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) client.AgentClient {
|
||||||
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
agent := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
ReportTimestamp: true,
|
SystemPrompt: orchestratorPrompt,
|
||||||
TimeFormat: time.Kitchen,
|
JsonTools: orchestratorTools,
|
||||||
Prefix: "Orchestrator 🎼",
|
Log: log,
|
||||||
}))
|
EndToolCall: "noAction",
|
||||||
|
})
|
||||||
if err != nil {
|
|
||||||
return OrchestratorAgent{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
|
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
|
||||||
@ -134,7 +129,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
|||||||
})
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
go contactAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
@ -142,7 +137,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
|||||||
})
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go locationAgent.client.RunAgent(locationPrompt, locationTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
go locationAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
@ -150,7 +145,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
|||||||
})
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go eventAgent.client.RunAgent(eventPrompt, eventTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
go eventAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
@ -165,7 +160,5 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
|||||||
}, errors.New("Finished! Kinda bad return type but...")
|
}, errors.New("Finished! Kinda bad return type but...")
|
||||||
})
|
})
|
||||||
|
|
||||||
return OrchestratorAgent{
|
return agent
|
||||||
Client: agent,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
@ -3,17 +3,28 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
"screenmark/screenmark/agents"
|
"screenmark/screenmark/agents"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func createLogger(prefix string) *log.Logger {
|
||||||
|
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||||
|
ReportTimestamp: true,
|
||||||
|
TimeFormat: time.Kitchen,
|
||||||
|
Prefix: prefix,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
|
||||||
func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
||||||
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
|
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -28,6 +39,9 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
|||||||
imageModel := models.NewImageModel(db)
|
imageModel := models.NewImageModel(db)
|
||||||
contactModel := models.NewContactModel(db)
|
contactModel := models.NewContactModel(db)
|
||||||
|
|
||||||
|
databaseEventLog := createLogger("Database Events 🤖")
|
||||||
|
databaseEventLog.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
err := listener.Listen("new_image")
|
err := listener.Listen("new_image")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -39,55 +53,41 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
|||||||
imageId := uuid.MustParse(parameters.Extra)
|
imageId := uuid.MustParse(parameters.Extra)
|
||||||
eventManager.listeners[parameters.Extra] = make(chan string)
|
eventManager.listeners[parameters.Extra] = make(chan string)
|
||||||
|
|
||||||
|
databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
noteAgent, err := agents.NewNoteAgent(noteModel)
|
noteAgent := agents.NewNoteAgent(createLogger("Notes 📝"), noteModel)
|
||||||
if err != nil {
|
contactAgent := agents.NewContactAgent(createLogger("Contacts 👥"), contactModel)
|
||||||
panic(err)
|
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍"), locationModel)
|
||||||
}
|
eventAgent := agents.NewEventAgent(createLogger("Events 📅"), eventModel, locationModel)
|
||||||
|
|
||||||
contactAgent, err := agents.NewContactAgent(contactModel)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
locationAgent, err := agents.NewLocationAgent(locationModel)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
eventAgent, err := agents.NewEventAgent(eventModel)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
image, err := imageModel.GetToProcessWithData(ctx, imageId)
|
image, err := imageModel.GetToProcessWithData(ctx, imageId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to GetToProcessWithData")
|
databaseEventLog.Error("Failed to GetToProcessWithData", "error", err)
|
||||||
log.Println(err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
||||||
log.Println("Failed to FinishProcessing")
|
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
||||||
log.Println(err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
orchestrator, err := agents.NewOrchestratorAgent(noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼"), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
||||||
|
err = orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
databaseEventLog.Error("Orchestrator failed", "error", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Still need to find some way to hide this complexity away.
|
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
||||||
// I don't think wrapping agents in structs actually works too well.
|
|
||||||
err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
databaseEventLog.Error("Failed to finish processing", "ImageID", imageId)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
imageModel.FinishProcessing(ctx, image.ID)
|
databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -122,9 +122,6 @@ func ListenProcessingImageStatus(db *sql.DB, eventManager *EventManager) {
|
|||||||
stringUuid := data.Extra[0:36]
|
stringUuid := data.Extra[0:36]
|
||||||
status := data.Extra[36:]
|
status := data.Extra[36:]
|
||||||
|
|
||||||
fmt.Printf("UUID: %s\n", stringUuid)
|
|
||||||
fmt.Printf("Receiving :s\n", data.Extra)
|
|
||||||
|
|
||||||
imageListener, exists := eventManager.listeners[stringUuid]
|
imageListener, exists := eventManager.listeners[stringUuid]
|
||||||
if !exists {
|
if !exists {
|
||||||
continue
|
continue
|
||||||
|
Reference in New Issue
Block a user