Ticket #17258: 17258.thread-local-connections.5.diff

File 17258.thread-local-connections.5.diff, 12.7 KB (added by Julien Phalip, 13 years ago)
  • django/db/__init__.py

    diff --git a/django/db/__init__.py b/django/db/__init__.py
    index 8395468..30e1c3d 100644
    a b router = ConnectionRouter(settings.DATABASE_ROUTERS)  
    2222# we manually create the dictionary from the settings, passing only the
    2323# settings that the database backends care about. Note that TIME_ZONE is used
    2424# by the PostgreSQL backends.
    25 # we load all these up for backwards compatibility, you should use
     25# We load all these up for backwards compatibility, you should use
    2626# connections['default'] instead.
    27 connection = connections[DEFAULT_DB_ALIAS]
     27class DefaultConnectionProxy(object):
     28    """
     29    Proxy for the thread-local default connection.
     30    """
     31    def __getattr__(self, item):
     32        return getattr(connections[DEFAULT_DB_ALIAS], item)
     33
     34    def __setattr__(self, name, value):
     35        return setattr(connections[DEFAULT_DB_ALIAS], name, value)
     36
     37connection = DefaultConnectionProxy()
    2838backend = load_backend(connection.settings_dict['ENGINE'])
    2939
    3040# Register an event that closes the database connection
  • django/db/backends/__init__.py

    diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
    index f2bde84..a218127 100644
    a b  
     1from django.db.utils import DatabaseError
     2
    13try:
    24    import thread
    35except ImportError:
    46    import dummy_thread as thread
    5 from threading import local
    67from contextlib import contextmanager
    78
    89from django.conf import settings
    from django.utils.importlib import import_module  
    1314from django.utils.timezone import is_aware
    1415
    1516
    16 class BaseDatabaseWrapper(local):
     17class BaseDatabaseWrapper(object):
    1718    """
    1819    Represents a database connection.
    1920    """
    2021    ops = None
    2122    vendor = 'unknown'
    2223
    23     def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
     24    def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
     25                 allow_thread_sharing=False):
    2426        # `settings_dict` should be a dictionary containing keys such as
    2527        # NAME, USER, etc. It's called `settings_dict` instead of `settings`
    2628        # to disambiguate it from Django settings modules.
    class BaseDatabaseWrapper(local):  
    3436        self.transaction_state = []
    3537        self.savepoint_state = 0
    3638        self._dirty = None
     39        self._thread_ident = thread.get_ident()
     40        self.allow_thread_sharing = allow_thread_sharing
    3741
    3842    def __eq__(self, other):
    3943        return self.alias == other.alias
    class BaseDatabaseWrapper(local):  
    116120                "pending COMMIT/ROLLBACK")
    117121        self._dirty = False
    118122
     123    def validate_thread_sharing(self):
     124        if (not self.allow_thread_sharing
     125            and self._thread_ident != thread.get_ident()):
     126                raise DatabaseError ("DatabaseWrapper objects created in a "
     127                    "thread can only be used in that same thread, unless you "
     128                    "explicitly set its 'check_same_thread' property to "
     129                    "False. The object was created in thread id %s and this "
     130                    "is thread id %s."
     131                    % (self._thread_ident, thread.get_ident()))
     132
    119133    def is_dirty(self):
    120134        """
    121135        Returns True if the current transaction requires a commit for changes to
    class BaseDatabaseWrapper(local):  
    179193        """
    180194        Commits changes if the system is not in managed transaction mode.
    181195        """
     196        self.validate_thread_sharing()
    182197        if not self.is_managed():
    183198            self._commit()
    184199            self.clean_savepoints()
    class BaseDatabaseWrapper(local):  
    189204        """
    190205        Rolls back changes if the system is not in managed transaction mode.
    191206        """
     207        self.validate_thread_sharing()
    192208        if not self.is_managed():
    193209            self._rollback()
    194210        else:
    class BaseDatabaseWrapper(local):  
    198214        """
    199215        Does the commit itself and resets the dirty flag.
    200216        """
     217        self.validate_thread_sharing()
    201218        self._commit()
    202219        self.set_clean()
    203220
    class BaseDatabaseWrapper(local):  
    205222        """
    206223        This function does the rollback itself and resets the dirty flag.
    207224        """
     225        self.validate_thread_sharing()
    208226        self._rollback()
    209227        self.set_clean()
    210228
    class BaseDatabaseWrapper(local):  
    228246        Rolls back the most recent savepoint (if one exists). Does nothing if
    229247        savepoints are not supported.
    230248        """
     249        self.validate_thread_sharing()
    231250        if self.savepoint_state:
    232251            self._savepoint_rollback(sid)
    233252
    class BaseDatabaseWrapper(local):  
    236255        Commits the most recent savepoint (if one exists). Does nothing if
    237256        savepoints are not supported.
    238257        """
     258        self.validate_thread_sharing()
    239259        if self.savepoint_state:
    240260            self._savepoint_commit(sid)
    241261
    class BaseDatabaseWrapper(local):  
    269289        pass
    270290
    271291    def close(self):
     292        self.validate_thread_sharing()
    272293        if self.connection is not None:
    273294            self.connection.close()
    274295            self.connection = None
    275296
    276297    def cursor(self):
     298        self.validate_thread_sharing()
    277299        if (self.use_debug_cursor or
    278300            (self.use_debug_cursor is None and settings.DEBUG)):
    279301            cursor = self.make_debug_cursor(self._cursor())
  • django/db/backends/sqlite3/base.py

    diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py
    index a610606..e440e8b 100644
    a b standard library.  
    77
    88import datetime
    99import decimal
     10import warnings
    1011import re
    1112import sys
    1213
    13 from django.conf import settings
    1414from django.db import utils
    1515from django.db.backends import *
    1616from django.db.backends.signals import connection_created
    class DatabaseWrapper(BaseDatabaseWrapper):  
    241241                'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
    242242            }
    243243            kwargs.update(settings_dict['OPTIONS'])
     244            # Always allow the underlying SQLite connection to be shareable
     245            # between multiple threads. The safe-guarding will be handled at a
     246            # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
     247            # property (i.e. the official API). This is necessary as the
     248            # shareability is disabled by default in pysqlite and it cannot be
     249            # changed once a connection is opened.
     250            if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
     251                warnings.warn(
     252                    'The `check_same_thread` option was provided and set to '
     253                    'True. It will be overriden with True. Use the '
     254                    '`DatabaseWrapper.allow_thread_sharing` property instead '
     255                    'for controlling thread shareability.',
     256                    RuntimeWarning
     257                )
     258            kwargs.update({'check_same_thread': False})
    244259            self.connection = Database.connect(**kwargs)
    245260            # Register extract, date_trunc, and regexp functions.
    246261            self.connection.create_function("django_extract", 2, _sqlite_extract)
  • django/db/utils.py

    diff --git a/django/db/utils.py b/django/db/utils.py
    index f0c13e3..41ad6df 100644
    a b  
    11import os
     2from threading import local
    23
    34from django.conf import settings
    45from django.core.exceptions import ImproperlyConfigured
    class ConnectionDoesNotExist(Exception):  
    5051class ConnectionHandler(object):
    5152    def __init__(self, databases):
    5253        self.databases = databases
    53         self._connections = {}
     54        self._connections = local()
    5455
    5556    def ensure_defaults(self, alias):
    5657        """
    class ConnectionHandler(object):  
    7374            conn.setdefault(setting, None)
    7475
    7576    def __getitem__(self, alias):
    76         if alias in self._connections:
    77             return self._connections[alias]
     77        if hasattr(self._connections, alias):
     78            return getattr(self._connections, alias)
    7879
    7980        self.ensure_defaults(alias)
    8081        db = self.databases[alias]
    8182        backend = load_backend(db['ENGINE'])
    8283        conn = backend.DatabaseWrapper(db, alias)
    83         self._connections[alias] = conn
     84        setattr(self._connections, alias, conn)
    8485        return conn
    8586
     87    def __setitem__(self, key, value):
     88        setattr(self._connections, key, value)
     89
    8690    def __iter__(self):
    8791        return iter(self.databases)
    8892
  • tests/regressiontests/backends/tests.py

    diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py
    index 936f010..2c7dbd0 100644
    a b  
    33from __future__ import with_statement, absolute_import
    44
    55import datetime
     6import threading
    67
    78from django.conf import settings
    89from django.core.management.color import no_style
    class ConnectionCreatedSignalTest(TestCase):  
    283284        connection_created.connect(receiver)
    284285        connection.close()
    285286        cursor = connection.cursor()
    286         self.assertTrue(data["connection"] is connection)
     287        self.assertTrue(data["connection"].connection is connection.connection)
    287288
    288289        connection_created.disconnect(receiver)
    289290        data.clear()
    class FkConstraintsTests(TransactionTestCase):  
    446447                        connection.check_constraints()
    447448            finally:
    448449                transaction.rollback()
     450
     451class ThreadTests(TestCase):
     452
     453    def test_default_connection_thread_local(self):
     454        """
     455        Ensure that the default connection (i.e. django.db.connection) is
     456        different for each thread.
     457        Refs #17258.
     458        """
     459        connections_set = set()
     460        connection.cursor()
     461        connections_set.add(connection.connection)
     462        def runner():
     463            from django.db import connection
     464            connection.cursor()
     465            connections_set.add(connection.connection)
     466        for x in xrange(2):
     467            t = threading.Thread(target=runner)
     468            t.start()
     469            t.join()
     470        self.assertEquals(len(connections_set), 3)
     471        # Finish by closing the connections opened by the other threads (the
     472        # connection opened in the main thread will automatically be closed on
     473        # teardown).
     474        for conn in connections_set:
     475            if conn != connection.connection:
     476                conn.close()
     477
     478    def test_connections_thread_local(self):
     479        """
     480        Ensure that the connections are different for each thread.
     481        Refs #17258.
     482        """
     483        connections_set = set()
     484        for conn in connections.all():
     485            connections_set.add(conn)
     486        def runner():
     487            from django.db import connections
     488            for conn in connections.all():
     489                connections_set.add(conn)
     490        for x in xrange(2):
     491            t = threading.Thread(target=runner)
     492            t.start()
     493            t.join()
     494        self.assertEquals(len(connections_set), 6)
     495        # Finish by closing the connections opened by the other threads (the
     496        # connection opened in the main thread will automatically be closed on
     497        # teardown).
     498        for conn in connections_set:
     499            if conn != connection:
     500                conn.close()
     501
     502    def test_pass_connection_between_threads(self):
     503        """
     504        Ensure that one can pass a connection from one thread to the other.
     505        Refs #17258.
     506        """
     507        models.Person.objects.create(first_name="John", last_name="Doe")
     508
     509        def do_thread():
     510            def runner(main_thread_connection):
     511                from django.db import connections
     512                connections['default'] = main_thread_connection
     513                try:
     514                    models.Person.objects.get(first_name="John", last_name="Doe")
     515                except DatabaseError, e:
     516                    exceptions.append(e)
     517            t = threading.Thread(target=runner, args=[connections['default']])
     518            t.start()
     519            t.join()
     520
     521        # Without touching allow_thread_sharing, which should be False by default.
     522        exceptions = []
     523        do_thread()
     524        # Forbidden!
     525        self.assertTrue(isinstance(exceptions[0], DatabaseError))
     526
     527        # If explicitly setting allow_thread_sharing to False
     528        connections['default'].allow_thread_sharing = False
     529        exceptions = []
     530        do_thread()
     531        # Forbidden!
     532        self.assertTrue(isinstance(exceptions[0], DatabaseError))
     533
     534        # If explicitly setting allow_thread_sharing to True
     535        connections['default'].allow_thread_sharing = True
     536        exceptions = []
     537        do_thread()
     538        # All good
     539        self.assertEqual(len(exceptions), 0)
     540 No newline at end of file
Back to Top