Commit 86b7ebbd authored by Julien Muchembled's avatar Julien Muchembled

storage: prevent 2 nodes from working with the same database

parent 8d42a2e6
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import struct, threading import os, errno, socket, struct, sys, threading
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
...@@ -55,6 +55,10 @@ class DatabaseManager(object): ...@@ -55,6 +55,10 @@ class DatabaseManager(object):
ENGINES = () ENGINES = ()
UNSAFE = False UNSAFE = False
__lock = None
LOCK = "neostorage"
LOCKED = "error: database is locked"
_deferred = 0 _deferred = 0
_duplicating = _repairing = None _duplicating = _repairing = None
...@@ -84,6 +88,7 @@ class DatabaseManager(object): ...@@ -84,6 +88,7 @@ class DatabaseManager(object):
def _duplicate(self): def _duplicate(self):
cls = self.__class__ cls = self.__class__
db = cls.__new__(cls) db = cls.__new__(cls)
db.LOCK = None
db._duplicating = self db._duplicating = self
try: try:
db._connect() db._connect()
...@@ -102,6 +107,26 @@ class DatabaseManager(object): ...@@ -102,6 +107,26 @@ class DatabaseManager(object):
def _connect(self): def _connect(self):
"""Connect to the database""" """Connect to the database"""
def lock(self, db_path):
if self.LOCK:
assert self.__lock is None, self.__lock
# For platforms that don't support anonymous sockets,
# we can either use zc.lockfile or an empty SQLite db
# (with BEGIN EXCLUSIVE).
try:
stat = os.stat(db_path)
except OSError as e:
if e.errno != errno.ENOENT:
raise
return # in-memory or temporary database
s = self.__lock = socket.socket(socket.AF_UNIX)
try:
s.bind('\0%s:%s:%s' % (self.LOCK, stat.st_dev, stat.st_ino))
except socket.error as e:
if e.errno != errno.EADDRINUSE:
raise
sys.exit(self.LOCKED)
@abstract @abstract
def erase(self): def erase(self):
"""""" """"""
...@@ -152,6 +177,9 @@ class DatabaseManager(object): ...@@ -152,6 +177,9 @@ class DatabaseManager(object):
def close(self): def close(self):
self._deferredCommit() self._deferredCommit()
self._close() self._close()
if self.__lock:
self.__lock.close()
del self.__lock
def _commit(self): def _commit(self):
"""Backend-specific code to commit the pending changes""" """Backend-specific code to commit the pending changes"""
......
...@@ -29,6 +29,7 @@ import os ...@@ -29,6 +29,7 @@ import os
import re import re
import string import string
import struct import struct
import sys
import time import time
from . import LOG_QUERIES from . import LOG_QUERIES
...@@ -102,9 +103,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -102,9 +103,17 @@ class MySQLDatabaseManager(DatabaseManager):
conn.autocommit(False) conn.autocommit(False)
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1)) conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1))
conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
conn.query("SHOW VARIABLES WHERE variable_name='max_allowed_packet'") def query(sql):
r = conn.store_result() conn.query(sql)
(name, value), = r.fetch_row(r.num_rows()) r = conn.store_result()
return r.fetch_row(r.num_rows())
if self.LOCK:
(locked,), = query("SELECT GET_LOCK('%s.%s', 0)"
% (self.db, self.LOCK))
if not locked:
sys.exit(self.LOCKED)
(name, value), = query(
"SHOW VARIABLES WHERE variable_name='max_allowed_packet'")
if int(value) < self._max_allowed_packet: if int(value) < self._max_allowed_packet:
raise DatabaseFailure("Global variable %r is too small." raise DatabaseFailure("Global variable %r is too small."
" Minimal value must be %uk." " Minimal value must be %uk."
......
...@@ -78,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -78,6 +78,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False) self.conn = sqlite3.connect(self.db, check_same_thread=False)
self.lock(self.db)
if self.UNSAFE: if self.UNSAFE:
q = self.query q = self.query
q("PRAGMA synchronous = OFF") q("PRAGMA synchronous = OFF")
......
...@@ -217,7 +217,8 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -217,7 +217,8 @@ class NeoUnitTestBase(NeoTestBase):
temp_dir = getTempDirectory() temp_dir = getTempDirectory()
for i in xrange(number): for i in xrange(number):
try: try:
os.remove(os.path.join(temp_dir, 'test_neo%s.sqlite' % i)) os.remove(os.path.join(temp_dir,
'%s%s.sqlite' % (prefix, i)))
except OSError, e: except OSError, e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
......
...@@ -37,10 +37,11 @@ from neo.lib import logging ...@@ -37,10 +37,11 @@ from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES UUID_NAMESPACES
from neo.lib.util import dump from neo.lib.util import dump
from .. import ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL, \ from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, NeoTestBase, setupMySQLdb buildUrlFromString, cluster, getTempDirectory, NeoTestBase, Patch,
setupMySQLdb)
from neo.client.Storage import Storage from neo.client.Storage import Storage
from neo.storage.database import buildDatabaseManager from neo.storage.database import manager, buildDatabaseManager
try: try:
coverage = sys.modules['neo.scripts.runner'].coverage coverage = sys.modules['neo.scripts.runner'].coverage
...@@ -483,7 +484,8 @@ class NEOCluster(object): ...@@ -483,7 +484,8 @@ class NEOCluster(object):
def getSQLConnection(self, db): def getSQLConnection(self, db):
assert db is not None and db in self.db_list assert db is not None and db in self.db_list
return buildDatabaseManager(self.adapter, (self.db_template(db),)) with Patch(manager.DatabaseManager, LOCK=None):
return buildDatabaseManager(self.adapter, (self.db_template(db),))
def getMasterProcessList(self): def getMasterProcessList(self):
return self.process_dict.get(NodeTypes.MASTER) return self.process_dict.get(NodeTypes.MASTER)
......
...@@ -131,6 +131,15 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -131,6 +131,15 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2)) self.assertEqual(set(list1), set(list2))
def _test_lockDatabase_open(self):
raise NotImplementedError
def test_lockDatabase(self):
db = self._test_lockDatabase_open()
self.assertRaises(SystemExit, self._test_lockDatabase_open)
db.close()
self._test_lockDatabase_open().close()
def test_getUnfinishedTIDDict(self): def test_getUnfinishedTIDDict(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
......
...@@ -29,11 +29,13 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -29,11 +29,13 @@ class StorageMySQLdbTests(StorageDBTests):
engine = None engine = None
def getDB(self, reset=0): def _test_lockDatabase_open(self):
self.prepareDatabase(number=1, prefix=DB_PREFIX) self.prepareDatabase(number=1, prefix=DB_PREFIX)
# db manager
database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET) database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET)
db = MySQLDatabaseManager(database, self.engine) return MySQLDatabaseManager(database, self.engine)
def getDB(self, reset=0):
db = self._test_lockDatabase_open()
self.assertEqual(db.db, DB_PREFIX + '0') self.assertEqual(db.db, DB_PREFIX + '0')
self.assertEqual(db.user, DB_USER) self.assertEqual(db.user, DB_USER)
try: try:
...@@ -129,11 +131,13 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -129,11 +131,13 @@ class StorageMySQLdbTests(StorageDBTests):
class StorageMySQLdbRocksDBTests(StorageMySQLdbTests): class StorageMySQLdbRocksDBTests(StorageMySQLdbTests):
engine = "RocksDB" engine = "RocksDB"
test_lockDatabase = None
class StorageMySQLdbTokuDBTests(StorageMySQLdbTests): class StorageMySQLdbTokuDBTests(StorageMySQLdbTests):
engine = "TokuDB" engine = "TokuDB"
test_lockDatabase = None
del StorageDBTests del StorageDBTests
......
...@@ -14,17 +14,29 @@ ...@@ -14,17 +14,29 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import os, unittest
from .. import getTempDirectory, DB_PREFIX
from .testStorageDBTests import StorageDBTests from .testStorageDBTests import StorageDBTests
from neo.storage.database.sqlite import SQLiteDatabaseManager from neo.storage.database.sqlite import SQLiteDatabaseManager
class StorageSQLiteTests(StorageDBTests): class StorageSQLiteTests(StorageDBTests):
def _test_lockDatabase_open(self):
db = os.path.join(getTempDirectory(), DB_PREFIX + '0.sqlite')
return SQLiteDatabaseManager(db)
def getDB(self, reset=0): def getDB(self, reset=0):
db = SQLiteDatabaseManager(':memory:') db = SQLiteDatabaseManager(':memory:')
db.setup(reset) db.setup(reset)
return db return db
def test_lockDatabase(self):
super(StorageSQLiteTests, self).test_lockDatabase()
# No lock on temporary databases.
db = self.getDB()
self.getDB().close()
db.close()
del StorageDBTests del StorageDBTests
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment