From 09355686ddb49e94bb8ffd83cfb84fb18dc73903 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 18 Dec 2024 10:59:34 +0100 Subject: [PATCH 1/6] Simplify logic for download backend notification --- src/utils/mod.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index e1e144cfa2..bbcfae035d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -252,8 +252,8 @@ async fn download_file_( let use_rustls = process .var_os("RUSTUP_USE_RUSTLS") .is_none_or(|it| it != "0"); - let (backend, notification) = if use_curl_backend { - (Backend::Curl, Notification::UsingCurl) + let backend = if use_curl_backend { + Backend::Curl } else { let tls_backend = if use_rustls { TlsBackend::Rustls @@ -267,9 +267,14 @@ async fn download_file_( TlsBackend::Rustls } }; - (Backend::Reqwest(tls_backend), Notification::UsingReqwest) + Backend::Reqwest(tls_backend) }; - notify_handler(notification); + + notify_handler(match backend { + Backend::Curl => Notification::UsingCurl, + Backend::Reqwest(_) => Notification::UsingReqwest, + }); + let res = download_to_path_with_backend(backend, url, path, resume_from_partial, Some(callback)) .await; From 45ab1fd8d33f115f2f98a1536b9e56398798deff Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 18 Dec 2024 11:07:00 +0100 Subject: [PATCH 2/6] Implement more complete backend selection --- src/utils/mod.rs | 78 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index bbcfae035d..0b30019965 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -12,6 +12,8 @@ use anyhow::{anyhow, bail, Context, Result}; use retry::delay::{jitter, Fibonacci}; use retry::{retry, OperationResult}; use sha2::Sha256; +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] +use tracing::info; use url::Url; use crate::errors::*; @@ -210,8 +212,9 @@ async fn download_file_( notify_handler: &dyn Fn(Notification<'_>), process: &Process, ) -> Result<()> { - use download::download_to_path_with_backend; - use download::{Backend, Event, TlsBackend}; + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + use download::TlsBackend; + use download::{download_to_path_with_backend, Backend, Event}; use sha2::Digest; use std::cell::RefCell; @@ -246,28 +249,59 @@ async fn download_file_( // Download the file // Keep the curl env var around for a bit - let use_curl_backend = process - .var_os("RUSTUP_USE_CURL") - .is_some_and(|it| it != "0"); - let use_rustls = process - .var_os("RUSTUP_USE_RUSTLS") - .is_none_or(|it| it != "0"); - let backend = if use_curl_backend { - Backend::Curl - } else { - let tls_backend = if use_rustls { - TlsBackend::Rustls - } else { - #[cfg(feature = "reqwest-native-tls")] - { - TlsBackend::NativeTls + let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0"); + let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); + + let backend = match (use_curl_backend, use_rustls) { + // If environment specifies a backend that's unavailable, error out + #[cfg(not(feature = "reqwest-rustls-tls"))] + (_, Some(true)) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" + )); + } + #[cfg(not(feature = "reqwest-native-tls"))] + (_, Some(false)) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" + )); + } + #[cfg(not(feature = "curl-backend"))] + (Some(true), _) => { + return Err(anyhow!( + "RUSTUP_USE_CURL is set, but this rustup distribution was not built with the curl-backend feature" + )); + } + + // Positive selections, from least preferred to most preferred + #[cfg(feature = "curl-backend")] + (Some(true), None) => Backend::Curl, + #[cfg(feature = "reqwest-native-tls")] + (_, Some(false)) => { + if use_curl_backend == Some(true) { + info!("RUSTUP_USE_CURL is set and RUSTUP_USE_RUSTLS is set to off, using reqwest with native-tls"); } - #[cfg(not(feature = "reqwest-native-tls"))] - { - TlsBackend::Rustls + Backend::Reqwest(TlsBackend::NativeTls) + } + #[cfg(feature = "reqwest-rustls-tls")] + _ => { + if use_curl_backend == Some(true) { + info!( + "both RUSTUP_USE_CURL and RUSTUP_USE_RUSTLS are set, using reqwest with rustls" + ); } - }; - Backend::Reqwest(tls_backend) + Backend::Reqwest(TlsBackend::Rustls) + } + + // Falling back if only one backend is available + #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] + _ => Backend::Reqwest(TlsBackend::NativeTls), + #[cfg(all( + not(feature = "reqwest-rustls-tls"), + not(feature = "reqwest-native-tls"), + feature = "curl-backend" + ))] + _ => Backend::Curl, }; notify_handler(match backend { From 02a94c4cb8db386c852ca95a85c127bf893aaa0e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 20 Dec 2024 15:04:21 +0100 Subject: [PATCH 3/6] download: remove intermediate reqwest-backend feature --- Cargo.toml | 10 ++-------- ci/run.bash | 4 ++-- download/Cargo.toml | 7 ++++--- download/src/lib.rs | 15 +++++++++++---- download/tests/download-reqwest-resume.rs | 2 +- download/tests/read-proxy-env.rs | 2 +- 6 files changed, 21 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7b45f6c870..87d2a1a211 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,14 +12,8 @@ build = "build.rs" [features] curl-backend = ["download/curl-backend"] -default = [ - "curl-backend", - "reqwest-backend", - "reqwest-native-tls", - "reqwest-rustls-tls", -] +default = ["curl-backend", "reqwest-native-tls", "reqwest-rustls-tls"] -reqwest-backend = ["download/reqwest-backend"] vendored-openssl = ['openssl/vendored'] reqwest-native-tls = ["download/reqwest-native-tls"] @@ -183,4 +177,4 @@ opt-level = 0 [package.metadata.cargo-all-features] # Building with no web backend will error. -always_include_features = ["reqwest-backend", "reqwest-rustls-tls"] +always_include_features = ["reqwest-rustls-tls"] diff --git a/ci/run.bash b/ci/run.bash index 58eb064471..f0193d3448 100644 --- a/ci/run.bash +++ b/ci/run.bash @@ -8,7 +8,7 @@ rustc -vV cargo -vV -FEATURES=('--no-default-features' '--features' 'curl-backend,reqwest-backend,reqwest-native-tls') +FEATURES=('--no-default-features' '--features' 'curl-backend,reqwest-native-tls') case "$(uname -s)" in *NT* ) ;; # Windows NT * ) FEATURES+=('--features' 'vendored-openssl') ;; @@ -38,7 +38,7 @@ target_cargo() { target_cargo build download_pkg_test() { - features=('--no-default-features' '--features' 'curl-backend,reqwest-backend,reqwest-native-tls') + features=('--no-default-features' '--features' 'curl-backend,reqwest-native-tls') case "$TARGET" in # these platforms aren't supported by ring: powerpc* ) ;; diff --git a/download/Cargo.toml b/download/Cargo.toml index f7f3374cf2..ced09b067a 100644 --- a/download/Cargo.toml +++ b/download/Cargo.toml @@ -5,12 +5,13 @@ edition.workspace = true license.workspace = true [features] -default = ["reqwest-backend", "reqwest-rustls-tls", "reqwest-native-tls"] +default = ["reqwest-rustls-tls", "reqwest-native-tls"] curl-backend = ["curl"] -reqwest-backend = ["reqwest", "env_proxy"] -reqwest-native-tls = ["reqwest/native-tls"] +reqwest-native-tls = ["reqwest/native-tls", "dep:reqwest", "dep:env_proxy"] reqwest-rustls-tls = [ "reqwest/rustls-tls-manual-roots-no-provider", + "dep:env_proxy", + "dep:reqwest", "dep:rustls", "dep:rustls-platform-verifier", ] diff --git a/download/src/lib.rs b/download/src/lib.rs index 3a74157128..6e72da2172 100644 --- a/download/src/lib.rs +++ b/download/src/lib.rs @@ -167,7 +167,11 @@ pub async fn download_to_path_with_backend_( Ok::<(), anyhow::Error>(()) } -#[cfg(all(not(feature = "reqwest-backend"), not(feature = "curl-backend")))] +#[cfg(all( + not(feature = "reqwest-rustls-tls"), + not(feature = "reqwest-native-tls"), + not(feature = "curl-backend") +))] compile_error!("Must enable at least one backend"); /// Download via libcurl; encrypt with the native (or OpenSSl) TLS @@ -284,7 +288,7 @@ pub mod curl { } } -#[cfg(feature = "reqwest-backend")] +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] pub mod reqwest_be { #[cfg(all( not(feature = "reqwest-rustls-tls"), @@ -480,7 +484,7 @@ pub enum DownloadError { Message(String), #[error(transparent)] IoError(#[from] std::io::Error), - #[cfg(feature = "reqwest-backend")] + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] #[error(transparent)] Reqwest(#[from] ::reqwest::Error), #[cfg(feature = "curl-backend")] @@ -504,7 +508,10 @@ pub mod curl { } } -#[cfg(not(feature = "reqwest-backend"))] +#[cfg(all( + not(feature = "reqwest-rustls-tls"), + not(feature = "reqwest-native-tls") +))] pub mod reqwest_be { use anyhow::{anyhow, Result}; use url::Url; diff --git a/download/tests/download-reqwest-resume.rs b/download/tests/download-reqwest-resume.rs index dec87a8ce1..5cd6ba69b7 100644 --- a/download/tests/download-reqwest-resume.rs +++ b/download/tests/download-reqwest-resume.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "reqwest-backend")] +#![cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Mutex; diff --git a/download/tests/read-proxy-env.rs b/download/tests/read-proxy-env.rs index fbe284fef4..0b92155047 100644 --- a/download/tests/read-proxy-env.rs +++ b/download/tests/read-proxy-env.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "reqwest-backend")] +#![cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] use std::env::{remove_var, set_var}; use std::error::Error; From 5f1009831ae94f6cba42e3e908cde000a05a2d02 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 20 Dec 2024 15:10:10 +0100 Subject: [PATCH 4/6] download: attach download functions to Backend type --- download/src/lib.rs | 241 +++++++++++----------- download/tests/download-curl-resume.rs | 53 ++--- download/tests/download-reqwest-resume.rs | 63 +++--- src/utils/mod.rs | 8 +- 4 files changed, 182 insertions(+), 183 deletions(-) diff --git a/download/src/lib.rs b/download/src/lib.rs index 6e72da2172..01080eb087 100644 --- a/download/src/lib.rs +++ b/download/src/lib.rs @@ -30,143 +30,146 @@ pub enum Backend { Reqwest(TlsBackend), } -#[derive(Debug, Copy, Clone)] -pub enum TlsBackend { - Rustls, - NativeTls, -} - -#[derive(Debug, Copy, Clone)] -pub enum Event<'a> { - ResumingPartialDownload, - /// Received the Content-Length of the to-be downloaded data. - DownloadContentLengthReceived(u64), - /// Received some data. - DownloadDataReceived(&'a [u8]), -} - -type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>; +impl Backend { + pub async fn download_to_path( + self, + url: &Url, + path: &Path, + resume_from_partial: bool, + callback: Option>, + ) -> Result<()> { + let Err(err) = self + .download_impl(url, path, resume_from_partial, callback) + .await + else { + return Ok(()); + }; -async fn download_with_backend( - backend: Backend, - url: &Url, - resume_from: u64, - callback: DownloadCallback<'_>, -) -> Result<()> { - match backend { - Backend::Curl => curl::download(url, resume_from, callback), - Backend::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls).await, + // TODO: We currently clear up the cached download on any error, should we restrict it to a subset? + Err( + if let Err(file_err) = remove_file(path).context("cleaning up cached downloads") { + file_err.context(err) + } else { + err + }, + ) } -} - -pub async fn download_to_path_with_backend( - backend: Backend, - url: &Url, - path: &Path, - resume_from_partial: bool, - callback: Option>, -) -> Result<()> { - let Err(err) = - download_to_path_with_backend_(backend, url, path, resume_from_partial, callback).await - else { - return Ok(()); - }; - - // TODO: We currently clear up the cached download on any error, should we restrict it to a subset? - Err( - if let Err(file_err) = remove_file(path).context("cleaning up cached downloads") { - file_err.context(err) - } else { - err - }, - ) -} -pub async fn download_to_path_with_backend_( - backend: Backend, - url: &Url, - path: &Path, - resume_from_partial: bool, - callback: Option>, -) -> Result<()> { - use std::cell::RefCell; - use std::fs::OpenOptions; - use std::io::{Read, Seek, SeekFrom, Write}; - - let (file, resume_from) = if resume_from_partial { - // TODO: blocking call - let possible_partial = OpenOptions::new().read(true).open(path); - - let downloaded_so_far = if let Ok(mut partial) = possible_partial { - if let Some(cb) = callback { - cb(Event::ResumingPartialDownload)?; - - let mut buf = vec![0; 32768]; - let mut downloaded_so_far = 0; - loop { - let n = partial.read(&mut buf)?; - downloaded_so_far += n as u64; - if n == 0 { - break; + async fn download_impl( + self, + url: &Url, + path: &Path, + resume_from_partial: bool, + callback: Option>, + ) -> Result<()> { + use std::cell::RefCell; + use std::fs::OpenOptions; + use std::io::{Read, Seek, SeekFrom, Write}; + + let (file, resume_from) = if resume_from_partial { + // TODO: blocking call + let possible_partial = OpenOptions::new().read(true).open(path); + + let downloaded_so_far = if let Ok(mut partial) = possible_partial { + if let Some(cb) = callback { + cb(Event::ResumingPartialDownload)?; + + let mut buf = vec![0; 32768]; + let mut downloaded_so_far = 0; + loop { + let n = partial.read(&mut buf)?; + downloaded_so_far += n as u64; + if n == 0 { + break; + } + cb(Event::DownloadDataReceived(&buf[..n]))?; } - cb(Event::DownloadDataReceived(&buf[..n]))?; - } - downloaded_so_far + downloaded_so_far + } else { + let file_info = partial.metadata()?; + file_info.len() + } } else { - let file_info = partial.metadata()?; - file_info.len() - } + 0 + }; + + // TODO: blocking call + let mut possible_partial = OpenOptions::new() + .write(true) + .create(true) + .truncate(false) + .open(path) + .context("error opening file for download")?; + + possible_partial.seek(SeekFrom::End(0))?; + + (possible_partial, downloaded_so_far) } else { - 0 + ( + OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .context("error creating file for download")?, + 0, + ) }; - // TODO: blocking call - let mut possible_partial = OpenOptions::new() - .write(true) - .create(true) - .truncate(false) - .open(path) - .context("error opening file for download")?; + let file = RefCell::new(file); - possible_partial.seek(SeekFrom::End(0))?; + // TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange. + self.download(url, resume_from, &|event| { + if let Event::DownloadDataReceived(data) = event { + file.borrow_mut() + .write_all(data) + .context("unable to write download to disk")?; + } + match callback { + Some(cb) => cb(event), + None => Ok(()), + } + }) + .await?; - (possible_partial, downloaded_so_far) - } else { - ( - OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(path) - .context("error creating file for download")?, - 0, - ) - }; + file.borrow_mut() + .sync_data() + .context("unable to sync download to disk")?; - let file = RefCell::new(file); + Ok::<(), anyhow::Error>(()) + } - // TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange. - download_with_backend(backend, url, resume_from, &|event| { - if let Event::DownloadDataReceived(data) = event { - file.borrow_mut() - .write_all(data) - .context("unable to write download to disk")?; - } - match callback { - Some(cb) => cb(event), - None => Ok(()), + async fn download( + self, + url: &Url, + resume_from: u64, + callback: DownloadCallback<'_>, + ) -> Result<()> { + match self { + Self::Curl => curl::download(url, resume_from, callback), + Self::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls).await, } - }) - .await?; + } +} - file.borrow_mut() - .sync_data() - .context("unable to sync download to disk")?; +#[derive(Debug, Copy, Clone)] +pub enum TlsBackend { + Rustls, + NativeTls, +} - Ok::<(), anyhow::Error>(()) +#[derive(Debug, Copy, Clone)] +pub enum Event<'a> { + ResumingPartialDownload, + /// Received the Content-Length of the to-be downloaded data. + DownloadContentLengthReceived(u64), + /// Received some data. + DownloadDataReceived(&'a [u8]), } +type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>; + #[cfg(all( not(feature = "reqwest-rustls-tls"), not(feature = "reqwest-native-tls"), diff --git a/download/tests/download-curl-resume.rs b/download/tests/download-curl-resume.rs index 85cdffd367..c31e94323c 100644 --- a/download/tests/download-curl-resume.rs +++ b/download/tests/download-curl-resume.rs @@ -20,7 +20,8 @@ async fn partially_downloaded_file_gets_resumed_from_byte_offset() { write_file(&target_path, "123"); let from_url = Url::from_file_path(&from_path).unwrap(); - download_to_path_with_backend(Backend::Curl, &from_url, &target_path, true, None) + Backend::Curl + .download_to_path(&from_url, &target_path, true, None) .await .expect("Test download failed"); @@ -41,34 +42,34 @@ async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { let callback_len = Mutex::new(None); let received_in_callback = Mutex::new(Vec::new()); - download_to_path_with_backend( - Backend::Curl, - &from_url, - &target_path, - true, - Some(&|msg| { - match msg { - Event::ResumingPartialDownload => { - assert!(!callback_partial.load(Ordering::SeqCst)); - callback_partial.store(true, Ordering::SeqCst); - } - Event::DownloadContentLengthReceived(len) => { - let mut flag = callback_len.lock().unwrap(); - assert!(flag.is_none()); - *flag = Some(len); - } - Event::DownloadDataReceived(data) => { - for b in data.iter() { - received_in_callback.lock().unwrap().push(*b); + Backend::Curl + .download_to_path( + &from_url, + &target_path, + true, + Some(&|msg| { + match msg { + Event::ResumingPartialDownload => { + assert!(!callback_partial.load(Ordering::SeqCst)); + callback_partial.store(true, Ordering::SeqCst); + } + Event::DownloadContentLengthReceived(len) => { + let mut flag = callback_len.lock().unwrap(); + assert!(flag.is_none()); + *flag = Some(len); + } + Event::DownloadDataReceived(data) => { + for b in data.iter() { + received_in_callback.lock().unwrap().push(*b); + } } } - } - Ok(()) - }), - ) - .await - .expect("Test download failed"); + Ok(()) + }), + ) + .await + .expect("Test download failed"); assert!(callback_partial.into_inner()); assert_eq!(*callback_len.lock().unwrap(), Some(5)); diff --git a/download/tests/download-reqwest-resume.rs b/download/tests/download-reqwest-resume.rs index 5cd6ba69b7..727ee93b81 100644 --- a/download/tests/download-reqwest-resume.rs +++ b/download/tests/download-reqwest-resume.rs @@ -20,15 +20,10 @@ async fn resume_partial_from_file_url() { write_file(&target_path, "123"); let from_url = Url::from_file_path(&from_path).unwrap(); - download_to_path_with_backend( - Backend::Reqwest(TlsBackend::NativeTls), - &from_url, - &target_path, - true, - None, - ) - .await - .expect("Test download failed"); + Backend::Reqwest(TlsBackend::NativeTls) + .download_to_path(&from_url, &target_path, true, None) + .await + .expect("Test download failed"); assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); } @@ -47,34 +42,34 @@ async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { let callback_len = Mutex::new(None); let received_in_callback = Mutex::new(Vec::new()); - download_to_path_with_backend( - Backend::Reqwest(TlsBackend::NativeTls), - &from_url, - &target_path, - true, - Some(&|msg| { - match msg { - Event::ResumingPartialDownload => { - assert!(!callback_partial.load(Ordering::SeqCst)); - callback_partial.store(true, Ordering::SeqCst); - } - Event::DownloadContentLengthReceived(len) => { - let mut flag = callback_len.lock().unwrap(); - assert!(flag.is_none()); - *flag = Some(len); - } - Event::DownloadDataReceived(data) => { - for b in data.iter() { - received_in_callback.lock().unwrap().push(*b); + Backend::Reqwest(TlsBackend::NativeTls) + .download_to_path( + &from_url, + &target_path, + true, + Some(&|msg| { + match msg { + Event::ResumingPartialDownload => { + assert!(!callback_partial.load(Ordering::SeqCst)); + callback_partial.store(true, Ordering::SeqCst); + } + Event::DownloadContentLengthReceived(len) => { + let mut flag = callback_len.lock().unwrap(); + assert!(flag.is_none()); + *flag = Some(len); + } + Event::DownloadDataReceived(data) => { + for b in data.iter() { + received_in_callback.lock().unwrap().push(*b); + } } } - } - Ok(()) - }), - ) - .await - .expect("Test download failed"); + Ok(()) + }), + ) + .await + .expect("Test download failed"); assert!(callback_partial.into_inner()); assert_eq!(*callback_len.lock().unwrap(), Some(5)); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 0b30019965..adb9e0b063 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -214,7 +214,7 @@ async fn download_file_( ) -> Result<()> { #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] use download::TlsBackend; - use download::{download_to_path_with_backend, Backend, Event}; + use download::{Backend, Event}; use sha2::Digest; use std::cell::RefCell; @@ -309,9 +309,9 @@ async fn download_file_( Backend::Reqwest(_) => Notification::UsingReqwest, }); - let res = - download_to_path_with_backend(backend, url, path, resume_from_partial, Some(callback)) - .await; + let res = backend + .download_to_path(url, path, resume_from_partial, Some(callback)) + .await; notify_handler(Notification::DownloadFinished); From 75d17e200a28345022acf6b1dce5558b703f6d40 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 20 Dec 2024 15:14:19 +0100 Subject: [PATCH 5/6] download: simplify feature guards --- download/src/lib.rs | 55 ++++++++++----------------------------------- src/utils/mod.rs | 2 ++ 2 files changed, 14 insertions(+), 43 deletions(-) diff --git a/download/src/lib.rs b/download/src/lib.rs index 01080eb087..aa0899d5f6 100644 --- a/download/src/lib.rs +++ b/download/src/lib.rs @@ -26,7 +26,9 @@ const REQWEST_RUSTLS_TLS_USER_AGENT: &str = #[derive(Debug, Copy, Clone)] pub enum Backend { + #[cfg(feature = "curl-backend")] Curl, + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] Reqwest(TlsBackend), } @@ -140,6 +142,14 @@ impl Backend { Ok::<(), anyhow::Error>(()) } + #[cfg_attr( + all( + not(feature = "curl-backend"), + not(feature = "reqwest-rustls-tls"), + not(feature = "reqwest-native-tls") + ), + allow(unused_variables) + )] async fn download( self, url: &Url, @@ -147,7 +157,9 @@ impl Backend { callback: DownloadCallback<'_>, ) -> Result<()> { match self { + #[cfg(feature = "curl-backend")] Self::Curl => curl::download(url, resume_from, callback), + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] Self::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls).await, } } @@ -170,13 +182,6 @@ pub enum Event<'a> { type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>; -#[cfg(all( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls"), - not(feature = "curl-backend") -))] -compile_error!("Must enable at least one backend"); - /// Download via libcurl; encrypt with the native (or OpenSSl) TLS /// stack via libcurl #[cfg(feature = "curl-backend")] @@ -494,39 +499,3 @@ pub enum DownloadError { #[error(transparent)] CurlError(#[from] ::curl::Error), } - -#[cfg(not(feature = "curl-backend"))] -pub mod curl { - use anyhow::{anyhow, Result}; - use url::Url; - - use super::{DownloadError, Event}; - - pub fn download( - _url: &Url, - _resume_from: u64, - _callback: &dyn Fn(Event<'_>) -> Result<()>, - ) -> Result<()> { - Err(anyhow!(DownloadError::BackendUnavailable("curl"))) - } -} - -#[cfg(all( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls") -))] -pub mod reqwest_be { - use anyhow::{anyhow, Result}; - use url::Url; - - use super::{DownloadError, Event, TlsBackend}; - - pub async fn download( - _url: &Url, - _resume_from: u64, - _callback: &dyn Fn(Event<'_>) -> Result<()>, - _tls: TlsBackend, - ) -> Result<()> { - Err(anyhow!(DownloadError::BackendUnavailable("reqwest"))) - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index adb9e0b063..5f4eb1db7f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -305,7 +305,9 @@ async fn download_file_( }; notify_handler(match backend { + #[cfg(feature = "curl-backend")] Backend::Curl => Notification::UsingCurl, + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] Backend::Reqwest(_) => Notification::UsingReqwest, }); From 168574f997b78a7504838cb3d3b50a897cdc9aef Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 20 Dec 2024 15:25:43 +0100 Subject: [PATCH 6/6] download: clean up TLS feature guards --- download/src/lib.rs | 56 +++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/download/src/lib.rs b/download/src/lib.rs index aa0899d5f6..73faeaa6d9 100644 --- a/download/src/lib.rs +++ b/download/src/lib.rs @@ -160,17 +160,39 @@ impl Backend { #[cfg(feature = "curl-backend")] Self::Curl => curl::download(url, resume_from, callback), #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - Self::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls).await, + Self::Reqwest(tls) => tls.download(url, resume_from, callback).await, } } } +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] #[derive(Debug, Copy, Clone)] pub enum TlsBackend { + #[cfg(feature = "reqwest-rustls-tls")] Rustls, + #[cfg(feature = "reqwest-native-tls")] NativeTls, } +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] +impl TlsBackend { + async fn download( + self, + url: &Url, + resume_from: u64, + callback: DownloadCallback<'_>, + ) -> Result<()> { + let client = match self { + #[cfg(feature = "reqwest-rustls-tls")] + Self::Rustls => &reqwest_be::CLIENT_RUSTLS_TLS, + #[cfg(feature = "reqwest-native-tls")] + Self::NativeTls => &reqwest_be::CLIENT_NATIVE_TLS, + }; + + reqwest_be::download(url, resume_from, callback, client).await + } +} + #[derive(Debug, Copy, Clone)] pub enum Event<'a> { ResumingPartialDownload, @@ -298,12 +320,6 @@ pub mod curl { #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] pub mod reqwest_be { - #[cfg(all( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls") - ))] - compile_error!("Must select a reqwest TLS backend"); - use std::io; #[cfg(feature = "reqwest-rustls-tls")] use std::sync::Arc; @@ -320,20 +336,20 @@ pub mod reqwest_be { use tokio_stream::StreamExt; use url::Url; - use super::{DownloadError, Event, TlsBackend}; + use super::{DownloadError, Event}; pub async fn download( url: &Url, resume_from: u64, callback: &dyn Fn(Event<'_>) -> Result<()>, - tls: TlsBackend, + client: &Client, ) -> Result<()> { // Short-circuit reqwest for the "file:" URL scheme if download_from_file_url(url, resume_from, callback)? { return Ok(()); } - let res = request(url, resume_from, tls) + let res = request(url, resume_from, client) .await .context("failed to make network request")?; @@ -367,7 +383,7 @@ pub mod reqwest_be { } #[cfg(feature = "reqwest-rustls-tls")] - static CLIENT_RUSTLS_TLS: LazyLock = LazyLock::new(|| { + pub(super) static CLIENT_RUSTLS_TLS: LazyLock = LazyLock::new(|| { let catcher = || { client_generic() .use_preconfigured_tls( @@ -393,7 +409,7 @@ pub mod reqwest_be { }); #[cfg(feature = "reqwest-native-tls")] - static CLIENT_DEFAULT_TLS: LazyLock = LazyLock::new(|| { + pub(super) static CLIENT_NATIVE_TLS: LazyLock = LazyLock::new(|| { let catcher = || { client_generic() .user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT) @@ -416,22 +432,8 @@ pub mod reqwest_be { async fn request( url: &Url, resume_from: u64, - backend: TlsBackend, + client: &Client, ) -> Result { - let client: &Client = match backend { - #[cfg(feature = "reqwest-rustls-tls")] - TlsBackend::Rustls => &CLIENT_RUSTLS_TLS, - #[cfg(not(feature = "reqwest-rustls-tls"))] - TlsBackend::Rustls => { - return Err(DownloadError::BackendUnavailable("reqwest rustls")); - } - #[cfg(feature = "reqwest-native-tls")] - TlsBackend::NativeTls => &CLIENT_DEFAULT_TLS, - #[cfg(not(feature = "reqwest-native-tls"))] - TlsBackend::NativeTls => { - return Err(DownloadError::BackendUnavailable("reqwest default TLS")); - } - }; let mut req = client.get(url.as_str()); if resume_from != 0 {