482 lines
13 KiB
Go

package agents
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
)
type ImageInfo struct {
Tags []string `json:"tags"`
Text []string `json:"text"`
Links []string `json:"links"`
Locations []model.Locations `json:"locations"`
Events []model.Events `json:"events"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
// ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
AgentMessages
}
type AgentMessages struct {
Messages []AgentMessage `json:"messages"`
}
type AgentMessage interface {
MessageToJson() ([]byte, error)
}
type AgentTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
type AgentAssistantToolCall struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
return json.Marshal(toolCall)
}
type AgentArrayMessage struct {
Role string `json:"role"`
Content []AgentContent `json:"content"`
}
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *AgentMessages) AddText(message AgentTextMessage) {
content.Messages = append(content.Messages, message)
}
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
content.Messages = append(content.Messages, toolCall)
}
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
arrayMessage.Content[0] = AgentImage{
ImageType: IMAGE_TYPE,
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
}
content.Messages = append(content.Messages, arrayMessage)
return nil
}
func (content *AgentMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, AgentTextMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
return nil
}
type AgentContent interface {
ToJson() ([]byte, error)
}
type ImageUrl struct {
Url string `json:"url"`
}
type AgentImage struct {
ImageType string `json:"type"`
ImageUrl string `json:"image_url"`
}
func (imageMessage AgentImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage)
}
type AiClient interface {
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
}
type AgentClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
Do func(req *http.Request) (*http.Response, error)
}
// func (client AgentClient) Do(req *http.Request) () {
// httpClient := http.Client{}
// return httpClient.Do(req)
// }
const OPENAI_API_KEY = "OPENAI_API_KEY"
const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
// TODO: extract to text file probably
const PROMPT = `
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
Be sure to extract every link (URL) that you find.
Use generic tags.
`
const RESPONSE_FORMAT = `
{
"name": "image_info",
"strict": true,
"schema": {
"type": "object",
"title": "image",
"required": ["tags", "text", "links"],
"additionalProperties": false,
"properties": {
"tags": {
"type": "array",
"title": "tags",
"description": "A list of tags you think the image is relevant to.",
"items": {
"type": "string"
}
},
"text": {
"type": "array",
"title": "text",
"description": "A list of sentences the image contains.",
"items": {
"type": "string"
}
},
"links": {
"type": "array",
"title": "links",
"description": "A list of all the links you can find in the image.",
"items": {
"type": "string"
}
},
"locations": {
"title": "locations",
"type": "array",
"description": "A list of locations you can find on the image, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"title": "name",
"type": "string"
},
"coordinates": {
"title": "coordinates",
"type": "string"
},
"address": {
"title": "address",
"type": "string"
},
"description": {
"title": "description",
"type": "string"
}
}
}
},
"events": {
"title": "events",
"type": "array",
"description": "A list of events you find on the image, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"type": "string",
"title": "name"
},
"locations": {
"title": "locations",
"type": "array",
"description": "A list of locations on this event, if any",
"items": {
"type": "object",
"required": ["name"],
"additionalProperties": false,
"properties": {
"name": {
"title": "name",
"type": "string"
},
"coordinates": {
"title": "coordinates",
"type": "string"
},
"address": {
"title": "address",
"type": "string"
},
"description": {
"title": "description",
"type": "string"
}
}
}
}
}
}
}
}
}
}
`
func CreateAgentClient(prompt string) (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: prompt,
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
}, nil
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) {
request := AgentRequestBody{
Model: model,
Temperature: temperature,
ResponseFormat: ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
},
}
// TODO: Add build pattern here that deals with errors in some internal state?
// I want a monad!!!
err := request.AddSystem(prompt)
if err != nil {
return request, err
}
log.Println(request)
err = request.AddImage(imageName, imageData)
if err != nil {
return request, err
}
request.Tools = nil
return request, nil
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type ResponseChoiceMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message ResponseChoiceMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
// TODO: add usage parsing
func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
response := AgentResponse{}
err := json.Unmarshal(jsonResponse, &response)
if err != nil {
return ImageInfo{}, err
}
if len(response.Choices) != 1 {
log.Println(string(jsonResponse))
return ImageInfo{}, errors.New("Expected exactly one choice.")
}
imageInfo := ImageInfo{}
err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo)
if err != nil {
return ImageInfo{}, errors.New("Could not parse content into image type.")
}
return imageInfo, nil
}
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
log.Println(string(response))
toolCalls := agentResponse.Choices[0].Message.ToolCalls
if len(toolCalls) > 0 {
// Should for sure be more flexible.
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: toolCalls,
})
}
return agentResponse, nil
}
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
if err != nil {
return ImageInfo{}, err
}
var jsonSchema any
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
if err != nil {
return ImageInfo{}, err
}
aiRequest.ResponseFormat = ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
}
jsonAiRequest, err := json.Marshal(aiRequest)
if err != nil {
return ImageInfo{}, err
}
request, err := client.getRequest(jsonAiRequest)
if err != nil {
return ImageInfo{}, err
}
resp, err := client.Do(request)
if err != nil {
return ImageInfo{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return ImageInfo{}, err
}
log.Println(string(response))
return parseAgentResponse(response)
}