Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
M
MariaDB
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
nexedi
MariaDB
Commits
b2a6dc71
Commit
b2a6dc71
authored
Jun 05, 2024
by
Sergei Golubchik
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
mhnsw: don't guess whether it's insert or update
we know it every time
parent
ab43c7e8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
22 deletions
+32
-22
sql/vector_mhnsw.cc
sql/vector_mhnsw.cc
+32
-22
No files found.
sql/vector_mhnsw.cc
View file @
b2a6dc71
...
...
@@ -65,6 +65,7 @@ class FVectorNode: public FVector
int
instantiate_vector
();
size_t
get_ref_len
()
const
;
uchar
*
get_ref
()
const
{
return
ref
;
}
bool
is_new
()
const
;
static
uchar
*
get_key
(
const
FVectorNode
*
elem
,
size_t
*
key_len
,
my_bool
);
};
...
...
@@ -76,6 +77,7 @@ class MHNSW_Context
TABLE
*
table
;
Field
*
vec_field
;
size_t
vec_len
=
0
;
FVector
*
target
=
0
;
Hash_set
<
FVectorNode
>
node_cache
{
PSI_INSTRUMENT_MEM
,
FVectorNode
::
get_key
};
...
...
@@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const
return
ctx
->
table
->
file
->
ref_length
;
}
bool
FVectorNode
::
is_new
()
const
{
return
this
==
ctx
->
target
;
}
uchar
*
FVectorNode
::
get_key
(
const
FVectorNode
*
elem
,
size_t
*
key_len
,
my_bool
)
{
*
key_len
=
elem
->
get_ref_len
();
...
...
@@ -327,6 +334,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
const
FVectorNode
&
source_node
,
const
List
<
FVectorNode
>
&
new_neighbors
)
{
int
err
;
TABLE
*
graph
=
ctx
->
table
->
hlindex
;
DBUG_ASSERT
(
new_neighbors
.
elements
<=
HNSW_MAX_M
);
...
...
@@ -349,25 +357,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
graph
->
field
[
1
]
->
store_binary
(
source_node
.
get_ref
(),
source_node
.
get_ref_len
());
graph
->
field
[
2
]
->
store_binary
(
neighbor_array_bytes
,
total_size
);
uchar
*
key
=
static_cast
<
uchar
*>
(
alloca
(
graph
->
key_info
->
key_length
));
key_copy
(
key
,
graph
->
record
[
0
],
graph
->
key_info
,
graph
->
key_info
->
key_length
);
// XXX try to write first?
int
err
=
graph
->
file
->
ha_index_read_map
(
graph
->
record
[
1
],
key
,
HA_WHOLE_KEY
,
HA_READ_KEY_EXACT
);
// no record
if
(
err
==
HA_ERR_KEY_NOT_FOUND
)
if
(
source_node
.
is_new
())
{
dbug_print_vec_ref
(
"INSERT "
,
layer_number
,
source_node
);
err
=
graph
->
file
->
ha_write_row
(
graph
->
record
[
0
]);
}
else
if
(
!
err
)
else
{
dbug_print_vec_ref
(
"UPDATE "
,
layer_number
,
source_node
);
dbug_print_vec_neigh
(
layer_number
,
new_neighbors
);
err
=
graph
->
file
->
ha_update_row
(
graph
->
record
[
1
],
graph
->
record
[
0
]);
uchar
*
key
=
static_cast
<
uchar
*>
(
alloca
(
graph
->
key_info
->
key_length
));
key_copy
(
key
,
graph
->
record
[
0
],
graph
->
key_info
,
graph
->
key_info
->
key_length
);
err
=
graph
->
file
->
ha_index_read_map
(
graph
->
record
[
1
],
key
,
HA_WHOLE_KEY
,
HA_READ_KEY_EXACT
);
if
(
!
err
)
err
=
graph
->
file
->
ha_update_row
(
graph
->
record
[
1
],
graph
->
record
[
0
]);
}
my_safe_afree
(
neighbor_array_bytes
,
total_size
);
return
err
;
...
...
@@ -428,7 +435,7 @@ static int update_neighbors(MHNSW_Context *ctx,
}
static
int
search_layer
(
MHNSW_Context
*
ctx
,
const
FVector
&
target
,
static
int
search_layer
(
MHNSW_Context
*
ctx
,
const
List
<
FVectorNode
>
&
start_nodes
,
uint
max_candidates_return
,
size_t
layer
,
List
<
FVectorNode
>
*
result
)
...
...
@@ -439,6 +446,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
Queue
<
FVectorNode
,
const
FVector
>
candidates
;
Queue
<
FVectorNode
,
const
FVector
>
best
;
Hash_set
<
FVectorNode
>
visited
(
PSI_INSTRUMENT_MEM
,
FVectorNode
::
get_key
);
const
FVector
&
target
=
*
ctx
->
target
;
candidates
.
init
(
10000
,
false
,
cmp_vec
,
&
target
);
best
.
init
(
max_candidates_return
,
true
,
cmp_vec
,
&
target
);
...
...
@@ -550,20 +558,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
SCOPE_EXIT
([
graph
](){
graph
->
file
->
ha_index_end
();
});
h
->
position
(
table
->
record
[
0
]);
if
(
int
err
=
graph
->
file
->
ha_index_last
(
graph
->
record
[
0
]))
{
if
(
err
!=
HA_ERR_END_OF_FILE
)
return
err
;
// First insert!
h
->
position
(
table
->
record
[
0
]);
return
write_neighbors
(
&
ctx
,
0
,
{
&
ctx
,
h
->
ref
},
{});
FVectorNode
target
(
&
ctx
,
h
->
ref
);
ctx
.
target
=
&
target
;
return
write_neighbors
(
&
ctx
,
0
,
target
,
{});
}
longlong
max_layer
=
graph
->
field
[
0
]
->
val_int
();
h
->
position
(
table
->
record
[
0
]);
List
<
FVectorNode
>
candidates
;
List
<
FVectorNode
>
start_nodes
;
String
ref_str
,
*
ref_ptr
;
...
...
@@ -583,6 +592,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
return
bad_value_on_insert
(
vec_field
);
FVectorNode
target
(
&
ctx
,
h
->
ref
,
res
->
ptr
());
ctx
.
target
=
&
target
;
double
new_num
=
my_rnd
(
&
thd
->
rand
);
double
log
=
-
std
::
log
(
new_num
)
*
NORMALIZATION_FACTOR
;
...
...
@@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
for
(
longlong
cur_layer
=
max_layer
;
cur_layer
>
new_node_layer
;
cur_layer
--
)
{
if
(
int
err
=
search_layer
(
&
ctx
,
target
,
start_nodes
,
if
(
int
err
=
search_layer
(
&
ctx
,
start_nodes
,
thd
->
variables
.
hnsw_ef_constructor
,
cur_layer
,
&
candidates
))
return
err
;
...
...
@@ -603,7 +613,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
cur_layer
>=
0
;
cur_layer
--
)
{
List
<
FVectorNode
>
neighbors
;
if
(
int
err
=
search_layer
(
&
ctx
,
target
,
start_nodes
,
if
(
int
err
=
search_layer
(
&
ctx
,
start_nodes
,
thd
->
variables
.
hnsw_ef_constructor
,
cur_layer
,
&
candidates
))
return
err
;
...
...
@@ -682,6 +692,7 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
res
=
vec_field
->
val_str
(
&
buf
);
FVector
target
(
&
ctx
,
res
->
ptr
());
ctx
.
target
=
&
target
;
ulonglong
ef_search
=
std
::
max
<
ulonglong
>
(
//XXX why not always limit?
thd
->
variables
.
hnsw_ef_search
,
limit
);
...
...
@@ -689,16 +700,15 @@ int mhnsw_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
for
(
size_t
cur_layer
=
max_layer
;
cur_layer
>
0
;
cur_layer
--
)
{
//XXX in the paper ef_search=1 here
if
(
int
err
=
search_layer
(
&
ctx
,
target
,
start_nodes
,
ef_search
,
cur_layer
,
&
candidates
))
if
(
int
err
=
search_layer
(
&
ctx
,
start_nodes
,
ef_search
,
cur_layer
,
&
candidates
))
return
err
;
start_nodes
.
empty
();
start_nodes
.
push_back
(
candidates
.
head
(),
&
ctx
.
root
);
// XXX so ef_search=1 ???
candidates
.
empty
();
}
if
(
int
err
=
search_layer
(
&
ctx
,
target
,
start_nodes
,
ef_search
,
0
,
&
candidates
))
if
(
int
err
=
search_layer
(
&
ctx
,
start_nodes
,
ef_search
,
0
,
&
candidates
))
return
err
;
size_t
context_size
=
limit
*
h
->
ref_length
+
sizeof
(
ulonglong
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment