Commit be021b86 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: change storage format

instead of one row per node per layer, have one row per node.
store all neighbors for all layers in that row, and the vector itself too

it completely avoids searches in the graph table and
will allow to implement deletions in the future
parent 33d79a44
......@@ -8088,8 +8088,7 @@ int handler::prepare_for_insert(bool do_create)
return 1;
/* Preparation for unique of blob's */
if (table->s->long_unique_table || table->s->period.unique_keys ||
table->hlindex)
if (table->s->long_unique_table || table->s->period.unique_keys)
{
if (do_create && create_lookup_handler())
return 1;
......
......@@ -18,7 +18,6 @@
#include <my_global.h>
#include "vector_mhnsw.h"
#include "item_vectorfunc.h"
#include "key.h"
#include <scope.h>
// Algorithm parameters
......@@ -32,6 +31,13 @@ static constexpr uint ef_construction= 10;
// sizeof(double) aligned memory to SIMD_word aligned
#define SIMD_margin (SIMD_word - sizeof(double))
enum Graph_table_fields {
FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS
};
enum Graph_table_indices {
IDX_LAYER
};
class MHNSW_Context;
class FVector: public Sql_alloc
......@@ -48,23 +54,36 @@ class FVector: public Sql_alloc
class FVectorNode: public FVector
{
private:
uchar *ref;
List<FVectorNode> *neighbors= nullptr;
char *neighbors_read= 0;
uchar *tref, *gref;
size_t max_layer;
static uchar *gref_max;
int alloc_neighborhood(uint8_t layer);
public:
FVectorNode(MHNSW_Context *ctx_, const void *ref_);
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
List<FVectorNode> *neighbors= nullptr;
FVectorNode(MHNSW_Context *ctx_, const void *gref_);
FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer,
const void *vec_);
float distance_to(const FVector &other) const;
int instantiate_vector();
int instantiate_neighbors(size_t layer);
size_t get_ref_len() const;
uchar *get_ref() const { return ref; }
List<FVectorNode> &get_neighbors(size_t layer) const;
bool is_new() const;
int load();
int load_from_record();
int save();
size_t get_tref_len() const;
uchar *get_tref() const { return tref; }
size_t get_gref_len() const;
uchar *get_gref() const { return gref; }
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
};
// this assumes that 1) rows from graph table are never deleted,
// 2) and thus a ref for a new row is larger than refs of existing rows,
// thus we can treat the not-yet-inserted row as having max possible ref.
// oh, yes, and 3) 8 bytes ought to be enough for everyone
uchar *FVectorNode::gref_max=(uchar*)"\xff\xff\xff\xff\xff\xff\xff\xff";
class MHNSW_Context
{
public:
......@@ -73,7 +92,6 @@ class MHNSW_Context
Field *vec_field;
size_t vec_len= 0;
size_t byte_len= 0;
FVector *target= 0;
uint err= 0;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
......@@ -89,7 +107,12 @@ class MHNSW_Context
free_root(&root, MYF(0));
}
FVectorNode *get_node(const void *ref_);
FVectorNode *get_node(const void *gref);
void set_lengths(size_t len)
{
byte_len= len;
vec_len= MY_ALIGN(byte_len/sizeof(float), SIMD_floats);
}
};
FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
......@@ -99,6 +122,7 @@ FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
void FVector::make_vec(const void *vec_)
{
DBUG_ASSERT(ctx->vec_len);
vec= (float*)alloc_root(&ctx->root,
ctx->vec_len * sizeof(float) + SIMD_margin);
if (int off= ((intptr)vec) % SIMD_word)
......@@ -108,22 +132,23 @@ void FVector::make_vec(const void *vec_)
vec[i]=0;
}
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_)
: FVector(ctx_)
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_)
: FVector(ctx_), tref(nullptr)
{
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
gref= (uchar*)memdup_root(&ctx->root, gref_, get_gref_len());
}
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_)
: FVector(ctx_, vec_)
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer,
const void *vec_)
: FVector(ctx_, vec_), gref(gref_max)
{
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
tref= (uchar*)memdup_root(&ctx->root, tref_, get_tref_len());
alloc_neighborhood(layer);
}
float FVectorNode::distance_to(const FVector &other) const
{
if (!vec)
const_cast<FVectorNode*>(this)->instantiate_vector();
const_cast<FVectorNode*>(this)->load();
#if __GNUC__ > 7
typedef float v8f __attribute__((vector_size(SIMD_word)));
v8f *p1= (v8f*)vec;
......@@ -140,85 +165,91 @@ float FVectorNode::distance_to(const FVector &other) const
#endif
}
int FVectorNode::instantiate_vector()
int FVectorNode::alloc_neighborhood(uint8_t layer)
{
DBUG_ASSERT(vec == nullptr);
if ((ctx->err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref)))
return ctx->err;
String buf, *v= ctx->vec_field->val_str(&buf);
if (unlikely(ctx->byte_len == 0))
{
ctx->byte_len= v->length();
ctx->vec_len= MY_ALIGN(ctx->byte_len/sizeof(float), SIMD_floats);
}
make_vec(v->ptr());
DBUG_ASSERT(!neighbors);
max_layer= layer;
neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
return 0;
}
int FVectorNode::instantiate_neighbors(size_t layer)
int FVectorNode::load()
{
if (!neighbors)
{
neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
neighbors_read= (char*)alloc_root(&ctx->root, layer+1);
bzero(neighbors_read, layer+1);
}
if (!neighbors_read[layer])
{
if (!is_new())
{
DBUG_ASSERT(gref);
if (tref)
return 0;
TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
const size_t ref_len= get_ref_len();
graph->field[0]->store(layer, false);
graph->field[1]->store_binary(ref, ref_len);
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
if ((ctx->err= graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT)))
if ((ctx->err= graph->file->ha_rnd_pos(graph->record[0], gref)))
return ctx->err;
return load_from_record();
}
String strbuf, *str= graph->field[2]->val_str(&strbuf);
if (str->length() % ref_len)
return ctx->err= HA_ERR_CRASHED; // corrupted HNSW index
int FVectorNode::load_from_record()
{
TABLE *graph= ctx->table->hlindex;
String buf, *v= graph->field[FIELD_TREF]->val_str(&buf);
if (unlikely(!v || v->length() != get_tref_len()))
return ctx->err= HA_ERR_CRASHED;
tref= (uchar*)memdup_root(&ctx->root, v->ptr(), v->length());
v= graph->field[FIELD_VEC]->val_str(&buf);
if (unlikely(!v))
return ctx->err= HA_ERR_CRASHED;
DBUG_ASSERT(ctx->byte_len);
if (v->length() != ctx->byte_len)
return ctx->err= HA_ERR_CRASHED;
make_vec(v->ptr());
for (const char *pos= str->ptr(); pos < str->end(); pos+= ref_len)
neighbors[layer].push_back(ctx->get_node(pos), &ctx->root);
}
neighbors_read[layer]= 1;
}
longlong layer= graph->field[FIELD_LAYER]->val_int();
if (layer > 100) // 10e30 nodes at M=2, more at larger M's
return ctx->err= HA_ERR_CRASHED;
return 0;
}
if (alloc_neighborhood(static_cast<uint8_t>(layer)))
return ctx->err;
List<FVectorNode> &FVectorNode::get_neighbors(size_t layer) const
{
const_cast<FVectorNode*>(this)->instantiate_neighbors(layer);
return neighbors[layer];
v= graph->field[FIELD_NEIGHBORS]->val_str(&buf);
if (unlikely(!v))
return ctx->err= HA_ERR_CRASHED;
// <N> <gref> <gref> ... <N> ...etc...
uchar *ptr= (uchar*)v->ptr(), *end= ptr + v->length();
for (size_t i=0; i <= max_layer; i++)
{
if (unlikely(ptr >= end))
return ctx->err= HA_ERR_CRASHED;
size_t grefs= *ptr++;
if (unlikely(ptr + grefs * get_gref_len() > end))
return ctx->err= HA_ERR_CRASHED;
for (; grefs--; ptr+= get_gref_len())
neighbors[i].push_back(ctx->get_node(ptr), &ctx->root);
}
return 0;
}
size_t FVectorNode::get_ref_len() const
size_t FVectorNode::get_tref_len() const
{
return ctx->table->file->ref_length;
}
bool FVectorNode::is_new() const
size_t FVectorNode::get_gref_len() const
{
return this == ctx->target;
return ctx->table->hlindex->file->ref_length;
}
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{
*key_len= elem->get_ref_len();
return elem->ref;
*key_len= elem->get_gref_len();
return elem->gref;
}
FVectorNode *MHNSW_Context::get_node(const void *ref)
FVectorNode *MHNSW_Context::get_node(const void *gref)
{
FVectorNode *node= node_cache.find(ref, table->file->ref_length);
FVectorNode *node= node_cache.find(gref, table->hlindex->file->ref_length);
if (!node)
{
node= new (&root) FVectorNode(this, ref);
node= new (&root) FVectorNode(this, gref);
node_cache.insert(node);
}
return node;
......@@ -237,7 +268,7 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod
}
static int select_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &target,
FVectorNode &target,
const List<FVectorNode> &candidates_unsafe,
size_t max_neighbor_connections)
{
......@@ -245,14 +276,11 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
Queue<FVectorNode, const FVector> pq; // working queue
Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates
/*
make a copy of candidates in case it's target.get_neighbors(layer).
make a copy of candidates in case it's target.neighbors[layer].
because we're going to modify the latter below
*/
List<FVectorNode> candidates= candidates_unsafe;
List<FVectorNode> &neighbors= target.get_neighbors(layer);
if (ctx->err)
return ctx->err;
List<FVectorNode> &neighbors= target.neighbors[layer];
neighbors.empty();
......@@ -273,10 +301,11 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
{
const FVectorNode *vec= pq.pop();
const float target_dist= vec->distance_to(target);
const float target_dista= target_dist / alpha;
bool discard= false;
for (const FVectorNode &neigh : neighbors)
{
if ((discard= vec->distance_to(neigh) * alpha < target_dist))
if ((discard= vec->distance_to(neigh) < target_dista))
break;
}
if (!discard)
......@@ -285,51 +314,49 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
pq_discard.push(vec);
}
while (pq_discard.elements() &&
neighbors.elements < max_neighbor_connections)
while (pq_discard.elements() && neighbors.elements < max_neighbor_connections)
{
neighbors.push_back(pq_discard.pop(), &ctx->root);
const FVectorNode *vec= pq_discard.pop();
neighbors.push_back(vec, &ctx->root);
}
return 0;
}
static int write_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &source_node)
int FVectorNode::save()
{
TABLE *graph= ctx->table->hlindex;
const List<FVectorNode> &new_neighbors= source_node.get_neighbors(layer);
if (ctx->err)
return ctx->err;
DBUG_ASSERT(tref);
DBUG_ASSERT(vec);
DBUG_ASSERT(neighbors);
size_t total_size= new_neighbors.elements * source_node.get_ref_len();
graph->field[FIELD_LAYER]->store(max_layer, false);
graph->field[FIELD_TREF]->set_notnull();
graph->field[FIELD_TREF]->store_binary(tref, get_tref_len());
graph->field[FIELD_VEC]->store_binary((uchar*)vec, ctx->byte_len);
// Allocate memory for the struct and the flexible array member
char *neighbor_array_bytes= static_cast<char *>(my_safe_alloca(total_size));
size_t total_size= 0;
for (size_t i=0; i <= max_layer; i++)
total_size+= 1 + get_gref_len() * neighbors[i].elements;
char *pos= neighbor_array_bytes;
for (const auto &node: new_neighbors)
uchar *neighbor_blob= static_cast<uchar *>(my_safe_alloca(total_size));
uchar *ptr= neighbor_blob;
for (size_t i= 0; i <= max_layer; i++)
{
*ptr++= (uchar)(neighbors[i].elements);
for (const auto &neigh: neighbors[i])
{
DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len());
memcpy(pos, node.get_ref(), node.get_ref_len());
pos+= node.get_ref_len();
memcpy(ptr, neigh.get_gref(), get_gref_len());
ptr+= neigh.get_gref_len();
}
}
graph->field[FIELD_NEIGHBORS]->store_binary(neighbor_blob, total_size);
graph->field[0]->store(layer, false);
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
if (source_node.is_new())
ctx->err= graph->file->ha_write_row(graph->record[0]);
else
if (gref != gref_max)
{
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
ctx->err= graph->file->ha_index_read_map(graph->record[1], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT);
ctx->err= graph->file->ha_rnd_pos(graph->record[1], gref);
if (!ctx->err)
{
ctx->err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
......@@ -337,7 +364,14 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer,
ctx->err= 0;
}
}
my_safe_afree(neighbor_array_bytes, total_size);
else
{
ctx->err= graph->file->ha_write_row(graph->record[0]);
graph->file->position(graph->record[0]);
gref= (uchar*)memdup_root(&ctx->root, graph->file->ref, get_gref_len());
}
my_safe_afree(neighbor_blob, total_size);
return ctx->err;
}
......@@ -346,36 +380,23 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors,
const FVectorNode &node)
{
for (const FVectorNode &neigh: node.get_neighbors(layer))
for (FVectorNode &neigh: node.neighbors[layer])
{
List<FVectorNode> &neighneighbors= neigh.get_neighbors(layer);
if (ctx->err)
return ctx->err;
List<FVectorNode> &neighneighbors= neigh.neighbors[layer];
neighneighbors.push_back(&node, &ctx->root);
if (neighneighbors.elements > max_neighbors)
{
if (select_neighbors(ctx, layer, neigh, neighneighbors, max_neighbors))
return ctx->err;
}
if (write_neighbors(ctx, layer, neigh))
if (neigh.save())
return ctx->err;
}
return ctx->err;
}
static int update_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors, const FVectorNode &node)
{
// 1. update node's neighbors
if (write_neighbors(ctx, layer, node))
return ctx->err;
// 2. update node's neighbors' neighbors (shrink before update)
return update_second_degree_neighbors(ctx, layer, max_neighbors, node);
return 0;
}
static int search_layer(MHNSW_Context *ctx,
static int search_layer(MHNSW_Context *ctx, const FVector &target,
const List<FVectorNode> &start_nodes,
uint max_candidates_return, size_t layer,
List<FVectorNode> *result)
......@@ -386,7 +407,6 @@ static int search_layer(MHNSW_Context *ctx,
Queue<FVectorNode, const FVector> candidates;
Queue<FVectorNode, const FVector> best;
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
const FVector &target= *ctx->target;
candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target);
......@@ -412,7 +432,7 @@ static int search_layer(MHNSW_Context *ctx,
// Can't get better.
}
for (const FVectorNode &neigh: cur_vec.get_neighbors(layer))
for (const FVectorNode &neigh: cur_vec.neighbors[layer])
{
if (visited.find(&neigh))
continue;
......@@ -436,7 +456,7 @@ static int search_layer(MHNSW_Context *ctx,
while (best.elements())
result->push_front(best.pop(), &ctx->root);
return ctx->err;
return 0;
}
......@@ -446,7 +466,6 @@ static int bad_value_on_insert(Field *f)
f->table->s->db.str, f->table->s->table_name.str, f->field_name.str,
f->table->in_use->get_stmt_da()->current_row_for_warning());
return HA_ERR_GENERIC;
}
......@@ -457,7 +476,6 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
Field *vec_field= keyinfo->key_part->field;
String buf, *res= vec_field->val_str(&buf);
handler *h= table->file->lookup_handler;
MHNSW_Context ctx(table, vec_field);
/* metadata are checked on open */
......@@ -467,7 +485,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
DBUG_ASSERT(vec_field->binary());
DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT);
DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL
DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length);
DBUG_ASSERT(table->file->ref_length <= graph->field[FIELD_TREF]->field_length);
// XXX returning an error here will rollback the insert in InnoDB
// but in MyISAM the row will stay inserted, making the index out of sync:
......@@ -480,86 +498,90 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
table->file->position(table->record[0]);
if (int err= h->ha_rnd_init(0))
return err;
SCOPE_EXIT([h](){ h->ha_rnd_end(); });
if (int err= graph->file->ha_index_init(0, 1))
if (int err= graph->file->ha_index_init(IDX_LAYER, 1))
return err;
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
ctx.err= graph->file->ha_index_last(graph->record[0]);
graph->file->ha_index_end();
if ((ctx.err= graph->file->ha_index_last(graph->record[0])))
if (ctx.err)
{
if (ctx.err != HA_ERR_END_OF_FILE)
return ctx.err;
ctx.err= 0;
// First insert!
FVectorNode target(&ctx, table->file->ref);
ctx.target= &target;
return write_neighbors(&ctx, 0, target);
ctx.set_lengths(res->length());
FVectorNode target(&ctx, table->file->ref, 0, res->ptr());
return target.save();
}
longlong max_layer= graph->field[FIELD_LAYER]->val_int();
List<FVectorNode> candidates;
List<FVectorNode> start_nodes;
String ref_str, *ref_ptr;
ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorNode *start_node= ctx.get_node(ref_ptr->ptr());
graph->file->position(graph->record[0]);
FVectorNode *start_node= ctx.get_node(graph->file->ref);
if (start_nodes.push_back(start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM;
if (int err= start_node->instantiate_vector())
ctx.set_lengths(graph->field[FIELD_VEC]->value_length());
if (int err= start_node->load_from_record())
return err;
if (ctx.byte_len != res->length())
return bad_value_on_insert(vec_field);
FVectorNode target(&ctx, table->file->ref, res->ptr());
ctx.target= &target;
if (int err= graph->file->ha_rnd_init(0))
return err;
SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
longlong new_node_layer= static_cast<longlong>(std::floor(log));
longlong max_layer= graph->field[0]->val_int();
longlong new_node_layer= std::min<longlong>(std::floor(log), max_layer + 1);
longlong cur_layer;
if (new_node_layer > max_layer)
{
if (write_neighbors(&ctx, max_layer + 1, target))
return ctx.err;
new_node_layer= max_layer;
}
else
{
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
FVectorNode target(&ctx, table->file->ref, new_node_layer, res->ptr());
for (cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{
if (search_layer(&ctx, start_nodes, 1, cur_layer, &candidates))
if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates))
return ctx.err;
start_nodes= candidates;
candidates.empty();
}
}
for (longlong cur_layer= new_node_layer; cur_layer >= 0; cur_layer--)
for (; cur_layer >= 0; cur_layer--)
{
uint max_neighbors= (cur_layer == 0) // heuristics from the paper
? thd->variables.mhnsw_max_edges_per_node * 2
: thd->variables.mhnsw_max_edges_per_node;
if (search_layer(&ctx, start_nodes, ef_construction, cur_layer,
if (search_layer(&ctx, target, start_nodes, ef_construction, cur_layer,
&candidates))
return ctx.err;
if (select_neighbors(&ctx, cur_layer, target, candidates, max_neighbors))
return ctx.err;
if (update_neighbors(&ctx, cur_layer, max_neighbors, target))
return ctx.err;
start_nodes= candidates;
candidates.empty();
}
if (target.save())
return ctx.err;
for (longlong cur_layer= new_node_layer; cur_layer >= 0; cur_layer--)
{
uint max_neighbors= (cur_layer == 0) // heuristics from the paper
? thd->variables.mhnsw_max_edges_per_node * 2
: thd->variables.mhnsw_max_edges_per_node;
// XXX do only one ha_update_row() per node
if (update_second_degree_neighbors(&ctx, cur_layer, max_neighbors, target))
return ctx.err;
}
dbug_tmp_restore_column_map(&table->read_set, old_map);
return 0;
......@@ -581,26 +603,27 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
if (int err= graph->file->ha_index_init(0, 1))
return err;
ctx.err= graph->file->ha_index_last(graph->record[0]);
graph->file->ha_index_end();
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
if ((ctx.err= graph->file->ha_index_last(graph->record[0])))
if (ctx.err)
return ctx.err;
longlong max_layer= graph->field[0]->val_int();
longlong max_layer= graph->field[FIELD_LAYER]->val_int();
List<FVectorNode> candidates;
List<FVectorNode> start_nodes;
String ref_str, *ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorNode *start_node= ctx.get_node(ref_ptr->ptr());
graph->file->position(graph->record[0]);
FVectorNode *start_node= ctx.get_node(graph->file->ref);
// one could put all max_layer nodes in start_nodes
// but it has no effect of the recall or speed
if (start_nodes.push_back(start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM;
if (int err= start_node->instantiate_vector())
ctx.set_lengths(graph->field[FIELD_VEC]->value_length());
if (int err= start_node->load_from_record())
return err;
/*
......@@ -609,22 +632,26 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
in any order. For simplicity let's sort by the start_node.
*/
if (!res || ctx.byte_len != res->length())
res= vec_field->val_str(&buf);
(res= &buf)->set((char*)start_node->vec, ctx.byte_len, &my_charset_bin);
if (int err= graph->file->ha_rnd_init(0))
return err;
SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
FVector target(&ctx, res->ptr());
ctx.target= &target;
uint ef_search= thd->variables.mhnsw_min_limit;
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
{
if (search_layer(&ctx, start_nodes, 1, cur_layer, &candidates))
if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates))
return ctx.err;
start_nodes= candidates;
candidates.empty();
}
if (search_layer(&ctx, start_nodes, ef_search, 0, &candidates))
if (search_layer(&ctx, target, start_nodes, ef_search, 0, &candidates))
return ctx.err;
size_t context_size=limit * h->ref_length + sizeof(ulonglong);
......@@ -637,7 +664,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
while (limit--)
{
context-= h->ref_length;
memcpy(context, candidates.pop()->get_ref(), h->ref_length);
memcpy(context, candidates.pop()->get_tref(), h->ref_length);
}
DBUG_ASSERT(context - sizeof(ulonglong) == graph->context);
......@@ -658,13 +685,13 @@ int mhnsw_next(TABLE *table)
const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
{
const char templ[]="CREATE TABLE i ( "
" layer int not null, "
" src varbinary(%u) not null, "
" neighbors varbinary(%u) not null,"
" primary key (layer, src)) ";
" layer tinyint not null, "
" tref varbinary(%u), "
" vec blob not null, "
" neighbors blob not null, "
" key (layer)) ";
size_t len= sizeof(templ) + 32;
char *s= thd->alloc(len);
len= my_snprintf(s, len, templ, ref_length, 2 * ref_length *
thd->variables.mhnsw_max_edges_per_node);
len= my_snprintf(s, len, templ, ref_length);
return {s, len};
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment