diff --git a/functions/replace-route/app.py b/functions/replace-route/app.py index bc61ad5..8a9d345 100644 --- a/functions/replace-route/app.py +++ b/functions/replace-route/app.py @@ -20,6 +20,7 @@ LIFECYCLE_HOOK_NAME_KEY = "LifecycleHookName" AUTO_SCALING_GROUP_NAME_KEY = "AutoScalingGroupName" LIFECYCLE_ACTION_TOKEN_KEY = "LifecycleActionToken" +EC2_INSTANCE_ID_KEY = "EC2InstanceId" # Checks every CONNECTIVITY_CHECK_INTERVAL seconds, exits after 1 minute DEFAULT_CONNECTIVITY_CHECK_INTERVAL = "5" @@ -209,15 +210,25 @@ def is_source_dest_check_enabled(instance_id): logger.error(f"Error checking source/dest check: {e}") return None + +def find_default_routes(route_table_ids): + routes = [None for _ in route_table_ids] + response = ec2_client.describe_route_tables(RouteTableIds=route_table_ids) + for route_table in response["RouteTables"]: + for route in route_table["Routes"]: + if route.get("DestinationCidrBlock") == "0.0.0.0/0": + idx = route_table_ids.index(route_table["RouteTableId"]) + routes[idx] = route + break + return routes + + def are_any_routes_pointing_to_nat_gateway(route_table_ids): - ec2 = boto3.client('ec2') try: - response = ec2.describe_route_tables(RouteTableIds=route_table_ids) - for rtb in response.get('RouteTables', []): - for route in rtb.get('Routes', []): - if route.get('DestinationCidrBlock') == "0.0.0.0/0" and 'NatGatewayId' in route and route.get('State') == 'active': - return True - return False + default_routes = find_default_routes(route_table_ids) + return any( + route is not None and "NatGatewayId" in route for route in default_routes + ) except Exception as e: logger.error(f"Error checking NAT Gateway routes: {e}") return False @@ -414,10 +425,12 @@ def handler(event, _): if ( LIFECYCLE_HOOK_NAME_KEY in message and AUTO_SCALING_GROUP_NAME_KEY in message + and EC2_INSTANCE_ID_KEY in message ): asg = message[AUTO_SCALING_GROUP_NAME_KEY] lifecycle_hook_name = message[LIFECYCLE_HOOK_NAME_KEY] lifecycle_action_token = message[LIFECYCLE_ACTION_TOKEN_KEY] + instance_id = message[EC2_INSTANCE_ID_KEY] else: logger.error("Failed to find lifecycle message to parse") raise LifecycleMessageError @@ -428,16 +441,21 @@ def handler(event, _): availability_zone, vpc_zone_identifier = get_az_and_vpc_zone_identifier(asg) public_subnet_id = vpc_zone_identifier.split(",")[0] az = availability_zone.upper().replace("-", "_") - route_tables = az in os.environ and os.getenv(az).split(",") - if not route_tables: + route_table_ids = az in os.environ and os.getenv(az).split(",") + if not route_table_ids: raise MissingEnvironmentVariableError - vpc_id = get_vpc_id(route_tables[0]) + + vpc_id = get_vpc_id(route_table_ids[0]) nat_gateway_id = get_nat_gateway_id(vpc_id, public_subnet_id) - for rtb in route_tables: - replace_route(rtb, nat_gateway_id) - logger.info("Route replacement succeeded") + default_routes = find_default_routes(route_table_ids) + for route_table_id, route in zip(route_table_ids, default_routes): + if route is not None and route.get("InstanceId") == instance_id: + replace_route(route_table_id, nat_gateway_id) + logger.info("Route replacement succeeded") + else: + logger.info("Skipping route replacement in table %s", route_table_id) complete_asg_lifecycle_action( asg, lifecycle_hook_name, lifecycle_action_token, "CONTINUE" diff --git a/functions/replace-route/tests/test_replace_route.py b/functions/replace-route/tests/test_replace_route.py index 0f0b580..1d31e12 100644 --- a/functions/replace-route/tests/test_replace_route.py +++ b/functions/replace-route/tests/test_replace_route.py @@ -39,20 +39,17 @@ def setup_networking(): CidrBlock="10.1.1.0/24", AvailabilityZone=f"{az}" ) - private_subnet = ec2.create_subnet( VpcId=vpc.id, CidrBlock="10.1.2.0/24", AvailabilityZone=f"{az}", ) - private_subnet_two = ec2.create_subnet( VpcId=vpc.id, CidrBlock="10.1.3.0/24", AvailabilityZone=f"{az}", ) - route_table = ec2.create_route_table(VpcId=vpc.id) route_table_two = ec2.create_route_table(VpcId=vpc.id) sg = ec2.create_security_group(GroupName="test-sg", Description="test-sg") @@ -64,16 +61,33 @@ def setup_networking(): AllocationId=allocation_id )["NatGateway"]["NatGatewayId"] - eni = ec2_client.create_network_interface( - SubnetId=public_subnet.id, PrivateIpAddress="10.1.1.5" + launch_template = ec2_client.create_launch_template( + LaunchTemplateName="test_launch_template", + LaunchTemplateData={"ImageId": EXAMPLE_AMI_ID, "InstanceType": "t2.micro"}, + )["LaunchTemplate"] + + autoscaling_client = boto3.client("autoscaling") + autoscaling_client.create_auto_scaling_group( + AutoScalingGroupName="alternat-asg", + VPCZoneIdentifier=public_subnet.id, + MinSize=1, + MaxSize=1, + LaunchTemplate={ + "LaunchTemplateId": launch_template["LaunchTemplateId"], + "Version": str(launch_template["LatestVersionNumber"]), + }, ) + + reservations = ec2_client.describe_instances()["Reservations"] + instance_id = reservations[0]["Instances"][0]["InstanceId"] + ec2_client.associate_route_table( RouteTableId=route_table.id, SubnetId=private_subnet.id ) ec2_client.create_route( DestinationCidrBlock="0.0.0.0/0", - NetworkInterfaceId=eni["NetworkInterface"]["NetworkInterfaceId"], + InstanceId=instance_id, RouteTableId=route_table.id ) ec2_client.associate_route_table( @@ -82,19 +96,16 @@ def setup_networking(): ) ec2_client.create_route( DestinationCidrBlock="0.0.0.0/0", - NetworkInterfaceId=eni["NetworkInterface"]["NetworkInterfaceId"], + InstanceId=instance_id, RouteTableId=route_table_two.id ) return { - "vpc": vpc.id, "public_subnet": public_subnet.id, - "private_subnet": private_subnet.id, - "private_subnet_two": private_subnet_two.id, "nat_gw": nat_gw_id, "route_table": route_table.id, "route_table_two": route_table_two.id, - "sg": sg.id, + "instance": instance_id, } @@ -135,29 +146,12 @@ def verify_nat_instance_route(mocked_networking, instance_id): @mock_aws def test_handler(monkeypatch): mocked_networking = setup_networking() - ec2_client = boto3.client("ec2") - template = ec2_client.create_launch_template( - LaunchTemplateName="test_launch_template", - LaunchTemplateData={"ImageId": EXAMPLE_AMI_ID, "InstanceType": "t2.micro"}, - )["LaunchTemplate"] - - autoscaling_client = boto3.client("autoscaling") - autoscaling_client.create_auto_scaling_group( - AutoScalingGroupName="alternat-asg", - VPCZoneIdentifier=mocked_networking["public_subnet"], - MinSize=1, - MaxSize=1, - LaunchTemplate={ - "LaunchTemplateId": template["LaunchTemplateId"], - "Version": str(template["LatestVersionNumber"]), - }, - ) from app import handler script_dir = os.path.dirname(__file__) with open(os.path.join(script_dir, "../sns-event.json"), "r") as file: - asg_termination_event = file.read() + asg_termination_event = json.loads(file.read()) az = f"{os.environ['AWS_DEFAULT_REGION']}a".upper().replace("-", "_") monkeypatch.setenv(az, ",".join([mocked_networking["route_table"],mocked_networking["route_table_two"]])) @@ -169,10 +163,20 @@ def mock_make_api_call(self, operation_name, kwarg): if operation_name == "CompleteLifecycleAction": return mock_complete_lifecycle_action(self, operation_name, kwarg) return orig_make_api_call(self, operation_name, kwarg) + with mock.patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call): - handler(json.loads(asg_termination_event), {}) + handler(asg_termination_event, {}) mock_complete_lifecycle_action.assert_called_once() + verify_nat_instance_route(mocked_networking, mocked_networking["instance"]) + + sns_message = json.loads(asg_termination_event["Records"][0]["Sns"]["Message"]) + sns_message["EC2InstanceId"] = mocked_networking["instance"] + asg_termination_event["Records"][0]["Sns"]["Message"] = json.dumps(sns_message) + mock_complete_lifecycle_action.reset_mock() + with mock.patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call): + handler(asg_termination_event, {}) + mock_complete_lifecycle_action.assert_called_once() verify_nat_gateway_route(mocked_networking)