From 0191d46142af5b7b58cb8da0cad2f5a7b5d1cfe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Cabrera?= Date: Tue, 12 May 2026 13:29:55 -0400 Subject: [PATCH 1/3] Introduce the javy-profiler-lib crate Part of https://github.com/bytecodealliance/javy/issues/1206 This is the first change towards a Wasm-native profiler for Javy generated modules. Particularly, this change introduces the profiler library crate, which will be used as a user-provided library by [Whamm](https://github.com/ejrgilbert/whamm). The library will be compiled to wasm32-wasip1 and will have the following responsibilities: * Detecting the Wasm function in the target app holding the interpreter dispatch loop * Detecting the byte-load provenance used to dispatch to the next opcode (index argument to `br_table`) * Holing the necessary state to construct the JS execution trace * Generating the execution trace output for analysis This change includes, point 1 and 2. The remaining functionality will be done in separate pull requests. Note that, as stated in the issue above, this work is deemed experimental for the time being and this crate will be included through a feature flag when added to the CLI. --- Cargo.lock | 9 + Cargo.toml | 1 + crates/profiler-lib/Cargo.toml | 21 ++ crates/profiler-lib/src/ai.rs | 399 +++++++++++++++++++++++++++++++ crates/profiler-lib/src/lib.rs | 31 +++ crates/profiler-lib/src/state.rs | 389 ++++++++++++++++++++++++++++++ 6 files changed, 850 insertions(+) create mode 100644 crates/profiler-lib/Cargo.toml create mode 100644 crates/profiler-lib/src/ai.rs create mode 100644 crates/profiler-lib/src/lib.rs create mode 100644 crates/profiler-lib/src/state.rs diff --git a/Cargo.lock b/Cargo.lock index f04951e0..08da5283 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1722,6 +1722,15 @@ dependencies = [ "wasmtime-wizer", ] +[[package]] +name = "javy-profiler-lib" +version = "0.0.0" +dependencies = [ + "anyhow", + "walrus", + "wat", +] + [[package]] name = "javy-release" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f352078d..4b3709e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "crates/runner", "fuzz", "release", + "crates/profiler-lib", ] resolver = "2" diff --git a/crates/profiler-lib/Cargo.toml b/crates/profiler-lib/Cargo.toml new file mode 100644 index 00000000..9ea91be7 --- /dev/null +++ b/crates/profiler-lib/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "javy-profiler-lib" +version = "0.0.0" +authors.workspace = true +edition.workspace = true +license.workspace = true +publish = false + +[lib] +name = "javy_profiler_lib" +crate-type = ["cdylib"] + +[dependencies] +anyhow = { workspace = true } +walrus = { workspace = true } + +[dev-dependencies] +wat = "1" + +[package.metadata.javy] +targets = ["wasip1"] \ No newline at end of file diff --git a/crates/profiler-lib/src/ai.rs b/crates/profiler-lib/src/ai.rs new file mode 100644 index 00000000..88f37aea --- /dev/null +++ b/crates/profiler-lib/src/ai.rs @@ -0,0 +1,399 @@ +//! Abstract interpretation for `br_target` provenance. +//! +//! Given a WebAssembly module and a local function id, perform +//! abstract interpetation to determine the byte offset provenance of +//! the br_table index argument. +//! The target br_table instruction is chosen using a fixed number of +//! branch targets as heuristic. + +use std::collections::{BTreeSet, HashMap}; +use walrus::ir::{ + Binop, Block, Br, BrIf, BrTable, Call, CallIndirect, Const, Drop, GlobalGet, GlobalSet, IfElse, + Instr, InstrSeq, InstrSeqId, Load, LoadKind, LocalGet, LocalSet, LocalTee, Loop, MemoryGrow, + MemorySize, Return, ReturnCall, ReturnCallIndirect, Select, Store, Unop, Unreachable, Visitor, + dfs_in_order, +}; +use walrus::{InstrLocId, LocalFunction, LocalId, Module}; + +/// The set of offsets of byte-load instructions that contribute to +/// the br_table index argument. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +struct Provenance(BTreeSet); + +impl Provenance { + fn new() -> Self { + Self(BTreeSet::new()) + } + + fn with(pos: u32) -> Self { + let mut s = BTreeSet::new(); + s.insert(pos); + Self(s) + } + + fn join_in_place(&mut self, other: &Provenance) { + for &v in &other.0 { + self.0.insert(v); + } + } + + fn joined(&self, other: &Provenance) -> Provenance { + let mut out = self.clone(); + out.join_in_place(other); + out + } +} + +#[derive(Clone, Debug, Default)] +struct AbstractState { + stack: Vec, + locals: HashMap, +} + +impl AbstractState { + fn join(&mut self, other: &AbstractState) { + // Wasm spec ensures that stack length must match at merge + // points. + let n = self.stack.len().min(other.stack.len()); + for i in 0..n { + self.stack[i].join_in_place(&other.stack[i]); + } + for (k, v) in &other.locals { + self.locals + .entry(*k) + .and_modify(|cur| cur.join_in_place(v)) + .or_insert_with(|| v.clone()); + } + } + + fn pop(&mut self) -> Provenance { + self.stack.pop().unwrap_or_default() + } + + fn push(&mut self, p: Provenance) { + self.stack.push(p); + } +} + +/// Control frames. +enum ControlFrame { + Block { + seq_id: InstrSeqId, + target: AbstractState, + }, + Loop { + seq_id: InstrSeqId, + header: AbstractState, + }, + If { + seq_id: InstrSeqId, + entry_state: AbstractState, + target: AbstractState, + }, + Else { + seq_id: InstrSeqId, + if_exit: AbstractState, + target: AbstractState, + }, +} + +impl ControlFrame { + fn seq_id(&self) -> InstrSeqId { + match self { + Self::Block { seq_id, .. } + | Self::Loop { seq_id, .. } + | Self::If { seq_id, .. } + | Self::Else { seq_id, .. } => *seq_id, + } + } +} + +/// Analyze the function with a custom `br_table` threshold. +pub(crate) fn analyze(module: &Module, func: &LocalFunction, threshold: u32) -> BTreeSet { + let mut interp = AbstractInterp::new(module, func, threshold); + dfs_in_order(&mut interp, func, func.entry_block()); + interp.dispatch_loads +} + +struct AbstractInterp<'a> { + /// The target module. + module: &'a Module, + /// Interpreter state. + state: AbstractState, + /// Control frames. + frames: Vec, + /// Program counter (byte offset of the instruction being visited). + pc: u32, + /// The set of loads that contribute the index argument to the + /// `br_table`. + dispatch_loads: BTreeSet, + /// Threshold above which a `br_table` qualifies as the + /// interpreter dispatch loop. + dispatch_target_threshold: u32, +} + +impl<'a> AbstractInterp<'a> { + fn new(module: &'a Module, func: &'a LocalFunction, threshold: u32) -> Self { + let mut state = AbstractState::default(); + for arg in &func.args { + state.locals.insert(*arg, Provenance::new()); + } + + let mut frames = Vec::new(); + // Push the implicit start control block. + frames.push(ControlFrame::Block { + seq_id: func.entry_block(), + target: AbstractState::default(), + }); + + Self { + module, + state, + frames, + pc: 0, + dispatch_target_threshold: threshold, + dispatch_loads: BTreeSet::new(), + } + } + + /// Blanket pop/push for operators deemed not to affect + /// provenance. + fn pop_push_n(&mut self, n_pop: usize, n_push: usize) { + for _ in 0..n_pop { + self.state.pop(); + } + for _ in 0..n_push { + self.state.push(Provenance::new()); + } + } + + /// Merge the current state into the target branch state. + fn join_into_target(&mut self, target: InstrSeqId) { + let snapshot = self.state.clone(); + if let Some(frame) = self.frames.iter_mut().rev().find(|f| f.seq_id() == target) { + match frame { + ControlFrame::Loop { header, .. } => header.join(&snapshot), + ControlFrame::Block { target, .. } + | ControlFrame::If { target, .. } + | ControlFrame::Else { target, .. } => target.join(&snapshot), + } + } + } +} + +pub fn is_byte_load(load: &Load) -> bool { + matches!(load.kind, LoadKind::I32_8 { .. } | LoadKind::I64_8 { .. }) +} + +impl<'f, 'instr> Visitor<'instr> for AbstractInterp<'f> { + fn visit_instr(&mut self, _: &'instr Instr, loc: &'instr InstrLocId) { + // Save the program counter before visiting each operator. + self.pc = loc.data(); + } + + fn end_instr_seq(&mut self, _: &'instr InstrSeq) { + let frame = match self.frames.pop() { + Some(f) => f, + None => return, + }; + match frame { + // On block end, join the state of the ending block into + // the current state. + ControlFrame::Block { target, .. } => { + self.state.join(&target); + } + // On loop end, we have a fall-through. + ControlFrame::Loop { .. } => {} + // On if end, replace the frame with else, and restore the + // state to the if entry state. + // Also, merge and store the exit state by merging the + // current state with the target state. + ControlFrame::If { + entry_state, + target, + .. + } => { + let mut exit_state = self.state.clone(); + exit_state.join(&target); + match self.frames.last_mut() { + Some(ControlFrame::Else { if_exit: slot, .. }) => *slot = exit_state, + _ => panic!("If frame must be followed by matching Else frame on the stack"), + } + self.state = entry_state; + } + // On else end, merge the current state with the state + // from both the if and else. + ControlFrame::Else { + if_exit, target, .. + } => { + self.state.join(&target); + self.state.join(&if_exit); + } + } + } + + fn visit_block(&mut self, b: &Block) { + self.frames.push(ControlFrame::Block { + seq_id: b.seq, + target: AbstractState::default(), + }); + } + + fn visit_loop(&mut self, l: &Loop) { + self.frames.push(ControlFrame::Loop { + seq_id: l.seq, + header: self.state.clone(), + }); + } + + fn visit_if_else(&mut self, ie: &IfElse) { + self.state.pop(); + let snapshot = self.state.clone(); + + self.frames.push(ControlFrame::Else { + seq_id: ie.alternative, + if_exit: AbstractState::default(), + target: AbstractState::default(), + }); + self.frames.push(ControlFrame::If { + seq_id: ie.consequent, + entry_state: snapshot, + target: AbstractState::default(), + }); + } + + fn visit_br(&mut self, br: &Br) { + self.join_into_target(br.block); + } + + fn visit_br_if(&mut self, br_if: &BrIf) { + self.state.pop(); + self.join_into_target(br_if.block); + } + + fn visit_br_table(&mut self, bt: &BrTable) { + let index = self.state.pop(); + if bt.blocks.len() >= self.dispatch_target_threshold as usize { + self.dispatch_loads.extend(index.0.iter().copied()); + } + let snapshot = self.state.clone(); + let mut seen: BTreeSet = BTreeSet::new(); + for &t in bt.blocks.iter().chain(std::iter::once(&bt.default)) { + if !seen.insert(t) { + continue; + } + if let Some(frame) = self.frames.iter_mut().rev().find(|f| f.seq_id() == t) { + match frame { + ControlFrame::Loop { header, .. } => header.join(&snapshot), + ControlFrame::Block { target, .. } + | ControlFrame::If { target, .. } + | ControlFrame::Else { target, .. } => target.join(&snapshot), + } + } + } + } + + fn visit_load(&mut self, l: &Load) { + self.state.pop(); + if is_byte_load(l) { + self.state.push(Provenance::with(self.pc)); + } else { + self.state.push(Provenance::new()); + } + } + + fn visit_store(&mut self, _: &Store) { + self.pop_push_n(2, 0); + } + + fn visit_const(&mut self, _: &Const) { + self.state.push(Provenance::new()); + } + + fn visit_binop(&mut self, _: &Binop) { + let rhs = self.state.pop(); + let lhs = self.state.pop(); + self.state.push(lhs.joined(&rhs)); + } + + fn visit_unop(&mut self, _: &Unop) { + let v = self.state.pop(); + self.state.push(v); + } + + fn visit_local_get(&mut self, lg: &LocalGet) { + let p = self + .state + .locals + .get(&lg.local) + .cloned() + .unwrap_or_default(); + self.state.push(p); + } + + fn visit_local_set(&mut self, ls: &LocalSet) { + let v = self.state.pop(); + self.state.locals.insert(ls.local, v); + } + + fn visit_local_tee(&mut self, lt: &LocalTee) { + let v = self.state.stack.last().cloned().unwrap_or_default(); + self.state.locals.insert(lt.local, v); + } + + fn visit_global_get(&mut self, _: &GlobalGet) { + self.state.push(Provenance::new()); + } + + fn visit_global_set(&mut self, _: &GlobalSet) { + self.state.pop(); + } + + fn visit_drop(&mut self, _: &Drop) { + self.state.pop(); + } + + fn visit_select(&mut self, _: &Select) { + self.state.pop(); + + let r = self.state.pop(); + let l = self.state.pop(); + self.state.push(l.joined(&r)); + } + + fn visit_call(&mut self, c: &Call) { + let ty_id = self.module.funcs.get(c.func).ty(); + let ty = self.module.types.get(ty_id); + self.pop_push_n(ty.params().len(), ty.results().len()); + } + + fn visit_call_indirect(&mut self, c: &CallIndirect) { + let ty = self.module.types.get(c.ty); + // +1 for the function index popped before the args. + self.pop_push_n(ty.params().len() + 1, ty.results().len()); + } + + fn visit_return_call(&mut self, c: &ReturnCall) { + let ty_id = self.module.funcs.get(c.func).ty(); + let ty = self.module.types.get(ty_id); + self.pop_push_n(ty.params().len(), 0); + } + + fn visit_return_call_indirect(&mut self, c: &ReturnCallIndirect) { + let ty = self.module.types.get(c.ty); + self.pop_push_n(ty.params().len() + 1, 0); + } + + fn visit_return(&mut self, _: &Return) {} + + fn visit_unreachable(&mut self, _: &Unreachable) {} + + fn visit_memory_size(&mut self, _: &MemorySize) { + self.state.push(Provenance::new()); + } + + fn visit_memory_grow(&mut self, _: &MemoryGrow) { + self.state.pop(); + self.state.push(Provenance::new()); + } +} diff --git a/crates/profiler-lib/src/lib.rs b/crates/profiler-lib/src/lib.rs new file mode 100644 index 00000000..893de494 --- /dev/null +++ b/crates/profiler-lib/src/lib.rs @@ -0,0 +1,31 @@ +mod ai; +mod state; + +use state::State; +use std::sync::LazyLock; + +// TODO: Passing empty bytes is temporary. Whamm currently does not +// offer a mechanism to pass in the target application bytes. +// Prior to hooking the bytes we need to either find the best +// way to accomplish that in Whamm's instrumentation pass or +// create a custom pass in Javy, e.g., through Wizer. Ideally we +// want the former: arguably, having access to the bytes is +// something that other libraries might need. +static STATE: std::sync::LazyLock = + LazyLock::new(|| State::from_bytes(&[]).expect("State initialization to work")); + +/// Returns true if the given index corresponds to the dispatch +/// function according to the heuristics defined in the [`state`] +/// module. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn is_dipatch_func(index: u32) -> bool { + STATE.is_dispatch_func(index) +} + +/// Returns true if the given program counter offset is classified as +/// a dispatch load which feeds the index argument for the `br_table` +/// dispatch. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn is_dispatch_load(fid: u32, pc: u32) -> bool { + STATE.is_dispatch_load(fid, pc) +} diff --git a/crates/profiler-lib/src/state.rs b/crates/profiler-lib/src/state.rs new file mode 100644 index 00000000..81a55334 --- /dev/null +++ b/crates/profiler-lib/src/state.rs @@ -0,0 +1,389 @@ +//! Instrumentation state to derive probe insertion. + +use anyhow::{Result, bail}; +use std::collections::BTreeSet; +use walrus::ir::{BrTable, Visitor, dfs_in_order}; +use walrus::{FunctionId, FunctionKind, LocalFunction, ModuleConfig}; + +use crate::ai; + +/// Threshold above which a `br_table` qualifies as the interpreter +/// dispatch loop. +pub const DISPATCH_TARGET_THRESHOLD: u32 = 250; + +pub struct State { + /// Wasm function index, which contains a `br_table` with at least + /// the configured target threshold. There should be a single + /// function which meets this criteria. + pub dispatch_func_idx: u32, + /// Byte offsets of the `i32.load8_u` instructions in the dispatch + /// function whose values feed the dispatch `br_table`'s index. + dispatch_loads: BTreeSet, +} + +impl State { + /// Construct a `State` from the given Wasm bytes. + pub fn from_bytes(bytes: &[u8]) -> Result { + Self::from_bytes_with_threshold(bytes, DISPATCH_TARGET_THRESHOLD) + } + + /// Construct a `State` with a custom `br_table` target threshold. + pub(crate) fn from_bytes_with_threshold(bytes: &[u8], threshold: u32) -> Result { + let module = ModuleConfig::new().parse(bytes)?; + + let candidates: Vec<(u32, FunctionId)> = module + .funcs + .iter() + .enumerate() + .filter_map(|(idx, func)| match &func.kind { + FunctionKind::Local(local) if has_large_br_table(local, threshold) => { + Some((idx as u32, func.id())) + } + _ => None, + }) + .collect(); + + if candidates.len() != 1 { + bail!( + "Unexpected number of dispatch functions. Expected 1, found {}", + candidates.len() + ); + } + + let (dispatch_func_idx, dispatch_func_id) = candidates[0]; + + let local = match &module.funcs.get(dispatch_func_id).kind { + FunctionKind::Local(l) => l, + // Mostly for completeness, this should not be possible, + // given the filtering above. + _ => unreachable!("filtered to local functions only"), + }; + let dispatch_loads = ai::analyze(&module, local, threshold); + + Ok(Self { + dispatch_func_idx, + dispatch_loads, + }) + } + + /// Given a function id in the module, return whether it matches + /// the dispatch function heuristics. + pub fn is_dispatch_func(&self, id: u32) -> bool { + self.dispatch_func_idx == id + } + + /// Given a function id and an instruction offset, return whether + /// the instruction at `pc` is the memory load responsible for + /// fetching the next QuickJS opcode. + pub fn is_dispatch_load(&self, id: u32, pc: u32) -> bool { + id == self.dispatch_func_idx && self.dispatch_loads.contains(&pc) + } +} + +/// True iff `func` contains a `br_table` whose target list has at +/// least `threshold` entries. +fn has_large_br_table(func: &LocalFunction, threshold: u32) -> bool { + struct Detect { + threshold: u32, + found: bool, + } + impl<'instr> Visitor<'instr> for Detect { + fn visit_br_table(&mut self, br_table: &BrTable) { + if br_table.blocks.len() >= self.threshold as usize { + self.found = true; + } + } + } + let mut d = Detect { + threshold, + found: false, + }; + dfs_in_order(&mut d, func, func.entry_block()); + d.found +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::{Result, anyhow}; + use std::collections::HashMap; + use walrus::InstrLocId; + use walrus::ir::Instr; + + fn to_bytes(wat: &str) -> Result> { + Ok(wat::parse_str(wat)?) + } + + /// Map every instruction's byte offset to the corresponding + /// instruction in the dispatch function. + fn pc2instr(state: &State) -> Result> { + let func = state + .module() + .funcs + .iter() + .nth(state.dispatch_func_idx as usize) + .ok_or_else(|| anyhow!("no function at index {}", state.dispatch_func_idx))?; + let local = match &func.kind { + FunctionKind::Local(l) => l, + _ => bail!( + "function at index {} is not a local function", + state.dispatch_func_idx + ), + }; + + #[derive(Default)] + struct Collect { + map: HashMap, + } + impl<'i> Visitor<'i> for Collect { + fn visit_instr(&mut self, instr: &'i Instr, loc: &'i InstrLocId) { + self.map.insert(loc.data(), instr.clone()); + } + } + let mut v = Collect::default(); + dfs_in_order(&mut v, local, local.entry_block()); + Ok(v.map) + } + + fn br_table(n: usize) -> String { + let labels = vec!["0"; n].join(" "); + format!("br_table {labels} 0") + } + + fn assert_all_byte_loads(state: &State) -> Result<()> { + let map = pc2instr(state)?; + for &pc in &state.dispatch_loads { + let instr = map + .get(&pc) + .ok_or_else(|| anyhow!("pc {pc} not found in function"))?; + match instr { + Instr::Load(l) => { + if !ai::is_byte_load(l) { + bail!("pc {pc} is a load but not a byte load: {:?}", l.kind); + } + } + other => bail!("pc {pc} is {other:?}, not a load"), + } + } + Ok(()) + } + + #[test] + fn straight_line_load() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert!(state.is_dispatch_func(state.dispatch_func_idx)); + assert_eq!(state.dispatch_loads.len(), 1, "expected one dispatch load"); + assert_all_byte_loads(&state)?; + Ok(()) + } + + #[test] + fn provenance_survives_i32_and() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + i32.const 0xff + i32.and + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert_eq!( + state.dispatch_loads.len(), + 1, + "and must not drop provenance" + ); + assert_all_byte_loads(&state)?; + Ok(()) + } + + #[test] + fn provenance_flows_through_local_roundtrip() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) (local $byte i32) + (block + local.get $p + i32.load8_u + local.set $byte + local.get $byte + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert_eq!(state.dispatch_loads.len(), 1); + assert_all_byte_loads(&state)?; + Ok(()) + } + + #[test] + fn if_else_merge_collects_both_loads() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) (local $byte i32) + (block + local.get $p + i32.const 1 + i32.lt_s + if + local.get $p + i32.load8_u offset=0 + local.set $byte + else + local.get $p + i32.load8_u offset=4 + local.set $byte + end + local.get $byte + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert_eq!(state.dispatch_loads.len(), 2, "both loads must be recorded"); + assert_all_byte_loads(&state)?; + Ok(()) + } + + #[test] + fn conditional_value() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) (local $byte i32) + (block + local.get $p + i32.load8_u offset=0 + local.set $byte + local.get $p + i32.const 200 + i32.lt_s + if + local.get $p + i32.load8_u offset=1 + local.set $byte + end + local.get $byte + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert_eq!(state.dispatch_loads.len(), 2); + assert_all_byte_loads(&state)?; + Ok(()) + } + + #[test] + fn non_byte_load_is_not_recorded() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert!( + state.dispatch_loads.is_empty(), + "non-byte load must not be recorded" + ); + Ok(()) + } + + #[test] + fn br_table_without_load_is_empty() -> Result<()> { + let wat = format!( + r#" + (module + (func (param $p i32) + (block + i32.const 0 + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert!(state.dispatch_loads.is_empty()); + Ok(()) + } + + #[test] + fn correctly_identifies_dispatch_func() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + assert!(state.is_dispatch_func(state.dispatch_func_idx)); + assert!(!state.is_dispatch_func(state.dispatch_func_idx + 1)); + Ok(()) + } + + #[test] + fn dispatch_load_is_scoped_to_dispatch_func() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + {br_table}))) + "#, + br_table = br_table(3) + ); + let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + + let pc = *state.dispatch_loads.iter().next().unwrap(); + assert!(state.is_dispatch_load(state.dispatch_func_idx, pc)); + assert!(!state.is_dispatch_load(state.dispatch_func_idx + 1, pc)); + Ok(()) + } +} From aafbe861b9e703ebbbe7eb2a10835a7c4f48b0dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Cabrera?= Date: Tue, 12 May 2026 14:06:39 -0400 Subject: [PATCH 2/3] Apply clippy suggestions --- crates/profiler-lib/src/ai.rs | 13 +++++++------ crates/profiler-lib/src/lib.rs | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/crates/profiler-lib/src/ai.rs b/crates/profiler-lib/src/ai.rs index 88f37aea..a7467e4a 100644 --- a/crates/profiler-lib/src/ai.rs +++ b/crates/profiler-lib/src/ai.rs @@ -139,12 +139,13 @@ impl<'a> AbstractInterp<'a> { state.locals.insert(*arg, Provenance::new()); } - let mut frames = Vec::new(); - // Push the implicit start control block. - frames.push(ControlFrame::Block { - seq_id: func.entry_block(), - target: AbstractState::default(), - }); + let frames = vec![ + // Push the implicit start control block. + ControlFrame::Block { + seq_id: func.entry_block(), + target: AbstractState::default(), + }, + ]; Self { module, diff --git a/crates/profiler-lib/src/lib.rs b/crates/profiler-lib/src/lib.rs index 893de494..0b726576 100644 --- a/crates/profiler-lib/src/lib.rs +++ b/crates/profiler-lib/src/lib.rs @@ -18,7 +18,7 @@ static STATE: std::sync::LazyLock = /// function according to the heuristics defined in the [`state`] /// module. #[unsafe(no_mangle)] -pub unsafe extern "C" fn is_dipatch_func(index: u32) -> bool { +pub extern "C" fn is_dipatch_func(index: u32) -> bool { STATE.is_dispatch_func(index) } @@ -26,6 +26,6 @@ pub unsafe extern "C" fn is_dipatch_func(index: u32) -> bool { /// a dispatch load which feeds the index argument for the `br_table` /// dispatch. #[unsafe(no_mangle)] -pub unsafe extern "C" fn is_dispatch_load(fid: u32, pc: u32) -> bool { +pub extern "C" fn is_dispatch_load(fid: u32, pc: u32) -> bool { STATE.is_dispatch_load(fid, pc) } From a72a479d69de15014d126ba1a6c22c3d3af85194 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Cabrera?= Date: Tue, 12 May 2026 14:18:26 -0400 Subject: [PATCH 3/3] Fix tests in state.rs --- crates/profiler-lib/src/state.rs | 56 ++++++++++++++++---------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/crates/profiler-lib/src/state.rs b/crates/profiler-lib/src/state.rs index 81a55334..1512619b 100644 --- a/crates/profiler-lib/src/state.rs +++ b/crates/profiler-lib/src/state.rs @@ -108,27 +108,27 @@ mod tests { use anyhow::{Result, anyhow}; use std::collections::HashMap; use walrus::InstrLocId; + use walrus::Module; use walrus::ir::Instr; - fn to_bytes(wat: &str) -> Result> { - Ok(wat::parse_str(wat)?) + fn make(wat: &str, threshold: u32) -> Result<(Module, State)> { + let bytes = wat::parse_str(wat)?; + let module = ModuleConfig::new().parse(&bytes)?; + let state = State::from_bytes_with_threshold(&bytes, threshold)?; + Ok((module, state)) } /// Map every instruction's byte offset to the corresponding - /// instruction in the dispatch function. - fn pc2instr(state: &State) -> Result> { - let func = state - .module() + /// instruction in the function at index `fid`. + fn pc2instr(module: &Module, fid: u32) -> Result> { + let func = module .funcs .iter() - .nth(state.dispatch_func_idx as usize) - .ok_or_else(|| anyhow!("no function at index {}", state.dispatch_func_idx))?; + .nth(fid as usize) + .ok_or_else(|| anyhow!("no function at index {fid}"))?; let local = match &func.kind { FunctionKind::Local(l) => l, - _ => bail!( - "function at index {} is not a local function", - state.dispatch_func_idx - ), + _ => bail!("function at index {fid} is not a local function"), }; #[derive(Default)] @@ -150,8 +150,8 @@ mod tests { format!("br_table {labels} 0") } - fn assert_all_byte_loads(state: &State) -> Result<()> { - let map = pc2instr(state)?; + fn assert_all_byte_loads(module: &Module, state: &State) -> Result<()> { + let map = pc2instr(module, state.dispatch_func_idx)?; for &pc in &state.dispatch_loads { let instr = map .get(&pc) @@ -182,11 +182,11 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (module, state) = make(&wat, 3)?; assert!(state.is_dispatch_func(state.dispatch_func_idx)); assert_eq!(state.dispatch_loads.len(), 1, "expected one dispatch load"); - assert_all_byte_loads(&state)?; + assert_all_byte_loads(&module, &state)?; Ok(()) } @@ -206,14 +206,14 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (module, state) = make(&wat, 3)?; assert_eq!( state.dispatch_loads.len(), 1, "and must not drop provenance" ); - assert_all_byte_loads(&state)?; + assert_all_byte_loads(&module, &state)?; Ok(()) } @@ -233,10 +233,10 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (module, state) = make(&wat, 3)?; assert_eq!(state.dispatch_loads.len(), 1); - assert_all_byte_loads(&state)?; + assert_all_byte_loads(&module, &state)?; Ok(()) } @@ -265,10 +265,10 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (module, state) = make(&wat, 3)?; assert_eq!(state.dispatch_loads.len(), 2, "both loads must be recorded"); - assert_all_byte_loads(&state)?; + assert_all_byte_loads(&module, &state)?; Ok(()) } @@ -296,10 +296,10 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (module, state) = make(&wat, 3)?; assert_eq!(state.dispatch_loads.len(), 2); - assert_all_byte_loads(&state)?; + assert_all_byte_loads(&module, &state)?; Ok(()) } @@ -317,7 +317,7 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (_module, state) = make(&wat, 3)?; assert!( state.dispatch_loads.is_empty(), @@ -338,7 +338,7 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (_module, state) = make(&wat, 3)?; assert!(state.dispatch_loads.is_empty()); Ok(()) @@ -358,7 +358,7 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (_module, state) = make(&wat, 3)?; assert!(state.is_dispatch_func(state.dispatch_func_idx)); assert!(!state.is_dispatch_func(state.dispatch_func_idx + 1)); @@ -379,7 +379,7 @@ mod tests { "#, br_table = br_table(3) ); - let state = State::from_bytes_with_threshold(&to_bytes(&wat)?, 3)?; + let (_module, state) = make(&wat, 3)?; let pc = *state.dispatch_loads.iter().next().unwrap(); assert!(state.is_dispatch_load(state.dispatch_func_idx, pc));