Commit d93f6633 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: auto-tune efConstruction

* remove hard-coded ef_construction_multiplier
* instead, let ef_construction go up and down automatically as needed
* as needed means that expanding the queue changes the result much
* much is defined by the queue stiffness, as in Hooke's law
* also search_layer() now returns only as many elements as needed, the
  caller no longer needs to overallocate result arrays for throwaway nodes
* change _downheap() to return the position where the element ended up
parent a6c88428
...@@ -77,7 +77,7 @@ void queue_replace(QUEUE *queue,uint idx); ...@@ -77,7 +77,7 @@ void queue_replace(QUEUE *queue,uint idx);
#define queue_remove_all(queue) { (queue)->elements= 0; } #define queue_remove_all(queue) { (queue)->elements= 0; }
#define queue_is_full(queue) ((queue)->elements == (queue)->max_elements) #define queue_is_full(queue) ((queue)->elements == (queue)->max_elements)
void _downheap(QUEUE *queue, uint idx); uint _downheap(QUEUE *queue, uint idx);
void queue_fix(QUEUE *queue); void queue_fix(QUEUE *queue);
#define is_queue_inited(queue) ((queue)->root != 0) #define is_queue_inited(queue) ((queue)->root != 0)
......
...@@ -284,7 +284,7 @@ uchar *queue_remove(QUEUE *queue, uint idx) ...@@ -284,7 +284,7 @@ uchar *queue_remove(QUEUE *queue, uint idx)
idx Index of element to change idx Index of element to change
*/ */
void _downheap(QUEUE *queue, uint idx) uint _downheap(QUEUE *queue, uint idx)
{ {
uchar *element= queue->root[idx]; uchar *element= queue->root[idx];
uint next_index, uint next_index,
...@@ -314,6 +314,7 @@ void _downheap(QUEUE *queue, uint idx) ...@@ -314,6 +314,7 @@ void _downheap(QUEUE *queue, uint idx)
queue->root[idx]=element; queue->root[idx]=element;
if (offset_to_queue_pos) if (offset_to_queue_pos)
(*(uint*) (element + offset_to_queue_pos-1))= idx; (*(uint*) (element + offset_to_queue_pos-1))= idx;
return idx;
} }
......
...@@ -46,11 +46,11 @@ class Queue ...@@ -46,11 +46,11 @@ class Queue
void push(const Element *element) { queue_insert(&m_queue, (uchar*)element); } void push(const Element *element) { queue_insert(&m_queue, (uchar*)element); }
Element *pop() { return (Element *)queue_remove_top(&m_queue); } Element *pop() { return (Element *)queue_remove_top(&m_queue); }
void clear() { queue_remove_all(&m_queue); } void clear() { queue_remove_all(&m_queue); }
void propagate_top() { queue_replace_top(&m_queue); } uint propagate_top() { return queue_replace_top(&m_queue); }
void replace_top(const Element *element) uint replace_top(const Element *element)
{ {
queue_top(&m_queue)= (uchar*)element; queue_top(&m_queue)= (uchar*)element;
propagate_top(); return propagate_top();
} }
private: private:
QUEUE m_queue; QUEUE m_queue;
......
...@@ -29,13 +29,10 @@ ulonglong mhnsw_cache_size; ...@@ -29,13 +29,10 @@ ulonglong mhnsw_cache_size;
#define clo_nei_read float4get #define clo_nei_read float4get
// Algorithm parameters // Algorithm parameters
// best by test (fastest construction with recall > 99% for ef=20, limit=10) static constexpr double alpha = 1.1;
// for random-xs-20-euclidean (9000) [ 3, 1.1, M=7 ] static constexpr double stiffness = 0.002;
// for mnist-784-euclidean (60000) [ 4, 1.1, M=13 ] static constexpr uint ef_construction_max_factor= 16;
// for sift-128-euclidean (1000000) [ 4, 1.1, M>64 ] (98% with M=64) static constexpr uint clo_nei_threshold= 10000;
static const double ef_construction_multiplier = 4;
static const double alpha = 1.1;
static const uint clo_nei_threshold= 10000;
enum Graph_table_fields { enum Graph_table_fields {
FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS FIELD_LAYER, FIELD_TREF, FIELD_VEC, FIELD_NEIGHBORS
...@@ -230,6 +227,7 @@ class MHNSW_Context : public Sql_alloc ...@@ -230,6 +227,7 @@ class MHNSW_Context : public Sql_alloc
{ {
std::atomic<uint> refcnt; std::atomic<uint> refcnt;
std::atomic<double> ef_power; // for the bloom filter size heuristic std::atomic<double> ef_power; // for the bloom filter size heuristic
std::atomic<uint> ef_construction;
mysql_mutex_t cache_lock; mysql_mutex_t cache_lock;
mysql_mutex_t node_lock[8]; mysql_mutex_t node_lock[8];
...@@ -268,6 +266,7 @@ class MHNSW_Context : public Sql_alloc ...@@ -268,6 +266,7 @@ class MHNSW_Context : public Sql_alloc
mysql_mutex_init(PSI_INSTRUMENT_ME, node_lock + i, MY_MUTEX_INIT_SLOW); mysql_mutex_init(PSI_INSTRUMENT_ME, node_lock + i, MY_MUTEX_INIT_SLOW);
init_alloc_root(PSI_INSTRUMENT_MEM, &root, 1024*1024, 0, MYF(0)); init_alloc_root(PSI_INSTRUMENT_MEM, &root, 1024*1024, 0, MYF(0));
set_ef_power(0.6); set_ef_power(0.6);
set_ef_construction(0);
refcnt.store(0, std::memory_order_relaxed); refcnt.store(0, std::memory_order_relaxed);
} }
...@@ -305,6 +304,17 @@ class MHNSW_Context : public Sql_alloc ...@@ -305,6 +304,17 @@ class MHNSW_Context : public Sql_alloc
ef_power.store(x, std::memory_order_relaxed); ef_power.store(x, std::memory_order_relaxed);
} }
uint get_ef_construction()
{
return ef_construction.load(std::memory_order_relaxed);
}
void set_ef_construction(uint x)
{
x= std::min(std::max(x, M), M*ef_construction_max_factor); // safety
ef_construction.store(x, std::memory_order_relaxed);
}
uint max_neighbors(size_t layer) const uint max_neighbors(size_t layer) const
{ {
return (layer ? 1 : 2) * M; // heuristic from the paper return (layer ? 1 : 2) * M; // heuristic from the paper
...@@ -490,6 +500,8 @@ int MHNSW_Trx::MHNSW_hton::do_commit(handlerton *, THD *thd, bool) ...@@ -490,6 +500,8 @@ int MHNSW_Trx::MHNSW_hton::do_commit(handlerton *, THD *thd, bool)
node->vec= nullptr; node->vec= nullptr;
ctx->start= nullptr; ctx->start= nullptr;
} }
if (ctx->get_ef_construction() < trx->get_ef_construction())
ctx->set_ef_construction(trx->get_ef_construction());
ctx->release(true, trx->table_share); ctx->release(true, trx->table_share);
} }
trx->~MHNSW_Trx(); trx->~MHNSW_Trx();
...@@ -507,6 +519,7 @@ MHNSW_Trx *MHNSW_Trx::get_from_thd(THD *thd, TABLE *table) ...@@ -507,6 +519,7 @@ MHNSW_Trx *MHNSW_Trx::get_from_thd(THD *thd, TABLE *table)
trx= new (&thd->transaction->mem_root) MHNSW_Trx(table); trx= new (&thd->transaction->mem_root) MHNSW_Trx(table);
trx->next= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &hton)); trx->next= static_cast<MHNSW_Trx*>(thd_get_ha_data(thd, &hton));
thd_set_ha_data(thd, &hton, trx); thd_set_ha_data(thd, &hton, trx);
// XXX copy ef_construction from MHNSW_Context
if (!trx->next) if (!trx->next)
{ {
bool all= thd_test_options(thd, OPTION_NOT_AUTOCOMMIT | OPTION_BEGIN); bool all= thd_test_options(thd, OPTION_NOT_AUTOCOMMIT | OPTION_BEGIN);
...@@ -703,7 +716,9 @@ struct Visited : public Sql_alloc ...@@ -703,7 +716,9 @@ struct Visited : public Sql_alloc
{ {
FVectorNode *node; FVectorNode *node;
const float distance_to_target; const float distance_to_target;
Visited(FVectorNode *n, float d) : node(n), distance_to_target(d) {} bool expand;
Visited(FVectorNode *n, float d, bool e= false)
: node(n), distance_to_target(d), expand(e) {}
static int cmp(void *, const Visited* a, const Visited *b) static int cmp(void *, const Visited* a, const Visited *b)
{ {
return a->distance_to_target < b->distance_to_target ? -1 : return a->distance_to_target < b->distance_to_target ? -1 :
...@@ -730,9 +745,9 @@ class VisitedSet ...@@ -730,9 +745,9 @@ class VisitedSet
uint count= 0; uint count= 0;
VisitedSet(MEM_ROOT *root, const FVector *target, uint size) : VisitedSet(MEM_ROOT *root, const FVector *target, uint size) :
root(root), target(target), map(size, 0.01) {} root(root), target(target), map(size, 0.01) {}
Visited *create(FVectorNode *node) Visited *create(FVectorNode *node, bool e= false)
{ {
auto *v= new (root) Visited(node, node->distance_to(target)); auto *v= new (root) Visited(node, node->distance_to(target), e);
insert(node); insert(node);
count++; count++;
return v; return v;
...@@ -889,16 +904,34 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph, ...@@ -889,16 +904,34 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
} }
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
Neighborhood *start_nodes, uint ef, size_t layer, Neighborhood *start_nodes, uint result_size,
Neighborhood *result, bool skip_deleted) size_t layer, Neighborhood *result, bool construction)
{ {
DBUG_ASSERT(start_nodes->num > 0); DBUG_ASSERT(start_nodes->num > 0);
result->empty(); result->empty();
MEM_ROOT * const root= graph->in_use->mem_root; MEM_ROOT * const root= graph->in_use->mem_root;
Queue<Visited> candidates, best;
bool skip_deleted;
uint ef= result_size, expand_size= 0;
Queue<Visited> candidates; if (construction)
Queue<Visited> best; {
skip_deleted= false;
if (ef > 1)
{
uint efc= std::max(ctx->get_ef_construction(), ef);
// round down efc/2 to 2^n-1
expand_size= (my_round_up_to_next_power((efc >> 1) + 2) - 1) >> 1;
ef= efc + expand_size;
}
}
else
{
skip_deleted= layer == 0;
if (ef > 1 || layer == 0)
ef= ef * graph->in_use->variables.mhnsw_limit_multiplier;
}
// WARNING! heuristic here // WARNING! heuristic here
const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer)); const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
...@@ -908,23 +941,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, ...@@ -908,23 +941,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
candidates.init(10000, false, Visited::cmp); candidates.init(10000, false, Visited::cmp);
best.init(ef, true, Visited::cmp); best.init(ef, true, Visited::cmp);
DBUG_ASSERT(start_nodes->num <= result_size);
for (size_t i=0; i < start_nodes->num; i++) for (size_t i=0; i < start_nodes->num; i++)
{ {
Visited *v= visited.create(start_nodes->links[i]); Visited *v= visited.create(start_nodes->links[i]);
candidates.push(v); candidates.push(v);
if (skip_deleted && v->node->deleted) if (skip_deleted && v->node->deleted)
continue; continue;
if (best.elements() < ef) best.push(v);
best.push(v);
else if (v->distance_to_target < best.top()->distance_to_target)
best.replace_top(v);
} }
float furthest_best= FLT_MAX; float furthest_best= FLT_MAX;
while (candidates.elements()) while (candidates.elements())
{ {
const Visited &cur= *candidates.pop(); const Visited &cur= *candidates.pop();
if (cur.distance_to_target > furthest_best && best.elements() == ef) if (cur.distance_to_target > furthest_best && best.is_full())
break; // All possible candidates are worse than what we have break; // All possible candidates are worse than what we have
visited.flush(); visited.flush();
...@@ -943,8 +974,8 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, ...@@ -943,8 +974,8 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
continue; continue;
if (int err= links[i]->load(graph)) if (int err= links[i]->load(graph))
return err; return err;
Visited *v= visited.create(links[i]); Visited *v= visited.create(links[i], cur.expand);
if (best.elements() < ef) if (!best.is_full())
{ {
candidates.push(v); candidates.push(v);
if (skip_deleted && v->node->deleted) if (skip_deleted && v->node->deleted)
...@@ -957,7 +988,8 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, ...@@ -957,7 +988,8 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
candidates.push(v); candidates.push(v);
if (skip_deleted && v->node->deleted) if (skip_deleted && v->node->deleted)
continue; continue;
best.replace_top(v); if (best.replace_top(v) <= expand_size)
v->expand= true;
furthest_best= best.top()->distance_to_target; furthest_best= best.top()->distance_to_target;
} }
} }
...@@ -966,9 +998,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target, ...@@ -966,9 +998,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
if (ef > 1 && visited.count*2 > est_size) if (ef > 1 && visited.count*2 > est_size)
ctx->set_ef_power(std::log(visited.count*2/est_heuristic) / std::log(ef)); ctx->set_ef_power(std::log(visited.count*2/est_heuristic) / std::log(ef));
while (best.elements() > result_size)
best.pop();
uint expanded= 0;
result->num= best.elements(); result->num= best.elements();
for (FVectorNode **links= result->links + result->num; best.elements();) for (FVectorNode **links= result->links + result->num; best.elements();)
{
expanded+= best.top()->expand;
*--links= best.pop()->node; *--links= best.pop()->node;
}
if (expanded && expanded > stiffness*expand_size*result_size) // Hooke's law
ctx->set_ef_construction(ef);
else if (expand_size)
ctx->set_ef_construction(ef - expand_size - 1); // decrease slowly
return 0; return 0;
} }
...@@ -1029,10 +1073,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -1029,10 +1073,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
if (ctx->byte_len != res->length()) if (ctx->byte_len != res->length())
return bad_value_on_insert(vec_field); return bad_value_on_insert(vec_field);
size_t ef= ctx->max_neighbors(0) * ef_construction_multiplier; const size_t max_found= ctx->max_neighbors(0);
Neighborhood candidates, start_nodes; Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef); candidates.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef); start_nodes.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
start_nodes.links[start_nodes.num++]= ctx->start; start_nodes.links[start_nodes.num++]= ctx->start;
const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M); const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
...@@ -1060,8 +1104,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) ...@@ -1060,8 +1104,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
{ {
uint max_neighbors= ctx->max_neighbors(cur_layer); uint max_neighbors= ctx->max_neighbors(cur_layer);
if (int err= search_layer(ctx, graph, target->vec, &start_nodes, if (int err= search_layer(ctx, graph, target->vec, &start_nodes,
ef_construction_multiplier * max_neighbors, max_neighbors, cur_layer, &candidates, true))
cur_layer, &candidates, false))
return err; return err;
if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates, if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
...@@ -1103,13 +1146,9 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -1103,13 +1146,9 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
return err; return err;
SCOPE_EXIT([ctx, table](){ ctx->release(table); }); SCOPE_EXIT([ctx, table](){ ctx->release(table); });
// this auto-scales ef with the limit, providing more adequate
// behavior than a fixed ef
size_t ef= limit * thd->variables.mhnsw_limit_multiplier;
Neighborhood candidates, start_nodes; Neighborhood candidates, start_nodes;
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef); candidates.init(thd->alloc<FVectorNode*>(limit + 7), limit);
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef); start_nodes.init(thd->alloc<FVectorNode*>(limit + 7), limit);
// one could put all max_layer nodes in start_nodes // one could put all max_layer nodes in start_nodes
// but it has no effect of the recall or speed // but it has no effect of the recall or speed
...@@ -1144,8 +1183,8 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit) ...@@ -1144,8 +1183,8 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
std::swap(start_nodes, candidates); std::swap(start_nodes, candidates);
} }
if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0, if (int err= search_layer(ctx, graph, target, &start_nodes, limit, 0,
&candidates, true)) &candidates, false))
return err; return err;
if (limit > candidates.num) if (limit > candidates.num)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment