Skip to content

Expose discriminant values in stable_mir #141639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions compiler/rustc_smir/src/rustc_smir/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use rustc_middle::ty::layout::{
};
use rustc_middle::ty::print::{with_forced_trimmed_paths, with_no_trimmed_paths};
use rustc_middle::ty::{
GenericPredicates, Instance, List, ScalarInt, TyCtxt, TypeVisitableExt, ValTree,
CoroutineArgsExt, GenericPredicates, Instance, List, ScalarInt, TyCtxt, TypeVisitableExt,
ValTree,
};
use rustc_middle::{mir, ty};
use rustc_span::def_id::LOCAL_CRATE;
Expand All @@ -22,9 +23,9 @@ use stable_mir::mir::mono::{InstanceDef, StaticDef};
use stable_mir::mir::{BinOp, Body, Place, UnOp};
use stable_mir::target::{MachineInfo, MachineSize};
use stable_mir::ty::{
AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, FieldDef, FnDef, ForeignDef,
ForeignItemKind, GenericArgs, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy, Span, Ty,
TyConst, TyKind, UintTy, VariantDef,
AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, CoroutineDef, Discr, FieldDef, FnDef,
ForeignDef, ForeignItemKind, GenericArgs, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy,
Span, Ty, TyConst, TyKind, UintTy, VariantDef, VariantIdx,
};
use stable_mir::{Crate, CrateDef, CrateItem, CrateNum, DefId, Error, Filename, ItemKind, Symbol};

Expand Down Expand Up @@ -440,6 +441,30 @@ impl<'tcx> SmirCtxt<'tcx> {
def.internal(&mut *tables, tcx).variants().len()
}

/// Discriminant for a given variant index of AdtDef
pub fn adt_discr_for_variant(&self, adt: AdtDef, variant: VariantIdx) -> Discr {
let mut tables = self.0.borrow_mut();
let tcx = tables.tcx;
let adt = adt.internal(&mut *tables, tcx);
let variant = variant.internal(&mut *tables, tcx);
adt.discriminant_for_variant(tcx, variant).stable(&mut *tables)
}

/// Discriminant for a given variand index and args of a coroutine
pub fn coroutine_discr_for_variant(
&self,
coroutine: CoroutineDef,
args: &GenericArgs,
variant: VariantIdx,
) -> Discr {
let mut tables = self.0.borrow_mut();
let tcx = tables.tcx;
let coroutine = coroutine.def_id().internal(&mut *tables, tcx);
let args = args.internal(&mut *tables, tcx);
let variant = variant.internal(&mut *tables, tcx);
args.as_coroutine().discriminant_for_variant(coroutine, tcx, variant).stable(&mut *tables)
}

/// The name of a variant.
pub fn variant_name(&self, def: VariantDef) -> Symbol {
let mut tables = self.0.borrow_mut();
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_smir/src/rustc_smir/convert/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,11 @@ impl<'tcx> Stable<'tcx> for ty::ImplTraitInTraitData {
}
}
}

impl<'tcx> Stable<'tcx> for rustc_middle::ty::util::Discr<'tcx> {
type T = stable_mir::ty::Discr;

fn stable(&self, tables: &mut Tables<'_>) -> Self::T {
stable_mir::ty::Discr { val: self.val, ty: self.ty.stable(tables) }
}
}
23 changes: 19 additions & 4 deletions compiler/rustc_smir/src/stable_mir/compiler_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use stable_mir::mir::mono::{Instance, InstanceDef, StaticDef};
use stable_mir::mir::{BinOp, Body, Place, UnOp};
use stable_mir::target::MachineInfo;
use stable_mir::ty::{
AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, FieldDef, FnDef, ForeignDef,
ForeignItemKind, ForeignModule, ForeignModuleDef, GenericArgs, GenericPredicates, Generics,
ImplDef, ImplTrait, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy, Span, TraitDecl,
TraitDef, Ty, TyConst, TyConstId, TyKind, UintTy, VariantDef,
AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, CoroutineDef, Discr, FieldDef, FnDef,
ForeignDef, ForeignItemKind, ForeignModule, ForeignModuleDef, GenericArgs, GenericPredicates,
Generics, ImplDef, ImplTrait, IntrinsicDef, LineInfo, MirConst, PolyFnSig, RigidTy, Span,
TraitDecl, TraitDef, Ty, TyConst, TyConstId, TyKind, UintTy, VariantDef, VariantIdx,
};
use stable_mir::{
AssocItems, Crate, CrateItem, CrateItems, CrateNum, DefId, Error, Filename, ImplTraitDecls,
Expand Down Expand Up @@ -225,6 +225,21 @@ impl<'tcx> SmirInterface<'tcx> {
self.cx.adt_variants_len(def)
}

/// Discriminant for a given variant index of AdtDef
pub(crate) fn adt_discr_for_variant(&self, adt: AdtDef, variant: VariantIdx) -> Discr {
self.cx.adt_discr_for_variant(adt, variant)
}

/// Discriminant for a given variand index and args of a coroutine
pub(crate) fn coroutine_discr_for_variant(
&self,
coroutine: CoroutineDef,
args: &GenericArgs,
variant: VariantIdx,
) -> Discr {
self.cx.coroutine_discr_for_variant(coroutine, args, variant)
}

/// The name of a variant.
pub(crate) fn variant_name(&self, def: VariantDef) -> Symbol {
self.cx.variant_name(def)
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_smir/src/stable_mir/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,12 @@ crate_def! {
pub CoroutineDef;
}

impl CoroutineDef {
pub fn discriminant_for_variant(&self, args: &GenericArgs, idx: VariantIdx) -> Discr {
with(|cx| cx.coroutine_discr_for_variant(*self, args, idx))
}
}

crate_def! {
#[derive(Serialize)]
pub CoroutineClosureDef;
Expand Down Expand Up @@ -818,6 +824,15 @@ impl AdtDef {
pub fn variant(&self, idx: VariantIdx) -> Option<VariantDef> {
(idx.to_index() < self.num_variants()).then_some(VariantDef { idx, adt_def: *self })
}

pub fn discriminant_for_variant(&self, idx: VariantIdx) -> Discr {
with(|cx| cx.adt_discr_for_variant(*self, idx))
}
}

pub struct Discr {
pub val: u128,
pub ty: Ty,
}

/// Definition of a variant, which can be either a struct / union field or an enum variant.
Expand Down
183 changes: 183 additions & 0 deletions tests/ui-fulldeps/stable-mir/check_variant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
//@ run-pass
//! Test that users are able to use stable mir APIs to retrieve
//! discriminant value and type for AdtDef and Coroutine variants

//@ ignore-stage1
//@ ignore-cross-compile
//@ ignore-remote
//@ edition: 2024

#![feature(rustc_private)]
#![feature(assert_matches)]

extern crate rustc_middle;
#[macro_use]
extern crate rustc_smir;
extern crate rustc_driver;
extern crate rustc_interface;
extern crate stable_mir;

use std::io::Write;
use std::ops::ControlFlow;

use stable_mir::CrateItem;
use stable_mir::crate_def::CrateDef;
use stable_mir::mir::{AggregateKind, Rvalue, Statement, StatementKind};
use stable_mir::ty::{IntTy, RigidTy, Ty};

const CRATE_NAME: &str = "crate_variant_ty";

/// Test if we can retrieve discriminant info for different types.
fn test_def_tys() -> ControlFlow<()> {
check_adt_mono();
check_adt_poly();
check_adt_poly2();

ControlFlow::Continue(())
}

fn check_adt_mono() {
let mono = get_fn("mono").expect_body();

check_statement_is_aggregate_assign(
&mono.blocks[0].statements[0],
0,
RigidTy::Int(IntTy::Isize),
);
check_statement_is_aggregate_assign(
&mono.blocks[1].statements[0],
1,
RigidTy::Int(IntTy::Isize),
);
check_statement_is_aggregate_assign(
&mono.blocks[2].statements[0],
2,
RigidTy::Int(IntTy::Isize),
);
}

fn check_adt_poly() {
let poly = get_fn("poly").expect_body();

check_statement_is_aggregate_assign(
&poly.blocks[0].statements[0],
0,
RigidTy::Int(IntTy::Isize),
);
check_statement_is_aggregate_assign(
&poly.blocks[1].statements[0],
1,
RigidTy::Int(IntTy::Isize),
);
check_statement_is_aggregate_assign(
&poly.blocks[2].statements[0],
2,
RigidTy::Int(IntTy::Isize),
);
}

fn check_adt_poly2() {
let poly = get_fn("poly2").expect_body();

check_statement_is_aggregate_assign(
&poly.blocks[0].statements[0],
0,
RigidTy::Int(IntTy::Isize),
);
check_statement_is_aggregate_assign(
&poly.blocks[1].statements[0],
1,
RigidTy::Int(IntTy::Isize),
);
check_statement_is_aggregate_assign(
&poly.blocks[2].statements[0],
2,
RigidTy::Int(IntTy::Isize),
);
}

fn get_fn(name: &str) -> CrateItem {
stable_mir::all_local_items().into_iter().find(|it| it.name().eq(name)).unwrap()
}

fn check_statement_is_aggregate_assign(
statement: &Statement,
expected_discr_val: u128,
expected_discr_ty: RigidTy,
) {
if let Statement { kind: StatementKind::Assign(_, rvalue), .. } = statement
&& let Rvalue::Aggregate(aggregate, _) = rvalue
&& let AggregateKind::Adt(adt_def, variant_idx, ..) = aggregate
{
let discr = adt_def.discriminant_for_variant(*variant_idx);

assert_eq!(discr.val, expected_discr_val);
assert_eq!(discr.ty, Ty::from_rigid_kind(expected_discr_ty));
} else {
unreachable!("Unexpected statement");
}
}

/// This test will generate and analyze a dummy crate using the stable mir.
/// For that, it will first write the dummy crate into a file.
/// Then it will create a `StableMir` using custom arguments and then
/// it will run the compiler.
fn main() {
let path = "defs_ty_input.rs";
generate_input(&path).unwrap();
let args = &[
"rustc".to_string(),
"-Cpanic=abort".to_string(),
"--crate-name".to_string(),
CRATE_NAME.to_string(),
path.to_string(),
];
run!(args, test_def_tys).unwrap();
}

fn generate_input(path: &str) -> std::io::Result<()> {
let mut file = std::fs::File::create(path)?;
write!(
file,
r#"
use std::hint::black_box;

enum Mono {{
A,
B(i32),
C {{ a: i32, b: u32 }},
}}

enum Poly<T> {{
A,
B(T),
C {{ t: T }},
}}

pub fn main() {{
mono();
poly();
poly2::<i32>(1);
}}

fn mono() {{
black_box(Mono::A);
black_box(Mono::B(6));
black_box(Mono::C {{a: 1, b: 10 }});
}}

fn poly() {{
black_box(Poly::<i32>::A);
black_box(Poly::B(1i32));
black_box(Poly::C {{ t: 1i32 }});
}}

fn poly2<T: Copy>(t: T) {{
black_box(Poly::<T>::A);
black_box(Poly::B(t));
black_box(Poly::C {{ t: t }});
}}
"#
)?;
Ok(())
}
Loading