diff --git a/Cargo.toml b/Cargo.toml index 2520c9a3..cefa075b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,7 +55,8 @@ hdf5-metno = { version = "0.12", default-features = false } ndarray = "0.17" quanticsgrids = { git = "https://github.com/tensor4all/quanticsgrids-rs", rev = "a76b8fb" } hdf5-rt = { git = "https://github.com/tensor4all/hdf5-rt", default-features = false } -tenferro = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "f07a3a08dd188ea5a726a43c6b53de5d417eb0a6", default-features = false } -tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "f07a3a08dd188ea5a726a43c6b53de5d417eb0a6" } -tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "f07a3a08dd188ea5a726a43c6b53de5d417eb0a6", default-features = false } -tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "f07a3a08dd188ea5a726a43c6b53de5d417eb0a6", default-features = false } +tenferro = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "95fb0724c81cfb2f7640b2cd9255d1700c100358", default-features = false } +tenferro-device = { package = "tenferro-internal-device", git = "https://github.com/tensor4all/tenferro-rs.git", rev = "95fb0724c81cfb2f7640b2cd9255d1700c100358" } +tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "95fb0724c81cfb2f7640b2cd9255d1700c100358", default-features = false } +tenferro-linalg = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "95fb0724c81cfb2f7640b2cd9255d1700c100358", default-features = false } +tenferro-tensor = { package = "tenferro-internal-tensor", git = "https://github.com/tensor4all/tenferro-rs.git", rev = "95fb0724c81cfb2f7640b2cd9255d1700c100358", default-features = false } diff --git a/benchmarks/rust/benchmark_tt_ops.rs b/benchmarks/rust/benchmark_tt_ops.rs index c99ecceb..49a112ca 100644 --- a/benchmarks/rust/benchmark_tt_ops.rs +++ b/benchmarks/rust/benchmark_tt_ops.rs @@ -25,7 +25,7 @@ use tensor4all_core::{ TensorContractionLike, TensorDynLen, }; use tensor4all_itensorlike::{CanonicalForm, ContractOptions, TensorTrain}; -use tenferro::{CpuBackend, DotGeneralConfig, EagerContext, EagerTensor, Tensor, TypedTensor}; +use tenferro::{CpuBackend, DotGeneralConfig, EagerRuntime, EagerTensor, Tensor, TypedTensor}; #[derive(Debug, Clone)] struct Options { @@ -162,7 +162,7 @@ fn deterministic_native_tensor(shape: Vec, seed: usize) -> Tensor { let data = (0..len) .map(|idx| deterministic_value(idx, seed)) .collect::>(); - Tensor::C64(TypedTensor::from_vec(shape, data)) + Tensor::C64(TypedTensor::from_vec_col_major(shape, data)) } fn make_sites(length: usize, phys_dim: usize) -> Vec { @@ -214,10 +214,7 @@ fn make_native_mps_t4a_shapes( .collect() } -fn eager_mps_tensors( - ctx: &Arc, - tensors: Vec, -) -> Vec { +fn eager_mps_tensors(ctx: &Arc, tensors: Vec) -> Vec { tensors .into_iter() .map(|tensor| EagerTensor::from_tensor_in(tensor, Arc::clone(ctx))) @@ -666,7 +663,7 @@ fn main() -> Result<()> { let bra = make_mps(&sites, chi, 0)?; let ket = make_mps(&sites, chi, opts.length)?; let bra_conj = preconjugate_sites(&bra)?; - let raw_ctx = EagerContext::with_cpu_backend(CpuBackend::with_threads(1)); + let raw_ctx = EagerRuntime::with_cpu_backend(CpuBackend::with_threads(1)); let raw_bra = eager_mps_tensors( &raw_ctx, make_native_mps_t4a_shapes(opts.length, opts.phys_dim, chi, 0), diff --git a/crates/tensor4all-core/Cargo.toml b/crates/tensor4all-core/Cargo.toml index 873ecb9a..a3f2e498 100644 --- a/crates/tensor4all-core/Cargo.toml +++ b/crates/tensor4all-core/Cargo.toml @@ -13,19 +13,34 @@ backend-tenferro = ["tensor4all-tensorbackend/backend-tenferro"] tenferro-cpu-faer = [ "tensor4all-tensorbackend/tenferro-cpu-faer", "tensor4all-tcicore/tenferro-cpu-faer", + "tenferro/autodiff", "tenferro/cpu-faer", + "tenferro-einsum/autodiff", + "tenferro-einsum/cpu-faer", + "tenferro-linalg/autodiff", + "tenferro-linalg/cpu-faer", "tenferro-tensor/cpu-faer", ] tenferro-system-blas = [ "tensor4all-tensorbackend/tenferro-system-blas", "tensor4all-tcicore/tenferro-system-blas", + "tenferro/autodiff", "tenferro/cpu-blas", + "tenferro-einsum/autodiff", + "tenferro-einsum/cpu-blas", + "tenferro-linalg/autodiff", + "tenferro-linalg/cpu-blas", "tenferro-tensor/cpu-blas", ] tenferro-provider-inject = [ "tensor4all-tensorbackend/tenferro-provider-inject", "tensor4all-tcicore/tenferro-provider-inject", + "tenferro/autodiff", "tenferro/cpu-blas", + "tenferro-einsum/autodiff", + "tenferro-einsum/cpu-blas", + "tenferro-linalg/autodiff", + "tenferro-linalg/cpu-blas", "tenferro-tensor/provider-inject", ] @@ -43,6 +58,8 @@ omeco.workspace = true petgraph.workspace = true smallvec.workspace = true tenferro.workspace = true +tenferro-einsum.workspace = true +tenferro-linalg.workspace = true tenferro-tensor.workspace = true [dev-dependencies] diff --git a/crates/tensor4all-core/src/defaults/contract.rs b/crates/tensor4all-core/src/defaults/contract.rs index 85e3bd09..de8074f0 100644 --- a/crates/tensor4all-core/src/defaults/contract.rs +++ b/crates/tensor4all-core/src/defaults/contract.rs @@ -27,8 +27,8 @@ use std::time::{Duration, Instant}; use anyhow::Result; use petgraph::algo::connected_components; use petgraph::prelude::*; -use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad; -use tenferro::EinsumSubscripts; +use tenferro_einsum::eager_tensor::einsum_subscripts as eager_einsum_ad; +use tenferro_einsum::EinsumSubscripts; use tensor4all_tensorbackend::{einsum_native_tensors, einsum_native_tensors_owned}; use crate::defaults::{DynId, DynIndex, TensorDynLen}; diff --git a/crates/tensor4all-core/src/defaults/factorize.rs b/crates/tensor4all-core/src/defaults/factorize.rs index 4589ff37..e9ac9324 100644 --- a/crates/tensor4all-core/src/defaults/factorize.rs +++ b/crates/tensor4all-core/src/defaults/factorize.rs @@ -35,7 +35,10 @@ use crate::defaults::tensordynlen::unfold_split_inner; use crate::defaults::DynIndex; use crate::{contract_pair, unfold_split, TensorDynLen}; +use anyhow::Result as AnyhowResult; use num_complex::{Complex64, ComplexFloat}; +use tenferro::EagerTensor; +use tenferro_linalg::eager_tensor::full_piv_lu_solve as eager_full_piv_lu_solve; use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar}; use tensor4all_tensorbackend::{Matrix, TensorElement}; @@ -563,7 +566,7 @@ where })?; let (left_inner, right_inner) = match canonical { Canonical::Left => { - let left = pivot_inner.right_solve(&cols_inner).map_err(|e| { + let left = eager_right_solve(&pivot_inner, &cols_inner).map_err(|e| { FactorizeError::ComputationError(anyhow::anyhow!( "fixed-pivot CI right solve failed: {e}" )) @@ -571,7 +574,7 @@ where (left, rows_inner) } Canonical::Right => { - let right = pivot_inner.solve(&rows_inner).map_err(|e| { + let right = eager_full_piv_lu_solve(&pivot_inner, &rows_inner).map_err(|e| { FactorizeError::ComputationError(anyhow::anyhow!( "fixed-pivot CI solve failed: {e}" )) @@ -613,6 +616,12 @@ where }) } +fn eager_right_solve(a: &EagerTensor, rhs: &EagerTensor) -> AnyhowResult { + let a_t = a.transpose(&[1, 0])?; + let rhs_t = rhs.transpose(&[1, 0])?; + Ok(eager_full_piv_lu_solve(&a_t, &rhs_t)?.transpose(&[1, 0])?) +} + /// Convert a native rank-2 tensor into a backend [`Matrix`]. fn native_tensor_to_matrix( tensor: &tenferro::Tensor, diff --git a/crates/tensor4all-core/src/defaults/qr.rs b/crates/tensor4all-core/src/defaults/qr.rs index c6c4394b..959cabf0 100644 --- a/crates/tensor4all-core/src/defaults/qr.rs +++ b/crates/tensor4all-core/src/defaults/qr.rs @@ -8,6 +8,7 @@ use crate::global_default::GlobalDefault; use crate::TensorDynLen; use num_complex::ComplexFloat; use tenferro::DType; +use tenferro_linalg::eager_tensor::qr as eager_qr; use tensor4all_tensorbackend::{ native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major, }; @@ -258,9 +259,8 @@ pub fn qr_with( .map_err(|e| anyhow::anyhow!("Failed to unfold tensor: {}", e)) .map_err(QrError::ComputationError)?; let k = m.min(n); - let (mut q_inner, mut r_inner) = matrix_inner - .qr() - .map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?; + let (mut q_inner, mut r_inner) = + eager_qr(&matrix_inner).map_err(|e| QrError::ComputationError(anyhow::anyhow!("{e}")))?; let r = if options.truncate { // Determine rtol to use diff --git a/crates/tensor4all-core/src/defaults/structured_contraction.rs b/crates/tensor4all-core/src/defaults/structured_contraction.rs index edc6cffc..0652deec 100644 --- a/crates/tensor4all-core/src/defaults/structured_contraction.rs +++ b/crates/tensor4all-core/src/defaults/structured_contraction.rs @@ -331,14 +331,14 @@ pub(crate) fn normalize_payload_read_for_roots<'a>( pub(crate) fn storage_payload_native(storage: &Storage) -> Result { if storage.is_f64() { - Ok(NativeTensor::from_vec( + Ok(NativeTensor::from_vec_col_major( storage.payload_dims().to_vec(), storage .payload_f64_col_major_vec() .map_err(anyhow::Error::msg)?, )) } else if storage.is_c64() { - Ok(NativeTensor::from_vec( + Ok(NativeTensor::from_vec_col_major( storage.payload_dims().to_vec(), storage .payload_c64_col_major_vec() @@ -527,7 +527,8 @@ mod tests { #[test] fn normalizes_repeated_payload_roots_by_extracting_diagonal() { - let payload = tenferro::Tensor::from_vec(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]); + let payload = + tenferro::Tensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 2.0, 3.0, 4.0]); let (normalized, roots) = normalize_payload_for_roots(&payload, &[0, 0]).unwrap(); assert_eq!(normalized.shape(), &[2]); diff --git a/crates/tensor4all-core/src/defaults/svd.rs b/crates/tensor4all-core/src/defaults/svd.rs index 5e4d1be0..ba460dcc 100644 --- a/crates/tensor4all-core/src/defaults/svd.rs +++ b/crates/tensor4all-core/src/defaults/svd.rs @@ -17,6 +17,7 @@ use crate::truncation::{ use crate::TensorDynLen; use std::sync::Mutex; use tenferro::{DType, EagerTensor}; +use tenferro_linalg::eager_tensor::svd as eager_svd; use tensor4all_tensorbackend::{ native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_f64_col_major, }; @@ -238,9 +239,8 @@ fn svd_truncated_inner( .map_err(SvdError::ComputationError)?; let k = m.min(n); - let (mut u_inner, mut s_inner, mut vt_inner) = matrix_inner - .svd() - .map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?; + let (mut u_inner, mut s_inner, mut vt_inner) = + eager_svd(&matrix_inner).map_err(|e| SvdError::ComputationError(anyhow::anyhow!("{e}")))?; let s_full = singular_values_from_native(s_inner.data())?; let mut r = if options.truncate { let policy = options.policy.unwrap_or_else(default_svd_truncation_policy); diff --git a/crates/tensor4all-core/src/defaults/svd/tests/mod.rs b/crates/tensor4all-core/src/defaults/svd/tests/mod.rs index ab78cd03..d89214a6 100644 --- a/crates/tensor4all-core/src/defaults/svd/tests/mod.rs +++ b/crates/tensor4all-core/src/defaults/svd/tests/mod.rs @@ -41,10 +41,10 @@ fn compute_retained_rank_supports_all_policy_axes() { #[test] fn singular_values_from_native_accepts_real_and_complex_dense() { - let dense = NativeTensor::from_vec(vec![2], vec![3.0_f64, 1.5]); + let dense = NativeTensor::from_vec_col_major(vec![2], vec![3.0_f64, 1.5]); assert_eq!(singular_values_from_native(&dense).unwrap(), vec![3.0, 1.5]); - let complex = NativeTensor::from_vec( + let complex = NativeTensor::from_vec_col_major( vec![2], vec![Complex64::new(1.0, 2.0), Complex64::new(0.5, -4.0)], ); @@ -89,7 +89,7 @@ fn svd_options_accessors_roundtrip() { #[test] fn singular_values_from_native_rejects_unsupported_scalar_types() { - let tensor = NativeTensor::from_vec(vec![2], vec![1.0_f32, 2.0]); + let tensor = NativeTensor::from_vec_col_major(vec![2], vec![1.0_f32, 2.0]); let err = singular_values_from_native(&tensor).unwrap_err(); assert!(err .to_string() diff --git a/crates/tensor4all-core/src/defaults/tensordynlen.rs b/crates/tensor4all-core/src/defaults/tensordynlen.rs index d27f2fd7..2fa4872f 100644 --- a/crates/tensor4all-core/src/defaults/tensordynlen.rs +++ b/crates/tensor4all-core/src/defaults/tensordynlen.rs @@ -14,8 +14,9 @@ use std::collections::{HashMap, HashSet}; use std::env; use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; -use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad; -use tenferro::{DType, DotGeneralConfig, EagerTensor, EinsumSubscripts, Tensor as NativeTensor}; +use tenferro::{DType, DotGeneralConfig, EagerTensor, Tensor as NativeTensor}; +use tenferro_einsum::eager_tensor::einsum_subscripts as eager_einsum_ad; +use tenferro_einsum::EinsumSubscripts; use tensor4all_tensorbackend::{ axpby_native_tensor, contract_native_tensor, default_eager_ctx, dense_native_tensor_from_col_major, diag_native_tensor_from_col_major, @@ -125,7 +126,9 @@ fn native_tensor_profile_bytes(native: &NativeTensor) -> usize { DType::F64 => 8, DType::C32 => 8, DType::C64 => 16, + DType::I32 => 4, DType::I64 => 8, + DType::Bool => 1, }; native.shape().iter().product::() * element_size } @@ -1048,10 +1051,12 @@ impl TensorDynLen { ) -> Result { if Self::is_diag_axis_classes(axis_classes) { match native.dtype() { - DType::F32 | DType::F64 | DType::I64 => Storage::from_diag_col_major( - native_tensor_primal_to_diag_f64(native)?, - logical_rank, - ), + DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => { + Storage::from_diag_col_major( + native_tensor_primal_to_diag_f64(native)?, + logical_rank, + ) + } DType::C32 | DType::C64 => Storage::from_diag_col_major( native_tensor_primal_to_diag_c64(native)?, logical_rank, @@ -1535,7 +1540,7 @@ impl TensorDynLen { } let starts_tensor = EagerTensor::from_tensor_in( - NativeTensor::from_vec(vec![rank], starts), + NativeTensor::from_vec_col_major(vec![rank], starts), default_eager_ctx(), ); let sliced = self diff --git a/crates/tensor4all-simplett/src/einsum_helper.rs b/crates/tensor4all-simplett/src/einsum_helper.rs index 55778474..afe17045 100644 --- a/crates/tensor4all-simplett/src/einsum_helper.rs +++ b/crates/tensor4all-simplett/src/einsum_helper.rs @@ -1,9 +1,9 @@ use std::error::Error; use std::fmt; -use tenferro_einsum::typed_eager_einsum; -use tenferro_tensor::{TensorScalar, TypedTensor}; -use tensor4all_tensorbackend::with_default_backend; +use tenferro_einsum::Subscripts; +use tenferro_tensor::{Tensor, TensorScalar, TypedTensor}; +use tensor4all_tensorbackend::einsum_native_tensors; pub(crate) type Result = std::result::Result; @@ -91,12 +91,44 @@ pub(crate) fn einsum_tensors( subscripts: &str, operands: &[&TypedTensor], ) -> Result> { - with_default_backend(|backend| typed_eager_einsum(backend, operands, subscripts)).map_err( - |err| EinsumHelperError::Backend { + let parsed = Subscripts::parse(subscripts).map_err(|err| EinsumHelperError::Backend { + subscripts: subscripts.to_string(), + message: err.to_string(), + })?; + let tensors: Vec = operands + .iter() + .map(|tensor| T::into_tensor(tensor.shape.clone(), tensor.host_data().to_vec())) + .collect(); + let input_ids = parsed + .inputs + .iter() + .map(|ids| ids.iter().map(|&id| id as usize).collect::>()) + .collect::>(); + let operand_refs = tensors + .iter() + .zip(input_ids.iter()) + .map(|(tensor, ids)| (tensor, ids.as_slice())) + .collect::>(); + let output_ids = parsed + .output + .iter() + .map(|&id| id as usize) + .collect::>(); + + let result = einsum_native_tensors(&operand_refs, &output_ids).map_err(|err| { + EinsumHelperError::Backend { subscripts: subscripts.to_string(), message: err.to_string(), - }, - ) + } + })?; + let actual = result.dtype(); + T::try_into_typed(result).ok_or_else(|| EinsumHelperError::Backend { + subscripts: subscripts.to_string(), + message: format!( + "dtype mismatch: result has {actual:?}, expected {:?}", + T::dtype() + ), + }) } pub(crate) fn typed_tensor_from_col_major_slice( @@ -112,7 +144,10 @@ pub(crate) fn typed_tensor_from_col_major_slice( }); } - Ok(TypedTensor::from_vec(shape.to_vec(), data.to_vec())) + Ok(TypedTensor::from_vec_col_major( + shape.to_vec(), + data.to_vec(), + )) } pub(crate) fn typed_tensor_reshape( @@ -130,7 +165,7 @@ pub(crate) fn typed_tensor_reshape( }); } - Ok(TypedTensor::from_vec( + Ok(TypedTensor::from_vec_col_major( shape.to_vec(), tensor.host_data().to_vec(), )) diff --git a/crates/tensor4all-simplett/src/mpo/types.rs b/crates/tensor4all-simplett/src/mpo/types.rs index 82480e40..af1b312b 100644 --- a/crates/tensor4all-simplett/src/mpo/types.rs +++ b/crates/tensor4all-simplett/src/mpo/types.rs @@ -183,7 +183,7 @@ pub fn tensor4_from_data( }); } let dims = [left_dim, site_dim_1, site_dim_2, right_dim]; - let inner = TfTensor::from_vec(dims.to_vec(), data); + let inner = TfTensor::from_vec_col_major(dims.to_vec(), data); Ok(Tensor::from_tenferro_unchecked(inner)) } diff --git a/crates/tensor4all-simplett/src/tensor.rs b/crates/tensor4all-simplett/src/tensor.rs index c2453bbe..46af3fcb 100644 --- a/crates/tensor4all-simplett/src/tensor.rs +++ b/crates/tensor4all-simplett/src/tensor.rs @@ -61,7 +61,7 @@ fn col_major_data_to_tensor( dims: [usize; N], data: Vec, ) -> Tensor { - let inner = TfTensor::from_vec(dims.to_vec(), data); + let inner = TfTensor::from_vec_col_major(dims.to_vec(), data); Tensor::from_tenferro_unchecked(inner) } @@ -221,11 +221,11 @@ impl Tensor { /// use tensor4all_simplett::tensor::{Tensor2, Tensor3}; /// use tenferro_tensor::TypedTensor; /// - /// let rank_2 = TypedTensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); + /// let rank_2 = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); /// let tensor = Tensor2::try_from_tenferro(rank_2).unwrap(); /// assert_eq!(tensor.dims(), &[2, 2]); /// - /// let rank_2 = TypedTensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); + /// let rank_2 = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); /// assert!(Tensor3::try_from_tenferro(rank_2).is_err()); /// ``` pub fn try_from_tenferro(tensor: TfTensor) -> Result { @@ -244,11 +244,11 @@ impl Tensor { /// use tensor4all_simplett::tensor::{Tensor2, Tensor3}; /// use tenferro_tensor::TypedTensor; /// - /// let rank_2 = TypedTensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); + /// let rank_2 = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); /// let tensor = Tensor2::from_tenferro(rank_2).unwrap(); /// assert_eq!(tensor.dims(), &[2, 2]); /// - /// let rank_2 = TypedTensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); + /// let rank_2 = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]); /// assert!(Tensor3::from_tenferro(rank_2).is_err()); /// ``` pub fn from_tenferro(tensor: TfTensor) -> Result { diff --git a/crates/tensor4all-simplett/src/types.rs b/crates/tensor4all-simplett/src/types.rs index dc2f4cff..28ac115c 100644 --- a/crates/tensor4all-simplett/src/types.rs +++ b/crates/tensor4all-simplett/src/types.rs @@ -184,7 +184,7 @@ pub fn tensor3_from_data( }); } let dims = [left_dim, site_dim, right_dim]; - let inner = TfTensor::from_vec(dims.to_vec(), data); + let inner = TfTensor::from_vec_col_major(dims.to_vec(), data); Ok(Tensor::from_tenferro_unchecked(inner)) } diff --git a/crates/tensor4all-tcicore/benches/dense_vs_tenferro.rs b/crates/tensor4all-tcicore/benches/dense_vs_tenferro.rs index c47a71ae..788bd4f1 100644 --- a/crates/tensor4all-tcicore/benches/dense_vs_tenferro.rs +++ b/crates/tensor4all-tcicore/benches/dense_vs_tenferro.rs @@ -46,7 +46,7 @@ fn bench_dense_vs_tenferro(c: &mut Criterion) { &size, |b, &n| { b.iter(|| { - let mat = Tensor::from_vec(vec![n, n], data.clone()); + let mat = Tensor::from_vec_col_major(vec![n, n], data.clone()); black_box(with_default_backend(|backend| mat.full_piv_lu(backend)).unwrap()); }); }, diff --git a/crates/tensor4all-tcicore/benches/rrlu_bench.rs b/crates/tensor4all-tcicore/benches/rrlu_bench.rs index e90713fd..fa429d93 100644 --- a/crates/tensor4all-tcicore/benches/rrlu_bench.rs +++ b/crates/tensor4all-tcicore/benches/rrlu_bench.rs @@ -24,7 +24,7 @@ fn random_tenferro_matrix(n: usize, m: usize, seed: u64) -> Tensor { data[row + n * col] = rng.random::(); } } - Tensor::from_vec(vec![n, m], data) + Tensor::from_vec_col_major(vec![n, m], data) } fn bench_rrlu(c: &mut Criterion) { diff --git a/crates/tensor4all-tensorbackend/Cargo.toml b/crates/tensor4all-tensorbackend/Cargo.toml index 3c40d078..7de2feea 100644 --- a/crates/tensor4all-tensorbackend/Cargo.toml +++ b/crates/tensor4all-tensorbackend/Cargo.toml @@ -12,9 +12,26 @@ categories = ["science", "mathematics"] [features] default = ["backend-tenferro", "tenferro-cpu-faer"] backend-tenferro = [] -tenferro-cpu-faer = ["tenferro/cpu-faer"] -tenferro-system-blas = ["tenferro/cpu-blas", "tenferro-tensor/cpu-blas"] -tenferro-provider-inject = ["tenferro/cpu-blas", "tenferro-tensor/provider-inject"] +tenferro-cpu-faer = [ + "tenferro/autodiff", + "tenferro/cpu-faer", + "tenferro-einsum/autodiff", + "tenferro-einsum/cpu-faer", +] +tenferro-system-blas = [ + "tenferro/autodiff", + "tenferro/cpu-blas", + "tenferro-einsum/autodiff", + "tenferro-einsum/cpu-blas", + "tenferro-tensor/cpu-blas", +] +tenferro-provider-inject = [ + "tenferro/autodiff", + "tenferro/cpu-blas", + "tenferro-einsum/autodiff", + "tenferro-einsum/cpu-blas", + "tenferro-tensor/provider-inject", +] einsum-dispatch-profile = [] [dependencies] diff --git a/crates/tensor4all-tensorbackend/src/any_scalar.rs b/crates/tensor4all-tensorbackend/src/any_scalar.rs index 0fce8662..1a220ea6 100644 --- a/crates/tensor4all-tensorbackend/src/any_scalar.rs +++ b/crates/tensor4all-tensorbackend/src/any_scalar.rs @@ -14,7 +14,9 @@ use crate::tensor_element::TensorElement; enum ScalarValue { F32(f32), F64(f64), + I32(i32), I64(i64), + Bool(bool), C32(Complex32), C64(Complex64), } @@ -24,7 +26,15 @@ impl ScalarValue { match self { Self::F32(value) => value as f64, Self::F64(value) => value, + Self::I32(value) => value as f64, Self::I64(value) => value as f64, + Self::Bool(value) => { + if value { + 1.0 + } else { + 0.0 + } + } Self::C32(value) => value.re as f64, Self::C64(value) => value.re, } @@ -32,7 +42,7 @@ impl ScalarValue { fn imag(self) -> f64 { match self { - Self::F32(_) | Self::F64(_) | Self::I64(_) => 0.0, + Self::F32(_) | Self::F64(_) | Self::I32(_) | Self::I64(_) | Self::Bool(_) => 0.0, Self::C32(value) => value.im as f64, Self::C64(value) => value.im, } @@ -42,7 +52,15 @@ impl ScalarValue { match self { Self::F32(value) => value.abs() as f64, Self::F64(value) => value.abs(), + Self::I32(value) => value.abs() as f64, Self::I64(value) => value.abs() as f64, + Self::Bool(value) => { + if value { + 1.0 + } else { + 0.0 + } + } Self::C32(value) => value.norm() as f64, Self::C64(value) => value.norm(), } @@ -56,7 +74,9 @@ impl ScalarValue { match self { Self::F32(value) => value == 0.0, Self::F64(value) => value == 0.0, + Self::I32(value) => value == 0, Self::I64(value) => value == 0, + Self::Bool(value) => !value, Self::C32(value) => value == Complex32::new(0.0, 0.0), Self::C64(value) => value == Complex64::new(0.0, 0.0), } @@ -66,7 +86,9 @@ impl ScalarValue { match self { Self::F32(value) => Complex64::new(value as f64, 0.0), Self::F64(value) => Complex64::new(value, 0.0), + Self::I32(value) => Complex64::new(value as f64, 0.0), Self::I64(value) => Complex64::new(value as f64, 0.0), + Self::Bool(value) => Complex64::new(if value { 1.0 } else { 0.0 }, 0.0), Self::C32(value) => Complex64::new(value.re as f64, value.im as f64), Self::C64(value) => value, } @@ -99,11 +121,21 @@ fn scalar_value_from_native(native: &NativeTensor) -> Result { .and_then(|values| values.first().copied()) .map(ScalarValue::F64) .ok_or_else(|| anyhow!("failed to read f64 scalar tensor value")), + DType::I32 => native + .as_slice::() + .and_then(|values| values.first().copied()) + .map(ScalarValue::I32) + .ok_or_else(|| anyhow!("failed to read i32 scalar tensor value")), DType::I64 => native .as_slice::() .and_then(|values| values.first().copied()) .map(ScalarValue::I64) .ok_or_else(|| anyhow!("failed to read i64 scalar tensor value")), + DType::Bool => native + .as_slice::() + .and_then(|values| values.first().copied()) + .map(ScalarValue::Bool) + .ok_or_else(|| anyhow!("failed to read bool scalar tensor value")), DType::C32 => native .as_slice::() .and_then(|values| values.first().copied()) @@ -158,6 +190,11 @@ pub(crate) fn promote_scalar_native(native: &NativeTensor, target: DType) -> Res "cannot promote f32 scalar to i64 without truncation" )); } + (ScalarValue::F32(_), DType::I32 | DType::Bool) => { + return Err(anyhow!( + "cannot promote f32 scalar to integer/bool without truncation" + )); + } (ScalarValue::F64(value), DType::F32) => Scalar::from_value(value as f32), (ScalarValue::F64(value), DType::F64) => Scalar::from_value(value), (ScalarValue::F64(_), DType::I64) => { @@ -165,12 +202,35 @@ pub(crate) fn promote_scalar_native(native: &NativeTensor, target: DType) -> Res "cannot promote f64 scalar to i64 without truncation" )); } + (ScalarValue::F64(_), DType::I32 | DType::Bool) => { + return Err(anyhow!( + "cannot promote f64 scalar to integer/bool without truncation" + )); + } (ScalarValue::F64(value), DType::C32) => { Scalar::from_value(Complex32::new(value as f32, 0.0)) } (ScalarValue::F64(value), DType::C64) => Scalar::from_value(Complex64::new(value, 0.0)), + (ScalarValue::I32(value), DType::F32) => Scalar::from_value(value as f32), + (ScalarValue::I32(value), DType::F64) => Scalar::from_value(value as f64), + (ScalarValue::I32(value), DType::I32) => { + return Ok(NativeTensor::from_vec_col_major(vec![], vec![value])); + } + (ScalarValue::I32(value), DType::I64) => Scalar::from_i64(value as i64), + (ScalarValue::I32(value), DType::C32) => { + Scalar::from_value(Complex32::new(value as f32, 0.0)) + } + (ScalarValue::I32(value), DType::C64) => { + Scalar::from_value(Complex64::new(value as f64, 0.0)) + } + (ScalarValue::I32(_), DType::Bool) => { + return Err(anyhow!("cannot promote i32 scalar to bool")); + } (ScalarValue::I64(value), DType::F32) => Scalar::from_value(value as f32), (ScalarValue::I64(value), DType::F64) => Scalar::from_value(value as f64), + (ScalarValue::I64(_), DType::I32 | DType::Bool) => { + return Err(anyhow!("cannot promote i64 scalar to i32/bool")); + } (ScalarValue::I64(value), DType::I64) => Scalar::from_i64(value), (ScalarValue::I64(value), DType::C32) => { Scalar::from_value(Complex32::new(value as f32, 0.0)) @@ -178,10 +238,30 @@ pub(crate) fn promote_scalar_native(native: &NativeTensor, target: DType) -> Res (ScalarValue::I64(value), DType::C64) => { Scalar::from_value(Complex64::new(value as f64, 0.0)) } + (ScalarValue::Bool(value), DType::F32) => { + Scalar::from_value(if value { 1.0_f32 } else { 0.0_f32 }) + } + (ScalarValue::Bool(value), DType::F64) => Scalar::from_value(if value { 1.0 } else { 0.0 }), + (ScalarValue::Bool(value), DType::I32) => { + return Ok(NativeTensor::from_vec_col_major( + vec![], + vec![if value { 1 } else { 0 }], + )); + } + (ScalarValue::Bool(value), DType::I64) => Scalar::from_i64(if value { 1 } else { 0 }), + (ScalarValue::Bool(value), DType::Bool) => { + return Ok(NativeTensor::from_vec_col_major(vec![], vec![value])); + } + (ScalarValue::Bool(value), DType::C32) => { + Scalar::from_value(Complex32::new(if value { 1.0 } else { 0.0 }, 0.0)) + } + (ScalarValue::Bool(value), DType::C64) => { + Scalar::from_value(Complex64::new(if value { 1.0 } else { 0.0 }, 0.0)) + } (ScalarValue::C32(value), DType::F32) => Scalar::from_value(value.re), (ScalarValue::C32(value), DType::F64) => Scalar::from_value(value.re as f64), - (ScalarValue::C32(_), DType::I64) => { - return Err(anyhow!("cannot promote c32 scalar to i64")); + (ScalarValue::C32(_), DType::I32 | DType::I64 | DType::Bool) => { + return Err(anyhow!("cannot promote c32 scalar to integer/bool")); } (ScalarValue::C32(value), DType::C32) => Scalar::from_value(value), (ScalarValue::C32(value), DType::C64) => { @@ -189,8 +269,8 @@ pub(crate) fn promote_scalar_native(native: &NativeTensor, target: DType) -> Res } (ScalarValue::C64(value), DType::F32) => Scalar::from_value(value.re as f32), (ScalarValue::C64(value), DType::F64) => Scalar::from_value(value.re), - (ScalarValue::C64(_), DType::I64) => { - return Err(anyhow!("cannot promote c64 scalar to i64")); + (ScalarValue::C64(_), DType::I32 | DType::I64 | DType::Bool) => { + return Err(anyhow!("cannot promote c64 scalar to integer/bool")); } (ScalarValue::C64(value), DType::C32) => { Scalar::from_value(Complex32::new(value.re as f32, value.im as f32)) @@ -254,11 +334,25 @@ impl Scalar { fn from_i64(value: i64) -> Self { Self { - native: NativeTensor::from_vec(vec![], vec![value]), + native: NativeTensor::from_vec_col_major(vec![], vec![value]), value: ScalarValue::I64(value), } } + fn from_i32(value: i32) -> Self { + Self { + native: NativeTensor::from_vec_col_major(vec![], vec![value]), + value: ScalarValue::I32(value), + } + } + + fn from_bool(value: bool) -> Self { + Self { + native: NativeTensor::from_vec_col_major(vec![], vec![value]), + value: ScalarValue::Bool(value), + } + } + pub(crate) fn from_native(value: NativeTensor) -> Result { Self::wrap_native(value) } @@ -285,7 +379,7 @@ impl Scalar { /// ``` #[allow(private_bounds)] pub fn from_value(value: T) -> Self { - let native = NativeTensor::from_vec(vec![], vec![value]); + let native = NativeTensor::from_vec_col_major(vec![], vec![value]); Self { native, value: T::scalar_value(value), @@ -479,7 +573,9 @@ impl Scalar { match self.value() { ScalarValue::F32(value) => Some(value as f64), ScalarValue::F64(value) => Some(value), + ScalarValue::I32(value) => Some(value as f64), ScalarValue::I64(value) => Some(value as f64), + ScalarValue::Bool(value) => Some(if value { 1.0 } else { 0.0 }), ScalarValue::C32(_) | ScalarValue::C64(_) => None, } } @@ -502,7 +598,11 @@ impl Scalar { /// ``` pub fn as_c64(&self) -> Option { match self.value() { - ScalarValue::F32(_) | ScalarValue::F64(_) | ScalarValue::I64(_) => None, + ScalarValue::F32(_) + | ScalarValue::F64(_) + | ScalarValue::I32(_) + | ScalarValue::I64(_) + | ScalarValue::Bool(_) => None, ScalarValue::C32(value) => Some(Complex64::new(value.re as f64, value.im as f64)), ScalarValue::C64(value) => Some(value), } @@ -529,7 +629,9 @@ impl Scalar { match self.value() { ScalarValue::F32(value) => Self::from_value(value), ScalarValue::F64(value) => Self::from_value(value), + ScalarValue::I32(value) => Self::from_i32(value), ScalarValue::I64(value) => Self::from_i64(value), + ScalarValue::Bool(value) => Self::from_bool(value), ScalarValue::C32(value) => Self::from_value(value.conj()), ScalarValue::C64(value) => Self::from_value(value.conj()), } @@ -674,7 +776,9 @@ impl SumFromStorage for Scalar { match scalar_value_from_storage(storage) { ScalarValue::F32(value) => Self::from_value(value), ScalarValue::F64(value) => Self::from_value(value), + ScalarValue::I32(value) => Self::from_i32(value), ScalarValue::I64(value) => Self::from_i64(value), + ScalarValue::Bool(value) => Self::from_bool(value), ScalarValue::C32(value) => Self::from_value(value), ScalarValue::C64(value) => Self::from_value(value), } @@ -712,7 +816,9 @@ impl TryFrom for f64 { match value.value() { ScalarValue::F32(real) => Ok(real as f64), ScalarValue::F64(real) => Ok(real), + ScalarValue::I32(real) => Ok(real as f64), ScalarValue::I64(real) => Ok(real as f64), + ScalarValue::Bool(real) => Ok(if real { 1.0 } else { 0.0 }), ScalarValue::C32(_) | ScalarValue::C64(_) => { Err("cannot convert complex scalar to f64") } @@ -783,7 +889,9 @@ impl Neg for Scalar { match self.value() { ScalarValue::F32(value) => Self::from_value(-value), ScalarValue::F64(value) => Self::from_value(-value), + ScalarValue::I32(value) => Self::from_i32(-value), ScalarValue::I64(value) => Self::from_i64(-value), + ScalarValue::Bool(value) => Self::from_real(if value { -1.0 } else { 0.0 }), ScalarValue::C32(value) => Self::from_value(-value), ScalarValue::C64(value) => Self::from_value(-value), } @@ -847,17 +955,57 @@ impl PartialOrd for Scalar { match (self.value(), other.value()) { (ScalarValue::F32(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&rhs), (ScalarValue::F32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs), + (ScalarValue::F32(lhs), ScalarValue::I32(rhs)) => { + (lhs as f64).partial_cmp(&(rhs as f64)) + } (ScalarValue::F32(lhs), ScalarValue::I64(rhs)) => { (lhs as f64).partial_cmp(&(rhs as f64)) } + (ScalarValue::F32(lhs), ScalarValue::Bool(rhs)) => { + (lhs as f64).partial_cmp(&(if rhs { 1.0 } else { 0.0 })) + } (ScalarValue::F64(lhs), ScalarValue::F32(rhs)) => lhs.partial_cmp(&(rhs as f64)), (ScalarValue::F64(lhs), ScalarValue::F64(rhs)) => lhs.partial_cmp(&rhs), + (ScalarValue::F64(lhs), ScalarValue::I32(rhs)) => lhs.partial_cmp(&(rhs as f64)), (ScalarValue::F64(lhs), ScalarValue::I64(rhs)) => lhs.partial_cmp(&(rhs as f64)), + (ScalarValue::F64(lhs), ScalarValue::Bool(rhs)) => { + lhs.partial_cmp(&(if rhs { 1.0 } else { 0.0 })) + } + (ScalarValue::I32(lhs), ScalarValue::F32(rhs)) => { + (lhs as f64).partial_cmp(&(rhs as f64)) + } + (ScalarValue::I32(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs), + (ScalarValue::I32(lhs), ScalarValue::I32(rhs)) => lhs.partial_cmp(&rhs), + (ScalarValue::I32(lhs), ScalarValue::I64(rhs)) => { + (lhs as f64).partial_cmp(&(rhs as f64)) + } + (ScalarValue::I32(lhs), ScalarValue::Bool(rhs)) => { + (lhs as f64).partial_cmp(&(if rhs { 1.0 } else { 0.0 })) + } (ScalarValue::I64(lhs), ScalarValue::F32(rhs)) => { (lhs as f64).partial_cmp(&(rhs as f64)) } (ScalarValue::I64(lhs), ScalarValue::F64(rhs)) => (lhs as f64).partial_cmp(&rhs), + (ScalarValue::I64(lhs), ScalarValue::I32(rhs)) => { + (lhs as f64).partial_cmp(&(rhs as f64)) + } (ScalarValue::I64(lhs), ScalarValue::I64(rhs)) => lhs.partial_cmp(&rhs), + (ScalarValue::I64(lhs), ScalarValue::Bool(rhs)) => { + (lhs as f64).partial_cmp(&(if rhs { 1.0 } else { 0.0 })) + } + (ScalarValue::Bool(lhs), ScalarValue::F32(rhs)) => { + (if lhs { 1.0 } else { 0.0 }).partial_cmp(&(rhs as f64)) + } + (ScalarValue::Bool(lhs), ScalarValue::F64(rhs)) => { + (if lhs { 1.0 } else { 0.0 }).partial_cmp(&rhs) + } + (ScalarValue::Bool(lhs), ScalarValue::I32(rhs)) => { + (if lhs { 1.0 } else { 0.0 }).partial_cmp(&(rhs as f64)) + } + (ScalarValue::Bool(lhs), ScalarValue::I64(rhs)) => { + (if lhs { 1.0 } else { 0.0 }).partial_cmp(&(rhs as f64)) + } + (ScalarValue::Bool(lhs), ScalarValue::Bool(rhs)) => lhs.partial_cmp(&rhs), _ => None, } } @@ -868,7 +1016,9 @@ impl fmt::Display for Scalar { match self.value() { ScalarValue::F32(value) => value.fmt(f), ScalarValue::F64(value) => value.fmt(f), + ScalarValue::I32(value) => value.fmt(f), ScalarValue::I64(value) => value.fmt(f), + ScalarValue::Bool(value) => value.fmt(f), ScalarValue::C32(value) => value.fmt(f), ScalarValue::C64(value) => value.fmt(f), } diff --git a/crates/tensor4all-tensorbackend/src/any_scalar/tests/mod.rs b/crates/tensor4all-tensorbackend/src/any_scalar/tests/mod.rs index 42e27ec9..d8246b1c 100644 --- a/crates/tensor4all-tensorbackend/src/any_scalar/tests/mod.rs +++ b/crates/tensor4all-tensorbackend/src/any_scalar/tests/mod.rs @@ -165,10 +165,10 @@ fn promote_scalar_native_covers_all_scalar_type_pairs() { #[test] fn i64_native_scalar_is_supported_without_public_tensor_element() { - let scalar = Scalar::from_native(NativeTensor::from_vec(vec![], vec![-3_i64])) + let scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![-3_i64])) .expect("i64 native scalar"); - let zero = - Scalar::from_native(NativeTensor::from_vec(vec![], vec![0_i64])).expect("i64 zero scalar"); + let zero = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![0_i64])) + .expect("i64 zero scalar"); assert_eq!(scalar.native.dtype(), DType::I64); assert_eq!(scalar.real(), -3.0); @@ -185,10 +185,63 @@ fn i64_native_scalar_is_supported_without_public_tensor_element() { assert_eq!(format!("{}", zero), "0"); } +#[test] +fn i32_native_scalar_is_supported_without_public_tensor_element() { + let scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![-4_i32])) + .expect("i32 native scalar"); + let zero = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![0_i32])) + .expect("i32 zero scalar"); + + assert_eq!(scalar.native.dtype(), DType::I32); + assert_eq!(scalar.real(), -4.0); + assert_eq!(scalar.imag(), 0.0); + assert_eq!(scalar.abs(), 4.0); + assert!(!scalar.is_complex()); + assert!(scalar.is_real()); + assert!(!scalar.is_zero()); + assert!(zero.is_zero()); + assert_eq!(scalar.as_f64(), Some(-4.0)); + assert_eq!(scalar.as_c64(), None); + assert_eq!(Complex64::from(scalar.clone()), Complex64::new(-4.0, 0.0)); + assert_eq!(scalar.conj(), scalar); + assert_eq!((-scalar).as_f64(), Some(4.0)); + assert_eq!(format!("{}", zero), "0"); +} + +#[test] +fn bool_native_scalar_is_supported_without_public_tensor_element() { + let true_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![true])) + .expect("bool true scalar"); + let false_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![false])) + .expect("bool false scalar"); + + assert_eq!(true_scalar.native.dtype(), DType::Bool); + assert_eq!(true_scalar.real(), 1.0); + assert_eq!(true_scalar.imag(), 0.0); + assert_eq!(true_scalar.abs(), 1.0); + assert!(true_scalar.is_real()); + assert!(!true_scalar.is_zero()); + assert_eq!(true_scalar.as_f64(), Some(1.0)); + assert_eq!(true_scalar.as_c64(), None); + assert_eq!( + Complex64::from(true_scalar.clone()), + Complex64::new(1.0, 0.0) + ); + assert_eq!(true_scalar.conj(), true_scalar); + assert_eq!((-true_scalar).as_f64(), Some(-1.0)); + assert_eq!(format!("{}", false_scalar), "false"); + + assert_eq!(false_scalar.real(), 0.0); + assert_eq!(false_scalar.abs(), 0.0); + assert!(false_scalar.is_zero()); + assert_eq!(false_scalar.as_f64(), Some(0.0)); + assert_eq!((-false_scalar).as_f64(), Some(0.0)); +} + #[test] fn promote_i64_native_scalar_covers_supported_targets_and_rejections() { - let i64_scalar = - Scalar::from_native(NativeTensor::from_vec(vec![], vec![7_i64])).expect("i64 scalar"); + let i64_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![7_i64])) + .expect("i64 scalar"); let promoted_f32 = Scalar::from_native(promote_scalar_native(i64_scalar.as_native(), DType::F32).unwrap()) @@ -231,9 +284,117 @@ fn promote_i64_native_scalar_covers_supported_targets_and_rejections() { .is_err()); } +#[test] +fn promote_i32_native_scalar_covers_supported_targets_and_rejections() { + let i32_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![5_i32])) + .expect("i32 scalar"); + + let promoted_f32 = + Scalar::from_native(promote_scalar_native(i32_scalar.as_native(), DType::F32).unwrap()) + .expect("promoted f32"); + let promoted_f64 = + Scalar::from_native(promote_scalar_native(i32_scalar.as_native(), DType::F64).unwrap()) + .expect("promoted f64"); + let promoted_i32 = + Scalar::from_native(promote_scalar_native(i32_scalar.as_native(), DType::I32).unwrap()) + .expect("promoted i32"); + let promoted_i64 = + Scalar::from_native(promote_scalar_native(i32_scalar.as_native(), DType::I64).unwrap()) + .expect("promoted i64"); + let promoted_c32 = + Scalar::from_native(promote_scalar_native(i32_scalar.as_native(), DType::C32).unwrap()) + .expect("promoted c32"); + let promoted_c64 = + Scalar::from_native(promote_scalar_native(i32_scalar.as_native(), DType::C64).unwrap()) + .expect("promoted c64"); + + assert_eq!(promoted_f32.native.dtype(), DType::F32); + assert_eq!(promoted_f32.as_f64(), Some(5.0)); + assert_eq!(promoted_f64.native.dtype(), DType::F64); + assert_eq!(promoted_f64.as_f64(), Some(5.0)); + assert_eq!(promoted_i32.native.dtype(), DType::I32); + assert_eq!(promoted_i32.as_f64(), Some(5.0)); + assert_eq!(promoted_i64.native.dtype(), DType::I64); + assert_eq!(promoted_i64.as_f64(), Some(5.0)); + assert_eq!(promoted_c32.native.dtype(), DType::C32); + assert_eq!(promoted_c32.as_c64(), Some(Complex64::new(5.0, 0.0))); + assert_eq!(promoted_c64.native.dtype(), DType::C64); + assert_eq!(promoted_c64.as_c64(), Some(Complex64::new(5.0, 0.0))); + + assert!(promote_scalar_native(i32_scalar.as_native(), DType::Bool).is_err()); + assert!(promote_scalar_native(Scalar::from_value(1.25_f32).as_native(), DType::I32).is_err()); + assert!(promote_scalar_native(Scalar::from_value(1.25_f64).as_native(), DType::I32).is_err()); + assert!(promote_scalar_native( + Scalar::from_value(Complex32::new(1.0, 0.0)).as_native(), + DType::I32 + ) + .is_err()); + assert!(promote_scalar_native( + Scalar::from_value(Complex64::new(1.0, 0.0)).as_native(), + DType::I32 + ) + .is_err()); +} + +#[test] +fn promote_bool_native_scalar_covers_supported_targets_and_rejections() { + let true_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![true])) + .expect("bool scalar"); + + let promoted_f32 = + Scalar::from_native(promote_scalar_native(true_scalar.as_native(), DType::F32).unwrap()) + .expect("promoted f32"); + let promoted_f64 = + Scalar::from_native(promote_scalar_native(true_scalar.as_native(), DType::F64).unwrap()) + .expect("promoted f64"); + let promoted_i32 = + Scalar::from_native(promote_scalar_native(true_scalar.as_native(), DType::I32).unwrap()) + .expect("promoted i32"); + let promoted_i64 = + Scalar::from_native(promote_scalar_native(true_scalar.as_native(), DType::I64).unwrap()) + .expect("promoted i64"); + let promoted_bool = + Scalar::from_native(promote_scalar_native(true_scalar.as_native(), DType::Bool).unwrap()) + .expect("promoted bool"); + let promoted_c32 = + Scalar::from_native(promote_scalar_native(true_scalar.as_native(), DType::C32).unwrap()) + .expect("promoted c32"); + let promoted_c64 = + Scalar::from_native(promote_scalar_native(true_scalar.as_native(), DType::C64).unwrap()) + .expect("promoted c64"); + + assert_eq!(promoted_f32.native.dtype(), DType::F32); + assert_eq!(promoted_f32.as_f64(), Some(1.0)); + assert_eq!(promoted_f64.native.dtype(), DType::F64); + assert_eq!(promoted_f64.as_f64(), Some(1.0)); + assert_eq!(promoted_i32.native.dtype(), DType::I32); + assert_eq!(promoted_i32.as_f64(), Some(1.0)); + assert_eq!(promoted_i64.native.dtype(), DType::I64); + assert_eq!(promoted_i64.as_f64(), Some(1.0)); + assert_eq!(promoted_bool.native.dtype(), DType::Bool); + assert_eq!(promoted_bool.as_f64(), Some(1.0)); + assert_eq!(promoted_c32.native.dtype(), DType::C32); + assert_eq!(promoted_c32.as_c64(), Some(Complex64::new(1.0, 0.0))); + assert_eq!(promoted_c64.native.dtype(), DType::C64); + assert_eq!(promoted_c64.as_c64(), Some(Complex64::new(1.0, 0.0))); + + assert!(promote_scalar_native(Scalar::from_value(1.25_f32).as_native(), DType::Bool).is_err()); + assert!(promote_scalar_native(Scalar::from_value(1.25_f64).as_native(), DType::Bool).is_err()); + assert!(promote_scalar_native( + Scalar::from_value(Complex32::new(1.0, 0.0)).as_native(), + DType::Bool + ) + .is_err()); + assert!(promote_scalar_native( + Scalar::from_value(Complex64::new(1.0, 0.0)).as_native(), + DType::Bool + ) + .is_err()); +} + #[test] fn promote_scalar_native_rejects_non_scalar_tensor() { - let tensor = NativeTensor::from_vec(vec![2], vec![1.0_f64, 2.0]); + let tensor = NativeTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]); let err = promote_scalar_native(&tensor, DType::F64).unwrap_err(); @@ -242,8 +403,8 @@ fn promote_scalar_native_rejects_non_scalar_tensor() { #[test] fn i64_native_scalar_participates_in_real_ordering() { - let i64_scalar = - Scalar::from_native(NativeTensor::from_vec(vec![], vec![3_i64])).expect("i64 scalar"); + let i64_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![3_i64])) + .expect("i64 scalar"); assert_eq!( i64_scalar.partial_cmp(&Scalar::from_value(2.5_f32)), @@ -255,7 +416,63 @@ fn i64_native_scalar_participates_in_real_ordering() { ); assert_eq!( i64_scalar.partial_cmp( - &Scalar::from_native(NativeTensor::from_vec(vec![], vec![3_i64])).unwrap() + &Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![3_i64])).unwrap() + ), + Some(Ordering::Equal) + ); +} + +#[test] +fn i32_and_bool_native_scalars_participate_in_real_ordering() { + let i32_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![2_i32])) + .expect("i32 scalar"); + let true_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![true])) + .expect("bool true scalar"); + let false_scalar = Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![false])) + .expect("bool false scalar"); + + assert_eq!( + i32_scalar.partial_cmp(&Scalar::from_value(1.5_f32)), + Some(Ordering::Greater) + ); + assert_eq!( + i32_scalar.partial_cmp(&Scalar::from_value(2.5_f64)), + Some(Ordering::Less) + ); + assert_eq!( + i32_scalar.partial_cmp( + &Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![2_i32])).unwrap() + ), + Some(Ordering::Equal) + ); + assert_eq!( + i32_scalar.partial_cmp(&true_scalar), + Some(Ordering::Greater) + ); + assert_eq!( + true_scalar.partial_cmp(&Scalar::from_value(0.5_f32)), + Some(Ordering::Greater) + ); + assert_eq!( + true_scalar.partial_cmp(&Scalar::from_value(1.5_f64)), + Some(Ordering::Less) + ); + assert_eq!(true_scalar.partial_cmp(&i32_scalar), Some(Ordering::Less)); + assert_eq!( + true_scalar.partial_cmp( + &Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![1_i64])).unwrap() + ), + Some(Ordering::Equal) + ); + assert_eq!( + false_scalar.partial_cmp( + &Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![0_i32])).unwrap() + ), + Some(Ordering::Equal) + ); + assert_eq!( + false_scalar.partial_cmp( + &Scalar::from_native(NativeTensor::from_vec_col_major(vec![], vec![false])).unwrap() ), Some(Ordering::Equal) ); diff --git a/crates/tensor4all-tensorbackend/src/backend.rs b/crates/tensor4all-tensorbackend/src/backend.rs index 8f3fdf4f..b78abd51 100644 --- a/crates/tensor4all-tensorbackend/src/backend.rs +++ b/crates/tensor4all-tensorbackend/src/backend.rs @@ -21,7 +21,7 @@ use crate::matrix::Matrix; /// use tensor4all_tensorbackend::svd_backend; /// use tenferro::TypedTensor; /// -/// let a = TypedTensor::::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 2.0]); +/// let a = TypedTensor::::from_vec_col_major(vec![2, 2], vec![1.0, 0.0, 0.0, 2.0]); /// let result = svd_backend(&a).unwrap(); /// /// assert_eq!(result.u.shape, vec![2, 2]); @@ -479,7 +479,7 @@ fn matrix_to_typed_tensor(matrix: &Matrix) -> TypedTensor where T: TensorScalar + Copy, { - TypedTensor::from_vec( + TypedTensor::from_vec_col_major( vec![matrix.nrows(), matrix.ncols()], matrix.as_col_major_slice().to_vec(), ) diff --git a/crates/tensor4all-tensorbackend/src/backend/tests/mod.rs b/crates/tensor4all-tensorbackend/src/backend/tests/mod.rs index 857934b0..cd16eb27 100644 --- a/crates/tensor4all-tensorbackend/src/backend/tests/mod.rs +++ b/crates/tensor4all-tensorbackend/src/backend/tests/mod.rs @@ -54,7 +54,7 @@ fn scale_columns_complex( #[test] fn qr_backend_reconstructs_real_matrix() { - let input = TypedTensor::from_vec(vec![2, 2], vec![1.0_f64, 3.0, 2.0, 4.0]); + let input = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 3.0, 2.0, 4.0]); let (q, r) = qr_backend(&input).unwrap(); assert_eq!(q.shape, vec![2, 2]); @@ -75,7 +75,7 @@ fn qr_backend_reconstructs_real_matrix() { #[test] fn svd_backend_reconstructs_complex_matrix() { - let input = TypedTensor::from_vec( + let input = TypedTensor::from_vec_col_major( vec![2, 2], vec![ Complex64::new(1.0, -0.5), @@ -107,8 +107,8 @@ fn svd_backend_reconstructs_complex_matrix() { #[test] fn solve_backend_solves_real_system() { - let a = TypedTensor::from_vec(vec![2, 2], vec![2.0_f64, 1.0, 1.0, 2.0]); - let b = TypedTensor::from_vec(vec![2, 1], vec![1.0_f64, 0.0]); + let a = TypedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 1.0, 1.0, 2.0]); + let b = TypedTensor::from_vec_col_major(vec![2, 1], vec![1.0_f64, 0.0]); let x = solve_backend(&a, &b).unwrap(); @@ -293,8 +293,8 @@ fn triangular_solve_matrix_owned_solves_complex64_system() { #[test] fn triangular_solve_backend_solves_typed_tensor_system() { - let a = TypedTensor::from_vec(vec![2, 2], vec![2.0_f64, 1.0, 0.0, 3.0]); - let b = TypedTensor::from_vec(vec![2, 1], vec![2.0_f64, 7.0]); + let a = TypedTensor::from_vec_col_major(vec![2, 2], vec![2.0_f64, 1.0, 0.0, 3.0]); + let b = TypedTensor::from_vec_col_major(vec![2, 1], vec![2.0_f64, 7.0]); let x = triangular_solve_backend(&a, &b, true, true, false, false).unwrap(); @@ -314,7 +314,7 @@ fn try_into_typed_result_reports_dtype_mismatch() { #[test] fn full_piv_lu_backend_returns_square_factors() { - let input = TypedTensor::from_vec(vec![2, 2], vec![0.0_f64, 2.0, 1.0, 3.0]); + let input = TypedTensor::from_vec_col_major(vec![2, 2], vec![0.0_f64, 2.0, 1.0, 3.0]); let decomp = full_piv_lu_backend(&input).unwrap(); diff --git a/crates/tensor4all-tensorbackend/src/context.rs b/crates/tensor4all-tensorbackend/src/context.rs index 4bf5d27b..805b6a58 100644 --- a/crates/tensor4all-tensorbackend/src/context.rs +++ b/crates/tensor4all-tensorbackend/src/context.rs @@ -3,21 +3,22 @@ //! tensor4all-rs routes tenferro CPU execution through one process-global //! `CpuContext`, matching tenferro's `cpu:0` default-global thread-pool model. //! Plain tensor operations, cached traced execution, and eager AD currently use -//! separate `CpuBackend` values because tenferro does not yet expose a public -//! API for borrowing the backend owned by an `EagerContext`. All -//! backends are created from the same global CPU context, so thread-pool -//! configuration is shared. +//! separate `CpuBackend` values because tenferro does not expose a public API +//! for borrowing the backend owned by an `EagerRuntime`. All backends are +//! created from the same global CPU context, so thread-pool configuration is +//! shared. use std::sync::{Arc, Mutex, OnceLock}; -use tenferro::{CpuBackend, EagerContext, Engine}; +use tenferro::{CpuBackend, EagerRuntime, GraphCompiler, GraphExecutor}; use tenferro_tensor::buffer_pool::BufferPoolStats; use tenferro_tensor::cpu::CpuContext; static DEFAULT_CPU_CONTEXT: OnceLock> = OnceLock::new(); static DEFAULT_BACKEND: OnceLock> = OnceLock::new(); -static DEFAULT_ENGINE: OnceLock>> = OnceLock::new(); -static DEFAULT_EAGER_CTX: OnceLock> = OnceLock::new(); +static DEFAULT_GRAPH_COMPILER: OnceLock> = OnceLock::new(); +static DEFAULT_GRAPH_EXECUTOR: OnceLock>> = OnceLock::new(); +static DEFAULT_EAGER_RUNTIME: OnceLock> = OnceLock::new(); fn default_cpu_context() -> Arc { DEFAULT_CPU_CONTEXT @@ -29,13 +30,27 @@ fn default_backend() -> &'static Mutex { DEFAULT_BACKEND.get_or_init(|| Mutex::new(CpuBackend::from_context(default_cpu_context()))) } -fn default_engine() -> &'static Mutex> { - DEFAULT_ENGINE - .get_or_init(|| Mutex::new(Engine::new(CpuBackend::from_context(default_cpu_context())))) +fn default_graph_compiler() -> &'static Mutex { + DEFAULT_GRAPH_COMPILER.get_or_init(|| Mutex::new(GraphCompiler::new())) } -fn lock_default_engine() -> std::sync::MutexGuard<'static, Engine> { - match default_engine().lock() { +fn default_graph_executor() -> &'static Mutex> { + DEFAULT_GRAPH_EXECUTOR.get_or_init(|| { + Mutex::new(GraphExecutor::new(CpuBackend::from_context( + default_cpu_context(), + ))) + }) +} + +fn lock_default_graph_compiler() -> std::sync::MutexGuard<'static, GraphCompiler> { + match default_graph_compiler().lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + } +} + +fn lock_default_graph_executor() -> std::sync::MutexGuard<'static, GraphExecutor> { + match default_graph_executor().lock() { Ok(guard) => guard, Err(poisoned) => poisoned.into_inner(), } @@ -53,44 +68,49 @@ pub fn with_default_backend(f: impl FnOnce(&mut CpuBackend) -> R) -> R { f(&mut backend) } -/// Run a closure against the process-global tenferro execution engine. +/// Run a closure against the process-global tenferro graph compiler/executor. /// /// This is used for native tensor operations that benefit from tenferro's /// persistent execution caches, such as N-ary einsum contraction paths. -pub(crate) fn with_default_engine(f: impl FnOnce(&mut Engine) -> R) -> R { - let mut engine = lock_default_engine(); - f(&mut engine) +pub(crate) fn with_default_graph_runtime( + f: impl FnOnce(&mut GraphCompiler, &mut GraphExecutor) -> R, +) -> R { + let mut compiler = lock_default_graph_compiler(); + let mut executor = lock_default_graph_executor(); + f(&mut compiler, &mut executor) } -/// Return retained-buffer statistics for the process-global execution engine. +/// Return retained-buffer statistics for the process-global graph executor. pub(crate) fn default_engine_buffer_pool_stats() -> BufferPoolStats { - lock_default_engine().buffer_pool_stats() + lock_default_graph_executor().buffer_pool_stats() } -/// Reset retained buffers in the process-global execution engine. +/// Reset retained buffers in the process-global graph executor. pub(crate) fn reset_default_engine_buffer_pool() { - lock_default_engine().reset_buffer_pool(); + lock_default_graph_executor().reset_buffer_pool(); } -/// Drop and recreate the process-global execution engine. +/// Drop and recreate the process-global graph compiler/executor. /// /// This releases tenferro's retained execution buffers and cached contraction /// paths. It is intended for diagnostics and memory-pressure recovery, not for /// normal hot loops where the caches are valuable. pub(crate) fn reset_default_engine() { - let mut engine = lock_default_engine(); - *engine = Engine::new(CpuBackend::from_context(default_cpu_context())); + let mut compiler = lock_default_graph_compiler(); + *compiler = GraphCompiler::new(); + let mut executor = lock_default_graph_executor(); + *executor = GraphExecutor::new(CpuBackend::from_context(default_cpu_context())); } /// Return the process-global eager context used for reverse-mode AD. /// /// This context owns a separate `CpuBackend` from [`with_default_backend`] and -/// the cached execution engine, but all backends share the same process-global +/// the cached graph executor, but all backends share the same process-global /// tenferro CPU context. -pub fn default_eager_ctx() -> Arc { - DEFAULT_EAGER_CTX +pub fn default_eager_ctx() -> Arc { + DEFAULT_EAGER_RUNTIME .get_or_init(|| { - EagerContext::with_cpu_backend(CpuBackend::from_context(default_cpu_context())) + EagerRuntime::with_cpu_backend(CpuBackend::from_context(default_cpu_context())) }) .clone() } @@ -130,11 +150,13 @@ mod tests { #[test] fn default_engine_is_shared_across_threads() { - let main_threads = with_default_engine(|engine| engine.backend().num_threads()); - let worker_threads = - std::thread::spawn(|| with_default_engine(|engine| engine.backend().num_threads())) - .join() - .expect("worker thread should complete"); + let main_threads = + with_default_graph_runtime(|_, executor| executor.backend().num_threads()); + let worker_threads = std::thread::spawn(|| { + with_default_graph_runtime(|_, executor| executor.backend().num_threads()) + }) + .join() + .expect("worker thread should complete"); assert_eq!(main_threads, worker_threads); } diff --git a/crates/tensor4all-tensorbackend/src/matrix.rs b/crates/tensor4all-tensorbackend/src/matrix.rs index 4db5f684..d615138c 100644 --- a/crates/tensor4all-tensorbackend/src/matrix.rs +++ b/crates/tensor4all-tensorbackend/src/matrix.rs @@ -61,7 +61,7 @@ pub struct Matrix { /// use tenferro::TypedTensor; /// use tensor4all_tensorbackend::Matrix; /// -/// let tensor = TypedTensor::from_vec(vec![2, 1, 1], vec![1.0_f64, 2.0]); +/// let tensor = TypedTensor::from_vec_col_major(vec![2, 1, 1], vec![1.0_f64, 2.0]); /// let err = Matrix::try_from_typed_tensor(tensor).unwrap_err(); /// assert!(err.to_string().contains("rank-2 tensor")); /// ``` @@ -175,7 +175,7 @@ impl Matrix { where T: TensorScalar, { - TypedTensor::from_vec(vec![self.nrows, self.ncols], self.data.clone()) + TypedTensor::from_vec_col_major(vec![self.nrows, self.ncols], self.data.clone()) } /// Consume this matrix as a tenferro [`TypedTensor`] without cloning. @@ -197,7 +197,7 @@ impl Matrix { where T: TensorScalar, { - TypedTensor::from_vec(vec![self.nrows, self.ncols], self.data) + TypedTensor::from_vec_col_major(vec![self.nrows, self.ncols], self.data) } /// Consume a rank-2 tenferro [`TypedTensor`] as a [`Matrix`]. @@ -218,7 +218,7 @@ impl Matrix { /// use tenferro::TypedTensor; /// use tensor4all_tensorbackend::Matrix; /// - /// let tensor = TypedTensor::from_vec(vec![2, 2], vec![1.0_f64, 3.0, 2.0, 4.0]); + /// let tensor = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0_f64, 3.0, 2.0, 4.0]); /// let m = Matrix::try_from_typed_tensor(tensor).unwrap(); /// assert_eq!(m.nrows(), 2); /// assert_eq!(m.ncols(), 2); @@ -230,12 +230,11 @@ impl Matrix { where T: Clone, { - let (shape, data) = - tensor - .try_into_vec() - .map_err(|source| MatrixTensorConversionError::HostBuffer { - message: source.to_string(), - })?; + let (shape, data) = tensor.try_into_vec_col_major().map_err(|source| { + MatrixTensorConversionError::HostBuffer { + message: source.to_string(), + } + })?; if shape.len() != 2 { return Err(MatrixTensorConversionError::Rank { shape }); } @@ -814,7 +813,7 @@ where .context("batched matrix multiplication failed")?; let c = T::try_into_typed(c) .ok_or_else(|| anyhow::anyhow!("batched matrix multiplication returned wrong dtype"))?; - let (_shape, data) = c.try_into_vec()?; + let (_shape, data) = c.try_into_vec_col_major()?; let expected_len = batch .checked_mul(m) .and_then(|value| value.checked_mul(n)) diff --git a/crates/tensor4all-tensorbackend/src/matrix/tests/mod.rs b/crates/tensor4all-tensorbackend/src/matrix/tests/mod.rs index c3b11ada..42f44112 100644 --- a/crates/tensor4all-tensorbackend/src/matrix/tests/mod.rs +++ b/crates/tensor4all-tensorbackend/src/matrix/tests/mod.rs @@ -58,7 +58,7 @@ fn matrix_into_typed_tensor_consumes_column_major_layout() { #[test] fn matrix_from_typed_tensor_consumes_column_major_layout() { - let tensor = TypedTensor::from_vec(vec![2, 2], vec![1.0, 3.0, 2.0, 4.0]); + let tensor = TypedTensor::from_vec_col_major(vec![2, 2], vec![1.0, 3.0, 2.0, 4.0]); let m = Matrix::try_from_typed_tensor(tensor).unwrap(); @@ -70,7 +70,7 @@ fn matrix_from_typed_tensor_consumes_column_major_layout() { #[test] fn matrix_from_typed_tensor_rejects_non_matrix_rank() { - let tensor = TypedTensor::from_vec(vec![2, 1, 1], vec![1.0, 2.0]); + let tensor = TypedTensor::from_vec_col_major(vec![2, 1, 1], vec![1.0, 2.0]); let err = Matrix::try_from_typed_tensor(tensor).unwrap_err(); diff --git a/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs b/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs index bb951fc0..617c255f 100644 --- a/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs +++ b/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs @@ -9,17 +9,17 @@ use std::time::{Duration, Instant}; use anyhow::{anyhow, ensure, Result}; use num_complex::{Complex32, Complex64}; use omeco::ScoreFunction; -use tenferro::traced_tensor::{einsum_subscripts_with, EinsumOptimize}; use tenferro::{ - DType, EinsumSubscripts, Tensor as NativeTensor, TensorBackend, TensorRead, TensorView, - TracedTensor, + DType, Tensor as NativeTensor, TensorBackend, TensorRead, TensorView, TracedTensor, +}; +use tenferro_einsum::{ + ContractionOptimizerOptions, ContractionTree, EinsumOptimize, EinsumSubscripts, Subscripts, }; -use tenferro_einsum::{ContractionOptimizerOptions, ContractionTree, Subscripts}; use crate::any_scalar::promote_scalar_native; use crate::context::{ default_engine_buffer_pool_stats, reset_default_engine, reset_default_engine_buffer_pool, - with_default_backend, with_default_engine, + with_default_backend, with_default_graph_runtime, }; use crate::memory::release_process_allocator_cached_memory; use crate::storage::Storage; @@ -209,7 +209,9 @@ fn dtype_size_bytes(dtype: DType) -> usize { DType::F64 => 8, DType::C32 => 8, DType::C64 => 16, + DType::I32 => 4, DType::I64 => 8, + DType::Bool => 1, } } @@ -499,14 +501,18 @@ fn common_dtype(dtypes: &[DType]) -> DType { let has_f64 = dtypes.contains(&DType::F64); let has_c64 = dtypes.contains(&DType::C64); let has_c32 = dtypes.contains(&DType::C32); + let has_i32 = dtypes.contains(&DType::I32); let has_i64 = dtypes.contains(&DType::I64); + let has_bool = dtypes.contains(&DType::Bool); let has_complex = has_c64 || has_c32; if has_c64 || (has_f64 && has_complex) { DType::C64 } else if has_c32 { DType::C32 - } else if has_f64 || has_i64 { + } else if has_f64 || has_i64 || has_i32 { DType::F64 + } else if has_bool { + DType::Bool } else { DType::F32 } @@ -559,20 +565,34 @@ fn cached_einsum_native_tensors( .zip(inputs.iter()) .map(|(placeholder, tensor)| (placeholder, *tensor)) .collect::>(); + let input_specs = placeholders + .iter() + .zip(inputs.iter()) + .map(|(placeholder, tensor)| (placeholder, tensor.dtype(), tensor.shape())) + .collect::>(); let trace_pool = native_einsum_pool_trace_enabled(); let pool_before = trace_pool.then(default_engine_buffer_pool_stats); - let result = with_default_engine(|engine| { - let mut result = einsum_subscripts_with( - engine, + let result = with_default_graph_runtime(|compiler, executor| { + executor + .register_extension(tenferro_einsum::register_runtime) + .map_err(|e| anyhow!("native einsum runtime registration failed: {e}"))?; + let result = tenferro_einsum::einsum_subscripts_with( + compiler, &placeholder_refs, subscripts, EinsumOptimize::default(), ) .map_err(|e| anyhow!("native einsum failed: {e}"))?; - result - .eval_with_inputs(engine, &bindings) - .cloned() + // Native einsum creates fresh placeholder tensors for each call. + // tenferro's compiled-program cache retains those placeholder keys, so + // reuse can leave the program expecting bindings from a previous call. + compiler.clear_compile_cache(); + let program = compiler + .compile_with_input_specs(&result, &input_specs) + .map_err(|e| anyhow!("native einsum graph compilation failed: {e}"))?; + executor + .run_with_inputs(&program, &bindings) .map_err(|e| anyhow!("native einsum failed: {e}")) })?; if trace_pool { @@ -648,10 +668,11 @@ fn cached_einsum_native_reads( inputs: &[TensorRead<'_>], subscripts: &Subscripts, ) -> Result { - with_default_backend(|backend| { - tenferro_einsum::eager_einsum_read_subscripts(backend, inputs, subscripts) - .map_err(|e| anyhow!("native read einsum failed: {e}")) - }) + let tensors = inputs.iter().map(TensorRead::to_tensor).collect::>(); + let tensor_refs = tensors.iter().collect::>(); + let einsum_subscripts = EinsumSubscripts::from(subscripts); + cached_einsum_native_tensors(&tensor_refs, &einsum_subscripts) + .map_err(|e| anyhow!("native read einsum failed: {e}")) } /// Build native einsum ids for a binary contraction. @@ -769,12 +790,14 @@ pub fn storage_payload_native_read_input(storage: &Storage) -> Result Result Result .to_vec(), tensor.shape(), ), + DType::I32 => Storage::from_dense_col_major( + tensor + .as_slice::() + .ok_or_else(|| anyhow!("failed to read i32 native tensor"))? + .iter() + .map(|&value| value as f64) + .collect::>(), + tensor.shape(), + ), DType::I64 => Storage::from_dense_col_major( tensor .as_slice::() @@ -823,6 +857,15 @@ pub fn native_tensor_primal_to_storage(tensor: &NativeTensor) -> Result .collect::>(), tensor.shape(), ), + DType::Bool => Storage::from_dense_col_major( + tensor + .as_slice::() + .ok_or_else(|| anyhow!("failed to read bool native tensor"))? + .iter() + .map(|&value| if value { 1.0 } else { 0.0 }) + .collect::>(), + tensor.shape(), + ), DType::C32 => Storage::from_dense_col_major( tensor .as_slice::() @@ -853,12 +896,24 @@ pub fn native_tensor_primal_to_dense_f64_col_major(tensor: &NativeTensor) -> Res .map(|&value| value as f64) .collect()), DType::F64 => ::dense_values_from_native_col_major(tensor), + DType::I32 => Ok(tensor + .as_slice::() + .ok_or_else(|| anyhow!("failed to read i32 native tensor"))? + .iter() + .map(|&value| value as f64) + .collect()), DType::I64 => Ok(tensor .as_slice::() .ok_or_else(|| anyhow!("failed to read i64 native tensor"))? .iter() .map(|&value| value as f64) .collect()), + DType::Bool => Ok(tensor + .as_slice::() + .ok_or_else(|| anyhow!("failed to read bool native tensor"))? + .iter() + .map(|&value| if value { 1.0 } else { 0.0 }) + .collect()), other => Err(anyhow!("expected real native tensor, got dtype {other:?}")), } } @@ -896,10 +951,18 @@ pub fn native_tensor_primal_to_diag_f64(tensor: &NativeTensor) -> Result::diag_values_from_native_temp(&promoted) } DType::F64 => ::diag_values_from_native_temp(tensor), + DType::I32 => { + let promoted = convert_tensor(tensor, DType::F64)?; + ::diag_values_from_native_temp(&promoted) + } DType::I64 => { let promoted = convert_tensor(tensor, DType::F64)?; ::diag_values_from_native_temp(&promoted) } + DType::Bool => { + let promoted = convert_tensor(tensor, DType::F64)?; + ::diag_values_from_native_temp(&promoted) + } other => Err(anyhow!("expected real native tensor, got dtype {other:?}")), } } @@ -978,7 +1041,10 @@ pub fn scale_native_tensor(tensor: &NativeTensor, scalar: &AnyScalar) -> Result< .iter() .map(|&value| value * factor) .collect::>(); - Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + tensor.shape().to_vec(), + values, + )) } DType::F64 => { let factor = scalar @@ -991,7 +1057,10 @@ pub fn scale_native_tensor(tensor: &NativeTensor, scalar: &AnyScalar) -> Result< .iter() .map(|&value| value * factor) .collect::>(); - Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + tensor.shape().to_vec(), + values, + )) } DType::C32 => { let factor = scalar @@ -1004,7 +1073,10 @@ pub fn scale_native_tensor(tensor: &NativeTensor, scalar: &AnyScalar) -> Result< .iter() .map(|&value| value * factor) .collect::>(); - Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + tensor.shape().to_vec(), + values, + )) } DType::C64 => { let factor = scalar @@ -1017,9 +1089,14 @@ pub fn scale_native_tensor(tensor: &NativeTensor, scalar: &AnyScalar) -> Result< .iter() .map(|&value| value * factor) .collect::>(); - Ok(NativeTensor::from_vec(tensor.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + tensor.shape().to_vec(), + values, + )) } - DType::I64 => Err(anyhow!("scale_native_tensor does not support i64 tensors")), + DType::I32 | DType::I64 | DType::Bool => Err(anyhow!( + "scale_native_tensor does not support integer/bool tensors" + )), } } @@ -1069,7 +1146,10 @@ pub fn axpby_native_tensor( .zip(rhs_values.iter()) .map(|(&x, &y)| a * x + b * y) .collect::>(); - Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + lhs.shape().to_vec(), + values, + )) } DType::F64 => { let a = a @@ -1091,7 +1171,10 @@ pub fn axpby_native_tensor( .zip(rhs_values.iter()) .map(|(&x, &y)| a * x + b * y) .collect::>(); - Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + lhs.shape().to_vec(), + values, + )) } DType::C32 => { let a = a @@ -1113,7 +1196,10 @@ pub fn axpby_native_tensor( .zip(rhs_values.iter()) .map(|(&x, &y)| a * x + b * y) .collect::>(); - Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + lhs.shape().to_vec(), + values, + )) } DType::C64 => { let a = a @@ -1135,9 +1221,14 @@ pub fn axpby_native_tensor( .zip(rhs_values.iter()) .map(|(&x, &y)| a * x + b * y) .collect::>(); - Ok(NativeTensor::from_vec(lhs.shape().to_vec(), values)) + Ok(NativeTensor::from_vec_col_major( + lhs.shape().to_vec(), + values, + )) } - DType::I64 => Err(anyhow!("axpby_native_tensor does not support i64 tensors")), + DType::I32 | DType::I64 | DType::Bool => Err(anyhow!( + "axpby_native_tensor does not support integer/bool tensors" + )), } } @@ -1165,8 +1256,8 @@ pub fn axpby_native_tensor( /// use tensor4all_tensorbackend::einsum_native_tensors_owned; /// use tenferro::Tensor as NativeTensor; /// -/// let lhs = NativeTensor::from_vec(vec![2, 3], vec![1.0_f64; 6]); -/// let rhs = NativeTensor::from_vec(vec![3, 2], vec![1.0_f64; 6]); +/// let lhs = NativeTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]); +/// let rhs = NativeTensor::from_vec_col_major(vec![3, 2], vec![1.0_f64; 6]); /// let result = einsum_native_tensors_owned(vec![(lhs, vec![0, 1]), (rhs, vec![1, 2])], &[0, 2]).unwrap(); /// /// assert_eq!(result.shape(), &[2, 2]); @@ -1258,8 +1349,8 @@ pub fn einsum_native_tensors_owned( /// use tensor4all_tensorbackend::einsum_native_tensors; /// use tenferro::Tensor as NativeTensor; /// -/// let lhs = NativeTensor::from_vec(vec![2, 3], vec![1.0_f64; 6]); -/// let rhs = NativeTensor::from_vec(vec![3, 2], vec![1.0_f64; 6]); +/// let lhs = NativeTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64; 6]); +/// let rhs = NativeTensor::from_vec_col_major(vec![3, 2], vec![1.0_f64; 6]); /// let result = einsum_native_tensors(&[(&lhs, &[0, 1]), (&rhs, &[1, 2])], &[0, 2]).unwrap(); /// /// assert_eq!(result.shape(), &[2, 2]); @@ -1412,8 +1503,8 @@ pub fn outer_product_native_tensor(lhs: &NativeTensor, rhs: &NativeTensor) -> Re /// Conjugate a native tensor. pub fn conj_native_tensor(tensor: &NativeTensor) -> Result { match tensor.dtype() { - DType::F32 | DType::F64 | DType::I64 => Ok(tensor.clone()), - DType::C32 => Ok(NativeTensor::from_vec( + DType::F32 | DType::F64 | DType::I32 | DType::I64 | DType::Bool => Ok(tensor.clone()), + DType::C32 => Ok(NativeTensor::from_vec_col_major( tensor.shape().to_vec(), tensor .as_slice::() @@ -1422,7 +1513,7 @@ pub fn conj_native_tensor(tensor: &NativeTensor) -> Result { .map(|&value| value.conj()) .collect::>(), )), - DType::C64 => Ok(NativeTensor::from_vec( + DType::C64 => Ok(NativeTensor::from_vec_col_major( tensor.shape().to_vec(), tensor .as_slice::() diff --git a/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs b/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs index 16440446..9ad5c0d2 100644 --- a/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs +++ b/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs @@ -36,13 +36,9 @@ fn recorded_native_einsum_call_count(path: NativeEinsumPath) -> usize { }) } -fn default_engine_contains_einsum_subscripts_key( - inputs: &[&[u32]], - output: &[u32], - shapes: Vec>, -) -> bool { - crate::context::with_default_engine(|engine| { - engine.einsum_cache_contains_subscripts(&(EinsumSubscripts::new(inputs, output), shapes)) +fn default_graph_runtime_has_einsum_extension_cache_entries() -> bool { + crate::context::with_default_graph_runtime(|compiler, _| { + compiler.cache_stats().extensions.entries > 0 }) } @@ -258,9 +254,9 @@ fn native_einsum_accepts_unsorted_nonfirst_operand_labels() { #[test] fn einsum_native_tensors_supports_retained_shared_nary_label() { - let a = NativeTensor::from_vec(vec![2, 2], vec![5.0_f64, 7.0, 11.0, 13.0]); - let b = NativeTensor::from_vec(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]); - let c = NativeTensor::from_vec(vec![2, 2], vec![11.0_f64, 13.0, 17.0, 19.0]); + let a = NativeTensor::from_vec_col_major(vec![2, 2], vec![5.0_f64, 7.0, 11.0, 13.0]); + let b = NativeTensor::from_vec_col_major(vec![2, 3], vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]); + let c = NativeTensor::from_vec_col_major(vec![2, 2], vec![11.0_f64, 13.0, 17.0, 19.0]); let out = einsum_native_tensors( &[(&a, &[0, 1]), (&b, &[0, 2]), (&c, &[0, 3])], @@ -293,9 +289,9 @@ fn einsum_native_tensors_supports_retained_shared_nary_label() { #[test] fn einsum_native_tensors_populates_process_global_path_cache() { - let a = NativeTensor::from_vec(vec![2, 3, 4], vec![1.0_f64; 24]); - let b = NativeTensor::from_vec(vec![4, 5], vec![2.0_f64; 20]); - let c = NativeTensor::from_vec(vec![3, 2], vec![3.0_f64; 6]); + let a = NativeTensor::from_vec_col_major(vec![2, 3, 4], vec![1.0_f64; 24]); + let b = NativeTensor::from_vec_col_major(vec![4, 5], vec![2.0_f64; 20]); + let c = NativeTensor::from_vec_col_major(vec![3, 2], vec![3.0_f64; 6]); let out = einsum_native_tensors(&[(&a, &[0, 1, 2]), (&b, &[2, 3]), (&c, &[1, 0])], &[3]).unwrap(); @@ -305,18 +301,14 @@ fn einsum_native_tensors_populates_process_global_path_cache() { native_tensor_primal_to_dense_f64_col_major(&out).unwrap(), vec![144.0; 5] ); - assert!(default_engine_contains_einsum_subscripts_key( - &[&[0, 1, 2], &[2, 3], &[1, 0]], - &[3], - vec![vec![2, 3, 4], vec![4, 5], vec![3, 2]] - )); + assert!(default_graph_runtime_has_einsum_extension_cache_entries()); } #[test] fn einsum_native_tensors_mixed_dtype_records_borrowed_conversion_profile() { let _guard = ProfileGuard::enable(); - let lhs = NativeTensor::from_vec(vec![2, 2], vec![1.0_f32, 2.0, 3.0, 4.0]); - let rhs = NativeTensor::from_vec(vec![2, 3], vec![5.0_f64, 6.0, 7.0, 8.0, 9.0, 10.0]); + let lhs = NativeTensor::from_vec_col_major(vec![2, 2], vec![1.0_f32, 2.0, 3.0, 4.0]); + let rhs = NativeTensor::from_vec_col_major(vec![2, 3], vec![5.0_f64, 6.0, 7.0, 8.0, 9.0, 10.0]); let owned = einsum_native_tensors_owned( vec![(lhs.clone(), vec![0, 1]), (rhs.clone(), vec![1, 2])], @@ -369,7 +361,7 @@ fn einsum_native_tensors_dense_binary_records_borrowed_profile() { #[test] fn native_read_input_owned_and_plan_helpers_cover_debug_paths() { - let tensor = NativeTensor::from_vec(vec![2], vec![1.0_f64, 2.0]); + let tensor = NativeTensor::from_vec_col_major(vec![2], vec![1.0_f64, 2.0]); let input = NativeTensorReadInput::Owned(tensor.clone()); assert_eq!(input.dtype(), DType::F64); assert_eq!(input.shape(), &[2]); @@ -431,11 +423,11 @@ fn native_read_input_owned_and_plan_helpers_cover_debug_paths() { fn native_einsum_profile_print_and_c32_arithmetic_paths() { let _guard = ProfileGuard::enable(); - let lhs = NativeTensor::from_vec( + let lhs = NativeTensor::from_vec_col_major( vec![2], vec![Complex32::new(1.0, 2.0), Complex32::new(-3.0, 0.5)], ); - let rhs = NativeTensor::from_vec( + let rhs = NativeTensor::from_vec_col_major( vec![2], vec![Complex32::new(0.5, -1.0), Complex32::new(4.0, 2.0)], ); diff --git a/crates/tensor4all-tensorbackend/src/tensor_element.rs b/crates/tensor4all-tensorbackend/src/tensor_element.rs index 34557a9e..4be6da12 100644 --- a/crates/tensor4all-tensorbackend/src/tensor_element.rs +++ b/crates/tensor4all-tensorbackend/src/tensor_element.rs @@ -41,7 +41,9 @@ fn tensor_dtype_name(dtype: DType) -> &'static str { match dtype { DType::F32 => "f32", DType::F64 => "f64", + DType::I32 => "i32", DType::I64 => "i64", + DType::Bool => "bool", DType::C32 => "c32", DType::C64 => "c64", } @@ -84,7 +86,10 @@ macro_rules! impl_tensor_element { dims, expected_len ); - Ok(NativeTensor::from_vec(dims.to_vec(), data.to_vec())) + Ok(NativeTensor::from_vec_col_major( + dims.to_vec(), + data.to_vec(), + )) } fn diag_native_tensor_from_col_major( @@ -97,7 +102,7 @@ macro_rules! impl_tensor_element { } fn scalar_native_tensor(value: Self) -> Result { - Ok(NativeTensor::from_vec(vec![], vec![value])) + Ok(NativeTensor::from_vec_col_major(vec![], vec![value])) } fn dense_values_from_native_col_major(tensor: &NativeTensor) -> Result> { diff --git a/crates/tensor4all-treetci/src/materialize.rs b/crates/tensor4all-treetci/src/materialize.rs index 5e077e8e..0f1e9001 100644 --- a/crates/tensor4all-treetci/src/materialize.rs +++ b/crates/tensor4all-treetci/src/materialize.rs @@ -52,8 +52,9 @@ macro_rules! impl_full_piv_lu_scalar { let lhs_t = transpose_column_major(lhs_values, lhs_rows, lhs_cols); let pivot_t = transpose_column_major(pivot_values, pivot_rows, pivot_cols); - let pivot_tensor = Tensor::from_vec(vec![pivot_cols, pivot_rows], pivot_t); - let lhs_tensor = Tensor::from_vec(vec![lhs_cols, lhs_rows], lhs_t); + let pivot_tensor = + Tensor::from_vec_col_major(vec![pivot_cols, pivot_rows], pivot_t); + let lhs_tensor = Tensor::from_vec_col_major(vec![lhs_cols, lhs_rows], lhs_t); let solved_t = with_default_backend(|backend| { pivot_tensor.full_piv_lu_solve(&lhs_tensor, backend) })?;