Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/e7d0a1e41e544ce1a7a9e92b4233ee3f.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "e7d0a1e4-1e54-4ce1-a7a9-e92b4233ee3f",
"type": "bugfix",
"description": "Fixed a panic when deserializing nested unions in JSON- and CBOR-based protocols.",
"modules": [
"."
]
}
8 changes: 8 additions & 0 deletions .changelog/fc69b4a3361441429959c7580656e5fe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "fc69b4a3-3614-4142-9959-c7580656e5fe",
"type": "bugfix",
"description": "Fixed a deserialization failure in all protocols when encountering a union with explicit null members.",
"modules": [
"."
]
}
8 changes: 3 additions & 5 deletions serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,12 @@ type DeserializableError interface {
// ReadUnion is a utility API for generated clients.
func ReadUnion(d ShapeDeserializer, schema *Schema, memberFn func(*Schema) error) error {
ms, err := d.ReadUnion(schema)
if err != nil {
if ms == nil || err != nil {
return err
}

if ms != nil {
if err := memberFn(ms); err != nil {
return err
}
if err := memberFn(ms); err != nil {
return err
}

for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ func (d *ShapeDeserializer) ReadStructMember() (*smithy.Schema, error) {
// ReadUnion implements [smithy.ShapeDeserializer].
func (d *ShapeDeserializer) ReadUnion(s *smithy.Schema) (*smithy.Schema, error) {
top := d.head.Top()
if top == nil || top.kind != deserCtxUnion { // first call: open the map
if top == nil || top.kind != deserCtxUnion || top.schema != s { // first call: open the map
if d.eof() {
return nil, errUnexpectedEOF
}
Expand Down
56 changes: 56 additions & 0 deletions transport/http/protocol/internal/cbor/union_in_union_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package cbor

import (
"testing"

"github.com/aws/smithy-go"
smithycbor "github.com/aws/smithy-go/encoding/cbor"
"github.com/aws/smithy-go/prelude"
)

var (
testSchemaInnerUnion = smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "InnerUnion",
}, smithy.ShapeTypeUnion, 1)

testSchemaOuterUnion = smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "OuterUnion",
}, smithy.ShapeTypeUnion, 1)

testSchemaInnerUnion_Lambda *smithy.Schema
testSchemaOuterUnion_Mcp *smithy.Schema
)

func init() {
testSchemaInnerUnion_Lambda = testSchemaInnerUnion.AddMember("lambda", prelude.String)
testSchemaOuterUnion_Mcp = testSchemaOuterUnion.AddMember("mcp", testSchemaInnerUnion)
}

func TestReadUnion_NestedUnionValue(t *testing.T) {
// CBOR equivalent of {"mcp":{"lambda":"arn:aws:lambda:fn"}}
payload := smithycbor.Encode(smithycbor.Map{
"mcp": smithycbor.Map{
"lambda": smithycbor.String("arn:aws:lambda:fn"),
},
})

d := NewShapeDeserializer(payload)

var member, value string
err := smithy.ReadUnion(d, testSchemaOuterUnion, func(ms *smithy.Schema) error {
member = ms.MemberName()
return smithy.ReadUnion(d, testSchemaInnerUnion, func(inner *smithy.Schema) error {
member += "." + inner.MemberName()
return d.ReadString(inner, &value)
})
})
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if member != "mcp.lambda" {
t.Errorf("expected member mcp.lambda, got %q", member)
}
if value != "arn:aws:lambda:fn" {
t.Errorf("expected value arn:aws:lambda:fn, got %q", value)
}
}
6 changes: 3 additions & 3 deletions transport/http/protocol/internal/json/shape_deserializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"time"

"github.com/aws/smithy-go"
"github.com/aws/smithy-go/transport/http/protocol/internal/json/internal/stdlib"
"github.com/aws/smithy-go/document"
"github.com/aws/smithy-go/internal/serde"
smithytime "github.com/aws/smithy-go/time"
"github.com/aws/smithy-go/traits"
"github.com/aws/smithy-go/transport/http/protocol/internal/json/internal/stdlib"
)

type ctxKind int8
Expand Down Expand Up @@ -500,7 +500,7 @@ func (d *ShapeDeserializer) ReadStructMember() (*smithy.Schema, error) {

// ReadUnion implements [smithy.ShapeDeserializer].
func (d *ShapeDeserializer) ReadUnion(s *smithy.Schema) (*smithy.Schema, error) {
if top := d.head.Top(); top == nil || top.kind != ctxUnion {
if top := d.head.Top(); top == nil || top.kind != ctxUnion || top.schema != s {
if isNil, err := d.ReadNil(s); isNil || err != nil {
return nil, err
}
Expand All @@ -512,7 +512,7 @@ func (d *ShapeDeserializer) ReadUnion(s *smithy.Schema) (*smithy.Schema, error)
if !isLCB(tok) {
return nil, fmt.Errorf("expected '{', got %s", tok)
}
d.head.Push(deserCtx{kind: ctxUnion})
d.head.Push(deserCtx{kind: ctxUnion, schema: s})
}

for {
Expand Down
137 changes: 137 additions & 0 deletions transport/http/protocol/internal/json/union_in_union_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package json

import (
"testing"

"github.com/aws/smithy-go"
"github.com/aws/smithy-go/prelude"
)

// Schemas modeling a union member whose value is itself a union, e.g.
// bedrock-agentcore's TargetConfiguration -> mcp -> McpTargetConfiguration.
var (
testSchemaInnerUnion = smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "InnerUnion",
}, smithy.ShapeTypeUnion, 1)

testSchemaOuterUnion = smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "OuterUnion",
}, smithy.ShapeTypeUnion, 1)

testSchemaInnerStruct = smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "InnerStruct",
}, smithy.ShapeTypeStructure, 1)

testSchemaInnerUnion_Lambda *smithy.Schema
testSchemaOuterUnion_Mcp *smithy.Schema
testSchemaInnerStruct_Value *smithy.Schema
)

func init() {
testSchemaInnerUnion_Lambda = testSchemaInnerUnion.AddMember("lambda", prelude.String)
testSchemaOuterUnion_Mcp = testSchemaOuterUnion.AddMember("mcp", testSchemaInnerUnion)
testSchemaInnerStruct_Value = testSchemaInnerStruct.AddMember("value", prelude.String)
}

// readNestedUnion mimics the calling pattern of SDK-generated code
// (smithy.ReadUnion in serde.go): repeatedly call ReadUnion until it returns
// no member, deserializing each member value in between.
func readNestedUnion(d *ShapeDeserializer) (member string, value string, err error) {
err = smithy.ReadUnion(d, testSchemaOuterUnion, func(ms *smithy.Schema) error {
member = ms.MemberName()
return smithy.ReadUnion(d, testSchemaInnerUnion, func(inner *smithy.Schema) error {
member += "." + inner.MemberName()
return d.ReadString(inner, &value)
})
})
return member, value, err
}

func TestReadUnion_NestedUnionValue(t *testing.T) {
// A union whose member value is itself a union. Before the fix,
// ReadUnion mistook the parent's union context for its own, skipped the
// inner '{' and panicked in memberFromToken (slice bounds [1:0]).
d := NewShapeDeserializer([]byte(`{"mcp":{"lambda":"arn:aws:lambda:fn"}}`))
member, value, err := readNestedUnion(d)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if member != "mcp.lambda" {
t.Errorf("expected member mcp.lambda, got %q", member)
}
if value != "arn:aws:lambda:fn" {
t.Errorf("expected value arn:aws:lambda:fn, got %q", value)
}
}

func TestReadUnion_NestedUnionNullValue(t *testing.T) {
// A union member with a null value is skipped entirely; the member
// callback must not fire.
d := NewShapeDeserializer([]byte(`{"mcp":null}`))
member, _, err := readNestedUnion(d)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if member != "" {
t.Errorf("expected no member, got %q", member)
}
}

func TestReadUnion_FlatUnionStillWorks(t *testing.T) {
// Regression guard: a plain (non-nested) union read.
d := NewShapeDeserializer([]byte(`{"lambda":"v"}`))
var member, value string
err := smithy.ReadUnion(d, testSchemaInnerUnion, func(ms *smithy.Schema) error {
member = ms.MemberName()
return d.ReadString(ms, &value)
})
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if member != "lambda" || value != "v" {
t.Errorf("expected lambda/v, got %q/%q", member, value)
}
}

func TestReadUnion_NestedUnionWithStruct(t *testing.T) {
// Union -> union -> struct, matching bedrockagentcorecontrol's
// Action -> ConfigurationBundleAction -> StaticOverride pattern.
schemaDeepOuter := smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "DeepOuter",
}, smithy.ShapeTypeUnion, 1)
schemaDeepInner := smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "DeepInner",
}, smithy.ShapeTypeUnion, 1)
schemaDeepOuter_inner := schemaDeepOuter.AddMember("inner", schemaDeepInner)
schemaDeepInner_leaf := schemaDeepInner.AddMember("leaf", testSchemaInnerStruct)

payload := []byte(`{"inner": {"leaf": {"value": "hello"}}}`)

d := NewShapeDeserializer(payload)
defer d.Close()

var result string
err := smithy.ReadUnion(d, schemaDeepOuter, func(ms *smithy.Schema) error {
if ms != schemaDeepOuter_inner {
t.Fatalf("unexpected outer member %v", ms)
}
return smithy.ReadUnion(d, schemaDeepInner, func(ms *smithy.Schema) error {
if ms != schemaDeepInner_leaf {
t.Fatalf("unexpected inner member %v", ms)
}
return smithy.ReadStruct(d, testSchemaInnerStruct, func(ms *smithy.Schema) error {
switch ms {
case testSchemaInnerStruct_Value:
return d.ReadString(ms, &result)
}
return nil
})
})
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "hello" {
t.Fatalf("expected %q, got %q", "hello", result)
}
}
Loading