diff --git a/src/MySQLdb/_mysql.c b/src/MySQLdb/_mysql.c index 30b111e5..085c01c2 100644 --- a/src/MySQLdb/_mysql.c +++ b/src/MySQLdb/_mysql.c @@ -84,6 +84,7 @@ typedef struct { bool open; bool reconnect; PyObject *converter; + PyThread_type_lock lock; } _mysql_ConnectionObject; #define check_connection(c) \ @@ -109,6 +110,67 @@ typedef struct { extern PyTypeObject _mysql_ResultObject_Type; +static int +_mysql_ConnectionObject_AllocateLock(_mysql_ConnectionObject *self) +{ + self->lock = PyThread_allocate_lock(); + if (self->lock == NULL) { + PyErr_NoMemory(); + return -1; + } + return 0; +} + +/* Try to acquire the connection lock without blocking. + * Returns 0 on success, -1 if the lock is already held by another thread + * (ProgrammingError is set). */ +static int +_mysql_ConnectionObject_Lock(_mysql_ConnectionObject *self) +{ + PyLockStatus status = PyThread_acquire_lock(self->lock, NOWAIT_LOCK); + if (status != PY_LOCK_ACQUIRED) { + PyErr_SetString(_mysql_ProgrammingError, + "This connection is already in use from another thread. " + "Do not use the same connection object from multiple threads simultaneously."); + return -1; + } + return 0; +} + +/* Blocking variant used only during deallocation where we must acquire the + * lock to call mysql_free_result() even if it means waiting. */ +static void +_mysql_ConnectionObject_LockWait(_mysql_ConnectionObject *self) +{ + Py_BEGIN_ALLOW_THREADS + PyThread_acquire_lock(self->lock, WAIT_LOCK); + Py_END_ALLOW_THREADS +} + +static void +_mysql_ConnectionObject_Unlock(_mysql_ConnectionObject *self) +{ + PyThread_release_lock(self->lock); +} + +#define BEGIN_CONNECTION_LOCK(c) _mysql_ConnectionObject_Lock(c) +#define END_CONNECTION_LOCK(c) _mysql_ConnectionObject_Unlock(c) + +#define BEGIN_RESULT_CONNECTION_LOCK(r) \ + _mysql_ConnectionObject_LockWait(result_connection(r)) +#define END_RESULT_CONNECTION_LOCK(r) \ + END_CONNECTION_LOCK(result_connection(r)) +#define BEGIN_CONNECTION_OPERATION(c, on_closed) \ + do { \ + if (_mysql_ConnectionObject_Lock(c) < 0) return NULL; \ + if (!(c)->open) { \ + END_CONNECTION_LOCK(c); \ + on_closed; \ + } \ + } while (0) +#define BEGIN_RESULT_OPERATION(r, on_closed) \ + BEGIN_CONNECTION_OPERATION(result_connection(r), on_closed) + PyObject * _mysql_Exception(_mysql_ConnectionObject *c) @@ -274,6 +336,14 @@ _mysql_ResultObject_Initialize( self->conn = (PyObject *) conn; Py_INCREF(conn); self->use = use; + if (BEGIN_CONNECTION_LOCK(conn) < 0) { + return -1; + } + if (!conn->open) { + END_CONNECTION_LOCK(conn); + _mysql_Exception(conn); + return -1; + } Py_BEGIN_ALLOW_THREADS ; if (use) result = mysql_use_result(&(conn->connection)); @@ -284,6 +354,7 @@ _mysql_ResultObject_Initialize( Py_END_ALLOW_THREADS ; self->encoding = _get_encoding(&(conn->connection)); + END_CONNECTION_LOCK(conn); //fprintf(stderr, "encoding=%s\n", self->encoding); if (!result) { if (mysql_errno(&(conn->connection))) { @@ -458,6 +529,7 @@ _mysql_ConnectionObject_Initialize( self->converter = NULL; self->open = false; self->reconnect = false; + self->lock = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssisOiiisssiOsiiissss:connect", @@ -479,6 +551,10 @@ _mysql_ConnectionObject_Initialize( )) return -1; + if (_mysql_ConnectionObject_AllocateLock(self)) { + return -1; + } + #ifndef HAVE_MYSQL_SERVER_PUBLIC_KEY if (server_public_key_path) { PyErr_SetString(_mysql_NotSupportedError, "server_public_key_path is not supported"); @@ -744,12 +820,13 @@ _mysql_ConnectionObject_close( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS mysql_close(&(self->connection)); Py_END_ALLOW_THREADS self->open = false; _mysql_ConnectionObject_clear(self); + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -764,8 +841,9 @@ _mysql_ConnectionObject_affected_rows( PyObject *noargs) { my_ulonglong ret; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); ret = mysql_affected_rows(&(self->connection)); + END_CONNECTION_LOCK(self); if (ret == (my_ulonglong)-1) return PyLong_FromLong(-1); return PyLong_FromUnsignedLongLong(ret); @@ -800,11 +878,16 @@ _mysql_ConnectionObject_dump_debug_info( PyObject *noargs) { int err; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_dump_debug_info(&(self->connection)); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -818,11 +901,16 @@ _mysql_ConnectionObject_autocommit( { int flag, err; if (!PyArg_ParseTuple(args, "i", &flag)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_autocommit(&(self->connection), flag); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -834,10 +922,12 @@ _mysql_ConnectionObject_get_autocommit( _mysql_ConnectionObject *self, PyObject *args) { - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); if (self->connection.server_status & SERVER_STATUS_AUTOCOMMIT) { + END_CONNECTION_LOCK(self); Py_RETURN_TRUE; } + END_CONNECTION_LOCK(self); Py_RETURN_FALSE; } @@ -850,11 +940,16 @@ _mysql_ConnectionObject_commit( PyObject *noargs) { int err; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_commit(&(self->connection)); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -867,11 +962,16 @@ _mysql_ConnectionObject_rollback( PyObject *noargs) { int err; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_rollback(&(self->connection)); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -894,11 +994,16 @@ _mysql_ConnectionObject_next_result( PyObject *noargs) { int err; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_next_result(&(self->connection)); Py_END_ALLOW_THREADS - if (err > 0) return _mysql_Exception(self); + if (err > 0) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); return PyLong_FromLong(err); } @@ -917,11 +1022,16 @@ _mysql_ConnectionObject_set_server_option( int err, flags=0; if (!PyArg_ParseTuple(args, "i", &flags)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_set_server_option(&(self->connection), flags); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); return PyLong_FromLong(err); } @@ -942,8 +1052,11 @@ _mysql_ConnectionObject_sqlstate( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_sqlstate(&(self->connection))); + PyObject *ret; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_sqlstate(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_warning_count__doc__[] = @@ -957,8 +1070,11 @@ _mysql_ConnectionObject_warning_count( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong(mysql_warning_count(&(self->connection))); + unsigned int count; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + count = mysql_warning_count(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong(count); } static char _mysql_ConnectionObject_errno__doc__[] = @@ -972,8 +1088,11 @@ _mysql_ConnectionObject_errno( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong((long)mysql_errno(&(self->connection))); + unsigned int err; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + err = mysql_errno(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong((long)err); } static char _mysql_ConnectionObject_error__doc__[] = @@ -987,8 +1106,11 @@ _mysql_ConnectionObject_error( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_error(&(self->connection))); + PyObject *ret; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_error(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_escape_string__doc__[] = @@ -1008,6 +1130,7 @@ _mysql_escape_string( char *in, *out; unsigned long len; Py_ssize_t size; + int use_connection = 0; if (!PyArg_ParseTuple(args, "s#:escape_string", &in, &size)) return NULL; str = PyBytes_FromStringAndSize((char *) NULL, size*2+1); if (!str) return PyErr_NoMemory(); @@ -1015,7 +1138,14 @@ _mysql_escape_string( if (self && PyModule_Check((PyObject*)self)) self = NULL; - if (self && self->open) { + if (self) { + if (BEGIN_CONNECTION_LOCK(self) < 0) { + Py_DECREF(str); + return NULL; + } + use_connection = self->open; + } + if (use_connection) { #if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID) len = mysql_real_escape_string_quote(&(self->connection), out, in, size, '\''); #else @@ -1024,6 +1154,9 @@ _mysql_escape_string( } else { len = mysql_escape_string(out, in, size); } + if (self) { + END_CONNECTION_LOCK(self); + } if (_PyBytes_Resize(&str, len) < 0) return NULL; return (str); } @@ -1044,9 +1177,15 @@ _mysql_string_literal( PyObject *o) { PyObject *s; // input string or bytes. need to decref. + int use_connection = 0; + PyObject *str = NULL; if (self && PyModule_Check((PyObject*)self)) self = NULL; + if (self) { + if (BEGIN_CONNECTION_LOCK(self) < 0) return NULL; + use_connection = self->open; + } if (PyBytes_Check(o)) { s = o; @@ -1054,9 +1193,9 @@ _mysql_string_literal( } else { PyObject *t = PyObject_Str(o); - if (!t) return NULL; + if (!t) goto error; - const char *encoding = (self && self->open) ? + const char *encoding = use_connection ? _get_encoding(&self->connection) : utf8; if (encoding == utf8) { s = t; @@ -1064,7 +1203,7 @@ _mysql_string_literal( else { s = PyUnicode_AsEncodedString(t, encoding, "strict"); Py_DECREF(t); - if (!s) return NULL; + if (!s) goto error; } } @@ -1073,6 +1212,10 @@ _mysql_string_literal( Py_ssize_t size; if (PyUnicode_Check(s)) { in = PyUnicode_AsUTF8AndSize(s, &size); + if (!in) { + Py_DECREF(s); + goto error; + } } else { assert(PyBytes_Check(s)); in = PyBytes_AsString(s); @@ -1080,16 +1223,17 @@ _mysql_string_literal( } // Prepare output buffer (str, out) - PyObject *str = PyBytes_FromStringAndSize((char *) NULL, size*2+3); + str = PyBytes_FromStringAndSize((char *) NULL, size*2+3); if (!str) { Py_DECREF(s); - return PyErr_NoMemory(); + PyErr_NoMemory(); + goto error; } char *out = PyBytes_AS_STRING(str); // escape unsigned long len; - if (self && self->open) { + if (use_connection) { #if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID) len = mysql_real_escape_string_quote(&(self->connection), out+1, in, size, '\''); #else @@ -1100,12 +1244,20 @@ _mysql_string_literal( } Py_DECREF(s); + if (self) { + END_CONNECTION_LOCK(self); + } *out = *(out+len+1) = '\''; if (_PyBytes_Resize(&str, len+2) < 0) { Py_DECREF(str); return NULL; } return str; +error: + if (self) { + END_CONNECTION_LOCK(self); + } + return NULL; } static PyObject * @@ -1142,6 +1294,7 @@ _mysql_escape( PyObject *args) { PyObject *o=NULL, *d=NULL; + PyObject *converter = NULL; if (!PyArg_ParseTuple(args, "O|O:escape", &o, &d)) return NULL; if (d) { @@ -1152,13 +1305,21 @@ _mysql_escape( } return _escape_item(self, o, d); } else { - if (!self) { + if (!self || PyModule_Check(self)) { PyErr_SetString(PyExc_TypeError, "argument 2 must be a mapping"); return NULL; } - return _escape_item(self, o, - ((_mysql_ConnectionObject *) self)->converter); + if (BEGIN_CONNECTION_LOCK((_mysql_ConnectionObject *)self) < 0) return NULL; + converter = ((_mysql_ConnectionObject *) self)->converter; + Py_XINCREF(converter); + END_CONNECTION_LOCK((_mysql_ConnectionObject *)self); + if (!converter) { + return _mysql_Exception((_mysql_ConnectionObject *)self); + } + PyObject *ret = _escape_item(self, o, converter); + Py_DECREF(converter); + return ret; } } @@ -1176,7 +1337,7 @@ _mysql_ResultObject_describe( MYSQL_FIELD *fields; unsigned int i, n; - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); @@ -1204,8 +1365,10 @@ _mysql_ResultObject_describe( if (!t) goto error; PyTuple_SET_ITEM(d, i, t); } + END_RESULT_CONNECTION_LOCK(self); return d; error: + END_RESULT_CONNECTION_LOCK(self); Py_XDECREF(d); return NULL; } @@ -1222,7 +1385,7 @@ _mysql_ResultObject_field_flags( PyObject *d; MYSQL_FIELD *fields; unsigned int i, n; - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); if (!(d = PyTuple_New(n))) return NULL; @@ -1231,8 +1394,10 @@ _mysql_ResultObject_field_flags( if (!(f = PyLong_FromLong((long)fields[i].flags))) goto error; PyTuple_SET_ITEM(d, i, f); } + END_RESULT_CONNECTION_LOCK(self); return d; error: + END_RESULT_CONNECTION_LOCK(self); Py_XDECREF(d); return NULL; } @@ -1539,8 +1704,9 @@ _mysql_ResultObject_fetch_row( if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ii:fetch_row", kwlist, &maxrows, &how)) return NULL; - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); if (how >= (int)(sizeof(row_converters) / sizeof(row_converters[0]))) { + END_RESULT_CONNECTION_LOCK(self); PyErr_SetString(PyExc_ValueError, "how out of range"); return NULL; } @@ -1561,8 +1727,10 @@ _mysql_ResultObject_fetch_row( */ PyObject *t = PyList_AsTuple(r); Py_DECREF(r); + END_RESULT_CONNECTION_LOCK(self); return t; error: + END_RESULT_CONNECTION_LOCK(self); Py_XDECREF(r); return NULL; } @@ -1575,7 +1743,7 @@ _mysql_ResultObject_discard( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); MYSQL_ROW row; Py_BEGIN_ALLOW_THREADS @@ -1585,8 +1753,11 @@ _mysql_ResultObject_discard( Py_END_ALLOW_THREADS _mysql_ConnectionObject *conn = (_mysql_ConnectionObject *)self->conn; if (mysql_errno(&conn->connection)) { - return _mysql_Exception(conn); + PyObject *ret = _mysql_Exception(conn); + END_RESULT_CONNECTION_LOCK(self); + return ret; } + END_RESULT_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1620,11 +1791,16 @@ _mysql_ConnectionObject_change_user( if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|ss:change_user", kwlist, &user, &pwd, &db)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_change_user(&(self->connection), user, pwd, db); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1639,9 +1815,13 @@ _mysql_ConnectionObject_character_set_name( PyObject *noargs) { const char *s; - check_connection(self); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); s = mysql_character_set_name(&(self->connection)); - return PyUnicode_FromString(s); + ret = PyUnicode_FromString(s); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_set_character_set__doc__[] = @@ -1656,12 +1836,18 @@ _mysql_ConnectionObject_set_character_set( { const char *s; int err; + if (!PyArg_ParseTuple(args, "s", &s)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_set_character_set(&(self->connection), s); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1695,9 +1881,13 @@ _mysql_ConnectionObject_get_character_set_info( PyObject *result; MY_CHARSET_INFO cs; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); mysql_get_character_set_info(&(self->connection), &cs); - if (!(result = PyDict_New())) return NULL; + if (!(result = PyDict_New())) { + END_CONNECTION_LOCK(self); + return NULL; + } if (cs.csname) PyDict_SetItemString(result, "name", PyUnicode_FromString(cs.csname)); if (cs.name) @@ -1708,6 +1898,7 @@ _mysql_ConnectionObject_get_character_set_info( PyDict_SetItemString(result, "dir", PyUnicode_FromString(cs.dir)); PyDict_SetItemString(result, "mbminlen", PyLong_FromLong(cs.mbminlen)); PyDict_SetItemString(result, "mbmaxlen", PyLong_FromLong(cs.mbmaxlen)); + END_CONNECTION_LOCK(self); return result; } #endif @@ -1727,9 +1918,11 @@ _mysql_ConnectionObject_get_native_connection( PyObject *noargs) { PyObject *result; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); result = PyCapsule_New(&(self->connection), "_mysql.connection.native_connection", NULL); + END_CONNECTION_LOCK(self); return result; } @@ -1755,8 +1948,12 @@ _mysql_ConnectionObject_get_host_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_get_host_info(&(self->connection))); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_get_host_info(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_get_proto_info__doc__[] = @@ -1769,8 +1966,12 @@ _mysql_ConnectionObject_get_proto_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong((long)mysql_get_proto_info(&(self->connection))); + unsigned int proto; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + proto = mysql_get_proto_info(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong((long)proto); } static char _mysql_ConnectionObject_get_server_info__doc__[] = @@ -1783,8 +1984,12 @@ _mysql_ConnectionObject_get_server_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_get_server_info(&(self->connection))); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_get_server_info(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_info__doc__[] = @@ -1799,9 +2004,16 @@ _mysql_ConnectionObject_info( PyObject *noargs) { const char *s; - check_connection(self); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); s = mysql_info(&(self->connection)); - if (s) return PyUnicode_FromString(s); + if (s) { + ret = PyUnicode_FromString(s); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1832,8 +2044,10 @@ _mysql_ConnectionObject_insert_id( PyObject *noargs) { my_ulonglong r; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); r = mysql_insert_id(&(self->connection)); + END_CONNECTION_LOCK(self); return PyLong_FromUnsignedLongLong(r); } @@ -1849,13 +2063,19 @@ _mysql_ConnectionObject_kill( unsigned long pid; int r; char query[50]; + if (!PyArg_ParseTuple(args, "k:kill", &pid)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); snprintf(query, 50, "KILL %lu", pid); Py_BEGIN_ALLOW_THREADS r = mysql_query(&(self->connection), query); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1870,8 +2090,12 @@ _mysql_ConnectionObject_field_count( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong((long)mysql_field_count(&(self->connection))); + unsigned int count; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + count = mysql_field_count(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong((long)count); } static char _mysql_ConnectionObject_fileno__doc__[] = @@ -1884,8 +2108,12 @@ _mysql_ConnectionObject_fileno( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong(self->connection.net.fd); + int fd; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + fd = self->connection.net.fd; + END_CONNECTION_LOCK(self); + return PyLong_FromLong(fd); } static char _mysql_ResultObject_num_fields__doc__[] = @@ -1896,8 +2124,12 @@ _mysql_ResultObject_num_fields( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); - return PyLong_FromLong((long)mysql_num_fields(self->result)); + unsigned int fields; + + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); + fields = mysql_num_fields(self->result); + END_RESULT_CONNECTION_LOCK(self); + return PyLong_FromLong((long)fields); } static char _mysql_ResultObject_num_rows__doc__[] = @@ -1911,8 +2143,12 @@ _mysql_ResultObject_num_rows( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); - return PyLong_FromUnsignedLongLong(mysql_num_rows(self->result)); + my_ulonglong rows; + + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); + rows = mysql_num_rows(self->result); + END_RESULT_CONNECTION_LOCK(self); + return PyLong_FromUnsignedLongLong(rows); } static char _mysql_ConnectionObject_ping__doc__[] = @@ -1940,8 +2176,9 @@ _mysql_ConnectionObject_ping( PyObject *args) { int reconnect = 0; + if (!PyArg_ParseTuple(args, "|p", &reconnect)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); if (reconnect != (self->reconnect == true)) { // libmysqlclient show warning to stderr when MYSQL_OPT_RECONNECT is used. // so we avoid using it as possible for now. @@ -1956,7 +2193,12 @@ _mysql_ConnectionObject_ping( Py_BEGIN_ALLOW_THREADS r = mysql_ping(&(self->connection)); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1974,13 +2216,19 @@ _mysql_ConnectionObject_query( char *query; Py_ssize_t len; int r; + if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_real_query(&(self->connection), query, len); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1998,13 +2246,19 @@ _mysql_ConnectionObject_send_query( Py_ssize_t len; int r; MYSQL *mysql = &(self->connection); + if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_send_query(mysql, query, len); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2019,12 +2273,18 @@ _mysql_ConnectionObject_read_query_result( { int r; MYSQL *mysql = &(self->connection); - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = (int)mysql_read_query_result(mysql); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2047,12 +2307,18 @@ _mysql_ConnectionObject_select_db( { char *db; int r; + if (!PyArg_ParseTuple(args, "s:select_db", &db)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_select_db(&(self->connection), db); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2067,11 +2333,17 @@ _mysql_ConnectionObject_shutdown( PyObject *noargs) { int r; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_query(&(self->connection), "SHUTDOWN"); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2088,12 +2360,20 @@ _mysql_ConnectionObject_stat( PyObject *noargs) { const char *s; - check_connection(self); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS s = mysql_stat(&(self->connection)); Py_END_ALLOW_THREADS - if (!s) return _mysql_Exception(self); - return PyUnicode_FromString(s); + if (!s) { + ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + ret = PyUnicode_FromString(s); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_store_result__doc__[] = @@ -2110,8 +2390,10 @@ _mysql_ConnectionObject_store_result( PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; _mysql_ResultObject *r=NULL; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); arglist = Py_BuildValue("(OiO)", self, 0, self->converter); + END_CONNECTION_LOCK(self); if (!arglist) goto error; kwarglist = PyDict_New(); if (!kwarglist) goto error; @@ -2150,8 +2432,10 @@ _mysql_ConnectionObject_thread_id( PyObject *noargs) { unsigned long pid; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); pid = mysql_thread_id(&(self->connection)); + END_CONNECTION_LOCK(self); return PyLong_FromLong((long)pid); } @@ -2169,8 +2453,10 @@ _mysql_ConnectionObject_use_result( PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; _mysql_ResultObject *r=NULL; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); arglist = Py_BuildValue("(OiO)", self, 1, self->converter); + END_CONNECTION_LOCK(self); if (!arglist) return NULL; kwarglist = PyDict_New(); if (!kwarglist) goto error; @@ -2201,31 +2487,31 @@ _mysql_ConnectionObject_discard_result( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + MYSQL_RES *res; + MYSQL_ROW row; + int err = 0; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); MYSQL *conn = &(self->connection); Py_BEGIN_ALLOW_THREADS; - MYSQL_RES *res = mysql_use_result(conn); - if (res == NULL) { - Py_BLOCK_THREADS; - if (mysql_errno(conn) != 0) { - // fprintf(stderr, "mysql_use_result failed: %s\n", mysql_error(conn)); - return _mysql_Exception(self); + res = mysql_use_result(conn); + if (res != NULL) { + while (NULL != (row = mysql_fetch_row(res))) { + // do nothing. } - Py_RETURN_NONE; - } - - MYSQL_ROW row; - while (NULL != (row = mysql_fetch_row(res))) { - // do nothing. + mysql_free_result(res); } - mysql_free_result(res); Py_END_ALLOW_THREADS; - if (mysql_errno(conn)) { + err = mysql_errno(conn); + if (err) { // fprintf(stderr, "mysql_free_result failed: %s\n", mysql_error(conn)); - return _mysql_Exception(self); + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2239,6 +2525,10 @@ _mysql_ConnectionObject_dealloc( self->open = false; } Py_CLEAR(self->converter); + if (self->lock != NULL) { + PyThread_free_lock(self->lock); + self->lock = NULL; + } MyFree(self); } @@ -2263,9 +2553,11 @@ _mysql_ResultObject_data_seek( PyObject *args) { unsigned int row; + if (!PyArg_ParseTuple(args, "i:data_seek", &row)) return NULL; - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); mysql_data_seek(self->result, row); + END_RESULT_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2273,8 +2565,15 @@ static void _mysql_ResultObject_dealloc( _mysql_ResultObject *self) { + PyObject *conn = self->conn; PyObject_GC_UnTrack((PyObject *)self); + if (conn != NULL) { + BEGIN_RESULT_CONNECTION_LOCK(self); + } mysql_free_result(self->result); + if (conn != NULL) { + END_RESULT_CONNECTION_LOCK(self); + } _mysql_ResultObject_clear(self); MyFree(self); } diff --git a/tests/test_connection.py b/tests/test_connection.py index 960de572..f707f60c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,3 +1,6 @@ +import threading +import time + import pytest from MySQLdb._exceptions import ProgrammingError @@ -24,3 +27,80 @@ def test_multi_statements_false(): cursor.execute("select 17") rows = cursor.fetchall() assert rows == ((17,),) + + +def test_connection_concurrent_use_raises(): + """While a slow query holds the connection lock, any other access from + a second thread must raise ProgrammingError immediately (not block).""" + conn = connection_factory() + try: + thread_error = None + done = threading.Event() + + def run_slow_query(): + nonlocal thread_error + try: + conn.query("SELECT SLEEP(0.5)") + result = conn.store_result() + result.fetch_row() + except Exception as exc: # pragma: no cover + thread_error = exc + finally: + done.set() + + thread = threading.Thread(target=run_slow_query) + thread.start() + + # Give the background thread time to acquire the lock and enter SLEEP. + time.sleep(0.1) + + start = time.monotonic() + with pytest.raises(ProgrammingError, match="already in use"): + conn.thread_id() + # Should fail immediately, not wait for the SLEEP to finish. + assert time.monotonic() - start < 0.1 + + done.wait() + thread.join() + assert thread_error is None + finally: + conn.close() + + +def test_result_concurrent_use_raises(): + """While fetch_row holds the connection lock streaming a slow result, + any other access from a second thread must raise ProgrammingError immediately.""" + conn = connection_factory() + try: + conn.query("SELECT 1 UNION ALL SELECT SLEEP(0.5)") + result = conn.use_result() + + thread_error = None + done = threading.Event() + + def fetch_all_rows(): + nonlocal thread_error + try: + assert result.fetch_row(maxrows=0) == ((1,), (0,)) + except Exception as exc: # pragma: no cover + thread_error = exc + finally: + done.set() + + thread = threading.Thread(target=fetch_all_rows) + thread.start() + + # Give the background thread time to acquire the lock and enter SLEEP. + time.sleep(0.1) + + start = time.monotonic() + with pytest.raises(ProgrammingError, match="already in use"): + conn.thread_id() + # Should fail immediately, not wait for the SLEEP to finish. + assert time.monotonic() - start < 0.1 + + done.wait() + thread.join() + assert thread_error is None + finally: + conn.close()