package main import ( "bytes" "encoding/base64" "encoding/json" "errors" "fmt" "io" "log" "net/http" "os" "path/filepath" "screenmark/screenmark/.gen/haystack/haystack/model" ) type ImageInfo struct { Tags []string `json:"tags"` Text []string `json:"text"` Links []string `json:"links"` Locations []model.Locations `json:"locations"` Events []model.Events `json:"events"` } type ResponseFormat struct { Type string `json:"type"` JsonSchema any `json:"json_schema"` } type OpenAiRequestBody struct { Model string `json:"model"` Temperature float64 `json:"temperature"` ResponseFormat ResponseFormat `json:"response_format"` OpenAiMessages } type OpenAiMessages struct { Messages []OpenAiMessage `json:"messages"` } type OpenAiMessage interface { MessageToJson() ([]byte, error) } type OpenAiTextMessage struct { Role string `json:"role"` Content string `json:"content"` } func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) { // TODO: Validate the `Role`. return json.Marshal(textContent) } type OpenAiArrayMessage struct { Role string `json:"role"` Content []OpenAiContent `json:"content"` } func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) { return json.Marshal(arrayContent) } func (content *OpenAiMessages) AddImage(imageName string, image []byte) error { extension := filepath.Ext(imageName) if len(extension) == 0 { // TODO: could also validate for image types we support. return errors.New("Image does not have extension") } extension = extension[1:] encodedString := base64.StdEncoding.EncodeToString(image) arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)} arrayMessage.Content[0] = OpenAiImage{ ImageType: IMAGE_TYPE, ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), } content.Messages = append(content.Messages, arrayMessage) return nil } func (content *OpenAiMessages) AddSystem(prompt string) error { if len(content.Messages) != 0 { return errors.New("You can only add a system prompt at the beginning") } content.Messages = append(content.Messages, OpenAiTextMessage{ Role: ROLE_SYSTEM, Content: prompt, }) return nil } type OpenAiContent interface { ToJson() ([]byte, error) } type ImageUrl struct { Url string `json:"url"` } type OpenAiImage struct { ImageType string `json:"type"` ImageUrl string `json:"image_url"` } func (imageMessage OpenAiImage) ToJson() ([]byte, error) { imageMessage.ImageType = IMAGE_TYPE return json.Marshal(imageMessage) } type AiClient interface { GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) } type OpenAiClient struct { url string apiKey string systemPrompt string responseFormat string Do func(req *http.Request) (*http.Response, error) } // func (client OpenAiClient) Do(req *http.Request) () { // httpClient := http.Client{} // return httpClient.Do(req) // } const OPENAI_API_KEY = "OPENAI_API_KEY" const ROLE_USER = "user" const ROLE_SYSTEM = "system" const IMAGE_TYPE = "image_url" // TODO: extract to text file probably const PROMPT = ` You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags. Be sure to extract every link (URL) that you find. Use generic tags. You also want to extract events in the image, and the location/locations this event is hosted in. You need to extract locations in the image if any exist, and give the approximate coordinates for this location. ` const RESPONSE_FORMAT = ` { "name": "image_info", "strict": true, "schema": { "type": "object", "title": "image", "required": ["tags", "text", "links"], "additionalProperties": false, "properties": { "tags": { "type": "array", "title": "tags", "description": "A list of tags you think the image is relevant to.", "items": { "type": "string" } }, "text": { "type": "array", "title": "text", "description": "A list of sentences the image contains.", "items": { "type": "string" } }, "links": { "type": "array", "title": "links", "description": "A list of all the links you can find in the image.", "items": { "type": "string" } }, "locations": { "title": "locations", "type": "array", "description": "A list of locations you can find on the image, if any", "items": { "type": "object", "required": ["name"], "additionalProperties": false, "properties": { "name": { "title": "name", "type": "string" }, "coordinates": { "title": "coordinates", "type": "string" }, "address": { "title": "address", "type": "string" }, "description": { "title": "description", "type": "string" } } } }, "events": { "title": "events", "type": "array", "description": "A list of events you find on the image, if any", "items": { "type": "object", "required": ["name"], "additionalProperties": false, "properties": { "name": { "type": "string", "title": "name" }, "locations": { "title": "locations", "type": "array", "description": "A list of locations on this event, if any", "items": { "type": "object", "required": ["name"], "additionalProperties": false, "properties": { "name": { "title": "name", "type": "string" }, "coordinates": { "title": "coordinates", "type": "string" }, "address": { "title": "address", "type": "string" }, "description": { "title": "description", "type": "string" } } } } } } } } } } ` func CreateOpenAiClient() (OpenAiClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) if len(apiKey) == 0 { return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.") } return OpenAiClient{ apiKey: apiKey, url: "https://api.mistral.ai/v1/chat/completions", systemPrompt: PROMPT, Do: func(req *http.Request) (*http.Response, error) { client := &http.Client{} return client.Do(req) }, }, nil } func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) { req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) if err != nil { return req, err } req.Header.Add("Authorization", "Bearer "+client.apiKey) req.Header.Add("Content-Type", "application/json") return req, nil } func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (OpenAiRequestBody, error) { request := OpenAiRequestBody{ Model: model, Temperature: temperature, ResponseFormat: ResponseFormat{ Type: "json_schema", JsonSchema: jsonSchema, }, } // TODO: Add build pattern here that deals with errors in some internal state? // I want a monad!!! err := request.AddSystem(prompt) if err != nil { return request, err } err = request.AddImage(imageName, imageData) if err != nil { return request, err } return request, nil } type ResponseChoiceMessage struct { Role string `json:"role"` Content string `json:"content"` } type ResponseChoice struct { Index int `json:"index"` Message ResponseChoiceMessage `json:"message"` FinishReason string `json:"finish_reason"` } type OpenAiResponse struct { Id string `json:"id"` Object string `json:"object"` Choices []ResponseChoice `json:"choices"` Created int `json:"created"` } // TODO: add usage parsing func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) { response := OpenAiResponse{} err := json.Unmarshal(jsonResponse, &response) if err != nil { return ImageInfo{}, err } if len(response.Choices) != 1 { log.Println(string(jsonResponse)) return ImageInfo{}, errors.New("Expected exactly one choice.") } imageInfo := ImageInfo{} err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo) if err != nil { return ImageInfo{}, errors.New("Could not parse content into image type.") } return imageInfo, nil } func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) if err != nil { return ImageInfo{}, err } var jsonSchema any err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema) if err != nil { return ImageInfo{}, err } log.Println(jsonSchema) aiRequest.ResponseFormat = ResponseFormat{ Type: "json_schema", JsonSchema: jsonSchema, } jsonAiRequest, err := json.Marshal(aiRequest) if err != nil { return ImageInfo{}, err } request, err := client.getRequest(jsonAiRequest) if err != nil { return ImageInfo{}, err } resp, err := client.Do(request) if err != nil { return ImageInfo{}, err } response, err := io.ReadAll(resp.Body) if err != nil { return ImageInfo{}, err } log.Println(string(response)) return parseOpenAiResponse(response) }