feat: using tools for event loocation agent

This commit is contained in:
2025-03-22 10:12:51 +00:00
parent 1cd4698969
commit 87869543f7
5 changed files with 222 additions and 15 deletions

View File

@ -1,4 +1,4 @@
package main package agents
import ( import (
"bytes" "bytes"
@ -33,6 +33,9 @@ type AgentRequestBody struct {
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"` ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
AgentMessages AgentMessages
} }
@ -145,10 +148,6 @@ You are an image information extractor. The user will provide you with screensho
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags. that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
Be sure to extract every link (URL) that you find. Be sure to extract every link (URL) that you find.
Use generic tags. Use generic tags.
You also want to extract events in the image, and the location/locations this event is hosted in.
You need to extract locations in the image if any exist, and give the approximate coordinates for this location.
` `
const RESPONSE_FORMAT = ` const RESPONSE_FORMAT = `
@ -262,7 +261,7 @@ const RESPONSE_FORMAT = `
} }
` `
func CreateAgentClient() (AgentClient, error) { func CreateAgentClient(prompt string) (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY) apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 { if len(apiKey) == 0 {
@ -272,7 +271,7 @@ func CreateAgentClient() (AgentClient, error) {
return AgentClient{ return AgentClient{
apiKey: apiKey, apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions", url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: PROMPT, systemPrompt: prompt,
Do: func(req *http.Request) (*http.Response, error) { Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{} client := &http.Client{}
return client.Do(req) return client.Do(req)
@ -309,11 +308,15 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im
return request, err return request, err
} }
log.Println(request)
err = request.AddImage(imageName, imageData) err = request.AddImage(imageName, imageData)
if err != nil { if err != nil {
return request, err return request, err
} }
request.Tools = nil
return request, nil return request, nil
} }
@ -370,8 +373,6 @@ func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (Imag
return ImageInfo{}, err return ImageInfo{}, err
} }
log.Println(jsonSchema)
aiRequest.ResponseFormat = ResponseFormat{ aiRequest.ResponseFormat = ResponseFormat{
Type: "json_schema", Type: "json_schema",
JsonSchema: jsonSchema, JsonSchema: jsonSchema,

View File

@ -1,4 +1,4 @@
package main package agents
import ( import (
"bytes" "bytes"

View File

@ -0,0 +1,139 @@
package agents
import (
"encoding/json"
"io"
"log"
)
const eventLocationPrompt = `
You are an agent that extracts events and locations from an image.
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
It is possible that there is no location or event on an image.
`
// TODO: this should be read directly from a file on load.
const TOOLS = `
[
{
"type": "function",
"function": {
"name": "createLocation",
"description": "Creates a location",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"coordinates": {
"type": "string"
},
"address": {
"type": "string"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "listLocations",
"description": "Lists the locations available",
"parameters": {
"type": "object",
"properties": {}
}
}
},
{
"type": "function",
"function": {
"name": "createEvent",
"description": "Creates a new event",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"datetime": {
"type": "string"
},
"locationId": {
"type": "string",
"description": "The ID of the location, available by listLocations"
}
},
"required": ["name"]
}
}
}
]
`
type EventLocationAgent = AgentClient
func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
auto := "auto"
request := AgentRequestBody{
Tools: &tools,
ToolChoice: &auto,
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
Type: "text",
},
}
err = request.AddSystem(eventLocationPrompt)
if err != nil {
return err
}
log.Println(request)
request.AddImage(imageName, imageData)
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return err
}
httpRequest, err := agent.getRequest(jsonAiRequest)
if err != nil {
return err
}
resp, err := agent.Do(httpRequest)
if err != nil {
return err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
log.Println(string(response))
return nil
}
func NewLocationEventAgent() (EventLocationAgent, error) {
agent, err := CreateAgentClient(eventLocationPrompt)
if err != nil {
return EventLocationAgent{}, err
}
return agent, nil
}

View File

@ -12,6 +12,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time" "time"
@ -23,21 +24,21 @@ import (
) )
type TestAiClient struct { type TestAiClient struct {
ImageInfo ImageInfo ImageInfo agents.ImageInfo
} }
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) {
return client.ImageInfo, nil return client.ImageInfo, nil
} }
func GetAiClient() (AiClient, error) { func GetAiClient() (agents.AiClient, error) {
mode := os.Getenv("MODE") mode := os.Getenv("MODE")
if mode == "TESTING" { if mode == "TESTING" {
address := "10 Downing Street" address := "10 Downing Street"
description := "Cheese and Crackers" description := "Cheese and Crackers"
return TestAiClient{ return TestAiClient{
ImageInfo: ImageInfo{ ImageInfo: agents.ImageInfo{
Tags: []string{"tag"}, Tags: []string{"tag"},
Links: []string{"links"}, Links: []string{"links"},
Text: []string{"text"}, Text: []string{"text"},
@ -55,7 +56,7 @@ func GetAiClient() (AiClient, error) {
}, nil }, nil
} }
return CreateOpenAiClient() return agents.CreateAgentClient(agents.PROMPT)
} }
func main() { func main() {
@ -107,6 +108,11 @@ func main() {
panic(err) panic(err)
} }
locationAgent, err := agents.NewLocationEventAgent()
if err != nil {
panic(err)
}
image, err := imageModel.GetToProcessWithData(ctx, imageId) image, err := imageModel.GetToProcessWithData(ctx, imageId)
if err != nil { if err != nil {
log.Println("Failed to GetToProcessWithData") log.Println("Failed to GetToProcessWithData")
@ -114,6 +120,9 @@ func main() {
return return
} }
log.Println("Calling locationAgent!")
locationAgent.GetLocations(image.Image.ImageName, image.Image.Image)
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {
log.Println("Failed to GetImageInfo") log.Println("Failed to GetImageInfo")

58
backend/tools.json Normal file
View File

@ -0,0 +1,58 @@
[
{
"type": "function",
"function": {
"name": "createLocation",
"description": "Creates a location",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"coordinates": {
"type": "string"
},
"address": {
"type": "string"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "listLocations",
"description": "Lists the locations available",
"parameters": {
"type": "object",
"properties": {}
}
}
},
{
"type": "function",
"function": {
"name": "createEvent",
"description": "Creates a new event",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"datetime": {
"type": "string"
},
"locationId": {
"type": "string",
"description": "The ID of the location, available by `listLocations`"
}
},
"required": ["name"]
}
}
}
]