Commit 25b345e6 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: refactor FVector* classes

Now there's an FVector class which is a pure vector, an array of floats.
It doesn't necessarily corresponds to a row in the table, and usually
there is only one FVector instance - the one we're searching for.

And there's an FVectorNode class, which is a node in the graph.
It has a ref (identifying a row in the source table), possibly an array
of floats (or not — in which case it will be read lazily from the
source table as needed). There are many FVectorNodes and they're
cached to avoid re-reading them from the disk.
parent 45b29ebc
......@@ -42,64 +42,31 @@ const LEX_CSTRING mhnsw_hlindex_table={STRING_WITH_LEN("\
")};
class FVectorRef: public Sql_alloc
class MHNSW_Context;
class FVector: public Sql_alloc
{
public:
// Shallow ref copy. Used for other ref lookups in HashSet
FVectorRef(const void *ref, size_t ref_len): ref{(uchar*)ref}, ref_len{ref_len} {}
static uchar *get_key(const FVectorRef *elem, size_t *key_len, my_bool)
{
*key_len= elem->ref_len;
return elem->ref;
}
static void free_vector(void *elem)
{
delete (FVectorRef *)elem;
}
size_t get_ref_len() const { return ref_len; }
uchar* get_ref() const { return ref; }
MHNSW_Context *ctx;
FVector(MHNSW_Context *ctx_, const void *vec_);
float *vec;
protected:
FVectorRef() = default;
uchar *ref;
size_t ref_len;
FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {}
};
class FVector: public FVectorRef
class FVectorNode: public FVector
{
private:
float *vec;
size_t vec_len;
uchar *ref;
public:
FVector(): vec(nullptr), vec_len(0) {}
bool init(MEM_ROOT *root, const uchar *ref_, size_t ref_len_, const void *vec_, size_t bytes)
{
ref= (uchar*)alloc_root(root, ref_len_ + bytes);
if (!ref)
return true;
vec= reinterpret_cast<float *>(ref + ref_len_);
memcpy(ref, ref_, ref_len_);
memcpy(vec, vec_, bytes);
ref_len= ref_len_;
vec_len= bytes / sizeof(float);
return false;
}
size_t size_of() const { return vec_len * sizeof(float); }
float distance_to(const FVector &other) const
{
DBUG_ASSERT(other.vec_len == vec_len);
return euclidean_vec_distance(vec, other.vec, vec_len);
}
FVectorNode(MHNSW_Context *ctx_, const void *ref_);
FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_);
float distance_to(const FVector &other) const;
int instantiate_vector();
size_t get_ref_len() const;
uchar *get_ref() const { return ref; }
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
};
class MHNSW_Context
......@@ -108,8 +75,9 @@ class MHNSW_Context
MEM_ROOT root;
TABLE *table;
Field *vec_field;
Hash_set<FVectorRef> vector_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key};
Hash_set<FVectorRef> vector_ref_cache{PSI_INSTRUMENT_MEM, FVectorRef::get_key};
size_t vec_len= 0;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
MHNSW_Context(TABLE *table, Field *vec_field)
: table(table), vec_field(vec_field)
......@@ -122,40 +90,67 @@ class MHNSW_Context
free_root(&root, MYF(0));
}
FVectorRef *get_fvector_ref(const uchar *ref, size_t ref_len)
{
FVectorRef tmp(ref, ref_len);
FVectorRef *v= vector_ref_cache.find(&tmp);
if (v)
return v;
uchar *buf= (uchar*)memdup_root(&root, ref, ref_len);
if ((v= new (&root) FVectorRef(buf, ref_len)))
vector_ref_cache.insert(v);
return v;
}
FVectorNode *get_node(const void *ref_);
};
FVector *get_fvector_from_source(const FVectorRef &ref)
{
FVectorRef *v= vector_cache.find(&ref);
if (v)
return (FVector *)v;
FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
{
vec= (float*)memdup_root(&ctx->root, vec_, ctx->vec_len * sizeof(float));
}
if (table->file->ha_rnd_pos(table->record[0], ref.get_ref()))
return nullptr; // XXX the error code is lost
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_)
: FVector(ctx_)
{
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
}
String buf, *vec= vec_field->val_str(&buf);
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_, const void *vec_)
: FVector(ctx_, vec_)
{
ref= (uchar*)memdup_root(&ctx->root, ref_, get_ref_len());
}
FVector *new_vector= new (&root) FVector;
new_vector->init(&root, ref.get_ref(), ref.get_ref_len(), vec->ptr(), vec->length());
float FVectorNode::distance_to(const FVector &other) const
{
if (!vec)
const_cast<FVectorNode*>(this)->instantiate_vector();
return euclidean_vec_distance(vec, other.vec, ctx->vec_len);
}
vector_cache.insert(new_vector);
int FVectorNode::instantiate_vector()
{
DBUG_ASSERT(vec == nullptr);
if (int err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref))
return err;
String buf, *v= ctx->vec_field->val_str(&buf);
ctx->vec_len= v->length() / sizeof(float);
vec= (float*)memdup_root(&ctx->root, v->ptr(), v->length());
return 0;
}
size_t FVectorNode::get_ref_len() const
{
return ctx->table->file->ref_length;
}
return new_vector;
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
{
*key_len= elem->get_ref_len();
return elem->ref;
}
FVectorNode *MHNSW_Context::get_node(const void *ref)
{
FVectorNode *node= node_cache.find(ref, table->file->ref_length);
if (!node)
{
node= new (&root) FVectorNode(this, ref);
node_cache.insert(node);
}
};
return node;
}
static int cmp_vec(const FVector *target, const FVector *a, const FVector *b)
static int cmp_vec(const FVector *target, const FVectorNode *a, const FVectorNode *b)
{
float a_dist= a->distance_to(*target);
float b_dist= b->distance_to(*target);
......@@ -171,8 +166,8 @@ const bool KEEP_PRUNED_CONNECTIONS=true; // XXX why?
const bool EXTEND_CANDIDATES=true; // XXX or false?
static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorRef &source_node,
List<FVectorRef> *neighbors)
const FVectorNode &source_node,
List<FVectorNode> *neighbors)
{
TABLE *graph= ctx->table->hlindex;
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
......@@ -189,18 +184,16 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
// mhnsw_insert() guarantees that all ref have the same length
uint ref_length= source_node.get_ref_len();
const uchar *neigh_arr_bytes= reinterpret_cast<const uchar *>(str->ptr());
const char *neigh_arr_bytes= str->ptr();
uint number_of_neighbors= HNSW_MAX_M_read(neigh_arr_bytes);
if (number_of_neighbors * ref_length + HNSW_MAX_M_WIDTH != str->length())
return HA_ERR_CRASHED; // should not happen, corrupted HNSW index
const uchar *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
const char *pos= neigh_arr_bytes + HNSW_MAX_M_WIDTH;
for (uint i= 0; i < number_of_neighbors; i++)
{
FVectorRef *v= ctx->get_fvector_ref(pos, ref_length);
if (!v)
return HA_ERR_OUT_OF_MEM;
neighbors->push_back(v, &ctx->root);
FVectorNode *neigh= ctx->get_node(pos);
neighbors->push_back(neigh, &ctx->root);
pos+= ref_length;
}
......@@ -210,20 +203,20 @@ static int get_neighbors(MHNSW_Context *ctx, size_t layer_number,
static int select_neighbors(MHNSW_Context *ctx,
size_t layer_number, const FVector &target,
const List<FVectorRef> &candidates,
const List<FVectorNode> &candidates,
size_t max_neighbor_connections,
List<FVectorRef> *neighbors)
List<FVectorNode> *neighbors)
{
/*
TODO: If the input neighbors list is already sorted in search_layer, then
no need to do additional queue build steps here.
*/
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key);
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
Queue<FVector, const FVector> pq; // working queue
Queue<FVector, const FVector> pq_discard; // queue for discarded candidates
Queue<FVector, const FVector> best; // neighbors to return
Queue<FVectorNode, const FVector> pq; // working queue
Queue<FVectorNode, const FVector> pq_discard; // queue for discarded candidates
Queue<FVectorNode, const FVector> best; // neighbors to return
// TODO(cvicentiu) this 1000 here is a hardcoded value for max queue size.
// This should not be fixed.
......@@ -232,32 +225,26 @@ static int select_neighbors(MHNSW_Context *ctx,
best.init(max_neighbor_connections, true, cmp_vec, &target))
return HA_ERR_OUT_OF_MEM;
for (const FVectorRef &candidate : candidates)
for (const FVectorNode &candidate : candidates)
{
FVector *v= ctx->get_fvector_from_source(candidate);
if (!v)
return HA_ERR_OUT_OF_MEM;
visited.insert(&candidate);
pq.push(v);
pq.push(&candidate);
}
if (EXTEND_CANDIDATES)
{
for (const FVectorRef &candidate : candidates)
for (const FVectorNode &candidate : candidates)
{
List<FVectorRef> candidate_neighbors;
List<FVectorNode> candidate_neighbors;
if (int err= get_neighbors(ctx, layer_number, candidate,
&candidate_neighbors))
return err;
for (const FVectorRef &extra_candidate : candidate_neighbors)
for (const FVectorNode &extra_candidate : candidate_neighbors)
{
if (visited.find(&extra_candidate))
continue;
visited.insert(&extra_candidate);
FVector *v= ctx->get_fvector_from_source(extra_candidate);
if (!v)
return HA_ERR_OUT_OF_MEM;
pq.push(v);
pq.push(&extra_candidate);
}
}
}
......@@ -268,7 +255,7 @@ static int select_neighbors(MHNSW_Context *ctx,
float best_top= best.top()->distance_to(target);
while (pq.elements() && best.elements() < max_neighbor_connections)
{
const FVector *vec= pq.pop();
const FVectorNode *vec= pq.pop();
const float cur_dist= vec->distance_to(target);
if (cur_dist < best_top)
{
......@@ -298,7 +285,7 @@ static int select_neighbors(MHNSW_Context *ctx,
static void dbug_print_vec_ref(const char *prefix, uint layer,
const FVectorRef &ref)
const FVectorNode &ref)
{
#ifndef DBUG_OFF
// TODO(cvicentiu) disable this in release build.
......@@ -313,21 +300,21 @@ static void dbug_print_vec_ref(const char *prefix, uint layer,
#endif
}
static void dbug_print_vec_neigh(uint layer, const List<FVectorRef> &neighbors)
static void dbug_print_vec_neigh(uint layer, const List<FVectorNode> &neighbors)
{
#ifndef DBUG_OFF
DBUG_PRINT("VECTOR", ("NEIGH: NUM: %d", neighbors.elements));
for (const FVectorRef& ref : neighbors)
for (const FVectorNode& ref : neighbors)
{
dbug_print_vec_ref("NEIGH: ", layer, ref);
}
#endif
}
static void dbug_print_hash_vec(Hash_set<FVectorRef> &h)
static void dbug_print_hash_vec(Hash_set<FVectorNode> &h)
{
#ifndef DBUG_OFF
for (FVectorRef &ptr : h)
for (FVectorNode &ptr : h)
{
DBUG_PRINT("VECTOR", ("HASH elem: %p", &ptr));
dbug_print_vec_ref("VISITED: ", 0, ptr);
......@@ -337,8 +324,8 @@ static void dbug_print_hash_vec(Hash_set<FVectorRef> &h)
static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
const FVectorRef &source_node,
const List<FVectorRef> &new_neighbors)
const FVectorNode &source_node,
const List<FVectorNode> &new_neighbors)
{
TABLE *graph= ctx->table->hlindex;
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
......@@ -390,14 +377,14 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
static int update_second_degree_neighbors(MHNSW_Context *ctx,
size_t layer_number,
uint max_neighbors,
const FVectorRef &source_node,
const List<FVectorRef> &neighbors)
const FVectorNode &source_node,
const List<FVectorNode> &neighbors)
{
//dbug_print_vec_ref("Updating second degree neighbors", layer_number, source_node);
//dbug_print_vec_neigh(layer_number, neighbors);
for (const FVectorRef &neigh: neighbors) // XXX why this loop?
for (const FVectorNode &neigh: neighbors) // XXX why this loop?
{
List<FVectorRef> new_neighbors;
List<FVectorNode> new_neighbors;
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err;
new_neighbors.push_back(&source_node, &ctx->root);
......@@ -405,20 +392,17 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
return err;
}
for (const FVectorRef &neigh: neighbors)
for (const FVectorNode &neigh: neighbors)
{
List<FVectorRef> new_neighbors;
List<FVectorNode> new_neighbors;
if (int err= get_neighbors(ctx, layer_number, neigh, &new_neighbors))
return err;
if (new_neighbors.elements > max_neighbors)
{
// shrink the neighbors
List<FVectorRef> selected;
FVector *v= ctx->get_fvector_from_source(neigh);
if (!v)
return HA_ERR_OUT_OF_MEM;
if (int err= select_neighbors(ctx, layer_number, *v,
List<FVectorNode> selected;
if (int err= select_neighbors(ctx, layer_number, neigh,
new_neighbors, max_neighbors, &selected))
return err;
if (int err= write_neighbors(ctx, layer_number, neigh, selected))
......@@ -432,8 +416,8 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx,
static int update_neighbors(MHNSW_Context *ctx,
size_t layer_number, uint max_neighbors,
const FVectorRef &source_node,
const List<FVectorRef> &neighbors)
const FVectorNode &source_node,
const List<FVectorNode> &neighbors)
{
// 1. update node's neighbors
if (int err= write_neighbors(ctx, layer_number, source_node, neighbors))
......@@ -445,36 +429,35 @@ static int update_neighbors(MHNSW_Context *ctx,
static int search_layer(MHNSW_Context *ctx, const FVector &target,
const List<FVectorRef> &start_nodes,
const List<FVectorNode> &start_nodes,
uint max_candidates_return, size_t layer,
List<FVectorRef> *result)
List<FVectorNode> *result)
{
DBUG_ASSERT(start_nodes.elements > 0);
DBUG_ASSERT(result->elements == 0);
Queue<FVector, const FVector> candidates;
Queue<FVector, const FVector> best;
Hash_set<FVectorRef> visited(PSI_INSTRUMENT_MEM, FVectorRef::get_key);
Queue<FVectorNode, const FVector> candidates;
Queue<FVectorNode, const FVector> best;
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
candidates.init(10000, false, cmp_vec, &target);
best.init(max_candidates_return, true, cmp_vec, &target);
for (const FVectorRef &node : start_nodes)
for (const FVectorNode &node : start_nodes)
{
FVector *v= ctx->get_fvector_from_source(node);
candidates.push(v);
candidates.push(&node);
if (best.elements() < max_candidates_return)
best.push(v);
else if (v->distance_to(target) > best.top()->distance_to(target))
best.replace_top(v);
visited.insert(v);
best.push(&node);
else if (node.distance_to(target) > best.top()->distance_to(target))
best.replace_top(&node);
visited.insert(&node);
dbug_print_vec_ref("INSERTING node in visited: ", layer, node);
}
float furthest_best= best.top()->distance_to(target);
while (candidates.elements())
{
const FVector &cur_vec= *candidates.pop();
const FVectorNode &cur_vec= *candidates.pop();
float cur_distance= cur_vec.distance_to(target);
if (cur_distance > furthest_best && best.elements() == max_candidates_return)
{
......@@ -482,27 +465,26 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
// Can't get better.
}
List<FVectorRef> neighbors;
List<FVectorNode> neighbors;
get_neighbors(ctx, layer, cur_vec, &neighbors);
for (const FVectorRef &neigh: neighbors)
for (const FVectorNode &neigh: neighbors)
{
dbug_print_hash_vec(visited);
if (visited.find(&neigh))
continue;
FVector *clone= ctx->get_fvector_from_source(neigh);
visited.insert(clone);
visited.insert(&neigh);
if (best.elements() < max_candidates_return)
{
candidates.push(clone);
best.push(clone);
candidates.push(&neigh);
best.push(&neigh);
furthest_best= best.top()->distance_to(target);
}
else if (clone->distance_to(target) < furthest_best)
else if (neigh.distance_to(target) < furthest_best)
{
best.replace_top(clone);
candidates.push(clone);
best.replace_top(&neigh);
candidates.push(&neigh);
furthest_best= best.top()->distance_to(target);
}
}
......@@ -575,34 +557,32 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
// First insert!
h->position(table->record[0]);
return write_neighbors(&ctx, 0, {h->ref, h->ref_length}, {});
return write_neighbors(&ctx, 0, {&ctx, h->ref}, {});
}
longlong max_layer= graph->field[0]->val_int();
h->position(table->record[0]);
List<FVectorRef> candidates;
List<FVectorRef> start_nodes;
List<FVectorNode> candidates;
List<FVectorNode> start_nodes;
String ref_str, *ref_ptr;
ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()};
FVectorNode start_node(&ctx, ref_ptr->ptr());
// TODO(cvicentiu) use a random start node in last layer.
// XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref, &ctx.root))
if (start_nodes.push_back(&start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM;
FVector *v= ctx.get_fvector_from_source(start_node_ref);
if (!v)
return HA_ERR_OUT_OF_MEM;
if (int err= start_node.instantiate_vector())
return err;
if (v->size_of() != res->length())
if (ctx.vec_len * sizeof(float) != res->length())
return bad_value_on_insert(vec_field);
FVector target;
target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length());
FVectorNode target(&ctx, h->ref, res->ptr());
double new_num= my_rnd(&thd->rand);
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
......@@ -622,7 +602,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for (longlong cur_layer= std::min(max_layer, new_node_layer);
cur_layer >= 0; cur_layer--)
{
List<FVectorRef> neighbors;
List<FVectorNode> neighbors;
if (int err= search_layer(&ctx, target, start_nodes,
thd->variables.hnsw_ef_constructor, cur_layer,
&candidates))
......@@ -679,33 +659,29 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
longlong max_layer= graph->field[0]->val_int();
List<FVectorRef> candidates; // XXX List? not Queue by distance?
List<FVectorRef> start_nodes;
String ref_str, *ref_ptr;
List<FVectorNode> candidates; // XXX List? not Queue by distance?
List<FVectorNode> start_nodes;
String ref_str, *ref_ptr= graph->field[1]->val_str(&ref_str);
ref_ptr= graph->field[1]->val_str(&ref_str);
FVectorRef start_node_ref{ref_ptr->ptr(), ref_ptr->length()};
FVectorNode start_node(&ctx, ref_ptr->ptr());
// TODO(cvicentiu) use a random start node in last layer.
// XXX or may be *all* nodes in the last layer? there should be few
if (start_nodes.push_back(&start_node_ref, &ctx.root))
if (start_nodes.push_back(&start_node, &ctx.root))
return HA_ERR_OUT_OF_MEM;
FVector *v= ctx.get_fvector_from_source(start_node_ref);
if (!v)
return HA_ERR_OUT_OF_MEM;
if (int err= start_node.instantiate_vector())
return err;
/*
if the query vector is NULL or invalid, VEC_DISTANCE will return
NULL, so the result is basically unsorted, we can return rows
in any order. For simplicity let's sort by the start_node.
*/
if (!res || v->size_of() != res->length())
if (!res || ctx.vec_len * sizeof(float) != res->length())
res= vec_field->val_str(&buf);
FVector target;
if (target.init(&ctx.root, h->ref, h->ref_length, res->ptr(), res->length()))
return HA_ERR_OUT_OF_MEM;
FVector target(&ctx, res->ptr());
ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
thd->variables.hnsw_ef_search, limit);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment