Orchestrator + Tooling rework #4
@ -1,479 +0,0 @@
|
|||||||
package agents
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ImageInfo struct {
|
|
||||||
Tags []string `json:"tags"`
|
|
||||||
Text []string `json:"text"`
|
|
||||||
Links []string `json:"links"`
|
|
||||||
|
|
||||||
Locations []model.Locations `json:"locations"`
|
|
||||||
Events []model.Events `json:"events"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseFormat struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
JsonSchema any `json:"json_schema"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentRequestBody struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
ResponseFormat ResponseFormat `json:"response_format"`
|
|
||||||
|
|
||||||
Tools *any `json:"tools,omitempty"`
|
|
||||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
|
||||||
|
|
||||||
AgentMessages
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentMessages struct {
|
|
||||||
Messages []AgentMessage `json:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentMessage interface {
|
|
||||||
MessageToJson() ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentTextMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
|
||||||
// TODO: Validate the `Role`.
|
|
||||||
return json.Marshal(textContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentAssistantToolCall struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) {
|
|
||||||
return json.Marshal(toolCall)
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentArrayMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content []AgentContent `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
|
||||||
return json.Marshal(arrayContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (content *AgentMessages) AddText(message AgentTextMessage) {
|
|
||||||
content.Messages = append(content.Messages, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) {
|
|
||||||
content.Messages = append(content.Messages, toolCall)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
|
||||||
extension := filepath.Ext(imageName)
|
|
||||||
if len(extension) == 0 {
|
|
||||||
// TODO: could also validate for image types we support.
|
|
||||||
return errors.New("Image does not have extension")
|
|
||||||
}
|
|
||||||
|
|
||||||
extension = extension[1:]
|
|
||||||
|
|
||||||
encodedString := base64.StdEncoding.EncodeToString(image)
|
|
||||||
|
|
||||||
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
|
|
||||||
arrayMessage.Content[0] = AgentImage{
|
|
||||||
ImageType: IMAGE_TYPE,
|
|
||||||
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
|
||||||
}
|
|
||||||
|
|
||||||
content.Messages = append(content.Messages, arrayMessage)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (content *AgentMessages) AddSystem(prompt string) error {
|
|
||||||
if len(content.Messages) != 0 {
|
|
||||||
return errors.New("You can only add a system prompt at the beginning")
|
|
||||||
}
|
|
||||||
|
|
||||||
content.Messages = append(content.Messages, AgentTextMessage{
|
|
||||||
Role: ROLE_SYSTEM,
|
|
||||||
Content: prompt,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentContent interface {
|
|
||||||
ToJson() ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageUrl struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentImage struct {
|
|
||||||
ImageType string `json:"type"`
|
|
||||||
ImageUrl string `json:"image_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (imageMessage AgentImage) ToJson() ([]byte, error) {
|
|
||||||
imageMessage.ImageType = IMAGE_TYPE
|
|
||||||
return json.Marshal(imageMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
type AiClient interface {
|
|
||||||
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentClient struct {
|
|
||||||
url string
|
|
||||||
apiKey string
|
|
||||||
systemPrompt string
|
|
||||||
responseFormat string
|
|
||||||
|
|
||||||
Do func(req *http.Request) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// func (client AgentClient) Do(req *http.Request) () {
|
|
||||||
// httpClient := http.Client{}
|
|
||||||
// return httpClient.Do(req)
|
|
||||||
// }
|
|
||||||
|
|
||||||
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
|
||||||
const ROLE_USER = "user"
|
|
||||||
const ROLE_SYSTEM = "system"
|
|
||||||
const IMAGE_TYPE = "image_url"
|
|
||||||
|
|
||||||
// TODO: extract to text file probably
|
|
||||||
const PROMPT = `
|
|
||||||
You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text
|
|
||||||
that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags.
|
|
||||||
Be sure to extract every link (URL) that you find.
|
|
||||||
Use generic tags.
|
|
||||||
`
|
|
||||||
|
|
||||||
const RESPONSE_FORMAT = `
|
|
||||||
{
|
|
||||||
"name": "image_info",
|
|
||||||
"strict": true,
|
|
||||||
"schema": {
|
|
||||||
"type": "object",
|
|
||||||
"title": "image",
|
|
||||||
"required": ["tags", "text", "links"],
|
|
||||||
"additionalProperties": false,
|
|
||||||
"properties": {
|
|
||||||
"tags": {
|
|
||||||
"type": "array",
|
|
||||||
"title": "tags",
|
|
||||||
"description": "A list of tags you think the image is relevant to.",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"text": {
|
|
||||||
"type": "array",
|
|
||||||
"title": "text",
|
|
||||||
"description": "A list of sentences the image contains.",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"links": {
|
|
||||||
"type": "array",
|
|
||||||
"title": "links",
|
|
||||||
"description": "A list of all the links you can find in the image.",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"locations": {
|
|
||||||
"title": "locations",
|
|
||||||
"type": "array",
|
|
||||||
"description": "A list of locations you can find on the image, if any",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"required": ["name"],
|
|
||||||
"additionalProperties": false,
|
|
||||||
"properties": {
|
|
||||||
"name": {
|
|
||||||
"title": "name",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"coordinates": {
|
|
||||||
"title": "coordinates",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"address": {
|
|
||||||
"title": "address",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"description": {
|
|
||||||
"title": "description",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"events": {
|
|
||||||
"title": "events",
|
|
||||||
"type": "array",
|
|
||||||
"description": "A list of events you find on the image, if any",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"required": ["name"],
|
|
||||||
"additionalProperties": false,
|
|
||||||
"properties": {
|
|
||||||
"name": {
|
|
||||||
"type": "string",
|
|
||||||
"title": "name"
|
|
||||||
},
|
|
||||||
"locations": {
|
|
||||||
"title": "locations",
|
|
||||||
"type": "array",
|
|
||||||
"description": "A list of locations on this event, if any",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"required": ["name"],
|
|
||||||
"additionalProperties": false,
|
|
||||||
"properties": {
|
|
||||||
"name": {
|
|
||||||
"title": "name",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"coordinates": {
|
|
||||||
"title": "coordinates",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"address": {
|
|
||||||
"title": "address",
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"description": {
|
|
||||||
"title": "description",
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`
|
|
||||||
|
|
||||||
func CreateAgentClient(prompt string) (AgentClient, error) {
|
|
||||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
|
||||||
|
|
||||||
if len(apiKey) == 0 {
|
|
||||||
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
|
||||||
}
|
|
||||||
|
|
||||||
return AgentClient{
|
|
||||||
apiKey: apiKey,
|
|
||||||
url: "https://api.mistral.ai/v1/chat/completions",
|
|
||||||
systemPrompt: prompt,
|
|
||||||
Do: func(req *http.Request) (*http.Response, error) {
|
|
||||||
client := &http.Client{}
|
|
||||||
return client.Do(req)
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
|
||||||
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
|
||||||
if err != nil {
|
|
||||||
return req, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) {
|
|
||||||
request := AgentRequestBody{
|
|
||||||
Model: model,
|
|
||||||
Temperature: temperature,
|
|
||||||
ResponseFormat: ResponseFormat{
|
|
||||||
Type: "json_schema",
|
|
||||||
JsonSchema: jsonSchema,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Add build pattern here that deals with errors in some internal state?
|
|
||||||
// I want a monad!!!
|
|
||||||
err := request.AddSystem(prompt)
|
|
||||||
if err != nil {
|
|
||||||
return request, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(request)
|
|
||||||
|
|
||||||
err = request.AddImage(imageName, imageData)
|
|
||||||
if err != nil {
|
|
||||||
return request, err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Tools = nil
|
|
||||||
|
|
||||||
return request, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionCall struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments string `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Index int `json:"index"`
|
|
||||||
Id string `json:"id"`
|
|
||||||
Function FunctionCall `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseChoiceMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseChoice struct {
|
|
||||||
Index int `json:"index"`
|
|
||||||
Message ResponseChoiceMessage `json:"message"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AgentResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Choices []ResponseChoice `json:"choices"`
|
|
||||||
Created int `json:"created"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: add usage parsing
|
|
||||||
func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
|
|
||||||
response := AgentResponse{}
|
|
||||||
|
|
||||||
err := json.Unmarshal(jsonResponse, &response)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(response.Choices) != 1 {
|
|
||||||
log.Println(string(jsonResponse))
|
|
||||||
return ImageInfo{}, errors.New("Expected exactly one choice.")
|
|
||||||
}
|
|
||||||
|
|
||||||
imageInfo := ImageInfo{}
|
|
||||||
err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, errors.New("Could not parse content into image type.")
|
|
||||||
}
|
|
||||||
|
|
||||||
return imageInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) {
|
|
||||||
jsonAiRequest, err := json.Marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return AgentResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
httpRequest, err := client.getRequest(jsonAiRequest)
|
|
||||||
if err != nil {
|
|
||||||
return AgentResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(httpRequest)
|
|
||||||
if err != nil {
|
|
||||||
return AgentResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return AgentResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
agentResponse := AgentResponse{}
|
|
||||||
err = json.Unmarshal(response, &agentResponse)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return AgentResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(string(response))
|
|
||||||
|
|
||||||
toolCalls := agentResponse.Choices[0].Message.ToolCalls
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
// Should for sure be more flexible.
|
|
||||||
request.AddToolCall(AgentAssistantToolCall{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: "",
|
|
||||||
ToolCalls: toolCalls,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return agentResponse, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
|
||||||
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var jsonSchema any
|
|
||||||
err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
aiRequest.ResponseFormat = ResponseFormat{
|
|
||||||
Type: "json_schema",
|
|
||||||
JsonSchema: jsonSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonAiRequest, err := json.Marshal(aiRequest)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
request, err := client.getRequest(jsonAiRequest)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(request)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return ImageInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(string(response))
|
|
||||||
|
|
||||||
return parseAgentResponse(response)
|
|
||||||
}
|
|
@ -1,213 +0,0 @@
|
|||||||
package agents
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMessageBuilder(t *testing.T) {
|
|
||||||
content := AgentMessages{}
|
|
||||||
|
|
||||||
err := content.AddSystem("Some prompt")
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(content.Messages) != 1 {
|
|
||||||
t.Logf("Expected length 1, got %d.\n", len(content.Messages))
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageBuilderImage(t *testing.T) {
|
|
||||||
content := AgentMessages{}
|
|
||||||
|
|
||||||
prompt := "some prompt"
|
|
||||||
imageTitle := "image.png"
|
|
||||||
data := []byte("some data")
|
|
||||||
|
|
||||||
content.AddSystem(prompt)
|
|
||||||
content.AddImage(imageTitle, data)
|
|
||||||
|
|
||||||
if len(content.Messages) != 2 {
|
|
||||||
t.Logf("Expected length 2, got %d.\n", len(content.Messages))
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
promptMessage, ok := content.Messages[0].(AgentTextMessage)
|
|
||||||
if !ok {
|
|
||||||
t.Logf("Expected text content message, got %T\n", content.Messages[0])
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if promptMessage.Role != ROLE_SYSTEM {
|
|
||||||
t.Log("Prompt message role is incorrect.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if promptMessage.Content != prompt {
|
|
||||||
t.Log("Prompt message content is incorrect.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage)
|
|
||||||
if !ok {
|
|
||||||
t.Logf("Expected text content message, got %T\n", content.Messages[1])
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if arrayContentMessage.Role != ROLE_USER {
|
|
||||||
t.Log("Array content message role is incorrect.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(arrayContentMessage.Content) != 1 {
|
|
||||||
t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content))
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
imageContent, ok := arrayContentMessage.Content[0].(AgentImage)
|
|
||||||
if !ok {
|
|
||||||
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
base64data := base64.StdEncoding.EncodeToString(data)
|
|
||||||
url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data)
|
|
||||||
|
|
||||||
if imageContent.ImageUrl != url {
|
|
||||||
t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFullImageRequest(t *testing.T) {
|
|
||||||
request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data"))
|
|
||||||
if err != nil {
|
|
||||||
t.Log(request)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,c29tZSBkYXRh"}}]}]}`
|
|
||||||
|
|
||||||
if string(jsonData) != expectedJson {
|
|
||||||
t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData))
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResponse(t *testing.T) {
|
|
||||||
testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}`
|
|
||||||
buffer := bytes.NewReader([]byte(testResponse))
|
|
||||||
|
|
||||||
body := io.NopCloser(buffer)
|
|
||||||
|
|
||||||
client := AgentClient{
|
|
||||||
url: "http://localhost:1234",
|
|
||||||
apiKey: "some-key",
|
|
||||||
Do: func(_req *http.Request) (*http.Response, error) {
|
|
||||||
return &http.Response{Body: body}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := client.GetImageInfo("image.png", []byte("some data"))
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 {
|
|
||||||
t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links))
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Tags[0] != "tag1" {
|
|
||||||
t.Log("0th tag is wrong.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Tags[1] != "tag2" {
|
|
||||||
t.Log("1th tag is wrong.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if info.Text[0] != "text1" {
|
|
||||||
t.Log("0th text is wrong.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResponseParsing(t *testing.T) {
|
|
||||||
response := `{
|
|
||||||
"id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1740422508,
|
|
||||||
"model": "gpt-4o-mini-2024-07-18",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}",
|
|
||||||
"refusal": null
|
|
||||||
},
|
|
||||||
"logprobs": null,
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 775,
|
|
||||||
"completion_tokens": 33,
|
|
||||||
"total_tokens": 808,
|
|
||||||
"prompt_tokens_details": {
|
|
||||||
"cached_tokens": 0,
|
|
||||||
"audio_tokens": 0
|
|
||||||
},
|
|
||||||
"completion_tokens_details": {
|
|
||||||
"reasoning_tokens": 0,
|
|
||||||
"audio_tokens": 0,
|
|
||||||
"accepted_prediction_tokens": 0,
|
|
||||||
"rejected_prediction_tokens": 0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"service_tier": "default",
|
|
||||||
"system_fingerprint": "fp_7fcd609668"
|
|
||||||
}`
|
|
||||||
|
|
||||||
imageParsed, err := parseAgentResponse([]byte(response))
|
|
||||||
if err != nil {
|
|
||||||
t.Log(err)
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" {
|
|
||||||
t.Log(imageParsed)
|
|
||||||
t.Log("Should have one link called 'link'.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" {
|
|
||||||
t.Log(imageParsed)
|
|
||||||
t.Log("Should have one tag called 'tag'.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" {
|
|
||||||
t.Log(imageParsed)
|
|
||||||
t.Log("Should have one text called 'text'.")
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
}
|
|
208
backend/agents/client/chat.go
Normal file
208
backend/agents/client/chat.go
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Chat struct {
|
||||||
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessage interface {
|
||||||
|
IsResponse() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: the role could be inferred from the type.
|
||||||
|
// This would solve some bugs.
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
Is there a world where this actually becomes the product?
|
||||||
|
Where we build such a resilient system of AI calls that we
|
||||||
|
can build some app builder, or even just an API system,
|
||||||
|
with a fancy UI?
|
||||||
|
|
||||||
|
Manage all the complexity for the user?
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
// =============================================
|
||||||
|
// Messages from us to the AI.
|
||||||
|
// =============================================
|
||||||
|
|
||||||
|
type UserRole = string
|
||||||
|
|
||||||
|
const (
|
||||||
|
User UserRole = "user"
|
||||||
|
System UserRole = "system"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ToolRole = string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Tool ToolRole = "tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatUserMessage struct {
|
||||||
|
Role UserRole `json:"role"`
|
||||||
|
|
||||||
|
MessageContent `json:"MessageContent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
|
||||||
|
switch t := m.MessageContent.(type) {
|
||||||
|
case SingleMessage:
|
||||||
|
return json.Marshal(&struct {
|
||||||
|
Role UserRole `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}{
|
||||||
|
Role: User,
|
||||||
|
Content: t.Content,
|
||||||
|
})
|
||||||
|
case ArrayMessage:
|
||||||
|
return json.Marshal(&struct {
|
||||||
|
Role UserRole `json:"role"`
|
||||||
|
Content []ImageMessageContent `json:"content"`
|
||||||
|
}{
|
||||||
|
Role: User,
|
||||||
|
Content: t.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte{}, errors.New("Unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ChatUserMessage) IsResponse() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatUserToolResponse struct {
|
||||||
|
Role ToolRole `json:"role"`
|
||||||
|
|
||||||
|
// The name of the function we are responding to.
|
||||||
|
Name string `json:"name"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ToolCallId string `json:"tool_call_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ChatUserToolResponse) IsResponse() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatAiMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
ToolCalls *[]ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ChatAiMessage) IsResponse() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================
|
||||||
|
// Unique interface for message content.
|
||||||
|
// =============================================
|
||||||
|
|
||||||
|
type MessageContent interface {
|
||||||
|
IsSingleMessage() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type SingleMessage struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m SingleMessage) IsSingleMessage() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type ArrayMessage struct {
|
||||||
|
Content []ImageMessageContent `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ArrayMessage) IsSingleMessage() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageMessageContent struct {
|
||||||
|
ImageType string `json:"type"`
|
||||||
|
ImageUrl string `json:"image_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageContentUrl struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================
|
||||||
|
// Adjacent interfaces.
|
||||||
|
// =============================================
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Id string `json:"id"`
|
||||||
|
Function FunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================
|
||||||
|
// Chat methods
|
||||||
|
// =============================================
|
||||||
|
|
||||||
|
func (chat *Chat) AddSystem(prompt string) {
|
||||||
|
chat.Messages = append(chat.Messages, ChatUserMessage{
|
||||||
|
Role: System,
|
||||||
|
MessageContent: SingleMessage{
|
||||||
|
Content: prompt,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chat *Chat) AddImage(imageName string, image []byte) error {
|
||||||
|
extension := filepath.Ext(imageName)
|
||||||
|
if len(extension) == 0 {
|
||||||
|
// TODO: could also validate for image types we support.
|
||||||
|
return errors.New("Image does not have extension")
|
||||||
|
}
|
||||||
|
|
||||||
|
extension = extension[1:]
|
||||||
|
|
||||||
|
encodedString := base64.StdEncoding.EncodeToString(image)
|
||||||
|
|
||||||
|
messageContent := ArrayMessage{
|
||||||
|
Content: make([]ImageMessageContent, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
messageContent.Content[0] = ImageMessageContent{
|
||||||
|
ImageType: "image_url",
|
||||||
|
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayMessage := ChatUserMessage{Role: User, MessageContent: messageContent}
|
||||||
|
chat.Messages = append(chat.Messages, arrayMessage)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chat *Chat) AddAiResponse(res ChatAiMessage) {
|
||||||
|
chat.Messages = append(chat.Messages, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chat *Chat) AddToolResponse(res ChatUserToolResponse) {
|
||||||
|
chat.Messages = append(chat.Messages, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chat Chat) GetLatest() (ChatMessage, error) {
|
||||||
|
if len(chat.Messages) == 0 {
|
||||||
|
return nil, errors.New("Not enough messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
return chat.Messages[len(chat.Messages)-1], nil
|
||||||
|
}
|
24
backend/agents/client/chat_test.go
Normal file
24
backend/agents/client/chat_test.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlatMarshallSingleMessage(t *testing.T) {
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
message := ChatUserMessage{
|
||||||
|
Role: User,
|
||||||
|
MessageContent: SingleMessage{
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
json, err := json.Marshal(message)
|
||||||
|
require.NoError(err)
|
||||||
|
|
||||||
|
require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}")
|
||||||
|
}
|
194
backend/agents/client/client.go
Normal file
194
backend/agents/client/client.go
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResponseFormat struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
JsonSchema any `json:"json_schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AgentRequestBody struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
ResponseFormat ResponseFormat `json:"response_format"`
|
||||||
|
|
||||||
|
Tools *any `json:"tools,omitempty"`
|
||||||
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||||
|
|
||||||
|
EndToolCall string `json:"-"`
|
||||||
|
|
||||||
|
Chat *Chat `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(&struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
ResponseFormat ResponseFormat `json:"response_format"`
|
||||||
|
|
||||||
|
Tools *any `json:"tools,omitempty"`
|
||||||
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||||
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
}{
|
||||||
|
Model: req.Model,
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
ResponseFormat: req.ResponseFormat,
|
||||||
|
|
||||||
|
Tools: req.Tools,
|
||||||
|
ToolChoice: req.ToolChoice,
|
||||||
|
|
||||||
|
Messages: req.Chat.Messages,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message ChatAiMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AgentResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Choices []ResponseChoice `json:"choices"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AgentClient struct {
|
||||||
|
url string
|
||||||
|
apiKey string
|
||||||
|
responseFormat string
|
||||||
|
|
||||||
|
ToolHandler ToolsHandlers
|
||||||
|
|
||||||
|
Do func(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||||
|
|
||||||
|
func CreateAgentClient() (AgentClient, error) {
|
||||||
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||||
|
|
||||||
|
if len(apiKey) == 0 {
|
||||||
|
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return AgentClient{
|
||||||
|
apiKey: apiKey,
|
||||||
|
url: "https://api.mistral.ai/v1/chat/completions",
|
||||||
|
Do: func(req *http.Request) (*http.Response, error) {
|
||||||
|
client := &http.Client{}
|
||||||
|
return client.Do(req)
|
||||||
|
},
|
||||||
|
|
||||||
|
ToolHandler: ToolsHandlers{
|
||||||
|
handlers: map[string]ToolHandler{},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||||
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
|
||||||
|
jsonAiRequest, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpRequest, err := client.getRequest(jsonAiRequest)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(httpRequest)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(string(response))
|
||||||
|
|
||||||
|
agentResponse := AgentResponse{}
|
||||||
|
err = json.Unmarshal(response, &agentResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return AgentResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(agentResponse.Choices) != 1 {
|
||||||
|
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Chat.AddAiResponse(agentResponse.Choices[0].Message)
|
||||||
|
|
||||||
|
return agentResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||||
|
for {
|
||||||
|
err := client.Process(info, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Request(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var FinishedCall = errors.New("Last tool tool was called")
|
||||||
|
|
||||||
|
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
message, err := req.Chat.GetLatest()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
aiMessage, ok := message.(ChatAiMessage)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("Latest message isnt an AI message")
|
||||||
|
}
|
||||||
|
|
||||||
|
if aiMessage.ToolCalls == nil {
|
||||||
|
// Not an error, we just dont have any tool calls to process.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, toolCall := range *aiMessage.ToolCalls {
|
||||||
|
if toolCall.Function.Name == req.EndToolCall {
|
||||||
|
return FinishedCall
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
||||||
|
|
||||||
|
req.Chat.AddToolResponse(toolResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
70
backend/agents/client/tools.go
Normal file
70
backend/agents/client/tools.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ToolHandlerInfo struct {
|
||||||
|
UserId uuid.UUID
|
||||||
|
ImageId uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolHandler struct {
|
||||||
|
Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolsHandlers struct {
|
||||||
|
handlers map[string]ToolHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.")
|
||||||
|
|
||||||
|
const NonExistantTool = "This tool does not exist"
|
||||||
|
|
||||||
|
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse {
|
||||||
|
fnName := toolCallMessage.Function.Name
|
||||||
|
arguments := toolCallMessage.Function.Arguments
|
||||||
|
|
||||||
|
responseMessage := ChatUserToolResponse{
|
||||||
|
Role: "tool",
|
||||||
|
Name: fnName,
|
||||||
|
ToolCallId: toolCallMessage.Id,
|
||||||
|
}
|
||||||
|
|
||||||
|
fnHandler, exists := handler.handlers[fnName]
|
||||||
|
if !exists {
|
||||||
|
responseMessage.Content = NonExistantTool
|
||||||
|
return responseMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := fnHandler.Fn(info, arguments, toolCallMessage)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
responseMessage.Content = err.Error()
|
||||||
|
} else {
|
||||||
|
responseMessage.Content = res
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
||||||
|
handler.handlers[name] = ToolHandler{
|
||||||
|
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
|
||||||
|
res, err := fn(info, args, call)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
marshalledRes, err := json.Marshal(res)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(marshalledRes), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
188
backend/agents/client/tools_test.go
Normal file
188
backend/agents/client/tools_test.go
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ToolTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
|
||||||
|
handler ToolsHandlers
|
||||||
|
client AgentClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ToolTestSuite) SetupTest() {
|
||||||
|
suite.handler = ToolsHandlers{
|
||||||
|
handlers: map[string]ToolHandler{},
|
||||||
|
}
|
||||||
|
|
||||||
|
suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
|
||||||
|
return args, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
|
||||||
|
return false, errors.New("I will always error")
|
||||||
|
})
|
||||||
|
|
||||||
|
suite.client.ToolHandler = suite.handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ToolTestSuite) TestSingleToolCall() {
|
||||||
|
require := suite.Require()
|
||||||
|
|
||||||
|
response := suite.handler.Handle(
|
||||||
|
ToolHandlerInfo{
|
||||||
|
UserId: uuid.Nil,
|
||||||
|
ImageId: uuid.Nil,
|
||||||
|
},
|
||||||
|
ToolCall{
|
||||||
|
Index: 0,
|
||||||
|
Id: "1",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "a",
|
||||||
|
Arguments: "return",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.EqualValues(response, ChatUserToolResponse{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "\"return\"",
|
||||||
|
ToolCallId: "1",
|
||||||
|
Name: "a",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ToolTestSuite) TestMultipleToolCalls() {
|
||||||
|
assert := suite.Assert()
|
||||||
|
require := suite.Require()
|
||||||
|
|
||||||
|
chat := Chat{
|
||||||
|
Messages: []ChatMessage{ChatAiMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: &[]ToolCall{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Id: "1",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "a",
|
||||||
|
Arguments: "first-call",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Index: 1,
|
||||||
|
Id: "2",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "a",
|
||||||
|
Arguments: "second-call",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := suite.client.Process(
|
||||||
|
ToolHandlerInfo{
|
||||||
|
UserId: uuid.Nil,
|
||||||
|
ImageId: uuid.Nil,
|
||||||
|
},
|
||||||
|
&AgentRequestBody{
|
||||||
|
Chat: &chat,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(err, "Tool call shouldnt return an error")
|
||||||
|
assert.EqualValues(chat.Messages[1:], []ChatMessage{
|
||||||
|
ChatUserToolResponse{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "\"first-call\"",
|
||||||
|
ToolCallId: "1",
|
||||||
|
Name: "a",
|
||||||
|
},
|
||||||
|
ChatUserToolResponse{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "\"second-call\"",
|
||||||
|
ToolCallId: "2",
|
||||||
|
Name: "a",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
|
||||||
|
assert := suite.Assert()
|
||||||
|
require := suite.Require()
|
||||||
|
|
||||||
|
chat := Chat{
|
||||||
|
Messages: []ChatMessage{ChatAiMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: &[]ToolCall{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Id: "1",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "error",
|
||||||
|
Arguments: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Index: 1,
|
||||||
|
Id: "2",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "non-existant",
|
||||||
|
Arguments: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Index: 2,
|
||||||
|
Id: "3",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "a",
|
||||||
|
Arguments: "no-error",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := suite.client.Process(
|
||||||
|
ToolHandlerInfo{
|
||||||
|
UserId: uuid.Nil,
|
||||||
|
ImageId: uuid.Nil,
|
||||||
|
},
|
||||||
|
&AgentRequestBody{
|
||||||
|
Chat: &chat,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(err, "Tool call shouldnt return an error")
|
||||||
|
|
||||||
|
assert.EqualValues(chat.Messages[1:], []ChatMessage{
|
||||||
|
ChatUserToolResponse{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "I will always error",
|
||||||
|
ToolCallId: "1",
|
||||||
|
Name: "error",
|
||||||
|
},
|
||||||
|
ChatUserToolResponse{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "This tool does not exist",
|
||||||
|
ToolCallId: "2",
|
||||||
|
Name: "non-existant",
|
||||||
|
},
|
||||||
|
ChatUserToolResponse{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "\"no-error\"",
|
||||||
|
ToolCallId: "3",
|
||||||
|
Name: "a",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolSuite(t *testing.T) {
|
||||||
|
suite.Run(t, &ToolTestSuite{
|
||||||
|
client: AgentClient{},
|
||||||
|
})
|
||||||
|
}
|
@ -3,10 +3,8 @@ package agents
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"reflect"
|
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -105,24 +103,21 @@ const TOOLS = `
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "finish",
|
"name": "finish",
|
||||||
"description": "Nothing else to do, call this function.",
|
"description": "Nothing else to do. call this function.",
|
||||||
"parameters": {
|
"parameters": {}
|
||||||
"type": "object",
|
|
||||||
"properties": {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
`
|
`
|
||||||
|
|
||||||
type EventLocationAgent struct {
|
type EventLocationAgent struct {
|
||||||
client AgentClient
|
client client.AgentClient
|
||||||
|
|
||||||
eventModel models.EventModel
|
eventModel models.EventModel
|
||||||
locationModel models.LocationModel
|
locationModel models.LocationModel
|
||||||
contactModel models.ContactModel
|
contactModel models.ContactModel
|
||||||
|
|
||||||
toolHandler ToolsHandlers
|
toolHandler client.ToolsHandlers
|
||||||
}
|
}
|
||||||
|
|
||||||
type ListLocationArguments struct{}
|
type ListLocationArguments struct{}
|
||||||
@ -158,157 +153,66 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
|||||||
|
|
||||||
toolChoice := "any"
|
toolChoice := "any"
|
||||||
|
|
||||||
request := AgentRequestBody{
|
request := client.AgentRequestBody{
|
||||||
Tools: &tools,
|
Tools: &tools,
|
||||||
ToolChoice: &toolChoice,
|
ToolChoice: &toolChoice,
|
||||||
Model: "pixtral-12b-2409",
|
Model: "pixtral-12b-2409",
|
||||||
Temperature: 0.3,
|
Temperature: 0.3,
|
||||||
ResponseFormat: ResponseFormat{
|
EndToolCall: "finish",
|
||||||
|
ResponseFormat: client.ResponseFormat{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
},
|
},
|
||||||
|
Chat: &client.Chat{
|
||||||
|
Messages: make([]client.ChatMessage, 0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = request.AddSystem(eventLocationPrompt)
|
request.Chat.AddSystem(eventLocationPrompt)
|
||||||
if err != nil {
|
request.Chat.AddImage(imageName, imageData)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.AddImage(imageName, imageData)
|
|
||||||
|
|
||||||
_, err = agent.client.Request(&request)
|
_, err = agent.client.Request(&request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
toolHandlerInfo := ToolHandlerInfo{
|
toolHandlerInfo := client.ToolHandlerInfo{
|
||||||
imageId: imageId,
|
ImageId: imageId,
|
||||||
userId: userId,
|
UserId: userId,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
|
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
||||||
|
|
||||||
for err == nil {
|
|
||||||
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
|
||||||
|
|
||||||
response, requestError := agent.client.Request(&request)
|
|
||||||
if requestError != nil {
|
|
||||||
return requestError
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(response)
|
|
||||||
|
|
||||||
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
|
|
||||||
|
|
||||||
err = innerErr
|
|
||||||
|
|
||||||
log.Println(a)
|
|
||||||
log.Println("--------------------------")
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: extract this into a more general tool handler package.
|
|
||||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
|
|
||||||
agentMessage := request.Messages[len(request.Messages)-1]
|
|
||||||
|
|
||||||
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
|
||||||
if !ok {
|
|
||||||
return "", errors.New("Latest message was not a tool call.")
|
|
||||||
}
|
|
||||||
|
|
||||||
fnName := toolCall.ToolCalls[0].Function.Name
|
|
||||||
arguments := toolCall.ToolCalls[0].Function.Arguments
|
|
||||||
|
|
||||||
if fnName == "finish" {
|
|
||||||
return "", errors.New("This is the end! Maybe we just return a boolean.")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn, exists := handler.Handlers[fnName]
|
|
||||||
if !exists {
|
|
||||||
return "", errors.New("Could not find tool with this name.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// holy jesus what the fuck.
|
|
||||||
parseMethod := reflect.ValueOf(fn).Field(1)
|
|
||||||
if !parseMethod.IsValid() {
|
|
||||||
return "", errors.New("Parse method not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
|
|
||||||
if !parsedArgs[1].IsNil() {
|
|
||||||
return "", parsedArgs[1].Interface().(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Calling: %s\n", fnName)
|
|
||||||
|
|
||||||
fnMethod := reflect.ValueOf(fn).Field(2)
|
|
||||||
if !fnMethod.IsValid() {
|
|
||||||
return "", errors.New("Fn method not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
|
|
||||||
if !response[1].IsNil() {
|
|
||||||
return "", response[1].Interface().(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
stringResponse, err := json.Marshal(response[0].Interface())
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.AddText(AgentTextMessage{
|
|
||||||
Role: "tool",
|
|
||||||
Name: "createLocation",
|
|
||||||
Content: string(stringResponse),
|
|
||||||
ToolCallId: toolCall.ToolCalls[0].Id,
|
|
||||||
})
|
|
||||||
|
|
||||||
return string(stringResponse), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
||||||
client, err := CreateAgentClient(eventLocationPrompt)
|
agentClient, err := client.CreateAgentClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return EventLocationAgent{}, err
|
return EventLocationAgent{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
agent := EventLocationAgent{
|
agent := EventLocationAgent{
|
||||||
client: client,
|
client: agentClient,
|
||||||
locationModel: locationModel,
|
locationModel: locationModel,
|
||||||
eventModel: eventModel,
|
eventModel: eventModel,
|
||||||
contactModel: contactModel,
|
contactModel: contactModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
toolHandler := ToolsHandlers{
|
agentClient.ToolHandler.AddTool("listLocations",
|
||||||
Handlers: make(map[string]ToolHandlerInterface),
|
func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
}
|
return agent.locationModel.List(context.Background(), info.UserId)
|
||||||
|
|
||||||
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
|
|
||||||
FunctionName: "listLocations",
|
|
||||||
Parse: func(stringArgs string) (ListLocationArguments, error) {
|
|
||||||
args := ListLocationArguments{}
|
|
||||||
err := json.Unmarshal([]byte(stringArgs), &args)
|
|
||||||
|
|
||||||
return args, err
|
|
||||||
},
|
},
|
||||||
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
|
)
|
||||||
return agent.locationModel.List(context.Background(), info.userId)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
|
agentClient.ToolHandler.AddTool("createLocation",
|
||||||
FunctionName: "createLocation",
|
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||||
Parse: func(stringArgs string) (CreateLocationArguments, error) {
|
|
||||||
args := CreateLocationArguments{}
|
args := CreateLocationArguments{}
|
||||||
err := json.Unmarshal([]byte(stringArgs), &args)
|
err := json.Unmarshal([]byte(_args), &args)
|
||||||
|
if err != nil {
|
||||||
|
return model.Locations{}, err
|
||||||
|
}
|
||||||
|
|
||||||
return args, err
|
|
||||||
},
|
|
||||||
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
location, err := agent.locationModel.Save(ctx, info.userId, model.Locations{
|
location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{
|
||||||
Name: args.Name,
|
Name: args.Name,
|
||||||
Address: args.Address,
|
Address: args.Address,
|
||||||
})
|
})
|
||||||
@ -317,21 +221,20 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
|||||||
return location, err
|
return location, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.locationModel.SaveToImage(ctx, info.imageId, location.ID)
|
_, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID)
|
||||||
|
|
||||||
return location, err
|
return location, err
|
||||||
},
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
agentClient.ToolHandler.AddTool("createEvent",
|
||||||
|
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||||
|
args := CreateEventArguments{}
|
||||||
|
err := json.Unmarshal([]byte(_args), &args)
|
||||||
|
if err != nil {
|
||||||
|
return model.Locations{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{
|
|
||||||
FunctionName: "createEvent",
|
|
||||||
Parse: func(stringArgs string) (CreateEventArguments, error) {
|
|
||||||
args := CreateEventArguments{}
|
|
||||||
err := json.Unmarshal([]byte(stringArgs), &args)
|
|
||||||
|
|
||||||
return args, err
|
|
||||||
},
|
|
||||||
Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) {
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
layout := "2006-01-02T15:04:05Z"
|
layout := "2006-01-02T15:04:05Z"
|
||||||
@ -346,7 +249,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
|||||||
return model.Events{}, err
|
return model.Events{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
event, err := agent.eventModel.Save(ctx, info.userId, model.Events{
|
event, err := agent.eventModel.Save(ctx, info.UserId, model.Events{
|
||||||
Name: args.Name,
|
Name: args.Name,
|
||||||
StartDateTime: &startTime,
|
StartDateTime: &startTime,
|
||||||
EndDateTime: &endTime,
|
EndDateTime: &endTime,
|
||||||
@ -356,7 +259,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
|||||||
return event, err
|
return event, err
|
||||||
}
|
}
|
||||||
|
|
||||||
organizer, err := agent.contactModel.Save(ctx, info.userId, model.Contacts{
|
organizer, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{
|
||||||
Name: args.Name,
|
Name: args.Name,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -364,12 +267,12 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
|||||||
return event, err
|
return event, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.eventModel.SaveToImage(ctx, info.imageId, event.ID)
|
_, err = agent.eventModel.SaveToImage(ctx, info.ImageId, event.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return event, err
|
return event, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = agent.contactModel.SaveToImage(ctx, info.imageId, organizer.ID)
|
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, organizer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return event, err
|
return event, err
|
||||||
}
|
}
|
||||||
@ -386,9 +289,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
|||||||
|
|
||||||
return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID)
|
return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID)
|
||||||
},
|
},
|
||||||
}
|
)
|
||||||
|
|
||||||
agent.toolHandler = toolHandler
|
|
||||||
|
|
||||||
return agent, nil
|
return agent, nil
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package agents
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -19,26 +20,26 @@ Do not return anything except markdown.
|
|||||||
`
|
`
|
||||||
|
|
||||||
type NoteAgent struct {
|
type NoteAgent struct {
|
||||||
client AgentClient
|
client client.AgentClient
|
||||||
|
|
||||||
noteModel models.NoteModel
|
noteModel models.NoteModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
request := AgentRequestBody{
|
request := client.AgentRequestBody{
|
||||||
Model: "pixtral-12b-2409",
|
Model: "pixtral-12b-2409",
|
||||||
Temperature: 0.3,
|
Temperature: 0.3,
|
||||||
ResponseFormat: ResponseFormat{
|
ResponseFormat: client.ResponseFormat{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
},
|
},
|
||||||
|
Chat: &client.Chat{
|
||||||
|
Messages: make([]client.ChatMessage, 0),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := request.AddSystem(noteAgentPrompt)
|
request.Chat.AddSystem(noteAgentPrompt)
|
||||||
if err != nil {
|
request.Chat.AddImage(imageName, imageData)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.AddImage(imageName, imageData)
|
|
||||||
resp, err := agent.client.Request(&request)
|
resp, err := agent.client.Request(&request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -65,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
|
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
|
||||||
client, err := CreateAgentClient(noteAgentPrompt)
|
client, err := client.CreateAgentClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return NoteAgent{}, err
|
return NoteAgent{}, err
|
||||||
}
|
}
|
||||||
|
167
backend/agents/orchestrator.go
Normal file
167
backend/agents/orchestrator.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
package agents
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"screenmark/screenmark/agents/client"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const orchestratorPrompt = `
|
||||||
|
You are an Orchestrator for various AI agents.
|
||||||
|
|
||||||
|
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
||||||
|
|
||||||
|
You might decide no agent needs to be called.
|
||||||
|
|
||||||
|
The agents are available as tool calls.
|
||||||
|
|
||||||
|
Agents available:
|
||||||
|
|
||||||
|
eventLocationAgent
|
||||||
|
|
||||||
|
Use it when you think the image contains an event or a location of any sort. This can be an event page, a map, an address or a date.
|
||||||
|
|
||||||
|
noteAgent
|
||||||
|
|
||||||
|
Use it when there is text on the screen. Any text, always use this. Use me!
|
||||||
|
|
||||||
|
defaultAgent
|
||||||
|
|
||||||
|
When none of the above apply.
|
||||||
|
|
||||||
|
Always call agents in parallel if you need to call more than 1.
|
||||||
|
`
|
||||||
|
|
||||||
|
const MY_TOOLS = `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "eventLocationAgent",
|
||||||
|
"description": "Uses the event location agent",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "noteAgent",
|
||||||
|
"description": "Uses the note agent",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "defaultAgent",
|
||||||
|
"description": "Used when you dont think its a good idea to call other agents",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]`
|
||||||
|
|
||||||
|
type OrchestratorAgent struct {
|
||||||
|
client client.AgentClient
|
||||||
|
}
|
||||||
|
|
||||||
|
type Status struct {
|
||||||
|
Ok bool `json:"ok"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: the primary function of the agent could be extracted outwards.
|
||||||
|
// This is basically the same function as we have in the `event_location_agent.go`
|
||||||
|
func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
|
toolChoice := "any"
|
||||||
|
|
||||||
|
var tools any
|
||||||
|
err := json.Unmarshal([]byte(MY_TOOLS), &tools)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
request := client.AgentRequestBody{
|
||||||
|
Model: "pixtral-12b-2409",
|
||||||
|
Temperature: 0.3,
|
||||||
|
ResponseFormat: client.ResponseFormat{
|
||||||
|
Type: "text",
|
||||||
|
},
|
||||||
|
ToolChoice: &toolChoice,
|
||||||
|
Tools: &tools,
|
||||||
|
|
||||||
|
EndToolCall: "defaultAgent",
|
||||||
|
|
||||||
|
Chat: &client.Chat{
|
||||||
|
Messages: make([]client.ChatMessage, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Chat.AddSystem(orchestratorPrompt)
|
||||||
|
request.Chat.AddImage(imageName, imageData)
|
||||||
|
|
||||||
|
res, err := agent.client.Request(&request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(res)
|
||||||
|
|
||||||
|
toolHandlerInfo := client.ToolHandlerInfo{
|
||||||
|
ImageId: imageId,
|
||||||
|
UserId: userId,
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
||||||
|
agent, err := client.CreateAgentClient()
|
||||||
|
if err != nil {
|
||||||
|
return OrchestratorAgent{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
agent.ToolHandler.AddTool("eventLocationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
|
// We need a way to keep track of this async?
|
||||||
|
// Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish.
|
||||||
|
|
||||||
|
go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
|
return Status{
|
||||||
|
Ok: true,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
|
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
|
return Status{
|
||||||
|
Ok: true,
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
agent.ToolHandler.AddTool("defaultAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
|
// To nothing
|
||||||
|
|
||||||
|
return Status{
|
||||||
|
Ok: true,
|
||||||
|
}, errors.New("Finished! Kinda bad return type but...")
|
||||||
|
})
|
||||||
|
|
||||||
|
return OrchestratorAgent{
|
||||||
|
client: agent,
|
||||||
|
}, nil
|
||||||
|
}
|
@ -1,26 +0,0 @@
|
|||||||
package agents
|
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
|
||||||
|
|
||||||
type ToolHandlerInfo struct {
|
|
||||||
userId uuid.UUID
|
|
||||||
imageId uuid.UUID
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolHandler[TArgs any, TResp any] struct {
|
|
||||||
FunctionName string
|
|
||||||
Parse func(args string) (TArgs, error)
|
|
||||||
Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolHandlerInterface interface {
|
|
||||||
GetFunctionName() string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string {
|
|
||||||
return handler.FunctionName
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolsHandlers struct {
|
|
||||||
Handlers map[string]ToolHandlerInterface
|
|
||||||
}
|
|
@ -10,6 +10,6 @@ require (
|
|||||||
github.com/joho/godotenv v1.5.1 // indirect
|
github.com/joho/godotenv v1.5.1 // indirect
|
||||||
github.com/lib/pq v1.10.9 // indirect
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/stretchr/testify v1.9.0 // indirect
|
github.com/stretchr/testify v1.10.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
@ -14,6 +14,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/agents"
|
"screenmark/screenmark/agents"
|
||||||
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -24,41 +25,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TestAiClient struct {
|
type TestAiClient struct {
|
||||||
ImageInfo agents.ImageInfo
|
ImageInfo client.ImageMessageContent
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) {
|
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageMessageContent, error) {
|
||||||
return client.ImageInfo, nil
|
return client.ImageInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAiClient() (agents.AiClient, error) {
|
|
||||||
mode := os.Getenv("MODE")
|
|
||||||
if mode == "TESTING" {
|
|
||||||
address := "10 Downing Street"
|
|
||||||
description := "Cheese and Crackers"
|
|
||||||
|
|
||||||
return TestAiClient{
|
|
||||||
ImageInfo: agents.ImageInfo{
|
|
||||||
Tags: []string{"tag"},
|
|
||||||
Links: []string{"links"},
|
|
||||||
Text: []string{"text"},
|
|
||||||
Locations: []model.Locations{{
|
|
||||||
ID: uuid.Nil,
|
|
||||||
Name: "London",
|
|
||||||
Address: &address,
|
|
||||||
}},
|
|
||||||
Events: []model.Events{{
|
|
||||||
ID: uuid.Nil,
|
|
||||||
Name: "Party",
|
|
||||||
Description: &description,
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return agents.CreateAgentClient(agents.PROMPT)
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
err := godotenv.Load()
|
err := godotenv.Load()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -74,9 +47,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
imageModel := models.NewImageModel(db)
|
imageModel := models.NewImageModel(db)
|
||||||
linkModel := models.NewLinkModel(db)
|
|
||||||
tagModel := models.NewTagModel(db)
|
|
||||||
textModel := models.NewTextModel(db)
|
|
||||||
locationModel := models.NewLocationModel(db)
|
locationModel := models.NewLocationModel(db)
|
||||||
eventModel := models.NewEventModel(db)
|
eventModel := models.NewEventModel(db)
|
||||||
userModel := models.NewUserModel(db)
|
userModel := models.NewUserModel(db)
|
||||||
@ -105,11 +75,6 @@ func main() {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
openAiClient, err := GetAiClient()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel)
|
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@ -127,51 +92,21 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
|
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to FinishProcessing")
|
log.Println("Failed to FinishProcessing")
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this can very much be parallel
|
orchestrator, err := agents.NewOrchestratorAgent(locationAgent, noteAgent, image.Image.ImageName, image.Image.Image)
|
||||||
|
|
||||||
log.Println("Calling locationAgent!")
|
|
||||||
err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
|
||||||
log.Println(err)
|
|
||||||
|
|
||||||
log.Println("Calling noteAgent!")
|
|
||||||
err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
|
||||||
log.Println(err)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to GetImageInfo")
|
panic(err)
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)
|
err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to save tags")
|
fmt.Println(err)
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Failed to save links")
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Failed to save text")
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user