refactor(naming): using Agent
instead of openai
This commit is contained in:
@ -28,42 +28,42 @@ type ResponseFormat struct {
|
|||||||
JsonSchema any `json:"json_schema"`
|
JsonSchema any `json:"json_schema"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiRequestBody struct {
|
type AgentRequestBody struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
ResponseFormat ResponseFormat `json:"response_format"`
|
ResponseFormat ResponseFormat `json:"response_format"`
|
||||||
|
|
||||||
OpenAiMessages
|
AgentMessages
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiMessages struct {
|
type AgentMessages struct {
|
||||||
Messages []OpenAiMessage `json:"messages"`
|
Messages []AgentMessage `json:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiMessage interface {
|
type AgentMessage interface {
|
||||||
MessageToJson() ([]byte, error)
|
MessageToJson() ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiTextMessage struct {
|
type AgentTextMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (textContent OpenAiTextMessage) MessageToJson() ([]byte, error) {
|
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
||||||
// TODO: Validate the `Role`.
|
// TODO: Validate the `Role`.
|
||||||
return json.Marshal(textContent)
|
return json.Marshal(textContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiArrayMessage struct {
|
type AgentArrayMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content []OpenAiContent `json:"content"`
|
Content []AgentContent `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (arrayContent OpenAiArrayMessage) MessageToJson() ([]byte, error) {
|
func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
||||||
return json.Marshal(arrayContent)
|
return json.Marshal(arrayContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
|
func (content *AgentMessages) AddImage(imageName string, image []byte) error {
|
||||||
extension := filepath.Ext(imageName)
|
extension := filepath.Ext(imageName)
|
||||||
if len(extension) == 0 {
|
if len(extension) == 0 {
|
||||||
// TODO: could also validate for image types we support.
|
// TODO: could also validate for image types we support.
|
||||||
@ -74,8 +74,8 @@ func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
|
|||||||
|
|
||||||
encodedString := base64.StdEncoding.EncodeToString(image)
|
encodedString := base64.StdEncoding.EncodeToString(image)
|
||||||
|
|
||||||
arrayMessage := OpenAiArrayMessage{Role: ROLE_USER, Content: make([]OpenAiContent, 1)}
|
arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)}
|
||||||
arrayMessage.Content[0] = OpenAiImage{
|
arrayMessage.Content[0] = AgentImage{
|
||||||
ImageType: IMAGE_TYPE,
|
ImageType: IMAGE_TYPE,
|
||||||
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
|
||||||
}
|
}
|
||||||
@ -85,12 +85,12 @@ func (content *OpenAiMessages) AddImage(imageName string, image []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (content *OpenAiMessages) AddSystem(prompt string) error {
|
func (content *AgentMessages) AddSystem(prompt string) error {
|
||||||
if len(content.Messages) != 0 {
|
if len(content.Messages) != 0 {
|
||||||
return errors.New("You can only add a system prompt at the beginning")
|
return errors.New("You can only add a system prompt at the beginning")
|
||||||
}
|
}
|
||||||
|
|
||||||
content.Messages = append(content.Messages, OpenAiTextMessage{
|
content.Messages = append(content.Messages, AgentTextMessage{
|
||||||
Role: ROLE_SYSTEM,
|
Role: ROLE_SYSTEM,
|
||||||
Content: prompt,
|
Content: prompt,
|
||||||
})
|
})
|
||||||
@ -98,7 +98,7 @@ func (content *OpenAiMessages) AddSystem(prompt string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiContent interface {
|
type AgentContent interface {
|
||||||
ToJson() ([]byte, error)
|
ToJson() ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,12 +106,12 @@ type ImageUrl struct {
|
|||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiImage struct {
|
type AgentImage struct {
|
||||||
ImageType string `json:"type"`
|
ImageType string `json:"type"`
|
||||||
ImageUrl string `json:"image_url"`
|
ImageUrl string `json:"image_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (imageMessage OpenAiImage) ToJson() ([]byte, error) {
|
func (imageMessage AgentImage) ToJson() ([]byte, error) {
|
||||||
imageMessage.ImageType = IMAGE_TYPE
|
imageMessage.ImageType = IMAGE_TYPE
|
||||||
return json.Marshal(imageMessage)
|
return json.Marshal(imageMessage)
|
||||||
}
|
}
|
||||||
@ -120,7 +120,7 @@ type AiClient interface {
|
|||||||
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
|
GetImageInfo(imageName string, imageData []byte) (ImageInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiClient struct {
|
type AgentClient struct {
|
||||||
url string
|
url string
|
||||||
apiKey string
|
apiKey string
|
||||||
systemPrompt string
|
systemPrompt string
|
||||||
@ -129,7 +129,7 @@ type OpenAiClient struct {
|
|||||||
Do func(req *http.Request) (*http.Response, error)
|
Do func(req *http.Request) (*http.Response, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (client OpenAiClient) Do(req *http.Request) () {
|
// func (client AgentClient) Do(req *http.Request) () {
|
||||||
// httpClient := http.Client{}
|
// httpClient := http.Client{}
|
||||||
// return httpClient.Do(req)
|
// return httpClient.Do(req)
|
||||||
// }
|
// }
|
||||||
@ -262,14 +262,14 @@ const RESPONSE_FORMAT = `
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
func CreateOpenAiClient() (OpenAiClient, error) {
|
func CreateAgentClient() (AgentClient, error) {
|
||||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||||
|
|
||||||
if len(apiKey) == 0 {
|
if len(apiKey) == 0 {
|
||||||
return OpenAiClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
||||||
}
|
}
|
||||||
|
|
||||||
return OpenAiClient{
|
return AgentClient{
|
||||||
apiKey: apiKey,
|
apiKey: apiKey,
|
||||||
url: "https://api.mistral.ai/v1/chat/completions",
|
url: "https://api.mistral.ai/v1/chat/completions",
|
||||||
systemPrompt: PROMPT,
|
systemPrompt: PROMPT,
|
||||||
@ -280,7 +280,7 @@ func CreateOpenAiClient() (OpenAiClient, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) {
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||||
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return req, err
|
return req, err
|
||||||
@ -292,8 +292,8 @@ func (client OpenAiClient) getRequest(body []byte) (*http.Request, error) {
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (OpenAiRequestBody, error) {
|
func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) {
|
||||||
request := OpenAiRequestBody{
|
request := AgentRequestBody{
|
||||||
Model: model,
|
Model: model,
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
ResponseFormat: ResponseFormat{
|
ResponseFormat: ResponseFormat{
|
||||||
@ -328,7 +328,7 @@ type ResponseChoice struct {
|
|||||||
FinishReason string `json:"finish_reason"`
|
FinishReason string `json:"finish_reason"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAiResponse struct {
|
type AgentResponse struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Choices []ResponseChoice `json:"choices"`
|
Choices []ResponseChoice `json:"choices"`
|
||||||
@ -336,8 +336,8 @@ type OpenAiResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add usage parsing
|
// TODO: add usage parsing
|
||||||
func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) {
|
func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) {
|
||||||
response := OpenAiResponse{}
|
response := AgentResponse{}
|
||||||
|
|
||||||
err := json.Unmarshal(jsonResponse, &response)
|
err := json.Unmarshal(jsonResponse, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -358,7 +358,7 @@ func parseOpenAiResponse(jsonResponse []byte) (ImageInfo, error) {
|
|||||||
return imageInfo, nil
|
return imageInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
|
||||||
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
|
aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ImageInfo{}, err
|
return ImageInfo{}, err
|
||||||
@ -399,5 +399,5 @@ func (client OpenAiClient) GetImageInfo(imageName string, imageData []byte) (Ima
|
|||||||
|
|
||||||
log.Println(string(response))
|
log.Println(string(response))
|
||||||
|
|
||||||
return parseOpenAiResponse(response)
|
return parseAgentResponse(response)
|
||||||
}
|
}
|
@ -11,7 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMessageBuilder(t *testing.T) {
|
func TestMessageBuilder(t *testing.T) {
|
||||||
content := OpenAiMessages{}
|
content := AgentMessages{}
|
||||||
|
|
||||||
err := content.AddSystem("Some prompt")
|
err := content.AddSystem("Some prompt")
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ func TestMessageBuilder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMessageBuilderImage(t *testing.T) {
|
func TestMessageBuilderImage(t *testing.T) {
|
||||||
content := OpenAiMessages{}
|
content := AgentMessages{}
|
||||||
|
|
||||||
prompt := "some prompt"
|
prompt := "some prompt"
|
||||||
imageTitle := "image.png"
|
imageTitle := "image.png"
|
||||||
@ -41,7 +41,7 @@ func TestMessageBuilderImage(t *testing.T) {
|
|||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
promptMessage, ok := content.Messages[0].(OpenAiTextMessage)
|
promptMessage, ok := content.Messages[0].(AgentTextMessage)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Logf("Expected text content message, got %T\n", content.Messages[0])
|
t.Logf("Expected text content message, got %T\n", content.Messages[0])
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
@ -57,7 +57,7 @@ func TestMessageBuilderImage(t *testing.T) {
|
|||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
arrayContentMessage, ok := content.Messages[1].(OpenAiArrayMessage)
|
arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Logf("Expected text content message, got %T\n", content.Messages[1])
|
t.Logf("Expected text content message, got %T\n", content.Messages[1])
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
@ -73,7 +73,7 @@ func TestMessageBuilderImage(t *testing.T) {
|
|||||||
t.FailNow()
|
t.FailNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
imageContent, ok := arrayContentMessage.Content[0].(OpenAiImage)
|
imageContent, ok := arrayContentMessage.Content[0].(AgentImage)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
|
t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0])
|
||||||
t.FailNow()
|
t.FailNow()
|
||||||
@ -115,7 +115,7 @@ func TestResponse(t *testing.T) {
|
|||||||
|
|
||||||
body := io.NopCloser(buffer)
|
body := io.NopCloser(buffer)
|
||||||
|
|
||||||
client := OpenAiClient{
|
client := AgentClient{
|
||||||
url: "http://localhost:1234",
|
url: "http://localhost:1234",
|
||||||
apiKey: "some-key",
|
apiKey: "some-key",
|
||||||
Do: func(_req *http.Request) (*http.Response, error) {
|
Do: func(_req *http.Request) (*http.Response, error) {
|
||||||
@ -187,7 +187,7 @@ func TestResponseParsing(t *testing.T) {
|
|||||||
"system_fingerprint": "fp_7fcd609668"
|
"system_fingerprint": "fp_7fcd609668"
|
||||||
}`
|
}`
|
||||||
|
|
||||||
imageParsed, err := parseOpenAiResponse([]byte(response))
|
imageParsed, err := parseAgentResponse([]byte(response))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
t.FailNow()
|
t.FailNow()
|
Reference in New Issue
Block a user