-
Notifications
You must be signed in to change notification settings - Fork 88
Expand file tree
/
Copy pathtest_resolver.py
More file actions
168 lines (142 loc) · 5.28 KB
/
test_resolver.py
File metadata and controls
168 lines (142 loc) · 5.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dns.message
import dns.rdataclass
import dns.rdatatype
import dns.resolver
from mock import patch
import pytest
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError
from google.cloud.sql.connector.resolver import DefaultResolver
from google.cloud.sql.connector.resolver import DnsResolver
conn_str = "my-project:my-region:my-instance"
conn_name = ConnectionName("my-project", "my-region", "my-instance")
conn_name_with_domain = ConnectionName(
"my-project", "my-region", "my-instance", "db.example.com"
)
async def test_DefaultResolver() -> None:
"""Test DefaultResolver just parses instance connection string."""
resolver = DefaultResolver()
result = await resolver.resolve(conn_str)
assert result == conn_name
async def test_DnsResolver_with_conn_str() -> None:
"""Test DnsResolver with instance connection name just parses connection string."""
resolver = DnsResolver()
result = await resolver.resolve(conn_str)
assert result == conn_name
query_text = """id 1234
opcode QUERY
rcode NOERROR
flags QR AA RD RA
;QUESTION
db.example.com. IN TXT
;ANSWER
db.example.com. 0 IN TXT "test-project:test-region:test-instance"
db.example.com. 0 IN TXT "my-project:my-region:my-instance"
;AUTHORITY
;ADDITIONAL
"""
async def test_DnsResolver_with_dns_name() -> None:
"""Test DnsResolver resolves TXT record into proper instance connection name.
Should sort valid TXT records alphabetically and take first one.
"""
# Patch DNS resolution with valid TXT records
with patch("dns.asyncresolver.Resolver.resolve") as mock_connect:
answer = dns.resolver.Answer(
"db.example.com",
dns.rdatatype.TXT,
dns.rdataclass.IN,
dns.message.from_text(query_text),
)
mock_connect.return_value = answer
resolver = DnsResolver()
resolver.port = 5053
# Resolution should return first value sorted alphabetically
result = await resolver.resolve("db.example.com")
assert result == conn_name_with_domain
query_text_malformed = """id 1234
opcode QUERY
rcode NOERROR
flags QR AA RD RA
;QUESTION
bad.example.com. IN TXT
;ANSWER
bad.example.com. 0 IN TXT "malformed-instance-name"
;AUTHORITY
;ADDITIONAL
"""
async def test_DnsResolver_with_malformed_txt() -> None:
"""Test DnsResolver with TXT record that holds malformed instance connection name.
Should throw DnsResolutionError
"""
# patch DNS resolution with malformed TXT record
with patch("dns.asyncresolver.Resolver.resolve") as mock_connect:
answer = dns.resolver.Answer(
"bad.example.com",
dns.rdatatype.TXT,
dns.rdataclass.IN,
dns.message.from_text(query_text_malformed),
)
mock_connect.return_value = answer
resolver = DnsResolver()
resolver.port = 5053
with pytest.raises(DnsResolutionError) as exc_info:
await resolver.resolve("bad.example.com")
assert (
exc_info.value.args[0]
== "Unable to parse TXT record for `bad.example.com` -> `malformed-instance-name`"
)
async def test_DnsResolver_with_bad_dns_name() -> None:
"""Test DnsResolver with bad dns name.
Should throw DnsResolutionError
"""
resolver = DnsResolver()
resolver.port = 5053
# set lifetime to 1 second for shorter timeout
resolver.lifetime = 1
with pytest.raises(DnsResolutionError) as exc_info:
await resolver.resolve("bad.dns.com")
assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`"
a_record_query_text = """id 1234
opcode QUERY
rcode NOERROR
flags QR AA RD RA
;QUESTION
db.example.com. IN A
;ANSWER
db.example.com. 0 IN A 127.0.0.1
;AUTHORITY
;ADDITIONAL
"""
async def test_DnsResolver_resolve_a_record() -> None:
"""Test DnsResolver resolves A record into IP address."""
with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve:
answer = dns.resolver.Answer(
"db.example.com",
dns.rdatatype.A,
dns.rdataclass.IN,
dns.message.from_text(a_record_query_text),
)
mock_resolve.return_value = answer
resolver = DnsResolver()
result = await resolver.resolve_a_record("db.example.com")
assert result == ["127.0.0.1"]
async def test_DnsResolver_resolve_a_record_empty() -> None:
"""Test DnsResolver resolves A record but gets error."""
with patch("dns.asyncresolver.Resolver.resolve") as mock_resolve:
mock_resolve.side_effect = Exception("DNS Error")
resolver = DnsResolver()
result = await resolver.resolve_a_record("db.example.com")
assert result == []