Commit 5147ca6f authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: closest neighbor precalc heuristic

This is based on the heuristic  that if a candidate neighbor
has a very close neighbor of its own, than this close neighbor
is also likely a candidate neighbor itself.

Meaning, we might replace the loop that compares a candidate
with all neighbors if we know the distance between the candidate
and its closest neighbor. Which can be precalculated.

This gives the most speedup when the number of neighbors
and the number of dimensions are large. In the tests it was
2.5-3x speedup, with the recall being worse by 0.1%-1%

Incidentally, in the opposite case it gives both litle speedup
and notably worse recall. Tests have shown 1.13x speedup
with recall going down by ~20% in the worst - smallest - case.

Thus, this heuristic is only enabled above the certain threshold.
parent 1165e6a6
......@@ -21,6 +21,10 @@
#include "key.h"
#include <scope.h>
#define clo_nei_size 4
#define clo_nei_store float4store
#define clo_nei_read float4get
// Algorithm parameters
// best by test (fastest construction with recall > 99% for ef=20, limit=10)
// for random-xs-20-euclidean (9000) [ 3, 1.1, M=7 ]
......@@ -28,6 +32,7 @@
// for sift-128-euclidean (1000000) [ 4, 1.1, M>64 ] (98% with M=64)
static const double ef_construction_multiplier = 4;
static const double alpha = 1.1;
static const uint clo_nei_threshold= 10000;
// SIMD definitions
#define SIMD_word (256/8)
......@@ -54,8 +59,9 @@ class FVectorNode: public FVector
private:
uchar *ref;
List<FVectorNode> *neighbors= nullptr;
char *neighbors_read= 0;
public:
float *closest_neighbor= 0;
FVectorNode(MHNSW_Context *ctx_, const void *ref_);
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
float distance_to(const FVector &other) const;
......@@ -63,6 +69,7 @@ class FVectorNode: public FVector
int instantiate_neighbors(size_t layer);
size_t get_ref_len() const;
uchar *get_ref() const { return ref; }
void update_closest_neighbor(size_t layer, float dist, const FVectorNode &v);
List<FVectorNode> &get_neighbors(size_t layer) const;
bool is_new() const;
......@@ -164,10 +171,10 @@ int FVectorNode::instantiate_neighbors(size_t layer)
if (!neighbors)
{
neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
neighbors_read= (char*)alloc_root(&ctx->root, layer+1);
bzero(neighbors_read, layer+1);
closest_neighbor= (float*)alloc_root(&ctx->root, (layer+1)*sizeof(*closest_neighbor));
memset(closest_neighbor, 0xff, (layer+1)*sizeof(*closest_neighbor)); // NaN
}
if (!neighbors_read[layer])
if (isnan(closest_neighbor[layer]))
{
if (!is_new())
{
......@@ -183,13 +190,15 @@ int FVectorNode::instantiate_neighbors(size_t layer)
return ctx->err;
String strbuf, *str= graph->field[2]->val_str(&strbuf);
if (str->length() % ref_len)
if ((str->length() - clo_nei_size) % ref_len)
return ctx->err= HA_ERR_CRASHED; // corrupted HNSW index
for (const char *pos= str->ptr(); pos < str->end(); pos+= ref_len)
clo_nei_read(closest_neighbor[layer], str->ptr());
for (const char *pos= str->ptr() + clo_nei_size; pos < str->end(); pos+= ref_len)
neighbors[layer].push_back(ctx->get_node(pos), &ctx->root);
}
neighbors_read[layer]= 1;
else
closest_neighbor[layer]= FLT_MAX;
}
return 0;
......@@ -201,6 +210,14 @@ List<FVectorNode> &FVectorNode::get_neighbors(size_t layer) const
return neighbors[layer];
}
void FVectorNode::update_closest_neighbor(size_t layer, float dist,
const FVectorNode &other)
{
if (memcmp(ref, other.get_ref(), get_ref_len()) < 0 &&
closest_neighbor[layer] > dist)
closest_neighbor[layer]= dist;
}
size_t FVectorNode::get_ref_len() const
{
return ctx->table->file->ref_length;
......@@ -241,7 +258,7 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod
}
static int select_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &target,
FVectorNode &target,
const List<FVectorNode> &candidates_unsafe,
size_t max_neighbor_connections)
{
......@@ -254,11 +271,13 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
*/
List<FVectorNode> candidates= candidates_unsafe;
List<FVectorNode> &neighbors= target.get_neighbors(layer);
const bool do_cn= max_neighbor_connections*ctx->vec_len > clo_nei_threshold;
if (ctx->err)
return ctx->err;
neighbors.empty();
target.closest_neighbor[layer]= FLT_MAX;
if (pq.init(10000, 0, cmp_vec, &target) ||
pq_discard.init(10000, 0, cmp_vec, &target))
......@@ -277,22 +296,35 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
{
const FVectorNode *vec= pq.pop();
const float target_dist= vec->distance_to(target);
const float target_dista= target_dist / alpha;
bool discard= false;
for (const FVectorNode &neigh : neighbors)
if (do_cn)
{
if ((discard= vec->distance_to(neigh) * alpha < target_dist))
break;
vec->get_neighbors(layer);
discard= vec->closest_neighbor[layer] < target_dista;
}
else
{
for (const FVectorNode &neigh : neighbors)
{
if ((discard= vec->distance_to(neigh) < target_dista))
break;
}
}
if (!discard)
{
neighbors.push_back(vec, &ctx->root);
target.update_closest_neighbor(layer, target_dist, *vec);
}
else if (pq_discard.elements() + neighbors.elements < max_neighbor_connections)
pq_discard.push(vec);
}
while (pq_discard.elements() &&
neighbors.elements < max_neighbor_connections)
while (pq_discard.elements() && neighbors.elements < max_neighbor_connections)
{
neighbors.push_back(pq_discard.pop(), &ctx->root);
const FVectorNode *vec= pq_discard.pop();
neighbors.push_back(vec, &ctx->root);
target.update_closest_neighbor(layer, vec->distance_to(target), *vec);
}
return 0;
......@@ -300,32 +332,33 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
static int write_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &source_node)
const FVectorNode &node)
{
TABLE *graph= ctx->table->hlindex;
const List<FVectorNode> &new_neighbors= source_node.get_neighbors(layer);
const List<FVectorNode> &new_neighbors= node.get_neighbors(layer);
if (ctx->err)
return ctx->err;
size_t total_size= new_neighbors.elements * source_node.get_ref_len();
size_t total_size= new_neighbors.elements * node.get_ref_len() + clo_nei_size;
// Allocate memory for the struct and the flexible array member
char *neighbor_array_bytes= static_cast<char *>(my_safe_alloca(total_size));
char *pos= neighbor_array_bytes;
for (const auto &node: new_neighbors)
clo_nei_store(neighbor_array_bytes, node.closest_neighbor[layer]);
char *pos= neighbor_array_bytes + clo_nei_size;
for (const auto &neigh: new_neighbors)
{
DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len());
memcpy(pos, node.get_ref(), node.get_ref_len());
pos+= node.get_ref_len();
DBUG_ASSERT(neigh.get_ref_len() == node.get_ref_len());
memcpy(pos, neigh.get_ref(), neigh.get_ref_len());
pos+= neigh.get_ref_len();
}
graph->field[0]->store(layer, false);
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
graph->field[1]->store_binary(node.get_ref(), node.get_ref_len());
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
if (source_node.is_new())
if (node.is_new())
ctx->err= graph->file->ha_write_row(graph->record[0]);
else
{
......@@ -350,12 +383,13 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors,
const FVectorNode &node)
{
for (const FVectorNode &neigh: node.get_neighbors(layer))
for (FVectorNode &neigh: node.get_neighbors(layer))
{
List<FVectorNode> &neighneighbors= neigh.get_neighbors(layer);
if (ctx->err)
return ctx->err;
neighneighbors.push_back(&node, &ctx->root);
neigh.update_closest_neighbor(layer, neigh.distance_to(node), node);
if (neighneighbors.elements > max_neighbors)
{
if (select_neighbors(ctx, layer, neigh, neighneighbors, max_neighbors))
......@@ -671,7 +705,7 @@ const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
" primary key (layer, src)) ";
size_t len= sizeof(templ) + 32;
char *s= thd->alloc(len);
len= my_snprintf(s, len, templ, ref_length, 2 * ref_length *
thd->variables.mhnsw_max_edges_per_node);
len= my_snprintf(s, len, templ, ref_length, clo_nei_size +
2 * ref_length * thd->variables.mhnsw_max_edges_per_node);
return {s, len};
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment