Skip to content

Commit bacb8e8

Browse files
committed
Make the enum check work for negative discriminants
The discriminant check was not working correctly for negative numbers. This change fixes that by masking out the relevant bits correctly.
1 parent ad3b725 commit bacb8e8

File tree

3 files changed

+94
-4
lines changed

3 files changed

+94
-4
lines changed

compiler/rustc_mir_transform/src/check_enums.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ enum EnumCheckType<'tcx> {
120120
},
121121
}
122122

123+
#[derive(Debug, Copy, Clone)]
123124
struct TyAndSize<'tcx> {
124125
pub ty: Ty<'tcx>,
125126
pub size: Size,
@@ -338,7 +339,7 @@ fn insert_direct_enum_check<'tcx>(
338339
let invalid_discr_block_data = BasicBlockData::new(None, false);
339340
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340341
let block_data = &mut basic_blocks[current_block];
341-
let discr = insert_discr_cast_to_u128(
342+
let discr_place = insert_discr_cast_to_u128(
342343
tcx,
343344
local_decls,
344345
block_data,
@@ -349,13 +350,35 @@ fn insert_direct_enum_check<'tcx>(
349350
source_info,
350351
);
351352

353+
// Mask out the bits of the discriminant type.
354+
let mask = discr.size.unsigned_int_max();
355+
let discr_masked =
356+
local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
357+
let rvalue = Rvalue::BinaryOp(
358+
BinOp::BitAnd,
359+
Box::new((
360+
Operand::Copy(discr_place),
361+
Operand::Constant(Box::new(ConstOperand {
362+
span: source_info.span,
363+
user_ty: None,
364+
const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
365+
})),
366+
)),
367+
);
368+
block_data.statements.push(Statement {
369+
source_info,
370+
kind: StatementKind::Assign(Box::new((discr_masked, rvalue))),
371+
});
372+
352373
// Branch based on the discriminant value.
353374
block_data.terminator = Some(Terminator {
354375
source_info,
355376
kind: TerminatorKind::SwitchInt {
356-
discr: Operand::Copy(discr),
377+
discr: Operand::Copy(discr_masked),
357378
targets: SwitchTargets::new(
358-
discriminants.into_iter().map(|discr| (discr, new_block)),
379+
discriminants
380+
.into_iter()
381+
.map(|discr_val| (discr.size.truncate(discr_val), new_block)),
359382
invalid_discr_block,
360383
),
361384
},
@@ -372,7 +395,7 @@ fn insert_direct_enum_check<'tcx>(
372395
})),
373396
expected: true,
374397
target: new_block,
375-
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
398+
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
376399
// This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
377400
// We never want to insert an unwind into unsafe code, because unwinding could
378401
// make a failing UB check turn into much worse UB when we start unwinding.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//@ run-fail
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0xfd
4+
5+
#[allow(dead_code)]
6+
enum Foo {
7+
A = -2,
8+
B = -1,
9+
C = 1,
10+
}
11+
12+
fn main() {
13+
let _val: Foo = unsafe { std::mem::transmute::<i8, Foo>(-3) };
14+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
#[derive(Debug, PartialEq)]
6+
enum Foo {
7+
A = -12121,
8+
B = -2,
9+
C = -1,
10+
D = 1,
11+
E = 2,
12+
F = 12121,
13+
}
14+
15+
#[allow(dead_code)]
16+
#[repr(i64)]
17+
#[derive(Debug, PartialEq)]
18+
enum Bar {
19+
A = i64::MIN,
20+
B = -2,
21+
C = -1,
22+
D = 1,
23+
E = 2,
24+
F = i64::MAX,
25+
}
26+
27+
fn main() {
28+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-12121) };
29+
assert_eq!(val, Foo::A);
30+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-2) };
31+
assert_eq!(val, Foo::B);
32+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(-1) };
33+
assert_eq!(val, Foo::C);
34+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(1) };
35+
assert_eq!(val, Foo::D);
36+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(2) };
37+
assert_eq!(val, Foo::E);
38+
let val: Foo = unsafe { std::mem::transmute::<i16, Foo>(12121) };
39+
assert_eq!(val, Foo::F);
40+
41+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MIN) };
42+
assert_eq!(val, Bar::A);
43+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-2) };
44+
assert_eq!(val, Bar::B);
45+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(-1) };
46+
assert_eq!(val, Bar::C);
47+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(1) };
48+
assert_eq!(val, Bar::D);
49+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(2) };
50+
assert_eq!(val, Bar::E);
51+
let val: Bar = unsafe { std::mem::transmute::<i64, Bar>(i64::MAX) };
52+
assert_eq!(val, Bar::F);
53+
}

0 commit comments

Comments
 (0)