Commit b2a6dc71 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: don't guess whether it's insert or update

we know it every time
parent ab43c7e8
......@@ -65,6 +65,7 @@ class FVectorNode: public FVector
int instantiate_vector();
size_t get_ref_len() const;
uchar *get_ref() const { return ref; }
bool is_new() const;
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
};
......@@ -76,6 +77,7 @@ class MHNSW_Context
TABLE *table;
Field *vec_field;
size_t vec_len= 0;
FVector *target= 0;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
......@@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const
return ctx->table->file->ref_length;
}
bool FVectorNode::is_new() const
{
return this == ctx->target;
}
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{
*key_len= elem->get_ref_len();
......@@ -327,6 +334,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorNode &source_node,
const List<FVectorNode> &new_neighbors)
{
int err;
TABLE *graph= ctx->table->hlindex;
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
......@@ -349,25 +357,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
// XXX try to write first?
int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY,
HA_READ_KEY_EXACT);
// no record
if (err == HA_ERR_KEY_NOT_FOUND)
if (source_node.is_new())
{
dbug_print_vec_ref("INSERT ", layer_number, source_node);
err= graph->file->ha_write_row(graph->record[0]);
}
else if (!err)
else
{
dbug_print_vec_ref("UPDATE ", layer_number, source_node);
dbug_print_vec_neigh(layer_number, new_neighbors);
err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
err= graph->file->ha_index_read_map(graph->record[1], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT);
if (!err)
err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
}
my_safe_afree(neighbor_array_bytes, total_size);
return err;
......@@ -428,7 +435,7 @@ static int update_neighbors(MHNSW_Context *ctx,
}
static int search_layer(MHNSW_Context *ctx, const FVector &target,
static int search_layer(MHNSW_Context *ctx,
const List<FVectorNode> &start_nodes,
uint max_candidates_return, size_t layer,
List<FVectorNode> *result)
......@@ -439,6 +446,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
Queue<FVectorNode, const FVector> candidates;
Queue<FVectorNode, const FVector> best;
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
const FVector &target= *ctx->target;
candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target);
......@@ -550,20 +558,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
h->position(table->record[0]);
if (int err= graph->file->ha_index_last(graph->record[0]))
{
if (err != HA_ERR_END_OF_FILE)
return err;
// First insert!
h->position(table->record[0]);
return write_neighbors(&ctx, 0, {&ctx, h->ref}, {});
FVectorNode target(&ctx, h->ref);
ctx.target= &target;
return write_neighbors(&ctx, 0, target, {});
}
longlong max_layer= graph->field[0]->val_int();
h->position(table->record[0]);
List<FVectorNode> candidates;
List<FVectorNode> start_nodes;
String ref_str, *ref_ptr;
......@@ -583,6 +592,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
return bad_value_on_insert(vec_field);
FVectorNode target(&ctx, h->ref, res->ptr());
ctx.target= &target;
double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
......@@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{
if (int err= search_layer(&ctx, target, start_nodes,
if (int err= search_layer(&ctx, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer,
&candidates))
return err;
......@@ -603,7 +613,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
cur_layer >= 0; cur_layer--)
{
List<FVectorNode> neighbors;
if (int err= search_layer(&ctx, target, start_nodes,
if (int err= search_layer(&ctx, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer,
&candidates))
return err;
......@@ -682,6 +692,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
res= vec_field->val_str(&buf);
FVector target(&ctx, res->ptr());
ctx.target= &target;
ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
thd->variables.hnsw_ef_search, limit);
......@@ -689,16 +700,15 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
{
//XXX in the paper ef_search=1 here
if (int err= search_layer(&ctx, target, start_nodes, ef_search,
cur_layer, &candidates))
if (int err= search_layer(&ctx, start_nodes, ef_search, cur_layer,
&candidates))
return err;
start_nodes.empty();
start_nodes.push_back(candidates.head(), &ctx.root); // XXX so ef_search=1 ???
candidates.empty();
}
if (int err= search_layer(&ctx, target, start_nodes, ef_search, 0,
&candidates))
if (int err= search_layer(&ctx, start_nodes, ef_search, 0, &candidates))
return err;
size_t context_size=limit * h->ref_length + sizeof(ulonglong);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment