diff --git a/crates/cairo-lang-semantic/src/expr/inference/canonic.rs b/crates/cairo-lang-semantic/src/expr/inference/canonic.rs index bdf4b627339..55a13c74b94 100644 --- a/crates/cairo-lang-semantic/src/expr/inference/canonic.rs +++ b/crates/cairo-lang-semantic/src/expr/inference/canonic.rs @@ -535,6 +535,9 @@ impl<'db> SemanticRewriter, MapperError> for Mapper<'db, '_> { let TypeLongId::Var(var) = value else { return value.default_rewrite(self); }; + if var.inference_id != self.mapping.source_inference_id { + return value.default_rewrite(self); + } let id = self .mapping .type_var_mapping @@ -553,6 +556,9 @@ impl<'db> SemanticRewriter, MapperError> for Mapper<'db, '_> { let ConstValue::Var(var, ty) = value else { return value.default_rewrite(self); }; + if var.inference_id != self.mapping.source_inference_id { + return value.default_rewrite(self); + } let id = self .mapping .const_var_mapping @@ -582,6 +588,9 @@ impl<'db> SemanticRewriter, MapperError> for Mapper<'db, '_> { return value.default_rewrite(self); }; let var = var_id.long(self.get_db()); + if var.inference_id != self.mapping.source_inference_id { + return value.default_rewrite(self); + } let id = self .mapping .impl_var_mapping @@ -614,6 +623,9 @@ impl<'db> SemanticRewriter, MapperError> for Mapper<'db, return value.default_rewrite(self); }; let var = var_id.long(self.get_db()); + if var.inference_id != self.mapping.source_inference_id { + return value.default_rewrite(self); + } let id = self .mapping .negative_impl_var_mapping @@ -627,3 +639,83 @@ impl<'db> SemanticRewriter, MapperError> for Mapper<'db, Ok(RewriteResult::Modified) } } + +#[cfg(test)] +mod tests { + use cairo_lang_utils::Intern; + use salsa::Database; + + use super::{Mapper, VarMapping}; + use crate::expr::inference::{ + InferenceData, InferenceId, LocalConstVarId, LocalTypeVarId, TypeVar, + }; + use crate::items::constant::ConstValue; + use crate::test_utils::SemanticDatabaseForTesting; + use crate::TypeLongId; + + fn overlapping_type_vars_from_inference<'db>( + db: &'db dyn Database, + ) -> (TypeVar<'db>, TypeVar<'db>) { + let mut source_inference_data = InferenceData::new(InferenceId::NoContext); + let mut source_inference = source_inference_data.inference(db); + let source_var = source_inference.new_type_var_raw(None); + + let mut foreign_inference_data = InferenceData::new(InferenceId::Canonical); + let mut foreign_inference = foreign_inference_data.inference(db); + let foreign_var = foreign_inference.new_type_var_raw(None); + + assert_eq!(source_var.id, foreign_var.id); + (source_var, foreign_var) + } + + #[test] + fn mapper_maps_source_type_var_created_by_inference() { + let db = SemanticDatabaseForTesting::default(); + let db: &dyn Database = &db; + let (source_var, _) = overlapping_type_vars_from_inference(db); + let mapped_var_id = LocalTypeVarId(source_var.id.0 + 7); + let original = TypeLongId::Var(source_var); + let mut mapping = VarMapping::new_to_canonic(source_var.inference_id); + mapping.type_var_mapping.insert(source_var.id, mapped_var_id); + + let mapped = Mapper::map(db, original, &mapping).unwrap(); + let TypeLongId::Var(mapped_var) = mapped else { panic!("expected type var after map") }; + assert_eq!(mapped_var.inference_id, InferenceId::Canonical); + assert_eq!(mapped_var.id, mapped_var_id); + } + + #[test] + fn mapper_ignores_foreign_type_var_with_overlapping_local_id() { + let db = SemanticDatabaseForTesting::default(); + let db: &dyn Database = &db; + let (source_var, foreign_var) = overlapping_type_vars_from_inference(db); + let mut mapping = VarMapping::new_to_canonic(source_var.inference_id); + mapping.type_var_mapping.insert(source_var.id, LocalTypeVarId(source_var.id.0 + 7)); + + let original = TypeLongId::Var(foreign_var); + let mapped = Mapper::map(db, original.clone(), &mapping).unwrap(); + assert_eq!(mapped, original); + } + + #[test] + fn mapper_ignores_foreign_const_var_with_overlapping_local_id() { + let db = SemanticDatabaseForTesting::default(); + let db: &dyn Database = &db; + let mut source_inference_data = InferenceData::new(InferenceId::NoContext); + let mut source_inference = source_inference_data.inference(db); + let source_var = source_inference.new_const_var_raw(None); + + let mut foreign_inference_data = InferenceData::new(InferenceId::Canonical); + let mut foreign_inference = foreign_inference_data.inference(db); + let foreign_var = foreign_inference.new_const_var_raw(None); + assert_eq!(source_var.id, foreign_var.id); + + let mut mapping = VarMapping::new_to_canonic(source_var.inference_id); + mapping.const_var_mapping.insert(source_var.id, LocalConstVarId(source_var.id.0 + 3)); + + let ty = TypeLongId::Tuple(vec![]).intern(db); + let original = ConstValue::Var(foreign_var, ty); + let mapped = Mapper::map(db, original.clone(), &mapping).unwrap(); + assert_eq!(mapped, original); + } +}