From 38d5f840ac9a6cc536c1f4d7615126e4a7c4be49 Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Mon, 9 Jan 2017 01:38:38 -0500 Subject: [PATCH] Add url parser close #6 --- environs.py | 16 ++++++++++++++++ tests/test_environs.py | 36 +++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/environs.py b/environs.py index 4a15494..3d3b056 100644 --- a/environs.py +++ b/environs.py @@ -5,6 +5,11 @@ import json as pyjson import os import re +try: + import urllib.parse as urlparse +except ImportError: + # Python 2 + import urlparse import marshmallow as ma from read_env import read_env as _read_env @@ -94,6 +99,16 @@ def _preprocess_dict(value, **kwargs): def _preprocess_json(value, **kwargs): return pyjson.loads(value) +class URLField(ma.fields.URL): + def _serialize(self, value, attr, obj): + return value.geturl() + + # Override deserialize rather than _deserialize because we need + # to call urlparse *after* validation has occurred + def deserialize(self, value, attr=None, data=None): + ret = super(URLField, self).deserialize(value, attr, data) + return urlparse.urlparse(ret) + class Env(object): """An environment variable reader.""" __call__ = _field2method(ma.fields.Field, '__call__') @@ -115,6 +130,7 @@ def __init__(self): date=_field2method(ma.fields.Date, 'date'), timedelta=_field2method(ma.fields.TimeDelta, 'timedelta'), uuid=_field2method(ma.fields.UUID, 'uuid'), + url=_field2method(URLField, 'url'), ) def __repr__(self): diff --git a/tests/test_environs.py b/tests/test_environs.py index 8f2cb29..712b585 100644 --- a/tests/test_environs.py +++ b/tests/test_environs.py @@ -3,6 +3,11 @@ import uuid from decimal import Decimal import datetime as dt +try: + import urllib.parse as urlparse +except ImportError: + # Python 2 + import urlparse import pytest from marshmallow import fields, validate @@ -17,7 +22,7 @@ def _set_env(envvars): return _set_env -@pytest.fixture +@pytest.fixture(scope='function') def env(): return environs.Env() @@ -44,7 +49,7 @@ def test_int_cast(self, set_env, env): def test_invalid_int(self, set_env, env): set_env({'INT': 'invalid'}) with pytest.raises(environs.EnvError) as excinfo: - env.int('INT') == 42 + env.int('INT') assert 'Environment variable "INT" invalid' in excinfo.value.args[0] def test_float_cast(self, set_env, env): @@ -122,6 +127,23 @@ def test_uuid_cast(self, set_env, env): set_env({'UUID': str(uid)}) assert env.uuid('UUID') == uid + def test_url_cast(self, set_env, env): + set_env({'URL': 'http://stevenloria.com/projects/?foo=42'}) + res = env.url('URL') + assert isinstance(res, urlparse.ParseResult) + + @pytest.mark.parametrize('url', + [ + 'foo', + '42', + 'foo@bar', + ]) + def test_invalid_url(self, url, set_env, env): + set_env({'URL': url}) + with pytest.raises(environs.EnvError) as excinfo: + env.url('URL') + assert 'Environment variable "URL" invalid' in excinfo.value.args[0] + class TestProxiedVariables: @@ -248,17 +270,25 @@ def _deserialize(self, value, *args, **kwargs): class TestDumping: def test_dump(self, set_env, env): dtime = dt.datetime.utcnow() - set_env({'STR': 'foo', 'INT': '42', 'DTIME': dtime.isoformat()}) + set_env({ + 'STR': 'foo', + 'INT': '42', + 'DTIME': dtime.isoformat(), + 'URLPARSE': 'http://stevenloria.com/projects/?foo=42', + }) env.str('STR') env.int('INT') env.datetime('DTIME') + env.url('URLPARSE') result = env.dump() assert result['STR'] == 'foo' assert result['INT'] == 42 assert 'DTIME' in result assert type(result['DTIME']) is str + assert type(result['URLPARSE']) is str + assert result['URLPARSE'] == 'http://stevenloria.com/projects/?foo=42' def test_env_with_custom_parser(self, set_env, env): @env.parser_for('url')