Skip to content

AlexTuring010/neural-lsh

Repository files navigation

neural-lsh

Implementation of Neural LSH (Dong, Indyk, Razenshteyn, Wagner — Learning Space Partitions for Nearest Neighbor Search, ICLR 2020 — arXiv:1901.08544) for SIFT and MNIST.

Idea: instead of hashing points into buckets with a random LSH function, learn the bucket assignment with a neural network — train an MLP on a partition of the dataset's k-NN graph and use the trained model as the bucketing function at query time.

Second homework of the Software Development for Algorithmic Problems course at the University of Athens (Department of Informatics & Telecommunications). Graded ~10/10.

Built with @soc9999 (Sokratis Papargyris) — partner across all three homeworks.

Headline result — Neural LSH vs. our HW1 IVF-Flat (on MNIST)

Same dataset (MNIST, 60k × 784), same brute-force ground truth, our two implementations side by side:

Regime HW1 IVF-Flat best HW2 Neural LSH best
High-accuracy (recall ≥ 0.99) 0.99 @ ~3× speedup 0.99 @ 10–12× speedup ~4× faster
High-speed (~30× speedup) recall drops to ~0.65–0.70 0.96–0.97 recall at 38–40× speedup recall doesn't crater
Pushed further n/a — IVF-Flat gets worse with more clusters up to ~200× speedup at lower recall with m=512 scales further

The intuition: every approximate nearest-neighbour method buckets points and searches a few buckets at query time. LSH uses random hashes, IVF-Flat uses k-means. Neural LSH just learns a much better bucketing function by treating "which partition does this point belong to" as a supervised classification problem.

Full SIFT-1M experiments were aspirational — we got the pipeline working on siftsmall to confirm correctness, but kept losing the long SIFT-1M runs to Colab's GPU disconnects ("disconnected for inactivity" after 4 hours of training is a special kind of pain). All the headline numbers below are MNIST.

Detailed sweeps in EXPERIMENTS.md.

How it works — pipeline

              ┌─────────────────────────────────────────────────────────────┐
              │                          BUILD                              │
              ├─────────────────────────────────────────────────────────────┤
              │                                                             │
input.dat ──▶ │  C++  allnn  (reused HW1 code, IVF-Flat or brute build)   │
              │      │                                                      │
              │      ▼                                                      │
              │  k-NN graph (binary)                                        │
              │      │                                                      │
              │      ▼                                                      │
              │  KaHIP graph partitioning  ──▶ m balanced partitions        │
              │      │                          (each point ↔ partition id) │
              │      ▼                                                      │
              │  PyTorch MLP training:                                      │
              │   input  = vector                                           │
              │   target = partition id                                     │
              │      │                                                      │
              │      ▼                                                      │
              │  index.bin = trained model + partition info                 │
              └─────────────────────────────────────────────────────────────┘

              ┌─────────────────────────────────────────────────────────────┐
              │                          SEARCH                             │
              ├─────────────────────────────────────────────────────────────┤
              │  query  ──▶  MLP  ──▶  top-T predicted partition ids        │
              │                              │                              │
              │                              ▼                              │
              │              brute-force NN within those partitions         │
              └─────────────────────────────────────────────────────────────┘

What's interesting about the implementation

  • Hybrid C++/Python pipeline. The expensive bits (k-NN graph build, brute-force search at query time) stay in C++ (reused from HW1). The model and orchestration are Python. The two halves communicate via a custom binary k-NN format that the Python side reads with numpy.fromfile.
  • Bayesian HPO with Optuna (beyond the brief). --mode tune runs k-fold cross-validation, samples MLP architectures via TPE, and saves the top-k models for downstream comparison. This was where most of the real findings came from.
  • Multi-model evaluation in one pass. nlsh_search.py --models_index <dir>/index.json evaluates all tuned models against the same query set in a single run, producing per-model output files and a models_compare.csv summary. Without this, sweeping the top-5 models would mean five separate search runs.
  • Caching at three levels. (1) Brute-force results (HW1's cache, retained); (2) k-NN graph results — if you've already built a k=10 graph, you can re-cut to k=5 without re-running the C++ build; (3) KaHIP partitions — if data and partition params are unchanged, skip re-partitioning. Without this, the experiment loop would have been multi-day instead of multi-hour.
  • Probabilistic bucket selection at query time. The standard T parameter picks the top-T partitions. We added T < 1 to mean "take all partitions where the model's probability exceeds T" — a soft variant we wanted to compare. (Spoiler: no clear win, but it was a fun afternoon.)

Architectural findings from the 320-model sweep

We ran a Bayesian HPO sweep across MLP architectures, dropout, learning rate, schedulers, and the two graph-construction methods. The dominant survivors:

  • 2 layers beats 3+ in the top-10 across nearly every setting.
  • 512 nodes per hidden layer is the consistent winner.
  • Dropout = 0.0 wins. Regularization isn't earning its keep on this problem.
  • brute and ivfflat graph construction give equivalent final recall (~0.98). Building the graph faster doesn't cost final accuracy — useful for scaling up.
  • Partition parameters dominate model parameters. The MLP architecture moves recall by a few points; the partition m moves recall by tens of points. Bottleneck = partition quality. The NN learns whatever partition you give it, well or badly.

That last finding is what motivated the "two-stage tuning" experiment — find good MLP hyperparameters once, then sweep the partition space cheaply.

Build & run

Dependencies

pip install -r requirements.txt          # numpy, scikit-learn, optuna
# PyTorch — install per https://pytorch.org/get-started/locally/
# KaHIP — pip install kahip, or conda install -c conda-forge kahip

Compile the C++ helper

make                                      # produces ./allnn

Build an index

python nlsh_build.py \
    -d sift/sift_base.fvecs \
    -i my_index.bin \
    -type sift \
    --knn 10 -m 128 --imbalance 0.03

Search

python nlsh_search.py \
    -d sift/sift_base.fvecs \
    -q sift/sift_query.fvecs \
    -i my_index.bin \
    -o results.txt \
    -type sift -N 10 -T 5

Tune (Bayesian HPO across MLP architectures)

python nlsh_build.py --mode tune \
    -d sift/sift_base.fvecs -i my_tuned -type sift \
    --tune_trials 50 --tune_folds 3 --tune_epochs 10 \
    --top_k 5 --final_epochs 100

Then evaluate the top-5 models in one pass:

python nlsh_search.py \
    -d sift/sift_base.fvecs -q sift/sift_query.fvecs \
    --models_index outputs/tuning/<run-folder>/index.json \
    -type sift -N 10 -T 5

A consolidated models_compare.csv is dumped alongside the per-model output files.

Project layout

nlsh_build.py     # build script: graph → partition → train MLP → save index
nlsh_search.py    # search script: load index, evaluate, optionally compare models
src/              # C++ allnn (cut-down version of HW1's search.cpp)
include/          # C++ headers (IVFFlat, brute, utils, Config)
python/
  graph_utils.py  # binary k-NN reader, KaHIP CSR builder
  dataloader.py   # MNIST / SIFT readers (Python ports of the HW1 C++ parsers)
  model.py        # MLP definition
  search.py       # NLSH search algorithm + brute-force reference
scripts/          # parameter sweep harnesses, two-stage tuning, results regen
outputs/          # tuning runs, best-params caches, comparison CSVs

Sequence — the trilogy

This is part two of a three-homework arc:

  1. ann-search-cpp — 4 ANN algorithms (LSH, Hypercube, IVF-Flat, IVF-PQ) in C++ on MNIST + SIFT
  2. neural-lsh (you are here) — Neural LSH implementation with Bayesian HPO
  3. protein-homolog-search — capstone: protein remote-homolog detection on SwissProt using ESM-2 embeddings + the ANN methods from parts 1 and 2

License

MIT — applies to my and Sokratis's joint work in this repo. Assignment-distributed materials retain their original course copyright.

About

Implementation of Neural LSH (Dong et al., ICLR 2020) for ANN search: C++ k-NN graph + KaHIP graph partitioning + PyTorch MLP for learned bucket assignment, with Bayesian HPO via Optuna. Beats HW1's IVF-Flat ~4× at high recall on MNIST.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors