diff --git a/.sqlx/query-df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f.json b/.sqlx/query-df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f.json new file mode 100644 index 00000000..b6fd2fc8 --- /dev/null +++ b/.sqlx/query-df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f.json @@ -0,0 +1,44 @@ +{ + "db_name": "PostgreSQL", + "query": "-- we need to join tables from the pg_catalog since \"TRUNCATE\" triggers are \n-- not available in the information_schema.trigger table.\nselect \n t.tgname as \"name!\",\n c.relname as \"table_name!\",\n p.proname as \"proc_name!\",\n n.nspname as \"schema_name!\",\n t.tgtype as \"details_bitmask!\"\nfrom \n pg_catalog.pg_trigger t \n left join pg_catalog.pg_proc p on t.tgfoid = p.oid\n left join pg_catalog.pg_class c on t.tgrelid = c.oid\n left join pg_catalog.pg_namespace n on c.relnamespace = n.oid\nwhere \n -- triggers enforcing constraints (e.g. unique fields) should not be included.\n t.tgisinternal = false and \n t.tgconstraint = 0;\n", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "name!", + "type_info": "Name" + }, + { + "ordinal": 1, + "name": "table_name!", + "type_info": "Name" + }, + { + "ordinal": 2, + "name": "proc_name!", + "type_info": "Name" + }, + { + "ordinal": 3, + "name": "schema_name!", + "type_info": "Name" + }, + { + "ordinal": 4, + "name": "details_bitmask!", + "type_info": "Int2" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + true, + true, + true, + false + ] + }, + "hash": "df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f" +} diff --git a/Cargo.lock b/Cargo.lock index 55db1b6f..10f45b7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2788,6 +2788,7 @@ dependencies = [ "serde", "serde_json", "sqlx", + "strum", "tokio", ] diff --git a/crates/pgt_schema_cache/Cargo.toml b/crates/pgt_schema_cache/Cargo.toml index 291f80ca..c5fadb3e 100644 --- a/crates/pgt_schema_cache/Cargo.toml +++ b/crates/pgt_schema_cache/Cargo.toml @@ -20,6 +20,7 @@ pgt_diagnostics.workspace = true serde.workspace = true serde_json.workspace = true sqlx.workspace = true +strum = { workspace = true } tokio.workspace = true [dev-dependencies] diff --git a/crates/pgt_schema_cache/src/columns.rs b/crates/pgt_schema_cache/src/columns.rs index 6e2e2adf..de7c2d4a 100644 --- a/crates/pgt_schema_cache/src/columns.rs +++ b/crates/pgt_schema_cache/src/columns.rs @@ -37,7 +37,7 @@ impl From for ColumnClassKind { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] pub struct Column { pub name: String, diff --git a/crates/pgt_schema_cache/src/functions.rs b/crates/pgt_schema_cache/src/functions.rs index 36db011d..5e40709f 100644 --- a/crates/pgt_schema_cache/src/functions.rs +++ b/crates/pgt_schema_cache/src/functions.rs @@ -58,7 +58,7 @@ impl From> for FunctionArgs { } } -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct Function { /// The Id (`oid`). pub id: i64, diff --git a/crates/pgt_schema_cache/src/lib.rs b/crates/pgt_schema_cache/src/lib.rs index fc717fbe..d978a94b 100644 --- a/crates/pgt_schema_cache/src/lib.rs +++ b/crates/pgt_schema_cache/src/lib.rs @@ -8,6 +8,7 @@ mod policies; mod schema_cache; mod schemas; mod tables; +mod triggers; mod types; mod versions; @@ -16,3 +17,4 @@ pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; pub use schema_cache::SchemaCache; pub use schemas::Schema; pub use tables::{ReplicaIdentity, Table}; +pub use triggers::{Trigger, TriggerAffected, TriggerEvent}; diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 46a3ab18..641dad12 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -54,7 +54,7 @@ impl From for Policy { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] pub struct Policy { name: String, table_name: String, diff --git a/crates/pgt_schema_cache/src/queries/triggers.sql b/crates/pgt_schema_cache/src/queries/triggers.sql new file mode 100644 index 00000000..c28cc39f --- /dev/null +++ b/crates/pgt_schema_cache/src/queries/triggers.sql @@ -0,0 +1,17 @@ +-- we need to join tables from the pg_catalog since "TRUNCATE" triggers are +-- not available in the information_schema.trigger table. +select + t.tgname as "name!", + c.relname as "table_name!", + p.proname as "proc_name!", + n.nspname as "schema_name!", + t.tgtype as "details_bitmask!" +from + pg_catalog.pg_trigger t + left join pg_catalog.pg_proc p on t.tgfoid = p.oid + left join pg_catalog.pg_class c on t.tgrelid = c.oid + left join pg_catalog.pg_namespace n on c.relnamespace = n.oid +where + -- triggers enforcing constraints (e.g. unique fields) should not be included. + t.tgisinternal = false and + t.tgconstraint = 0; diff --git a/crates/pgt_schema_cache/src/schema_cache.rs b/crates/pgt_schema_cache/src/schema_cache.rs index 8a5c1a93..b21d2baf 100644 --- a/crates/pgt_schema_cache/src/schema_cache.rs +++ b/crates/pgt_schema_cache/src/schema_cache.rs @@ -1,5 +1,6 @@ use sqlx::postgres::PgPool; +use crate::Trigger; use crate::columns::Column; use crate::functions::Function; use crate::policies::Policy; @@ -8,7 +9,7 @@ use crate::tables::Table; use crate::types::PostgresType; use crate::versions::Version; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] pub struct SchemaCache { pub schemas: Vec, pub tables: Vec, @@ -17,11 +18,12 @@ pub struct SchemaCache { pub versions: Vec, pub columns: Vec, pub policies: Vec, + pub triggers: Vec, } impl SchemaCache { pub async fn load(pool: &PgPool) -> Result { - let (schemas, tables, functions, types, versions, columns, policies) = futures_util::try_join!( + let (schemas, tables, functions, types, versions, columns, policies, triggers) = futures_util::try_join!( Schema::load(pool), Table::load(pool), Function::load(pool), @@ -29,6 +31,7 @@ impl SchemaCache { Version::load(pool), Column::load(pool), Policy::load(pool), + Trigger::load(pool), )?; Ok(SchemaCache { @@ -39,6 +42,7 @@ impl SchemaCache { versions, columns, policies, + triggers, }) } diff --git a/crates/pgt_schema_cache/src/schemas.rs b/crates/pgt_schema_cache/src/schemas.rs index 41747194..5a007e51 100644 --- a/crates/pgt_schema_cache/src/schemas.rs +++ b/crates/pgt_schema_cache/src/schemas.rs @@ -2,7 +2,7 @@ use sqlx::PgPool; use crate::schema_cache::SchemaCacheItem; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] pub struct Schema { pub id: i64, pub name: String, diff --git a/crates/pgt_schema_cache/src/tables.rs b/crates/pgt_schema_cache/src/tables.rs index ea889ca9..99061384 100644 --- a/crates/pgt_schema_cache/src/tables.rs +++ b/crates/pgt_schema_cache/src/tables.rs @@ -23,7 +23,7 @@ impl From for ReplicaIdentity { } } -#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[derive(Debug, Default, PartialEq, Eq)] pub struct Table { pub id: i64, pub schema: String, diff --git a/crates/pgt_schema_cache/src/triggers.rs b/crates/pgt_schema_cache/src/triggers.rs new file mode 100644 index 00000000..0a5241d6 --- /dev/null +++ b/crates/pgt_schema_cache/src/triggers.rs @@ -0,0 +1,300 @@ +use crate::schema_cache::SchemaCacheItem; +use strum::{EnumIter, IntoEnumIterator}; + +#[derive(Debug, PartialEq, Eq)] +pub enum TriggerAffected { + Row, + Statement, +} + +impl From for TriggerAffected { + fn from(value: i16) -> Self { + let is_row = 0b0000_0001; + if value & is_row == is_row { + Self::Row + } else { + Self::Statement + } + } +} + +#[derive(Debug, PartialEq, Eq, EnumIter)] +pub enum TriggerEvent { + Insert, + Delete, + Update, + Truncate, +} + +struct TriggerEvents(Vec); + +impl From for TriggerEvents { + fn from(value: i16) -> Self { + Self( + TriggerEvent::iter() + .filter(|variant| { + #[rustfmt::skip] + let mask = match variant { + TriggerEvent::Insert => 0b0000_0100, + TriggerEvent::Delete => 0b0000_1000, + TriggerEvent::Update => 0b0001_0000, + TriggerEvent::Truncate => 0b0010_0000, + }; + mask & value == mask + }) + .collect(), + ) + } +} + +#[derive(Debug, PartialEq, Eq, EnumIter)] +pub enum TriggerTiming { + Before, + After, + Instead, +} + +impl TryFrom for TriggerTiming { + type Error = (); + fn try_from(value: i16) -> Result { + TriggerTiming::iter() + .find(|variant| { + match variant { + TriggerTiming::Instead => { + let mask = 0b0100_0000; + mask & value == mask + } + TriggerTiming::Before => { + let mask = 0b0000_0010; + mask & value == mask + } + TriggerTiming::After => { + let mask = 0b1011_1101; + // timing is "AFTER" if neither INSTEAD nor BEFORE bit are set. + mask | value == mask + } + } + }) + .ok_or(()) + } +} + +pub struct TriggerQueried { + name: String, + table_name: String, + schema_name: String, + proc_name: String, + details_bitmask: i16, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Trigger { + name: String, + table_name: String, + schema_name: String, + proc_name: String, + affected: TriggerAffected, + timing: TriggerTiming, + events: Vec, +} + +impl From for Trigger { + fn from(value: TriggerQueried) -> Self { + Self { + name: value.name, + table_name: value.table_name, + proc_name: value.proc_name, + schema_name: value.schema_name, + affected: value.details_bitmask.into(), + timing: value.details_bitmask.try_into().unwrap(), + events: TriggerEvents::from(value.details_bitmask).0, + } + } +} + +impl SchemaCacheItem for Trigger { + type Item = Trigger; + + async fn load(pool: &sqlx::PgPool) -> Result, sqlx::Error> { + let results = sqlx::query_file_as!(TriggerQueried, "src/queries/triggers.sql") + .fetch_all(pool) + .await?; + + Ok(results.into_iter().map(|r| r.into()).collect()) + } +} + +#[cfg(test)] +mod tests { + use pgt_test_utils::test_database::get_new_test_db; + use sqlx::Executor; + + use crate::{ + SchemaCache, + triggers::{TriggerAffected, TriggerEvent, TriggerTiming}, + }; + + #[tokio::test] + async fn loads_triggers() { + let test_db = get_new_test_db().await; + + let setup = r#" + create table public.users ( + id serial primary key, + name text + ); + + create or replace function public.log_user_insert() + returns trigger as $$ + begin + -- dummy body + return new; + end; + $$ language plpgsql; + + create trigger trg_users_insert + before insert on public.users + for each row + execute function public.log_user_insert(); + + create trigger trg_users_update + after update or insert on public.users + for each statement + execute function public.log_user_insert(); + + create trigger trg_users_delete + before delete on public.users + for each row + execute function public.log_user_insert(); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + + let cache = SchemaCache::load(&test_db) + .await + .expect("Failed to load Schema Cache"); + + let triggers: Vec<_> = cache + .triggers + .iter() + .filter(|t| t.table_name == "users") + .collect(); + assert_eq!(triggers.len(), 3); + + let insert_trigger = triggers + .iter() + .find(|t| t.name == "trg_users_insert") + .unwrap(); + assert_eq!(insert_trigger.schema_name, "public"); + assert_eq!(insert_trigger.table_name, "users"); + assert_eq!(insert_trigger.timing, TriggerTiming::Before); + assert_eq!(insert_trigger.affected, TriggerAffected::Row); + assert!(insert_trigger.events.contains(&TriggerEvent::Insert)); + assert_eq!(insert_trigger.proc_name, "log_user_insert"); + + let update_trigger = triggers + .iter() + .find(|t| t.name == "trg_users_update") + .unwrap(); + assert_eq!(insert_trigger.schema_name, "public"); + assert_eq!(insert_trigger.table_name, "users"); + assert_eq!(update_trigger.timing, TriggerTiming::After); + assert_eq!(update_trigger.affected, TriggerAffected::Statement); + assert!(update_trigger.events.contains(&TriggerEvent::Update)); + assert!(update_trigger.events.contains(&TriggerEvent::Insert)); + assert_eq!(update_trigger.proc_name, "log_user_insert"); + + let delete_trigger = triggers + .iter() + .find(|t| t.name == "trg_users_delete") + .unwrap(); + assert_eq!(insert_trigger.schema_name, "public"); + assert_eq!(insert_trigger.table_name, "users"); + assert_eq!(delete_trigger.timing, TriggerTiming::Before); + assert_eq!(delete_trigger.affected, TriggerAffected::Row); + assert!(delete_trigger.events.contains(&TriggerEvent::Delete)); + assert_eq!(delete_trigger.proc_name, "log_user_insert"); + } + + #[tokio::test] + async fn loads_instead_and_truncate_triggers() { + let test_db = get_new_test_db().await; + + let setup = r#" + create table public.docs ( + id serial primary key, + content text + ); + + create view public.docs_view as + select * from public.docs; + + create or replace function public.docs_instead_of_update() + returns trigger as $$ + begin + -- dummy body + return new; + end; + $$ language plpgsql; + + create trigger trg_docs_instead_update + instead of update on public.docs_view + for each row + execute function public.docs_instead_of_update(); + + create or replace function public.docs_truncate() + returns trigger as $$ + begin + -- dummy body + return null; + end; + $$ language plpgsql; + + create trigger trg_docs_truncate + after truncate on public.docs + for each statement + execute function public.docs_truncate(); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + + let cache = SchemaCache::load(&test_db) + .await + .expect("Failed to load Schema Cache"); + + let triggers: Vec<_> = cache + .triggers + .iter() + .filter(|t| t.table_name == "docs" || t.table_name == "docs_view") + .collect(); + assert_eq!(triggers.len(), 2); + + let instead_trigger = triggers + .iter() + .find(|t| t.name == "trg_docs_instead_update") + .unwrap(); + assert_eq!(instead_trigger.schema_name, "public"); + assert_eq!(instead_trigger.table_name, "docs_view"); + assert_eq!(instead_trigger.timing, TriggerTiming::Instead); + assert_eq!(instead_trigger.affected, TriggerAffected::Row); + assert!(instead_trigger.events.contains(&TriggerEvent::Update)); + assert_eq!(instead_trigger.proc_name, "docs_instead_of_update"); + + let truncate_trigger = triggers + .iter() + .find(|t| t.name == "trg_docs_truncate") + .unwrap(); + assert_eq!(truncate_trigger.schema_name, "public"); + assert_eq!(truncate_trigger.table_name, "docs"); + assert_eq!(truncate_trigger.timing, TriggerTiming::After); + assert_eq!(truncate_trigger.affected, TriggerAffected::Statement); + assert!(truncate_trigger.events.contains(&TriggerEvent::Truncate)); + assert_eq!(truncate_trigger.proc_name, "docs_truncate"); + } +} diff --git a/crates/pgt_schema_cache/src/types.rs b/crates/pgt_schema_cache/src/types.rs index 8b2d04bb..8df6b0cb 100644 --- a/crates/pgt_schema_cache/src/types.rs +++ b/crates/pgt_schema_cache/src/types.rs @@ -36,7 +36,7 @@ impl From> for Enums { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] pub struct PostgresType { pub id: i64, pub name: String, diff --git a/crates/pgt_schema_cache/src/versions.rs b/crates/pgt_schema_cache/src/versions.rs index cf2a140f..a4769c55 100644 --- a/crates/pgt_schema_cache/src/versions.rs +++ b/crates/pgt_schema_cache/src/versions.rs @@ -2,7 +2,7 @@ use sqlx::PgPool; use crate::schema_cache::SchemaCacheItem; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Default)] pub struct Version { pub version: Option, pub version_num: Option,