refactor(agents): working e2e now
I guess some repeated code doesnt hurt anyone, if it keeps things simpler. Trying to be fancy with the interfaces didn't work so well.
This commit is contained in:
@ -206,6 +206,10 @@ func CreateAgentClient(prompt string) (AgentClient, error) {
|
|||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
return client.Do(req)
|
return client.Do(req)
|
||||||
},
|
},
|
||||||
|
|
||||||
|
ToolHandler: ToolsHandlers{
|
||||||
|
handlers: &map[string]ToolHandler{},
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -267,23 +271,18 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err
|
|||||||
func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error {
|
func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
for err == nil {
|
for {
|
||||||
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
err = client.ToolHandler.Handle(info, &request)
|
||||||
|
if err != nil {
|
||||||
response, requestError := client.Request(&request)
|
break
|
||||||
if requestError != nil {
|
|
||||||
return requestError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println(response)
|
_, err = client.Request(&request)
|
||||||
|
if err != nil {
|
||||||
a, innerErr := client.ToolHandler.Handle(info, &request)
|
break
|
||||||
|
}
|
||||||
err = innerErr
|
|
||||||
|
|
||||||
log.Println(a)
|
|
||||||
log.Println("--------------------------")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
log.Println(err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
@ -21,49 +21,45 @@ type ToolsHandlers struct {
|
|||||||
handlers *map[string]ToolHandler
|
handlers *map[string]ToolHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
|
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error {
|
||||||
agentMessage := request.Messages[len(request.Messages)-1]
|
agentMessage := request.Messages[len(request.Messages)-1]
|
||||||
|
|
||||||
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("Latest message was not a tool call.")
|
return errors.New("Latest message was not a tool call.")
|
||||||
}
|
}
|
||||||
|
|
||||||
fnName := toolCall.ToolCalls[0].Function.Name
|
fnName := toolCall.ToolCalls[0].Function.Name
|
||||||
arguments := toolCall.ToolCalls[0].Function.Arguments
|
arguments := toolCall.ToolCalls[0].Function.Arguments
|
||||||
|
|
||||||
|
log.Println(handler.handlers)
|
||||||
|
|
||||||
fnHandler, exists := (*handler.handlers)[fnName]
|
fnHandler, exists := (*handler.handlers)[fnName]
|
||||||
if !exists {
|
if !exists {
|
||||||
return "", errors.New("Could not find tool with this name.")
|
return errors.New("Could not find tool with this name.")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Calling: %s\n", fnName)
|
log.Printf("Calling: %s\n", fnName)
|
||||||
res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0])
|
res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
|
log.Println(res)
|
||||||
|
|
||||||
request.AddText(AgentTextMessage{
|
request.AddText(AgentTextMessage{
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
Name: "createLocation",
|
Name: fnName,
|
||||||
Content: res,
|
Content: res,
|
||||||
ToolCallId: toolCall.ToolCalls[0].Id,
|
ToolCallId: toolCall.ToolCalls[0].Id,
|
||||||
})
|
})
|
||||||
|
|
||||||
return res, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) {
|
func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
||||||
(*handler.handlers)["createLocation"] = ToolHandler{
|
(*handler.handlers)[name] = ToolHandler{
|
||||||
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
|
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
|
||||||
argsStruct := getArgs()
|
res, err := fn(info, args, call)
|
||||||
|
|
||||||
err := json.Unmarshal([]byte(args), &argsStruct)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := fn(info, argsStruct, call)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,6 @@ package agents
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
@ -202,22 +200,17 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
|||||||
}
|
}
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("listLocations",
|
agentClient.ToolHandler.AddTool("listLocations",
|
||||||
func() any {
|
func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
return ListLocationArguments{}
|
|
||||||
},
|
|
||||||
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
|
|
||||||
return agent.locationModel.List(context.Background(), info.UserId)
|
return agent.locationModel.List(context.Background(), info.UserId)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("createLocation",
|
agentClient.ToolHandler.AddTool("createLocation",
|
||||||
func() any {
|
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||||
return CreateLocationArguments{}
|
args := CreateLocationArguments{}
|
||||||
},
|
err := json.Unmarshal([]byte(_args), &args)
|
||||||
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
|
if err != nil {
|
||||||
args, ok := _args.(CreateLocationArguments)
|
return model.Locations{}, err
|
||||||
if !ok {
|
|
||||||
return _args, errors.New("Type error, arguments are not of the correct struct type")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@ -238,13 +231,11 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
|||||||
)
|
)
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("createEvent",
|
agentClient.ToolHandler.AddTool("createEvent",
|
||||||
func() any {
|
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||||
return CreateEventArguments{}
|
args := CreateEventArguments{}
|
||||||
},
|
err := json.Unmarshal([]byte(_args), &args)
|
||||||
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
|
if err != nil {
|
||||||
args, ok := _args.(CreateEventArguments)
|
return model.Locations{}, err
|
||||||
if !ok {
|
|
||||||
return _args, errors.New("Type error, arguments are not of the correct struct type")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -106,8 +106,6 @@ func main() {
|
|||||||
|
|
||||||
orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
// TODO: this can very much be parallel
|
// TODO: this can very much be parallel
|
||||||
|
|
||||||
log.Println("Calling locationAgent!")
|
log.Println("Calling locationAgent!")
|
||||||
|
Reference in New Issue
Block a user