Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add headers and params in operation and request builder kwargs #1183

Merged
merged 26 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 2 additions & 1 deletion ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
| ----------------------------------------------------------------------- | ----------- |
| `@autorest/core` | `3.6.2` |
| `@autorest/modelerfour` | `4.19.1` |
| `azure-core` dep of generated code | `1.20.1` |
| `azure-core` dep of generated code | `1.23.0` |
| `msrest` dep of generated code | `0.6.21` |
| `azure-mgmt-core` dep of generated code (If generating mgmt plane code) | `1.3.0` |

**New Features**

- Add support for handwritten customizations of generated code. For more information, see https://aka.ms/azsdk/python/dpcodegen/python/customize #1153
- Allow `header` and `params` as kwargs in operation and request-build function to hand over REST Header and Query parameters case insensitively #1183

**Bug Fixes**

Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _build_exceptions_set(yaml_data: List[Dict[str, Any]]) -> Set[int]:
def _build_package_dependency() -> Dict[str, str]:
return {
"dependency_azure_mgmt_core": "azure-mgmt-core<2.0.0,>=1.3.0",
"dependency_azure_core": "azure-core<2.0.0,>=1.20.1",
"dependency_azure_core": "azure-core<2.0.0,>=1.23.0",
"dependency_msrest": "msrest>=0.6.21",
}

Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/lro_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def imports_for_multiapi(self, async_mode: bool) -> FileImport:
file_import.add_submodule_import(poller_import_path, poller, ImportType.AZURECORE, TypingSection.CONDITIONAL)
return file_import

def imports(self, async_mode: bool) -> FileImport:
file_import = self._imports_base(async_mode)
def imports(self, async_mode: bool, is_python3_file: bool) -> FileImport:
file_import = self._imports_base(async_mode, is_python3_file)
file_import.add_submodule_import("typing", "Union", ImportType.STDLIB, TypingSection.CONDITIONAL)

poller_import_path = ".".join(self.get_poller_path(async_mode).split(".")[:-1])
Expand Down
6 changes: 3 additions & 3 deletions autorest/codegen/models/lro_paging_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

class LROPagingOperation(PagingOperation, LROOperation):

def imports(self, async_mode: bool) -> FileImport:
lro_imports = LROOperation.imports(self, async_mode)
paging_imports = PagingOperation.imports(self, async_mode)
def imports(self, async_mode: bool, is_python3_file: bool) -> FileImport:
lro_imports = LROOperation.imports(self, async_mode, is_python3_file)
paging_imports = PagingOperation.imports(self, async_mode, is_python3_file)

file_import = lro_imports
file_import.merge(paging_imports)
Expand Down
17 changes: 13 additions & 4 deletions autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .base_builder import BaseBuilder, create_parameters
from .imports import FileImport, ImportType, TypingSection
from .schema_response import SchemaResponse
from .parameter import Parameter, get_parameter
from .parameter import Parameter, get_parameter, ParameterLocation
from .parameter_list import ParameterList, get_parameter_list
from .base_schema import BaseSchema
from .object_schema import ObjectSchema
Expand Down Expand Up @@ -171,13 +171,17 @@ def _imports_shared(self, async_mode: bool) -> FileImport: # pylint: disable=unu
def imports_for_multiapi(self, async_mode: bool) -> FileImport: # pylint: disable=unused-argument
return self._imports_shared(async_mode)

def imports(self, async_mode: bool) -> FileImport:
file_import = self._imports_base(async_mode)
def imports(self, async_mode: bool, is_python3_file: bool) -> FileImport:
file_import = self._imports_base(async_mode, is_python3_file)
if self.has_response_body and not self.has_optional_return_type and not self.code_model.options["models_mode"]:
file_import.add_submodule_import("typing", "cast", ImportType.STDLIB)
return file_import

def _imports_base(self, async_mode: bool) -> FileImport:
@staticmethod
def has_kwargs_to_pop_with_default(kwargs_to_pop: List[Parameter], location: ParameterLocation) -> bool:
return any(kwarg.has_default_value and kwarg.location == location for kwarg in kwargs_to_pop)

def _imports_base(self, async_mode: bool, is_python3_file: bool) -> FileImport:
file_import = self._imports_shared(async_mode)

# Exceptions
Expand All @@ -192,6 +196,10 @@ def _imports_base(self, async_mode: bool) -> FileImport:
file_import.add_submodule_import("typing", "TypeVar", ImportType.STDLIB, TypingSection.CONDITIONAL)
file_import.add_submodule_import("azure.core.pipeline", "PipelineResponse", ImportType.AZURECORE)
file_import.add_submodule_import("azure.core.rest", "HttpRequest", ImportType.AZURECORE)
kwargs_to_pop = self.parameters.kwargs_to_pop(is_python3_file)
if (self.has_kwargs_to_pop_with_default(kwargs_to_pop, ParameterLocation.Header) or
self.has_kwargs_to_pop_with_default(kwargs_to_pop, ParameterLocation.Query)):
file_import.add_submodule_import("azure.core.utils", "case_insensitive_dict", ImportType.AZURECORE)
if async_mode:
file_import.add_submodule_import("azure.core.pipeline.transport", "AsyncHttpResponse", ImportType.AZURECORE)
else:
Expand Down Expand Up @@ -224,6 +232,7 @@ def _imports_base(self, async_mode: bool) -> FileImport:
file_import.add_submodule_import(
f"{relative_path}_vendor", "_convert_request", ImportType.LOCAL
)

if self.code_model.options["version_tolerant"] and (
self.parameters.has_body or
any(r for r in self.responses if r.has_body)
Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/operation_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def imports_for_multiapi(self, async_mode: bool) -> FileImport:
file_import.add_submodule_import(".." if async_mode else ".", "models", ImportType.LOCAL, alias="_models")
return file_import

def imports(self, async_mode: bool) -> FileImport:
def imports(self, async_mode: bool, is_python3_file: bool) -> FileImport:
file_import = FileImport()
file_import.add_submodule_import("azure.core.exceptions", "ClientAuthenticationError", ImportType.AZURECORE)
file_import.add_submodule_import("azure.core.exceptions", "ResourceNotFoundError", ImportType.AZURECORE)
file_import.add_submodule_import("azure.core.exceptions", "ResourceExistsError", ImportType.AZURECORE)
for operation in self.operations:
file_import.merge(operation.imports(async_mode))
file_import.merge(operation.imports(async_mode, is_python3_file))
local_path = "..." if async_mode else ".."
if self.code_model.has_schemas and self.code_model.options["models_mode"]:
file_import.add_submodule_import(local_path, "models", ImportType.LOCAL, alias="_models")
Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/paging_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def imports_for_multiapi(self, async_mode: bool) -> FileImport:

return file_import

def imports(self, async_mode: bool) -> FileImport:
file_import = self._imports_base(async_mode)
def imports(self, async_mode: bool, is_python3_file: bool) -> FileImport:
file_import = self._imports_base(async_mode, is_python3_file)
# operation adds an import for distributed_trace_async, we don't want it
file_import.imports = [i for i in file_import.imports if not i.submodule_name == "distributed_trace_async"]

Expand Down
10 changes: 7 additions & 3 deletions autorest/codegen/models/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
self.body_kwargs: List[Parameter] = []
self.is_body_kwarg = False
self.need_import = True
self.is_kwarg = (self.rest_api_name == "Content-Type" or (self.constant and self.rest_api_name != "Accept"))
self.is_kwarg = (self.rest_api_name == "Content-Type" or (self.constant and self.inputtable_by_user))

def __hash__(self) -> int:
return hash(self.serialized_name)
Expand Down Expand Up @@ -175,6 +175,10 @@ def xml_serialization_ctxt(self) -> str:
def is_body(self) -> bool:
return self.location == ParameterLocation.Body

@property
def inputtable_by_user(self) -> bool:
return self.rest_api_name != "Accept"

@property
def pre_semicolon_content_types(self) -> List[str]:
"""Splits on semicolon of media types and returns the first half.
Expand All @@ -185,8 +189,8 @@ def pre_semicolon_content_types(self) -> List[str]:
@property
def in_method_signature(self) -> bool:
return not(
# don't put accept in signature
self.rest_api_name == "Accept"
# if not inputtable, don't put in signature
not self.inputtable_by_user
# If i'm not in the method code, no point in being in signature
or not self.in_method_code
# If I'm grouped, my grouper will be on signature, not me
Expand Down
4 changes: 1 addition & 3 deletions autorest/codegen/models/request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def imports(self) -> FileImport:
f"{relative_path}_vendor", "_format_url_section", ImportType.LOCAL
)
if self.parameters.headers or self.parameters.query:
file_import.add_submodule_import(
"typing", "Dict", ImportType.STDLIB, typing_section=TypingSection.CONDITIONAL
)
file_import.add_submodule_import("azure.core.utils", "case_insensitive_dict", ImportType.AZURECORE)
file_import.add_submodule_import(
"typing", "Any", ImportType.STDLIB, typing_section=TypingSection.CONDITIONAL
)
Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/request_builder_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class RequestBuilderParameter(ParameterOnlyPathAndBodyPositional):
@property
def in_method_signature(self) -> bool:
return not(
# don't put accept in method signature
self.rest_api_name == "Accept"
# if not inputtable, don't put in signature
not self.inputtable_by_user
# If i'm not in the method code, no point in being in signature
or not self.in_method_code
# If I'm a flattened property of a body, don't want me, want the body param
Expand Down
68 changes: 47 additions & 21 deletions autorest/codegen/serializers/builder_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@ def _serialize_files_and_data_body(builder, param_name: str) -> List[str]:
retval.append("}")
return retval

def _pop_parameters_kwarg(
function_name: str,
kwarg_name: str,
) -> str:
return f'_{function_name}_parameters = kwargs.pop("{kwarg_name}", {{}}) # type: Dict[str, Any]'

def _serialize_grouped_body(builder) -> List[str]:
retval: List[str] = []
for grouped_parameter in builder.parameters.grouped:
Expand Down Expand Up @@ -411,10 +405,11 @@ def serializer_name(self) -> str:
...

def _serialize_parameter(
self, param: Parameter, function_name: str
self, param: Parameter, kwarg_name: str
) -> List[str]:
set_parameter = "_{}_parameters['{}'] = {}".format(
function_name,
function_name = "header" if kwarg_name == "headers" else "query"
set_parameter = "_{}['{}'] = {}".format(
kwarg_name,
param.rest_api_name,
utils.build_serialize_data_call(param, function_name, self.serializer_name)
)
Expand All @@ -434,10 +429,6 @@ def _get_json_response_template(self, builder) -> List[str]:
template.extend(f"response.json() == {response_body}".splitlines())
return template


def pop_kwargs_from_signature(self, builder) -> List[str]:
return utils.pop_kwargs_from_signature(self._get_kwargs_to_pop(builder))

def serialize_path(self, builder) -> List[str]:
return utils.serialize_path(builder.parameters.path, self.serializer_name)

Expand All @@ -463,6 +454,21 @@ def description_and_summary(self, builder) -> List[str]:
def serializer_name(self) -> str:
return "_SERIALIZER"

@staticmethod
def declare_non_inputtable_constants(builder) -> List[str]:
def _get_value(param: Parameter):
if param.location in [ParameterLocation.Header, ParameterLocation.Query]:
kwarg_dict = "headers" if param.location == ParameterLocation.Header else "params"
return f"_{kwarg_dict}.pop('{param.rest_api_name}', {param.constant_declaration})"
return f"{param.constant_declaration}"
return [
f"{p.serialized_name} = {_get_value(p)}"
for p in builder.parameters.constant
if p.original_parameter is None and
p.in_method_code and
not p.in_method_signature
]

def want_example_template(self, builder) -> bool:
if self.code_model.options["builders_visibility"] != "public":
return False # if we're not exposing rest layer, don't need to generate
Expand Down Expand Up @@ -507,14 +513,24 @@ def _has_data_example_template(self, builder) -> bool:
def _body_params_to_pass_to_request_creation(self, builder) -> List[str]:
...

def pop_kwargs_from_signature(self, builder) -> List[str]:
return utils.pop_kwargs_from_signature(
self._get_kwargs_to_pop(builder),
check_kwarg_dict=True,
pop_headers_kwarg=utils.PopKwargType.CASE_INSENSITIVE if bool(builder.parameters.headers)
else utils.PopKwargType.NO,
pop_params_kwarg=utils.PopKwargType.CASE_INSENSITIVE if bool(builder.parameters.query)
else utils.PopKwargType.NO,
)

def create_http_request(self, builder) -> List[str]:
retval = ["return HttpRequest("]
retval.append(f' method="{builder.method}",')
retval.append(" url=_url,")
if builder.parameters.query:
retval.append(" params=_query_parameters,")
retval.append(" params=_params,")
if builder.parameters.headers:
retval.append(" headers=_header_parameters,")
retval.append(" headers=_headers,")
if builder.parameters.has_body:
retval.extend([
f" {body_kwarg}={body_kwarg},"
Expand All @@ -526,21 +542,19 @@ def create_http_request(self, builder) -> List[str]:

def serialize_headers(self, builder) -> List[str]:
retval = ["# Construct headers"]
retval.append(_pop_parameters_kwarg("header", "headers"))
for parameter in builder.parameters.headers:
retval.extend(self._serialize_parameter(
parameter,
function_name="header",
kwarg_name="headers",
))
return retval

def serialize_query(self, builder) -> List[str]:
retval = ["# Construct parameters"]
retval.append(_pop_parameters_kwarg("query", "params"))
for parameter in builder.parameters.query:
retval.extend(self._serialize_parameter(
parameter,
function_name="query",
kwarg_name="params",
))
return retval

Expand Down Expand Up @@ -653,7 +667,15 @@ def _response_type_annotation(self, builder, modify_if_head_as_boolean: bool = T
return response_str

def pop_kwargs_from_signature(self, builder) -> List[str]:
kwargs = utils.pop_kwargs_from_signature(self._get_kwargs_to_pop(builder))
kwargs_to_pop = self._get_kwargs_to_pop(builder)
kwargs = utils.pop_kwargs_from_signature(
kwargs_to_pop,
check_kwarg_dict=True,
pop_headers_kwarg=utils.PopKwargType.CASE_INSENSITIVE if builder.has_kwargs_to_pop_with_default(
kwargs_to_pop, ParameterLocation.Header) else utils.PopKwargType.SIMPLE,
pop_params_kwarg=utils.PopKwargType.CASE_INSENSITIVE if builder.has_kwargs_to_pop_with_default(
kwargs_to_pop, ParameterLocation.Query) else utils.PopKwargType.SIMPLE,
)
kwargs.append(f"cls = kwargs.pop('cls', None) {self.cls_type_annotation(builder)}")
return kwargs

Expand Down Expand Up @@ -860,6 +882,8 @@ def _call_request_builder_helper(
if not self.code_model.options["version_tolerant"]:
template_url = template_url or f"self.{builder.name}.metadata['url']"
retval.append(f" template_url={template_url},")
retval.append(' headers=_headers,')
retval.append(' params=_params,')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also update line 1015 to
retval.append("error_map.update(kwargs.pop('error_map', {})) or {}")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated as this:

retval.append("error_map.update(kwargs.pop('error_map', {}) or {})")

since we want to avoid kwargs['error_map'] is None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, if the user passes in operation(error_map=None), we would have error_map = None, then map_error gets mad

retval.append(f")")
if not self.code_model.options["version_tolerant"]:
pass_files = ""
Expand Down Expand Up @@ -1019,7 +1043,7 @@ def error_map(self, builder) -> List[str]:
else:
retval.append(" 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError")
retval.append("}")
retval.append("error_map.update(kwargs.pop('error_map', {}))")
retval.append("error_map.update(kwargs.pop('error_map', {}) or {})")
return retval

@staticmethod
Expand Down Expand Up @@ -1297,6 +1321,8 @@ def initial_call(self, builder) -> List[str]:
for parameter in builder.parameters.method
])
retval.append(" cls=lambda x,y,z: x,")
retval.append(" headers=_headers,")
retval.append(" params=_params,")
retval.append(" **kwargs")
retval.append(" )")
retval.append("kwargs.pop('error_map', None)")
Expand Down
22 changes: 16 additions & 6 deletions autorest/codegen/serializers/client_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ def init_signature_and_response_type_annotation(self, async_mode: bool) -> str:
)

def pop_kwargs_from_signature(self, async_mode: bool) -> List[str]:
return utils.pop_kwargs_from_signature(self.code_model.service_client.parameters.kwargs_to_pop(
async_mode or self.is_python3_file
))
return utils.pop_kwargs_from_signature(
self.code_model.service_client.parameters.kwargs_to_pop(
async_mode or self.is_python3_file,
),
check_kwarg_dict=False,
pop_headers_kwarg=utils.PopKwargType.NO,
pop_params_kwarg=utils.PopKwargType.NO,
)

def class_definition(self, async_mode) -> str:
class_name = self.code_model.class_name
Expand Down Expand Up @@ -212,9 +217,14 @@ def init_signature_and_response_type_annotation(self, async_mode: bool) -> str:
)

def pop_kwargs_from_signature(self, async_mode: bool) -> List[str]:
return utils.pop_kwargs_from_signature(self.code_model.global_parameters.config_kwargs_to_pop(
async_mode or self.is_python3_file
))
return utils.pop_kwargs_from_signature(
self.code_model.global_parameters.config_kwargs_to_pop(
async_mode or self.is_python3_file
),
check_kwarg_dict=False,
pop_headers_kwarg=utils.PopKwargType.NO,
pop_params_kwarg=utils.PopKwargType.NO,
)

def set_constants(self) -> List[str]:
return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _is_paging(operation):
for operation_group in operation_groups:
imports.merge(operation_group.imports(
async_mode=self.async_mode,
is_python3_file=self.is_python3_file,
))

template = self.env.get_or_select_template("operation_groups_container.py.jinja2")
Expand Down
Loading