Skip to content

Commit 3b5fe7c

Browse files
tests etc
1 parent 3426fcf commit 3b5fe7c

File tree

1 file changed

+107
-8
lines changed

1 file changed

+107
-8
lines changed

crates/pgt_schema_cache/src/triggers.rs

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,21 @@ impl TryFrom<i16> for TriggerTiming {
5959
fn try_from(value: i16) -> Result<Self, ()> {
6060
TriggerTiming::iter()
6161
.find(|variant| {
62-
#[rustfmt::skip]
63-
let mask = match variant {
64-
TriggerTiming::Instead => 0b0100_0000,
65-
TriggerTiming::Before => 0b0000_0010,
66-
TriggerTiming::After => 0b0000_0000, // before/after share same bit
67-
};
68-
mask & value == mask
62+
match variant {
63+
TriggerTiming::Instead => {
64+
let mask = 0b0100_0000;
65+
mask & value == mask
66+
}
67+
TriggerTiming::Before => {
68+
let mask = 0b0000_0010;
69+
mask & value == mask
70+
}
71+
TriggerTiming::After => {
72+
let mask = 0b1011_1101;
73+
// timing is "AFTER" if neither INSTEAD nor BEFORE bit are set.
74+
mask | value == mask
75+
}
76+
}
6977
})
7078
.ok_or(())
7179
}
@@ -84,6 +92,7 @@ pub struct Trigger {
8492
name: String,
8593
table_name: String,
8694
schema_name: String,
95+
proc_name: String,
8796
affected: TriggerAffected,
8897
timing: TriggerTiming,
8998
events: Vec<TriggerEvent>,
@@ -94,6 +103,7 @@ impl From<TriggerQueried> for Trigger {
94103
Self {
95104
name: value.name,
96105
table_name: value.table_name,
106+
proc_name: value.proc_name,
97107
schema_name: value.schema_name,
98108
affected: value.details_bitmask.into(),
99109
timing: value.details_bitmask.try_into().unwrap(),
@@ -148,7 +158,7 @@ mod tests {
148158
execute function public.log_user_insert();
149159
150160
create trigger trg_users_update
151-
after update on public.users
161+
after update or insert on public.users
152162
for each statement
153163
execute function public.log_user_insert();
154164
@@ -178,24 +188,113 @@ mod tests {
178188
.iter()
179189
.find(|t| t.name == "trg_users_insert")
180190
.unwrap();
191+
assert_eq!(insert_trigger.schema_name, "public");
192+
assert_eq!(insert_trigger.table_name, "users");
181193
assert_eq!(insert_trigger.timing, TriggerTiming::Before);
182194
assert_eq!(insert_trigger.affected, TriggerAffected::Row);
183195
assert!(insert_trigger.events.contains(&TriggerEvent::Insert));
196+
assert_eq!(insert_trigger.proc_name, "log_user_insert");
184197

185198
let update_trigger = triggers
186199
.iter()
187200
.find(|t| t.name == "trg_users_update")
188201
.unwrap();
202+
assert_eq!(insert_trigger.schema_name, "public");
203+
assert_eq!(insert_trigger.table_name, "users");
189204
assert_eq!(update_trigger.timing, TriggerTiming::After);
190205
assert_eq!(update_trigger.affected, TriggerAffected::Statement);
191206
assert!(update_trigger.events.contains(&TriggerEvent::Update));
207+
assert!(update_trigger.events.contains(&TriggerEvent::Insert));
208+
assert_eq!(update_trigger.proc_name, "log_user_insert");
192209

193210
let delete_trigger = triggers
194211
.iter()
195212
.find(|t| t.name == "trg_users_delete")
196213
.unwrap();
214+
assert_eq!(insert_trigger.schema_name, "public");
215+
assert_eq!(insert_trigger.table_name, "users");
197216
assert_eq!(delete_trigger.timing, TriggerTiming::Before);
198217
assert_eq!(delete_trigger.affected, TriggerAffected::Row);
199218
assert!(delete_trigger.events.contains(&TriggerEvent::Delete));
219+
assert_eq!(delete_trigger.proc_name, "log_user_insert");
220+
}
221+
222+
#[tokio::test]
223+
async fn loads_instead_and_truncate_triggers() {
224+
let test_db = get_new_test_db().await;
225+
226+
let setup = r#"
227+
create table public.docs (
228+
id serial primary key,
229+
content text
230+
);
231+
232+
create view public.docs_view as
233+
select * from public.docs;
234+
235+
create or replace function public.docs_instead_of_update()
236+
returns trigger as $$
237+
begin
238+
-- dummy body
239+
return new;
240+
end;
241+
$$ language plpgsql;
242+
243+
create trigger trg_docs_instead_update
244+
instead of update on public.docs_view
245+
for each row
246+
execute function public.docs_instead_of_update();
247+
248+
create or replace function public.docs_truncate()
249+
returns trigger as $$
250+
begin
251+
-- dummy body
252+
return null;
253+
end;
254+
$$ language plpgsql;
255+
256+
create trigger trg_docs_truncate
257+
after truncate on public.docs
258+
for each statement
259+
execute function public.docs_truncate();
260+
"#;
261+
262+
test_db
263+
.execute(setup)
264+
.await
265+
.expect("Failed to setup test database");
266+
267+
let cache = SchemaCache::load(&test_db)
268+
.await
269+
.expect("Failed to load Schema Cache");
270+
271+
let triggers: Vec<_> = cache
272+
.triggers
273+
.iter()
274+
.filter(|t| t.table_name == "docs" || t.table_name == "docs_view")
275+
.collect();
276+
assert_eq!(triggers.len(), 2);
277+
278+
let instead_trigger = triggers
279+
.iter()
280+
.find(|t| t.name == "trg_docs_instead_update")
281+
.unwrap();
282+
assert_eq!(instead_trigger.schema_name, "public");
283+
assert_eq!(instead_trigger.table_name, "docs_view");
284+
assert_eq!(instead_trigger.timing, TriggerTiming::Instead);
285+
assert_eq!(instead_trigger.affected, TriggerAffected::Row);
286+
assert!(instead_trigger.events.contains(&TriggerEvent::Update));
287+
assert_eq!(instead_trigger.proc_name, "docs_instead_of_update");
288+
289+
let truncate_trigger = triggers
290+
.iter()
291+
.find(|t| t.name == "trg_docs_truncate")
292+
.unwrap();
293+
assert_eq!(truncate_trigger.schema_name, "public");
294+
assert_eq!(truncate_trigger.table_name, "docs");
295+
assert_eq!(truncate_trigger.timing, TriggerTiming::After);
296+
assert_eq!(truncate_trigger.affected, TriggerAffected::Statement);
297+
assert!(truncate_trigger.events.contains(&TriggerEvent::Truncate));
298+
assert_eq!(truncate_trigger.proc_name, "docs_truncate");
200299
}
201300
}

0 commit comments

Comments
 (0)