diff --git a/Makefile b/Makefile index 6d573be..0db4555 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ .PHONY: all build test test: - @PROJECT_ROOT=$(PWD) APP_ENV=test go test -timeout 5s -tags grpc,pulsar,newrelic,kitex ./... + @PROJECT_ROOT=$(PWD) APP_ENV=test go test -timeout 5s -tags grpc,pulsar,newrelic,kitex,sqs,sqs_worker,otel ./... diff --git a/plugins/grpc/presets/grpc_with_error_reporting.go b/plugins/grpc/presets/grpc_with_error_reporting.go index 82bc791..a78e5ee 100644 --- a/plugins/grpc/presets/grpc_with_error_reporting.go +++ b/plugins/grpc/presets/grpc_with_error_reporting.go @@ -1,5 +1,5 @@ //go:build grpc && newrelic && otel && sentry -// +build grpc,newrelic,sentry,otel +// +build grpc,newrelic,otel,sentry package presets diff --git a/plugins/sqs/README.md b/plugins/sqs/README.md index 89b8c1d..04560db 100644 --- a/plugins/sqs/README.md +++ b/plugins/sqs/README.md @@ -26,3 +26,26 @@ func main() { } ``` +## OpenTelemetry + +When built with `-tags "sqs otel"`, every `*sqs.SQS` client produced by +`AwsTopicManager.AddTopic` is transparently instrumented via the +`aws-sdk-go` v1 request handler chain. Each SQS API call emits a span +following the OpenTelemetry Semantic Conventions 1.41.0 — [messaging spans](https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/) and [AWS SQS](https://opentelemetry.io/docs/specs/semconv/messaging/sqs/): + +| Operation | Span kind | `messaging.operation.type` | `messaging.operation.name` | +| --------------------------------------------- | --------- | -------------------------- | -------------------------- | +| `SendMessage` | Producer | `send` | `send` | +| `SendMessageBatch` | Producer | `send` | `send_batch` | +| `ReceiveMessage` | Consumer | `receive` | `receive` | +| `DeleteMessage(Batch)` | Client | `settle` | `delete[_batch]` | +| `ChangeMessageVisibility(Batch)` | Client | `settle` | `change_visibility[_batch]`| + +Common attributes include `messaging.system=aws_sqs`, +`messaging.destination.name`, `aws.sqs.queue.url`, `aws.request_id`, +`server.address`, and — for `SendMessage` — `messaging.message.id`. + +W3C trace context (`traceparent`) is injected into the outgoing message's +`MessageAttributes` (per entry for `SendMessageBatch`), enabling downstream +consumers to continue the trace. + diff --git a/plugins/sqs/otel_carrier.go b/plugins/sqs/otel_carrier.go new file mode 100644 index 0000000..09d290d --- /dev/null +++ b/plugins/sqs/otel_carrier.go @@ -0,0 +1,39 @@ +//go:build sqs && otel +// +build sqs,otel + +package sqs + +import ( + "github.com/aws/aws-sdk-go/aws" + aws_sqs "github.com/aws/aws-sdk-go/service/sqs" + "go.opentelemetry.io/otel/propagation" +) + +// SQSMessageCarrier adapts an SQS MessageAttributeValue map to the +// propagation.TextMapCarrier interface so W3C trace context can be injected +// and extracted from SQS messages. +type SQSMessageCarrier map[string]*aws_sqs.MessageAttributeValue + +var _ propagation.TextMapCarrier = SQSMessageCarrier{} + +func (c SQSMessageCarrier) Get(key string) string { + if v, ok := c[key]; ok && v != nil && v.StringValue != nil { + return *v.StringValue + } + return "" +} + +func (c SQSMessageCarrier) Set(key, val string) { + c[key] = &aws_sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: aws.String(val), + } +} + +func (c SQSMessageCarrier) Keys() []string { + keys := make([]string, 0, len(c)) + for k := range c { + keys = append(keys, k) + } + return keys +} diff --git a/plugins/sqs/otel_client.go b/plugins/sqs/otel_client.go new file mode 100644 index 0000000..1fcd664 --- /dev/null +++ b/plugins/sqs/otel_client.go @@ -0,0 +1,262 @@ +//go:build sqs && otel +// +build sqs,otel + +package sqs + +import ( + "context" + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/aws/aws-sdk-go/aws/request" + aws_sqs "github.com/aws/aws-sdk-go/service/sqs" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +const tracerName = "github.com/shoplineapp/go-app/plugins/sqs" + +// opClassification describes how a given SQS API operation should be mapped +// onto OTel messaging semantic conventions. +type opClassification struct { + spanKind trace.SpanKind + opType string // messaging.operation.type + opName string // messaging.operation.name +} + +// classifyOperation returns the classification for a given SQS API operation +// name. The second return value is false for operations that are not +// messaging-relevant and should be skipped. +func classifyOperation(name string) (opClassification, bool) { + switch name { + case "SendMessage": + return opClassification{trace.SpanKindProducer, "send", "send"}, true + case "SendMessageBatch": + return opClassification{trace.SpanKindProducer, "send", "send_batch"}, true + case "ReceiveMessage": + return opClassification{trace.SpanKindConsumer, "receive", "receive"}, true + case "DeleteMessage": + return opClassification{trace.SpanKindClient, "settle", "delete"}, true + case "DeleteMessageBatch": + return opClassification{trace.SpanKindClient, "settle", "delete_batch"}, true + case "ChangeMessageVisibility": + return opClassification{trace.SpanKindClient, "settle", "change_visibility"}, true + case "ChangeMessageVisibilityBatch": + return opClassification{trace.SpanKindClient, "settle", "change_visibility_batch"}, true + } + return opClassification{}, false +} + +// InstrumentSQSClient attaches OpenTelemetry instrumentation to a v1 +// aws-sdk-go SQS client. It: +// - starts a span in the Validate phase (before params are serialized) +// - injects W3C trace context into outgoing message attributes for send +// operations (also Validate phase, after span start) +// - ends the span in the Complete phase, recording any error and setting +// AWS/messaging semconv attributes. +func InstrumentSQSClient(c *aws_sqs.SQS) { + instrument(c) +} + +func instrument(c *aws_sqs.SQS) { + if c == nil { + return + } + c.Handlers.Validate.PushFrontNamed(request.NamedHandler{ + Name: "otel.sqs.StartSpan", + Fn: startSpan, + }) + c.Handlers.Validate.PushBackNamed(request.NamedHandler{ + Name: "otel.sqs.InjectTraceContext", + Fn: injectTraceContext, + }) + c.Handlers.Complete.PushBackNamed(request.NamedHandler{ + Name: "otel.sqs.EndSpan", + Fn: endSpan, + }) +} + +func startSpan(r *request.Request) { + class, ok := classifyOperation(r.Operation.Name) + if !ok { + return + } + + queueURL := extractQueueURL(r.Params) + queueName := queueNameFromURL(queueURL) + + attrs := []attribute.KeyValue{ + attribute.String("messaging.system", "aws_sqs"), + attribute.String("messaging.operation.name", class.opName), + attribute.String("messaging.operation.type", class.opType), + } + if queueURL != "" { + attrs = append(attrs, attribute.String("aws.sqs.queue.url", queueURL)) + } + if queueName != "" { + attrs = append(attrs, attribute.String("messaging.destination.name", queueName)) + } + if host, port := HostPortFromURL(queueURL); host != "" { + attrs = append(attrs, attribute.String("server.address", host)) + if port > 0 { + attrs = append(attrs, attribute.Int("server.port", port)) + } + } + + spanName := class.opName + if queueName != "" { + spanName = class.opName + " " + queueName + } + + tracer := otel.Tracer(tracerName) + ctx, _ := tracer.Start(r.Context(), spanName, + trace.WithSpanKind(class.spanKind), + trace.WithAttributes(attrs...), + ) + r.SetContext(ctx) +} + +func injectTraceContext(r *request.Request) { + propagator := otel.GetTextMapPropagator() + switch p := r.Params.(type) { + case *aws_sqs.SendMessageInput: + if p == nil { + return + } + if p.MessageAttributes == nil { + p.MessageAttributes = map[string]*aws_sqs.MessageAttributeValue{} + } + propagator.Inject(r.Context(), SQSMessageCarrier(p.MessageAttributes)) + case *aws_sqs.SendMessageBatchInput: + if p == nil { + return + } + for i := range p.Entries { + if p.Entries[i] == nil { + continue + } + if p.Entries[i].MessageAttributes == nil { + p.Entries[i].MessageAttributes = map[string]*aws_sqs.MessageAttributeValue{} + } + propagator.Inject(r.Context(), SQSMessageCarrier(p.Entries[i].MessageAttributes)) + } + } +} + +func endSpan(r *request.Request) { + span := trace.SpanFromContext(r.Context()) + if !span.IsRecording() { + return + } + + if r.RequestID != "" { + span.SetAttributes(attribute.String("aws.request_id", r.RequestID)) + } + + switch out := r.Data.(type) { + case *aws_sqs.SendMessageOutput: + if out != nil && out.MessageId != nil { + span.SetAttributes(attribute.String("messaging.message.id", *out.MessageId)) + } + case *aws_sqs.ReceiveMessageOutput: + addReceiveLinks(span, out) + } + + if r.Error != nil { + span.RecordError(r.Error) + span.SetStatus(codes.Error, r.Error.Error()) + span.SetAttributes(attribute.String("error.type", fmt.Sprintf("%T", r.Error))) + } + + span.End() +} + +func addReceiveLinks(span trace.Span, out *aws_sqs.ReceiveMessageOutput) { + if out == nil { + return + } + propagator := otel.GetTextMapPropagator() + for _, msg := range out.Messages { + if msg == nil { + continue + } + carrier := SQSMessageCarrier(msg.MessageAttributes) + ctx := propagator.Extract(context.Background(), carrier) + sc := trace.SpanContextFromContext(ctx) + if sc.IsValid() { + span.AddLink(trace.Link{SpanContext: sc}) + } + } +} + +// extractQueueURL returns the QueueUrl from a known SQS input type, or "". +func extractQueueURL(params interface{}) string { + switch p := params.(type) { + case *aws_sqs.SendMessageInput: + return derefString(p.QueueUrl) + case *aws_sqs.SendMessageBatchInput: + return derefString(p.QueueUrl) + case *aws_sqs.ReceiveMessageInput: + return derefString(p.QueueUrl) + case *aws_sqs.DeleteMessageInput: + return derefString(p.QueueUrl) + case *aws_sqs.DeleteMessageBatchInput: + return derefString(p.QueueUrl) + case *aws_sqs.ChangeMessageVisibilityInput: + return derefString(p.QueueUrl) + case *aws_sqs.ChangeMessageVisibilityBatchInput: + return derefString(p.QueueUrl) + } + return "" +} + +// queueNameFromURL extracts the queue name (last path segment) from a +// standard SQS queue URL such as +// https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue. +func queueNameFromURL(queueURL string) string { + if queueURL == "" { + return "" + } + u, err := url.Parse(queueURL) + if err != nil { + return "" + } + path := strings.Trim(u.Path, "/") + if path == "" { + return "" + } + if idx := strings.LastIndex(path, "/"); idx >= 0 { + return path[idx+1:] + } + return path +} + +// HostPortFromURL returns the hostname and port of an SQS queue URL. +// Port is 0 when the URL has no explicit port. +func HostPortFromURL(queueURL string) (host string, port int) { + if queueURL == "" { + return "", 0 + } + u, err := url.Parse(queueURL) + if err != nil { + return "", 0 + } + host = u.Hostname() + if p := u.Port(); p != "" { + if n, err := strconv.Atoi(p); err == nil { + port = n + } + } + return host, port +} + +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/plugins/sqs/otel_client_noop.go b/plugins/sqs/otel_client_noop.go new file mode 100644 index 0000000..bb8903d --- /dev/null +++ b/plugins/sqs/otel_client_noop.go @@ -0,0 +1,9 @@ +//go:build sqs && !otel + +package sqs + +import aws_sqs "github.com/aws/aws-sdk-go/service/sqs" + +func InstrumentSQSClient(c *aws_sqs.SQS) {} + +func instrument(c *aws_sqs.SQS) {} diff --git a/plugins/sqs/otel_client_test.go b/plugins/sqs/otel_client_test.go new file mode 100644 index 0000000..e8fac48 --- /dev/null +++ b/plugins/sqs/otel_client_test.go @@ -0,0 +1,401 @@ +//go:build sqs && otel +// +build sqs,otel + +package sqs + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + aws_session "github.com/aws/aws-sdk-go/aws/session" + aws_sqs "github.com/aws/aws-sdk-go/service/sqs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/propagation" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" +) + +// --- helpers --- + +func setupTestTracer(t *testing.T) *tracetest.InMemoryExporter { + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + otel.SetTracerProvider(tp) + otel.SetTextMapPropagator(propagation.TraceContext{}) + t.Cleanup(func() { + _ = tp.Shutdown(context.Background()) + }) + return exporter +} + +func attrMap(attrs []attribute.KeyValue) map[string]string { + m := make(map[string]string, len(attrs)) + for _, a := range attrs { + m[string(a.Key)] = a.Value.Emit() + } + return m +} + +// fakeSQSServer serves canned AWS SQS query-protocol responses. It routes +// based on the "Action" form field and captures the last request body for +// inspection. +type fakeSQSServer struct { + *httptest.Server + lastBody string +} + +func newFakeSQSServer(t *testing.T) *fakeSQSServer { + f := &fakeSQSServer{} + f.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + f.lastBody = string(body) + form, _ := url.ParseQuery(f.lastBody) + action := form.Get("Action") + + md5Of := func(s string) string { + sum := md5.Sum([]byte(s)) + return hex.EncodeToString(sum[:]) + } + + w.Header().Set("Content-Type", "text/xml") + w.Header().Set("x-amzn-RequestId", "test-request-id") + switch action { + case "SendMessage": + bodyMD5 := md5Of(form.Get("MessageBody")) + _, _ = w.Write([]byte(` + + + test-message-id + ` + bodyMD5 + ` + + test-request-id +`)) + case "SendMessageBatch": + var entries strings.Builder + for i := 1; ; i++ { + prefix := fmt.Sprintf("SendMessageBatchRequestEntry.%d.", i) + id := form.Get(prefix + "Id") + if id == "" { + break + } + body := form.Get(prefix + "MessageBody") + fmt.Fprintf(&entries, `%smid-%d%s`, id, i, md5Of(body)) + } + _, _ = w.Write([]byte(` + + ` + entries.String() + ` + test-request-id +`)) + case "ReceiveMessage": + _, _ = w.Write([]byte(` + + + + recv-msg-id + recv-handle + ` + md5Of("test body") + ` + test body + + traceparent + + 00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01 + String + + + + + test-request-id +`)) + case "DeleteMessage": + _, _ = w.Write([]byte(` + + test-request-id +`)) + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + t.Cleanup(f.Server.Close) + return f +} + +func newTestSQSClient(t *testing.T, endpoint string) *aws_sqs.SQS { + sess, err := aws_session.NewSession(&aws.Config{ + Region: aws.String("us-east-1"), + Endpoint: aws.String(endpoint), + Credentials: credentials.NewStaticCredentials("akid", "secret", ""), + DisableSSL: aws.Bool(true), + S3ForcePathStyle: aws.Bool(true), + MaxRetries: aws.Int(0), + }) + require.NoError(t, err) + client := aws_sqs.New(sess) + InstrumentSQSClient(client) + return client +} + +const testQueueURL = "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue" + +// --- tests --- + +func TestClassifyOperation(t *testing.T) { + cases := []struct { + op string + expected opClassification + ok bool + }{ + {"SendMessage", opClassification{trace.SpanKindProducer, "send", "send"}, true}, + {"SendMessageBatch", opClassification{trace.SpanKindProducer, "send", "send_batch"}, true}, + {"ReceiveMessage", opClassification{trace.SpanKindConsumer, "receive", "receive"}, true}, + {"DeleteMessage", opClassification{trace.SpanKindClient, "settle", "delete"}, true}, + {"DeleteMessageBatch", opClassification{trace.SpanKindClient, "settle", "delete_batch"}, true}, + {"ChangeMessageVisibility", opClassification{trace.SpanKindClient, "settle", "change_visibility"}, true}, + {"ChangeMessageVisibilityBatch", opClassification{trace.SpanKindClient, "settle", "change_visibility_batch"}, true}, + {"CreateQueue", opClassification{}, false}, + } + for _, c := range cases { + got, ok := classifyOperation(c.op) + assert.Equal(t, c.ok, ok, c.op) + if c.ok { + assert.Equal(t, c.expected, got, c.op) + } + } +} + +func TestQueueNameFromURL(t *testing.T) { + assert.Equal(t, "MyQueue", queueNameFromURL(testQueueURL)) + assert.Equal(t, "", queueNameFromURL("")) + assert.Equal(t, "OnlyName", queueNameFromURL("https://sqs.us-east-1.amazonaws.com/OnlyName")) +} + +func TestHostPortFromURL(t *testing.T) { + host, port := HostPortFromURL(testQueueURL) + assert.Equal(t, "sqs.us-east-1.amazonaws.com", host) + assert.Equal(t, 0, port) + + host, port = HostPortFromURL("https://sqs.us-east-1.amazonaws.com:9359/MyQueue") + assert.Equal(t, "sqs.us-east-1.amazonaws.com", host) + assert.Equal(t, 9359, port) + + host, port = HostPortFromURL("") + assert.Equal(t, "", host) + assert.Equal(t, 0, port) +} + +func TestInstrument_SendMessage_CreatesProducerSpan(t *testing.T) { + exporter := setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.SendMessage(&aws_sqs.SendMessageInput{ + QueueUrl: aws.String(testQueueURL), + MessageBody: aws.String("hello"), + }) + require.NoError(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + + s := spans[0] + assert.Equal(t, "send MyQueue", s.Name) + assert.Equal(t, trace.SpanKindProducer, s.SpanKind) + + attrs := attrMap(s.Attributes) + assert.Equal(t, "aws_sqs", attrs["messaging.system"]) + assert.Equal(t, "send", attrs["messaging.operation.name"]) + assert.Equal(t, "send", attrs["messaging.operation.type"]) + assert.Equal(t, "MyQueue", attrs["messaging.destination.name"]) + assert.Equal(t, testQueueURL, attrs["aws.sqs.queue.url"]) + assert.Equal(t, "sqs.us-east-1.amazonaws.com", attrs["server.address"]) + assert.Equal(t, "test-message-id", attrs["messaging.message.id"]) + assert.Equal(t, "test-request-id", attrs["aws.request_id"]) +} + +func TestInstrument_SendMessage_InjectsTraceparent(t *testing.T) { + _ = setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.SendMessage(&aws_sqs.SendMessageInput{ + QueueUrl: aws.String(testQueueURL), + MessageBody: aws.String("hello"), + }) + require.NoError(t, err) + + assert.Contains(t, fake.lastBody, "MessageAttribute.1.Name=traceparent") + assert.Contains(t, fake.lastBody, "MessageAttribute.1.Value.DataType=String") +} + +func TestInstrument_SendMessage_PreservesExistingAttributes(t *testing.T) { + _ = setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.SendMessage(&aws_sqs.SendMessageInput{ + QueueUrl: aws.String(testQueueURL), + MessageBody: aws.String("hello"), + MessageAttributes: map[string]*aws_sqs.MessageAttributeValue{ + "custom": {DataType: aws.String("String"), StringValue: aws.String("value")}, + }, + }) + require.NoError(t, err) + + assert.Contains(t, fake.lastBody, "Name=custom") + assert.Contains(t, fake.lastBody, "Name=traceparent") +} + +func TestInstrument_SendMessageBatch_InjectsPerEntry(t *testing.T) { + _ = setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.SendMessageBatch(&aws_sqs.SendMessageBatchInput{ + QueueUrl: aws.String(testQueueURL), + Entries: []*aws_sqs.SendMessageBatchRequestEntry{ + {Id: aws.String("e1"), MessageBody: aws.String("a")}, + {Id: aws.String("e2"), MessageBody: aws.String("b")}, + }, + }) + require.NoError(t, err) + + assert.Contains(t, fake.lastBody, "SendMessageBatchRequestEntry.1.MessageAttribute.1.Name=traceparent") + assert.Contains(t, fake.lastBody, "SendMessageBatchRequestEntry.2.MessageAttribute.1.Name=traceparent") +} + +func TestInstrument_SendMessageBatch_SpanAttributes(t *testing.T) { + exporter := setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.SendMessageBatch(&aws_sqs.SendMessageBatchInput{ + QueueUrl: aws.String(testQueueURL), + Entries: []*aws_sqs.SendMessageBatchRequestEntry{ + {Id: aws.String("e1"), MessageBody: aws.String("a")}, + {Id: aws.String("e2"), MessageBody: aws.String("b")}, + }, + }) + require.NoError(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + s := spans[0] + assert.Equal(t, "send_batch MyQueue", s.Name) + assert.Equal(t, trace.SpanKindProducer, s.SpanKind) + + attrs := attrMap(s.Attributes) + assert.Equal(t, "send_batch", attrs["messaging.operation.name"]) +} + +func TestInstrument_ReceiveMessage_ConsumerKind(t *testing.T) { + exporter := setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.ReceiveMessage(&aws_sqs.ReceiveMessageInput{ + QueueUrl: aws.String(testQueueURL), + }) + require.NoError(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + s := spans[0] + assert.Equal(t, "receive MyQueue", s.Name) + assert.Equal(t, trace.SpanKindConsumer, s.SpanKind) + attrs := attrMap(s.Attributes) + assert.Equal(t, "receive", attrs["messaging.operation.type"]) + assert.Equal(t, "receive", attrs["messaging.operation.name"]) +} + +func TestInstrument_ReceiveMessage_LinksToCreationContext(t *testing.T) { + exporter := setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.ReceiveMessage(&aws_sqs.ReceiveMessageInput{ + QueueUrl: aws.String(testQueueURL), + }) + require.NoError(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + s := spans[0] + require.Len(t, s.Links, 1) + linkedSC := s.Links[0].SpanContext + assert.True(t, linkedSC.IsValid()) + assert.Equal(t, "4bf92f3577b34da6a3ce929d0e0e4736", linkedSC.TraceID().String()) +} + +func TestInstrument_DeleteMessage_SettleKind(t *testing.T) { + exporter := setupTestTracer(t) + fake := newFakeSQSServer(t) + client := newTestSQSClient(t, fake.URL) + + _, err := client.DeleteMessage(&aws_sqs.DeleteMessageInput{ + QueueUrl: aws.String(testQueueURL), + ReceiptHandle: aws.String("rh"), + }) + require.NoError(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + s := spans[0] + assert.Equal(t, "delete MyQueue", s.Name) + assert.Equal(t, trace.SpanKindClient, s.SpanKind) + attrs := attrMap(s.Attributes) + assert.Equal(t, "settle", attrs["messaging.operation.type"]) + assert.Equal(t, "delete", attrs["messaging.operation.name"]) +} + +func TestInstrument_RecordsErrorOnFailure(t *testing.T) { + exporter := setupTestTracer(t) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`SenderInvalidParameterValueboomerr-req`)) + })) + t.Cleanup(srv.Close) + + client := newTestSQSClient(t, srv.URL) + _, err := client.SendMessage(&aws_sqs.SendMessageInput{ + QueueUrl: aws.String(testQueueURL), + MessageBody: aws.String("hello"), + }) + require.Error(t, err) + + spans := exporter.GetSpans() + require.Len(t, spans, 1) + attrs := attrMap(spans[0].Attributes) + assert.Contains(t, attrs, "error.type") +} + +func TestInstrument_SkipsUnclassifiedOperation(t *testing.T) { + exporter := setupTestTracer(t) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/xml") + _, _ = w.Write([]byte(` + + ` + testQueueURL + ` + r +`)) + })) + t.Cleanup(srv.Close) + + client := newTestSQSClient(t, srv.URL) + _, err := client.GetQueueUrl(&aws_sqs.GetQueueUrlInput{QueueName: aws.String("MyQueue")}) + require.NoError(t, err) + + assert.Empty(t, exporter.GetSpans(), "non-messaging operations should not be instrumented") +} diff --git a/plugins/sqs/topic.go b/plugins/sqs/topic.go index e06423e..44d8abb 100644 --- a/plugins/sqs/topic.go +++ b/plugins/sqs/topic.go @@ -54,6 +54,9 @@ func (t *AwsTopicManager) AddTopic(name, arn string) *Topic { Arn: arn, SQSAPI: aws_sqs.New(session), } + if client, ok := topic.SQSAPI.(*aws_sqs.SQS); ok { + instrument(client) + } t.TopicMaps[name] = topic diff --git a/plugins/sqs_worker/README.md b/plugins/sqs_worker/README.md index 3ce7718..ec46a94 100644 --- a/plugins/sqs_worker/README.md +++ b/plugins/sqs_worker/README.md @@ -24,3 +24,13 @@ func main() { } ``` +## OpenTelemetry + +When built with `-tags "sqs sqs_worker otel"`, the worker starts a +`SpanKindConsumer` "process" span for each incoming message, following the +OpenTelemetry Semantic Conventions 1.41.0 — [messaging spans](https://opentelemetry.io/docs/specs/semconv/messaging/messaging-spans/) and [AWS SQS](https://opentelemetry.io/docs/specs/semconv/messaging/sqs/). +The span is parented to the upstream producer's trace context extracted from +the message's `MessageAttributes` (W3C `traceparent`), so traces span +naturally from the producer through to message processing. The subsequent +`DeleteMessage` call is emitted as a child settle span. + diff --git a/plugins/sqs_worker/otel_consumer.go b/plugins/sqs_worker/otel_consumer.go new file mode 100644 index 0000000..1e9e77b --- /dev/null +++ b/plugins/sqs_worker/otel_consumer.go @@ -0,0 +1,98 @@ +//go:build sqs && sqs_worker && otel +// +build sqs,sqs_worker,otel + +package sqs_worker + +import ( + "context" + "fmt" + + aws_sqs "github.com/aws/aws-sdk-go/service/sqs" + "github.com/shoplineapp/go-app/common" + sqsplugin "github.com/shoplineapp/go-app/plugins/sqs" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +const tracerName = "github.com/shoplineapp/go-app/plugins/sqs_worker" + +func init() { + processHook = otelProcessHook +} + +// otelProcessHook extracts the upstream trace context from the message's +// MessageAttributes and starts a SpanKindConsumer "process" span as a child +// of that upstream context. The returned context carries the new span, so +// any subsequent SQS call (e.g. DeleteMessage) made with it will be +// captured as a child settle span. +func otelProcessHook( + ctx context.Context, + topic, msgID, queueURL string, + attrs map[string]*aws_sqs.MessageAttributeValue, +) (context.Context, func(error)) { + ctx = extractMessageContext(ctx, attrs) + return startProcessSpan(ctx, topic, msgID, queueURL) +} + +// extractMessageContext extracts W3C trace context (traceparent, +// tracestate, baggage) from SQS message attributes via the configured +// TextMapPropagator. +func extractMessageContext( + ctx context.Context, + attrs map[string]*aws_sqs.MessageAttributeValue, +) context.Context { + if attrs == nil { + return common.NewContextWithTraceID(ctx, "") + } + + ctx = otel.GetTextMapPropagator().Extract(ctx, sqsplugin.SQSMessageCarrier(attrs)) + + var traceID string + if sc := trace.SpanContextFromContext(ctx); sc.IsValid() { + traceID = sc.TraceID().String() + } + + return common.NewContextWithTraceID(ctx, traceID) +} + +// startProcessSpan starts a SpanKindConsumer "process" span for a single +// SQS message and returns the new context plus an end function that should +// be called with any processing error. +func startProcessSpan(ctx context.Context, topic, msgID, queueURL string) (context.Context, func(error)) { + tracer := otel.Tracer(tracerName) + attrs := []attribute.KeyValue{ + attribute.String("messaging.system", "aws_sqs"), + attribute.String("messaging.operation.name", "process"), + attribute.String("messaging.operation.type", "process"), + attribute.String("messaging.destination.name", topic), + } + if queueURL != "" { + attrs = append(attrs, attribute.String("aws.sqs.queue.url", queueURL)) + } + if host, port := sqsplugin.HostPortFromURL(queueURL); host != "" { + attrs = append(attrs, attribute.String("server.address", host)) + if port > 0 { + attrs = append(attrs, attribute.Int("server.port", port)) + } + } + if msgID != "" { + attrs = append(attrs, attribute.String("messaging.message.id", msgID)) + } + + ctx, span := tracer.Start(ctx, "process "+topic, + trace.WithSpanKind(trace.SpanKindConsumer), + trace.WithAttributes(attrs...), + ) + + end := func(err error) { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + span.SetAttributes(attribute.String("error.type", fmt.Sprintf("%T", err))) + } + span.End() + } + return ctx, end +} diff --git a/plugins/sqs_worker/otel_consumer_test.go b/plugins/sqs_worker/otel_consumer_test.go new file mode 100644 index 0000000..5f82feb --- /dev/null +++ b/plugins/sqs_worker/otel_consumer_test.go @@ -0,0 +1,147 @@ +//go:build sqs && sqs_worker && otel +// +build sqs,sqs_worker,otel + +package sqs_worker + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + aws_sqs "github.com/aws/aws-sdk-go/service/sqs" + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/propagation" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" +) + +func setupTestTracer(t *testing.T) *tracetest.InMemoryExporter { + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + otel.SetTracerProvider(tp) + otel.SetTextMapPropagator(propagation.TraceContext{}) + t.Cleanup(func() { + _ = tp.Shutdown(context.Background()) + }) + return exporter +} + +func attrMap(attrs []attribute.KeyValue) map[string]string { + m := make(map[string]string, len(attrs)) + for _, a := range attrs { + m[string(a.Key)] = a.Value.Emit() + } + return m +} + +func stringAttr(s string) *aws_sqs.MessageAttributeValue { + return &aws_sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: aws.String(s), + } +} + +func TestExtractMessageContext_WithTraceparent(t *testing.T) { + _ = setupTestTracer(t) + attrs := map[string]*aws_sqs.MessageAttributeValue{ + "traceparent": stringAttr("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"), + } + ctx := extractMessageContext(context.Background(), attrs) + + spanCtx := trace.SpanContextFromContext(ctx) + assert.True(t, spanCtx.IsValid()) + assert.Equal(t, "4bf92f3577b34da6a3ce929d0e0e4736", spanCtx.TraceID().String()) +} + +func TestExtractMessageContext_WithoutTraceparent(t *testing.T) { + _ = setupTestTracer(t) + attrs := map[string]*aws_sqs.MessageAttributeValue{ + "unrelated": stringAttr("value"), + } + ctx := extractMessageContext(context.Background(), attrs) + spanCtx := trace.SpanContextFromContext(ctx) + assert.False(t, spanCtx.IsValid()) +} + +func TestExtractMessageContext_NilAttrs(t *testing.T) { + _ = setupTestTracer(t) + ctx := extractMessageContext(context.Background(), nil) + assert.NotNil(t, ctx) +} + +func TestExtractMessageContext_BackwardCompatTraceID(t *testing.T) { + _ = setupTestTracer(t) + attrs := map[string]*aws_sqs.MessageAttributeValue{ + "traceparent": stringAttr("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"), + } + ctx := extractMessageContext(context.Background(), attrs) + assert.Equal(t, "4bf92f3577b34da6a3ce929d0e0e4736", ctx.Value("trace_id")) +} + +func TestStartProcessSpan_CreatesConsumerSpan(t *testing.T) { + exporter := setupTestTracer(t) + + queueURL := "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue" + _, end := startProcessSpan(context.Background(), "my-queue", "msg-123", queueURL) + end(nil) + + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + s := spans[0] + assert.Equal(t, "process my-queue", s.Name) + assert.Equal(t, trace.SpanKindConsumer, s.SpanKind) + + attrs := attrMap(s.Attributes) + assert.Equal(t, "aws_sqs", attrs["messaging.system"]) + assert.Equal(t, "process", attrs["messaging.operation.name"]) + assert.Equal(t, "process", attrs["messaging.operation.type"]) + assert.Equal(t, "my-queue", attrs["messaging.destination.name"]) + assert.Equal(t, "msg-123", attrs["messaging.message.id"]) + assert.Equal(t, queueURL, attrs["aws.sqs.queue.url"]) + assert.Equal(t, "sqs.us-east-1.amazonaws.com", attrs["server.address"]) +} + +func TestStartProcessSpan_RecordsError(t *testing.T) { + exporter := setupTestTracer(t) + + _, end := startProcessSpan(context.Background(), "my-queue", "msg-123", "") + end(errors.New("processing failed")) + + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + assert.Equal(t, codes.Error, spans[0].Status.Code) + assert.Equal(t, "processing failed", spans[0].Status.Description) + assert.Contains(t, attrMap(spans[0].Attributes), "error.type") +} + +func TestStartProcessSpan_ContinuesProducerTrace(t *testing.T) { + exporter := setupTestTracer(t) + + attrs := map[string]*aws_sqs.MessageAttributeValue{ + "traceparent": stringAttr("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"), + } + ctx := extractMessageContext(context.Background(), attrs) + _, end := startProcessSpan(ctx, "my-queue", "msg-123", "") + end(nil) + + spans := exporter.GetSpans() + assert.Len(t, spans, 1) + assert.Equal(t, "4bf92f3577b34da6a3ce929d0e0e4736", spans[0].SpanContext.TraceID().String()) +} + +func TestOtelProcessHook_InstalledAsDefault(t *testing.T) { + exporter := setupTestTracer(t) + + ctx, end := processHook(context.Background(), "my-queue", "msg-1", "", nil) + _ = ctx + end(nil) + + spans := exporter.GetSpans() + assert.Len(t, spans, 1, "otel init() should override processHook") + assert.Equal(t, "process my-queue", spans[0].Name) +} diff --git a/plugins/sqs_worker/worker.go b/plugins/sqs_worker/worker.go index 4b3805f..2391fed 100644 --- a/plugins/sqs_worker/worker.go +++ b/plugins/sqs_worker/worker.go @@ -34,6 +34,19 @@ type awsMessage struct { *aws_sqs.Message } +// processHook wraps per-message processing so plugins (e.g. the OTel +// instrumentation under the `otel` build tag) can inject a +// SpanKindConsumer "process" span around the handler invocation and any +// subsequent settle operations. The default implementation is a no-op; +// otel_consumer.go overrides it via an init(). +var processHook = func( + ctx context.Context, + topic, msgID, queueURL string, + attrs map[string]*aws_sqs.MessageAttributeValue, +) (context.Context, func(error)) { + return ctx, func(error) {} +} + func init() { plugins.Registry = append(plugins.Registry, NewAwsSqsWorker) } @@ -119,6 +132,13 @@ func (w *AwsSqsWorker) handleMessage(messages []*awsMessage) { w.logger.WithFields(log.Fields{"message": *awsMsg.Body}).Debug("Message received") topic := w.topicMgr.GetTopic(awsMsg.topicName) + + var msgID string + if awsMsg.MessageId != nil { + msgID = *awsMsg.MessageId + } + ctx, endProcess := processHook(w.ctx, topic.Name, msgID, topic.Arn, awsMsg.MessageAttributes) + deleteMessage := true err := w.handler.OnEvent(topic, *awsMsg.Body) if err != nil { @@ -127,11 +147,13 @@ func (w *AwsSqsWorker) handleMessage(messages []*awsMessage) { if deleteMessage { deleteMessageInput := aws_sqs.DeleteMessageInput{QueueUrl: &topic.Arn, ReceiptHandle: awsMsg.ReceiptHandle} - _, err = topic.DeleteMessage(&deleteMessageInput) - if err != nil { - w.logger.WithFields(log.Fields{"message": *awsMsg.Body, "error": err}).Error("Fail to delete message") + _, delErr := topic.DeleteMessageWithContext(ctx, &deleteMessageInput) + if delErr != nil { + w.logger.WithFields(log.Fields{"message": *awsMsg.Body, "error": delErr}).Error("Fail to delete message") } } + + endProcess(err) }(msg, &wg) }