Commit e339fd00 authored by Sergei Golubchik's avatar Sergei Golubchik

cleanup search_layer()

to return only as many elements as needed, the caller no longer needs to
overallocate result arrays for throwaway nodes
parent d94d4fb9
......@@ -885,16 +885,29 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
}
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
Neighborhood *start_nodes, uint ef, size_t layer,
Neighborhood *result, bool skip_deleted)
Neighborhood *start_nodes, uint result_size,
size_t layer, Neighborhood *result, bool construction)
{
DBUG_ASSERT(start_nodes->num > 0);
result->num= 0;
MEM_ROOT * const root= graph->in_use->mem_root;
Queue<Visited> candidates, best;
bool skip_deleted;
uint ef= result_size;
Queue<Visited> candidates;
Queue<Visited> best;
if (construction)
{
skip_deleted= false;
if (ef > 1)
ef= std::max(ef_construction, ef);
}
else
{
skip_deleted= layer == 0;
if (ef > 1 || layer == 0)
ef= std::max(graph->in_use->variables.mhnsw_min_limit, ef);
}
// WARNING! heuristic here
const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
......@@ -904,23 +917,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
candidates.init(10000, false, Visited::cmp);
best.init(ef, true, Visited::cmp);
DBUG_ASSERT(start_nodes->num <= result_size);
for (size_t i=0; i < start_nodes->num; i++)
{
Visited *v= visited.create(start_nodes->links[i]);
candidates.push(v);
if (skip_deleted && v->node->deleted)
continue;
if (best.elements() < ef)
best.push(v);
else if (v->distance_to_target < best.top()->distance_to_target)
best.replace_top(v);
best.push(v);
}
float furthest_best= FLT_MAX;
while (candidates.elements())
{
const Visited &cur= *candidates.pop();
if (cur.distance_to_target > furthest_best && best.elements() == ef)
if (cur.distance_to_target > furthest_best && best.is_full())
break; // All possible candidates are worse than what we have
visited.flush();
......@@ -940,7 +951,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
if (int err= links[i]->load(graph))
return err;
Visited *v= visited.create(links[i]);
if (best.elements() < ef)
if (!best.is_full())
{
candidates.push(v);
if (skip_deleted && v->node->deleted)
......@@ -965,6 +976,9 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok
}
while (best.elements() > result_size)
best.pop();
result->num= best.elements();
for (FVectorNode **links= result->links + result->num; best.elements();)
*--links= best.pop()->node;
......@@ -1028,9 +1042,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
if (ctx->byte_len != res->length())
return bad_value_on_insert(vec_field);
const size_t max_found= ctx->max_neighbors(0);
Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction);
start_nodes.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction);
candidates.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
start_nodes.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
start_nodes.links[start_nodes.num++]= ctx->start;
const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
......@@ -1058,7 +1073,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
{
uint max_neighbors= ctx->max_neighbors(cur_layer);
if (int err= search_layer(ctx, graph, target->vec, &start_nodes,
ef_construction, cur_layer, &candidates, false))
max_neighbors, cur_layer, &candidates, true))
return err;
if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
......@@ -1100,11 +1115,9 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
return err;
SCOPE_EXIT([ctx, table](){ ctx->release(table); });
size_t ef= thd->variables.mhnsw_min_limit;
Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef);
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef);
candidates.init(thd->alloc<FVectorNode*>(limit + 7), limit);
start_nodes.init(thd->alloc<FVectorNode*>(limit + 7), limit);
// one could put all max_layer nodes in start_nodes
// but it has no effect of the recall or speed
......@@ -1139,8 +1152,8 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
std::swap(start_nodes, candidates);
}
if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0,
&candidates, true))
if (int err= search_layer(ctx, graph, target, &start_nodes,
static_cast<uint>(limit), 0, &candidates, false))
return err;
if (limit > candidates.num)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment