feat: using tools for event loocation agent
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
package main
|
package agents
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -33,6 +33,9 @@ type AgentRequestBody struct {
|
|||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
ResponseFormat ResponseFormat `json:"response_format"`
|
ResponseFormat ResponseFormat `json:"response_format"`
|
||||||
|
|
||||||
|
Tools *any `json:"tools,omitempty"`
|
||||||
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||||
|
|
||||||
AgentMessages
|
AgentMessages
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,10 +148,6 @@ You are an image information extractor. The user will provide you with screensho
|
|||||||
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
|
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
|
||||||
Be sure to extract every link (URL) that you find.
|
Be sure to extract every link (URL) that you find.
|
||||||
Use generic tags.
|
Use generic tags.
|
||||||
|
|
||||||
You also want to extract events in the image, and the location/locations this event is hosted in.
|
|
||||||
|
|
||||||
You need to extract locations in the image if any exist, and give the approximate coordinates for this location.
|
|
||||||
`
|
`
|
||||||
|
|
||||||
const RESPONSE_FORMAT = `
|
const RESPONSE_FORMAT = `
|
||||||
@ -262,7 +261,7 @@ const RESPONSE_FORMAT = `
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
func CreateAgentClient() (AgentClient, error) {
|
func CreateAgentClient(prompt string) (AgentClient, error) {
|
||||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||||
|
|
||||||
if len(apiKey) == 0 {
|
if len(apiKey) == 0 {
|
||||||
@ -272,7 +271,7 @@ func CreateAgentClient() (AgentClient, error) {
|
|||||||
return AgentClient{
|
return AgentClient{
|
||||||
apiKey: apiKey,
|
apiKey: apiKey,
|
||||||
url: "https://api.mistral.ai/v1/chat/completions",
|
url: "https://api.mistral.ai/v1/chat/completions",
|
||||||
systemPrompt: PROMPT,
|
systemPrompt: prompt,
|
||||||
Do: func(req *http.Request) (*http.Response, error) {
|
Do: func(req *http.Request) (*http.Response, error) {
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
return client.Do(req)
|
return client.Do(req)
|
||||||
@ -309,11 +308,15 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im
|
|||||||
return request, err
|
return request, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Println(request)
|
||||||
|
|
||||||
err = request.AddImage(imageName, imageData)
|
err = request.AddImage(imageName, imageData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return request, err
|
return request, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request.Tools = nil
|
||||||
|
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,8 +373,6 @@ func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (Imag
|
|||||||
return ImageInfo{}, err
|
return ImageInfo{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println(jsonSchema)
|
|
||||||
|
|
||||||
aiRequest.ResponseFormat = ResponseFormat{
|
aiRequest.ResponseFormat = ResponseFormat{
|
||||||
Type: "json_schema",
|
Type: "json_schema",
|
||||||
JsonSchema: jsonSchema,
|
JsonSchema: jsonSchema,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package main
|
package agents
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
139
backend/agents/event_location_agent.go
Normal file
139
backend/agents/event_location_agent.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package agents
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const eventLocationPrompt = `
|
||||||
|
You are an agent that extracts events and locations from an image.
|
||||||
|
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
|
||||||
|
|
||||||
|
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
||||||
|
|
||||||
|
It is possible that there is no location or event on an image.
|
||||||
|
`
|
||||||
|
|
||||||
|
// TODO: this should be read directly from a file on load.
|
||||||
|
const TOOLS = `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "createLocation",
|
||||||
|
"description": "Creates a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"coordinates": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"address": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "listLocations",
|
||||||
|
"description": "Lists the locations available",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "createEvent",
|
||||||
|
"description": "Creates a new event",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"datetime": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"locationId": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the location, available by listLocations"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
`
|
||||||
|
|
||||||
|
type EventLocationAgent = AgentClient
|
||||||
|
|
||||||
|
func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error {
|
||||||
|
var tools any
|
||||||
|
err := json.Unmarshal([]byte(TOOLS), &tools)
|
||||||
|
|
||||||
|
auto := "auto"
|
||||||
|
|
||||||
|
request := AgentRequestBody{
|
||||||
|
Tools: &tools,
|
||||||
|
ToolChoice: &auto,
|
||||||
|
Model: "pixtral-12b-2409",
|
||||||
|
Temperature: 0.3,
|
||||||
|
ResponseFormat: ResponseFormat{
|
||||||
|
Type: "text",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = request.AddSystem(eventLocationPrompt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(request)
|
||||||
|
|
||||||
|
request.AddImage(imageName, imageData)
|
||||||
|
|
||||||
|
jsonAiRequest, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpRequest, err := agent.getRequest(jsonAiRequest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := agent.Do(httpRequest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(string(response))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLocationEventAgent() (EventLocationAgent, error) {
|
||||||
|
agent, err := CreateAgentClient(eventLocationPrompt)
|
||||||
|
if err != nil {
|
||||||
|
return EventLocationAgent{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent, nil
|
||||||
|
}
|
@ -12,6 +12,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
|
"screenmark/screenmark/agents"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -23,21 +24,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TestAiClient struct {
|
type TestAiClient struct {
|
||||||
ImageInfo ImageInfo
|
ImageInfo agents.ImageInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) {
|
||||||
return client.ImageInfo, nil
|
return client.ImageInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAiClient() (AiClient, error) {
|
func GetAiClient() (agents.AiClient, error) {
|
||||||
mode := os.Getenv("MODE")
|
mode := os.Getenv("MODE")
|
||||||
if mode == "TESTING" {
|
if mode == "TESTING" {
|
||||||
address := "10 Downing Street"
|
address := "10 Downing Street"
|
||||||
description := "Cheese and Crackers"
|
description := "Cheese and Crackers"
|
||||||
|
|
||||||
return TestAiClient{
|
return TestAiClient{
|
||||||
ImageInfo: ImageInfo{
|
ImageInfo: agents.ImageInfo{
|
||||||
Tags: []string{"tag"},
|
Tags: []string{"tag"},
|
||||||
Links: []string{"links"},
|
Links: []string{"links"},
|
||||||
Text: []string{"text"},
|
Text: []string{"text"},
|
||||||
@ -55,7 +56,7 @@ func GetAiClient() (AiClient, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return CreateOpenAiClient()
|
return agents.CreateAgentClient(agents.PROMPT)
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -107,6 +108,11 @@ func main() {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
locationAgent, err := agents.NewLocationEventAgent()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
image, err := imageModel.GetToProcessWithData(ctx, imageId)
|
image, err := imageModel.GetToProcessWithData(ctx, imageId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to GetToProcessWithData")
|
log.Println("Failed to GetToProcessWithData")
|
||||||
@ -114,6 +120,9 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Println("Calling locationAgent!")
|
||||||
|
locationAgent.GetLocations(image.Image.ImageName, image.Image.Image)
|
||||||
|
|
||||||
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
|
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to GetImageInfo")
|
log.Println("Failed to GetImageInfo")
|
||||||
|
58
backend/tools.json
Normal file
58
backend/tools.json
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "createLocation",
|
||||||
|
"description": "Creates a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"coordinates": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"address": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "listLocations",
|
||||||
|
"description": "Lists the locations available",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "createEvent",
|
||||||
|
"description": "Creates a new event",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"datetime": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"locationId": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the location, available by `listLocations`"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
Reference in New Issue
Block a user