diff --git a/controller/task_detail.go b/controller/task_detail.go index df28445..00f1abe 100644 --- a/controller/task_detail.go +++ b/controller/task_detail.go @@ -312,25 +312,11 @@ func (tc *TaskController) SaveStatus(statusDisplay string) bool { // SaveType saves the new type to the current task after validating the display value. // Returns true if the type was successfully updated, false otherwise. func (tc *TaskController) SaveType(typeDisplay string) bool { - // Parse type display back to TaskType - var newType taskpkg.Type - typeFound := false - - for _, t := range []taskpkg.Type{ - taskpkg.TypeStory, - taskpkg.TypeBug, - taskpkg.TypeSpike, - taskpkg.TypeEpic, - } { - if taskpkg.TypeDisplay(t) == typeDisplay { - newType = t - typeFound = true - break - } - } - - if !typeFound { - newType = taskpkg.NormalizeType(typeDisplay) + // reverse the display string ("Bug 💥") back to a canonical key ("bug") + newType, ok := taskpkg.ParseDisplay(typeDisplay) + if !ok { + slog.Warn("unrecognized type display", "display", typeDisplay) + return false } // Validate using TypeValidator diff --git a/controller/task_detail_test.go b/controller/task_detail_test.go index bf6803b..4eba6f1 100644 --- a/controller/task_detail_test.go +++ b/controller/task_detail_test.go @@ -222,7 +222,7 @@ func TestTaskController_SaveType(t *testing.T) { setupTask: func(tc *TaskController, s store.Store) { tc.SetDraft(newTestTask()) }, - typeDisplay: "Bug", + typeDisplay: task.TypeDisplay(task.TypeBug), wantType: task.TypeBug, wantSuccess: true, }, @@ -233,25 +233,25 @@ func TestTaskController_SaveType(t *testing.T) { _ = s.CreateTask(t) tc.StartEditSession(t.ID) }, - typeDisplay: "Spike", + typeDisplay: task.TypeDisplay(task.TypeSpike), wantType: task.TypeSpike, wantSuccess: true, }, { - name: "invalid type normalizes to default", + name: "invalid type is rejected", setupTask: func(tc *TaskController, s store.Store) { tc.SetDraft(newTestTask()) }, typeDisplay: "InvalidType", - wantType: task.TypeStory, // NormalizeType defaults to story - wantSuccess: true, + wantType: task.TypeStory, // task type unchanged from setup + wantSuccess: false, }, { name: "no active task", setupTask: func(tc *TaskController, s store.Store) { // Don't set up any task }, - typeDisplay: "Story", + typeDisplay: task.TypeDisplay(task.TypeStory), wantSuccess: false, }, } diff --git a/go.mod b/go.mod index 74fe457..359af51 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.1.6 // indirect github.com/alecthomas/chroma/v2 v2.14.0 // indirect + github.com/alecthomas/participle/v2 v2.1.4 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect diff --git a/go.sum b/go.sum index abe9013..c81156a 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,11 @@ github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNx github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E= github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I= +github.com/alecthomas/participle/v2 v2.1.4 h1:W/H79S8Sat/krZ3el6sQMvMaahJ+XcM9WSI2naI7w2U= +github.com/alecthomas/participle/v2 v2.1.4/go.mod h1:8tqVbpTX20Ru4NfYQgZf4mP18eXPTBViyMWiArNEgGI= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= diff --git a/ruki/ast.go b/ruki/ast.go new file mode 100644 index 0000000..77b569b --- /dev/null +++ b/ruki/ast.go @@ -0,0 +1,190 @@ +package ruki + +import "time" + +// --- top-level union types --- + +// Statement is the result of parsing a CRUD command. +// Exactly one variant is non-nil. +type Statement struct { + Select *SelectStmt + Create *CreateStmt + Update *UpdateStmt + Delete *DeleteStmt +} + +// SelectStmt represents "select [where ]". +type SelectStmt struct { + Where Condition // nil = select all +} + +// CreateStmt represents "create =...". +type CreateStmt struct { + Assignments []Assignment +} + +// UpdateStmt represents "update where set =...". +type UpdateStmt struct { + Where Condition + Set []Assignment +} + +// DeleteStmt represents "delete where ". +type DeleteStmt struct { + Where Condition +} + +// --- triggers --- + +// Trigger is the result of parsing a reactive rule. +type Trigger struct { + Timing string // "before" or "after" + Event string // "create", "update", or "delete" + Where Condition // optional guard (nil if omitted) + Action *Statement // after-triggers only (create/update/delete, not select) + Run *RunAction // after-triggers only (alternative to Action) + Deny *string // before-triggers only +} + +// RunAction represents "run()" as a top-level trigger action. +type RunAction struct { + Command Expr +} + +// --- conditions --- + +// Condition is the interface for all boolean condition nodes. +type Condition interface { + conditionNode() +} + +// BinaryCondition represents " and/or ". +type BinaryCondition struct { + Op string // "and" or "or" + Left Condition + Right Condition +} + +// NotCondition represents "not ". +type NotCondition struct { + Inner Condition +} + +// CompareExpr represents " ". +type CompareExpr struct { + Left Expr + Op string // "=", "!=", "<", ">", "<=", ">=" + Right Expr +} + +// IsEmptyExpr represents " is [not] empty". +type IsEmptyExpr struct { + Expr Expr + Negated bool // true = "is not empty" +} + +// InExpr represents " [not] in ". +type InExpr struct { + Value Expr + Collection Expr + Negated bool // true = "not in" +} + +// QuantifierExpr represents " any/all ". +type QuantifierExpr struct { + Expr Expr + Kind string // "any" or "all" + Condition Condition +} + +func (*BinaryCondition) conditionNode() {} +func (*NotCondition) conditionNode() {} +func (*CompareExpr) conditionNode() {} +func (*IsEmptyExpr) conditionNode() {} +func (*InExpr) conditionNode() {} +func (*QuantifierExpr) conditionNode() {} + +// --- expressions --- + +// Expr is the interface for all expression nodes. +type Expr interface { + exprNode() +} + +// FieldRef represents a bare field name like "status" or "priority". +type FieldRef struct { + Name string +} + +// QualifiedRef represents "old.field" or "new.field". +type QualifiedRef struct { + Qualifier string // "old" or "new" + Name string +} + +// StringLiteral represents a double-quoted string value. +type StringLiteral struct { + Value string +} + +// IntLiteral represents an integer value. +type IntLiteral struct { + Value int +} + +// DateLiteral represents a YYYY-MM-DD date. +type DateLiteral struct { + Value time.Time +} + +// DurationLiteral represents a number+unit like "2day" or "1week". +type DurationLiteral struct { + Value int + Unit string +} + +// ListLiteral represents ["a", "b", ...]. +type ListLiteral struct { + Elements []Expr +} + +// EmptyLiteral represents the "empty" keyword. +type EmptyLiteral struct{} + +// FunctionCall represents "name(args...)". +type FunctionCall struct { + Name string + Args []Expr +} + +// BinaryExpr represents " +/- ". +type BinaryExpr struct { + Op string // "+" or "-" + Left Expr + Right Expr +} + +// SubQuery represents "select [where ]" used inside count(). +type SubQuery struct { + Where Condition // nil = select all +} + +func (*FieldRef) exprNode() {} +func (*QualifiedRef) exprNode() {} +func (*StringLiteral) exprNode() {} +func (*IntLiteral) exprNode() {} +func (*DateLiteral) exprNode() {} +func (*DurationLiteral) exprNode() {} +func (*ListLiteral) exprNode() {} +func (*EmptyLiteral) exprNode() {} +func (*FunctionCall) exprNode() {} +func (*BinaryExpr) exprNode() {} +func (*SubQuery) exprNode() {} + +// --- assignments --- + +// Assignment represents "field=value" in create/update statements. +type Assignment struct { + Field string + Value Expr +} diff --git a/ruki/grammar.go b/ruki/grammar.go new file mode 100644 index 0000000..5168357 --- /dev/null +++ b/ruki/grammar.go @@ -0,0 +1,173 @@ +package ruki + +// grammar.go — unexported participle grammar structs. +// these encode operator precedence via grammar layering. +// consumers never see these; lower.go converts them to clean AST types. + +// --- top-level statement grammar --- + +type statementGrammar struct { + Select *selectGrammar `parser:" @@"` + Create *createGrammar `parser:"| @@"` + Update *updateGrammar `parser:"| @@"` + Delete *deleteGrammar `parser:"| @@"` +} + +type selectGrammar struct { + Where *orCond `parser:"'select' ( 'where' @@ )?"` +} + +type createGrammar struct { + Assignments []assignmentGrammar `parser:"'create' @@+"` +} + +type updateGrammar struct { + Where orCond `parser:"'update' 'where' @@"` + Set []assignmentGrammar `parser:"'set' @@+"` +} + +type deleteGrammar struct { + Where orCond `parser:"'delete' 'where' @@"` +} + +type assignmentGrammar struct { + Field string `parser:"@Ident '='"` + Value exprGrammar `parser:"@@"` +} + +// --- trigger grammar --- + +type triggerGrammar struct { + Timing string `parser:"@( 'before' | 'after' )"` + Event string `parser:"@( 'create' | 'update' | 'delete' )"` + Where *orCond `parser:"( 'where' @@ )?"` + Action *actionGrammar `parser:"( @@"` + Deny *denyGrammar `parser:"| @@ )?"` +} + +type actionGrammar struct { + Run *runGrammar `parser:" @@"` + Create *createGrammar `parser:"| @@"` + Update *updateGrammar `parser:"| @@"` + Delete *deleteGrammar `parser:"| @@"` +} + +type runGrammar struct { + Command exprGrammar `parser:"'run' '(' @@ ')'"` +} + +type denyGrammar struct { + Message string `parser:"'deny' @String"` +} + +// --- condition grammar (precedence layers) --- + +// orCond is the lowest-precedence condition layer. +type orCond struct { + Left andCond `parser:"@@"` + Right []andCond `parser:"( 'or' @@ )*"` +} + +type andCond struct { + Left notCond `parser:"@@"` + Right []notCond `parser:"( 'and' @@ )*"` +} + +type notCond struct { + Not *notCond `parser:" 'not' @@"` + Primary *primaryCond `parser:"| @@"` +} + +type primaryCond struct { + Paren *orCond `parser:" '(' @@ ')'"` + Expr *exprCond `parser:"| @@"` +} + +// exprCond parses an expression followed by a condition operator. +type exprCond struct { + Left exprGrammar `parser:"@@"` + Compare *compareTail `parser:"( @@"` + IsEmpty *isEmptyTail `parser:"| @@"` + IsNotEmpty *isNotEmptyTail `parser:"| @@"` + NotIn *notInTail `parser:"| @@"` + In *inTail `parser:"| @@"` + Any *quantifierTail `parser:"| @@"` + All *allQuantTail `parser:"| @@ )?"` +} + +type compareTail struct { + Op string `parser:"@CompareOp"` + Right exprGrammar `parser:"@@"` +} + +type isEmptyTail struct { + Is string `parser:"@'is' 'empty'"` +} + +type isNotEmptyTail struct { + Is string `parser:"@'is' 'not' 'empty'"` +} + +type inTail struct { + Collection exprGrammar `parser:"'in' @@"` +} + +type notInTail struct { + Collection exprGrammar `parser:"'not' 'in' @@"` +} + +type quantifierTail struct { + Condition primaryCond `parser:"'any' @@"` +} + +type allQuantTail struct { + Condition primaryCond `parser:"'all' @@"` +} + +// --- expression grammar --- + +type exprGrammar struct { + Left unaryExpr `parser:"@@"` + Tail []exprBinTail `parser:"@@*"` +} + +type exprBinTail struct { + Op string `parser:"@( Plus | Minus )"` + Right unaryExpr `parser:"@@"` +} + +type unaryExpr struct { + FuncCall *funcCallExpr `parser:" @@"` + SubQuery *subQueryExpr `parser:"| @@"` + QualRef *qualRefExpr `parser:"| @@"` + ListLit *listLitExpr `parser:"| @@"` + StrLit *string `parser:"| @String"` + DateLit *string `parser:"| @Date"` + DurLit *string `parser:"| @Duration"` + IntLit *int `parser:"| @Int"` + Empty *emptyExpr `parser:"| @@"` + FieldRef *string `parser:"| @Ident"` + Paren *exprGrammar `parser:"| '(' @@ ')'"` +} + +type funcCallExpr struct { + Name string `parser:"@Ident '('"` + Args []exprGrammar `parser:"( @@ ( ',' @@ )* )? ')'"` +} + +type subQueryExpr struct { + Where *orCond `parser:"'select' ( 'where' @@ )?"` +} + +type qualRefExpr struct { + Qualifier string `parser:"@( 'old' | 'new' ) '.'"` + Name string `parser:"@Ident"` +} + +type listLitExpr struct { + Elements []exprGrammar `parser:"'[' ( @@ ( ',' @@ )* )? ']'"` +} + +type emptyExpr struct { + Keyword string `parser:"@'empty'"` +} diff --git a/ruki/lexer.go b/ruki/lexer.go new file mode 100644 index 0000000..e236709 --- /dev/null +++ b/ruki/lexer.go @@ -0,0 +1,24 @@ +package ruki + +import "github.com/alecthomas/participle/v2/lexer" + +// rukiLexer defines the token rules for the ruki DSL. +// rule ordering is critical: longer/more-specific patterns first. +var rukiLexer = lexer.MustSimple([]lexer.SimpleRule{ + {Name: "Comment", Pattern: `--[^\n]*`}, + {Name: "Whitespace", Pattern: `\s+`}, + {Name: "Duration", Pattern: `\d+(?:sec|min|hour|day|week|month|year)s?`}, + {Name: "Date", Pattern: `\d{4}-\d{2}-\d{2}`}, + {Name: "Int", Pattern: `\d+`}, + {Name: "String", Pattern: `"(?:[^"\\]|\\.)*"`}, + {Name: "CompareOp", Pattern: `!=|<=|>=|[=<>]`}, + {Name: "Plus", Pattern: `\+`}, + {Name: "Minus", Pattern: `-`}, + {Name: "Dot", Pattern: `\.`}, + {Name: "LParen", Pattern: `\(`}, + {Name: "RParen", Pattern: `\)`}, + {Name: "LBracket", Pattern: `\[`}, + {Name: "RBracket", Pattern: `\]`}, + {Name: "Comma", Pattern: `,`}, + {Name: "Ident", Pattern: `[a-zA-Z_][a-zA-Z0-9_]*`}, +}) diff --git a/ruki/lower.go b/ruki/lower.go new file mode 100644 index 0000000..cfc7885 --- /dev/null +++ b/ruki/lower.go @@ -0,0 +1,383 @@ +package ruki + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +// lower.go converts participle grammar structs into clean AST types. + +func lowerStatement(g *statementGrammar) (*Statement, error) { + switch { + case g.Select != nil: + s, err := lowerSelect(g.Select) + if err != nil { + return nil, err + } + return &Statement{Select: s}, nil + case g.Create != nil: + s, err := lowerCreate(g.Create) + if err != nil { + return nil, err + } + return &Statement{Create: s}, nil + case g.Update != nil: + s, err := lowerUpdate(g.Update) + if err != nil { + return nil, err + } + return &Statement{Update: s}, nil + case g.Delete != nil: + s, err := lowerDelete(g.Delete) + if err != nil { + return nil, err + } + return &Statement{Delete: s}, nil + default: + return nil, fmt.Errorf("empty statement") + } +} + +func lowerSelect(g *selectGrammar) (*SelectStmt, error) { + var where Condition + if g.Where != nil { + var err error + where, err = lowerOrCond(g.Where) + if err != nil { + return nil, err + } + } + return &SelectStmt{Where: where}, nil +} + +func lowerCreate(g *createGrammar) (*CreateStmt, error) { + assignments, err := lowerAssignments(g.Assignments) + if err != nil { + return nil, err + } + return &CreateStmt{Assignments: assignments}, nil +} + +func lowerUpdate(g *updateGrammar) (*UpdateStmt, error) { + where, err := lowerOrCond(&g.Where) + if err != nil { + return nil, err + } + set, err := lowerAssignments(g.Set) + if err != nil { + return nil, err + } + return &UpdateStmt{Where: where, Set: set}, nil +} + +func lowerDelete(g *deleteGrammar) (*DeleteStmt, error) { + where, err := lowerOrCond(&g.Where) + if err != nil { + return nil, err + } + return &DeleteStmt{Where: where}, nil +} + +func lowerAssignments(gs []assignmentGrammar) ([]Assignment, error) { + result := make([]Assignment, len(gs)) + for i, g := range gs { + val, err := lowerExpr(&g.Value) + if err != nil { + return nil, err + } + result[i] = Assignment{Field: g.Field, Value: val} + } + return result, nil +} + +// --- trigger lowering --- + +func lowerTrigger(g *triggerGrammar) (*Trigger, error) { + t := &Trigger{ + Timing: g.Timing, + Event: g.Event, + } + + if g.Where != nil { + where, err := lowerOrCond(g.Where) + if err != nil { + return nil, err + } + t.Where = where + } + + if g.Action != nil { + if err := lowerTriggerAction(g.Action, t); err != nil { + return nil, err + } + } + + if g.Deny != nil { + msg := unquoteString(g.Deny.Message) + t.Deny = &msg + } + + return t, nil +} + +func lowerTriggerAction(g *actionGrammar, t *Trigger) error { + switch { + case g.Run != nil: + cmd, err := lowerExpr(&g.Run.Command) + if err != nil { + return err + } + t.Run = &RunAction{Command: cmd} + case g.Create != nil: + s, err := lowerCreate(g.Create) + if err != nil { + return err + } + t.Action = &Statement{Create: s} + case g.Update != nil: + s, err := lowerUpdate(g.Update) + if err != nil { + return err + } + t.Action = &Statement{Update: s} + case g.Delete != nil: + s, err := lowerDelete(g.Delete) + if err != nil { + return err + } + t.Action = &Statement{Delete: s} + default: + return fmt.Errorf("empty trigger action") + } + return nil +} + +// --- condition lowering --- + +func lowerOrCond(g *orCond) (Condition, error) { + left, err := lowerAndCond(&g.Left) + if err != nil { + return nil, err + } + for _, r := range g.Right { + right, err := lowerAndCond(&r) + if err != nil { + return nil, err + } + left = &BinaryCondition{Op: "or", Left: left, Right: right} + } + return left, nil +} + +func lowerAndCond(g *andCond) (Condition, error) { + left, err := lowerNotCond(&g.Left) + if err != nil { + return nil, err + } + for _, r := range g.Right { + right, err := lowerNotCond(&r) + if err != nil { + return nil, err + } + left = &BinaryCondition{Op: "and", Left: left, Right: right} + } + return left, nil +} + +func lowerNotCond(g *notCond) (Condition, error) { + if g.Not != nil { + inner, err := lowerNotCond(g.Not) + if err != nil { + return nil, err + } + return &NotCondition{Inner: inner}, nil + } + return lowerPrimaryCond(g.Primary) +} + +func lowerPrimaryCond(g *primaryCond) (Condition, error) { + if g.Paren != nil { + return lowerOrCond(g.Paren) + } + return lowerExprCond(g.Expr) +} + +func lowerExprCond(g *exprCond) (Condition, error) { + left, err := lowerExpr(&g.Left) + if err != nil { + return nil, err + } + + switch { + case g.Compare != nil: + right, err := lowerExpr(&g.Compare.Right) + if err != nil { + return nil, err + } + return &CompareExpr{Left: left, Op: g.Compare.Op, Right: right}, nil + + case g.IsEmpty != nil: + return &IsEmptyExpr{Expr: left, Negated: false}, nil + + case g.IsNotEmpty != nil: + return &IsEmptyExpr{Expr: left, Negated: true}, nil + + case g.In != nil: + coll, err := lowerExpr(&g.In.Collection) + if err != nil { + return nil, err + } + return &InExpr{Value: left, Collection: coll, Negated: false}, nil + + case g.NotIn != nil: + coll, err := lowerExpr(&g.NotIn.Collection) + if err != nil { + return nil, err + } + return &InExpr{Value: left, Collection: coll, Negated: true}, nil + + case g.Any != nil: + cond, err := lowerPrimaryCond(&g.Any.Condition) + if err != nil { + return nil, err + } + return &QuantifierExpr{Expr: left, Kind: "any", Condition: cond}, nil + + case g.All != nil: + cond, err := lowerPrimaryCond(&g.All.Condition) + if err != nil { + return nil, err + } + return &QuantifierExpr{Expr: left, Kind: "all", Condition: cond}, nil + + default: + // bare expression used as condition — this is a parse error + return nil, fmt.Errorf("expression used as condition without comparison operator") + } +} + +// --- expression lowering --- + +func lowerExpr(g *exprGrammar) (Expr, error) { + left, err := lowerUnary(&g.Left) + if err != nil { + return nil, err + } + for _, tail := range g.Tail { + right, err := lowerUnary(&tail.Right) + if err != nil { + return nil, err + } + left = &BinaryExpr{Op: tail.Op, Left: left, Right: right} + } + return left, nil +} + +func lowerUnary(g *unaryExpr) (Expr, error) { + switch { + case g.FuncCall != nil: + return lowerFuncCall(g.FuncCall) + case g.SubQuery != nil: + return lowerSubQuery(g.SubQuery) + case g.QualRef != nil: + return &QualifiedRef{Qualifier: g.QualRef.Qualifier, Name: g.QualRef.Name}, nil + case g.ListLit != nil: + return lowerListLit(g.ListLit) + case g.StrLit != nil: + return &StringLiteral{Value: unquoteString(*g.StrLit)}, nil + case g.DateLit != nil: + return parseDateLiteral(*g.DateLit) + case g.DurLit != nil: + return parseDurationLiteral(*g.DurLit) + case g.IntLit != nil: + return &IntLiteral{Value: *g.IntLit}, nil + case g.Empty != nil: + return &EmptyLiteral{}, nil + case g.FieldRef != nil: + return &FieldRef{Name: *g.FieldRef}, nil + case g.Paren != nil: + return lowerExpr(g.Paren) + default: + return nil, fmt.Errorf("empty expression") + } +} + +func lowerFuncCall(g *funcCallExpr) (Expr, error) { + args := make([]Expr, len(g.Args)) + for i, a := range g.Args { + arg, err := lowerExpr(&a) + if err != nil { + return nil, err + } + args[i] = arg + } + return &FunctionCall{Name: g.Name, Args: args}, nil +} + +func lowerSubQuery(g *subQueryExpr) (Expr, error) { + var where Condition + if g.Where != nil { + var err error + where, err = lowerOrCond(g.Where) + if err != nil { + return nil, err + } + } + return &SubQuery{Where: where}, nil +} + +func lowerListLit(g *listLitExpr) (Expr, error) { + elems := make([]Expr, len(g.Elements)) + for i, e := range g.Elements { + elem, err := lowerExpr(&e) + if err != nil { + return nil, err + } + elems[i] = elem + } + return &ListLiteral{Elements: elems}, nil +} + +// --- literal helpers --- + +func unquoteString(s string) string { + // strip surrounding quotes and unescape + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + unquoted, err := strconv.Unquote(s) + if err == nil { + return unquoted + } + // fallback: just strip quotes + return s[1 : len(s)-1] + } + return s +} + +func parseDateLiteral(s string) (Expr, error) { + t, err := time.Parse("2006-01-02", s) + if err != nil { + return nil, fmt.Errorf("invalid date literal %q: %w", s, err) + } + return &DateLiteral{Value: t}, nil +} + +func parseDurationLiteral(s string) (Expr, error) { + // find where digits end and unit begins + i := 0 + for i < len(s) && s[i] >= '0' && s[i] <= '9' { + i++ + } + if i == 0 || i == len(s) { + return nil, fmt.Errorf("invalid duration literal %q", s) + } + + val, err := strconv.Atoi(s[:i]) + if err != nil { + return nil, fmt.Errorf("invalid duration value in %q: %w", s, err) + } + + unit := strings.TrimSuffix(s[i:], "s") // normalize "days" → "day" + return &DurationLiteral{Value: val, Unit: unit}, nil +} diff --git a/ruki/parser.go b/ruki/parser.go new file mode 100644 index 0000000..5d2cdc3 --- /dev/null +++ b/ruki/parser.go @@ -0,0 +1,101 @@ +package ruki + +import ( + "github.com/alecthomas/participle/v2" +) + +// Schema provides the canonical field catalog and normalization functions +// that the parser uses for validation. Production code adapts this from +// workflow.Fields(), workflow.StatusRegistry, and workflow.TypeRegistry. +type Schema interface { + // Field returns the field spec for a given field name. + Field(name string) (FieldSpec, bool) + // NormalizeStatus validates and normalizes a raw status string. + // returns the canonical key and true, or ("", false) for unknown values. + NormalizeStatus(raw string) (string, bool) + // NormalizeType validates and normalizes a raw type string. + // returns the canonical key and true, or ("", false) for unknown values. + NormalizeType(raw string) (string, bool) +} + +// ValueType identifies the semantic type of a field in the DSL. +type ValueType int + +const ( + ValueString ValueType = iota + ValueInt // priority, points + ValueDate // due + ValueTimestamp // createdAt, updatedAt + ValueDuration // duration literals + ValueBool // contains() return type + ValueID // task identifier + ValueRef // reference to another task + ValueRecurrence // recurrence pattern + ValueListString // tags + ValueListRef // dependsOn + ValueStatus // status enum + ValueTaskType // type enum +) + +// FieldSpec describes a single task field for the parser. +type FieldSpec struct { + Name string + Type ValueType +} + +// Parser parses ruki DSL statements and triggers. +type Parser struct { + stmtParser *participle.Parser[statementGrammar] + triggerParser *participle.Parser[triggerGrammar] + schema Schema + qualifiers qualifierPolicy // set before each validation pass +} + +// NewParser constructs a Parser with the given schema for validation. +// panics if the grammar is invalid (programming error, not user error). +func NewParser(schema Schema) *Parser { + opts := []participle.Option{ + participle.Lexer(rukiLexer), + participle.Elide("Comment", "Whitespace"), + participle.UseLookahead(3), + } + return &Parser{ + stmtParser: participle.MustBuild[statementGrammar](opts...), + triggerParser: participle.MustBuild[triggerGrammar](opts...), + schema: schema, + } +} + +// ParseStatement parses a CRUD statement and returns a validated AST. +func (p *Parser) ParseStatement(input string) (*Statement, error) { + g, err := p.stmtParser.ParseString("", input) + if err != nil { + return nil, err + } + stmt, err := lowerStatement(g) + if err != nil { + return nil, err + } + p.qualifiers = noQualifiers + if err := p.validateStatement(stmt); err != nil { + return nil, err + } + return stmt, nil +} + +// ParseTrigger parses a reactive trigger rule and returns a validated AST. +func (p *Parser) ParseTrigger(input string) (*Trigger, error) { + g, err := p.triggerParser.ParseString("", input) + if err != nil { + return nil, err + } + trig, err := lowerTrigger(g) + if err != nil { + return nil, err + } + p.qualifiers = triggerQualifiers(trig.Event) + if err := p.validateTrigger(trig); err != nil { + return nil, err + } + return trig, nil +} diff --git a/ruki/parser_test.go b/ruki/parser_test.go new file mode 100644 index 0000000..c857c27 --- /dev/null +++ b/ruki/parser_test.go @@ -0,0 +1,633 @@ +package ruki + +import ( + "testing" + "time" +) + +// testSchema implements Schema for tests with standard tiki fields. +type testSchema struct{} + +func (testSchema) Field(name string) (FieldSpec, bool) { + fields := map[string]FieldSpec{ + "id": {Name: "id", Type: ValueID}, + "title": {Name: "title", Type: ValueString}, + "description": {Name: "description", Type: ValueString}, + "status": {Name: "status", Type: ValueStatus}, + "type": {Name: "type", Type: ValueTaskType}, + "tags": {Name: "tags", Type: ValueListString}, + "dependsOn": {Name: "dependsOn", Type: ValueListRef}, + "due": {Name: "due", Type: ValueDate}, + "recurrence": {Name: "recurrence", Type: ValueRecurrence}, + "assignee": {Name: "assignee", Type: ValueString}, + "priority": {Name: "priority", Type: ValueInt}, + "points": {Name: "points", Type: ValueInt}, + "createdBy": {Name: "createdBy", Type: ValueString}, + "createdAt": {Name: "createdAt", Type: ValueTimestamp}, + "updatedAt": {Name: "updatedAt", Type: ValueTimestamp}, + } + f, ok := fields[name] + return f, ok +} + +func (testSchema) NormalizeStatus(raw string) (string, bool) { + valid := map[string]string{ + "backlog": "backlog", + "ready": "ready", + "todo": "ready", + "in progress": "in_progress", + "in_progress": "in_progress", + "review": "review", + "done": "done", + "cancelled": "cancelled", + } + canonical, ok := valid[raw] + return canonical, ok +} + +func (testSchema) NormalizeType(raw string) (string, bool) { + valid := map[string]string{ + "story": "story", + "feature": "story", + "task": "story", + "bug": "bug", + "spike": "spike", + "epic": "epic", + } + canonical, ok := valid[raw] + return canonical, ok +} + +func newTestParser() *Parser { + return NewParser(testSchema{}) +} + +func TestParseSelect(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantWhere bool + }{ + {"select all", "select", false}, + {"select with where", `select where status = "done"`, true}, + {"select with and", `select where status = "done" and priority <= 2`, true}, + {"select with in", `select where "bug" in tags`, true}, + {"select with quantifier", `select where dependsOn any status != "done"`, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if stmt.Select == nil { + t.Fatal("expected Select, got nil") + } + if tt.wantWhere && stmt.Select.Where == nil { + t.Fatal("expected Where condition, got nil") + } + if !tt.wantWhere && stmt.Select.Where != nil { + t.Fatal("expected nil Where, got condition") + } + }) + } +} + +func TestParseCreate(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantFields int + }{ + { + "basic create", + `create title="Fix login" priority=2 status="ready" tags=["bug"]`, + 4, + }, + { + "single field", + `create title="hello"`, + 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if stmt.Create == nil { + t.Fatal("expected Create, got nil") + } + if len(stmt.Create.Assignments) != tt.wantFields { + t.Fatalf("expected %d assignments, got %d", tt.wantFields, len(stmt.Create.Assignments)) + } + }) + } +} + +func TestParseUpdate(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantSet int + }{ + { + "update by id", + `update where id = "TIKI-ABC123" set status="done"`, + 1, + }, + { + "update with complex where", + `update where status = "ready" and "sprint-3" in tags set status="cancelled"`, + 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if stmt.Update == nil { + t.Fatal("expected Update, got nil") + } + if len(stmt.Update.Set) != tt.wantSet { + t.Fatalf("expected %d set assignments, got %d", tt.wantSet, len(stmt.Update.Set)) + } + }) + } +} + +func TestParseDelete(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"delete by id", `delete where id = "TIKI-ABC123"`}, + {"delete with complex where", `delete where status = "cancelled" and "old" in tags`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if stmt.Delete == nil { + t.Fatal("expected Delete, got nil") + } + if stmt.Delete.Where == nil { + t.Fatal("expected Where condition, got nil") + } + }) + } +} + +func TestParseExpressions(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + check func(t *testing.T, stmt *Statement) + }{ + { + "string literal in assignment", + `create title="hello world"`, + func(t *testing.T, stmt *Statement) { + t.Helper() + sl, ok := stmt.Create.Assignments[0].Value.(*StringLiteral) + if !ok { + t.Fatalf("expected StringLiteral, got %T", stmt.Create.Assignments[0].Value) + } + if sl.Value != "hello world" { + t.Fatalf("expected %q, got %q", "hello world", sl.Value) + } + }, + }, + { + "int literal in assignment", + `create title="x" priority=2`, + func(t *testing.T, stmt *Statement) { + t.Helper() + il, ok := stmt.Create.Assignments[1].Value.(*IntLiteral) + if !ok { + t.Fatalf("expected IntLiteral, got %T", stmt.Create.Assignments[1].Value) + } + if il.Value != 2 { + t.Fatalf("expected 2, got %d", il.Value) + } + }, + }, + { + "date literal in assignment", + `create title="x" due=2026-03-25`, + func(t *testing.T, stmt *Statement) { + t.Helper() + dl, ok := stmt.Create.Assignments[1].Value.(*DateLiteral) + if !ok { + t.Fatalf("expected DateLiteral, got %T", stmt.Create.Assignments[1].Value) + } + expected := time.Date(2026, 3, 25, 0, 0, 0, 0, time.UTC) + if !dl.Value.Equal(expected) { + t.Fatalf("expected %v, got %v", expected, dl.Value) + } + }, + }, + { + "list literal in assignment", + `create title="x" tags=["bug", "frontend"]`, + func(t *testing.T, stmt *Statement) { + t.Helper() + ll, ok := stmt.Create.Assignments[1].Value.(*ListLiteral) + if !ok { + t.Fatalf("expected ListLiteral, got %T", stmt.Create.Assignments[1].Value) + } + if len(ll.Elements) != 2 { + t.Fatalf("expected 2 elements, got %d", len(ll.Elements)) + } + }, + }, + { + "empty literal in assignment", + `create title="x" assignee=empty`, + func(t *testing.T, stmt *Statement) { + t.Helper() + if _, ok := stmt.Create.Assignments[1].Value.(*EmptyLiteral); !ok { + t.Fatalf("expected EmptyLiteral, got %T", stmt.Create.Assignments[1].Value) + } + }, + }, + { + "function call in assignment", + `create title="x" due=next_date(recurrence)`, + func(t *testing.T, stmt *Statement) { + t.Helper() + fc, ok := stmt.Create.Assignments[1].Value.(*FunctionCall) + if !ok { + t.Fatalf("expected FunctionCall, got %T", stmt.Create.Assignments[1].Value) + } + if fc.Name != "next_date" { + t.Fatalf("expected next_date, got %s", fc.Name) + } + if len(fc.Args) != 1 { + t.Fatalf("expected 1 arg, got %d", len(fc.Args)) + } + }, + }, + { + "binary plus expression", + `create title="x" tags=tags + ["new"]`, + func(t *testing.T, stmt *Statement) { + t.Helper() + be, ok := stmt.Create.Assignments[1].Value.(*BinaryExpr) + if !ok { + t.Fatalf("expected BinaryExpr, got %T", stmt.Create.Assignments[1].Value) + } + if be.Op != "+" { + t.Fatalf("expected +, got %s", be.Op) + } + }, + }, + { + "duration literal", + `create title="x" due=2026-03-25 + 2day`, + func(t *testing.T, stmt *Statement) { + t.Helper() + be, ok := stmt.Create.Assignments[1].Value.(*BinaryExpr) + if !ok { + t.Fatalf("expected BinaryExpr, got %T", stmt.Create.Assignments[1].Value) + } + dur, ok := be.Right.(*DurationLiteral) + if !ok { + t.Fatalf("expected DurationLiteral, got %T", be.Right) + } + if dur.Value != 2 || dur.Unit != "day" { + t.Fatalf("expected 2day, got %d%s", dur.Value, dur.Unit) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + tt.check(t, stmt) + }) + } +} + +func TestParseConditions(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + check func(t *testing.T, stmt *Statement) + }{ + { + "simple compare", + `select where status = "done"`, + func(t *testing.T, stmt *Statement) { + t.Helper() + cmp, ok := stmt.Select.Where.(*CompareExpr) + if !ok { + t.Fatalf("expected CompareExpr, got %T", stmt.Select.Where) + } + if cmp.Op != "=" { + t.Fatalf("expected =, got %s", cmp.Op) + } + }, + }, + { + "is empty", + `select where assignee is empty`, + func(t *testing.T, stmt *Statement) { + t.Helper() + ie, ok := stmt.Select.Where.(*IsEmptyExpr) + if !ok { + t.Fatalf("expected IsEmptyExpr, got %T", stmt.Select.Where) + } + if ie.Negated { + t.Fatal("expected Negated=false") + } + }, + }, + { + "is not empty", + `select where description is not empty`, + func(t *testing.T, stmt *Statement) { + t.Helper() + ie, ok := stmt.Select.Where.(*IsEmptyExpr) + if !ok { + t.Fatalf("expected IsEmptyExpr, got %T", stmt.Select.Where) + } + if !ie.Negated { + t.Fatal("expected Negated=true") + } + }, + }, + { + "value in field", + `select where "bug" in tags`, + func(t *testing.T, stmt *Statement) { + t.Helper() + in, ok := stmt.Select.Where.(*InExpr) + if !ok { + t.Fatalf("expected InExpr, got %T", stmt.Select.Where) + } + if in.Negated { + t.Fatal("expected Negated=false") + } + }, + }, + { + "value not in list", + `select where status not in ["done", "cancelled"]`, + func(t *testing.T, stmt *Statement) { + t.Helper() + in, ok := stmt.Select.Where.(*InExpr) + if !ok { + t.Fatalf("expected InExpr, got %T", stmt.Select.Where) + } + if !in.Negated { + t.Fatal("expected Negated=true") + } + }, + }, + { + "and precedence", + `select where status = "done" and priority <= 2`, + func(t *testing.T, stmt *Statement) { + t.Helper() + bc, ok := stmt.Select.Where.(*BinaryCondition) + if !ok { + t.Fatalf("expected BinaryCondition, got %T", stmt.Select.Where) + } + if bc.Op != "and" { + t.Fatalf("expected and, got %s", bc.Op) + } + }, + }, + { + "or precedence — and binds tighter", + `select where priority = 1 or priority = 2 and status = "done"`, + func(t *testing.T, stmt *Statement) { + t.Helper() + // should parse as: priority=1 or (priority=2 and status="done") + bc, ok := stmt.Select.Where.(*BinaryCondition) + if !ok { + t.Fatalf("expected BinaryCondition, got %T", stmt.Select.Where) + } + if bc.Op != "or" { + t.Fatalf("expected or at top, got %s", bc.Op) + } + // right side should be an and + right, ok := bc.Right.(*BinaryCondition) + if !ok { + t.Fatalf("expected BinaryCondition on right, got %T", bc.Right) + } + if right.Op != "and" { + t.Fatalf("expected and on right, got %s", right.Op) + } + }, + }, + { + "not condition", + `select where not status = "done"`, + func(t *testing.T, stmt *Statement) { + t.Helper() + nc, ok := stmt.Select.Where.(*NotCondition) + if !ok { + t.Fatalf("expected NotCondition, got %T", stmt.Select.Where) + } + if _, ok := nc.Inner.(*CompareExpr); !ok { + t.Fatalf("expected CompareExpr inside not, got %T", nc.Inner) + } + }, + }, + { + "parenthesized condition", + `select where (status = "done" or status = "cancelled") and priority = 1`, + func(t *testing.T, stmt *Statement) { + t.Helper() + bc, ok := stmt.Select.Where.(*BinaryCondition) + if !ok { + t.Fatalf("expected BinaryCondition, got %T", stmt.Select.Where) + } + if bc.Op != "and" { + t.Fatalf("expected and at top, got %s", bc.Op) + } + // left should be an or (the parenthesized group) + left, ok := bc.Left.(*BinaryCondition) + if !ok { + t.Fatalf("expected BinaryCondition on left, got %T", bc.Left) + } + if left.Op != "or" { + t.Fatalf("expected or on left, got %s", left.Op) + } + }, + }, + { + "quantifier any", + `select where dependsOn any status != "done"`, + func(t *testing.T, stmt *Statement) { + t.Helper() + qe, ok := stmt.Select.Where.(*QuantifierExpr) + if !ok { + t.Fatalf("expected QuantifierExpr, got %T", stmt.Select.Where) + } + if qe.Kind != "any" { + t.Fatalf("expected any, got %s", qe.Kind) + } + }, + }, + { + "quantifier all", + `select where dependsOn all status = "done"`, + func(t *testing.T, stmt *Statement) { + t.Helper() + qe, ok := stmt.Select.Where.(*QuantifierExpr) + if !ok { + t.Fatalf("expected QuantifierExpr, got %T", stmt.Select.Where) + } + if qe.Kind != "all" { + t.Fatalf("expected all, got %s", qe.Kind) + } + }, + }, + { + "quantifier binds to primary — and separates", + `select where dependsOn any status != "done" and priority = 1`, + func(t *testing.T, stmt *Statement) { + t.Helper() + // should parse as: (dependsOn any (status != "done")) and (priority = 1) + bc, ok := stmt.Select.Where.(*BinaryCondition) + if !ok { + t.Fatalf("expected BinaryCondition at top, got %T", stmt.Select.Where) + } + if bc.Op != "and" { + t.Fatalf("expected and, got %s", bc.Op) + } + if _, ok := bc.Left.(*QuantifierExpr); !ok { + t.Fatalf("expected QuantifierExpr on left, got %T", bc.Left) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + tt.check(t, stmt) + }) + } +} + +func TestParseQualifiedRefs(t *testing.T) { + p := newTestParser() + + input := `select where status = "done"` + stmt, err := p.ParseStatement(input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + cmp, ok := stmt.Select.Where.(*CompareExpr) + if !ok { + t.Fatalf("expected CompareExpr, got %T", stmt.Select.Where) + } + fr, ok := cmp.Left.(*FieldRef) + if !ok { + t.Fatalf("expected FieldRef, got %T", cmp.Left) + } + if fr.Name != "status" { + t.Fatalf("expected status, got %s", fr.Name) + } +} + +func TestParseSubQuery(t *testing.T) { + p := newTestParser() + + input := `select where count(select where status = "in progress" and assignee = "bob") >= 3` + stmt, err := p.ParseStatement(input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + cmp, ok := stmt.Select.Where.(*CompareExpr) + if !ok { + t.Fatalf("expected CompareExpr, got %T", stmt.Select.Where) + } + + fc, ok := cmp.Left.(*FunctionCall) + if !ok { + t.Fatalf("expected FunctionCall, got %T", cmp.Left) + } + if fc.Name != "count" { + t.Fatalf("expected count, got %s", fc.Name) + } + + sq, ok := fc.Args[0].(*SubQuery) + if !ok { + t.Fatalf("expected SubQuery arg, got %T", fc.Args[0]) + } + if sq.Where == nil { + t.Fatal("expected SubQuery Where, got nil") + } +} + +func TestParseStatementErrors(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"empty input", ""}, + {"unknown keyword", "drop where id = 1"}, + {"missing where in update", `update set status="done"`}, + {"missing set in update", `update where id = "x"`}, + {"missing where in delete", `delete id = "x"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + }) + } +} + +func TestParseComment(t *testing.T) { + p := newTestParser() + + input := `-- this is a comment +select where status = "done"` + stmt, err := p.ParseStatement(input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if stmt.Select == nil { + t.Fatal("expected Select") + } +} diff --git a/ruki/trigger_test.go b/ruki/trigger_test.go new file mode 100644 index 0000000..fb0b953 --- /dev/null +++ b/ruki/trigger_test.go @@ -0,0 +1,290 @@ +package ruki + +import "testing" + +func TestParseTrigger_BeforeDeny(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + event string + }{ + { + "block completion with open deps", + `before update where new.status = "done" and dependsOn any status != "done" deny "cannot complete task with open dependencies"`, + "update", + }, + { + "deny delete high priority", + `before delete where old.priority <= 2 deny "cannot delete high priority tasks"`, + "delete", + }, + { + "require description for high priority", + `before update where new.priority <= 2 and new.description is empty deny "high priority tasks need a description"`, + "update", + }, + { + "require description for stories", + `before create where new.type = "story" and new.description is empty deny "stories must have a description"`, + "create", + }, + { + "prevent skipping review", + `before update where old.status = "in progress" and new.status = "done" deny "tasks must go through review before completion"`, + "update", + }, + { + "protect high priority from demotion", + `before update where old.priority = 1 and old.status = "in progress" and new.priority > 1 deny "cannot demote priority of active critical tasks"`, + "update", + }, + { + "no empty epics", + `before update where new.status = "done" and new.type = "epic" and blocks(new.id) is empty deny "epic has no dependencies"`, + "update", + }, + { + "WIP limit", + `before update where new.status = "in progress" and count(select where assignee = new.assignee and status = "in progress") >= 3 deny "WIP limit reached for this assignee"`, + "update", + }, + { + "points required before start", + `before update where new.status = "in progress" and new.points = 0 deny "tasks must be estimated before starting work"`, + "update", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + trig, err := p.ParseTrigger(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if trig.Timing != "before" { + t.Fatalf("expected before, got %s", trig.Timing) + } + if trig.Event != tt.event { + t.Fatalf("expected %s, got %s", tt.event, trig.Event) + } + if trig.Deny == nil { + t.Fatal("expected Deny, got nil") + } + if trig.Action != nil { + t.Fatal("expected nil Action in before-trigger") + } + }) + } +} + +func TestParseTrigger_AfterAction(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + event string + wantCreate bool + wantUpdate bool + wantDelete bool + wantRun bool + }{ + { + "recurring task create next", + `after update where new.status = "done" and old.recurrence is not empty create title=old.title priority=old.priority tags=old.tags recurrence=old.recurrence due=next_date(old.recurrence) status="ready"`, + "update", + true, false, false, false, + }, + { + "recurring task clear recurrence", + `after update where new.status = "done" and old.recurrence is not empty update where id = old.id set recurrence=empty`, + "update", + false, true, false, false, + }, + { + "auto assign urgent", + `after create where new.priority <= 2 and new.assignee is empty update where id = new.id set assignee="booleanmaybe"`, + "create", + false, true, false, false, + }, + { + "cascade epic completion", + `after update where new.status = "done" update where id in blocks(old.id) and type = "epic" and dependsOn all status = "done" set status="done"`, + "update", + false, true, false, false, + }, + { + "reopen epic on regression", + `after update where old.status = "done" and new.status != "done" update where id in blocks(old.id) and type = "epic" and status = "done" set status="in progress"`, + "update", + false, true, false, false, + }, + { + "auto tag bugs", + `after create where new.type = "bug" update where id = new.id set tags=new.tags + ["needs-triage"]`, + "create", + false, true, false, false, + }, + { + "propagate cancellation", + `after update where new.status = "cancelled" update where id in blocks(old.id) and status in ["backlog", "ready"] set status="cancelled"`, + "update", + false, true, false, false, + }, + { + "unblock on last blocker", + `after update where new.status = "done" update where old.id in dependsOn and dependsOn all status = "done" and status = "backlog" set status="ready"`, + "update", + false, true, false, false, + }, + { + "cleanup on delete", + `after delete update where old.id in dependsOn set dependsOn=dependsOn - [old.id]`, + "delete", + false, true, false, false, + }, + { + "auto delete stale", + `after update where new.status = "done" and old.updatedAt < now() - 2day delete where id = old.id`, + "update", + false, false, true, false, + }, + { + "run action", + `after update where new.status = "in progress" and "claude" in new.tags run("claude -p 'implement tiki " + old.id + "'")`, + "update", + false, false, false, true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + trig, err := p.ParseTrigger(tt.input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if trig.Timing != "after" { + t.Fatalf("expected after, got %s", trig.Timing) + } + if trig.Event != tt.event { + t.Fatalf("expected %s, got %s", tt.event, trig.Event) + } + if trig.Deny != nil { + t.Fatal("expected nil Deny in after-trigger") + } + + if tt.wantRun { + if trig.Run == nil { + t.Fatal("expected Run action, got nil") + } + } else { + if trig.Action == nil { + t.Fatal("expected Action, got nil") + } + if tt.wantCreate && trig.Action.Create == nil { + t.Fatal("expected Create action") + } + if tt.wantUpdate && trig.Action.Update == nil { + t.Fatal("expected Update action") + } + if tt.wantDelete && trig.Action.Delete == nil { + t.Fatal("expected Delete action") + } + } + }) + } +} + +func TestParseTrigger_StructuralErrors(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + { + "before with action", + `before update where new.status = "done" update where id = old.id set status="done"`, + }, + { + "after with deny", + `after update where new.status = "done" deny "no"`, + }, + { + "before without deny", + `before update where new.status = "done"`, + }, + { + "after without action", + `after update where new.status = "done"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseTrigger(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + }) + } +} + +func TestParseTrigger_QualifiedRefsInWhere(t *testing.T) { + p := newTestParser() + + input := `before update where old.status = "in progress" and new.status = "done" deny "skip"` + trig, err := p.ParseTrigger(input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + + bc, ok := trig.Where.(*BinaryCondition) + if !ok { + t.Fatalf("expected BinaryCondition, got %T", trig.Where) + } + + // check left side has old.status + leftCmp, ok := bc.Left.(*CompareExpr) + if !ok { + t.Fatalf("expected CompareExpr on left, got %T", bc.Left) + } + qr, ok := leftCmp.Left.(*QualifiedRef) + if !ok { + t.Fatalf("expected QualifiedRef, got %T", leftCmp.Left) + } + if qr.Qualifier != "old" || qr.Name != "status" { + t.Fatalf("expected old.status, got %s.%s", qr.Qualifier, qr.Name) + } + + // check right side has new.status + rightCmp, ok := bc.Right.(*CompareExpr) + if !ok { + t.Fatalf("expected CompareExpr on right, got %T", bc.Right) + } + qr2, ok := rightCmp.Left.(*QualifiedRef) + if !ok { + t.Fatalf("expected QualifiedRef, got %T", rightCmp.Left) + } + if qr2.Qualifier != "new" || qr2.Name != "status" { + t.Fatalf("expected new.status, got %s.%s", qr2.Qualifier, qr2.Name) + } +} + +func TestParseTrigger_NoWhereGuard(t *testing.T) { + p := newTestParser() + + input := `after delete update where old.id in dependsOn set dependsOn=dependsOn - [old.id]` + trig, err := p.ParseTrigger(input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + if trig.Where != nil { + t.Fatal("expected nil Where guard") + } + if trig.Action == nil || trig.Action.Update == nil { + t.Fatal("expected Update action") + } +} diff --git a/ruki/validate.go b/ruki/validate.go new file mode 100644 index 0000000..3df6d11 --- /dev/null +++ b/ruki/validate.go @@ -0,0 +1,792 @@ +package ruki + +import "fmt" + +// validate.go — structural validation and semantic type-checking. + +// qualifierPolicy controls which old./new. qualifiers are allowed during validation. +type qualifierPolicy struct { + allowOld bool + allowNew bool +} + +// no qualifiers allowed (standalone statements). +var noQualifiers = qualifierPolicy{} + +func triggerQualifiers(event string) qualifierPolicy { + switch event { + case "create": + return qualifierPolicy{allowNew: true} + case "delete": + return qualifierPolicy{allowOld: true} + default: // "update" + return qualifierPolicy{allowOld: true, allowNew: true} + } +} + +// known builtins and their return types. +var builtinFuncs = map[string]struct { + returnType ValueType + minArgs int + maxArgs int +}{ + "count": {ValueInt, 1, 1}, + "now": {ValueTimestamp, 0, 0}, + "next_date": {ValueDate, 1, 1}, + "blocks": {ValueListRef, 1, 1}, + "contains": {ValueBool, 2, 2}, + "call": {ValueString, 1, 1}, + "user": {ValueString, 0, 0}, +} + +// --- structural validation --- + +func (p *Parser) validateStatement(s *Statement) error { + switch { + case s.Create != nil: + if len(s.Create.Assignments) == 0 { + return fmt.Errorf("create must have at least one assignment") + } + return p.validateAssignments(s.Create.Assignments) + case s.Update != nil: + if len(s.Update.Set) == 0 { + return fmt.Errorf("update must have at least one assignment in set") + } + if err := p.validateCondition(s.Update.Where); err != nil { + return err + } + return p.validateAssignments(s.Update.Set) + case s.Delete != nil: + return p.validateCondition(s.Delete.Where) + case s.Select != nil: + if s.Select.Where != nil { + return p.validateCondition(s.Select.Where) + } + return nil + default: + return fmt.Errorf("empty statement") + } +} + +func (p *Parser) validateTrigger(t *Trigger) error { + if t.Timing == "before" { + if t.Action != nil || t.Run != nil { + return fmt.Errorf("before-trigger must not have an action") + } + if t.Deny == nil { + return fmt.Errorf("before-trigger must have deny") + } + } + if t.Timing == "after" { + if t.Deny != nil { + return fmt.Errorf("after-trigger must not have deny") + } + if t.Action == nil && t.Run == nil { + return fmt.Errorf("after-trigger must have an action") + } + } + + if t.Where != nil { + if err := p.validateCondition(t.Where); err != nil { + return err + } + } + + if t.Action != nil { + if t.Action.Select != nil { + return fmt.Errorf("trigger action must not be select") + } + if err := p.validateStatement(t.Action); err != nil { + return err + } + } + + if t.Run != nil { + typ, err := p.inferExprType(t.Run.Command) + if err != nil { + return fmt.Errorf("run command: %w", err) + } + if typ != ValueString { + return fmt.Errorf("run command must be string, got %s", typeName(typ)) + } + } + + return nil +} + +func (p *Parser) validateAssignments(assignments []Assignment) error { + seen := make(map[string]struct{}, len(assignments)) + for _, a := range assignments { + if _, dup := seen[a.Field]; dup { + return fmt.Errorf("duplicate assignment to field %q", a.Field) + } + seen[a.Field] = struct{}{} + fs, ok := p.schema.Field(a.Field) + if !ok { + return fmt.Errorf("unknown field %q in assignment", a.Field) + } + rhsType, err := p.inferExprType(a.Value) + if err != nil { + return fmt.Errorf("field %q: %w", a.Field, err) + } + if err := p.checkAssignmentCompat(fs.Type, rhsType, a.Value); err != nil { + return fmt.Errorf("field %q: %w", a.Field, err) + } + } + return nil +} + +// --- condition validation with type-checking --- + +func (p *Parser) validateCondition(c Condition) error { + switch c := c.(type) { + case *BinaryCondition: + if err := p.validateCondition(c.Left); err != nil { + return err + } + return p.validateCondition(c.Right) + + case *NotCondition: + return p.validateCondition(c.Inner) + + case *CompareExpr: + return p.validateCompare(c) + + case *IsEmptyExpr: + _, err := p.inferExprType(c.Expr) + return err + + case *InExpr: + return p.validateIn(c) + + case *QuantifierExpr: + return p.validateQuantifier(c) + + default: + return fmt.Errorf("unknown condition type %T", c) + } +} + +func (p *Parser) validateCompare(c *CompareExpr) error { + leftType, err := p.inferExprType(c.Left) + if err != nil { + return err + } + rightType, err := p.inferExprType(c.Right) + if err != nil { + return err + } + + // resolve empty from context + leftType, rightType = resolveEmptyPair(leftType, rightType) + + if !typesCompatible(leftType, rightType) { + return fmt.Errorf("cannot compare %s %s %s", typeName(leftType), c.Op, typeName(rightType)) + } + + // reject cross-type comparisons involving enum fields, + // unless the other side is a string literal (e.g. status = "done") + if err := p.checkCompareCompat(leftType, rightType, c.Left, c.Right); err != nil { + return err + } + + // use the most specific type for operator and enum validation + enumType := leftType + if rightType == ValueStatus || rightType == ValueTaskType { + enumType = rightType + } + + if err := checkCompareOp(enumType, c.Op); err != nil { + return err + } + + return p.validateEnumLiterals(c.Left, c.Right, enumType) +} + +func (p *Parser) validateIn(c *InExpr) error { + valType, err := p.inferExprType(c.Value) + if err != nil { + return err + } + + // infer collection type first — this validates list homogeneity + collType, err := p.inferExprType(c.Collection) + if err != nil { + return err + } + + // get the actual element type, checking literal elements directly + elemType, err := p.inferListElementType(c.Collection) + if err != nil { + return err + } + + if listElementType(collType) == -1 { + return fmt.Errorf("%s is not a collection type; use contains() for substring checks", typeName(collType)) + } + + if !membershipCompatible(valType, elemType) { + // allow string-like values in list literals whose elements are all string literals + ll, isLiteral := c.Collection.(*ListLiteral) + if !isLiteral || !isStringLike(valType) || !allStringLiterals(ll) { + return fmt.Errorf("element type mismatch: %s in %s", typeName(valType), typeName(collType)) + } + } + + return p.validateEnumListElements(c.Collection, valType) +} + +func (p *Parser) validateQuantifier(q *QuantifierExpr) error { + exprType, err := p.inferExprType(q.Expr) + if err != nil { + return err + } + if exprType != ValueListRef { + return fmt.Errorf("quantifier %s requires list, got %s", q.Kind, typeName(exprType)) + } + saved := p.qualifiers + p.qualifiers = noQualifiers + err = p.validateCondition(q.Condition) + p.qualifiers = saved + return err +} + +// --- type inference --- + +func (p *Parser) inferExprType(e Expr) (ValueType, error) { + switch e := e.(type) { + case *FieldRef: + fs, ok := p.schema.Field(e.Name) + if !ok { + return 0, fmt.Errorf("unknown field %q", e.Name) + } + return fs.Type, nil + + case *QualifiedRef: + if e.Qualifier == "old" && !p.qualifiers.allowOld { + return 0, fmt.Errorf("old. qualifier is not valid in this context") + } + if e.Qualifier == "new" && !p.qualifiers.allowNew { + return 0, fmt.Errorf("new. qualifier is not valid in this context") + } + fs, ok := p.schema.Field(e.Name) + if !ok { + return 0, fmt.Errorf("unknown field %q in %s.%s", e.Name, e.Qualifier, e.Name) + } + return fs.Type, nil + + case *StringLiteral: + return ValueString, nil + + case *IntLiteral: + return ValueInt, nil + + case *DateLiteral: + return ValueDate, nil + + case *DurationLiteral: + return ValueDuration, nil + + case *ListLiteral: + return p.inferListType(e) + + case *EmptyLiteral: + return -1, nil // sentinel: resolved from context + + case *FunctionCall: + return p.inferFuncCallType(e) + + case *BinaryExpr: + return p.inferBinaryExprType(e) + + case *SubQuery: + return 0, fmt.Errorf("subquery is only valid as argument to count()") + + default: + return 0, fmt.Errorf("unknown expression type %T", e) + } +} + +func (p *Parser) inferListType(l *ListLiteral) (ValueType, error) { + if len(l.Elements) == 0 { + return ValueListString, nil // default empty list type + } + firstType, err := p.inferExprType(l.Elements[0]) + if err != nil { + return 0, err + } + for i := 1; i < len(l.Elements); i++ { + t, err := p.inferExprType(l.Elements[i]) + if err != nil { + return 0, err + } + if !typesCompatible(firstType, t) { + return 0, fmt.Errorf("list elements must be the same type: got %s and %s", typeName(firstType), typeName(t)) + } + } + switch firstType { + case ValueRef, ValueID: + return ValueListRef, nil + default: + return ValueListString, nil + } +} + +// inferListElementType returns the element type of a list expression, +// checking literal elements directly when the list type enum is too coarse. +func (p *Parser) inferListElementType(e Expr) (ValueType, error) { + if ll, ok := e.(*ListLiteral); ok && len(ll.Elements) > 0 { + return p.inferExprType(ll.Elements[0]) + } + collType, err := p.inferExprType(e) + if err != nil { + return 0, err + } + elem := listElementType(collType) + if elem == -1 { + return collType, nil // not a list type — return as-is for error reporting + } + return elem, nil +} + +func (p *Parser) inferFuncCallType(fc *FunctionCall) (ValueType, error) { + builtin, ok := builtinFuncs[fc.Name] + if !ok { + return 0, fmt.Errorf("unknown function %q", fc.Name) + } + if len(fc.Args) < builtin.minArgs || len(fc.Args) > builtin.maxArgs { + if builtin.minArgs == builtin.maxArgs { + return 0, fmt.Errorf("%s() expects %d argument(s), got %d", fc.Name, builtin.minArgs, len(fc.Args)) + } + return 0, fmt.Errorf("%s() expects %d-%d arguments, got %d", fc.Name, builtin.minArgs, builtin.maxArgs, len(fc.Args)) + } + + // validate argument types for specific functions + switch fc.Name { + case "count": + sq, ok := fc.Args[0].(*SubQuery) + if !ok { + return 0, fmt.Errorf("count() argument must be a select subquery") + } + if sq.Where != nil { + if err := p.validateCondition(sq.Where); err != nil { + return 0, fmt.Errorf("count() subquery: %w", err) + } + } + case "blocks": + argType, err := p.inferExprType(fc.Args[0]) + if err != nil { + return 0, err + } + if argType != ValueID && argType != ValueRef && argType != ValueString { + return 0, fmt.Errorf("blocks() argument must be an id or ref, got %s", typeName(argType)) + } + if argType == ValueString { + if _, ok := fc.Args[0].(*StringLiteral); !ok { + return 0, fmt.Errorf("blocks() argument must be an id or ref, got %s", typeName(argType)) + } + } + case "contains": + for i, arg := range fc.Args { + t, err := p.inferExprType(arg) + if err != nil { + return 0, err + } + if t != ValueString { + return 0, fmt.Errorf("contains() argument %d must be string, got %s", i+1, typeName(t)) + } + } + case "call": + t, err := p.inferExprType(fc.Args[0]) + if err != nil { + return 0, err + } + if t != ValueString { + return 0, fmt.Errorf("call() argument must be string, got %s", typeName(t)) + } + case "next_date": + t, err := p.inferExprType(fc.Args[0]) + if err != nil { + return 0, err + } + if t != ValueRecurrence { + return 0, fmt.Errorf("next_date() argument must be recurrence, got %s", typeName(t)) + } + } + + return builtin.returnType, nil +} + +func (p *Parser) inferBinaryExprType(b *BinaryExpr) (ValueType, error) { + leftType, err := p.inferExprType(b.Left) + if err != nil { + return 0, err + } + rightType, err := p.inferExprType(b.Right) + if err != nil { + return 0, err + } + + leftType, rightType = resolveEmptyPair(leftType, rightType) + + switch b.Op { + case "+": + return p.inferPlusType(leftType, rightType, b.Right) + case "-": + return p.inferMinusType(leftType, rightType, b.Right) + default: + return 0, fmt.Errorf("unknown binary operator %q", b.Op) + } +} + +func isStringLike(t ValueType) bool { + switch t { + case ValueString, ValueStatus, ValueTaskType, ValueID, ValueRef: + return true + default: + return false + } +} + +func (p *Parser) inferPlusType(left, right ValueType, rightExpr Expr) (ValueType, error) { + switch { + case isStringLike(left) && isStringLike(right): + return ValueString, nil + case left == ValueInt && right == ValueInt: + return ValueInt, nil + case left == ValueListString && (right == ValueString || right == ValueListString): + return ValueListString, nil + case left == ValueListRef && (isRefCompatible(right) || right == ValueListRef): + return ValueListRef, nil + case left == ValueListRef && right == ValueString: + if _, ok := rightExpr.(*StringLiteral); ok { + return ValueListRef, nil + } + return 0, fmt.Errorf("cannot add %s + %s", typeName(left), typeName(right)) + case left == ValueListRef && right == ValueListString: + if _, ok := rightExpr.(*ListLiteral); ok { + return ValueListRef, nil + } + return 0, fmt.Errorf("cannot add list field to list") + case left == ValueDate && right == ValueDuration: + return ValueDate, nil + case left == ValueTimestamp && right == ValueDuration: + return ValueTimestamp, nil + default: + return 0, fmt.Errorf("cannot add %s + %s", typeName(left), typeName(right)) + } +} + +func (p *Parser) inferMinusType(left, right ValueType, rightExpr Expr) (ValueType, error) { + switch { + case left == ValueListString && (right == ValueString || right == ValueListString): + return ValueListString, nil + case left == ValueListRef && (isRefCompatible(right) || right == ValueListRef): + return ValueListRef, nil + case left == ValueListRef && right == ValueString: + if _, ok := rightExpr.(*StringLiteral); ok { + return ValueListRef, nil + } + return 0, fmt.Errorf("cannot subtract %s - %s", typeName(left), typeName(right)) + case left == ValueListRef && right == ValueListString: + if _, ok := rightExpr.(*ListLiteral); ok { + return ValueListRef, nil + } + return 0, fmt.Errorf("cannot subtract list field from list") + case left == ValueInt && right == ValueInt: + return ValueInt, nil + case left == ValueDate && right == ValueDuration: + return ValueDate, nil + case left == ValueDate && right == ValueDate: + return ValueDuration, nil + case left == ValueTimestamp && right == ValueDuration: + return ValueTimestamp, nil + case left == ValueTimestamp && right == ValueTimestamp: + return ValueDuration, nil + default: + return 0, fmt.Errorf("cannot subtract %s - %s", typeName(left), typeName(right)) + } +} + +// --- enum literal validation --- + +func (p *Parser) validateEnumLiterals(left, right Expr, resolvedType ValueType) error { + if resolvedType == ValueStatus { + if s, ok := right.(*StringLiteral); ok { + if _, valid := p.schema.NormalizeStatus(s.Value); !valid { + return fmt.Errorf("unknown status %q", s.Value) + } + } + if s, ok := left.(*StringLiteral); ok { + if _, valid := p.schema.NormalizeStatus(s.Value); !valid { + return fmt.Errorf("unknown status %q", s.Value) + } + } + } + if resolvedType == ValueTaskType { + if s, ok := right.(*StringLiteral); ok { + if _, valid := p.schema.NormalizeType(s.Value); !valid { + return fmt.Errorf("unknown type %q", s.Value) + } + } + if s, ok := left.(*StringLiteral); ok { + if _, valid := p.schema.NormalizeType(s.Value); !valid { + return fmt.Errorf("unknown type %q", s.Value) + } + } + } + return nil +} + +// validateEnumListElements checks string literals inside a list expression +// against the appropriate enum normalizer, based on the value type being checked. +func (p *Parser) validateEnumListElements(collection Expr, valType ValueType) error { + ll, ok := collection.(*ListLiteral) + if !ok { + return nil + } + for _, elem := range ll.Elements { + s, ok := elem.(*StringLiteral) + if !ok { + continue + } + if valType == ValueStatus { + if _, valid := p.schema.NormalizeStatus(s.Value); !valid { + return fmt.Errorf("unknown status %q", s.Value) + } + } + if valType == ValueTaskType { + if _, valid := p.schema.NormalizeType(s.Value); !valid { + return fmt.Errorf("unknown type %q", s.Value) + } + } + } + return nil +} + +// --- assignment compatibility --- + +func (p *Parser) checkAssignmentCompat(fieldType, rhsType ValueType, rhs Expr) error { + // empty is assignable to anything + if _, ok := rhs.(*EmptyLiteral); ok { + return nil + } + if rhsType == -1 { // unresolved empty + return nil + } + + if typesCompatible(fieldType, rhsType) { + // enum fields only accept same-type or string literals + if (fieldType == ValueStatus || fieldType == ValueTaskType) && rhsType != fieldType { + if _, ok := rhs.(*StringLiteral); !ok { + return fmt.Errorf("cannot assign %s to %s field", typeName(rhsType), typeName(fieldType)) + } + } + // non-enum string-like fields reject enum-typed RHS + if (fieldType == ValueString || fieldType == ValueID || fieldType == ValueRef) && + (rhsType == ValueStatus || rhsType == ValueTaskType) { + return fmt.Errorf("cannot assign %s to %s field", typeName(rhsType), typeName(fieldType)) + } + + // list field rejects list literals with non-string elements + if fieldType == ValueListString { + if ll, ok := rhs.(*ListLiteral); ok { + for _, elem := range ll.Elements { + elemType, err := p.inferExprType(elem) + if err == nil && elemType != ValueString { + if _, isLit := elem.(*StringLiteral); !isLit { + return fmt.Errorf("cannot assign %s to list field", typeName(elemType)) + } + } + } + } + } + + // validate enum values + if fieldType == ValueStatus { + if s, ok := rhs.(*StringLiteral); ok { + if _, valid := p.schema.NormalizeStatus(s.Value); !valid { + return fmt.Errorf("unknown status %q", s.Value) + } + } + } + if fieldType == ValueTaskType { + if s, ok := rhs.(*StringLiteral); ok { + if _, valid := p.schema.NormalizeType(s.Value); !valid { + return fmt.Errorf("unknown type %q", s.Value) + } + } + } + return nil + } + + // list literal is assignable to list, but only if all elements are string literals + if fieldType == ValueListRef && rhsType == ValueListString { + if ll, ok := rhs.(*ListLiteral); ok && allStringLiterals(ll) { + return nil + } + } + + return fmt.Errorf("cannot assign %s to %s field", typeName(rhsType), typeName(fieldType)) +} + +// --- type helpers --- + +func typesCompatible(a, b ValueType) bool { + if a == b { + return true + } + if a == -1 || b == -1 { // unresolved empty + return true + } + // string-like types are compatible with each other + stringLike := map[ValueType]bool{ + ValueString: true, + ValueStatus: true, + ValueTaskType: true, + ValueID: true, + ValueRef: true, + } + return stringLike[a] && stringLike[b] +} + +func isEnumType(t ValueType) bool { + return t == ValueStatus || t == ValueTaskType +} + +// allStringLiterals returns true if every element in the list is a *StringLiteral. +func allStringLiterals(ll *ListLiteral) bool { + for _, elem := range ll.Elements { + if _, ok := elem.(*StringLiteral); !ok { + return false + } + } + return true +} + +// checkCompareCompat rejects nonsensical cross-type comparisons in WHERE clauses. +// e.g. status = title (enum vs string field) is rejected, +// but status = "done" (enum vs string literal) is allowed. +func (p *Parser) checkCompareCompat(leftType, rightType ValueType, left, right Expr) error { + if isEnumType(leftType) && rightType != leftType { + if err := checkEnumOperand(leftType, rightType, right); err != nil { + return err + } + } + if isEnumType(rightType) && leftType != rightType { + if err := checkEnumOperand(rightType, leftType, left); err != nil { + return err + } + } + return nil +} + +func checkEnumOperand(enumType, otherType ValueType, other Expr) error { + if otherType == ValueString { + if _, ok := other.(*StringLiteral); !ok { + return fmt.Errorf("cannot compare %s with %s field", typeName(enumType), typeName(otherType)) + } + return nil + } + return fmt.Errorf("cannot compare %s with %s", typeName(enumType), typeName(otherType)) +} + +// membershipCompatible checks strict type compatibility for in/not in +// expressions. Unlike typesCompatible, it does not treat all string-like +// types as interchangeable — only ID and Ref are interchangeable. +func membershipCompatible(a, b ValueType) bool { + if a == b { + return true + } + if a == -1 || b == -1 { + return true + } + // ID and Ref are the same concept + if (a == ValueID || a == ValueRef) && (b == ValueID || b == ValueRef) { + return true + } + return false +} + +// isRefCompatible returns true for types that can appear as operands +// in list add/remove operations. +func isRefCompatible(t ValueType) bool { + switch t { + case ValueRef, ValueID: + return true + default: + return false + } +} + +func resolveEmptyPair(a, b ValueType) (ValueType, ValueType) { + if a == -1 && b != -1 { + a = b + } + if b == -1 && a != -1 { + b = a + } + return a, b +} + +func listElementType(t ValueType) ValueType { + switch t { + case ValueListString: + return ValueString + case ValueListRef: + return ValueRef + default: + return -1 + } +} + +func checkCompareOp(t ValueType, op string) error { + switch op { + case "=", "!=": + return nil // all types support equality + case "<", ">", "<=", ">=": + switch t { + case ValueInt, ValueDate, ValueTimestamp, ValueDuration: + return nil + default: + return fmt.Errorf("operator %s not supported for %s", op, typeName(t)) + } + default: + return fmt.Errorf("unknown operator %q", op) + } +} + +func typeName(t ValueType) string { + switch t { + case ValueString: + return "string" + case ValueInt: + return "int" + case ValueDate: + return "date" + case ValueTimestamp: + return "timestamp" + case ValueDuration: + return "duration" + case ValueBool: + return "bool" + case ValueID: + return "id" + case ValueRef: + return "ref" + case ValueRecurrence: + return "recurrence" + case ValueListString: + return "list" + case ValueListRef: + return "list" + case ValueStatus: + return "status" + case ValueTaskType: + return "type" + case -1: + return "empty" + default: + return "unknown" + } +} diff --git a/ruki/validate_test.go b/ruki/validate_test.go new file mode 100644 index 0000000..bf7174e --- /dev/null +++ b/ruki/validate_test.go @@ -0,0 +1,1561 @@ +package ruki + +import ( + "strings" + "testing" +) + +func TestValidation_TypeMismatch(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "priority equals string", + `select where priority = "high"`, + "cannot compare", + }, + { + "string field ordered compare", + `select where status < "done"`, + "operator < not supported", + }, + { + "int to string assignment", + `create title="x" priority="high"`, + "cannot assign string to int field", + }, + { + "string to int field", + `create title="x" points="five"`, + "cannot assign string to int field", + }, + { + "int to string field", + `create title="x" assignee=42`, + "cannot assign int to string field", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_UnknownField(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + {"unknown field in where", `select where foo = "bar"`, "unknown field"}, + {"unknown field in assignment", `create title="x" foo="bar"`, "unknown field"}, + {"unknown qualified field in statement", `select where old.foo = "bar"`, "old. qualifier is not valid"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_UnknownFunction(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`select where foo(1) = 1`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unknown function") { + t.Fatalf("expected unknown function error, got: %v", err) + } +} + +func TestValidation_FunctionArgCount(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"now with args", `select where now(1) = now()`}, + {"count no args", `select where count() >= 1`}, + {"contains one arg", `select where contains("a") = "b"`}, + {"user with args", `select where user(1) = "bob"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "argument") { + t.Fatalf("expected argument count error, got: %v", err) + } + }) + } +} + +func TestValidation_QuantifierRequiresListRef(t *testing.T) { + p := newTestParser() + + // tags is list, not list + _, err := p.ParseStatement(`select where tags any status = "done"`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "requires list") { + t.Fatalf("expected list error, got: %v", err) + } +} + +func TestValidation_CountRequiresSubquery(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`select where count(1) >= 3`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "subquery") { + t.Fatalf("expected subquery error, got: %v", err) + } +} + +func TestValidation_UnknownStatus(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`select where status = "nonexistent"`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unknown status") { + t.Fatalf("expected unknown status error, got: %v", err) + } +} + +func TestValidation_UnknownType(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`select where type = "nonexistent"`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unknown type") { + t.Fatalf("expected unknown type error, got: %v", err) + } +} + +func TestValidation_ValidStatusAndType(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"valid status", `select where status = "done"`}, + {"valid status alias", `select where status = "in progress"`}, + {"valid type", `select where type = "bug"`}, + {"valid type alias", `select where type = "feature"`}, + {"valid status in assignment", `create title="x" status="done"`}, + {"valid type in assignment", `create title="x" type="bug"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_BinaryExprTypes(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "date minus date yields duration — assigned to date is wrong", + `create title="x" due=2026-03-25 - 2026-03-20`, + "cannot assign duration to date field", + }, + { + "string minus string", + `create title="x" - "y"`, + "cannot subtract", + }, + { + "int plus string", + `create title="x" priority=1 + "a"`, + "cannot add", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_ValidBinaryExprTypes(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"string concat", `create title="hello" + " world"`}, + {"list append", `create title="x" tags=tags + ["new"]`}, + {"list remove", `create title="x" tags=tags - ["old"]`}, + {"date plus duration", `create title="x" due=2026-03-25 + 2day`}, + {"date minus duration", `create title="x" due=2026-03-25 - 1week`}, + {"list ref append", `create title="x" dependsOn=dependsOn + ["TIKI-ABC123"]`}, + {"list ref remove", `create title="x" dependsOn=dependsOn - ["TIKI-ABC123"]`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_RunCommandMustBeString(t *testing.T) { + p := newTestParser() + + // int expression in run() should be rejected + _, err := p.ParseTrigger(`after update run(1 + 2)`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "must be string") { + t.Fatalf("expected string type error, got: %v", err) + } + + // valid: string expression in run() + _, err = p.ParseTrigger(`after update run("echo hello")`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidation_EmptyAssignments(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"empty to string field", `create title="x" assignee=empty`}, + {"empty to list field", `create title="x" tags=empty`}, + {"empty to date field", `create title="x" due=empty`}, + {"empty to int field", `create title="x" priority=empty`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_IsEmptyOnAllTypes(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"string is empty", `select where assignee is empty`}, + {"list is empty", `select where tags is empty`}, + {"date is empty", `select where due is empty`}, + {"int is empty", `select where priority is empty`}, + {"string is not empty", `select where title is not empty`}, + {"function result is empty", `select where blocks(id) is empty`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_SelectNotAllowedAsTriggerAction(t *testing.T) { + p := newTestParser() + + // select is rejected at parse level since the action grammar doesn't include it + _, err := p.ParseTrigger(`after update select`) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestValidation_FunctionArgTypes(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "blocks with non-id arg", + `select where blocks(priority) is empty`, + "blocks() argument must be an id or ref", + }, + { + "contains with non-string first arg", + `select where contains(1, "a") = contains("a", "b")`, + "contains() argument 1 must be string", + }, + { + "contains with non-string second arg", + `select where contains("a", 1) = contains("a", "b")`, + "contains() argument 2 must be string", + }, + { + "call with non-string arg", + `create title=call(42)`, + "call() argument must be string", + }, + { + "next_date with non-recurrence arg", + `create title="x" due=next_date(42)`, + "next_date() argument must be recurrence", + }, + { + "next_date with string field arg", + `create title="x" due=next_date(title)`, + "next_date() argument must be recurrence", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_ValidFunctionUsages(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"blocks with id field", `select where blocks(id) is empty`}, + {"blocks with id ref", `select where blocks("TIKI-ABC123") is empty`}, + {"call with string", `create title=call("echo hi")`}, + {"contains", `select where contains(title, "bug") = contains(title, "fix")`}, + {"user", `select where assignee = user()`}, + {"now", `select where updatedAt < now()`}, + {"count with subquery", `select where count(select where status = "done") >= 1`}, + {"count with bare select", `select where count(select) >= 0`}, + {"next_date with recurrence field", `create title="x" due=next_date(recurrence)`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_InExprTypes(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "string in list of ints — element type mismatch", + `select where "bug" in [1, 2]`, + "element type mismatch", + }, + { + "scalar field as collection — not a collection", + `select where "d" in title`, + "not a collection", + }, + { + "scalar string field as collection", + `select where "x" in assignee`, + "not a collection", + }, + { + "scalar int field as collection", + `select where 1 in priority`, + "not a collection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_EnumInList(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "invalid status in list", + `select where status in ["done", "bogus"]`, + "unknown status", + }, + { + "invalid type in list", + `select where type in ["bug", "bogus"]`, + "unknown type", + }, + { + "all invalid statuses in list", + `select where status in ["nope", "nada"]`, + "unknown status", + }, + { + "invalid status in not-in list", + `select where status not in ["done", "bogus"]`, + "unknown status", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_ValidInExpr(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"string in tags field", `select where "bug" in tags`}, + {"id in dependsOn field", `select where id in dependsOn`}, + {"status in list", `select where status in ["done", "cancelled"]`}, + {"status not in list", `select where status not in ["done"]`}, + {"int in list", `select where priority in [1, 2, 3]`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_TimestampArithmetic(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"timestamp plus duration", `select where updatedAt < now() + 1day`}, + {"timestamp minus duration", `select where updatedAt > now() - 1week`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_EmptyComparisons(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"field equals empty", `select where assignee = empty`}, + {"empty not equal field", `select where priority != empty`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_UnknownStatusInAssignment(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`create title="x" status="nonexistent"`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unknown status") { + t.Fatalf("expected unknown status error, got: %v", err) + } +} + +func TestValidation_UnknownTypeInAssignment(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`create title="x" type="nonexistent"`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unknown type") { + t.Fatalf("expected unknown type error, got: %v", err) + } +} + +func TestValidation_StatusOnLeftSide(t *testing.T) { + p := newTestParser() + + // status literal on the left side of comparison + _, err := p.ParseStatement(`select where "nonexistent" = status`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unknown status") { + t.Fatalf("expected unknown status error, got: %v", err) + } +} + +func TestValidation_TypeOnLeftSide(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`select where "nonexistent" = type`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unknown type") { + t.Fatalf("expected unknown type error, got: %v", err) + } +} + +func TestValidation_DurationCompare(t *testing.T) { + p := newTestParser() + + // duration supports ordering operators + _, err := p.ParseStatement(`select where updatedAt - createdAt > 7day`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidation_ListHomogeneity(t *testing.T) { + p := newTestParser() + + _, err := p.ParseStatement(`select where status in ["done", 1]`) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "list elements must be the same type") { + t.Fatalf("expected homogeneity error, got: %v", err) + } +} + +func TestValidation_NestedConditions(t *testing.T) { + p := newTestParser() + + // exercise not + or paths + tests := []struct { + name string + input string + }{ + {"not with or", `select where not (status = "done" or priority = 1)`}, + {"double not", `select where not not status = "done"`}, + {"or chain", `select where status = "done" or status = "ready" or status = "backlog"`}, + {"and chain", `select where priority = 1 and status = "done" and assignee = "bob"`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_TriggerCreateAction(t *testing.T) { + p := newTestParser() + + // after-trigger with create action + _, err := p.ParseTrigger(`after update where new.status = "done" create title="follow-up" priority=3`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidation_TriggerDeleteAction(t *testing.T) { + p := newTestParser() + + _, err := p.ParseTrigger(`after update where new.status = "done" delete where id = old.id`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidation_ParenExpr(t *testing.T) { + p := newTestParser() + + // parenthesized expression + _, err := p.ParseStatement(`create title="x" priority=(1 + 2)`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidation_MoreBinaryExprErrors(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "date plus date", + `select where due = 2026-03-25 + 2026-03-20`, + "cannot add", + }, + { + "int minus string", + `create title="x" priority=1 - "a"`, + "cannot subtract", + }, + { + "duration minus duration", + `create title="x" due=1day - 2day`, + "cannot subtract", + }, + { + "bool in comparison", + `select where contains("a", "b") < contains("c", "d")`, + "operator < not supported for bool", + }, + { + "list ordered compare", + `select where tags < ["a"]`, + "operator < not supported", + }, + { + "recurrence ordered compare", + `select where recurrence < recurrence`, + "operator < not supported", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_TimestampMinusTimestamp(t *testing.T) { + p := newTestParser() + + // timestamp - timestamp = duration; comparing to duration + _, err := p.ParseStatement(`select where updatedAt - createdAt > 1day`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidation_EmptyInBinaryExpr(t *testing.T) { + p := newTestParser() + + // empty + empty — should resolve but might fail on the operator + _, err := p.ParseStatement(`create title="x" tags=empty + empty`) + // this may error — just exercise the code path + _ = err +} + +func TestValidation_DateCompareOps(t *testing.T) { + p := newTestParser() + + ops := []string{"=", "!=", "<", ">", "<=", ">="} + for _, op := range ops { + t.Run(op, func(t *testing.T) { + input := `select where due ` + op + ` 2026-03-25` + _, err := p.ParseStatement(input) + if err != nil { + t.Fatalf("unexpected error for %s: %v", op, err) + } + }) + } +} + +func TestValidation_IntCompareOps(t *testing.T) { + p := newTestParser() + + ops := []string{"=", "!=", "<", ">", "<=", ">="} + for _, op := range ops { + t.Run(op, func(t *testing.T) { + input := `select where priority ` + op + ` 3` + _, err := p.ParseStatement(input) + if err != nil { + t.Fatalf("unexpected error for %s: %v", op, err) + } + }) + } +} + +func TestValidation_StringCompareOps(t *testing.T) { + p := newTestParser() + + // equality should work + _, err := p.ParseStatement(`select where title = "hello"`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // inequality should work + _, err = p.ParseStatement(`select where title != "hello"`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // ordering should fail + _, err = p.ParseStatement(`select where title < "hello"`) + if err == nil { + t.Fatal("expected error for title < string") + } +} + +func TestValidation_IDCompare(t *testing.T) { + p := newTestParser() + + // id equality + _, err := p.ParseStatement(`select where id = "TIKI-ABC123"`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // id ordering should fail + _, err = p.ParseStatement(`select where id < "TIKI-ABC123"`) + if err == nil { + t.Fatal("expected error for id < string") + } +} + +func TestValidation_BareSubqueryRejected(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "subquery in where comparison", + `select where select = 1`, + "subquery", + }, + { + "subquery in create assignment", + `create title=select`, + "subquery", + }, + { + "subquery in update assignment", + `update where status = "done" set title=select`, + "subquery", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_QualifiedRefInStatement(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "old ref in select", + `select where old.status = "done"`, + "old.", + }, + { + "new ref in select", + `select where new.status = "done"`, + "new.", + }, + { + "old ref in create", + `create title=old.title`, + "old.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_EnumAssignmentStrictness(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "string field to status", + `create title="x" status=title`, + "cannot assign string to status", + }, + { + "id field to status", + `create title="x" status=id`, + "cannot assign id to status", + }, + { + "string field to type", + `create title="x" type=title`, + "cannot assign string to type", + }, + { + "status field to string", + `update where id="x" set title=status`, + "cannot assign status to string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_InExprStrictTypes(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "id in tags", + `select where id in tags`, + "element type mismatch", + }, + { + "status in tags", + `select where status in tags`, + "element type mismatch", + }, + { + "status in dependsOn", + `select where status in dependsOn`, + "element type mismatch", + }, + { + "type in tags", + `select where type in tags`, + "element type mismatch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_ListRefOperandStrictness(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "list ref plus status", + `create title="x" dependsOn=dependsOn + status`, + "cannot add", + }, + { + "list ref plus type", + `create title="x" dependsOn=dependsOn + type`, + "cannot add", + }, + { + "list ref minus status", + `create title="x" dependsOn=dependsOn - status`, + "cannot subtract", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_QualifiedRefInTrigger(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "old ref in after create guard", + `after create where old.status = "done" update where id = new.id set status="done"`, + "old.", + }, + { + "new ref in before delete", + `before delete where new.status = "done" deny "x"`, + "new.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseTrigger(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_DependsOnListLiteralAssignment(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + {"single ref", `create title="x" dependsOn=["TIKI-ABC123"]`}, + {"multiple refs", `create title="x" dependsOn=["TIKI-ABC123", "TIKI-DEF456"]`}, + {"update set dependsOn", `update where id="TIKI-1" set dependsOn=["TIKI-ABC123"]`}, + {"empty list", `create title="x" dependsOn=[]`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_ListStringRejectsNonStringElements(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "int elements in tags", + `create title="x" tags=[1, 2]`, + "cannot assign", + }, + { + "date elements in tags", + `create title="x" tags=[2026-03-25]`, + "cannot assign", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_ListRefRejectsListStringField(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "list ref plus list string field", + `create title="x" dependsOn=dependsOn + tags`, + "cannot add list field to list", + }, + { + "list ref minus list string field", + `create title="x" dependsOn=dependsOn - tags`, + "cannot subtract list field from list", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } + + // regression: list literals still allowed with list + valid := []struct { + name string + input string + }{ + {"list ref plus list ref field", `create title="x" dependsOn=dependsOn + dependsOn`}, + {"list ref plus string literal list", `create title="x" dependsOn=dependsOn + ["TIKI-ABC123"]`}, + {"list ref minus string literal list", `create title="x" dependsOn=dependsOn - ["TIKI-ABC123"]`}, + } + + for _, tt := range valid { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_CountSubqueryValidated(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "unknown field in count subquery", + `select where count(select where nosuchfield = "x") >= 1`, + "unknown field", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } + + // valid: count with valid subquery still works + _, err := p.ParseStatement(`select where count(select where status = "done") >= 1`) + if err != nil { + t.Fatalf("unexpected error for valid count subquery: %v", err) + } + + // valid: count subquery can reference new. in trigger context (parameterized query) + _, err = p.ParseTrigger(`before update where new.status = "in progress" and count(select where assignee = new.assignee and status = "in progress") >= 3 deny "WIP limit"`) + if err != nil { + t.Fatalf("unexpected error for count subquery with new.: %v", err) + } +} + +func TestValidation_QuantifierNoQualifiers(t *testing.T) { + p := newTestParser() + + // old. inside quantifier body should be rejected even in update trigger + _, err := p.ParseTrigger(`before update where dependsOn any old.status = "done" deny "blocked"`) + if err == nil { + t.Fatal("expected error for old. in quantifier, got nil") + } + if !strings.Contains(err.Error(), "old.") { + t.Fatalf("expected old. qualifier error, got: %v", err) + } + + // new. inside quantifier body should also be rejected + _, err = p.ParseTrigger(`before update where dependsOn any new.status = "done" deny "blocked"`) + if err == nil { + t.Fatal("expected error for new. in quantifier, got nil") + } + if !strings.Contains(err.Error(), "new.") { + t.Fatalf("expected new. qualifier error, got: %v", err) + } + + // unqualified field inside quantifier should still work + _, err = p.ParseTrigger(`before update where dependsOn any status = "done" deny "blocked"`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidation_QualifiedRefValidInTrigger(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + }{ + { + "old and new in before update", + `before update where old.status = "in progress" and new.status = "done" deny "skip"`, + }, + { + "new in after create", + `after create where new.priority <= 2 update where id = new.id set assignee="bob"`, + }, + { + "old in after delete", + `after delete update where old.id in dependsOn set dependsOn=dependsOn - [old.id]`, + }, + { + "old and new in after update", + `after update where new.status = "done" update where id = old.id set recurrence=empty`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseTrigger(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_ListRefRejectsStringFields(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "list ref plus title field", + `create title="x" dependsOn=dependsOn + title`, + "cannot add", + }, + { + "list ref plus assignee field", + `create title="x" dependsOn=dependsOn + assignee`, + "cannot add", + }, + { + "list ref minus title field", + `create title="x" dependsOn=dependsOn - title`, + "cannot subtract", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } + + // string literals should still be allowed + valid := []struct { + name string + input string + }{ + {"list ref plus string literal", `create title="x" dependsOn=dependsOn + "TIKI-ABC123"`}, + {"list ref minus string literal", `create title="x" dependsOn=dependsOn - "TIKI-ABC123"`}, + } + + for _, tt := range valid { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_CompareEnumStrictness(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "status equals title field", + `select where status = title`, + "cannot compare", + }, + { + "status equals id field", + `select where status = id`, + "cannot compare", + }, + { + "status equals type field", + `select where status = type`, + "cannot compare", + }, + { + "type equals assignee field", + `select where type = assignee`, + "cannot compare", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } + + // these must remain valid + valid := []struct { + name string + input string + }{ + {"status equals string literal", `select where status = "done"`}, + {"type equals string literal", `select where type = "bug"`}, + {"id equals string literal", `select where id = "TIKI-ABC123"`}, + {"string field equals string literal", `select where assignee = "alice"`}, + } + + for _, tt := range valid { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_BlocksRejectsStringFields(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "blocks with title field", + `select where blocks(title) is empty`, + "blocks() argument must be an id or ref", + }, + { + "blocks with assignee field", + `select where blocks(assignee) is empty`, + "blocks() argument must be an id or ref", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } + + // these must remain valid + valid := []struct { + name string + input string + }{ + {"blocks with id field", `select where blocks(id) is empty`}, + {"blocks with string literal", `select where blocks("TIKI-ABC123") is empty`}, + } + + for _, tt := range valid { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_DuplicateAssignments(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "duplicate in create", + `create title="a" title="b"`, + "duplicate assignment", + }, + { + "duplicate in update set", + `update where id="x" set status="ready" status="done"`, + "duplicate assignment", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} + +func TestValidation_EnumInRejectsFieldRefs(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "status in list containing field ref", + `select where status in [title]`, + "element type mismatch", + }, + { + "type in list containing field ref", + `select where type in [assignee]`, + "element type mismatch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } + + // string literals should still be allowed + valid := []struct { + name string + input string + }{ + {"status in string literal list", `select where status in ["done", "ready"]`}, + {"type in string literal list", `select where type in ["bug", "epic"]`}, + } + + for _, tt := range valid { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestValidation_ListAssignmentRejectsFieldRefs(t *testing.T) { + p := newTestParser() + + tests := []struct { + name string + input string + wantErr string + }{ + { + "tags with field ref element", + `create title="x" tags=["bug", id]`, + "cannot assign", + }, + { + "dependsOn with non-literal element", + `create title="x" dependsOn=["TIKI-ABC123", title]`, + "cannot assign", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := p.ParseStatement(tt.input) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got: %v", tt.wantErr, err) + } + }) + } +} diff --git a/task/type.go b/task/type.go index 7ab423d..6101a77 100644 --- a/task/type.go +++ b/task/type.go @@ -64,3 +64,14 @@ func TypeEmoji(taskType Type) string { func TypeDisplay(taskType Type) string { return currentTypeRegistry().TypeDisplay(taskType) } + +// ParseDisplay reverses a TypeDisplay() string back to a canonical key. +// Returns (key, true) on match, or (fallback, false) for unrecognized display strings. +func ParseDisplay(display string) (Type, bool) { + return currentTypeRegistry().ParseDisplay(display) +} + +// AllTypes returns the ordered list of all configured type keys. +func AllTypes() []Type { + return currentTypeRegistry().Keys() +} diff --git a/view/taskdetail/task_edit_fields.go b/view/taskdetail/task_edit_fields.go index fa893da..f44b138 100644 --- a/view/taskdetail/task_edit_fields.go +++ b/view/taskdetail/task_edit_fields.go @@ -33,11 +33,10 @@ func (ev *TaskEditView) ensureStatusSelectList(task *taskpkg.Task) *component.Ed func (ev *TaskEditView) ensureTypeSelectList(task *taskpkg.Task) *component.EditSelectList { if ev.typeSelectList == nil { - typeOptions := []string{ - taskpkg.TypeDisplay(taskpkg.TypeStory), - taskpkg.TypeDisplay(taskpkg.TypeBug), - taskpkg.TypeDisplay(taskpkg.TypeSpike), - taskpkg.TypeDisplay(taskpkg.TypeEpic), + allTypes := taskpkg.AllTypes() + typeOptions := make([]string, len(allTypes)) + for i, t := range allTypes { + typeOptions[i] = taskpkg.TypeDisplay(t) } colors := config.GetColors() diff --git a/workflow/fields.go b/workflow/fields.go new file mode 100644 index 0000000..2577a85 --- /dev/null +++ b/workflow/fields.go @@ -0,0 +1,68 @@ +package workflow + +// ValueType identifies the semantic type of a task field. +type ValueType int + +const ( + TypeString ValueType = iota + TypeInt // numeric (priority, points) + TypeDate // midnight-UTC date (e.g. due) + TypeTimestamp // full timestamp (e.g. createdAt, updatedAt) + TypeDuration // reserved for future use + TypeBool // reserved for future use + TypeID // task identifier + TypeRef // reference to another task ID + TypeRecurrence // structured cron-based recurrence pattern + TypeListString // []string (e.g. tags) + TypeListRef // []string of task ID references (e.g. dependsOn) + TypeStatus // workflow status enum backed by StatusRegistry + TypeTaskType // task type enum backed by TypeRegistry +) + +// FieldDef describes a single task field's name and semantic type. +type FieldDef struct { + Name string + Type ValueType +} + +// fieldCatalog is the authoritative list of DSL-visible task fields. +var fieldCatalog = []FieldDef{ + {Name: "id", Type: TypeID}, + {Name: "title", Type: TypeString}, + {Name: "description", Type: TypeString}, + {Name: "status", Type: TypeStatus}, + {Name: "type", Type: TypeTaskType}, + {Name: "tags", Type: TypeListString}, + {Name: "dependsOn", Type: TypeListRef}, + {Name: "due", Type: TypeDate}, + {Name: "recurrence", Type: TypeRecurrence}, + {Name: "assignee", Type: TypeString}, + {Name: "priority", Type: TypeInt}, + {Name: "points", Type: TypeInt}, + {Name: "createdBy", Type: TypeString}, + {Name: "createdAt", Type: TypeTimestamp}, + {Name: "updatedAt", Type: TypeTimestamp}, +} + +// pre-built lookup for Field() +var fieldByName map[string]FieldDef + +func init() { + fieldByName = make(map[string]FieldDef, len(fieldCatalog)) + for _, f := range fieldCatalog { + fieldByName[f.Name] = f + } +} + +// Field returns the FieldDef for a given field name and whether it exists. +func Field(name string) (FieldDef, bool) { + f, ok := fieldByName[name] + return f, ok +} + +// Fields returns the ordered list of all DSL-visible task fields. +func Fields() []FieldDef { + result := make([]FieldDef, len(fieldCatalog)) + copy(result, fieldCatalog) + return result +} diff --git a/workflow/fields_test.go b/workflow/fields_test.go new file mode 100644 index 0000000..1f3d8c6 --- /dev/null +++ b/workflow/fields_test.go @@ -0,0 +1,74 @@ +package workflow + +import "testing" + +func TestField(t *testing.T) { + tests := []struct { + name string + want ValueType + wantOK bool + }{ + {"id", TypeID, true}, + {"title", TypeString, true}, + {"description", TypeString, true}, + {"status", TypeStatus, true}, + {"type", TypeTaskType, true}, + {"tags", TypeListString, true}, + {"dependsOn", TypeListRef, true}, + {"due", TypeDate, true}, + {"recurrence", TypeRecurrence, true}, + {"assignee", TypeString, true}, + {"priority", TypeInt, true}, + {"points", TypeInt, true}, + {"createdBy", TypeString, true}, + {"createdAt", TypeTimestamp, true}, + {"updatedAt", TypeTimestamp, true}, + {"nonexistent", 0, false}, + {"comments", 0, false}, // excluded from DSL catalog + {"loadedMtime", 0, false}, // excluded from DSL catalog + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, ok := Field(tt.name) + if ok != tt.wantOK { + t.Errorf("Field(%q) ok = %v, want %v", tt.name, ok, tt.wantOK) + return + } + if ok && f.Type != tt.want { + t.Errorf("Field(%q).Type = %v, want %v", tt.name, f.Type, tt.want) + } + }) + } +} + +func TestFields(t *testing.T) { + fields := Fields() + if len(fields) != 15 { + t.Fatalf("expected 15 fields, got %d", len(fields)) + } + + // verify it returns a copy + fields[0].Name = "modified" + original, _ := Field("id") + if original.Name == "modified" { + t.Error("Fields() should return a copy, not a reference to the internal slice") + } +} + +func TestDateVsTimestamp(t *testing.T) { + due, _ := Field("due") + if due.Type != TypeDate { + t.Errorf("due should be TypeDate, got %v", due.Type) + } + + createdAt, _ := Field("createdAt") + if createdAt.Type != TypeTimestamp { + t.Errorf("createdAt should be TypeTimestamp, got %v", createdAt.Type) + } + + updatedAt, _ := Field("updatedAt") + if updatedAt.Type != TypeTimestamp { + t.Errorf("updatedAt should be TypeTimestamp, got %v", updatedAt.Type) + } +} diff --git a/workflow/status.go b/workflow/status.go new file mode 100644 index 0000000..59cbe4a --- /dev/null +++ b/workflow/status.go @@ -0,0 +1,156 @@ +package workflow + +import ( + "fmt" + "log/slog" + "strings" +) + +// StatusKey is a named type for workflow status keys. +// All status keys are normalized: lowercase, underscores as separators. +type StatusKey string + +// well-known status constants (defaults from workflow.yaml template) +const ( + StatusBacklog StatusKey = "backlog" + StatusReady StatusKey = "ready" + StatusInProgress StatusKey = "in_progress" + StatusReview StatusKey = "review" + StatusDone StatusKey = "done" +) + +// StatusDef defines a single workflow status. +type StatusDef struct { + Key string `yaml:"key"` + Label string `yaml:"label"` + Emoji string `yaml:"emoji"` + Active bool `yaml:"active"` + Default bool `yaml:"default"` + Done bool `yaml:"done"` +} + +// StatusRegistry is an ordered collection of valid statuses. +// It is constructed from a list of StatusDef and provides lookup and query methods. +// StatusRegistry holds no global state — the populated singleton lives in config/. +type StatusRegistry struct { + statuses []StatusDef + byKey map[StatusKey]StatusDef + defaultKey StatusKey + doneKey StatusKey +} + +// NormalizeStatusKey lowercases, trims, and replaces "-" and " " with "_". +// This preserves multi-word keys (e.g. "in-progress" → "in_progress"). +func NormalizeStatusKey(key string) StatusKey { + normalized := strings.ToLower(strings.TrimSpace(key)) + normalized = strings.ReplaceAll(normalized, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + return StatusKey(normalized) +} + +// NewStatusRegistry constructs a StatusRegistry from the given definitions. +// Returns an error if keys are empty, duplicated, or the list is empty. +func NewStatusRegistry(defs []StatusDef) (*StatusRegistry, error) { + if len(defs) == 0 { + return nil, fmt.Errorf("statuses list is empty") + } + + reg := &StatusRegistry{ + statuses: make([]StatusDef, 0, len(defs)), + byKey: make(map[StatusKey]StatusDef, len(defs)), + } + + for i, def := range defs { + if def.Key == "" { + return nil, fmt.Errorf("status at index %d has empty key", i) + } + + normalized := NormalizeStatusKey(def.Key) + def.Key = string(normalized) + + if _, exists := reg.byKey[normalized]; exists { + return nil, fmt.Errorf("duplicate status key %q", normalized) + } + + if def.Default { + if reg.defaultKey != "" { + slog.Warn("multiple statuses marked default; using first", "first", reg.defaultKey, "duplicate", normalized) + } else { + reg.defaultKey = normalized + } + } + if def.Done { + if reg.doneKey != "" { + slog.Warn("multiple statuses marked done; using first", "first", reg.doneKey, "duplicate", normalized) + } else { + reg.doneKey = normalized + } + } + + reg.byKey[normalized] = def + reg.statuses = append(reg.statuses, def) + } + + // if no explicit default, use the first status + if reg.defaultKey == "" { + reg.defaultKey = StatusKey(reg.statuses[0].Key) + slog.Warn("no status marked default; using first status", "key", reg.defaultKey) + } + + if reg.doneKey == "" { + slog.Warn("no status marked done; task completion features may not work correctly") + } + + return reg, nil +} + +// All returns the ordered list of status definitions. +// returns a copy to prevent callers from mutating internal state. +func (r *StatusRegistry) All() []StatusDef { + result := make([]StatusDef, len(r.statuses)) + copy(result, r.statuses) + return result +} + +// Lookup returns the StatusDef for a given key (normalized) and whether it exists. +func (r *StatusRegistry) Lookup(key string) (StatusDef, bool) { + def, ok := r.byKey[NormalizeStatusKey(key)] + return def, ok +} + +// IsValid reports whether key is a recognized status. +func (r *StatusRegistry) IsValid(key string) bool { + _, ok := r.byKey[NormalizeStatusKey(key)] + return ok +} + +// IsActive reports whether the status has the active flag set. +func (r *StatusRegistry) IsActive(key string) bool { + def, ok := r.byKey[NormalizeStatusKey(key)] + return ok && def.Active +} + +// IsDone reports whether the status has the done flag set. +func (r *StatusRegistry) IsDone(key string) bool { + def, ok := r.byKey[NormalizeStatusKey(key)] + return ok && def.Done +} + +// DefaultKey returns the key of the status with default: true. +func (r *StatusRegistry) DefaultKey() StatusKey { + return r.defaultKey +} + +// DoneKey returns the key of the status with done: true. +func (r *StatusRegistry) DoneKey() StatusKey { + return r.doneKey +} + +// Keys returns all status keys in definition order. +func (r *StatusRegistry) Keys() []StatusKey { + keys := make([]StatusKey, len(r.statuses)) + for i, s := range r.statuses { + keys[i] = StatusKey(s.Key) + } + return keys +} diff --git a/workflow/status_test.go b/workflow/status_test.go new file mode 100644 index 0000000..f328669 --- /dev/null +++ b/workflow/status_test.go @@ -0,0 +1,263 @@ +package workflow + +import "testing" + +func defaultTestStatuses() []StatusDef { + return []StatusDef{ + {Key: "backlog", Label: "Backlog", Emoji: "📥", Default: true}, + {Key: "ready", Label: "Ready", Emoji: "📋", Active: true}, + {Key: "in_progress", Label: "In Progress", Emoji: "⚙️", Active: true}, + {Key: "review", Label: "Review", Emoji: "👀", Active: true}, + {Key: "done", Label: "Done", Emoji: "✅", Done: true}, + } +} + +func mustBuildStatusRegistry(t *testing.T, defs []StatusDef) *StatusRegistry { + t.Helper() + reg, err := NewStatusRegistry(defs) + if err != nil { + t.Fatalf("NewStatusRegistry: %v", err) + } + return reg +} + +func TestNormalizeStatusKey(t *testing.T) { + tests := []struct { + input string + want StatusKey + }{ + {"backlog", "backlog"}, + {"BACKLOG", "backlog"}, + {"In-Progress", "in_progress"}, + {"in progress", "in_progress"}, + {" DONE ", "done"}, + {"In_Review", "in_review"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := NormalizeStatusKey(tt.input); got != tt.want { + t.Errorf("NormalizeStatusKey(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNewStatusRegistry_DefaultStatuses(t *testing.T) { + reg := mustBuildStatusRegistry(t, defaultTestStatuses()) + + if len(reg.All()) != 5 { + t.Fatalf("expected 5 statuses, got %d", len(reg.All())) + } + if reg.DefaultKey() != "backlog" { + t.Errorf("expected default key 'backlog', got %q", reg.DefaultKey()) + } + if reg.DoneKey() != "done" { + t.Errorf("expected done key 'done', got %q", reg.DoneKey()) + } +} + +func TestNewStatusRegistry_CustomStatuses(t *testing.T) { + custom := []StatusDef{ + {Key: "new", Label: "New", Emoji: "🆕", Default: true}, + {Key: "wip", Label: "Work In Progress", Emoji: "🔧", Active: true}, + {Key: "closed", Label: "Closed", Emoji: "🔒", Done: true}, + } + reg := mustBuildStatusRegistry(t, custom) + + if len(reg.All()) != 3 { + t.Fatalf("expected 3 statuses, got %d", len(reg.All())) + } + if reg.DefaultKey() != "new" { + t.Errorf("expected default key 'new', got %q", reg.DefaultKey()) + } + if reg.DoneKey() != "closed" { + t.Errorf("expected done key 'closed', got %q", reg.DoneKey()) + } +} + +func TestStatusRegistry_IsValid(t *testing.T) { + reg := mustBuildStatusRegistry(t, defaultTestStatuses()) + + tests := []struct { + key string + want bool + }{ + {"backlog", true}, + {"ready", true}, + {"in_progress", true}, + {"In-Progress", true}, + {"review", true}, + {"done", true}, + {"unknown", false}, + {"", false}, + {"todo", false}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + if got := reg.IsValid(tt.key); got != tt.want { + t.Errorf("IsValid(%q) = %v, want %v", tt.key, got, tt.want) + } + }) + } +} + +func TestStatusRegistry_IsActive(t *testing.T) { + reg := mustBuildStatusRegistry(t, defaultTestStatuses()) + + tests := []struct { + key string + want bool + }{ + {"backlog", false}, + {"ready", true}, + {"in_progress", true}, + {"review", true}, + {"done", false}, + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + if got := reg.IsActive(tt.key); got != tt.want { + t.Errorf("IsActive(%q) = %v, want %v", tt.key, got, tt.want) + } + }) + } +} + +func TestStatusRegistry_IsDone(t *testing.T) { + reg := mustBuildStatusRegistry(t, defaultTestStatuses()) + + if !reg.IsDone("done") { + t.Error("expected 'done' to be marked as done") + } + if reg.IsDone("backlog") { + t.Error("expected 'backlog' to not be marked as done") + } +} + +func TestStatusRegistry_Lookup(t *testing.T) { + reg := mustBuildStatusRegistry(t, defaultTestStatuses()) + + def, ok := reg.Lookup("ready") + if !ok { + t.Fatal("expected to find 'ready'") + } + if def.Label != "Ready" { + t.Errorf("expected label 'Ready', got %q", def.Label) + } + if def.Emoji != "📋" { + t.Errorf("expected emoji '📋', got %q", def.Emoji) + } + + _, ok = reg.Lookup("nonexistent") + if ok { + t.Error("expected Lookup to return false for nonexistent key") + } +} + +func TestStatusRegistry_Keys(t *testing.T) { + reg := mustBuildStatusRegistry(t, defaultTestStatuses()) + + keys := reg.Keys() + expected := []StatusKey{"backlog", "ready", "in_progress", "review", "done"} + + if len(keys) != len(expected) { + t.Fatalf("expected %d keys, got %d", len(expected), len(keys)) + } + for i, key := range keys { + if key != expected[i] { + t.Errorf("keys[%d] = %q, want %q", i, key, expected[i]) + } + } +} + +func TestStatusRegistry_NormalizesKeys(t *testing.T) { + custom := []StatusDef{ + {Key: "In-Progress", Label: "In Progress", Default: true}, + {Key: " DONE ", Label: "Done", Done: true}, + } + reg := mustBuildStatusRegistry(t, custom) + + if !reg.IsValid("in_progress") { + t.Error("expected 'in_progress' to be valid after normalization") + } + if !reg.IsValid("done") { + t.Error("expected 'done' to be valid after normalization") + } +} + +func TestNewStatusRegistry_EmptyKey(t *testing.T) { + defs := []StatusDef{ + {Key: "", Label: "No Key"}, + } + _, err := NewStatusRegistry(defs) + if err == nil { + t.Error("expected error for empty key") + } +} + +func TestNewStatusRegistry_DuplicateKey(t *testing.T) { + defs := []StatusDef{ + {Key: "ready", Label: "Ready", Default: true}, + {Key: "ready", Label: "Ready 2"}, + } + _, err := NewStatusRegistry(defs) + if err == nil { + t.Error("expected error for duplicate key") + } +} + +func TestNewStatusRegistry_Empty(t *testing.T) { + _, err := NewStatusRegistry(nil) + if err == nil { + t.Error("expected error for empty statuses") + } +} + +func TestNewStatusRegistry_DefaultFallsToFirst(t *testing.T) { + defs := []StatusDef{ + {Key: "alpha", Label: "Alpha"}, + {Key: "beta", Label: "Beta"}, + } + reg, err := NewStatusRegistry(defs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if reg.DefaultKey() != "alpha" { + t.Errorf("expected default to fall back to first status 'alpha', got %q", reg.DefaultKey()) + } +} + +func TestStatusRegistry_AllReturnsCopy(t *testing.T) { + reg := mustBuildStatusRegistry(t, defaultTestStatuses()) + + all := reg.All() + all[0].Key = "mutated" + + // internal state must be unchanged + keys := reg.Keys() + if keys[0] != "backlog" { + t.Errorf("All() mutation leaked into registry: first key = %q, want %q", keys[0], "backlog") + } +} + +func TestStatusKeyConstants(t *testing.T) { + if StatusBacklog != "backlog" { + t.Errorf("StatusBacklog = %q", StatusBacklog) + } + if StatusReady != "ready" { + t.Errorf("StatusReady = %q", StatusReady) + } + if StatusInProgress != "in_progress" { + t.Errorf("StatusInProgress = %q", StatusInProgress) + } + if StatusReview != "review" { + t.Errorf("StatusReview = %q", StatusReview) + } + if StatusDone != "done" { + t.Errorf("StatusDone = %q", StatusDone) + } +} diff --git a/workflow/tasktype.go b/workflow/tasktype.go new file mode 100644 index 0000000..db8fe36 --- /dev/null +++ b/workflow/tasktype.go @@ -0,0 +1,195 @@ +package workflow + +import ( + "fmt" + "strings" +) + +// TaskType is a named type for workflow task type keys. +type TaskType string + +// well-known built-in type constants. +const ( + TypeStory TaskType = "story" + TypeBug TaskType = "bug" + TypeSpike TaskType = "spike" + TypeEpic TaskType = "epic" +) + +// TypeDef defines a single task type with metadata and aliases. +type TypeDef struct { + Key TaskType + Label string + Emoji string + Aliases []string // e.g. "feature" and "task" → story +} + +// DefaultTypeDefs returns the built-in type definitions. +func DefaultTypeDefs() []TypeDef { + return []TypeDef{ + {Key: TypeStory, Label: "Story", Emoji: "🌀", Aliases: []string{"feature", "task"}}, + {Key: TypeBug, Label: "Bug", Emoji: "💥"}, + {Key: TypeSpike, Label: "Spike", Emoji: "🔍"}, + {Key: TypeEpic, Label: "Epic", Emoji: "🗂️"}, + } +} + +// TypeRegistry is an ordered collection of valid task types. +// It is constructed from a list of TypeDef and provides lookup and normalization. +type TypeRegistry struct { + types []TypeDef + byKey map[TaskType]TypeDef + byAlias map[string]TaskType // normalized alias → canonical key + fallback TaskType // returned for unknown types +} + +// NormalizeTypeKey lowercases, trims, and strips all separators ("-", "_", " "). +// Built-in type keys are single words, so stripping is lossless. +func NormalizeTypeKey(s string) TaskType { + s = strings.ToLower(strings.TrimSpace(s)) + s = strings.ReplaceAll(s, "_", "") + s = strings.ReplaceAll(s, "-", "") + s = strings.ReplaceAll(s, " ", "") + return TaskType(s) +} + +// NewTypeRegistry constructs a TypeRegistry from the given definitions. +// The first definition's key is used as the fallback for unknown types. +func NewTypeRegistry(defs []TypeDef) (*TypeRegistry, error) { + if len(defs) == 0 { + return nil, fmt.Errorf("type definitions list is empty") + } + + reg := &TypeRegistry{ + types: make([]TypeDef, 0, len(defs)), + byKey: make(map[TaskType]TypeDef, len(defs)), + byAlias: make(map[string]TaskType), + fallback: NormalizeTypeKey(string(defs[0].Key)), + } + + // first pass: register all primary keys + for i, def := range defs { + if def.Key == "" { + return nil, fmt.Errorf("type at index %d has empty key", i) + } + + normalized := NormalizeTypeKey(string(def.Key)) + def.Key = normalized + defs[i] = def + + if _, exists := reg.byKey[normalized]; exists { + return nil, fmt.Errorf("duplicate type key %q", normalized) + } + + reg.byKey[normalized] = def + reg.types = append(reg.types, def) + } + + // second pass: register aliases against the complete key set + for _, def := range defs { + for _, alias := range def.Aliases { + normAlias := string(NormalizeTypeKey(alias)) + if existing, ok := reg.byAlias[normAlias]; ok { + return nil, fmt.Errorf("duplicate alias %q (already maps to %q)", alias, existing) + } + if _, ok := reg.byKey[TaskType(normAlias)]; ok { + return nil, fmt.Errorf("alias %q collides with primary key", alias) + } + reg.byAlias[normAlias] = def.Key + } + } + + return reg, nil +} + +// Lookup returns the TypeDef for a given key (normalized) and whether it exists. +func (r *TypeRegistry) Lookup(key TaskType) (TypeDef, bool) { + def, ok := r.byKey[NormalizeTypeKey(string(key))] + return def, ok +} + +// ParseType parses a raw string into a TaskType with validation. +// Returns the canonical key and true if recognized (including aliases), +// or (fallback, false) for unknown types. +func (r *TypeRegistry) ParseType(s string) (TaskType, bool) { + normalized := NormalizeTypeKey(s) + + // check primary keys + if _, ok := r.byKey[normalized]; ok { + return normalized, true + } + + // check aliases + if canonical, ok := r.byAlias[string(normalized)]; ok { + return canonical, true + } + + return r.fallback, false +} + +// NormalizeType normalizes a raw string into a TaskType. +// Unknown types default to the fallback (first registered type). +func (r *TypeRegistry) NormalizeType(s string) TaskType { + t, _ := r.ParseType(s) + return t +} + +// TypeLabel returns the human-readable label for a task type. +func (r *TypeRegistry) TypeLabel(t TaskType) string { + if def, ok := r.Lookup(t); ok { + return def.Label + } + return string(t) +} + +// TypeEmoji returns the emoji for a task type. +func (r *TypeRegistry) TypeEmoji(t TaskType) string { + if def, ok := r.Lookup(t); ok { + return def.Emoji + } + return "" +} + +// TypeDisplay returns "Label Emoji" for a task type. +func (r *TypeRegistry) TypeDisplay(t TaskType) string { + label := r.TypeLabel(t) + emoji := r.TypeEmoji(t) + if emoji == "" { + return label + } + return label + " " + emoji +} + +// ParseDisplay reverses a TypeDisplay() string (e.g. "Bug 💥") back to +// its canonical key. Returns (key, true) on match, or (fallback, false). +func (r *TypeRegistry) ParseDisplay(display string) (TaskType, bool) { + for _, def := range r.types { + if r.TypeDisplay(def.Key) == display { + return def.Key, true + } + } + return r.fallback, false +} + +// Keys returns all type keys in definition order. +func (r *TypeRegistry) Keys() []TaskType { + keys := make([]TaskType, len(r.types)) + for i, td := range r.types { + keys[i] = td.Key + } + return keys +} + +// All returns the ordered list of type definitions. +// returns a copy to prevent callers from mutating internal state. +func (r *TypeRegistry) All() []TypeDef { + result := make([]TypeDef, len(r.types)) + copy(result, r.types) + return result +} + +// IsValid reports whether key is a recognized type (primary key only, not alias). +func (r *TypeRegistry) IsValid(key TaskType) bool { + _, ok := r.byKey[NormalizeTypeKey(string(key))] + return ok +} diff --git a/workflow/tasktype_test.go b/workflow/tasktype_test.go new file mode 100644 index 0000000..bb69fc0 --- /dev/null +++ b/workflow/tasktype_test.go @@ -0,0 +1,363 @@ +package workflow + +import "testing" + +func mustBuildTypeRegistry(t *testing.T, defs []TypeDef) *TypeRegistry { + t.Helper() + reg, err := NewTypeRegistry(defs) + if err != nil { + t.Fatalf("NewTypeRegistry: %v", err) + } + return reg +} + +func TestNormalizeTypeKey(t *testing.T) { + tests := []struct { + input string + want TaskType + }{ + {"story", "story"}, + {"Story", "story"}, + {"BUG", "bug"}, + {"SPIKE", "spike"}, + {"in_progress", "inprogress"}, + {"some-type", "sometype"}, + {" EPIC ", "epic"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := NormalizeTypeKey(tt.input); got != tt.want { + t.Errorf("NormalizeTypeKey(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTypeRegistry_ParseType(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + tests := []struct { + name string + input string + want TaskType + wantOK bool + }{ + {"story", "story", TypeStory, true}, + {"bug", "bug", TypeBug, true}, + {"spike", "spike", TypeSpike, true}, + {"epic", "epic", TypeEpic, true}, + {"feature alias", "feature", TypeStory, true}, + {"task alias", "task", TypeStory, true}, + {"case insensitive", "Story", TypeStory, true}, + {"uppercase", "BUG", TypeBug, true}, + {"unknown", "unknown", TypeStory, false}, + {"empty", "", TypeStory, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := reg.ParseType(tt.input) + if got != tt.want || ok != tt.wantOK { + t.Errorf("ParseType(%q) = (%q, %v), want (%q, %v)", tt.input, got, ok, tt.want, tt.wantOK) + } + }) + } +} + +func TestTypeRegistry_NormalizeType(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + tests := []struct { + name string + input string + want TaskType + }{ + {"story", "story", TypeStory}, + {"bug", "bug", TypeBug}, + {"spike", "spike", TypeSpike}, + {"epic", "epic", TypeEpic}, + {"feature alias", "feature", TypeStory}, + {"task alias", "task", TypeStory}, + {"case insensitive", "EPIC", TypeEpic}, + {"unknown defaults to story", "unknown", TypeStory}, + {"empty defaults to story", "", TypeStory}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := reg.NormalizeType(tt.input); got != tt.want { + t.Errorf("NormalizeType(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTypeRegistry_TypeLabel(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + tests := []struct { + input TaskType + want string + }{ + {TypeStory, "Story"}, + {TypeBug, "Bug"}, + {TypeSpike, "Spike"}, + {TypeEpic, "Epic"}, + {TaskType("unknown"), "unknown"}, + } + + for _, tt := range tests { + t.Run(string(tt.input), func(t *testing.T) { + if got := reg.TypeLabel(tt.input); got != tt.want { + t.Errorf("TypeLabel(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTypeRegistry_TypeEmoji(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + tests := []struct { + input TaskType + want string + }{ + {TypeStory, "🌀"}, + {TypeBug, "💥"}, + {TypeSpike, "🔍"}, + {TypeEpic, "🗂️"}, + {TaskType("unknown"), ""}, + } + + for _, tt := range tests { + t.Run(string(tt.input), func(t *testing.T) { + if got := reg.TypeEmoji(tt.input); got != tt.want { + t.Errorf("TypeEmoji(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTypeRegistry_TypeDisplay(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + tests := []struct { + input TaskType + want string + }{ + {TypeStory, "Story 🌀"}, + {TypeBug, "Bug 💥"}, + {TypeSpike, "Spike 🔍"}, + {TypeEpic, "Epic 🗂️"}, + {TaskType("unknown"), "unknown"}, + } + + for _, tt := range tests { + t.Run(string(tt.input), func(t *testing.T) { + if got := reg.TypeDisplay(tt.input); got != tt.want { + t.Errorf("TypeDisplay(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTypeRegistry_ParseDisplay(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + tests := []struct { + name string + input string + want TaskType + wantOK bool + }{ + {"story display", "Story 🌀", TypeStory, true}, + {"bug display", "Bug 💥", TypeBug, true}, + {"spike display", "Spike 🔍", TypeSpike, true}, + {"epic display", "Epic 🗂️", TypeEpic, true}, + {"unknown display", "Unknown", TypeStory, false}, + {"label only", "Bug", TypeStory, false}, + {"empty", "", TypeStory, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := reg.ParseDisplay(tt.input) + if got != tt.want || ok != tt.wantOK { + t.Errorf("ParseDisplay(%q) = (%q, %v), want (%q, %v)", tt.input, got, ok, tt.want, tt.wantOK) + } + }) + } +} + +func TestTypeRegistry_Keys(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + keys := reg.Keys() + expected := []TaskType{TypeStory, TypeBug, TypeSpike, TypeEpic} + + if len(keys) != len(expected) { + t.Fatalf("expected %d keys, got %d", len(expected), len(keys)) + } + for i, key := range keys { + if key != expected[i] { + t.Errorf("keys[%d] = %q, want %q", i, key, expected[i]) + } + } +} + +func TestTypeRegistry_IsValid(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + if !reg.IsValid(TypeStory) { + t.Error("expected story to be valid") + } + if reg.IsValid("feature") { + t.Error("expected alias 'feature' to not be valid as primary key") + } + if reg.IsValid("unknown") { + t.Error("expected unknown to not be valid") + } +} + +func TestTypeRegistry_LookupNormalizesInput(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + // Lookup should normalize inputs just like StatusRegistry does + tests := []struct { + name string + input TaskType + want bool + }{ + {"lowercase", "story", true}, + {"uppercase", "STORY", true}, + {"mixed case", "Bug", true}, + {"unknown", "nope", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, ok := reg.Lookup(tt.input) + if ok != tt.want { + t.Errorf("Lookup(%q) ok = %v, want %v", tt.input, ok, tt.want) + } + }) + } + + // TypeLabel/TypeEmoji/IsValid should also normalize + if label := reg.TypeLabel("BUG"); label != "Bug" { + t.Errorf("TypeLabel(BUG) = %q, want %q", label, "Bug") + } + if emoji := reg.TypeEmoji("EPIC"); emoji != "🗂️" { + t.Errorf("TypeEmoji(EPIC) = %q, want %q", emoji, "🗂️") + } + if !reg.IsValid("SPIKE") { + t.Error("expected IsValid(SPIKE) to be true after normalization") + } +} + +func TestNewTypeRegistry_EmptyKey(t *testing.T) { + defs := []TypeDef{{Key: "", Label: "No Key"}} + _, err := NewTypeRegistry(defs) + if err == nil { + t.Error("expected error for empty key") + } +} + +func TestNewTypeRegistry_DuplicateKey(t *testing.T) { + defs := []TypeDef{ + {Key: "story", Label: "Story"}, + {Key: "story", Label: "Story 2"}, + } + _, err := NewTypeRegistry(defs) + if err == nil { + t.Error("expected error for duplicate key") + } +} + +func TestNewTypeRegistry_DuplicateAlias(t *testing.T) { + defs := []TypeDef{ + {Key: "story", Label: "Story", Aliases: []string{"feature"}}, + {Key: "bug", Label: "Bug", Aliases: []string{"feature"}}, + } + _, err := NewTypeRegistry(defs) + if err == nil { + t.Error("expected error for duplicate alias") + } +} + +func TestNewTypeRegistry_AliasCollidesWithKey(t *testing.T) { + defs := []TypeDef{ + {Key: "story", Label: "Story"}, + {Key: "bug", Label: "Bug", Aliases: []string{"story"}}, + } + _, err := NewTypeRegistry(defs) + if err == nil { + t.Error("expected error when alias collides with primary key") + } +} + +func TestNewTypeRegistry_AliasCollidesWithLaterKey(t *testing.T) { + // alias "feature" on story should collide with later primary key "feature" + defs := []TypeDef{ + {Key: "story", Label: "Story", Aliases: []string{"feature"}}, + {Key: "feature", Label: "Feature"}, + } + _, err := NewTypeRegistry(defs) + if err == nil { + t.Error("expected error when alias collides with a later primary key") + } +} + +func TestNewTypeRegistry_FallbackNormalized(t *testing.T) { + defs := []TypeDef{ + {Key: "Story", Label: "Story"}, + {Key: "bug", Label: "Bug"}, + } + reg := mustBuildTypeRegistry(t, defs) + + // fallback should be normalized even though the input key was "Story" + got, ok := reg.ParseType("unknown-thing") + if ok { + t.Fatal("expected ok=false for unknown type") + } + if got != "story" { + t.Errorf("ParseType(unknown) = %q, want %q (normalized fallback)", got, "story") + } +} + +func TestNewTypeRegistry_Empty(t *testing.T) { + _, err := NewTypeRegistry(nil) + if err == nil { + t.Error("expected error for empty type definitions") + } +} + +func TestTypeRegistry_AllReturnsCopy(t *testing.T) { + reg := mustBuildTypeRegistry(t, DefaultTypeDefs()) + + all := reg.All() + all[0].Key = "mutated" + + // internal state must be unchanged + keys := reg.Keys() + if keys[0] != TypeStory { + t.Errorf("All() mutation leaked into registry: first key = %q, want %q", keys[0], TypeStory) + } +} + +func TestTypeConstants(t *testing.T) { + if TypeStory != "story" { + t.Errorf("TypeStory = %q", TypeStory) + } + if TypeBug != "bug" { + t.Errorf("TypeBug = %q", TypeBug) + } + if TypeSpike != "spike" { + t.Errorf("TypeSpike = %q", TypeSpike) + } + if TypeEpic != "epic" { + t.Errorf("TypeEpic = %q", TypeEpic) + } +}