Commit 81f3e97b authored by Marko Mäkelä's avatar Marko Mäkelä

MDEV-33383: Corrupted red-black tree due to incorrect comparison

fts_doc_id_cmp(): Replaces several duplicated functions for
comparing two doc_id_t*. On IA-32, AMD64, ARMv7, ARMv8, RISC-V
this should make use of some conditional ALU instructions.
On POWER there will be conditional jumps. Unlike the original
functions, these will return the correct result even if the
difference of the two doc_id does not fit in the int data type.
We use static_assert() and offsetof() to check at compilation time
that this function is compatible with the rbt_create() calls.

fts_query_compare_rank(): As documented, return -1 and not 1
when the rank are equal and r1->doc_id < r2->doc_id. This will
affect the result of ha_innobase::ft_read().

fts_ptr2_cmp(), fts_ptr1_ptr2_cmp(): These replace
fts_trx_table_cmp(), fts_trx_table_id_cmp().
The fts_savepoint_t::tables will be sorted by dict_table_t*
rather than dict_table_t::id. There was no correctness bug in
the previous comparison predicates. We can avoid one level of
unnecessary pointer dereferencing in this way.
Actually, fts_savepoint_t is duplicating trx_t::mod_tables.
MDEV-33401 was filed about removing it.

The added unit test innodb_rbt-t covers both the previous buggy comparison
predicate and the revised fts_doc_id_cmp(), using keys which led to
finding the bug. Thanks to Shaohua Wang from Alibaba for providing the
example and the revised comparison predicate.

Reviewed by: Thirunarayanan Balathandayuthapani
parent 92f87f2c
...@@ -2212,6 +2212,22 @@ fts_trx_row_get_new_state( ...@@ -2212,6 +2212,22 @@ fts_trx_row_get_new_state(
return(result); return(result);
} }
/** Compare two doubly indirected pointers */
static int fts_ptr2_cmp(const void *p1, const void *p2)
{
const void *a= **static_cast<const void*const*const*>(p1);
const void *b= **static_cast<const void*const*const*>(p2);
return b > a ? -1 : a > b;
}
/** Compare a singly indirected pointer to a doubly indirected one */
static int fts_ptr1_ptr2_cmp(const void *p1, const void *p2)
{
const void *a= *static_cast<const void*const*>(p1);
const void *b= **static_cast<const void*const*const*>(p2);
return b > a ? -1 : a > b;
}
/******************************************************************//** /******************************************************************//**
Create a savepoint instance. Create a savepoint instance.
@return savepoint instance */ @return savepoint instance */
...@@ -2234,8 +2250,8 @@ fts_savepoint_create( ...@@ -2234,8 +2250,8 @@ fts_savepoint_create(
savepoint->name = mem_heap_strdup(heap, name); savepoint->name = mem_heap_strdup(heap, name);
} }
savepoint->tables = rbt_create( static_assert(!offsetof(fts_trx_table_t, table), "ABI");
sizeof(fts_trx_table_t*), fts_trx_table_cmp); savepoint->tables = rbt_create(sizeof(fts_trx_table_t*), fts_ptr2_cmp);
return(savepoint); return(savepoint);
} }
...@@ -2283,6 +2299,19 @@ fts_trx_create( ...@@ -2283,6 +2299,19 @@ fts_trx_create(
return(ftt); return(ftt);
} }
/** Compare two doc_id */
static inline int doc_id_cmp(doc_id_t a, doc_id_t b)
{
return b > a ? -1 : a > b;
}
/** Compare two DOC_ID. */
int fts_doc_id_cmp(const void *p1, const void *p2)
{
return doc_id_cmp(*static_cast<const doc_id_t*>(p1),
*static_cast<const doc_id_t*>(p2));
}
/******************************************************************//** /******************************************************************//**
Create an FTS trx table. Create an FTS trx table.
@return FTS trx table */ @return FTS trx table */
...@@ -2301,7 +2330,8 @@ fts_trx_table_create( ...@@ -2301,7 +2330,8 @@ fts_trx_table_create(
ftt->table = table; ftt->table = table;
ftt->fts_trx = fts_trx; ftt->fts_trx = fts_trx;
ftt->rows = rbt_create(sizeof(fts_trx_row_t), fts_trx_row_doc_id_cmp); static_assert(!offsetof(fts_trx_row_t, doc_id), "ABI");
ftt->rows = rbt_create(sizeof(fts_trx_row_t), fts_doc_id_cmp);
return(ftt); return(ftt);
} }
...@@ -2325,7 +2355,8 @@ fts_trx_table_clone( ...@@ -2325,7 +2355,8 @@ fts_trx_table_clone(
ftt->table = ftt_src->table; ftt->table = ftt_src->table;
ftt->fts_trx = ftt_src->fts_trx; ftt->fts_trx = ftt_src->fts_trx;
ftt->rows = rbt_create(sizeof(fts_trx_row_t), fts_trx_row_doc_id_cmp); static_assert(!offsetof(fts_trx_row_t, doc_id), "ABI");
ftt->rows = rbt_create(sizeof(fts_trx_row_t), fts_doc_id_cmp);
/* Copy the rb tree values to the new savepoint. */ /* Copy the rb tree values to the new savepoint. */
rbt_merge_uniq(ftt->rows, ftt_src->rows); rbt_merge_uniq(ftt->rows, ftt_src->rows);
...@@ -2350,13 +2381,9 @@ fts_trx_init( ...@@ -2350,13 +2381,9 @@ fts_trx_init(
{ {
fts_trx_table_t* ftt; fts_trx_table_t* ftt;
ib_rbt_bound_t parent; ib_rbt_bound_t parent;
ib_rbt_t* tables; ib_rbt_t* tables = static_cast<fts_savepoint_t*>(
fts_savepoint_t* savepoint; ib_vector_last(savepoints))->tables;
rbt_search_cmp(tables, &parent, &table, fts_ptr1_ptr2_cmp, nullptr);
savepoint = static_cast<fts_savepoint_t*>(ib_vector_last(savepoints));
tables = savepoint->tables;
rbt_search_cmp(tables, &parent, &table->id, fts_trx_table_id_cmp, NULL);
if (parent.result == 0) { if (parent.result == 0) {
fts_trx_table_t** fttp; fts_trx_table_t** fttp;
...@@ -5638,8 +5665,8 @@ fts_savepoint_rollback_last_stmt( ...@@ -5638,8 +5665,8 @@ fts_savepoint_rollback_last_stmt(
l_ftt = rbt_value(fts_trx_table_t*, node); l_ftt = rbt_value(fts_trx_table_t*, node);
rbt_search_cmp( rbt_search_cmp(
s_tables, &parent, &(*l_ftt)->table->id, s_tables, &parent, &(*l_ftt)->table,
fts_trx_table_id_cmp, NULL); fts_ptr1_ptr2_cmp, nullptr);
if (parent.result == 0) { if (parent.result == 0) {
fts_trx_table_t** s_ftt; fts_trx_table_t** s_ftt;
......
...@@ -385,22 +385,6 @@ fts_query_terms_in_document( ...@@ -385,22 +385,6 @@ fts_query_terms_in_document(
ulint* total); /*!< out: total words in document */ ulint* total); /*!< out: total words in document */
#endif #endif
/********************************************************************
Compare two fts_doc_freq_t doc_ids.
@return < 0 if n1 < n2, 0 if n1 == n2, > 0 if n1 > n2 */
UNIV_INLINE
int
fts_freq_doc_id_cmp(
/*================*/
const void* p1, /*!< in: id1 */
const void* p2) /*!< in: id2 */
{
const fts_doc_freq_t* fq1 = (const fts_doc_freq_t*) p1;
const fts_doc_freq_t* fq2 = (const fts_doc_freq_t*) p2;
return((int) (fq1->doc_id - fq2->doc_id));
}
#if 0 #if 0
/*******************************************************************//** /*******************************************************************//**
Print the table used for calculating LCS. */ Print the table used for calculating LCS. */
...@@ -506,14 +490,11 @@ fts_query_compare_rank( ...@@ -506,14 +490,11 @@ fts_query_compare_rank(
if (r2->rank < r1->rank) { if (r2->rank < r1->rank) {
return(-1); return(-1);
} else if (r2->rank == r1->rank) { } else if (r2->rank == r1->rank) {
if (r1->doc_id < r2->doc_id) { if (r1->doc_id < r2->doc_id) {
return(1); return -1;
} else if (r1->doc_id > r2->doc_id) {
return(1);
} }
return(0); return r1->doc_id > r2->doc_id;
} }
return(1); return(1);
...@@ -674,8 +655,9 @@ fts_query_add_word_freq( ...@@ -674,8 +655,9 @@ fts_query_add_word_freq(
word_freq.doc_count = 0; word_freq.doc_count = 0;
static_assert(!offsetof(fts_doc_freq_t, doc_id), "ABI");
word_freq.doc_freqs = rbt_create( word_freq.doc_freqs = rbt_create(
sizeof(fts_doc_freq_t), fts_freq_doc_id_cmp); sizeof(fts_doc_freq_t), fts_doc_id_cmp);
parent.last = rbt_add_node( parent.last = rbt_add_node(
query->word_freqs, &parent, &word_freq); query->word_freqs, &parent, &word_freq);
...@@ -1253,8 +1235,9 @@ fts_query_intersect( ...@@ -1253,8 +1235,9 @@ fts_query_intersect(
/* Create the rb tree that will hold the doc ids of /* Create the rb tree that will hold the doc ids of
the intersection. */ the intersection. */
static_assert(!offsetof(fts_ranking_t, doc_id), "ABI");
query->intersection = rbt_create( query->intersection = rbt_create(
sizeof(fts_ranking_t), fts_ranking_doc_id_cmp); sizeof(fts_ranking_t), fts_doc_id_cmp);
query->total_size += SIZEOF_RBT_CREATE; query->total_size += SIZEOF_RBT_CREATE;
...@@ -1540,8 +1523,9 @@ fts_merge_doc_ids( ...@@ -1540,8 +1523,9 @@ fts_merge_doc_ids(
to create a new result set for fts_query_intersect(). */ to create a new result set for fts_query_intersect(). */
if (query->oper == FTS_EXIST) { if (query->oper == FTS_EXIST) {
static_assert(!offsetof(fts_ranking_t, doc_id), "ABI");
query->intersection = rbt_create( query->intersection = rbt_create(
sizeof(fts_ranking_t), fts_ranking_doc_id_cmp); sizeof(fts_ranking_t), fts_doc_id_cmp);
query->total_size += SIZEOF_RBT_CREATE; query->total_size += SIZEOF_RBT_CREATE;
} }
...@@ -3012,8 +2996,9 @@ fts_query_visitor( ...@@ -3012,8 +2996,9 @@ fts_query_visitor(
if (query->oper == FTS_EXIST) { if (query->oper == FTS_EXIST) {
ut_ad(query->intersection == NULL); ut_ad(query->intersection == NULL);
static_assert(!offsetof(fts_ranking_t, doc_id), "ABI");
query->intersection = rbt_create( query->intersection = rbt_create(
sizeof(fts_ranking_t), fts_ranking_doc_id_cmp); sizeof(fts_ranking_t), fts_doc_id_cmp);
query->total_size += SIZEOF_RBT_CREATE; query->total_size += SIZEOF_RBT_CREATE;
} }
...@@ -3123,8 +3108,8 @@ fts_ast_visit_sub_exp( ...@@ -3123,8 +3108,8 @@ fts_ast_visit_sub_exp(
/* Create new result set to store the sub-expression result. We /* Create new result set to store the sub-expression result. We
will merge this result set with the parent after processing. */ will merge this result set with the parent after processing. */
query->doc_ids = rbt_create(sizeof(fts_ranking_t), static_assert(!offsetof(fts_ranking_t, doc_id), "ABI");
fts_ranking_doc_id_cmp); query->doc_ids = rbt_create(sizeof(fts_ranking_t), fts_doc_id_cmp);
query->total_size += SIZEOF_RBT_CREATE; query->total_size += SIZEOF_RBT_CREATE;
...@@ -3661,8 +3646,9 @@ fts_query_prepare_result( ...@@ -3661,8 +3646,9 @@ fts_query_prepare_result(
result = static_cast<fts_result_t*>( result = static_cast<fts_result_t*>(
ut_zalloc_nokey(sizeof(*result))); ut_zalloc_nokey(sizeof(*result)));
static_assert(!offsetof(fts_ranking_t, doc_id), "ABI");
result->rankings_by_id = rbt_create( result->rankings_by_id = rbt_create(
sizeof(fts_ranking_t), fts_ranking_doc_id_cmp); sizeof(fts_ranking_t), fts_doc_id_cmp);
query->total_size += sizeof(fts_result_t) + SIZEOF_RBT_CREATE; query->total_size += sizeof(fts_result_t) + SIZEOF_RBT_CREATE;
result_is_null = true; result_is_null = true;
...@@ -4065,8 +4051,9 @@ fts_query( ...@@ -4065,8 +4051,9 @@ fts_query(
query.heap = mem_heap_create(128); query.heap = mem_heap_create(128);
/* Create the rb tree for the doc id (current) set. */ /* Create the rb tree for the doc id (current) set. */
static_assert(!offsetof(fts_ranking_t, doc_id), "ABI");
query.doc_ids = rbt_create( query.doc_ids = rbt_create(
sizeof(fts_ranking_t), fts_ranking_doc_id_cmp); sizeof(fts_ranking_t), fts_doc_id_cmp);
query.parser = index->parser; query.parser = index->parser;
query.total_size += SIZEOF_RBT_CREATE; query.total_size += SIZEOF_RBT_CREATE;
......
...@@ -163,6 +163,9 @@ struct fts_token_t; ...@@ -163,6 +163,9 @@ struct fts_token_t;
struct fts_doc_ids_t; struct fts_doc_ids_t;
struct fts_index_cache_t; struct fts_index_cache_t;
/** Compare two DOC_ID. */
int fts_doc_id_cmp(const void *p1, const void *p2)
__attribute__((nonnull, warn_unused_result));
/** Initialize the "fts_table" for internal query into FTS auxiliary /** Initialize the "fts_table" for internal query into FTS auxiliary
tables */ tables */
......
...@@ -271,27 +271,6 @@ fts_index_fetch_nodes( ...@@ -271,27 +271,6 @@ fts_index_fetch_nodes(
word, /*!< in: the word to fetch */ word, /*!< in: the word to fetch */
fts_fetch_t* fetch) /*!< in: fetch callback.*/ fts_fetch_t* fetch) /*!< in: fetch callback.*/
MY_ATTRIBUTE((nonnull)); MY_ATTRIBUTE((nonnull));
/******************************************************************//**
Compare two fts_trx_table_t instances, we actually compare the
table id's here.
@return < 0 if n1 < n2, 0 if n1 == n2, > 0 if n1 > n2 */
UNIV_INLINE
int
fts_trx_table_cmp(
/*==============*/
const void* v1, /*!< in: id1 */
const void* v2) /*!< in: id2 */
MY_ATTRIBUTE((nonnull, warn_unused_result));
/******************************************************************//**
Compare a table id with a trx_table_t table id.
@return < 0 if n1 < n2, 0 if n1 == n2, > 0 if n1 > n2 */
UNIV_INLINE
int
fts_trx_table_id_cmp(
/*=================*/
const void* p1, /*!< in: id1 */
const void* p2) /*!< in: id2 */
MY_ATTRIBUTE((nonnull, warn_unused_result));
#define fts_sql_commit(trx) trx_commit_for_mysql(trx) #define fts_sql_commit(trx) trx_commit_for_mysql(trx)
#define fts_sql_rollback(trx) (trx)->rollback() #define fts_sql_rollback(trx) (trx)->rollback()
/******************************************************************//** /******************************************************************//**
......
...@@ -52,47 +52,3 @@ fts_read_object_id( ...@@ -52,47 +52,3 @@ fts_read_object_id(
if the id is HEX or DEC and do the right thing with it. */ if the id is HEX or DEC and do the right thing with it. */
return(sscanf(str, UINT64PFx, id) == 1); return(sscanf(str, UINT64PFx, id) == 1);
} }
/******************************************************************//**
Compare two fts_trx_table_t instances.
@return < 0 if n1 < n2, 0 if n1 == n2, > 0 if n1 > n2 */
UNIV_INLINE
int
fts_trx_table_cmp(
/*==============*/
const void* p1, /*!< in: id1 */
const void* p2) /*!< in: id2 */
{
const dict_table_t* table1
= (*static_cast<const fts_trx_table_t* const*>(p1))->table;
const dict_table_t* table2
= (*static_cast<const fts_trx_table_t* const*>(p2))->table;
return((table1->id > table2->id)
? 1
: (table1->id == table2->id)
? 0
: -1);
}
/******************************************************************//**
Compare a table id with a fts_trx_table_t table id.
@return < 0 if n1 < n2, 0 if n1 == n2,> 0 if n1 > n2 */
UNIV_INLINE
int
fts_trx_table_id_cmp(
/*=================*/
const void* p1, /*!< in: id1 */
const void* p2) /*!< in: id2 */
{
const uintmax_t* table_id = static_cast<const uintmax_t*>(p1);
const dict_table_t* table2
= (*static_cast<const fts_trx_table_t* const*>(p2))->table;
return((*table_id > table2->id)
? 1
: (*table_id == table2->id)
? 0
: -1);
}
...@@ -279,32 +279,6 @@ struct fts_token_t { ...@@ -279,32 +279,6 @@ struct fts_token_t {
/** It's defined in fts/fts0fts.c */ /** It's defined in fts/fts0fts.c */
extern const fts_index_selector_t fts_index_selector[]; extern const fts_index_selector_t fts_index_selector[];
/******************************************************************//**
Compare two fts_trx_row_t instances doc_ids. */
UNIV_INLINE
int
fts_trx_row_doc_id_cmp(
/*===================*/
/*!< out:
< 0 if n1 < n2,
0 if n1 == n2,
> 0 if n1 > n2 */
const void* p1, /*!< in: id1 */
const void* p2); /*!< in: id2 */
/******************************************************************//**
Compare two fts_ranking_t instances doc_ids. */
UNIV_INLINE
int
fts_ranking_doc_id_cmp(
/*===================*/
/*!< out:
< 0 if n1 < n2,
0 if n1 == n2,
> 0 if n1 > n2 */
const void* p1, /*!< in: id1 */
const void* p2); /*!< in: id2 */
/******************************************************************//** /******************************************************************//**
Duplicate a string. */ Duplicate a string. */
UNIV_INLINE UNIV_INLINE
......
...@@ -46,38 +46,6 @@ fts_string_dup( ...@@ -46,38 +46,6 @@ fts_string_dup(
dst->f_n_char = src->f_n_char; dst->f_n_char = src->f_n_char;
} }
/******************************************************************//**
Compare two fts_trx_row_t doc_ids.
@return < 0 if n1 < n2, 0 if n1 == n2, > 0 if n1 > n2 */
UNIV_INLINE
int
fts_trx_row_doc_id_cmp(
/*===================*/
const void* p1, /*!< in: id1 */
const void* p2) /*!< in: id2 */
{
const fts_trx_row_t* tr1 = (const fts_trx_row_t*) p1;
const fts_trx_row_t* tr2 = (const fts_trx_row_t*) p2;
return((int)(tr1->doc_id - tr2->doc_id));
}
/******************************************************************//**
Compare two fts_ranking_t doc_ids.
@return < 0 if n1 < n2, 0 if n1 == n2, > 0 if n1 > n2 */
UNIV_INLINE
int
fts_ranking_doc_id_cmp(
/*===================*/
const void* p1, /*!< in: id1 */
const void* p2) /*!< in: id2 */
{
const fts_ranking_t* rk1 = (const fts_ranking_t*) p1;
const fts_ranking_t* rk2 = (const fts_ranking_t*) p2;
return((int)(rk1->doc_id - rk2->doc_id));
}
/******************************************************************//** /******************************************************************//**
Get the first character's code position for FTS index partition */ Get the first character's code position for FTS index partition */
extern extern
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
INCLUDE_DIRECTORIES(${CMAKE_SOURCE_DIR}/include INCLUDE_DIRECTORIES(${CMAKE_SOURCE_DIR}/include
${CMAKE_SOURCE_DIR}/unittest/mytap ${CMAKE_SOURCE_DIR}/unittest/mytap
${CMAKE_SOURCE_DIR}/storage/innobase/include) ${CMAKE_SOURCE_DIR}/storage/innobase/include)
ADD_EXECUTABLE(innodb_rbt-t innodb_rbt-t.cc ../ut/ut0rbt.cc)
TARGET_LINK_LIBRARIES(innodb_rbt-t mysys mytap)
ADD_DEPENDENCIES(innodb_rbt-t GenError)
MY_ADD_TEST(innodb_rbt)
ADD_EXECUTABLE(innodb_fts-t innodb_fts-t.cc) ADD_EXECUTABLE(innodb_fts-t innodb_fts-t.cc)
TARGET_LINK_LIBRARIES(innodb_fts-t mysys mytap) TARGET_LINK_LIBRARIES(innodb_fts-t mysys mytap)
ADD_DEPENDENCIES(innodb_fts-t GenError) ADD_DEPENDENCIES(innodb_fts-t GenError)
......
#include "tap.h"
#include "ut0rbt.h"
const size_t alloc_max_retries= 0;
void os_thread_sleep(ulint) { abort(); }
void ut_dbg_assertion_failed(const char *, const char *, unsigned)
{ abort(); }
namespace ib { fatal_or_error::~fatal_or_error() { abort(); } }
#ifdef UNIV_PFS_MEMORY
PSI_memory_key mem_key_other, mem_key_std;
PSI_memory_key ut_new_get_key_by_file(uint32_t) { return mem_key_std; }
#endif
static const uint64_t doc_ids[]=
{
103571, 104018, 106821, 108647, 109352, 109379,
110325, 122868, 210682130, 231275441, 234172769, 366236849,
526467159, 1675241735, 1675243405, 1947751899, 1949940363, 2033691953,
2148227299, 2256289791, 2294223591, 2367501260, 2792700091, 2792701220,
2817121627, 2820680352, 2821165664, 3253312130, 3404918378, 3532599429,
3538712078, 3539373037, 3546479309, 3566641838, 3580209634, 3580871267,
3693930556, 3693932734, 3693932983, 3781949558, 3839877411, 3930968983
};
static int fts_doc_id_cmp(const void *p1, const void *p2)
{
uint64_t a= *static_cast<const uint64_t*>(p1),
b= *static_cast<const uint64_t*>(p2);
return b > a ? -1 : a > b;
}
static int fts_doc_id_buggy_cmp(const void *p1, const void *p2)
{
return int(*static_cast<const uint64_t*>(p1) -
*static_cast<const uint64_t*>(p2));
}
typedef int (*comparator) (const void*, const void*);
static void rbt_populate(ib_rbt_t *rbt)
{
ib_rbt_bound_t parent;
for (const uint64_t &doc_id : doc_ids)
{
if (rbt_search(rbt, &parent, &doc_id))
rbt_add_node(rbt, &parent, &doc_id);
}
}
static void rbt_populate2(ib_rbt_t *rbt)
{
for (const uint64_t &doc_id : doc_ids)
rbt_insert(rbt, &doc_id, &doc_id);
}
static bool rbt_search_all(ib_rbt_t *rbt)
{
ib_rbt_bound_t parent;
for (const uint64_t &doc_id : doc_ids)
if (rbt_search(rbt, &parent, &doc_id))
return false;
return true;
}
static void rbt_test(comparator cmp, bool buggy)
{
ib_rbt_t *rbt= rbt_create(sizeof(uint64_t), cmp);
rbt_populate(rbt);
ok(rbt_search_all(rbt) != buggy, "search after populate");
rbt_free(rbt);
rbt= rbt_create(sizeof(uint64_t), cmp);
rbt_populate2(rbt);
ok(rbt_search_all(rbt) != buggy, "search after populate2");
rbt_free(rbt);
}
int main ()
{
rbt_test(fts_doc_id_buggy_cmp, true);
rbt_test(fts_doc_id_cmp, false);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment