Skip to content

Commit

Permalink
switch to dict models (#1203)
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft authored Mar 23, 2022
1 parent 8207add commit d7fe86f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 271 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def test_dunder_all(package_name):
def test_imports():
# make sure we can import all of the models we've added to the customization class
from dpgcustomizationcustomizedversiontolerant.models import (
Input, LROProduct, Product, ProductResult
Input, LROProduct, Product
)
models = [Input, LROProduct, Product, ProductResult]
models = [Input, LROProduct, Product]
# check public models
public_models = [
name for name, obj in
inspect.getmembers(sys.modules["dpgcustomizationcustomizedversiontolerant.models"])
if name[0] != "_" and obj in models
]
assert len(public_models) == 4
assert len(public_models) == 3
Original file line number Diff line number Diff line change
Expand Up @@ -2,110 +2,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from typing import Any, overload, Union, TYPE_CHECKING
from azure.core.paging import ItemPaged
from typing import Any, Iterable
from azure.core.polling import LROPoller
from ..models import * # pylint: disable=wildcard-import,unused-wildcard-import
from ._operations import DPGClientOperationsMixin as DPGClientOperationsMixinGenerated, JSONType

if TYPE_CHECKING:
from typing import Literal


def mode_checks(*args, **kwargs: Any) -> bool:
"""Return whether model_mode is True"""
if args:
return args[0] == "model"
if "mode" in kwargs:
return kwargs["mode"] == "model"
raise ValueError("You need to specify 'mode' equal to 'raw' or 'model'.")
from ._operations import DPGClientOperationsMixin as DPGClientOperationsMixinGenerated


class DPGClientOperationsMixin(DPGClientOperationsMixinGenerated):
@overload
def get_model(self, mode: "Literal['raw']", **kwargs: Any) -> JSONType:
"""Pass in mode='raw' to get raw JSON out"""

@overload
def get_model(self, mode: "Literal['model']", **kwargs: Any) -> Product:
"""Pass in mode='model' to get a handwritten model out"""

@overload
def get_model(self, mode: str, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

def get_model(self, *args, **kwargs: Any) -> Union[JSONType, Product]:
model_mode = mode_checks(*args, **kwargs)
response = super().get_model(*args, **kwargs)
if model_mode:
return Product.deserialize(response)
return response

@overload
def post_model(self, mode: "Literal['raw']", input: JSONType, **kwargs: Any) -> JSONType:
"""Pass in mode='raw' to pass in raw json"""

@overload
def post_model(self, mode: "Literal['model']", input: Input, **kwargs: Any) -> Product:
"""Pass in mode='model' to pass in model"""

@overload
def post_model(self, mode: str, input: Input, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

def post_model(self, *args, **kwargs: Any) -> JSONType:
model_mode = mode_checks(*args, **kwargs)
if model_mode:
if len(args) > 1:
args = list(args) # type: ignore
args[1] = Input.serialize(args[1]) # type: ignore # pylint: disable=expression-not-assigned
else:
kwargs["input"] == Input.serialize(kwargs["input"]) # pylint: disable=expression-not-assigned
response = super().post_model(*args, **kwargs)
if model_mode:
return Product.deserialize(response)
return response

@overload
def get_pages(self, mode: "Literal['raw']", **kwargs) -> ItemPaged[JSONType]:
"""Pass in mode='raw' to pass for raw json"""

@overload
def get_pages(self, mode: "Literal['model']", **kwargs) -> ItemPaged[Product]:
"""Pass in mode='model' to pass for raw json"""

@overload
def get_pages(self, mode: str, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

def get_pages(self, *args, **kwargs):
model_mode = mode_checks(*args, **kwargs)
if model_mode:
kwargs["cls"] = lambda objs: [Product.deserialize(x) for x in objs]
return super().get_pages(*args, **kwargs)

@overload
def begin_lro(self, mode: "Literal['raw']", **kwargs) -> LROPoller[JSONType]:
"""Pass in mode='raw' to pass for raw json"""

@overload
def begin_lro(self, mode: "Literal['model']", **kwargs) -> LROPoller[LROProduct]:
"""Pass in mode='model' to pass for raw json"""

@overload
def begin_lro(self, mode: str, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

def begin_lro(self, *args, **kwargs: Any):
model_mode = mode_checks(*args, **kwargs)
if model_mode:
kwargs["cls"] = lambda pipeline_response, deserialized, headers: LROProduct.deserialize(pipeline_response)
return super().begin_lro(*args, **kwargs)
def get_model(self, mode: str, **kwargs: Any) -> Product:
response = super().get_model(mode, **kwargs)
return Product(**response)

def post_model(self, mode: str, input: Product, **kwargs: Any) -> Product:
response = super().post_model(mode, input, **kwargs)
return Product(**response)

def get_pages(self, mode: str, **kwargs: Any) -> Iterable[Product]:
return super().get_pages(mode, cls=lambda objs: [Product(**x) for x in objs], **kwargs)

def begin_lro(self, mode: str, **kwargs: Any) -> LROPoller[LROProduct]:
return super().begin_lro(
mode,
cls=lambda pipeline_response, deserialized, headers: LROProduct._from_dict( # pylint: disable=protected-access
**deserialized
),
**kwargs
)


def patch_sdk():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,102 +3,36 @@
# Licensed under the MIT License.
# ------------------------------------

from typing import Any, Union, overload, TYPE_CHECKING
from typing import Any, AsyncIterable, TYPE_CHECKING
from azure.core.polling import AsyncLROPoller
from azure.core.async_paging import AsyncItemPaged

from ._operations import DPGClientOperationsMixin as DPGClientOperationsMixinGenerated, JSONType
from ._operations import DPGClientOperationsMixin as DPGClientOperationsMixinGenerated
from ...models import * # pylint: disable=wildcard-import,unused-wildcard-import
from ..._operations._patch import mode_checks

if TYPE_CHECKING:
from typing import Literal


class DPGClientOperationsMixin(DPGClientOperationsMixinGenerated):
@overload
async def get_model(self, mode: "Literal['raw']", **kwargs: Any) -> JSONType:
"""Pass in mode='raw' to get raw JSON out"""

@overload
async def get_model(self, mode: "Literal['model']", **kwargs: Any) -> Product:
"""Pass in mode='model' to get a handwritten model out"""

@overload
async def get_model(self, mode: str, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

async def get_model(self, *args, **kwargs: Any) -> Union[JSONType, Product]:
model_mode = mode_checks(*args, **kwargs)
response = await super().get_model(*args, **kwargs)
if model_mode:
return Product.deserialize(response)
return response

@overload
async def post_model(self, mode: "Literal['raw']", input: JSONType, **kwargs: Any) -> JSONType:
"""Pass in mode='raw' to pass in raw json"""

@overload
async def post_model(self, mode: "Literal['model']", input: Input, **kwargs: Any) -> Product:
"""Pass in mode='model' to pass in model"""

@overload
async def post_model(self, mode: str, input: Input, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

async def post_model(self, *args, **kwargs: Any) -> JSONType:
model_mode = mode_checks(*args, **kwargs)
if model_mode:
if len(args) > 1:
args = list(args) # type: ignore
args[1] = Input.serialize(args[1]) # type: ignore # pylint: disable=expression-not-assigned
else:
kwargs["input"] == Input.serialize(kwargs["input"]) # pylint: disable=expression-not-assigned
response = await super().post_model(*args, **kwargs)
if model_mode:
return Product.deserialize(response)
return response

@overload
def get_pages(self, mode: "Literal['raw']", **kwargs) -> AsyncItemPaged[JSONType]:
"""Pass in mode='raw' to pass for raw json"""

@overload
def get_pages(self, mode: "Literal['model']", **kwargs) -> AsyncItemPaged[Product]:
"""Pass in mode='model' to pass for raw json"""

@overload
def get_pages(self, mode: str, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

def get_pages(self, *args, **kwargs):
model_mode = mode_checks(*args, **kwargs)
if model_mode:
kwargs["cls"] = lambda objs: [Product.deserialize(x) for x in objs]
return super().get_pages(*args, **kwargs)

@overload
async def begin_lro(self, mode: "Literal['raw']", **kwargs) -> AsyncLROPoller[JSONType]:
"""Pass in mode='raw' to pass for raw json"""

@overload
async def begin_lro(self, mode: "Literal['model']", **kwargs) -> AsyncLROPoller[LROProduct]:
"""Pass in mode='model' to pass for raw json"""

@overload
async def begin_lro(self, mode: str, **kwargs: Any):
"""Pass in other modes"""
raise Exception("No Implementation")

async def begin_lro(self, *args, **kwargs: Any):
model_mode = mode_checks(*args, **kwargs)
if model_mode:
kwargs["cls"] = lambda pipeline_response, deserialized, headers: LROProduct.deserialize(pipeline_response)
return await super().begin_lro(*args, **kwargs)
async def get_model(self, mode: str, **kwargs: Any) -> Product:
response = await super().get_model(mode, **kwargs)
return Product(**response)

async def post_model(self, mode: str, input: Product, **kwargs: Any) -> Product:
response = await super().post_model(mode, input, **kwargs)
return Product(**response)

def get_pages(self, mode: str, **kwargs) -> AsyncIterable[Product]:
return super().get_pages(mode, cls=lambda objs: [Product(**x) for x in objs], **kwargs)

async def begin_lro(self, mode: str, **kwargs: Any) -> AsyncLROPoller[LROProduct]:
return await super().begin_lro(
mode,
cls=lambda pipeline_response, deserialized, headers: LROProduct._from_dict( # pylint: disable=protected-access
**deserialized
),
**kwargs
)


def patch_sdk():
Expand Down
Loading

0 comments on commit d7fe86f

Please sign in to comment.