Ticket #17258: 17258.thread-local-connections.4.diff
File 17258.thread-local-connections.4.diff, 12.7 KB (added by , 13 years ago) |
---|
-
django/db/__init__.py
diff --git a/django/db/__init__.py b/django/db/__init__.py index 8395468..30e1c3d 100644
a b router = ConnectionRouter(settings.DATABASE_ROUTERS) 22 22 # we manually create the dictionary from the settings, passing only the 23 23 # settings that the database backends care about. Note that TIME_ZONE is used 24 24 # by the PostgreSQL backends. 25 # we load all these up for backwards compatibility, you should use25 # We load all these up for backwards compatibility, you should use 26 26 # connections['default'] instead. 27 connection = connections[DEFAULT_DB_ALIAS] 27 class DefaultConnectionProxy(object): 28 """ 29 Proxy for the thread-local default connection. 30 """ 31 def __getattr__(self, item): 32 return getattr(connections[DEFAULT_DB_ALIAS], item) 33 34 def __setattr__(self, name, value): 35 return setattr(connections[DEFAULT_DB_ALIAS], name, value) 36 37 connection = DefaultConnectionProxy() 28 38 backend = load_backend(connection.settings_dict['ENGINE']) 29 39 30 40 # Register an event that closes the database connection -
django/db/backends/__init__.py
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index f2bde84..460746b 100644
a b 1 from django.db.utils import DatabaseError 2 1 3 try: 2 4 import thread 3 5 except ImportError: 4 6 import dummy_thread as thread 5 from threading import local6 7 from contextlib import contextmanager 7 8 8 9 from django.conf import settings … … from django.utils.importlib import import_module 13 14 from django.utils.timezone import is_aware 14 15 15 16 16 class BaseDatabaseWrapper( local):17 class BaseDatabaseWrapper(object): 17 18 """ 18 19 Represents a database connection. 19 20 """ 20 21 ops = None 21 22 vendor = 'unknown' 22 23 23 def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): 24 def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS, 25 check_same_thread=True): 24 26 # `settings_dict` should be a dictionary containing keys such as 25 27 # NAME, USER, etc. It's called `settings_dict` instead of `settings` 26 28 # to disambiguate it from Django settings modules. … … class BaseDatabaseWrapper(local): 34 36 self.transaction_state = [] 35 37 self.savepoint_state = 0 36 38 self._dirty = None 39 self.thread_id = thread.get_ident() 40 self.check_same_thread = check_same_thread 37 41 38 42 def __eq__(self, other): 39 43 return self.alias == other.alias … … class BaseDatabaseWrapper(local): 116 120 "pending COMMIT/ROLLBACK") 117 121 self._dirty = False 118 122 123 def validate_thread(self): 124 if self.check_same_thread: 125 current_thread_id = thread.get_ident() 126 if current_thread_id != self.thread_id: 127 raise DatabaseError ("DatabaseWrapper objects created in a " 128 "thread can only be used in that same thread, unless you " 129 "explicitly set its 'check_same_thread' property to " 130 "False. The object was created in thread id %s and this " 131 "is thread id %s" % (self.thread_id, current_thread_id)) 132 return True 133 119 134 def is_dirty(self): 120 135 """ 121 136 Returns True if the current transaction requires a commit for changes to … … class BaseDatabaseWrapper(local): 179 194 """ 180 195 Commits changes if the system is not in managed transaction mode. 181 196 """ 197 self.validate_thread() 182 198 if not self.is_managed(): 183 199 self._commit() 184 200 self.clean_savepoints() … … class BaseDatabaseWrapper(local): 189 205 """ 190 206 Rolls back changes if the system is not in managed transaction mode. 191 207 """ 208 self.validate_thread() 192 209 if not self.is_managed(): 193 210 self._rollback() 194 211 else: … … class BaseDatabaseWrapper(local): 198 215 """ 199 216 Does the commit itself and resets the dirty flag. 200 217 """ 218 self.validate_thread() 201 219 self._commit() 202 220 self.set_clean() 203 221 … … class BaseDatabaseWrapper(local): 205 223 """ 206 224 This function does the rollback itself and resets the dirty flag. 207 225 """ 226 self.validate_thread() 208 227 self._rollback() 209 228 self.set_clean() 210 229 … … class BaseDatabaseWrapper(local): 228 247 Rolls back the most recent savepoint (if one exists). Does nothing if 229 248 savepoints are not supported. 230 249 """ 250 self.validate_thread() 231 251 if self.savepoint_state: 232 252 self._savepoint_rollback(sid) 233 253 … … class BaseDatabaseWrapper(local): 236 256 Commits the most recent savepoint (if one exists). Does nothing if 237 257 savepoints are not supported. 238 258 """ 259 self.validate_thread() 239 260 if self.savepoint_state: 240 261 self._savepoint_commit(sid) 241 262 … … class BaseDatabaseWrapper(local): 269 290 pass 270 291 271 292 def close(self): 293 self.validate_thread() 272 294 if self.connection is not None: 273 295 self.connection.close() 274 296 self.connection = None 275 297 276 298 def cursor(self): 299 self.validate_thread() 277 300 if (self.use_debug_cursor or 278 301 (self.use_debug_cursor is None and settings.DEBUG)): 279 302 cursor = self.make_debug_cursor(self._cursor()) -
django/db/backends/sqlite3/base.py
diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index a610606..4a96c7c 100644
a b class DatabaseWrapper(BaseDatabaseWrapper): 220 220 'iendswith': "LIKE %s ESCAPE '\\'", 221 221 } 222 222 223 def __init__(self, *args, **kwargs):224 super(DatabaseWrapper, self).__init__( *args, **kwargs)223 def __init__(self, settings_dict, *args, **kwargs): 224 super(DatabaseWrapper, self).__init__(settings_dict, *args, **kwargs) 225 225 226 226 self.features = DatabaseFeatures(self) 227 227 self.ops = DatabaseOperations(self) … … class DatabaseWrapper(BaseDatabaseWrapper): 229 229 self.creation = DatabaseCreation(self) 230 230 self.introspection = DatabaseIntrospection(self) 231 231 self.validation = BaseDatabaseValidation(self) 232 self.check_same_thread = ( 233 settings_dict['OPTIONS'].get('check_same_thread'), True) 232 234 233 235 def _cursor(self): 234 236 if self.connection is None: … … class DatabaseWrapper(BaseDatabaseWrapper): 239 241 kwargs = { 240 242 'database': settings_dict['NAME'], 241 243 'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES, 244 'check_same_thread': 1 if self.check_same_thread else 0, 242 245 } 243 246 kwargs.update(settings_dict['OPTIONS']) 244 247 self.connection = Database.connect(**kwargs) -
django/db/utils.py
diff --git a/django/db/utils.py b/django/db/utils.py index f0c13e3..41ad6df 100644
a b 1 1 import os 2 from threading import local 2 3 3 4 from django.conf import settings 4 5 from django.core.exceptions import ImproperlyConfigured … … class ConnectionDoesNotExist(Exception): 50 51 class ConnectionHandler(object): 51 52 def __init__(self, databases): 52 53 self.databases = databases 53 self._connections = {}54 self._connections = local() 54 55 55 56 def ensure_defaults(self, alias): 56 57 """ … … class ConnectionHandler(object): 73 74 conn.setdefault(setting, None) 74 75 75 76 def __getitem__(self, alias): 76 if alias in self._connections:77 return self._connections[alias]77 if hasattr(self._connections, alias): 78 return getattr(self._connections, alias) 78 79 79 80 self.ensure_defaults(alias) 80 81 db = self.databases[alias] 81 82 backend = load_backend(db['ENGINE']) 82 83 conn = backend.DatabaseWrapper(db, alias) 83 se lf._connections[alias] = conn84 setattr(self._connections, alias, conn) 84 85 return conn 85 86 87 def __setitem__(self, key, value): 88 setattr(self._connections, key, value) 89 86 90 def __iter__(self): 87 91 return iter(self.databases) 88 92 -
tests/regressiontests/backends/tests.py
diff --git a/tests/regressiontests/backends/tests.py b/tests/regressiontests/backends/tests.py index 936f010..998e7ce 100644
a b 3 3 from __future__ import with_statement, absolute_import 4 4 5 5 import datetime 6 import threading 6 7 7 8 from django.conf import settings 8 9 from django.core.management.color import no_style … … class ConnectionCreatedSignalTest(TestCase): 283 284 connection_created.connect(receiver) 284 285 connection.close() 285 286 cursor = connection.cursor() 286 self.assertTrue(data["connection"] isconnection)287 self.assertTrue(data["connection"].connection is connection.connection) 287 288 288 289 connection_created.disconnect(receiver) 289 290 data.clear() … … class FkConstraintsTests(TransactionTestCase): 446 447 connection.check_constraints() 447 448 finally: 448 449 transaction.rollback() 450 451 class ThreadTests(TestCase): 452 453 def test_default_connection_thread_local(self): 454 """ 455 Ensure that the default connection (i.e. django.db.connection) is 456 different for each thread. 457 Refs #17258. 458 """ 459 connections_set = set() 460 connection.cursor() 461 connections_set.add(connection.connection) 462 def runner(): 463 from django.db import connection 464 connection.cursor() 465 connections_set.add(connection.connection) 466 for x in xrange(2): 467 t = threading.Thread(target=runner) 468 t.start() 469 t.join() 470 self.assertEquals(len(connections_set), 3) 471 # Finish by closing the connections opened by the other threads (the 472 # connection opened in the main thread will automatically be closed on 473 # teardown). 474 for conn in connections_set: 475 if conn != connection.connection: 476 conn.close() 477 478 def test_connections_thread_local(self): 479 """ 480 Ensure that the connections are different for each thread. 481 Refs #17258. 482 """ 483 connections_set = set() 484 for conn in connections.all(): 485 connections_set.add(conn) 486 def runner(): 487 from django.db import connections 488 for conn in connections.all(): 489 connections_set.add(conn) 490 for x in xrange(2): 491 t = threading.Thread(target=runner) 492 t.start() 493 t.join() 494 self.assertEquals(len(connections_set), 6) 495 # Finish by closing the connections opened by the other threads (the 496 # connection opened in the main thread will automatically be closed on 497 # teardown). 498 for conn in connections_set: 499 if conn != connection: 500 conn.close() 501 502 def test_pass_connection_between_threads(self): 503 """ 504 Ensure that one can pass a connection from one thread to the other. 505 Refs #17258. 506 """ 507 models.Person.objects.create(first_name="John", last_name="Doe") 508 509 def do_thread(): 510 def runner(main_thread_connection): 511 from django.db import connections 512 connections['default'] = main_thread_connection 513 try: 514 models.Person.objects.get(first_name="John", last_name="Doe") 515 except DatabaseError, e: 516 exceptions.append(e) 517 t = threading.Thread(target=runner, args=[connections['default']]) 518 t.start() 519 t.join() 520 521 # By default (without touching check_same_thread) 522 exceptions = [] 523 do_thread() 524 self.assertTrue(isinstance(exceptions[0], DatabaseError)) 525 526 # If explicitly setting check_same_thread to True 527 connections['default'].check_same_thread = True 528 exceptions = [] 529 do_thread() 530 self.assertTrue(isinstance(exceptions[0], DatabaseError)) 531 532 # If explicitly setting check_same_thread to False 533 connections['default'].check_same_thread = False 534 exceptions = [] 535 do_thread() 536 self.assertEqual(len(exceptions), 0) 537 No newline at end of file -
tests/test_sqlite.py
diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index de8bf93..c143edd 100644
a b 14 14 15 15 DATABASES = { 16 16 'default': { 17 'ENGINE': 'django.db.backends.sqlite3' 17 'ENGINE': 'django.db.backends.sqlite3', 18 'OPTIONS': { 19 'check_same_thread': False, 20 } 18 21 }, 19 22 'other': { 20 23 'ENGINE': 'django.db.backends.sqlite3', 24 'OPTIONS': { 25 'check_same_thread': False, 26 } 21 27 } 22 28 }