Commit df646a64 authored by Julien Jerphanion's avatar Julien Jerphanion

[WIP] Adapt querying logic

This change the logic to query the tree for nearest neighbors.
Start with a simple sequential query for each point.
parent 6c0e5a62
This diff is collapsed.
......@@ -14,20 +14,29 @@ def test_creation_deletion(n, d, leaf_size):
tree = kdtree.KDTree(X, leaf_size=256)
del tree
@pytest.mark.skip(reason="The query is being refactored.")
@pytest.mark.parametrize("n", [10, 100, 1000, 10000])
@pytest.mark.parametrize("d", [10, 100])
@pytest.mark.parametrize("k", [1, 2, 5, 10])
@pytest.mark.parametrize("leaf_size", [256, 1024])
def test_against_sklearn(n, d, k, leaf_size):
np.random.seed(1)
def test_against_sklearn(n, d, k, leaf_size, n_query=1):
np.random.seed(2)
X = np.random.rand(n, d)
query_points = np.random.rand(n, d)
query_points = np.random.rand(n_query, d)
tree = kdtree.KDTree(X, leaf_size=256)
skl_tree = KDTree(X, leaf_size=256)
closests = np.zeros((n, k), dtype=np.int32)
tree.query(query_points, closests)
skl_closests = skl_tree.query(query_points, k=k, return_distance=False).astype(np.int32)
np.testing.assert_equal(closests, skl_closests)
knn_indices = np.zeros((n_query, k), dtype=np.int32)
knn_distances = np.zeros((n_query, k), dtype=np.float64)
tree.query(query_points, knn_indices, knn_distances)
skl_knn_distances, skl_knn_indices = skl_tree.query(
query_points,
k=k,
return_distance=True
)
# Adapting types
skl_knn_indices = skl_knn_indices.astype(np.int32)
np.testing.assert_equal(knn_indices, skl_knn_indices)
np.testing.assert_almost_equal(knn_distances, skl_knn_distances)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment