feat(tool-calls): listLocation tool call handling
This commit is contained in:
@ -50,6 +50,8 @@ type AgentMessage interface {
|
||||
type AgentTextMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
||||
@ -57,6 +59,16 @@ func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(textContent)
|
||||
}
|
||||
|
||||
type AgentAssistantToolCall struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(toolCall)
|
||||
}
|
||||
|
||||
type AgentArrayMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []AgentContent `json:"content"`
|
||||
@ -66,6 +78,14 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(arrayContent)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddText(message AgentTextMessage) {
|
||||
content.Messages = append(content.Messages, message)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
|
||||
content.Messages = append(content.Messages, toolCall)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
||||
extension := filepath.Ext(imageName)
|
||||
if len(extension) == 0 {
|
||||
@ -320,9 +340,21 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im
|
||||
return request, nil
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type ResponseChoiceMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
type ResponseChoice struct {
|
||||
|
@ -1,9 +1,13 @@
|
||||
package agents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"screenmark/screenmark/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const eventLocationPrompt = `
|
||||
@ -13,6 +17,8 @@ Your job is to check if an image has an event or a location and use the correct
|
||||
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
||||
|
||||
It is possible that there is no location or event on an image.
|
||||
|
||||
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
|
||||
`
|
||||
|
||||
// TODO: this should be read directly from a file on load.
|
||||
@ -77,9 +83,14 @@ const TOOLS = `
|
||||
]
|
||||
`
|
||||
|
||||
type EventLocationAgent = AgentClient
|
||||
type EventLocationAgent struct {
|
||||
client AgentClient
|
||||
|
||||
func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error {
|
||||
eventModel models.EventModel
|
||||
locationModel models.LocationModel
|
||||
}
|
||||
|
||||
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(TOOLS), &tools)
|
||||
|
||||
@ -109,12 +120,12 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
|
||||
return err
|
||||
}
|
||||
|
||||
httpRequest, err := agent.getRequest(jsonAiRequest)
|
||||
httpRequest, err := agent.client.getRequest(jsonAiRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := agent.Do(httpRequest)
|
||||
resp, err := agent.client.Do(httpRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -126,14 +137,70 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
|
||||
|
||||
log.Println(string(response))
|
||||
|
||||
agentResponse := AgentResponse{}
|
||||
err = json.Unmarshal(response, &agentResponse)
|
||||
|
||||
toolCalls := agentResponse.Choices[0].Message.ToolCalls[0]
|
||||
|
||||
if toolCalls.Function.Name == "listLocations" {
|
||||
locations, err := agent.locationModel.List(context.Background(), userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonLocations, err := json.Marshal(locations)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.AddToolCall(AgentAssistantToolCall{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{toolCalls},
|
||||
})
|
||||
|
||||
request.AddText(AgentTextMessage{
|
||||
Role: "tool",
|
||||
Name: "listLocations",
|
||||
Content: string(jsonLocations),
|
||||
ToolCallId: toolCalls.Id,
|
||||
})
|
||||
|
||||
jsonAiRequest, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpRequest, err := agent.client.getRequest(jsonAiRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := agent.client.Do(httpRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println(string(response))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewLocationEventAgent() (EventLocationAgent, error) {
|
||||
agent, err := CreateAgentClient(eventLocationPrompt)
|
||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
||||
client, err := CreateAgentClient(eventLocationPrompt)
|
||||
if err != nil {
|
||||
return EventLocationAgent{}, err
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
return EventLocationAgent{
|
||||
client: client,
|
||||
locationModel: locationModel,
|
||||
eventModel: eventModel,
|
||||
}, nil
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
locationAgent, err := agents.NewLocationEventAgent()
|
||||
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -121,7 +121,9 @@ func main() {
|
||||
}
|
||||
|
||||
log.Println("Calling locationAgent!")
|
||||
locationAgent.GetLocations(image.Image.ImageName, image.Image.Image)
|
||||
locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
|
||||
|
||||
return
|
||||
|
||||
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
|
||||
if err != nil {
|
||||
|
@ -3,6 +3,7 @@ package models
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
. "github.com/go-jet/jet/v2/postgres"
|
||||
"log"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||
@ -39,6 +40,21 @@ func getValues(location model.Locations) []any {
|
||||
return arr
|
||||
}
|
||||
|
||||
func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) {
|
||||
listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns).
|
||||
FROM(
|
||||
Locations.
|
||||
INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)).
|
||||
INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)),
|
||||
).WHERE(UserImages.UserID.EQ(UUID(userId)))
|
||||
|
||||
locations := []model.Locations{}
|
||||
|
||||
err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations)
|
||||
|
||||
return locations, err
|
||||
}
|
||||
|
||||
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) {
|
||||
insertLocationStmt := Locations.
|
||||
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
|
||||
|
Reference in New Issue
Block a user