diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py index afe58f311a0c..839cc1796ae8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_decode.py @@ -22,6 +22,7 @@ from typing_extensions import Literal +from . import described from .message import Message, Header, Properties if TYPE_CHECKING: @@ -153,6 +154,7 @@ def _decode_binary_large(buffer: memoryview) -> Tuple[memoryview, bytes]: length_index = c_unsigned_long.unpack(buffer[:4])[0] + 4 return buffer[length_index:], buffer[4:length_index].tobytes() + def _decode_decimal128(buffer: memoryview) -> Tuple[memoryview, decimal.Decimal]: """ Decode a Decimal128 value from the buffer. @@ -254,11 +256,20 @@ def _decode_map_large(buffer: memoryview) -> Tuple[memoryview, Dict[Any, Any]]: def _decode_array_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]: count = buffer[1] # Ignore first byte (size) and just rely on count if count: - subconstructor = buffer[2] - buffer = buffer[3:] values = [None] * count - for i in range(count): - buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) + subconstructor = buffer[2] + + if subconstructor == 0: + composite_type = buffer[3] + buffer, descriptor = _DECODE_BY_CONSTRUCTOR[composite_type](buffer[4:]) + subconstructor = buffer[0] + buffer = buffer[1:] + for i in range(count): + buffer, values[i] = _decode_described_array(buffer, subconstructor, descriptor) + else: + buffer = buffer[3:] + for i in range(count): + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) return buffer, values return buffer[2:], [] @@ -266,11 +277,20 @@ def _decode_array_small(buffer: memoryview) -> Tuple[memoryview, List[Any]]: def _decode_array_large(buffer: memoryview) -> Tuple[memoryview, List[Any]]: count = c_unsigned_long.unpack(buffer[4:8])[0] if count: - subconstructor = buffer[8] - buffer = buffer[9:] values = [None] * count - for i in range(count): - buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) + subconstructor = buffer[8] + + if subconstructor == 0: + composite_type = buffer[9] + buffer, descriptor = _DECODE_BY_CONSTRUCTOR[composite_type](buffer[10:]) + subconstructor = buffer[0] + buffer = buffer[1:] + for i in range(count): + buffer, values[i] = _decode_described_array(buffer, subconstructor, descriptor) + else: + buffer = buffer[9:] + for i in range(count): + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) return buffer, values return buffer[8:], [] @@ -280,7 +300,25 @@ def _decode_described(buffer: memoryview) -> Tuple[memoryview, object]: # descriptor without decoding descriptor value composite_type = buffer[0] buffer, descriptor = _DECODE_BY_CONSTRUCTOR[composite_type](buffer[1:]) - buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + tp = buffer[0] + buffer, value = _DECODE_BY_CONSTRUCTOR[tp](buffer[1:]) + try: + value = _DESCR_BY_CONSTRUCTOR[tp](value, descriptor=descriptor) + except KeyError: + pass + try: + composite_type = cast(int, _COMPOSITES[descriptor]) + return buffer, {composite_type: value} + except KeyError: + return buffer, value + + +def _decode_described_array(buffer: memoryview, tp, descriptor) -> Tuple[memoryview, object]: + buffer, value = _DECODE_BY_CONSTRUCTOR[tp](buffer) + try: + value = _DESCR_BY_CONSTRUCTOR[tp](value, descriptor=descriptor) + except KeyError: + pass try: composite_type = cast(int, _COMPOSITES[descriptor]) return buffer, {composite_type: value} @@ -394,3 +432,36 @@ def decode_empty_frame(header: memoryview) -> Tuple[int, bytes]: _DECODE_BY_CONSTRUCTOR[209] = _decode_map_large _DECODE_BY_CONSTRUCTOR[224] = _decode_array_small _DECODE_BY_CONSTRUCTOR[240] = _decode_array_large + +_DESCR_BY_CONSTRUCTOR = { + 67: described.DescribedInt, + 68: described.DescribedInt, + 69: described.DescribedList, + 80: described.DescribedInt, + 81: described.DescribedInt, + 82: described.DescribedInt, + 83: described.DescribedInt, + 84: described.DescribedInt, + 85: described.DescribedInt, + 96: described.DescribedInt, + 97: described.DescribedInt, + 112: described.DescribedInt, + 113: described.DescribedInt, + 114: described.DescribedFloat, + 128: described.DescribedInt, + 129: described.DescribedInt, + 130: described.DescribedFloat, + 131: described.DescribedInt, + 160: described.DescribedBytes, + 161: described.DescribedBytes, + 163: described.DescribedBytes, + 176: described.DescribedBytes, + 177: described.DescribedBytes, + 179: described.DescribedBytes, + 192: described.DescribedList, + 193: described.DescribedDict, + 208: described.DescribedList, + 209: described.DescribedDict, + 224: described.DescribedList, + 240: described.DescribedList, +} diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/described.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/described.py new file mode 100644 index 000000000000..a86a15f119ed --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/described.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +class Described: + def __new__(cls, value, descriptor=None): + obj = super().__new__(cls, value) + obj.descriptor = descriptor + return obj + + +class DescribedInt(Described, int): + pass + + +class DescribedFloat(Described, float): + pass + + +class DescribedBytes(Described, bytes): + pass + + +class DescribedList(Described, list): + pass + + +class DescribedDict(Described, dict): + pass diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_decode.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_decode.py index 05b77014c20c..4254715abffa 100644 --- a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_decode.py +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_decode.py @@ -1,5 +1,5 @@ import pytest -from azure.eventhub._pyamqp._decode import _decode_decimal128 +from azure.eventhub._pyamqp._decode import _decode_decimal128, _decode_described, _decode_array_small, _decode_array_large from decimal import Decimal @@ -18,3 +18,30 @@ def test_decimal_decode(value, expected): assert output[1] == expected +def test_described(): + value = b"\x80\0\0\x017\0\0\x07\xd3\xd0\0\0\0\x12\0\0\0\x02\xa1\ntest/topicP\0" + buffer, output = _decode_described(memoryview(value)) + assert output.descriptor == 1335734831059 + assert output == [b'test/topic', 0] + + +def test_array_of_described(): + value = b"\0\x03\0\x80\0\0\x017\0\0\x07\xd4\xd0\0\0\0\x0c\0\0\0\x02\xa1\x02n1\xa1\x02v1\0\0\0\x0c\0\0\0\x02\xa1\x02n2\xa1\x02v2\0\0\0\n\0\0\0\x02\xa1\x02n1\xa1\0" + + buffer, output = _decode_array_small(memoryview(value)) + assert output == [[b'n1', b'v1'], [b'n2', b'v2'], [b'n1', b'']] + assert output[0].descriptor == 1335734831060 + assert output[1].descriptor == 1335734831060 + assert output[2].descriptor == 1335734831060 + + +def test_array_of_described_large(): + value = b"\0\0\x0e\x0f\0\0\x01\0\0\x80\0\0\x017\0\0\x07\xd4\xd0" + for i in range(256): + value += b"\0\0\0\n\0\0\0\x02\xa1\x01n\xa1\x01v" + + buffer, output = _decode_array_large(memoryview(value)) + assert len(output) == 256 + for i in range(256): + assert output[i] == [b'n', b'v'] + assert output[i].descriptor == 1335734831060