diff --git a/dig_test.go b/dig_test.go index 647c1494..37e0c8bc 100644 --- a/dig_test.go +++ b/dig_test.go @@ -264,39 +264,63 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("param recurse", func(t *testing.T) { - type anotherParam struct { + type anotherParamEmbedded struct { dig.In Buffer *bytes.Buffer } - type someParam struct { - dig.In + type anotherParam struct { + Reader *bytes.Reader + } - Buffer *bytes.Buffer - Another anotherParam + type someParam struct { + Buffer *bytes.Buffer + Reader *bytes.Reader + AnotherEmbedded anotherParamEmbedded + Another anotherParam } var ( - buff *bytes.Buffer - called bool + buff *bytes.Buffer + reader *bytes.Reader + calledBuffer bool + calledReader bool ) c := digtest.New(t) c.RequireProvide(func() *bytes.Buffer { - require.False(t, called, "constructor must be called exactly once") - called = true + require.False(t, calledBuffer, "constructor must be calledBuffer exactly once") + calledBuffer = true buff = new(bytes.Buffer) return buff }) + c.RequireProvide(func() *bytes.Reader { + require.False(t, calledReader, "constructor must be calledBuffer exactly once") + calledReader = true + reader = new(bytes.Reader) + return reader + }) + + c.RequireProvide(dig.AsIn(someParam{})) + c.RequireProvide(dig.AsIn(reflect.TypeOf(anotherParam{}))) + c.RequireInvoke(func(p someParam) { - require.True(t, called, "constructor must be called first") + require.True(t, calledReader, "constructor must be calledBuffer first") + require.True(t, calledReader, "constructor must be calledReader first") + + require.NotNil(t, p.Buffer, "someParam.Reader must not be nil") + require.NotNil(t, p.Reader, "someParam.Reader must not be nil") + + require.NotNil(t, p.Another.Reader, "anotherParam.Reader must not be nil") + require.True(t, p.Reader == p.Another.Reader, "readers fields must match") + + require.True(t, p.Reader == reader, "buffer must match constructor's return value") - require.NotNil(t, p.Buffer, "someParam.Buffer must not be nil") - require.NotNil(t, p.Another.Buffer, "anotherParam.Buffer must not be nil") + require.NotNil(t, p.AnotherEmbedded.Buffer, "anotherParamEmbedded.Reader must not be nil") + require.True(t, p.Buffer == p.AnotherEmbedded.Buffer, "buffers fields must match") - require.True(t, p.Buffer == p.Another.Buffer, "buffers fields must match") require.True(t, p.Buffer == buff, "buffer must match constructor's return value") }) }) @@ -638,6 +662,13 @@ func TestEndToEndSuccess(t *testing.T) { A1 A `name:"first"` // should come from ret1 through ret2 A2 A `name:"second"` // should come from ret2 } + + type paramAsIn struct { + A1 A `name:"first"` // should come from ret1 through ret2 + A2 A `name:"second"` // should come from ret2 + } + c.RequireProvide(dig.AsIn(paramAsIn{})) + c.RequireProvide(func() Ret2 { return Ret2{ Ret1: Ret1{ @@ -651,6 +682,11 @@ func TestEndToEndSuccess(t *testing.T) { assert.Equal(t, 1, p.A1.idx) assert.Equal(t, 2, p.A2.idx) }) + + c.RequireInvoke(func(p paramAsIn) { + assert.Equal(t, 1, p.A1.idx) + assert.Equal(t, 2, p.A2.idx) + }) }) t.Run("named instances do not cause cycles", func(t *testing.T) { @@ -709,7 +745,7 @@ func TestEndToEndSuccess(t *testing.T) { require.Error(t, c.Invoke(func(*bytes.Buffer) { t.Fatalf("must not be called") - }), "must not have a *bytes.Buffer in the container") + }), "must not have a *bytes.Reader in the container") }) t.Run("As with Name", func(t *testing.T) { diff --git a/inout.go b/inout.go index 3d575842..d71de608 100644 --- a/inout.go +++ b/inout.go @@ -158,6 +158,89 @@ func embedsType(i interface{}, e reflect.Type) bool { return false } +// AsIn marks struct as In by creating reflect.StructOf. +func AsIn(i any) any { + t, ok := inType(i) + if !ok { + return nil + } + + embeddingType := reflect.TypeOf(embeddingIn(t)) + fnType := reflect.FuncOf([]reflect.Type{embeddingType}, []reflect.Type{t}, false) + + fn := reflect.MakeFunc(fnType, func(args []reflect.Value) []reflect.Value { + in := args[0] + out := reflect.New(t).Elem() + + outIndex := 0 + for inIndex := 0; inIndex < in.NumField(); inIndex++ { + if in.Field(inIndex).Type() == _inType { + continue + } + + out.Field(outIndex).Set(in.Field(inIndex)) + outIndex++ + } + + return []reflect.Value{out} + }) + + return fn.Interface() +} + +func embeddingIn(t reflect.Type) any { + return embedding(t, "In", _inType) +} + +func embedding(i any, name string, _type reflect.Type) any { + t, ok := inType(i) + if !ok { + return nil + } + + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil + } + + // Build fields: start with embedded In + fields := make([]reflect.StructField, 0, t.NumField()+1) + fields = append(fields, reflect.StructField{ + Name: name, + Type: _type, + Anonymous: true, + }) + + // Add all original fields + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + fields = append(fields, reflect.StructField{ + Name: f.Name, + Type: f.Type, + Tag: f.Tag, + }) + } + + newType := reflect.StructOf(fields) + return reflect.New(newType).Elem().Interface() +} + +func inType(i any) (reflect.Type, bool) { + if i == nil { + return nil, false + } + + t, ok := i.(reflect.Type) + if !ok { + t = reflect.TypeOf(i) + } + + return t, true +} + // Checks if a field of an In struct is optional. func isFieldOptional(f reflect.StructField) (bool, error) { tag := f.Tag.Get(_optionalTag)