Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions include/sampler_id/interp_decomp.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <torch/torch.h>
#include <cstdint>
#include <vector>

namespace ttid {
Expand All @@ -16,9 +17,7 @@ struct RowID {
std::vector<double> sv;
};

ColID interp_decomp_cols(const torch::Tensor& A, int64_t k);
ColID interp_decomp_cols(const torch::Tensor& A, double tol);
RowID interp_decomp_rows(const torch::Tensor& A, int64_t k);
RowID interp_decomp_rows(const torch::Tensor& A, double tol);
ColID interp_decomp_cols(const torch::Tensor& A, double tol = 0.0, int64_t k = INT64_MAX);
RowID interp_decomp_rows(const torch::Tensor& A, double tol = 0.0, int64_t k = INT64_MAX);

} // namespace ttid
2 changes: 1 addition & 1 deletion include/sampler_id/tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TensorTrain {
auto M = C.reshape({r * dk, rest});
auto [rows, Q, sv] = (tol > 0.0)
? interp_decomp_rows(M, tol)
: interp_decomp_rows(M, std::min(r * dk, rest));
: interp_decomp_rows(M, 0.0, std::min(r * dk, rest));
int64_t r_new = (int64_t)rows.size();
cores.push_back(Q.reshape({r, dk, r_new}));
C = M.index({torch::tensor(rows, torch::kLong), torch::indexing::Slice()});
Expand Down
12 changes: 6 additions & 6 deletions include/sampler_id/tt_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,18 @@ struct tt_id {
auto M = tensor_fun_to_mat(Ib, Jb);

if (b < center) {
auto res = interp_decomp_rows(M, param.reltol);
auto res = interp_decomp_rows(M, param.reltol, param.bondDim);
int64_t k = (int64_t)res.rows.size();
auto res_col = interp_decomp_cols(M, k);
auto res_col = interp_decomp_cols(M, 0.0, k);
Iset[b + 1] = Ib.at(res.rows);
Jset[b] = Jb.at(res_col.cols);
tt.cores[b] = res.P;
tt.cores[b + 1] = M.index({torch::tensor(res.rows, torch::kLong), torch::indexing::Slice()});
collectPivotError(res.sv);
} else {
auto res = interp_decomp_cols(M, param.reltol);
auto res = interp_decomp_cols(M, param.reltol, param.bondDim);
int64_t k = (int64_t)res.cols.size();
auto res_row = interp_decomp_rows(M, k);
auto res_row = interp_decomp_rows(M, 0.0, k);
Jset[b] = Jb.at(res.cols);
Iset[b + 1] = Ib.at(res_row.rows);
tt.cores[b] = M.index({torch::indexing::Slice(), torch::tensor(res.cols, torch::kLong)});
Expand All @@ -133,14 +133,14 @@ struct tt_id {
auto M = tensor_fun_to_mat(Ib, Jb);

if (isLeft) {
auto res = interp_decomp_rows(M, param.reltol);
auto res = interp_decomp_rows(M, param.reltol, param.bondDim);
Iset[b + 1] = Ib.at(res.rows);
Jset[b] = Iset[b + 1];
tt.cores[b] = res.P;
tt.cores[b + 1] = M.index({torch::tensor(res.rows, torch::kLong), torch::indexing::Slice()});
collectPivotError(res.sv);
} else {
auto res = interp_decomp_cols(M, param.reltol);
auto res = interp_decomp_cols(M, param.reltol, param.bondDim);
Jset[b] = Jb.at(res.cols);
tt.cores[b] = M.index({torch::indexing::Slice(), torch::tensor(res.cols, torch::kLong)});
tt.cores[b + 1] = res.P;
Expand Down
31 changes: 14 additions & 17 deletions src/interp_decomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,24 @@ static ColID build_col_id(const torch::Tensor& A, const torch::Tensor& R,
return {std::move(cols), Pmat, std::move(sv)};
}

ColID interp_decomp_cols(const torch::Tensor& A, int64_t k) {
ColID interp_decomp_cols(const torch::Tensor& A, double tol, int64_t k) {
TORCH_CHECK(A.dim() == 2, "A must be 2D");
TORCH_CHECK(k > 0 && k <= std::min(A.size(0), A.size(1)), "k out of range");
const int64_t max_k = std::min(A.size(0), A.size(1));
auto [R, P] = rrQR(A);
return build_col_id(A, R, P, k);
int64_t rank = max_k;
if (tol > 0.0) {
TORCH_CHECK(tol < 1.0, "tol must be in (0,1)");
rank = std::min(rank, rank_from_diag(R, tol));
}
if (k != INT64_MAX) {
TORCH_CHECK(k > 0, "k must be positive");
rank = std::min(rank, k);
}
return build_col_id(A, R, P, rank);
}

ColID interp_decomp_cols(const torch::Tensor& A, double tol) {
TORCH_CHECK(A.dim() == 2, "A must be 2D");
TORCH_CHECK(tol > 0.0 && tol < 1.0, "tol must be in (0,1)");
auto [R, P] = rrQR(A);
return build_col_id(A, R, P, rank_from_diag(R, tol));
}

RowID interp_decomp_rows(const torch::Tensor& A, int64_t k) {
auto col = interp_decomp_cols(A.t().contiguous(), k);
return {std::move(col.cols), col.P.t().contiguous(), std::move(col.sv)};
}

RowID interp_decomp_rows(const torch::Tensor& A, double tol) {
auto col = interp_decomp_cols(A.t().contiguous(), tol);
RowID interp_decomp_rows(const torch::Tensor& A, double tol, int64_t k) {
auto col = interp_decomp_cols(A.t().contiguous(), tol, k);
return {std::move(col.cols), col.P.t().contiguous(), std::move(col.sv)};
}

Expand Down
2 changes: 1 addition & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ int main() {
auto V = torch::randn({n, r}, torch::kFloat64);
auto A = torch::mm(U, V.t()); // rank-r matrix

auto [cols, P, sv] = interp_decomp_cols(A, r);
auto [cols, P, sv] = interp_decomp_cols(A, 0.0, r);

auto C = A.index({torch::indexing::Slice(), torch::tensor(cols, torch::kLong)});
auto A_approx = torch::mm(C, P);
Expand Down
22 changes: 11 additions & 11 deletions tests/test_interp_decomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TEST_CASE("interp_decomp_cols exact rank-k", "[id]") {
SECTION("float64") {
auto A = torch::mm(torch::randn({m, r}, torch::kFloat64),
torch::randn({n, r}, torch::kFloat64).t());
auto res = interp_decomp_cols(A, r);
auto res = interp_decomp_cols(A, 0.0, r);
REQUIRE((int64_t)res.cols.size() == r);
REQUIRE(res.P.sizes() == torch::IntArrayRef({r, n}));
REQUIRE(rel_err(A, reconstruct(A, res)) < 1e-10);
Expand All @@ -28,20 +28,20 @@ TEST_CASE("interp_decomp_cols exact rank-k", "[id]") {
SECTION("float32") {
auto A = torch::mm(torch::randn({m, r}, torch::kFloat32),
torch::randn({n, r}, torch::kFloat32).t());
REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, r))) < 1e-4f);
REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, r))) < 1e-4f);
}

SECTION("complex128") {
auto A = torch::mm(torch::randn({m, r}, torch::kComplexDouble),
torch::randn({n, r}, torch::kComplexDouble).t());
REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, r))) < 1e-10);
REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, r))) < 1e-10);
}
}

TEST_CASE("interp_decomp_cols selected columns are actual columns of A", "[id]") {
torch::manual_seed(1);
auto A = torch::randn({15, 12}, torch::kFloat64);
auto res = interp_decomp_cols(A, INT64_C(5));
auto res = interp_decomp_cols(A, 0.0, INT64_C(5));

for (int64_t i = 0; i < 5; ++i) {
int64_t c = res.cols[i];
Expand All @@ -54,7 +54,7 @@ TEST_CASE("interp_decomp_cols selected columns are actual columns of A", "[id]")

TEST_CASE("interp_decomp_cols distinct indices", "[id]") {
torch::manual_seed(2);
auto res = interp_decomp_cols(torch::randn({25, 18}, torch::kFloat64), INT64_C(6));
auto res = interp_decomp_cols(torch::randn({25, 18}, torch::kFloat64), 0.0, INT64_C(6));
auto s = res.cols; std::sort(s.begin(), s.end());
REQUIRE(std::adjacent_find(s.begin(), s.end()) == s.end());
}
Expand All @@ -64,22 +64,22 @@ TEST_CASE("interp_decomp_cols quality improves with rank", "[id]") {
auto A = torch::randn({40, 30}, torch::kFloat64);
double prev = 1.0;
for (int64_t k : {3, 6, 10, 15}) {
double e = rel_err(A, reconstruct(A, interp_decomp_cols(A, k)));
double e = rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, k)));
REQUIRE(e < prev); prev = e;
}
}

TEST_CASE("interp_decomp_cols full rank → near-zero error", "[id]") {
torch::manual_seed(4);
auto A = torch::randn({10, 8}, torch::kFloat64);
REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, INT64_C(8)))) < 1e-10);
REQUIRE(rel_err(A, reconstruct(A, interp_decomp_cols(A, 0.0, INT64_C(8)))) < 1e-10);
}

TEST_CASE("interp_decomp_cols input validation", "[id]") {
auto good = torch::randn({5, 4});
REQUIRE_THROWS(interp_decomp_cols(torch::randn({3, 3, 3}), INT64_C(2)));
REQUIRE_THROWS(interp_decomp_cols(good, INT64_C(0)));
REQUIRE_THROWS(interp_decomp_cols(good, INT64_C(5)));
REQUIRE_THROWS(interp_decomp_cols(torch::randn({3, 3, 3}), 0.0, INT64_C(2)));
REQUIRE_THROWS(interp_decomp_cols(good, 0.0, INT64_C(0)));
REQUIRE_NOTHROW(interp_decomp_cols(good, 0.0, INT64_C(5))); // k > max_k silently caps
}

TEST_CASE("interp_decomp_cols singular values", "[id]") {
Expand All @@ -93,7 +93,7 @@ TEST_CASE("interp_decomp_cols singular values", "[id]") {
A.slice(1, k, n) = torch::randn({m, n - k}, torch::kFloat64) * 0.1;

SECTION("fixed rank") {
auto res = interp_decomp_cols(A, k);
auto res = interp_decomp_cols(A, 0.0, k);
auto sv = res.sv; std::sort(sv.rbegin(), sv.rend());
for (int64_t i = 0; i < k; ++i)
REQUIRE(sv[i] == Catch::Approx(sv_exp[i]).epsilon(0.1).margin(0.5));
Expand Down
Loading