diff --git a/.changelog/e7d0a1e41e544ce1a7a9e92b4233ee3f.json b/.changelog/e7d0a1e41e544ce1a7a9e92b4233ee3f.json new file mode 100644 index 00000000..246076ff --- /dev/null +++ b/.changelog/e7d0a1e41e544ce1a7a9e92b4233ee3f.json @@ -0,0 +1,8 @@ +{ + "id": "e7d0a1e4-1e54-4ce1-a7a9-e92b4233ee3f", + "type": "bugfix", + "description": "Fixed a panic when deserializing nested unions in JSON- and CBOR-based protocols.", + "modules": [ + "." + ] +} diff --git a/.changelog/fc69b4a3361441429959c7580656e5fe.json b/.changelog/fc69b4a3361441429959c7580656e5fe.json new file mode 100644 index 00000000..3d85c137 --- /dev/null +++ b/.changelog/fc69b4a3361441429959c7580656e5fe.json @@ -0,0 +1,8 @@ +{ + "id": "fc69b4a3-3614-4142-9959-c7580656e5fe", + "type": "bugfix", + "description": "Fixed a deserialization failure in all protocols when encountering a union with explicit null members.", + "modules": [ + "." + ] +} \ No newline at end of file diff --git a/serde.go b/serde.go index dacc4390..a9effc56 100644 --- a/serde.go +++ b/serde.go @@ -134,14 +134,12 @@ type DeserializableError interface { // ReadUnion is a utility API for generated clients. func ReadUnion(d ShapeDeserializer, schema *Schema, memberFn func(*Schema) error) error { ms, err := d.ReadUnion(schema) - if err != nil { + if ms == nil || err != nil { return err } - if ms != nil { - if err := memberFn(ms); err != nil { - return err - } + if err := memberFn(ms); err != nil { + return err } for { diff --git a/transport/http/protocol/internal/cbor/shape_deserializer.go b/transport/http/protocol/internal/cbor/shape_deserializer.go index eac54c11..133dab21 100644 --- a/transport/http/protocol/internal/cbor/shape_deserializer.go +++ b/transport/http/protocol/internal/cbor/shape_deserializer.go @@ -523,7 +523,7 @@ func (d *ShapeDeserializer) ReadStructMember() (*smithy.Schema, error) { // ReadUnion implements [smithy.ShapeDeserializer]. func (d *ShapeDeserializer) ReadUnion(s *smithy.Schema) (*smithy.Schema, error) { top := d.head.Top() - if top == nil || top.kind != deserCtxUnion { // first call: open the map + if top == nil || top.kind != deserCtxUnion || top.schema != s { // first call: open the map if d.eof() { return nil, errUnexpectedEOF } diff --git a/transport/http/protocol/internal/cbor/union_in_union_test.go b/transport/http/protocol/internal/cbor/union_in_union_test.go new file mode 100644 index 00000000..a3d59b58 --- /dev/null +++ b/transport/http/protocol/internal/cbor/union_in_union_test.go @@ -0,0 +1,56 @@ +package cbor + +import ( + "testing" + + "github.com/aws/smithy-go" + smithycbor "github.com/aws/smithy-go/encoding/cbor" + "github.com/aws/smithy-go/prelude" +) + +var ( + testSchemaInnerUnion = smithy.NewSchema(smithy.ShapeID{ + Namespace: "com.test", Name: "InnerUnion", + }, smithy.ShapeTypeUnion, 1) + + testSchemaOuterUnion = smithy.NewSchema(smithy.ShapeID{ + Namespace: "com.test", Name: "OuterUnion", + }, smithy.ShapeTypeUnion, 1) + + testSchemaInnerUnion_Lambda *smithy.Schema + testSchemaOuterUnion_Mcp *smithy.Schema +) + +func init() { + testSchemaInnerUnion_Lambda = testSchemaInnerUnion.AddMember("lambda", prelude.String) + testSchemaOuterUnion_Mcp = testSchemaOuterUnion.AddMember("mcp", testSchemaInnerUnion) +} + +func TestReadUnion_NestedUnionValue(t *testing.T) { + // CBOR equivalent of {"mcp":{"lambda":"arn:aws:lambda:fn"}} + payload := smithycbor.Encode(smithycbor.Map{ + "mcp": smithycbor.Map{ + "lambda": smithycbor.String("arn:aws:lambda:fn"), + }, + }) + + d := NewShapeDeserializer(payload) + + var member, value string + err := smithy.ReadUnion(d, testSchemaOuterUnion, func(ms *smithy.Schema) error { + member = ms.MemberName() + return smithy.ReadUnion(d, testSchemaInnerUnion, func(inner *smithy.Schema) error { + member += "." + inner.MemberName() + return d.ReadString(inner, &value) + }) + }) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if member != "mcp.lambda" { + t.Errorf("expected member mcp.lambda, got %q", member) + } + if value != "arn:aws:lambda:fn" { + t.Errorf("expected value arn:aws:lambda:fn, got %q", value) + } +} diff --git a/transport/http/protocol/internal/json/shape_deserializer.go b/transport/http/protocol/internal/json/shape_deserializer.go index b1337833..ff21e308 100644 --- a/transport/http/protocol/internal/json/shape_deserializer.go +++ b/transport/http/protocol/internal/json/shape_deserializer.go @@ -11,11 +11,11 @@ import ( "time" "github.com/aws/smithy-go" - "github.com/aws/smithy-go/transport/http/protocol/internal/json/internal/stdlib" "github.com/aws/smithy-go/document" "github.com/aws/smithy-go/internal/serde" smithytime "github.com/aws/smithy-go/time" "github.com/aws/smithy-go/traits" + "github.com/aws/smithy-go/transport/http/protocol/internal/json/internal/stdlib" ) type ctxKind int8 @@ -500,7 +500,7 @@ func (d *ShapeDeserializer) ReadStructMember() (*smithy.Schema, error) { // ReadUnion implements [smithy.ShapeDeserializer]. func (d *ShapeDeserializer) ReadUnion(s *smithy.Schema) (*smithy.Schema, error) { - if top := d.head.Top(); top == nil || top.kind != ctxUnion { + if top := d.head.Top(); top == nil || top.kind != ctxUnion || top.schema != s { if isNil, err := d.ReadNil(s); isNil || err != nil { return nil, err } @@ -512,7 +512,7 @@ func (d *ShapeDeserializer) ReadUnion(s *smithy.Schema) (*smithy.Schema, error) if !isLCB(tok) { return nil, fmt.Errorf("expected '{', got %s", tok) } - d.head.Push(deserCtx{kind: ctxUnion}) + d.head.Push(deserCtx{kind: ctxUnion, schema: s}) } for { diff --git a/transport/http/protocol/internal/json/union_in_union_test.go b/transport/http/protocol/internal/json/union_in_union_test.go new file mode 100644 index 00000000..c4409740 --- /dev/null +++ b/transport/http/protocol/internal/json/union_in_union_test.go @@ -0,0 +1,137 @@ +package json + +import ( + "testing" + + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/prelude" +) + +// Schemas modeling a union member whose value is itself a union, e.g. +// bedrock-agentcore's TargetConfiguration -> mcp -> McpTargetConfiguration. +var ( + testSchemaInnerUnion = smithy.NewSchema(smithy.ShapeID{ + Namespace: "com.test", Name: "InnerUnion", + }, smithy.ShapeTypeUnion, 1) + + testSchemaOuterUnion = smithy.NewSchema(smithy.ShapeID{ + Namespace: "com.test", Name: "OuterUnion", + }, smithy.ShapeTypeUnion, 1) + + testSchemaInnerStruct = smithy.NewSchema(smithy.ShapeID{ + Namespace: "com.test", Name: "InnerStruct", + }, smithy.ShapeTypeStructure, 1) + + testSchemaInnerUnion_Lambda *smithy.Schema + testSchemaOuterUnion_Mcp *smithy.Schema + testSchemaInnerStruct_Value *smithy.Schema +) + +func init() { + testSchemaInnerUnion_Lambda = testSchemaInnerUnion.AddMember("lambda", prelude.String) + testSchemaOuterUnion_Mcp = testSchemaOuterUnion.AddMember("mcp", testSchemaInnerUnion) + testSchemaInnerStruct_Value = testSchemaInnerStruct.AddMember("value", prelude.String) +} + +// readNestedUnion mimics the calling pattern of SDK-generated code +// (smithy.ReadUnion in serde.go): repeatedly call ReadUnion until it returns +// no member, deserializing each member value in between. +func readNestedUnion(d *ShapeDeserializer) (member string, value string, err error) { + err = smithy.ReadUnion(d, testSchemaOuterUnion, func(ms *smithy.Schema) error { + member = ms.MemberName() + return smithy.ReadUnion(d, testSchemaInnerUnion, func(inner *smithy.Schema) error { + member += "." + inner.MemberName() + return d.ReadString(inner, &value) + }) + }) + return member, value, err +} + +func TestReadUnion_NestedUnionValue(t *testing.T) { + // A union whose member value is itself a union. Before the fix, + // ReadUnion mistook the parent's union context for its own, skipped the + // inner '{' and panicked in memberFromToken (slice bounds [1:0]). + d := NewShapeDeserializer([]byte(`{"mcp":{"lambda":"arn:aws:lambda:fn"}}`)) + member, value, err := readNestedUnion(d) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if member != "mcp.lambda" { + t.Errorf("expected member mcp.lambda, got %q", member) + } + if value != "arn:aws:lambda:fn" { + t.Errorf("expected value arn:aws:lambda:fn, got %q", value) + } +} + +func TestReadUnion_NestedUnionNullValue(t *testing.T) { + // A union member with a null value is skipped entirely; the member + // callback must not fire. + d := NewShapeDeserializer([]byte(`{"mcp":null}`)) + member, _, err := readNestedUnion(d) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if member != "" { + t.Errorf("expected no member, got %q", member) + } +} + +func TestReadUnion_FlatUnionStillWorks(t *testing.T) { + // Regression guard: a plain (non-nested) union read. + d := NewShapeDeserializer([]byte(`{"lambda":"v"}`)) + var member, value string + err := smithy.ReadUnion(d, testSchemaInnerUnion, func(ms *smithy.Schema) error { + member = ms.MemberName() + return d.ReadString(ms, &value) + }) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if member != "lambda" || value != "v" { + t.Errorf("expected lambda/v, got %q/%q", member, value) + } +} + +func TestReadUnion_NestedUnionWithStruct(t *testing.T) { + // Union -> union -> struct, matching bedrockagentcorecontrol's + // Action -> ConfigurationBundleAction -> StaticOverride pattern. + schemaDeepOuter := smithy.NewSchema(smithy.ShapeID{ + Namespace: "com.test", Name: "DeepOuter", + }, smithy.ShapeTypeUnion, 1) + schemaDeepInner := smithy.NewSchema(smithy.ShapeID{ + Namespace: "com.test", Name: "DeepInner", + }, smithy.ShapeTypeUnion, 1) + schemaDeepOuter_inner := schemaDeepOuter.AddMember("inner", schemaDeepInner) + schemaDeepInner_leaf := schemaDeepInner.AddMember("leaf", testSchemaInnerStruct) + + payload := []byte(`{"inner": {"leaf": {"value": "hello"}}}`) + + d := NewShapeDeserializer(payload) + defer d.Close() + + var result string + err := smithy.ReadUnion(d, schemaDeepOuter, func(ms *smithy.Schema) error { + if ms != schemaDeepOuter_inner { + t.Fatalf("unexpected outer member %v", ms) + } + return smithy.ReadUnion(d, schemaDeepInner, func(ms *smithy.Schema) error { + if ms != schemaDeepInner_leaf { + t.Fatalf("unexpected inner member %v", ms) + } + return smithy.ReadStruct(d, testSchemaInnerStruct, func(ms *smithy.Schema) error { + switch ms { + case testSchemaInnerStruct_Value: + return d.ReadString(ms, &result) + } + return nil + }) + }) + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "hello" { + t.Fatalf("expected %q, got %q", "hello", result) + } +}