Haystack/openai.go

281 lines
6.3 KiB
Go

package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
)
type ImageInfo struct {
Tags []string `json:"tags"`
Text []string `json:"text"`
Links []string `json:"links"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type OpenAiRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
OpenAiMessages
}
type OpenAiMessages struct {
Messages []OpenAiMessage `json:"messages"`
}
type OpenAiMessage interface {
MessageToJson() ([]byte, error)
}
type OpenAiTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
type OpenAiArrayMessage struct {
Role string `json:"role"`
Content []OpenAiContent `json:"content"`
}
func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)}
arrayMessage.Content[0] = OpenAiImage{
ImageType: IMAGE_TYPE,
ImageUrl: ImageUrl{
Url: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
},
}
content.Messages = append(content.Messages, arrayMessage)
return nil
}
func (content *OpenAiMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, OpenAiTextMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
return nil
}
type OpenAiContent interface {
ToJson() ([]byte, error)
}
type ImageUrl struct {
Url string `json:"url"`
}
type OpenAiImage struct {
ImageType string `json:"type"`
ImageUrl ImageUrl `json:"image_url"`
}
func (imageMessage OpenAiImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage)
}
type OpenAiClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
Do func(req *http.Request) (*http.Response, error)
}
// func (client OpenAiClient) Do(req *http.Request) () {
// httpClient := http.Client{}
// return httpClient.Do(req)
// }
const OPENAI_API_KEY = "OPENAI_API_KEY"
const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
// TODO: extract to text file probably
const PROMPT = `
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
This system is part of a bookmark manager, who's main goal is to allow the user to search through various screenshots.
`
const RESPONSE_FORMAT = `
{
"name": "schema_description",
"schema": {
"type": "object",
"properties": {
"tags": {
"type": "array",
"description": "A list of tags you think the image is relevant to.",
"items": {
"type": "string"
}
},
"text": {
"type": "array",
"description": "A list of sentences the image contains.",
"items": {
"type": "string"
}
},
"links": {
"type": "array",
"description": "A list of all the links you can find in the image.",
"items": {
"type": "string"
}
}
},
"required": [
"tags",
"text",
"links"
],
"additionalProperties": false
},
"strict": true
}
`
func CreateOpenAiClient() (OpenAiClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return OpenAiClient{
apiKey: apiKey,
url: "https://api.openai.com/v1/chat/completions",
systemPrompt: PROMPT,
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
}, nil
}
func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func getCompletionsForImage(model string, temperature float64, prompt, imageName string, imageData []byte) (OpenAiRequestBody, error) {
request := OpenAiRequestBody{
Model: model,
Temperature: temperature,
}
// TODO: Add build pattern here that deals with errors in some internal state?
// I want a monad!!!
err := request.AddSystem(prompt)
if err != nil {
return request, err
}
err = request.AddImage(imageName, imageData)
if err != nil {
return request, err
}
return request, nil
}
func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
aiRequest, err := getCompletionsForImage("gpt-4o-mini", 1.0, client.systemPrompt, imageName, imageData)
if err != nil {
return ImageInfo{}, err
}
var jsonSchema any
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
if err != nil {
return ImageInfo{}, err
}
aiRequest.ResponseFormat = ResponseFormat{
Type: "json_schema",
JsonSchema: jsonSchema,
}
jsonAiRequest, err := json.Marshal(aiRequest)
if err != nil {
return ImageInfo{}, err
}
request, err := client.getRequest(jsonAiRequest)
if err != nil {
return ImageInfo{}, err
}
resp, err := client.Do(request)
if err != nil {
return ImageInfo{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return ImageInfo{}, err
}
info := ImageInfo{}
err = json.Unmarshal(response, &info)
if err != nil {
return ImageInfo{}, err
}
log.Println(string(response))
return info, nil
}