diff --git a/crates/tensor4all-treetn/src/treetn/fit.rs b/crates/tensor4all-treetn/src/treetn/fit.rs index 09d9c3df..7f4bad4f 100644 --- a/crates/tensor4all-treetn/src/treetn/fit.rs +++ b/crates/tensor4all-treetn/src/treetn/fit.rs @@ -149,6 +149,99 @@ where .ok_or_else(|| anyhow::anyhow!("Tensor for node {:?} not found in {}", node, tree_name)) } +fn tensors_share_contractable_index(left: &T, right: &T) -> bool +where + T: TensorLike, +{ + let left_indices = left.external_indices(); + let right_indices = right.external_indices(); + left_indices.iter().any(|left_index| { + right_indices + .iter() + .any(|right_index| left_index.is_contractable(right_index)) + }) +} + +fn tensor_connected_components(tensors: &[&T]) -> Vec> +where + T: TensorLike, +{ + let mut visited = vec![false; tensors.len()]; + let mut components = Vec::new(); + + for start in 0..tensors.len() { + if visited[start] { + continue; + } + + let mut component = Vec::new(); + let mut stack = vec![start]; + visited[start] = true; + + while let Some(current) = stack.pop() { + component.push(current); + for candidate in 0..tensors.len() { + if visited[candidate] { + continue; + } + if tensors_share_contractable_index(tensors[current], tensors[candidate]) { + visited[candidate] = true; + stack.push(candidate); + } + } + } + + component.sort_unstable(); + components.push(component); + } + + components +} + +fn contract_fit_tensor_refs(tensors: &[&T]) -> Result +where + T: TensorLike, + ::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, +{ + if tensors.is_empty() { + return Err(anyhow::anyhow!( + "fit contraction requires at least one tensor" + )); + } + + let components = tensor_connected_components(tensors); + if components.len() == 1 { + return T::contract(tensors).map_err(|e| anyhow::anyhow!("contract failed: {}", e)); + } + + let mut contracted_components = Vec::with_capacity(components.len()); + for component in components { + let component_refs = component + .iter() + .map(|&tensor_index| tensors[tensor_index]) + .collect::>(); + let contracted = if component_refs.len() == 1 { + component_refs[0].clone() + } else { + T::contract(&component_refs) + .map_err(|e| anyhow::anyhow!("component contract failed: {}", e))? + }; + contracted_components.push(contracted); + } + + let mut iter = contracted_components.into_iter(); + let mut result = iter + .next() + .ok_or_else(|| anyhow::anyhow!("fit contraction produced no components"))?; + for component in iter { + result = result + .outer_product(&component) + .map_err(|e| anyhow::anyhow!("component outer product failed: {}", e))?; + } + + Ok(result) +} + // ============================================================================ // FitEnvironment: Environment tensor cache // ============================================================================ @@ -424,8 +517,7 @@ where // A, B, and C must form one connected local environment. let c_conj = tensor_c.conj(); - let env = T::contract(&[tensor_a, tensor_b, &c_conj]) - .map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; + let env = contract_fit_tensor_refs(&[tensor_a, tensor_b, &c_conj])?; if let Some(started) = started { with_fit_profile(|profile| { @@ -470,8 +562,7 @@ where let c_conj = tensor_c.conj(); let mut tensor_refs: Vec<&T> = vec![tensor_a, tensor_b, &c_conj]; tensor_refs.extend(child_envs.iter()); - let result = - T::contract(&tensor_refs).map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; + let result = contract_fit_tensor_refs(&tensor_refs)?; if let Some(started) = started { with_fit_profile(|profile| { @@ -673,8 +764,7 @@ where let contract_started = fit_profile_enabled().then(Instant::now); let mut tensor_refs: Vec<&T> = vec![a_u, b_u, a_v, b_v]; tensor_refs.extend(env_tensors.iter()); - let ab_uv = - T::contract(&tensor_refs).map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; + let ab_uv = contract_fit_tensor_refs(&tensor_refs)?; if let Some(contract_started) = contract_started { with_fit_profile(|profile| { profile.two_site_contract_time += contract_started.elapsed(); diff --git a/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs b/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs index b1c9d28f..f8df457b 100644 --- a/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs @@ -431,7 +431,7 @@ fn test_contract_fit_positive_sweeps_do_not_skip_without_truncation_options() { } #[test] -fn test_contract_fit_rejects_leaf_site_space_that_contracts_away() { +fn test_contract_fit_handles_leaf_site_space_that_contracts_away() { let left = DynIndex::new_dyn(2); let right = DynIndex::new_dyn(2); let shared_left = DynIndex::new_dyn(2); @@ -492,14 +492,16 @@ fn test_contract_fit_rejects_leaf_site_space_that_contracts_away() { ) .unwrap(); - let err = contract_fit( + let fitted = contract_fit( &tn_a, &tn_b, &"A".to_string(), FitContractionOptions::new(1), ) - .unwrap_err() - .to_string(); + .unwrap(); - assert!(err.contains("Disconnected tensor network")); + assert_eq!(fitted.node_count(), 3); + let fitted_dense = fitted.to_dense().unwrap(); + let expected_dense = tn_a.contract_naive(&tn_b).unwrap(); + assert!(fitted_dense.distance(&expected_dense).unwrap() < 1e-10); } diff --git a/crates/tensor4all-treetn/src/treetn/partial_contraction.rs b/crates/tensor4all-treetn/src/treetn/partial_contraction.rs index 65f690ff..9deda608 100644 --- a/crates/tensor4all-treetn/src/treetn/partial_contraction.rs +++ b/crates/tensor4all-treetn/src/treetn/partial_contraction.rs @@ -442,8 +442,9 @@ where validate_union_topology(&node_names, &union_edges)?; let structural_result = (|| { - let aligned_a = align_to_union_topology(a, &node_names, &union_edges)?; - let aligned_b = align_to_union_topology(b, &node_names, &union_edges)?; + let mut aligned_a = align_to_union_topology(a, &node_names, &union_edges)?; + let mut aligned_b = align_to_union_topology(b, &node_names, &union_edges)?; + connect_nodewise_outer_product_spectators(&mut aligned_a, &mut aligned_b)?; contract(&aligned_a, &aligned_b, center, options.clone()) .context("partial_contract: failed contraction after aligning mismatched topologies") })(); @@ -732,6 +733,77 @@ where Ok(()) } +fn has_contractable_index_pair(left: &[DynIndex], right: &[DynIndex]) -> bool { + left.iter().any(|idx_left| { + right + .iter() + .any(|idx_right| idx_left.is_contractable(idx_right)) + }) +} + +fn connect_nodewise_outer_product_spectators( + a: &mut TreeTN, + b: &mut TreeTN, +) -> Result<()> +where + V: Clone + Hash + Eq + Send + Sync + Debug + Ord, + ::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync, +{ + let mut node_names = a.node_names(); + node_names.retain(|node_name| b.node_index(node_name).is_some()); + node_names.sort(); + + for node_name in node_names { + let node_a = a + .node_index(&node_name) + .context("partial_contract: node disappeared while adding dummy contraction links")?; + let node_b = b.node_index(&node_name).context( + "partial_contract: matching node disappeared while adding dummy contraction links", + )?; + + let tensor_a = a + .tensor(node_a) + .cloned() + .context("partial_contract: missing tensor while adding dummy contraction links")?; + let tensor_b = b.tensor(node_b).cloned().context( + "partial_contract: missing matching tensor while adding dummy contraction links", + )?; + + if has_contractable_index_pair(&tensor_a.external_indices(), &tensor_b.external_indices()) { + continue; + } + + let (dummy_a, dummy_b) = DynIndex::create_dummy_link_pair(); + let dummy_tensor_a = + ::ones(std::slice::from_ref(&dummy_a)) + .context("partial_contract: failed to build dummy contraction link")?; + let dummy_tensor_b = + ::ones(std::slice::from_ref(&dummy_b)) + .context("partial_contract: failed to build matching dummy contraction link")?; + let expanded_a = tensor_a + .outer_product(&dummy_tensor_a) + .context("partial_contract: failed to attach dummy contraction link")?; + let expanded_b = tensor_b + .outer_product(&dummy_tensor_b) + .context("partial_contract: failed to attach matching dummy contraction link")?; + + a.replace_tensor(node_a, expanded_a) + .context( + "partial_contract: failed to replace tensor after adding dummy contraction link", + )? + .context("partial_contract: node disappeared after adding dummy contraction link")?; + b.replace_tensor(node_b, expanded_b) + .context( + "partial_contract: failed to replace matching tensor after adding dummy contraction link", + )? + .context( + "partial_contract: matching node disappeared after adding dummy contraction link", + )?; + } + + Ok(()) +} + /// Partially contract two TreeTNs according to the given specification. /// /// # Arguments @@ -795,7 +867,7 @@ where { validate_partial_contraction_spec(a, b, spec)?; - let (a_modified, mut b_modified, restore_from, restore_to) = + let (mut a_modified, mut b_modified, restore_from, restore_to) = apply_diagonal_pairs(a, b, &spec.diagonal_pairs)?; for (idx_a, idx_b) in &spec.contract_pairs { @@ -810,6 +882,7 @@ where let mut result = if a_modified.same_topology(&b_modified) { align_contract_pair_site_nodes(&a_modified, &mut b_modified, &spec.contract_pairs)?; + connect_nodewise_outer_product_spectators(&mut a_modified, &mut b_modified)?; contract(&a_modified, &b_modified, center, options) .context("partial_contract: contraction failed")? } else { diff --git a/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs b/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs index 56d46da3..1595b15f 100644 --- a/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs @@ -488,6 +488,69 @@ fn test_partial_contract_contract_only() { assert_eq!(result.external_indices().len(), 2); } +#[test] +fn partial_contract_fit_inserts_dummy_links_for_nodewise_outer_product_spectators() { + let k_left = DynIndex::new_dyn(2); + let k_right = DynIndex::new_dyn(2); + let i = DynIndex::new_dyn(2); + let j = DynIndex::new_dyn(2); + let bond_left = DynIndex::new_dyn(2); + let bond_right = DynIndex::new_dyn(2); + + let a0 = TensorDynLen::from_dense( + vec![k_left.clone(), bond_left.clone()], + vec![1.0, 2.0, 3.0, 4.0], + ) + .unwrap(); + let a1 = TensorDynLen::from_dense(vec![bond_left.clone(), i.clone()], vec![5.0, 6.0, 7.0, 8.0]) + .unwrap(); + let b0 = TensorDynLen::from_dense( + vec![k_right.clone(), bond_right.clone()], + vec![0.5, 1.5, 2.5, 3.5], + ) + .unwrap(); + let b1 = TensorDynLen::from_dense( + vec![bond_right.clone(), j.clone()], + vec![1.0, -1.0, 2.0, -2.0], + ) + .unwrap(); + let lhs = TreeTN::::from_tensors( + vec![a0, a1], + vec!["left".to_string(), "right".to_string()], + ) + .unwrap(); + let rhs = TreeTN::::from_tensors( + vec![b0, b1], + vec!["left".to_string(), "right".to_string()], + ) + .unwrap(); + let spec = PartialContractionSpec { + contract_pairs: vec![(k_left.clone(), k_right.clone())], + diagonal_pairs: vec![], + output_order: None, + }; + + let rhs_aligned = rhs.replaceind(&k_right, &k_left).unwrap(); + let reference = tensordot( + &lhs.contract_to_tensor().unwrap(), + &rhs_aligned.contract_to_tensor().unwrap(), + &[(k_left.clone(), k_left)], + ) + .unwrap(); + let actual = partial_contract( + &lhs, + &rhs, + &spec, + &"right".to_string(), + ContractionOptions::fit(), + ) + .unwrap() + .to_dense() + .unwrap(); + + assert!(actual.distance(&reference).unwrap() < 1.0e-10); +} + #[test] fn test_partial_contract_empty_spec() { // Empty spec: no contract, no diagonal pair → full outer product