-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathtest_operator.py
83 lines (66 loc) · 2.89 KB
/
test_operator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import datetime
import os
import unittest
from tempfile import NamedTemporaryFile
from unittest import mock
from airflow.models import DAG
from airflow_clickhouse_plugin.hooks.clickhouse_hook import ClickHouseHook
from airflow_clickhouse_plugin.operators.clickhouse_operator import ClickHouseOperator
_TEST_START_DATE = datetime.datetime.now()
class ClickHouseOperatorTestCase(unittest.TestCase):
@mock.patch(
'airflow_clickhouse_plugin'
'.operators.clickhouse_operator.ClickHouseHook',
)
def test(self, clickhouse_hook_mock: mock.MagicMock):
sql = object()
clickhouse_conn_id = object()
parameters = object()
database = object()
op = ClickHouseOperator(
task_id='_', sql=sql, clickhouse_conn_id=clickhouse_conn_id,
parameters=parameters, database=database,
)
op.execute(context=dict())
clickhouse_hook_mock.assert_called_once_with(
clickhouse_conn_id=clickhouse_conn_id,
database=database,
)
clickhouse_hook_mock().run.assert_called_once_with(sql, parameters)
@mock.patch(
'airflow_clickhouse_plugin'
'.operators.clickhouse_operator.ClickHouseHook',
)
def test_defaults(self, clickhouse_hook_mock: mock.MagicMock):
sql = 'SELECT 1'
op = ClickHouseOperator(task_id='_', sql=sql)
op.execute(context=dict())
clickhouse_hook_mock.assert_called_once_with(
clickhouse_conn_id=ClickHouseHook.default_conn_name,
database=None,
)
clickhouse_hook_mock().run.assert_called_once_with(sql, None)
def test_template_fields_overrides(self):
assert ClickHouseOperator.template_fields == ('_sql',)
def test_resolve_template_files_value(self):
with NamedTemporaryFile(suffix='.sql') as sql_file:
sql_file.write(b'{{ ds }}')
sql_file.flush()
sql_file_dir = os.path.dirname(sql_file.name)
sql_file_name = os.path.basename(sql_file.name)
with DAG('test-dag', start_date=_TEST_START_DATE, template_searchpath=sql_file_dir):
task = ClickHouseOperator(task_id='test_task', sql=sql_file_name)
task.resolve_template_files()
assert task._sql == '{{ ds }}'
def test_resolve_template_files_list(self):
with NamedTemporaryFile(suffix='.sql') as sql_file:
sql_file.write(b'{{ ds }}')
sql_file.flush()
sql_file_dir = os.path.dirname(sql_file.name)
sql_file_name = os.path.basename(sql_file.name)
with DAG('test-dag', start_date=_TEST_START_DATE, template_searchpath=sql_file_dir):
task = ClickHouseOperator(task_id='test_task', sql=[sql_file_name, 'some_string'])
task.resolve_template_files()
assert task._sql == ['{{ ds }}', 'some_string']
if __name__ == '__main__':
unittest.main()