Skip to content

Commit

Permalink
add null-byte check middleware (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
Panaetius authored Apr 8, 2024
1 parent 56c725b commit cd909af
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 12 deletions.
5 changes: 5 additions & 0 deletions bases/renku_data_services/data_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from renku_data_services.migrations.core import run_migrations_for_app
from renku_data_services.storage.rclone import RCloneValidator
from renku_data_services.utils.middleware import validate_null_byte


def create_app() -> Sanic:
Expand Down Expand Up @@ -65,6 +66,8 @@ async def setup_sentry(_):
app.signal("http.routing.after")(_set_transaction)

app = register_all_handlers(app, config)

# Setup prometheus
monitor(app, endpoint_type="url", multiprocess_mode="all", is_middleware=True).expose_endpoint()

if environ.get("CORS_ALLOW_ALL_ORIGINS", "false").lower() == "true":
Expand All @@ -73,6 +76,8 @@ async def setup_sentry(_):
app.config.CORS_ORIGINS = "*"
Extend(app)

app.register_middleware(validate_null_byte, "request")

@app.main_process_start
async def do_migrations(*_):
logger.info("running migrations")
Expand Down
4 changes: 0 additions & 4 deletions components/renku_data_services/crc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ class ResourceClass(ResourcesCompareMixin):
tolerations: list[str] = field(default_factory=list)

def __post_init__(self):
if "\x00" in self.name:
raise ValidationError(message="'\x00' is not allowed in 'name' field.")
if len(self.name) > 40:
raise ValidationError(message="'name' cannot be longer than 40 characters.")
if self.default_storage > self.max_storage:
Expand Down Expand Up @@ -179,8 +177,6 @@ class ResourcePool:

def __post_init__(self):
"""Validate the resource pool after initialization."""
if "\x00" in self.name:
raise ValidationError(message="'\x00' is not allowed in 'name' field.")
if len(self.name) > 40:
raise ValidationError(message="'name' cannot be longer than 40 characters.")
if self.default and not self.public:
Expand Down
4 changes: 0 additions & 4 deletions components/renku_data_services/storage/rclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,6 @@ def validate_config(self, configuration: Union["RCloneConfig", dict[str, Any]],
for key in keys:
value = configuration[key]

if isinstance(value, str) and "\x00" in value:
# validate strings for Postgresql compatibility
raise errors.ValidationError(message=f"Null byte found in value '{value}' for key '{key}'")

option: RCloneOption | None = self.get_option_for_provider(key, provider)

if option is None:
Expand Down
11 changes: 11 additions & 0 deletions components/renku_data_services/utils/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Custom Sanic Middleware."""

from sanic import Request

from renku_data_services import errors


async def validate_null_byte(request: Request):
"""Validate that a request does not contain a null byte."""
if b"\u0000" in request.body:
raise errors.ValidationError(message="Null byte found in request")
2 changes: 2 additions & 0 deletions test/bases/renku_data_services/data_api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sanic import Sanic
from sanic_testing.testing import SanicASGITestClient

from components.renku_data_services.utils.middleware import validate_null_byte
from renku_data_services.app_config.config import Config
from renku_data_services.data_api.app import register_all_handlers
from renku_data_services.users.dummy_kc_api import DummyKeycloakAPI
Expand All @@ -15,6 +16,7 @@ async def sanic_client(app_config: Config, users: list[UserInfo]) -> SanicASGITe
app_config.kc_api = DummyKeycloakAPI(users=get_kc_users(users))
app = Sanic(app_config.app_name)
app = register_all_handlers(app, app_config)
app.register_middleware(validate_null_byte, "request")
await app_config.kc_user_repo.initialize(app_config.kc_api)
await app_config.group_repo.generate_user_namespaces()
return SanicASGITestClient(app)
17 changes: 17 additions & 0 deletions test/bases/renku_data_services/data_api/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,20 @@ async def test_delete_project_members(create_project, sanic_client, user_headers
"last_name": "Doe",
"role": "owner",
}


@pytest.mark.asyncio
async def test_null_byte_middleware(sanic_client, user_headers, regular_user, app_config):
payload = {
"name": "Renku Native \x00Project",
"slug": "project-slug",
"description": "First Renku native project",
"visibility": "public",
"repositories": ["http://renkulab.io/repository-1", "http://renkulab.io/repository-2"],
"namespace": f"{regular_user.first_name}.{regular_user.last_name}",
}

_, response = await sanic_client.post("/api/data/projects", headers=user_headers, json=payload)

assert response.status_code == 422, response.text
assert "Null byte found in request" in response.text
3 changes: 1 addition & 2 deletions test/components/renku_data_services/crc_models/hypothesis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from hypothesis import assume
from hypothesis import strategies as st

Expand All @@ -24,7 +23,7 @@ def make_cpu_float(data) -> dict[str, int | float]:
a_quota_storage = st.integers(min_value=2000, max_value=10000)
a_quota_memory = st.integers(min_value=64, max_value=1000)
a_row_id = st.integers(min_value=1, max_value=SQL_BIGINT_MAX)
a_name = st.text(min_size=5)
a_name = st.text(min_size=5, alphabet=st.characters(codec="utf-8", exclude_characters=["\x00"]))
a_uuid_string = st.uuids(version=4).map(lambda x: str(x))
a_bool = st.booleans()
a_tolerations_list = st.lists(a_uuid_string, min_size=3, max_size=3)
Expand Down
16 changes: 14 additions & 2 deletions test/components/renku_data_services/storage_models/hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@ def s3_configuration(draw):
keys=st.just("provider"), values=st.sampled_from(["Other", "AWS", "GCS"]), min_size=0, max_size=1
)
)
region = draw(st.dictionaries(keys=st.just("region"), values=st.text(), min_size=0, max_size=1))
region = draw(
st.dictionaries(
keys=st.just("region"),
values=st.text(alphabet=st.characters(codec="utf-8", exclude_characters=["\x00"])),
min_size=0,
max_size=1,
)
)
endpoint = draw(st.dictionaries(keys=st.just("endpoint"), values=urls(), min_size=0))
return {"type": "s3", **providers, **region, **endpoint}


@st.composite
def azure_configuration(draw):
account = draw(st.dictionaries(keys=st.just("account"), values=st.text(min_size=5)))
account = draw(
st.dictionaries(
keys=st.just("account"),
values=st.text(min_size=5, alphabet=st.characters(codec="utf-8", exclude_characters=["\x00"])),
)
)
endpoint = draw(st.dictionaries(keys=st.just("endpoint"), values=urls(), min_size=0))
return {"type": "azureblob", **account, **endpoint}

Expand Down

0 comments on commit cd909af

Please sign in to comment.