Skip to content

add faster implementations of Mean and MCKP estimators#458

Open
caelen00000 wants to merge 19 commits into
broadinstitute:mainfrom
caelen00000:fast-mckp
Open

add faster implementations of Mean and MCKP estimators#458
caelen00000 wants to merge 19 commits into
broadinstitute:mainfrom
caelen00000:fast-mckp

Conversation

@caelen00000

@caelen00000 caelen00000 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Note: I accidentally hit the merge button when looking at the diff, so this whole thing is somewhat half-cooked and should definitely not be merged. However, I would like to gauge interest in the changes I've made, so I'm leaving it up.

Summary

This PR is a proof-of-concept rewrite of the Mean and MultipleChoiceKnapsack EstimationMethods. On my system ( using the --cuda and --estimator-multiple-cpu arguments) the time taken to estimate noise targets and compute denoised counts is greatly reduced. When applied to the heart10k dataset, this implementation takes ~8 seconds, compared to the main branch at ~60s, and the sf-modernization-estimation branch at ~40s.

I will note that most of the speedup vs the sf-modernization-estimation branch comes from the improved mean estimator, whereas MCKP is only slightly faster here.

Mean Estimator

I use the inverse indices returned by torch.unique(noise_log_prob_coo.row, return_inverse=True) to group the indices into noise_log_prob_coo according to the m-index value. Then torch.bincounts gets a weighted sum of probability * c within each group.

This can be applied to the entire posterior, without chunking or densifying, and works seamlessly with either CPU (runtime 1.5s) or GPU (runtime 0.68s) tensors. Compare with the original at ~30s.

As far as correctness goes, it needs more testing for sure, but tentatively looks OK. All tests pass, and I compared the sum of all estimated counts and found a difference on the order of 10-6 percent, presumably due to floating-point arithmetic differences.

MCKP Estimator

My goal here was to remove the pandas dataframes. As soon as I saw them, I had some flashbacks to the time I tried to implement XGBoost in dataframes, which, performance-wise, did not go well 🫠. Unfortunately, I hadn't seen the work already done in the modernization branch, which is probably a better solution. But, if you want to avoid adding dependencies, I have an alternative.

Algorithm Overview

Similarly to the mean estimator, torch.unique inverse indices are used to group the posterior COO, but this time according the the gene. Genes with zero probability are excluded, then genes are iterated on as follows:

  1. A dense 2D tensor is created by subsetting the COO to the current gene and removing cells with zero probability. Great care was taken to minimize the size and the number of these tensors before proceeding to the next step.
  2. The MAP solution is iteratively improved as in the CellBender paper. Note that I have probably missed some corner cases here, so my implementation currently deviates somewhat from the original.
    Here is one of the original output plots:
image And here is mine: image
  1. The indices of the solutions are used to construct the final CSR matrix and we're done.

This loop can be run on the GPU, but is slow. As such, I only run the initial grouping step on the GPU.

Some Thoughts on Multiprocessing

If use_multiple_processes=True, this iterative method is easy enough to chunk. The challenge lies in efficient inter-process communication. In fact, the original MultipleChoiceKnapsack.estimate_noise has a comment that mentions multiprocessing causes a slowdown (which on my PC is not true, but it is still slower than it could be). This is because each argument to _mckp_chunk_estimate_noise must be pickled and copied into each child process.

The solution to this is to load the arguments into shared memory. For arbitrary objects, this is tedious but doable, and leads to moderate speedups. Check out my early commits if you want to see an admittedly very ugly example.

Luckily, torch.multiprocessing exists: its a wrapper around the python multiprocessing library that automatically handles tensor shared memory management. In a single process, I see MCKP estimation taking ~10s, but with 14 processes, it takes ~3s.

Finding the optimal number of processes and threads per process to use still needs work and currently will require individual tuning via some hardcoded variables. By default, torch uses the number of physical cores, which in my case is 14. If you launch 6 processes, that is 84 threads competing for resources. In my testing, when using multiple processes, it it best to limit each to a single thread.

Determining the total number of processes is also an issue. On linux, I find that using the number of physical cores is best, but on Windows, using approx. half as many works better. I'm assuming this is due to increased overhead from spawning processes on Windows compared to forking on linux.

Closing Notes

Like I said earlier, I didn't intended for the world to see this code in its current state, so please don't judge it too harshly. Hopefully at least some of it will prove useful. Either way, I fun writing it and learned a lot in the process! Let me know if you want me to clean any of it up and I can submit a new PR.

@caelen00000 caelen00000 changed the title Fast mckp add faster implementations of Mean and MCKP estimators Jun 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant