feat/inter-agent-communication #10
@ -65,8 +65,8 @@ func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
|
||||
})
|
||||
case ArrayMessage:
|
||||
return json.Marshal(&struct {
|
||||
Role UserRole `json:"role"`
|
||||
Content []ImageMessageContent `json:"content"`
|
||||
Role UserRole `json:"role"`
|
||||
Content []MessageContentMessage `json:"content"`
|
||||
}{
|
||||
Role: User,
|
||||
Content: t.Content,
|
||||
@ -121,18 +121,35 @@ func (m SingleMessage) IsSingleMessage() bool {
|
||||
}
|
||||
|
||||
type ArrayMessage struct {
|
||||
Content []ImageMessageContent `json:"content"`
|
||||
Content []MessageContentMessage `json:"content"`
|
||||
}
|
||||
|
||||
func (m ArrayMessage) IsSingleMessage() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type MessageContentMessage interface {
|
||||
IsImageMessage() bool
|
||||
}
|
||||
|
||||
type TextMessageContent struct {
|
||||
TextType string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
func (m TextMessageContent) IsImageMessage() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type ImageMessageContent struct {
|
||||
ImageType string `json:"type"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
}
|
||||
|
||||
func (m ImageMessageContent) IsImageMessage() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type ImageContentUrl struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
@ -165,7 +182,7 @@ func (chat *Chat) AddSystem(prompt string) {
|
||||
})
|
||||
}
|
||||
|
||||
func (chat *Chat) AddImage(imageName string, image []byte) error {
|
||||
func (chat *Chat) AddImage(imageName string, image []byte, query *string) error {
|
||||
extension := filepath.Ext(imageName)
|
||||
if len(extension) == 0 {
|
||||
// TODO: could also validate for image types we support.
|
||||
@ -173,14 +190,28 @@ func (chat *Chat) AddImage(imageName string, image []byte) error {
|
||||
}
|
||||
|
||||
extension = extension[1:]
|
||||
|
||||
encodedString := base64.StdEncoding.EncodeToString(image)
|
||||
|
||||
messageContent := ArrayMessage{
|
||||
Content: make([]ImageMessageContent, 1),
|
||||
contentLength := 1
|
||||
if query != nil {
|
||||
contentLength = 2
|
||||
}
|
||||
|
||||
messageContent.Content[0] = ImageMessageContent{
|
||||
messageContent := ArrayMessage{
|
||||
Content: make([]MessageContentMessage, contentLength),
|
||||
}
|
||||
|
||||
index := 0
|
||||
|
||||
if query != nil {
|
||||
messageContent.Content[index] = TextMessageContent{
|
||||
TextType: "text",
|
||||
Text: *query,
|
||||
}
|
||||
index += 1
|
||||
}
|
||||
|
||||
messageContent.Content[index] = ImageMessageContent{
|
||||
ImageType: "image_url",
|
||||
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
||||
}
|
||||
|
@ -73,16 +73,28 @@ type AgentClient struct {
|
||||
|
||||
Log *log.Logger
|
||||
|
||||
Reply string
|
||||
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
|
||||
Options CreateAgentClientOptions
|
||||
}
|
||||
|
||||
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||
|
||||
func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
||||
type CreateAgentClientOptions struct {
|
||||
Log *log.Logger
|
||||
SystemPrompt string
|
||||
JsonTools string
|
||||
EndToolCall string
|
||||
Query *string
|
||||
}
|
||||
|
||||
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
|
||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||
|
||||
if len(apiKey) == 0 {
|
||||
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
||||
panic("No api key")
|
||||
}
|
||||
|
||||
return AgentClient{
|
||||
@ -93,12 +105,14 @@ func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
||||
return client.Do(req)
|
||||
},
|
||||
|
||||
Log: log,
|
||||
Log: options.Log,
|
||||
|
||||
ToolHandler: ToolsHandlers{
|
||||
handlers: map[string]ToolHandler{},
|
||||
},
|
||||
}, nil
|
||||
|
||||
Options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||
@ -146,8 +160,6 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
|
||||
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
||||
}
|
||||
|
||||
client.Log.SetLevel(log.DebugLevel)
|
||||
|
||||
msg := agentResponse.Choices[0].Message
|
||||
|
||||
if len(msg.Content) > 0 {
|
||||
@ -170,7 +182,7 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
|
||||
return agentResponse, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||
func (client *AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||
for {
|
||||
err := client.Process(info, req)
|
||||
if err != nil {
|
||||
@ -186,7 +198,7 @@ func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody)
|
||||
|
||||
var FinishedCall = errors.New("Last tool tool was called")
|
||||
|
||||
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||
func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||
var err error
|
||||
|
||||
message, err := req.Chat.GetLatest()
|
||||
@ -211,7 +223,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
||||
|
||||
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
||||
|
||||
client.Log.SetLevel(log.DebugLevel)
|
||||
if toolCall.Function.Name == "reply" {
|
||||
client.Reply = toolCall.Function.Arguments
|
||||
}
|
||||
|
||||
client.Log.Debugf("Response: %s", toolResponse.Content)
|
||||
|
||||
req.Chat.AddToolResponse(toolResponse)
|
||||
@ -220,9 +235,9 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
||||
return err
|
||||
}
|
||||
|
||||
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(jsonTools), &tools)
|
||||
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
|
||||
|
||||
toolChoice := "any"
|
||||
|
||||
@ -231,7 +246,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
||||
ToolChoice: &toolChoice,
|
||||
Model: "pixtral-12b-2409",
|
||||
Temperature: 0.3,
|
||||
EndToolCall: endToolCall,
|
||||
EndToolCall: client.Options.EndToolCall,
|
||||
ResponseFormat: ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
@ -240,8 +255,8 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(systemPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
request.Chat.AddSystem(client.Options.SystemPrompt)
|
||||
request.Chat.AddImage(imageName, imageData, client.Options.Query)
|
||||
|
||||
_, err = client.Request(&request)
|
||||
if err != nil {
|
||||
@ -249,8 +264,10 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
||||
}
|
||||
|
||||
toolHandlerInfo := ToolHandlerInfo{
|
||||
ImageId: imageId,
|
||||
UserId: userId,
|
||||
ImageId: imageId,
|
||||
ImageName: imageName,
|
||||
UserId: userId,
|
||||
Image: &imageData,
|
||||
}
|
||||
|
||||
return client.ToolLoop(toolHandlerInfo, &request)
|
||||
|
@ -8,8 +8,12 @@ import (
|
||||
)
|
||||
|
||||
type ToolHandlerInfo struct {
|
||||
UserId uuid.UUID
|
||||
ImageId uuid.UUID
|
||||
UserId uuid.UUID
|
||||
ImageId uuid.UUID
|
||||
ImageName string
|
||||
|
||||
// Pointer because we don't want to copy this around too much.
|
||||
Image *[]byte
|
||||
}
|
||||
|
||||
type ToolHandler struct {
|
||||
|
@ -3,11 +3,9 @@ package agents
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"screenmark/screenmark/models"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
@ -81,12 +79,6 @@ const contactTools = `
|
||||
]
|
||||
`
|
||||
|
||||
type ContactAgent struct {
|
||||
client client.AgentClient
|
||||
|
||||
contactModel models.ContactModel
|
||||
}
|
||||
|
||||
type listContactsArguments struct{}
|
||||
type createContactsArguments struct {
|
||||
Name string `json:"name"`
|
||||
@ -98,23 +90,16 @@ type linkContactArguments struct {
|
||||
ContactID string `json:"contactId"`
|
||||
}
|
||||
|
||||
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: "Contacts 👥",
|
||||
}))
|
||||
if err != nil {
|
||||
return ContactAgent{}, err
|
||||
}
|
||||
|
||||
agent := ContactAgent{
|
||||
client: agentClient,
|
||||
contactModel: contactModel,
|
||||
}
|
||||
func NewContactAgent(log *log.Logger, contactModel models.ContactModel) client.AgentClient {
|
||||
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: contactPrompt,
|
||||
JsonTools: contactTools,
|
||||
Log: log,
|
||||
EndToolCall: "finish",
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return agent.contactModel.List(context.Background(), info.UserId)
|
||||
return contactModel.List(context.Background(), info.UserId)
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("createContact", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||
@ -126,7 +111,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
contact, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{
|
||||
contact, err := contactModel.Save(ctx, info.UserId, model.Contacts{
|
||||
Name: args.Name,
|
||||
PhoneNumber: args.PhoneNumber,
|
||||
Email: args.Email,
|
||||
@ -136,7 +121,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
||||
return model.Contacts{}, err
|
||||
}
|
||||
|
||||
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contact.ID)
|
||||
_, err = contactModel.SaveToImage(ctx, info.ImageId, contact.ID)
|
||||
if err != nil {
|
||||
return model.Contacts{}, err
|
||||
}
|
||||
@ -158,7 +143,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||
_, err = contactModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -166,5 +151,5 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
||||
return "Saved", nil
|
||||
})
|
||||
|
||||
return agent, nil
|
||||
return agentClient
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package agents
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"screenmark/screenmark/models"
|
||||
@ -27,6 +26,9 @@ Lists the users already existing events.
|
||||
createEvent
|
||||
Use this to create a new events.
|
||||
|
||||
getEventLocationId
|
||||
Use this if the image contains a location or place. This tool will return the locationId.
|
||||
|
||||
finish
|
||||
Call when there is nothing else to do.
|
||||
`
|
||||
@ -63,11 +65,28 @@ const eventTools = `
|
||||
"endDateTime": {
|
||||
"type": "string",
|
||||
"description": "The end time as an ISO string"
|
||||
}
|
||||
},
|
||||
"locationId": {
|
||||
"type": "string",
|
||||
"description": "The UUID of this location. You should use getEventLocationId to get this information, but only if you believe the event contains a location"
|
||||
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "getEventLocationId",
|
||||
"description": "Get the ID of the location on the image, only use if the event contains a location or place.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -83,12 +102,6 @@ const eventTools = `
|
||||
}
|
||||
]`
|
||||
|
||||
type EventAgent struct {
|
||||
client client.AgentClient
|
||||
|
||||
eventsModel models.EventModel
|
||||
}
|
||||
|
||||
type listEventArguments struct{}
|
||||
type createEventArguments struct {
|
||||
Name string `json:"name"`
|
||||
@ -100,24 +113,20 @@ type linkEventArguments struct {
|
||||
EventID string `json:"eventId"`
|
||||
}
|
||||
|
||||
func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: "Events 📍",
|
||||
}))
|
||||
func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationModel models.LocationModel) client.AgentClient {
|
||||
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: eventPrompt,
|
||||
JsonTools: eventTools,
|
||||
Log: log,
|
||||
EndToolCall: "finish",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return EventAgent{}, err
|
||||
}
|
||||
|
||||
agent := EventAgent{
|
||||
client: agentClient,
|
||||
eventsModel: eventsModel,
|
||||
}
|
||||
locationAgent := NewLocationAgent(log.WithPrefix("Events 📅 > Locations 📍"), locationModel)
|
||||
locationQuery := "Can you get me the ID of the location present in this image?"
|
||||
locationAgent.Options.Query = &locationQuery
|
||||
|
||||
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return agent.eventsModel.List(context.Background(), info.UserId)
|
||||
return eventsModel.List(context.Background(), info.UserId)
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("createEvent", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||
@ -141,7 +150,7 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
||||
return model.Events{}, err
|
||||
}
|
||||
|
||||
events, err := agent.eventsModel.Save(ctx, info.UserId, model.Events{
|
||||
events, err := eventsModel.Save(ctx, info.UserId, model.Events{
|
||||
Name: args.Name,
|
||||
StartDateTime: &startTime,
|
||||
EndDateTime: &endTime,
|
||||
@ -151,7 +160,7 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
||||
return model.Events{}, err
|
||||
}
|
||||
|
||||
_, err = agent.eventsModel.SaveToImage(ctx, info.ImageId, events.ID)
|
||||
_, err = eventsModel.SaveToImage(ctx, info.ImageId, events.ID)
|
||||
if err != nil {
|
||||
return model.Events{}, err
|
||||
}
|
||||
@ -173,9 +182,17 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
agent.eventsModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||
eventsModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||
return "Saved", nil
|
||||
})
|
||||
|
||||
return agent, nil
|
||||
agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
// TODO: reenable this when I'm creating the agent locally instead of getting it from above.
|
||||
locationAgent.RunAgent(info.UserId, info.ImageId, info.ImageName, *info.Image)
|
||||
|
||||
log.Debugf("Reply from location %s\n", locationAgent.Reply)
|
||||
return locationAgent.Reply, nil
|
||||
})
|
||||
|
||||
return agentClient
|
||||
}
|
||||
|
@ -3,11 +3,9 @@ package agents
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"screenmark/screenmark/models"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
@ -27,6 +25,9 @@ Lists the users already existing locations.
|
||||
createLocation
|
||||
Use this to create a new location, when you don't see a matching one from listLocations call.
|
||||
|
||||
reply
|
||||
Use this only if the user has asked a question about a location.
|
||||
|
||||
finish
|
||||
Call when there is nothing else to do.
|
||||
`
|
||||
@ -63,6 +64,22 @@ const locationTools = `
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "reply",
|
||||
"description": "Reply to a user query, only if the user has asked something",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"locationId": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["locationId"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -78,12 +95,6 @@ const locationTools = `
|
||||
}
|
||||
]`
|
||||
|
||||
type LocationAgent struct {
|
||||
client client.AgentClient
|
||||
|
||||
locationModel models.LocationModel
|
||||
}
|
||||
|
||||
type listLocationArguments struct{}
|
||||
type createLocationArguments struct {
|
||||
Name string `json:"name"`
|
||||
@ -93,24 +104,16 @@ type linkLocationArguments struct {
|
||||
LocationID string `json:"locationId"`
|
||||
}
|
||||
|
||||
func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error) {
|
||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: "Locations 📍",
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
return LocationAgent{}, err
|
||||
}
|
||||
|
||||
agent := LocationAgent{
|
||||
client: agentClient,
|
||||
locationModel: locationModel,
|
||||
}
|
||||
func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) client.AgentClient {
|
||||
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: locationPrompt,
|
||||
JsonTools: locationTools,
|
||||
Log: log,
|
||||
EndToolCall: "finish",
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return agent.locationModel.List(context.Background(), info.UserId)
|
||||
return locationModel.List(context.Background(), info.UserId)
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||
@ -122,7 +125,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{
|
||||
location, err := locationModel.Save(ctx, info.UserId, model.Locations{
|
||||
Name: args.Name,
|
||||
Address: args.Address,
|
||||
})
|
||||
@ -131,7 +134,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
|
||||
return model.Locations{}, err
|
||||
}
|
||||
|
||||
_, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID)
|
||||
_, err = locationModel.SaveToImage(ctx, info.ImageId, location.ID)
|
||||
if err != nil {
|
||||
return model.Locations{}, err
|
||||
}
|
||||
@ -153,9 +156,13 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
|
||||
return "", err
|
||||
}
|
||||
|
||||
agent.locationModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||
locationModel.SaveToImage(ctx, info.ImageId, contactUuid)
|
||||
return "Saved", nil
|
||||
})
|
||||
|
||||
return agent, nil
|
||||
agentClient.ToolHandler.AddTool("reply", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return "ok", nil
|
||||
})
|
||||
|
||||
return agentClient
|
||||
}
|
||||
|
@ -2,11 +2,9 @@ package agents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"screenmark/screenmark/models"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
@ -43,7 +41,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(noteAgentPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
request.Chat.AddImage(imageName, imageData, nil)
|
||||
|
||||
resp, err := agent.client.Request(&request)
|
||||
if err != nil {
|
||||
@ -70,20 +68,16 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
|
||||
client, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: "Notes 📝",
|
||||
}))
|
||||
if err != nil {
|
||||
return NoteAgent{}, err
|
||||
}
|
||||
func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) NoteAgent {
|
||||
client := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: noteAgentPrompt,
|
||||
Log: log,
|
||||
})
|
||||
|
||||
agent := NoteAgent{
|
||||
client: client,
|
||||
noteModel: noteModel,
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
return agent
|
||||
}
|
||||
|
@ -2,14 +2,12 @@ package agents
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
const OrchestratorPrompt = `
|
||||
const orchestratorPrompt = `
|
||||
You are an Orchestrator for various AI agents.
|
||||
|
||||
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
||||
@ -40,7 +38,7 @@ Always call agents in parallel if you need to call more than 1.
|
||||
Do not call the agent if you do not think it is relevant for the image.
|
||||
`
|
||||
|
||||
const OrchestratorTools = `
|
||||
const orchestratorTools = `
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
@ -114,16 +112,13 @@ type Status struct {
|
||||
Ok bool `json:"ok"`
|
||||
}
|
||||
|
||||
func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locationAgent LocationAgent, eventAgent EventAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
||||
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: "Orchestrator 🎼",
|
||||
}))
|
||||
|
||||
if err != nil {
|
||||
return OrchestratorAgent{}, err
|
||||
}
|
||||
func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) client.AgentClient {
|
||||
agent := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: orchestratorPrompt,
|
||||
JsonTools: orchestratorTools,
|
||||
Log: log,
|
||||
EndToolCall: "noAction",
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
|
||||
@ -134,7 +129,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
||||
go contactAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
@ -142,7 +137,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go locationAgent.client.RunAgent(locationPrompt, locationTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
||||
go locationAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
@ -150,7 +145,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go eventAgent.client.RunAgent(eventPrompt, eventTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
||||
go eventAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
@ -165,7 +160,5 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
|
||||
}, errors.New("Finished! Kinda bad return type but...")
|
||||
})
|
||||
|
||||
return OrchestratorAgent{
|
||||
Client: agent,
|
||||
}, nil
|
||||
return agent
|
||||
}
|
||||
|
@ -3,17 +3,28 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"screenmark/screenmark/agents"
|
||||
"screenmark/screenmark/models"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func createLogger(prefix string) *log.Logger {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: prefix,
|
||||
})
|
||||
|
||||
logger.SetLevel(log.DebugLevel)
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
||||
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
|
||||
if err != nil {
|
||||
@ -28,6 +39,9 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
||||
imageModel := models.NewImageModel(db)
|
||||
contactModel := models.NewContactModel(db)
|
||||
|
||||
databaseEventLog := createLogger("Database Events 🤖")
|
||||
databaseEventLog.SetLevel(log.DebugLevel)
|
||||
|
||||
err := listener.Listen("new_image")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -39,55 +53,41 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
||||
imageId := uuid.MustParse(parameters.Extra)
|
||||
eventManager.listeners[parameters.Extra] = make(chan string)
|
||||
|
||||
databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
go func() {
|
||||
noteAgent, err := agents.NewNoteAgent(noteModel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
contactAgent, err := agents.NewContactAgent(contactModel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
locationAgent, err := agents.NewLocationAgent(locationModel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
eventAgent, err := agents.NewEventAgent(eventModel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
noteAgent := agents.NewNoteAgent(createLogger("Notes 📝"), noteModel)
|
||||
contactAgent := agents.NewContactAgent(createLogger("Contacts 👥"), contactModel)
|
||||
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍"), locationModel)
|
||||
eventAgent := agents.NewEventAgent(createLogger("Events 📅"), eventModel, locationModel)
|
||||
|
||||
image, err := imageModel.GetToProcessWithData(ctx, imageId)
|
||||
if err != nil {
|
||||
log.Println("Failed to GetToProcessWithData")
|
||||
log.Println(err)
|
||||
databaseEventLog.Error("Failed to GetToProcessWithData", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
||||
log.Println("Failed to FinishProcessing")
|
||||
log.Println(err)
|
||||
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
orchestrator, err := agents.NewOrchestratorAgent(noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
||||
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼"), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
||||
err = orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
databaseEventLog.Error("Orchestrator failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Still need to find some way to hide this complexity away.
|
||||
// I don't think wrapping agents in structs actually works too well.
|
||||
err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
databaseEventLog.Error("Failed to finish processing", "ImageID", imageId)
|
||||
return
|
||||
}
|
||||
|
||||
imageModel.FinishProcessing(ctx, image.ID)
|
||||
databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@ -122,9 +122,6 @@ func ListenProcessingImageStatus(db *sql.DB, eventManager *EventManager) {
|
||||
stringUuid := data.Extra[0:36]
|
||||
status := data.Extra[36:]
|
||||
|
||||
fmt.Printf("UUID: %s\n", stringUuid)
|
||||
fmt.Printf("Receiving :s\n", data.Extra)
|
||||
|
||||
imageListener, exists := eventManager.listeners[stringUuid]
|
||||
if !exists {
|
||||
continue
|
||||
|
Reference in New Issue
Block a user