diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3342a85 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.env +db +screenmark diff --git a/go.mod b/go.mod index e898e15..d88507b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module screenmark/screenmark go 1.24.0 + +require github.com/joho/godotenv v1.5.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d61b19e --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= diff --git a/main.go b/main.go index 51a2dfc..c065d1b 100644 --- a/main.go +++ b/main.go @@ -6,9 +6,21 @@ import ( "log" "net/http" "os" + + "github.com/joho/godotenv" ) func main() { + err := godotenv.Load() + if err != nil { + panic(err) + } + + openAiClient, err := CreateOpenAiClient() + if err != nil { + panic(err) + } + mux := http.NewServeMux() mux.HandleFunc("OPTIONS /image/{name}", func(w http.ResponseWriter, r *http.Request) { @@ -55,6 +67,16 @@ func main() { fmt.Fprintf(w, "Couldnt write the image") return } + + resp, err := openAiClient.GetImageInfo(imageName, image) + if err != nil { + log.Println(err) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "some shit happened") + return + } + + log.Printf("%+v\n", resp) }) log.Println("Listening and serving.") diff --git a/openai.go b/openai.go new file mode 100644 index 0000000..1bcd3aa --- /dev/null +++ b/openai.go @@ -0,0 +1,280 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" +) + +type ImageInfo struct { + Tags []string `json:"tags"` + Text []string `json:"text"` + Links []string `json:"links"` +} + +type ResponseFormat struct { + Type string `json:"type"` + JsonSchema any `json:"json_schema"` +} + +type OpenAiRequestBody struct { + Model string `json:"model"` + Temperature float64 `json:"temperature"` + ResponseFormat ResponseFormat `json:"response_format"` + + OpenAiMessages +} + +type OpenAiMessages struct { + Messages []OpenAiMessage `json:"messages"` +} + +type OpenAiMessage interface { + MessageToJson() ([]byte, error) +} + +type OpenAiTextMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) { + // TODO: Validate the `Role`. + return json.Marshal(textContent) +} + +type OpenAiArrayMessage struct { + Role string `json:"role"` + Content []OpenAiContent `json:"content"` +} + +func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) { + return json.Marshal(arrayContent) +} + +func (content *OpenAiMessages) AddImage(imageName string, image []byte) error { + extension := filepath.Ext(imageName) + if len(extension) == 0 { + // TODO: could also validate for image types we support. + return errors.New("Image does not have extension") + } + + extension = extension[1:] + + encodedString := base64.StdEncoding.EncodeToString(image) + + arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)} + arrayMessage.Content[0] = OpenAiImage{ + ImageType: IMAGE_TYPE, + ImageUrl: ImageUrl{ + Url: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), + }, + } + + content.Messages = append(content.Messages, arrayMessage) + + return nil +} + +func (content *OpenAiMessages) AddSystem(prompt string) error { + if len(content.Messages) != 0 { + return errors.New("You can only add a system prompt at the beginning") + } + + content.Messages = append(content.Messages, OpenAiTextMessage{ + Role: ROLE_SYSTEM, + Content: prompt, + }) + + return nil +} + +type OpenAiContent interface { + ToJson() ([]byte, error) +} + +type ImageUrl struct { + Url string `json:"url"` +} + +type OpenAiImage struct { + ImageType string `json:"type"` + ImageUrl ImageUrl `json:"image_url"` +} + +func (imageMessage OpenAiImage) ToJson() ([]byte, error) { + imageMessage.ImageType = IMAGE_TYPE + return json.Marshal(imageMessage) +} + +type OpenAiClient struct { + url string + apiKey string + systemPrompt string + responseFormat string + + Do func(req *http.Request) (*http.Response, error) +} + +// func (client OpenAiClient) Do(req *http.Request) () { +// httpClient := http.Client{} +// return httpClient.Do(req) +// } + +const OPENAI_API_KEY = "OPENAI_API_KEY" +const ROLE_USER = "user" +const ROLE_SYSTEM = "system" +const IMAGE_TYPE = "image_url" + +// TODO: extract to text file probably +const PROMPT = ` +You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text +that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags. + +This system is part of a bookmark manager, who's main goal is to allow the user to search through various screenshots. +` + +const RESPONSE_FORMAT = ` +{ + "name": "schema_description", + "schema": { + "type": "object", + "properties": { + "tags": { + "type": "array", + "description": "A list of tags you think the image is relevant to.", + "items": { + "type": "string" + } + }, + "text": { + "type": "array", + "description": "A list of sentences the image contains.", + "items": { + "type": "string" + } + }, + "links": { + "type": "array", + "description": "A list of all the links you can find in the image.", + "items": { + "type": "string" + } + } + }, + "required": [ + "tags", + "text", + "links" + ], + "additionalProperties": false + }, + "strict": true +} +` + +func CreateOpenAiClient() (OpenAiClient, error) { + apiKey := os.Getenv(OPENAI_API_KEY) + + if len(apiKey) == 0 { + return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.") + } + + return OpenAiClient{ + apiKey: apiKey, + url: "https://api.openai.com/v1/chat/completions", + systemPrompt: PROMPT, + Do: func(req *http.Request) (*http.Response, error) { + client := &http.Client{} + return client.Do(req) + }, + }, nil +} + +func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) { + req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) + if err != nil { + return req, err + } + + req.Header.Add("Authorization", "Bearer "+client.apiKey) + req.Header.Add("Content-Type", "application/json") + + return req, nil +} + +func getCompletionsForImage(model string, temperature float64, prompt, imageName string, imageData []byte) (OpenAiRequestBody, error) { + request := OpenAiRequestBody{ + Model: model, + Temperature: temperature, + } + + // TODO: Add build pattern here that deals with errors in some internal state? + // I want a monad!!! + err := request.AddSystem(prompt) + if err != nil { + return request, err + } + + err = request.AddImage(imageName, imageData) + if err != nil { + return request, err + } + + return request, nil +} + +func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { + aiRequest, err := getCompletionsForImage("gpt-4o-mini", 1.0, client.systemPrompt, imageName, imageData) + if err != nil { + return ImageInfo{}, err + } + + var jsonSchema any + err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema) + if err != nil { + return ImageInfo{}, err + } + + aiRequest.ResponseFormat = ResponseFormat{ + Type: "json_schema", + JsonSchema: jsonSchema, + } + + jsonAiRequest, err := json.Marshal(aiRequest) + if err != nil { + return ImageInfo{}, err + } + + request, err := client.getRequest(jsonAiRequest) + if err != nil { + return ImageInfo{}, err + } + + resp, err := client.Do(request) + if err != nil { + return ImageInfo{}, err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return ImageInfo{}, err + } + + info := ImageInfo{} + err = json.Unmarshal(response, &info) + if err != nil { + return ImageInfo{}, err + } + + log.Println(string(response)) + + return info, nil +} diff --git a/openai_test.go b/openai_test.go new file mode 100644 index 0000000..3a2f663 --- /dev/null +++ b/openai_test.go @@ -0,0 +1,151 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" +) + +func TestMessageBuilder(t *testing.T) { + content := OpenAiMessages{} + + err := content.AddSystem("Some prompt") + + if err != nil { + t.Log(err) + t.FailNow() + } + + if len(content.Messages) != 1 { + t.Logf("Expected length 1, got %d.\n", len(content.Messages)) + t.FailNow() + } +} + +func TestMessageBuilderImage(t *testing.T) { + content := OpenAiMessages{} + + prompt := "some prompt" + imageTitle := "image.png" + data := []byte("some data") + + content.AddSystem(prompt) + content.AddImage(imageTitle, data) + + if len(content.Messages) != 2 { + t.Logf("Expected length 2, got %d.\n", len(content.Messages)) + t.FailNow() + } + + promptMessage, ok := content.Messages[0].(OpenAiTextMessage) + if !ok { + t.Logf("Expected text content message, got %T\n", content.Messages[0]) + t.FailNow() + } + + if promptMessage.Role != ROLE_SYSTEM { + t.Log("Prompt message role is incorrect.") + t.FailNow() + } + + if promptMessage.Content != prompt { + t.Log("Prompt message content is incorrect.") + t.FailNow() + } + + arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage) + if !ok { + t.Logf("Expected text content message, got %T\n", content.Messages[1]) + t.FailNow() + } + + if arrayContentMessage.Role != ROLE_USER { + t.Log("Array content message role is incorrect.") + t.FailNow() + } + + if len(arrayContentMessage.Content) != 1 { + t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content)) + t.FailNow() + } + + imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage) + if !ok { + t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0]) + t.FailNow() + } + + base64data := base64.StdEncoding.EncodeToString(data) + url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data) + + if imageContent.ImageUrl.Url != url { + t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl.Url) + t.FailNow() + } +} + +func TestFullImageRequest(t *testing.T) { + request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", []byte("some data")) + if err != nil { + t.Log(request) + t.FailNow() + } + + jsonData, err := json.Marshal(request) + if err != nil { + t.Log(err) + t.FailNow() + } + + expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,c29tZSBkYXRh"}}]}]}` + + if string(jsonData) != expectedJson { + t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData)) + t.FailNow() + } +} + +func TestResponse(t *testing.T) { + testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}` + buffer := bytes.NewReader([]byte(testResponse)) + + body := io.NopCloser(buffer) + + client := OpenAiClient{ + url: "http://localhost:1234", + apiKey: "some-key", + Do: func(_req *http.Request) (*http.Response, error) { + return &http.Response{Body: body}, nil + }, + } + + info, err := client.GetImageInfo("image.png", []byte("some data")) + if err != nil { + t.Log(err) + t.FailNow() + } + + if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 { + t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links)) + t.FailNow() + } + + if info.Tags[0] != "tag1" { + t.Log("0th tag is wrong.") + t.FailNow() + } + + if info.Tags[1] != "tag2" { + t.Log("1th tag is wrong.") + t.FailNow() + } + + if info.Text[0] != "text1" { + t.Log("0th text is wrong.") + t.FailNow() + } +}