Orchestrator + Tooling rework #4
@ -206,6 +206,10 @@ func CreateAgentClient(prompt string) (AgentClient, error) {
|
||||
client := &http.Client{}
|
||||
return client.Do(req)
|
||||
},
|
||||
|
||||
ToolHandler: ToolsHandlers{
|
||||
handlers: &map[string]ToolHandler{},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -267,23 +271,18 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err
|
||||
func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error {
|
||||
var err error
|
||||
|
||||
for err == nil {
|
||||
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
||||
|
||||
response, requestError := client.Request(&request)
|
||||
if requestError != nil {
|
||||
return requestError
|
||||
for {
|
||||
err = client.ToolHandler.Handle(info, &request)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
log.Println(response)
|
||||
|
||||
a, innerErr := client.ToolHandler.Handle(info, &request)
|
||||
|
||||
err = innerErr
|
||||
|
||||
log.Println(a)
|
||||
log.Println("--------------------------")
|
||||
_, err = client.Request(&request)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
|
@ -21,49 +21,45 @@ type ToolsHandlers struct {
|
||||
handlers *map[string]ToolHandler
|
||||
}
|
||||
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error {
|
||||
agentMessage := request.Messages[len(request.Messages)-1]
|
||||
|
||||
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
||||
if !ok {
|
||||
return "", errors.New("Latest message was not a tool call.")
|
||||
return errors.New("Latest message was not a tool call.")
|
||||
}
|
||||
|
||||
fnName := toolCall.ToolCalls[0].Function.Name
|
||||
arguments := toolCall.ToolCalls[0].Function.Arguments
|
||||
|
||||
log.Println(handler.handlers)
|
||||
|
||||
fnHandler, exists := (*handler.handlers)[fnName]
|
||||
if !exists {
|
||||
return "", errors.New("Could not find tool with this name.")
|
||||
return errors.New("Could not find tool with this name.")
|
||||
}
|
||||
|
||||
log.Printf("Calling: %s\n", fnName)
|
||||
res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
log.Println(res)
|
||||
|
||||
request.AddText(AgentTextMessage{
|
||||
Role: "tool",
|
||||
Name: "createLocation",
|
||||
Name: fnName,
|
||||
Content: res,
|
||||
ToolCallId: toolCall.ToolCalls[0].Id,
|
||||
})
|
||||
|
||||
return res, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) {
|
||||
(*handler.handlers)["createLocation"] = ToolHandler{
|
||||
func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
||||
(*handler.handlers)[name] = ToolHandler{
|
||||
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
|
||||
argsStruct := getArgs()
|
||||
|
||||
err := json.Unmarshal([]byte(args), &argsStruct)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
res, err := fn(info, argsStruct, call)
|
||||
res, err := fn(info, args, call)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -3,8 +3,6 @@ package agents
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"screenmark/screenmark/models"
|
||||
@ -202,22 +200,17 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
||||
}
|
||||
|
||||
agentClient.ToolHandler.AddTool("listLocations",
|
||||
func() any {
|
||||
return ListLocationArguments{}
|
||||
},
|
||||
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
|
||||
func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return agent.locationModel.List(context.Background(), info.UserId)
|
||||
},
|
||||
)
|
||||
|
||||
agentClient.ToolHandler.AddTool("createLocation",
|
||||
func() any {
|
||||
return CreateLocationArguments{}
|
||||
},
|
||||
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
|
||||
args, ok := _args.(CreateLocationArguments)
|
||||
if !ok {
|
||||
return _args, errors.New("Type error, arguments are not of the correct struct type")
|
||||
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||
args := CreateLocationArguments{}
|
||||
err := json.Unmarshal([]byte(_args), &args)
|
||||
if err != nil {
|
||||
return model.Locations{}, err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@ -238,13 +231,11 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models
|
||||
)
|
||||
|
||||
agentClient.ToolHandler.AddTool("createEvent",
|
||||
func() any {
|
||||
return CreateEventArguments{}
|
||||
},
|
||||
func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) {
|
||||
args, ok := _args.(CreateEventArguments)
|
||||
if !ok {
|
||||
return _args, errors.New("Type error, arguments are not of the correct struct type")
|
||||
func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
|
||||
args := CreateEventArguments{}
|
||||
err := json.Unmarshal([]byte(_args), &args)
|
||||
if err != nil {
|
||||
return model.Locations{}, err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
@ -106,8 +106,6 @@ func main() {
|
||||
|
||||
orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
|
||||
return
|
||||
|
||||
// TODO: this can very much be parallel
|
||||
|
||||
log.Println("Calling locationAgent!")
|
||||
|
Reference in New Issue
Block a user