diff --git a/backend/openai.go b/backend/openai.go index 1bcd3aa..91ee9c4 100644 --- a/backend/openai.go +++ b/backend/openai.go @@ -210,10 +210,14 @@ func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) { return req, nil } -func getCompletionsForImage(model string, temperature float64, prompt, imageName string, imageData []byte) (OpenAiRequestBody, error) { +func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (OpenAiRequestBody, error) { request := OpenAiRequestBody{ Model: model, Temperature: temperature, + ResponseFormat: ResponseFormat{ + Type: "json_schema", + JsonSchema: jsonSchema, + }, } // TODO: Add build pattern here that deals with errors in some internal state? @@ -231,8 +235,84 @@ func getCompletionsForImage(model string, temperature float64, prompt, imageName return request, nil } +// { +// "id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ", +// "object": "chat.completion", +// "created": 1740422508, +// "model": "gpt-4o-mini-2024-07-18", +// "choices": [ +// { +// "index": 0, +// "message": { +// "role": "assistant", +// "content": "{\"links\":[],\"tags\":[\"Git\",\"Programming\",\"Humor\",\"Meme\"],\"text\":[\"GIT FLOW\",\"JUST USE MAIN\",\"JUST USE MAIN\"]}", +// "refusal": null +// }, +// "logprobs": null, +// "finish_reason": "stop" +// } +// ], +// "usage": { +// "prompt_tokens": 775, +// "completion_tokens": 33, +// "total_tokens": 808, +// "prompt_tokens_details": { +// "cached_tokens": 0, +// "audio_tokens": 0 +// }, +// "completion_tokens_details": { +// "reasoning_tokens": 0, +// "audio_tokens": 0, +// "accepted_prediction_tokens": 0, +// "rejected_prediction_tokens": 0 +// } +// }, +// "service_tier": "default", +// "system_fingerprint": "fp_7fcd609668" +// } + +type ResponseChoiceMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ResponseChoice struct { + Index int `json:"index"` + Message ResponseChoiceMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type OpenAiResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Choices []ResponseChoice `json:"choices"` + Created int `json:"created"` +} + +// TODO: add usage parsing +func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) { + response := OpenAiResponse{} + + err := json.Unmarshal(jsonResponse, &response) + if err != nil { + return ImageInfo{}, err + } + + if len(response.Choices) != 1 { + return ImageInfo{}, errors.New("Expected exactly one choice.") + } + + imageInfo := ImageInfo{} + err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo) + if err != nil { + return ImageInfo{}, errors.New("Could not parse content into image type.") + } + + return imageInfo, nil +} + func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { - aiRequest, err := getCompletionsForImage("gpt-4o-mini", 1.0, client.systemPrompt, imageName, imageData) + aiRequest, err := getCompletionsForImage("gpt-4o-mini", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) if err != nil { return ImageInfo{}, err } diff --git a/backend/openai_test.go b/backend/openai_test.go index 3a2f663..945f507 100644 --- a/backend/openai_test.go +++ b/backend/openai_test.go @@ -89,7 +89,7 @@ func TestMessageBuilderImage(t *testing.T) { } func TestFullImageRequest(t *testing.T) { - request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", []byte("some data")) + request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data")) if err != nil { t.Log(request) t.FailNow() @@ -101,7 +101,7 @@ func TestFullImageRequest(t *testing.T) { t.FailNow() } - expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}` + expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}` if string(jsonData) != expectedJson { t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData)) @@ -149,3 +149,65 @@ func TestResponse(t *testing.T) { t.FailNow() } } + +func TestResponseParsing(t *testing.T) { + response := `{ + "id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ", + "object": "chat.completion", + "created": 1740422508, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 775, + "completion_tokens": 33, + "total_tokens": 808, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default", + "system_fingerprint": "fp_7fcd609668" +}` + + imageParsed, err := parseOpenAiResponse([]byte(response)) + if err != nil { + t.Log(err) + t.FailNow() + } + + if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" { + t.Log(imageParsed) + t.Log("Should have one link called 'link'.") + t.FailNow() + } + + if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" { + t.Log(imageParsed) + t.Log("Should have one tag called 'tag'.") + t.FailNow() + } + + if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" { + t.Log(imageParsed) + t.Log("Should have one text called 'text'.") + t.FailNow() + } +}