refactor(naming): using Agent instead of openai

This commit is contained in:
2025-03-21 17:07:00 +00:00
parent 4ea817e81f
commit 1cd4698969
2 changed files with 39 additions and 39 deletions

View File

@ -28,42 +28,42 @@ type ResponseFormat struct {
JsonSchema any `json:"json_schema"` JsonSchema any `json:"json_schema"`
} }
type OpenAiRequestBody struct { type AgentRequestBody struct {
Model string `json:"model"` Model string `json:"model"`
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"` ResponseFormat ResponseFormat `json:"response_format"`
OpenAiMessages AgentMessages
} }
type OpenAiMessages struct { type AgentMessages struct {
Messages []OpenAiMessage `json:"messages"` Messages []AgentMessage `json:"messages"`
} }
type OpenAiMessage interface { type AgentMessage interface {
MessageToJson() ([]byte, error) MessageToJson() ([]byte, error)
} }
type OpenAiTextMessage struct { type AgentTextMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
} }
func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) { func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`. // TODO: Validate the `Role`.
return json.Marshal(textContent) return json.Marshal(textContent)
} }
type OpenAiArrayMessage struct { type AgentArrayMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content []OpenAiContent `json:"content"` Content []AgentContent `json:"content"`
} }
func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) { func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent) return json.Marshal(arrayContent)
} }
func (content *OpenAiMessages) AddImage(imageName string, image []byte) error { func (content *AgentMessages) AddImage(imageName string, image []byte) error {
extension := filepath.Ext(imageName) extension := filepath.Ext(imageName)
if len(extension) == 0 { if len(extension) == 0 {
// TODO: could also validate for image types we support. // TODO: could also validate for image types we support.
@ -74,8 +74,8 @@ func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
encodedString := base64.StdEncoding.EncodeToString(image) encodedString := base64.StdEncoding.EncodeToString(image)
arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)} arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
arrayMessage.Content[0] = OpenAiImage{ arrayMessage.Content[0] = AgentImage{
ImageType: IMAGE_TYPE, ImageType: IMAGE_TYPE,
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
} }
@ -85,12 +85,12 @@ func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
return nil return nil
} }
func (content *OpenAiMessages) AddSystem(prompt string) error { func (content *AgentMessages) AddSystem(prompt string) error {
if len(content.Messages) != 0 { if len(content.Messages) != 0 {
return errors.New("You can only add a system prompt at the beginning") return errors.New("You can only add a system prompt at the beginning")
} }
content.Messages = append(content.Messages, OpenAiTextMessage{ content.Messages = append(content.Messages, AgentTextMessage{
Role: ROLE_SYSTEM, Role: ROLE_SYSTEM,
Content: prompt, Content: prompt,
}) })
@ -98,7 +98,7 @@ func (content *OpenAiMessages) AddSystem(prompt string) error {
return nil return nil
} }
type OpenAiContent interface { type AgentContent interface {
ToJson() ([]byte, error) ToJson() ([]byte, error)
} }
@ -106,12 +106,12 @@ type ImageUrl struct {
Url string `json:"url"` Url string `json:"url"`
} }
type OpenAiImage struct { type AgentImage struct {
ImageType string `json:"type"` ImageType string `json:"type"`
ImageUrl string `json:"image_url"` ImageUrl string `json:"image_url"`
} }
func (imageMessage OpenAiImage) ToJson() ([]byte, error) { func (imageMessage AgentImage) ToJson() ([]byte, error) {
imageMessage.ImageType = IMAGE_TYPE imageMessage.ImageType = IMAGE_TYPE
return json.Marshal(imageMessage) return json.Marshal(imageMessage)
} }
@ -120,7 +120,7 @@ type AiClient interface {
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
} }
type OpenAiClient struct { type AgentClient struct {
url string url string
apiKey string apiKey string
systemPrompt string systemPrompt string
@ -129,7 +129,7 @@ type OpenAiClient struct {
Do func(req *http.Request) (*http.Response, error) Do func(req *http.Request) (*http.Response, error)
} }
// func (client OpenAiClient) Do(req *http.Request) () { // func (client AgentClient) Do(req *http.Request) () {
// httpClient := http.Client{} // httpClient := http.Client{}
// return httpClient.Do(req) // return httpClient.Do(req)
// } // }
@ -262,14 +262,14 @@ const RESPONSE_FORMAT = `
} }
` `
func CreateOpenAiClient() (OpenAiClient, error) { func CreateAgentClient() (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY) apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 { if len(apiKey) == 0 {
return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.") return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
} }
return OpenAiClient{ return AgentClient{
apiKey: apiKey, apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions", url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: PROMPT, systemPrompt: PROMPT,
@ -280,7 +280,7 @@ func CreateOpenAiClient() (OpenAiClient, error) {
}, nil }, nil
} }
func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) { func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil { if err != nil {
return req, err return req, err
@ -292,8 +292,8 @@ func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) {
return req, nil return req, nil
} }
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (OpenAiRequestBody, error) { func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) {
request := OpenAiRequestBody{ request := AgentRequestBody{
Model: model, Model: model,
Temperature: temperature, Temperature: temperature,
ResponseFormat: ResponseFormat{ ResponseFormat: ResponseFormat{
@ -328,7 +328,7 @@ type ResponseChoice struct {
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} }
type OpenAiResponse struct { type AgentResponse struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Choices []ResponseChoice `json:"choices"` Choices []ResponseChoice `json:"choices"`
@ -336,8 +336,8 @@ type OpenAiResponse struct {
} }
// TODO: add usage parsing // TODO: add usage parsing
func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) { func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
response := OpenAiResponse{} response := AgentResponse{}
err := json.Unmarshal(jsonResponse, &response) err := json.Unmarshal(jsonResponse, &response)
if err != nil { if err != nil {
@ -358,7 +358,7 @@ func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) {
return imageInfo, nil return imageInfo, nil
} }
func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
if err != nil { if err != nil {
return ImageInfo{}, err return ImageInfo{}, err
@ -399,5 +399,5 @@ func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (Ima
log.Println(string(response)) log.Println(string(response))
return parseOpenAiResponse(response) return parseAgentResponse(response)
} }

View File

@ -11,7 +11,7 @@ import (
) )
func TestMessageBuilder(t *testing.T) { func TestMessageBuilder(t *testing.T) {
content := OpenAiMessages{} content := AgentMessages{}
err := content.AddSystem("Some prompt") err := content.AddSystem("Some prompt")
@ -27,7 +27,7 @@ func TestMessageBuilder(t *testing.T) {
} }
func TestMessageBuilderImage(t *testing.T) { func TestMessageBuilderImage(t *testing.T) {
content := OpenAiMessages{} content := AgentMessages{}
prompt := "some prompt" prompt := "some prompt"
imageTitle := "image.png" imageTitle := "image.png"
@ -41,7 +41,7 @@ func TestMessageBuilderImage(t *testing.T) {
t.FailNow() t.FailNow()
} }
promptMessage, ok := content.Messages[0].(OpenAiTextMessage) promptMessage, ok := content.Messages[0].(AgentTextMessage)
if !ok { if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[0]) t.Logf("Expected text content message, got %T\n", content.Messages[0])
t.FailNow() t.FailNow()
@ -57,7 +57,7 @@ func TestMessageBuilderImage(t *testing.T) {
t.FailNow() t.FailNow()
} }
arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage) arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage)
if !ok { if !ok {
t.Logf("Expected text content message, got %T\n", content.Messages[1]) t.Logf("Expected text content message, got %T\n", content.Messages[1])
t.FailNow() t.FailNow()
@ -73,7 +73,7 @@ func TestMessageBuilderImage(t *testing.T) {
t.FailNow() t.FailNow()
} }
imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage) imageContent, ok := arrayContentMessage.Content[0].(AgentImage)
if !ok { if !ok {
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0]) t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
t.FailNow() t.FailNow()
@ -115,7 +115,7 @@ func TestResponse(t *testing.T) {
body := io.NopCloser(buffer) body := io.NopCloser(buffer)
client := OpenAiClient{ client := AgentClient{
url: "http://localhost:1234", url: "http://localhost:1234",
apiKey: "some-key", apiKey: "some-key",
Do: func(_req *http.Request) (*http.Response, error) { Do: func(_req *http.Request) (*http.Response, error) {
@ -187,7 +187,7 @@ func TestResponseParsing(t *testing.T) {
"system_fingerprint": "fp_7fcd609668" "system_fingerprint": "fp_7fcd609668"
}` }`
imageParsed, err := parseOpenAiResponse([]byte(response)) imageParsed, err := parseAgentResponse([]byte(response))
if err != nil { if err != nil {
t.Log(err) t.Log(err)
t.FailNow() t.FailNow()