diff --git a/src/utils.rs b/src/utils.rs index b30050c..845c108 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -115,12 +115,18 @@ impl IntelObject { if resp.status().is_success() { if let Some(content_length) = resp.content_length() { if content_length <= below_size_kb * 1024 { + let content_type = resp.headers().get(header::CONTENT_TYPE).cloned(); let text = resp.text().await?; let text = f(text); - return Ok(HttpResponse::Ok() - .content_type(ContentType::octet_stream()) - .body(text) - .into()); + + return Ok(if let Some(content_type) = content_type { + HttpResponse::Ok() + .content_type(content_type) + .body(text) + .into() + } else { + HttpResponse::Ok().body(text).into() + }); } } } @@ -384,6 +390,46 @@ mod tests { .rewrite_upstream(&mission, 1, |s| s.replace("ipsu", "lore"), &config) .await .unwrap(); + match obj { + IntelResponse::Redirect(_) => panic!("must be response"), + IntelResponse::Response(resp) => { + let body = to_bytes(resp.into_body()).await.unwrap(); + let text = std::str::from_utf8(&*body).unwrap(); + assert_eq!(text, "lorem lorem", "must be rewritten"); + } + } + }) + .await; + } + + #[tokio::test] + async fn must_rewrite_upstream_respect_content_type() { + with_mock(|server, config, mission, _| async move { + let _mock = server + .mock_async(|when, then| { + when.method(Method::GET).path("/with_content_type"); + then.status(200) + .header("content-type", "text/plain") + .body("lorem ipsum"); + }) + .await; + let _mock2 = server + .mock_async(|when, then| { + when.method(Method::GET).path("/without_content_type"); + then.status(200).body("lorem ipsum"); + }) + .await; + let task = Task { + storage: "storage", + origin: server.base_url(), + path: "with_content_type".to_string(), + retry_limit: 0, + }; + let obj = task + .resolve_upstream() + .rewrite_upstream(&mission, 1, |s| s, &config) + .await + .unwrap(); match obj { IntelResponse::Redirect(_) => panic!("must be response"), IntelResponse::Response(resp) => { @@ -393,12 +439,26 @@ mod tests { .unwrap() .to_str() .unwrap(), - "application/octet-stream", - "must be octet-stream" + "text/plain", + "must be text/plain" ); - let body = to_bytes(resp.into_body()).await.unwrap(); - let text = std::str::from_utf8(&*body).unwrap(); - assert_eq!(text, "lorem lorem", "must be rewritten"); + } + } + let task2 = Task { + storage: "storage", + origin: server.base_url(), + path: "without_content_type".to_string(), + retry_limit: 0, + }; + let obj2 = task2 + .resolve_upstream() + .rewrite_upstream(&mission, 1, |s| s, &config) + .await + .unwrap(); + match obj2 { + IntelResponse::Redirect(_) => panic!("must be response"), + IntelResponse::Response(resp) => { + assert!(resp.headers().get("content-type").is_none(), "must be none"); } } })