Commit e43ac760 authored by konstantin@mysql.com's avatar konstantin@mysql.com

Fixes for bugs #2274 "mysqld gets SIGSEGV during processing of malformed

COM_EXECUTE packet" and #2795 "prepare + execute without bind_param crashes
 server" and #2473 "seg fault running tests/client_test.c": 
- length checking added to packet parser 
- default impelemntation of Item_param::set_param_func will work in
case of malformed packet.
No test cases are possible in our test suite, as there are no tests 
operating on protocol layer.
parent 7e2bb11d
...@@ -517,7 +517,33 @@ String *Item_null::val_str(String *str) ...@@ -517,7 +517,33 @@ String *Item_null::val_str(String *str)
{ null_value=1; return 0;} { null_value=1; return 0;}
/* Item_param related */ /*********************** Item_param related ******************************/
/*
Default function of Item_param::set_param_func, so in case
of malformed packet the server won't SIGSEGV
*/
static void
default_set_param_func(Item_param *param,
uchar **pos __attribute__((unused)),
ulong len __attribute__((unused)))
{
param->set_null();
}
Item_param::Item_param(unsigned position) :
value_is_set(FALSE),
item_result_type(STRING_RESULT),
item_type(STRING_ITEM),
item_is_time(FALSE),
long_data_supplied(FALSE),
pos_in_query(position),
set_param_func(default_set_param_func)
{
name= (char*) "?";
}
void Item_param::set_null() void Item_param::set_null()
{ {
DBUG_ENTER("Item_param::set_null"); DBUG_ENTER("Item_param::set_null");
......
...@@ -348,16 +348,7 @@ class Item_param :public Item ...@@ -348,16 +348,7 @@ class Item_param :public Item
bool long_data_supplied; bool long_data_supplied;
uint pos_in_query; uint pos_in_query;
Item_param(uint position) Item_param(uint position);
{
name= (char*) "?";
pos_in_query= position;
item_type= STRING_ITEM;
item_result_type = STRING_RESULT;
item_is_time= false;
long_data_supplied= false;
value_is_set= 0;
}
enum Type type() const { return item_type; } enum Type type() const { return item_type; }
double val(); double val();
longlong val_int(); longlong val_int();
...@@ -374,11 +365,14 @@ class Item_param :public Item ...@@ -374,11 +365,14 @@ class Item_param :public Item
void set_time(TIME *tm, timestamp_type type); void set_time(TIME *tm, timestamp_type type);
bool get_time(TIME *tm); bool get_time(TIME *tm);
void reset() {} void reset() {}
#ifndef EMBEDDED_LIBRARY /*
void (*set_param_func)(Item_param *param, uchar **pos); Assign placeholder value from bind data.
#else Note, that 'len' has different semantics in embedded library (as we
void (*set_param_func)(Item_param *param, uchar **pos, ulong data_len); don't need to check that packet is not broken there). See
#endif sql_prepare.cc for details.
*/
void (*set_param_func)(Item_param *param, uchar **pos, ulong len);
enum Item_result result_type () const enum Item_result result_type () const
{ return item_result_type; } { return item_result_type; }
String *query_val_str(String *str); String *query_val_str(String *str);
......
...@@ -621,7 +621,7 @@ int mysqld_help (THD *thd, const char *text); ...@@ -621,7 +621,7 @@ int mysqld_help (THD *thd, const char *text);
/* sql_prepare.cc */ /* sql_prepare.cc */
void mysql_stmt_prepare(THD *thd, char *packet, uint packet_length); void mysql_stmt_prepare(THD *thd, char *packet, uint packet_length);
void mysql_stmt_execute(THD *thd, char *packet); void mysql_stmt_execute(THD *thd, char *packet, uint packet_length);
void mysql_stmt_free(THD *thd, char *packet); void mysql_stmt_free(THD *thd, char *packet);
void mysql_stmt_reset(THD *thd, char *packet); void mysql_stmt_reset(THD *thd, char *packet);
void mysql_stmt_get_longdata(THD *thd, char *pos, ulong packet_length); void mysql_stmt_get_longdata(THD *thd, char *pos, ulong packet_length);
......
...@@ -1405,7 +1405,7 @@ bool dispatch_command(enum enum_server_command command, THD *thd, ...@@ -1405,7 +1405,7 @@ bool dispatch_command(enum enum_server_command command, THD *thd,
} }
case COM_EXECUTE: case COM_EXECUTE:
{ {
mysql_stmt_execute(thd, packet); mysql_stmt_execute(thd, packet, packet_length);
break; break;
} }
case COM_LONG_DATA: case COM_LONG_DATA:
......
...@@ -94,7 +94,8 @@ class Prepared_statement: public Statement ...@@ -94,7 +94,8 @@ class Prepared_statement: public Statement
bool long_data_used; bool long_data_used;
bool log_full_query; bool log_full_query;
#ifndef EMBEDDED_LIBRARY #ifndef EMBEDDED_LIBRARY
bool (*set_params)(Prepared_statement *st, uchar *pos, uchar *read_pos); bool (*set_params)(Prepared_statement *st, uchar *data, uchar *data_end,
uchar *read_pos);
#else #else
bool (*set_params_data)(Prepared_statement *st); bool (*set_params_data)(Prepared_statement *st);
#endif #endif
...@@ -117,14 +118,6 @@ inline bool is_param_null(const uchar *pos, ulong param_no) ...@@ -117,14 +118,6 @@ inline bool is_param_null(const uchar *pos, ulong param_no)
enum { STMT_QUERY_LOG_LENGTH= 8192 }; enum { STMT_QUERY_LOG_LENGTH= 8192 };
#ifdef EMBEDDED_LIBRARY
#define SET_PARAM_FUNCTION(fn_name) \
static void fn_name(Item_param *param, uchar **pos, ulong data_len)
#else
#define SET_PARAM_FUNCTION(fn_name) \
static void fn_name(Item_param *param, uchar **pos)
#endif
enum enum_send_error { DONT_SEND_ERROR= 0, SEND_ERROR }; enum enum_send_error { DONT_SEND_ERROR= 0, SEND_ERROR };
/* /*
...@@ -186,29 +179,38 @@ static bool send_prep_stmt(Prepared_statement *stmt, ...@@ -186,29 +179,38 @@ static bool send_prep_stmt(Prepared_statement *stmt,
*/ */
#ifndef EMBEDDED_LIBRARY #ifndef EMBEDDED_LIBRARY
static ulong get_param_length(uchar **packet) static ulong get_param_length(uchar **packet, ulong len)
{ {
reg1 uchar *pos= *packet; reg1 uchar *pos= *packet;
if (len < 1)
return 0;
if (*pos < 251) if (*pos < 251)
{ {
(*packet)++; (*packet)++;
return (ulong) *pos; return (ulong) *pos;
} }
if (len < 3)
return 0;
if (*pos == 252) if (*pos == 252)
{ {
(*packet)+=3; (*packet)+=3;
return (ulong) uint2korr(pos+1); return (ulong) uint2korr(pos+1);
} }
if (len < 4)
return 0;
if (*pos == 253) if (*pos == 253)
{ {
(*packet)+=4; (*packet)+=4;
return (ulong) uint3korr(pos+1); return (ulong) uint3korr(pos+1);
} }
if (len < 5)
return 0;
(*packet)+=9; // Must be 254 when here (*packet)+=9; // Must be 254 when here
/* TODO: why uint4korr here? (should be uint8korr) */
return (ulong) uint4korr(pos+1); return (ulong) uint4korr(pos+1);
} }
#else #else
#define get_param_length(A) data_len #define get_param_length(packet, len) len
#endif /*!EMBEDDED_LIBRARY*/ #endif /*!EMBEDDED_LIBRARY*/
/* /*
...@@ -230,55 +232,80 @@ static ulong get_param_length(uchar **packet) ...@@ -230,55 +232,80 @@ static ulong get_param_length(uchar **packet)
none none
*/ */
SET_PARAM_FUNCTION(set_param_tiny) void set_param_tiny(Item_param *param, uchar **pos, ulong len)
{ {
#ifndef EMBEDDED_LIBRARY
if (len < 1)
return;
#endif
param->set_int((longlong)(**pos)); param->set_int((longlong)(**pos));
*pos+= 1; *pos+= 1;
} }
SET_PARAM_FUNCTION(set_param_short) void set_param_short(Item_param *param, uchar **pos, ulong len)
{ {
#ifndef EMBEDDED_LIBRARY
if (len < 2)
return;
#endif
param->set_int((longlong)sint2korr(*pos)); param->set_int((longlong)sint2korr(*pos));
*pos+= 2; *pos+= 2;
} }
SET_PARAM_FUNCTION(set_param_int32) void set_param_int32(Item_param *param, uchar **pos, ulong len)
{ {
#ifndef EMBEDDED_LIBRARY
if (len < 4)
return;
#endif
param->set_int((longlong)sint4korr(*pos)); param->set_int((longlong)sint4korr(*pos));
*pos+= 4; *pos+= 4;
} }
SET_PARAM_FUNCTION(set_param_int64) void set_param_int64(Item_param *param, uchar **pos, ulong len)
{ {
#ifndef EMBEDDED_LIBRARY
if (len < 8)
return;
#endif
param->set_int((longlong)sint8korr(*pos)); param->set_int((longlong)sint8korr(*pos));
*pos+= 8; *pos+= 8;
} }
SET_PARAM_FUNCTION(set_param_float) void set_param_float(Item_param *param, uchar **pos, ulong len)
{ {
#ifndef EMBEDDED_LIBRARY
if (len < 4)
return;
#endif
float data; float data;
float4get(data,*pos); float4get(data,*pos);
param->set_double((double) data); param->set_double((double) data);
*pos+= 4; *pos+= 4;
} }
SET_PARAM_FUNCTION(set_param_double) void set_param_double(Item_param *param, uchar **pos, ulong len)
{ {
#ifndef EMBEDDED_LIBRARY
if (len < 8)
return;
#endif
double data; double data;
float8get(data,*pos); float8get(data,*pos);
param->set_double((double) data); param->set_double((double) data);
*pos+= 8; *pos+= 8;
} }
SET_PARAM_FUNCTION(set_param_time) void set_param_time(Item_param *param, uchar **pos, ulong len)
{ {
ulong length; ulong length;
if ((length= get_param_length(pos))) if ((length= get_param_length(pos, len)) >= 8)
{ {
uchar *to= *pos; uchar *to= *pos;
TIME tm; TIME tm;
/* TODO: why length is compared with 8 here? */
tm.second_part= (length > 8 ) ? (ulong) sint4korr(to+7): 0; tm.second_part= (length > 8 ) ? (ulong) sint4korr(to+7): 0;
tm.day= (ulong) sint4korr(to+1); tm.day= (ulong) sint4korr(to+1);
...@@ -294,11 +321,11 @@ SET_PARAM_FUNCTION(set_param_time) ...@@ -294,11 +321,11 @@ SET_PARAM_FUNCTION(set_param_time)
*pos+= length; *pos+= length;
} }
SET_PARAM_FUNCTION(set_param_datetime) void set_param_datetime(Item_param *param, uchar **pos, ulong len)
{ {
uint length; uint length;
if ((length= get_param_length(pos))) if ((length= get_param_length(pos, len)) >= 4)
{ {
uchar *to= *pos; uchar *to= *pos;
TIME tm; TIME tm;
...@@ -324,11 +351,11 @@ SET_PARAM_FUNCTION(set_param_datetime) ...@@ -324,11 +351,11 @@ SET_PARAM_FUNCTION(set_param_datetime)
*pos+= length; *pos+= length;
} }
SET_PARAM_FUNCTION(set_param_date) void set_param_date(Item_param *param, uchar **pos, ulong len)
{ {
ulong length; ulong length;
if ((length= get_param_length(pos))) if ((length= get_param_length(pos, len)) >= 4)
{ {
uchar *to= *pos; uchar *to= *pos;
TIME tm; TIME tm;
...@@ -346,11 +373,11 @@ SET_PARAM_FUNCTION(set_param_date) ...@@ -346,11 +373,11 @@ SET_PARAM_FUNCTION(set_param_date)
*pos+= length; *pos+= length;
} }
SET_PARAM_FUNCTION(set_param_str) void set_param_str(Item_param *param, uchar **pos, ulong len)
{ {
ulong len= get_param_length(pos); ulong length= get_param_length(pos, len);
param->set_value((const char *)*pos, len); param->set_value((const char *)*pos, length);
*pos+= len; *pos+= length;
} }
static void setup_one_conversion_function(Item_param *param, uchar param_type) static void setup_one_conversion_function(Item_param *param, uchar param_type)
...@@ -405,8 +432,8 @@ static void setup_one_conversion_function(Item_param *param, uchar param_type) ...@@ -405,8 +432,8 @@ static void setup_one_conversion_function(Item_param *param, uchar param_type)
and if binary/update log is set, generate the valid query. and if binary/update log is set, generate the valid query.
*/ */
static bool insert_params_withlog(Prepared_statement *stmt, uchar *pos, static bool insert_params_withlog(Prepared_statement *stmt, uchar *null_array,
uchar *read_pos) uchar *read_pos, uchar *data_end)
{ {
THD *thd= stmt->thd; THD *thd= stmt->thd;
Item_param **begin= stmt->param_array; Item_param **begin= stmt->param_array;
...@@ -428,7 +455,7 @@ static bool insert_params_withlog(Prepared_statement *stmt, uchar *pos, ...@@ -428,7 +455,7 @@ static bool insert_params_withlog(Prepared_statement *stmt, uchar *pos,
res= param->query_val_str(&str); res= param->query_val_str(&str);
else else
{ {
if (is_param_null(pos, it - begin)) if (is_param_null(null_array, it - begin))
{ {
param->maybe_null= param->null_value= 1; param->maybe_null= param->null_value= 1;
res= &my_null_string; res= &my_null_string;
...@@ -436,7 +463,9 @@ static bool insert_params_withlog(Prepared_statement *stmt, uchar *pos, ...@@ -436,7 +463,9 @@ static bool insert_params_withlog(Prepared_statement *stmt, uchar *pos,
else else
{ {
param->maybe_null= param->null_value= 0; param->maybe_null= param->null_value= 0;
param->set_param_func(param, &read_pos); if (read_pos >= data_end)
DBUG_RETURN(1);
param->set_param_func(param, &read_pos, data_end - read_pos);
res= param->query_val_str(&str); res= param->query_val_str(&str);
} }
} }
...@@ -452,8 +481,8 @@ static bool insert_params_withlog(Prepared_statement *stmt, uchar *pos, ...@@ -452,8 +481,8 @@ static bool insert_params_withlog(Prepared_statement *stmt, uchar *pos,
} }
static bool insert_params(Prepared_statement *stmt, uchar *pos, static bool insert_params(Prepared_statement *stmt, uchar *null_array,
uchar *read_pos) uchar *read_pos, uchar *data_end)
{ {
Item_param **begin= stmt->param_array; Item_param **begin= stmt->param_array;
Item_param **end= begin + stmt->param_count; Item_param **end= begin + stmt->param_count;
...@@ -465,20 +494,23 @@ static bool insert_params(Prepared_statement *stmt, uchar *pos, ...@@ -465,20 +494,23 @@ static bool insert_params(Prepared_statement *stmt, uchar *pos,
Item_param *param= *it; Item_param *param= *it;
if (!param->long_data_supplied) if (!param->long_data_supplied)
{ {
if (is_param_null(pos, it - begin)) if (is_param_null(null_array, it - begin))
param->maybe_null= param->null_value= 1; param->maybe_null= param->null_value= 1;
else else
{ {
param->maybe_null= param->null_value= 0; param->maybe_null= param->null_value= 0;
param->set_param_func(param, &read_pos); if (read_pos >= data_end)
DBUG_RETURN(1);
param->set_param_func(param, &read_pos, data_end - read_pos);
} }
} }
} }
DBUG_RETURN(0); DBUG_RETURN(0);
} }
static bool setup_conversion_functions(Prepared_statement *stmt, static bool setup_conversion_functions(Prepared_statement *stmt,
uchar **data) uchar **data, uchar *data_end)
{ {
/* skip null bits */ /* skip null bits */
uchar *read_pos= *data + (stmt->param_count+7) / 8; uchar *read_pos= *data + (stmt->param_count+7) / 8;
...@@ -495,6 +527,8 @@ static bool setup_conversion_functions(Prepared_statement *stmt, ...@@ -495,6 +527,8 @@ static bool setup_conversion_functions(Prepared_statement *stmt,
Item_param **end= it + stmt->param_count; Item_param **end= it + stmt->param_count;
for (; it < end; ++it) for (; it < end; ++it)
{ {
if (read_pos >= data_end)
DBUG_RETURN(1);
setup_one_conversion_function(*it, *read_pos); setup_one_conversion_function(*it, *read_pos);
read_pos+= 2; read_pos+= 2;
} }
...@@ -1072,7 +1106,7 @@ static void reset_stmt_for_execute(Prepared_statement *stmt) ...@@ -1072,7 +1106,7 @@ static void reset_stmt_for_execute(Prepared_statement *stmt)
*/ */
void mysql_stmt_execute(THD *thd, char *packet) void mysql_stmt_execute(THD *thd, char *packet, uint packet_length)
{ {
ulong stmt_id= uint4korr(packet); ulong stmt_id= uint4korr(packet);
Prepared_statement *stmt; Prepared_statement *stmt;
...@@ -1097,10 +1131,11 @@ void mysql_stmt_execute(THD *thd, char *packet) ...@@ -1097,10 +1131,11 @@ void mysql_stmt_execute(THD *thd, char *packet)
#ifndef EMBEDDED_LIBRARY #ifndef EMBEDDED_LIBRARY
if (stmt->param_count) if (stmt->param_count)
{ {
uchar *packet_end= (uchar *) packet + packet_length - 1;
packet+= 4; packet+= 4;
uchar *null_array= (uchar *) packet; uchar *null_array= (uchar *) packet;
if (setup_conversion_functions(stmt, (uchar **) &packet) || if (setup_conversion_functions(stmt, (uchar **) &packet, packet_end) ||
stmt->set_params(stmt, null_array, (uchar *) packet)) stmt->set_params(stmt, null_array, (uchar *) packet, packet_end))
goto set_params_data_err; goto set_params_data_err;
} }
#else #else
...@@ -1159,6 +1194,7 @@ void mysql_stmt_execute(THD *thd, char *packet) ...@@ -1159,6 +1194,7 @@ void mysql_stmt_execute(THD *thd, char *packet)
void mysql_stmt_reset(THD *thd, char *packet) void mysql_stmt_reset(THD *thd, char *packet)
{ {
/* There is always space for 4 bytes in buffer */
ulong stmt_id= uint4korr(packet); ulong stmt_id= uint4korr(packet);
Prepared_statement *stmt; Prepared_statement *stmt;
...@@ -1189,6 +1225,7 @@ void mysql_stmt_reset(THD *thd, char *packet) ...@@ -1189,6 +1225,7 @@ void mysql_stmt_reset(THD *thd, char *packet)
void mysql_stmt_free(THD *thd, char *packet) void mysql_stmt_free(THD *thd, char *packet)
{ {
/* There is always space for 4 bytes in packet buffer */
ulong stmt_id= uint4korr(packet); ulong stmt_id= uint4korr(packet);
Prepared_statement *stmt; Prepared_statement *stmt;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment