diff --git a/src/background_jobs.rs b/src/background_jobs.rs index 9a5fc44a30c..5611a5e906d 100644 --- a/src/background_jobs.rs +++ b/src/background_jobs.rs @@ -1,9 +1,10 @@ use diesel::prelude::*; +use diesel::r2d2::{ConnectionManager, PooledConnection}; use reqwest::blocking::Client; use std::panic::AssertUnwindSafe; use std::sync::{Arc, Mutex, MutexGuard, PoisonError}; -use crate::db::DieselPool; +use crate::db::ConnectionPool; use crate::swirl::errors::EnqueueError; use crate::swirl::PerformError; use crate::uploaders::Uploader; @@ -23,6 +24,20 @@ pub enum Job { UpdateDownloads, } +/// Database state that is passed to `Job::perform()`. +pub(crate) struct PerformState<'a> { + /// The existing connection used to lock the background job. + /// + /// Most jobs can reuse the existing connection, however it will already be within a + /// transaction and is thus not appropriate in all cases. + pub(crate) conn: &'a PgConnection, + /// A connection pool for obtaining a unique connection. + /// + /// This will be `None` within our standard test framework, as there everything is expected to + /// run within a single transaction. + pub(crate) pool: Option, +} + impl Job { const DAILY_DB_MAINTENANCE: &str = "daily_db_maintenance"; const DUMP_DB: &str = "dump_db"; @@ -94,38 +109,53 @@ impl Job { pub(super) fn perform( self, env: &Option, - conn: &DieselPool, + state: PerformState<'_>, ) -> Result<(), PerformError> { + let PerformState { conn, pool } = state; let env = env .as_ref() .expect("Application should configure a background runner environment"); match self { - Job::DailyDbMaintenance => conn.with_connection(&worker::perform_daily_db_maintenance), + Job::DailyDbMaintenance => { + worker::perform_daily_db_maintenance(&*fresh_connection(pool)?) + } Job::DumpDb(args) => worker::perform_dump_db(env, args.database_url, args.target_name), - Job::IndexAddCrate(args) => conn - .with_connection(&|conn| worker::perform_index_add_crate(env, conn, &args.krate)), + Job::IndexAddCrate(args) => worker::perform_index_add_crate(env, conn, &args.krate), Job::IndexSquash => worker::perform_index_squash(env), Job::IndexSyncToHttp(args) => worker::perform_index_sync_to_http(env, args.crate_name), - Job::IndexUpdateYanked(args) => conn.with_connection(&|conn| { + Job::IndexUpdateYanked(args) => { worker::perform_index_update_yanked(env, conn, &args.krate, &args.version_num) - }), + } Job::NormalizeIndex(args) => worker::perform_normalize_index(env, args), - Job::RenderAndUploadReadme(args) => conn.with_connection(&|conn| { - worker::perform_render_and_upload_readme( - conn, - env, - args.version_id, - &args.text, - &args.readme_path, - args.base_url.as_deref(), - args.pkg_path_in_vcs.as_deref(), - ) - }), - Job::UpdateDownloads => conn.with_connection(&worker::perform_update_downloads), + Job::RenderAndUploadReadme(args) => worker::perform_render_and_upload_readme( + conn, + env, + args.version_id, + &args.text, + &args.readme_path, + args.base_url.as_deref(), + args.pkg_path_in_vcs.as_deref(), + ), + Job::UpdateDownloads => worker::perform_update_downloads(&*fresh_connection(pool)?), } } } +/// A helper function for jobs needing a fresh connection (i.e. not already within a transaction). +/// +/// This will error when run from our main test framework, as there most work is expected to be +/// done within an existing transaction. +fn fresh_connection( + pool: Option, +) -> Result>, PerformError> { + let Some(pool) = pool else { + // In production a pool should be available. This can only be hit in tests, which don't + // provide the pool. + return Err(String::from("Database pool was unavailable").into()); + }; + Ok(pool.get()?) +} + #[derive(Serialize, Deserialize)] pub struct DumpDbJob { pub(super) database_url: String, diff --git a/src/db.rs b/src/db.rs index 9bd8760dce2..d040b786826 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,38 +1,28 @@ use diesel::prelude::*; use diesel::r2d2::{self, ConnectionManager, CustomizeConnection}; -use parking_lot::{ReentrantMutex, ReentrantMutexGuard}; use prometheus::Histogram; -use std::sync::Arc; -use std::{error::Error, ops::Deref, time::Duration}; +use std::sync::{Arc, Mutex, MutexGuard}; +use std::{ops::Deref, time::Duration}; use thiserror::Error; use url::Url; use crate::config; +pub type ConnectionPool = r2d2::Pool>; + #[derive(Clone)] pub enum DieselPool { Pool { - pool: r2d2::Pool>, + pool: ConnectionPool, time_to_obtain_connection_metric: Histogram, }, BackgroundJobPool { - pool: r2d2::Pool>, + pool: ConnectionPool, }, - Test(Arc>), + Test(Arc>), } -type Callback<'a> = &'a dyn Fn(&PgConnection) -> Result<(), Box>; - impl DieselPool { - pub(crate) fn with_connection(&self, f: Callback<'_>) -> Result<(), Box> { - match self { - DieselPool::Pool { pool, .. } | DieselPool::BackgroundJobPool { pool } => { - f(&*pool.get()?) - } - DieselPool::Test(connection) => f(&connection.lock()), - } - } - pub(crate) fn new( url: &str, config: &config::DatabasePools, @@ -69,12 +59,19 @@ impl DieselPool { Self::BackgroundJobPool { pool } } + pub(crate) fn to_real_pool(&self) -> Option { + match self { + Self::Pool { pool, .. } | Self::BackgroundJobPool { pool } => Some(pool.clone()), + _ => None, + } + } + pub(crate) fn new_test(config: &config::DatabasePools, url: &str) -> DieselPool { let conn = PgConnection::establish(&connection_url(config, url)) .expect("failed to establish connection"); conn.begin_test_transaction() .expect("failed to begin test transaction"); - DieselPool::Test(Arc::new(ReentrantMutex::new(conn))) + DieselPool::Test(Arc::new(Mutex::new(conn))) } pub fn get(&self) -> Result, PoolError> { @@ -92,7 +89,7 @@ impl DieselPool { } }), DieselPool::BackgroundJobPool { pool } => Ok(DieselPooledConn::Pool(pool.get()?)), - DieselPool::Test(conn) => Ok(DieselPooledConn::Test(conn.lock())), + DieselPool::Test(conn) => Ok(DieselPooledConn::Test(conn.try_lock().unwrap())), } } @@ -136,9 +133,10 @@ pub struct PoolState { pub idle_connections: u32, } +#[allow(clippy::large_enum_variant)] pub enum DieselPooledConn<'a> { Pool(r2d2::PooledConnection>), - Test(ReentrantMutexGuard<'a, PgConnection>), + Test(MutexGuard<'a, PgConnection>), } impl Deref for DieselPooledConn<'_> { diff --git a/src/swirl/runner.rs b/src/swirl/runner.rs index d23c60314f7..a027013f609 100644 --- a/src/swirl/runner.rs +++ b/src/swirl/runner.rs @@ -1,3 +1,4 @@ +use diesel::connection::{AnsiTransactionManager, TransactionManager}; use diesel::prelude::*; use diesel::r2d2; use diesel::r2d2::ConnectionManager; @@ -11,7 +12,7 @@ use threadpool::ThreadPool; use super::errors::*; use super::storage; -use crate::background_jobs::{Environment, Job}; +use crate::background_jobs::{Environment, Job, PerformState}; use crate::db::{DieselPool, DieselPooledConn}; use event::Event; @@ -55,9 +56,7 @@ impl Runner { job_start_timeout: Duration::from_secs(10), } } -} -impl Runner { pub fn test_runner(environment: Environment, connection_pool: DieselPool) -> Self { Self { connection_pool, @@ -66,17 +65,7 @@ impl Runner { job_start_timeout: Duration::from_secs(5), } } -} - -impl Runner { - #[doc(hidden)] - /// For use in integration tests - pub(super) fn connection_pool(&self) -> &DieselPool { - &self.connection_pool - } -} -impl Runner { /// Runs all pending jobs in the queue /// /// This function will return once all jobs in the queue have begun running, @@ -120,20 +109,18 @@ impl Runner { fn run_single_job(&self, sender: SyncSender) { let environment = self.environment.clone(); - // FIXME: https://github.com/sfackler/r2d2/pull/70 - let connection_pool = AssertUnwindSafe(self.connection_pool().clone()); - self.get_single_job(sender, move |job| { + self.get_single_job(sender, move |job, state| { let job = Job::from_value(&job.job_type, job.data)?; - - // Make sure to move the whole `AssertUnwindSafe` - let connection_pool = connection_pool; - job.perform(&environment, &connection_pool.0) + job.perform(&environment, state) }) } fn get_single_job(&self, sender: SyncSender, f: F) where - F: FnOnce(storage::BackgroundJob) -> Result<(), PerformError> + Send + UnwindSafe + 'static, + F: FnOnce(storage::BackgroundJob, PerformState<'_>) -> Result<(), PerformError> + + Send + + UnwindSafe + + 'static, { use diesel::result::Error::RollbackTransaction; @@ -166,9 +153,53 @@ impl Runner { }; let job_id = job.id; - let result = catch_unwind(|| f(job)) - .map_err(|e| try_to_extract_panic_info(&e)) - .and_then(|r| r); + let transaction_manager = conn.transaction_manager(); + let initial_depth = >::get_transaction_depth( + transaction_manager + ); + if initial_depth != 1 { + warn!("Initial transaction depth is not 1. This is very unexpected"); + } + + let result = conn + .transaction(|| { + let pool = pool.to_real_pool(); + let state = AssertUnwindSafe(PerformState {conn, pool}); + catch_unwind(|| { + // Ensure the whole `AssertUnwindSafe(_)` is moved + let state = state; + f(job, state.0) + }) + .map_err(|e| try_to_extract_panic_info(&e)) + }) + // TODO: Replace with flatten() once that stabilizes + .and_then(std::convert::identity); + + loop { + let depth = >::get_transaction_depth( + transaction_manager + ); + if depth == initial_depth { + break; + } + warn!("Rolling back a transaction due to a panic in a background task"); + match transaction_manager + .rollback_transaction(conn) + { + Ok(_) => (), + Err(e) => { + error!("Leaking a thread and database connection because of an error while rolling back transaction: {e}"); + loop { + std::thread::sleep(Duration::from_secs(24 * 60 * 60)); + error!("How am I still alive?"); + } + } + } + } match result { Ok(_) => storage::delete_successful_job(conn, job_id)?, @@ -269,7 +300,7 @@ mod tests { let return_barrier = Arc::new(AssertUnwindSafe(Barrier::new(2))); let return_barrier2 = return_barrier.clone(); - runner.get_single_job(dummy_sender(), move |job| { + runner.get_single_job(dummy_sender(), move |job, _| { fetch_barrier.0.wait(); // Tell thread 2 it can lock its job assert_eq!(first_job_id, job.id); return_barrier.0.wait(); // Wait for thread 2 to lock its job @@ -277,7 +308,7 @@ mod tests { }); fetch_barrier2.0.wait(); // Wait until thread 1 locks its job - runner.get_single_job(dummy_sender(), move |job| { + runner.get_single_job(dummy_sender(), move |job, _| { assert_eq!(second_job_id, job.id); return_barrier2.0.wait(); // Tell thread 1 it can unlock its job Ok(()) @@ -293,7 +324,7 @@ mod tests { let runner = runner(); create_dummy_job(&runner); - runner.get_single_job(dummy_sender(), |_| Ok(())); + runner.get_single_job(dummy_sender(), |_, _| Ok(())); runner.wait_for_jobs().unwrap(); let remaining_jobs = background_jobs @@ -311,10 +342,12 @@ mod tests { let barrier = Arc::new(AssertUnwindSafe(Barrier::new(2))); let barrier2 = barrier.clone(); - runner.get_single_job(dummy_sender(), move |_| { - barrier.0.wait(); - // error so the job goes back into the queue - Err("nope".into()) + runner.get_single_job(dummy_sender(), move |_, state| { + state.conn.transaction(|| { + barrier.0.wait(); + // The job should go back into the queue after a panic + panic!(); + }) }); let conn = &*runner.connection().unwrap(); @@ -350,7 +383,7 @@ mod tests { let runner = runner(); let job_id = create_dummy_job(&runner).id; - runner.get_single_job(dummy_sender(), |_| panic!()); + runner.get_single_job(dummy_sender(), |_, _| panic!()); runner.wait_for_jobs().unwrap(); let tries = background_jobs diff --git a/src/tests/all.rs b/src/tests/all.rs index 489e0168740..d1f42d5d88e 100644 --- a/src/tests/all.rs +++ b/src/tests/all.rs @@ -162,16 +162,11 @@ fn new_category<'a>(category: &'a str, slug: &'a str, description: &'a str) -> N // This reflects the configuration of our test environment. In the production application, this // does not hold true. #[test] -fn multiple_live_references_to_the_same_connection_can_be_checked_out() { - use std::ptr; - +#[should_panic] +fn recursive_get_of_db_conn_in_tests_will_panic() { let (app, _) = TestApp::init().empty(); let app = app.as_inner(); - let conn1 = app.primary_database.get().unwrap(); - let conn2 = app.primary_database.get().unwrap(); - let conn1_ref: &PgConnection = &conn1; - let conn2_ref: &PgConnection = &conn2; - - assert!(ptr::eq(conn1_ref, conn2_ref)); + let _conn1 = app.primary_database.get().unwrap(); + let _conn2 = app.primary_database.get().unwrap(); }