Skip to content
21 changes: 19 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ integration-test: NORMALIZE ?= jq -S -e -f $(TESTDIR)/../normalise-filter.jq
integration-test: DIFF ?= | diff -
integration-test:
errors=""; \
report() { echo "$$1: $$2"; errors="$$errors\n$$1: $$2"; }; \
report() { printf "%s: %s\n" "$$1" "$$2"; errors="$$errors\n$$1: $$2"; }; \
for rust in ${TESTS}; do \
target=$${rust%.rs}.smir.json; \
dir=$$(dirname $${rust}); \
Expand All @@ -37,12 +37,29 @@ integration-test:
&& rm $${target} \
|| report "$$rust" "Unexpected json output"; \
done; \
[ -z "$$errors" ] || (echo "===============\nFAILING TESTS:$$errors"; exit 1)
[ -z "$$errors" ] || (printf "===============\nFAILING TESTS:%s\n" "$$errors"; exit 1)


golden:
make integration-test DIFF=">"

.PHONY: test-skip-lang-start
test-skip-lang-start: TESTS ?= $(shell find $(TESTDIR) -type f -name "*.rs")
test-skip-lang-start: SMIR ?= cargo run -- --d2 "-Zno-codegen"
test-skip-lang-start:
errors=""; \
report() { printf "FAIL: %s: %s\n" "$$1" "$$2"; errors="$$errors\n$$1: $$2"; }; \
for rust in ${TESTS}; do \
dir=$$(dirname $${rust}); \
name=$$(basename $${rust} .rs); \
d2=$${dir}/$${name}.smir.d2; \
echo "$$rust"; \
SKIP_LANG_START=1 ASSERT_FILTER=1 ${SMIR} --out-dir $${dir} $${rust} \
|| { report "$$rust" "Conversion failed"; continue; }; \
rm -f $${d2}; \
done; \
[ -z "$$errors" ] || (printf "===============\nFAILING TESTS:%s\n" "$$errors"; exit 1)

format:
cargo fmt
bash -O globstar -c 'nixfmt **/*.nix'
Expand Down
189 changes: 188 additions & 1 deletion src/mk_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
//! This module provides functionality to generate graph visualizations
//! of Rust's MIR in various formats (DOT, D2).

use std::collections::{HashMap, HashSet, VecDeque};
use std::fs::File;
use std::io::{self, Write};

use crate::compat::middle::ty::TyCtxt;
use crate::compat::output::{mir_output_path, OutputDest};
use crate::printer::collect_smir;
use crate::compat::stable_mir::mir::{ConstOperand, Operand, TerminatorKind};
use crate::printer::{collect_smir, Item, MonoItemKind};

// Sub-modules
pub mod context;
Expand All @@ -21,6 +23,191 @@ pub use context::GraphContext;
pub use index::{AllocEntry, AllocIndex, AllocKind, TypeIndex};
pub use util::GraphLabelString;

// =============================================================================
// Item Filtering
// =============================================================================

/// A predicate that identifies items to exclude from graph output.
/// Each variant corresponds to an environment variable that enables it.
#[derive(Debug)]
pub(crate) enum ItemFilter {
/// Exclude `std::rt::lang_start` and items only reachable through it.
/// Enabled by `SKIP_LANG_START=1`.
LangStart,
}

impl ItemFilter {
/// Return the set of filters currently enabled via environment variables.
pub fn enabled() -> Vec<ItemFilter> {
[std::env::var("SKIP_LANG_START")
.ok()
.map(|_| Self::LangStart)]
.into_iter()
.flatten()
.collect()
}

/// Compute the set of symbol names this filter wants to exclude.
pub fn compute_exclusions(&self, items: &[Item], ctx: &GraphContext) -> HashSet<String> {
match self {
ItemFilter::LangStart => compute_lang_start_exclusions(items, ctx),
}
}

/// Apply all enabled filters: collect exclusions, then prune both
/// `items` and `ctx.functions` in one pass.
///
/// After this call, `ctx.resolve_call_target()` returns `None` for any
/// excluded function, so renderers don't need a separate exclusion set.
///
/// When `ASSERT_FILTER=1` is set (intended for integration tests), this
/// asserts that each filter actually matched something and that no
/// matching items survive after filtering.
pub fn apply_all(items: &mut Vec<Item>, ctx: &mut GraphContext) {
let filters = Self::enabled();
if filters.is_empty() {
return;
}
let assert_mode = std::env::var("ASSERT_FILTER").is_ok();
let mut excluded = HashSet::new();
for filter in &filters {
let filter_excluded = filter.compute_exclusions(items, ctx);
// The precondition assert checks that the test input actually
// contains items this filter targets. For LangStart, this holds
// for any crate with `fn main` because rustc always emits
// `std::rt::lang_start` as the runtime entry wrapper. If a test
// program is a library crate (no main), lang_start won't be
// present and this assert will fire; in that case, either skip
// the lib crate in test-skip-lang-start or gate LangStart on
// the presence of a main function.
if assert_mode {
assert!(
!filter_excluded.is_empty(),
"ASSERT_FILTER: {:?} matched no items. \
If the test input is a library crate (no fn main), \
std::rt::lang_start won't be present; either exclude \
lib crates from test-skip-lang-start or adjust the \
filter precondition.",
filter
);
}
excluded.extend(filter_excluded);
}
items.retain(|i| !excluded.contains(&i.symbol_name));
ctx.functions.retain(|_, name| !excluded.contains(name));
if assert_mode {
for filter in &filters {
assert!(
!filter.survives(items),
"ASSERT_FILTER: {:?} items survived filtering",
filter
);
}
}
}

/// Check whether any items matching this filter remain after filtering.
fn survives(&self, items: &[Item]) -> bool {
match self {
ItemFilter::LangStart => items
.iter()
.any(|i| is_std_rt_lang_start(&i.mono_item_kind)),
}
}
}

/// Compute the set of symbol names to exclude from graph rendering.
/// Excludes `std::rt::lang_start` items and items uniquely downstream
/// of them (i.e., only reachable through `lang_start` in the call graph).
///
/// The algorithm:
/// 1. Build a call graph from Call terminators
/// 2. Identify `std::rt::lang_start` seed items (via demangled name of MonoItemFn)
/// 3. Find entry-point items (not called by any other item)
/// 4. BFS from non-seed entry points, not entering seed nodes
/// 5. Everything not reachable gets excluded
fn compute_lang_start_exclusions(items: &[Item], ctx: &GraphContext) -> HashSet<String> {
// Build forward call graph: symbol_name -> list of callee names
let mut call_graph: HashMap<&str, Vec<&str>> = HashMap::new();
for item in items {
if let MonoItemKind::MonoItemFn {
body: Some(body), ..
} = &item.mono_item_kind
{
let callees: Vec<&str> = body
.blocks
.iter()
.filter_map(|block| match &block.terminator.kind {
TerminatorKind::Call {
func: Operand::Constant(ConstOperand { const_, .. }),
..
} => ctx.functions.get(&const_.ty()).map(|s| s.as_str()),
_ => None,
})
.collect();
call_graph.insert(&item.symbol_name, callees);
}
}

// Identify seed items via the demangled MonoItemFn name containing "std::rt::lang_start".
let seed_names: HashSet<&str> = items
.iter()
.filter(|item| is_std_rt_lang_start(&item.mono_item_kind))
.map(|item| item.symbol_name.as_str())
.collect();

// Retrieve all items that were called via a Call terminator
let has_callers: HashSet<&str> = call_graph.values().flatten().copied().collect();

// BFS from non-seed entry points (items with no callers)
let mut reachable: HashSet<&str> = HashSet::new();
let mut queue: VecDeque<&str> = VecDeque::new();

for item in items {
let name = item.symbol_name.as_str();
let is_entry = !has_callers.contains(name);
if is_entry && !seed_names.contains(name) {
reachable.insert(name);
queue.push_back(name);
}
}

while let Some(name) = queue.pop_front() {
if let Some(callees) = call_graph.get(name) {
for &callee in callees {
if !reachable.contains(callee) && !seed_names.contains(callee) {
reachable.insert(callee);
queue.push_back(callee);
}
}
}
}

// Everything NOT reachable should be excluded
let all_names: HashSet<&str> = items
.iter()
.map(|i| i.symbol_name.as_str())
.chain(ctx.functions.values().map(|s| s.as_str()))
.collect();

all_names
.difference(&reachable)
.map(|s| s.to_string())
.collect()
}

/// Check the demangled MonoItemFn name for `std::rt::lang_start`.
/// This catches:
/// - `std::rt::lang_start::<()>` (the runtime entry point)
/// - `std::rt::lang_start::<()>::{closure#0}` (its closure)
/// - `<{closure@std::rt::lang_start<()>::{closure#0}} as ...>::call_once` (trait impls referencing it)
/// - `std::ptr::drop_in_place::<{closure@std::rt::lang_start<()>::{closure#0}}>` (drop glue)
///
/// But not a user-defined `lang_start` e.g. `crate1::something::lang_start`.
fn is_std_rt_lang_start(kind: &MonoItemKind) -> bool {
matches!(kind, MonoItemKind::MonoItemFn { name, .. } if name.contains("std::rt::lang_start"))
}

// =============================================================================
// Entry Points
// =============================================================================
Expand Down
10 changes: 6 additions & 4 deletions src/mk_graph/output/d2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
use crate::compat::stable_mir;
use stable_mir::mir::TerminatorKind;

use crate::printer::SmirJson;
use crate::MonoItemKind;
use crate::printer::{MonoItemKind, SmirJson};

use crate::mk_graph::context::GraphContext;
use crate::mk_graph::util::{
escape_d2, is_unqualified, name_lines, short_name, terminator_targets,
};
use crate::mk_graph::ItemFilter;

impl SmirJson {
/// Convert the MIR to D2 diagram format
pub fn to_d2_file(self) -> String {
let ctx = GraphContext::from_smir(&self);
pub fn to_d2_file(mut self) -> String {
let mut ctx = GraphContext::from_smir(&self);
ItemFilter::apply_all(&mut self.items, &mut ctx);

let mut output = String::new();

output.push_str("direction: right\n\n");
Expand Down
7 changes: 4 additions & 3 deletions src/mk_graph/output/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ use crate::MonoItemKind;

use crate::mk_graph::context::GraphContext;
use crate::mk_graph::util::{block_name, is_unqualified, name_lines, short_name, GraphLabelString};
use crate::mk_graph::ItemFilter;

impl SmirJson {
/// Convert the MIR to DOT (Graphviz) format
pub fn to_dot_file(self) -> String {
pub fn to_dot_file(mut self) -> String {
let mut bytes = Vec::new();

// Build context BEFORE consuming self
let ctx = GraphContext::from_smir(&self);
let mut ctx = GraphContext::from_smir(&self);
ItemFilter::apply_all(&mut self.items, &mut ctx);

{
let mut writer = DotWriter::from(&mut bytes);
Expand Down