feat: working e2e system making requests to open ai
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
.env
|
||||||
|
db
|
||||||
|
screenmark
|
2
go.mod
2
go.mod
@ -1,3 +1,5 @@
|
|||||||
module screenmark/screenmark
|
module screenmark/screenmark
|
||||||
|
|
||||||
go 1.24.0
|
go 1.24.0
|
||||||
|
|
||||||
|
require github.com/joho/godotenv v1.5.1 // indirect
|
||||||
|
2
go.sum
Normal file
2
go.sum
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
22
main.go
22
main.go
@ -6,9 +6,21 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/joho/godotenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
err := godotenv.Load()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAiClient, err := CreateOpenAiClient()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
mux.HandleFunc("OPTIONS /image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("OPTIONS /image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -55,6 +67,16 @@ func main() {
|
|||||||
fmt.Fprintf(w, "Couldnt write the image")
|
fmt.Fprintf(w, "Couldnt write the image")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp, err := openAiClient.GetImageInfo(imageName, image)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, "some shit happened")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("%+v\n", resp)
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Println("Listening and serving.")
|
log.Println("Listening and serving.")
|
||||||
|
280
openai.go
Normal file
280
openai.go
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageInfo struct {
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Text []string `json:"text"`
|
||||||
|
Links []string `json:"links"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseFormat struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
JsonSchema any `json:"json_schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiRequestBody struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
ResponseFormat ResponseFormat `json:"response_format"`
|
||||||
|
|
||||||
|
OpenAiMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiMessages struct {
|
||||||
|
Messages []OpenAiMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiMessage interface {
|
||||||
|
MessageToJson() ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiTextMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) {
|
||||||
|
// TODO: Validate the `Role`.
|
||||||
|
return json.Marshal(textContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiArrayMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content []OpenAiContent `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) {
|
||||||
|
return json.Marshal(arrayContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
|
||||||
|
extension := filepath.Ext(imageName)
|
||||||
|
if len(extension) == 0 {
|
||||||
|
// TODO: could also validate for image types we support.
|
||||||
|
return errors.New("Image does not have extension")
|
||||||
|
}
|
||||||
|
|
||||||
|
extension = extension[1:]
|
||||||
|
|
||||||
|
encodedString := base64.StdEncoding.EncodeToString(image)
|
||||||
|
|
||||||
|
arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)}
|
||||||
|
arrayMessage.Content[0] = OpenAiImage{
|
||||||
|
ImageType: IMAGE_TYPE,
|
||||||
|
ImageUrl: ImageUrl{
|
||||||
|
Url: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
content.Messages = append(content.Messages, arrayMessage)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (content *OpenAiMessages) AddSystem(prompt string) error {
|
||||||
|
if len(content.Messages) != 0 {
|
||||||
|
return errors.New("You can only add a system prompt at the beginning")
|
||||||
|
}
|
||||||
|
|
||||||
|
content.Messages = append(content.Messages, OpenAiTextMessage{
|
||||||
|
Role: ROLE_SYSTEM,
|
||||||
|
Content: prompt,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiContent interface {
|
||||||
|
ToJson() ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageUrl struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiImage struct {
|
||||||
|
ImageType string `json:"type"`
|
||||||
|
ImageUrl ImageUrl `json:"image_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (imageMessage OpenAiImage) ToJson() ([]byte, error) {
|
||||||
|
imageMessage.ImageType = IMAGE_TYPE
|
||||||
|
return json.Marshal(imageMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAiClient struct {
|
||||||
|
url string
|
||||||
|
apiKey string
|
||||||
|
systemPrompt string
|
||||||
|
responseFormat string
|
||||||
|
|
||||||
|
Do func(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// func (client OpenAiClient) Do(req *http.Request) () {
|
||||||
|
// httpClient := http.Client{}
|
||||||
|
// return httpClient.Do(req)
|
||||||
|
// }
|
||||||
|
|
||||||
|
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||||
|
const ROLE_USER = "user"
|
||||||
|
const ROLE_SYSTEM = "system"
|
||||||
|
const IMAGE_TYPE = "image_url"
|
||||||
|
|
||||||
|
// TODO: extract to text file probably
|
||||||
|
const PROMPT = `
|
||||||
|
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
|
||||||
|
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
|
||||||
|
|
||||||
|
This system is part of a bookmark manager, who's main goal is to allow the user to search through various screenshots.
|
||||||
|
`
|
||||||
|
|
||||||
|
const RESPONSE_FORMAT = `
|
||||||
|
{
|
||||||
|
"name": "schema_description",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tags": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "A list of tags you think the image is relevant to.",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "A list of sentences the image contains.",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"links": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "A list of all the links you can find in the image.",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"tags",
|
||||||
|
"text",
|
||||||
|
"links"
|
||||||
|
],
|
||||||
|
"additionalProperties": false
|
||||||
|
},
|
||||||
|
"strict": true
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
func CreateOpenAiClient() (OpenAiClient, error) {
|
||||||
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||||
|
|
||||||
|
if len(apiKey) == 0 {
|
||||||
|
return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return OpenAiClient{
|
||||||
|
apiKey: apiKey,
|
||||||
|
url: "https://api.openai.com/v1/chat/completions",
|
||||||
|
systemPrompt: PROMPT,
|
||||||
|
Do: func(req *http.Request) (*http.Response, error) {
|
||||||
|
client := &http.Client{}
|
||||||
|
return client.Do(req)
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) {
|
||||||
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCompletionsForImage(model string, temperature float64, prompt, imageName string, imageData []byte) (OpenAiRequestBody, error) {
|
||||||
|
request := OpenAiRequestBody{
|
||||||
|
Model: model,
|
||||||
|
Temperature: temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add build pattern here that deals with errors in some internal state?
|
||||||
|
// I want a monad!!!
|
||||||
|
err := request.AddSystem(prompt)
|
||||||
|
if err != nil {
|
||||||
|
return request, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = request.AddImage(imageName, imageData)
|
||||||
|
if err != nil {
|
||||||
|
return request, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
||||||
|
aiRequest, err := getCompletionsForImage("gpt-4o-mini", 1.0, client.systemPrompt, imageName, imageData)
|
||||||
|
if err != nil {
|
||||||
|
return ImageInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var jsonSchema any
|
||||||
|
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
|
||||||
|
if err != nil {
|
||||||
|
return ImageInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
aiRequest.ResponseFormat = ResponseFormat{
|
||||||
|
Type: "json_schema",
|
||||||
|
JsonSchema: jsonSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonAiRequest, err := json.Marshal(aiRequest)
|
||||||
|
if err != nil {
|
||||||
|
return ImageInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
request, err := client.getRequest(jsonAiRequest)
|
||||||
|
if err != nil {
|
||||||
|
return ImageInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return ImageInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ImageInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
info := ImageInfo{}
|
||||||
|
err = json.Unmarshal(response, &info)
|
||||||
|
if err != nil {
|
||||||
|
return ImageInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(string(response))
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
151
openai_test.go
Normal file
151
openai_test.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMessageBuilder(t *testing.T) {
|
||||||
|
content := OpenAiMessages{}
|
||||||
|
|
||||||
|
err := content.AddSystem("Some prompt")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(content.Messages) != 1 {
|
||||||
|
t.Logf("Expected length 1, got %d.\n", len(content.Messages))
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageBuilderImage(t *testing.T) {
|
||||||
|
content := OpenAiMessages{}
|
||||||
|
|
||||||
|
prompt := "some prompt"
|
||||||
|
imageTitle := "image.png"
|
||||||
|
data := []byte("some data")
|
||||||
|
|
||||||
|
content.AddSystem(prompt)
|
||||||
|
content.AddImage(imageTitle, data)
|
||||||
|
|
||||||
|
if len(content.Messages) != 2 {
|
||||||
|
t.Logf("Expected length 2, got %d.\n", len(content.Messages))
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
promptMessage, ok := content.Messages[0].(OpenAiTextMessage)
|
||||||
|
if !ok {
|
||||||
|
t.Logf("Expected text content message, got %T\n", content.Messages[0])
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if promptMessage.Role != ROLE_SYSTEM {
|
||||||
|
t.Log("Prompt message role is incorrect.")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if promptMessage.Content != prompt {
|
||||||
|
t.Log("Prompt message content is incorrect.")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage)
|
||||||
|
if !ok {
|
||||||
|
t.Logf("Expected text content message, got %T\n", content.Messages[1])
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if arrayContentMessage.Role != ROLE_USER {
|
||||||
|
t.Log("Array content message role is incorrect.")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(arrayContentMessage.Content) != 1 {
|
||||||
|
t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content))
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage)
|
||||||
|
if !ok {
|
||||||
|
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
base64data := base64.StdEncoding.EncodeToString(data)
|
||||||
|
url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data)
|
||||||
|
|
||||||
|
if imageContent.ImageUrl.Url != url {
|
||||||
|
t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl.Url)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFullImageRequest(t *testing.T) {
|
||||||
|
request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", []byte("some data"))
|
||||||
|
if err != nil {
|
||||||
|
t.Log(request)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}`
|
||||||
|
|
||||||
|
if string(jsonData) != expectedJson {
|
||||||
|
t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData))
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponse(t *testing.T) {
|
||||||
|
testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}`
|
||||||
|
buffer := bytes.NewReader([]byte(testResponse))
|
||||||
|
|
||||||
|
body := io.NopCloser(buffer)
|
||||||
|
|
||||||
|
client := OpenAiClient{
|
||||||
|
url: "http://localhost:1234",
|
||||||
|
apiKey: "some-key",
|
||||||
|
Do: func(_req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{Body: body}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := client.GetImageInfo("image.png", []byte("some data"))
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 {
|
||||||
|
t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links))
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Tags[0] != "tag1" {
|
||||||
|
t.Log("0th tag is wrong.")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Tags[1] != "tag2" {
|
||||||
|
t.Log("1th tag is wrong.")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Text[0] != "text1" {
|
||||||
|
t.Log("0th text is wrong.")
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user