From 7f96d2fc4596e729e2519cfc368ea4026a2d6606 Mon Sep 17 00:00:00 2001 From: John Costa Date: Fri, 21 Mar 2025 17:07:00 +0000 Subject: [PATCH] refactor(naming): using `Agent` instead of `openai` --- backend/{openai.go => agents/agent.go} | 64 +++++++++---------- .../{openai_test.go => agents/agent_test.go} | 14 ++-- 2 files changed, 39 insertions(+), 39 deletions(-) rename backend/{openai.go => agents/agent.go} (86%) rename backend/{openai_test.go => agents/agent_test.go} (93%) diff --git a/backend/openai.go b/backend/agents/agent.go similarity index 86% rename from backend/openai.go rename to backend/agents/agent.go index 0eab6af..346cfe3 100644 --- a/backend/openai.go +++ b/backend/agents/agent.go @@ -28,42 +28,42 @@ type ResponseFormat struct { JsonSchema any `json:"json_schema"` } -type OpenAiRequestBody struct { +type AgentRequestBody struct { Model string `json:"model"` Temperature float64 `json:"temperature"` ResponseFormat ResponseFormat `json:"response_format"` - OpenAiMessages + AgentMessages } -type OpenAiMessages struct { - Messages []OpenAiMessage `json:"messages"` +type AgentMessages struct { + Messages []AgentMessage `json:"messages"` } -type OpenAiMessage interface { +type AgentMessage interface { MessageToJson() ([]byte, error) } -type OpenAiTextMessage struct { +type AgentTextMessage struct { Role string `json:"role"` Content string `json:"content"` } -func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) { +func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { // TODO: Validate the `Role`. return json.Marshal(textContent) } -type OpenAiArrayMessage struct { - Role string `json:"role"` - Content []OpenAiContent `json:"content"` +type AgentArrayMessage struct { + Role string `json:"role"` + Content []AgentContent `json:"content"` } -func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) { +func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { return json.Marshal(arrayContent) } -func (content *OpenAiMessages) AddImage(imageName string, image []byte) error { +func (content *AgentMessages) AddImage(imageName string, image []byte) error { extension := filepath.Ext(imageName) if len(extension) == 0 { // TODO: could also validate for image types we support. @@ -74,8 +74,8 @@ func (content *OpenAiMessages) AddImage(imageName string, image []byte) error { encodedString := base64.StdEncoding.EncodeToString(image) - arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)} - arrayMessage.Content[0] = OpenAiImage{ + arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} + arrayMessage.Content[0] = AgentImage{ ImageType: IMAGE_TYPE, ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), } @@ -85,12 +85,12 @@ func (content *OpenAiMessages) AddImage(imageName string, image []byte) error { return nil } -func (content *OpenAiMessages) AddSystem(prompt string) error { +func (content *AgentMessages) AddSystem(prompt string) error { if len(content.Messages) != 0 { return errors.New("You can only add a system prompt at the beginning") } - content.Messages = append(content.Messages, OpenAiTextMessage{ + content.Messages = append(content.Messages, AgentTextMessage{ Role: ROLE_SYSTEM, Content: prompt, }) @@ -98,7 +98,7 @@ func (content *OpenAiMessages) AddSystem(prompt string) error { return nil } -type OpenAiContent interface { +type AgentContent interface { ToJson() ([]byte, error) } @@ -106,12 +106,12 @@ type ImageUrl struct { Url string `json:"url"` } -type OpenAiImage struct { +type AgentImage struct { ImageType string `json:"type"` ImageUrl string `json:"image_url"` } -func (imageMessage OpenAiImage) ToJson() ([]byte, error) { +func (imageMessage AgentImage) ToJson() ([]byte, error) { imageMessage.ImageType = IMAGE_TYPE return json.Marshal(imageMessage) } @@ -120,7 +120,7 @@ type AiClient interface { GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) } -type OpenAiClient struct { +type AgentClient struct { url string apiKey string systemPrompt string @@ -129,7 +129,7 @@ type OpenAiClient struct { Do func(req *http.Request) (*http.Response, error) } -// func (client OpenAiClient) Do(req *http.Request) () { +// func (client AgentClient) Do(req *http.Request) () { // httpClient := http.Client{} // return httpClient.Do(req) // } @@ -262,14 +262,14 @@ const RESPONSE_FORMAT = ` } ` -func CreateOpenAiClient() (OpenAiClient, error) { +func CreateAgentClient() (AgentClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) if len(apiKey) == 0 { - return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.") + return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") } - return OpenAiClient{ + return AgentClient{ apiKey: apiKey, url: "https://api.mistral.ai/v1/chat/completions", systemPrompt: PROMPT, @@ -280,7 +280,7 @@ func CreateOpenAiClient() (OpenAiClient, error) { }, nil } -func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) { +func (client AgentClient) getRequest(body []byte) (*http.Request, error) { req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) if err != nil { return req, err @@ -292,8 +292,8 @@ func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) { return req, nil } -func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (OpenAiRequestBody, error) { - request := OpenAiRequestBody{ +func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) { + request := AgentRequestBody{ Model: model, Temperature: temperature, ResponseFormat: ResponseFormat{ @@ -328,7 +328,7 @@ type ResponseChoice struct { FinishReason string `json:"finish_reason"` } -type OpenAiResponse struct { +type AgentResponse struct { Id string `json:"id"` Object string `json:"object"` Choices []ResponseChoice `json:"choices"` @@ -336,8 +336,8 @@ type OpenAiResponse struct { } // TODO: add usage parsing -func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) { - response := OpenAiResponse{} +func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) { + response := AgentResponse{} err := json.Unmarshal(jsonResponse, &response) if err != nil { @@ -358,7 +358,7 @@ func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) { return imageInfo, nil } -func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { +func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) if err != nil { return ImageInfo{}, err @@ -399,5 +399,5 @@ func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (Ima log.Println(string(response)) - return parseOpenAiResponse(response) + return parseAgentResponse(response) } diff --git a/backend/openai_test.go b/backend/agents/agent_test.go similarity index 93% rename from backend/openai_test.go rename to backend/agents/agent_test.go index 387e82f..f6c1a28 100644 --- a/backend/openai_test.go +++ b/backend/agents/agent_test.go @@ -11,7 +11,7 @@ import ( ) func TestMessageBuilder(t *testing.T) { - content := OpenAiMessages{} + content := AgentMessages{} err := content.AddSystem("Some prompt") @@ -27,7 +27,7 @@ func TestMessageBuilder(t *testing.T) { } func TestMessageBuilderImage(t *testing.T) { - content := OpenAiMessages{} + content := AgentMessages{} prompt := "some prompt" imageTitle := "image.png" @@ -41,7 +41,7 @@ func TestMessageBuilderImage(t *testing.T) { t.FailNow() } - promptMessage, ok := content.Messages[0].(OpenAiTextMessage) + promptMessage, ok := content.Messages[0].(AgentTextMessage) if !ok { t.Logf("Expected text content message, got %T\n", content.Messages[0]) t.FailNow() @@ -57,7 +57,7 @@ func TestMessageBuilderImage(t *testing.T) { t.FailNow() } - arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage) + arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage) if !ok { t.Logf("Expected text content message, got %T\n", content.Messages[1]) t.FailNow() @@ -73,7 +73,7 @@ func TestMessageBuilderImage(t *testing.T) { t.FailNow() } - imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage) + imageContent, ok := arrayContentMessage.Content[0].(AgentImage) if !ok { t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0]) t.FailNow() @@ -115,7 +115,7 @@ func TestResponse(t *testing.T) { body := io.NopCloser(buffer) - client := OpenAiClient{ + client := AgentClient{ url: "http://localhost:1234", apiKey: "some-key", Do: func(_req *http.Request) (*http.Response, error) { @@ -187,7 +187,7 @@ func TestResponseParsing(t *testing.T) { "system_fingerprint": "fp_7fcd609668" }` - imageParsed, err := parseOpenAiResponse([]byte(response)) + imageParsed, err := parseAgentResponse([]byte(response)) if err != nil { t.Log(err) t.FailNow()