Ticket #17258: 17258.thread-local-connections.4.diff

File 17258.thread-local-connections.4.diff, 12.7 KB (added by Julien Phalip, 13 years ago)

Similar approach to pysqlite's check_same_thread

  • django/db/__init__.py

    diff --git a/django/db/__init__.py b/django/db/__init__.py
    index 8395468..30e1c3d 100644
    a b router = ConnectionRouter(settings.DATABASE_ROUTERS)  
    2222# we manually create the dictionary from the settings, passing only the
    2323# settings that the database backends care about. Note that TIME_ZONE is used
    2424# by the PostgreSQL backends.
    25 # we load all these up for backwards compatibility, you should use
     25# We load all these up for backwards compatibility, you should use
    2626# connections['default'] instead.
    27 connection = connections[DEFAULT_DB_ALIAS]
     27class DefaultConnectionProxy(object):
     28    """
     29    Proxy for the thread-local default connection.
     30    """
     31    def __getattr__(self, item):
     32        return getattr(connections[DEFAULT_DB_ALIAS], item)
     33
     34    def __setattr__(self, name, value):
     35        return setattr(connections[DEFAULT_DB_ALIAS], name, value)
     36
     37connection = DefaultConnectionProxy()
    2838backend = load_backend(connection.settings_dict['ENGINE'])
    2939
    3040# Register an event that closes the database connection
  • django/db/backends/__init__.py

    diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
    index f2bde84..460746b 100644
    a b  
     1from django.db.utils import DatabaseError
     2
    13try:
    24    import thread
    35except ImportError:
    46    import dummy_thread as thread
    5 from threading import local
    67from contextlib import contextmanager
    78
    89from django.conf import settings
    from django.utils.importlib import import_module  
    1314from django.utils.timezone import is_aware
    1415
    1516
    16 class BaseDatabaseWrapper(local):
     17class BaseDatabaseWrapper(object):
    1718    """
    1819    Represents a database connection.
    1920    """
    2021    ops = None
    2122    vendor = 'unknown'
    2223
    23     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
     24    def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
     25                 check_same_thread=True):
    2426        # `settings_dict` should be a dictionary containing keys such as
    2527        # NAME, USER, etc. It's called `settings_dict` instead of `settings`
    2628        # to disambiguate it from Django settings modules.
    class BaseDatabaseWrapper(local):  
    3436        self.transaction_state = []
    3537        self.savepoint_state = 0
    3638        self._dirty = None
     39        self.thread_id = thread.get_ident()
     40        self.check_same_thread = check_same_thread
    3741
    3842    def __eq__(self, other):
    3943        return self.alias == other.alias
    class BaseDatabaseWrapper(local):  
    116120                "pending COMMIT/ROLLBACK")
    117121        self._dirty = False
    118122
     123    def validate_thread(self):
     124        if self.check_same_thread:
     125            current_thread_id = thread.get_ident()
     126            if current_thread_id != self.thread_id:
     127                raise DatabaseError ("DatabaseWrapper objects created in a "
     128                    "thread can only be used in that same thread, unless you "
     129                    "explicitly set its 'check_same_thread' property to "
     130                    "False. The object was created in thread id %s and this "
     131                    "is thread id %s" % (self.thread_id, current_thread_id))
     132        return True
     133
    119134    def is_dirty(self):
    120135        """
    121136        Returns True if the current transaction requires a commit for changes to
    class BaseDatabaseWrapper(local):  
    179194        """
    180195        Commits changes if the system is not in managed transaction mode.
    181196        """
     197        self.validate_thread()
    182198        if not self.is_managed():
    183199            self._commit()
    184200            self.clean_savepoints()
    class BaseDatabaseWrapper(local):  
    189205        """
    190206        Rolls back changes if the system is not in managed transaction mode.
    191207        """
     208        self.validate_thread()
    192209        if not self.is_managed():
    193210            self._rollback()
    194211        else:
    class BaseDatabaseWrapper(local):  
    198215        """
    199216        Does the commit itself and resets the dirty flag.
    200217        """
     218        self.validate_thread()
    201219        self._commit()
    202220        self.set_clean()
    203221
    class BaseDatabaseWrapper(local):  
    205223        """
    206224        This function does the rollback itself and resets the dirty flag.
    207225        """
     226        self.validate_thread()
    208227        self._rollback()
    209228        self.set_clean()
    210229
    class BaseDatabaseWrapper(local):  
    228247        Rolls back the most recent savepoint (if one exists). Does nothing if
    229248        savepoints are not supported.
    230249        """
     250        self.validate_thread()
    231251        if self.savepoint_state:
    232252            self._savepoint_rollback(sid)
    233253
    class BaseDatabaseWrapper(local):  
    236256        Commits the most recent savepoint (if one exists). Does nothing if
    237257        savepoints are not supported.
    238258        """
     259        self.validate_thread()
    239260        if self.savepoint_state:
    240261            self._savepoint_commit(sid)
    241262
    class BaseDatabaseWrapper(local):  
    269290        pass
    270291
    271292    def close(self):
     293        self.validate_thread()
    272294        if self.connection is not None:
    273295            self.connection.close()
    274296            self.connection = None
    275297
    276298    def cursor(self):
     299        self.validate_thread()
    277300        if (self.use_debug_cursor or
    278301            (self.use_debug_cursor is None and settings.DEBUG)):
    279302            cursor = self.make_debug_cursor(self._cursor())
  • django/db/backends/sqlite3/base.py

    diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
    index a610606..4a96c7c 100644
    a b class DatabaseWrapper(BaseDatabaseWrapper):  
    220220        'iendswith': "LIKE %s ESCAPE '\\'",
    221221    }
    222222
    223     def __init__(self, *args, **kwargs):
    224         super(DatabaseWrapper, self).__init__(*args, **kwargs)
     223    def __init__(self, settings_dict, *args, **kwargs):
     224        super(DatabaseWrapper, self).__init__(settings_dict, *args, **kwargs)
    225225
    226226        self.features = DatabaseFeatures(self)
    227227        self.ops = DatabaseOperations(self)
    class DatabaseWrapper(BaseDatabaseWrapper):  
    229229        self.creation = DatabaseCreation(self)
    230230        self.introspection = DatabaseIntrospection(self)
    231231        self.validation = BaseDatabaseValidation(self)
     232        self.check_same_thread = (
     233            settings_dict['OPTIONS'].get('check_same_thread'), True)
    232234
    233235    def _cursor(self):
    234236        if self.connection is None:
    class DatabaseWrapper(BaseDatabaseWrapper):  
    239241            kwargs = {
    240242                'database': settings_dict['NAME'],
    241243                'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
     244                'check_same_thread': 1 if self.check_same_thread else 0,
    242245            }
    243246            kwargs.update(settings_dict['OPTIONS'])
    244247            self.connection = Database.connect(**kwargs)
  • django/db/utils.py

    diff --git a/django/db/utils.py b/django/db/utils.py
    index f0c13e3..41ad6df 100644
    a b  
    11import os
     2from threading import local
    23
    34from django.conf import settings
    45from django.core.exceptions import ImproperlyConfigured
    class ConnectionDoesNotExist(Exception):  
    5051class ConnectionHandler(object):
    5152    def __init__(self, databases):
    5253        self.databases = databases
    53         self._connections = {}
     54        self._connections = local()
    5455
    5556    def ensure_defaults(self, alias):
    5657        """
    class ConnectionHandler(object):  
    7374            conn.setdefault(setting, None)
    7475
    7576    def __getitem__(self, alias):
    76         if alias in self._connections:
    77             return self._connections[alias]
     77        if hasattr(self._connections, alias):
     78            return getattr(self._connections, alias)
    7879
    7980        self.ensure_defaults(alias)
    8081        db = self.databases[alias]
    8182        backend = load_backend(db['ENGINE'])
    8283        conn = backend.DatabaseWrapper(db, alias)
    83         self._connections[alias] = conn
     84        setattr(self._connections, alias, conn)
    8485        return conn
    8586
     87    def __setitem__(self, key, value):
     88        setattr(self._connections, key, value)
     89
    8690    def __iter__(self):
    8791        return iter(self.databases)
    8892
  • tests/regressiontests/backends/tests.py

    diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py
    index 936f010..998e7ce 100644
    a b  
    33from __future__ import with_statement, absolute_import
    44
    55import datetime
     6import threading
    67
    78from django.conf import settings
    89from django.core.management.color import no_style
    class ConnectionCreatedSignalTest(TestCase):  
    283284        connection_created.connect(receiver)
    284285        connection.close()
    285286        cursor = connection.cursor()
    286         self.assertTrue(data["connection"] is connection)
     287        self.assertTrue(data["connection"].connection is connection.connection)
    287288
    288289        connection_created.disconnect(receiver)
    289290        data.clear()
    class FkConstraintsTests(TransactionTestCase):  
    446447                        connection.check_constraints()
    447448            finally:
    448449                transaction.rollback()
     450
     451class ThreadTests(TestCase):
     452
     453    def test_default_connection_thread_local(self):
     454        """
     455        Ensure that the default connection (i.e. django.db.connection) is
     456        different for each thread.
     457        Refs #17258.
     458        """
     459        connections_set = set()
     460        connection.cursor()
     461        connections_set.add(connection.connection)
     462        def runner():
     463            from django.db import connection
     464            connection.cursor()
     465            connections_set.add(connection.connection)
     466        for x in xrange(2):
     467            t = threading.Thread(target=runner)
     468            t.start()
     469            t.join()
     470        self.assertEquals(len(connections_set), 3)
     471        # Finish by closing the connections opened by the other threads (the
     472        # connection opened in the main thread will automatically be closed on
     473        # teardown).
     474        for conn in connections_set:
     475            if conn != connection.connection:
     476                conn.close()
     477
     478    def test_connections_thread_local(self):
     479        """
     480        Ensure that the connections are different for each thread.
     481        Refs #17258.
     482        """
     483        connections_set = set()
     484        for conn in connections.all():
     485            connections_set.add(conn)
     486        def runner():
     487            from django.db import connections
     488            for conn in connections.all():
     489                connections_set.add(conn)
     490        for x in xrange(2):
     491            t = threading.Thread(target=runner)
     492            t.start()
     493            t.join()
     494        self.assertEquals(len(connections_set), 6)
     495        # Finish by closing the connections opened by the other threads (the
     496        # connection opened in the main thread will automatically be closed on
     497        # teardown).
     498        for conn in connections_set:
     499            if conn != connection:
     500                conn.close()
     501
     502    def test_pass_connection_between_threads(self):
     503        """
     504        Ensure that one can pass a connection from one thread to the other.
     505        Refs #17258.
     506        """
     507        models.Person.objects.create(first_name="John", last_name="Doe")
     508
     509        def do_thread():
     510            def runner(main_thread_connection):
     511                from django.db import connections
     512                connections['default'] = main_thread_connection
     513                try:
     514                    models.Person.objects.get(first_name="John", last_name="Doe")
     515                except DatabaseError, e:
     516                    exceptions.append(e)
     517            t = threading.Thread(target=runner, args=[connections['default']])
     518            t.start()
     519            t.join()
     520
     521        # By default (without touching check_same_thread)
     522        exceptions = []
     523        do_thread()
     524        self.assertTrue(isinstance(exceptions[0], DatabaseError))
     525
     526        # If explicitly setting check_same_thread to True
     527        connections['default'].check_same_thread = True
     528        exceptions = []
     529        do_thread()
     530        self.assertTrue(isinstance(exceptions[0], DatabaseError))
     531
     532        # If explicitly setting check_same_thread to False
     533        connections['default'].check_same_thread = False
     534        exceptions = []
     535        do_thread()
     536        self.assertEqual(len(exceptions), 0)
     537 No newline at end of file
  • tests/test_sqlite.py

    diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py
    index de8bf93..c143edd 100644
    a b  
    1414
    1515DATABASES = {
    1616    'default': {
    17         'ENGINE': 'django.db.backends.sqlite3'
     17        'ENGINE': 'django.db.backends.sqlite3',
     18        'OPTIONS': {
     19            'check_same_thread': False,
     20        }
    1821    },
    1922    'other': {
    2023        'ENGINE': 'django.db.backends.sqlite3',
     24        'OPTIONS': {
     25            'check_same_thread': False,
     26        }
    2127    }
    2228}
Back to Top