Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,13 @@ impl<Op: GraphOp> FragmentBuilder<Op> {
.map(|input| self.resolve_input_key(input))
.collect();

let global_op_key = GlobalOpKey {
primitive: op.clone(),
inputs: global_inputs,
mode: mode.clone(),
};
let global_op_key = Arc::new(GlobalOpKey::new(op.clone(), global_inputs, mode.clone()));

let mut output_ids = Vec::with_capacity(n_outputs);
for slot in 0..n_outputs {
let val_id = self.vals.len();
let key = GlobalValKey::Derived {
op: global_op_key.clone(),
op: Arc::clone(&global_op_key),
output_slot: slot as u8,
};
self.vals.push(ValNode {
Expand Down
17 changes: 9 additions & 8 deletions src/materialize.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::sync::Arc;

use crate::resolve::{ResolvedView, ValDef};
use crate::traits::GraphOp;
Expand Down Expand Up @@ -30,7 +31,7 @@ pub struct MaterializedGraph<Op: GraphOp> {
struct Materializer<'a, Op: GraphOp> {
view: &'a ResolvedView<Op>,
val_map: HashMap<GlobalValKey<Op>, usize>,
op_map: HashMap<GlobalOpKey<Op>, usize>,
op_map: HashMap<Arc<GlobalOpKey<Op>>, usize>,
vals: Vec<MaterializedVal<Op>>,
ops: Vec<MaterializedOp<Op>>,
input_keys: Vec<GlobalValKey<Op>>,
Expand Down Expand Up @@ -89,11 +90,11 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> {
mode: OpMode,
output_slot: usize,
) -> usize {
let op_key = GlobalOpKey {
primitive: op.clone(),
inputs: input_keys.clone(),
mode: mode.clone(),
};
let op_key = Arc::new(GlobalOpKey::new(
op.clone(),
input_keys.clone(),
mode.clone(),
));

if self.op_map.contains_key(&op_key) {
let output_key = GlobalValKey::Derived {
Expand All @@ -115,7 +116,7 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> {

let materialized_inputs = input_keys.iter().map(|input| self.visit(input)).collect();
let op_index = self.ops.len();
self.op_map.insert(op_key.clone(), op_index);
self.op_map.insert(Arc::clone(&op_key), op_index);
self.ops.push(MaterializedOp {
op: op.clone(),
inputs: materialized_inputs,
Expand All @@ -125,7 +126,7 @@ impl<'a, Op: GraphOp> Materializer<'a, Op> {

for slot in 0..op.n_outputs() {
let output_key = GlobalValKey::Derived {
op: op_key.clone(),
op: Arc::clone(&op_key),
output_slot: slot as u8,
};
let val_index = self.vals.len();
Expand Down
161 changes: 155 additions & 6 deletions src/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;

use crate::traits::GraphOp;

/// Fragment-local value identifier.
Expand All @@ -21,19 +25,164 @@ pub enum ValRef<Op: GraphOp> {
}

/// Cross-fragment structural identity for a value.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug)]
pub enum GlobalValKey<Op: GraphOp> {
Input(Op::InputKey),
Derived {
op: GlobalOpKey<Op>,
/// Shared structural identity of the operation that produced this value.
op: Arc<GlobalOpKey<Op>>,
output_slot: u8,
},
}

/// Cross-fragment structural identity for an operation.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
///
/// `GlobalOpKey` caches a structural fingerprint so maps keyed by recursively
/// derived values can avoid repeatedly re-hashing the whole input tree. Equality
/// still checks the full structure after the fingerprint prefilter.
#[derive(Clone, Debug)]
pub struct GlobalOpKey<Op: GraphOp> {
pub primitive: Op,
pub inputs: Vec<GlobalValKey<Op>>,
pub mode: OpMode,
primitive: Op,
inputs: Vec<GlobalValKey<Op>>,
mode: OpMode,
/// Cached hash prefilter for recursively structural keys.
///
/// This is not an identity proof: equality still compares the full
/// structure after the fingerprint matches, so hash collisions remain
/// correct.
fingerprint: u64,
}

impl<Op: GraphOp> GlobalOpKey<Op> {
/// Builds an operation key and precomputes its structural fingerprint.
///
/// # Examples
///
/// ```ignore
/// use computegraph::{GlobalOpKey, GlobalValKey, GraphOp, OpMode};
///
/// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
/// enum Op {
/// Add,
/// }
///
/// impl GraphOp for Op {
/// type Operand = f64;
/// type Context = ();
/// type InputKey = &'static str;
///
/// fn n_inputs(&self) -> usize { 2 }
/// fn n_outputs(&self) -> usize { 1 }
/// }
///
/// let key = GlobalOpKey::new(
/// Op::Add,
/// vec![GlobalValKey::Input("x"), GlobalValKey::Input("y")],
/// OpMode::Primal,
/// );
/// assert_eq!(key.inputs().len(), 2);
/// ```
pub fn new(primitive: Op, inputs: Vec<GlobalValKey<Op>>, mode: OpMode) -> Self {
let fingerprint = fingerprint_op(&primitive, &inputs, &mode);
Self {
primitive,
inputs,
mode,
fingerprint,
}
}

/// Returns the cached structural fingerprint.
pub fn fingerprint(&self) -> u64 {
self.fingerprint
}

/// Returns the operation primitive.
pub fn primitive(&self) -> &Op {
&self.primitive
}

/// Returns the structural input keys.
pub fn inputs(&self) -> &[GlobalValKey<Op>] {
&self.inputs
}

/// Returns whether this operation belongs to the primal or linear graph.
pub fn mode(&self) -> &OpMode {
&self.mode
}
}

impl<Op: GraphOp> PartialEq for GlobalValKey<Op> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Input(lhs), Self::Input(rhs)) => lhs == rhs,
(
Self::Derived {
op: lhs_op,
output_slot: lhs_slot,
},
Self::Derived {
op: rhs_op,
output_slot: rhs_slot,
},
) => {
lhs_slot == rhs_slot
&& (Arc::ptr_eq(lhs_op, rhs_op) || lhs_op.as_ref() == rhs_op.as_ref())
}
_ => false,
}
}
}

impl<Op: GraphOp> Eq for GlobalValKey<Op> {}

impl<Op: GraphOp> Hash for GlobalValKey<Op> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Self::Input(key) => {
0u8.hash(state);
key.hash(state);
}
Self::Derived { op, output_slot } => {
1u8.hash(state);
op.fingerprint.hash(state);
output_slot.hash(state);
}
}
}
}

impl<Op: GraphOp> PartialEq for GlobalOpKey<Op> {
fn eq(&self, other: &Self) -> bool {
self.fingerprint == other.fingerprint
&& self.primitive == other.primitive
&& self.mode == other.mode
&& self.inputs == other.inputs
}
}

impl<Op: GraphOp> Eq for GlobalOpKey<Op> {}

impl<Op: GraphOp> Hash for GlobalOpKey<Op> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.fingerprint.hash(state);
}
}

fn fingerprint_op<Op: GraphOp>(primitive: &Op, inputs: &[GlobalValKey<Op>], mode: &OpMode) -> u64 {
let mut hasher = DefaultHasher::new();
primitive.hash(&mut hasher);
mode.hash(&mut hasher);
inputs.len().hash(&mut hasher);
for input in inputs {
fingerprint_val(input).hash(&mut hasher);
}
hasher.finish()
}

fn fingerprint_val<Op: GraphOp>(key: &GlobalValKey<Op>) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
51 changes: 41 additions & 10 deletions tests/scalar_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mod common;

use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;

use common::ScalarOp;
Expand Down Expand Up @@ -74,21 +76,50 @@ fn interner_get_returns_none_for_unknown() {
fn interner_derived_key() {
let mut interner = KeyInterner::<ScalarOp>::new();
let key = GlobalValKey::<ScalarOp>::Derived {
op: GlobalOpKey {
primitive: ScalarOp::Add,
inputs: vec![
op: Arc::new(GlobalOpKey::new(
ScalarOp::Add,
vec![
GlobalValKey::Input("x".to_string()),
GlobalValKey::Input("y".to_string()),
],
mode: OpMode::Primal,
},
OpMode::Primal,
)),
output_slot: 0,
};
let id = interner.intern(key.clone());
assert_eq!(interner.resolve(id), &key);
assert_eq!(interner.get(&key), Some(id));
}

#[test]
fn derived_keys_with_distinct_op_arcs_are_structurally_equal() {
let inputs = vec![
GlobalValKey::Input("x".to_string()),
GlobalValKey::Input("y".to_string()),
];
let lhs = GlobalValKey::<ScalarOp>::Derived {
op: Arc::new(GlobalOpKey::new(
ScalarOp::Add,
inputs.clone(),
OpMode::Primal,
)),
output_slot: 0,
};
let rhs = GlobalValKey::<ScalarOp>::Derived {
op: Arc::new(GlobalOpKey::new(ScalarOp::Add, inputs, OpMode::Primal)),
output_slot: 0,
};

assert_eq!(lhs, rhs);
assert_eq!(hash_key(&lhs), hash_key(&rhs));
}

fn hash_key(key: &GlobalValKey<ScalarOp>) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}

// === Fragment tests ===

#[test]
Expand Down Expand Up @@ -125,14 +156,14 @@ fn fragment_builder_add_op() {

// Verify GlobalValKey structure
let expected_key = GlobalValKey::Derived {
op: GlobalOpKey {
primitive: ScalarOp::Add,
inputs: vec![
op: Arc::new(GlobalOpKey::new(
ScalarOp::Add,
vec![
GlobalValKey::Input("x".to_string()),
GlobalValKey::Input("y".to_string()),
],
mode: OpMode::Primal,
},
OpMode::Primal,
)),
output_slot: 0,
};
assert_eq!(frag.vals()[sum_id].key, expected_key);
Expand Down
Loading