Commit be021b86 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: change storage format

instead of one row per node per layer, have one row per node.
store all neighbors for all layers in that row, and the vector itself too

it completely avoids searches in the graph table and
will allow to implement deletions in the future
parent 33d79a44
...@@ -8088,8 +8088,7 @@ int handler::prepare_for_insert(bool do_create) ...@@ -8088,8 +8088,7 @@ int handler::prepare_for_insert(bool do_create)
return 1; return 1;
/* Preparation for unique of blob's */ /* Preparation for unique of blob's */
if (table->s->long_unique_table || table->s->period.unique_keys || if (table->s->long_unique_table || table->s->period.unique_keys)
table->hlindex)
{ {
if (do_create && create_lookup_handler()) if (do_create && create_lookup_handler())
return 1; return 1;
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include <my_global.h> #include <my_global.h>
#include "vector_mhnsw.h" #include "vector_mhnsw.h"
#include "item_vectorfunc.h" #include "item_vectorfunc.h"
#include "key.h"
#include <scope.h> #include <scope.h>
// Algorithm parameters // Algorithm parameters
...@@ -32,6 +31,13 @@ static constexpr uint ef_construction= 10; ...@@ -32,6 +31,13 @@ static constexpr uint ef_construction= 10;
// sizeof(double) aligned memory to SIMD_word aligned // sizeof(double) aligned memory to SIMD_word aligned
#define SIMD_margin (SIMD_word - sizeof(double)) #define SIMD_margin (SIMD_word - sizeof(double))
enum Graph_table_fields {
FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS
};
enum Graph_table_indices {
IDX_LAYER
};
class MHNSW_Context; class MHNSW_Context;
class FVector: public Sql_alloc class FVector: public Sql_alloc
...@@ -48,23 +54,36 @@ class FVector: public Sql_alloc ...@@ -48,23 +54,36 @@ class FVector: public Sql_alloc
class FVectorNode: public FVector class FVectorNode: public FVector
{ {
private: private:
uchar *ref; uchar *tref, *gref;
List<FVectorNode> *neighbors= nullptr; size_t max_layer;
char *neighbors_read= 0;
static uchar *gref_max;
int alloc_neighborhood(uint8_t layer);
public: public:
FVectorNode(MHNSW_Context *ctx_, const void *ref_); List<FVectorNode> *neighbors= nullptr;
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
FVectorNode(MHNSW_Context *ctx_, const void *gref_);
FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer,
const void *vec_);
float distance_to(const FVector &other) const; float distance_to(const FVector &other) const;
int instantiate_vector(); int load();
int instantiate_neighbors(size_t layer); int load_from_record();
size_t get_ref_len() const; int save();
uchar *get_ref() const { return ref; } size_t get_tref_len() const;
List<FVectorNode> &get_neighbors(size_t layer) const; uchar *get_tref() const { return tref; }
bool is_new() const; size_t get_gref_len() const;
uchar *get_gref() const { return gref; }
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool); static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
}; };
// this assumes that 1) rows from graph table are never deleted,
// 2) and thus a ref for a new row is larger than refs of existing rows,
// thus we can treat the not-yet-inserted row as having max possible ref.
// oh, yes, and 3) 8 bytes ought to be enough for everyone
uchar *FVectorNode::gref_max=(uchar*)"\xff\xff\xff\xff\xff\xff\xff\xff";
class MHNSW_Context class MHNSW_Context
{ {
public: public:
...@@ -73,7 +92,6 @@ class MHNSW_Context ...@@ -73,7 +92,6 @@ class MHNSW_Context
Field *vec_field; Field *vec_field;
size_t vec_len= 0; size_t vec_len= 0;
size_t byte_len= 0; size_t byte_len= 0;
FVector *target= 0;
uint err= 0; uint err= 0;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key}; Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
...@@ -89,7 +107,12 @@ class MHNSW_Context ...@@ -89,7 +107,12 @@ class MHNSW_Context
free_root(&root, MYF(0)); free_root(&root, MYF(0));
} }
FVectorNode *get_node(const void *ref_); FVectorNode *get_node(const void *gref);
void set_lengths(size_t len)
{
byte_len= len;
vec_len= MY_ALIGN(byte_len/sizeof(float), SIMD_floats);
}
}; };
FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_) FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
...@@ -99,6 +122,7 @@ FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_) ...@@ -99,6 +122,7 @@ FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
void FVector::make_vec(const void *vec_) void FVector::make_vec(const void *vec_)
{ {
DBUG_ASSERT(ctx->vec_len);
vec= (float*)alloc_root(&ctx->root, vec= (float*)alloc_root(&ctx->root,
ctx->vec_len * sizeof(float) + SIMD_margin); ctx->vec_len * sizeof(float) + SIMD_margin);
if (int off= ((intptr)vec) % SIMD_word) if (int off= ((intptr)vec) % SIMD_word)
...@@ -108,22 +132,23 @@ void FVector::make_vec(const void *vec_) ...@@ -108,22 +132,23 @@ void FVector::make_vec(const void *vec_)
vec[i]=0; vec[i]=0;
} }
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_) FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *gref_)
: FVector(ctx_) : FVector(ctx_), tref(nullptr)
{ {
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len()); gref= (uchar*)memdup_root(&ctx->root, gref_, get_gref_len());
} }
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_) FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *tref_, uint8_t layer,
: FVector(ctx_, vec_) const void *vec_)
: FVector(ctx_, vec_), gref(gref_max)
{ {
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len()); tref= (uchar*)memdup_root(&ctx->root, tref_, get_tref_len());
alloc_neighborhood(layer);
} }
float FVectorNode::distance_to(const FVector &other) const float FVectorNode::distance_to(const FVector &other) const
{ {
if (!vec) const_cast<FVectorNode*>(this)->load();
const_cast<FVectorNode*>(this)->instantiate_vector();
#if __GNUC__ > 7 #if __GNUC__ > 7
typedef float v8f __attribute__((vector_size(SIMD_word))); typedef float v8f __attribute__((vector_size(SIMD_word)));
v8f *p1= (v8f*)vec; v8f *p1= (v8f*)vec;
...@@ -140,85 +165,91 @@ float FVectorNode::distance_to(const FVector &other) const ...@@ -140,85 +165,91 @@ float FVectorNode::distance_to(const FVector &other) const
#endif #endif
} }
int FVectorNode::instantiate_vector() int FVectorNode::alloc_neighborhood(uint8_t layer)
{ {
DBUG_ASSERT(vec == nullptr); DBUG_ASSERT(!neighbors);
if ((ctx->err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref))) max_layer= layer;
return ctx->err; neighbors= new (&ctx->root) List<FVectorNode>[layer+1];
String buf, *v= ctx->vec_field->val_str(&buf);
if (unlikely(ctx->byte_len == 0))
{
ctx->byte_len= v->length();
ctx->vec_len= MY_ALIGN(ctx->byte_len/sizeof(float), SIMD_floats);
}
make_vec(v->ptr());
return 0; return 0;
} }
int FVectorNode::instantiate_neighbors(size_t layer) int FVectorNode::load()
{ {
if (!neighbors) DBUG_ASSERT(gref);
{ if (tref)
neighbors= new (&ctx->root) List<FVectorNode>[layer+1]; return 0;
neighbors_read= (char*)alloc_root(&ctx->root, layer+1);
bzero(neighbors_read, layer+1);
}
if (!neighbors_read[layer])
{
if (!is_new())
{
TABLE *graph= ctx->table->hlindex; TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); if ((ctx->err= graph->file->ha_rnd_pos(graph->record[0], gref)))
const size_t ref_len= get_ref_len();
graph->field[0]->store(layer, false);
graph->field[1]->store_binary(ref, ref_len);
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
if ((ctx->err= graph->file->ha_index_read_map(graph->record[0], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT)))
return ctx->err; return ctx->err;
return load_from_record();
}
String strbuf, *str= graph->field[2]->val_str(&strbuf); int FVectorNode::load_from_record()
if (str->length() % ref_len) {
return ctx->err= HA_ERR_CRASHED; // corrupted HNSW index TABLE *graph= ctx->table->hlindex;
String buf, *v= graph->field[FIELD_TREF]->val_str(&buf);
if (unlikely(!v || v->length() != get_tref_len()))
return ctx->err= HA_ERR_CRASHED;
tref= (uchar*)memdup_root(&ctx->root, v->ptr(), v->length());
v= graph->field[FIELD_VEC]->val_str(&buf);
if (unlikely(!v))
return ctx->err= HA_ERR_CRASHED;
DBUG_ASSERT(ctx->byte_len);
if (v->length() != ctx->byte_len)
return ctx->err= HA_ERR_CRASHED;
make_vec(v->ptr());
for (const char *pos= str->ptr(); pos < str->end(); pos+= ref_len) longlong layer= graph->field[FIELD_LAYER]->val_int();
neighbors[layer].push_back(ctx->get_node(pos), &ctx->root); if (layer > 100) // 10e30 nodes at M=2, more at larger M's
} return ctx->err= HA_ERR_CRASHED;
neighbors_read[layer]= 1;
}
return 0; if (alloc_neighborhood(static_cast<uint8_t>(layer)))
} return ctx->err;
List<FVectorNode> &FVectorNode::get_neighbors(size_t layer) const v= graph->field[FIELD_NEIGHBORS]->val_str(&buf);
{ if (unlikely(!v))
const_cast<FVectorNode*>(this)->instantiate_neighbors(layer); return ctx->err= HA_ERR_CRASHED;
return neighbors[layer];
// <N> <gref> <gref> ... <N> ...etc...
uchar *ptr= (uchar*)v->ptr(), *end= ptr + v->length();
for (size_t i=0; i <= max_layer; i++)
{
if (unlikely(ptr >= end))
return ctx->err= HA_ERR_CRASHED;
size_t grefs= *ptr++;
if (unlikely(ptr + grefs * get_gref_len() > end))
return ctx->err= HA_ERR_CRASHED;
for (; grefs--; ptr+= get_gref_len())
neighbors[i].push_back(ctx->get_node(ptr), &ctx->root);
}
return 0;
} }
size_t FVectorNode::get_ref_len() const size_t FVectorNode::get_tref_len() const
{ {
return ctx->table->file->ref_length; return ctx->table->file->ref_length;
} }
bool FVectorNode::is_new() const size_t FVectorNode::get_gref_len() const
{ {
return this == ctx->target; return ctx->table->hlindex->file->ref_length;
} }
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool) uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{ {
*key_len= elem->get_ref_len(); *key_len= elem->get_gref_len();
return elem->ref; return elem->gref;
} }
FVectorNode *MHNSW_Context::get_node(const void *ref) FVectorNode *MHNSW_Context::get_node(const void *gref)
{ {
FVectorNode *node= node_cache.find(ref, table->file->ref_length); FVectorNode *node= node_cache.find(gref, table->hlindex->file->ref_length);
if (!node) if (!node)
{ {
node= new (&root) FVectorNode(this, ref); node= new (&root) FVectorNode(this, gref);
node_cache.insert(node); node_cache.insert(node);
} }
return node; return node;
...@@ -237,7 +268,7 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod ...@@ -237,7 +268,7 @@ static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNod
} }
static int select_neighbors(MHNSW_Context *ctx, size_t layer, static int select_neighbors(MHNSW_Context *ctx, size_t layer,
const FVectorNode &target, FVectorNode &target,
const List<FVectorNode> &candidates_unsafe, const List<FVectorNode> &candidates_unsafe,
size_t max_neighbor_connections) size_t max_neighbor_connections)
{ {
...@@ -245,14 +276,11 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer, ...@@ -245,14 +276,11 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
Queue<FVectorNode, const FVector> pq; // working queue Queue<FVectorNode, const FVector> pq; // working queue
Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates
/* /*
make a copy of candidates in case it's target.get_neighbors(layer). make a copy of candidates in case it's target.neighbors[layer].
because we're going to modify the latter below because we're going to modify the latter below
*/ */
List<FVectorNode> candidates= candidates_unsafe; List<FVectorNode> candidates= candidates_unsafe;
List<FVectorNode> &neighbors= target.get_neighbors(layer); List<FVectorNode> &neighbors= target.neighbors[layer];
if (ctx->err)
return ctx->err;
neighbors.empty(); neighbors.empty();
...@@ -273,10 +301,11 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer, ...@@ -273,10 +301,11 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
{ {
const FVectorNode *vec= pq.pop(); const FVectorNode *vec= pq.pop();
const float target_dist= vec->distance_to(target); const float target_dist= vec->distance_to(target);
const float target_dista= target_dist / alpha;
bool discard= false; bool discard= false;
for (const FVectorNode &neigh : neighbors) for (const FVectorNode &neigh : neighbors)
{ {
if ((discard= vec->distance_to(neigh) * alpha < target_dist)) if ((discard= vec->distance_to(neigh) < target_dista))
break; break;
} }
if (!discard) if (!discard)
...@@ -285,51 +314,49 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer, ...@@ -285,51 +314,49 @@ static int select_neighbors(MHNSW_Context *ctx, size_t layer,
pq_discard.push(vec); pq_discard.push(vec);
} }
while (pq_discard.elements() && while (pq_discard.elements() && neighbors.elements < max_neighbor_connections)
neighbors.elements < max_neighbor_connections)
{ {
neighbors.push_back(pq_discard.pop(), &ctx->root); const FVectorNode *vec= pq_discard.pop();
neighbors.push_back(vec, &ctx->root);
} }
return 0; return 0;
} }
static int write_neighbors(MHNSW_Context *ctx, size_t layer, int FVectorNode::save()
const FVectorNode &source_node)
{ {
TABLE *graph= ctx->table->hlindex; TABLE *graph= ctx->table->hlindex;
const List<FVectorNode> &new_neighbors= source_node.get_neighbors(layer);
if (ctx->err) DBUG_ASSERT(tref);
return ctx->err; DBUG_ASSERT(vec);
DBUG_ASSERT(neighbors);
size_t total_size= new_neighbors.elements * source_node.get_ref_len(); graph->field[FIELD_LAYER]->store(max_layer, false);
graph->field[FIELD_TREF]->set_notnull();
graph->field[FIELD_TREF]->store_binary(tref, get_tref_len());
graph->field[FIELD_VEC]->store_binary((uchar*)vec, ctx->byte_len);
// Allocate memory for the struct and the flexible array member size_t total_size= 0;
char *neighbor_array_bytes= static_cast<char *>(my_safe_alloca(total_size)); for (size_t i=0; i <= max_layer; i++)
total_size+= 1 + get_gref_len() * neighbors[i].elements;
char *pos= neighbor_array_bytes; uchar *neighbor_blob= static_cast<uchar *>(my_safe_alloca(total_size));
for (const auto &node: new_neighbors) uchar *ptr= neighbor_blob;
for (size_t i= 0; i <= max_layer; i++)
{
*ptr++= (uchar)(neighbors[i].elements);
for (const auto &neigh: neighbors[i])
{ {
DBUG_ASSERT(node.get_ref_len() == source_node.get_ref_len()); memcpy(ptr, neigh.get_gref(), get_gref_len());
memcpy(pos, node.get_ref(), node.get_ref_len()); ptr+= neigh.get_gref_len();
pos+= node.get_ref_len();
} }
}
graph->field[FIELD_NEIGHBORS]->store_binary(neighbor_blob, total_size);
graph->field[0]->store(layer, false); if (gref != gref_max)
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
if (source_node.is_new())
ctx->err= graph->file->ha_write_row(graph->record[0]);
else
{ {
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length)); ctx->err= graph->file->ha_rnd_pos(graph->record[1], gref);
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
ctx->err= graph->file->ha_index_read_map(graph->record[1], key,
HA_WHOLE_KEY, HA_READ_KEY_EXACT);
if (!ctx->err) if (!ctx->err)
{ {
ctx->err= graph->file->ha_update_row(graph->record[1], graph->record[0]); ctx->err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
...@@ -337,7 +364,14 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer, ...@@ -337,7 +364,14 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer,
ctx->err= 0; ctx->err= 0;
} }
} }
my_safe_afree(neighbor_array_bytes, total_size); else
{
ctx->err= graph->file->ha_write_row(graph->record[0]);
graph->file->position(graph->record[0]);
gref= (uchar*)memdup_root(&ctx->root, graph->file->ref, get_gref_len());
}
my_safe_afree(neighbor_blob, total_size);
return ctx->err; return ctx->err;
} }
...@@ -346,36 +380,23 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer, ...@@ -346,36 +380,23 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors, uint max_neighbors,
const FVectorNode &node) const FVectorNode &node)
{ {
for (const FVectorNode &neigh: node.get_neighbors(layer)) for (FVectorNode &neigh: node.neighbors[layer])
{ {
List<FVectorNode> &neighneighbors= neigh.get_neighbors(layer); List<FVectorNode> &neighneighbors= neigh.neighbors[layer];
if (ctx->err)
return ctx->err;
neighneighbors.push_back(&node, &ctx->root); neighneighbors.push_back(&node, &ctx->root);
if (neighneighbors.elements > max_neighbors) if (neighneighbors.elements > max_neighbors)
{ {
if (select_neighbors(ctx, layer, neigh, neighneighbors, max_neighbors)) if (select_neighbors(ctx, layer, neigh, neighneighbors, max_neighbors))
return ctx->err; return ctx->err;
} }
if (write_neighbors(ctx, layer, neigh)) if (neigh.save())
return ctx->err; return ctx->err;
} }
return ctx->err; return 0;
}
static int update_neighbors(MHNSW_Context *ctx, size_t layer,
uint max_neighbors, const FVectorNode &node)
{
// 1. update node's neighbors
if (write_neighbors(ctx, layer, node))
return ctx->err;
// 2. update node's neighbors' neighbors (shrink before update)
return update_second_degree_neighbors(ctx, layer, max_neighbors, node);
} }
static int search_layer(MHNSW_Context *ctx, static int search_layer(MHNSW_Context *ctx, const FVector &target,
const List<FVectorNode> &start_nodes, const List<FVectorNode> &start_nodes,
uint max_candidates_return, size_t layer, uint max_candidates_return, size_t layer,
List<FVectorNode> *result) List<FVectorNode> *result)
...@@ -386,7 +407,6 @@ static int search_layer(MHNSW_Context *ctx, ...@@ -386,7 +407,6 @@ static int search_layer(MHNSW_Context *ctx,
Queue<FVectorNode, const FVector> candidates; Queue<FVectorNode, const FVector> candidates;
Queue<FVectorNode, const FVector> best; Queue<FVectorNode, const FVector> best;
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key); Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
const FVector &target= *ctx->target;
candidates.init(10000, false, cmp_vec, &target); candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target); best.init(max_candidates_return, true, cmp_vec, &target);
...@@ -412,7 +432,7 @@ static int search_layer(MHNSW_Context *ctx, ...@@ -412,7 +432,7 @@ static int search_layer(MHNSW_Context *ctx,
// Can't get better. // Can't get better.
} }
for (const FVectorNode &neigh: cur_vec.get_neighbors(layer)) for (const FVectorNode &neigh: cur_vec.neighbors[layer])
{ {
if (visited.find(&neigh)) if (visited.find(&neigh))
continue; continue;
...@@ -436,7 +456,7 @@ static int search_layer(MHNSW_Context *ctx, ...@@ -436,7 +456,7 @@ static int search_layer(MHNSW_Context *ctx,
while (best.elements()) while (best.elements())
result->push_front(best.pop(), &ctx->root); result->push_front(best.pop(), &ctx->root);
return ctx->err; return 0;
} }
...@@ -446,7 +466,6 @@ static int bad_value_on_insert(Field *f) ...@@ -446,7 +466,6 @@ static int bad_value_on_insert(Field *f)
f->table->s->db.str, f->table->s->table_name.str, f->field_name.str, f->table->s->db.str, f->table->s->table_name.str, f->field_name.str,
f->table->in_use->get_stmt_da()->current_row_for_warning()); f->table->in_use->get_stmt_da()->current_row_for_warning());
return HA_ERR_GENERIC; return HA_ERR_GENERIC;
} }
...@@ -457,7 +476,6 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -457,7 +476,6 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set);
Field *vec_field= keyinfo->key_part->field; Field *vec_field= keyinfo->key_part->field;
String buf, *res= vec_field->val_str(&buf); String buf, *res= vec_field->val_str(&buf);
handler *h= table->file->lookup_handler;
MHNSW_Context ctx(table, vec_field); MHNSW_Context ctx(table, vec_field);
/* metadata are checked on open */ /* metadata are checked on open */
...@@ -467,7 +485,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -467,7 +485,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
DBUG_ASSERT(vec_field->binary()); DBUG_ASSERT(vec_field->binary());
DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT); DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT);
DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL
DBUG_ASSERT(h->ref_length <= graph->field[1]->field_length); DBUG_ASSERT(table->file->ref_length <= graph->field[FIELD_TREF]->field_length);
// XXX returning an error here will rollback the insert in InnoDB // XXX returning an error here will rollback the insert in InnoDB
// but in MyISAM the row will stay inserted, making the index out of sync: // but in MyISAM the row will stay inserted, making the index out of sync:
...@@ -480,86 +498,90 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -480,86 +498,90 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
table->file->position(table->record[0]); table->file->position(table->record[0]);
if (int err= h->ha_rnd_init(0)) if (int err= graph->file->ha_index_init(IDX_LAYER, 1))
return err;
SCOPE_EXIT([h](){ h->ha_rnd_end(); });
if (int err= graph->file->ha_index_init(0, 1))
return err; return err;
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); }); ctx.err= graph->file->ha_index_last(graph->record[0]);
graph->file->ha_index_end();
if ((ctx.err= graph->file->ha_index_last(graph->record[0]))) if (ctx.err)
{ {
if (ctx.err != HA_ERR_END_OF_FILE) if (ctx.err != HA_ERR_END_OF_FILE)
return ctx.err; return ctx.err;
ctx.err= 0; ctx.err= 0;
// First insert! // First insert!
FVectorNode target(&ctx, table->file->ref); ctx.set_lengths(res->length());
ctx.target= &target; FVectorNode target(&ctx, table->file->ref, 0, res->ptr());
return write_neighbors(&ctx, 0, target); return target.save();
} }
longlong max_layer= graph->field[FIELD_LAYER]->val_int();
List<FVectorNode> candidates; List<FVectorNode> candidates;
List<FVectorNode> start_nodes; List<FVectorNode> start_nodes;
String ref_str, *ref_ptr;
ref_ptr= graph->field[1]->val_str(&ref_str); graph->file->position(graph->record[0]);
FVectorNode *start_node= ctx.get_node(ref_ptr->ptr()); FVectorNode *start_node= ctx.get_node(graph->file->ref);
if (start_nodes.push_back(start_node, &ctx.root)) if (start_nodes.push_back(start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
if (int err= start_node->instantiate_vector()) ctx.set_lengths(graph->field[FIELD_VEC]->value_length());
if (int err= start_node->load_from_record())
return err; return err;
if (ctx.byte_len != res->length()) if (ctx.byte_len != res->length())
return bad_value_on_insert(vec_field); return bad_value_on_insert(vec_field);
FVectorNode target(&ctx, table->file->ref, res->ptr()); if (int err= graph->file->ha_rnd_init(0))
ctx.target= &target; return err;
SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
double new_num= my_rnd(&thd->rand); double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR; double log= -std::log(new_num) * NORMALIZATION_FACTOR;
longlong new_node_layer= static_cast<longlong>(std::floor(log)); longlong new_node_layer= std::min<longlong>(std::floor(log), max_layer + 1);
longlong max_layer= graph->field[0]->val_int(); longlong cur_layer;
if (new_node_layer > max_layer) FVectorNode target(&ctx, table->file->ref, new_node_layer, res->ptr());
{
if (write_neighbors(&ctx, max_layer + 1, target)) for (cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
return ctx.err;
new_node_layer= max_layer;
}
else
{
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
{ {
if (search_layer(&ctx, start_nodes, 1, cur_layer, &candidates)) if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates))
return ctx.err; return ctx.err;
start_nodes= candidates; start_nodes= candidates;
candidates.empty(); candidates.empty();
} }
}
for (longlong cur_layer= new_node_layer; cur_layer >= 0; cur_layer--) for (; cur_layer >= 0; cur_layer--)
{ {
uint max_neighbors= (cur_layer == 0) // heuristics from the paper uint max_neighbors= (cur_layer == 0) // heuristics from the paper
? thd->variables.mhnsw_max_edges_per_node * 2 ? thd->variables.mhnsw_max_edges_per_node * 2
: thd->variables.mhnsw_max_edges_per_node; : thd->variables.mhnsw_max_edges_per_node;
if (search_layer(&ctx, start_nodes, ef_construction, cur_layer, if (search_layer(&ctx, target, start_nodes, ef_construction, cur_layer,
&candidates)) &candidates))
return ctx.err; return ctx.err;
if (select_neighbors(&ctx, cur_layer, target, candidates, max_neighbors)) if (select_neighbors(&ctx, cur_layer, target, candidates, max_neighbors))
return ctx.err; return ctx.err;
if (update_neighbors(&ctx, cur_layer, max_neighbors, target))
return ctx.err;
start_nodes= candidates; start_nodes= candidates;
candidates.empty(); candidates.empty();
} }
if (target.save())
return ctx.err;
for (longlong cur_layer= new_node_layer; cur_layer >= 0; cur_layer--)
{
uint max_neighbors= (cur_layer == 0) // heuristics from the paper
? thd->variables.mhnsw_max_edges_per_node * 2
: thd->variables.mhnsw_max_edges_per_node;
// XXX do only one ha_update_row() per node
if (update_second_degree_neighbors(&ctx, cur_layer, max_neighbors, target))
return ctx.err;
}
dbug_tmp_restore_column_map(&table->read_set, old_map); dbug_tmp_restore_column_map(&table->read_set, old_map);
return 0; return 0;
...@@ -581,26 +603,27 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -581,26 +603,27 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
if (int err= graph->file->ha_index_init(0, 1)) if (int err= graph->file->ha_index_init(0, 1))
return err; return err;
ctx.err= graph->file->ha_index_last(graph->record[0]);
graph->file->ha_index_end();
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); }); if (ctx.err)
if ((ctx.err= graph->file->ha_index_last(graph->record[0])))
return ctx.err; return ctx.err;
longlong max_layer= graph->field[0]->val_int(); longlong max_layer= graph->field[FIELD_LAYER]->val_int();
List<FVectorNode> candidates; List<FVectorNode> candidates;
List<FVectorNode> start_nodes; List<FVectorNode> start_nodes;
String ref_str, *ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorNode *start_node= ctx.get_node(ref_ptr->ptr()); graph->file->position(graph->record[0]);
FVectorNode *start_node= ctx.get_node(graph->file->ref);
// one could put all max_layer nodes in start_nodes // one could put all max_layer nodes in start_nodes
// but it has no effect of the recall or speed // but it has no effect of the recall or speed
if (start_nodes.push_back(start_node, &ctx.root)) if (start_nodes.push_back(start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM; return HA_ERR_OUT_OF_MEM;
if (int err= start_node->instantiate_vector()) ctx.set_lengths(graph->field[FIELD_VEC]->value_length());
if (int err= start_node->load_from_record())
return err; return err;
/* /*
...@@ -609,22 +632,26 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -609,22 +632,26 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
in any order. For simplicity let's sort by the start_node. in any order. For simplicity let's sort by the start_node.
*/ */
if (!res || ctx.byte_len != res->length()) if (!res || ctx.byte_len != res->length())
res= vec_field->val_str(&buf); (res= &buf)->set((char*)start_node->vec, ctx.byte_len, &my_charset_bin);
if (int err= graph->file->ha_rnd_init(0))
return err;
SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); });
FVector target(&ctx, res->ptr()); FVector target(&ctx, res->ptr());
ctx.target= &target;
uint ef_search= thd->variables.mhnsw_min_limit; uint ef_search= thd->variables.mhnsw_min_limit;
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--) for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
{ {
if (search_layer(&ctx, start_nodes, 1, cur_layer, &candidates)) if (search_layer(&ctx, target, start_nodes, 1, cur_layer, &candidates))
return ctx.err; return ctx.err;
start_nodes= candidates; start_nodes= candidates;
candidates.empty(); candidates.empty();
} }
if (search_layer(&ctx, start_nodes, ef_search, 0, &candidates)) if (search_layer(&ctx, target, start_nodes, ef_search, 0, &candidates))
return ctx.err; return ctx.err;
size_t context_size=limit * h->ref_length + sizeof(ulonglong); size_t context_size=limit * h->ref_length + sizeof(ulonglong);
...@@ -637,7 +664,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -637,7 +664,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
while (limit--) while (limit--)
{ {
context-= h->ref_length; context-= h->ref_length;
memcpy(context, candidates.pop()->get_ref(), h->ref_length); memcpy(context, candidates.pop()->get_tref(), h->ref_length);
} }
DBUG_ASSERT(context - sizeof(ulonglong) == graph->context); DBUG_ASSERT(context - sizeof(ulonglong) == graph->context);
...@@ -658,13 +685,13 @@ int mhnsw_next(TABLE *table) ...@@ -658,13 +685,13 @@ int mhnsw_next(TABLE *table)
const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length) const LEX_CSTRING mhnsw_hlindex_table_def(THD *thd, uint ref_length)
{ {
const char templ[]="CREATE TABLE i ( " const char templ[]="CREATE TABLE i ( "
" layer int not null, " " layer tinyint not null, "
" src varbinary(%u) not null, " " tref varbinary(%u), "
" neighbors varbinary(%u) not null," " vec blob not null, "
" primary key (layer, src)) "; " neighbors blob not null, "
" key (layer)) ";
size_t len= sizeof(templ) + 32; size_t len= sizeof(templ) + 32;
char *s= thd->alloc(len); char *s= thd->alloc(len);
len= my_snprintf(s, len, templ, ref_length, 2 * ref_length * len= my_snprintf(s, len, templ, ref_length);
thd->variables.mhnsw_max_edges_per_node);
return {s, len}; return {s, len};
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment