From 8b96fa02492f581dc34643440b0a737c0aa74e24 Mon Sep 17 00:00:00 2001 From: Judah Wyllie Date: Thu, 8 Jan 2026 15:15:43 -0800 Subject: [PATCH 1/2] Stop walk on error --- validator/core/walk.go | 31 +++++++++++++++++++++++++++++++ validator/validator.go | 12 +++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/validator/core/walk.go b/validator/core/walk.go index 53245086..508e59df 100644 --- a/validator/core/walk.go +++ b/validator/core/walk.go @@ -17,6 +17,13 @@ type Events struct { directiveList []func(walker *Walker, directives []*ast.Directive) value []func(walker *Walker, value *ast.Value) variable []func(walker *Walker, variable *ast.VariableDefinition) + + // StopOnFirstError indicates whether to stop traversal on the first error. + StopOnFirstError bool + + // Stopped indicates traversal should stop early. This is set by validators + // that wish to abort walking once an error has been encountered. + Stopped bool } func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) { @@ -76,6 +83,9 @@ type Walker struct { } func (w *Walker) walk() { + if w.Observers != nil && w.Observers.Stopped { + return + } for _, child := range w.Document.Operations { w.validatedFragmentSpreads = make(map[string]bool) w.walkOperation(child) @@ -87,6 +97,9 @@ func (w *Walker) walk() { } func (w *Walker) walkOperation(operation *ast.OperationDefinition) { + if w.Observers != nil && w.Observers.Stopped { + return + } w.CurrentOperation = operation for _, varDef := range operation.VariableDefinitions { varDef.Definition = w.Schema.Types[varDef.Type.Name()] @@ -130,6 +143,9 @@ func (w *Walker) walkOperation(operation *ast.OperationDefinition) { } func (w *Walker) walkFragment(it *ast.FragmentDefinition) { + if w.Observers != nil && w.Observers.Stopped { + return + } def := w.Schema.Types[it.TypeCondition] it.Definition = def @@ -143,6 +159,9 @@ func (w *Walker) walkFragment(it *ast.FragmentDefinition) { } func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) { + if w.Observers != nil && w.Observers.Stopped { + return + } for _, dir := range directives { def := w.Schema.Directives[dir.Name] dir.Definition = def @@ -169,6 +188,9 @@ func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Dir } func (w *Walker) walkValue(value *ast.Value) { + if w.Observers != nil && w.Observers.Stopped { + return + } if value.Kind == ast.Variable && w.CurrentOperation != nil { value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw) if value.VariableDefinition != nil { @@ -207,6 +229,9 @@ func (w *Walker) walkValue(value *ast.Value) { } func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) { + if w.Observers != nil && w.Observers.Stopped { + return + } if argDef != nil { arg.Value.ExpectedType = argDef.Type arg.Value.ExpectedTypeHasDefault = argDef.DefaultValue != nil && argDef.DefaultValue.Kind != ast.NullValue @@ -217,12 +242,18 @@ func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) } func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) { + if w.Observers != nil && w.Observers.Stopped { + return + } for _, child := range it { w.walkSelection(parentDef, child) } } func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) { + if w.Observers != nil && w.Observers.Stopped { + return + } switch it := it.(type) { case *ast.Field: var def *ast.FieldDefinition diff --git a/validator/validator.go b/validator/validator.go index 1214ed16..40094345 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -118,6 +118,10 @@ func Validate(schema *Schema, doc *QueryDocument, rules ...Rule) gqlerror.List { } func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules) gqlerror.List { + return ValidateWithRulesAndStopOnFirstError(schema, doc, rules, false) +} + +func ValidateWithRulesAndStopOnFirstError(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules, stopOnFirstError bool) gqlerror.List { if rules == nil { rules = validatorrules.NewDefaultRules() } @@ -132,7 +136,9 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules if len(errs) > 0 { return errs } - observers := &core.Events{} + observers := &core.Events{ + StopOnFirstError: stopOnFirstError, + } var currentRules []Rule // nolint:prealloc // would require extra local refs for len for name, ruleFunc := range rules.GetInner() { @@ -150,6 +156,10 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules o(err) } errs = append(errs, err) + + if observers.StopOnFirstError { + observers.Stopped = true + } }) } From 5cf80bad753f14e48ea2ef230c17589d57b2839b Mon Sep 17 00:00:00 2001 From: Judah Wyllie Date: Thu, 15 Jan 2026 15:21:41 -0800 Subject: [PATCH 2/2] Allow configurable error limit --- validator/core/walk.go | 3 - validator/validator.go | 13 ++- validator/validator_test.go | 212 ++++++++++++++++++++++++++++++++++++ validator/walk_test.go | 20 ++++ 4 files changed, 239 insertions(+), 9 deletions(-) diff --git a/validator/core/walk.go b/validator/core/walk.go index 508e59df..7e58dd35 100644 --- a/validator/core/walk.go +++ b/validator/core/walk.go @@ -18,9 +18,6 @@ type Events struct { value []func(walker *Walker, value *ast.Value) variable []func(walker *Walker, variable *ast.VariableDefinition) - // StopOnFirstError indicates whether to stop traversal on the first error. - StopOnFirstError bool - // Stopped indicates traversal should stop early. This is set by validators // that wish to abort walking once an error has been encountered. Stopped bool diff --git a/validator/validator.go b/validator/validator.go index 40094345..d241f4ec 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -118,10 +118,10 @@ func Validate(schema *Schema, doc *QueryDocument, rules ...Rule) gqlerror.List { } func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules) gqlerror.List { - return ValidateWithRulesAndStopOnFirstError(schema, doc, rules, false) + return ValidateWithRulesAndMaximumErrors(schema, doc, rules, 0) } -func ValidateWithRulesAndStopOnFirstError(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules, stopOnFirstError bool) gqlerror.List { +func ValidateWithRulesAndMaximumErrors(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules, maximumErrors int) gqlerror.List { if rules == nil { rules = validatorrules.NewDefaultRules() } @@ -133,12 +133,13 @@ func ValidateWithRulesAndStopOnFirstError(schema *Schema, doc *QueryDocument, ru if doc == nil { errs = append(errs, gqlerror.Errorf("cannot validate as QueryDocument is nil")) } + if maximumErrors < 0 { + errs = append(errs, gqlerror.Errorf("maximumErrors cannot be negative")) + } if len(errs) > 0 { return errs } - observers := &core.Events{ - StopOnFirstError: stopOnFirstError, - } + observers := &core.Events{} var currentRules []Rule // nolint:prealloc // would require extra local refs for len for name, ruleFunc := range rules.GetInner() { @@ -157,7 +158,7 @@ func ValidateWithRulesAndStopOnFirstError(schema *Schema, doc *QueryDocument, ru } errs = append(errs, err) - if observers.StopOnFirstError { + if maximumErrors > 0 && len(errs) >= maximumErrors { observers.Stopped = true } }) diff --git a/validator/validator_test.go b/validator/validator_test.go index 6befa032..a9482d37 100644 --- a/validator/validator_test.go +++ b/validator/validator_test.go @@ -318,3 +318,215 @@ func TestRemoveRule(t *testing.T) { // no error validator.RemoveRule("Rule that should no longer exist") } + +func TestValidateWithRulesAndMaximumErrors(t *testing.T) { + t.Run("maximumErrors limits error count", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + field3: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + field3 + } + `}) + require.NoError(t, err) + + // Create a rule that generates errors for each field + errorRule := validator.Rule{ + Name: "ErrorRule", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + addError(validator.Message("Error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(errorRule) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 2) + + // Should only return 2 errors even though there are 3 fields + require.Len(t, errList, 2) + }) + + t.Run("maximumErrors zero means no limit", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + field3: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + field3 + } + `}) + require.NoError(t, err) + + // Create a rule that generates errors for each field + errorRule := validator.Rule{ + Name: "ErrorRule", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + addError(validator.Message("Error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(errorRule) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 0) + + // Should return all errors when maximumErrors is 0 + require.Len(t, errList, 3) + }) + + t.Run("negative maximumErrors returns error", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + } + `}) + require.NoError(t, err) + + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, nil, -1) + + // Should return an error about negative maximumErrors + require.Len(t, errList, 1) + require.Contains(t, errList[0].Message, "maximumErrors cannot be negative") + }) + + t.Run("maximumErrors stops traversal early", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + field3: String! + field4: String! + field5: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + field3 + field4 + field5 + } + `}) + require.NoError(t, err) + + fieldCount := 0 + // Create a rule that generates errors and counts fields + errorRule := validator.Rule{ + Name: "ErrorRule", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + fieldCount++ + addError(validator.Message("Error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(errorRule) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 2) + + // Should only return 2 errors + require.Len(t, errList, 2) + // Should have stopped traversal early after exactly 2 fields processed + require.Equal(t, 2, fieldCount) + }) + + t.Run("maximumErrors with multiple rules", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + } + `}) + require.NoError(t, err) + + // Create two rules that each generate errors and count fields + fieldCount := 0 + rule1 := validator.Rule{ + Name: "Rule1", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + fieldCount++ + addError(validator.Message("Rule1 error for field %s", field.Name)) + }) + }, + } + rule2 := validator.Rule{ + Name: "Rule2", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + fieldCount++ + addError(validator.Message("Rule2 error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(rule1, rule2) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 3) + + // Although we set maximumErrors to 3, we expect 4 errors here (2 rules × 2 fields). + // The limit is evaluated after the batch is processed, allowing a final overflow. + require.Equal(t, 4, fieldCount) + require.Equal(t, 4, len(errList)) + }) +} diff --git a/validator/walk_test.go b/validator/walk_test.go index d92b8858..168e1584 100644 --- a/validator/walk_test.go +++ b/validator/walk_test.go @@ -50,3 +50,23 @@ func TestWalkInlineFragment(t *testing.T) { require.True(t, called) } + +func TestWalkStoppedEarly(t *testing.T) { + schema, err := LoadSchema(Prelude, &ast.Source{Input: "type Query { name: String, age: Int }\n schema { query: Query }"}) + require.NoError(t, err) + query, err := parser.ParseQuery(&ast.Source{Input: "{ name age }"}) + require.NoError(t, err) + + fieldCount := 0 + observers := &Events{} + observers.OnField(func(walker *Walker, field *ast.Field) { + fieldCount++ + if fieldCount == 1 { + observers.Stopped = true + } + }) + + Walk(schema, query, observers) + + require.Equal(t, 1, fieldCount) +}