9 Commits

32 changed files with 904 additions and 153 deletions

View File

@ -0,0 +1,22 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"github.com/google/uuid"
"time"
)
type ProcessingLists struct {
ID uuid.UUID `sql:"primary_key"`
UserID uuid.UUID
Title string
Fields string
Status Progress
CreatedAt *time.Time
}

View File

@ -20,10 +20,11 @@ type imageTable struct {
ID postgres.ColumnString
ImageName postgres.ColumnString
Description postgres.ColumnString
Image postgres.ColumnString
Image postgres.ColumnBytea
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type ImageTable struct {
@ -64,9 +65,10 @@ func newImageTableImpl(schemaName, tableName, alias string) imageTable {
IDColumn = postgres.StringColumn("id")
ImageNameColumn = postgres.StringColumn("image_name")
DescriptionColumn = postgres.StringColumn("description")
ImageColumn = postgres.StringColumn("image")
ImageColumn = postgres.ByteaColumn("image")
allColumns = postgres.ColumnList{IDColumn, ImageNameColumn, DescriptionColumn, ImageColumn}
mutableColumns = postgres.ColumnList{ImageNameColumn, DescriptionColumn, ImageColumn}
defaultColumns = postgres.ColumnList{IDColumn}
)
return imageTable{
@ -80,5 +82,6 @@ func newImageTableImpl(schemaName, tableName, alias string) imageTable {
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -23,6 +23,7 @@ type imageListsTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type ImageListsTable struct {
@ -65,6 +66,7 @@ func newImageListsTableImpl(schemaName, tableName, alias string) imageListsTable
ListIDColumn = postgres.StringColumn("list_id")
allColumns = postgres.ColumnList{IDColumn, ImageIDColumn, ListIDColumn}
mutableColumns = postgres.ColumnList{ImageIDColumn, ListIDColumn}
defaultColumns = postgres.ColumnList{IDColumn}
)
return imageListsTable{
@ -77,5 +79,6 @@ func newImageListsTableImpl(schemaName, tableName, alias string) imageListsTable
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -24,6 +24,7 @@ type imageSchemaItemsTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type ImageSchemaItemsTable struct {
@ -67,6 +68,7 @@ func newImageSchemaItemsTableImpl(schemaName, tableName, alias string) imageSche
ImageIDColumn = postgres.StringColumn("image_id")
allColumns = postgres.ColumnList{IDColumn, ValueColumn, SchemaItemIDColumn, ImageIDColumn}
mutableColumns = postgres.ColumnList{ValueColumn, SchemaItemIDColumn, ImageIDColumn}
defaultColumns = postgres.ColumnList{IDColumn}
)
return imageSchemaItemsTable{
@ -80,5 +82,6 @@ func newImageSchemaItemsTableImpl(schemaName, tableName, alias string) imageSche
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -25,6 +25,7 @@ type listsTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type ListsTable struct {
@ -69,6 +70,7 @@ func newListsTableImpl(schemaName, tableName, alias string) listsTable {
CreatedAtColumn = postgres.TimestampzColumn("created_at")
allColumns = postgres.ColumnList{IDColumn, UserIDColumn, NameColumn, DescriptionColumn, CreatedAtColumn}
mutableColumns = postgres.ColumnList{UserIDColumn, NameColumn, DescriptionColumn, CreatedAtColumn}
defaultColumns = postgres.ColumnList{IDColumn, CreatedAtColumn}
)
return listsTable{
@ -83,5 +85,6 @@ func newListsTableImpl(schemaName, tableName, alias string) listsTable {
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -23,6 +23,7 @@ type logsTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type LogsTable struct {
@ -65,6 +66,7 @@ func newLogsTableImpl(schemaName, tableName, alias string) logsTable {
CreatedAtColumn = postgres.TimestampzColumn("created_at")
allColumns = postgres.ColumnList{LogColumn, ImageIDColumn, CreatedAtColumn}
mutableColumns = postgres.ColumnList{LogColumn, ImageIDColumn, CreatedAtColumn}
defaultColumns = postgres.ColumnList{CreatedAtColumn}
)
return logsTable{
@ -77,5 +79,6 @@ func newLogsTableImpl(schemaName, tableName, alias string) logsTable {
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -0,0 +1,93 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var ProcessingLists = newProcessingListsTable("haystack", "processing_lists", "")
type processingListsTable struct {
postgres.Table
// Columns
ID postgres.ColumnString
UserID postgres.ColumnString
Title postgres.ColumnString
Fields postgres.ColumnString
Status postgres.ColumnString
CreatedAt postgres.ColumnTimestampz
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type ProcessingListsTable struct {
processingListsTable
EXCLUDED processingListsTable
}
// AS creates new ProcessingListsTable with assigned alias
func (a ProcessingListsTable) AS(alias string) *ProcessingListsTable {
return newProcessingListsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new ProcessingListsTable with assigned schema name
func (a ProcessingListsTable) FromSchema(schemaName string) *ProcessingListsTable {
return newProcessingListsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new ProcessingListsTable with assigned table prefix
func (a ProcessingListsTable) WithPrefix(prefix string) *ProcessingListsTable {
return newProcessingListsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new ProcessingListsTable with assigned table suffix
func (a ProcessingListsTable) WithSuffix(suffix string) *ProcessingListsTable {
return newProcessingListsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newProcessingListsTable(schemaName, tableName, alias string) *ProcessingListsTable {
return &ProcessingListsTable{
processingListsTable: newProcessingListsTableImpl(schemaName, tableName, alias),
EXCLUDED: newProcessingListsTableImpl("", "excluded", ""),
}
}
func newProcessingListsTableImpl(schemaName, tableName, alias string) processingListsTable {
var (
IDColumn = postgres.StringColumn("id")
UserIDColumn = postgres.StringColumn("user_id")
TitleColumn = postgres.StringColumn("title")
FieldsColumn = postgres.StringColumn("fields")
StatusColumn = postgres.StringColumn("status")
CreatedAtColumn = postgres.TimestampzColumn("created_at")
allColumns = postgres.ColumnList{IDColumn, UserIDColumn, TitleColumn, FieldsColumn, StatusColumn, CreatedAtColumn}
mutableColumns = postgres.ColumnList{UserIDColumn, TitleColumn, FieldsColumn, StatusColumn, CreatedAtColumn}
defaultColumns = postgres.ColumnList{IDColumn, StatusColumn, CreatedAtColumn}
)
return processingListsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
ID: IDColumn,
UserID: UserIDColumn,
Title: TitleColumn,
Fields: FieldsColumn,
Status: StatusColumn,
CreatedAt: CreatedAtColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -25,6 +25,7 @@ type schemaItemsTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type SchemaItemsTable struct {
@ -69,6 +70,7 @@ func newSchemaItemsTableImpl(schemaName, tableName, alias string) schemaItemsTab
SchemaIDColumn = postgres.StringColumn("schema_id")
allColumns = postgres.ColumnList{IDColumn, ItemColumn, ValueColumn, DescriptionColumn, SchemaIDColumn}
mutableColumns = postgres.ColumnList{ItemColumn, ValueColumn, DescriptionColumn, SchemaIDColumn}
defaultColumns = postgres.ColumnList{IDColumn}
)
return schemaItemsTable{
@ -83,5 +85,6 @@ func newSchemaItemsTableImpl(schemaName, tableName, alias string) schemaItemsTab
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -22,6 +22,7 @@ type schemasTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type SchemasTable struct {
@ -63,6 +64,7 @@ func newSchemasTableImpl(schemaName, tableName, alias string) schemasTable {
ListIDColumn = postgres.StringColumn("list_id")
allColumns = postgres.ColumnList{IDColumn, ListIDColumn}
mutableColumns = postgres.ColumnList{ListIDColumn}
defaultColumns = postgres.ColumnList{IDColumn}
)
return schemasTable{
@ -74,5 +76,6 @@ func newSchemasTableImpl(schemaName, tableName, alias string) schemasTable {
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -15,6 +15,7 @@ func UseSchema(schema string) {
ImageSchemaItems = ImageSchemaItems.FromSchema(schema)
Lists = Lists.FromSchema(schema)
Logs = Logs.FromSchema(schema)
ProcessingLists = ProcessingLists.FromSchema(schema)
SchemaItems = SchemaItems.FromSchema(schema)
Schemas = Schemas.FromSchema(schema)
UserImages = UserImages.FromSchema(schema)

View File

@ -24,6 +24,7 @@ type userImagesTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type UserImagesTable struct {
@ -67,6 +68,7 @@ func newUserImagesTableImpl(schemaName, tableName, alias string) userImagesTable
CreatedAtColumn = postgres.TimestampzColumn("created_at")
allColumns = postgres.ColumnList{IDColumn, ImageIDColumn, UserIDColumn, CreatedAtColumn}
mutableColumns = postgres.ColumnList{ImageIDColumn, UserIDColumn, CreatedAtColumn}
defaultColumns = postgres.ColumnList{IDColumn, CreatedAtColumn}
)
return userImagesTable{
@ -80,5 +82,6 @@ func newUserImagesTableImpl(schemaName, tableName, alias string) userImagesTable
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -24,6 +24,7 @@ type userImagesToProcessTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type UserImagesToProcessTable struct {
@ -67,6 +68,7 @@ func newUserImagesToProcessTableImpl(schemaName, tableName, alias string) userIm
UserIDColumn = postgres.StringColumn("user_id")
allColumns = postgres.ColumnList{IDColumn, StatusColumn, ImageIDColumn, UserIDColumn}
mutableColumns = postgres.ColumnList{StatusColumn, ImageIDColumn, UserIDColumn}
defaultColumns = postgres.ColumnList{IDColumn, StatusColumn}
)
return userImagesToProcessTable{
@ -80,5 +82,6 @@ func newUserImagesToProcessTableImpl(schemaName, tableName, alias string) userIm
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -22,6 +22,7 @@ type usersTable struct {
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
DefaultColumns postgres.ColumnList
}
type UsersTable struct {
@ -63,6 +64,7 @@ func newUsersTableImpl(schemaName, tableName, alias string) usersTable {
EmailColumn = postgres.StringColumn("email")
allColumns = postgres.ColumnList{IDColumn, EmailColumn}
mutableColumns = postgres.ColumnList{EmailColumn}
defaultColumns = postgres.ColumnList{IDColumn}
)
return usersTable{
@ -74,5 +76,6 @@ func newUsersTableImpl(schemaName, tableName, alias string) usersTable {
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}

View File

@ -187,6 +187,15 @@ func (chat *Chat) AddSystem(prompt string) {
})
}
func (chat *Chat) AddUser(msg string) {
chat.Messages = append(chat.Messages, ChatUserMessage{
Role: User,
MessageContent: SingleMessage{
Content: msg,
},
})
}
func (chat *Chat) AddImage(imageName string, image []byte, query *string) error {
extension := filepath.Ext(imageName)
if len(extension) == 0 {

View File

@ -270,3 +270,38 @@ func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageNa
return client.ToolLoop(toolHandlerInfo, &request)
}
func (client *AgentClient) RunAgentAlone(userID uuid.UUID, userReq string) error {
var tools any
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
if err != nil {
return err
}
toolChoice := "auto"
seed := 42
request := AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "google/gemini-2.5-flash",
RandomSeed: &seed,
Temperature: 0.3,
EndToolCall: client.Options.EndToolCall,
ResponseFormat: ResponseFormat{
Type: "text",
},
Chat: &Chat{
Messages: make([]ChatMessage, 0),
},
}
request.Chat.AddSystem(client.Options.SystemPrompt)
request.Chat.AddUser(userReq)
toolHandlerInfo := ToolHandlerInfo{
UserId: userID,
}
return client.ToolLoop(toolHandlerInfo, &request)
}

View File

@ -0,0 +1,140 @@
package agents
import (
"context"
"encoding/json"
"fmt"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"github.com/charmbracelet/log"
"github.com/google/uuid"
)
const createListAgentPrompt = `
You are an agent who's job is to produce a reasonable output for an unstructured input.
Your job is to create lists for the user, the user will give you a title and some fields they want
as part of the list. Your job is to take these fields, adjust their names so they have good names,
and add a good description for each one.
You can add fields if you think they make a lot of sense.
You can remove fields if they are not correct, but be sure before you do this.
`
const listJsonSchema = `
{
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "the title of the list"
},
"description": {
"type": "string",
"description": "the description of the list"
},
"fields": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the field."
},
"description": {
"type": "string",
"description": "A description of the field."
}
},
"required": [
"name",
"description"
]
},
"description": "An array of field objects."
}
},
"required": [
"fields"
]
}
`
type createNewListArguments struct {
Title string `json:"title"`
Description string `json:"description"`
Fields []struct {
Name string `json:"name"`
Description string `json:"description"`
} `json:"fields"`
}
type CreateListAgent struct {
client client.AgentClient
listModel models.ListModel
}
func (agent *CreateListAgent) CreateList(log *log.Logger, userID uuid.UUID, userReq string) error {
request := client.AgentRequestBody{
Model: "google/gemini-2.5-flash",
Temperature: 0.3,
ResponseFormat: client.ResponseFormat{
Type: "json_object",
JsonSchema: listJsonSchema,
},
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
request.Chat.AddSystem(agent.client.Options.SystemPrompt)
request.Chat.AddUser(userReq)
resp, err := agent.client.Request(&request)
if err != nil {
return fmt.Errorf("request: %w", err)
}
ctx := context.Background()
structuredOutput := resp.Choices[0].Message.Content
var createListArgs createNewListArguments
err = json.Unmarshal([]byte(structuredOutput), &createListArgs)
if err != nil {
return err
}
schemaItems := make([]model.SchemaItems, 0)
for _, field := range createListArgs.Fields {
schemaItems = append(schemaItems, model.SchemaItems{
Item: field.Name,
Description: field.Description,
Value: "string", // keep it simple for now.
})
}
agent.listModel.Save(ctx, userID, createListArgs.Title, createListArgs.Description, schemaItems)
return nil
}
func NewCreateListAgent(log *log.Logger, listModel models.ListModel) CreateListAgent {
client := client.CreateAgentClient(client.CreateAgentClientOptions{
SystemPrompt: createListAgentPrompt,
Log: log,
})
agent := CreateListAgent{
client,
listModel,
}
return agent
}

View File

@ -15,6 +15,9 @@ You are an AI agent who's job is to describe the image you see.
You should also add any text you see in the image, if no text exists, just add a description.
Be consise and don't add too much extra information or formatting characters, simple text.
You must write this text in Markdown. You can add extra information for the user.
You must organise this text nicely, not be all over the place.
`
type DescriptionAgent struct {

View File

@ -33,6 +33,8 @@ and extract some meaning about what the image is.
You must call "listLists" to see which available lists are already available.
Use "createList" only once, don't create multiple lists for one image.
You can add an image to multiple lists, this is also true if you already created a list. But only add to a list if it makes sense to do so.
**Tools:**
* think: Internal reasoning/planning step.
* listLists: Get existing lists
@ -184,10 +186,6 @@ func NewListAgent(log *log.Logger, listModel models.ListModel) client.AgentClien
return "Thought", nil
})
agentClient.ToolHandler.AddTool("listLists", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return listModel.List(context.Background(), info.UserId)
})
agentClient.ToolHandler.AddTool("createList", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
args := createListArguments{}
err := json.Unmarshal([]byte(_args), &args)
@ -208,6 +206,10 @@ func NewListAgent(log *log.Logger, listModel models.ListModel) client.AgentClien
return savedList, nil
})
agentClient.ToolHandler.AddTool("listLists", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return listModel.List(context.Background(), info.UserId)
})
agentClient.ToolHandler.AddTool("addToList", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
args := addToListArguments{}
err := json.Unmarshal([]byte(_args), &args)

View File

@ -8,8 +8,10 @@ import (
"net/http"
"os"
"screenmark/screenmark/agents"
"screenmark/screenmark/middleware"
"screenmark/screenmark/models"
"strconv"
"sync"
"time"
"github.com/charmbracelet/log"
@ -64,14 +66,24 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) {
}
descriptionAgent := agents.NewDescriptionAgent(createLogger("Description 📝", splitWriter), imageModel)
err = descriptionAgent.Describe(createLogger("Description 📓", splitWriter), image.Image.ID, image.Image.ImageName, image.Image.Image)
if err != nil {
log.Error(err)
}
listAgent := agents.NewListAgent(createLogger("Lists 🖋️", splitWriter), listModel)
listAgent.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
descriptionAgent.Describe(createLogger("Description 📓", splitWriter), image.Image.ID, image.Image.ImageName, image.Image.Image)
}()
go func() {
defer wg.Done()
listAgent.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
}()
wg.Wait()
_, err = imageModel.FinishProcessing(ctx, image.ID)
if err != nil {
@ -128,6 +140,57 @@ func ListenProcessingImageStatus(db *sql.DB, images models.ImageModel, notifier
}
}
func ListenNewStackEvents(db *sql.DB) {
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
if err != nil {
panic(err)
}
})
defer listener.Close()
stackModel := models.NewListModel(db)
newStacksLogger := createLogger("New Stacks 🤖", os.Stdout)
newStacksLogger.SetLevel(log.DebugLevel)
err := listener.Listen("new_stack")
if err != nil {
panic(err)
}
for parameters := range listener.Notify {
stackID := uuid.MustParse(parameters.Extra)
newStacksLogger.Debug("Starting processing stack", "StackID", stackID)
ctx := context.Background()
go func() {
stack, err := stackModel.GetProcessing(ctx, stackID)
if err != nil {
newStacksLogger.Error("failed to get processing", "error", err)
return
}
if err := stackModel.StartProcessing(ctx, stackID); err != nil {
newStacksLogger.Error("failed to start processing", "error", err)
return
}
listAgent := agents.NewCreateListAgent(newStacksLogger, stackModel)
userListRequest := fmt.Sprintf("title=%s,fields=%s", stack.Title, stack.Fields)
err = listAgent.CreateList(newStacksLogger, stack.UserID, userListRequest)
if err != nil {
newStacksLogger.Error("running agent", "err", err)
return
}
newStacksLogger.Debug("Finished processing stack", "StackID", stackID)
}()
}
}
/*
* TODO: We have channels open every a user sends an image.
* We never close these channels.
@ -140,7 +203,7 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc {
userSplitters := make(map[string]*ChannelSplitter[Notification])
return func(w http.ResponseWriter, r *http.Request) {
_userId := r.Context().Value(USER_ID).(uuid.UUID)
_userId := r.Context().Value(middleware.USER_ID).(uuid.UUID)
if _userId == uuid.Nil {
w.WriteHeader(http.StatusUnauthorized)
return

View File

@ -12,6 +12,9 @@ import (
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"screenmark/screenmark/stacks"
ourmiddleware "screenmark/screenmark/middleware"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
@ -41,6 +44,8 @@ func main() {
imageModel := models.NewImageModel(db)
userModel := models.NewUserModel(db)
stackHandler := stacks.CreateStackHandler(db)
mail, err := CreateMailClient()
if err != nil {
panic(err)
@ -52,14 +57,14 @@ func main() {
go ListenNewImageEvents(db, &notifier)
go ListenProcessingImageStatus(db, imageModel, &notifier)
go ListenNewStackEvents(db)
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(CorsMiddleware)
r.Options("/*", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
r.Use(ourmiddleware.CorsMiddleware)
r.Route("/stacks", stackHandler.CreateRoutes)
// Temporarily not in protect route because we aren't using cookies.
// Therefore they don't get automatically attached to the request.
@ -102,7 +107,7 @@ func main() {
})
r.Group(func(r chi.Router) {
r.Use(ProtectedRoute)
r.Use(ourmiddleware.ProtectedRoute)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
@ -112,7 +117,7 @@ func main() {
})
r.Get("/image", func(w http.ResponseWriter, r *http.Request) {
userId := r.Context().Value(USER_ID).(uuid.UUID)
userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID)
if err != nil {
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, "You cannot read this")
@ -168,7 +173,7 @@ func main() {
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
imageName := r.PathValue("name")
userId := r.Context().Value(USER_ID).(uuid.UUID)
userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID)
if len(imageName) == 0 {
w.WriteHeader(http.StatusBadRequest)
@ -256,7 +261,7 @@ func main() {
})
r.Route("/notifications", func(r chi.Router) {
r.Use(GetUserIdFromUrl)
r.Use(ourmiddleware.GetUserIdFromUrl)
r.Get("/", CreateEventsHandler(&notifier))
})
@ -322,8 +327,8 @@ func main() {
return
}
refresh := CreateRefreshToken(uuid)
access := CreateAccessToken(uuid)
refresh := ourmiddleware.CreateRefreshToken(uuid)
access := ourmiddleware.CreateAccessToken(uuid)
codeReturn := CodeReturn{
Access: access,
@ -355,8 +360,8 @@ func main() {
return
}
refresh := CreateRefreshToken(uuid)
access := CreateAccessToken(uuid)
refresh := ourmiddleware.CreateRefreshToken(uuid)
access := ourmiddleware.CreateAccessToken(uuid)
codeReturn := CodeReturn{
Access: access,

View File

@ -1,61 +0,0 @@
package main
import (
"context"
"net/http"
)
func CorsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
next.ServeHTTP(w, r)
})
}
const USER_ID = "UserID"
func ProtectedRoute(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if len(token) < len("Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
func GetUserIdFromUrl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if len(token) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}

View File

@ -0,0 +1,11 @@
package middleware
import "net/http"
func SetJson(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
next.ServeHTTP(w, r)
})
}

View File

@ -1,4 +1,4 @@
package main
package middleware
import (
"errors"

View File

@ -0,0 +1,116 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/charmbracelet/log"
"github.com/google/uuid"
)
func CorsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
// Access-Control-Allow-Methods is often needed for preflight OPTIONS requests
w.Header().Add("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// The client makes an OPTIONS preflight request before a complex request.
// We must handle this and respond with the appropriate headers.
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
const USER_ID = "UserID"
func GetUserID(ctx context.Context, logger *log.Logger, w http.ResponseWriter) (uuid.UUID, error) {
userId := ctx.Value(USER_ID)
if userId == nil {
w.WriteHeader(http.StatusUnauthorized)
logger.Warn("UserID not present in request")
return uuid.Nil, errors.New("context does not contain a user id")
}
userIdUuid, ok := userId.(uuid.UUID)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
logger.Warn("UserID not of correct type")
return uuid.Nil, fmt.Errorf("context user id is not of type uuid, got: %t", userId)
}
return userIdUuid, nil
}
func ProtectedRoute(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if len(token) < len("Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
func GetUserIdFromUrl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if len(token) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
func GetPathParamID(logger *log.Logger, param string, w http.ResponseWriter, r *http.Request) (uuid.UUID, error) {
pathParam := r.PathValue(param)
if len(pathParam) == 0 {
w.WriteHeader(http.StatusBadRequest)
err := fmt.Errorf("%s was not present", param)
logger.Warn(err)
return uuid.Nil, err
}
uuidParam, err := uuid.Parse(pathParam)
if len(pathParam) == 0 {
w.WriteHeader(http.StatusBadRequest)
err := fmt.Errorf("could not parse param: %w", err)
logger.Warn(err)
return uuid.Nil, err
}
return uuidParam, nil
}

View File

@ -26,6 +26,90 @@ type ListWithItems struct {
}
}
type ImageWithSchema struct {
model.ImageLists
Items []model.ImageSchemaItems
}
type IDValue struct {
ID string `json:"id"`
Value string `json:"value"`
}
// ========================================
// SELECT for lists
// ========================================
func (m ListModel) List(ctx context.Context, userId uuid.UUID) ([]ListWithItems, error) {
getListsWithItems := SELECT(
Lists.AllColumns,
Schemas.AllColumns,
SchemaItems.AllColumns,
).
FROM(
Lists.
INNER_JOIN(Schemas, Schemas.ListID.EQ(Lists.ID)).
INNER_JOIN(SchemaItems, SchemaItems.SchemaID.EQ(Schemas.ID)),
).
WHERE(Lists.UserID.EQ(UUID(userId)))
lists := []ListWithItems{}
err := getListsWithItems.QueryContext(ctx, m.dbPool, &lists)
return lists, err
}
func (m ListModel) ListItems(ctx context.Context, listID uuid.UUID) ([]ImageWithSchema, error) {
getListItems := SELECT(
ImageLists.AllColumns,
ImageSchemaItems.AllColumns,
).
FROM(
ImageLists.
INNER_JOIN(ImageSchemaItems, ImageSchemaItems.ImageID.EQ(ImageLists.ImageID)),
).
WHERE(ImageLists.ListID.EQ(UUID(listID)))
listItems := make([]ImageWithSchema, 0)
err := getListItems.QueryContext(ctx, m.dbPool, &listItems)
return listItems, err
}
// ========================================
// SELECT for specific items
// ========================================
func (m ListModel) GetProcessing(ctx context.Context, processingListID uuid.UUID) (model.ProcessingLists, error) {
getProcessingListStmt := ProcessingLists.
SELECT(ProcessingLists.AllColumns).
WHERE(ProcessingLists.ID.EQ(UUID(processingListID)))
list := model.ProcessingLists{}
err := getProcessingListStmt.QueryContext(ctx, m.dbPool, &list)
return list, err
}
// ========================================
// UPDATE
// ========================================
func (m ListModel) StartProcessing(ctx context.Context, processingListID uuid.UUID) error {
startProcessingStmt := ProcessingLists.
UPDATE(ProcessingLists.Status).
SET(model.Progress_InProgress).
WHERE(ProcessingLists.ID.EQ(UUID(processingListID)))
_, err := startProcessingStmt.ExecContext(ctx, m.dbPool)
return err
}
// ========================================
// INSERT methods
// ========================================
func (m ListModel) Save(ctx context.Context, userId uuid.UUID, name string, description string, schemaItems []model.SchemaItems) (ListWithItems, error) {
tx, err := m.dbPool.BeginTx(ctx, nil)
@ -86,30 +170,6 @@ func (m ListModel) Save(ctx context.Context, userId uuid.UUID, name string, desc
return listWithItems, err
}
func (m ListModel) List(ctx context.Context, userId uuid.UUID) ([]ListWithItems, error) {
getListsWithItems := SELECT(
Lists.AllColumns,
Schemas.AllColumns,
SchemaItems.AllColumns,
).
FROM(
Lists.
INNER_JOIN(Schemas, Schemas.ListID.EQ(Lists.ID)).
INNER_JOIN(SchemaItems, SchemaItems.SchemaID.EQ(Schemas.ID)),
).
WHERE(Lists.UserID.EQ(UUID(userId)))
lists := []ListWithItems{}
err := getListsWithItems.QueryContext(ctx, m.dbPool, &lists)
return lists, err
}
type IDValue struct {
ID string `json:"id"`
Value string `json:"value"`
}
func (m ListModel) SaveInto(ctx context.Context, listId uuid.UUID, imageId uuid.UUID, schemaValues []IDValue) error {
imageSchemaItems := make([]model.ImageSchemaItems, len(schemaValues))
@ -152,6 +212,16 @@ func (m ListModel) SaveInto(ctx context.Context, listId uuid.UUID, imageId uuid.
return err
}
func (m ListModel) SaveProcessing(ctx context.Context, userID uuid.UUID, title string, fields string) error {
insertListToProcess := ProcessingLists.
INSERT(ProcessingLists.UserID, ProcessingLists.Title, ProcessingLists.Fields).
VALUES(userID, title, fields)
_, err := insertListToProcess.ExecContext(ctx, m.dbPool)
return err
}
func NewListModel(db *sql.DB) ListModel {
return ListModel{dbPool: db}
}

View File

@ -51,7 +51,10 @@ func (m UserModel) Save(ctx context.Context, user model.Users) (model.Users, err
type UserImageWithImage struct {
model.UserImages
Image model.Image
Image struct {
model.Image
ImageLists []model.ImageLists
}
}
func (m UserModel) GetUserImages(ctx context.Context, userId uuid.UUID) ([]UserImageWithImage, error) {
@ -60,8 +63,13 @@ func (m UserModel) GetUserImages(ctx context.Context, userId uuid.UUID) ([]UserI
Image.ID,
Image.ImageName,
Image.Description,
ImageLists.AllColumns,
).
FROM(UserImages.INNER_JOIN(Image, Image.ID.EQ(UserImages.ImageID))).
FROM(
UserImages.
INNER_JOIN(Image, Image.ID.EQ(UserImages.ImageID)).
INNER_JOIN(ImageLists, ImageLists.ImageID.EQ(UserImages.ImageID)),
).
WHERE(UserImages.UserID.EQ(UUID(userId)))
userImages := []UserImageWithImage{}

View File

@ -52,6 +52,18 @@ CREATE TABLE haystack.lists (
created_at TIMESTAMP WITH TIME ZONE DEFAULT now()
);
CREATE TABLE haystack.processing_lists (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES haystack.users (id),
title TEXT NOT NULL,
fields TEXT NOT NULL,
status haystack.progress NOT NULL DEFAULT 'not-started',
created_at TIMESTAMP WITH TIME ZONE DEFAULT now()
);
CREATE TABLE haystack.image_lists (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
@ -104,6 +116,14 @@ PERFORM pg_notify('new_processing_image_status', NEW.id::text || NEW.status::tex
END
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION notify_new_stacks()
RETURNS TRIGGER AS $$
BEGIN
PERFORM pg_notify('new_stack', NEW.id::text);
RETURN NEW;
END
$$ LANGUAGE plpgsql;
/* -----| Triggers |----- */
CREATE OR REPLACE TRIGGER on_new_image AFTER INSERT
@ -117,4 +137,9 @@ ON haystack.user_images_to_process
FOR EACH ROW
EXECUTE PROCEDURE notify_new_processing_image_status();
CREATE OR REPLACE TRIGGER on_new_image AFTER INSERT
ON haystack.processing_lists
FOR EACH ROW
EXECUTE PROCEDURE notify_new_stacks();
/* -----| Test Data |----- */

153
backend/stacks/handler.go Normal file
View File

@ -0,0 +1,153 @@
package stacks
import (
"database/sql"
"encoding/json"
"io"
"net/http"
"os"
"screenmark/screenmark/middleware"
"screenmark/screenmark/models"
"github.com/charmbracelet/log"
"github.com/go-chi/chi/v5"
)
func writeJsonOrError[K any](logger *log.Logger, object K, w http.ResponseWriter) {
jsonObject, err := json.Marshal(object)
if err != nil {
logger.Warn("could not marshal json object", "err", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.Write(jsonObject)
w.WriteHeader(http.StatusOK)
}
type StackHandler struct {
logger *log.Logger
stackModel models.ListModel
}
func withValidatedPost[K any](
fn func(request K, w http.ResponseWriter, r *http.Request),
) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
request := new(K)
body, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
err = json.Unmarshal(body, request)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
fn(*request, w, r)
}
}
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userID, err := middleware.GetUserID(ctx, h.logger, w)
if err != nil {
return
}
lists, err := h.stackModel.List(ctx, userID)
if err != nil {
h.logger.Warn("could not get stacks", "err", err)
w.WriteHeader(http.StatusBadRequest)
return
}
writeJsonOrError(h.logger, lists, w)
}
func (h *StackHandler) getStackItems(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
_, err := middleware.GetUserID(ctx, h.logger, w)
if err != nil {
return
}
listID, err := middleware.GetPathParamID(h.logger, "listID", w, r)
if err != nil {
return
}
// TODO: must check for permission here.
lists, err := h.stackModel.ListItems(ctx, listID)
if err != nil {
h.logger.Warn("could not get list items", "err", err)
w.WriteHeader(http.StatusBadRequest)
return
}
writeJsonOrError(h.logger, lists, w)
}
type EditStack struct {
Hello string `json:"hello"`
}
func (h *StackHandler) editStack(req EditStack, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
}
type CreateStackBody struct {
Title string `json:"title"`
// We want a regular string because AI will take care of creating these for us.
Fields string `json:"fields"`
}
func (h *StackHandler) createStack(body CreateStackBody, w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userID, err := middleware.GetUserID(ctx, h.logger, w)
if err != nil {
return
}
err = h.stackModel.SaveProcessing(ctx, userID, body.Title, body.Fields)
if err != nil {
h.logger.Warn("could not save processing", "err", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
func (h *StackHandler) CreateRoutes(r chi.Router) {
h.logger.Info("Mounting stack router")
r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRoute)
r.Use(middleware.SetJson)
r.Get("/", h.getAllStacks)
r.Get("/{listID}", h.getStackItems)
r.Post("/", withValidatedPost(h.createStack))
r.Patch("/{listID}", withValidatedPost(h.editStack))
})
}
func CreateStackHandler(db *sql.DB) StackHandler {
stackModel := models.NewListModel(db)
logger := log.New(os.Stdout).WithPrefix("Stacks")
return StackHandler{
logger,
stackModel,
}
}

View File

@ -0,0 +1,35 @@
import { List } from "@network/index";
import { Component } from "solid-js";
import fastHashCode from "../../utils/hash";
import { A } from "@solidjs/router";
const colors = [
"bg-emerald-50",
"bg-lime-50",
"bg-indigo-50",
"bg-sky-50",
"bg-amber-50",
"bg-teal-50",
"bg-fuchsia-50",
"bg-pink-50",
];
export const ListCard: Component<{ list: List }> = (props) => {
return (
<A
href={`/list/${props.list.ID}`}
class={
"flex flex-col p-4 border border-neutral-200 rounded-lg " +
colors[
fastHashCode(props.list.Name, { forcePositive: true }) % colors.length
]
}
>
<p class="text-xl font-bold">{props.list.Name}</p>
<p class="text-lg">{props.list.Images.length}</p>
</A>
);
};

View File

@ -96,7 +96,16 @@ const userImageValidator = strictObject({
CreatedAt: pipe(string()),
ImageID: pipe(string(), uuid()),
UserID: pipe(string(), uuid()),
Image: imageMetaValidator,
Image: strictObject({
...imageMetaValidator.entries,
ImageLists: array(
strictObject({
ID: pipe(string(), uuid()),
ImageID: pipe(string(), uuid()),
ListID: pipe(string(), uuid()),
}),
),
}),
});
const userProcessingImageValidator = strictObject({

View File

@ -1,21 +1,6 @@
import { Component, For } from "solid-js";
import { A } from "@solidjs/router";
import { useSearchImageContext } from "@contexts/SearchImageContext";
import fastHashCode from "../../utils/hash";
const colors = [
"bg-emerald-50",
"bg-lime-50",
"bg-indigo-50",
"bg-sky-50",
"bg-amber-50",
"bg-teal-50",
"bg-fuchsia-50",
"bg-pink-50",
];
import { ListCard } from "@components/list-card";
export const Categories: Component = () => {
const { lists } = useSearchImageContext();
@ -24,23 +9,7 @@ export const Categories: Component = () => {
<div class="rounded-xl bg-white p-4 flex flex-col gap-2">
<h2 class="text-xl font-bold">Generated Lists</h2>
<div class="w-full grid grid-cols-3 auto-rows-[minmax(100px,1fr)] gap-4">
<For each={lists()}>
{(list) => (
<A
href={`/list/${list.ID}`}
class={
"flex flex-col p-4 border border-neutral-200 rounded-lg " +
colors[
fastHashCode(list.Name, { forcePositive: true }) %
colors.length
]
}
>
<p class="text-xl font-bold">{list.Name}</p>
<p class="text-lg">{list.Images.length}</p>
</A>
)}
</For>
<For each={lists()}>{(list) => <ListCard list={list} />}</For>
</div>
</div>
);

View File

@ -1,13 +1,15 @@
import { ImageComponent } from "@components/image";
import { useSearchImageContext } from "@contexts/SearchImageContext";
import { useParams } from "@solidjs/router";
import { type Component } from "solid-js";
import { For, type Component } from "solid-js";
import SolidjsMarkdown from "solidjs-markdown";
import { List } from "../list";
import { ListCard } from "@components/list-card";
export const ImagePage: Component = () => {
const { imageId } = useParams<{ imageId: string }>();
const { userImages } = useSearchImageContext();
const { userImages, lists } = useSearchImageContext();
const image = () => userImages().find((i) => i.ImageID === imageId);
@ -16,11 +18,22 @@ export const ImagePage: Component = () => {
<div class="w-full bg-white rounded-xl p-4">
<ImageComponent ID={imageId} />
</div>
<div>
<h2 class="font-bold text-xl">Description</h2>
<div class="w-full bg-white rounded-xl p-4 flex flex-col gap-4">
<h2 class="font-bold text-2xl">Description</h2>
<div class="grid grid-cols-3 gap-4">
<For each={image()?.Image.ImageLists}>
{(imageList) => (
<ListCard
list={lists().find((l) => l.ID === imageList.ListID)!}
/>
)}
</For>
</div>
</div>
<div class="w-full bg-white rounded-xl p-4">
<h2 class="font-bold text-2xl">Description</h2>
<SolidjsMarkdown>{image()?.Image.Description}</SolidjsMarkdown>
</div>
<div class="w-full grid grid-cols-3 gap-2 grid-flow-row-dense p-4 bg-white rounded-xl"></div>
</main>
);
};