diff --git a/Cargo.lock b/Cargo.lock index c284de056b..1c3134d195 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,6 +543,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-stream", "url", ] @@ -1827,6 +1828,7 @@ dependencies = [ "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg", ] @@ -1985,6 +1987,7 @@ dependencies = [ "thiserror", "threadpool", "tokio", + "tokio-retry", "toml", "tracing", "tracing-opentelemetry", @@ -2451,6 +2454,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.25.0" @@ -2876,6 +2890,19 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "wasm-streams" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.69" diff --git a/Cargo.toml b/Cargo.toml index 6c722f3824..ed2b7d6e0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,8 +35,6 @@ otel = [ "dep:tracing-subscriber", "dep:opentelemetry", "dep:opentelemetry_sdk", - "dep:tokio", - "dep:tracing", ] # Exports code dependent on private interfaces for the integration test suite @@ -85,11 +83,12 @@ tempfile.workspace = true termcolor.workspace = true thiserror.workspace = true threadpool = "1" -tokio = { workspace = true, optional = true } +tokio-retry.workspace = true +tokio.workspace = true toml = "0.8" tracing-opentelemetry = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true, features = ["env-filter"] } -tracing = { workspace = true, optional = true } +tracing.workspace = true url.workspace = true wait-timeout = "0.2" walkdir = { workspace = true, optional = true } @@ -145,7 +144,9 @@ rustup-macros = { path = "rustup-macros" } tempfile = "3.8" termcolor = "1.2" thiserror = "1.0" -tokio = { version = "1.26.0", default-features = false, features = ["rt-multi-thread"] } +tokio = { version = "1.26.0", default-features = false, features = ["macros", "rt-multi-thread"] } +tokio-retry = { version = "0.3.0" } +tokio-stream = { version = "0.1.14" } tracing = "0.1" tracing-opentelemetry = "0.24" tracing-subscriber = "0.3.16" diff --git a/download/Cargo.toml b/download/Cargo.toml index 3f3fb2455d..a431c0c483 100644 --- a/download/Cargo.toml +++ b/download/Cargo.toml @@ -18,8 +18,10 @@ anyhow.workspace = true curl = { version = "0.4.44", optional = true } env_proxy = { version = "0.4.1", optional = true } once_cell = { workspace = true, optional = true } -reqwest = { version = "0.12", default-features = false, features = ["blocking", "gzip", "socks"], optional = true } +reqwest = { version = "0.12", default-features = false, features = ["blocking", "gzip", "socks", "stream"], optional = true } thiserror.workspace = true +tokio = { workspace = true, default-features = false, features = ["sync"] } +tokio-stream.workspace = true url.workspace = true [dev-dependencies] @@ -27,4 +29,3 @@ http-body-util = "0.1.0" hyper = { version = "1.0", default-features = false, features = ["server", "http1"] } hyper-util = { version = "0.1.1", features = ["tokio"] } tempfile.workspace = true -tokio = { workspace = true, default-features = false, features = ["sync"] } diff --git a/download/src/lib.rs b/download/src/lib.rs index d14162df82..20b1bf2dc3 100644 --- a/download/src/lib.rs +++ b/download/src/lib.rs @@ -5,6 +5,7 @@ use std::path::Path; use anyhow::Context; pub use anyhow::Result; +use std::fs::remove_file; use url::Url; mod errors; @@ -49,7 +50,7 @@ pub enum Event<'a> { type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>; -fn download_with_backend( +async fn download_with_backend( backend: Backend, url: &Url, resume_from: u64, @@ -57,11 +58,34 @@ fn download_with_backend( ) -> Result<()> { match backend { Backend::Curl => curl::download(url, resume_from, callback), - Backend::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls), + Backend::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls).await, } } -pub fn download_to_path_with_backend( +pub async fn download_to_path_with_backend( + backend: Backend, + url: &Url, + path: &Path, + resume_from_partial: bool, + callback: Option>, +) -> Result<()> { + let Err(err) = + download_to_path_with_backend_(backend, url, path, resume_from_partial, callback).await + else { + return Ok(()); + }; + + // TODO: We currently clear up the cached download on any error, should we restrict it to a subset? + Err( + if let Err(file_err) = remove_file(path).context("cleaning up cached downloads") { + file_err.context(err) + } else { + err + }, + ) +} + +pub async fn download_to_path_with_backend_( backend: Backend, url: &Url, path: &Path, @@ -69,88 +93,81 @@ pub fn download_to_path_with_backend( callback: Option>, ) -> Result<()> { use std::cell::RefCell; - use std::fs::remove_file; use std::fs::OpenOptions; use std::io::{Read, Seek, SeekFrom, Write}; - || -> Result<()> { - let (file, resume_from) = if resume_from_partial { - let possible_partial = OpenOptions::new().read(true).open(path); - - let downloaded_so_far = if let Ok(mut partial) = possible_partial { - if let Some(cb) = callback { - cb(Event::ResumingPartialDownload)?; - - let mut buf = vec![0; 32768]; - let mut downloaded_so_far = 0; - loop { - let n = partial.read(&mut buf)?; - downloaded_so_far += n as u64; - if n == 0 { - break; - } - cb(Event::DownloadDataReceived(&buf[..n]))?; + let (file, resume_from) = if resume_from_partial { + // TODO: blocking call + let possible_partial = OpenOptions::new().read(true).open(path); + + let downloaded_so_far = if let Ok(mut partial) = possible_partial { + if let Some(cb) = callback { + cb(Event::ResumingPartialDownload)?; + + let mut buf = vec![0; 32768]; + let mut downloaded_so_far = 0; + loop { + let n = partial.read(&mut buf)?; + downloaded_so_far += n as u64; + if n == 0 { + break; } - - downloaded_so_far - } else { - let file_info = partial.metadata()?; - file_info.len() + cb(Event::DownloadDataReceived(&buf[..n]))?; } - } else { - 0 - }; - - let mut possible_partial = OpenOptions::new() - .write(true) - .create(true) - .truncate(false) - .open(path) - .context("error opening file for download")?; - possible_partial.seek(SeekFrom::End(0))?; - - (possible_partial, downloaded_so_far) + downloaded_so_far + } else { + let file_info = partial.metadata()?; + file_info.len() + } } else { - ( - OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(path) - .context("error creating file for download")?, - 0, - ) + 0 }; - let file = RefCell::new(file); + // TODO: blocking call + let mut possible_partial = OpenOptions::new() + .write(true) + .create(true) + .truncate(false) + .open(path) + .context("error opening file for download")?; - download_with_backend(backend, url, resume_from, &|event| { - if let Event::DownloadDataReceived(data) = event { - file.borrow_mut() - .write_all(data) - .context("unable to write download to disk")?; - } - match callback { - Some(cb) => cb(event), - None => Ok(()), - } - })?; + possible_partial.seek(SeekFrom::End(0))?; - file.borrow_mut() - .sync_data() - .context("unable to sync download to disk")?; - - Ok(()) - }() - .map_err(|e| { - // TODO: We currently clear up the cached download on any error, should we restrict it to a subset? - if let Err(file_err) = remove_file(path).context("cleaning up cached downloads") { - file_err.context(e) - } else { - e + (possible_partial, downloaded_so_far) + } else { + ( + OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .context("error creating file for download")?, + 0, + ) + }; + + let file = RefCell::new(file); + + // TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange. + download_with_backend(backend, url, resume_from, &|event| { + if let Event::DownloadDataReceived(data) = event { + file.borrow_mut() + .write_all(data) + .context("unable to write download to disk")?; + } + match callback { + Some(cb) => cb(event), + None => Ok(()), } }) + .await?; + + file.borrow_mut() + .sync_data() + .context("unable to sync download to disk")?; + + Ok::<(), anyhow::Error>(()) } #[cfg(all(not(feature = "reqwest-backend"), not(feature = "curl-backend")))] @@ -285,15 +302,15 @@ pub mod reqwest_be { use anyhow::{anyhow, Context, Result}; #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-default-tls"))] use once_cell::sync::Lazy; - use reqwest::blocking::{Client, ClientBuilder, Response}; - use reqwest::{header, Proxy}; + use reqwest::{header, Client, ClientBuilder, Proxy, Response}; + use tokio_stream::StreamExt; use url::Url; use super::Event; use super::TlsBackend; use crate::errors::*; - pub fn download( + pub async fn download( url: &Url, resume_from: u64, callback: &dyn Fn(Event<'_>) -> Result<()>, @@ -304,31 +321,26 @@ pub mod reqwest_be { return Ok(()); } - let mut res = request(url, resume_from, tls).context("failed to make network request")?; + let res = request(url, resume_from, tls) + .await + .context("failed to make network request")?; if !res.status().is_success() { let code: u16 = res.status().into(); return Err(anyhow!(DownloadError::HttpStatus(u32::from(code)))); } - let buffer_size = 0x10000; - let mut buffer = vec![0u8; buffer_size]; - - if let Some(len) = res.headers().get(header::CONTENT_LENGTH) { - // TODO possible issues during unwrap? - let len = len.to_str().unwrap().parse::().unwrap() + resume_from; + if let Some(len) = res.content_length() { + let len = len + resume_from; callback(Event::DownloadContentLengthReceived(len))?; } - loop { - let bytes_read = io::Read::read(&mut res, &mut buffer)?; - - if bytes_read != 0 { - callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; - } else { - return Ok(()); - } + let mut stream = res.bytes_stream(); + while let Some(item) = stream.next().await { + let bytes = item?; + callback(Event::DownloadDataReceived(&bytes))?; } + Ok(()) } fn client_generic() -> ClientBuilder { @@ -377,7 +389,7 @@ pub mod reqwest_be { env_proxy::for_url(url).to_url() } - fn request( + async fn request( url: &Url, resume_from: u64, backend: TlsBackend, @@ -402,7 +414,7 @@ pub mod reqwest_be { req = req.header(header::RANGE, format!("bytes={resume_from}-")); } - Ok(req.send()?) + Ok(req.send().await?) } fn download_from_file_url( diff --git a/download/tests/download-curl-resume.rs b/download/tests/download-curl-resume.rs index 3ac0e78871..85cdffd367 100644 --- a/download/tests/download-curl-resume.rs +++ b/download/tests/download-curl-resume.rs @@ -10,8 +10,8 @@ use download::*; mod support; use crate::support::{serve_file, tmp_dir, write_file}; -#[test] -fn partially_downloaded_file_gets_resumed_from_byte_offset() { +#[tokio::test] +async fn partially_downloaded_file_gets_resumed_from_byte_offset() { let tmpdir = tmp_dir(); let from_path = tmpdir.path().join("download-source"); write_file(&from_path, "xxx45"); @@ -21,13 +21,14 @@ fn partially_downloaded_file_gets_resumed_from_byte_offset() { let from_url = Url::from_file_path(&from_path).unwrap(); download_to_path_with_backend(Backend::Curl, &from_url, &target_path, true, None) + .await .expect("Test download failed"); assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); } -#[test] -fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { +#[tokio::test] +async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { let tmpdir = tmp_dir(); let target_path = tmpdir.path().join("downloaded"); write_file(&target_path, "123"); @@ -66,6 +67,7 @@ fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { Ok(()) }), ) + .await .expect("Test download failed"); assert!(callback_partial.into_inner()); diff --git a/download/tests/download-reqwest-resume.rs b/download/tests/download-reqwest-resume.rs index d326da6dc3..189e078f0f 100644 --- a/download/tests/download-reqwest-resume.rs +++ b/download/tests/download-reqwest-resume.rs @@ -10,8 +10,8 @@ use download::*; mod support; use crate::support::{serve_file, tmp_dir, write_file}; -#[test] -fn resume_partial_from_file_url() { +#[tokio::test] +async fn resume_partial_from_file_url() { let tmpdir = tmp_dir(); let from_path = tmpdir.path().join("download-source"); write_file(&from_path, "xxx45"); @@ -27,13 +27,14 @@ fn resume_partial_from_file_url() { true, None, ) + .await .expect("Test download failed"); assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); } -#[test] -fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { +#[tokio::test] +async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { let tmpdir = tmp_dir(); let target_path = tmpdir.path().join("downloaded"); write_file(&target_path, "123"); @@ -72,6 +73,7 @@ fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { Ok(()) }), ) + .await .expect("Test download failed"); assert!(callback_partial.into_inner()); diff --git a/download/tests/read-proxy-env.rs b/download/tests/read-proxy-env.rs index b55bd326f5..435ef2692b 100644 --- a/download/tests/read-proxy-env.rs +++ b/download/tests/read-proxy-env.rs @@ -9,7 +9,7 @@ use std::thread; use std::time::Duration; use env_proxy::for_url; -use reqwest::{blocking::Client, Proxy}; +use reqwest::{Client, Proxy}; use url::Url; static SERIALISE_TESTS: Mutex<()> = Mutex::new(()); @@ -27,8 +27,8 @@ fn scrub_env() { } // Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy -#[test] -fn read_basic_proxy_params() { +#[tokio::test] +async fn read_basic_proxy_params() { let _guard = SERIALISE_TESTS .lock() .expect("Unable to lock the test guard"); @@ -42,8 +42,8 @@ fn read_basic_proxy_params() { } // Tests to verify if socks feature is available and being used -#[test] -fn socks_proxy_request() { +#[tokio::test] +async fn socks_proxy_request() { static CALL_COUNT: AtomicUsize = AtomicUsize::new(0); let _guard = SERIALISE_TESTS .lock() @@ -68,7 +68,7 @@ fn socks_proxy_request() { .timeout(Duration::from_secs(1)) .build() .unwrap(); - let res = client.get(url.as_str()).send(); + let res = client.get(url.as_str()).send().await; if let Err(e) = res { let s = e.source().unwrap(); diff --git a/rustup-macros/src/lib.rs b/rustup-macros/src/lib.rs index 3e1531c7c1..5ccb3ae5cf 100644 --- a/rustup-macros/src/lib.rs +++ b/rustup-macros/src/lib.rs @@ -77,6 +77,8 @@ pub fn unit_test( .into() } +// False positive from clippy :/ +#[allow(clippy::redundant_clone)] fn test_inner(mod_path: String, mut input: ItemFn) -> syn::Result { if input.sig.asyncness.is_some() { let before_ident = format!("{}::before_test_async", mod_path); diff --git a/src/bin/rustup-init.rs b/src/bin/rustup-init.rs index 678ce8cd12..9b624a48a7 100644 --- a/src/bin/rustup-init.rs +++ b/src/bin/rustup-init.rs @@ -19,6 +19,7 @@ use cfg_if::cfg_if; use rs_tracing::{ close_trace_file, close_trace_file_internal, open_trace_file, trace_to_file_internal, }; +use tokio::runtime::Builder; use rustup::cli::common; use rustup::cli::proxy_mode; @@ -26,7 +27,7 @@ use rustup::cli::rustup_mode; #[cfg(windows)] use rustup::cli::self_update; use rustup::cli::setup_mode; -use rustup::currentprocess::{process, varsource::VarSource, with, OSProcess}; +use rustup::currentprocess::{process, varsource::VarSource, with_runtime, OSProcess}; use rustup::env_var::RUST_RECURSION_COUNT_MAX; use rustup::is_proxyable_tools; use rustup::utils::utils::{self, ExitCode}; @@ -36,19 +37,25 @@ fn main() { pre_rustup_main_init(); let process = OSProcess::default(); - with(process.into(), || match maybe_trace_rustup() { - Err(e) => { - common::report_error(&e); - std::process::exit(1); + let mut builder = Builder::new_multi_thread(); + builder.enable_all(); + with_runtime(process.into(), builder, { + async { + match maybe_trace_rustup().await { + Err(e) => { + common::report_error(&e); + std::process::exit(1); + } + Ok(utils::ExitCode(c)) => std::process::exit(c), + } } - Ok(utils::ExitCode(c)) => std::process::exit(c), }); } -fn maybe_trace_rustup() -> Result { +async fn maybe_trace_rustup() -> Result { #[cfg(not(feature = "otel"))] { - run_rustup() + run_rustup().await } #[cfg(feature = "otel")] { @@ -63,52 +70,37 @@ fn maybe_trace_rustup() -> Result { }; use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry}; - // Background submission requires a runtime, and since we're probably - // going to want async eventually, we just use tokio. - let threaded_rt = tokio::runtime::Runtime::new()?; - - let result = threaded_rt.block_on(async { - global::set_text_map_propagator(TraceContextPropagator::new()); - let tracer = opentelemetry_otlp::new_pipeline() - .tracing() - .with_exporter( - opentelemetry_otlp::new_exporter() - .tonic() - .with_timeout(Duration::from_secs(3)), - ) - .with_trace_config( - trace::config() - .with_sampler(Sampler::AlwaysOn) - .with_resource(Resource::new(vec![KeyValue::new( - "service.name", - "rustup", - )])), - ) - .install_batch(opentelemetry_sdk::runtime::Tokio)?; - let env_filter = EnvFilter::try_from_default_env().unwrap_or(EnvFilter::new("INFO")); - let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - let subscriber = Registry::default().with(env_filter).with(telemetry); - tracing::subscriber::set_global_default(subscriber)?; - let result = run_rustup(); - // We're tracing, so block until all spans are exported. - opentelemetry::global::shutdown_tracer_provider(); - result - }); - // default runtime behaviour is to block until nothing is running; - // instead we supply a timeout, as we're either already errored and are - // reporting back without care for lost threads etc... or everything - // completed. - threaded_rt.shutdown_timeout(Duration::from_millis(5)); + global::set_text_map_propagator(TraceContextPropagator::new()); + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_exporter( + opentelemetry_otlp::new_exporter() + .tonic() + .with_timeout(Duration::from_secs(3)), + ) + .with_trace_config( + trace::config() + .with_sampler(Sampler::AlwaysOn) + .with_resource(Resource::new(vec![KeyValue::new("service.name", "rustup")])), + ) + .install_batch(opentelemetry_sdk::runtime::Tokio)?; + let env_filter = EnvFilter::try_from_default_env().unwrap_or(EnvFilter::new("INFO")); + let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); + let subscriber = Registry::default().with(env_filter).with(telemetry); + tracing::subscriber::set_global_default(subscriber)?; + let result = run_rustup().await; + // We're tracing, so block until all spans are exported. + opentelemetry::global::shutdown_tracer_provider(); result } } #[cfg_attr(feature = "otel", tracing::instrument)] -fn run_rustup() -> Result { +async fn run_rustup() -> Result { if let Ok(dir) = process().var("RUSTUP_TRACE_DIR") { open_trace_file!(dir)?; } - let result = run_rustup_inner(); + let result = run_rustup_inner().await; if process().var("RUSTUP_TRACE_DIR").is_ok() { close_trace_file!(); } @@ -116,7 +108,7 @@ fn run_rustup() -> Result { } #[cfg_attr(feature = "otel", tracing::instrument(err))] -fn run_rustup_inner() -> Result { +async fn run_rustup_inner() -> Result { // Guard against infinite proxy recursion. This mostly happens due to // bugs in rustup. do_recursion_guard()?; @@ -127,13 +119,13 @@ fn run_rustup_inner() -> Result { utils::current_exe()?; match process().name().as_deref() { - Some("rustup") => rustup_mode::main(), + Some("rustup") => rustup_mode::main().await, Some(n) if n.starts_with("rustup-setup") || n.starts_with("rustup-init") => { // NB: The above check is only for the prefix of the file // name. Browsers rename duplicates to // e.g. rustup-setup(2), and this allows all variations // to work. - setup_mode::main() + setup_mode::main().await } Some(n) if n.starts_with("rustup-gc-") => { // This is the final uninstallation stage on windows where diff --git a/src/cli/common.rs b/src/cli/common.rs index e8fc3a1ec3..7591fe6f5e 100644 --- a/src/cli/common.rs +++ b/src/cli/common.rs @@ -285,7 +285,7 @@ fn show_channel_updates( Ok(()) } -pub(crate) fn update_all_channels( +pub(crate) async fn update_all_channels( cfg: &Cfg, do_self_update: bool, force_update: bool, @@ -310,7 +310,7 @@ pub(crate) fn update_all_channels( }; if do_self_update { - self_update(show_channel_updates) + self_update(show_channel_updates).await } else { show_channel_updates() } @@ -350,7 +350,8 @@ pub(crate) fn self_update_permitted(explicit: bool) -> Result(before_restart: F) -> Result +/// Performs all of a self-update: check policy, download, apply and exit. +pub(crate) async fn self_update(before_restart: F) -> Result where F: FnOnce() -> Result, { @@ -363,7 +364,7 @@ where SelfUpdatePermission::Permit => {} } - let setup_path = self_update::prepare_update()?; + let setup_path = self_update::prepare_update().await?; before_restart()?; diff --git a/src/cli/rustup_mode.rs b/src/cli/rustup_mode.rs index 19ab09da4c..524d2f0f0e 100644 --- a/src/cli/rustup_mode.rs +++ b/src/cli/rustup_mode.rs @@ -526,7 +526,7 @@ enum SetSubcmd { } #[cfg_attr(feature = "otel", tracing::instrument(fields(args = format!("{:?}", process().args_os().collect::>()))))] -pub fn main() -> Result { +pub async fn main() -> Result { self_update::cleanup_self_updater()?; use clap::error::ErrorKind::*; @@ -595,7 +595,7 @@ pub fn main() -> Result { match subcmd { RustupSubcmd::DumpTestament => common::dump_testament(), - RustupSubcmd::Install { opts } => update(cfg, opts), + RustupSubcmd::Install { opts } => update(cfg, opts).await, RustupSubcmd::Uninstall { opts } => toolchain_remove(cfg, opts), RustupSubcmd::Show { verbose, subcmd } => handle_epipe(match subcmd { None => show(cfg, verbose), @@ -611,33 +611,40 @@ pub fn main() -> Result { no_self_update, force, force_non_host, - } => update( - cfg, - UpdateOpts { - toolchain, - no_self_update, - force, - force_non_host, - ..UpdateOpts::default() - }, - ), + } => { + update( + cfg, + UpdateOpts { + toolchain, + no_self_update, + force, + force_non_host, + ..UpdateOpts::default() + }, + ) + .await + } RustupSubcmd::Toolchain { subcmd } => match subcmd { - ToolchainSubcmd::Install { opts } => update(cfg, opts), + ToolchainSubcmd::Install { opts } => update(cfg, opts).await, ToolchainSubcmd::List { verbose } => { handle_epipe(common::list_toolchains(cfg, verbose)) } - ToolchainSubcmd::Link { toolchain, path } => toolchain_link(cfg, &toolchain, &path), + ToolchainSubcmd::Link { toolchain, path } => { + toolchain_link(cfg, &toolchain, &path).await + } ToolchainSubcmd::Uninstall { opts } => toolchain_remove(cfg, opts), }, - RustupSubcmd::Check => check_updates(cfg), - RustupSubcmd::Default { toolchain } => default_(cfg, toolchain), + RustupSubcmd::Check => check_updates(cfg).await, + RustupSubcmd::Default { toolchain } => default_(cfg, toolchain).await, RustupSubcmd::Target { subcmd } => match subcmd { TargetSubcmd::List { toolchain, installed, } => handle_epipe(target_list(cfg, toolchain, installed)), - TargetSubcmd::Add { target, toolchain } => target_add(cfg, target, toolchain), - TargetSubcmd::Remove { target, toolchain } => target_remove(cfg, target, toolchain), + TargetSubcmd::Add { target, toolchain } => target_add(cfg, target, toolchain).await, + TargetSubcmd::Remove { target, toolchain } => { + target_remove(cfg, target, toolchain).await + } }, RustupSubcmd::Component { subcmd } => match subcmd { ComponentSubcmd::List { @@ -648,17 +655,17 @@ pub fn main() -> Result { component, toolchain, target, - } => component_add(cfg, component, toolchain, target), + } => component_add(cfg, component, toolchain, target).await, ComponentSubcmd::Remove { component, toolchain, target, - } => component_remove(cfg, component, toolchain, target), + } => component_remove(cfg, component, toolchain, target).await, }, RustupSubcmd::Override { subcmd } => match subcmd { OverrideSubcmd::List => handle_epipe(common::list_overrides(cfg)), OverrideSubcmd::Set { toolchain, path } => { - override_add(cfg, toolchain, path.as_deref()) + override_add(cfg, toolchain, path.as_deref()).await } OverrideSubcmd::Unset { path, nonexistent } => { override_remove(cfg, path.as_deref(), nonexistent) @@ -679,7 +686,7 @@ pub fn main() -> Result { #[cfg(not(windows))] RustupSubcmd::Man { command, toolchain } => man(cfg, &command, toolchain), RustupSubcmd::Self_ { subcmd } => match subcmd { - SelfSubcmd::Update => self_update::update(cfg), + SelfSubcmd::Update => self_update::update(cfg).await, SelfSubcmd::Uninstall { no_prompt } => self_update::uninstall(no_prompt), SelfSubcmd::UpgradeData => cfg.upgrade_data().map(|_| ExitCode(0)), }, @@ -698,7 +705,10 @@ pub fn main() -> Result { } } -fn default_(cfg: &Cfg, toolchain: Option) -> Result { +async fn default_( + cfg: &Cfg, + toolchain: Option, +) -> Result { common::warn_if_host_is_emulated(); if let Some(toolchain) = toolchain { @@ -712,7 +722,7 @@ fn default_(cfg: &Cfg, toolchain: Option) -> Resul } MaybeResolvableToolchainName::Some(ResolvableToolchainName::Official(toolchain)) => { let desc = toolchain.resolve(&cfg.get_default_host_triple()?)?; - let status = DistributableToolchain::install_if_not_installed(cfg, &desc)?; + let status = DistributableToolchain::install_if_not_installed(cfg, &desc).await?; cfg.set_default(Some(&(&desc).into()))?; @@ -743,14 +753,14 @@ fn default_(cfg: &Cfg, toolchain: Option) -> Resul Ok(utils::ExitCode(0)) } -fn check_updates(cfg: &Cfg) -> Result { +async fn check_updates(cfg: &Cfg) -> Result { let mut t = process().stdout().terminal(); let channels = cfg.list_channels()?; for channel in channels { let (name, distributable) = channel; let current_version = distributable.show_version()?; - let dist_version = distributable.show_dist_version()?; + let dist_version = distributable.show_dist_version().await?; let _ = t.attr(terminalsource::Attr::Bold); write!(t.lock(), "{name} - ")?; match (current_version, dist_version) { @@ -779,12 +789,12 @@ fn check_updates(cfg: &Cfg) -> Result { } } - check_rustup_update()?; + check_rustup_update().await?; Ok(utils::ExitCode(0)) } -fn update(cfg: &mut Cfg, opts: UpdateOpts) -> Result { +async fn update(cfg: &mut Cfg, opts: UpdateOpts) -> Result { common::warn_if_host_is_emulated(); let self_update_mode = cfg.get_self_update_mode()?; // Priority: no-self-update feature > self_update_mode > no-self-update args. @@ -833,7 +843,8 @@ fn update(cfg: &mut Cfg, opts: UpdateOpts) -> Result { desc.clone(), ) { Ok(mut d) => { - d.update_extra(&components, &targets, profile, force, allow_downgrade)? + d.update_extra(&components, &targets, profile, force, allow_downgrade) + .await? } Err(RustupError::ToolchainNotInstalled(_)) => { crate::toolchain::distributable::DistributableToolchain::install( @@ -843,7 +854,8 @@ fn update(cfg: &mut Cfg, opts: UpdateOpts) -> Result { &targets, profile, force, - )? + ) + .await? .0 } Err(e) => Err(e)?, @@ -860,17 +872,17 @@ fn update(cfg: &mut Cfg, opts: UpdateOpts) -> Result { } } if self_update { - common::self_update(|| Ok(utils::ExitCode(0)))?; + common::self_update(|| Ok(utils::ExitCode(0))).await?; } } else { - common::update_all_channels(cfg, self_update, opts.force)?; + common::update_all_channels(cfg, self_update, opts.force).await?; info!("cleaning up downloads & tmp directories"); utils::delete_dir_contents_following_links(&cfg.download_dir); cfg.tmp_cx.clean(); } if !self_update::NEVER_SELF_UPDATE && self_update_mode == SelfUpdateMode::CheckOnly { - check_rustup_update()?; + check_rustup_update().await?; } if self_update::NEVER_SELF_UPDATE { @@ -1100,7 +1112,7 @@ fn target_list( ) } -fn target_add( +async fn target_add( cfg: &Cfg, mut targets: Vec, toolchain: Option, @@ -1143,13 +1155,13 @@ fn target_add( Some(TargetTriple::new(target)), false, ); - distributable.add_component(new_component)?; + distributable.add_component(new_component).await?; } Ok(utils::ExitCode(0)) } -fn target_remove( +async fn target_remove( cfg: &Cfg, targets: Vec, toolchain: Option, @@ -1177,7 +1189,7 @@ fn target_remove( warn!("after removing the last target, no build targets will be available"); } let new_component = Component::new("rust-std".to_string(), Some(target), false); - distributable.remove_component(new_component)?; + distributable.remove_component(new_component).await?; } Ok(utils::ExitCode(0)) @@ -1194,7 +1206,7 @@ fn component_list( Ok(utils::ExitCode(0)) } -fn component_add( +async fn component_add( cfg: &Cfg, components: Vec, toolchain: Option, @@ -1205,7 +1217,7 @@ fn component_add( for component in &components { let new_component = Component::try_new(component, &distributable, target.as_ref())?; - distributable.add_component(new_component)?; + distributable.add_component(new_component).await?; } Ok(utils::ExitCode(0)) @@ -1220,7 +1232,7 @@ fn get_target( .or_else(|| Some(distributable.desc().target.clone())) } -fn component_remove( +async fn component_remove( cfg: &Cfg, components: Vec, toolchain: Option, @@ -1231,13 +1243,17 @@ fn component_remove( for component in &components { let new_component = Component::try_new(component, &distributable, target.as_ref())?; - distributable.remove_component(new_component)?; + distributable.remove_component(new_component).await?; } Ok(utils::ExitCode(0)) } -fn toolchain_link(cfg: &Cfg, dest: &CustomToolchainName, src: &Path) -> Result { +async fn toolchain_link( + cfg: &Cfg, + dest: &CustomToolchainName, + src: &Path, +) -> Result { cfg.ensure_toolchains_dir()?; let mut pathbuf = PathBuf::from(src); @@ -1255,9 +1271,10 @@ fn toolchain_link(cfg: &Cfg, dest: &CustomToolchainName, src: &Path) -> Result Result, @@ -1289,15 +1306,10 @@ fn override_add( Err(e @ RustupError::ToolchainNotInstalled(_)) => match &toolchain_name { ToolchainName::Custom(_) => Err(e)?, ToolchainName::Official(desc) => { - let status = DistributableToolchain::install( - cfg, - desc, - &[], - &[], - cfg.get_profile()?, - false, - )? - .0; + let status = + DistributableToolchain::install(cfg, desc, &[], &[], cfg.get_profile()?, false) + .await? + .0; writeln!(process().stdout().lock())?; common::show_channel_update( cfg, diff --git a/src/cli/self_update.rs b/src/cli/self_update.rs index 7a3db755d9..d889ad14f9 100644 --- a/src/cli/self_update.rs +++ b/src/cli/self_update.rs @@ -364,7 +364,7 @@ fn canonical_cargo_home() -> Result> { /// Installing is a simple matter of copying the running binary to /// `CARGO_HOME`/bin, hard-linking the various Rust tools to it, /// and adding `CARGO_HOME`/bin to PATH. -pub(crate) fn install( +pub(crate) async fn install( no_prompt: bool, verbose: bool, quiet: bool, @@ -399,7 +399,7 @@ pub(crate) fn install( md(&mut term, MSVC_AUTO_INSTALL_MESSAGE); match windows::choose_vs_install()? { Some(VsInstallPlan::Automatic) => { - match try_install_msvc(&opts) { + match try_install_msvc(&opts).await { Err(e) => { // Make sure the console doesn't exit before the user can // see the error and give the option to continue anyway. @@ -869,16 +869,16 @@ fn maybe_install_rust( // - delete the partial install and start over // For now, we error. let mut toolchain = DistributableToolchain::new(&cfg, desc.clone())?; - toolchain.update(components, targets, cfg.get_profile()?)? + utils::run_future(toolchain.update(components, targets, cfg.get_profile()?))? } else { - DistributableToolchain::install( + utils::run_future(DistributableToolchain::install( &cfg, desc, components, targets, cfg.get_profile()?, true, - )? + ))? .0 }; @@ -1081,7 +1081,7 @@ pub(crate) fn uninstall(no_prompt: bool) -> Result { /// (and on windows this process will not be running to do it), /// rustup-init is stored in `CARGO_HOME`/bin, and then deleted next /// time rustup runs. -pub(crate) fn update(cfg: &Cfg) -> Result { +pub(crate) async fn update(cfg: &Cfg) -> Result { common::warn_if_host_is_emulated(); use common::SelfUpdatePermission::*; @@ -1104,7 +1104,7 @@ pub(crate) fn update(cfg: &Cfg) -> Result { Permit => {} } - match prepare_update()? { + match prepare_update().await? { Some(setup_path) => { let Some(version) = get_and_parse_new_rustup_version(&setup_path) else { err!("failed to get rustup version"); @@ -1160,7 +1160,7 @@ fn parse_new_rustup_version(version: String) -> String { String::from(matched_version) } -pub(crate) fn prepare_update() -> Result> { +pub(crate) async fn prepare_update() -> Result> { let cargo_home = utils::cargo_home()?; let rustup_path = cargo_home.join(format!("bin{MAIN_SEPARATOR}rustup{EXE_SUFFIX}")); let setup_path = cargo_home.join(format!("bin{MAIN_SEPARATOR}rustup-init{EXE_SUFFIX}")); @@ -1193,7 +1193,7 @@ pub(crate) fn prepare_update() -> Result> { // Get available version info!("checking for self-update"); - let available_version = get_available_rustup_version()?; + let available_version = get_available_rustup_version().await?; // If up-to-date if available_version == current_version { @@ -1208,7 +1208,7 @@ pub(crate) fn prepare_update() -> Result> { // Download new version info!("downloading self-update"); - utils::download_file(&download_url, &setup_path, None, &|_| ())?; + utils::download_file(&download_url, &setup_path, None, &|_| ()).await?; // Mark as executable utils::make_executable(&setup_path)?; @@ -1216,7 +1216,7 @@ pub(crate) fn prepare_update() -> Result> { Ok(Some(setup_path)) } -pub(crate) fn get_available_rustup_version() -> Result { +async fn get_available_rustup_version() -> Result { let update_root = update_root(); let tempdir = tempfile::Builder::new() .prefix("rustup-update") @@ -1227,7 +1227,7 @@ pub(crate) fn get_available_rustup_version() -> Result { let release_file_url = format!("{update_root}/release-stable.toml"); let release_file_url = utils::parse_url(&release_file_url)?; let release_file = tempdir.path().join("release-stable.toml"); - utils::download_file(&release_file_url, &release_file, None, &|_| ())?; + utils::download_file(&release_file_url, &release_file, None, &|_| ()).await?; let release_toml_str = utils::read_file("rustup release", &release_file)?; let release_toml: toml::Value = toml::from_str(&release_toml_str).context("unable to parse rustup release file")?; @@ -1254,13 +1254,13 @@ pub(crate) fn get_available_rustup_version() -> Result { Ok(String::from(available_version)) } -pub(crate) fn check_rustup_update() -> Result<()> { +pub(crate) async fn check_rustup_update() -> Result<()> { let mut t = process().stdout().terminal(); // Get current rustup version let current_version = env!("CARGO_PKG_VERSION"); // Get available rustup version - let available_version = get_available_rustup_version()?; + let available_version = get_available_rustup_version().await?; let _ = t.attr(terminalsource::Attr::Bold); write!(t.lock(), "rustup - ")?; diff --git a/src/cli/self_update/windows.rs b/src/cli/self_update/windows.rs index c7097ea895..2102afd418 100644 --- a/src/cli/self_update/windows.rs +++ b/src/cli/self_update/windows.rs @@ -169,7 +169,7 @@ pub(crate) enum ContinueInstall { /// /// Returns `Ok(ContinueInstall::No)` if installing Visual Studio was successful /// but the rustup install should not be continued at this time. -pub(crate) fn try_install_msvc(opts: &InstallOpts<'_>) -> Result { +pub(crate) async fn try_install_msvc(opts: &InstallOpts<'_>) -> Result { // download the installer let visual_studio_url = utils::parse_url("https://aka.ms/vs/17/release/vs_community.exe")?; @@ -187,7 +187,8 @@ pub(crate) fn try_install_msvc(opts: &InstallOpts<'_>) -> Result Result { +pub async fn main() -> Result { use clap::error::ErrorKind; let RustupInit { @@ -122,5 +122,5 @@ pub fn main() -> Result { targets: &target.iter().map(|s| &**s).collect::>(), }; - self_update::install(no_prompt, verbose, quiet, opts) + self_update::install(no_prompt, verbose, quiet, opts).await } diff --git a/src/config.rs b/src/config.rs index e2b5b0905a..bcfa565e88 100644 --- a/src/config.rs +++ b/src/config.rs @@ -786,23 +786,23 @@ impl Cfg { let targets: Vec<_> = targets.iter().map(AsRef::as_ref).collect(); let toolchain = match DistributableToolchain::new(self, toolchain.clone()) { Err(RustupError::ToolchainNotInstalled(_)) => { - DistributableToolchain::install( + utils::run_future(DistributableToolchain::install( self, &toolchain, &components, &targets, profile.unwrap_or(Profile::Default), false, - )? + ))? .1 } Ok(mut distributable) => { if !distributable.components_exist(&components, &targets)? { - distributable.update( + utils::run_future(distributable.update( &components, &targets, profile.unwrap_or(Profile::Default), - )?; + ))?; } distributable } @@ -895,7 +895,13 @@ impl Cfg { // Update toolchains and collect the results let channels = channels.map(|(desc, mut distributable)| { - let st = distributable.update_extra(&[], &[], profile, force_update, false); + let st = utils::run_future(distributable.update_extra( + &[], + &[], + profile, + force_update, + false, + )); if let Err(ref e) = st { (self.notify_handler)(Notification::NonFatalError(e)); @@ -938,14 +944,14 @@ impl Cfg { match DistributableToolchain::new(self, desc.clone()) { Err(RustupError::ToolchainNotInstalled(_)) => { if install_if_missing { - DistributableToolchain::install( + utils::run_future(DistributableToolchain::install( self, desc, &[], &[], self.get_profile()?, true, - )?; + ))?; } } o => { diff --git a/src/currentprocess.rs b/src/currentprocess.rs index 5400bbabcb..4523e4e6f8 100644 --- a/src/currentprocess.rs +++ b/src/currentprocess.rs @@ -1,11 +1,12 @@ -use std::cell::RefCell; use std::env; use std::ffi::OsString; use std::fmt::Debug; -use std::io::{self, IsTerminal}; +use std::future::Future; +use std::io; use std::panic; use std::path::PathBuf; use std::sync::Once; +use std::{cell::RefCell, io::IsTerminal}; #[cfg(feature = "test")] use std::{ collections::HashMap, @@ -135,6 +136,20 @@ pub fn with(process: Process, f: F) -> R where F: FnOnce() -> R, { + ensure_hook(); + + PROCESS.with(|p| { + if let Some(old_p) = &*p.borrow() { + panic!("current process already set {old_p:?}"); + } + *p.borrow_mut() = Some(process); + let result = f(); + *p.borrow_mut() = None; + result + }) +} + +fn ensure_hook() { HOOK_INSTALLED.call_once(|| { let orig_hook = panic::take_hook(); panic::set_hook(Box::new(move |info| { @@ -142,13 +157,65 @@ where orig_hook(info); })); }); +} - PROCESS.with(|p| { +/// Run a function in the context of a process definition and a tokio runtime. +/// +/// The process state is injected into a thread-local in every work thread of +/// the runtime, but this requires access to the runtime builder, so this +/// function must be the one to create the runtime. +pub fn with_runtime<'a, R>( + process: Process, + mut runtime_builder: tokio::runtime::Builder, + fut: impl Future + 'a, +) -> R { + ensure_hook(); + + let start_process = process.clone(); + let unpark_process = process.clone(); + let runtime = runtime_builder + // propagate to blocking threads + .on_thread_start(move || { + // assign the process persistently to the thread local. + PROCESS.with(|p| { + if let Some(old_p) = &*p.borrow() { + panic!("current process already set {old_p:?}"); + } + *p.borrow_mut() = Some(start_process.clone()); + // Thread exits will clear the process. + }); + }) + .on_thread_stop(move || { + PROCESS.with(|p| { + *p.borrow_mut() = None; + }); + }) + // propagate to async worker threads + .on_thread_unpark(move || { + // assign the process persistently to the thread local. + PROCESS.with(|p| { + if let Some(old_p) = &*p.borrow() { + panic!("current process already set {old_p:?}"); + } + *p.borrow_mut() = Some(unpark_process.clone()); + // Thread exits will clear the process. + }); + }) + .on_thread_park(move || { + PROCESS.with(|p| { + *p.borrow_mut() = None; + }); + }) + .build() + .unwrap(); + + // The current thread doesn't get hooks run on it. + PROCESS.with(move |p| { if let Some(old_p) = &*p.borrow() { panic!("current process already set {old_p:?}"); } *p.borrow_mut() = Some(process); - let result = f(); + let result = runtime.block_on(fut); *p.borrow_mut() = None; result }) @@ -233,6 +300,7 @@ impl TestProcess { stderr: Arc::new(Mutex::new(Vec::new())), } } + fn new_id() -> u64 { let low_bits: u64 = std::process::id() as u64; let mut rng = thread_rng(); diff --git a/src/dist/dist.rs b/src/dist/dist.rs index f53b24a736..3980ebb9b0 100644 --- a/src/dist/dist.rs +++ b/src/dist/dist.rs @@ -703,7 +703,7 @@ pub(crate) fn valid_profile_names() -> String { // // Returns the manifest's hash if anything changed. #[cfg_attr(feature = "otel", tracing::instrument(err, skip_all, fields(profile=format!("{profile:?}"), prefix=prefix.path().to_string_lossy().to_string())))] -pub(crate) fn update_from_dist( +pub(crate) async fn update_from_dist( download: DownloadCfg<'_>, update_hash: Option<&Path>, toolchain: &ToolchainDesc, @@ -735,7 +735,8 @@ pub(crate) fn update_from_dist( old_date, components, targets, - ); + ) + .await; // Don't leave behind an empty / broken installation directory if res.is_err() && fresh_install { @@ -746,7 +747,7 @@ pub(crate) fn update_from_dist( res } -fn update_from_dist_( +async fn update_from_dist_( download: DownloadCfg<'_>, update_hash: Option<&Path>, toolchain: &ToolchainDesc, @@ -812,7 +813,9 @@ fn update_from_dist_( components, targets, &mut fetched, - ) { + ) + .await + { Ok(v) => break Ok(v), Err(e) => { if !backtrack { @@ -882,7 +885,7 @@ fn update_from_dist_( } } -fn try_update_from_dist_( +async fn try_update_from_dist_( download: DownloadCfg<'_>, update_hash: Option<&Path>, toolchain: &ToolchainDesc, @@ -909,7 +912,9 @@ fn try_update_from_dist_( None }, toolchain, - ) { + ) + .await + { Ok(Some((m, hash))) => { (download.notify_handler)(Notification::DownloadedManifest( &m.date, @@ -962,14 +967,17 @@ fn try_update_from_dist_( fetched.clone_from(&m.date); - return match manifestation.update( - &m, - changes, - force_update, - &download, - &toolchain.manifest_name(), - true, - ) { + return match manifestation + .update( + &m, + changes, + force_update, + &download, + &toolchain.manifest_name(), + true, + ) + .await + { Ok(status) => match status { UpdateStatus::Unchanged => Ok(None), UpdateStatus::Changed => Ok(Some(hash)), @@ -1012,7 +1020,7 @@ fn try_update_from_dist_( } // If the v2 manifest is not found then try v1 - let manifest = match dl_v1_manifest(download, toolchain) { + let manifest = match dl_v1_manifest(download, toolchain).await { Ok(m) => m, Err(any) => { enum Cases { @@ -1043,12 +1051,14 @@ fn try_update_from_dist_( } } }; - let result = manifestation.update_v1( - &manifest, - update_hash, - download.tmp_cx, - &download.notify_handler, - ); + let result = manifestation + .update_v1( + &manifest, + update_hash, + download.tmp_cx, + &download.notify_handler, + ) + .await; // inspect, determine what context to add, then process afterwards. let mut download_not_exists = false; match &result { @@ -1068,13 +1078,16 @@ fn try_update_from_dist_( } } -pub(crate) fn dl_v2_manifest( +pub(crate) async fn dl_v2_manifest( download: DownloadCfg<'_>, update_hash: Option<&Path>, toolchain: &ToolchainDesc, ) -> Result> { let manifest_url = toolchain.manifest_v2_url(download.dist_root); - match download.download_and_check(&manifest_url, update_hash, ".toml") { + match download + .download_and_check(&manifest_url, update_hash, ".toml") + .await + { Ok(manifest_dl) => { // Downloaded ok! let (manifest_file, manifest_hash) = if let Some(m) = manifest_dl { @@ -1097,7 +1110,10 @@ pub(crate) fn dl_v2_manifest( } } -fn dl_v1_manifest(download: DownloadCfg<'_>, toolchain: &ToolchainDesc) -> Result> { +async fn dl_v1_manifest( + download: DownloadCfg<'_>, + toolchain: &ToolchainDesc, +) -> Result> { let root_url = toolchain.package_dir(download.dist_root); if !["nightly", "beta", "stable"].contains(&&*toolchain.channel) { @@ -1111,7 +1127,7 @@ fn dl_v1_manifest(download: DownloadCfg<'_>, toolchain: &ToolchainDesc) -> Resul } let manifest_url = toolchain.manifest_v1_url(download.dist_root); - let manifest_dl = download.download_and_check(&manifest_url, None, "")?; + let manifest_dl = download.download_and_check(&manifest_url, None, "").await?; let (manifest_file, _) = manifest_dl.unwrap(); let manifest_str = utils::read_file("manifest", &manifest_file)?; let urls = manifest_str diff --git a/src/dist/download.rs b/src/dist/download.rs index 8980d180df..0ee96f95d2 100644 --- a/src/dist/download.rs +++ b/src/dist/download.rs @@ -38,7 +38,7 @@ impl<'a> DownloadCfg<'a> { /// Partial downloads are stored in `self.download_dir`, keyed by hash. If the /// target file already exists, then the hash is checked and it is returned /// immediately without re-downloading. - pub(crate) fn download(&self, url: &Url, hash: &str) -> Result { + pub(crate) async fn download(&self, url: &Url, hash: &str) -> Result { utils::ensure_dir_exists( "Download Directory", self.download_dir, @@ -77,7 +77,9 @@ impl<'a> DownloadCfg<'a> { Some(&mut hasher), true, &|n| (self.notify_handler)(n.into()), - ) { + ) + .await + { let err = Err(e); if partial_file_existed { return err.context(RustupError::BrokenPartialFile); @@ -124,13 +126,14 @@ impl<'a> DownloadCfg<'a> { Ok(()) } - fn download_hash(&self, url: &str) -> Result { + async fn download_hash(&self, url: &str) -> Result { let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?; let hash_file = self.tmp_cx.new_file()?; utils::download_file(&hash_url, &hash_file, None, &|n| { (self.notify_handler)(n.into()) - })?; + }) + .await?; utils::read_file("hash", &hash_file).map(|s| s[0..64].to_owned()) } @@ -140,13 +143,13 @@ impl<'a> DownloadCfg<'a> { /// and if they match, the download is skipped. /// Verifies the signature found at the same url with a `.asc` suffix, and prints a /// warning when the signature does not verify, or is not found. - pub(crate) fn download_and_check( + pub(crate) async fn download_and_check( &self, url_str: &str, update_hash: Option<&Path>, ext: &str, ) -> Result, String)>> { - let hash = self.download_hash(url_str)?; + let hash = self.download_hash(url_str).await?; let partial_hash: String = hash.chars().take(UPDATE_HASH_LEN).collect(); if let Some(hash_file) = update_hash { @@ -170,7 +173,8 @@ impl<'a> DownloadCfg<'a> { let mut hasher = Sha256::new(); utils::download_file(&url, &file, Some(&mut hasher), &|n| { (self.notify_handler)(n.into()) - })?; + }) + .await?; let actual_hash = format!("{:x}", hasher.finalize()); if hash != actual_hash { diff --git a/src/dist/manifestation.rs b/src/dist/manifestation.rs index 273a71902f..7568e34d38 100644 --- a/src/dist/manifestation.rs +++ b/src/dist/manifestation.rs @@ -7,8 +7,7 @@ mod tests; use std::path::Path; use anyhow::{anyhow, bail, Context, Result}; -use retry::delay::NoDelay; -use retry::{retry, OperationResult}; +use tokio_retry::{strategy::FixedInterval, RetryIf}; use crate::currentprocess::{process, varsource::VarSource}; use crate::dist::component::{ @@ -21,7 +20,7 @@ use crate::dist::manifest::{Component, CompressionKind, Manifest, TargetedPackag use crate::dist::notifications::*; use crate::dist::prefix::InstallPrefix; use crate::dist::temp; -use crate::errors::{OperationError, RustupError}; +use crate::errors::RustupError; use crate::utils::utils; pub(crate) const DIST_MANIFEST: &str = "multirust-channel-manifest.toml"; @@ -99,7 +98,10 @@ impl Manifestation { /// distribution manifest to "rustlib/rustup-dist.toml" and a /// configuration containing the component name-target pairs to /// "rustlib/rustup-config.toml". - pub fn update( + /// + /// It is *not* safe to run two updates concurrently. See + /// https://github.com/rust-lang/rustup/issues/988 for the details. + pub async fn update( &self, new_manifest: &Manifest, changes: Changes, @@ -173,26 +175,22 @@ impl Manifestation { let url_url = utils::parse_url(&url)?; - let downloaded_file = retry(NoDelay.take(max_retries), || { - match download_cfg.download(&url_url, &hash) { - Ok(f) => OperationResult::Ok(f), - Err(e) => { - match e.downcast_ref::() { - Some(RustupError::BrokenPartialFile) => { - (download_cfg.notify_handler)(Notification::RetryingDownload(&url)); - return OperationResult::Retry(OperationError(e)); - } - Some(RustupError::DownloadingFile { .. }) => { - (download_cfg.notify_handler)(Notification::RetryingDownload(&url)); - return OperationResult::Retry(OperationError(e)); - } - Some(_) => return OperationResult::Err(OperationError(e)), - None => (), - }; - OperationResult::Err(OperationError(e)) + let downloaded_file = RetryIf::spawn( + FixedInterval::from_millis(0).take(max_retries), + || download_cfg.download(&url_url, &hash), + |e: &anyhow::Error| { + // retry only known retriable cases + match e.downcast_ref::() { + Some(RustupError::BrokenPartialFile) + | Some(RustupError::DownloadingFile { .. }) => { + (download_cfg.notify_handler)(Notification::RetryingDownload(&url)); + true + } + _ => false, } - } - }) + }, + ) + .await .with_context(|| RustupError::ComponentDownloadFailed(component.name(new_manifest)))?; things_downloaded.push(hash); @@ -380,7 +378,7 @@ impl Manifestation { } /// Installation using the legacy v1 manifest format - pub(crate) fn update_v1( + pub(crate) async fn update_v1( &self, new_manifest: &[String], update_hash: Option<&Path>, @@ -423,7 +421,9 @@ impl Manifestation { notify_handler, }; - let dl = dlcfg.download_and_check(&url, update_hash, ".tar.gz")?; + let dl = dlcfg + .download_and_check(&url, update_hash, ".tar.gz") + .await?; if dl.is_none() { return Ok(None); }; diff --git a/src/dist/manifestation/tests.rs b/src/dist/manifestation/tests.rs index 54cd010ea3..86aef3f0f2 100644 --- a/src/dist/manifestation/tests.rs +++ b/src/dist/manifestation/tests.rs @@ -323,7 +323,7 @@ fn rename_component() { )]; change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -332,12 +332,12 @@ fn rename_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/bonus"))); assert!(!utils::path_exists(prefix.path().join("bin/bobo"))); change_channel_date(url, "nightly", "2016-02-02"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -346,7 +346,7 @@ fn rename_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/bonus"))); assert!(!utils::path_exists(prefix.path().join("bin/bobo"))); @@ -388,7 +388,7 @@ fn rename_component_new() { )]; // Install the basics from day 1 change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -397,7 +397,7 @@ fn rename_component_new() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); // Neither bonus nor bobo are installed at this point. assert!(!utils::path_exists(prefix.path().join("bin/bonus"))); @@ -405,7 +405,7 @@ fn rename_component_new() { // Now we move to day 2, where bobo is part of the set of things we want // to have installed change_channel_date(url, "nightly", "2016-02-02"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -414,7 +414,7 @@ fn rename_component_new() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); // As a result `bin/bonus` is present but not `bin/bobo` which we'd // expect since the bonus component installs `bin/bonus` regardless of @@ -430,7 +430,7 @@ fn rename_component_new() { // an upgrade then all the existing components will be upgraded. // FIXME: Unify this with dist::update_from_dist #[allow(clippy::too_many_arguments)] -fn update_from_dist( +async fn update_from_dist( dist_server: &Url, toolchain: &ToolchainDesc, prefix: &InstallPrefix, @@ -443,7 +443,7 @@ fn update_from_dist( // Download the dist manifest and place it into the installation prefix let manifest_url = make_manifest_url(dist_server, toolchain)?; let manifest_file = tmp_cx.new_file()?; - utils::download_file(&manifest_url, &manifest_file, None, &|_| {})?; + utils::download_file(&manifest_url, &manifest_file, None, &|_| {}).await?; let manifest_str = utils::read_file("manifest", &manifest_file)?; let manifest = Manifest::parse(&manifest_str)?; @@ -461,14 +461,16 @@ fn update_from_dist( remove_components: remove.to_owned(), }; - manifestation.update( - &manifest, - changes, - force, - download_cfg, - &toolchain.manifest_name(), - true, - ) + manifestation + .update( + &manifest, + changes, + force, + download_cfg, + &toolchain.manifest_name(), + true, + ) + .await } fn make_manifest_url(dist_server: &Url, toolchain: &ToolchainDesc) -> Result { @@ -574,7 +576,7 @@ fn initial_install(comps: Compressions) { prefix, download_cfg, tmp_cx| { - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -583,7 +585,7 @@ fn initial_install(comps: Compressions) { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/rustc"))); @@ -613,7 +615,7 @@ fn test_uninstall() { prefix, download_cfg, tmp_cx| { - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -622,7 +624,7 @@ fn test_uninstall() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); uninstall(toolchain, prefix, tmp_cx, &|_| ()).unwrap(); @@ -638,7 +640,7 @@ fn uninstall_removes_config_file() { prefix, download_cfg, tmp_cx| { - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -647,7 +649,7 @@ fn uninstall_removes_config_file() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( prefix.manifest_file("multirust-config.toml") @@ -667,7 +669,7 @@ fn upgrade() { download_cfg, tmp_cx| { change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -676,14 +678,14 @@ fn upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert_eq!( "2016-02-01", fs::read_to_string(prefix.path().join("bin/rustc")).unwrap() ); change_channel_date(url, "nightly", "2016-02-02"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -692,7 +694,7 @@ fn upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert_eq!( "2016-02-02", @@ -745,7 +747,7 @@ fn unavailable_component() { change_channel_date(url, "nightly", "2016-02-01"); // Update with bonus. - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -754,13 +756,13 @@ fn unavailable_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/bonus"))); change_channel_date(url, "nightly", "2016-02-02"); // Update without bonus, should fail. - let err = update_from_dist( + let err = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -769,7 +771,7 @@ fn unavailable_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); match err.downcast::() { Ok(RustupError::RequestedComponentsUnavailable { @@ -815,7 +817,7 @@ fn unavailable_component_from_profile() { &|url, toolchain, prefix, download_cfg, tmp_cx| { change_channel_date(url, "nightly", "2016-02-01"); // Update with rustc. - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -824,13 +826,13 @@ fn unavailable_component_from_profile() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/rustc"))); change_channel_date(url, "nightly", "2016-02-02"); // Update without rustc, should fail. - let err = update_from_dist( + let err = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -839,7 +841,7 @@ fn unavailable_component_from_profile() { download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); match err.downcast::() { Ok(RustupError::RequestedComponentsUnavailable { @@ -857,7 +859,17 @@ fn unavailable_component_from_profile() { _ => panic!(), } - update_from_dist(url, toolchain, prefix, &[], &[], download_cfg, tmp_cx, true).unwrap(); + utils::run_future(update_from_dist( + url, + toolchain, + prefix, + &[], + &[], + download_cfg, + tmp_cx, + true, + )) + .unwrap(); }, ); } @@ -894,7 +906,7 @@ fn removed_component() { // Update with bonus. change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -903,13 +915,13 @@ fn removed_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/bonus"))); // Update without bonus, should fail with RequestedComponentsUnavailable change_channel_date(url, "nightly", "2016-02-02"); - let err = update_from_dist( + let err = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -918,7 +930,7 @@ fn removed_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); match err.downcast::() { Ok(RustupError::RequestedComponentsUnavailable { @@ -976,7 +988,7 @@ fn unavailable_components_is_target() { // Update with rust-std change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -985,7 +997,7 @@ fn unavailable_components_is_target() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( @@ -997,7 +1009,7 @@ fn unavailable_components_is_target() { // Update without rust-std change_channel_date(url, "nightly", "2016-02-02"); - let err = update_from_dist( + let err = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1006,7 +1018,7 @@ fn unavailable_components_is_target() { download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); match err.downcast::() { Ok(RustupError::RequestedComponentsUnavailable { @@ -1071,7 +1083,7 @@ fn unavailable_components_with_same_target() { &|url, toolchain, prefix, download_cfg, tmp_cx| { // Update with rust-std and rustc change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1080,14 +1092,14 @@ fn unavailable_components_with_same_target() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/rustc"))); assert!(utils::path_exists(prefix.path().join("lib/libstd.rlib"))); // Update without rust-std and rustc change_channel_date(url, "nightly", "2016-02-02"); - let err = update_from_dist( + let err = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1096,7 +1108,7 @@ fn unavailable_components_with_same_target() { download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); match err.downcast::() { Ok(RustupError::RequestedComponentsUnavailable { @@ -1144,7 +1156,7 @@ fn update_preserves_extensions() { ]; change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1153,7 +1165,7 @@ fn update_preserves_extensions() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( @@ -1164,7 +1176,7 @@ fn update_preserves_extensions() { )); change_channel_date(url, "nightly", "2016-02-02"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1173,7 +1185,7 @@ fn update_preserves_extensions() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( @@ -1192,7 +1204,7 @@ fn update_makes_no_changes_for_identical_manifest() { prefix, download_cfg, tmp_cx| { - let status = update_from_dist( + let status = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1201,10 +1213,10 @@ fn update_makes_no_changes_for_identical_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert_eq!(status, UpdateStatus::Changed); - let status = update_from_dist( + let status = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1213,7 +1225,7 @@ fn update_makes_no_changes_for_identical_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert_eq!(status, UpdateStatus::Unchanged); }); @@ -1239,7 +1251,7 @@ fn add_extensions_for_initial_install() { ), ]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1248,7 +1260,7 @@ fn add_extensions_for_initial_install() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( prefix.path().join("lib/i686-apple-darwin/libstd.rlib") @@ -1266,7 +1278,7 @@ fn add_extensions_for_same_manifest() { prefix, download_cfg, tmp_cx| { - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1275,7 +1287,7 @@ fn add_extensions_for_same_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); let adds = vec![ @@ -1291,7 +1303,7 @@ fn add_extensions_for_same_manifest() { ), ]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1300,7 +1312,7 @@ fn add_extensions_for_same_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( @@ -1321,7 +1333,7 @@ fn add_extensions_for_upgrade() { tmp_cx| { change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1330,7 +1342,7 @@ fn add_extensions_for_upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); change_channel_date(url, "nightly", "2016-02-02"); @@ -1348,7 +1360,7 @@ fn add_extensions_for_upgrade() { ), ]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1357,7 +1369,7 @@ fn add_extensions_for_upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( @@ -1383,7 +1395,7 @@ fn add_extension_not_in_manifest() { true, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1392,7 +1404,7 @@ fn add_extension_not_in_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); }); } @@ -1411,7 +1423,7 @@ fn add_extension_that_is_required_component() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1420,7 +1432,7 @@ fn add_extension_that_is_required_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); }); } @@ -1439,17 +1451,17 @@ fn add_extensions_does_not_remove_other_components() { toolchain, prefix, download_cfg, - tmp_cx| { - update_from_dist( + temp_cx| { + utils::run_future(update_from_dist( url, toolchain, prefix, &[], &[], download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .unwrap(); let adds = vec![Component::new( @@ -1458,16 +1470,16 @@ fn add_extensions_does_not_remove_other_components() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, &adds, &[], download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/rustc"))); @@ -1489,7 +1501,7 @@ fn remove_extensions_for_initial_install() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1498,7 +1510,7 @@ fn remove_extensions_for_initial_install() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); }); } @@ -1523,7 +1535,7 @@ fn remove_extensions_for_same_manifest() { ), ]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1532,7 +1544,7 @@ fn remove_extensions_for_same_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); let removes = vec![Component::new( @@ -1541,7 +1553,7 @@ fn remove_extensions_for_same_manifest() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1550,7 +1562,7 @@ fn remove_extensions_for_same_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(!utils::path_exists( @@ -1584,7 +1596,7 @@ fn remove_extensions_for_upgrade() { ), ]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1593,7 +1605,7 @@ fn remove_extensions_for_upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); change_channel_date(url, "nightly", "2016-02-02"); @@ -1604,7 +1616,7 @@ fn remove_extensions_for_upgrade() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1613,7 +1625,7 @@ fn remove_extensions_for_upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(!utils::path_exists( @@ -1635,7 +1647,7 @@ fn remove_extension_not_in_manifest() { tmp_cx| { change_channel_date(url, "nightly", "2016-02-01"); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1644,7 +1656,7 @@ fn remove_extension_not_in_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); change_channel_date(url, "nightly", "2016-02-02"); @@ -1655,7 +1667,7 @@ fn remove_extension_not_in_manifest() { true, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1664,7 +1676,7 @@ fn remove_extension_not_in_manifest() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); }); } @@ -1700,7 +1712,7 @@ fn remove_extension_not_in_manifest_but_is_already_installed() { Some(TargetTriple::new("x86_64-apple-darwin")), true, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1709,7 +1721,7 @@ fn remove_extension_not_in_manifest_but_is_already_installed() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/bonus"))); @@ -1720,7 +1732,7 @@ fn remove_extension_not_in_manifest_but_is_already_installed() { Some(TargetTriple::new("x86_64-apple-darwin")), true, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1729,7 +1741,7 @@ fn remove_extension_not_in_manifest_but_is_already_installed() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); }, ); @@ -1742,17 +1754,17 @@ fn remove_extension_that_is_required_component() { toolchain, prefix, download_cfg, - tmp_cx| { - update_from_dist( + temp_cx| { + utils::run_future(update_from_dist( url, toolchain, prefix, &[], &[], download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .unwrap(); let removes = vec![Component::new( @@ -1761,16 +1773,16 @@ fn remove_extension_that_is_required_component() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, &[], &removes, download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .unwrap(); }); } @@ -1782,17 +1794,17 @@ fn remove_extension_not_installed() { toolchain, prefix, download_cfg, - tmp_cx| { - update_from_dist( + temp_cx| { + utils::run_future(update_from_dist( url, toolchain, prefix, &[], &[], download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .unwrap(); let removes = vec![Component::new( @@ -1801,16 +1813,16 @@ fn remove_extension_not_installed() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, &[], &removes, download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .unwrap(); }); } @@ -1832,7 +1844,7 @@ fn remove_extensions_does_not_remove_other_components() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1841,7 +1853,7 @@ fn remove_extensions_does_not_remove_other_components() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); let removes = vec![Component::new( @@ -1850,7 +1862,7 @@ fn remove_extensions_does_not_remove_other_components() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1859,7 +1871,7 @@ fn remove_extensions_does_not_remove_other_components() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/rustc"))); @@ -1881,7 +1893,7 @@ fn add_and_remove_for_upgrade() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1890,7 +1902,7 @@ fn add_and_remove_for_upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); change_channel_date(url, "nightly", "2016-02-02"); @@ -1907,7 +1919,7 @@ fn add_and_remove_for_upgrade() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1916,7 +1928,7 @@ fn add_and_remove_for_upgrade() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( @@ -1941,7 +1953,7 @@ fn add_and_remove() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1950,7 +1962,7 @@ fn add_and_remove() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); let adds = vec![Component::new( @@ -1965,7 +1977,7 @@ fn add_and_remove() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -1974,7 +1986,7 @@ fn add_and_remove() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists( @@ -1992,17 +2004,17 @@ fn add_and_remove_same_component() { toolchain, prefix, download_cfg, - tmp_cx| { - update_from_dist( + temp_cx| { + utils::run_future(update_from_dist( url, toolchain, prefix, &[], &[], download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .unwrap(); let adds = vec![Component::new( @@ -2017,16 +2029,16 @@ fn add_and_remove_same_component() { false, )]; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, &adds, &removes, download_cfg, - tmp_cx, + temp_cx, false, - ) + )) .expect_err("can't both add and remove components"); }); } @@ -2042,7 +2054,7 @@ fn bad_component_hash() { let path = path.join("dist/2016-02-02/rustc-nightly-x86_64-apple-darwin.tar.gz"); utils_raw::write_file(&path, "bogus").unwrap(); - let err = update_from_dist( + let err = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -2051,7 +2063,7 @@ fn bad_component_hash() { download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); match err.downcast::() { @@ -2072,7 +2084,7 @@ fn unable_to_download_component() { let path = path.join("dist/2016-02-02/rustc-nightly-x86_64-apple-darwin.tar.gz"); fs::remove_file(path).unwrap(); - let err = update_from_dist( + let err = utils::run_future(update_from_dist( url, toolchain, prefix, @@ -2081,7 +2093,7 @@ fn unable_to_download_component() { download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); match err.downcast::() { @@ -2129,7 +2141,7 @@ fn reuse_downloaded_file() { }, }; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -2138,13 +2150,13 @@ fn reuse_downloaded_file() { &download_cfg, tmp_cx, false, - ) + )) .unwrap_err(); assert!(!reuse_notification_fired.get()); allow_installation(prefix); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -2153,7 +2165,7 @@ fn reuse_downloaded_file() { &download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(reuse_notification_fired.get()); @@ -2196,7 +2208,7 @@ fn checks_files_hashes_before_reuse() { }, }; - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -2205,7 +2217,7 @@ fn checks_files_hashes_before_reuse() { &download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(noticed_bad_checksum.get()); @@ -2243,7 +2255,7 @@ fn handle_corrupt_partial_downloads() { ) .unwrap(); - update_from_dist( + utils::run_future(update_from_dist( url, toolchain, prefix, @@ -2252,7 +2264,7 @@ fn handle_corrupt_partial_downloads() { download_cfg, tmp_cx, false, - ) + )) .unwrap(); assert!(utils::path_exists(prefix.path().join("bin/rustc"))); diff --git a/src/install.rs b/src/install.rs index abd29b30bf..922570ede7 100644 --- a/src/install.rs +++ b/src/install.rs @@ -59,7 +59,7 @@ pub(crate) enum InstallMethod<'a> { impl<'a> InstallMethod<'a> { // Install a toolchain #[cfg_attr(feature = "otel", tracing::instrument(err, skip_all))] - pub(crate) fn install(&self) -> Result { + pub(crate) async fn install(&self) -> Result { let nh = self.cfg().notify_handler.clone(); match self { InstallMethod::Copy { .. } @@ -72,9 +72,11 @@ impl<'a> InstallMethod<'a> { } (self.cfg().notify_handler)(RootNotification::ToolchainDirectory(&self.dest_path())); - let updated = self.run(&self.dest_path(), &|n| { - (self.cfg().notify_handler)(n.into()) - })?; + let updated = self + .run(&self.dest_path(), &|n| { + (self.cfg().notify_handler)(n.into()) + }) + .await?; let status = match updated { false => { @@ -102,7 +104,7 @@ impl<'a> InstallMethod<'a> { } } - fn run(&self, path: &Path, notify_handler: &dyn Fn(Notification<'_>)) -> Result { + async fn run(&self, path: &Path, notify_handler: &dyn Fn(Notification<'_>)) -> Result { if path.exists() { // Don't uninstall first for Dist method match self { @@ -147,7 +149,8 @@ impl<'a> InstallMethod<'a> { old_date_version.as_ref().map(|dv| dv.0.as_str()), components, targets, - )?; + ) + .await?; if let Some(hash) = maybe_new_hash { if let Some(hash_file) = update_hash { diff --git a/src/test/mock/clitools.rs b/src/test/mock/clitools.rs index efa071e297..c43016afb8 100644 --- a/src/test/mock/clitools.rs +++ b/src/test/mock/clitools.rs @@ -16,6 +16,7 @@ use std::{ use enum_map::{enum_map, Enum, EnumMap}; use once_cell::sync::Lazy; +use tokio::runtime::Builder; use url::Url; use crate::cli::rustup_mode; @@ -722,7 +723,13 @@ impl Config { ); } let tp = currentprocess::TestProcess::new(&*self.workdir.borrow(), &arg_strings, vars, ""); - let process_res = currentprocess::with(tp.clone().into(), rustup_mode::main); + let mut builder = Builder::new_multi_thread(); + builder + .enable_all() + .worker_threads(2) + .max_blocking_threads(2); + let process_res = + currentprocess::with_runtime(tp.clone().into(), builder, rustup_mode::main()); // convert Err's into an ec let ec = match process_res { Ok(process_res) => process_res, diff --git a/src/toolchain/distributable.rs b/src/toolchain/distributable.rs index a3bee5e03c..faa58736fe 100644 --- a/src/toolchain/distributable.rs +++ b/src/toolchain/distributable.rs @@ -52,7 +52,7 @@ impl<'a> DistributableToolchain<'a> { &self.desc } - pub(crate) fn add_component(&self, mut component: Component) -> anyhow::Result<()> { + pub(crate) async fn add_component(&self, mut component: Component) -> anyhow::Result<()> { // TODO: take multiple components? let manifestation = self.get_manifestation()?; let manifest = self.get_manifest()?; @@ -110,14 +110,16 @@ impl<'a> DistributableToolchain<'a> { &|n: crate::dist::Notification<'_>| (self.cfg.notify_handler)(n.into()); let download_cfg = self.cfg.download_cfg(¬ify_handler); - manifestation.update( - &manifest, - changes, - false, - &download_cfg, - &self.desc.manifest_name(), - false, - )?; + manifestation + .update( + &manifest, + changes, + false, + &download_cfg, + &self.desc.manifest_name(), + false, + ) + .await?; Ok(()) } @@ -326,7 +328,7 @@ impl<'a> DistributableToolchain<'a> { } #[cfg_attr(feature = "otel", tracing::instrument(err, skip_all))] - pub(crate) fn install( + pub(crate) async fn install( cfg: &'a Cfg, desc: &'_ ToolchainDesc, components: &[&str], @@ -350,12 +352,13 @@ impl<'a> DistributableToolchain<'a> { components, targets, } - .install()?; + .install() + .await?; Ok((status, Self::new(cfg, desc.clone())?)) } #[cfg_attr(feature = "otel", tracing::instrument(err, skip_all))] - pub fn install_if_not_installed( + pub async fn install_if_not_installed( cfg: &'a Cfg, desc: &'a ToolchainDesc, ) -> anyhow::Result { @@ -364,23 +367,28 @@ impl<'a> DistributableToolchain<'a> { (cfg.notify_handler)(Notification::UsingExistingToolchain(desc)); Ok(UpdateStatus::Unchanged) } else { - Ok(Self::install(cfg, desc, &[], &[], cfg.get_profile()?, false)?.0) + Ok( + Self::install(cfg, desc, &[], &[], cfg.get_profile()?, false) + .await? + .0, + ) } } #[cfg_attr(feature = "otel", tracing::instrument(err, skip_all))] - pub(crate) fn update( + pub(crate) async fn update( &mut self, components: &[&str], targets: &[&str], profile: Profile, ) -> anyhow::Result { self.update_extra(components, targets, profile, true, false) + .await } /// Update a toolchain with control over the channel behaviour #[cfg_attr(feature = "otel", tracing::instrument(err, skip_all))] - pub(crate) fn update_extra( + pub(crate) async fn update_extra( &mut self, components: &[&str], targets: &[&str], @@ -422,6 +430,7 @@ impl<'a> DistributableToolchain<'a> { targets, } .install() + .await } pub fn recursion_error(&self, binary_lossy: String) -> Result { @@ -459,7 +468,7 @@ impl<'a> DistributableToolchain<'a> { } } - pub(crate) fn remove_component(&self, mut component: Component) -> anyhow::Result<()> { + pub(crate) async fn remove_component(&self, mut component: Component) -> anyhow::Result<()> { // TODO: take multiple components? let manifestation = self.get_manifestation()?; let config = manifestation.read_config()?.unwrap_or_default(); @@ -508,25 +517,29 @@ impl<'a> DistributableToolchain<'a> { &|n: crate::dist::Notification<'_>| (self.cfg.notify_handler)(n.into()); let download_cfg = self.cfg.download_cfg(¬ify_handler); - manifestation.update( - &manifest, - changes, - false, - &download_cfg, - &self.desc.manifest_name(), - false, - )?; + manifestation + .update( + &manifest, + changes, + false, + &download_cfg, + &self.desc.manifest_name(), + false, + ) + .await?; Ok(()) } - pub fn show_dist_version(&self) -> anyhow::Result> { + pub async fn show_dist_version(&self) -> anyhow::Result> { let update_hash = self.cfg.get_hash_file(&self.desc, false)?; let notify_handler = &|n: crate::dist::Notification<'_>| (self.cfg.notify_handler)(n.into()); let download_cfg = self.cfg.download_cfg(¬ify_handler); - match crate::dist::dist::dl_v2_manifest(download_cfg, Some(&update_hash), &self.desc)? { + match crate::dist::dist::dl_v2_manifest(download_cfg, Some(&update_hash), &self.desc) + .await? + { Some((manifest, _)) => Ok(Some(manifest.get_rust_version()?.to_string())), None => Ok(None), } diff --git a/src/utils/utils.rs b/src/utils/utils.rs index 6554a73816..ae04c0c8b2 100644 --- a/src/utils/utils.rs +++ b/src/utils/utils.rs @@ -1,5 +1,6 @@ use std::env; use std::fs::{self, File}; +use std::future::Future; use std::io::{self, BufReader, Write}; use std::path::{Path, PathBuf}; use std::process::ExitStatus; @@ -9,6 +10,8 @@ use home::env as home; use retry::delay::{jitter, Fibonacci}; use retry::{retry, OperationResult}; use sha2::Sha256; +use tokio::runtime::Handle; +use tokio::task; use url::Url; use crate::currentprocess::{ @@ -142,16 +145,16 @@ where }) } -pub fn download_file( +pub async fn download_file( url: &Url, path: &Path, hasher: Option<&mut Sha256>, notify_handler: &dyn Fn(Notification<'_>), ) -> Result<()> { - download_file_with_resume(url, path, hasher, false, ¬ify_handler) + download_file_with_resume(url, path, hasher, false, ¬ify_handler).await } -pub(crate) fn download_file_with_resume( +pub(crate) async fn download_file_with_resume( url: &Url, path: &Path, hasher: Option<&mut Sha256>, @@ -159,7 +162,7 @@ pub(crate) fn download_file_with_resume( notify_handler: &dyn Fn(Notification<'_>), ) -> Result<()> { use download::DownloadError as DEK; - match download_file_(url, path, hasher, resume_from_partial, notify_handler) { + match download_file_(url, path, hasher, resume_from_partial, notify_handler).await { Ok(_) => Ok(()), Err(e) => { if e.downcast_ref::().is_some() { @@ -189,7 +192,7 @@ pub(crate) fn download_file_with_resume( } } -fn download_file_( +async fn download_file_( url: &Url, path: &Path, hasher: Option<&mut Sha256>, @@ -257,13 +260,32 @@ fn download_file_( }; notify_handler(notification); let res = - download_to_path_with_backend(backend, url, path, resume_from_partial, Some(callback)); + download_to_path_with_backend(backend, url, path, resume_from_partial, Some(callback)) + .await; notify_handler(Notification::DownloadFinished); res } +/// Temporary thunk to support asyncifying from underneath. +pub(crate) fn run_future(f: F) -> Result +where + F: Future>, + E: std::convert::From, +{ + match Handle::try_current() { + Ok(current) => { + // hide the asyncness for now. + task::block_in_place(|| current.block_on(f)) + } + Err(_) => { + // Make a runtime to hide the asyncness. + tokio::runtime::Runtime::new()?.block_on(f) + } + } +} + pub(crate) fn parse_url(url: &str) -> Result { Url::parse(url).with_context(|| format!("failed to parse url: {url}")) }