Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions cuda/culling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,62 @@
#include <thrust/reduce.h>
#include <thrust/transform.h>

// Constants for 64-bit Morton code (21 bits per component)
constexpr uint32_t BITS_PER_COORD = 21;
constexpr uint32_t MAX_COORD_VAL = (1 << BITS_PER_COORD) - 1; // 2^21 - 1 = 2097151

__device__ __forceinline__ uint64_t spread_bits(uint64_t n) {
// Ensure only the lower 21 bits are used.
n &= MAX_COORD_VAL;

// Perform the bit spreading using shifts and masks.
// The sequence is designed to efficiently insert two zero bits
// between every original bit.

// Pattern: 0 Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z
// Final goal: x_20 00 x_19 00 x_18 00 ... 00 x_0

// Mask constants derived from Hacker's Delight (Chapter 7)
// The constants are carefully calculated to group operations.

n = (n | (n << 32)) & 0x1F000000FFFF;
n = (n | (n << 16)) & 0x1F0000FF0000FF;
n = (n | (n << 8)) & 0x100F807C0F807C0F;
n = (n | (n << 4)) & 0x1084210842108421;
n = (n | (n << 2)) & 0x1249249249249249;

return n;
}

__global__ void morton_codes_kernel(const int N, const float *__restrict__ d_xyz, const float x_max, const float y_max,
const float z_max, const float x_min, const float y_min, const float z_min,
uint64_t *__restrict__ codes) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i >= N)
return;

// load unormalized corrdinates
const float x = d_xyz[i * 3 + 0];
const float y = d_xyz[i * 3 + 1];
const float z = d_xyz[i * 3 + 2];

// normalize and quantize using k=21 bit depth
const uint64_t x_q = (uint64_t)((x - x_min) * (MAX_COORD_VAL / (x_max - x_min)));
const uint64_t y_q = (uint64_t)((y - y_min) * (MAX_COORD_VAL / (y_max - y_min)));
const uint64_t z_q = (uint64_t)((z - z_min) * (MAX_COORD_VAL / (z_max - z_min)));

// spread bits to interleave
const uint64_t x_spread = spread_bits(x_q);
const uint64_t y_spread = spread_bits(y_q);
const uint64_t z_spread = spread_bits(z_q);

// interleave bits
const uint64_t code = (z_spread << 2) | (y_spread << 1) | x_spread;

codes[i] = code;
}

__device__ __forceinline__ bool z_distance_culling(const float z, const float near_thresh) { return z >= near_thresh; }

__device__ __forceinline__ bool frustum_culling(const float u, const float v, const int padding, const int width,
Expand Down Expand Up @@ -286,6 +342,22 @@ __global__ void find_tile_boundaries_kernel(const double *__restrict__ sorted_ke
}
}

void compute_morton_codes(const int N, const float *d_xyz, const float x_max, const float y_max, const float z_max,
const float x_min, const float y_min, const float z_min, uint64_t *codes,
cudaStream_t stream) {
ASSERT_DEVICE_POINTER(d_xyz);
ASSERT_DEVICE_POINTER(codes);

const int threads_per_block = 256;
// Calculate the number of blocks needed to cover all N points
const int num_blocks = (N + threads_per_block - 1) / threads_per_block;

dim3 gridsize(num_blocks, 1, 1);
dim3 blocksize(threads_per_block, 1, 1);

morton_codes_kernel<<<gridsize, blocksize, 0, stream>>>(N, d_xyz, x_max, y_max, z_max, x_min, y_min, z_min, codes);
}

void cull_gaussians(float *const uv, float *const xyz, const int N, const float near_thresh, const int padding,
const int width, const int height, bool *mask, cudaStream_t stream) {
ASSERT_DEVICE_POINTER(uv);
Expand Down
155 changes: 155 additions & 0 deletions cuda/trainer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
#include <thread>
#include <thrust/count.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/transform.h>

/**
Expand Down Expand Up @@ -76,6 +80,7 @@ private:
void optimizer_step(ForwardPassData pass_data);
void add_sh_band();
void adaptive_density_step();
void sort_gaussians();

// --- Async Loading Members ---
std::thread loader_thread;
Expand Down Expand Up @@ -769,6 +774,155 @@ void TrainerImpl::adaptive_density_step() {
}
}

struct BoundingBox {
float min_x, max_x;
float min_y, max_y;
float min_z, max_z;
};

struct TupleToBox {
__host__ __device__ BoundingBox operator()(const thrust::tuple<float, float, float> &t) const {
float x = thrust::get<0>(t);
float y = thrust::get<1>(t);
float z = thrust::get<2>(t);
return BoundingBox{x, x, y, y, z, z};
}
};

struct MergeBoxes {
__host__ __device__ BoundingBox operator()(const BoundingBox &a, const BoundingBox &b) const {
return BoundingBox{min(a.min_x, b.min_x), max(a.max_x, b.max_x), min(a.min_y, b.min_y),
max(a.max_y, b.max_y), min(a.min_z, b.min_z), max(a.max_z, b.max_z)};
}
};

struct StridedAddressMap {
const int *sorted_indices;
int stride;

__host__ __device__ int operator()(int i) const {
int logical_block = i / stride;
int offset = i % stride;
int target_block = sorted_indices[logical_block];

return (target_block * stride) + offset;
}
};

template <int STRIDE> struct StridedAddressMapTemplate {
const int *sorted_indices;

__host__ __device__ int operator()(int i) const {
int logical_block = i / STRIDE;
int offset = i % STRIDE;
int target_block = sorted_indices[logical_block];

return (target_block * STRIDE) + offset;
}
};

// Dynamic stride version (legacy name for compatibility with dynamic calls)
template <typename T>
void gather_with_stride(thrust::device_vector<T> &data, thrust::device_vector<int> &indices, int stride, int size) {
if (data.size() == 0 || indices.size() == 0)
return;

thrust::device_vector<T> output(data.size());
const int *raw_indices = thrust::raw_pointer_cast(indices.data());

auto map_iter =
thrust::make_transform_iterator(thrust::make_counting_iterator(0), StridedAddressMap{raw_indices, stride});

thrust::gather(thrust::device, map_iter, map_iter + size * stride, data.begin(), output.begin());
data.swap(output);
}

// Templated stride version with optimizations
template <int STRIDE, typename T>
void gather_with_stride(thrust::device_vector<T> &data, thrust::device_vector<int> &indices, int size) {
if (data.size() == 0 || indices.size() == 0)
return;

if constexpr (STRIDE == 1) {
// Direct gather
thrust::device_vector<T> output(data.size());
thrust::gather(thrust::device, indices.begin(), indices.end(), data.begin(), output.begin());
data.swap(output);
} else if constexpr (std::is_same_v<T, float> && STRIDE == 4) {
// float4 optimization (quaternions)
thrust::device_vector<float> output(data.size());
const float4 *raw_data = reinterpret_cast<const float4 *>(thrust::raw_pointer_cast(data.data()));
float4 *raw_output = reinterpret_cast<float4 *>(thrust::raw_pointer_cast(output.data()));
thrust::gather(thrust::device, indices.begin(), indices.end(), raw_data, raw_output);
data.swap(output);
} else {
// Generic templated stride (e.g. 3 for XYZ/RGB/Scale) - benefits from const stride mod/div
thrust::device_vector<T> output(data.size());
const int *raw_indices = thrust::raw_pointer_cast(indices.data());

auto map_iter = thrust::make_transform_iterator(thrust::make_counting_iterator(0),
StridedAddressMapTemplate<STRIDE>{raw_indices});

thrust::gather(thrust::device, map_iter, map_iter + size * STRIDE, data.begin(), output.begin());
data.swap(output);
}
}

void TrainerImpl::sort_gaussians() {
thrust::device_vector<uint64_t> morton_codes(num_gaussians);

auto x_it = thrust::make_strided_iterator(cuda.gaussians.d_xyz.begin(), 3);
auto y_it = thrust::make_strided_iterator(cuda.gaussians.d_xyz.begin() + 1, 3);
auto z_it = thrust::make_strided_iterator(cuda.gaussians.d_xyz.begin() + 2, 3);

auto points_iter_start = thrust::make_zip_iterator(thrust::make_tuple(x_it, y_it, z_it));
auto points_iter_end = points_iter_start + num_gaussians;

const float inf = std::numeric_limits<float>::infinity();
BoundingBox init = {inf, -inf, inf, -inf, inf, -inf};

BoundingBox bb = thrust::transform_reduce(points_iter_start, points_iter_end, TupleToBox(), init, MergeBoxes());

compute_morton_codes(num_gaussians, thrust::raw_pointer_cast(cuda.gaussians.d_xyz.data()), bb.max_x, bb.max_y,
bb.max_z, bb.min_x, bb.min_y, bb.min_z, thrust::raw_pointer_cast(morton_codes.data()));

thrust::device_vector<int> sort_ids(num_gaussians);

thrust::sequence(sort_ids.begin(), sort_ids.end());

thrust::sort_by_key(morton_codes.begin(), morton_codes.end(), sort_ids.begin());

const int num_sh_coeffs = (l_max + 1) * (l_max + 1) - 1;

// Use templated versions for fixed-size components
gather_with_stride<3>(cuda.gaussians.d_xyz, sort_ids, num_gaussians);
gather_with_stride<3>(cuda.gaussians.d_rgb, sort_ids, num_gaussians);
gather_with_stride<1>(cuda.gaussians.d_opacity, sort_ids, num_gaussians);
gather_with_stride<3>(cuda.gaussians.d_scale, sort_ids, num_gaussians);
gather_with_stride<4>(cuda.gaussians.d_quaternion, sort_ids, num_gaussians);

// Keep dynamic version for SH
gather_with_stride(cuda.gaussians.d_sh, sort_ids, num_sh_coeffs, num_gaussians);

gather_with_stride<3>(cuda.optimizer.m_grad_xyz, sort_ids, num_gaussians);
gather_with_stride<3>(cuda.optimizer.v_grad_xyz, sort_ids, num_gaussians);

gather_with_stride<3>(cuda.optimizer.m_grad_rgb, sort_ids, num_gaussians);
gather_with_stride<3>(cuda.optimizer.v_grad_rgb, sort_ids, num_gaussians);

gather_with_stride<1>(cuda.optimizer.m_grad_opacity, sort_ids, num_gaussians);
gather_with_stride<1>(cuda.optimizer.v_grad_opacity, sort_ids, num_gaussians);

gather_with_stride<3>(cuda.optimizer.m_grad_scale, sort_ids, num_gaussians);
gather_with_stride<3>(cuda.optimizer.v_grad_scale, sort_ids, num_gaussians);

gather_with_stride<4>(cuda.optimizer.m_grad_quaternion, sort_ids, num_gaussians);
gather_with_stride<4>(cuda.optimizer.v_grad_quaternion, sort_ids, num_gaussians);

gather_with_stride(cuda.optimizer.m_grad_sh, sort_ids, num_sh_coeffs, num_gaussians);
gather_with_stride(cuda.optimizer.v_grad_sh, sort_ids, num_sh_coeffs, num_gaussians);
}

float TrainerImpl::backward_pass(const Image &curr_image, const Camera &curr_camera, ForwardPassData &pass_data,
const float bg_color, const thrust::device_vector<float> &d_gt_image) {
const int width = (int)curr_camera.width;
Expand Down Expand Up @@ -1239,6 +1393,7 @@ void TrainerImpl::train() {
if (iter > config.adaptive_control_start && iter % config.adaptive_control_interval == 0 &&
iter < config.adaptive_control_end) {
adaptive_density_step();
sort_gaussians();
reset_grad_accum();
}

Expand Down
17 changes: 17 additions & 0 deletions include/gsplat_cuda/cuda_forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,20 @@ float compute_psnr(const float *predicted_data, const float *gt_data, int rows,
void accumulate_gradients(const int N, const float u_scale, const float v_scale, const bool *d_mask,
const float *d_grad_xyz, const float *d_grad_uv, float *d_xyz_grad_accum,
float *d_uv_grad_acuum, int *d_grad_accum_dur, cudaStream_t stream = 0);

/**
* @brief Launch CUDA kernel to compute morton codes of xyz corrdinates
* @param[in] N Total number of coordinates
* @param[in] d_xyz Device pointer to array of xyz values
* @param[in] x_max Max x value
* @param[in] y_max Max y value
* @param[in] z_max Max z value
* @param[in] x_min Min x value
* @param[in] y_min Min y value
* @param[in] z_min Min z value
* @param[out] codes Morton codes
* @param[in] stream The CUDA stream to execute on
*/
void compute_morton_codes(const int N, const float *d_xyz, const float x_max, const float y_max, const float z_max,
const float x_min, const float y_min, const float z_min, uint64_t *codes,
cudaStream_t stream = 0);
Loading