-
-
Notifications
You must be signed in to change notification settings - Fork 32.7k
/
Copy path__init__.py
117 lines (94 loc) · 3.82 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""The OpenAI Conversation integration."""
from __future__ import annotations
import openai
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, Platform
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import (
ConfigEntryNotReady,
HomeAssistantError,
ServiceValidationError,
)
from homeassistant.helpers import config_validation as cv, selector
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, LOGGER
SERVICE_GENERATE_IMAGE = "generate_image"
PLATFORMS = (Platform.CONVERSATION,)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up OpenAI Conversation."""
async def render_image(call: ServiceCall) -> ServiceResponse:
"""Render an image with dall-e."""
entry_id = call.data["config_entry"]
entry = hass.config_entries.async_get_entry(entry_id)
if entry is None or entry.domain != DOMAIN:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="invalid_config_entry",
translation_placeholders={"config_entry": entry_id},
)
client: openai.AsyncClient = entry.runtime_data
try:
response = await client.images.generate(
model="dall-e-3",
prompt=call.data["prompt"],
size=call.data["size"],
quality=call.data["quality"],
style=call.data["style"],
response_format="url",
n=1,
)
except openai.OpenAIError as err:
raise HomeAssistantError(f"Error generating image: {err}") from err
return response.data[0].model_dump(exclude={"b64_json"})
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_IMAGE,
render_image,
schema=vol.Schema(
{
vol.Required("config_entry"): selector.ConfigEntrySelector(
{
"integration": DOMAIN,
}
),
vol.Required("prompt"): cv.string,
vol.Optional("size", default="1024x1024"): vol.In(
("1024x1024", "1024x1792", "1792x1024")
),
vol.Optional("quality", default="standard"): vol.In(("standard", "hd")),
vol.Optional("style", default="vivid"): vol.In(("vivid", "natural")),
}
),
supports_response=SupportsResponse.ONLY,
)
return True
async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bool:
"""Set up OpenAI Conversation from a config entry."""
client = openai.AsyncOpenAI(
api_key=entry.data[CONF_API_KEY],
http_client=get_async_client(hass),
)
# Cache current platform data which gets added to each request (caching done by library)
_ = await hass.async_add_executor_job(client.platform_headers)
try:
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
except openai.AuthenticationError as err:
LOGGER.error("Invalid API key: %s", err)
return False
except openai.OpenAIError as err:
raise ConfigEntryNotReady(err) from err
entry.runtime_data = client
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload OpenAI."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)