refactor(naming): using Agent instead of openai
This commit is contained in:
403
backend/agents/agent.go
Normal file
403
backend/agents/agent.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
)
|
||||
|
||||
type ImageInfo struct {
|
||||
Tags []string `json:"tags"`
|
||||
Text []string `json:"text"`
|
||||
Links []string `json:"links"`
|
||||
|
||||
Locations []model.Locations `json:"locations"`
|
||||
Events []model.Events `json:"events"`
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JsonSchema any `json:"json_schema"`
|
||||
}
|
||||
|
||||
type AgentRequestBody struct {
|
||||
Model string `json:"model"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ResponseFormat ResponseFormat `json:"response_format"`
|
||||
|
||||
AgentMessages
|
||||
}
|
||||
|
||||
type AgentMessages struct {
|
||||
Messages []AgentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type AgentMessage interface {
|
||||
MessageToJson() ([]byte, error)
|
||||
}
|
||||
|
||||
type AgentTextMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
||||
// TODO: Validate the `Role`.
|
||||
return json.Marshal(textContent)
|
||||
}
|
||||
|
||||
type AgentArrayMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []AgentContent `json:"content"`
|
||||
}
|
||||
|
||||
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(arrayContent)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
||||
extension := filepath.Ext(imageName)
|
||||
if len(extension) == 0 {
|
||||
// TODO: could also validate for image types we support.
|
||||
return errors.New("Image does not have extension")
|
||||
}
|
||||
|
||||
extension = extension[1:]
|
||||
|
||||
encodedString := base64.StdEncoding.EncodeToString(image)
|
||||
|
||||
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
|
||||
arrayMessage.Content[0] = AgentImage{
|
||||
ImageType: IMAGE_TYPE,
|
||||
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
||||
}
|
||||
|
||||
content.Messages = append(content.Messages, arrayMessage)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddSystem(prompt string) error {
|
||||
if len(content.Messages) != 0 {
|
||||
return errors.New("You can only add a system prompt at the beginning")
|
||||
}
|
||||
|
||||
content.Messages = append(content.Messages, AgentTextMessage{
|
||||
Role: ROLE_SYSTEM,
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AgentContent interface {
|
||||
ToJson() ([]byte, error)
|
||||
}
|
||||
|
||||
type ImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
type AgentImage struct {
|
||||
ImageType string `json:"type"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
}
|
||||
|
||||
func (imageMessage AgentImage) ToJson() ([]byte, error) {
|
||||
imageMessage.ImageType = IMAGE_TYPE
|
||||
return json.Marshal(imageMessage)
|
||||
}
|
||||
|
||||
type AiClient interface {
|
||||
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
|
||||
}
|
||||
|
||||
type AgentClient struct {
|
||||
url string
|
||||
apiKey string
|
||||
systemPrompt string
|
||||
responseFormat string
|
||||
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// func (client AgentClient) Do(req *http.Request) () {
|
||||
// httpClient := http.Client{}
|
||||
// return httpClient.Do(req)
|
||||
// }
|
||||
|
||||
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||
const ROLE_USER = "user"
|
||||
const ROLE_SYSTEM = "system"
|
||||
const IMAGE_TYPE = "image_url"
|
||||
|
||||
// TODO: extract to text file probably
|
||||
const PROMPT = `
|
||||
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
|
||||
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
|
||||
Be sure to extract every link (URL) that you find.
|
||||
Use generic tags.
|
||||
|
||||
You also want to extract events in the image, and the location/locations this event is hosted in.
|
||||
|
||||
You need to extract locations in the image if any exist, and give the approximate coordinates for this location.
|
||||
`
|
||||
|
||||
const RESPONSE_FORMAT = `
|
||||
{
|
||||
"name": "image_info",
|
||||
"strict": true,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"title": "image",
|
||||
"required": ["tags", "text", "links"],
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"title": "tags",
|
||||
"description": "A list of tags you think the image is relevant to.",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"text": {
|
||||
"type": "array",
|
||||
"title": "text",
|
||||
"description": "A list of sentences the image contains.",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"links": {
|
||||
"type": "array",
|
||||
"title": "links",
|
||||
"description": "A list of all the links you can find in the image.",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"locations": {
|
||||
"title": "locations",
|
||||
"type": "array",
|
||||
"description": "A list of locations you can find on the image, if any",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "name",
|
||||
"type": "string"
|
||||
},
|
||||
"coordinates": {
|
||||
"title": "coordinates",
|
||||
"type": "string"
|
||||
},
|
||||
"address": {
|
||||
"title": "address",
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"title": "description",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"events": {
|
||||
"title": "events",
|
||||
"type": "array",
|
||||
"description": "A list of events you find on the image, if any",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"title": "name"
|
||||
},
|
||||
"locations": {
|
||||
"title": "locations",
|
||||
"type": "array",
|
||||
"description": "A list of locations on this event, if any",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "name",
|
||||
"type": "string"
|
||||
},
|
||||
"coordinates": {
|
||||
"title": "coordinates",
|
||||
"type": "string"
|
||||
},
|
||||
"address": {
|
||||
"title": "address",
|
||||
"type": "string"
|
||||
},
|
||||
"description": {
|
||||
"title": "description",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
func CreateAgentClient() (AgentClient, error) {
|
||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||
|
||||
if len(apiKey) == 0 {
|
||||
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
||||
}
|
||||
|
||||
return AgentClient{
|
||||
apiKey: apiKey,
|
||||
url: "https://api.mistral.ai/v1/chat/completions",
|
||||
systemPrompt: PROMPT,
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
client := &http.Client{}
|
||||
return client.Do(req)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) {
|
||||
request := AgentRequestBody{
|
||||
Model: model,
|
||||
Temperature: temperature,
|
||||
ResponseFormat: ResponseFormat{
|
||||
Type: "json_schema",
|
||||
JsonSchema: jsonSchema,
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: Add build pattern here that deals with errors in some internal state?
|
||||
// I want a monad!!!
|
||||
err := request.AddSystem(prompt)
|
||||
if err != nil {
|
||||
return request, err
|
||||
}
|
||||
|
||||
err = request.AddImage(imageName, imageData)
|
||||
if err != nil {
|
||||
return request, err
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
type ResponseChoiceMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ResponseChoiceMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AgentResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Choices []ResponseChoice `json:"choices"`
|
||||
Created int `json:"created"`
|
||||
}
|
||||
|
||||
// TODO: add usage parsing
|
||||
func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
|
||||
response := AgentResponse{}
|
||||
|
||||
err := json.Unmarshal(jsonResponse, &response)
|
||||
if err != nil {
|
||||
return ImageInfo{}, err
|
||||
}
|
||||
|
||||
if len(response.Choices) != 1 {
|
||||
log.Println(string(jsonResponse))
|
||||
return ImageInfo{}, errors.New("Expected exactly one choice.")
|
||||
}
|
||||
|
||||
imageInfo := ImageInfo{}
|
||||
err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo)
|
||||
if err != nil {
|
||||
return ImageInfo{}, errors.New("Could not parse content into image type.")
|
||||
}
|
||||
|
||||
return imageInfo, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
||||
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
|
||||
if err != nil {
|
||||
return ImageInfo{}, err
|
||||
}
|
||||
|
||||
var jsonSchema any
|
||||
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
|
||||
if err != nil {
|
||||
return ImageInfo{}, err
|
||||
}
|
||||
|
||||
log.Println(jsonSchema)
|
||||
|
||||
aiRequest.ResponseFormat = ResponseFormat{
|
||||
Type: "json_schema",
|
||||
JsonSchema: jsonSchema,
|
||||
}
|
||||
|
||||
jsonAiRequest, err := json.Marshal(aiRequest)
|
||||
if err != nil {
|
||||
return ImageInfo{}, err
|
||||
}
|
||||
|
||||
request, err := client.getRequest(jsonAiRequest)
|
||||
if err != nil {
|
||||
return ImageInfo{}, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(request)
|
||||
if err != nil {
|
||||
return ImageInfo{}, err
|
||||
}
|
||||
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ImageInfo{}, err
|
||||
}
|
||||
|
||||
log.Println(string(response))
|
||||
|
||||
return parseAgentResponse(response)
|
||||
}
|
||||
213
backend/agents/agent_test.go
Normal file
213
backend/agents/agent_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageBuilder(t *testing.T) {
|
||||
content := AgentMessages{}
|
||||
|
||||
err := content.AddSystem("Some prompt")
|
||||
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(content.Messages) != 1 {
|
||||
t.Logf("Expected length 1, got %d.\n", len(content.Messages))
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBuilderImage(t *testing.T) {
|
||||
content := AgentMessages{}
|
||||
|
||||
prompt := "some prompt"
|
||||
imageTitle := "image.png"
|
||||
data := []byte("some data")
|
||||
|
||||
content.AddSystem(prompt)
|
||||
content.AddImage(imageTitle, data)
|
||||
|
||||
if len(content.Messages) != 2 {
|
||||
t.Logf("Expected length 2, got %d.\n", len(content.Messages))
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
promptMessage, ok := content.Messages[0].(AgentTextMessage)
|
||||
if !ok {
|
||||
t.Logf("Expected text content message, got %T\n", content.Messages[0])
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if promptMessage.Role != ROLE_SYSTEM {
|
||||
t.Log("Prompt message role is incorrect.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if promptMessage.Content != prompt {
|
||||
t.Log("Prompt message content is incorrect.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage)
|
||||
if !ok {
|
||||
t.Logf("Expected text content message, got %T\n", content.Messages[1])
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if arrayContentMessage.Role != ROLE_USER {
|
||||
t.Log("Array content message role is incorrect.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(arrayContentMessage.Content) != 1 {
|
||||
t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content))
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
imageContent, ok := arrayContentMessage.Content[0].(AgentImage)
|
||||
if !ok {
|
||||
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
base64data := base64.StdEncoding.EncodeToString(data)
|
||||
url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data)
|
||||
|
||||
if imageContent.ImageUrl != url {
|
||||
t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestFullImageRequest(t *testing.T) {
|
||||
request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data"))
|
||||
if err != nil {
|
||||
t.Log(request)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,c29tZSBkYXRh"}}]}]}`
|
||||
|
||||
if string(jsonData) != expectedJson {
|
||||
t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData))
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponse(t *testing.T) {
|
||||
testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}`
|
||||
buffer := bytes.NewReader([]byte(testResponse))
|
||||
|
||||
body := io.NopCloser(buffer)
|
||||
|
||||
client := AgentClient{
|
||||
url: "http://localhost:1234",
|
||||
apiKey: "some-key",
|
||||
Do: func(_req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{Body: body}, nil
|
||||
},
|
||||
}
|
||||
|
||||
info, err := client.GetImageInfo("image.png", []byte("some data"))
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 {
|
||||
t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links))
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if info.Tags[0] != "tag1" {
|
||||
t.Log("0th tag is wrong.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if info.Tags[1] != "tag2" {
|
||||
t.Log("1th tag is wrong.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if info.Text[0] != "text1" {
|
||||
t.Log("0th text is wrong.")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseParsing(t *testing.T) {
|
||||
response := `{
|
||||
"id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ",
|
||||
"object": "chat.completion",
|
||||
"created": 1740422508,
|
||||
"model": "gpt-4o-mini-2024-07-18",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}",
|
||||
"refusal": null
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 775,
|
||||
"completion_tokens": 33,
|
||||
"total_tokens": 808,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": 0,
|
||||
"audio_tokens": 0
|
||||
},
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0
|
||||
}
|
||||
},
|
||||
"service_tier": "default",
|
||||
"system_fingerprint": "fp_7fcd609668"
|
||||
}`
|
||||
|
||||
imageParsed, err := parseAgentResponse([]byte(response))
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" {
|
||||
t.Log(imageParsed)
|
||||
t.Log("Should have one link called 'link'.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" {
|
||||
t.Log(imageParsed)
|
||||
t.Log("Should have one tag called 'tag'.")
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" {
|
||||
t.Log(imageParsed)
|
||||
t.Log("Should have one text called 'text'.")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user