Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/download/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,14 @@ mod curl {
})?;
}

// If we didn't get a 20x or 0 ("OK" for files) then return an error
// If we didn't get a 20x or 0 ("OK" for files) then return an error.
// If resuming a download, we need a 206, as a 200 would mean the server ignored
// the range header, resulting in corruption.
let code = handle.response_code()?;
match code {
0 | 200..=299 => {}
_ => {
return Err(DownloadError::HttpStatus(code).into());
}
};
match (resume_from > 0, code) {
(_, 0) | (true, 206) | (false, 200..=299) => {}
_ => return Err(DownloadError::HttpStatus(code).into()),
}

Ok(())
})
Expand Down Expand Up @@ -587,9 +587,13 @@ mod reqwest_be {
.await
.context("error downloading file")?;

if !res.status().is_success() {
let code: u16 = res.status().into();
return Err(anyhow!(DownloadError::HttpStatus(u32::from(code))));
// If a download is being resumed, we expect a 206 response;
// otherwise, if the server ignored the range header,
// an error is thrown preemptively to avoid corruption.
let status = res.status().into();
match (resume_from > 0, status) {
(true, 206) | (false, 200..=299) => {}
_ => return Err(DownloadError::HttpStatus(u32::from(status)).into()),
}

if let Some(len) = res.content_length() {
Expand Down
48 changes: 41 additions & 7 deletions src/download/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod curl {
let target_path = tmpdir.path().join("downloaded");
write_file(&target_path, "123");

let addr = serve_file(b"xxx45".to_vec());
let addr = serve_file(b"xxx45".to_vec(), true);

let from_url = format!("http://{addr}").parse().unwrap();

Expand Down Expand Up @@ -213,7 +213,7 @@ mod reqwest {
let target_path = tmpdir.path().join("downloaded");
write_file(&target_path, "123");

let addr = serve_file(b"xxx45".to_vec());
let addr = serve_file(b"xxx45".to_vec(), true);

let from_url = format!("http://{addr}").parse().unwrap();

Expand Down Expand Up @@ -258,6 +258,33 @@ mod reqwest {
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
}

#[tokio::test]
async fn resume_partial_fails_if_server_ignores_range() {
let _guard = scrub_env().await;
let tmpdir = tmp_dir();
let target_path = tmpdir.path().join("downloaded");
write_file(&target_path, "123");

let addr = serve_file(b"xxx45".to_vec(), false);
let from_url = format!("http://{addr}").parse().unwrap();

Backend::Reqwest(TlsBackend::NativeTls)
.download_to_path(
&from_url,
&target_path,
true,
None,
Duration::from_secs(180),
)
.await
.expect_err("download should fail if server ignores range");

assert!(
!target_path.exists(),
"partial file should have been deleted"
);
}

#[tokio::test]
async fn network_failure_does_not_delete_partial_file() {
let _guard = scrub_env().await;
Expand Down Expand Up @@ -299,11 +326,16 @@ pub fn write_file(path: &Path, contents: &str) {
// A dead simple hyper server implementation.
// For more info, see:
// https://hyper.rs/guides/1/server/hello-world/
async fn run_server(addr_tx: Sender<SocketAddr>, addr: SocketAddr, contents: Vec<u8>) {
async fn run_server(
addr_tx: Sender<SocketAddr>,
addr: SocketAddr,
contents: Vec<u8>,
honor_range: bool,
) {
let svc = service_fn(move |req: Request<hyper::body::Incoming>| {
let contents = contents.clone();
async move {
let res = serve_contents(req, contents);
let res = serve_contents(req, contents, honor_range);
Ok::<_, Infallible>(res)
}
});
Expand Down Expand Up @@ -331,12 +363,12 @@ async fn run_server(addr_tx: Sender<SocketAddr>, addr: SocketAddr, contents: Vec
}
}

pub fn serve_file(contents: Vec<u8>) -> SocketAddr {
pub fn serve_file(contents: Vec<u8>, honor_range: bool) -> SocketAddr {
let addr = ([127, 0, 0, 1], 0).into();
let (addr_tx, addr_rx) = channel();

thread::spawn(move || {
let server = run_server(addr_tx, addr, contents);
let server = run_server(addr_tx, addr, contents, honor_range);
let rt = tokio::runtime::Runtime::new().expect("could not creating Runtime");
rt.block_on(server);
});
Expand All @@ -348,9 +380,11 @@ pub fn serve_file(contents: Vec<u8>) -> SocketAddr {
fn serve_contents(
req: Request<hyper::body::Incoming>,
contents: Vec<u8>,
honor_range: bool,
) -> hyper::Response<Full<Bytes>> {
let mut range_header = None;
let (status, body) = if let Some(range) = req.headers().get(hyper::header::RANGE) {
let (status, body) = if honor_range && let Some(range) = req.headers().get(hyper::header::RANGE)
{
// extract range "bytes={start}-"
let range = range.to_str().expect("unexpected Range header");
assert!(range.starts_with("bytes="));
Expand Down