Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode, Obliga
use tracing::{debug, instrument, trace};

use crate::deref_separator::deref_finder;
use crate::patch::MirPatch;
use crate::{abort_unwinding_calls, errors, pass_manager as pm, simplify};

pub(super) struct StateTransform;
Expand Down Expand Up @@ -210,6 +211,10 @@ struct TransformVisitor<'tcx> {
old_yield_ty: Ty<'tcx>,

old_ret_ty: Ty<'tcx>,

body_span: Span,

patch: MirPatch<'tcx>,
}

impl<'tcx> TransformVisitor<'tcx> {
Expand Down Expand Up @@ -383,6 +388,36 @@ impl<'tcx> TransformVisitor<'tcx> {
(assign, temp)
}

// Create a temporary assignment if the visited destination and RHS now
// refer to overlapping fields of the coroutine state.
// https://github.com/rust-lang/rust/issues/149748
// trans
// _x.a = copy/move _x.b
// into
// _temp = copy/move _x.b
// _x.a = move _temp
fn split_overlapping_assignment_if_needed(
&mut self,
dst: Place<'tcx>,
dst_ty: Option<Ty<'tcx>>,
rvalue: &mut Rvalue<'tcx>,
location: Location,
) {
if let Rvalue::Use(operand) = rvalue
&& let Operand::Copy(src) | Operand::Move(src) = *operand
&& !src.is_indirect()
&& !dst.is_indirect()
&& src.local == dst.local
&& let Some(ty) = dst_ty
{
let temp = Place::from(self.patch.new_temp(ty, self.body_span));
let temp_assign_stmt =
StatementKind::Assign(Box::new((temp, Rvalue::Use(operand.clone()))));
self.patch.add_statement(location, temp_assign_stmt);
*operand = Operand::Move(temp);
}
}

/// Swaps all references of `old_local` and `new_local`.
#[tracing::instrument(level = "trace", skip(self, body))]
fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
Expand Down Expand Up @@ -416,6 +451,18 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
}
}

fn visit_assign(
&mut self,
dst: &mut Place<'tcx>,
rvalue: &mut Rvalue<'tcx>,
location: Location,
) {
let dst_ty = self.remap.get(dst.local).and_then(|entry| entry.map(|(ty, _, _)| ty));
self.visit_place(dst, PlaceContext::MutatingUse(MutatingUseContext::Store), location);
self.visit_rvalue(rvalue, location);
self.split_overlapping_assignment_if_needed(*dst, dst_ty, rvalue, location);
}

#[tracing::instrument(level = "trace", skip(self, stmt), ret)]
fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
// Remove StorageLive and StorageDead statements for remapped locals
Expand Down Expand Up @@ -1584,8 +1631,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
new_ret_local,
old_ret_ty,
old_yield_ty,
body_span: body.span,
patch: MirPatch::new(body),
};
transform.visit_body(body);
std::mem::replace(&mut transform.patch, MirPatch::new(body)).apply(body);

// Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
transform.replace_local(RETURN_PLACE, new_ret_local, body);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _2;
debug x => ((*_20).0: T);
debug x => ((*_21).0: T);
let mut _0: std::task::Poll<()>;
let _3: T;
let mut _4: impl std::future::Future<Output = ()>;
Expand All @@ -20,16 +20,17 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
let mut _16: std::pin::Pin<&mut impl std::future::Future<Output = ()>>;
let mut _17: isize;
let mut _18: ();
let mut _19: u32;
let mut _20: &mut {async fn body of a<T>()};
let mut _19: T;
let mut _20: u32;
let mut _21: &mut {async fn body of a<T>()};
scope 1 {
debug x => (((*_20) as variant#4).0: T);
debug x => (((*_21) as variant#4).0: T);
}

bb0: {
_20 = copy (_1.0: &mut {async fn body of a<T>()});
_19 = discriminant((*_20));
switchInt(move _19) -> [0: bb9, 3: bb12, 4: bb13, otherwise: bb14];
_21 = copy (_1.0: &mut {async fn body of a<T>()});
_20 = discriminant((*_21));
switchInt(move _20) -> [0: bb9, 3: bb12, 4: bb13, otherwise: bb14];
}

bb1: {
Expand All @@ -45,13 +46,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)

bb3: {
_0 = Poll::<()>::Pending;
discriminant((*_20)) = 4;
discriminant((*_21)) = 4;
return;
}

bb4: {
StorageLive(_16);
_15 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
_15 = &mut (((*_21) as variant#4).1: impl std::future::Future<Output = ()>);
_16 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _15) -> [return: bb7, unwind unreachable];
}

Expand Down Expand Up @@ -83,7 +84,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
}

bb11: {
drop(((*_20).0: T)) -> [return: bb10, unwind unreachable];
drop(((*_21).0: T)) -> [return: bb10, unwind unreachable];
}

bb12: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _2;
debug x => ((*_20).0: T);
debug x => ((*_21).0: T);
let mut _0: std::task::Poll<()>;
let _3: T;
let mut _4: impl std::future::Future<Output = ()>;
Expand All @@ -20,16 +20,17 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
let mut _16: std::pin::Pin<&mut impl std::future::Future<Output = ()>>;
let mut _17: isize;
let mut _18: ();
let mut _19: u32;
let mut _20: &mut {async fn body of a<T>()};
let mut _19: T;
let mut _20: u32;
let mut _21: &mut {async fn body of a<T>()};
scope 1 {
debug x => (((*_20) as variant#4).0: T);
debug x => (((*_21) as variant#4).0: T);
}

bb0: {
_20 = copy (_1.0: &mut {async fn body of a<T>()});
_19 = discriminant((*_20));
switchInt(move _19) -> [0: bb12, 2: bb18, 3: bb16, 4: bb17, otherwise: bb19];
_21 = copy (_1.0: &mut {async fn body of a<T>()});
_20 = discriminant((*_21));
switchInt(move _20) -> [0: bb12, 2: bb18, 3: bb16, 4: bb17, otherwise: bb19];
}

bb1: {
Expand Down Expand Up @@ -59,13 +60,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)

bb6: {
_0 = Poll::<()>::Pending;
discriminant((*_20)) = 4;
discriminant((*_21)) = 4;
return;
}

bb7: {
StorageLive(_16);
_15 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
_15 = &mut (((*_21) as variant#4).1: impl std::future::Future<Output = ()>);
_16 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _15) -> [return: bb10, unwind: bb15];
}

Expand Down Expand Up @@ -97,11 +98,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
}

bb14: {
drop(((*_20).0: T)) -> [return: bb13, unwind: bb4];
drop(((*_21).0: T)) -> [return: bb13, unwind: bb4];
}

bb15 (cleanup): {
discriminant((*_20)) = 2;
discriminant((*_21)) = 2;
resume;
}

Expand Down
Loading
Loading