Commit bcbbf333 authored by Sergei Golubchik's avatar Sergei Golubchik

mhnsw: SIMD for euclidean distance

parent f579b988
......@@ -25,6 +25,13 @@
static constexpr float alpha = 1.1f;
static constexpr uint ef_construction= 10;
// SIMD definitions
#define SIMD_word (256/8)
#define SIMD_floats (SIMD_word/sizeof(float))
// how many extra bytes we need to alloc to be able to convert
// sizeof(double) aligned memory to SIMD_word aligned
#define SIMD_margin (SIMD_word - sizeof(double))
class MHNSW_Context;
class FVector: public Sql_alloc
......@@ -35,6 +42,7 @@ class FVector: public Sql_alloc
float *vec;
protected:
FVector(MHNSW_Context *ctx_) : ctx(ctx_), vec(nullptr) {}
void make_vec(const void *vec_);
};
class FVectorNode: public FVector
......@@ -64,6 +72,7 @@ class MHNSW_Context
TABLE *table;
Field *vec_field;
size_t vec_len= 0;
size_t byte_len= 0;
FVector *target= 0;
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
......@@ -84,7 +93,18 @@ class MHNSW_Context
FVector::FVector(MHNSW_Context *ctx_, const void *vec_) : ctx(ctx_)
{
vec= (float*)memdup_root(&ctx->root, vec_, ctx->vec_len * sizeof(float));
make_vec(vec_);
}
void FVector::make_vec(const void *vec_)
{
vec= (float*)alloc_root(&ctx->root,
ctx->vec_len * sizeof(float) + SIMD_margin);
if (int off= ((intptr)vec) % SIMD_word)
vec += (SIMD_word - off) / sizeof(float);
memcpy(vec, vec_, ctx->byte_len);
for (size_t i=ctx->byte_len/sizeof(float); i < ctx->vec_len; i++)
vec[i]=0;
}
FVectorNode::FVectorNode(MHNSW_Context *ctx_, const void *ref_)
......@@ -103,7 +123,20 @@ float FVectorNode::distance_to(const FVector &other) const
{
if (!vec)
const_cast<FVectorNode*>(this)->instantiate_vector();
#if __GNUC__ > 7
typedef float v8f __attribute__((vector_size(SIMD_word)));
v8f *p1= (v8f*)vec;
v8f *p2= (v8f*)other.vec;
v8f d= {0,0,0,0,0,0,0,0};
for (size_t i= 0; i < ctx->vec_len/SIMD_floats; p1++, p2++, i++)
{
v8f dist= *p1 - *p2;
d+= dist * dist;
}
return d[0] + d[1] + d[2] + d[3] + d[4] + d[5] + d[6] + d[7];
#else
return euclidean_vec_distance(vec, other.vec, ctx->vec_len);
#endif
}
int FVectorNode::instantiate_vector()
......@@ -112,8 +145,12 @@ int FVectorNode::instantiate_vector()
if (int err= ctx->table->file->ha_rnd_pos(ctx->table->record[0], ref))
return err;
String buf, *v= ctx->vec_field->val_str(&buf);
ctx->vec_len= v->length() / sizeof(float);
vec= (float*)memdup_root(&ctx->root, v->ptr(), v->length());
if (unlikely(ctx->byte_len == 0))
{
ctx->byte_len= v->length();
ctx->vec_len= MY_ALIGN(ctx->byte_len/sizeof(float), SIMD_floats);
}
make_vec(v->ptr());
return 0;
}
......@@ -469,7 +506,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
if (int err= start_node->instantiate_vector())
return err;
if (ctx.vec_len * sizeof(float) != res->length())
if (ctx.byte_len != res->length())
return bad_value_on_insert(vec_field);
FVectorNode target(&ctx, table->file->ref, res->ptr());
......@@ -563,7 +600,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
NULL, so the result is basically unsorted, we can return rows
in any order. For simplicity let's sort by the start_node.
*/
if (!res || ctx.vec_len * sizeof(float) != res->length())
if (!res || ctx.byte_len != res->length())
res= vec_field->val_str(&buf);
FVector target(&ctx, res->ptr());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment