feat(tool-calls): listLocation tool call handling
This commit is contained in:
@ -48,8 +48,10 @@ type AgentMessage interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AgentTextMessage struct {
|
type AgentTextMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
||||||
@ -57,6 +59,16 @@ func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
|||||||
return json.Marshal(textContent)
|
return json.Marshal(textContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AgentAssistantToolCall struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
|
||||||
|
return json.Marshal(toolCall)
|
||||||
|
}
|
||||||
|
|
||||||
type AgentArrayMessage struct {
|
type AgentArrayMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content []AgentContent `json:"content"`
|
Content []AgentContent `json:"content"`
|
||||||
@ -66,6 +78,14 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
|||||||
return json.Marshal(arrayContent)
|
return json.Marshal(arrayContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (content *AgentMessages) AddText(message AgentTextMessage) {
|
||||||
|
content.Messages = append(content.Messages, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
|
||||||
|
content.Messages = append(content.Messages, toolCall)
|
||||||
|
}
|
||||||
|
|
||||||
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
||||||
extension := filepath.Ext(imageName)
|
extension := filepath.Ext(imageName)
|
||||||
if len(extension) == 0 {
|
if len(extension) == 0 {
|
||||||
@ -320,9 +340,21 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Id string `json:"id"`
|
||||||
|
Function FunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
type ResponseChoiceMessage struct {
|
type ResponseChoiceMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseChoice struct {
|
type ResponseChoice struct {
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
package agents
|
package agents
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"screenmark/screenmark/models"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventLocationPrompt = `
|
const eventLocationPrompt = `
|
||||||
@ -13,6 +17,8 @@ Your job is to check if an image has an event or a location and use the correct
|
|||||||
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
||||||
|
|
||||||
It is possible that there is no location or event on an image.
|
It is possible that there is no location or event on an image.
|
||||||
|
|
||||||
|
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
|
||||||
`
|
`
|
||||||
|
|
||||||
// TODO: this should be read directly from a file on load.
|
// TODO: this should be read directly from a file on load.
|
||||||
@ -77,9 +83,14 @@ const TOOLS = `
|
|||||||
]
|
]
|
||||||
`
|
`
|
||||||
|
|
||||||
type EventLocationAgent = AgentClient
|
type EventLocationAgent struct {
|
||||||
|
client AgentClient
|
||||||
|
|
||||||
func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error {
|
eventModel models.EventModel
|
||||||
|
locationModel models.LocationModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
var tools any
|
var tools any
|
||||||
err := json.Unmarshal([]byte(TOOLS), &tools)
|
err := json.Unmarshal([]byte(TOOLS), &tools)
|
||||||
|
|
||||||
@ -109,12 +120,12 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpRequest, err := agent.getRequest(jsonAiRequest)
|
httpRequest, err := agent.client.getRequest(jsonAiRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := agent.Do(httpRequest)
|
resp, err := agent.client.Do(httpRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -126,14 +137,70 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte)
|
|||||||
|
|
||||||
log.Println(string(response))
|
log.Println(string(response))
|
||||||
|
|
||||||
|
agentResponse := AgentResponse{}
|
||||||
|
err = json.Unmarshal(response, &agentResponse)
|
||||||
|
|
||||||
|
toolCalls := agentResponse.Choices[0].Message.ToolCalls[0]
|
||||||
|
|
||||||
|
if toolCalls.Function.Name == "listLocations" {
|
||||||
|
locations, err := agent.locationModel.List(context.Background(), userId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonLocations, err := json.Marshal(locations)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.AddToolCall(AgentAssistantToolCall{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: []ToolCall{toolCalls},
|
||||||
|
})
|
||||||
|
|
||||||
|
request.AddText(AgentTextMessage{
|
||||||
|
Role: "tool",
|
||||||
|
Name: "listLocations",
|
||||||
|
Content: string(jsonLocations),
|
||||||
|
ToolCallId: toolCalls.Id,
|
||||||
|
})
|
||||||
|
|
||||||
|
jsonAiRequest, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpRequest, err := agent.client.getRequest(jsonAiRequest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := agent.client.Do(httpRequest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(string(response))
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocationEventAgent() (EventLocationAgent, error) {
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
||||||
agent, err := CreateAgentClient(eventLocationPrompt)
|
client, err := CreateAgentClient(eventLocationPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return EventLocationAgent{}, err
|
return EventLocationAgent{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return agent, nil
|
return EventLocationAgent{
|
||||||
|
client: client,
|
||||||
|
locationModel: locationModel,
|
||||||
|
eventModel: eventModel,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,7 @@ func main() {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
locationAgent, err := agents.NewLocationEventAgent()
|
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -121,7 +121,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Calling locationAgent!")
|
log.Println("Calling locationAgent!")
|
||||||
locationAgent.GetLocations(image.Image.ImageName, image.Image.Image)
|
locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
|
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -3,6 +3,7 @@ package models
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
"log"
|
"log"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||||
@ -39,6 +40,21 @@ func getValues(location model.Locations) []any {
|
|||||||
return arr
|
return arr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) {
|
||||||
|
listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns).
|
||||||
|
FROM(
|
||||||
|
Locations.
|
||||||
|
INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)).
|
||||||
|
INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)),
|
||||||
|
).WHERE(UserImages.UserID.EQ(UUID(userId)))
|
||||||
|
|
||||||
|
locations := []model.Locations{}
|
||||||
|
|
||||||
|
err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations)
|
||||||
|
|
||||||
|
return locations, err
|
||||||
|
}
|
||||||
|
|
||||||
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) {
|
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) {
|
||||||
insertLocationStmt := Locations.
|
insertLocationStmt := Locations.
|
||||||
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
|
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
|
||||||
|
Reference in New Issue
Block a user