291 lines
6.5 KiB
Go

package client
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
)
type ImageInfo struct {
Tags []string `json:"tags"`
Text []string `json:"text"`
Links []string `json:"links"`
Locations []model.Locations `json:"locations"`
Events []model.Events `json:"events"`
}
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
// ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
AgentMessages
}
type AgentMessages struct {
Messages []AgentMessage `json:"messages"`
}
type AgentMessage interface {
MessageToJson() ([]byte, error)
}
type AgentTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
type ToolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type AgentAssistantToolCall struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
return json.Marshal(toolCall)
}
type AgentArrayMessage struct {
Role string `json:"role"`
Content []AgentContent `json:"content"`
}
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *AgentMessages) AddText(message AgentTextMessage) {
content.Messages = append(content.Messages, message)
}
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
content.Messages = append(content.Messages, toolCall)
}
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {
// TODO: could also validate for image types we support.
return errors.New("Image does not have extension")
}
extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
arrayMessage.Content[0] = AgentImage{
ImageType: IMAGE_TYPE,
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
}
content.Messages = append(content.Messages, arrayMessage)
return nil
}
func (content *AgentMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, AgentTextMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
return nil
}
type AgentContent interface {
ToJson() ([]byte, error)
}
type ImageUrl struct {
Url string `json:"url"`
}
type AgentImage struct {
ImageType string `json:"type"`
ImageUrl string `json:"image_url"`
}
func (imageMessage AgentImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage)
}
type AiClient interface {
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
}
type ResponseChoiceMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message ResponseChoiceMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
type AgentClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
ToolHandler ToolsHandlers
Do func(req *http.Request) (*http.Response, error)
}
const OPENAI_API_KEY = "OPENAI_API_KEY"
const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
func CreateAgentClient(prompt string) (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: prompt,
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
ToolHandler: ToolsHandlers{
handlers: &map[string]ToolHandler{},
},
}, nil
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(request)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
log.Println(string(response))
toolCalls := agentResponse.Choices[0].Message.ToolCalls
if len(toolCalls) > 0 {
// Should for sure be more flexible.
request.AddToolCall(AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: toolCalls,
})
}
return agentResponse, nil
}
func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error {
var err error
for {
err = client.ToolHandler.Handle(info, &request)
if err != nil {
break
}
_, err = client.Request(&request)
if err != nil {
break
}
}
if err != nil {
log.Println(err)
}
return err
}