From 680f750e1652760d1174c47dbfa34847e9ef8f3a Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:00:07 -0700 Subject: [PATCH 01/18] chore(main): release 1.18.3 (#1309) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- README.md | 3 +- google/cloud/sql/connector/connector.py | 9 ++-- google/cloud/sql/connector/enums.py | 1 + google/cloud/sql/connector/proxy.py | 60 +++++++++++++++++++++ google/cloud/sql/connector/psycopg.py | 72 +++++++++++++++++++++++++ pyproject.toml | 1 + requirements-test.txt | 1 + tests/system/test_psycopg_connection.py | 60 +++++++++++++++++++++ tests/unit/test_psycopg.py | 40 ++++++++++++++ 9 files changed, 243 insertions(+), 4 deletions(-) create mode 100644 google/cloud/sql/connector/proxy.py create mode 100644 google/cloud/sql/connector/psycopg.py create mode 100644 tests/system/test_psycopg_connection.py create mode 100644 tests/unit/test_psycopg.py diff --git a/README.md b/README.md index d6921425f..2fbb25473 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ The Cloud SQL Python Connector is a package to be used alongside a database driv Currently supported drivers are: - [`pymysql`](https://github.com/PyMySQL/PyMySQL) (MySQL) - [`pg8000`](https://github.com/tlocke/pg8000) (PostgreSQL) + - [`psycopg`](https://github.com/psycopg/psycopg) (PostgreSQL) - [`asyncpg`](https://github.com/MagicStack/asyncpg) (PostgreSQL) - [`pytds`](https://github.com/denisenkom/pytds) (SQL Server) @@ -600,7 +601,7 @@ async def main(): # acquire connection and query Cloud SQL database async with pool.acquire() as conn: res = await conn.fetch("SELECT NOW()") - + # close Connector await connector.close_async() ``` diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 798969c2c..8332525e3 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -40,6 +40,7 @@ from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 +import google.cloud.sql.connector.psycopg as psycopg import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver @@ -234,7 +235,7 @@ def connect( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, pg8000, and pytds. + with. Supported drivers are pymysql, pg8000, psycopg, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -276,7 +277,8 @@ async def connect_async( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, asyncpg, pg8000, and pytds. + with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and + pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -288,7 +290,7 @@ async def connect_async( ValueError: Connection attempt with built-in database authentication and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, - pg8000, and pytds. + pg8000, psycopg, and pytds. RuntimeError: Connector has been closed. Cannot connect using a closed Connector. """ @@ -357,6 +359,7 @@ async def connect_async( connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, + "psycopg": psycopg.connect, "asyncpg": asyncpg.connect, "pytds": pytds.connect, } diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index e6b56af0e..5926b75a2 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -62,6 +62,7 @@ class DriverMapping(Enum): ASYNCPG = "POSTGRES" PG8000 = "POSTGRES" + PSYCOPG = "POSTGRES" PYMYSQL = "MYSQL" PYTDS = "SQLSERVER" diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py new file mode 100644 index 000000000..1f398461a --- /dev/null +++ b/google/cloud/sql/connector/proxy.py @@ -0,0 +1,60 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import socket +import os +import threading +from pathlib import Path + +SERVER_PROXY_PORT = 3307 + +def start_local_proxy( + ssl_sock, + socket_path, +): + desired_path = Path(socket_path) + desired_path.mkdir(parents=True, exist_ok=True) + + if os.path.exists(socket_path): + os.remove(socket_path) + conn_unix = None + unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + unix_socket.bind(socket_path) + unix_socket.listen(1) + + threading.Thread(target=local_communication, args=(unix_socket, ssl_sock, socket_path)).start() + + +def local_communication( + unix_socket, ssl_sock, socket_path +): + try: + conn_unix, addr_unix = unix_socket.accept() + + while True: + data = conn_unix.recv(10485760) + if not data: + break + ssl_sock.sendall(data) + response = ssl_sock.recv(10485760) + conn_unix.sendall(response) + + finally: + if conn_unix is not None: + conn_unix.close() + unix_socket.close() + os.remove(socket_path) # Clean up the socket file diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/psycopg.py new file mode 100644 index 000000000..c34dcb1e1 --- /dev/null +++ b/google/cloud/sql/connector/psycopg.py @@ -0,0 +1,72 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import ssl +from typing import Any, TYPE_CHECKING +import threading + +SERVER_PROXY_PORT = 3307 + +if TYPE_CHECKING: + import psycopg + + +def connect( + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any +) -> "psycopg.Connection": + """Helper function to create a psycopg DB-API connection object. + + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + kwargs: Additional arguments to pass to the psycopg connect method. + + Returns: + psycopg.Connection: A psycopg connection to the Cloud SQL + instance. + + Raises: + ImportError: The psycopg module cannot be imported. + """ + try: + from psycopg.rows import dict_row + from psycopg import Connection + import threading + from google.cloud.sql.connector.proxy import start_local_proxy + except ImportError: + raise ImportError( + 'Unable to import module "psycopg." Please install and try again.' + ) + + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password", None) + + kwargs.pop("timeout", None) + + start_local_proxy(sock, f"/tmp/connector-socket/.s.PGSQL.3307") + + conn = Connection.connect( + f"host=/tmp/connector-socket port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", + autocommit=True, + row_factory=dict_row, + **kwargs + ) + + conn.autocommit = True + return conn diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..1c933cf6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ Changelog = "https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/b [project.optional-dependencies] pymysql = ["PyMySQL>=1.1.0"] pg8000 = ["pg8000>=1.31.1"] +psycopg = ["psycopg>=3.2.9"] pytds = ["python-tds>=1.15.0"] asyncpg = ["asyncpg>=0.30.0"] diff --git a/requirements-test.txt b/requirements-test.txt index fa630a373..25311db58 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,6 +7,7 @@ sqlalchemy-pytds==1.0.2 sqlalchemy-stubs==0.4 PyMySQL==1.1.2 pg8000==1.31.5 +psycopg[binary]==3.2.9 asyncpg==0.31.0 python-tds==1.17.1 aioresponses==0.7.8 diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py new file mode 100644 index 000000000..5b3a3c79e --- /dev/null +++ b/tests/system/test_psycopg_connection.py @@ -0,0 +1,60 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime +import os + +# [START cloud_sql_connector_postgres_psycopg] + +from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver + +from sqlalchemy.dialects.postgresql.base import PGDialect +PGDialect._get_server_version_info = lambda *args: (9, 2) + +# [END cloud_sql_connector_postgres_psycopg] + + +def test_psycopg_connection() -> None: + """Basic test to get time from database.""" + inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_PASS"] + db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") + + connector = Connector(refresh_strategy="background", resolver=DefaultResolver) + + pool = connector.connect( + inst_conn_name, + "psycopg", + user=user, + password=password, + db=db, + ip_type=ip_type, # can be "public", "private" or "psc" + ) + + with pool as conn: + + # Open a cursor to perform database operations + with conn.cursor() as cur: + + # Query the database and obtain data as Python objects. + cur.execute("SELECT NOW()") + curr_time = cur.fetchone()["now"] + assert type(curr_time) is datetime + + diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_psycopg.py new file mode 100644 index 000000000..8d9fe1f85 --- /dev/null +++ b/tests/unit/test_psycopg.py @@ -0,0 +1,40 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import socket +import ssl +from typing import Any + +from mock import patch, PropertyMock +import pytest + +from google.cloud.sql.connector.psycopg import connect + + +@pytest.mark.usefixtures("proxy_server") +async def test_psycopg(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that psycopg gets to proper connection call.""" + ip_addr = "127.0.0.1" + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + ) + with patch("psycopg.connect") as mock_connect: + type(mock_connect.return_value).autocommit = PropertyMock(return_value=True) + connection = connect(ip_addr, sock, **kwargs) + assert connection.autocommit is True + # verify that driver connection call would be made + assert mock_connect.assert_called_once From 541974292cfc4a0ce45232eaef807aff73ffc78f Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Wed, 16 Jul 2025 18:06:56 -0600 Subject: [PATCH 02/18] feat(main): add support for psycopg Changelog: - Add proxy for connections that can only be made through an unix socket, to support the TLS connection - Add support for psycopg, using the proxy server - Add unit and integration tests - Update docs --- google/cloud/sql/connector/proxy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 1f398461a..5f46003bb 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -25,7 +25,10 @@ def start_local_proxy( ssl_sock, socket_path, ): - desired_path = Path(socket_path) + path_parts = socket_path.rsplit('/', 1) + parent_directory = '/'.join(path_parts[:-1]) + + desired_path = Path(parent_directory) desired_path.mkdir(parents=True, exist_ok=True) if os.path.exists(socket_path): From 4d7cf6e7dbbd1d202f41bfa8733ba7db8658ba25 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 17 Jul 2025 21:44:09 -0600 Subject: [PATCH 03/18] fix(main): fix feedback PR Changelog: - Make local_socket_path configurable - Set right file permissions - Handle exceptions properly - Use asyncio and the main loop to stop the local proxy and clear the file when the connector is stopped --- google/cloud/sql/connector/connector.py | 18 +++- google/cloud/sql/connector/exceptions.py | 6 ++ google/cloud/sql/connector/proxy.py | 91 +++++++++++++------- google/cloud/sql/connector/psycopg.py | 15 +--- tests/system/test_psycopg_connection.py | 105 ++++++++++++++++++----- 5 files changed, 166 insertions(+), 69 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 8332525e3..7df618e1c 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -47,10 +47,12 @@ from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys +from google.cloud.sql.connector.proxy import start_local_proxy logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] +LOCAL_PROXY_DRIVERS = ["psycopg"] SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" @@ -435,7 +437,7 @@ async def connect_async( # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: return await connector( - ip_address, + host, await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) @@ -445,6 +447,18 @@ async def connect_async( socket.create_connection((ip_address, SERVER_PROXY_PORT)), server_hostname=ip_address, ) + + host = ip_address + # start local proxy if driver needs it + if driver in LOCAL_PROXY_DRIVERS: + local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") + host = local_socket_path + start_local_proxy( + sock, + socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", + loop=self._loop + ) + # If this connection was opened using a domain name, then store it # for later in case we need to forcibly close it on failover. if conn_info.conn_name.domain_name: @@ -452,7 +466,7 @@ async def connect_async( # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, - ip_address, + host, sock, **kwargs, ) diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 1f15ced47..f4ff92dc5 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -91,3 +91,9 @@ class ClosedConnectorError(Exception): Exception to be raised when a Connector is closed and connect method is called on it. """ + + +class LocalProxyStartupError(Exception): + """ + Exception to be raised when a the local UNIX-socket based proxy can not be started. + """ diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 5f46003bb..06b623e82 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -16,48 +16,75 @@ import socket import os -import threading +import ssl +import asyncio from pathlib import Path +from typing import Optional + +from google.cloud.sql.connector.exceptions import LocalProxyStartupError SERVER_PROXY_PORT = 3307 +LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 def start_local_proxy( - ssl_sock, - socket_path, + ssl_sock: ssl.SSLSocket, + socket_path: Optional[str] = "/tmp/connector-socket", + loop: Optional[asyncio.AbstractEventLoop] = None, ): - path_parts = socket_path.rsplit('/', 1) - parent_directory = '/'.join(path_parts[:-1]) + """Helper function to start a UNIX based local proxy for + transport messages through the SSL Socket. + + Args: + ssl_sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + socket_path: A system path that is going to be used to store the socket. + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + Raises: + LocalProxyStartupError: Local UNIX socket based proxy was not able to + get started. + """ + unix_socket = None - desired_path = Path(parent_directory) - desired_path.mkdir(parents=True, exist_ok=True) + try: + path_parts = socket_path.rsplit('/', 1) + parent_directory = '/'.join(path_parts[:-1]) - if os.path.exists(socket_path): - os.remove(socket_path) - conn_unix = None - unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + desired_path = Path(parent_directory) + desired_path.mkdir(parents=True, exist_ok=True) - unix_socket.bind(socket_path) - unix_socket.listen(1) + if os.path.exists(socket_path): + os.remove(socket_path) + unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - threading.Thread(target=local_communication, args=(unix_socket, ssl_sock, socket_path)).start() + unix_socket.bind(socket_path) + unix_socket.listen(1) + unix_socket.setblocking(False) + os.chmod(socket_path, 0o600) + except Exception: + raise LocalProxyStartupError( + 'Local UNIX socket based proxy was not able to get started.' + ) + loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) -def local_communication( - unix_socket, ssl_sock, socket_path + +async def local_communication( + unix_socket, ssl_sock, socket_path, loop ): - try: - conn_unix, addr_unix = unix_socket.accept() - - while True: - data = conn_unix.recv(10485760) - if not data: - break - ssl_sock.sendall(data) - response = ssl_sock.recv(10485760) - conn_unix.sendall(response) - - finally: - if conn_unix is not None: - conn_unix.close() - unix_socket.close() - os.remove(socket_path) # Clean up the socket file + try: + client, _ = await loop.sock_accept(unix_socket) + + while True: + data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) + if not data: + client.close() + break + ssl_sock.sendall(data) + response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) + await loop.sock_sendall(client, response) + except Exception: + pass + finally: + client.close() + os.remove(socket_path) # Clean up the socket file diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/psycopg.py index c34dcb1e1..fe862bc24 100644 --- a/google/cloud/sql/connector/psycopg.py +++ b/google/cloud/sql/connector/psycopg.py @@ -25,13 +25,12 @@ def connect( - ip_address: str, sock: ssl.SSLSocket, **kwargs: Any + host: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "psycopg.Connection": """Helper function to create a psycopg DB-API connection object. Args: - ip_address (str): A string containing an IP address for the Cloud SQL - instance. + host (str): A string containing the socket path used by the local proxy. sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. kwargs: Additional arguments to pass to the psycopg connect method. @@ -44,10 +43,7 @@ def connect( ImportError: The psycopg module cannot be imported. """ try: - from psycopg.rows import dict_row from psycopg import Connection - import threading - from google.cloud.sql.connector.proxy import start_local_proxy except ImportError: raise ImportError( 'Unable to import module "psycopg." Please install and try again.' @@ -59,14 +55,9 @@ def connect( kwargs.pop("timeout", None) - start_local_proxy(sock, f"/tmp/connector-socket/.s.PGSQL.3307") - conn = Connection.connect( - f"host=/tmp/connector-socket port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", - autocommit=True, - row_factory=dict_row, + f"host={host} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", **kwargs ) - conn.autocommit = True return conn diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 5b3a3c79e..9d6363d5c 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -18,12 +18,84 @@ import os # [START cloud_sql_connector_postgres_psycopg] +from typing import Union + +import sqlalchemy from google.cloud.sql.connector import Connector from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver + + +def create_sqlalchemy_engine( + instance_connection_name: str, + user: str, + password: str, + db: str, + ip_type: str = "public", + refresh_strategy: str = "background", + resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, +) -> tuple[sqlalchemy.engine.Engine, Connector]: + """Creates a connection pool for a Cloud SQL instance and returns the pool + and the connector. Callers are responsible for closing the pool and the + connector. + + A sample invocation looks like: + + engine, connector = create_sqlalchemy_engine( + inst_conn_name, + user, + password, + db, + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + # do something with query result + connector.close() + + Args: + instance_connection_name (str): + The instance connection name specifies the instance relative to the + project and region. For example: "my-project:my-region:my-instance" + user (str): + The database user name, e.g., root + password (str): + The database user's password, e.g., secret-password + db (str): + The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". + refresh_strategy (Optional[str]): + Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" + or "background". For serverless environments use "lazy" to avoid + errors resulting from CPU being throttled. + resolver (Optional[google.cloud.sql.connector.DefaultResolver]): + Resolver class for resolving instance connection name. Use + google.cloud.sql.connector.DnsResolver when resolving DNS domain + names or google.cloud.sql.connector.DefaultResolver for regular + instance connection names ("my-project:my-region:my-instance"). + """ + connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) + + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+psycopg://", + creator=lambda: connector.connect( + instance_connection_name, + "psycopg", + user=user, + password=password, + db=db, + local_socket_path="/tmp/conn", + ip_type=ip_type, # can be "public", "private" or "psc" + autocommit=True, + ), + ) + return engine, connector -from sqlalchemy.dialects.postgresql.base import PGDialect -PGDialect._get_server_version_info = lambda *args: (9, 2) # [END cloud_sql_connector_postgres_psycopg] @@ -36,25 +108,12 @@ def test_psycopg_connection() -> None: db = os.environ["POSTGRES_DB"] ip_type = os.environ.get("IP_TYPE", "public") - connector = Connector(refresh_strategy="background", resolver=DefaultResolver) - - pool = connector.connect( - inst_conn_name, - "psycopg", - user=user, - password=password, - db=db, - ip_type=ip_type, # can be "public", "private" or "psc" + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type ) - - with pool as conn: - - # Open a cursor to perform database operations - with conn.cursor() as cur: - - # Query the database and obtain data as Python objects. - cur.execute("SELECT NOW()") - curr_time = cur.fetchone()["now"] - assert type(curr_time) is datetime - - + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() From 35ea30bda7cbfb83dc66470096beb5742749daa4 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Tue, 22 Jul 2025 21:13:46 -0600 Subject: [PATCH 04/18] fix(main): Prevent asyncio destroyed task warning Changelog: - Return the asyncio task from `start_local_proxy` - Handle it in `close_async` to cancel it gracefully --- google/cloud/sql/connector/connector.py | 10 +++++++++- google/cloud/sql/connector/proxy.py | 7 +++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7df618e1c..e19c85fc6 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -160,6 +160,7 @@ def __init__( self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None self._closed: bool = False + self._proxy: Optional[asyncio.Task] = None # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -453,7 +454,7 @@ async def connect_async( if driver in LOCAL_PROXY_DRIVERS: local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") host = local_socket_path - start_local_proxy( + self._proxy = start_local_proxy( sock, socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", loop=self._loop @@ -538,6 +539,13 @@ async def close_async(self) -> None: self._closed = True if self._client: await self._client.close() + if self._proxy: + proxy_task = asyncio.gather(self._proxy) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except TimeoutError: + # This task runs forever so it is expected to raise this exception + pass await asyncio.gather(*[cache.close() for cache in self._cache.values()]) diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 06b623e82..85ece82a1 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -30,7 +30,7 @@ def start_local_proxy( ssl_sock: ssl.SSLSocket, socket_path: Optional[str] = "/tmp/connector-socket", loop: Optional[asyncio.AbstractEventLoop] = None, -): +) -> asyncio.Task: """Helper function to start a UNIX based local proxy for transport messages through the SSL Socket. @@ -40,6 +40,9 @@ def start_local_proxy( socket_path: A system path that is going to be used to store the socket. loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + Returns: + asyncio.Task: The asyncio task containing the proxy server process. + Raises: LocalProxyStartupError: Local UNIX socket based proxy was not able to get started. @@ -66,7 +69,7 @@ def start_local_proxy( 'Local UNIX socket based proxy was not able to get started.' ) - loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) + return loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) async def local_communication( From 28c3d95d7d74cb7d73f955486f61795468d4f31a Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 24 Jul 2025 20:35:21 -0600 Subject: [PATCH 05/18] fix(main) Fix linting and undefined cases Changelog: - Fix linting issues - Define `self.proxy` on the constructor - Prevent issues with undefined variables --- google/cloud/sql/connector/connector.py | 4 ++-- google/cloud/sql/connector/proxy.py | 6 +++--- google/cloud/sql/connector/psycopg.py | 1 - tests/unit/test_psycopg.py | 3 ++- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index e19c85fc6..d3c7e31d9 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -40,6 +40,7 @@ from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 +from google.cloud.sql.connector.proxy import start_local_proxy import google.cloud.sql.connector.psycopg as psycopg import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -47,7 +48,6 @@ from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys -from google.cloud.sql.connector.proxy import start_local_proxy logger = logging.getLogger(name=__name__) @@ -438,7 +438,7 @@ async def connect_async( # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: return await connector( - host, + ip_address, await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 85ece82a1..cbb795ea5 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -14,11 +14,11 @@ limitations under the License. """ -import socket -import os -import ssl import asyncio +import os from pathlib import Path +import socket +import ssl from typing import Optional from google.cloud.sql.connector.exceptions import LocalProxyStartupError diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/psycopg.py index fe862bc24..80e824002 100644 --- a/google/cloud/sql/connector/psycopg.py +++ b/google/cloud/sql/connector/psycopg.py @@ -16,7 +16,6 @@ import ssl from typing import Any, TYPE_CHECKING -import threading SERVER_PROXY_PORT = 3307 diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_psycopg.py index 8d9fe1f85..aa30a9c4b 100644 --- a/tests/unit/test_psycopg.py +++ b/tests/unit/test_psycopg.py @@ -18,7 +18,8 @@ import ssl from typing import Any -from mock import patch, PropertyMock +from mock import patch +from mock import PropertyMock import pytest from google.cloud.sql.connector.psycopg import connect From ef46cf0526a4a20858b33f094bf5234ca03fac40 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 24 Jul 2025 21:11:28 -0600 Subject: [PATCH 06/18] fix(main): Fix psycopg unit test --- tests/unit/test_psycopg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_psycopg.py index aa30a9c4b..8088751c5 100644 --- a/tests/unit/test_psycopg.py +++ b/tests/unit/test_psycopg.py @@ -33,7 +33,7 @@ async def test_psycopg(context: ssl.SSLContext, kwargs: Any) -> None: socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, ) - with patch("psycopg.connect") as mock_connect: + with patch("psycopg.Connection.connect") as mock_connect: type(mock_connect.return_value).autocommit = PropertyMock(return_value=True) connection = connect(ip_addr, sock, **kwargs) assert connection.autocommit is True From 38255b5167596e77542bc5990fca9ffbb60a3c43 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 28 Jul 2025 13:06:57 -0600 Subject: [PATCH 07/18] fix(main): Fix linting --- google/cloud/sql/connector/proxy.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index cbb795ea5..385f2ff51 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -19,7 +19,6 @@ from pathlib import Path import socket import ssl -from typing import Optional from google.cloud.sql.connector.exceptions import LocalProxyStartupError @@ -28,8 +27,8 @@ def start_local_proxy( ssl_sock: ssl.SSLSocket, - socket_path: Optional[str] = "/tmp/connector-socket", - loop: Optional[asyncio.AbstractEventLoop] = None, + socket_path: str, + loop: asyncio.AbstractEventLoop ) -> asyncio.Task: """Helper function to start a UNIX based local proxy for transport messages through the SSL Socket. From 1974df51f653e1d9444227fca22d889e9cd5e8f5 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 28 Jul 2025 18:23:15 -0600 Subject: [PATCH 08/18] fix(main) Increase code coverage to 94% Changelog: - Add unit tests for proxy - Add test case to connector for drivers that require the local proxy - Make proper adjustments to code --- google/cloud/sql/connector/connector.py | 4 +- google/cloud/sql/connector/proxy.py | 6 +- tests/unit/test_connector.py | 39 +++++++++++++ tests/unit/test_proxy.py | 75 +++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_proxy.py diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index d3c7e31d9..ff0159aaa 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -40,7 +40,7 @@ from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 -from google.cloud.sql.connector.proxy import start_local_proxy +import google.cloud.sql.connector.proxy as proxy import google.cloud.sql.connector.psycopg as psycopg import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -454,7 +454,7 @@ async def connect_async( if driver in LOCAL_PROXY_DRIVERS: local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") host = local_socket_path - self._proxy = start_local_proxy( + self._proxy = proxy.start_local_proxy( sock, socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", loop=self._loop diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 385f2ff51..e3bd69614 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -74,9 +74,9 @@ def start_local_proxy( async def local_communication( unix_socket, ssl_sock, socket_path, loop ): + client, _ = await loop.sock_accept(unix_socket) + try: - client, _ = await loop.sock_accept(unix_socket) - while True: data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) if not data: @@ -85,8 +85,6 @@ async def local_communication( ssl_sock.sendall(data) response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) await loop.sock_sendall(client, response) - except Exception: - pass finally: client.close() os.remove(socket_path) # Clean up the socket file diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a09b5b72f..a3efd4e55 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -17,6 +17,8 @@ import asyncio import os from threading import Thread +import socket +import ssl from typing import Union from aiohttp import ClientResponseError @@ -35,6 +37,7 @@ from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.proxy import start_local_proxy @pytest.mark.asyncio @@ -283,6 +286,42 @@ async def test_Connector_connect_async( # verify connector made connection call assert connection is True +@pytest.mark.usefixtures("proxy_server") +@pytest.mark.asyncio +async def test_Connector_connect_local_proxy( + fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext +) -> None: + """Test that Connector.connect can launch start_local_proxy.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + socket_path = "/tmp/connector-socket/socket" + ip_addr = "127.0.0.1" + ssl_sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + ) + loop = asyncio.get_running_loop() + task = start_local_proxy(ssl_sock, socket_path, loop) + # patch db connection creation + with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: + with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: + mock_connect.return_value = True + mock_proxy.return_value = task + connection = await connector.connect_async( + "test-project:test-region:test-instance", + "psycopg", + user="my-user", + password="my-pass", + db="my-db", + local_socket_path=socket_path, + ) + # verify connector called local proxy + mock_connect.assert_called_once() + mock_proxy.assert_called_once() + assert connection is True + @pytest.mark.asyncio async def test_Connector_connect_async_multiple_event_loops( diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py new file mode 100644 index 000000000..7e1e5c79d --- /dev/null +++ b/tests/unit/test_proxy.py @@ -0,0 +1,75 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import socket +import ssl +from typing import Any + +from mock import Mock +import pytest + +from google.cloud.sql.connector import proxy + +LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 + +@pytest.mark.usefixtures("proxy_server") +@pytest.mark.asyncio +async def test_proxy_creates_folder(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that the proxy server is getting back the task.""" + ip_addr = "127.0.0.1" + path = "/tmp/connector-socket/socket" + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + ) + loop = asyncio.get_running_loop() + + task = proxy.start_local_proxy(sock, path, loop) + assert (task is not None) + + proxy_task = asyncio.gather(task) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except TimeoutError: + pass # This task runs forever so it is expected to throw this exception + +@pytest.mark.usefixtures("proxy_server") +@pytest.mark.asyncio +async def test_local_proxy_communication(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that the communication is getting through.""" + socket_path = "/tmp/connector-socket/socket" + ssl_sock = Mock(spec=ssl.SSLSocket) + loop = asyncio.get_running_loop() + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: + ssl_sock.recv.return_value = b"Received" + + task = proxy.start_local_proxy(ssl_sock, socket_path, loop) + + client.connect(socket_path) + client.sendall(b"Test") + await asyncio.sleep(1) + + ssl_sock.sendall.assert_called_with(b"Test") + response = client.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) + assert (response == b"Received") + + client.close() + await asyncio.sleep(1) + + proxy_task = asyncio.gather(task) + await asyncio.wait_for(proxy_task, timeout=2) From b663bbeb968fd02673299ecb788e0c3cbc21ec1b Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 28 Jul 2025 20:26:45 -0600 Subject: [PATCH 09/18] fix(main): Add support for Python 3.9 --- google/cloud/sql/connector/connector.py | 10 ++-------- tests/unit/test_connector.py | 6 ++++++ tests/unit/test_proxy.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index ff0159aaa..fab2d0797 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -537,18 +537,12 @@ async def close_async(self) -> None: """Helper function to cancel the cache's tasks and close aiohttp.ClientSession.""" self._closed = True + if self._proxies: + await asyncio.gather(*[proxy.close() for proxy in self._proxies]) if self._client: await self._client.close() - if self._proxy: - proxy_task = asyncio.gather(self._proxy) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except TimeoutError: - # This task runs forever so it is expected to raise this exception - pass await asyncio.gather(*[cache.close() for cache in self._cache.values()]) - async def create_async_connector( ip_type: str | IPTypes = IPTypes.PUBLIC, enable_iam_auth: bool = False, diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a3efd4e55..06d0ab2a2 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -321,6 +321,12 @@ async def test_Connector_connect_local_proxy( mock_connect.assert_called_once() mock_proxy.assert_called_once() assert connection is True + + proxy_task = asyncio.gather(task) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): + pass # This task runs forever so it is expected to throw this exception @pytest.mark.asyncio diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index 7e1e5c79d..c1143f19f 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -44,7 +44,7 @@ async def test_proxy_creates_folder(context: ssl.SSLContext, kwargs: Any) -> Non proxy_task = asyncio.gather(task) try: await asyncio.wait_for(proxy_task, timeout=0.1) - except TimeoutError: + except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): pass # This task runs forever so it is expected to throw this exception @pytest.mark.usefixtures("proxy_server") From ca6a27115600d3aea25bdc3ade0ed438cf54381c Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 18 Aug 2025 18:38:31 -0600 Subject: [PATCH 10/18] fix(main): Make local proxy to accept multiple connections (WIP) --- google/cloud/sql/connector/connector.py | 50 +++-- google/cloud/sql/connector/enums.py | 4 +- .../{psycopg.py => local_unix_socket.py} | 37 +--- google/cloud/sql/connector/proxy.py | 176 ++++++++++++------ pyproject.toml | 1 - tests/system/test_psycopg_connection.py | 17 +- ...t_psycopg.py => test_local_unix_socket.py} | 14 +- 7 files changed, 169 insertions(+), 130 deletions(-) rename google/cloud/sql/connector/{psycopg.py => local_unix_socket.py} (51%) rename tests/unit/{test_psycopg.py => test_local_unix_socket.py} (61%) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index fab2d0797..78b9f4784 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -38,10 +38,10 @@ from google.cloud.sql.connector.exceptions import ConnectorLoopError from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache +import google.cloud.sql.connector.local_unix_socket as local_unix_socket from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 -import google.cloud.sql.connector.proxy as proxy -import google.cloud.sql.connector.psycopg as psycopg +from google.cloud.sql.connector.proxy import Proxy import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver @@ -52,7 +52,6 @@ logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] -LOCAL_PROXY_DRIVERS = ["psycopg"] SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" @@ -160,7 +159,7 @@ def __init__( self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None self._closed: bool = False - self._proxy: Optional[asyncio.Task] = None + self._proxies: Optional[Proxy] = None # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -221,6 +220,29 @@ def __init__( def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN + def start_unix_socket_proxy_async( + self, + instance_connection_name: str, + local_socket_path: str, + **kwargs: Any + ) -> None: + """Creates a new Proxy instance and stores it to properly disposal + + Args: + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. Takes the form of + "project-id:region:instance-name" + + Example: "my-project:us-central1:my-instance" + + local_socket_path (str): A string representing the location of the local socket. + + **kwargs: Any driver-specific arguments to pass to the underlying + driver .connect call. + """ + # TODO: validates the local socket path is not the same as other invocation + self._proxies.append(new Proxy(self, instance_connection_name, local_socket_path, self.loop, **kwargs)) + def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -238,7 +260,7 @@ def connect( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, pg8000, psycopg, and pytds. + with. Supported drivers are pymysql, pg8000, local_unix_socket, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -280,7 +302,7 @@ async def connect_async( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and + with. Supported drivers are pymysql, asyncpg, pg8000, local_unix_socket, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying @@ -293,7 +315,7 @@ async def connect_async( ValueError: Connection attempt with built-in database authentication and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, - pg8000, psycopg, and pytds. + pg8000, local_unix_socket, and pytds. RuntimeError: Connector has been closed. Cannot connect using a closed Connector. """ @@ -362,7 +384,7 @@ async def connect_async( connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, - "psycopg": psycopg.connect, + "local_unix_socket": local_unix_socket.connect, "asyncpg": asyncpg.connect, "pytds": pytds.connect, } @@ -449,17 +471,6 @@ async def connect_async( server_hostname=ip_address, ) - host = ip_address - # start local proxy if driver needs it - if driver in LOCAL_PROXY_DRIVERS: - local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") - host = local_socket_path - self._proxy = proxy.start_local_proxy( - sock, - socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", - loop=self._loop - ) - # If this connection was opened using a domain name, then store it # for later in case we need to forcibly close it on failover. if conn_info.conn_name.domain_name: @@ -543,6 +554,7 @@ async def close_async(self) -> None: await self._client.close() await asyncio.gather(*[cache.close() for cache in self._cache.values()]) + async def create_async_connector( ip_type: str | IPTypes = IPTypes.PUBLIC, enable_iam_auth: bool = False, diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index 5926b75a2..7bde045ed 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -62,7 +62,7 @@ class DriverMapping(Enum): ASYNCPG = "POSTGRES" PG8000 = "POSTGRES" - PSYCOPG = "POSTGRES" + LOCAL_UNIX_SOCKET = "ANY" PYMYSQL = "MYSQL" PYTDS = "SQLSERVER" @@ -79,7 +79,7 @@ def validate_engine(driver: str, engine_version: str) -> None: the given engine. """ mapping = DriverMapping[driver.upper()] - if not engine_version.startswith(mapping.value): + if not mapping.value == "ANY" and not engine_version.startswith(mapping.value): raise IncompatibleDriverError( f"Database driver '{driver}' is incompatible with database " f"version '{engine_version}'. Given driver can " diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/local_unix_socket.py similarity index 51% rename from google/cloud/sql/connector/psycopg.py rename to google/cloud/sql/connector/local_unix_socket.py index 80e824002..497e503e0 100644 --- a/google/cloud/sql/connector/psycopg.py +++ b/google/cloud/sql/connector/local_unix_socket.py @@ -19,44 +19,19 @@ SERVER_PROXY_PORT = 3307 -if TYPE_CHECKING: - import psycopg - - def connect( host: str, sock: ssl.SSLSocket, **kwargs: Any -) -> "psycopg.Connection": - """Helper function to create a psycopg DB-API connection object. +) -> "ssl.SSLSocket": + """Helper function to retrieve the socket for local UNIX sockets. Args: host (str): A string containing the socket path used by the local proxy. sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. - kwargs: Additional arguments to pass to the psycopg connect method. + kwargs: Additional arguments to pass to the local UNIX socket connect method. Returns: - psycopg.Connection: A psycopg connection to the Cloud SQL - instance. - - Raises: - ImportError: The psycopg module cannot be imported. + ssl.SSLSocket: The same socket """ - try: - from psycopg import Connection - except ImportError: - raise ImportError( - 'Unable to import module "psycopg." Please install and try again.' - ) - - user = kwargs.pop("user") - db = kwargs.pop("db") - passwd = kwargs.pop("password", None) - - kwargs.pop("timeout", None) - - conn = Connection.connect( - f"host={host} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", - **kwargs - ) - - return conn + + return sock diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index e3bd69614..568c463ad 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -18,73 +18,125 @@ import os from pathlib import Path import socket +import selectors import ssl +from google.cloud.sql.connector import Connector from google.cloud.sql.connector.exceptions import LocalProxyStartupError SERVER_PROXY_PORT = 3307 LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 -def start_local_proxy( - ssl_sock: ssl.SSLSocket, - socket_path: str, - loop: asyncio.AbstractEventLoop -) -> asyncio.Task: - """Helper function to start a UNIX based local proxy for - transport messages through the SSL Socket. - - Args: - ssl_sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL - server CA cert and ephemeral cert. - socket_path: A system path that is going to be used to store the socket. - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. - - Returns: - asyncio.Task: The asyncio task containing the proxy server process. - - Raises: - LocalProxyStartupError: Local UNIX socket based proxy was not able to - get started. - """ - unix_socket = None - - try: - path_parts = socket_path.rsplit('/', 1) - parent_directory = '/'.join(path_parts[:-1]) - - desired_path = Path(parent_directory) - desired_path.mkdir(parents=True, exist_ok=True) - - if os.path.exists(socket_path): - os.remove(socket_path) - unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - - unix_socket.bind(socket_path) - unix_socket.listen(1) - unix_socket.setblocking(False) - os.chmod(socket_path, 0o600) - except Exception: - raise LocalProxyStartupError( - 'Local UNIX socket based proxy was not able to get started.' - ) - - return loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) - - -async def local_communication( - unix_socket, ssl_sock, socket_path, loop -): - client, _ = await loop.sock_accept(unix_socket) - - try: + +class Proxy: + """Creates an "accept loop" async task which will open the unix server socket and listen for new connections.""" + + def __init__( + self, + connector: Connector, + instance_connection_string: str, + socket_path: str, + loop: asyncio.AbstractEventLoop, + **kwargs: Any + ) -> None: + """Keeps track of all the async tasks and starts the accept loop for new connections. + + Args: + connector (Connector): The instance where this Proxy class was created. + + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. Takes the form of + "project-id:region:instance-name" + + Example: "my-project:us-central1:my-instance" + + socket_path (str): A system path that is going to be used to store the socket. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + **kwargs: Any driver-specific arguments to pass to the underlying + driver .connect call. + """ + self._connection_tasks = [] + self._addr = instance_connection_string + self._kwargs = kwargs + self._connector = connector + self._task = loop.create_task(accept_loop(socket_path, loop, **kwargs)) + + async def accept_loop( + self + socket_path: str, + loop: asyncio.AbstractEventLoop + ) -> asyncio.Task: + """Starts a UNIX based local proxy for transporting messages through + the SSL Socket, and waits until there is a new connection to accept, to register it + and keep track of it. + + Args: + socket_path: A system path that is going to be used to store the socket. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + Raises: + LocalProxyStartupError: Local UNIX socket based proxy was not able to + get started. + """ + unix_socket = None + sel = selectors.DefaultSelector() + + try: + path_parts = socket_path.rsplit('/', 1) + parent_directory = '/'.join(path_parts[:-1]) + + desired_path = Path(parent_directory) + desired_path.mkdir(parents=True, exist_ok=True) + + if os.path.exists(socket_path): + os.remove(socket_path) + + unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + unix_socket.bind(socket_path) + unix_socket.listen(1) + unix_socket.setblocking(False) + os.chmod(socket_path, 0o600) + + sel.register(unix_socket, selectors.EVENT_READ, data=None) + + except Exception: + raise LocalProxyStartupError( + 'Local UNIX socket based proxy was not able to get started.' + ) + while True: - data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) - if not data: - client.close() - break - ssl_sock.sendall(data) - response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) - await loop.sock_sendall(client, response) - finally: - client.close() - os.remove(socket_path) # Clean up the socket file + client, _ = await loop.sock_accept(unix_socket) + self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop))) + + async def close_async(self): + proxy_task = asyncio.gather(self._task) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): + pass # This task runs forever so it is expected to throw this exception + + + async def client_socket( + self, client, unix_socket, socket_path, loop + ): + try: + ssl_sock = self.connector.connect( + self._addr, + 'local_unix_socket', + **self._kwargs + ) + while True: + data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) + if not data: + client.close() + break + ssl_sock.sendall(data) + response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) + await loop.sock_sendall(client, response) + finally: + client.close() + os.remove(socket_path) # Clean up the socket file diff --git a/pyproject.toml b/pyproject.toml index 1c933cf6a..cbf0dd10f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ Changelog = "https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/b [project.optional-dependencies] pymysql = ["PyMySQL>=1.1.0"] pg8000 = ["pg8000>=1.31.1"] -psycopg = ["psycopg>=3.2.9"] pytds = ["python-tds>=1.15.0"] asyncpg = ["asyncpg>=0.30.0"] diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 9d6363d5c..e54d957c8 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -20,6 +20,7 @@ # [START cloud_sql_connector_postgres_psycopg] from typing import Union +from psycopg import Connection import sqlalchemy from google.cloud.sql.connector import Connector @@ -79,21 +80,25 @@ def create_sqlalchemy_engine( instance connection names ("my-project:my-region:my-instance"). """ connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) + unix_socket_path = "/tmp/conn" + await connector.start_unix_socket_proxy_async( + instance_connection_name, + unix_socket_path, + ip_type=ip_type, # can be "public", "private" or "psc" + ) # create SQLAlchemy connection pool engine = sqlalchemy.create_engine( "postgresql+psycopg://", - creator=lambda: connector.connect( - instance_connection_name, - "psycopg", + creator=lambda: Connection.connect( + f"host={unix_socket_path} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", user=user, password=password, db=db, - local_socket_path="/tmp/conn", - ip_type=ip_type, # can be "public", "private" or "psc" autocommit=True, - ), + ) ) + return engine, connector diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_local_unix_socket.py similarity index 61% rename from tests/unit/test_psycopg.py rename to tests/unit/test_local_unix_socket.py index 8088751c5..8672857ec 100644 --- a/tests/unit/test_psycopg.py +++ b/tests/unit/test_local_unix_socket.py @@ -22,20 +22,16 @@ from mock import PropertyMock import pytest -from google.cloud.sql.connector.psycopg import connect +from google.cloud.sql.connector.local_unix_socket import connect @pytest.mark.usefixtures("proxy_server") -async def test_psycopg(context: ssl.SSLContext, kwargs: Any) -> None: - """Test to verify that psycopg gets to proper connection call.""" +async def test_local_unix_socket(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that local_unix_socket gets to proper connection call.""" ip_addr = "127.0.0.1" sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, ) - with patch("psycopg.Connection.connect") as mock_connect: - type(mock_connect.return_value).autocommit = PropertyMock(return_value=True) - connection = connect(ip_addr, sock, **kwargs) - assert connection.autocommit is True - # verify that driver connection call would be made - assert mock_connect.assert_called_once + connection = connect(ip_addr, sock, **kwargs) + assert connection == sock From 6933d994a475fca31aea209321d3617a6a97401d Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Wed, 10 Sep 2025 13:18:07 -0600 Subject: [PATCH 11/18] test: Fix compilation errors --- google/cloud/sql/connector/connector.py | 63 +++++++++------ .../cloud/sql/connector/local_unix_socket.py | 2 - google/cloud/sql/connector/proxy.py | 52 ++++++------ tests/system/test_psycopg_connection.py | 10 ++- tests/unit/test_connector.py | 80 +++++++++---------- 5 files changed, 109 insertions(+), 98 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 78b9f4784..04e7cbc12 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -159,7 +159,7 @@ def __init__( self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None self._closed: bool = False - self._proxies: Optional[Proxy] = None + self._proxies: list[proxy.Proxy] = [] # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -220,29 +220,6 @@ def __init__( def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN - def start_unix_socket_proxy_async( - self, - instance_connection_name: str, - local_socket_path: str, - **kwargs: Any - ) -> None: - """Creates a new Proxy instance and stores it to properly disposal - - Args: - instance_connection_string (str): The instance connection name of the - Cloud SQL instance to connect to. Takes the form of - "project-id:region:instance-name" - - Example: "my-project:us-central1:my-instance" - - local_socket_path (str): A string representing the location of the local socket. - - **kwargs: Any driver-specific arguments to pass to the underlying - driver .connect call. - """ - # TODO: validates the local socket path is not the same as other invocation - self._proxies.append(new Proxy(self, instance_connection_name, local_socket_path, self.loop, **kwargs)) - def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -478,7 +455,7 @@ async def connect_async( # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, - host, + ip_address, sock, **kwargs, ) @@ -489,6 +466,42 @@ async def connect_async( await monitored_cache.force_refresh() raise + async def start_unix_socket_proxy_async( + self, instance_connection_string: str, local_socket_path: str, **kwargs: Any + ) -> None: + """Starts a local Unix socket proxy for a Cloud SQL instance. + + Args: + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. + local_socket_path (str): The path to the local Unix socket. + driver (str): The database driver name. + **kwargs: Keyword arguments to pass to the underlying database + driver. + """ + if "driver" in kwargs: + driver = kwargs["driver"] + else: + driver = "proxy" + + self._init_client(driver) + + # check if a proxy is already running for this socket path + for p in self._proxies: + if p.unix_socket_path == local_socket_path: + raise ValueError( + f"Proxy for socket path {local_socket_path} already exists." + ) + + # Create a new proxy instance + proxy_instance = proxy.Proxy( + local_socket_path, + ConnectorSocketFactory(self, instance_connection_string, **kwargs), + self._loop + ) + await proxy_instance.start() + self._proxies.append(proxy_instance) + async def _remove_cached( self, instance_connection_string: str, enable_iam_auth: bool ) -> None: diff --git a/google/cloud/sql/connector/local_unix_socket.py b/google/cloud/sql/connector/local_unix_socket.py index 497e503e0..25a9d3be3 100644 --- a/google/cloud/sql/connector/local_unix_socket.py +++ b/google/cloud/sql/connector/local_unix_socket.py @@ -17,8 +17,6 @@ import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - def connect( host: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "ssl.SSLSocket": diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 568c463ad..e5d9ab536 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -21,10 +21,8 @@ import selectors import ssl -from google.cloud.sql.connector import Connector from google.cloud.sql.connector.exceptions import LocalProxyStartupError -SERVER_PROXY_PORT = 3307 LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 @@ -33,11 +31,11 @@ class Proxy: def __init__( self, - connector: Connector, + connector, instance_connection_string: str, socket_path: str, loop: asyncio.AbstractEventLoop, - **kwargs: Any + **kwargs ) -> None: """Keeps track of all the async tasks and starts the accept loop for new connections. @@ -61,28 +59,8 @@ def __init__( self._addr = instance_connection_string self._kwargs = kwargs self._connector = connector - self._task = loop.create_task(accept_loop(socket_path, loop, **kwargs)) - async def accept_loop( - self - socket_path: str, - loop: asyncio.AbstractEventLoop - ) -> asyncio.Task: - """Starts a UNIX based local proxy for transporting messages through - the SSL Socket, and waits until there is a new connection to accept, to register it - and keep track of it. - - Args: - socket_path: A system path that is going to be used to store the socket. - - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. - - Raises: - LocalProxyStartupError: Local UNIX socket based proxy was not able to - get started. - """ unix_socket = None - sel = selectors.DefaultSelector() try: path_parts = socket_path.rsplit('/', 1) @@ -100,14 +78,34 @@ async def accept_loop( unix_socket.listen(1) unix_socket.setblocking(False) os.chmod(socket_path, 0o600) - - sel.register(unix_socket, selectors.EVENT_READ, data=None) + + self._task = loop.create_task(self.accept_loop(unix_socket, socket_path, loop)) except Exception: raise LocalProxyStartupError( 'Local UNIX socket based proxy was not able to get started.' ) + async def accept_loop( + self, + unix_socket, + socket_path: str, + loop: asyncio.AbstractEventLoop + ) -> asyncio.Task: + """Starts a UNIX based local proxy for transporting messages through + the SSL Socket, and waits until there is a new connection to accept, to register it + and keep track of it. + + Args: + socket_path: A system path that is going to be used to store the socket. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + Raises: + LocalProxyStartupError: Local UNIX socket based proxy was not able to + get started. + """ + print("on accept loop") while True: client, _ = await loop.sock_accept(unix_socket) self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop))) @@ -124,7 +122,7 @@ async def client_socket( self, client, unix_socket, socket_path, loop ): try: - ssl_sock = self.connector.connect( + ssl_sock = self._connector.connect( self._addr, 'local_unix_socket', **self._kwargs diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index e54d957c8..6f0e07c42 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -27,6 +27,7 @@ from google.cloud.sql.connector import DefaultResolver from google.cloud.sql.connector import DnsResolver +SERVER_PROXY_PORT = 3307 def create_sqlalchemy_engine( instance_connection_name: str, @@ -80,8 +81,9 @@ def create_sqlalchemy_engine( instance connection names ("my-project:my-region:my-instance"). """ connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) - unix_socket_path = "/tmp/conn" - await connector.start_unix_socket_proxy_async( + unix_socket_folder = "/tmp/conn" + unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" + connector.start_unix_socket_proxy_async( instance_connection_name, unix_socket_path, ip_type=ip_type, # can be "public", "private" or "psc" @@ -91,10 +93,10 @@ def create_sqlalchemy_engine( engine = sqlalchemy.create_engine( "postgresql+psycopg://", creator=lambda: Connection.connect( - f"host={unix_socket_path} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", + f"host={unix_socket_folder} port={SERVER_PROXY_PORT} dbname={db} user={user} password={password} sslmode=require", user=user, password=password, - db=db, + dbname=db, autocommit=True, ) ) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 06d0ab2a2..e3963faee 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -286,47 +286,47 @@ async def test_Connector_connect_async( # verify connector made connection call assert connection is True -@pytest.mark.usefixtures("proxy_server") -@pytest.mark.asyncio -async def test_Connector_connect_local_proxy( - fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext -) -> None: - """Test that Connector.connect can launch start_local_proxy.""" - async with Connector( - credentials=fake_credentials, loop=asyncio.get_running_loop() - ) as connector: - connector._client = fake_client - socket_path = "/tmp/connector-socket/socket" - ip_addr = "127.0.0.1" - ssl_sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), - server_hostname=ip_addr, - ) - loop = asyncio.get_running_loop() - task = start_local_proxy(ssl_sock, socket_path, loop) - # patch db connection creation - with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: - with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: - mock_connect.return_value = True - mock_proxy.return_value = task - connection = await connector.connect_async( - "test-project:test-region:test-instance", - "psycopg", - user="my-user", - password="my-pass", - db="my-db", - local_socket_path=socket_path, - ) - # verify connector called local proxy - mock_connect.assert_called_once() - mock_proxy.assert_called_once() - assert connection is True +# @pytest.mark.usefixtures("proxy_server") +# @pytest.mark.asyncio +# async def test_Connector_connect_local_proxy( +# fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext +# ) -> None: +# """Test that Connector.connect can launch start_local_proxy.""" +# async with Connector( +# credentials=fake_credentials, loop=asyncio.get_running_loop() +# ) as connector: +# connector._client = fake_client +# socket_path = "/tmp/connector-socket/socket" +# ip_addr = "127.0.0.1" +# ssl_sock = context.wrap_socket( +# socket.create_connection((ip_addr, 3307)), +# server_hostname=ip_addr, +# ) +# loop = asyncio.get_running_loop() +# task = start_local_proxy(ssl_sock, socket_path, loop) +# # patch db connection creation +# with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: +# with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: +# mock_connect.return_value = True +# mock_proxy.return_value = task +# connection = await connector.connect_async( +# "test-project:test-region:test-instance", +# "psycopg", +# user="my-user", +# password="my-pass", +# db="my-db", +# local_socket_path=socket_path, +# ) +# # verify connector called local proxy +# mock_connect.assert_called_once() +# mock_proxy.assert_called_once() +# assert connection is True - proxy_task = asyncio.gather(task) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): - pass # This task runs forever so it is expected to throw this exception +# proxy_task = asyncio.gather(task) +# try: +# await asyncio.wait_for(proxy_task, timeout=0.1) +# except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): +# pass # This task runs forever so it is expected to throw this exception @pytest.mark.asyncio From 173951ed7797a8761dac0f79be48b74ffaf87c46 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 2 Oct 2025 14:51:13 -0600 Subject: [PATCH 12/18] feat: Add proxy server and fix all unit tests --- google/cloud/sql/connector/connector.py | 189 +++++++++-- google/cloud/sql/connector/enums.py | 4 +- google/cloud/sql/connector/proxy.py | 308 ++++++++++++------ pyproject.toml | 4 + tests/conftest.py | 200 +++++++++--- tests/system/test_psycopg_connection.py | 8 +- tests/unit/test_connector.py | 218 +++++++++---- tests/unit/test_proxy.py | 407 +++++++++++++++++++++--- 8 files changed, 1049 insertions(+), 289 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 04e7cbc12..944634d88 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -20,7 +20,6 @@ from functools import partial import logging import os -import socket from threading import Thread from types import TracebackType from typing import Any, Callable, Optional, Union @@ -38,10 +37,9 @@ from google.cloud.sql.connector.exceptions import ConnectorLoopError from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache -import google.cloud.sql.connector.local_unix_socket as local_unix_socket from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 -from google.cloud.sql.connector.proxy import Proxy +import google.cloud.sql.connector.proxy as proxy import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver @@ -220,6 +218,108 @@ def __init__( def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN + async def _get_cache( + self, + instance_connection_string: str, + enable_iam_auth: bool, + ip_type: IPTypes, + driver: str | None, + ) -> MonitoredCache: + """Helper function to get instance's cache from Connector cache.""" + + # resolve instance connection name + conn_name = await self._resolver.resolve(instance_connection_string) + cache_key = (str(conn_name), enable_iam_auth) + + # if cache entry doesn't exist or is closed, create it + if cache_key not in self._cache or self._cache[cache_key].closed: + # if lazy refresh, init keys now + if self._refresh_strategy == RefreshStrategy.LAZY and self._keys is None: + self._keys = asyncio.create_task(generate_keys()) + # create cache + if self._refresh_strategy == RefreshStrategy.LAZY: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to lazy refresh" + ) + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( + conn_name, + self._init_client(driver), + self._keys, # type: ignore + enable_iam_auth, + ) + else: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to background refresh" + ) + cache = RefreshAheadCache( + conn_name, + self._init_client(driver), + self._keys, # type: ignore + enable_iam_auth, + ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) + logger.debug(f"['{conn_name}']: Connection info added to cache") + self._cache[cache_key] = monitored_cache + + monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] + + # Check that the information is valid and matches the driver and db type + try: + conn_info = await monitored_cache.connect_info() + # validate driver matches intended database engine + if driver: + DriverMapping.validate_engine(driver, conn_info.database_version) + if ip_type: + conn_info.get_preferred_ip(ip_type) + except Exception: + await self._remove_cached(str(conn_name), enable_iam_auth) + raise + + return monitored_cache + + async def connect_socket_async( + self, + instance_connection_string: str, + protocol_fn: Callable[[], asyncio.Protocol], + **kwargs: Any, + ) -> tuple[asyncio.Transport, asyncio.Protocol]: + """Helper function to connect to a Cloud SQL instance and return a socket.""" + + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + ip_type = kwargs.pop("ip_type", self._ip_type) + driver = kwargs.pop("driver", None) + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes._from_str(ip_type) + + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) + + try: + conn_info = await monitored_cache.connect_info() + ctx = await conn_info.create_ssl_context(enable_iam_auth) + ip_address = conn_info.get_preferred_ip(ip_type) + tx, p = await self._loop.create_connection( + protocol_fn, host=ip_address, port=3307, ssl=ctx + ) + except Exception as ex: + logger.exception("exception starting tls protocol", exc_info=ex) + # with an error from Cloud SQL Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached( + instance_connection_string, + enable_iam_auth, + ) + raise + + return tx, p + def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -237,7 +337,7 @@ def connect( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, pg8000, local_unix_socket, and pytds. + with. Supported drivers are pymysql, pg8000, psycopg, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -261,6 +361,18 @@ def connect( ) return connect_future.result() + def _init_client(self, driver: Optional[str]) -> CloudSQLClient: + """Lazy initialize the client, setting the driver name in the user agent string.""" + if self._client is None: + self._client = CloudSQLClient( + self._sqladmin_api_endpoint, + self._quota_project, + self._credentials, + user_agent=self._user_agent, + driver=driver, + ) + return self._client + async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -279,7 +391,7 @@ async def connect_async( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, asyncpg, pg8000, local_unix_socket, and + with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying @@ -358,15 +470,15 @@ async def connect_async( logger.debug(f"['{conn_name}']: Connection info added to cache") self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + # Map drivers to connect functions connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, - "local_unix_socket": local_unix_socket.connect, "asyncpg": asyncpg.connect, "pytds": pytds.connect, } - # only accept supported database drivers + # Only accept supported database drivers try: connector: Callable = connect_func[driver] # type: ignore except KeyError: @@ -376,6 +488,7 @@ async def connect_async( # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) kwargs["timeout"] = kwargs.get("timeout", self._timeout) # Host and ssl options come from the certificates and metadata, so we don't @@ -384,7 +497,12 @@ async def connect_async( kwargs.pop("ssl", None) kwargs.pop("port", None) - # attempt to get connection info for Cloud SQL instance + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) + conn_info = await monitored_cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + try: conn_info = await monitored_cache.connect_info() # validate driver matches intended database engine @@ -430,39 +548,31 @@ async def connect_async( ) if formatted_user != kwargs["user"]: logger.debug( - f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + f"['{instance_connection_string}']: " + "Truncated IAM database username from " + f"{kwargs['user']} to {formatted_user}" ) kwargs["user"] = formatted_user - try: + + ctx = await conn_info.create_ssl_context(enable_iam_auth) # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: - return await connector( - ip_address, - await conn_info.create_ssl_context(enable_iam_auth), - **kwargs, + return await connector(ip_address, ctx, **kwargs) + else: + # Synchronous drivers are blocking and run using executor + tx, _ = await self.connect_socket_async( + instance_connection_string, asyncio.Protocol, **kwargs ) - # Create socket with SSLContext for sync drivers - ctx = await conn_info.create_ssl_context(enable_iam_auth) - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - - # If this connection was opened using a domain name, then store it - # for later in case we need to forcibly close it on failover. - if conn_info.conn_name.domain_name: - monitored_cache.sockets.append(sock) - # Synchronous drivers are blocking and run using executor - connect_partial = partial( - connector, - ip_address, - sock, - **kwargs, - ) - return await self._loop.run_in_executor(None, connect_partial) + # See https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseTransport.get_extra_info + sock = tx.get_extra_info("ssl_object") + connect_partial = partial(connector, ip_address, sock, **kwargs) + return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) await monitored_cache.force_refresh() raise @@ -501,7 +611,7 @@ async def start_unix_socket_proxy_async( ) await proxy_instance.start() self._proxies.append(proxy_instance) - + async def _remove_cached( self, instance_connection_string: str, enable_iam_auth: bool ) -> None: @@ -567,7 +677,6 @@ async def close_async(self) -> None: await self._client.close() await asyncio.gather(*[cache.close() for cache in self._cache.values()]) - async def create_async_connector( ip_type: str | IPTypes = IPTypes.PUBLIC, enable_iam_auth: bool = False, @@ -657,3 +766,13 @@ async def create_async_connector( resolver=resolver, failover_period=failover_period, ) + + +class ConnectorSocketFactory(proxy.ServerConnectionFactory): + def __init__(self, connector:Connector, instance_connection_string:str, **kwargs): + self._connector = connector + self._instance_connection_string = instance_connection_string + self._connect_args=kwargs + + async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]): + await self._connector.connect_socket_async(self._instance_connection_string, protocol_fn, **self._connect_args) \ No newline at end of file diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index 7bde045ed..5926b75a2 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -62,7 +62,7 @@ class DriverMapping(Enum): ASYNCPG = "POSTGRES" PG8000 = "POSTGRES" - LOCAL_UNIX_SOCKET = "ANY" + PSYCOPG = "POSTGRES" PYMYSQL = "MYSQL" PYTDS = "SQLSERVER" @@ -79,7 +79,7 @@ def validate_engine(driver: str, engine_version: str) -> None: the given engine. """ mapping = DriverMapping[driver.upper()] - if not mapping.value == "ANY" and not engine_version.startswith(mapping.value): + if not engine_version.startswith(mapping.value): raise IncompatibleDriverError( f"Database driver '{driver}' is incompatible with database " f"version '{engine_version}'. Given driver can " diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index e5d9ab536..99a121782 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -14,127 +14,253 @@ limitations under the License. """ +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod import asyncio +from functools import partial +import logging import os from pathlib import Path -import socket -import selectors -import ssl +from typing import Callable, List -from google.cloud.sql.connector.exceptions import LocalProxyStartupError +logger = logging.getLogger(name=__name__) -LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 +class BaseProxyProtocol(asyncio.Protocol): + """ + A protocol to proxy data between two transports. + """ -class Proxy: - """Creates an "accept loop" async task which will open the unix server socket and listen for new connections.""" + def __init__(self, proxy: Proxy): + super().__init__() + self.proxy = proxy + self._buffer = bytearray() + self._target: asyncio.Transport | None = None + self.transport: asyncio.Transport | None = None + self._cached: List[bytes] = [] + logger.debug(f"__init__ {self}") + + def connection_made(self, transport): + logger.debug(f"connection_made {self}") + self.transport = transport + + def data_received(self, data): + if self._target is None: + self._cached.append(data) + else: + self._target.write(data) + + def set_target(self, target: asyncio.Transport): + logger.debug(f"set_target {self}") + self._target = target + if self._cached: + self._target.writelines(self._cached) + self._cached = [] + + def eof_received(self): + logger.debug(f"eof_received {self}") + if self._target is not None: + self._target.write_eof() + + def connection_lost(self, exc: Exception | None): + logger.debug(f"connection_lost {exc} {self}") + if self._target is not None: + self._target.close() + + +class ProxyClientConnection: + """ + Holds all of the tasks and details for a client proxy + """ def __init__( self, - connector, - instance_connection_string: str, - socket_path: str, - loop: asyncio.AbstractEventLoop, - **kwargs - ) -> None: - """Keeps track of all the async tasks and starts the accept loop for new connections. - - Args: - connector (Connector): The instance where this Proxy class was created. + client_transport: asyncio.Transport, + client_protocol: ClientToServerProtocol, + ): + self.client_transport = client_transport + self.client_protocol = client_protocol + self.server_transport: asyncio.Transport | None = None + self.server_protocol: ServerToClientProtocol | None = None + self.task: asyncio.Task | None = None + + def close(self): + logger.debug(f"closing {self}") + if self.client_transport is not None: + self._close_transport(self.client_transport) + if self.server_transport is not None: + self._close_transport(self.server_transport) + + def _close_transport(self, transport:asyncio.Transport): + if transport.is_closing(): + return + if transport.can_write_eof(): + transport.write_eof() + else: + transport.close() + +class ClientToServerProtocol(BaseProxyProtocol): + """ + Protocol to copy bytes from the unix socket client to the database server + """ + + def __init__(self, proxy: Proxy): + super().__init__(proxy) + self._buffer = bytearray() + self._target: asyncio.Transport | None = None + logger.debug(f"__init__ {self}") + + def connection_made(self, transport): + # When a connection is made, open the server connection + super().connection_made(transport) + self.proxy._handle_client_connection(transport, self) - instance_connection_string (str): The instance connection name of the - Cloud SQL instance to connect to. Takes the form of - "project-id:region:instance-name" - Example: "my-project:us-central1:my-instance" +class ServerToClientProtocol(BaseProxyProtocol): + """ + Protocol to copy bytes from the database server to the client socket + """ - socket_path (str): A system path that is going to be used to store the socket. + def __init__(self, proxy: Proxy, cconn: ProxyClientConnection): + super().__init__(proxy) + self._buffer = bytearray() + self._target = cconn.client_transport + self._client_protocol = cconn.client_protocol + logger.debug(f"__init__ {self}") - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + def connection_made(self, transport): + super().connection_made(transport) + self._client_protocol.set_target(transport) - **kwargs: Any driver-specific arguments to pass to the underlying - driver .connect call. + def connection_lost(self, exc: Exception | None): + super().connection_lost(exc) + self.proxy._handle_server_connection_lost() + +class ServerConnectionFactory(ABC): + """ + ServerConnectionFactory is an abstract class that provides connections to the service. + """ + @abstractmethod + async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]): + """ + Establishes a connection to the server and configures it to use the protocol + returned from protocol_fn, with asyncio.EventLoop.create_connection(). + :param protocol_fn: the protocol function + :return: None """ - self._connection_tasks = [] - self._addr = instance_connection_string - self._kwargs = kwargs - self._connector = connector + pass - unix_socket = None +class Proxy: + """ + A class to represent a local Unix socket proxy for a Cloud SQL instance. + This class manages a Unix socket that listens for incoming connections and + proxies them to a Cloud SQL instance. + """ - try: - path_parts = socket_path.rsplit('/', 1) - parent_directory = '/'.join(path_parts[:-1]) + def __init__( + self, + unix_socket_path: str, + server_connection_factory: ServerConnectionFactory, + loop: asyncio.AbstractEventLoop, + ): + """ + Creates a new Proxy + :param unix_socket_path: the path to listen for the proxy connection + :param loop: The event loop + :param instance_connect: A function that will establish the async connection to the server + + The instance_connect function is an asynchronous function that should set up a new connection. + It takes one argument - another function that + """ + self.unix_socket_path = unix_socket_path + self.alive = True + self._loop = loop + self._server: asyncio.AbstractServer | None = None + self._client_connections: set[ProxyClientConnection] = set() + self._server_connection_factory = server_connection_factory - desired_path = Path(parent_directory) - desired_path.mkdir(parents=True, exist_ok=True) + async def start(self) -> None: + """Starts the Unix socket server.""" + if os.path.exists(self.unix_socket_path): + os.remove(self.unix_socket_path) - if os.path.exists(socket_path): - os.remove(socket_path) + parent_dir = Path(self.unix_socket_path).parent + parent_dir.mkdir(parents=True, exist_ok=True) - unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + def new_protocol() -> ClientToServerProtocol: + return ClientToServerProtocol(self) - unix_socket.bind(socket_path) - unix_socket.listen(1) - unix_socket.setblocking(False) - os.chmod(socket_path, 0o600) + logger.debug(f"Socket path: {self.unix_socket_path}") + self._server = await self._loop.create_unix_server( + new_protocol, path=self.unix_socket_path + ) + self._loop.create_task(self._server.serve_forever()) - self._task = loop.create_task(self.accept_loop(unix_socket, socket_path, loop)) + def _handle_client_connection( + self, + client_transport: asyncio.Transport, + client_protocol: ClientToServerProtocol, + ) -> None: + """ + Register a new client connection and initiate the task to create a database connection. + This is called by ClientToServerProtocol.connection_made - except Exception: - raise LocalProxyStartupError( - 'Local UNIX socket based proxy was not able to get started.' - ) + :param client_transport: the client transport for the client unix socket + :param client_protocol: the instance for the + :return: None + """ + conn = ProxyClientConnection(client_transport, client_protocol) + self._client_connections.add(conn) + conn.task = self._loop.create_task(self._create_db_instance_connection(conn)) + conn.task.add_done_callback(lambda _: self._client_connections.discard(conn)) - async def accept_loop( + def _handle_server_connection_lost( self, - unix_socket, - socket_path: str, - loop: asyncio.AbstractEventLoop - ) -> asyncio.Task: - """Starts a UNIX based local proxy for transporting messages through - the SSL Socket, and waits until there is a new connection to accept, to register it - and keep track of it. + ) -> None: + """ + Closes the proxy server if the connection to the server is lost - Args: - socket_path: A system path that is going to be used to store the socket. + :return: None + """ + logger.debug(f"Closing proxy server due to lost connection") + self._loop.create_task(self.close()) + + async def _create_db_instance_connection(self, conn: ProxyClientConnection) -> None: + """ + Manages a single proxy connection from a client to the Cloud SQL instance. + """ + try: + logger.debug("_proxy_connection() started") + new_protocol = partial(ServerToClientProtocol, self, conn) + + # Establish connection to the database + await self._server_connection_factory.connect(new_protocol) + logger.debug("_proxy_connection() succeeded") - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + except Exception as e: + logger.error(f"Error handling proxy connection: {e}") + await self.close() + raise e - Raises: - LocalProxyStartupError: Local UNIX socket based proxy was not able to - get started. + async def close(self) -> None: """ - print("on accept loop") - while True: - client, _ = await loop.sock_accept(unix_socket) - self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop))) + Shuts down the proxy server and cleans up resources. + """ + logger.info(f"Closing Unix socket proxy at {self.unix_socket_path}") - async def close_async(self): - proxy_task = asyncio.gather(self._task) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): - pass # This task runs forever so it is expected to throw this exception + if self._server: + self._server.close() + await self._server.wait_closed() + if self._client_connections: + for conn in list(self._client_connections): + conn.close() + await asyncio.wait([c.task for c in self._client_connections if c.task is not None], timeout=0.1) - async def client_socket( - self, client, unix_socket, socket_path, loop - ): - try: - ssl_sock = self._connector.connect( - self._addr, - 'local_unix_socket', - **self._kwargs - ) - while True: - data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) - if not data: - client.close() - break - ssl_sock.sendall(data) - response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) - await loop.sock_sendall(client, response) - finally: - client.close() - os.remove(socket_path) # Clean up the socket file + if os.path.exists(self.unix_socket_path): + os.remove(self.unix_socket_path) + + logger.info(f"Unix socket proxy for {self.unix_socket_path} closed.") + self.alive = False \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..8ffce4d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,10 @@ exclude = ['docs/*', 'samples/*'] [tool.pytest.ini_options] asyncio_mode = "auto" +log_cli = true +log_cli_level = "DEBUG" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S.%f" [tool.ruff.lint] extend-select = ["I"] diff --git a/tests/conftest.py b/tests/conftest.py index 83d7a78f3..54836ee84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,8 @@ """ import asyncio +from asyncio import Server +import logging import os import socket import ssl @@ -36,6 +38,7 @@ from google.cloud.sql.connector.utils import write_to_file SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"] +logger = logging.getLogger(name=__name__) def pytest_addoption(parser: Any) -> None: @@ -84,55 +87,138 @@ def fake_credentials() -> FakeCredentials: return FakeCredentials() -async def start_proxy_server(instance: FakeCSQLInstance) -> None: +async def start_proxy_server_async( + instance: FakeCSQLInstance, with_read_write: bool +) -> Server: """Run local proxy server capable of performing mTLS""" ip_address = "127.0.0.1" port = 3307 - # create socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - # create SSL/TLS context - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - context.minimum_version = ssl.TLSVersion.TLSv1_3 - # tmpdir and its contents are automatically deleted after the CA cert - # and cert chain are loaded into the SSLcontext. The values - # need to be written to files in order to be loaded by the SSLContext - server_key_bytes = instance.server_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), + logger.debug("start_proxy_server_async started") + + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + server_key_bytes = instance.server_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + async with TemporaryDirectory() as tmpdir: + server_filename, _, key_filename = await write_to_file( + tmpdir, instance.server_cert_pem, "", server_key_bytes ) - async with TemporaryDirectory() as tmpdir: - server_filename, _, key_filename = await write_to_file( - tmpdir, instance.server_cert_pem, "", server_key_bytes - ) - context.load_cert_chain(server_filename, key_filename) - # allow socket to be re-used - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # bind socket to Cloud SQL proxy server port on localhost - sock.bind((ip_address, port)) - # listen for incoming connections - sock.listen(5) - - with context.wrap_socket(sock, server_side=True) as ssock: - while True: - conn, _ = ssock.accept() - conn.close() + context.load_cert_chain(server_filename, key_filename) + + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + logger.debug("Received fake connection") + if with_read_write: + line = await reader.readline() + logger.debug(f"Received request {line}") + writer.write("world\n".encode("utf-8")) + await writer.drain() + logger.debug("Wrote response") + if writer.can_write_eof(): + writer.write_eof() + logger.debug("Closing connection") + writer.close() + await writer.wait_closed() + logger.debug("Closed connection") + + server = await asyncio.start_server( + handler, host=ip_address, port=port, ssl=context + ) + logger.debug("Listening on 127.0.0.1:3307") + asyncio.create_task(server.serve_forever()) + return server -@pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeCSQLInstance) -> None: - """Run local proxy server capable of performing mTLS""" - thread = Thread( - target=asyncio.run, - args=( - start_proxy_server( - fake_instance, - ), - ), - daemon=True, +@pytest.fixture(scope="function") +def proxy_server_async(fake_instance: FakeCSQLInstance): + # Create an event loop in a different thread for the server + loop = asyncio.new_event_loop() + + def f(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + logger.debug("exiting thread") + + t = Thread(target=f, args=(loop,)) + t.start() + t.join(1) + + # Submit the server task to the thread + server_fut = asyncio.run_coroutine_threadsafe( + start_proxy_server_async(fake_instance, True), loop ) - thread.start() - thread.join(1.0) # add a delay to allow the proxy server to start + while not server_fut.done(): + t.join(0.1) + logger.debug("proxy_server_async server started") + yield + logger.debug("proxy_server_async fixture done") + + logger.debug("proxy_server_async fixture cleanup") + + # Stop the server after the test is complete + async def stop_server(): + logger.debug("inside_cleanup closing server") + server_fut.result().close() + loop.shutdown_asyncgens() + loop.stop() + logger.debug("inside_cleanup end") + + logger.debug("cleanup starting") + asyncio.run_coroutine_threadsafe(stop_server(), loop) + logger.debug("cleanup done") + while loop.is_running(): + t.join(0.1) + logger.debug("loop is not running") + loop.close() + t.join(1) + + +@pytest.fixture(scope="function") +def proxy_server(fake_instance: FakeCSQLInstance): + # Create an event loop in a different thread for the server + loop = asyncio.new_event_loop() + + def f(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + logger.debug("exiting thread") + + t = Thread(target=f, args=(loop,)) + t.start() + + # Submit the server task to the thread + server_fut = asyncio.run_coroutine_threadsafe( + start_proxy_server_async(fake_instance, False), loop + ) + while not server_fut.done(): + t.join(0.1) + logger.debug("proxy_server_async server started") + yield + logger.debug("proxy_server_async fixture done") + + logger.debug("proxy_server_async fixture cleanup") + + # Stop the server after the test is complete + async def stop_server(): + logger.debug("inside_cleanup closing server") + server_fut.result().close() + loop.shutdown_asyncgens() + loop.stop() + logger.debug("inside_cleanup end") + + logger.debug("cleanup starting") + asyncio.run_coroutine_threadsafe(stop_server(), loop) + logger.debug("cleanup done") + while loop.is_running(): + t.join(0.1) + logger.debug("loop is not running") + loop.close() @pytest.fixture @@ -191,3 +277,33 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache ) yield cache await cache.close() + + +@pytest.fixture +def connected_socket_pair() -> tuple[socket.socket, socket.socket]: + """A fixture that provides a pair of connected sockets.""" + server, client = socket.socketpair() + yield server, client + server.close() + client.close() + + +@pytest.fixture +async def echo_server() -> AsyncGenerator[tuple[str, int], None]: + """A fixture that starts an asyncio echo server.""" + + async def handle_echo(reader, writer): + while True: + data = await reader.read(100) + if not data: + break + writer.write(data) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_echo, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + yield addr + server.close() + await server.wait_closed() \ No newline at end of file diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 6f0e07c42..1fed06356 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -29,7 +29,7 @@ SERVER_PROXY_PORT = 3307 -def create_sqlalchemy_engine( +async def create_sqlalchemy_engine( instance_connection_name: str, user: str, password: str, @@ -83,7 +83,7 @@ def create_sqlalchemy_engine( connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) unix_socket_folder = "/tmp/conn" unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" - connector.start_unix_socket_proxy_async( + await connector.start_unix_socket_proxy_async( instance_connection_name, unix_socket_path, ip_type=ip_type, # can be "public", "private" or "psc" @@ -107,7 +107,7 @@ def create_sqlalchemy_engine( # [END cloud_sql_connector_postgres_psycopg] -def test_psycopg_connection() -> None: +async def test_psycopg_connection() -> None: """Basic test to get time from database.""" inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_USER"] @@ -115,7 +115,7 @@ def test_psycopg_connection() -> None: db = os.environ["POSTGRES_DB"] ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine( + engine, connector = await create_sqlalchemy_engine( inst_conn_name, user, password, db, ip_type ) with engine.connect() as conn: diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index e3963faee..f36464b71 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -1,12 +1,9 @@ """ Copyright 2021 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,10 +12,10 @@ """ import asyncio +import logging import os from threading import Thread import socket -import ssl from typing import Union from aiohttp import ClientResponseError @@ -39,20 +36,28 @@ from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.proxy import start_local_proxy +logger = logging.getLogger(name=__name__) + @pytest.mark.asyncio async def test_connect_enable_iam_auth_error( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that calling connect() with different enable_iam_auth argument values creates two cache entries.""" connect_string = "test-project:test-region:test-instance" + server, client = connected_socket_pair async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: connector._client = fake_client # patch db connection creation - with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect: + with ( + patch("socket.create_connection", return_value=client), + patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect, + ): mock_connect.return_value = True # connect with enable_iam_auth False connection = await connector.connect_async( @@ -85,6 +90,7 @@ async def test_connect_enable_iam_auth_error( async def test_connect_incompatible_driver_error( fake_credentials: Credentials, fake_client: CloudSQLClient, + proxy_server, ) -> None: """Test that calling connect() with driver that is incompatible with database version throws error.""" @@ -94,14 +100,8 @@ async def test_connect_incompatible_driver_error( ) as connector: connector._client = fake_client # try to connect using pymysql driver to a Postgres database - with pytest.raises(IncompatibleDriverError) as exc_info: + with pytest.raises(IncompatibleDriverError): await connector.connect_async(connect_string, "pymysql") - assert ( - exc_info.value.args[0] - == "Database driver 'pymysql' is incompatible with database version" - " 'POSTGRES_15'. Given driver can only be used with Cloud SQL MYSQL" - " databases." - ) def test_connect_with_unsupported_driver(fake_credentials: Credentials) -> None: @@ -242,13 +242,19 @@ def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None: def test_Connector_connect_bad_ip_type( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that Connector.connect errors due to bad ip_type str.""" + server, client = connected_socket_pair with Connector(credentials=fake_credentials) as connector: connector._client = fake_client bad_ip_type = "bad-ip-type" - with pytest.raises(ValueError) as exc_info: + with ( + patch("socket.create_connection", return_value=client), + pytest.raises(ValueError) as exc_info, + ): connector.connect( "test-project:test-region:test-instance", "pg8000", @@ -266,15 +272,21 @@ def test_Connector_connect_bad_ip_type( @pytest.mark.asyncio async def test_Connector_connect_async( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that Connector.connect_async can properly return a DB API connection.""" + server, client = connected_socket_pair async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: connector._client = fake_client # patch db connection creation - with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect: + with ( + patch("socket.create_connection", return_value=client), + patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect, + ): mock_connect.return_value = True connection = await connector.connect_async( "test-project:test-region:test-instance", @@ -286,48 +298,6 @@ async def test_Connector_connect_async( # verify connector made connection call assert connection is True -# @pytest.mark.usefixtures("proxy_server") -# @pytest.mark.asyncio -# async def test_Connector_connect_local_proxy( -# fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext -# ) -> None: -# """Test that Connector.connect can launch start_local_proxy.""" -# async with Connector( -# credentials=fake_credentials, loop=asyncio.get_running_loop() -# ) as connector: -# connector._client = fake_client -# socket_path = "/tmp/connector-socket/socket" -# ip_addr = "127.0.0.1" -# ssl_sock = context.wrap_socket( -# socket.create_connection((ip_addr, 3307)), -# server_hostname=ip_addr, -# ) -# loop = asyncio.get_running_loop() -# task = start_local_proxy(ssl_sock, socket_path, loop) -# # patch db connection creation -# with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: -# with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: -# mock_connect.return_value = True -# mock_proxy.return_value = task -# connection = await connector.connect_async( -# "test-project:test-region:test-instance", -# "psycopg", -# user="my-user", -# password="my-pass", -# db="my-db", -# local_socket_path=socket_path, -# ) -# # verify connector called local proxy -# mock_connect.assert_called_once() -# mock_proxy.assert_called_once() -# assert connection is True - -# proxy_task = asyncio.gather(task) -# try: -# await asyncio.wait_for(proxy_task, timeout=0.1) -# except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): -# pass # This task runs forever so it is expected to throw this exception - @pytest.mark.asyncio async def test_Connector_connect_async_multiple_event_loops( @@ -395,7 +365,9 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) -> async def test_Connector_remove_cached_bad_instance( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server, ) -> None: """When a Connector attempts to retrieve connection info for a non-existent instance, it should delete the instance from @@ -420,7 +392,9 @@ async def test_Connector_remove_cached_bad_instance( async def test_Connector_remove_cached_no_ip_type( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server, ) -> None: """When a Connector attempts to connect and preferred IP type is not present, it should delete the instance from the cache and ensure no background refresh @@ -552,10 +526,12 @@ def test_configured_quota_project_env_var( @pytest.mark.asyncio -async def test_connect_async_closed_connector( - fake_credentials: Credentials, fake_client: CloudSQLClient +async def test_Connector_start_unix_socket_proxy_async( + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server_async: None, ) -> None: - """Test that calling connect_async() on a closed connector raises an error.""" + """Test that Connector.connect_async can properly return a DB API connection.""" async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: @@ -704,3 +680,117 @@ async def test_Connector_connect_async_custom_dns_resolver_fallback( fake_client.instance.ip_addrs = original_ips +async def test_Connector_start_unix_socket_proxy_async( + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server_async: None, +) -> None: + """Test that Connector.connect_async can properly return a DB API connection.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + # Open proxy connection + # start the proxy server + await connector.start_unix_socket_proxy_async( + "test-project:test-region:test-instance", + "/tmp/csql-python/proxytest/.s.PGSQL.5432", + driver="asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + # Wait for server to start + await asyncio.sleep(0.5) + + reader, writer = await asyncio.open_unix_connection( + "/tmp/csql-python/proxytest/.s.PGSQL.5432" + ) + writer.write("hello\n".encode()) + await writer.drain() + await asyncio.sleep(0.5) + msg = await reader.readline() + assert msg.decode("utf-8") == "world\n" + + +class TestProtocol(asyncio.Protocol): + """ + A protocol to proxy data between two transports. + """ + + def __init__(self): + self._buffer = bytearray() + logger.debug(f"__init__ {self}") + self.received = bytearray() + self.connected = asyncio.Future() + self.future = asyncio.Future() + + def data_received(self, data): + logger.debug("received {!r}".format(data)) + self.received = data + + def connection_made(self, transport): + logger.debug(f"connection_made called {self}") + self.transport = transport + if not self.connected.done(): + self.connected.set_result(True) + # Write the request and EOF + transport.write("hello\n".encode()) + # if transport.can_write_eof(): + # transport.write_eof() + logger.debug(f"connection_made done, wrote hello{self}") + + def eof_received(self) -> bool | None: + logger.debug(f"eof_received {self}") + # If this has received data, then close. + if len(self.received) > 0: + self.transport.close() + if not self.connected.done(): + self.connected.set_result(True) + if not self.future.done(): + self.future.set_result(True) + return True + + def connection_lost(self, exc: Exception | None) -> None: + logger.debug(f"connection_lost {exc} {self}") + self.transport.abort() + if not self.connected.done(): + self.connected.set_result(True) + if not self.future.done(): + self.future.set_result(True) + super().connection_lost(exc) + + +@pytest.mark.asyncio +async def test_Connector_connect_socket_async( + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server_async: None, +) -> None: + """Test that Connector.connect_async can properly return a DB API connection.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + logger.info("client socket opening") + connector._client = fake_client + p = TestProtocol() + + # Open proxy connection + # start the proxy server + future = connector.connect_socket_async( + "test-project:test-region:test-instance", + lambda: p, + driver="asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + logger.info("client socket opening") + await future + logger.info("client socket opened") + await p.connected + logger.info("client socket connected") + await p.future + logger.info("client socket done") + + assert p.received.decode() == "world\n" diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index c1143f19f..0c179186d 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -15,61 +15,366 @@ """ import asyncio -import socket -import ssl -from typing import Any +import os +import shutil +import tempfile +from unittest.mock import MagicMock -from mock import Mock import pytest -from google.cloud.sql.connector import proxy +from google.cloud.sql.connector.proxy import Proxy, ServerConnectionFactory + + +@pytest.fixture +def short_tmpdir(): + """Create a temporary directory with a short path.""" + dir_path = tempfile.mkdtemp(dir="/tmp") + yield dir_path + shutil.rmtree(dir_path) -LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 -@pytest.mark.usefixtures("proxy_server") @pytest.mark.asyncio -async def test_proxy_creates_folder(context: ssl.SSLContext, kwargs: Any) -> None: - """Test to verify that the proxy server is getting back the task.""" - ip_addr = "127.0.0.1" - path = "/tmp/connector-socket/socket" - sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), - server_hostname=ip_addr, - ) - loop = asyncio.get_running_loop() - - task = proxy.start_local_proxy(sock, path, loop) - assert (task is not None) - - proxy_task = asyncio.gather(task) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): - pass # This task runs forever so it is expected to throw this exception - -@pytest.mark.usefixtures("proxy_server") +async def test_proxy_creates_folder_and_socket(short_tmpdir): + """ + Test to verify that the Proxy server creates the folder and socket file. + """ + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + connector = MagicMock(spec=ServerConnectionFactory) + proxy = Proxy(socket_path, connector, asyncio.get_event_loop()) + await proxy.start() + + assert os.path.exists(short_tmpdir) + assert os.path.exists(socket_path) + + await proxy.close() + + +# A mock ServerConnectionFactory for testing purposes. +class MockServerConnectionFactory(ServerConnectionFactory): + def __init__(self, loop): + self.server_protocol = None + self.server_transport = None + self.connect_called = asyncio.Event() + self.connect_ran = asyncio.Event() + self.force_connect_error = False + self.loop = loop + self.server_data = bytearray() + + async def connect(self, protocol_fn): + self.connect_called.set() + if self.force_connect_error: + raise Exception("Forced connection error") + + self.server_protocol = protocol_fn() + # Create a mock transport for server-side communication + self.server_transport = MagicMock(spec=asyncio.Transport) + self.server_transport.write.side_effect = self.server_data.extend + self.server_transport.is_closing.return_value = False + + # Simulate connection made for the server protocol + self.server_protocol.connection_made(self.server_transport) + self.connect_ran.set() + return self.server_transport, self.server_protocol + + +# Test fixture for the proxy +@pytest.fixture +async def proxy_server(short_tmpdir): + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + connector = MockServerConnectionFactory(loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + @pytest.mark.asyncio -async def test_local_proxy_communication(context: ssl.SSLContext, kwargs: Any) -> None: - """Test to verify that the communication is getting through.""" - socket_path = "/tmp/connector-socket/socket" - ssl_sock = Mock(spec=ssl.SSLSocket) - loop = asyncio.get_running_loop() - - with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: - ssl_sock.recv.return_value = b"Received" - - task = proxy.start_local_proxy(ssl_sock, socket_path, loop) - - client.connect(socket_path) - client.sendall(b"Test") - await asyncio.sleep(1) - - ssl_sock.sendall.assert_called_with(b"Test") - response = client.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) - assert (response == b"Received") - - client.close() - await asyncio.sleep(1) - - proxy_task = asyncio.gather(task) - await asyncio.wait_for(proxy_task, timeout=2) +async def test_proxy_client_to_server(proxy_server): + """ + 1. Create a new proxy. Open a client socket to the proxy. Write data to + the client socket. Read data from the server. Check that the data was + received by the server. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + + # Write data to the client socket + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + # Check that the data was received by the server + await asyncio.sleep(0.01) # give event loop a chance to run + assert connector.server_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_server_to_client(proxy_server): + """ + 2. Create a new proxy, Open a client socket. Write data to the server + socket. Read data from the client socket. Check that the data was + received by the client. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + + # Write data from the server to the client + test_data = b"test data from server" + connector.server_protocol.data_received(test_data) + + # Read data from the client socket + received_data = await reader.read(len(test_data)) + + # Check that the data was received by the client + assert received_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_server_connect_fails(proxy_server): + """ + 3. Create a new proxy. Open a client socket. The server socket fails to + connect. Check that the client socket is closed. + """ + proxy, socket_path, connector = proxy_server + connector.force_connect_error = True + + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be attempted + await connector.connect_called.wait() + + assert os.path.exists(socket_path) == True + + # The client connection should be closed by the proxy + # Reading should return EOF + data = await reader.read(100) + assert data == b"" + + await asyncio.sleep(1) # give proxy a chance to shut down + + assert os.path.exists(socket_path) == False + + +@pytest.mark.asyncio +async def test_proxy_client_closes_connection(proxy_server): + """ + 4. Create a new proxy. Open a client socket. Check that the server + socket connected. Close the client socket. Check that the server socket + is closed gracefully. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + assert connector.server_transport is not None + + # Close the client socket + writer.close() + await writer.wait_closed() + + # Check that the server socket is closed + await asyncio.sleep(0.01) # give event loop a chance to run + connector.server_transport.close.assert_called_once() + + +# +# TCP Server Fixtures and Tests +# + + +@pytest.fixture +async def tcp_echo_server(): + """Fixture to create a TCP echo server.""" + + async def echo(reader, writer): + try: + while not reader.at_eof(): + data = await reader.read(1024) + if not data: + break + writer.write(data) + await writer.drain() + finally: + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(echo, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + host, port = addr[0], addr[1] + + yield host, port + + server.close() + await server.wait_closed() + + +@pytest.fixture +async def tcp_server_accept_and_close(): + """Fixture to create a TCP server that accepts and immediately closes.""" + + async def accept_and_close(reader, writer): + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(accept_and_close, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + host, port = addr[0], addr[1] + + yield host, port + + server.close() + await server.wait_closed() + + +class TCPServerConnectionFactory(ServerConnectionFactory): + """A ServerConnectionFactory that connects to a TCP server.""" + + def __init__(self, host, port, loop): + self.host = host + self.port = port + self.loop = loop + self.connect_called = asyncio.Event() + self.connect_ran = asyncio.Event() + self.server_transport: asyncio.Transport | None = None + self.server_protocol: asyncio.Protocol | None = None + + async def connect(self, protocol_fn): + self.connect_called.set() + transport, protocol = await asyncio.wait_for(self.loop.create_connection( + protocol_fn, self.host, self.port, + ), timeout=0.5) + self.server_transport = transport + self.server_protocol = protocol + self.connect_ran.set() + return transport, protocol + + +@pytest.fixture +async def tcp_proxy_server(short_tmpdir, tcp_echo_server): + """Fixture to set up a proxy with a TCP backend.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + host, port = tcp_echo_server + connector = TCPServerConnectionFactory(host, port, loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.fixture +async def tcp_proxy_server_with_closing_backend(short_tmpdir, tcp_server_accept_and_close): + """Fixture to set up a proxy with a TCP backend that closes immediately.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + host, port = tcp_server_accept_and_close + connector = TCPServerConnectionFactory(host, port, loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.fixture +async def tcp_proxy_server_with_no_tcp_server(short_tmpdir): + """Fixture to set up a proxy with a TCP backend that closes immediately.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + connector = TCPServerConnectionFactory("localhost", "34532", loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.mark.asyncio +async def test_tcp_proxy_echo(tcp_proxy_server): + """ + Tests data flow from client to a TCP server and back. + """ + proxy, socket_path, connector = tcp_proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + # Read echoed data back from the server + received_data = await reader.read(len(test_data)) + assert received_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_tcp_proxy_server_connection_refused(tcp_proxy_server_with_no_tcp_server): + """ + Tests that the client socket is closed when TCP connection fails. + """ + proxy, socket_path, connector = tcp_proxy_server_with_no_tcp_server + + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + await asyncio.sleep(1.5) + assert os.path.exists(socket_path) == False + + + +@pytest.mark.asyncio +async def test_tcp_proxy_server_unexpected_closed(tcp_proxy_server_with_closing_backend): + """ + Tests that the client socket is closed when TCP connection fails. + """ + proxy, socket_path, connector = tcp_proxy_server_with_closing_backend + + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + + # The client connection should be closed by the proxy + data = await reader.read(100) + assert data == b"" + + await asyncio.sleep(0.5) # give event loop a chance to run + assert os.path.exists(socket_path) == False + + + +@pytest.mark.asyncio +async def test_tcp_proxy_client_closes_connection(tcp_proxy_server): + """ + Tests that closing the client socket closes the TCP server socket. + """ + proxy, socket_path, connector = tcp_proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_ran.wait() + + assert connector.server_transport is not None + assert not connector.server_transport.is_closing() + + # Close the client socket + writer.close() + await writer.wait_closed() + + # Check that the server socket is closing + await asyncio.sleep(0.01) + assert connector.server_transport.is_closing() \ No newline at end of file From 7cd74a536099909c5f867002625e6db2529f492c Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Fri, 17 Oct 2025 16:09:19 -0600 Subject: [PATCH 13/18] feat: replace ssl_object with sslcontext and build a socket manually --- google/cloud/sql/connector/connector.py | 7 +- tests/system/test_psycopg_connection.py | 129 +++++++++--------------- 2 files changed, 51 insertions(+), 85 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 944634d88..232cb51cf 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -20,6 +20,7 @@ from functools import partial import logging import os +import socket from threading import Thread from types import TracebackType from typing import Any, Callable, Optional, Union @@ -564,7 +565,11 @@ async def connect_async( instance_connection_string, asyncio.Protocol, **kwargs ) # See https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseTransport.get_extra_info - sock = tx.get_extra_info("ssl_object") + ctx = tx.get_extra_info("sslcontext") + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) connect_partial = partial(connector, ip_address, sock, **kwargs) return await self._loop.run_in_executor(None, connect_partial) diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 1fed06356..ae744dacf 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -14,6 +14,7 @@ limitations under the License. """ +import asyncio from datetime import datetime import os @@ -21,6 +22,8 @@ from typing import Union from psycopg import Connection +import pytest +import logging import sqlalchemy from google.cloud.sql.connector import Connector @@ -29,98 +32,56 @@ SERVER_PROXY_PORT = 3307 -async def create_sqlalchemy_engine( - instance_connection_name: str, - user: str, - password: str, - db: str, - ip_type: str = "public", - refresh_strategy: str = "background", - resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, -) -> tuple[sqlalchemy.engine.Engine, Connector]: - """Creates a connection pool for a Cloud SQL instance and returns the pool - and the connector. Callers are responsible for closing the pool and the - connector. - - A sample invocation looks like: - - engine, connector = create_sqlalchemy_engine( - inst_conn_name, - user, - password, - db, - ) - with engine.connect() as conn: - time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() - conn.commit() - curr_time = time[0] - # do something with query result - connector.close() - - Args: - instance_connection_name (str): - The instance connection name specifies the instance relative to the - project and region. For example: "my-project:my-region:my-instance" - user (str): - The database user name, e.g., root - password (str): - The database user's password, e.g., secret-password - db (str): - The name of the database, e.g., mydb - ip_type (str): - The IP type of the Cloud SQL instance to connect to. Can be one - of "public", "private", or "psc". - refresh_strategy (Optional[str]): - Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" - or "background". For serverless environments use "lazy" to avoid - errors resulting from CPU being throttled. - resolver (Optional[google.cloud.sql.connector.DefaultResolver]): - Resolver class for resolving instance connection name. Use - google.cloud.sql.connector.DnsResolver when resolving DNS domain - names or google.cloud.sql.connector.DefaultResolver for regular - instance connection names ("my-project:my-region:my-instance"). - """ - connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) - unix_socket_folder = "/tmp/conn" - unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" - await connector.start_unix_socket_proxy_async( - instance_connection_name, - unix_socket_path, - ip_type=ip_type, # can be "public", "private" or "psc" - ) - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "postgresql+psycopg://", - creator=lambda: Connection.connect( - f"host={unix_socket_folder} port={SERVER_PROXY_PORT} dbname={db} user={user} password={password} sslmode=require", - user=user, - password=password, - dbname=db, - autocommit=True, - ) - ) - - return engine, connector - +logger = logging.getLogger(name=__name__) # [END cloud_sql_connector_postgres_psycopg] +@pytest.mark.asyncio async def test_psycopg_connection() -> None: """Basic test to get time from database.""" - inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] + instance_connection_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = await create_sqlalchemy_engine( - inst_conn_name, user, password, db, ip_type - ) - with engine.connect() as conn: - time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() - conn.commit() - curr_time = time[0] - assert type(curr_time) is datetime - connector.close() + unix_socket_folder = "/tmp/conn" + unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" + + async with Connector( + refresh_strategy='lazy', resolver=DefaultResolver + ) as connector: + # Open proxy connection + # start the proxy server + + await connector.start_unix_socket_proxy_async( + instance_connection_name, + unix_socket_path, + driver="psycopg", + user=user, + password=password, + db=db, + ip_type=ip_type, # can be "public", "private" or "psc" + ) + + # Wait for server to start + await asyncio.sleep(0.5) + + engine = sqlalchemy.create_engine( + "postgresql+psycopg://", + creator=lambda: Connection.connect( + f"host={unix_socket_folder} port={SERVER_PROXY_PORT} dbname={db} user={user} password={password} sslmode=require", + user=user, + password=password, + dbname=db, + autocommit=True, + ) + ) + + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() From 75f4ca9d32bded350e977ada4c7e9d3a21b03613 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Fri, 17 Oct 2025 16:51:22 -0600 Subject: [PATCH 14/18] feat: Pass loop to aiohttp client --- google/cloud/sql/connector/client.py | 4 ++-- google/cloud/sql/connector/connector.py | 2 ++ tests/system/test_psycopg_connection.py | 4 ---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 11508ce17..ef748eb19 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -58,6 +58,7 @@ def __init__( client: Optional[aiohttp.ClientSession] = None, driver: Optional[str] = None, user_agent: Optional[str] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Establishes the client to be used for Cloud SQL Admin API requests. @@ -84,8 +85,7 @@ def __init__( } if quota_project: headers["x-goog-user-project"] = quota_project - - self._client = client if client else aiohttp.ClientSession(headers=headers) + self._client = client if client else aiohttp.ClientSession(headers=headers, loop=loop) self._credentials = credentials if sqladmin_api_endpoint is None: self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 232cb51cf..0f10e35f1 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -371,6 +371,7 @@ def _init_client(self, driver: Optional[str]) -> CloudSQLClient: self._credentials, user_agent=self._user_agent, driver=driver, + loop=self._loop ) return self._client @@ -432,6 +433,7 @@ async def connect_async( self._credentials, user_agent=self._user_agent, driver=driver, + loop=self._loop ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index ae744dacf..25a46f3b8 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -58,10 +58,6 @@ async def test_psycopg_connection() -> None: await connector.start_unix_socket_proxy_async( instance_connection_name, unix_socket_path, - driver="psycopg", - user=user, - password=password, - db=db, ip_type=ip_type, # can be "public", "private" or "psc" ) From de7ee14061687109717c72dc0987c72d83c9e0ce Mon Sep 17 00:00:00 2001 From: ThameezBo Date: Mon, 23 Feb 2026 23:04:25 +0100 Subject: [PATCH 15/18] fix: post merge conflict changes and linting --- google/cloud/sql/connector/connector.py | 83 +++++++++---------- .../cloud/sql/connector/local_unix_socket.py | 3 +- google/cloud/sql/connector/proxy.py | 2 +- tests/system/test_psycopg_connection.py | 5 +- tests/unit/test_connector.py | 37 +-------- tests/unit/test_local_unix_socket.py | 2 - tests/unit/test_proxy.py | 11 +-- 7 files changed, 48 insertions(+), 95 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 0f10e35f1..3d6570bbc 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -503,59 +503,52 @@ async def connect_async( monitored_cache = await self._get_cache( instance_connection_string, enable_iam_auth, ip_type, driver ) - conn_info = await monitored_cache.connect_info() - ip_address = conn_info.get_preferred_ip(ip_type) try: conn_info = await monitored_cache.connect_info() # validate driver matches intended database engine DriverMapping.validate_engine(driver, conn_info.database_version) ip_address = conn_info.get_preferred_ip(ip_type) - except Exception: - # with an error from Cloud SQL Admin API call or IP type, invalidate - # the cache and re-raise the error - await self._remove_cached(str(conn_name), enable_iam_auth) - raise - # If the connector is configured with a custom DNS name, attempt to use - # that DNS name to connect to the instance. Fall back to the metadata IP - # address if the DNS name does not resolve to an IP address. - if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): - try: - ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) - if ips: - ip_address = ips[0] + # If the connector is configured with a custom DNS name, attempt to use + # that DNS name to connect to the instance. Fall back to the metadata IP + # address if the DNS name does not resolve to an IP address. + if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): + try: + ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) + if ips: + ip_address = ips[0] + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " + "using it to connect" + ) + else: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved but returned no " + f"entries, using '{ip_address}' from instance metadata" + ) + except Exception as e: logger.debug( f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " - "using it to connect" + f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " + f"address: {e}, using '{ip_address}' from instance metadata" ) - else: - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' resolved but returned no " - f"entries, using '{ip_address}' from instance metadata" - ) - except Exception as e: - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " - f"address: {e}, using '{ip_address}' from instance metadata" - ) - logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") - # format `user` param for automatic IAM database authn - if enable_iam_auth: - formatted_user = format_database_user( - conn_info.database_version, kwargs["user"] - ) - if formatted_user != kwargs["user"]: - logger.debug( - f"['{instance_connection_string}']: " - "Truncated IAM database username from " - f"{kwargs['user']} to {formatted_user}" + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") + # format `user` param for automatic IAM database authn + if enable_iam_auth: + formatted_user = format_database_user( + conn_info.database_version, kwargs["user"] ) - kwargs["user"] = formatted_user + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: " + "Truncated IAM database username from " + f"{kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user ctx = await conn_info.create_ssl_context(enable_iam_auth) # async drivers are unblocking and can be awaited directly @@ -576,11 +569,9 @@ async def connect_async( return await self._loop.run_in_executor(None, connect_partial) except Exception: - # with any exception, we attempt a force refresh, then throw the error - monitored_cache = await self._get_cache( - instance_connection_string, enable_iam_auth, ip_type, driver - ) - await monitored_cache.force_refresh() + # with an error from Cloud SQL Admin API call or connection, invalidate + # the cache and re-raise the error + await self._remove_cached(str(conn_name), enable_iam_auth) raise async def start_unix_socket_proxy_async( diff --git a/google/cloud/sql/connector/local_unix_socket.py b/google/cloud/sql/connector/local_unix_socket.py index 25a9d3be3..31cbff341 100644 --- a/google/cloud/sql/connector/local_unix_socket.py +++ b/google/cloud/sql/connector/local_unix_socket.py @@ -15,7 +15,8 @@ """ import ssl -from typing import Any, TYPE_CHECKING +from typing import Any + def connect( host: str, sock: ssl.SSLSocket, **kwargs: Any diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 99a121782..48a67ef04 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -224,7 +224,7 @@ def _handle_server_connection_lost( :return: None """ - logger.debug(f"Closing proxy server due to lost connection") + logger.debug("Closing proxy server due to lost connection") self._loop.create_task(self.close()) async def _create_db_instance_connection(self, conn: ProxyClientConnection) -> None: diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 25a46f3b8..00e4184b6 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -16,19 +16,16 @@ import asyncio from datetime import datetime +import logging import os # [START cloud_sql_connector_postgres_psycopg] -from typing import Union - from psycopg import Connection import pytest -import logging import sqlalchemy from google.cloud.sql.connector import Connector from google.cloud.sql.connector import DefaultResolver -from google.cloud.sql.connector import DnsResolver SERVER_PROXY_PORT = 3307 diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index f36464b71..f3fbd808a 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -14,8 +14,8 @@ import asyncio import logging import os -from threading import Thread import socket +from threading import Thread from typing import Union from aiohttp import ClientResponseError @@ -34,7 +34,6 @@ from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.resolver import DnsResolver -from google.cloud.sql.connector.proxy import start_local_proxy logger = logging.getLogger(name=__name__) @@ -679,40 +678,6 @@ async def test_Connector_connect_async_custom_dns_resolver_fallback( # Restore original IPs fake_client.instance.ip_addrs = original_ips - -async def test_Connector_start_unix_socket_proxy_async( - fake_credentials: Credentials, - fake_client: CloudSQLClient, - proxy_server_async: None, -) -> None: - """Test that Connector.connect_async can properly return a DB API connection.""" - async with Connector( - credentials=fake_credentials, loop=asyncio.get_running_loop() - ) as connector: - connector._client = fake_client - # Open proxy connection - # start the proxy server - await connector.start_unix_socket_proxy_async( - "test-project:test-region:test-instance", - "/tmp/csql-python/proxytest/.s.PGSQL.5432", - driver="asyncpg", - user="my-user", - password="my-pass", - db="my-db", - ) - # Wait for server to start - await asyncio.sleep(0.5) - - reader, writer = await asyncio.open_unix_connection( - "/tmp/csql-python/proxytest/.s.PGSQL.5432" - ) - writer.write("hello\n".encode()) - await writer.drain() - await asyncio.sleep(0.5) - msg = await reader.readline() - assert msg.decode("utf-8") == "world\n" - - class TestProtocol(asyncio.Protocol): """ A protocol to proxy data between two transports. diff --git a/tests/unit/test_local_unix_socket.py b/tests/unit/test_local_unix_socket.py index 8672857ec..139592ebc 100644 --- a/tests/unit/test_local_unix_socket.py +++ b/tests/unit/test_local_unix_socket.py @@ -18,8 +18,6 @@ import ssl from typing import Any -from mock import patch -from mock import PropertyMock import pytest from google.cloud.sql.connector.local_unix_socket import connect diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index 0c179186d..7441f387e 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -22,7 +22,8 @@ import pytest -from google.cloud.sql.connector.proxy import Proxy, ServerConnectionFactory +from google.cloud.sql.connector.proxy import Proxy +from google.cloud.sql.connector.proxy import ServerConnectionFactory @pytest.fixture @@ -156,7 +157,7 @@ async def test_proxy_server_connect_fails(proxy_server): # wait for server connection to be attempted await connector.connect_called.wait() - assert os.path.exists(socket_path) == True + assert os.path.exists(socket_path) # The client connection should be closed by the proxy # Reading should return EOF @@ -165,7 +166,7 @@ async def test_proxy_server_connect_fails(proxy_server): await asyncio.sleep(1) # give proxy a chance to shut down - assert os.path.exists(socket_path) == False + assert os.path.exists(socket_path) @pytest.mark.asyncio @@ -336,7 +337,7 @@ async def test_tcp_proxy_server_connection_refused(tcp_proxy_server_with_no_tcp_ await writer.drain() await asyncio.sleep(1.5) - assert os.path.exists(socket_path) == False + assert os.path.exists(socket_path) @@ -355,7 +356,7 @@ async def test_tcp_proxy_server_unexpected_closed(tcp_proxy_server_with_closing_ assert data == b"" await asyncio.sleep(0.5) # give event loop a chance to run - assert os.path.exists(socket_path) == False + assert os.path.exists(socket_path) From 269a4fe506c423f9002eb9d5e06b89a13b574956 Mon Sep 17 00:00:00 2001 From: ThameezBo Date: Mon, 23 Feb 2026 23:08:59 +0100 Subject: [PATCH 16/18] fix: unit tests --- tests/unit/test_proxy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index 7441f387e..a8f6f3dc3 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -166,7 +166,7 @@ async def test_proxy_server_connect_fails(proxy_server): await asyncio.sleep(1) # give proxy a chance to shut down - assert os.path.exists(socket_path) + assert not os.path.exists(socket_path) @pytest.mark.asyncio @@ -337,7 +337,7 @@ async def test_tcp_proxy_server_connection_refused(tcp_proxy_server_with_no_tcp_ await writer.drain() await asyncio.sleep(1.5) - assert os.path.exists(socket_path) + assert not os.path.exists(socket_path) @@ -356,7 +356,7 @@ async def test_tcp_proxy_server_unexpected_closed(tcp_proxy_server_with_closing_ assert data == b"" await asyncio.sleep(0.5) # give event loop a chance to run - assert os.path.exists(socket_path) + assert not os.path.exists(socket_path) From cca6973b6e366242c0ee715209444d458681a0ee Mon Sep 17 00:00:00 2001 From: ThameezBo Date: Mon, 23 Feb 2026 23:12:09 +0100 Subject: [PATCH 17/18] fix: update psycopg binary --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index 25311db58..7ef4dd09c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,7 +7,7 @@ sqlalchemy-pytds==1.0.2 sqlalchemy-stubs==0.4 PyMySQL==1.1.2 pg8000==1.31.5 -psycopg[binary]==3.2.9 +psycopg[binary]==3.3.3 asyncpg==0.31.0 python-tds==1.17.1 aioresponses==0.7.8 From 91ba7083e23c633b296c002ffaeb5725cda1d74f Mon Sep 17 00:00:00 2001 From: ThameezBo Date: Mon, 23 Feb 2026 23:19:23 +0100 Subject: [PATCH 18/18] fix: lic header --- .ci/cloudbuild.yaml | 2 +- .github/trusted-contribution.yml | 2 +- .github/workflows/cloud_build_failure_reporter.yml | 2 +- .github/workflows/schedule_reporter.yml | 2 +- build.sh | 2 +- google/cloud/sql/connector/local_unix_socket.py | 2 +- google/cloud/sql/connector/monitored_cache.py | 2 +- google/cloud/sql/connector/proxy.py | 2 +- pyproject.toml | 2 +- samples/cloudrun/mysql/main.py | 2 +- samples/cloudrun/postgres/main.py | 2 +- samples/cloudrun/sqlserver/main.py | 2 +- tests/system/test_psycopg_connection.py | 2 +- tests/unit/test_local_unix_socket.py | 2 +- tests/unit/test_monitored_cache.py | 2 +- tests/unit/test_proxy.py | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/.ci/cloudbuild.yaml b/.ci/cloudbuild.yaml index e31e8bebb..57a845b75 100644 --- a/.ci/cloudbuild.yaml +++ b/.ci/cloudbuild.yaml @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/.github/trusted-contribution.yml b/.github/trusted-contribution.yml index 18580d069..bd6101b09 100644 --- a/.github/trusted-contribution.yml +++ b/.github/trusted-contribution.yml @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/.github/workflows/cloud_build_failure_reporter.yml b/.github/workflows/cloud_build_failure_reporter.yml index 493ddecd2..3d155d074 100644 --- a/.github/workflows/cloud_build_failure_reporter.yml +++ b/.github/workflows/cloud_build_failure_reporter.yml @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml index bad7e46c2..44899bbbd 100644 --- a/.github/workflows/schedule_reporter.yml +++ b/.github/workflows/schedule_reporter.yml @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build.sh b/build.sh index d3800a299..4342dee4a 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -# Copyright 2025 Google LLC. +# Copyright 2026 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/sql/connector/local_unix_socket.py b/google/cloud/sql/connector/local_unix_socket.py index 31cbff341..a439f2757 100644 --- a/google/cloud/sql/connector/local_unix_socket.py +++ b/google/cloud/sql/connector/local_unix_socket.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 0c3fc4d03..fb11a0730 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 48a67ef04..25686988b 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/pyproject.toml b/pyproject.toml index 8ffce4d63..c2735b2da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/samples/cloudrun/mysql/main.py b/samples/cloudrun/mysql/main.py index b1b546682..cfbcee23c 100644 --- a/samples/cloudrun/mysql/main.py +++ b/samples/cloudrun/mysql/main.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/samples/cloudrun/postgres/main.py b/samples/cloudrun/postgres/main.py index e33d5a06c..9c9707f28 100644 --- a/samples/cloudrun/postgres/main.py +++ b/samples/cloudrun/postgres/main.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/samples/cloudrun/sqlserver/main.py b/samples/cloudrun/sqlserver/main.py index 0ce8162cf..7bd8f3b32 100644 --- a/samples/cloudrun/sqlserver/main.py +++ b/samples/cloudrun/sqlserver/main.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 00e4184b6..7b248f09a 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tests/unit/test_local_unix_socket.py b/tests/unit/test_local_unix_socket.py index 139592ebc..75eccd029 100644 --- a/tests/unit/test_local_unix_socket.py +++ b/tests/unit/test_local_unix_socket.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py index 1c1f1df86..61f651f90 100644 --- a/tests/unit/test_monitored_cache.py +++ b/tests/unit/test_monitored_cache.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index a8f6f3dc3..654d94b48 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -1,5 +1,5 @@ """ -Copyright 2025 Google LLC +Copyright 2026 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.