diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..c9e7713 --- /dev/null +++ b/.envrc @@ -0,0 +1,4 @@ +use flake + +# Add scripts directory to PATH for convenient access to minic wrapper +PATH_add "$(pwd)/scripts" diff --git a/scripts/minic b/scripts/minic new file mode 100755 index 0000000..4933564 --- /dev/null +++ b/scripts/minic @@ -0,0 +1,2 @@ +#!/bin/bash +exec "$(dirname "$0")/../target/debug/mini_c" "$@" diff --git a/src/codegen/tac_code_gen.rs b/src/codegen/tac_code_gen.rs index 2f5e6ee..c969094 100644 --- a/src/codegen/tac_code_gen.rs +++ b/src/codegen/tac_code_gen.rs @@ -93,6 +93,44 @@ pub fn translate_statement(statement: CheckedStmt, env: &mut Environment) -> Vec instructions.push(Instruction::Label(label_end_if)); instructions }, + Statement::Switch { target, cases, default } => { + let target_ty = target.ty.clone(); + let (target_addr, mut instructions) = translate_expression(*target, env); + let label_default = env.new_label(); + let label_end = env.new_label(); + + let case_labels: Vec = (0..cases.len()) + .map(|_| env.new_label()) + .collect(); + + for (i, (lit, _)) in cases.iter().enumerate() { + let lit_addr = Address::Constant(lit.clone(), target_ty.clone()); + instructions.push(Instruction::ConditionalJMPRelational( + Operator::EQ, + target_addr.clone(), + lit_addr, + case_labels[i].clone(), + )); + } + + instructions.push(Instruction::JMP(label_default.clone())); + + for (i, (_, body)) in cases.into_iter().enumerate() { + instructions.push(Instruction::Label(case_labels[i].clone())); + for stmt in body { + instructions.extend(translate_statement(stmt, env)); + } + instructions.push(Instruction::JMP(label_end.clone())); + } + + instructions.push(Instruction::Label(label_default)); + for stmt in default { + instructions.extend(translate_statement(stmt, env)); + } + + instructions.push(Instruction::Label(label_end)); + instructions + }, _ => todo!() } } diff --git a/src/interpreter/exec_stmt.rs b/src/interpreter/exec_stmt.rs index ceeda2c..ffc8615 100644 --- a/src/interpreter/exec_stmt.rs +++ b/src/interpreter/exec_stmt.rs @@ -34,7 +34,7 @@ //! This gives MiniC correct lexical block scoping without a scope stack. use crate::environment::Environment; -use crate::ir::ast::{CheckedExpr, CheckedStmt, Expr, Statement}; +use crate::ir::ast::{CheckedExpr, CheckedStmt, Expr, Literal, Statement}; use super::eval_expr::{eval_call, eval_expr}; use super::value::{RuntimeError, Value}; @@ -111,6 +111,53 @@ pub fn exec_stmt(stmt: &CheckedStmt, env: &mut Environment) -> ExecResult } } }, + + // --- Switch --- + Statement::Switch { target, cases, default } => { + // Detecta labels duplicadas como erro + let mut seen_cases = Vec::new(); + for (lit, _) in cases { + if seen_cases.contains(lit) { + return Err(RuntimeError::new(format!( + "duplicate case label in switch: {:?}", + lit + ))); + } + seen_cases.push(lit.clone()); + } + + // Avalia o valor da expressão alvo + let target_val = eval_expr(target, env)?; + + // Determina qual bloco de comandos executar (Padrão: default) + let mut stmts_to_exec = default; + + for (lit, stmts) in cases { + let case_matches = match (lit, &target_val) { + (Literal::Int(i), Value::Int(v)) => *i == *v, + (Literal::Bool(b1), Value::Bool(b2)) => *b1 == *b2, + _ => false, + }; + + if case_matches { + stmts_to_exec = stmts; // Substitui o bloco pelo case correspondente + break; // Sai do loop após encontrar o case correspondente + } + } + + // Executa o bloco escolhido isolando o escopo do switch + let outer_keys = env.names(); + + for stmt in stmts_to_exec { + if let Some(ret) = exec_stmt(stmt, env)? { + env.remove_new(&outer_keys); // Limpa o escopo se houver um 'return' no meio + return Ok(Some(ret)); + } + } + + env.remove_new(&outer_keys); // Limpa o escopo no final da execução normal + Ok(None) + }, // --- Return --- Statement::Return(Some(expr)) => { diff --git a/src/ir/ast.rs b/src/ir/ast.rs index fc400dd..f431930 100644 --- a/src/ir/ast.rs +++ b/src/ir/ast.rs @@ -9,7 +9,7 @@ //! * [`Literal`] — a constant value written directly in source code. //! * [`Expr`] / [`ExprD`] — expressions (arithmetic, comparisons, calls, …). //! * [`Statement`] / [`StatementD`] — statements (declarations, assignments, -//! `if`, `while`, `return`, blocks). +//! `if`, `while`, `switch`, `return`, blocks). //! * [`FunDecl`] — a single function declaration with its body. //! * [`Program`] — the top-level container: a list of function declarations. //! @@ -151,6 +151,11 @@ pub enum Statement { cond: Box>, body: Box>, }, + Switch { + target: Box>, + cases: Vec<(Literal, Vec>)>, + default: Vec>, + }, /// Return statement: `return [expr]`. Return(Option>>), } diff --git a/src/parser/statements.rs b/src/parser/statements.rs index 9dcfef5..4785479 100644 --- a/src/parser/statements.rs +++ b/src/parser/statements.rs @@ -5,7 +5,7 @@ //! Exposes two public functions: //! //! * [`statement`] — the top-level entry point; tries each statement form in -//! order: `return`, `if`, `while`, call-statement, block, declaration, +//! order: `return`, `if`, `while`, `switch`, call-statement, block, declaration, //! assignment. //! * [`assignment`] — parses `lvalue = expression ;`; exported separately //! because the test suite uses it directly. @@ -13,15 +13,16 @@ //! # Grammar //! //! ```text -//! statement := block | if_stmt | while_stmt | simple ';' +//! statement := block | if_stmt | while_stmt | switch_stmt | simple ';' //! block := '{' statement* '}' //! if_stmt := 'if' expr block ['else' block] //! while_stmt := 'while' expr block +//! switch (expr) { case literal: statement+; … default: statement+ } //! simple := return | decl | call | assign //! ``` //! //! Every simple statement is terminated by `;`. -//! Compound statements (`if`, `while`, block) end with `}` and need no `;`. +//! Compound statements (`if`, `while`, `switch`, block) end with `}` and need no `;`. //! //! # Design Decisions //! @@ -40,16 +41,17 @@ //! suffixes in a loop using the same pattern as the `primary` parser in //! `expressions.rs`, producing a left-associative `Index` chain. -use crate::ir::ast::{Expr, ExprD, Statement, StatementD, UncheckedExpr, UncheckedStmt}; +use crate::ir::ast::{Literal, Expr, ExprD, Statement, StatementD, UncheckedExpr, UncheckedStmt}; use crate::parser::expressions::{expression, parse_call}; use crate::parser::functions::type_name; use crate::parser::identifiers::identifier; +use crate::parser::literals::{integer_literal, boolean_literal}; use nom::{ branch::alt, bytes::complete::tag, character::complete::{char, multispace0}, combinator::{map, opt}, - multi::many0, + multi::{many0, many1}, sequence::{delimited, preceded, tuple}, IResult, }; @@ -58,7 +60,7 @@ fn wrap(s: Statement<()>) -> UncheckedStmt { StatementD { stmt: s, ty: () } } -/// Parse any statement: block | if | while | return | decl | call | assignment. +/// Parse any statement: block | if | while | switch | return | decl | call | assignment. pub fn statement(input: &str) -> IResult<&str, UncheckedStmt> { preceded( multispace0, @@ -66,6 +68,7 @@ pub fn statement(input: &str) -> IResult<&str, UncheckedStmt> { block_statement, if_statement, while_statement, + switch_statement, return_statement, decl_statement, call_statement, @@ -160,6 +163,49 @@ fn while_statement(input: &str) -> IResult<&str, UncheckedStmt> { )) } +/// Parse a switch statement: `switch (expr) { case literal: statement+; … default: statement+ }`. +fn switch_statement(input: &str) -> IResult<&str, UncheckedStmt> { + let (rest, _) = preceded(multispace0, tag("switch"))(input)?; + let (rest, target) = preceded(multispace0, expression)(rest)?; + let (rest, _) = preceded(multispace0, char('{'))(rest)?; + + let (rest, cases) = many1(map( + tuple(( + preceded(multispace0, tag("case")), + preceded( + multispace0, + alt(( + map(integer_literal, |i| Literal::Int(i)), + map(boolean_literal, |b| Literal::Bool(b)) + )) + ), + preceded(multispace0, char(':')), + many1(preceded(multispace0, statement)), + )), + |(_, literal, _, statements)| (literal, statements), + ))(rest)?; + + let (rest, default) = preceded(multispace0, map( + tuple(( + preceded(multispace0, tag("default")), + preceded(multispace0, char(':')), + many1(preceded(multispace0, statement)), + )), + |(_, _, statements)| statements, + ))(rest)?; + + let (rest, _) = preceded(multispace0, char('}'))(rest)?; + + Ok(( + rest, + wrap(Statement::Switch { + target: Box::new(target), + cases, + default, + }), + )) +} + /// Parse an lvalue: identifier followed by zero or more `[ expr ]` suffixes. fn lvalue(input: &str) -> IResult<&str, UncheckedExpr> { let (mut rest, id) = preceded(multispace0, identifier)(input)?; diff --git a/src/semantic/type_checker.rs b/src/semantic/type_checker.rs index e17c4b3..74193b1 100644 --- a/src/semantic/type_checker.rs +++ b/src/semantic/type_checker.rs @@ -232,6 +232,77 @@ fn type_check_stmt( body: Box::new(body_checked), } } + Statement::Switch { target, cases, default } => { + // Verifica o tipo da expressão alvo (target) + let target_checked = type_check_expr_to_typed(target, env)?; + + match target_checked.ty { + Type::Int | Type::Bool => {} + _ => { + return Err(TypeError::new(format!( + "switch target must be Int or Bool, got {:?}", + target_checked.ty + ))); + } + } + + // Verifica cada um dos cases + let mut checked_cases = Vec::new(); + let mut seen_cases = Vec::new(); + for (lit, stmts) in cases { + if seen_cases.contains(lit) { + return Err(TypeError::new(format!( + "duplicate case label in switch: {:?}", + lit + ))); + } + seen_cases.push(lit.clone()); + + // Descobre o tipo do literal do case atual + let lit_ty = match lit { + Literal::Int(_) => Type::Int, + Literal::Bool(_) => Type::Bool, + _ => { + return Err(TypeError::new(format!( + "switch case literal type not supported: {:?}", + lit + ))); + } + }; + + // Garante que o tipo do literal é compatível com o tipo do target + if !types_compatible(&target_checked.ty, &lit_ty) { + return Err(TypeError::new(format!( + "switch case literal type mismatch: expected {:?}, got {:?}", + target_checked.ty, lit_ty + ))); + } + + // Cria um novo escopo para as instruções deste case + let snapshot = env.snapshot(); + let mut checked_stmts = Vec::new(); + for stmt in stmts { + checked_stmts.push(type_check_stmt(stmt, env, expected_return)?); + } + env.restore(snapshot); + + checked_cases.push((lit.clone(), checked_stmts)); + } + + // Verifica o ramo padrão (default) + let snapshot = env.snapshot(); + let mut checked_default = Vec::new(); + for stmt in default { + checked_default.push(type_check_stmt(stmt, env, expected_return)?); + } + env.restore(snapshot); + + Statement::Switch { + target: Box::new(target_checked), + cases: checked_cases, + default: checked_default, + } + } Statement::Return(expr) => match expr { None => { if *expected_return != Type::Unit { diff --git a/tests/interpreter.rs b/tests/interpreter.rs index 51696c9..1eeef78 100644 --- a/tests/interpreter.rs +++ b/tests/interpreter.rs @@ -254,5 +254,157 @@ fn test_stdlib_pow_float_args() { let src = r#" void main() { float r = pow(2.0, 3.0); } "#; - assert!(run(src).is_ok(), "{}", run(src).unwrap_err()); + assert!(run(src.trim()).is_ok(), "{}", run(src.trim()).unwrap_err()); +} + +// --------------------------------------------------------------------------- +// 7.12 Switch statement execution +// --------------------------------------------------------------------------- +#[test] +fn test_switch_exec_match_case() { + let src = r#" + int run_switch(int x) { + int res = 0; + switch x { + case 1: res = 10; + case 2: res = 20; + default: res = 30; + } + return res; + } + void main() { + int r = run_switch(2); + if r != 20 { + int[] err = [0]; + int fail = err[9]; + } + } + "#; + assert!(run(src.trim()).is_ok(), "{}", run(src.trim()).unwrap_err()); +} + +#[test] +fn test_switch_exec_match_default() { + let src = r#" + int run_switch(int x) { + int res = 0; + switch x { + case 1: res = 10; + case 2: res = 20; + default: res = 30; + } + return res; + } + void main() { + int r = run_switch(5); + if r != 30 { + int[] err = [0]; + int fail = err[9]; + } + } + "#; + assert!(run(src.trim()).is_ok(), "{}", run(src.trim()).unwrap_err()); +} + +#[test] +fn test_switch_exec_early_return() { + let src = r#" + int run_switch(int x) { + switch x { + case 1: return 100; + case 2: return 200; + default: return 300; + } + } + void main() { + int r = run_switch(1); + if r != 100 { + int[] err = [0]; + int fail = err[9]; + } + } + "#; + assert!(run(src.trim()).is_ok(), "{}", run(src.trim()).unwrap_err()); +} + +#[test] +fn test_switch_exec_scoping() { + let src = r#" + void main() { + int x = 1; + switch x { + case 1: int y = 10; + default: int y = 20; + } + y = 30; + } + "#; + let result = run(src.trim()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("type error: undeclared variable")); } + +#[test] +fn test_switch_exec_duplicate_error() { + let src = r#" + void main() { + int x = 1; + switch x { + case 1: x = 2; + case 1: x = 3; + default: x = 4; + } + } + "#; + let result = run(src.trim()); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("duplicate case label")); +} + +#[test] +fn test_switch_exec_bool_cases() { + let src = r#" + int check_bool(bool val) { + int res = 0; + switch val { + case true: res = 11; + case false: res = 22; + default: res = 33; + } + return res; + } + void main() { + int r1 = check_bool(true); + int r2 = check_bool(false); + if r1 != 11 or r2 != 22 { + int[] err = [0]; + int fail = err[9]; + } + } + "#; + assert!(run(src.trim()).is_ok(), "{}", run(src.trim()).unwrap_err()); +} + +#[test] +fn test_switch_exec_complex_target() { + let src = r#" + int check_complex(int a, int b) { + int res = 0; + switch a + b { + case 5: res = 50; + case 10: res = 100; + default: res = 999; + } + return res; + } + void main() { + int r = check_complex(2, 3); + if r != 50 { + int[] err = [0]; + int fail = err[9]; + } + } + "#; + assert!(run(src.trim()).is_ok(), "{}", run(src.trim()).unwrap_err()); +} + diff --git a/tests/parser.rs b/tests/parser.rs index eca6640..b207d51 100644 --- a/tests/parser.rs +++ b/tests/parser.rs @@ -672,3 +672,221 @@ fn test_array_in_expression() { assert!(matches!(result.exp, Expr::Index { ref base, ref index } if matches!(base.exp, Expr::ArrayLit(_)) && index.exp == Expr::Literal(Literal::Int(0)))); } + +// --- Switch --- + +#[test] +fn test_switch_statement() { + let result = statement("switch x { case 1: y = 1; default: y = 0; }").unwrap().1; + assert!(matches!(result.stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "x") && cases.len() == 1 && default.len() == 1)); + if let Statement::Switch { ref cases, ref default, .. } = result.stmt { + assert!(matches!(cases[0].0, Literal::Int(1))); + assert!(matches!(cases[0].1.len(), 1)); + if let Statement::Assign { ref target, ref value } = cases[0].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(1))); + } + assert!(matches!(default.len(), 1)); + if let Statement::Assign { ref target, ref value } = default[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(0))); + } + } +} + +#[test] +fn test_switch_multiple_cases() { + let result = statement("switch x { case 1: y = 1; case 2: y = 2; default: y = 0; }").unwrap().1; + assert!(matches!(result.stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "x") && cases.len() == 2 && default.len() == 1)); + if let Statement::Switch { ref cases, ref default, .. } = result.stmt { + assert!(matches!(cases[0].0, Literal::Int(1))); + assert!(matches!(cases[0].1.len(), 1)); + if let Statement::Assign { ref target, ref value } = cases[0].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(1))); + } + assert!(matches!(cases[1].0, Literal::Int(2))); + assert!(matches!(cases[1].1.len(), 1)); + if let Statement::Assign { ref target, ref value } = cases[1].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(2))); + } + assert!(matches!(default.len(), 1)); + if let Statement::Assign { ref target, ref value } = default[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(0))); + } + } +} + +#[test] +fn test_switch_multiple_statements() { + let result = statement("switch x { case 1: y = 1; z = 2; default: y = 0; z = 1; }").unwrap().1; + assert!(matches!(result.stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "x") && cases.len() == 1 && default.len() == 2)); + if let Statement::Switch { ref cases, ref default, .. } = result.stmt { + assert!(matches!(cases[0].0, Literal::Int(1))); + assert!(matches!(cases[0].1.len(), 2)); + if let Statement::Assign { ref target, ref value } = cases[0].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(1))); + } + if let Statement::Assign { ref target, ref value } = cases[0].1[1].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "z")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(2))); + } + assert!(matches!(default.len(), 2)); + if let Statement::Assign { ref target, ref value } = default[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(0))); + } + if let Statement::Assign { ref target, ref value } = default[1].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "z")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(1))); + } + } +} + +#[test] +fn test_switch_multiple_cases_and_statements() { + let result = statement("switch x { case 1: y = 1; z = 2; case 2: y = 2; z = 3; default: y = 0; z = 1; }").unwrap().1; + assert!(matches!(result.stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "x") && cases.len() == 2 && default.len() == 2)); + if let Statement::Switch { ref cases, ref default, .. } = result.stmt { + assert!(matches!(cases[0].0, Literal::Int(1))); + assert!(matches!(cases[0].1.len(), 2)); + if let Statement::Assign { ref target, ref value } = cases[0].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(1))); + } + if let Statement::Assign { ref target, ref value } = cases[0].1[1].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "z")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(2))); + } + assert!(matches!(cases[1].0, Literal::Int(2))); + assert!(matches!(cases[1].1.len(), 2)); + if let Statement::Assign { ref target, ref value } = cases[1].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(2))); + } + if let Statement::Assign { ref target, ref value } = cases[1].1[1].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "z")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(3))); + } + assert!(matches!(default.len(), 2)); + if let Statement::Assign { ref target, ref value } = default[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(0))); + } + if let Statement::Assign { ref target, ref value } = default[1].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "z")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(1))); + } + } +} + +#[test] +fn test_switch_boolean_cases() { + let result = statement("switch x { case true: y = 1; case false: y = 0; default: y = 3; }").unwrap().1; + assert!(matches!(result.stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "x") && cases.len() == 2 && default.len() == 1)); + if let Statement::Switch { ref cases, ref default, .. } = result.stmt { + assert!(matches!(cases[0].0, Literal::Bool(true))); + assert!(matches!(cases[0].1.len(), 1)); + if let Statement::Assign { ref target, ref value } = cases[0].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(1))); + } + assert!(matches!(cases[1].0, Literal::Bool(false))); + assert!(matches!(cases[1].1.len(), 1)); + if let Statement::Assign { ref target, ref value } = cases[1].1[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(0))); + } + assert!(matches!(default.len(), 1)); + if let Statement::Assign { ref target, ref value } = default[0].stmt { + assert!(matches!(target.exp, Expr::Ident(ref s) if s == "y")); + assert_eq!(value.exp, Expr::Literal(Literal::Int(3))); + } + } +} + +#[test] +fn test_switch_multiple_defaults_err() { + assert!(statement("switch x { case 1: y = 1; default: y = 0; default: z = 2; }").is_err()); +} + +#[test] +fn test_switch_no_default_err() { + assert!(statement("switch x { case 1: y = 1; }").is_err()); +} + +#[test] +fn test_switch_no_cases_err() { + assert!(statement("switch x { default: y = 0; }").is_err()); +} + +#[test] +fn test_switch_non_expression_err() { + assert!(statement("switch { case 1: y = 1; default: y = 0; }").is_err()); +} + +#[test] +fn test_switch_invalid_case_literal() { + assert!(statement("switch x { case y: y = 1; default: y = 0; }").is_err()); + assert!(statement("switch x { case 1.5: y = 1; default: y = 0; }").is_err()); + assert!(statement("switch x { case \"str\": y = 1; default: y = 0; }").is_err()); +} + +#[test] +fn test_switch_invalid_syntax() { + assert!(statement("switch x { case 1 y = 1; default: y = 0; }").is_err()); + assert!(statement("switch x { case 1: y = 1; default y = 0; }").is_err()); + assert!(statement("switch x { case 1: y = 1; default: y = 0 ").is_err()); + assert!(statement("switch x case 1: y = 1; default: y = 0; }").is_err()); + assert!(statement("switch x { case 1: y = 1; default: y = 0;").is_err()); +} + +#[test] +fn test_switch_whitespace() { + assert!(statement("switch x { case 1: y = 1; default: y = 0; }").is_ok()); + assert!(statement("switch x { case 1 : y = 1 ; default : y = 0 ; }").is_ok()); +} + +#[test] +fn test_switch_in_function() { + let result = fun_decl("void foo() { switch x { case 1: y = 1; default: y = 0; } }").unwrap().1; + assert!(matches!(result.body.stmt, Statement::Block { ref seq } if seq.len() == 1)); + if let Statement::Block { ref seq } = result.body.stmt { + assert!(matches!(seq[0].stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "x") && cases.len() == 1 && default.len() == 1)); + } +} + +#[test] +fn test_switch_in_if() { + let result = statement("if x { switch y { case 1: z = 1; default: z = 0; } }").unwrap().1; + assert!(matches!(result.stmt, Statement::If { .. })); + if let Statement::If { ref then_branch, .. } = &result.stmt { + assert!(matches!(then_branch.stmt, Statement::Block { ref seq } if seq.len() == 1)); + if let Statement::Block { ref seq } = &then_branch.stmt { + assert!(matches!(seq[0].stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "y") && cases.len() == 1 && default.len() == 1)); + } + } +} + +#[test] +fn test_switch_in_while() { + let result = statement("while x { switch y { case 1: z = 1; default: z = 0; } }").unwrap().1; + assert!(matches!(result.stmt, Statement::While { .. })); + if let Statement::While { ref body, .. } = &result.stmt { + assert!(matches!(body.stmt, Statement::Block { ref seq } if seq.len() == 1)); + if let Statement::Block { ref seq } = &body.stmt { + assert!(matches!(seq[0].stmt, Statement::Switch { ref target, ref cases, ref default } + if matches!(target.exp, Expr::Ident(ref s) if s == "y") && cases.len() == 1 && default.len() == 1)); + } + } +} diff --git a/tests/tac_gen.rs b/tests/tac_gen.rs index f602e15..989c301 100644 --- a/tests/tac_gen.rs +++ b/tests/tac_gen.rs @@ -58,12 +58,138 @@ fn test_if_else_with_relational_condition() { let temp = Address::Temporary("temp1".to_string(), Type::Int); assert_eq!(instructions, vec![ - Instruction::ConditionalJMPRelational(Operator::GTE, x.clone(), y.clone(), "Label1:".to_string()), + Instruction::ConditionalJMPRelational(Operator::LT, x.clone(), y.clone(), "Label1:".to_string()), + Instruction::JMP("Label2:".to_string()), + Instruction::Label("Label1:".to_string()), Instruction::BinaryAssignment(Operator::Add, temp.clone(), x.clone(), y.clone()), Instruction::CopyAssignment(z.clone(), temp), + Instruction::JMP("Label3:".to_string()), + Instruction::Label("Label2:".to_string()), + Instruction::CopyAssignment(z, x), + Instruction::Label("Label3:".to_string()), + ]); +} + +#[test] +fn test_switch_int_tac() { + let stmt = StatementD { + stmt: Statement::Switch { + target: Box::new(int_var("x")), + cases: vec![ + (Literal::Int(1), vec![assign("z", ExprD { exp: Expr::Literal(Literal::Int(10)), ty: Type::Int })]), + (Literal::Int(2), vec![assign("z", ExprD { exp: Expr::Literal(Literal::Int(20)), ty: Type::Int })]), + ], + default: vec![ + assign("z", ExprD { exp: Expr::Literal(Literal::Int(30)), ty: Type::Int }) + ], + }, + ty: Type::Unit, + }; + + let mut env = Environment::new(); + let instructions = translate_statement(stmt, &mut env); + + let x = Address::Variable("x".to_string(), Type::Int); + let z = Address::Variable("z".to_string(), Type::Int); + + assert_eq!(instructions, vec![ + // comparisons + Instruction::ConditionalJMPRelational(Operator::EQ, x.clone(), Address::Constant(Literal::Int(1), Type::Int), "Label3:".to_string()), + Instruction::ConditionalJMPRelational(Operator::EQ, x.clone(), Address::Constant(Literal::Int(2), Type::Int), "Label4:".to_string()), + // fallback to default + Instruction::JMP("Label1:".to_string()), + // case 1 + Instruction::Label("Label3:".to_string()), + Instruction::CopyAssignment(z.clone(), Address::Constant(Literal::Int(10), Type::Int)), Instruction::JMP("Label2:".to_string()), + // case 2 + Instruction::Label("Label4:".to_string()), + Instruction::CopyAssignment(z.clone(), Address::Constant(Literal::Int(20), Type::Int)), + Instruction::JMP("Label2:".to_string()), + // default Instruction::Label("Label1:".to_string()), - Instruction::CopyAssignment(z, x), + Instruction::CopyAssignment(z.clone(), Address::Constant(Literal::Int(30), Type::Int)), + // end + Instruction::Label("Label2:".to_string()), + ]); +} + +#[test] +fn test_switch_bool_tac() { + let stmt = StatementD { + stmt: Statement::Switch { + target: Box::new(ExprD { exp: Expr::Ident("b".to_string()), ty: Type::Bool }), + cases: vec![ + (Literal::Bool(true), vec![assign("z", ExprD { exp: Expr::Literal(Literal::Int(1)), ty: Type::Int })]), + ], + default: vec![ + assign("z", ExprD { exp: Expr::Literal(Literal::Int(0)), ty: Type::Int }) + ], + }, + ty: Type::Unit, + }; + + let mut env = Environment::new(); + let instructions = translate_statement(stmt, &mut env); + + let b = Address::Variable("b".to_string(), Type::Bool); + let z = Address::Variable("z".to_string(), Type::Int); + + assert_eq!(instructions, vec![ + // comparisons + Instruction::ConditionalJMPRelational(Operator::EQ, b.clone(), Address::Constant(Literal::Bool(true), Type::Bool), "Label3:".to_string()), + // fallback to default + Instruction::JMP("Label1:".to_string()), + // case true + Instruction::Label("Label3:".to_string()), + Instruction::CopyAssignment(z.clone(), Address::Constant(Literal::Int(1), Type::Int)), + Instruction::JMP("Label2:".to_string()), + // default + Instruction::Label("Label1:".to_string()), + Instruction::CopyAssignment(z.clone(), Address::Constant(Literal::Int(0), Type::Int)), + // end + Instruction::Label("Label2:".to_string()), + ]); +} + +#[test] +fn test_switch_complex_target_tac() { + let stmt = StatementD { + stmt: Statement::Switch { + target: Box::new(add(int_var("x"), int_var("y"))), + cases: vec![ + (Literal::Int(5), vec![assign("z", ExprD { exp: Expr::Literal(Literal::Int(50)), ty: Type::Int })]), + ], + default: vec![ + assign("z", ExprD { exp: Expr::Literal(Literal::Int(999)), ty: Type::Int }) + ], + }, + ty: Type::Unit, + }; + + let mut env = Environment::new(); + let instructions = translate_statement(stmt, &mut env); + + let x = Address::Variable("x".to_string(), Type::Int); + let y = Address::Variable("y".to_string(), Type::Int); + let z = Address::Variable("z".to_string(), Type::Int); + let temp = Address::Temporary("temp1".to_string(), Type::Int); + + assert_eq!(instructions, vec![ + // evaluate target expression + Instruction::BinaryAssignment(Operator::Add, temp.clone(), x, y), + // comparisons + Instruction::ConditionalJMPRelational(Operator::EQ, temp.clone(), Address::Constant(Literal::Int(5), Type::Int), "Label3:".to_string()), + // fallback to default + Instruction::JMP("Label1:".to_string()), + // case 5 + Instruction::Label("Label3:".to_string()), + Instruction::CopyAssignment(z.clone(), Address::Constant(Literal::Int(50), Type::Int)), + Instruction::JMP("Label2:".to_string()), + // default + Instruction::Label("Label1:".to_string()), + Instruction::CopyAssignment(z.clone(), Address::Constant(Literal::Int(999), Type::Int)), + // end Instruction::Label("Label2:".to_string()), ]); } diff --git a/tests/type_checker.rs b/tests/type_checker.rs index 3357161..97ad085 100644 --- a/tests/type_checker.rs +++ b/tests/type_checker.rs @@ -202,3 +202,141 @@ fn test_type_check_print_wrong_arity() { let result = parse_and_type_check("void main() { print(1, 2); }"); assert!(result.is_err(), "expected arity error for print(1, 2)"); } + +// --------------------------------------------------------------------------- +// 7.4 Switch Statement +// --------------------------------------------------------------------------- +#[test] +fn test_type_check_switch_int_ok() { + let src = r#" + void main() { + int x = 1; + int y = 0; + switch x { + case 1: y = 2; + case 2: y = 3; + default: y = 4; + } + } + "#; + assert!(parse_and_type_check(src.trim()).is_ok()); +} + +#[test] +fn test_type_check_switch_bool_ok() { + let src = r#" + void main() { + bool x = true; + int y = 0; + switch x { + case true: y = 1; + case false: y = 2; + default: y = 3; + } + } + "#; + assert!(parse_and_type_check(src.trim()).is_ok()); +} + +#[test] +fn test_type_check_switch_mismatched_case_type() { + let src = r#" + void main() { + int x = 1; + int y = 0; + switch x { + case true: y = 1; + default: y = 0; + } + } + "#; + let result = parse_and_type_check(src.trim()); + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("mismatch")); +} + +#[test] +fn test_type_check_switch_unsupported_target_type() { + let src = r#" + void main() { + float x = 1.0; + int y = 0; + switch x { + case 1: y = 1; + default: y = 0; + } + } + "#; + let result = parse_and_type_check(src.trim()); + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("must be Int or Bool")); +} + +#[test] +fn test_type_check_switch_duplicate_case_labels() { + let src = r#" + void main() { + int x = 1; + int y = 0; + switch x { + case 1: y = 2; + case 1: y = 3; + default: y = 4; + } + } + "#; + let result = parse_and_type_check(src.trim()); + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("duplicate case label")); +} + +#[test] +fn test_type_check_switch_cross_branch_pollution_err() { + let src = r#" + void main() { + int x = 1; + switch x { + case 1: int y = 10; + case 2: y = 20; + default: int z = 30; + } + } + "#; + let result = parse_and_type_check(src.trim()); + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("undeclared variable")); +} + +#[test] +fn test_type_check_switch_independent_branches_ok() { + let src = r#" + void main() { + int x = 1; + switch x { + case 1: int y = 10; + case 2: int y = 20; + default: int y = 30; + } + } + "#; + let result = parse_and_type_check(src.trim()); + assert!(result.is_ok()); +} + +#[test] +fn test_type_check_switch_scope_leak_err() { + let src = r#" + void main() { + int x = 1; + switch x { + case 1: int y = 10; + default: int z = 30; + } + y = 20; + } + "#; + let result = parse_and_type_check(src.trim()); + assert!(result.is_err()); + assert!(result.unwrap_err().message.contains("undeclared variable")); +} +