feat(tool-calls): listLocation tool call handling

This commit is contained in:
2025-03-22 11:14:00 +00:00
parent 87869543f7
commit 7b6c7090f8
4 changed files with 130 additions and 13 deletions

View File

@ -48,8 +48,10 @@ type AgentMessage interface {
} }
type AgentTextMessage struct { type AgentTextMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
} }
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
@ -57,6 +59,16 @@ func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
return json.Marshal(textContent) return json.Marshal(textContent)
} }
type AgentAssistantToolCall struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
return json.Marshal(toolCall)
}
type AgentArrayMessage struct { type AgentArrayMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content []AgentContent `json:"content"` Content []AgentContent `json:"content"`
@ -66,6 +78,14 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent) return json.Marshal(arrayContent)
} }
func (content *AgentMessages) AddText(message AgentTextMessage) {
content.Messages = append(content.Messages, message)
}
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
content.Messages = append(content.Messages, toolCall)
}
func (content *AgentMessages) AddImage(imageName string, image []byte) error { func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName) extension := filepath.Ext(imageName)
if len(extension) == 0 { if len(extension) == 0 {
@ -320,9 +340,21 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im
return request, nil return request, nil
} }
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type ResponseChoiceMessage struct { type ResponseChoiceMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
} }
type ResponseChoice struct { type ResponseChoice struct {

View File

@ -1,9 +1,13 @@
package agents package agents
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"log" "log"
"screenmark/screenmark/models"
"github.com/google/uuid"
) )
const eventLocationPrompt = ` const eventLocationPrompt = `
@ -13,6 +17,8 @@ Your job is to check if an image has an event or a location and use the correct
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location. If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
It is possible that there is no location or event on an image. It is possible that there is no location or event on an image.
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
` `
// TODO: this should be read directly from a file on load. // TODO: this should be read directly from a file on load.
@ -77,9 +83,14 @@ const TOOLS = `
] ]
` `
type EventLocationAgent = AgentClient type EventLocationAgent struct {
client AgentClient
func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error { eventModel models.EventModel
locationModel models.LocationModel
}
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
var tools any var tools any
err := json.Unmarshal([]byte(TOOLS), &tools) err := json.Unmarshal([]byte(TOOLS), &tools)
@ -109,12 +120,12 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
return err return err
} }
httpRequest, err := agent.getRequest(jsonAiRequest) httpRequest, err := agent.client.getRequest(jsonAiRequest)
if err != nil { if err != nil {
return err return err
} }
resp, err := agent.Do(httpRequest) resp, err := agent.client.Do(httpRequest)
if err != nil { if err != nil {
return err return err
} }
@ -126,14 +137,70 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
log.Println(string(response)) log.Println(string(response))
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
toolCalls := agentResponse.Choices[0].Message.ToolCalls[0]
if toolCalls.Function.Name == "listLocations" {
locations, err := agent.locationModel.List(context.Background(), userId)
if err != nil {
return err
}
jsonLocations, err := json.Marshal(locations)
if err != nil {
return err
}
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{toolCalls},
})
request.AddText(AgentTextMessage{
Role: "tool",
Name: "listLocations",
Content: string(jsonLocations),
ToolCallId: toolCalls.Id,
})
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return err
}
httpRequest, err := agent.client.getRequest(jsonAiRequest)
if err != nil {
return err
}
resp, err := agent.client.Do(httpRequest)
if err != nil {
return err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
log.Println(string(response))
}
return nil return nil
} }
func NewLocationEventAgent() (EventLocationAgent, error) { func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
agent, err := CreateAgentClient(eventLocationPrompt) client, err := CreateAgentClient(eventLocationPrompt)
if err != nil { if err != nil {
return EventLocationAgent{}, err return EventLocationAgent{}, err
} }
return agent, nil return EventLocationAgent{
client: client,
locationModel: locationModel,
eventModel: eventModel,
}, nil
} }

View File

@ -108,7 +108,7 @@ func main() {
panic(err) panic(err)
} }
locationAgent, err := agents.NewLocationEventAgent() locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -121,7 +121,9 @@ func main() {
} }
log.Println("Calling locationAgent!") log.Println("Calling locationAgent!")
locationAgent.GetLocations(image.Image.ImageName, image.Image.Image) locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
return
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@ package models
import ( import (
"context" "context"
"database/sql" "database/sql"
. "github.com/go-jet/jet/v2/postgres"
"log" "log"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
. "screenmark/screenmark/.gen/haystack/haystack/table" . "screenmark/screenmark/.gen/haystack/haystack/table"
@ -39,6 +40,21 @@ func getValues(location model.Locations) []any {
return arr return arr
} }
func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) {
listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns).
FROM(
Locations.
INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)).
INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)),
).WHERE(UserImages.UserID.EQ(UUID(userId)))
locations := []model.Locations{}
err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations)
return locations, err
}
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) { func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) {
insertLocationStmt := Locations. insertLocationStmt := Locations.
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description) INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)