Commit 6c0e5a62 authored by Julien Jerphanion's avatar Julien Jerphanion

[WIP] Adapt querying logic

This change the logic to query the tree for nearest neighbors.
Start with a simple sequential query for each point.
parent acfcdbe0
...@@ -318,8 +318,11 @@ cdef cypclass NeighborsHeaps: ...@@ -318,8 +318,11 @@ cdef cypclass NeighborsHeaps:
self._n_pushes += 1 self._n_pushes += 1
printf("Pushing for %d, (%d, %lf)\n", row, i_val, val)
# check if val should be in heap # check if val should be in heap
if val > distances[0]: if val > distances[0]:
printf("Discarding %d\n", row)
return return
# insert val at position zero # insert val at position zero
...@@ -383,8 +386,8 @@ cdef cypclass NeighborsHeaps: ...@@ -383,8 +386,8 @@ cdef cypclass NeighborsHeaps:
# use of getIntResult # use of getIntResult
return 1 if self._sorted else 0 return 1 if self._sorted else 0
I_t largest(self, I_t index): D_t largest(self, I_t index):
return self._indices[index * self._n_nbrs] return self._distances[index * self._n_nbrs]
cdef cypclass Node activable: cdef cypclass Node activable:
...@@ -407,63 +410,86 @@ cdef cypclass Node activable: ...@@ -407,63 +410,86 @@ cdef cypclass Node activable:
active Node _left active Node _left
active Node _right active Node _right
__init__(self, NodeData_t * node_data_ptr, D_t * node_bounds_ptr): I_t _node_bounds_ptr_offset
__init__(self, NodeData_t * node_data_ptr, D_t * node_bounds_ptr, I_t node_bounds_ptr_offset):
# Needed by for Cython+ actors # Needed by for Cython+ actors
self._active_result_class = WaitResult.construct self._active_result_class = WaitResult.construct
self._active_queue_class = consume BatchMailBox(scheduler) self._active_queue_class = consume BatchMailBox(scheduler)
self._node_data_ptr = node_data_ptr self._node_data_ptr = node_data_ptr
self._node_bounds_ptr = node_bounds_ptr self._node_bounds_ptr = node_bounds_ptr
self._node_bounds_ptr_offset = node_bounds_ptr_offset
# We use this to allow using actors for initialisation # We use this to allow using actors for initialisation
# because __init__ can't be reified. # because __init__ can't be reified.
void build_node( void build_node(
self, self,
I_t node_index, I_t node_index,
D_t * data_ptr, D_t * data_ptr,
I_t * indices_ptr, I_t * indices_ptr,
I_t leaf_size, I_t leaf_size,
I_t n_dims, I_t n_features,
I_t dim, I_t dim,
I_t start, I_t idx_start,
I_t end, I_t idx_end,
active Counter counter, active Counter counter,
): ):
cdef NodeData_t * node_data = self._node_data_ptr + node_index
deref(node_data).idx_start = idx_start
deref(node_data).idx_end = idx_end
deref(node_data).is_leaf = False
cdef DTYPE_t * lower_bounds = self._node_bounds_ptr + node_index * n_features
cdef DTYPE_t * upper_bounds = self._node_bounds_ptr + node_index * n_features +
cdef DTYPE_t * data_row
# Determine Node bounds
for j in range(n_features):
lower_bounds[j] = INF
upper_bounds[j] = -INF
# Compute the actual data range. At build time, this is slightly
# slower than using the previously-computed bounds of the parent node,
# but leads to more compact trees and thus faster queries.
for i in range(idx_start, idx_end):
data_row = data_ptr + indices_ptr[i] * n_features
for j in range(n_features):
lower_bounds[j] = fmin(lower_bounds[j], data_row[j])
upper_bounds[j] = fmax(upper_bounds[j], data_row[j])
# Choose the dimension with maximum spread at each recursion instead. # Choose the dimension with maximum spread at each recursion instead.
cdef I_t next_dim = find_node_split_dim(data_ptr, cdef I_t next_dim = find_node_split_dim(data_ptr,
indices_ptr + start, indices_ptr + start,
n_dims, n_features,
end - start) end - start)
cdef I_t mid = (start + end) // 2 cdef I_t mid = (start + end) // 2
cdef NodeData_t * node_data = self._node_data_ptr + node_index
deref(node_data).idx_start = start
deref(node_data).idx_end = end
if (end - start <= leaf_size): if idx_end - idx_start <= leaf_size:
deref(node_data).is_leaf = True deref(node_data).is_leaf = True
# Adding to the global counter the number # Adding to the global counter the number
# of samples the leaf is responsible for. # of samples the leaf is responsible for.
counter.add(NULL, end - start) counter.add(NULL, idx_end - idx_start)
return return
# We partition the samples in two nodes on a given dimension, # We partition the samples in two nodes on a given dimension,
# with the middle point as a pivot. # with the middle point as a pivot.
partition_node_indices(data_ptr, indices_ptr, start, mid, end, dim, n_dims) partition_node_indices(data_ptr, indices_ptr, idx_start, mid, idx_end, dim, n_features)
self._left = consume Node(self._node_data_ptr, self._node_bounds_ptr) self._left = consume Node(self._node_data_ptr, self._node_bounds_ptr, self._node_bounds_ptr_offset)
self._right = consume Node(self._node_data_ptr, self._node_bounds_ptr) self._right = consume Node(self._node_data_ptr, self._node_bounds_ptr, self._node_bounds_ptr_offset)
# Recursing on both partitions. # Recursing on both partitions.
self._left.build_node(NULL, <I_t> 2 * node_index, self._left.build_node(NULL, <I_t> 2 * node_index,
data_ptr, indices_ptr, data_ptr, indices_ptr,
leaf_size, n_dims, next_dim, leaf_size, n_features, next_dim,
start, mid, counter) idx_start, mid, counter)
self._right.build_node(NULL, <I_t> (2 * node_index + 1), self._right.build_node(NULL, <I_t> (2 * node_index + 1),
data_ptr, indices_ptr, data_ptr, indices_ptr,
leaf_size, n_dims, next_dim, leaf_size, n_features, next_dim,
mid, end, counter) mid, idx_end, counter)
cdef cypclass KDTree: cdef cypclass KDTree:
...@@ -557,7 +583,7 @@ cdef cypclass KDTree: ...@@ -557,7 +583,7 @@ cdef cypclass KDTree:
self._data_ptr, self._data_ptr,
self._indices_ptr, self._indices_ptr,
self._leaf_size, self._leaf_size,
n_dims=self._d, n_features=self._d,
dim=0, dim=0,
start=0, start=0,
end=self._n, end=self._n,
...@@ -595,12 +621,18 @@ cdef cypclass KDTree: ...@@ -595,12 +621,18 @@ cdef cypclass KDTree:
#------------------------------------------------------------ #------------------------------------------------------------
# Case 1: query point is outside node radius: # Case 1: query point is outside node radius:
# trim it from the query # trim it from the query
if reduced_dist_LB > heaps.largest(i_pt): cdef D_t largest = heaps.largest(i_pt)
printf("reduced_dist_LB=%lf\n", reduced_dist_LB)
if reduced_dist_LB > largest:
printf("Discarding node %d because reduced_dist_LB=%lf > largest=%lf\n", reduced_dist_LB, largest)
pass pass
#------------------------------------------------------------ #------------------------------------------------------------
# Case 2: this is a leaf node. Update set of nearby points # Case 2: this is a leaf node. Update set of nearby points
elif node_info.is_leaf: elif node_info.is_leaf:
printf("Inspecting vector in leaf %d\n", i_node)
for i in range(node_info.idx_start, node_info.idx_end): for i in range(node_info.idx_start, node_info.idx_end):
dist_pt = sqeuclidean_dist( dist_pt = sqeuclidean_dist(
x1=pt, x1=pt,
...@@ -613,6 +645,7 @@ cdef cypclass KDTree: ...@@ -613,6 +645,7 @@ cdef cypclass KDTree:
# Case 3: Node is not a leaf. Recursively query subnodes # Case 3: Node is not a leaf. Recursively query subnodes
# starting with the closest # starting with the closest
else: else:
printf("Deleguating to children %d\n", i_node)
i1 = 2 * i_node + 1 i1 = 2 * i_node + 1
i2 = i1 + 1 i2 = i1 + 1
reduced_dist_LB_1 = self.min_rdist(i1, pt) reduced_dist_LB_1 = self.min_rdist(i1, pt)
...@@ -645,8 +678,10 @@ cdef cypclass KDTree: ...@@ -645,8 +678,10 @@ cdef cypclass KDTree:
NeighborsHeaps heaps = NeighborsHeaps(<I_t *> closests.data, n_query, n_neighbors) NeighborsHeaps heaps = NeighborsHeaps(<I_t *> closests.data, n_query, n_neighbors)
for i in range(n_query): for i in range(n_query):
printf("Querying vector %d\n", i)
rdist_lower_bound = self.min_rdist(0, _query_points_ptr + i * n_features) rdist_lower_bound = self.min_rdist(0, _query_points_ptr + i * n_features)
self._query_single_depthfirst(0, _query_points_ptr, i, heaps, rdist_lower_bound) self._query_single_depthfirst(0, _query_points_ptr, i, heaps, rdist_lower_bound)
printf("Done Querying vector %d\n\n", i)
heaps.sort() heaps.sort()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment