Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion django/db/backends/sqlite3/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ def setup_worker_connection(self, _worker_id):
connection_str = (
f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared"
)
source_db_name = settings_dict["NAME"]
source_db = self.connection.Database.connect(
f"file:{alias}_{_worker_id}.sqlite3?mode=ro", uri=True
f"file:{source_db_name}?mode=ro", uri=True
)
target_db = sqlite3.connect(connection_str, uri=True)
source_db.backup(target_db)
Expand Down
20 changes: 10 additions & 10 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def as_sql(self, compiler, connection):
# order of precedence
expression_wrapper = "(%s)"
sql = connection.ops.combine_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
return expression_wrapper % sql, tuple(expression_params)

def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
Expand Down Expand Up @@ -835,7 +835,7 @@ def as_sql(self, compiler, connection):
# order of precedence
expression_wrapper = "(%s)"
sql = connection.ops.combine_duration_expression(self.connector, expressions)
return expression_wrapper % sql, expression_params
return expression_wrapper % sql, tuple(expression_params)

def as_sqlite(self, compiler, connection, **extra_context):
sql, params = self.as_sql(compiler, connection, **extra_context)
Expand Down Expand Up @@ -1179,8 +1179,8 @@ def as_sql(self, compiler, connection):
# oracledb does not always convert None to the appropriate
# NULL type (like in case expressions using numbers), so we
# use a literal SQL NULL
return "NULL", []
return "%s", [val]
return "NULL", ()
return "%s", (val,)

def as_sqlite(self, compiler, connection, **extra_context):
sql, params = self.as_sql(compiler, connection, **extra_context)
Expand Down Expand Up @@ -1273,7 +1273,7 @@ def __repr__(self):
return "'*'"

def as_sql(self, compiler, connection):
return "*", []
return "*", ()


class DatabaseDefault(Expression):
Expand Down Expand Up @@ -1313,7 +1313,7 @@ def resolve_expression(
def as_sql(self, compiler, connection):
if not connection.features.supports_default_keyword_in_insert:
return compiler.compile(self.expression)
return "DEFAULT", []
return "DEFAULT", ()


class Col(Expression):
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def as_sql(self, compiler, connection):
cols_sql.append(sql)
cols_params.extend(params)

return ", ".join(cols_sql), cols_params
return ", ".join(cols_sql), tuple(cols_params)

def relabeled_clone(self, relabels):
return self.__class__(
Expand Down Expand Up @@ -1447,7 +1447,7 @@ def relabeled_clone(self, relabels):
return clone

def as_sql(self, compiler, connection):
return connection.ops.quote_name(self.refs), []
return connection.ops.quote_name(self.refs), ()

def get_group_by_cols(self):
return [self]
Expand Down Expand Up @@ -1764,7 +1764,7 @@ def as_sql(
sql = template % template_params
if self._output_field_or_none is not None:
sql = connection.ops.unification_cast_sql(self.output_field) % sql
return sql, sql_params
return sql, tuple(sql_params)

def get_group_by_cols(self):
if not self.cases:
Expand Down Expand Up @@ -2148,7 +2148,7 @@ def as_sql(self, compiler, connection):
"end": end,
"exclude": self.get_exclusion(),
},
[],
(),
)

def __repr__(self):
Expand Down
50 changes: 50 additions & 0 deletions tests/backends/sqlite/test_creation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import multiprocessing
import sqlite3
import unittest
from unittest import mock

Expand Down Expand Up @@ -41,3 +42,52 @@ def test_get_test_db_clone_settings_not_supported(self, *mocked_objects):
msg = "Cloning with start method 'unsupported' is not supported."
with self.assertRaisesMessage(NotSupportedError, msg):
connection.creation.get_test_db_clone_settings(1)

@mock.patch.object(multiprocessing, "get_start_method", return_value="spawn")
def test_setup_worker_connection_respects_test_database_name(self, *mocked_objects):
test_connection = copy.copy(connections[DEFAULT_DB_ALIAS])
test_connection.settings_dict = copy.deepcopy(
connections[DEFAULT_DB_ALIAS].settings_dict
)
tests = [
("mytest.db", "mytest_2.db"),
("mytest", "mytest_2"),
]
for test_db_name, expected_source_db_name in tests:
with self.subTest(test_db_name=test_db_name):
# When calling setup_worker_connection(), the test db has been
# created already and its name has been copied to
# settings_dict["NAME"], so no need to set ["TEST"]["NAME"].
test_connection.settings_dict["NAME"] = test_db_name
creation_class = test_connection.creation_class(test_connection)
worker_id = 2
mock_source_db = mock.MagicMock()
mock_target_db = mock.MagicMock()
with (
# Mock connection to source test database.
mock.patch.object(
test_connection.Database,
"connect",
return_value=mock_source_db,
) as mock_source_connect,
# Mock connection to target in-memory db for copying.
mock.patch.object(
sqlite3,
"connect",
return_value=mock_target_db,
) as mock_target_connect,
# Mock reconnection to target in-memory db after copying.
mock.patch.object(test_connection, "connect"),
):
creation_class.setup_worker_connection(worker_id)
mock_source_connect.assert_called_once_with(
f"file:{expected_source_db_name}?mode=ro",
uri=True,
)
mock_target_connect.assert_called_once_with(
"file:memorydb_default_2?mode=memory&cache=shared",
uri=True,
)
mock_source_db.backup.assert_called_once_with(mock_target_db)
mock_source_db.close.assert_called_once()
mock_target_db.close.assert_called_once()
4 changes: 2 additions & 2 deletions tests/expressions/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,9 +2504,9 @@ def test_compile_unresolved(self):
# This test might need to be revisited later on if #25425 is enforced.
compiler = Time.objects.all().query.get_compiler(connection=connection)
value = Value("foo")
self.assertEqual(value.as_sql(compiler, connection), ("%s", ["foo"]))
self.assertEqual(value.as_sql(compiler, connection), ("%s", ("foo",)))
value = Value("foo", output_field=CharField())
self.assertEqual(value.as_sql(compiler, connection), ("%s", ["foo"]))
self.assertEqual(value.as_sql(compiler, connection), ("%s", ("foo",)))

def test_output_field_decimalfield(self):
Time.objects.create()
Expand Down