Commit f661b93c authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: modify target's neighbors directly

parent f6bc9879
......@@ -226,10 +226,9 @@ const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; // XXX or false?
static int select_neighbors(MHNSW_Context *ctx,
size_t layer, const FVector &target,
size_t layer, const FVectorNode &target,
const List<FVectorNode> &candidates,
size_t max_neighbor_connections,
List<FVectorNode> *neighbors)
size_t max_neighbor_connections)
{
/*
TODO: If the input neighbors list is already sorted in search_layer, then
......@@ -297,8 +296,10 @@ static int select_neighbors(MHNSW_Context *ctx,
}
DBUG_ASSERT(best.elements() <= max_neighbor_connections);
while (best.elements()) // XXX why not to return best directly?
neighbors->push_front(best.pop(), &ctx->root);
List<FVectorNode> &neighbors= target.get_neighbors(layer);
neighbors.empty();
while (best.elements())
neighbors.push_front(best.pop(), &ctx->root);
return 0;
}
......@@ -344,11 +345,11 @@ static void dbug_print_hash_vec(Hash_set<FVectorNode> &h)
static int write_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &source_node,
const List<FVectorNode> &new_neighbors)
const FVectorNode &source_node)
{
int err;
TABLE *graph= ctx->table->hlindex;
const List<FVectorNode> &new_neighbors= source_node.get_neighbors(layer);
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
size_t total_size= HNSW_MAX_M_WIDTH + new_neighbors.elements * source_node.get_ref_len();
......@@ -396,31 +397,25 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer,
static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors,
const FVectorNode &source_node,
const List<FVectorNode> &neighbors)
const FVectorNode &node)
{
//dbug_print_vec_ref("Updating second degree neighbors", layer, source_node);
//dbug_print_vec_neigh(layer, neighbors);
for (const FVectorNode &neigh: neighbors) // XXX why this loop?
for (const FVectorNode &neigh: node.get_neighbors(layer)) // XXX why this loop?
{
neigh.get_neighbors(layer).push_back(&source_node, &ctx->root);
if (int err= write_neighbors(ctx, layer, neigh, neigh.get_neighbors(layer)))
neigh.get_neighbors(layer).push_back(&node, &ctx->root);
if (int err= write_neighbors(ctx, layer, neigh))
return err;
}
for (const FVectorNode &neigh: neighbors)
for (const FVectorNode &neigh: node.get_neighbors(layer))
{
if (neigh.get_neighbors(layer).elements > max_neighbors)
{
// shrink the neighbors
List<FVectorNode> selected;
if (int err= select_neighbors(ctx, layer, neigh,
neigh.get_neighbors(layer),
max_neighbors, &selected))
neigh.get_neighbors(layer), max_neighbors))
return err;
if (int err= write_neighbors(ctx, layer, neigh, selected))
if (int err= write_neighbors(ctx, layer, neigh))
return err;
// XXX neigh.get_neighbors(layer)= selected;
}
}
......@@ -428,17 +423,14 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
}
static int update_neighbors(MHNSW_Context *ctx,
size_t layer, uint max_neighbors,
const FVectorNode &source_node,
const List<FVectorNode> &neighbors)
static int update_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors, const FVectorNode &node)
{
// 1. update node's neighbors
if (int err= write_neighbors(ctx, layer, source_node, neighbors))
if (int err= write_neighbors(ctx, layer, node))
return err;
// 2. update node's neighbors' neighbors (shrink before update)
return update_second_degree_neighbors(ctx, layer,
max_neighbors, source_node, neighbors);
return update_second_degree_neighbors(ctx, layer, max_neighbors, node);
}
......@@ -571,7 +563,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
// First insert!
FVectorNode target(&ctx, h->ref);
ctx.target= &target;
return write_neighbors(&ctx, 0, target, {});
return write_neighbors(&ctx, 0, target);
}
longlong max_layer= graph->field[0]->val_int();
......@@ -601,6 +593,14 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
longlong new_node_layer= static_cast<longlong>(std::floor(log));
// XXX what is that?
for (longlong cur_layer= new_node_layer; cur_layer >= max_layer + 1;
cur_layer--)
{
if (int err= write_neighbors(&ctx, cur_layer, target))
return err;
}
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{
if (int err= search_layer(&ctx, start_nodes,
......@@ -615,7 +615,6 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= std::min(max_layer, new_node_layer);
cur_layer >= 0; cur_layer--)
{
List<FVectorNode> neighbors;
if (int err= search_layer(&ctx, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer,
&candidates))
......@@ -626,23 +625,14 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
: thd->variables.hnsw_max_connection_per_layer;
if (int err= select_neighbors(&ctx, cur_layer, target, candidates,
max_neighbors, &neighbors))
max_neighbors))
return err;
if (int err= update_neighbors(&ctx, cur_layer, max_neighbors, target,
neighbors))
if (int err= update_neighbors(&ctx, cur_layer, max_neighbors, target))
return err;
start_nodes= candidates;
}
start_nodes.empty();
// XXX what is that?
for (longlong cur_layer= max_layer + 1; cur_layer <= new_node_layer;
cur_layer++)
{
if (int err= write_neighbors(&ctx, cur_layer, target, {}))
return err;
}
dbug_tmp_restore_column_map(&table->read_set, old_map);
return 0;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment