feat(location): working e2e with tool calling
This commit is contained in:
@ -393,6 +393,49 @@ func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
|
|||||||
return imageInfo, nil
|
return imageInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
|
||||||
|
jsonAiRequest, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpRequest, err := client.getRequest(jsonAiRequest)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(httpRequest)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
agentResponse := AgentResponse{}
|
||||||
|
err = json.Unmarshal(response, &agentResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(string(response))
|
||||||
|
|
||||||
|
toolCalls := agentResponse.Choices[0].Message.ToolCalls
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
// Should for sure be more flexible.
|
||||||
|
request.AddToolCall(AgentAssistantToolCall{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return agentResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
||||||
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
|
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -3,8 +3,9 @@ package agents
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -15,6 +16,8 @@ You are an agent that extracts events and locations from an image.
|
|||||||
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
|
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
|
||||||
|
|
||||||
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
||||||
|
Only create an event if you see an event on the image, not all locations have an associated event.
|
||||||
|
DO NOT CREATE EVENTS
|
||||||
|
|
||||||
It is possible that there is no location or event on an image.
|
It is possible that there is no location or event on an image.
|
||||||
|
|
||||||
@ -79,6 +82,17 @@ const TOOLS = `
|
|||||||
"required": ["name"]
|
"required": ["name"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "finish",
|
||||||
|
"description": "Nothing else to do, call this function.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
`
|
`
|
||||||
@ -90,15 +104,92 @@ type EventLocationAgent struct {
|
|||||||
locationModel models.LocationModel
|
locationModel models.LocationModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CreateLocationArguments struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address *string `json:"address,omitempty"`
|
||||||
|
Coordinates *string `json:"coordinates,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentRequestBody) error {
|
||||||
|
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
|
||||||
|
|
||||||
|
toolCall, ok := latestMessage.(AgentAssistantToolCall)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return errors.New("Latest message is not a tool_call")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, call := range toolCall.ToolCalls {
|
||||||
|
if call.Function.Name == "listLocations" {
|
||||||
|
log.Println("Function call: listLocations")
|
||||||
|
|
||||||
|
locations, err := agent.locationModel.List(context.Background(), userId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonLocations, err := json.Marshal(locations)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.AddText(AgentTextMessage{
|
||||||
|
Role: "tool",
|
||||||
|
Name: "listLocations",
|
||||||
|
Content: string(jsonLocations),
|
||||||
|
ToolCallId: call.Id,
|
||||||
|
})
|
||||||
|
} else if call.Function.Name == "createLocation" {
|
||||||
|
log.Println("Function call: createLocation")
|
||||||
|
|
||||||
|
locationArguments := CreateLocationArguments{}
|
||||||
|
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
|
||||||
|
Name: locationArguments.Name,
|
||||||
|
Address: locationArguments.Address,
|
||||||
|
Coordinates: locationArguments.Coordinates,
|
||||||
|
}})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
createdLocation := locations[0]
|
||||||
|
jsonCreatedLocation, err := json.Marshal(createdLocation)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.AddText(AgentTextMessage{
|
||||||
|
Role: "tool",
|
||||||
|
Name: "createLocation",
|
||||||
|
Content: string(jsonCreatedLocation),
|
||||||
|
ToolCallId: call.Id,
|
||||||
|
})
|
||||||
|
} else if call.Function.Name == "finish" {
|
||||||
|
log.Println("Finished!")
|
||||||
|
return errors.New("hmmm, this isnt actually an error but hey")
|
||||||
|
} else {
|
||||||
|
return errors.New("Unknown tool_call")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
|
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
var tools any
|
var tools any
|
||||||
err := json.Unmarshal([]byte(TOOLS), &tools)
|
err := json.Unmarshal([]byte(TOOLS), &tools)
|
||||||
|
|
||||||
auto := "auto"
|
toolChoice := "any"
|
||||||
|
|
||||||
request := AgentRequestBody{
|
request := AgentRequestBody{
|
||||||
Tools: &tools,
|
Tools: &tools,
|
||||||
ToolChoice: &auto,
|
ToolChoice: &toolChoice,
|
||||||
Model: "pixtral-12b-2409",
|
Model: "pixtral-12b-2409",
|
||||||
Temperature: 0.3,
|
Temperature: 0.3,
|
||||||
ResponseFormat: ResponseFormat{
|
ResponseFormat: ResponseFormat{
|
||||||
@ -115,81 +206,26 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string,
|
|||||||
|
|
||||||
request.AddImage(imageName, imageData)
|
request.AddImage(imageName, imageData)
|
||||||
|
|
||||||
jsonAiRequest, err := json.Marshal(request)
|
_, err = agent.client.Request(&request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpRequest, err := agent.client.getRequest(jsonAiRequest)
|
err = agent.HandleToolCall(userId, &request)
|
||||||
if err != nil {
|
for err == nil {
|
||||||
return err
|
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
||||||
|
|
||||||
|
response, requestError := agent.client.Request(&request)
|
||||||
|
if requestError != nil {
|
||||||
|
return requestError
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(response)
|
||||||
|
|
||||||
|
err = agent.HandleToolCall(userId, &request)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := agent.client.Do(httpRequest)
|
return err
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(string(response))
|
|
||||||
|
|
||||||
agentResponse := AgentResponse{}
|
|
||||||
err = json.Unmarshal(response, &agentResponse)
|
|
||||||
|
|
||||||
toolCalls := agentResponse.Choices[0].Message.ToolCalls[0]
|
|
||||||
|
|
||||||
if toolCalls.Function.Name == "listLocations" {
|
|
||||||
locations, err := agent.locationModel.List(context.Background(), userId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonLocations, err := json.Marshal(locations)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.AddToolCall(AgentAssistantToolCall{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "",
|
|
||||||
ToolCalls: []ToolCall{toolCalls},
|
|
||||||
})
|
|
||||||
|
|
||||||
request.AddText(AgentTextMessage{
|
|
||||||
Role: "tool",
|
|
||||||
Name: "listLocations",
|
|
||||||
Content: string(jsonLocations),
|
|
||||||
ToolCallId: toolCalls.Id,
|
|
||||||
})
|
|
||||||
|
|
||||||
jsonAiRequest, err := json.Marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
httpRequest, err := agent.client.getRequest(jsonAiRequest)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := agent.client.Do(httpRequest)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(string(response))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
||||||
|
@ -121,7 +121,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Calling locationAgent!")
|
log.Println("Calling locationAgent!")
|
||||||
locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
|
err = locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
|
||||||
|
|
||||||
|
log.Println(err)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Loca
|
|||||||
return locations, err
|
return locations, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) {
|
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) ([]model.Locations, error) {
|
||||||
insertLocationStmt := Locations.
|
insertLocationStmt := Locations.
|
||||||
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
|
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (m
|
|||||||
|
|
||||||
log.Println(insertLocationStmt.DebugSql())
|
log.Println(insertLocationStmt.DebugSql())
|
||||||
|
|
||||||
insertedLocation := model.Locations{}
|
insertedLocation := []model.Locations{}
|
||||||
err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation)
|
err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation)
|
||||||
|
|
||||||
return insertedLocation, err
|
return insertedLocation, err
|
||||||
@ -84,9 +84,11 @@ func (m LocationModel) SaveToImage(ctx context.Context, imageId uuid.UUID, locat
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: doesnt work if array is more than 1. BIG TODO
|
||||||
|
|
||||||
insertImageLocationStmt := ImageLocations.
|
insertImageLocationStmt := ImageLocations.
|
||||||
INSERT(ImageLocations.ImageID, ImageLocations.LocationID).
|
INSERT(ImageLocations.ImageID, ImageLocations.LocationID).
|
||||||
VALUES(imageId, location.ID)
|
VALUES(imageId, location[0].ID)
|
||||||
|
|
||||||
_, err = insertImageLocationStmt.ExecContext(ctx, m.dbPool)
|
_, err = insertImageLocationStmt.ExecContext(ctx, m.dbPool)
|
||||||
|
|
||||||
|
@ -54,5 +54,16 @@
|
|||||||
"required": ["name"]
|
"required": ["name"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "finish",
|
||||||
|
"description": "Nothing else to do, call this function.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
Reference in New Issue
Block a user