feat(tool-calls): listLocation tool call handling

This commit is contained in:
2025-03-22 11:14:00 +00:00
parent 7c473e054a
commit aad45fcf52
4 changed files with 130 additions and 13 deletions

View File

@ -50,6 +50,8 @@ type AgentMessage interface {
type AgentTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
@ -57,6 +59,16 @@ func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
return json.Marshal(textContent)
}
type AgentAssistantToolCall struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
return json.Marshal(toolCall)
}
type AgentArrayMessage struct {
Role string `json:"role"`
Content []AgentContent `json:"content"`
@ -66,6 +78,14 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *AgentMessages) AddText(message AgentTextMessage) {
content.Messages = append(content.Messages, message)
}
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
content.Messages = append(content.Messages, toolCall)
}
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
@ -320,9 +340,21 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im
return request, nil
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type ResponseChoiceMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
type ResponseChoice struct {

View File

@ -1,9 +1,13 @@
package agents
import (
"context"
"encoding/json"
"io"
"log"
"screenmark/screenmark/models"
"github.com/google/uuid"
)
const eventLocationPrompt = `
@ -13,6 +17,8 @@ Your job is to check if an image has an event or a location and use the correct
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
It is possible that there is no location or event on an image.
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
`
// TODO: this should be read directly from a file on load.
@ -77,9 +83,14 @@ const TOOLS = `
]
`
type EventLocationAgent = AgentClient
type EventLocationAgent struct {
client AgentClient
func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error {
eventModel models.EventModel
locationModel models.LocationModel
}
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
@ -109,12 +120,12 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
return err
}
httpRequest, err := agent.getRequest(jsonAiRequest)
httpRequest, err := agent.client.getRequest(jsonAiRequest)
if err != nil {
return err
}
resp, err := agent.Do(httpRequest)
resp, err := agent.client.Do(httpRequest)
if err != nil {
return err
}
@ -126,14 +137,70 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
log.Println(string(response))
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
toolCalls := agentResponse.Choices[0].Message.ToolCalls[0]
if toolCalls.Function.Name == "listLocations" {
locations, err := agent.locationModel.List(context.Background(), userId)
if err != nil {
return err
}
jsonLocations, err := json.Marshal(locations)
if err != nil {
return err
}
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{toolCalls},
})
request.AddText(AgentTextMessage{
Role: "tool",
Name: "listLocations",
Content: string(jsonLocations),
ToolCallId: toolCalls.Id,
})
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return err
}
httpRequest, err := agent.client.getRequest(jsonAiRequest)
if err != nil {
return err
}
resp, err := agent.client.Do(httpRequest)
if err != nil {
return err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
log.Println(string(response))
}
return nil
}
func NewLocationEventAgent() (EventLocationAgent, error) {
agent, err := CreateAgentClient(eventLocationPrompt)
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
if err != nil {
return EventLocationAgent{}, err
}
return agent, nil
return EventLocationAgent{
client: client,
locationModel: locationModel,
eventModel: eventModel,
}, nil
}

View File

@ -108,7 +108,7 @@ func main() {
panic(err)
}
locationAgent, err := agents.NewLocationEventAgent()
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel)
if err != nil {
panic(err)
}
@ -121,7 +121,9 @@ func main() {
}
log.Println("Calling locationAgent!")
locationAgent.GetLocations(image.Image.ImageName, image.Image.Image)
locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
return
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
if err != nil {

View File

@ -3,6 +3,7 @@ package models
import (
"context"
"database/sql"
. "github.com/go-jet/jet/v2/postgres"
"log"
"screenmark/screenmark/.gen/haystack/haystack/model"
. "screenmark/screenmark/.gen/haystack/haystack/table"
@ -39,6 +40,21 @@ func getValues(location model.Locations) []any {
return arr
}
func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) {
listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns).
FROM(
Locations.
INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)).
INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)),
).WHERE(UserImages.UserID.EQ(UUID(userId)))
locations := []model.Locations{}
err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations)
return locations, err
}
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) {
insertLocationStmt := Locations.
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)