Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions functions/replace-route/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LIFECYCLE_HOOK_NAME_KEY = "LifecycleHookName"
AUTO_SCALING_GROUP_NAME_KEY = "AutoScalingGroupName"
LIFECYCLE_ACTION_TOKEN_KEY = "LifecycleActionToken"
EC2_INSTANCE_ID_KEY = "EC2InstanceId"

# Checks every CONNECTIVITY_CHECK_INTERVAL seconds, exits after 1 minute
DEFAULT_CONNECTIVITY_CHECK_INTERVAL = "5"
Expand Down Expand Up @@ -209,15 +210,25 @@ def is_source_dest_check_enabled(instance_id):
logger.error(f"Error checking source/dest check: {e}")
return None


def find_default_routes(route_table_ids):
routes = [None for _ in route_table_ids]
response = ec2_client.describe_route_tables(RouteTableIds=route_table_ids)
for route_table in response["RouteTables"]:
for route in route_table["Routes"]:
if route.get("DestinationCidrBlock") == "0.0.0.0/0":
idx = route_table_ids.index(route_table["RouteTableId"])
routes[idx] = route
break
return routes


def are_any_routes_pointing_to_nat_gateway(route_table_ids):
ec2 = boto3.client('ec2')
try:
response = ec2.describe_route_tables(RouteTableIds=route_table_ids)
for rtb in response.get('RouteTables', []):
for route in rtb.get('Routes', []):
if route.get('DestinationCidrBlock') == "0.0.0.0/0" and 'NatGatewayId' in route and route.get('State') == 'active':
return True
return False
default_routes = find_default_routes(route_table_ids)
return any(
route is not None and "NatGatewayId" in route for route in default_routes
)
except Exception as e:
logger.error(f"Error checking NAT Gateway routes: {e}")
return False
Expand Down Expand Up @@ -414,10 +425,12 @@ def handler(event, _):
if (
LIFECYCLE_HOOK_NAME_KEY in message
and AUTO_SCALING_GROUP_NAME_KEY in message
and EC2_INSTANCE_ID_KEY in message
):
asg = message[AUTO_SCALING_GROUP_NAME_KEY]
lifecycle_hook_name = message[LIFECYCLE_HOOK_NAME_KEY]
lifecycle_action_token = message[LIFECYCLE_ACTION_TOKEN_KEY]
instance_id = message[EC2_INSTANCE_ID_KEY]
else:
logger.error("Failed to find lifecycle message to parse")
raise LifecycleMessageError
Expand All @@ -428,16 +441,21 @@ def handler(event, _):
availability_zone, vpc_zone_identifier = get_az_and_vpc_zone_identifier(asg)
public_subnet_id = vpc_zone_identifier.split(",")[0]
az = availability_zone.upper().replace("-", "_")
route_tables = az in os.environ and os.getenv(az).split(",")
if not route_tables:
route_table_ids = az in os.environ and os.getenv(az).split(",")
if not route_table_ids:
raise MissingEnvironmentVariableError
vpc_id = get_vpc_id(route_tables[0])

vpc_id = get_vpc_id(route_table_ids[0])

nat_gateway_id = get_nat_gateway_id(vpc_id, public_subnet_id)

for rtb in route_tables:
replace_route(rtb, nat_gateway_id)
logger.info("Route replacement succeeded")
default_routes = find_default_routes(route_table_ids)
for route_table_id, route in zip(route_table_ids, default_routes):
if route is not None and route.get("InstanceId") == instance_id:
replace_route(route_table_id, nat_gateway_id)
logger.info("Route replacement succeeded")
else:
logger.info("Skipping route replacement in table %s", route_table_id)

complete_asg_lifecycle_action(
asg, lifecycle_hook_name, lifecycle_action_token, "CONTINUE"
Expand Down
64 changes: 34 additions & 30 deletions functions/replace-route/tests/test_replace_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,17 @@ def setup_networking():
CidrBlock="10.1.1.0/24",
AvailabilityZone=f"{az}"
)

private_subnet = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock="10.1.2.0/24",
AvailabilityZone=f"{az}",
)

private_subnet_two = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock="10.1.3.0/24",
AvailabilityZone=f"{az}",
)


route_table = ec2.create_route_table(VpcId=vpc.id)
route_table_two = ec2.create_route_table(VpcId=vpc.id)
sg = ec2.create_security_group(GroupName="test-sg", Description="test-sg")
Expand All @@ -64,16 +61,33 @@ def setup_networking():
AllocationId=allocation_id
)["NatGateway"]["NatGatewayId"]

eni = ec2_client.create_network_interface(
SubnetId=public_subnet.id, PrivateIpAddress="10.1.1.5"
launch_template = ec2_client.create_launch_template(
LaunchTemplateName="test_launch_template",
LaunchTemplateData={"ImageId": EXAMPLE_AMI_ID, "InstanceType": "t2.micro"},
)["LaunchTemplate"]

autoscaling_client = boto3.client("autoscaling")
autoscaling_client.create_auto_scaling_group(
AutoScalingGroupName="alternat-asg",
VPCZoneIdentifier=public_subnet.id,
MinSize=1,
MaxSize=1,
LaunchTemplate={
"LaunchTemplateId": launch_template["LaunchTemplateId"],
"Version": str(launch_template["LatestVersionNumber"]),
},
)

reservations = ec2_client.describe_instances()["Reservations"]
instance_id = reservations[0]["Instances"][0]["InstanceId"]

ec2_client.associate_route_table(
RouteTableId=route_table.id,
SubnetId=private_subnet.id
)
ec2_client.create_route(
DestinationCidrBlock="0.0.0.0/0",
NetworkInterfaceId=eni["NetworkInterface"]["NetworkInterfaceId"],
InstanceId=instance_id,
RouteTableId=route_table.id
)
ec2_client.associate_route_table(
Expand All @@ -82,19 +96,16 @@ def setup_networking():
)
ec2_client.create_route(
DestinationCidrBlock="0.0.0.0/0",
NetworkInterfaceId=eni["NetworkInterface"]["NetworkInterfaceId"],
InstanceId=instance_id,
RouteTableId=route_table_two.id
)

return {
"vpc": vpc.id,
"public_subnet": public_subnet.id,
"private_subnet": private_subnet.id,
"private_subnet_two": private_subnet_two.id,
"nat_gw": nat_gw_id,
"route_table": route_table.id,
"route_table_two": route_table_two.id,
"sg": sg.id,
"instance": instance_id,
}


Expand Down Expand Up @@ -135,29 +146,12 @@ def verify_nat_instance_route(mocked_networking, instance_id):
@mock_aws
def test_handler(monkeypatch):
mocked_networking = setup_networking()
ec2_client = boto3.client("ec2")
template = ec2_client.create_launch_template(
LaunchTemplateName="test_launch_template",
LaunchTemplateData={"ImageId": EXAMPLE_AMI_ID, "InstanceType": "t2.micro"},
)["LaunchTemplate"]

autoscaling_client = boto3.client("autoscaling")
autoscaling_client.create_auto_scaling_group(
AutoScalingGroupName="alternat-asg",
VPCZoneIdentifier=mocked_networking["public_subnet"],
MinSize=1,
MaxSize=1,
LaunchTemplate={
"LaunchTemplateId": template["LaunchTemplateId"],
"Version": str(template["LatestVersionNumber"]),
},
)

from app import handler

script_dir = os.path.dirname(__file__)
with open(os.path.join(script_dir, "../sns-event.json"), "r") as file:
asg_termination_event = file.read()
asg_termination_event = json.loads(file.read())

az = f"{os.environ['AWS_DEFAULT_REGION']}a".upper().replace("-", "_")
monkeypatch.setenv(az, ",".join([mocked_networking["route_table"],mocked_networking["route_table_two"]]))
Expand All @@ -169,10 +163,20 @@ def mock_make_api_call(self, operation_name, kwarg):
if operation_name == "CompleteLifecycleAction":
return mock_complete_lifecycle_action(self, operation_name, kwarg)
return orig_make_api_call(self, operation_name, kwarg)

with mock.patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call):
handler(json.loads(asg_termination_event), {})
handler(asg_termination_event, {})
mock_complete_lifecycle_action.assert_called_once()
verify_nat_instance_route(mocked_networking, mocked_networking["instance"])

sns_message = json.loads(asg_termination_event["Records"][0]["Sns"]["Message"])
sns_message["EC2InstanceId"] = mocked_networking["instance"]
asg_termination_event["Records"][0]["Sns"]["Message"] = json.dumps(sns_message)

mock_complete_lifecycle_action.reset_mock()
with mock.patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call):
handler(asg_termination_event, {})
mock_complete_lifecycle_action.assert_called_once()
verify_nat_gateway_route(mocked_networking)


Expand Down