feat(location): working e2e with tool calling

This commit is contained in:
2025-03-22 12:22:31 +00:00
parent 7b6c7090f8
commit dfb4b34de3
5 changed files with 170 additions and 76 deletions

View File

@ -393,6 +393,49 @@ func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
return imageInfo, nil
}
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
log.Println(string(response))
toolCalls := agentResponse.Choices[0].Message.ToolCalls
if len(toolCalls) > 0 {
// Should for sure be more flexible.
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: toolCalls,
})
}
return agentResponse, nil
}
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
if err != nil {

View File

@ -3,8 +3,9 @@ package agents
import (
"context"
"encoding/json"
"io"
"errors"
"log"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/models"
"github.com/google/uuid"
@ -15,6 +16,8 @@ You are an agent that extracts events and locations from an image.
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
Only create an event if you see an event on the image, not all locations have an associated event.
DO NOT CREATE EVENTS
It is possible that there is no location or event on an image.
@ -79,6 +82,17 @@ const TOOLS = `
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "finish",
"description": "Nothing else to do, call this function.",
"parameters": {
"type": "object",
"properties": {}
}
}
}
]
`
@ -90,15 +104,92 @@ type EventLocationAgent struct {
locationModel models.LocationModel
}
type CreateLocationArguments struct {
Name string `json:"name"`
Address *string `json:"address,omitempty"`
Coordinates *string `json:"coordinates,omitempty"`
}
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentRequestBody) error {
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
toolCall, ok := latestMessage.(AgentAssistantToolCall)
if !ok {
return errors.New("Latest message is not a tool_call")
}
for _, call := range toolCall.ToolCalls {
if call.Function.Name == "listLocations" {
log.Println("Function call: listLocations")
locations, err := agent.locationModel.List(context.Background(), userId)
if err != nil {
return err
}
jsonLocations, err := json.Marshal(locations)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "listLocations",
Content: string(jsonLocations),
ToolCallId: call.Id,
})
} else if call.Function.Name == "createLocation" {
log.Println("Function call: createLocation")
locationArguments := CreateLocationArguments{}
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
if err != nil {
return err
}
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
Name: locationArguments.Name,
Address: locationArguments.Address,
Coordinates: locationArguments.Coordinates,
}})
if err != nil {
return err
}
createdLocation := locations[0]
jsonCreatedLocation, err := json.Marshal(createdLocation)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(jsonCreatedLocation),
ToolCallId: call.Id,
})
} else if call.Function.Name == "finish" {
log.Println("Finished!")
return errors.New("hmmm, this isnt actually an error but hey")
} else {
return errors.New("Unknown tool_call")
}
}
return nil
}
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
auto := "auto"
toolChoice := "any"
request := AgentRequestBody{
Tools: &tools,
ToolChoice: &auto,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
@ -115,83 +206,28 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string,
request.AddImage(imageName, imageData)
jsonAiRequest, err := json.Marshal(request)
_, err = agent.client.Request(&request)
if err != nil {
return err
}
httpRequest, err := agent.client.getRequest(jsonAiRequest)
if err != nil {
err = agent.HandleToolCall(userId, &request)
for err == nil {
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
response, requestError := agent.client.Request(&request)
if requestError != nil {
return requestError
}
log.Println(response)
err = agent.HandleToolCall(userId, &request)
}
return err
}
resp, err := agent.client.Do(httpRequest)
if err != nil {
return err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
log.Println(string(response))
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
toolCalls := agentResponse.Choices[0].Message.ToolCalls[0]
if toolCalls.Function.Name == "listLocations" {
locations, err := agent.locationModel.List(context.Background(), userId)
if err != nil {
return err
}
jsonLocations, err := json.Marshal(locations)
if err != nil {
return err
}
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{toolCalls},
})
request.AddText(AgentTextMessage{
Role: "tool",
Name: "listLocations",
Content: string(jsonLocations),
ToolCallId: toolCalls.Id,
})
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return err
}
httpRequest, err := agent.client.getRequest(jsonAiRequest)
if err != nil {
return err
}
resp, err := agent.client.Do(httpRequest)
if err != nil {
return err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
log.Println(string(response))
}
return nil
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
if err != nil {

View File

@ -121,7 +121,9 @@ func main() {
}
log.Println("Calling locationAgent!")
locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
err = locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
log.Println(err)
return

View File

@ -55,7 +55,7 @@ func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Loca
return locations, err
}
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) {
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) ([]model.Locations, error) {
insertLocationStmt := Locations.
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
@ -67,7 +67,7 @@ func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (m
log.Println(insertLocationStmt.DebugSql())
insertedLocation := model.Locations{}
insertedLocation := []model.Locations{}
err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation)
return insertedLocation, err
@ -84,9 +84,11 @@ func (m LocationModel) SaveToImage(ctx context.Context, imageId uuid.UUID, locat
return err
}
// TODO: doesnt work if array is more than 1. BIG TODO
insertImageLocationStmt := ImageLocations.
INSERT(ImageLocations.ImageID, ImageLocations.LocationID).
VALUES(imageId, location.ID)
VALUES(imageId, location[0].ID)
_, err = insertImageLocationStmt.ExecContext(ctx, m.dbPool)

View File

@ -54,5 +54,16 @@
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "finish",
"description": "Nothing else to do, call this function.",
"parameters": {
"type": "object",
"properties": {}
}
}
}
]