diff --git a/packages/http-client-python/generator/pygen/codegen/models/operation.py b/packages/http-client-python/generator/pygen/codegen/models/operation.py index 902f34fe978..54934671d6d 100644 --- a/packages/http-client-python/generator/pygen/codegen/models/operation.py +++ b/packages/http-client-python/generator/pygen/codegen/models/operation.py @@ -449,7 +449,7 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements file_import.add_import("json", ImportType.STDLIB) if self.enable_import_deserialize_xml: file_import.add_submodule_import(relative_path, "_deserialize_xml", ImportType.LOCAL) - elif self.need_deserialize: + if self.need_deserialize: file_import.add_submodule_import(relative_path, "_deserialize", ImportType.LOCAL) if self.default_error_deserialization(serialize_namespace) or self.non_default_errors: xml_non_default_errors = any( diff --git a/packages/http-client-python/generator/pygen/codegen/serializers/parameter_serializer.py b/packages/http-client-python/generator/pygen/codegen/serializers/parameter_serializer.py index 9091f86748d..f01566d0f94 100644 --- a/packages/http-client-python/generator/pygen/codegen/serializers/parameter_serializer.py +++ b/packages/http-client-python/generator/pygen/codegen/serializers/parameter_serializer.py @@ -162,13 +162,14 @@ def serialize_query_header( param.wire_name, self.serialize_parameter(param, serializer_name), ) - if not param.optional and (param.in_method_signature or param.constant): - retval = [set_parameter] - else: + is_content_type = getattr(param, "is_content_type", False) + if is_content_type or param.optional or not (param.in_method_signature or param.constant): retval = [ f"if {param.full_client_name} is not None:", f" {set_parameter}", ] + else: + retval = [set_parameter] return retval @staticmethod diff --git a/packages/http-client-python/generator/test/unittests/test_operation_imports.py b/packages/http-client-python/generator/test/unittests/test_operation_imports.py new file mode 100644 index 00000000000..0d2043c8334 --- /dev/null +++ b/packages/http-client-python/generator/test/unittests/test_operation_imports.py @@ -0,0 +1,184 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from pygen.codegen.models import ( + Operation, + Response, + ParameterList, + CodeModel, + RequestBuilder, + Client, +) +from pygen.codegen.models.parameter_list import RequestBuilderParameterList +from pygen.codegen.models.primitive_types import StringType + + +@pytest.fixture +def code_model(): + return CodeModel( + { + "clients": [ + { + "name": "client", + "namespace": "blah", + "moduleName": "blah", + "parameters": [], + "url": "", + "operationGroups": [], + } + ], + "namespace": "namespace", + }, + options={ + "show-send-request": True, + "builders-visibility": "public", + "show-operations": True, + "models-mode": "dpg", + "version-tolerant": True, + "flavor": "azure", + "tracing": False, + "azure-arm": False, + "head-as-boolean": False, + "combine-operation-files": False, + "validate-versioning": False, + }, + ) + + +@pytest.fixture +def client(code_model): + return Client( + { + "name": "client", + "namespace": "blah", + "moduleName": "blah", + "parameters": [], + "url": "", + "operationGroups": [], + }, + code_model, + parameters=[], + ) + + +@pytest.fixture +def request_builder(code_model, client): + return RequestBuilder( + yaml_data={ + "url": "http://fake.com", + "method": "GET", + "groupName": "blah", + "isOverload": False, + "apiVersions": [], + }, + client=client, + code_model=code_model, + name="test_imports_operation", + parameters=RequestBuilderParameterList({}, code_model, parameters=[]), + ) + + +@pytest.fixture +def base_type(code_model): + return StringType({"type": "string"}, code_model) + + +def _make_operation(code_model, client, request_builder, responses, exceptions=None): + return Operation( + yaml_data={ + "url": "http://fake.com", + "method": "GET", + "groupName": "blah", + "isOverload": False, + "apiVersions": [], + }, + client=client, + code_model=code_model, + request_builder=request_builder, + name="test_imports_operation", + parameters=ParameterList({}, code_model, []), + responses=responses, + exceptions=exceptions or [], + ) + + +def _has_import(file_import, submodule_name): + """Check whether a FileImport contains an import with the given submodule_name.""" + return any(imp.submodule_name == submodule_name for imp in file_import.imports) + + +def test_operation_imports_both_deserialize_xml_and_json( + code_model, client, request_builder, base_type +): + """When an operation has both XML and JSON/typed responses, both + _deserialize_xml and _deserialize should be imported.""" + xml_response = Response( + yaml_data={"statusCodes": [200], "defaultContentType": "application/xml"}, + code_model=code_model, + headers=[], + type=base_type, + ) + json_response = Response( + yaml_data={"statusCodes": [201], "defaultContentType": "application/json"}, + code_model=code_model, + headers=[], + type=base_type, + ) + operation = _make_operation( + code_model, client, request_builder, responses=[xml_response, json_response] + ) + + assert operation.enable_import_deserialize_xml + assert operation.need_deserialize + + file_import = operation.imports(async_mode=False, serialize_namespace="namespace") + assert _has_import(file_import, "_deserialize_xml") + assert _has_import(file_import, "_deserialize") + + +def test_operation_imports_only_xml( + code_model, client, request_builder, base_type +): + """When an operation has only XML responses, _deserialize_xml should be + imported. _deserialize should also be present because the response has + a non-binary type (triggers need_deserialize).""" + xml_response = Response( + yaml_data={"statusCodes": [200], "defaultContentType": "application/xml"}, + code_model=code_model, + headers=[], + type=base_type, + ) + operation = _make_operation( + code_model, client, request_builder, responses=[xml_response] + ) + + assert operation.enable_import_deserialize_xml + + file_import = operation.imports(async_mode=False, serialize_namespace="namespace") + assert _has_import(file_import, "_deserialize_xml") + + +def test_operation_imports_only_json( + code_model, client, request_builder, base_type +): + """When an operation has only JSON responses (no XML), _deserialize + should be imported but _deserialize_xml should NOT.""" + json_response = Response( + yaml_data={"statusCodes": [200], "defaultContentType": "application/json"}, + code_model=code_model, + headers=[], + type=base_type, + ) + operation = _make_operation( + code_model, client, request_builder, responses=[json_response] + ) + + assert not operation.enable_import_deserialize_xml + assert operation.need_deserialize + + file_import = operation.imports(async_mode=False, serialize_namespace="namespace") + assert not _has_import(file_import, "_deserialize_xml") + assert _has_import(file_import, "_deserialize") diff --git a/packages/http-client-python/generator/test/unittests/test_parameter_serializer.py b/packages/http-client-python/generator/test/unittests/test_parameter_serializer.py new file mode 100644 index 00000000000..abc4a7a7346 --- /dev/null +++ b/packages/http-client-python/generator/test/unittests/test_parameter_serializer.py @@ -0,0 +1,83 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from pygen.codegen.models import Parameter, CodeModel +from pygen.codegen.models.primitive_types import StringType +from pygen.codegen.serializers.parameter_serializer import ParameterSerializer + + +def get_code_model(): + return CodeModel( + { + "clients": [ + { + "name": "client", + "namespace": "blah", + "moduleName": "blah", + "parameters": [], + "url": "", + "operationGroups": [], + } + ], + "namespace": "namespace", + }, + options={ + "show-send-request": True, + "builders-visibility": "public", + "show-operations": True, + "models-mode": "dpg", + }, + ) + + +def make_header_param(wire_name, optional=False): + cm = get_code_model() + return Parameter( + yaml_data={ + "wireName": wire_name, + "clientName": wire_name.replace("-", "_").lower(), + "location": "header", + "optional": optional, + "implementation": "Method", + "inOverload": False, + "inOverloaded": False, + }, + code_model=cm, + type=StringType({"type": "string"}, cm), + ) + + +def test_content_type_header_has_none_guard(): + """Content-Type header should always have an 'if is not None' guard, + even when the parameter is required, to handle the case where + content_type is set to None for optional bodies.""" + param = make_header_param("Content-Type", optional=False) + result = ParameterSerializer("namespace").serialize_query_header( + param, "headers", "_SERIALIZER", is_legacy=False + ) + joined = "\n".join(result) + assert "is not None" in joined + + +def test_non_content_type_required_header_no_guard(): + """A required non-Content-Type header should be a direct assignment + with no None guard.""" + param = make_header_param("x-ms-version", optional=False) + result = ParameterSerializer("namespace").serialize_query_header( + param, "headers", "_SERIALIZER", is_legacy=False + ) + joined = "\n".join(result) + assert "is not None" not in joined + + +def test_optional_header_has_none_guard(): + """An optional header parameter should have an 'if is not None' guard + regardless of its wire name.""" + param = make_header_param("x-ms-version", optional=True) + result = ParameterSerializer("namespace").serialize_query_header( + param, "headers", "_SERIALIZER", is_legacy=False + ) + joined = "\n".join(result) + assert "is not None" in joined