Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 96 additions & 6 deletions crates/tensor4all-treetn/src/treetn/fit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,99 @@ where
.ok_or_else(|| anyhow::anyhow!("Tensor for node {:?} not found in {}", node, tree_name))
}

fn tensors_share_contractable_index<T>(left: &T, right: &T) -> bool
where
T: TensorLike,
{
let left_indices = left.external_indices();
let right_indices = right.external_indices();
left_indices.iter().any(|left_index| {
right_indices
.iter()
.any(|right_index| left_index.is_contractable(right_index))
})
}

fn tensor_connected_components<T>(tensors: &[&T]) -> Vec<Vec<usize>>
where
T: TensorLike,
{
let mut visited = vec![false; tensors.len()];
let mut components = Vec::new();

for start in 0..tensors.len() {
if visited[start] {
continue;
}

let mut component = Vec::new();
let mut stack = vec![start];
visited[start] = true;

while let Some(current) = stack.pop() {
component.push(current);
for candidate in 0..tensors.len() {
if visited[candidate] {
continue;
}
if tensors_share_contractable_index(tensors[current], tensors[candidate]) {
visited[candidate] = true;
stack.push(candidate);
}
}
}

component.sort_unstable();
components.push(component);
}

components
}

fn contract_fit_tensor_refs<T>(tensors: &[&T]) -> Result<T>
where
T: TensorLike,
<T::Index as IndexLike>::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync,
{
if tensors.is_empty() {
return Err(anyhow::anyhow!(
"fit contraction requires at least one tensor"
));
}

let components = tensor_connected_components(tensors);
if components.len() == 1 {
return T::contract(tensors).map_err(|e| anyhow::anyhow!("contract failed: {}", e));
}

let mut contracted_components = Vec::with_capacity(components.len());
for component in components {
let component_refs = component
.iter()
.map(|&tensor_index| tensors[tensor_index])
.collect::<Vec<_>>();
let contracted = if component_refs.len() == 1 {
component_refs[0].clone()
} else {
T::contract(&component_refs)
.map_err(|e| anyhow::anyhow!("component contract failed: {}", e))?
};
contracted_components.push(contracted);
}

let mut iter = contracted_components.into_iter();
let mut result = iter
.next()
.ok_or_else(|| anyhow::anyhow!("fit contraction produced no components"))?;
for component in iter {
result = result
.outer_product(&component)
.map_err(|e| anyhow::anyhow!("component outer product failed: {}", e))?;
}

Ok(result)
}

// ============================================================================
// FitEnvironment: Environment tensor cache
// ============================================================================
Expand Down Expand Up @@ -424,8 +517,7 @@ where

// A, B, and C must form one connected local environment.
let c_conj = tensor_c.conj();
let env = T::contract(&[tensor_a, tensor_b, &c_conj])
.map_err(|e| anyhow::anyhow!("contract failed: {}", e))?;
let env = contract_fit_tensor_refs(&[tensor_a, tensor_b, &c_conj])?;

if let Some(started) = started {
with_fit_profile(|profile| {
Expand Down Expand Up @@ -470,8 +562,7 @@ where
let c_conj = tensor_c.conj();
let mut tensor_refs: Vec<&T> = vec![tensor_a, tensor_b, &c_conj];
tensor_refs.extend(child_envs.iter());
let result =
T::contract(&tensor_refs).map_err(|e| anyhow::anyhow!("contract failed: {}", e))?;
let result = contract_fit_tensor_refs(&tensor_refs)?;

if let Some(started) = started {
with_fit_profile(|profile| {
Expand Down Expand Up @@ -673,8 +764,7 @@ where
let contract_started = fit_profile_enabled().then(Instant::now);
let mut tensor_refs: Vec<&T> = vec![a_u, b_u, a_v, b_v];
tensor_refs.extend(env_tensors.iter());
let ab_uv =
T::contract(&tensor_refs).map_err(|e| anyhow::anyhow!("contract failed: {}", e))?;
let ab_uv = contract_fit_tensor_refs(&tensor_refs)?;
if let Some(contract_started) = contract_started {
with_fit_profile(|profile| {
profile.two_site_contract_time += contract_started.elapsed();
Expand Down
12 changes: 7 additions & 5 deletions crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ fn test_contract_fit_positive_sweeps_do_not_skip_without_truncation_options() {
}

#[test]
fn test_contract_fit_rejects_leaf_site_space_that_contracts_away() {
fn test_contract_fit_handles_leaf_site_space_that_contracts_away() {
let left = DynIndex::new_dyn(2);
let right = DynIndex::new_dyn(2);
let shared_left = DynIndex::new_dyn(2);
Expand Down Expand Up @@ -492,14 +492,16 @@ fn test_contract_fit_rejects_leaf_site_space_that_contracts_away() {
)
.unwrap();

let err = contract_fit(
let fitted = contract_fit(
&tn_a,
&tn_b,
&"A".to_string(),
FitContractionOptions::new(1),
)
.unwrap_err()
.to_string();
.unwrap();

assert!(err.contains("Disconnected tensor network"));
assert_eq!(fitted.node_count(), 3);
let fitted_dense = fitted.to_dense().unwrap();
let expected_dense = tn_a.contract_naive(&tn_b).unwrap();
assert!(fitted_dense.distance(&expected_dense).unwrap() < 1e-10);
}
79 changes: 76 additions & 3 deletions crates/tensor4all-treetn/src/treetn/partial_contraction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,9 @@ where
validate_union_topology(&node_names, &union_edges)?;

let structural_result = (|| {
let aligned_a = align_to_union_topology(a, &node_names, &union_edges)?;
let aligned_b = align_to_union_topology(b, &node_names, &union_edges)?;
let mut aligned_a = align_to_union_topology(a, &node_names, &union_edges)?;
let mut aligned_b = align_to_union_topology(b, &node_names, &union_edges)?;
connect_nodewise_outer_product_spectators(&mut aligned_a, &mut aligned_b)?;
contract(&aligned_a, &aligned_b, center, options.clone())
.context("partial_contract: failed contraction after aligning mismatched topologies")
})();
Expand Down Expand Up @@ -732,6 +733,77 @@ where
Ok(())
}

fn has_contractable_index_pair(left: &[DynIndex], right: &[DynIndex]) -> bool {
left.iter().any(|idx_left| {
right
.iter()
.any(|idx_right| idx_left.is_contractable(idx_right))
})
}

fn connect_nodewise_outer_product_spectators<V>(
a: &mut TreeTN<TensorDynLen, V>,
b: &mut TreeTN<TensorDynLen, V>,
) -> Result<()>
where
V: Clone + Hash + Eq + Send + Sync + Debug + Ord,
<DynIndex as IndexLike>::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync,
{
let mut node_names = a.node_names();
node_names.retain(|node_name| b.node_index(node_name).is_some());
node_names.sort();

for node_name in node_names {
let node_a = a
.node_index(&node_name)
.context("partial_contract: node disappeared while adding dummy contraction links")?;
let node_b = b.node_index(&node_name).context(
"partial_contract: matching node disappeared while adding dummy contraction links",
)?;

let tensor_a = a
.tensor(node_a)
.cloned()
.context("partial_contract: missing tensor while adding dummy contraction links")?;
let tensor_b = b.tensor(node_b).cloned().context(
"partial_contract: missing matching tensor while adding dummy contraction links",
)?;

if has_contractable_index_pair(&tensor_a.external_indices(), &tensor_b.external_indices()) {
continue;
}

let (dummy_a, dummy_b) = DynIndex::create_dummy_link_pair();
let dummy_tensor_a =
<TensorDynLen as TensorConstructionLike>::ones(std::slice::from_ref(&dummy_a))
.context("partial_contract: failed to build dummy contraction link")?;
let dummy_tensor_b =
<TensorDynLen as TensorConstructionLike>::ones(std::slice::from_ref(&dummy_b))
.context("partial_contract: failed to build matching dummy contraction link")?;
let expanded_a = tensor_a
.outer_product(&dummy_tensor_a)
.context("partial_contract: failed to attach dummy contraction link")?;
let expanded_b = tensor_b
.outer_product(&dummy_tensor_b)
.context("partial_contract: failed to attach matching dummy contraction link")?;

a.replace_tensor(node_a, expanded_a)
.context(
"partial_contract: failed to replace tensor after adding dummy contraction link",
)?
.context("partial_contract: node disappeared after adding dummy contraction link")?;
b.replace_tensor(node_b, expanded_b)
.context(
"partial_contract: failed to replace matching tensor after adding dummy contraction link",
)?
.context(
"partial_contract: matching node disappeared after adding dummy contraction link",
)?;
}

Ok(())
}

/// Partially contract two TreeTNs according to the given specification.
///
/// # Arguments
Expand Down Expand Up @@ -795,7 +867,7 @@ where
{
validate_partial_contraction_spec(a, b, spec)?;

let (a_modified, mut b_modified, restore_from, restore_to) =
let (mut a_modified, mut b_modified, restore_from, restore_to) =
apply_diagonal_pairs(a, b, &spec.diagonal_pairs)?;

for (idx_a, idx_b) in &spec.contract_pairs {
Expand All @@ -810,6 +882,7 @@ where

let mut result = if a_modified.same_topology(&b_modified) {
align_contract_pair_site_nodes(&a_modified, &mut b_modified, &spec.contract_pairs)?;
connect_nodewise_outer_product_spectators(&mut a_modified, &mut b_modified)?;
contract(&a_modified, &b_modified, center, options)
.context("partial_contract: contraction failed")?
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,69 @@ fn test_partial_contract_contract_only() {
assert_eq!(result.external_indices().len(), 2);
}

#[test]
fn partial_contract_fit_inserts_dummy_links_for_nodewise_outer_product_spectators() {
let k_left = DynIndex::new_dyn(2);
let k_right = DynIndex::new_dyn(2);
let i = DynIndex::new_dyn(2);
let j = DynIndex::new_dyn(2);
let bond_left = DynIndex::new_dyn(2);
let bond_right = DynIndex::new_dyn(2);

let a0 = TensorDynLen::from_dense(
vec![k_left.clone(), bond_left.clone()],
vec![1.0, 2.0, 3.0, 4.0],
)
.unwrap();
let a1 = TensorDynLen::from_dense(vec![bond_left.clone(), i.clone()], vec![5.0, 6.0, 7.0, 8.0])
.unwrap();
let b0 = TensorDynLen::from_dense(
vec![k_right.clone(), bond_right.clone()],
vec![0.5, 1.5, 2.5, 3.5],
)
.unwrap();
let b1 = TensorDynLen::from_dense(
vec![bond_right.clone(), j.clone()],
vec![1.0, -1.0, 2.0, -2.0],
)
.unwrap();
let lhs = TreeTN::<TensorDynLen, String>::from_tensors(
vec![a0, a1],
vec!["left".to_string(), "right".to_string()],
)
.unwrap();
let rhs = TreeTN::<TensorDynLen, String>::from_tensors(
vec![b0, b1],
vec!["left".to_string(), "right".to_string()],
)
.unwrap();
let spec = PartialContractionSpec {
contract_pairs: vec![(k_left.clone(), k_right.clone())],
diagonal_pairs: vec![],
output_order: None,
};

let rhs_aligned = rhs.replaceind(&k_right, &k_left).unwrap();
let reference = tensordot(
&lhs.contract_to_tensor().unwrap(),
&rhs_aligned.contract_to_tensor().unwrap(),
&[(k_left.clone(), k_left)],
)
.unwrap();
let actual = partial_contract(
&lhs,
&rhs,
&spec,
&"right".to_string(),
ContractionOptions::fit(),
)
.unwrap()
.to_dense()
.unwrap();

assert!(actual.distance(&reference).unwrap() < 1.0e-10);
}

#[test]
fn test_partial_contract_empty_spec() {
// Empty spec: no contract, no diagonal pair → full outer product
Expand Down
Loading