Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def imports( # pylint: disable=too-many-branches, disable=too-many-statements
file_import.add_import("json", ImportType.STDLIB)
if self.enable_import_deserialize_xml:
file_import.add_submodule_import(relative_path, "_deserialize_xml", ImportType.LOCAL)
elif self.need_deserialize:
if self.need_deserialize:
file_import.add_submodule_import(relative_path, "_deserialize", ImportType.LOCAL)
if self.default_error_deserialization(serialize_namespace) or self.non_default_errors:
xml_non_default_errors = any(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,14 @@ def serialize_query_header(
param.wire_name,
self.serialize_parameter(param, serializer_name),
)
if not param.optional and (param.in_method_signature or param.constant):
retval = [set_parameter]
else:
is_content_type = getattr(param, "is_content_type", False)
if is_content_type or param.optional or not (param.in_method_signature or param.constant):
retval = [
f"if {param.full_client_name} is not None:",
f" {set_parameter}",
]
else:
retval = [set_parameter]
return retval

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import pytest
from pygen.codegen.models import (
Operation,
Response,
ParameterList,
CodeModel,
RequestBuilder,
Client,
)
from pygen.codegen.models.parameter_list import RequestBuilderParameterList
from pygen.codegen.models.primitive_types import StringType


@pytest.fixture
def code_model():
return CodeModel(
{
"clients": [
{
"name": "client",
"namespace": "blah",
"moduleName": "blah",
"parameters": [],
"url": "",
"operationGroups": [],
}
],
"namespace": "namespace",
},
options={
"show-send-request": True,
"builders-visibility": "public",
"show-operations": True,
"models-mode": "dpg",
"version-tolerant": True,
"flavor": "azure",
"tracing": False,
"azure-arm": False,
"head-as-boolean": False,
"combine-operation-files": False,
"validate-versioning": False,
},
)


@pytest.fixture
def client(code_model):
return Client(
{
"name": "client",
"namespace": "blah",
"moduleName": "blah",
"parameters": [],
"url": "",
"operationGroups": [],
},
code_model,
parameters=[],
)


@pytest.fixture
def request_builder(code_model, client):
return RequestBuilder(
yaml_data={
"url": "http://fake.com",
"method": "GET",
"groupName": "blah",
"isOverload": False,
"apiVersions": [],
},
client=client,
code_model=code_model,
name="test_imports_operation",
parameters=RequestBuilderParameterList({}, code_model, parameters=[]),
)


@pytest.fixture
def base_type(code_model):
return StringType({"type": "string"}, code_model)


def _make_operation(code_model, client, request_builder, responses, exceptions=None):
return Operation(
yaml_data={
"url": "http://fake.com",
"method": "GET",
"groupName": "blah",
"isOverload": False,
"apiVersions": [],
},
client=client,
code_model=code_model,
request_builder=request_builder,
name="test_imports_operation",
parameters=ParameterList({}, code_model, []),
responses=responses,
exceptions=exceptions or [],
)


def _has_import(file_import, submodule_name):
"""Check whether a FileImport contains an import with the given submodule_name."""
return any(imp.submodule_name == submodule_name for imp in file_import.imports)


def test_operation_imports_both_deserialize_xml_and_json(
code_model, client, request_builder, base_type
):
"""When an operation has both XML and JSON/typed responses, both
_deserialize_xml and _deserialize should be imported."""
xml_response = Response(
yaml_data={"statusCodes": [200], "defaultContentType": "application/xml"},
code_model=code_model,
headers=[],
type=base_type,
)
json_response = Response(
yaml_data={"statusCodes": [201], "defaultContentType": "application/json"},
code_model=code_model,
headers=[],
type=base_type,
)
operation = _make_operation(
code_model, client, request_builder, responses=[xml_response, json_response]
)

assert operation.enable_import_deserialize_xml
assert operation.need_deserialize

file_import = operation.imports(async_mode=False, serialize_namespace="namespace")
assert _has_import(file_import, "_deserialize_xml")
assert _has_import(file_import, "_deserialize")


def test_operation_imports_only_xml(
code_model, client, request_builder, base_type
):
"""When an operation has only XML responses, _deserialize_xml should be
imported. _deserialize should also be present because the response has
a non-binary type (triggers need_deserialize)."""
xml_response = Response(
yaml_data={"statusCodes": [200], "defaultContentType": "application/xml"},
code_model=code_model,
headers=[],
type=base_type,
)
operation = _make_operation(
code_model, client, request_builder, responses=[xml_response]
)

assert operation.enable_import_deserialize_xml

file_import = operation.imports(async_mode=False, serialize_namespace="namespace")
assert _has_import(file_import, "_deserialize_xml")


def test_operation_imports_only_json(
code_model, client, request_builder, base_type
):
"""When an operation has only JSON responses (no XML), _deserialize
should be imported but _deserialize_xml should NOT."""
json_response = Response(
yaml_data={"statusCodes": [200], "defaultContentType": "application/json"},
code_model=code_model,
headers=[],
type=base_type,
)
operation = _make_operation(
code_model, client, request_builder, responses=[json_response]
)

assert not operation.enable_import_deserialize_xml
assert operation.need_deserialize

file_import = operation.imports(async_mode=False, serialize_namespace="namespace")
assert not _has_import(file_import, "_deserialize_xml")
assert _has_import(file_import, "_deserialize")
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from pygen.codegen.models import Parameter, CodeModel
from pygen.codegen.models.primitive_types import StringType
from pygen.codegen.serializers.parameter_serializer import ParameterSerializer


def get_code_model():
return CodeModel(
{
"clients": [
{
"name": "client",
"namespace": "blah",
"moduleName": "blah",
"parameters": [],
"url": "",
"operationGroups": [],
}
],
"namespace": "namespace",
},
options={
"show-send-request": True,
"builders-visibility": "public",
"show-operations": True,
"models-mode": "dpg",
},
)


def make_header_param(wire_name, optional=False):
cm = get_code_model()
return Parameter(
yaml_data={
"wireName": wire_name,
"clientName": wire_name.replace("-", "_").lower(),
"location": "header",
"optional": optional,
"implementation": "Method",
"inOverload": False,
"inOverloaded": False,
},
code_model=cm,
type=StringType({"type": "string"}, cm),
)


def test_content_type_header_has_none_guard():
"""Content-Type header should always have an 'if is not None' guard,
even when the parameter is required, to handle the case where
content_type is set to None for optional bodies."""
param = make_header_param("Content-Type", optional=False)
result = ParameterSerializer("namespace").serialize_query_header(
param, "headers", "_SERIALIZER", is_legacy=False
)
joined = "\n".join(result)
assert "is not None" in joined


def test_non_content_type_required_header_no_guard():
"""A required non-Content-Type header should be a direct assignment
with no None guard."""
param = make_header_param("x-ms-version", optional=False)
result = ParameterSerializer("namespace").serialize_query_header(
param, "headers", "_SERIALIZER", is_legacy=False
)
joined = "\n".join(result)
assert "is not None" not in joined


def test_optional_header_has_none_guard():
"""An optional header parameter should have an 'if is not None' guard
regardless of its wire name."""
param = make_header_param("x-ms-version", optional=True)
result = ParameterSerializer("namespace").serialize_query_header(
param, "headers", "_SERIALIZER", is_legacy=False
)
joined = "\n".join(result)
assert "is not None" in joined
Loading