Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ func ReadUnion(d ShapeDeserializer, schema *Schema, memberFn func(*Schema) error
return err
}

if ms != nil {
if err := memberFn(ms); err != nil {
return err
}
// A nil member on the first read means the union value was null, empty,
// or contained only null members -- the union has been fully consumed and
// reading again would read past it.
if ms == nil {
return nil
}

if err := memberFn(ms); err != nil {
return err
}

for {
Expand Down
19 changes: 17 additions & 2 deletions transport/http/protocol/internal/json/shape_deserializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"time"

"github.com/aws/smithy-go"
"github.com/aws/smithy-go/transport/http/protocol/internal/json/internal/stdlib"
"github.com/aws/smithy-go/document"
"github.com/aws/smithy-go/internal/serde"
smithytime "github.com/aws/smithy-go/time"
"github.com/aws/smithy-go/traits"
"github.com/aws/smithy-go/transport/http/protocol/internal/json/internal/stdlib"
)

type ctxKind int8
Expand Down Expand Up @@ -500,7 +500,19 @@ func (d *ShapeDeserializer) ReadStructMember() (*smithy.Schema, error) {

// ReadUnion implements [smithy.ShapeDeserializer].
func (d *ShapeDeserializer) ReadUnion(s *smithy.Schema) (*smithy.Schema, error) {
if top := d.head.Top(); top == nil || top.kind != ctxUnion {
resuming := false
if top := d.head.Top(); top != nil && top.kind == ctxUnion {
// The context on top of the stack may belong to a parent union when
// this union is itself the value of a union member. Disambiguate by
// peeking: at a value position the next token can only be '{' (or
// null), while on resume it can only be a member name or '}'.
tok, err := d.peek()
if err != nil {
return nil, err
}
resuming = !isLCB(tok) && !isN(tok)
}
if !resuming {
if isNil, err := d.ReadNil(s); isNil || err != nil {
return nil, err
}
Expand Down Expand Up @@ -588,6 +600,9 @@ func unquote(tok []byte) (string, error) {
}

func memberFromToken(s *smithy.Schema, tok []byte, escaped bool) (*smithy.Schema, error) {
if len(tok) < 2 || tok[0] != '"' {
return nil, fmt.Errorf("expected member name, got %s", tok)
}
inner := tok[1 : len(tok)-1]
if m := memberByBytes(s, inner); m != nil {
return m, nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package json

import (
"testing"

"github.com/aws/smithy-go"
)

// Schemas modeling a union member whose value is itself a union, e.g.
// bedrock-agentcore's TargetConfiguration -> mcp -> McpTargetConfiguration.
var (
testSchemaString = smithy.NewSchema(smithy.ShapeID{
Namespace: "smithy.api", Name: "String",
}, smithy.ShapeTypeString, 0)

testSchemaInnerUnion = smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "InnerUnion",
}, smithy.ShapeTypeUnion, 1)

testSchemaOuterUnion = smithy.NewSchema(smithy.ShapeID{
Namespace: "com.test", Name: "OuterUnion",
}, smithy.ShapeTypeUnion, 1)

testSchemaInnerUnion_Lambda *smithy.Schema
testSchemaOuterUnion_Mcp *smithy.Schema
)

func init() {
testSchemaInnerUnion_Lambda = testSchemaInnerUnion.AddMember("lambda", testSchemaString)
testSchemaOuterUnion_Mcp = testSchemaOuterUnion.AddMember("mcp", testSchemaInnerUnion)
}

// readNestedUnion mimics the calling pattern of SDK-generated code
// (smithy.ReadUnion in serde.go): repeatedly call ReadUnion until it returns
// no member, deserializing each member value in between.
func readNestedUnion(d *ShapeDeserializer) (member string, value string, err error) {
err = smithy.ReadUnion(d, testSchemaOuterUnion, func(ms *smithy.Schema) error {
member = ms.MemberName()
return smithy.ReadUnion(d, testSchemaInnerUnion, func(inner *smithy.Schema) error {
member += "." + inner.MemberName()
return d.ReadString(inner, &value)
})
})
return member, value, err
}

func TestReadUnion_NestedUnionValue(t *testing.T) {
// A union whose member value is itself a union. Before the fix,
// ReadUnion mistook the parent's union context for its own, skipped the
// inner '{' and panicked in memberFromToken (slice bounds [1:0]).
d := NewShapeDeserializer([]byte(`{"mcp":{"lambda":"arn:aws:lambda:fn"}}`))

member, value, err := readNestedUnion(d)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if member != "mcp.lambda" {
t.Errorf("expected member mcp.lambda, got %q", member)
}
if value != "arn:aws:lambda:fn" {
t.Errorf("expected value arn:aws:lambda:fn, got %q", value)
}
}

func TestReadUnion_NestedUnionNullValue(t *testing.T) {
// A union member with a null value is skipped entirely; the member
// callback must not fire.
d := NewShapeDeserializer([]byte(`{"mcp":null}`))

member, _, err := readNestedUnion(d)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if member != "" {
t.Errorf("expected no member, got %q", member)
}
}

func TestReadUnion_FlatUnionStillWorks(t *testing.T) {
// Regression guard: a plain (non-nested) union read.
d := NewShapeDeserializer([]byte(`{"lambda":"v"}`))

var member, value string
err := smithy.ReadUnion(d, testSchemaInnerUnion, func(ms *smithy.Schema) error {
member = ms.MemberName()
return d.ReadString(ms, &value)
})
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if member != "lambda" || value != "v" {
t.Errorf("expected lambda/v, got %q/%q", member, value)
}
}