diff --git a/parser/argument.go b/parser/argument.go index 93420b1..0e68c3f 100644 --- a/parser/argument.go +++ b/parser/argument.go @@ -28,10 +28,6 @@ func NewArgument(input *Input, field reflect.StructField) (*Argument, error) { defaultValue: field.Tag.Get("default"), } - if err := validateArgumentType(argument); err != nil { - return nil, err - } - type_, err := getOrCreateArgumentType(field.Type) if err != nil { return nil, err @@ -79,45 +75,25 @@ func (arg *Argument) StructField() reflect.StructField { return arg.structField } -func validateArgumentType(arg *Argument) error { - kind, err := getTypeKind(arg.structField.Type) +func getOrCreateArgumentType(t reflect.Type) (Type, error) { + unupportedErr := fmt.Errorf("interface and union not supported for argument type") + kind, err := getTypeKind(t) if err != nil { - return err - } - - switch kind { - case KindInterface, KindUnion, KindInterfaceDefinition: - return fmt.Errorf( - "argument type %s not supported for field %s on struct %s \nif you think this is a mistake please open an issue at github.com/shreyas44/groot", - arg.structField.Type.Name(), - arg.structField.Name, - arg.Input().reflectType.Name(), - ) + return nil, err } - return nil -} - -func getOrCreateArgumentType(t reflect.Type) (Type, error) { parserType, ok := cache.get(t) if ok { - kind, err := getTypeKind(t) - if err != nil { - return nil, err - } - switch kind { + case KindObject: + if _, ok := parserType.(*Input); ok { + return parserType, nil + } case KindInterface, KindUnion, KindInterfaceDefinition: - err := fmt.Errorf("") - return nil, err + return nil, unupportedErr + default: + return parserType, nil } - - return parserType, nil - } - - kind, err := getTypeKind(t) - if err != nil { - return nil, err } switch kind { @@ -132,7 +108,7 @@ func getOrCreateArgumentType(t reflect.Type) (Type, error) { case KindNullable: return NewNullable(t, true) case KindInterface, KindUnion, KindInterfaceDefinition: - return nil, fmt.Errorf("interface and union not supported for argument type") + return nil, unupportedErr } panic("parser: unexpected error occurred") diff --git a/parser/argument_test.go b/parser/argument_test.go new file mode 100644 index 0000000..49c244c --- /dev/null +++ b/parser/argument_test.go @@ -0,0 +1,231 @@ +package parser + +import ( + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ArgTestCustomScalar string +type ArgTestEnum string +type ArgTestEmptyInput struct{} +type ArgTestUnionMember1 struct{} +type ArgTestUnionMember2 struct{} + +type ArgTestUnion struct { + UnionType + ArgTestUnionMember1 + ArgTestUnionMember2 +} + +type ArgTestInterfaceDefinition struct { + InterfaceType +} + +type ArgTestInterface interface { + ImplementsArgTestInterface() ArgTestInterfaceDefinition +} + +type ArgTestInput struct { + StringArg string `json:"stringArg"` + NilJsonArg string `json:"-"` + ArgWithoutJSON string + //lint:ignore U1000 argument is used through reflection + unexportedArg string +} + +const ( + ArgTestEnum_One ArgTestEnum = "One" + ArgTestEnum_Two ArgTestEnum = "Two" +) + +func (e ArgTestEnum) Values() []string { + return []string{ + string(ArgTestEnum_One), + string(ArgTestEnum_Two), + } +} + +func TestNewArgument(t *testing.T) { + inputType := reflect.TypeOf(ArgTestInput{}) + stringType := reflect.TypeOf("") + stringArg, _ := inputType.FieldByName("StringArg") + input := &Input{reflectType: inputType} + + arg, err := NewArgument(input, stringArg) + require.Nil(t, err) + + expectedArg := &Argument{ + input: input, + structField: stringArg, + type_: &Scalar{stringType}, + jsonName: "stringArg", + } + + assert.Equal(t, expectedArg, arg) + + t.Run("TestNilArgumentReturned", func(t *testing.T) { + testCases := []struct { + name string + fieldName string + }{ + { + name: "WithUnexportedArg", + fieldName: "unexportedArg", + }, + { + name: "WithNilJsonArg", + fieldName: "NilJsonArg", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + field, _ := inputType.FieldByName(testCase.fieldName) + arg, err := NewArgument(input, field) + assert.Nil(t, arg) + assert.Nil(t, err) + }) + } + }) + + t.Run("TestJSONName", func(t *testing.T) { + testCases := []struct { + name string + fieldName string + expectedJSONName string + }{ + { + name: "WithJSONStructTag", + fieldName: "StringArg", + expectedJSONName: "stringArg", + }, + { + name: "WithoutJSONStructTag", + fieldName: "ArgWithoutJSON", + expectedJSONName: "ArgWithoutJSON", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + field, _ := inputType.FieldByName(testCase.fieldName) + arg, err := NewArgument(input, field) + assert.Nil(t, err) + assert.Equal(t, testCase.expectedJSONName, arg.JSONName()) + }) + } + + }) +} + +func TestGetOrCreateArgumentType(t *testing.T) { + var ( + stringType = reflect.TypeOf("") + customScalarType = reflect.TypeOf(ArgTestCustomScalar("")) + enumType = reflect.TypeOf(ArgTestEnum_One) + listType = reflect.TypeOf([]string{}) + structType = reflect.TypeOf(ArgTestEmptyInput{}) + interfaceType = reflect.TypeOf((*ArgTestInterface)(nil)).Elem() + interfaceDefType = reflect.TypeOf(ArgTestInterfaceDefinition{}) + nullableStringType = reflect.TypeOf((*string)(nil)) + unionType = reflect.TypeOf(ArgTestUnion{}) + unsupportedErr = errors.New("interface and union not supported for argument type") + testCases = []struct { + name string + typ reflect.Type + expectedErr error + expectedType Type + }{ + { + name: "Scalar", + typ: stringType, + expectedErr: nil, + expectedType: &Scalar{stringType}, + }, + { + name: "CustomScalar", + typ: customScalarType, + expectedErr: nil, + expectedType: &Scalar{customScalarType}, + }, + { + name: "Enum", + typ: enumType, + expectedErr: nil, + expectedType: &Enum{enumType, ArgTestEnum_One.Values()}, + }, + { + name: "List", + typ: listType, + expectedErr: nil, + expectedType: &Array{listType, &Scalar{stringType}}, + }, + { + name: "Input", + typ: structType, + expectedErr: nil, + expectedType: &Input{structType, nil, []*Argument{}}, + }, + { + name: "Nullable", + typ: nullableStringType, + expectedErr: nil, + expectedType: &Nullable{nullableStringType, &Scalar{stringType}}, + }, + { + name: "Interface", + typ: interfaceType, + expectedErr: unsupportedErr, + expectedType: nil, + }, + { + name: "InterfaceDefinition", + typ: interfaceDefType, + expectedErr: unsupportedErr, + expectedType: nil, + }, + { + name: "Union", + typ: unionType, + expectedErr: unsupportedErr, + expectedType: nil, + }, + } + ) + + t.Run("WithEmptyCache", func(t *testing.T) { + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + resetCache() + typ, err := getOrCreateArgumentType(testCase.typ) + assert.Equal(t, testCase.expectedErr, err) + assert.Equal(t, testCase.expectedType, typ) + }) + } + }) + + t.Run("WithCacheContainingFieldTypes", func(t *testing.T) { + resetCache() + + // fill cache + cache.set(stringType, &Scalar{stringType}) + cache.set(structType, &Object{structType, []*Field{}, []*Interface{}}) + cache.set(interfaceType, &Interface{interfaceType, []*Field{}}) + cache.set(unionType, &Union{unionType, []*Object{ + {reflect.TypeOf(ArgTestUnionMember1{}), []*Field{}, []*Interface{}}, + {reflect.TypeOf(ArgTestUnionMember2{}), []*Field{}, []*Interface{}}, + }}) + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + typ, err := getOrCreateArgumentType(testCase.typ) + assert.Equal(t, testCase.expectedErr, err) + assert.Equal(t, testCase.expectedType, typ) + }) + } + }) +} diff --git a/parser/array_test.go b/parser/array_test.go new file mode 100644 index 0000000..9ac60a9 --- /dev/null +++ b/parser/array_test.go @@ -0,0 +1,40 @@ +package parser + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ExampleArrayElem struct{} + +func TestParsedArray(t *testing.T) { + var ( + stringListType = reflect.TypeOf([]string{}) + stringType = reflect.TypeOf("") + structListType = reflect.TypeOf([]ExampleArrayElem{}) + structElem = structListType.Elem() + + testCases = []struct { + name string + isArg bool + reflectTyp reflect.Type + expectedType Type + }{ + {"FieldWithStructElement", false, structListType, &Array{structListType, &Object{structElem, []*Field{}, []*Interface{}}}}, + {"ArgWithStructElement", true, structListType, &Array{structListType, &Input{structElem, nil, []*Argument{}}}}, + {"FieldWithStringElement", false, stringListType, &Array{stringListType, &Scalar{stringType}}}, + {"ArgWithStringElement", true, stringListType, &Array{stringListType, &Scalar{stringType}}}, + } + ) + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + typ, err := NewArray(testCase.reflectTyp, testCase.isArg) + require.Nil(t, err) + assert.Equal(t, testCase.expectedType, typ, testCase) + }) + } +} diff --git a/parser/cache_test.go b/parser/cache_test.go new file mode 100644 index 0000000..b185b0f --- /dev/null +++ b/parser/cache_test.go @@ -0,0 +1,33 @@ +package parser + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func resetCache() { + cache = map[reflect.Type]Type{} +} + +func TestCache(t *testing.T) { + resetCache() + stringType := reflect.TypeOf("") + intType := reflect.TypeOf(0) + stringScalar := &Scalar{stringType} + + t.Run("CacheExists", func(t *testing.T) { + cache.set(stringType, stringScalar) + + cacheVal, exists := cache.get(stringType) + assert.Equal(t, stringScalar, cacheVal) + assert.True(t, exists) + }) + + t.Run("CacheNotExists", func(t *testing.T) { + cacheVal, exists := cache.get(intType) + assert.Nil(t, cacheVal) + assert.False(t, exists) + }) +} diff --git a/parser/enum_test.go b/parser/enum_test.go new file mode 100644 index 0000000..57f70a8 --- /dev/null +++ b/parser/enum_test.go @@ -0,0 +1,54 @@ +package parser + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ExampleEnumWithValues string +type ExampleEnumWithoutValues string + +const ( + ExampleEnumWithValues_One ExampleEnumWithValues = "ONE" + ExampleEnumWithValues_Two ExampleEnumWithValues = "TWO" +) + +const ( + ExampleEnumWithoutValues_One ExampleEnumWithoutValues = "ONE" + ExampleEnumWithoutValues_Two ExampleEnumWithoutValues = "TWO" +) + +func (ExampleEnumWithValues) Values() []string { + return []string{ + string(ExampleEnumWithValues_One), + string(ExampleEnumWithValues_Two), + } +} + +func TestParsedEnumWithoutValues(t *testing.T) { + defer func() { recover() }() + + reflectT := reflect.TypeOf(ExampleEnumWithoutValues("")) + NewEnum(reflectT) + + t.Error("expected enum without Values method to panic") +} + +func TestParsedEnumWithValues(t *testing.T) { + reflectT := reflect.TypeOf(ExampleEnumWithValues("")) + enum, err := NewEnum(reflectT) + require.Nil(t, err) + + expectedEnum := &Enum{ + reflectT, + []string{ + string(ExampleEnumWithValues_One), + string(ExampleEnumWithValues_Two), + }, + } + + assert.Equal(t, expectedEnum, enum) +} diff --git a/parser/field_test.go b/parser/field_test.go new file mode 100644 index 0000000..1afae4f --- /dev/null +++ b/parser/field_test.go @@ -0,0 +1,7 @@ +package parser + +import "testing" + +func TestXxx(t *testing.T) { + +} diff --git a/parser/type.go b/parser/type.go index 5434cda..f25c523 100644 --- a/parser/type.go +++ b/parser/type.go @@ -62,16 +62,23 @@ func validateTypeKind(t reflect.Type, expected ...Kind) error { } func getOrCreateType(t reflect.Type) (Type, error) { - parserType, ok := cache.get(t) - if ok { - return parserType, nil - } - kind, err := getTypeKind(t) if err != nil { return nil, err } + parserType, ok := cache.get(t) + if ok { + switch kind { + case KindObject: + if _, ok := parserType.(*Object); ok { + return parserType, nil + } + default: + return parserType, nil + } + } + switch kind { case KindScalar, KindCustomScalar: return NewScalar(t) diff --git a/parser/validator.go b/parser/validator.go index a17fc63..7beba53 100644 --- a/parser/validator.go +++ b/parser/validator.go @@ -55,8 +55,10 @@ func (v *InputValidator) ReflectMethod() reflect.Method { func validateArgValidator(method reflect.Method, arg *Argument) error { errorInterface := reflect.TypeOf((*error)(nil)).Elem() + numIn := method.Type.NumIn() + numOut := method.Type.NumOut() - if method.Type.NumIn() != 2 || method.Type.In(1) != arg.structField.Type { + if numIn != 2 || (numIn > 1 && method.Type.In(1) != arg.structField.Type) { return fmt.Errorf( "method %s on struct %s expected to have 1 argument of type (%s)", method.Name, @@ -65,7 +67,7 @@ func validateArgValidator(method reflect.Method, arg *Argument) error { ) } - if method.Type.NumOut() != 1 || method.Type.Out(0) != errorInterface { + if numOut != 1 || (numOut > 0 && method.Type.Out(0) != errorInterface) { return fmt.Errorf( "method %s on struct %s expected to return only error", method.Name, @@ -97,5 +99,3 @@ func validateInputValidator(method reflect.Method) error { return nil } - -// TODO: validate return type diff --git a/parser/validator_test.go b/parser/validator_test.go new file mode 100644 index 0000000..ca82f73 --- /dev/null +++ b/parser/validator_test.go @@ -0,0 +1,180 @@ +package parser + +import ( + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type InputWithValidValidator struct{} +type InputWithNoValidator struct{} +type InputWithInvalidValidator struct{} + +func (InputWithValidValidator) Validate() error { + return nil +} + +// should return error +func (InputWithInvalidValidator) Validate() {} + +type ArgValidatorTestInput struct { + ArgWithValidator string + ArgWithoutValidator string + ArgWithInvalidValidatorNoReturn string + ArgWithInvalidValidatorGreaterReturnCount string + ArgWithInvalidValidatorInvalidReturnType string + ArgWithInvalidValidatorNoArg string + ArgWithInvalidValidatorGreaterArgCount string + ArgWithInvalidValidatorInvalidArgType string +} + +func (ArgValidatorTestInput) ValidateArgWithValidator(stringArg string) error { + return nil +} + +func (ArgValidatorTestInput) ValidateArgWithInvalidValidatorNoReturn(stringArg string) {} + +func (ArgValidatorTestInput) ValidateArgWithInvalidValidatorGreaterReturnCount(stringArg string) (string, error) { + return "", nil +} + +func (ArgValidatorTestInput) ValidateArgWithInvalidValidatorInvalidReturnType(stringArg string) string { + return "" +} + +func (ArgValidatorTestInput) ValidateArgWithInvalidValidatorNoArg() error { + return nil +} + +func (ArgValidatorTestInput) ValidateArgWithInvalidValidatorGreaterArgCount(stringArg string, intArg int) error { + return nil +} + +func (ArgValidatorTestInput) ValidateArgWithInvalidValidatorInvalidArgType(intArg int) error { + return nil +} + +func TestNewArgumentValidator(t *testing.T) { + inputReflectTyp := reflect.TypeOf(ArgValidatorTestInput{}) + input := &Input{reflectType: inputReflectTyp} + setup := func(fieldName string) (reflect.Method, *Argument, *ArgumentValidator, error) { + field, _ := inputReflectTyp.FieldByName(fieldName) + method, _ := inputReflectTyp.MethodByName("Validate" + fieldName) + arg := &Argument{structField: field, input: input} + + validator, err := NewArgumentValidator(arg) + return method, arg, validator, err + } + + t.Run("ReturnsValidator", func(t *testing.T) { + method, arg, validator, err := setup("ArgWithValidator") + assert.Equal(t, method.Type, validator.reflectMethod.Type) + assert.Equal(t, arg, validator.argument) + assert.Nil(t, err) + }) + + t.Run("ReturnsNilValidatorNilErr", func(t *testing.T) { + _, _, validator, err := setup("ArgWithoutValidator") + assert.Nil(t, validator) + assert.Nil(t, err) + }) + + t.Run("ReturnsErr", func(t *testing.T) { + _, _, validator, err := setup("ArgWithInvalidValidatorNoReturn") + assert.NotNil(t, err) + assert.Nil(t, validator) + }) +} + +func TestNewInputValidator(t *testing.T) { + setup := func(typ interface{}) (reflect.Type, *Input, *InputValidator, error) { + inputType := reflect.TypeOf(typ) + input := &Input{reflectType: inputType} + validator, err := NewInputValidator(input) + return inputType, input, validator, err + } + + t.Run("ReturnsValidator", func(t *testing.T) { + inputType, input, validator, err := setup(InputWithValidValidator{}) + assert.Equal(t, inputType.Method(0).Type, validator.reflectMethod.Type) + assert.Equal(t, input, validator.input) + assert.Nil(t, err) + }) + + t.Run("ReturnsNilValidatorNilErr", func(t *testing.T) { + _, _, validator, err := setup(InputWithNoValidator{}) + assert.Nil(t, validator) + assert.Nil(t, err) + }) + + t.Run("ReturnsErr", func(t *testing.T) { + _, _, validator, err := setup(InputWithInvalidValidator{}) + assert.Nil(t, validator) + assert.NotNil(t, err) + }) +} + +func TestValidateArgsValidator(t *testing.T) { + inputTyp := reflect.TypeOf(ArgValidatorTestInput{}) + testCases := []struct { + name string + fieldName string + expectedErr error + }{ + { + name: "NoErr", + fieldName: "ArgWithValidator", + expectedErr: nil, + }, + { + name: "NoReturnValues", + fieldName: "ArgWithInvalidValidatorNoReturn", + expectedErr: errors.New(""), + }, + { + name: "TooManyReturnValues", + fieldName: "ArgWithInvalidValidatorGreaterReturnCount", + expectedErr: errors.New(""), + }, + { + name: "InvalidReturnType", + fieldName: "ArgWithInvalidValidatorInvalidReturnType", + expectedErr: errors.New(""), + }, + { + name: "NoArguments", + fieldName: "ArgWithInvalidValidatorNoArg", + expectedErr: errors.New(""), + }, + { + name: "TooManyArguments", + fieldName: "ArgWithInvalidValidatorGreaterArgCount", + expectedErr: errors.New(""), + }, + { + name: "InvalidArgumentType", + fieldName: "ArgWithInvalidValidatorInvalidArgType", + expectedErr: errors.New(""), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + field, _ := inputTyp.FieldByName(tc.fieldName) + method, _ := inputTyp.MethodByName("Validate" + tc.fieldName) + + input := &Input{reflectType: inputTyp} + arg := &Argument{structField: field, input: input} + + _ = validateArgValidator(method, arg) + // TODO: assert + // assert.Equal(t, tc.expectedErr, err) + }) + } +} + +func TestValidateInputValidator(t *testing.T) { + +}