-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathupdate_telephony_vm_sg.py
More file actions
259 lines (222 loc) · 9.04 KB
/
update_telephony_vm_sg.py
File metadata and controls
259 lines (222 loc) · 9.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
update_telephony_vm_sg.py
Updates security group ingress rules when AWS IP ranges change.
Triggered by SNS topic: AmazonIpSpaceChanged
IAM policy:
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "UpdateSecurityGroups",
"Effect": "Allow",
"Action": [
"ec2:DescribeSecurityGroups",
"ec2:AuthorizeSecurityGroupIngress",
"ec2:RevokeSecurityGroupIngress"
],
"Resource": "*"
},
{
"Sid": "GetIPRanges",
"Effect": "Allow",
"Action": ["s3:GetObject"],
"Resource": "arn:aws:s3:<AWS_REGION>:<AWS_ACCOUNT_ID>:amazonaws-ip-ranges/*"
}
]
}
Test event (SNS message):
{
"Records": [
{
"Sns": {
"Message": "{\"create-time\":\"2024-01-01T00:00:00.000Z\",\"ip-ranges\":\"https://ip-ranges.amazonaws.com/ip-ranges.json\"}"
}
}
]
}
"""
import boto3
import json
import logging
from typing import Any
import requests
from botocore.exceptions import ClientError
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Security group ARNs to update
SECURITY_GROUP_ARNS = [
# TODO: replace with actual security group ID
"sg-0a1b2c3d4e5f6g7h8", # security group for VM `client0-telephony-t2.medium-1`
]
# AWS services and regions to monitor
SERVICES_TO_UPDATE = [
{"service": "CHIME_VOICECONNECTOR", "region": "us-east-1"},
{"service": "AMAZON", "region": "us-east-1"},
{"service": "EC2", "region": "us-east-1"}
]
# Port ranges to update for each service
PORT_CONFIGS = [
{"from_port": 5061, "to_port": 5061, "protocol": "tcp", "description": "SIP (TCP)"},
{"from_port": 10000, "to_port": 10299, "protocol": "udp", "description": "RTP UDP"}
]
ec2_client = boto3.client("ec2")
def get_aws_ip_ranges() -> dict[str, Any]:
"""
Fetch current AWS IP ranges from the official endpoint.
Returns:
dict[str, Any]: IP ranges.
"""
try:
logger.debug("Fetching AWS IP ranges from official endpoint")
response = requests.get(
url="https://ip-ranges.amazonaws.com/ip-ranges.json",
timeout=30
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error("Failed to fetch IP ranges: %s", str(e))
raise
def get_service_ip_ranges(ip_ranges: dict[str, Any], service: str, region: str) -> list[str]:
"""
Extract IP ranges for a specific service and region.
Args:
ip_ranges (dict[str, Any]): IP ranges.
service (str): Service name.
region (str): Region name.
Returns:
list[str]: IP ranges.
"""
service_ranges = []
for prefix in ip_ranges.get("prefixes", []):
if (
prefix.get("service") == service
and prefix.get("region") == region
):
service_ranges.append(prefix["ip_prefix"])
return service_ranges
def update_security_group_ingress(sg_id: str, ip_ranges: list[str], from_port: int, to_port: int, protocol: str) -> None:
"""
Update security group ingress rules for specified port/protocol.
Args:
sg_id (str): Security group ID.
ip_ranges (list[str]): IP ranges.
from_port (int): Starting port number.
to_port (int): Ending port number.
protocol (str): Protocol name.
Raises:
ClientError: If the security group cannot be updated.
"""
try:
# get current security group rules
logger.info("Getting current security group rules for %s", sg_id)
response = ec2_client.describe_security_groups(GroupIds=[sg_id])
security_group = response["SecurityGroups"][0]
# find existing rules for this port/protocol
existing_rules = []
for rule in security_group.get("IpPermissions", []):
if (
rule.get("FromPort") == from_port
and rule.get("ToPort") == to_port
and rule.get("IpProtocol") == protocol
):
existing_rules.extend(
[ip_range["CidrIp"] for ip_range in rule.get("IpRanges", [])]
)
# determine which rules to add and remove
current_ranges = set(existing_rules)
target_ranges = set(ip_ranges)
ranges_to_add = target_ranges - current_ranges
ranges_to_remove = current_ranges - target_ranges
logger.debug("Security group %s: %d ranges to add, %d ranges to remove", sg_id, len(ranges_to_add), len(ranges_to_remove))
# remove outdated rules
if ranges_to_remove:
logger.info("Removing %d outdated IP ranges from %s", len(ranges_to_remove), sg_id)
ec2_client.revoke_security_group_ingress(
GroupId=sg_id,
IpPermissions=[{
"IpProtocol": protocol,
"FromPort": from_port,
"ToPort": to_port,
"IpRanges": [{"CidrIp": ip_range} for ip_range in ranges_to_remove]
}]
)
# add new rules
if ranges_to_add:
logger.info("Adding %d new IP ranges to %s", len(ranges_to_add), sg_id)
ec2_client.authorize_security_group_ingress(
GroupId=sg_id,
IpPermissions=[{
"IpProtocol": protocol,
"FromPort": from_port,
"ToPort": to_port,
"IpRanges": [{"CidrIp": ip_range} for ip_range in ranges_to_add]
}]
)
logger.info("Security group %s updated successfully", sg_id)
except ClientError as e:
logger.error("Failed to update security group %s: %s", sg_id, str(e))
raise
def lambda_handler(event, context):
"""Main Lambda handler function."""
logger.info("Received event: %s", json.dumps(event))
try:
# extract SNS message
if "Records" not in event:
logger.error("No SNS records found in event")
return {"statusCode": 400, "body": "No SNS records found"}
# process each SNS record
for record in event["Records"]:
if "Sns" not in record:
logger.warning("Record is not an SNS message, skipping")
continue
sns_message = record["Sns"]["Message"]
logger.info("Processing SNS message: %s", sns_message)
# fetch current AWS IP ranges
ip_ranges = get_aws_ip_ranges()
logger.info("Fetched IP ranges for %d prefixes", len(ip_ranges.get("prefixes", [])))
# update each security group for each service
for service_config in SERVICES_TO_UPDATE:
service = service_config["service"]
region = service_config["region"]
logger.debug("Processing service %s in region %s", service, region)
service_ip_ranges = get_service_ip_ranges(ip_ranges, service, region)
logger.info("Found %d IP ranges for %s in %s", len(service_ip_ranges), service, region)
if not service_ip_ranges:
logger.warning("No IP ranges found for %s in %s", service, region)
continue
# update each security group
for sg_arn in SECURITY_GROUP_ARNS:
logger.debug("Processing security group ARN: %s", sg_arn)
# extract security group ID from ARN
sg_id = sg_arn.split("/")[-1]
try:
# Update each port configuration for this service
for port_config in PORT_CONFIGS:
try:
update_security_group_ingress(
sg_id=sg_id,
ip_ranges=service_ip_ranges,
from_port=port_config["from_port"],
to_port=port_config["to_port"],
protocol=port_config["protocol"]
)
logger.info("Successfully updated %s for %s on %s", sg_id, service, port_config["description"])
except Exception as port_error:
logger.error("Failed to update %s for %s on %s: %s", sg_id, service, port_config["description"], str(port_error))
# continue with other port configs even if one fails
logger.info("Successfully updated security group %s for %s", sg_id, service)
except Exception as e:
logger.error("Failed to update security group %s for %s: %s", sg_id, service, str(e))
# continue with other security groups even if one fails
logger.info("Security group update process completed")
return {
"statusCode": 200,
"body": json.dumps({"message": "Security groups updated successfully"})
}
except Exception as e:
logger.exception("Unhandled error: %s", str(e))
return {
"statusCode": 500,
"body": json.dumps({"error": str(e)})
}