package client import ( "bytes" "encoding/base64" "encoding/json" "errors" "fmt" "io" "log" "net/http" "os" "path/filepath" "screenmark/screenmark/.gen/haystack/haystack/model" ) type ImageInfo struct { Tags []string `json:"tags"` Text []string `json:"text"` Links []string `json:"links"` Locations []model.Locations `json:"locations"` Events []model.Events `json:"events"` } type ResponseFormat struct { Type string `json:"type"` JsonSchema any `json:"json_schema"` } type AgentRequestBody struct { Model string `json:"model"` Temperature float64 `json:"temperature"` ResponseFormat ResponseFormat `json:"response_format"` Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` AgentMessages } type AgentMessages struct { Messages []AgentMessage `json:"messages"` } type AgentMessage interface { MessageToJson() ([]byte, error) } type AgentTextMessage struct { Role string `json:"role"` Content string `json:"content"` ToolCallId string `json:"tool_call_id,omitempty"` Name string `json:"name,omitempty"` } func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { // TODO: Validate the `Role`. return json.Marshal(textContent) } type ToolCall struct { Index int `json:"index"` Id string `json:"id"` Function FunctionCall `json:"function"` } type FunctionCall struct { Name string `json:"name"` Arguments string `json:"arguments"` } type AgentAssistantToolCall struct { Role string `json:"role"` Content string `json:"content"` ToolCalls []ToolCall `json:"tool_calls"` } func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { return json.Marshal(toolCall) } type AgentArrayMessage struct { Role string `json:"role"` Content []AgentContent `json:"content"` } func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { return json.Marshal(arrayContent) } func (content *AgentMessages) AddText(message AgentTextMessage) { content.Messages = append(content.Messages, message) } func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { content.Messages = append(content.Messages, toolCall) } func (content *AgentMessages) AddImage(imageName string, image []byte) error { extension := filepath.Ext(imageName) if len(extension) == 0 { // TODO: could also validate for image types we support. return errors.New("Image does not have extension") } extension = extension[1:] encodedString := base64.StdEncoding.EncodeToString(image) arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} arrayMessage.Content[0] = AgentImage{ ImageType: IMAGE_TYPE, ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), } content.Messages = append(content.Messages, arrayMessage) return nil } func (content *AgentMessages) AddSystem(prompt string) error { if len(content.Messages) != 0 { return errors.New("You can only add a system prompt at the beginning") } content.Messages = append(content.Messages, AgentTextMessage{ Role: ROLE_SYSTEM, Content: prompt, }) return nil } type AgentContent interface { ToJson() ([]byte, error) } type ImageUrl struct { Url string `json:"url"` } type AgentImage struct { ImageType string `json:"type"` ImageUrl string `json:"image_url"` } func (imageMessage AgentImage) ToJson() ([]byte, error) { imageMessage.ImageType = IMAGE_TYPE return json.Marshal(imageMessage) } type AiClient interface { GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) } type ResponseChoiceMessage struct { Role string `json:"role"` Content string `json:"content"` ToolCalls []ToolCall `json:"tool_calls"` } type ResponseChoice struct { Index int `json:"index"` Message ResponseChoiceMessage `json:"message"` FinishReason string `json:"finish_reason"` } type AgentResponse struct { Id string `json:"id"` Object string `json:"object"` Choices []ResponseChoice `json:"choices"` Created int `json:"created"` } type AgentClient struct { url string apiKey string systemPrompt string responseFormat string ToolHandler ToolsHandlers Do func(req *http.Request) (*http.Response, error) } const OPENAI_API_KEY = "OPENAI_API_KEY" const ROLE_USER = "user" const ROLE_SYSTEM = "system" const IMAGE_TYPE = "image_url" func CreateAgentClient(prompt string) (AgentClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) if len(apiKey) == 0 { return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") } return AgentClient{ apiKey: apiKey, url: "https://api.mistral.ai/v1/chat/completions", systemPrompt: prompt, Do: func(req *http.Request) (*http.Response, error) { client := &http.Client{} return client.Do(req) }, }, nil } func (client AgentClient) getRequest(body []byte) (*http.Request, error) { req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) if err != nil { return req, err } req.Header.Add("Authorization", "Bearer "+client.apiKey) req.Header.Add("Content-Type", "application/json") return req, nil } func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { jsonAiRequest, err := json.Marshal(request) if err != nil { return AgentResponse{}, err } httpRequest, err := client.getRequest(jsonAiRequest) if err != nil { return AgentResponse{}, err } resp, err := client.Do(httpRequest) if err != nil { return AgentResponse{}, err } response, err := io.ReadAll(resp.Body) if err != nil { return AgentResponse{}, err } agentResponse := AgentResponse{} err = json.Unmarshal(response, &agentResponse) if err != nil { return AgentResponse{}, err } log.Println(string(response)) toolCalls := agentResponse.Choices[0].Message.ToolCalls if len(toolCalls) > 0 { // Should for sure be more flexible. request.AddToolCall(AgentAssistantToolCall{ Role: "assistant", Content: "", ToolCalls: toolCalls, }) } return agentResponse, nil } func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error { var err error for err == nil { log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) response, requestError := client.Request(&request) if requestError != nil { return requestError } log.Println(response) a, innerErr := client.ToolHandler.Handle(info, &request) err = innerErr log.Println(a) log.Println("--------------------------") } return nil }