diff --git a/include/sampler_id/interp_decomp.h b/include/sampler_id/interp_decomp.h index 7456d8b..d63bfb0 100644 --- a/include/sampler_id/interp_decomp.h +++ b/include/sampler_id/interp_decomp.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include namespace ttid { @@ -16,9 +17,7 @@ struct RowID { std::vector sv; }; -ColID interp_decomp_cols(const torch::Tensor& A, int64_t k); -ColID interp_decomp_cols(const torch::Tensor& A, double tol); -RowID interp_decomp_rows(const torch::Tensor& A, int64_t k); -RowID interp_decomp_rows(const torch::Tensor& A, double tol); +ColID interp_decomp_cols(const torch::Tensor& A, double tol = 0.0, int64_t k = INT64_MAX); +RowID interp_decomp_rows(const torch::Tensor& A, double tol = 0.0, int64_t k = INT64_MAX); } // namespace ttid diff --git a/include/sampler_id/tt.h b/include/sampler_id/tt.h index c57e040..8629d75 100644 --- a/include/sampler_id/tt.h +++ b/include/sampler_id/tt.h @@ -32,7 +32,7 @@ class TensorTrain { auto M = C.reshape({r * dk, rest}); auto [rows, Q, sv] = (tol > 0.0) ? interp_decomp_rows(M, tol) - : interp_decomp_rows(M, std::min(r * dk, rest)); + : interp_decomp_rows(M, 0.0, std::min(r * dk, rest)); int64_t r_new = (int64_t)rows.size(); cores.push_back(Q.reshape({r, dk, r_new})); C = M.index({torch::tensor(rows, torch::kLong), torch::indexing::Slice()}); diff --git a/include/sampler_id/tt_id.h b/include/sampler_id/tt_id.h index 9cfe853..98ddd92 100644 --- a/include/sampler_id/tt_id.h +++ b/include/sampler_id/tt_id.h @@ -103,18 +103,18 @@ struct tt_id { auto M = tensor_fun_to_mat(Ib, Jb); if (b < center) { - auto res = interp_decomp_rows(M, param.reltol); + auto res = interp_decomp_rows(M, param.reltol, param.bondDim); int64_t k = (int64_t)res.rows.size(); - auto res_col = interp_decomp_cols(M, k); + auto res_col = interp_decomp_cols(M, 0.0, k); Iset[b + 1] = Ib.at(res.rows); Jset[b] = Jb.at(res_col.cols); tt.cores[b] = res.P; tt.cores[b + 1] = M.index({torch::tensor(res.rows, torch::kLong), torch::indexing::Slice()}); collectPivotError(res.sv); } else { - auto res = interp_decomp_cols(M, param.reltol); + auto res = interp_decomp_cols(M, param.reltol, param.bondDim); int64_t k = (int64_t)res.cols.size(); - auto res_row = interp_decomp_rows(M, k); + auto res_row = interp_decomp_rows(M, 0.0, k); Jset[b] = Jb.at(res.cols); Iset[b + 1] = Ib.at(res_row.rows); tt.cores[b] = M.index({torch::indexing::Slice(), torch::tensor(res.cols, torch::kLong)}); @@ -133,14 +133,14 @@ struct tt_id { auto M = tensor_fun_to_mat(Ib, Jb); if (isLeft) { - auto res = interp_decomp_rows(M, param.reltol); + auto res = interp_decomp_rows(M, param.reltol, param.bondDim); Iset[b + 1] = Ib.at(res.rows); Jset[b] = Iset[b + 1]; tt.cores[b] = res.P; tt.cores[b + 1] = M.index({torch::tensor(res.rows, torch::kLong), torch::indexing::Slice()}); collectPivotError(res.sv); } else { - auto res = interp_decomp_cols(M, param.reltol); + auto res = interp_decomp_cols(M, param.reltol, param.bondDim); Jset[b] = Jb.at(res.cols); tt.cores[b] = M.index({torch::indexing::Slice(), torch::tensor(res.cols, torch::kLong)}); tt.cores[b + 1] = res.P; diff --git a/src/interp_decomp.cpp b/src/interp_decomp.cpp index 513f749..425001e 100644 --- a/src/interp_decomp.cpp +++ b/src/interp_decomp.cpp @@ -37,27 +37,24 @@ static ColID build_col_id(const torch::Tensor& A, const torch::Tensor& R, return {std::move(cols), Pmat, std::move(sv)}; } -ColID interp_decomp_cols(const torch::Tensor& A, int64_t k) { +ColID interp_decomp_cols(const torch::Tensor& A, double tol, int64_t k) { TORCH_CHECK(A.dim() == 2, "A must be 2D"); - TORCH_CHECK(k > 0 && k <= std::min(A.size(0), A.size(1)), "k out of range"); + const int64_t max_k = std::min(A.size(0), A.size(1)); auto [R, P] = rrQR(A); - return build_col_id(A, R, P, k); + int64_t rank = max_k; + if (tol > 0.0) { + TORCH_CHECK(tol < 1.0, "tol must be in (0,1)"); + rank = std::min(rank, rank_from_diag(R, tol)); + } + if (k != INT64_MAX) { + TORCH_CHECK(k > 0, "k must be positive"); + rank = std::min(rank, k); + } + return build_col_id(A, R, P, rank); } -ColID interp_decomp_cols(const torch::Tensor& A, double tol) { - TORCH_CHECK(A.dim() == 2, "A must be 2D"); - TORCH_CHECK(tol > 0.0 && tol < 1.0, "tol must be in (0,1)"); - auto [R, P] = rrQR(A); - return build_col_id(A, R, P, rank_from_diag(R, tol)); -} - -RowID interp_decomp_rows(const torch::Tensor& A, int64_t k) { - auto col = interp_decomp_cols(A.t().contiguous(), k); - return {std::move(col.cols), col.P.t().contiguous(), std::move(col.sv)}; -} - -RowID interp_decomp_rows(const torch::Tensor& A, double tol) { - auto col = interp_decomp_cols(A.t().contiguous(), tol); +RowID interp_decomp_rows(const torch::Tensor& A, double tol, int64_t k) { + auto col = interp_decomp_cols(A.t().contiguous(), tol, k); return {std::move(col.cols), col.P.t().contiguous(), std::move(col.sv)}; } diff --git a/src/main.cpp b/src/main.cpp index 691269c..ba2633d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -11,7 +11,7 @@ int main() { auto V = torch::randn({n, r}, torch::kFloat64); auto A = torch::mm(U, V.t()); // rank-r matrix - auto [cols, P, sv] = interp_decomp_cols(A, r); + auto [cols, P, sv] = interp_decomp_cols(A, 0.0, r); auto C = A.index({torch::indexing::Slice(), torch::tensor(cols, torch::kLong)}); auto A_approx = torch::mm(C, P); diff --git a/tests/test_interp_decomp.cpp b/tests/test_interp_decomp.cpp index beb24f2..05cedaa 100644 --- a/tests/test_interp_decomp.cpp +++ b/tests/test_interp_decomp.cpp @@ -19,7 +19,7 @@ TEST_CASE("interp_decomp_cols exact rank-k", "[id]") { SECTION("float64") { auto A = torch::mm(torch::randn({m, r}, torch::kFloat64), torch::randn({n, r}, torch::kFloat64).t()); - auto res = interp_decomp_cols(A, r); + auto res = interp_decomp_cols(A, 0.0, r); REQUIRE((int64_t)res.cols.size() == r); REQUIRE(res.P.sizes() == torch::IntArrayRef({r, n})); REQUIRE(rel_err(A, reconstruct(A, res)) < 1e-10); @@ -28,20 +28,20 @@ TEST_CASE("interp_decomp_cols exact rank-k", "[id]") { SECTION("float32") { auto A = torch::mm(torch::randn({m, r}, torch::kFloat32), torch::randn({n, r}, torch::kFloat32).t()); - REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, r))) < 1e-4f); + REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, r))) < 1e-4f); } SECTION("complex128") { auto A = torch::mm(torch::randn({m, r}, torch::kComplexDouble), torch::randn({n, r}, torch::kComplexDouble).t()); - REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, r))) < 1e-10); + REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, r))) < 1e-10); } } TEST_CASE("interp_decomp_cols selected columns are actual columns of A", "[id]") { torch::manual_seed(1); auto A = torch::randn({15, 12}, torch::kFloat64); - auto res = interp_decomp_cols(A, INT64_C(5)); + auto res = interp_decomp_cols(A, 0.0, INT64_C(5)); for (int64_t i = 0; i < 5; ++i) { int64_t c = res.cols[i]; @@ -54,7 +54,7 @@ TEST_CASE("interp_decomp_cols selected columns are actual columns of A", "[id]") TEST_CASE("interp_decomp_cols distinct indices", "[id]") { torch::manual_seed(2); - auto res = interp_decomp_cols(torch::randn({25, 18}, torch::kFloat64), INT64_C(6)); + auto res = interp_decomp_cols(torch::randn({25, 18}, torch::kFloat64), 0.0, INT64_C(6)); auto s = res.cols; std::sort(s.begin(), s.end()); REQUIRE(std::adjacent_find(s.begin(), s.end()) == s.end()); } @@ -64,7 +64,7 @@ TEST_CASE("interp_decomp_cols quality improves with rank", "[id]") { auto A = torch::randn({40, 30}, torch::kFloat64); double prev = 1.0; for (int64_t k : {3, 6, 10, 15}) { - double e = rel_err(A, reconstruct(A, interp_decomp_cols(A, k))); + double e = rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, k))); REQUIRE(e < prev); prev = e; } } @@ -72,14 +72,14 @@ TEST_CASE("interp_decomp_cols quality improves with rank", "[id]") { TEST_CASE("interp_decomp_cols full rank → near-zero error", "[id]") { torch::manual_seed(4); auto A = torch::randn({10, 8}, torch::kFloat64); - REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, INT64_C(8)))) < 1e-10); + REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, INT64_C(8)))) < 1e-10); } TEST_CASE("interp_decomp_cols input validation", "[id]") { auto good = torch::randn({5, 4}); - REQUIRE_THROWS(interp_decomp_cols(torch::randn({3, 3, 3}), INT64_C(2))); - REQUIRE_THROWS(interp_decomp_cols(good, INT64_C(0))); - REQUIRE_THROWS(interp_decomp_cols(good, INT64_C(5))); + REQUIRE_THROWS(interp_decomp_cols(torch::randn({3, 3, 3}), 0.0, INT64_C(2))); + REQUIRE_THROWS(interp_decomp_cols(good, 0.0, INT64_C(0))); + REQUIRE_NOTHROW(interp_decomp_cols(good, 0.0, INT64_C(5))); // k > max_k silently caps } TEST_CASE("interp_decomp_cols singular values", "[id]") { @@ -93,7 +93,7 @@ TEST_CASE("interp_decomp_cols singular values", "[id]") { A.slice(1, k, n) = torch::randn({m, n - k}, torch::kFloat64) * 0.1; SECTION("fixed rank") { - auto res = interp_decomp_cols(A, k); + auto res = interp_decomp_cols(A, 0.0, k); auto sv = res.sv; std::sort(sv.rbegin(), sv.rend()); for (int64_t i = 0; i < k; ++i) REQUIRE(sv[i] == Catch::Approx(sv_exp[i]).epsilon(0.1).margin(0.5));