diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 05d085fafe937..6054632c4f14a 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -47,6 +47,8 @@ struct SimplifyMatch<'tcx, 'a> { discr: &'a Operand<'tcx>, discr_local: Option, discr_ty: Ty<'tcx>, + /// Extra statements to emit after the unified statements (e.g., range assumes). + extra_stmts: Vec>, } impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { @@ -206,6 +208,48 @@ impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { } else { Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty()) }; + + // Emit range assume so that subsequent passes (and LLVM) can + // eliminate bounds checks that depend on the cast result. + // We know the result is one of the constant values, so we can + // assert `dest <= max_value`. + let dest_ty = first_const.ty(); + if !dest_ty.is_signed() { + let max_val = consts + .iter() + .filter_map(|(_, c)| { + c.const_.try_eval_scalar_int(self.tcx, self.typing_env).map(|s| { + s.to_uint( + self.tcx + .layout_of(self.typing_env.as_query_input(dest_ty)) + .unwrap() + .size, + ) + }) + }) + .max() + .unwrap(); + let max_const = Operand::const_from_scalar( + self.tcx, + dest_ty, + rustc_const_eval::interpret::Scalar::from_uint( + max_val, + self.tcx.layout_of(self.typing_env.as_query_input(dest_ty)).unwrap().size, + ), + rustc_span::DUMMY_SP, + ); + let bool_local = self.patch.new_temp( + self.tcx.types.bool, + self.body.basic_blocks[self.switch_bb].terminator().source_info.span, + ); + let cmp = Rvalue::BinaryOp(BinOp::Le, Box::new((Operand::Copy(dest), max_const))); + self.extra_stmts + .push(StatementKind::Assign(Box::new((Place::from(bool_local), cmp)))); + self.extra_stmts.push(StatementKind::Intrinsic(Box::new( + NonDivergingIntrinsic::Assume(Operand::Move(Place::from(bool_local))), + ))); + } + Some(StatementKind::Assign(Box::new((dest, rval)))) } else { None @@ -381,6 +425,7 @@ fn simplify_match<'tcx>( discr, discr_local: None, discr_ty: discr.ty(body.local_decls(), tcx), + extra_stmts: Vec::new(), }; let reachable_cases: Vec<_> = targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect(); @@ -432,6 +477,9 @@ fn simplify_match<'tcx>( for new_stmt in new_stmts { patch.add_statement(parent_end, new_stmt); } + for extra_stmt in simplify_match.extra_stmts { + patch.add_statement(parent_end, extra_stmt); + } if let Some(discr_local) = simplify_match.discr_local { patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); } diff --git a/tests/codegen-llvm/exhaustive-match-bounds-check-issue-149480.rs b/tests/codegen-llvm/exhaustive-match-bounds-check-issue-149480.rs new file mode 100644 index 0000000000000..b2e7fe7a886bd --- /dev/null +++ b/tests/codegen-llvm/exhaustive-match-bounds-check-issue-149480.rs @@ -0,0 +1,53 @@ +//@ compile-flags: -O +// Regression test for https://github.com/rust-lang/rust/issues/149480: +// the bounds check should be eliminated when indexing an array with +// the result of an exhaustive match over nested enums. The range +// assume emitted by MatchBranchSimplification after the IntToInt cast +// allows LLVM to prove the index is in-bounds. + +#![crate_type = "lib"] + +pub enum Foo { + A(A), + B(B), +} +pub enum A { + A0, + A1, + A2, +} +pub enum B { + B0, + B1, +} + +// CHECK-LABEL: @bar +#[no_mangle] +pub fn bar(foo: Foo, arr: &[u8; 5]) -> u8 { + let offset: usize = match foo { + Foo::A(A::A0) => 0, + Foo::A(A::A1) => 1, + Foo::A(A::A2) => 2, + Foo::B(B::B0) => 3, + Foo::B(B::B1) => 4, + }; + // The bounds check must be eliminated. + // CHECK-NOT: panic_bounds_check + // Positive check: the indexing must lower to a plain load from `arr`, + // so the test cannot pass accidentally if `bar` is optimized into + // another kind of panicking path or if `panic_bounds_check` is + // renamed. + // CHECK: load i8, ptr + // CHECK: ret i8 + arr[offset] +} + +// Sanity check: make sure `panic_bounds_check` is still the symbol LLVM +// emits for a non-elidable out-of-bounds index, so the `CHECK-NOT` above +// is guarding against something real. +// CHECK-LABEL: @test_check +#[no_mangle] +pub fn test_check(arr: &[u8], i: usize) -> u8 { + // CHECK: panic_bounds_check + arr[i] +} diff --git a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff index cc7e863d13546..19295692c4f40 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff @@ -6,6 +6,7 @@ let mut _0: u128; let mut _2: i128; + let mut _3: i128; ++ let mut _4: bool; bb0: { _2 = discriminant(_1); @@ -40,6 +41,8 @@ + StorageLive(_3); + _3 = move _2; + _0 = copy _3 as u128 (IntToInt); ++ _4 = Le(copy _0, const u128::MAX); ++ assume(move _4); + StorageDead(_3); return; } diff --git a/tests/mir-opt/matches_reduce_branches.match_sext_i8_u16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_sext_i8_u16.MatchBranchSimplification.diff index f273d5388350e..9aaeabbd29105 100644 --- a/tests/mir-opt/matches_reduce_branches.match_sext_i8_u16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_sext_i8_u16.MatchBranchSimplification.diff @@ -6,6 +6,7 @@ let mut _0: u16; let mut _2: i8; + let mut _3: i8; ++ let mut _4: bool; bb0: { _2 = discriminant(_1); @@ -45,6 +46,8 @@ + StorageLive(_3); + _3 = move _2; + _0 = copy _3 as u16 (IntToInt); ++ _4 = Le(copy _0, const u16::MAX); ++ assume(move _4); + StorageDead(_3); return; } diff --git a/tests/mir-opt/matches_reduce_branches.match_trunc_i16_u8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_trunc_i16_u8.MatchBranchSimplification.diff index c18719ebb55eb..4e32fa8b7ff36 100644 --- a/tests/mir-opt/matches_reduce_branches.match_trunc_i16_u8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_trunc_i16_u8.MatchBranchSimplification.diff @@ -6,6 +6,7 @@ let mut _0: u8; let mut _2: i16; + let mut _3: i16; ++ let mut _4: bool; bb0: { _2 = discriminant(_1); @@ -70,6 +71,8 @@ + StorageLive(_3); + _3 = move _2; + _0 = copy _3 as u8 (IntToInt); ++ _4 = Le(copy _0, const u8::MAX); ++ assume(move _4); + StorageDead(_3); return; } diff --git a/tests/mir-opt/matches_reduce_branches.match_trunc_u16_u8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_trunc_u16_u8.MatchBranchSimplification.diff index d4dafbd886fcb..8064b3175dd5d 100644 --- a/tests/mir-opt/matches_reduce_branches.match_trunc_u16_u8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_trunc_u16_u8.MatchBranchSimplification.diff @@ -6,6 +6,7 @@ let mut _0: u8; let mut _2: u16; + let mut _3: u16; ++ let mut _4: bool; bb0: { _2 = discriminant(_1); @@ -60,6 +61,8 @@ + StorageLive(_3); + _3 = move _2; + _0 = copy _3 as u8 (IntToInt); ++ _4 = Le(copy _0, const u8::MAX); ++ assume(move _4); + StorageDead(_3); return; } diff --git a/tests/mir-opt/matches_reduce_branches.match_zext_u8_u16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_zext_u8_u16.MatchBranchSimplification.diff index a888d247275fd..3e7d56ae187ba 100644 --- a/tests/mir-opt/matches_reduce_branches.match_zext_u8_u16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_zext_u8_u16.MatchBranchSimplification.diff @@ -6,6 +6,7 @@ let mut _0: u16; let mut _2: u8; + let mut _3: u8; ++ let mut _4: bool; bb0: { _2 = discriminant(_1); @@ -40,6 +41,8 @@ + StorageLive(_3); + _3 = move _2; + _0 = copy _3 as u16 (IntToInt); ++ _4 = Le(copy _0, const 255_u16); ++ assume(move _4); + StorageDead(_3); return; } diff --git a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff index 99985b28382f5..0f9ec81b9d4cb 100644 --- a/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_u8.exhaustive_match.MatchBranchSimplification.diff @@ -6,6 +6,7 @@ let mut _0: u8; let mut _2: isize; + let mut _3: isize; ++ let mut _4: bool; bb0: { _2 = discriminant(_1); @@ -30,6 +31,8 @@ + StorageLive(_3); + _3 = move _2; + _0 = copy _3 as u8 (IntToInt); ++ _4 = Le(copy _0, const 1_u8); ++ assume(move _4); + StorageDead(_3); return; }