diff --git a/xtask/src/check_raw.rs b/xtask/src/check_raw.rs index 1702fe04d..b6887afc6 100644 --- a/xtask/src/check_raw.rs +++ b/xtask/src/check_raw.rs @@ -16,8 +16,8 @@ use syn::spanned::Spanned; use syn::token::Comma; use syn::{ parenthesized, Abi, Attribute, Field, Fields, FieldsNamed, FieldsUnnamed, File, Item, - ItemConst, ItemMacro, ItemStruct, ItemType, LitInt, ReturnType, Type, TypeArray, TypeBareFn, - TypePtr, Visibility, + ItemConst, ItemMacro, ItemStruct, ItemType, ItemUnion, LitInt, ReturnType, Type, TypeArray, + TypeBareFn, TypePtr, Visibility, }; use walkdir::WalkDir; @@ -253,26 +253,46 @@ fn check_fields(fields: &Punctuated, src: &Path) -> Result<(), Err Ok(()) } +fn check_type_attrs(attrs: &[Attribute], spanned: &dyn Spanned, src: &Path) -> Result<(), Error> { + let attrs = parse_attrs(attrs, src)?; + let reprs = get_reprs(&attrs); + + let allowed_reprs: &[&[Repr]] = &[&[Repr::C], &[Repr::C, Repr::Packed], &[Repr::Transparent]]; + + if allowed_reprs.contains(&reprs.as_slice()) { + Ok(()) + } else { + Err(Error::new(ErrorKind::ForbiddenRepr, src, spanned)) + } +} + /// Validate a struct. fn check_struct(item: &ItemStruct, src: &Path) -> Result<(), Error> { if !is_pub(&item.vis) { return Err(Error::new(ErrorKind::MissingPub, src, &item.struct_token)); } - let attrs = parse_attrs(&item.attrs, src)?; - match &item.fields { Fields::Named(FieldsNamed { named, .. }) => check_fields(named, src)?, Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => check_fields(unnamed, src)?, Fields::Unit => {} } - let reprs = get_reprs(&attrs); - let allowed_reprs: &[&[Repr]] = &[&[Repr::C], &[Repr::C, Repr::Packed], &[Repr::Transparent]]; - if !allowed_reprs.contains(&reprs.as_slice()) { - return Err(Error::new(ErrorKind::ForbiddenRepr, src, item)); + check_type_attrs(&item.attrs, item, src)?; + + Ok(()) +} + +/// Validate a union. +fn check_union(item: &ItemUnion, src: &Path) -> Result<(), Error> { + if !is_pub(&item.vis) { + return Err(Error::new(ErrorKind::MissingPub, src, &item.union_token)); } + check_fields(&item.fields.named, src)?; + + check_type_attrs(&item.attrs, item, src)?; + Ok(()) } @@ -319,6 +339,9 @@ fn check_item(item: &Item, src: &Path) -> Result<(), Error> { Item::Struct(item) => { check_struct(item, src)?; } + Item::Union(item) => { + check_union(item, src)?; + } Item::Macro(item) => { check_macro(item, src)?; } @@ -555,4 +578,52 @@ mod tests { ErrorKind::ForbiddenType, ); } + + #[test] + fn test_union() { + // Valid union. + assert!(check_union( + &parse_quote! { + #[repr(C)] + pub union U { + pub a: u32, + pub b: u64, + } + }, + src(), + ) + .is_ok()); + + // Missing `pub` on union. + check_item_err( + parse_quote! { + #[repr(C)] + struct U { + pub f: u32, + } + }, + ErrorKind::MissingPub, + ); + + // Missing `pub` on field. + check_item_err( + parse_quote! { + #[repr(C)] + pub struct U { + f: u32, + } + }, + ErrorKind::MissingPub, + ); + + // Forbidden `repr`. + check_item_err( + parse_quote! { + pub struct S { + pub f: u32, + } + }, + ErrorKind::ForbiddenRepr, + ); + } }