diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 59ea381..767cbed 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -206,6 +206,10 @@ func CreateAgentClient(prompt string) (AgentClient, error) { client := &http.Client{} return client.Do(req) }, + + ToolHandler: ToolsHandlers{ + handlers: &map[string]ToolHandler{}, + }, }, nil } @@ -267,23 +271,18 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error { var err error - for err == nil { - log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) - - response, requestError := client.Request(&request) - if requestError != nil { - return requestError + for { + err = client.ToolHandler.Handle(info, &request) + if err != nil { + break } - log.Println(response) - - a, innerErr := client.ToolHandler.Handle(info, &request) - - err = innerErr - - log.Println(a) - log.Println("--------------------------") + _, err = client.Request(&request) + if err != nil { + break + } } - return nil + log.Println(err) + return err } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index b9c8f42..5f67a0c 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -21,49 +21,45 @@ type ToolsHandlers struct { handlers *map[string]ToolHandler } -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error { agentMessage := request.Messages[len(request.Messages)-1] toolCall, ok := agentMessage.(AgentAssistantToolCall) if !ok { - return "", errors.New("Latest message was not a tool call.") + return errors.New("Latest message was not a tool call.") } fnName := toolCall.ToolCalls[0].Function.Name arguments := toolCall.ToolCalls[0].Function.Arguments + log.Println(handler.handlers) + fnHandler, exists := (*handler.handlers)[fnName] if !exists { - return "", errors.New("Could not find tool with this name.") + return errors.New("Could not find tool with this name.") } log.Printf("Calling: %s\n", fnName) res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0]) if err != nil { - return "", err + return err } + log.Println(res) request.AddText(AgentTextMessage{ Role: "tool", - Name: "createLocation", + Name: fnName, Content: res, ToolCallId: toolCall.ToolCalls[0].Id, }) - return res, nil + return nil } -func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) { - (*handler.handlers)["createLocation"] = ToolHandler{ +func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { + (*handler.handlers)[name] = ToolHandler{ Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) { - argsStruct := getArgs() - - err := json.Unmarshal([]byte(args), &argsStruct) - if err != nil { - return "", err - } - - res, err := fn(info, argsStruct, call) + res, err := fn(info, args, call) if err != nil { return "", err } diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index eb2a41b..c489cdb 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -3,8 +3,6 @@ package agents import ( "context" "encoding/json" - "errors" - "log" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents/client" "screenmark/screenmark/models" @@ -202,22 +200,17 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models } agentClient.ToolHandler.AddTool("listLocations", - func() any { - return ListLocationArguments{} - }, - func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { return agent.locationModel.List(context.Background(), info.UserId) }, ) agentClient.ToolHandler.AddTool("createLocation", - func() any { - return CreateLocationArguments{} - }, - func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { - args, ok := _args.(CreateLocationArguments) - if !ok { - return _args, errors.New("Type error, arguments are not of the correct struct type") + func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { + args := CreateLocationArguments{} + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return model.Locations{}, err } ctx := context.Background() @@ -238,13 +231,11 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models ) agentClient.ToolHandler.AddTool("createEvent", - func() any { - return CreateEventArguments{} - }, - func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { - args, ok := _args.(CreateEventArguments) - if !ok { - return _args, errors.New("Type error, arguments are not of the correct struct type") + func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { + args := CreateEventArguments{} + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return model.Locations{}, err } ctx := context.Background() diff --git a/backend/main.go b/backend/main.go index dbc80df..8e4230a 100644 --- a/backend/main.go +++ b/backend/main.go @@ -106,8 +106,6 @@ func main() { orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - return - // TODO: this can very much be parallel log.Println("Calling locationAgent!")