I guess some repeated code doesnt hurt anyone, if it keeps things simpler. Trying to be fancy with the interfaces didn't work so well.
289 lines
6.5 KiB
Go
289 lines
6.5 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
|
)
|
|
|
|
type ImageInfo struct {
|
|
Tags []string `json:"tags"`
|
|
Text []string `json:"text"`
|
|
Links []string `json:"links"`
|
|
|
|
Locations []model.Locations `json:"locations"`
|
|
Events []model.Events `json:"events"`
|
|
}
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema any `json:"json_schema"`
|
|
}
|
|
|
|
type AgentRequestBody struct {
|
|
Model string `json:"model"`
|
|
Temperature float64 `json:"temperature"`
|
|
ResponseFormat ResponseFormat `json:"response_format"`
|
|
|
|
Tools *any `json:"tools,omitempty"`
|
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
|
|
|
// ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
|
|
|
AgentMessages
|
|
}
|
|
|
|
type AgentMessages struct {
|
|
Messages []AgentMessage `json:"messages"`
|
|
}
|
|
|
|
type AgentMessage interface {
|
|
MessageToJson() ([]byte, error)
|
|
}
|
|
|
|
type AgentTextMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
}
|
|
|
|
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
|
// TODO: Validate the `Role`.
|
|
return json.Marshal(textContent)
|
|
}
|
|
|
|
type ToolCall struct {
|
|
Index int `json:"index"`
|
|
Id string `json:"id"`
|
|
Function FunctionCall `json:"function"`
|
|
}
|
|
|
|
type FunctionCall struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
}
|
|
|
|
type AgentAssistantToolCall struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCall `json:"tool_calls"`
|
|
}
|
|
|
|
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
|
|
return json.Marshal(toolCall)
|
|
}
|
|
|
|
type AgentArrayMessage struct {
|
|
Role string `json:"role"`
|
|
Content []AgentContent `json:"content"`
|
|
}
|
|
|
|
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
|
return json.Marshal(arrayContent)
|
|
}
|
|
|
|
func (content *AgentMessages) AddText(message AgentTextMessage) {
|
|
content.Messages = append(content.Messages, message)
|
|
}
|
|
|
|
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
|
|
content.Messages = append(content.Messages, toolCall)
|
|
}
|
|
|
|
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
|
extension := filepath.Ext(imageName)
|
|
if len(extension) == 0 {
|
|
// TODO: could also validate for image types we support.
|
|
return errors.New("Image does not have extension")
|
|
}
|
|
|
|
extension = extension[1:]
|
|
|
|
encodedString := base64.StdEncoding.EncodeToString(image)
|
|
|
|
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
|
|
arrayMessage.Content[0] = AgentImage{
|
|
ImageType: IMAGE_TYPE,
|
|
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
|
}
|
|
|
|
content.Messages = append(content.Messages, arrayMessage)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (content *AgentMessages) AddSystem(prompt string) error {
|
|
if len(content.Messages) != 0 {
|
|
return errors.New("You can only add a system prompt at the beginning")
|
|
}
|
|
|
|
content.Messages = append(content.Messages, AgentTextMessage{
|
|
Role: ROLE_SYSTEM,
|
|
Content: prompt,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
type AgentContent interface {
|
|
ToJson() ([]byte, error)
|
|
}
|
|
|
|
type ImageUrl struct {
|
|
Url string `json:"url"`
|
|
}
|
|
|
|
type AgentImage struct {
|
|
ImageType string `json:"type"`
|
|
ImageUrl string `json:"image_url"`
|
|
}
|
|
|
|
func (imageMessage AgentImage) ToJson() ([]byte, error) {
|
|
imageMessage.ImageType = IMAGE_TYPE
|
|
return json.Marshal(imageMessage)
|
|
}
|
|
|
|
type AiClient interface {
|
|
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
|
|
}
|
|
|
|
type ResponseChoiceMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
ToolCalls []ToolCall `json:"tool_calls"`
|
|
}
|
|
|
|
type ResponseChoice struct {
|
|
Index int `json:"index"`
|
|
Message ResponseChoiceMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
type AgentResponse struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Choices []ResponseChoice `json:"choices"`
|
|
Created int `json:"created"`
|
|
}
|
|
|
|
type AgentClient struct {
|
|
url string
|
|
apiKey string
|
|
systemPrompt string
|
|
responseFormat string
|
|
|
|
ToolHandler ToolsHandlers
|
|
|
|
Do func(req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
|
const ROLE_USER = "user"
|
|
const ROLE_SYSTEM = "system"
|
|
const IMAGE_TYPE = "image_url"
|
|
|
|
func CreateAgentClient(prompt string) (AgentClient, error) {
|
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
|
|
|
if len(apiKey) == 0 {
|
|
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
|
}
|
|
|
|
return AgentClient{
|
|
apiKey: apiKey,
|
|
url: "https://api.mistral.ai/v1/chat/completions",
|
|
systemPrompt: prompt,
|
|
Do: func(req *http.Request) (*http.Response, error) {
|
|
client := &http.Client{}
|
|
return client.Do(req)
|
|
},
|
|
|
|
ToolHandler: ToolsHandlers{
|
|
handlers: &map[string]ToolHandler{},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
|
if err != nil {
|
|
return req, err
|
|
}
|
|
|
|
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
|
req.Header.Add("Content-Type", "application/json")
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
|
|
jsonAiRequest, err := json.Marshal(request)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
httpRequest, err := client.getRequest(jsonAiRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
resp, err := client.Do(httpRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
response, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
agentResponse := AgentResponse{}
|
|
err = json.Unmarshal(response, &agentResponse)
|
|
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
log.Println(string(response))
|
|
|
|
toolCalls := agentResponse.Choices[0].Message.ToolCalls
|
|
if len(toolCalls) > 0 {
|
|
// Should for sure be more flexible.
|
|
request.AddToolCall(AgentAssistantToolCall{
|
|
Role: "assistant",
|
|
Content: "",
|
|
ToolCalls: toolCalls,
|
|
})
|
|
}
|
|
|
|
return agentResponse, nil
|
|
}
|
|
|
|
func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error {
|
|
var err error
|
|
|
|
for {
|
|
err = client.ToolHandler.Handle(info, &request)
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
_, err = client.Request(&request)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
log.Println(err)
|
|
return err
|
|
}
|