diff --git a/Cargo.lock b/Cargo.lock index f04951e0..08da5283 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1722,6 +1722,15 @@ dependencies = [ "wasmtime-wizer", ] +[[package]] +name = "javy-profiler-lib" +version = "0.0.0" +dependencies = [ + "anyhow", + "walrus", + "wat", +] + [[package]] name = "javy-release" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f352078d..4b3709e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "crates/runner", "fuzz", "release", + "crates/profiler-lib", ] resolver = "2" diff --git a/crates/profiler-lib/Cargo.toml b/crates/profiler-lib/Cargo.toml new file mode 100644 index 00000000..9ea91be7 --- /dev/null +++ b/crates/profiler-lib/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "javy-profiler-lib" +version = "0.0.0" +authors.workspace = true +edition.workspace = true +license.workspace = true +publish = false + +[lib] +name = "javy_profiler_lib" +crate-type = ["cdylib"] + +[dependencies] +anyhow = { workspace = true } +walrus = { workspace = true } + +[dev-dependencies] +wat = "1" + +[package.metadata.javy] +targets = ["wasip1"] \ No newline at end of file diff --git a/crates/profiler-lib/src/ai.rs b/crates/profiler-lib/src/ai.rs new file mode 100644 index 00000000..a7467e4a --- /dev/null +++ b/crates/profiler-lib/src/ai.rs @@ -0,0 +1,400 @@ +//! Abstract interpretation for `br_target` provenance. +//! +//! Given a WebAssembly module and a local function id, perform +//! abstract interpetation to determine the byte offset provenance of +//! the br_table index argument. +//! The target br_table instruction is chosen using a fixed number of +//! branch targets as heuristic. + +use std::collections::{BTreeSet, HashMap}; +use walrus::ir::{ + Binop, Block, Br, BrIf, BrTable, Call, CallIndirect, Const, Drop, GlobalGet, GlobalSet, IfElse, + Instr, InstrSeq, InstrSeqId, Load, LoadKind, LocalGet, LocalSet, LocalTee, Loop, MemoryGrow, + MemorySize, Return, ReturnCall, ReturnCallIndirect, Select, Store, Unop, Unreachable, Visitor, + dfs_in_order, +}; +use walrus::{InstrLocId, LocalFunction, LocalId, Module}; + +/// The set of offsets of byte-load instructions that contribute to +/// the br_table index argument. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +struct Provenance(BTreeSet); + +impl Provenance { + fn new() -> Self { + Self(BTreeSet::new()) + } + + fn with(pos: u32) -> Self { + let mut s = BTreeSet::new(); + s.insert(pos); + Self(s) + } + + fn join_in_place(&mut self, other: &Provenance) { + for &v in &other.0 { + self.0.insert(v); + } + } + + fn joined(&self, other: &Provenance) -> Provenance { + let mut out = self.clone(); + out.join_in_place(other); + out + } +} + +#[derive(Clone, Debug, Default)] +struct AbstractState { + stack: Vec, + locals: HashMap, +} + +impl AbstractState { + fn join(&mut self, other: &AbstractState) { + // Wasm spec ensures that stack length must match at merge + // points. + let n = self.stack.len().min(other.stack.len()); + for i in 0..n { + self.stack[i].join_in_place(&other.stack[i]); + } + for (k, v) in &other.locals { + self.locals + .entry(*k) + .and_modify(|cur| cur.join_in_place(v)) + .or_insert_with(|| v.clone()); + } + } + + fn pop(&mut self) -> Provenance { + self.stack.pop().unwrap_or_default() + } + + fn push(&mut self, p: Provenance) { + self.stack.push(p); + } +} + +/// Control frames. +enum ControlFrame { + Block { + seq_id: InstrSeqId, + target: AbstractState, + }, + Loop { + seq_id: InstrSeqId, + header: AbstractState, + }, + If { + seq_id: InstrSeqId, + entry_state: AbstractState, + target: AbstractState, + }, + Else { + seq_id: InstrSeqId, + if_exit: AbstractState, + target: AbstractState, + }, +} + +impl ControlFrame { + fn seq_id(&self) -> InstrSeqId { + match self { + Self::Block { seq_id, .. } + | Self::Loop { seq_id, .. } + | Self::If { seq_id, .. } + | Self::Else { seq_id, .. } => *seq_id, + } + } +} + +/// Analyze the function with a custom `br_table` threshold. +pub(crate) fn analyze(module: &Module, func: &LocalFunction, threshold: u32) -> BTreeSet { + let mut interp = AbstractInterp::new(module, func, threshold); + dfs_in_order(&mut interp, func, func.entry_block()); + interp.dispatch_loads +} + +struct AbstractInterp<'a> { + /// The target module. + module: &'a Module, + /// Interpreter state. + state: AbstractState, + /// Control frames. + frames: Vec, + /// Program counter (byte offset of the instruction being visited). + pc: u32, + /// The set of loads that contribute the index argument to the + /// `br_table`. + dispatch_loads: BTreeSet, + /// Threshold above which a `br_table` qualifies as the + /// interpreter dispatch loop. + dispatch_target_threshold: u32, +} + +impl<'a> AbstractInterp<'a> { + fn new(module: &'a Module, func: &'a LocalFunction, threshold: u32) -> Self { + let mut state = AbstractState::default(); + for arg in &func.args { + state.locals.insert(*arg, Provenance::new()); + } + + let frames = vec![ + // Push the implicit start control block. + ControlFrame::Block { + seq_id: func.entry_block(), + target: AbstractState::default(), + }, + ]; + + Self { + module, + state, + frames, + pc: 0, + dispatch_target_threshold: threshold, + dispatch_loads: BTreeSet::new(), + } + } + + /// Blanket pop/push for operators deemed not to affect + /// provenance. + fn pop_push_n(&mut self, n_pop: usize, n_push: usize) { + for _ in 0..n_pop { + self.state.pop(); + } + for _ in 0..n_push { + self.state.push(Provenance::new()); + } + } + + /// Merge the current state into the target branch state. + fn join_into_target(&mut self, target: InstrSeqId) { + let snapshot = self.state.clone(); + if let Some(frame) = self.frames.iter_mut().rev().find(|f| f.seq_id() == target) { + match frame { + ControlFrame::Loop { header, .. } => header.join(&snapshot), + ControlFrame::Block { target, .. } + | ControlFrame::If { target, .. } + | ControlFrame::Else { target, .. } => target.join(&snapshot), + } + } + } +} + +pub fn is_byte_load(load: &Load) -> bool { + matches!(load.kind, LoadKind::I32_8 { .. } | LoadKind::I64_8 { .. }) +} + +impl<'f, 'instr> Visitor<'instr> for AbstractInterp<'f> { + fn visit_instr(&mut self, _: &'instr Instr, loc: &'instr InstrLocId) { + // Save the program counter before visiting each operator. + self.pc = loc.data(); + } + + fn end_instr_seq(&mut self, _: &'instr InstrSeq) { + let frame = match self.frames.pop() { + Some(f) => f, + None => return, + }; + match frame { + // On block end, join the state of the ending block into + // the current state. + ControlFrame::Block { target, .. } => { + self.state.join(&target); + } + // On loop end, we have a fall-through. + ControlFrame::Loop { .. } => {} + // On if end, replace the frame with else, and restore the + // state to the if entry state. + // Also, merge and store the exit state by merging the + // current state with the target state. + ControlFrame::If { + entry_state, + target, + .. + } => { + let mut exit_state = self.state.clone(); + exit_state.join(&target); + match self.frames.last_mut() { + Some(ControlFrame::Else { if_exit: slot, .. }) => *slot = exit_state, + _ => panic!("If frame must be followed by matching Else frame on the stack"), + } + self.state = entry_state; + } + // On else end, merge the current state with the state + // from both the if and else. + ControlFrame::Else { + if_exit, target, .. + } => { + self.state.join(&target); + self.state.join(&if_exit); + } + } + } + + fn visit_block(&mut self, b: &Block) { + self.frames.push(ControlFrame::Block { + seq_id: b.seq, + target: AbstractState::default(), + }); + } + + fn visit_loop(&mut self, l: &Loop) { + self.frames.push(ControlFrame::Loop { + seq_id: l.seq, + header: self.state.clone(), + }); + } + + fn visit_if_else(&mut self, ie: &IfElse) { + self.state.pop(); + let snapshot = self.state.clone(); + + self.frames.push(ControlFrame::Else { + seq_id: ie.alternative, + if_exit: AbstractState::default(), + target: AbstractState::default(), + }); + self.frames.push(ControlFrame::If { + seq_id: ie.consequent, + entry_state: snapshot, + target: AbstractState::default(), + }); + } + + fn visit_br(&mut self, br: &Br) { + self.join_into_target(br.block); + } + + fn visit_br_if(&mut self, br_if: &BrIf) { + self.state.pop(); + self.join_into_target(br_if.block); + } + + fn visit_br_table(&mut self, bt: &BrTable) { + let index = self.state.pop(); + if bt.blocks.len() >= self.dispatch_target_threshold as usize { + self.dispatch_loads.extend(index.0.iter().copied()); + } + let snapshot = self.state.clone(); + let mut seen: BTreeSet = BTreeSet::new(); + for &t in bt.blocks.iter().chain(std::iter::once(&bt.default)) { + if !seen.insert(t) { + continue; + } + if let Some(frame) = self.frames.iter_mut().rev().find(|f| f.seq_id() == t) { + match frame { + ControlFrame::Loop { header, .. } => header.join(&snapshot), + ControlFrame::Block { target, .. } + | ControlFrame::If { target, .. } + | ControlFrame::Else { target, .. } => target.join(&snapshot), + } + } + } + } + + fn visit_load(&mut self, l: &Load) { + self.state.pop(); + if is_byte_load(l) { + self.state.push(Provenance::with(self.pc)); + } else { + self.state.push(Provenance::new()); + } + } + + fn visit_store(&mut self, _: &Store) { + self.pop_push_n(2, 0); + } + + fn visit_const(&mut self, _: &Const) { + self.state.push(Provenance::new()); + } + + fn visit_binop(&mut self, _: &Binop) { + let rhs = self.state.pop(); + let lhs = self.state.pop(); + self.state.push(lhs.joined(&rhs)); + } + + fn visit_unop(&mut self, _: &Unop) { + let v = self.state.pop(); + self.state.push(v); + } + + fn visit_local_get(&mut self, lg: &LocalGet) { + let p = self + .state + .locals + .get(&lg.local) + .cloned() + .unwrap_or_default(); + self.state.push(p); + } + + fn visit_local_set(&mut self, ls: &LocalSet) { + let v = self.state.pop(); + self.state.locals.insert(ls.local, v); + } + + fn visit_local_tee(&mut self, lt: &LocalTee) { + let v = self.state.stack.last().cloned().unwrap_or_default(); + self.state.locals.insert(lt.local, v); + } + + fn visit_global_get(&mut self, _: &GlobalGet) { + self.state.push(Provenance::new()); + } + + fn visit_global_set(&mut self, _: &GlobalSet) { + self.state.pop(); + } + + fn visit_drop(&mut self, _: &Drop) { + self.state.pop(); + } + + fn visit_select(&mut self, _: &Select) { + self.state.pop(); + + let r = self.state.pop(); + let l = self.state.pop(); + self.state.push(l.joined(&r)); + } + + fn visit_call(&mut self, c: &Call) { + let ty_id = self.module.funcs.get(c.func).ty(); + let ty = self.module.types.get(ty_id); + self.pop_push_n(ty.params().len(), ty.results().len()); + } + + fn visit_call_indirect(&mut self, c: &CallIndirect) { + let ty = self.module.types.get(c.ty); + // +1 for the function index popped before the args. + self.pop_push_n(ty.params().len() + 1, ty.results().len()); + } + + fn visit_return_call(&mut self, c: &ReturnCall) { + let ty_id = self.module.funcs.get(c.func).ty(); + let ty = self.module.types.get(ty_id); + self.pop_push_n(ty.params().len(), 0); + } + + fn visit_return_call_indirect(&mut self, c: &ReturnCallIndirect) { + let ty = self.module.types.get(c.ty); + self.pop_push_n(ty.params().len() + 1, 0); + } + + fn visit_return(&mut self, _: &Return) {} + + fn visit_unreachable(&mut self, _: &Unreachable) {} + + fn visit_memory_size(&mut self, _: &MemorySize) { + self.state.push(Provenance::new()); + } + + fn visit_memory_grow(&mut self, _: &MemoryGrow) { + self.state.pop(); + self.state.push(Provenance::new()); + } +} diff --git a/crates/profiler-lib/src/lib.rs b/crates/profiler-lib/src/lib.rs new file mode 100644 index 00000000..0b726576 --- /dev/null +++ b/crates/profiler-lib/src/lib.rs @@ -0,0 +1,31 @@ +mod ai; +mod state; + +use state::State; +use std::sync::LazyLock; + +// TODO: Passing empty bytes is temporary. Whamm currently does not +// offer a mechanism to pass in the target application bytes. +// Prior to hooking the bytes we need to either find the best +// way to accomplish that in Whamm's instrumentation pass or +// create a custom pass in Javy, e.g., through Wizer. Ideally we +// want the former: arguably, having access to the bytes is +// something that other libraries might need. +static STATE: std::sync::LazyLock = + LazyLock::new(|| State::from_bytes(&[]).expect("State initialization to work")); + +/// Returns true if the given index corresponds to the dispatch +/// function according to the heuristics defined in the [`state`] +/// module. +#[unsafe(no_mangle)] +pub extern "C" fn is_dipatch_func(index: u32) -> bool { + STATE.is_dispatch_func(index) +} + +/// Returns true if the given program counter offset is classified as +/// a dispatch load which feeds the index argument for the `br_table` +/// dispatch. +#[unsafe(no_mangle)] +pub extern "C" fn is_dispatch_load(fid: u32, pc: u32) -> bool { + STATE.is_dispatch_load(fid, pc) +} diff --git a/crates/profiler-lib/src/state.rs b/crates/profiler-lib/src/state.rs new file mode 100644 index 00000000..1512619b --- /dev/null +++ b/crates/profiler-lib/src/state.rs @@ -0,0 +1,389 @@ +//! Instrumentation state to derive probe insertion. + +use anyhow::{Result, bail}; +use std::collections::BTreeSet; +use walrus::ir::{BrTable, Visitor, dfs_in_order}; +use walrus::{FunctionId, FunctionKind, LocalFunction, ModuleConfig}; + +use crate::ai; + +/// Threshold above which a `br_table` qualifies as the interpreter +/// dispatch loop. +pub const DISPATCH_TARGET_THRESHOLD: u32 = 250; + +pub struct State { + /// Wasm function index, which contains a `br_table` with at least + /// the configured target threshold. There should be a single + /// function which meets this criteria. + pub dispatch_func_idx: u32, + /// Byte offsets of the `i32.load8_u` instructions in the dispatch + /// function whose values feed the dispatch `br_table`'s index. + dispatch_loads: BTreeSet, +} + +impl State { + /// Construct a `State` from the given Wasm bytes. + pub fn from_bytes(bytes: &[u8]) -> Result { + Self::from_bytes_with_threshold(bytes, DISPATCH_TARGET_THRESHOLD) + } + + /// Construct a `State` with a custom `br_table` target threshold. + pub(crate) fn from_bytes_with_threshold(bytes: &[u8], threshold: u32) -> Result { + let module = ModuleConfig::new().parse(bytes)?; + + let candidates: Vec<(u32, FunctionId)> = module + .funcs + .iter() + .enumerate() + .filter_map(|(idx, func)| match &func.kind { + FunctionKind::Local(local) if has_large_br_table(local, threshold) => { + Some((idx as u32, func.id())) + } + _ => None, + }) + .collect(); + + if candidates.len() != 1 { + bail!( + "Unexpected number of dispatch functions. Expected 1, found {}", + candidates.len() + ); + } + + let (dispatch_func_idx, dispatch_func_id) = candidates[0]; + + let local = match &module.funcs.get(dispatch_func_id).kind { + FunctionKind::Local(l) => l, + // Mostly for completeness, this should not be possible, + // given the filtering above. + _ => unreachable!("filtered to local functions only"), + }; + let dispatch_loads = ai::analyze(&module, local, threshold); + + Ok(Self { + dispatch_func_idx, + dispatch_loads, + }) + } + + /// Given a function id in the module, return whether it matches + /// the dispatch function heuristics. + pub fn is_dispatch_func(&self, id: u32) -> bool { + self.dispatch_func_idx == id + } + + /// Given a function id and an instruction offset, return whether + /// the instruction at `pc` is the memory load responsible for + /// fetching the next QuickJS opcode. + pub fn is_dispatch_load(&self, id: u32, pc: u32) -> bool { + id == self.dispatch_func_idx && self.dispatch_loads.contains(&pc) + } +} + +/// True iff `func` contains a `br_table` whose target list has at +/// least `threshold` entries. +fn has_large_br_table(func: &LocalFunction, threshold: u32) -> bool { + struct Detect { + threshold: u32, + found: bool, + } + impl<'instr> Visitor<'instr> for Detect { + fn visit_br_table(&mut self, br_table: &BrTable) { + if br_table.blocks.len() >= self.threshold as usize { + self.found = true; + } + } + } + let mut d = Detect { + threshold, + found: false, + }; + dfs_in_order(&mut d, func, func.entry_block()); + d.found +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::{Result, anyhow}; + use std::collections::HashMap; + use walrus::InstrLocId; + use walrus::Module; + use walrus::ir::Instr; + + fn make(wat: &str, threshold: u32) -> Result<(Module, State)> { + let bytes = wat::parse_str(wat)?; + let module = ModuleConfig::new().parse(&bytes)?; + let state = State::from_bytes_with_threshold(&bytes, threshold)?; + Ok((module, state)) + } + + /// Map every instruction's byte offset to the corresponding + /// instruction in the function at index `fid`. + fn pc2instr(module: &Module, fid: u32) -> Result> { + let func = module + .funcs + .iter() + .nth(fid as usize) + .ok_or_else(|| anyhow!("no function at index {fid}"))?; + let local = match &func.kind { + FunctionKind::Local(l) => l, + _ => bail!("function at index {fid} is not a local function"), + }; + + #[derive(Default)] + struct Collect { + map: HashMap, + } + impl<'i> Visitor<'i> for Collect { + fn visit_instr(&mut self, instr: &'i Instr, loc: &'i InstrLocId) { + self.map.insert(loc.data(), instr.clone()); + } + } + let mut v = Collect::default(); + dfs_in_order(&mut v, local, local.entry_block()); + Ok(v.map) + } + + fn br_table(n: usize) -> String { + let labels = vec!["0"; n].join(" "); + format!("br_table {labels} 0") + } + + fn assert_all_byte_loads(module: &Module, state: &State) -> Result<()> { + let map = pc2instr(module, state.dispatch_func_idx)?; + for &pc in &state.dispatch_loads { + let instr = map + .get(&pc) + .ok_or_else(|| anyhow!("pc {pc} not found in function"))?; + match instr { + Instr::Load(l) => { + if !ai::is_byte_load(l) { + bail!("pc {pc} is a load but not a byte load: {:?}", l.kind); + } + } + other => bail!("pc {pc} is {other:?}, not a load"), + } + } + Ok(()) + } + + #[test] + fn straight_line_load() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + {br_table}))) + "#, + br_table = br_table(3) + ); + let (module, state) = make(&wat, 3)?; + + assert!(state.is_dispatch_func(state.dispatch_func_idx)); + assert_eq!(state.dispatch_loads.len(), 1, "expected one dispatch load"); + assert_all_byte_loads(&module, &state)?; + Ok(()) + } + + #[test] + fn provenance_survives_i32_and() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + i32.const 0xff + i32.and + {br_table}))) + "#, + br_table = br_table(3) + ); + let (module, state) = make(&wat, 3)?; + + assert_eq!( + state.dispatch_loads.len(), + 1, + "and must not drop provenance" + ); + assert_all_byte_loads(&module, &state)?; + Ok(()) + } + + #[test] + fn provenance_flows_through_local_roundtrip() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) (local $byte i32) + (block + local.get $p + i32.load8_u + local.set $byte + local.get $byte + {br_table}))) + "#, + br_table = br_table(3) + ); + let (module, state) = make(&wat, 3)?; + + assert_eq!(state.dispatch_loads.len(), 1); + assert_all_byte_loads(&module, &state)?; + Ok(()) + } + + #[test] + fn if_else_merge_collects_both_loads() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) (local $byte i32) + (block + local.get $p + i32.const 1 + i32.lt_s + if + local.get $p + i32.load8_u offset=0 + local.set $byte + else + local.get $p + i32.load8_u offset=4 + local.set $byte + end + local.get $byte + {br_table}))) + "#, + br_table = br_table(3) + ); + let (module, state) = make(&wat, 3)?; + + assert_eq!(state.dispatch_loads.len(), 2, "both loads must be recorded"); + assert_all_byte_loads(&module, &state)?; + Ok(()) + } + + #[test] + fn conditional_value() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) (local $byte i32) + (block + local.get $p + i32.load8_u offset=0 + local.set $byte + local.get $p + i32.const 200 + i32.lt_s + if + local.get $p + i32.load8_u offset=1 + local.set $byte + end + local.get $byte + {br_table}))) + "#, + br_table = br_table(3) + ); + let (module, state) = make(&wat, 3)?; + + assert_eq!(state.dispatch_loads.len(), 2); + assert_all_byte_loads(&module, &state)?; + Ok(()) + } + + #[test] + fn non_byte_load_is_not_recorded() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load + {br_table}))) + "#, + br_table = br_table(3) + ); + let (_module, state) = make(&wat, 3)?; + + assert!( + state.dispatch_loads.is_empty(), + "non-byte load must not be recorded" + ); + Ok(()) + } + + #[test] + fn br_table_without_load_is_empty() -> Result<()> { + let wat = format!( + r#" + (module + (func (param $p i32) + (block + i32.const 0 + {br_table}))) + "#, + br_table = br_table(3) + ); + let (_module, state) = make(&wat, 3)?; + + assert!(state.dispatch_loads.is_empty()); + Ok(()) + } + + #[test] + fn correctly_identifies_dispatch_func() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + {br_table}))) + "#, + br_table = br_table(3) + ); + let (_module, state) = make(&wat, 3)?; + + assert!(state.is_dispatch_func(state.dispatch_func_idx)); + assert!(!state.is_dispatch_func(state.dispatch_func_idx + 1)); + Ok(()) + } + + #[test] + fn dispatch_load_is_scoped_to_dispatch_func() -> Result<()> { + let wat = format!( + r#" + (module + (memory 1) + (func (param $p i32) + (block + local.get $p + i32.load8_u + {br_table}))) + "#, + br_table = br_table(3) + ); + let (_module, state) = make(&wat, 3)?; + + let pc = *state.dispatch_loads.iter().next().unwrap(); + assert!(state.is_dispatch_load(state.dispatch_func_idx, pc)); + assert!(!state.is_dispatch_load(state.dispatch_func_idx + 1, pc)); + Ok(()) + } +}