fix: saving image schema items

This commit is contained in:
2025-10-05 13:44:50 +01:00
parent 980b42aa44
commit 015a7cb5cd
6 changed files with 48 additions and 14 deletions

View File

@ -262,7 +262,7 @@ func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageNa
request.Chat.AddImage(imageName, imageData, client.Options.Query) request.Chat.AddImage(imageName, imageData, client.Options.Query)
toolHandlerInfo := ToolHandlerInfo{ toolHandlerInfo := ToolHandlerInfo{
ImageId: imageId, ImageID: imageId,
ImageName: imageName, ImageName: imageName,
UserId: userId, UserId: userId,
Image: &imageData, Image: &imageData,

View File

@ -9,7 +9,7 @@ import (
type ToolHandlerInfo struct { type ToolHandlerInfo struct {
UserId uuid.UUID UserId uuid.UUID
ImageId uuid.UUID ImageID uuid.UUID
ImageName string ImageName string
// Pointer because we don't want to copy this around too much. // Pointer because we don't want to copy this around too much.

View File

@ -40,7 +40,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() {
response := suite.handler.Handle( response := suite.handler.Handle(
ToolHandlerInfo{ ToolHandlerInfo{
UserId: uuid.Nil, UserId: uuid.Nil,
ImageId: uuid.Nil, ImageID: uuid.Nil,
}, },
ToolCall{ ToolCall{
Index: 0, Index: 0,
@ -91,7 +91,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() {
err := suite.client.Process( err := suite.client.Process(
ToolHandlerInfo{ ToolHandlerInfo{
UserId: uuid.Nil, UserId: uuid.Nil,
ImageId: uuid.Nil, ImageID: uuid.Nil,
}, },
&AgentRequestBody{ &AgentRequestBody{
Chat: &chat, Chat: &chat,
@ -154,7 +154,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
err := suite.client.Process( err := suite.client.Process(
ToolHandlerInfo{ ToolHandlerInfo{
UserId: uuid.Nil, UserId: uuid.Nil,
ImageId: uuid.Nil, ImageID: uuid.Nil,
}, },
&AgentRequestBody{ &AgentRequestBody{
Chat: &chat, Chat: &chat,

View File

@ -176,7 +176,7 @@ type addToListArguments struct {
Schema []models.IDValue Schema []models.IDValue
} }
func NewListAgent(log *log.Logger, listModel models.StackModel, limitsMethods limits.LimitsManagerMethods) client.AgentClient { func NewListAgent(log *log.Logger, stackModel models.StackModel, limitsMethods limits.LimitsManagerMethods) client.AgentClient {
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{ agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
SystemPrompt: listPrompt, SystemPrompt: listPrompt,
JsonTools: listTools, JsonTools: listTools,
@ -206,7 +206,7 @@ func NewListAgent(log *log.Logger, listModel models.StackModel, limitsMethods li
} }
ctx := context.Background() ctx := context.Background()
savedList, err := listModel.Save(ctx, info.UserId, args.Name, args.Desription, model.Progress_Complete) savedList, err := stackModel.Save(ctx, info.UserId, args.Name, args.Desription, model.Progress_Complete)
if err != nil { if err != nil {
log.Error("saving list", "err", err) log.Error("saving list", "err", err)
return "", err return "", err
@ -216,7 +216,7 @@ func NewListAgent(log *log.Logger, listModel models.StackModel, limitsMethods li
args.Schema[i].StackID = savedList.ID args.Schema[i].StackID = savedList.ID
} }
err = listModel.SaveItems(ctx, args.Schema) err = stackModel.SaveItems(ctx, args.Schema)
if err != nil { if err != nil {
log.Error("saving items", "err", err) log.Error("saving items", "err", err)
return "", err return "", err
@ -226,7 +226,7 @@ func NewListAgent(log *log.Logger, listModel models.StackModel, limitsMethods li
}) })
agentClient.ToolHandler.AddTool("listLists", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listLists", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return listModel.List(context.Background(), info.UserId) return stackModel.List(context.Background(), info.UserId)
}) })
agentClient.ToolHandler.AddTool("addToList", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("addToList", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
@ -243,7 +243,12 @@ func NewListAgent(log *log.Logger, listModel models.StackModel, limitsMethods li
return "", err return "", err
} }
if err := listModel.SaveImage(ctx, info.ImageId, listUUID); err != nil { imageStack, err := stackModel.SaveImage(ctx, info.ImageID, listUUID)
if err != nil {
return "", err
}
if err := stackModel.SaveSchemaItems(ctx, imageStack.ID, args.Schema); err != nil {
return "", err return "", err
} }

View File

@ -3,6 +3,7 @@ package models
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
. "screenmark/screenmark/.gen/haystack/haystack/table" . "screenmark/screenmark/.gen/haystack/haystack/table"
@ -103,12 +104,40 @@ func (m StackModel) SaveItems(ctx context.Context, items []model.SchemaItems) er
return err return err
} }
func (m StackModel) SaveImage(ctx context.Context, imageID uuid.UUID, stackID uuid.UUID) error { func (m StackModel) SaveImage(ctx context.Context, imageID uuid.UUID, stackID uuid.UUID) (model.ImageStacks, error) {
saveImageStmt := ImageStacks. saveImageStmt := ImageStacks.
INSERT(ImageStacks.ImageID, ImageStacks.StackID). INSERT(ImageStacks.ImageID, ImageStacks.StackID).
VALUES(imageID, stackID) VALUES(imageID, stackID).
RETURNING(ImageStacks.AllColumns)
_, err := saveImageStmt.ExecContext(ctx, m.dbPool) imageStack := model.ImageStacks{}
err := saveImageStmt.QueryContext(ctx, m.dbPool, &imageStack)
return imageStack, err
}
func (m StackModel) SaveSchemaItems(ctx context.Context, imageID uuid.UUID, items []IDValue) error {
if len(items) == 0 {
return fmt.Errorf("items cannot be empty")
}
saveSchemaItemStmt := ImageSchemaItems.
INSERT(
ImageSchemaItems.ImageID,
ImageSchemaItems.SchemaItemID,
ImageSchemaItems.Value,
)
for _, item := range items {
saveSchemaItemStmt = saveSchemaItemStmt.VALUES(
imageID,
item.ID,
item.Value,
)
}
_, err := saveSchemaItemStmt.ExecContext(ctx, m.dbPool)
return err return err
} }

View File

@ -222,7 +222,7 @@ const stackValidator = strictObject({
Name: string(), Name: string(),
Images: array(stackImage), Images: pipe(nullable(array(stackImage)), transform(l => l ?? [])),
SchemaItems: array(stackSchemaItem), SchemaItems: array(stackSchemaItem),
}); });