package main import ( "bytes" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "testing" ) func TestMessageBuilder(t *testing.T) { content := OpenAiMessages{} err := content.AddSystem("Some prompt") if err != nil { t.Log(err) t.FailNow() } if len(content.Messages) != 1 { t.Logf("Expected length 1, got %d.\n", len(content.Messages)) t.FailNow() } } func TestMessageBuilderImage(t *testing.T) { content := OpenAiMessages{} prompt := "some prompt" imageTitle := "image.png" data := []byte("some data") content.AddSystem(prompt) content.AddImage(imageTitle, data) if len(content.Messages) != 2 { t.Logf("Expected length 2, got %d.\n", len(content.Messages)) t.FailNow() } promptMessage, ok := content.Messages[0].(OpenAiTextMessage) if !ok { t.Logf("Expected text content message, got %T\n", content.Messages[0]) t.FailNow() } if promptMessage.Role != ROLE_SYSTEM { t.Log("Prompt message role is incorrect.") t.FailNow() } if promptMessage.Content != prompt { t.Log("Prompt message content is incorrect.") t.FailNow() } arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage) if !ok { t.Logf("Expected text content message, got %T\n", content.Messages[1]) t.FailNow() } if arrayContentMessage.Role != ROLE_USER { t.Log("Array content message role is incorrect.") t.FailNow() } if len(arrayContentMessage.Content) != 1 { t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content)) t.FailNow() } imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage) if !ok { t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0]) t.FailNow() } base64data := base64.StdEncoding.EncodeToString(data) url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data) if imageContent.ImageUrl != url { t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl) t.FailNow() } } func TestFullImageRequest(t *testing.T) { request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data")) if err != nil { t.Log(request) t.FailNow() } jsonData, err := json.Marshal(request) if err != nil { t.Log(err) t.FailNow() } expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,c29tZSBkYXRh"}}]}]}` if string(jsonData) != expectedJson { t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData)) t.FailNow() } } func TestResponse(t *testing.T) { testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}` buffer := bytes.NewReader([]byte(testResponse)) body := io.NopCloser(buffer) client := OpenAiClient{ url: "http://localhost:1234", apiKey: "some-key", Do: func(_req *http.Request) (*http.Response, error) { return &http.Response{Body: body}, nil }, } info, err := client.GetImageInfo("image.png", []byte("some data")) if err != nil { t.Log(err) t.FailNow() } if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 { t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links)) t.FailNow() } if info.Tags[0] != "tag1" { t.Log("0th tag is wrong.") t.FailNow() } if info.Tags[1] != "tag2" { t.Log("1th tag is wrong.") t.FailNow() } if info.Text[0] != "text1" { t.Log("0th text is wrong.") t.FailNow() } } func TestResponseParsing(t *testing.T) { response := `{ "id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ", "object": "chat.completion", "created": 1740422508, "model": "gpt-4o-mini-2024-07-18", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}", "refusal": null }, "logprobs": null, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 775, "completion_tokens": 33, "total_tokens": 808, "prompt_tokens_details": { "cached_tokens": 0, "audio_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "service_tier": "default", "system_fingerprint": "fp_7fcd609668" }` imageParsed, err := parseOpenAiResponse([]byte(response)) if err != nil { t.Log(err) t.FailNow() } if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" { t.Log(imageParsed) t.Log("Should have one link called 'link'.") t.FailNow() } if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" { t.Log(imageParsed) t.Log("Should have one tag called 'tag'.") t.FailNow() } if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" { t.Log(imageParsed) t.Log("Should have one text called 'text'.") t.FailNow() } }