import os
import time
import urllib.request
import numpy as np
from greedypermutation.balltree import greedy_tree
from metricspaces import MetricSpace
total_distcomps = 0
# Faster numpy point
class Point:
def __init__(self, point):
self.point = np.array(point)
self.hash = hash(str(self.point))
def __hash__(self):
return self.hash
def distsq(self, other):
global total_distcomps
total_distcomps += 1
return np.sum((self.point - other.point)**2)
def dist(self, other):
global total_distcomps
total_distcomps += 1
return np.linalg.norm(self.point - other.point)
def __str__(self):
return str(self.point)
def __iter__(self):
return iter(self.point)
GLOVE_URL = "http://ann-benchmarks.com/glove-100-angular.hdf5"
DATA_PATH = "./glove_100_angular.hdf5"
def download_glove():
"""Download the GloVe 100 Angular dataset."""
if not os.path.exists(DATA_PATH):
print("Downloading GloVe 100 Angular dataset...")
urllib.request.urlretrieve(GLOVE_URL, DATA_PATH)
print("GloVe 100 Angular dataset downloaded.")
else:
print("GloVe 100 Angular dataset already exists.")
def load_glove(sample_size=5000):
"""Load the GloVe 100 Angular dataset and sample the first `sample_size` points."""
import h5py
download_glove()
with h5py.File(DATA_PATH, "r") as f:
base = np.array(f["train"][:sample_size])
queries = np.array(f["test"][:sample_size])
ground_truth = np.array(f["neighbors"][:sample_size])
return base, queries, ground_truth
SIFTSMALL_URL = "ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz"
SIFT_DATA_PATH = "./siftsmall"
SIFT_ARCHIVE_PATH = "./siftsmall.tar.gz"
def download_siftsmall():
"""Download and extract the SIFT Small dataset."""
if not os.path.exists(SIFT_DATA_PATH):
if not os.path.exists(SIFT_ARCHIVE_PATH):
print("Downloading SIFT Small dataset...")
urllib.request.urlretrieve(SIFTSMALL_URL, SIFT_ARCHIVE_PATH)
print("SIFT Small dataset archive downloaded.")
print("Extracting SIFT Small dataset...")
import tarfile
with tarfile.open(SIFT_ARCHIVE_PATH, "r:gz") as tar:
tar.extractall(path=".")
print("SIFT Small dataset extracted.")
else:
print("SIFT Small dataset already exists.")
def load_siftsmall(sample_size=10000):
"""Load the SIFT Small dataset."""
download_siftsmall()
base_path = os.path.join(SIFT_DATA_PATH, "siftsmall_base.fvecs")
query_path = os.path.join(SIFT_DATA_PATH, "siftsmall_query.fvecs")
ground_truth_path = os.path.join(SIFT_DATA_PATH, "siftsmall_groundtruth.ivecs")
# handle fvec/ivecs format
def load_vecs(file_path, dtype):
data = np.fromfile(file_path, dtype=dtype)
dim = data[0].view(np.int32)
return data.reshape(-1, dim + 1)[:, 1:]
base = load_vecs(base_path, np.float32)[:sample_size]
queries = load_vecs(query_path, np.float32)[:sample_size]
ground_truth = load_vecs(ground_truth_path, np.int32)[:sample_size]
return base, queries, ground_truth
def compute_ground_truth(base, queries, k=10):
print("Computing ground truth with brute-force L2 distance...")
distances = np.linalg.norm(base[None, :, :] - queries[:, None, :], axis=2)
ground_truth = np.argsort(distances, axis=1)[:, :k]
print("Ground truth computed.")
return ground_truth
def test_greedy_tree_range_search(base, queries, ground_truth, k=10, radius_multiplier=1.5):
print("\n=== Testing Greedy Tree Range Search ===")
global total_distcomps
total_distcomps = 0
base_points = [Point(vec) for vec in base]
base_space = MetricSpace(base_points)
point_to_index = {point: idx for idx, point in enumerate(base_points)}
start_time = time.time()
tree = greedy_tree(base_space)
construction_time = time.time() - start_time
print(f"Tree constructed in {construction_time:.2f} seconds.")
print(f"Total distcomps (build): {total_distcomps}")
total_distcomps = 0
start_time = time.time()
correct = 0
increases = 0
for i, query_vec in enumerate(queries):
query_point = Point(query_vec)
radius = 1.0
neighbors = []
# Increment radius until at least k neighbors are found
while len(neighbors) < k:
neighbors = list(tree.range_search(query_point, radius))
radius *= radius_multiplier
increases += 1
neighbor_indices = [point_to_index[neighbor] for neighbor in neighbors[:k]]
correct += len(set(neighbor_indices).intersection(set(ground_truth[i][:k])))
search_time = time.time() - start_time
recall = correct / (len(queries) * k)
print(f"Greedy Tree Top-{k} accuracy: {recall:.4f}")
print(f"Total query time: {search_time:.2f} seconds.")
print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
print(f"Total distcomps: {total_distcomps}")
def test_greedy_tree_ann(base, queries, ground_truth, approx):
print("\n=== Testing Greedy Tree with ANN API ===")
global total_distcomps
total_distcomps = 0
base_points = [Point(vec) for vec in base]
base_space = MetricSpace(base_points)
point_to_index = {point: idx for idx, point in enumerate(base_points)}
start_time = time.time()
tree = greedy_tree(base_space)
construction_time = time.time() - start_time
print(f"Tree constructed in {construction_time:.2f} seconds.")
print(f"Total distcomps (build): {total_distcomps}")
total_distcomps = 0
start_time = time.time()
correct = 0
for i, query_vec in enumerate(queries):
query_point = Point(query_vec)
nearest_neighbor = tree.ann(query_point, approx)
neighbor_index = point_to_index[nearest_neighbor]
if neighbor_index == ground_truth[i][0]:
correct += 1
search_time = time.time() - start_time
recall = correct / len(queries)
print(f"Greedy Tree ANN Top-1 accuracy: {recall:.4f}")
print(f"Total query time: {search_time:.2f} seconds.")
print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
print(f"Total distcomps: {total_distcomps}")
def test_greedy_tree_nn(base, queries, ground_truth):
print("\n=== Testing Greedy Tree Nearest Neighbor (NN) Search ===")
global total_distcomps
total_distcomps = 0
base_points = [Point(vec) for vec in base]
base_space = MetricSpace(base_points)
point_to_index = {point: idx for idx, point in enumerate(base_points)}
start_time = time.time()
tree = greedy_tree(base_space)
construction_time = time.time() - start_time
print(f"Tree constructed in {construction_time:.2f} seconds.")
print(f"Total distcomps (build): {total_distcomps}")
total_distcomps = 0
start_time = time.time()
correct = 0
for i, query_vec in enumerate(queries):
query_point = Point(query_vec)
# Find the nearest neighbor
nearest_neighbor = tree.nn(query_point)
neighbor_index = point_to_index[nearest_neighbor]
if neighbor_index == ground_truth[i][0]:
correct += 1
search_time = time.time() - start_time
recall = correct / len(queries)
print(f"Greedy Tree NN Top-1 accuracy: {recall:.4f}")
print(f"Total query time: {search_time:.2f} seconds.")
print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
print(f"Total distcomps: {total_distcomps}")
def test_greedy_tree_knn(base, queries, ground_truth, k=10):
print("\n=== Testing Greedy Tree k-Nearest Neighbors (kNN) Search ===")
global total_distcomps
total_distcomps = 0
base_points = [Point(vec) for vec in base]
base_space = MetricSpace(base_points)
point_to_index = {point: idx for idx, point in enumerate(base_points)}
start_time = time.time()
tree = greedy_tree(base_space)
construction_time = time.time() - start_time
print(f"Tree constructed in {construction_time:.2f} seconds.")
print(f"Total distcomps (build): {total_distcomps}")
total_distcomps = 0
start_time = time.time()
correct = 0
for i, query_vec in enumerate(queries):
query_point = Point(query_vec)
# Find the k nearest neighbors
knn_neighbors = list(tree.knn(k, query_point))
neighbor_indices = [point_to_index[neighbor] for neighbor in knn_neighbors]
correct += len(set(neighbor_indices).intersection(set(ground_truth[i][:k])))
search_time = time.time() - start_time
recall = correct / (len(queries) * k)
print(f"Greedy Tree kNN Top-{k} accuracy: {recall:.4f}")
print(f"Total query time: {search_time:.2f} seconds.")
print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
print(f"Total distcomps: {total_distcomps}")
def test_greedy_tree_approx_knn(base, queries, ground_truth, k=10, approx=1.5):
print("\n=== Testing Greedy Tree Approximate k-Nearest Neighbors (Approx kNN) Search ===")
global total_distcomps
total_distcomps = 0
base_points = [Point(vec) for vec in base]
base_space = MetricSpace(base_points)
point_to_index = {point: idx for idx, point in enumerate(base_points)}
start_time = time.time()
tree = greedy_tree(base_space)
construction_time = time.time() - start_time
print(f"Tree constructed in {construction_time:.2f} seconds.")
print(f"Total distcomps (build): {total_distcomps}")
total_distcomps = 0
start_time = time.time()
correct = 0
for i, query_vec in enumerate(queries):
query_point = Point(query_vec)
# Perform approximate kNN search
knn_neighbors = list(tree.knn(k, query_point, approx))
neighbor_indices = [point_to_index[neighbor] for neighbor in knn_neighbors]
correct += len(set(neighbor_indices).intersection(set(ground_truth[i][:k])))
search_time = time.time() - start_time
recall = correct / (len(queries) * k)
print(f"Greedy Tree Approx kNN Top-{k} accuracy: {recall:.4f}")
print(f"Total query time: {search_time:.2f} seconds.")
print(f"Average query time: {search_time / len(queries):.4f} seconds/query.")
print(f"Total distcomps: {total_distcomps}")
if __name__ == "__main__":
# SIFT
print("!=== Running tests on SIFT ===!\n")
base, queries, not_ground_truth = load_siftsmall(sample_size=500)
ground_truth = compute_ground_truth(base, queries)
# test 1-NN
test_greedy_tree_nn(base, queries, ground_truth)
test_greedy_tree_ann(base, queries, ground_truth, approx=1.01)
test_greedy_tree_knn(base, queries, ground_truth, k=1)
test_greedy_tree_approx_knn(base, queries, ground_truth, k=1, approx=1.01)
test_greedy_tree_range_search(base, queries, ground_truth, k=1)
# test 10-NN
test_greedy_tree_knn(base, queries, ground_truth, k=10)
test_greedy_tree_approx_knn(base, queries, ground_truth, k=10, approx=1.01)
test_greedy_tree_range_search(base, queries, ground_truth, k=10)
# GLoVe
print("\n!=== Running tests on GloVe ===!\n")
base, queries, not_ground_truth = load_glove(sample_size=500)
ground_truth = compute_ground_truth(base, queries)
# test 1-NN
test_greedy_tree_nn(base, queries, ground_truth)
test_greedy_tree_ann(base, queries, ground_truth, approx=1.01)
test_greedy_tree_knn(base, queries, ground_truth, k=1)
test_greedy_tree_approx_knn(base, queries, ground_truth, k=1, approx=1.01)
test_greedy_tree_range_search(base, queries, ground_truth, k=1)
# test 10-NN
test_greedy_tree_knn(base, queries, ground_truth, k=10)
test_greedy_tree_approx_knn(base, queries, ground_truth, k=10, approx=1.01)
test_greedy_tree_range_search(base, queries, ground_truth, k=10)
Was nice to see some of the recent work on this at FWCG this year.
I've wrote up some quick tests on a couple popular ANN datasets (SIFT and GloVe), and I think I've stumbled upon a bug with the range search function. Repro below.
Python repro/full test
Sample Output
Abbreviating, the unexpected sections/lines here are:
Unless I've misunderstood what this routine is intended to produce, I believe all four of these accuracy/recall values should be 1.0. Please let me know if I'm missing anything.