feat: working e2e system making requests to open ai

This commit is contained in:
2025-02-22 17:32:48 +00:00
parent be0258e195
commit c0b3ead79b
6 changed files with 460 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
.env
db
screenmark

2
go.mod
View File

@ -1,3 +1,5 @@
module screenmark/screenmark module screenmark/screenmark
go 1.24.0 go 1.24.0
require github.com/joho/godotenv v1.5.1 // indirect

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=

22
main.go
View File

@ -6,9 +6,21 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"github.com/joho/godotenv"
) )
func main() { func main() {
err := godotenv.Load()
if err != nil {
panic(err)
}
openAiClient, err := CreateOpenAiClient()
if err != nil {
panic(err)
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("OPTIONS /image/{name}", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("OPTIONS /image/{name}", func(w http.ResponseWriter, r *http.Request) {
@ -55,6 +67,16 @@ func main() {
fmt.Fprintf(w, "Couldnt write the image") fmt.Fprintf(w, "Couldnt write the image")
return return
} }
resp, err := openAiClient.GetImageInfo(imageName, image)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "some shit happened")
return
}
log.Printf("%+v\n", resp)
}) })
log.Println("Listening and serving.") log.Println("Listening and serving.")

280
openai.go Normal file
View File

@ -0,0 +1,280 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
)
type ImageInfo struct {
Tags []string `json:"tags"`
Text []string `json:"text"`
Links []string `json:"links"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type OpenAiRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
OpenAiMessages
}
type OpenAiMessages struct {
Messages []OpenAiMessage `json:"messages"`
}
type OpenAiMessage interface {
MessageToJson() ([]byte, error)
}
type OpenAiTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
type OpenAiArrayMessage struct {
Role string `json:"role"`
Content []OpenAiContent `json:"content"`
}
func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)}
arrayMessage.Content[0] = OpenAiImage{
ImageType: IMAGE_TYPE,
ImageUrl: ImageUrl{
Url: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
},
}
content.Messages = append(content.Messages, arrayMessage)
return nil
}
func (content *OpenAiMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, OpenAiTextMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
return nil
}
type OpenAiContent interface {
ToJson() ([]byte, error)
}
type ImageUrl struct {
Url string `json:"url"`
}
type OpenAiImage struct {
ImageType string `json:"type"`
ImageUrl ImageUrl `json:"image_url"`
}
func (imageMessage OpenAiImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage)
}
type OpenAiClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
Do func(req *http.Request) (*http.Response, error)
}
// func (client OpenAiClient) Do(req *http.Request) () {
// httpClient := http.Client{}
// return httpClient.Do(req)
// }
const OPENAI_API_KEY = "OPENAI_API_KEY"
const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
// TODO: extract to text file probably
const PROMPT = `
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
This system is part of a bookmark manager, who's main goal is to allow the user to search through various screenshots.
`
const RESPONSE_FORMAT = `
{
"name": "schema_description",
"schema": {
"type": "object",
"properties": {
"tags": {
"type": "array",
"description": "A list of tags you think the image is relevant to.",
"items": {
"type": "string"
}
},
"text": {
"type": "array",
"description": "A list of sentences the image contains.",
"items": {
"type": "string"
}
},
"links": {
"type": "array",
"description": "A list of all the links you can find in the image.",
"items": {
"type": "string"
}
}
},
"required": [
"tags",
"text",
"links"
],
"additionalProperties": false
},
"strict": true
}
`
func CreateOpenAiClient() (OpenAiClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return OpenAiClient{
apiKey: apiKey,
url: "https://api.openai.com/v1/chat/completions",
systemPrompt: PROMPT,
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
}, nil
}
func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func getCompletionsForImage(model string, temperature float64, prompt, imageName string, imageData []byte) (OpenAiRequestBody, error) {
request := OpenAiRequestBody{
Model: model,
Temperature: temperature,
}
// TODO: Add build pattern here that deals with errors in some internal state?
// I want a monad!!!
err := request.AddSystem(prompt)
if err != nil {
return request, err
}
err = request.AddImage(imageName, imageData)
if err != nil {
return request, err
}
return request, nil
}
func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
aiRequest, err := getCompletionsForImage("gpt-4o-mini", 1.0, client.systemPrompt, imageName, imageData)
if err != nil {
return ImageInfo{}, err
}
var jsonSchema any
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
if err != nil {
return ImageInfo{}, err
}
aiRequest.ResponseFormat = ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
}
jsonAiRequest, err := json.Marshal(aiRequest)
if err != nil {
return ImageInfo{}, err
}
request, err := client.getRequest(jsonAiRequest)
if err != nil {
return ImageInfo{}, err
}
resp, err := client.Do(request)
if err != nil {
return ImageInfo{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return ImageInfo{}, err
}
info := ImageInfo{}
err = json.Unmarshal(response, &info)
if err != nil {
return ImageInfo{}, err
}
log.Println(string(response))
return info, nil
}

151
openai_test.go Normal file
View File

@ -0,0 +1,151 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
)
func TestMessageBuilder(t *testing.T) {
content := OpenAiMessages{}
err := content.AddSystem("Some prompt")
if err != nil {
t.Log(err)
t.FailNow()
}
if len(content.Messages) != 1 {
t.Logf("Expected length 1, got %d.\n", len(content.Messages))
t.FailNow()
}
}
func TestMessageBuilderImage(t *testing.T) {
content := OpenAiMessages{}
prompt := "some prompt"
imageTitle := "image.png"
data := []byte("some data")
content.AddSystem(prompt)
content.AddImage(imageTitle, data)
if len(content.Messages) != 2 {
t.Logf("Expected length 2, got %d.\n", len(content.Messages))
t.FailNow()
}
promptMessage, ok := content.Messages[0].(OpenAiTextMessage)
if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[0])
t.FailNow()
}
if promptMessage.Role != ROLE_SYSTEM {
t.Log("Prompt message role is incorrect.")
t.FailNow()
}
if promptMessage.Content != prompt {
t.Log("Prompt message content is incorrect.")
t.FailNow()
}
arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage)
if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[1])
t.FailNow()
}
if arrayContentMessage.Role != ROLE_USER {
t.Log("Array content message role is incorrect.")
t.FailNow()
}
if len(arrayContentMessage.Content) != 1 {
t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content))
t.FailNow()
}
imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage)
if !ok {
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
t.FailNow()
}
base64data := base64.StdEncoding.EncodeToString(data)
url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data)
if imageContent.ImageUrl.Url != url {
t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl.Url)
t.FailNow()
}
}
func TestFullImageRequest(t *testing.T) {
request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", []byte("some data"))
if err != nil {
t.Log(request)
t.FailNow()
}
jsonData, err := json.Marshal(request)
if err != nil {
t.Log(err)
t.FailNow()
}
expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}`
if string(jsonData) != expectedJson {
t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData))
t.FailNow()
}
}
func TestResponse(t *testing.T) {
testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}`
buffer := bytes.NewReader([]byte(testResponse))
body := io.NopCloser(buffer)
client := OpenAiClient{
url: "http://localhost:1234",
apiKey: "some-key",
Do: func(_req *http.Request) (*http.Response, error) {
return &http.Response{Body: body}, nil
},
}
info, err := client.GetImageInfo("image.png", []byte("some data"))
if err != nil {
t.Log(err)
t.FailNow()
}
if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 {
t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links))
t.FailNow()
}
if info.Tags[0] != "tag1" {
t.Log("0th tag is wrong.")
t.FailNow()
}
if info.Tags[1] != "tag2" {
t.Log("1th tag is wrong.")
t.FailNow()
}
if info.Text[0] != "text1" {
t.Log("0th text is wrong.")
t.FailNow()
}
}