diff --git a/backend/agents/agent.go b/backend/agents/agent.go index 346cfe3..cbb5b2f 100644 --- a/backend/agents/agent.go +++ b/backend/agents/agent.go @@ -1,4 +1,4 @@ -package main +package agents import ( "bytes" @@ -33,6 +33,9 @@ type AgentRequestBody struct { Temperature float64 `json:"temperature"` ResponseFormat ResponseFormat `json:"response_format"` + Tools *any `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + AgentMessages } @@ -145,10 +148,6 @@ You are an image information extractor. The user will provide you with screensho that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags. Be sure to extract every link (URL) that you find. Use generic tags. - -You also want to extract events in the image, and the location/locations this event is hosted in. - -You need to extract locations in the image if any exist, and give the approximate coordinates for this location. ` const RESPONSE_FORMAT = ` @@ -262,7 +261,7 @@ const RESPONSE_FORMAT = ` } ` -func CreateAgentClient() (AgentClient, error) { +func CreateAgentClient(prompt string) (AgentClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) if len(apiKey) == 0 { @@ -272,7 +271,7 @@ func CreateAgentClient() (AgentClient, error) { return AgentClient{ apiKey: apiKey, url: "https://api.mistral.ai/v1/chat/completions", - systemPrompt: PROMPT, + systemPrompt: prompt, Do: func(req *http.Request) (*http.Response, error) { client := &http.Client{} return client.Do(req) @@ -309,11 +308,15 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im return request, err } + log.Println(request) + err = request.AddImage(imageName, imageData) if err != nil { return request, err } + request.Tools = nil + return request, nil } @@ -370,8 +373,6 @@ func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (Imag return ImageInfo{}, err } - log.Println(jsonSchema) - aiRequest.ResponseFormat = ResponseFormat{ Type: "json_schema", JsonSchema: jsonSchema, diff --git a/backend/agents/agent_test.go b/backend/agents/agent_test.go index f6c1a28..770e0ab 100644 --- a/backend/agents/agent_test.go +++ b/backend/agents/agent_test.go @@ -1,4 +1,4 @@ -package main +package agents import ( "bytes" diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go new file mode 100644 index 0000000..2ec4d23 --- /dev/null +++ b/backend/agents/event_location_agent.go @@ -0,0 +1,139 @@ +package agents + +import ( + "encoding/json" + "io" + "log" +) + +const eventLocationPrompt = ` +You are an agent that extracts events and locations from an image. +Your job is to check if an image has an event or a location and use the correct tools to extract this information. + +If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location. + +It is possible that there is no location or event on an image. +` + +// TODO: this should be read directly from a file on load. +const TOOLS = ` +[ + { + "type": "function", + "function": { + "name": "createLocation", + "description": "Creates a location", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "coordinates": { + "type": "string" + }, + "address": { + "type": "string" + } + }, + "required": ["name"] + } + } + }, + { + "type": "function", + "function": { + "name": "listLocations", + "description": "Lists the locations available", + "parameters": { + "type": "object", + "properties": {} + } + } + }, + { + "type": "function", + "function": { + "name": "createEvent", + "description": "Creates a new event", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "datetime": { + "type": "string" + }, + "locationId": { + "type": "string", + "description": "The ID of the location, available by listLocations" + } + }, + "required": ["name"] + } + } + } +] +` + +type EventLocationAgent = AgentClient + +func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error { + var tools any + err := json.Unmarshal([]byte(TOOLS), &tools) + + auto := "auto" + + request := AgentRequestBody{ + Tools: &tools, + ToolChoice: &auto, + Model: "pixtral-12b-2409", + Temperature: 0.3, + ResponseFormat: ResponseFormat{ + Type: "text", + }, + } + + err = request.AddSystem(eventLocationPrompt) + if err != nil { + return err + } + + log.Println(request) + + request.AddImage(imageName, imageData) + + jsonAiRequest, err := json.Marshal(request) + if err != nil { + return err + } + + httpRequest, err := agent.getRequest(jsonAiRequest) + if err != nil { + return err + } + + resp, err := agent.Do(httpRequest) + if err != nil { + return err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + log.Println(string(response)) + + return nil +} + +func NewLocationEventAgent() (EventLocationAgent, error) { + agent, err := CreateAgentClient(eventLocationPrompt) + if err != nil { + return EventLocationAgent{}, err + } + + return agent, nil +} diff --git a/backend/main.go b/backend/main.go index 4063404..e5e5b06 100644 --- a/backend/main.go +++ b/backend/main.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents" "screenmark/screenmark/models" "time" @@ -23,21 +24,21 @@ import ( ) type TestAiClient struct { - ImageInfo ImageInfo + ImageInfo agents.ImageInfo } -func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { +func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) { return client.ImageInfo, nil } -func GetAiClient() (AiClient, error) { +func GetAiClient() (agents.AiClient, error) { mode := os.Getenv("MODE") if mode == "TESTING" { address := "10 Downing Street" description := "Cheese and Crackers" return TestAiClient{ - ImageInfo: ImageInfo{ + ImageInfo: agents.ImageInfo{ Tags: []string{"tag"}, Links: []string{"links"}, Text: []string{"text"}, @@ -55,7 +56,7 @@ func GetAiClient() (AiClient, error) { }, nil } - return CreateOpenAiClient() + return agents.CreateAgentClient(agents.PROMPT) } func main() { @@ -107,6 +108,11 @@ func main() { panic(err) } + locationAgent, err := agents.NewLocationEventAgent() + if err != nil { + panic(err) + } + image, err := imageModel.GetToProcessWithData(ctx, imageId) if err != nil { log.Println("Failed to GetToProcessWithData") @@ -114,6 +120,9 @@ func main() { return } + log.Println("Calling locationAgent!") + locationAgent.GetLocations(image.Image.ImageName, image.Image.Image) + imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) if err != nil { log.Println("Failed to GetImageInfo") diff --git a/backend/tools.json b/backend/tools.json new file mode 100644 index 0000000..13652dc --- /dev/null +++ b/backend/tools.json @@ -0,0 +1,58 @@ +[ + { + "type": "function", + "function": { + "name": "createLocation", + "description": "Creates a location", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "coordinates": { + "type": "string" + }, + "address": { + "type": "string" + } + }, + "required": ["name"] + } + } + }, + { + "type": "function", + "function": { + "name": "listLocations", + "description": "Lists the locations available", + "parameters": { + "type": "object", + "properties": {} + } + } + }, + { + "type": "function", + "function": { + "name": "createEvent", + "description": "Creates a new event", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "datetime": { + "type": "string" + }, + "locationId": { + "type": "string", + "description": "The ID of the location, available by `listLocations`" + } + }, + "required": ["name"] + } + } + } +]