214 lines
4.9 KiB
Go
214 lines
4.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"testing"
|
|
)
|
|
|
|
func TestMessageBuilder(t *testing.T) {
|
|
content := OpenAiMessages{}
|
|
|
|
err := content.AddSystem("Some prompt")
|
|
|
|
if err != nil {
|
|
t.Log(err)
|
|
t.FailNow()
|
|
}
|
|
|
|
if len(content.Messages) != 1 {
|
|
t.Logf("Expected length 1, got %d.\n", len(content.Messages))
|
|
t.FailNow()
|
|
}
|
|
}
|
|
|
|
func TestMessageBuilderImage(t *testing.T) {
|
|
content := OpenAiMessages{}
|
|
|
|
prompt := "some prompt"
|
|
imageTitle := "image.png"
|
|
data := []byte("some data")
|
|
|
|
content.AddSystem(prompt)
|
|
content.AddImage(imageTitle, data)
|
|
|
|
if len(content.Messages) != 2 {
|
|
t.Logf("Expected length 2, got %d.\n", len(content.Messages))
|
|
t.FailNow()
|
|
}
|
|
|
|
promptMessage, ok := content.Messages[0].(OpenAiTextMessage)
|
|
if !ok {
|
|
t.Logf("Expected text content message, got %T\n", content.Messages[0])
|
|
t.FailNow()
|
|
}
|
|
|
|
if promptMessage.Role != ROLE_SYSTEM {
|
|
t.Log("Prompt message role is incorrect.")
|
|
t.FailNow()
|
|
}
|
|
|
|
if promptMessage.Content != prompt {
|
|
t.Log("Prompt message content is incorrect.")
|
|
t.FailNow()
|
|
}
|
|
|
|
arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage)
|
|
if !ok {
|
|
t.Logf("Expected text content message, got %T\n", content.Messages[1])
|
|
t.FailNow()
|
|
}
|
|
|
|
if arrayContentMessage.Role != ROLE_USER {
|
|
t.Log("Array content message role is incorrect.")
|
|
t.FailNow()
|
|
}
|
|
|
|
if len(arrayContentMessage.Content) != 1 {
|
|
t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content))
|
|
t.FailNow()
|
|
}
|
|
|
|
imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage)
|
|
if !ok {
|
|
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
|
|
t.FailNow()
|
|
}
|
|
|
|
base64data := base64.StdEncoding.EncodeToString(data)
|
|
url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data)
|
|
|
|
if imageContent.ImageUrl.Url != url {
|
|
t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl.Url)
|
|
t.FailNow()
|
|
}
|
|
}
|
|
|
|
func TestFullImageRequest(t *testing.T) {
|
|
request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data"))
|
|
if err != nil {
|
|
t.Log(request)
|
|
t.FailNow()
|
|
}
|
|
|
|
jsonData, err := json.Marshal(request)
|
|
if err != nil {
|
|
t.Log(err)
|
|
t.FailNow()
|
|
}
|
|
|
|
expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}`
|
|
|
|
if string(jsonData) != expectedJson {
|
|
t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData))
|
|
t.FailNow()
|
|
}
|
|
}
|
|
|
|
func TestResponse(t *testing.T) {
|
|
testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}`
|
|
buffer := bytes.NewReader([]byte(testResponse))
|
|
|
|
body := io.NopCloser(buffer)
|
|
|
|
client := OpenAiClient{
|
|
url: "http://localhost:1234",
|
|
apiKey: "some-key",
|
|
Do: func(_req *http.Request) (*http.Response, error) {
|
|
return &http.Response{Body: body}, nil
|
|
},
|
|
}
|
|
|
|
info, err := client.GetImageInfo("image.png", []byte("some data"))
|
|
if err != nil {
|
|
t.Log(err)
|
|
t.FailNow()
|
|
}
|
|
|
|
if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 {
|
|
t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links))
|
|
t.FailNow()
|
|
}
|
|
|
|
if info.Tags[0] != "tag1" {
|
|
t.Log("0th tag is wrong.")
|
|
t.FailNow()
|
|
}
|
|
|
|
if info.Tags[1] != "tag2" {
|
|
t.Log("1th tag is wrong.")
|
|
t.FailNow()
|
|
}
|
|
|
|
if info.Text[0] != "text1" {
|
|
t.Log("0th text is wrong.")
|
|
t.FailNow()
|
|
}
|
|
}
|
|
|
|
func TestResponseParsing(t *testing.T) {
|
|
response := `{
|
|
"id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ",
|
|
"object": "chat.completion",
|
|
"created": 1740422508,
|
|
"model": "gpt-4o-mini-2024-07-18",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}",
|
|
"refusal": null
|
|
},
|
|
"logprobs": null,
|
|
"finish_reason": "stop"
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 775,
|
|
"completion_tokens": 33,
|
|
"total_tokens": 808,
|
|
"prompt_tokens_details": {
|
|
"cached_tokens": 0,
|
|
"audio_tokens": 0
|
|
},
|
|
"completion_tokens_details": {
|
|
"reasoning_tokens": 0,
|
|
"audio_tokens": 0,
|
|
"accepted_prediction_tokens": 0,
|
|
"rejected_prediction_tokens": 0
|
|
}
|
|
},
|
|
"service_tier": "default",
|
|
"system_fingerprint": "fp_7fcd609668"
|
|
}`
|
|
|
|
imageParsed, err := parseOpenAiResponse([]byte(response))
|
|
if err != nil {
|
|
t.Log(err)
|
|
t.FailNow()
|
|
}
|
|
|
|
if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" {
|
|
t.Log(imageParsed)
|
|
t.Log("Should have one link called 'link'.")
|
|
t.FailNow()
|
|
}
|
|
|
|
if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" {
|
|
t.Log(imageParsed)
|
|
t.Log("Should have one tag called 'tag'.")
|
|
t.FailNow()
|
|
}
|
|
|
|
if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" {
|
|
t.Log(imageParsed)
|
|
t.Log("Should have one text called 'text'.")
|
|
t.FailNow()
|
|
}
|
|
}
|