Skip to content

Range search bug? #13

Description

@jacketsj

Was nice to see some of the recent work on this at FWCG this year.

I've wrote up some quick tests on a couple popular ANN datasets (SIFT and GloVe), and I think I've stumbled upon a bug with the range search function. Repro below.

Python repro/full test
import os
import time
import urllib.request
import numpy as np
from greedypermutation.balltree import greedy_tree
from metricspaces import MetricSpace

total_distcomps = 0

# Faster numpy point
class Point:
    def __init__(self, point):
        self.point = np.array(point)
        self.hash = hash(str(self.point))

    def __hash__(self):
        return self.hash

    def distsq(self, other):
        global total_distcomps
        total_distcomps += 1
        return np.sum((self.point - other.point)**2)

    def dist(self, other):
        global total_distcomps
        total_distcomps += 1
        return np.linalg.norm(self.point - other.point)

    def __str__(self):
        return str(self.point)

    def __iter__(self):
        return iter(self.point)


GLOVE_URL = "http://ann-benchmarks.com/glove-100-angular.hdf5"
DATA_PATH = "./glove_100_angular.hdf5"

def download_glove():
    """Download the GloVe 100 Angular dataset."""
    if not os.path.exists(DATA_PATH):
        print("Downloading GloVe 100 Angular dataset...")
        urllib.request.urlretrieve(GLOVE_URL, DATA_PATH)
        print("GloVe 100 Angular dataset downloaded.")
    else:
        print("GloVe 100 Angular dataset already exists.")

def load_glove(sample_size=5000):
    """Load the GloVe 100 Angular dataset and sample the first `sample_size` points."""
    import h5py
    download_glove()
    with h5py.File(DATA_PATH, "r") as f:
        base = np.array(f["train"][:sample_size])
        queries = np.array(f["test"][:sample_size])
        ground_truth = np.array(f["neighbors"][:sample_size])
    return base, queries, ground_truth


SIFTSMALL_URL = "ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz"
SIFT_DATA_PATH = "./siftsmall"
SIFT_ARCHIVE_PATH = "./siftsmall.tar.gz"

def download_siftsmall():
    """Download and extract the SIFT Small dataset."""
    if not os.path.exists(SIFT_DATA_PATH):
        if not os.path.exists(SIFT_ARCHIVE_PATH):
            print("Downloading SIFT Small dataset...")
            urllib.request.urlretrieve(SIFTSMALL_URL, SIFT_ARCHIVE_PATH)
            print("SIFT Small dataset archive downloaded.")
        print("Extracting SIFT Small dataset...")
        import tarfile
        with tarfile.open(SIFT_ARCHIVE_PATH, "r:gz") as tar:
            tar.extractall(path=".")
        print("SIFT Small dataset extracted.")
    else:
        print("SIFT Small dataset already exists.")

def load_siftsmall(sample_size=10000):
    """Load the SIFT Small dataset."""
    download_siftsmall()
    base_path = os.path.join(SIFT_DATA_PATH, "siftsmall_base.fvecs")
    query_path = os.path.join(SIFT_DATA_PATH, "siftsmall_query.fvecs")
    ground_truth_path = os.path.join(SIFT_DATA_PATH, "siftsmall_groundtruth.ivecs")

    # handle fvec/ivecs format
    def load_vecs(file_path, dtype):
        data = np.fromfile(file_path, dtype=dtype)
        dim = data[0].view(np.int32)
        return data.reshape(-1, dim + 1)[:, 1:]

    base = load_vecs(base_path, np.float32)[:sample_size]
    queries = load_vecs(query_path, np.float32)[:sample_size]
    ground_truth = load_vecs(ground_truth_path, np.int32)[:sample_size]
    return base, queries, ground_truth

def compute_ground_truth(base, queries, k=10):
    print("Computing ground truth with brute-force L2 distance...")
    distances = np.linalg.norm(base[None, :, :] - queries[:, None, :], axis=2)
    ground_truth = np.argsort(distances, axis=1)[:, :k]
    print("Ground truth computed.")
    return ground_truth

def test_greedy_tree_range_search(base, queries, ground_truth, k=10, radius_multiplier=1.5):
    print("\n=== Testing Greedy Tree Range Search ===")

    global total_distcomps
    total_distcomps = 0

    base_points = [Point(vec) for vec in base]
    base_space = MetricSpace(base_points)

    point_to_index = {point: idx for idx, point in enumerate(base_points)}

    start_time = time.time()
    tree = greedy_tree(base_space)
    construction_time = time.time() - start_time
    print(f"Tree constructed in {construction_time:.2f} seconds.")
    print(f"Total distcomps (build): {total_distcomps}")

    total_distcomps = 0

    start_time = time.time()
    correct = 0
    increases = 0
    for i, query_vec in enumerate(queries):
        query_point = Point(query_vec)
        radius = 1.0
        neighbors = []

        # Increment radius until at least k neighbors are found
        while len(neighbors) < k:
            neighbors = list(tree.range_search(query_point, radius))
            radius *= radius_multiplier
            increases += 1

        neighbor_indices = [point_to_index[neighbor] for neighbor in neighbors[:k]]

        correct += len(set(neighbor_indices).intersection(set(ground_truth[i][:k])))
    search_time = time.time() - start_time

    recall = correct / (len(queries) * k)
    print(f"Greedy Tree Top-{k} accuracy: {recall:.4f}")
    print(f"Total query time: {search_time:.2f} seconds.")
    print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
    print(f"Total distcomps: {total_distcomps}")

def test_greedy_tree_ann(base, queries, ground_truth, approx):
    print("\n=== Testing Greedy Tree with ANN API ===")

    global total_distcomps
    total_distcomps = 0
    
    base_points = [Point(vec) for vec in base]
    base_space = MetricSpace(base_points)
    
    point_to_index = {point: idx for idx, point in enumerate(base_points)}
    
    start_time = time.time()
    tree = greedy_tree(base_space)
    construction_time = time.time() - start_time
    print(f"Tree constructed in {construction_time:.2f} seconds.")
    print(f"Total distcomps (build): {total_distcomps}")

    total_distcomps = 0
    
    start_time = time.time()
    correct = 0
    for i, query_vec in enumerate(queries):
        query_point = Point(query_vec)
        
        nearest_neighbor = tree.ann(query_point, approx)
        neighbor_index = point_to_index[nearest_neighbor]
        
        if neighbor_index == ground_truth[i][0]:
            correct += 1
    
    search_time = time.time() - start_time
    
    recall = correct / len(queries)
    print(f"Greedy Tree ANN Top-1 accuracy: {recall:.4f}")
    print(f"Total query time: {search_time:.2f} seconds.")
    print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
    print(f"Total distcomps: {total_distcomps}")

def test_greedy_tree_nn(base, queries, ground_truth):
    print("\n=== Testing Greedy Tree Nearest Neighbor (NN) Search ===")

    global total_distcomps
    total_distcomps = 0

    base_points = [Point(vec) for vec in base]
    base_space = MetricSpace(base_points)

    point_to_index = {point: idx for idx, point in enumerate(base_points)}

    start_time = time.time()
    tree = greedy_tree(base_space)
    construction_time = time.time() - start_time
    print(f"Tree constructed in {construction_time:.2f} seconds.")
    print(f"Total distcomps (build): {total_distcomps}")

    total_distcomps = 0

    start_time = time.time()
    correct = 0
    for i, query_vec in enumerate(queries):
        query_point = Point(query_vec)
        
        # Find the nearest neighbor
        nearest_neighbor = tree.nn(query_point)
        neighbor_index = point_to_index[nearest_neighbor]

        if neighbor_index == ground_truth[i][0]:
            correct += 1

    search_time = time.time() - start_time

    recall = correct / len(queries)
    print(f"Greedy Tree NN Top-1 accuracy: {recall:.4f}")
    print(f"Total query time: {search_time:.2f} seconds.")
    print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
    print(f"Total distcomps: {total_distcomps}")


def test_greedy_tree_knn(base, queries, ground_truth, k=10):
    print("\n=== Testing Greedy Tree k-Nearest Neighbors (kNN) Search ===")

    global total_distcomps
    total_distcomps = 0

    base_points = [Point(vec) for vec in base]
    base_space = MetricSpace(base_points)

    point_to_index = {point: idx for idx, point in enumerate(base_points)}

    start_time = time.time()
    tree = greedy_tree(base_space)
    construction_time = time.time() - start_time
    print(f"Tree constructed in {construction_time:.2f} seconds.")
    print(f"Total distcomps (build): {total_distcomps}")

    total_distcomps = 0

    start_time = time.time()
    correct = 0
    for i, query_vec in enumerate(queries):
        query_point = Point(query_vec)

        # Find the k nearest neighbors
        knn_neighbors = list(tree.knn(k, query_point))
        neighbor_indices = [point_to_index[neighbor] for neighbor in knn_neighbors]

        correct += len(set(neighbor_indices).intersection(set(ground_truth[i][:k])))

    search_time = time.time() - start_time

    recall = correct / (len(queries) * k)
    print(f"Greedy Tree kNN Top-{k} accuracy: {recall:.4f}")
    print(f"Total query time: {search_time:.2f} seconds.")
    print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
    print(f"Total distcomps: {total_distcomps}")


def test_greedy_tree_approx_knn(base, queries, ground_truth, k=10, approx=1.5):
    print("\n=== Testing Greedy Tree Approximate k-Nearest Neighbors (Approx kNN) Search ===")

    global total_distcomps
    total_distcomps = 0

    base_points = [Point(vec) for vec in base]
    base_space = MetricSpace(base_points)

    point_to_index = {point: idx for idx, point in enumerate(base_points)}

    start_time = time.time()
    tree = greedy_tree(base_space)
    construction_time = time.time() - start_time
    print(f"Tree constructed in {construction_time:.2f} seconds.")
    print(f"Total distcomps (build): {total_distcomps}")

    total_distcomps = 0

    start_time = time.time()
    correct = 0
    for i, query_vec in enumerate(queries):
        query_point = Point(query_vec)

        # Perform approximate kNN search
        knn_neighbors = list(tree.knn(k, query_point, approx))
        neighbor_indices = [point_to_index[neighbor] for neighbor in knn_neighbors]

        correct += len(set(neighbor_indices).intersection(set(ground_truth[i][:k])))

    search_time = time.time() - start_time

    recall = correct / (len(queries) * k)
    print(f"Greedy Tree Approx kNN Top-{k} accuracy: {recall:.4f}")
    print(f"Total query time: {search_time:.2f} seconds.")
    print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
    print(f"Total distcomps: {total_distcomps}")

if __name__ == "__main__":
    # SIFT
    print("!=== Running tests on SIFT ===!\n")
    base, queries, not_ground_truth = load_siftsmall(sample_size=500)
    ground_truth = compute_ground_truth(base, queries)

    # test 1-NN
    test_greedy_tree_nn(base, queries, ground_truth)
    test_greedy_tree_ann(base, queries, ground_truth, approx=1.01)
    test_greedy_tree_knn(base, queries, ground_truth, k=1)
    test_greedy_tree_approx_knn(base, queries, ground_truth, k=1, approx=1.01)
    test_greedy_tree_range_search(base, queries, ground_truth, k=1)

    # test 10-NN
    test_greedy_tree_knn(base, queries, ground_truth, k=10)
    test_greedy_tree_approx_knn(base, queries, ground_truth, k=10, approx=1.01)
    test_greedy_tree_range_search(base, queries, ground_truth, k=10)


    # GLoVe
    print("\n!=== Running tests on GloVe ===!\n")
    base, queries, not_ground_truth = load_glove(sample_size=500)
    ground_truth = compute_ground_truth(base, queries)

    # test 1-NN
    test_greedy_tree_nn(base, queries, ground_truth)
    test_greedy_tree_ann(base, queries, ground_truth, approx=1.01)
    test_greedy_tree_knn(base, queries, ground_truth, k=1)
    test_greedy_tree_approx_knn(base, queries, ground_truth, k=1, approx=1.01)
    test_greedy_tree_range_search(base, queries, ground_truth, k=1)

    # test 10-NN
    test_greedy_tree_knn(base, queries, ground_truth, k=10)
    test_greedy_tree_approx_knn(base, queries, ground_truth, k=10, approx=1.01)
    test_greedy_tree_range_search(base, queries, ground_truth, k=10)
Sample Output
!=== Running tests on SIFT ===!

SIFT Small dataset already exists.
Computing ground truth with brute-force L2 distance...
Ground truth computed.

=== Testing Greedy Tree Nearest Neighbor (NN) Search ===
Tree constructed in 8.09 seconds.
Total distcomps (build): 110024
Greedy Tree NN Top-1 accuracy: 1.0000
Total query time: 0.40 seconds.
Average query time: 0.0040 seconds/query.
Total distcomps: 31887

=== Testing Greedy Tree with ANN API ===
Tree constructed in 8.14 seconds.
Total distcomps (build): 110024
Greedy Tree ANN Top-1 accuracy: 0.9600
Total query time: 0.36 seconds.
Average query time: 0.0036 seconds/query.
Total distcomps: 31546

=== Testing Greedy Tree k-Nearest Neighbors (kNN) Search ===
Tree constructed in 8.14 seconds.
Total distcomps (build): 110024
Greedy Tree kNN Top-1 accuracy: 1.0000
Total query time: 0.65 seconds.
Average query time: 0.0065 seconds/query.
Total distcomps: 34456

=== Testing Greedy Tree Approximate k-Nearest Neighbors (Approx kNN) Search ===
Tree constructed in 8.24 seconds.
Total distcomps (build): 110024
Greedy Tree Approx kNN Top-1 accuracy: 1.0000
Total query time: 0.66 seconds.
Average query time: 0.0066 seconds/query.
Total distcomps: 34456

=== Testing Greedy Tree Range Search ===
Tree constructed in 8.19 seconds.
Total distcomps (build): 110024
Greedy Tree Top-1 accuracy: 0.2800
Total query time: 1.10 seconds.
Average query time: 0.0110 seconds/query.
Total distcomps: 35339

=== Testing Greedy Tree k-Nearest Neighbors (kNN) Search ===
Tree constructed in 8.22 seconds.
Total distcomps (build): 110024
Greedy Tree kNN Top-10 accuracy: 1.0000
Total query time: 1.08 seconds.
Average query time: 0.0108 seconds/query.
Total distcomps: 40794

=== Testing Greedy Tree Approximate k-Nearest Neighbors (Approx kNN) Search ===
Tree constructed in 8.11 seconds.
Total distcomps (build): 110024
Greedy Tree Approx kNN Top-10 accuracy: 1.0000
Total query time: 1.07 seconds.
Average query time: 0.0107 seconds/query.
Total distcomps: 40794

=== Testing Greedy Tree Range Search ===
Tree constructed in 8.18 seconds.
Total distcomps (build): 110024
Greedy Tree Top-10 accuracy: 0.3290
Total query time: 1.28 seconds.
Average query time: 0.0128 seconds/query.
Total distcomps: 39419


!=== Running tests on GloVe ===!

GloVe 100 Angular dataset already exists.
Computing ground truth with brute-force L2 distance...
Ground truth computed.

=== Testing Greedy Tree Nearest Neighbor (NN) Search ===
Tree constructed in 22.74 seconds.
Total distcomps (build): 125250
Greedy Tree NN Top-1 accuracy: 1.0000
Total query time: 2.99 seconds.
Average query time: 0.0060 seconds/query.
Total distcomps: 250000

=== Testing Greedy Tree with ANN API ===
Tree constructed in 22.58 seconds.
Total distcomps (build): 125250
Greedy Tree ANN Top-1 accuracy: 0.9460
Total query time: 2.82 seconds.
Average query time: 0.0056 seconds/query.
Total distcomps: 250000

=== Testing Greedy Tree k-Nearest Neighbors (kNN) Search ===
Tree constructed in 22.70 seconds.
Total distcomps (build): 125250
Greedy Tree kNN Top-1 accuracy: 1.0000
Total query time: 5.07 seconds.
Average query time: 0.0101 seconds/query.
Total distcomps: 249999

=== Testing Greedy Tree Approximate k-Nearest Neighbors (Approx kNN) Search ===
Tree constructed in 22.83 seconds.
Total distcomps (build): 125250
Greedy Tree Approx kNN Top-1 accuracy: 1.0000
Total query time: 5.10 seconds.
Average query time: 0.0102 seconds/query.
Total distcomps: 249999

=== Testing Greedy Tree Range Search ===
Tree constructed in 22.72 seconds.
Total distcomps (build): 125250
Greedy Tree Top-1 accuracy: 0.1540
Total query time: 11.04 seconds.
Average query time: 0.0221 seconds/query.
Total distcomps: 250000

=== Testing Greedy Tree k-Nearest Neighbors (kNN) Search ===
Tree constructed in 22.90 seconds.
Total distcomps (build): 125250
Greedy Tree kNN Top-10 accuracy: 1.0000
Total query time: 7.66 seconds.
Average query time: 0.0153 seconds/query.
Total distcomps: 250000

=== Testing Greedy Tree Approximate k-Nearest Neighbors (Approx kNN) Search ===
Tree constructed in 22.83 seconds.
Total distcomps (build): 125250
Greedy Tree Approx kNN Top-10 accuracy: 1.0000
Total query time: 7.58 seconds.
Average query time: 0.0152 seconds/query.
Total distcomps: 250000

=== Testing Greedy Tree Range Search ===
Tree constructed in 22.53 seconds.
Total distcomps (build): 125250
Greedy Tree Top-10 accuracy: 0.3166
Total query time: 12.08 seconds.
Average query time: 0.0242 seconds/query.
Total distcomps: 250000

Abbreviating, the unexpected sections/lines here are:

!=== Running tests on SIFT ===!
=== Testing Greedy Tree Range Search ===
Greedy Tree Top-1 accuracy: 0.2800
=== Testing Greedy Tree Range Search ===
Greedy Tree Top-10 accuracy: 0.3290

!=== Running tests on GloVe ===!
=== Testing Greedy Tree Range Search ===
Greedy Tree Top-1 accuracy: 0.1540
=== Testing Greedy Tree Range Search ===
Greedy Tree Top-10 accuracy: 0.3166

Unless I've misunderstood what this routine is intended to produce, I believe all four of these accuracy/recall values should be 1.0. Please let me know if I'm missing anything.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions