Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions validator/core/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ type Events struct {
directiveList []func(walker *Walker, directives []*ast.Directive)
value []func(walker *Walker, value *ast.Value)
variable []func(walker *Walker, variable *ast.VariableDefinition)

// Stopped indicates traversal should stop early. This is set by validators
// that wish to abort walking once an error has been encountered.
Stopped bool
}

func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
Expand Down Expand Up @@ -76,6 +80,9 @@ type Walker struct {
}

func (w *Walker) walk() {
if w.Observers != nil && w.Observers.Stopped {
return
}
for _, child := range w.Document.Operations {
w.validatedFragmentSpreads = make(map[string]bool)
w.walkOperation(child)
Expand All @@ -87,6 +94,9 @@ func (w *Walker) walk() {
}

func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
if w.Observers != nil && w.Observers.Stopped {
return
}
w.CurrentOperation = operation
for _, varDef := range operation.VariableDefinitions {
varDef.Definition = w.Schema.Types[varDef.Type.Name()]
Expand Down Expand Up @@ -130,6 +140,9 @@ func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
}

func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
if w.Observers != nil && w.Observers.Stopped {
return
}
def := w.Schema.Types[it.TypeCondition]

it.Definition = def
Expand All @@ -143,6 +156,9 @@ func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
}

func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
if w.Observers != nil && w.Observers.Stopped {
return
}
for _, dir := range directives {
def := w.Schema.Directives[dir.Name]
dir.Definition = def
Expand All @@ -169,6 +185,9 @@ func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Dir
}

func (w *Walker) walkValue(value *ast.Value) {
if w.Observers != nil && w.Observers.Stopped {
return
}
if value.Kind == ast.Variable && w.CurrentOperation != nil {
value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
if value.VariableDefinition != nil {
Expand Down Expand Up @@ -207,6 +226,9 @@ func (w *Walker) walkValue(value *ast.Value) {
}

func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
if w.Observers != nil && w.Observers.Stopped {
return
}
if argDef != nil {
arg.Value.ExpectedType = argDef.Type
arg.Value.ExpectedTypeHasDefault = argDef.DefaultValue != nil && argDef.DefaultValue.Kind != ast.NullValue
Expand All @@ -217,12 +239,18 @@ func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument)
}

func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
if w.Observers != nil && w.Observers.Stopped {
return
}
for _, child := range it {
w.walkSelection(parentDef, child)
}
}

func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
if w.Observers != nil && w.Observers.Stopped {
return
}
switch it := it.(type) {
case *ast.Field:
var def *ast.FieldDefinition
Expand Down
11 changes: 11 additions & 0 deletions validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ func Validate(schema *Schema, doc *QueryDocument, rules ...Rule) gqlerror.List {
}

func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules) gqlerror.List {
return ValidateWithRulesAndMaximumErrors(schema, doc, rules, 0)
}

func ValidateWithRulesAndMaximumErrors(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules, maximumErrors int) gqlerror.List {
if rules == nil {
rules = validatorrules.NewDefaultRules()
}
Expand All @@ -129,6 +133,9 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules
if doc == nil {
errs = append(errs, gqlerror.Errorf("cannot validate as QueryDocument is nil"))
}
if maximumErrors < 0 {
errs = append(errs, gqlerror.Errorf("maximumErrors cannot be negative"))
}
if len(errs) > 0 {
return errs
}
Expand All @@ -150,6 +157,10 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules
o(err)
}
errs = append(errs, err)

if maximumErrors > 0 && len(errs) >= maximumErrors {
observers.Stopped = true
}
})
}

Expand Down
212 changes: 212 additions & 0 deletions validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,215 @@ func TestRemoveRule(t *testing.T) {
// no error
validator.RemoveRule("Rule that should no longer exist")
}

func TestValidateWithRulesAndMaximumErrors(t *testing.T) {
t.Run("maximumErrors limits error count", func(t *testing.T) {
s := gqlparser.MustLoadSchema(
&ast.Source{
Name: "graph/schema.graphqls",
Input: `
type Query {
field1: String!
field2: String!
field3: String!
}
`, BuiltIn: false},
)

q, err := parser.ParseQuery(&ast.Source{
Name: "SomeQuery",
Input: `
query {
field1
field2
field3
}
`})
require.NoError(t, err)

// Create a rule that generates errors for each field
errorRule := validator.Rule{
Name: "ErrorRule",
RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) {
observers.OnField(func(walker *validator.Walker, field *ast.Field) {
addError(validator.Message("Error for field %s", field.Name))
})
},
}

rules := rules.NewRules(errorRule)
errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 2)

// Should only return 2 errors even though there are 3 fields
require.Len(t, errList, 2)
})

t.Run("maximumErrors zero means no limit", func(t *testing.T) {
s := gqlparser.MustLoadSchema(
&ast.Source{
Name: "graph/schema.graphqls",
Input: `
type Query {
field1: String!
field2: String!
field3: String!
}
`, BuiltIn: false},
)

q, err := parser.ParseQuery(&ast.Source{
Name: "SomeQuery",
Input: `
query {
field1
field2
field3
}
`})
require.NoError(t, err)

// Create a rule that generates errors for each field
errorRule := validator.Rule{
Name: "ErrorRule",
RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) {
observers.OnField(func(walker *validator.Walker, field *ast.Field) {
addError(validator.Message("Error for field %s", field.Name))
})
},
}

rules := rules.NewRules(errorRule)
errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 0)

// Should return all errors when maximumErrors is 0
require.Len(t, errList, 3)
})

t.Run("negative maximumErrors returns error", func(t *testing.T) {
s := gqlparser.MustLoadSchema(
&ast.Source{
Name: "graph/schema.graphqls",
Input: `
type Query {
field1: String!
}
`, BuiltIn: false},
)

q, err := parser.ParseQuery(&ast.Source{
Name: "SomeQuery",
Input: `
query {
field1
}
`})
require.NoError(t, err)

errList := validator.ValidateWithRulesAndMaximumErrors(s, q, nil, -1)

// Should return an error about negative maximumErrors
require.Len(t, errList, 1)
require.Contains(t, errList[0].Message, "maximumErrors cannot be negative")
})

t.Run("maximumErrors stops traversal early", func(t *testing.T) {
s := gqlparser.MustLoadSchema(
&ast.Source{
Name: "graph/schema.graphqls",
Input: `
type Query {
field1: String!
field2: String!
field3: String!
field4: String!
field5: String!
}
`, BuiltIn: false},
)

q, err := parser.ParseQuery(&ast.Source{
Name: "SomeQuery",
Input: `
query {
field1
field2
field3
field4
field5
}
`})
require.NoError(t, err)

fieldCount := 0
// Create a rule that generates errors and counts fields
errorRule := validator.Rule{
Name: "ErrorRule",
RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) {
observers.OnField(func(walker *validator.Walker, field *ast.Field) {
fieldCount++
addError(validator.Message("Error for field %s", field.Name))
})
},
}

rules := rules.NewRules(errorRule)
errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 2)

// Should only return 2 errors
require.Len(t, errList, 2)
// Should have stopped traversal early after exactly 2 fields processed
require.Equal(t, 2, fieldCount)
})

t.Run("maximumErrors with multiple rules", func(t *testing.T) {
s := gqlparser.MustLoadSchema(
&ast.Source{
Name: "graph/schema.graphqls",
Input: `
type Query {
field1: String!
field2: String!
}
`, BuiltIn: false},
)

q, err := parser.ParseQuery(&ast.Source{
Name: "SomeQuery",
Input: `
query {
field1
field2
}
`})
require.NoError(t, err)

// Create two rules that each generate errors and count fields
fieldCount := 0
rule1 := validator.Rule{
Name: "Rule1",
RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) {
observers.OnField(func(walker *validator.Walker, field *ast.Field) {
fieldCount++
addError(validator.Message("Rule1 error for field %s", field.Name))
})
},
}
rule2 := validator.Rule{
Name: "Rule2",
RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) {
observers.OnField(func(walker *validator.Walker, field *ast.Field) {
fieldCount++
addError(validator.Message("Rule2 error for field %s", field.Name))
})
},
}

rules := rules.NewRules(rule1, rule2)
errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 3)

// Although we set maximumErrors to 3, we expect 4 errors here (2 rules × 2 fields).
// The limit is evaluated after the batch is processed, allowing a final overflow.
require.Equal(t, 4, fieldCount)
require.Equal(t, 4, len(errList))
})
}
20 changes: 20 additions & 0 deletions validator/walk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,23 @@ func TestWalkInlineFragment(t *testing.T) {

require.True(t, called)
}

func TestWalkStoppedEarly(t *testing.T) {
schema, err := LoadSchema(Prelude, &ast.Source{Input: "type Query { name: String, age: Int }\n schema { query: Query }"})
require.NoError(t, err)
query, err := parser.ParseQuery(&ast.Source{Input: "{ name age }"})
require.NoError(t, err)

fieldCount := 0
observers := &Events{}
observers.OnField(func(walker *Walker, field *ast.Field) {
fieldCount++
if fieldCount == 1 {
observers.Stopped = true
}
})

Walk(schema, query, observers)

require.Equal(t, 1, fieldCount)
}
Loading