#!/usr/bin/env python3 """ Import AWS EC2 and Security Group data directly from AWS accounts using boto3 Supports MFA/OTP authentication """ import boto3 import sqlite3 import os import sys import configparser from pathlib import Path from getpass import getpass def get_aws_profiles(): """Read available AWS profiles from ~/.aws/config""" config_path = Path.home() / '.aws' / 'config' if not config_path.exists(): print(f"Error: AWS config file not found at {config_path}") return [] config = configparser.ConfigParser() config.read(config_path) profiles = [] for section in config.sections(): if section.startswith('profile '): profile_name = section.replace('profile ', '') profiles.append(profile_name) elif section == 'default': profiles.append('default') # Sort profiles alphabetically, but keep 'default' at the top profiles.sort(key=lambda x: ('0' if x == 'default' else '1' + x.lower())) return profiles def get_session_with_mfa(profile_name): """ Create a boto3 session with MFA authentication """ print(f"\nAuthenticating with profile: {profile_name}") # Create initial session session = boto3.Session(profile_name=profile_name) sts = session.client('sts') try: # Try to get caller identity (will fail if MFA is required) identity = sts.get_caller_identity() print(f"✓ Authenticated as: {identity['Arn']}") return session except Exception as e: # Check if MFA is required if 'MultiFactorAuthentication' in str(e) or 'MFA' in str(e): print("MFA/OTP required for this profile") # Get MFA device ARN from config or prompt config_path = Path.home() / '.aws' / 'config' config = configparser.ConfigParser() config.read(config_path) section_name = f'profile {profile_name}' if profile_name != 'default' else 'default' mfa_serial = None if section_name in config: mfa_serial = config[section_name].get('mfa_serial') if not mfa_serial: print("\nMFA device ARN not found in config.") print("Enter MFA device ARN (e.g., arn:aws:iam::123456789012:mfa/username):") mfa_serial = input("MFA ARN: ").strip() else: print(f"Using MFA device: {mfa_serial}") # Get OTP token token_code = getpass("Enter MFA token code: ") # Get temporary credentials try: response = sts.get_session_token( DurationSeconds=3600, # 1 hour SerialNumber=mfa_serial, TokenCode=token_code ) credentials = response['Credentials'] # Create new session with temporary credentials session = boto3.Session( aws_access_key_id=credentials['AccessKeyId'], aws_secret_access_key=credentials['SecretAccessKey'], aws_session_token=credentials['SessionToken'] ) print("✓ MFA authentication successful") return session except Exception as mfa_error: print(f"Error: MFA authentication failed: {mfa_error}") return None else: print(f"Error: Authentication failed: {e}") return None def get_account_info(session): """Get AWS account ID and alias""" sts = session.client('sts') identity = sts.get_caller_identity() account_id = identity['Account'] # Try to get account alias try: iam = session.client('iam') aliases = iam.list_account_aliases() account_name = aliases['AccountAliases'][0] if aliases['AccountAliases'] else account_id except: account_name = account_id return account_id, account_name def fetch_security_groups(session, account_id, account_name): """Fetch all security groups from AWS""" ec2 = session.client('ec2') print("Fetching security groups...") paginator = ec2.get_paginator('describe_security_groups') security_groups = [] sg_rules = [] for page in paginator.paginate(): for sg in page['SecurityGroups']: # Extract tags tags = {tag['Key']: tag['Value'] for tag in sg.get('Tags', [])} # Parse rules first to get accurate counts ingress_rules = [] egress_rules = [] for rule in sg.get('IpPermissions', []): ingress_rules.extend(parse_sg_rule(sg['GroupId'], 'ingress', rule)) for rule in sg.get('IpPermissionsEgress', []): egress_rules.extend(parse_sg_rule(sg['GroupId'], 'egress', rule)) sg_data = { 'account_id': account_id, 'account_name': account_name, 'group_id': sg['GroupId'], 'group_name': sg['GroupName'], 'tag_name': tags.get('Name', ''), 'tag_wave': tags.get('ucsb:dept:INFR:wave', 'none'), 'tag_git_repo': tags.get('git_repo', 'none'), 'tag_git_org': tags.get('git_org', ''), 'tag_git_file': tags.get('git_file', ''), 'tags_json': tags, 'ingress_rule_count': len(ingress_rules), 'egress_rule_count': len(egress_rules) } security_groups.append(sg_data) # Add parsed rules to the list sg_rules.extend(ingress_rules) sg_rules.extend(egress_rules) print(f"✓ Found {len(security_groups)} security groups with {len(sg_rules)} rules") return security_groups, sg_rules def parse_sg_rule(group_id, direction, rule): """Parse a security group rule into individual entries""" rules = [] protocol = rule.get('IpProtocol', '-1') from_port = rule.get('FromPort', '') to_port = rule.get('ToPort', '') # Normalize protocol if protocol == '-1': protocol_str = 'All' port_range = 'All' elif protocol == '6': protocol_str = 'TCP' port_range = f"{from_port}-{to_port}" if from_port != to_port else str(from_port) elif protocol == '17': protocol_str = 'UDP' port_range = f"{from_port}-{to_port}" if from_port != to_port else str(from_port) elif protocol == '1': protocol_str = 'ICMP' port_range = 'N/A' else: protocol_str = protocol port_range = f"{from_port}-{to_port}" if from_port and to_port else 'N/A' # Parse IP ranges for ip_range in rule.get('IpRanges', []): rules.append({ 'group_id': group_id, 'direction': direction, 'protocol': protocol_str, 'port_range': port_range, 'source_type': 'CIDR', 'source': ip_range['CidrIp'], 'description': ip_range.get('Description', '') }) # Parse IPv6 ranges for ip_range in rule.get('Ipv6Ranges', []): rules.append({ 'group_id': group_id, 'direction': direction, 'protocol': protocol_str, 'port_range': port_range, 'source_type': 'CIDR', 'source': ip_range['CidrIpv6'], 'description': ip_range.get('Description', '') }) # Parse security group references for sg_ref in rule.get('UserIdGroupPairs', []): source = sg_ref.get('GroupId', '') if sg_ref.get('GroupName'): source += f" ({sg_ref['GroupName']})" rules.append({ 'group_id': group_id, 'direction': direction, 'protocol': protocol_str, 'port_range': port_range, 'source_type': 'Security Group', 'source': source, 'description': sg_ref.get('Description', '') }) # Parse prefix lists for prefix in rule.get('PrefixListIds', []): rules.append({ 'group_id': group_id, 'direction': direction, 'protocol': protocol_str, 'port_range': port_range, 'source_type': 'Prefix List', 'source': prefix['PrefixListId'], 'description': prefix.get('Description', '') }) return rules def fetch_ec2_instances(session, account_id, account_name): """Fetch all EC2 instances from AWS""" ec2 = session.client('ec2') print("Fetching EC2 instances...") paginator = ec2.get_paginator('describe_instances') instances = [] for page in paginator.paginate(): for reservation in page['Reservations']: for instance in reservation['Instances']: # Extract tags tags = {tag['Key']: tag['Value'] for tag in instance.get('Tags', [])} # Extract security groups sg_ids = [sg['GroupId'] for sg in instance.get('SecurityGroups', [])] sg_names = [sg['GroupName'] for sg in instance.get('SecurityGroups', [])] instance_data = { 'account_id': account_id, 'account_name': account_name, 'tag_name': tags.get('Name', ''), 'instance_id': instance['InstanceId'], 'state': instance['State']['Name'], 'private_ip_address': instance.get('PrivateIpAddress', ''), 'security_groups_id_list': ';'.join(sg_ids), 'security_groups_name_list': ';'.join(sg_names), 'tag_git_repo': tags.get('git_repo', 'none'), 'tag_git_org': tags.get('git_org', ''), 'tag_git_file': tags.get('git_file', ''), 'tags_json': tags } instances.append(instance_data) print(f"✓ Found {len(instances)} EC2 instances") return instances def get_db(db_path): """Get database connection and create schema if needed""" conn = sqlite3.connect(db_path) cursor = conn.cursor() # Create tables if they don't exist cursor.execute(""" CREATE TABLE IF NOT EXISTS security_groups ( id INTEGER PRIMARY KEY AUTOINCREMENT, account_id TEXT, account_name TEXT, group_id TEXT UNIQUE, group_name TEXT, tag_name TEXT, tag_wave TEXT, tag_git_repo TEXT, tag_git_org TEXT, tag_git_file TEXT, tags_json TEXT, ingress_rule_count INTEGER, egress_rule_count INTEGER ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS ec2_instances ( id INTEGER PRIMARY KEY AUTOINCREMENT, account_id TEXT, account_name TEXT, tag_name TEXT, instance_id TEXT UNIQUE, state TEXT, private_ip_address TEXT, security_groups_id_list TEXT, security_groups_name_list TEXT, tag_git_repo TEXT, tag_git_org TEXT, tag_git_file TEXT, tags_json TEXT ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS sg_rules ( id INTEGER PRIMARY KEY AUTOINCREMENT, group_id TEXT, direction TEXT, protocol TEXT, port_range TEXT, source_type TEXT, source TEXT, description TEXT, FOREIGN KEY (group_id) REFERENCES security_groups(group_id) ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS refresh_timestamps ( id INTEGER PRIMARY KEY AUTOINCREMENT, account_id TEXT, account_name TEXT, last_refresh TIMESTAMP DEFAULT CURRENT_TIMESTAMP, UNIQUE(account_id) ) """) # Create indexes cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_group_id ON security_groups(group_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_account_name ON security_groups(account_name)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_ec2_instance_id ON ec2_instances(instance_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_ec2_account_name ON ec2_instances(account_name)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_rules_group_id ON sg_rules(group_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_rules_direction ON sg_rules(direction)") conn.commit() return conn def import_to_database(db_path, security_groups, ec2_instances, sg_rules=None, append=False): """Import data into SQLite database""" import json from datetime import datetime conn = get_db(db_path) cursor = conn.cursor() if not append: # Clear existing data (but keep refresh_timestamps) print("Clearing existing data...") cursor.execute("DELETE FROM security_groups") cursor.execute("DELETE FROM ec2_instances") cursor.execute("DELETE FROM sg_rules") # Import security groups print(f"Importing {len(security_groups)} security groups...") for sg in security_groups: cursor.execute(""" INSERT OR REPLACE INTO security_groups (account_id, account_name, group_id, group_name, tag_name, tag_wave, tag_git_repo, tag_git_org, tag_git_file, tags_json, ingress_rule_count, egress_rule_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( sg['account_id'], sg['account_name'], sg['group_id'], sg['group_name'], sg['tag_name'], sg['tag_wave'], sg['tag_git_repo'], sg.get('tag_git_org', ''), sg.get('tag_git_file', ''), json.dumps(sg.get('tags_json', {})), sg['ingress_rule_count'], sg.get('egress_rule_count', 0) )) # Import EC2 instances print(f"Importing {len(ec2_instances)} EC2 instances...") for instance in ec2_instances: cursor.execute(""" INSERT OR REPLACE INTO ec2_instances (account_id, account_name, tag_name, instance_id, state, private_ip_address, security_groups_id_list, security_groups_name_list, tag_git_repo, tag_git_org, tag_git_file, tags_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( instance['account_id'], instance['account_name'], instance['tag_name'], instance['instance_id'], instance['state'], instance['private_ip_address'], instance['security_groups_id_list'], instance['security_groups_name_list'], instance['tag_git_repo'], instance.get('tag_git_org', ''), instance.get('tag_git_file', ''), json.dumps(instance.get('tags_json', {})) )) # Import security group rules if sg_rules: print(f"Importing {len(sg_rules)} security group rules...") # If appending, delete existing rules for these security groups to avoid duplicates if append: unique_group_ids = set(rule['group_id'] for rule in sg_rules) for group_id in unique_group_ids: cursor.execute("DELETE FROM sg_rules WHERE group_id = ?", (group_id,)) for rule in sg_rules: cursor.execute(""" INSERT INTO sg_rules (group_id, direction, protocol, port_range, source_type, source, description) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( rule['group_id'], rule['direction'], rule['protocol'], rule['port_range'], rule['source_type'], rule['source'], rule['description'] )) # Update refresh timestamps for all accounts print("Updating refresh timestamps...") accounts = set() for sg in security_groups: accounts.add((sg['account_id'], sg['account_name'])) for instance in ec2_instances: accounts.add((instance['account_id'], instance['account_name'])) for account_id, account_name in accounts: cursor.execute(""" INSERT INTO refresh_timestamps (account_id, account_name, last_refresh) VALUES (?, ?, CURRENT_TIMESTAMP) ON CONFLICT(account_id) DO UPDATE SET last_refresh = CURRENT_TIMESTAMP, account_name = excluded.account_name """, (account_id, account_name)) conn.commit() conn.close() print("✓ Import complete") def main(): # Database path db_path = os.path.join(os.path.dirname(__file__), 'data', 'aws_export.db') os.makedirs(os.path.dirname(db_path), exist_ok=True) print("=" * 60) print("AWS Direct Import Tool") print("=" * 60) # Get available profiles profiles = get_aws_profiles() if not profiles: print("No AWS profiles found in ~/.aws/config") sys.exit(1) print("\nAvailable AWS profiles:") for i, profile in enumerate(profiles, 1): print(f" {i}. {profile}") # Let user select profile(s) print("\nEnter profile number(s) to import (comma-separated, or 'all'):") selection = input("Selection: ").strip() if selection.lower() == 'all': selected_profiles = profiles else: try: indices = [int(x.strip()) - 1 for x in selection.split(',')] selected_profiles = [profiles[i] for i in indices] except (ValueError, IndexError): print("Invalid selection") sys.exit(1) # Ask if should append or replace append_mode = False if len(selected_profiles) > 1: append_choice = input("\nAppend to existing data? (y/N): ").strip().lower() append_mode = append_choice == 'y' # Process each profile all_security_groups = [] all_ec2_instances = [] all_sg_rules = [] for i, profile in enumerate(selected_profiles): print(f"\n{'=' * 60}") print(f"Processing profile {i+1}/{len(selected_profiles)}: {profile}") print('=' * 60) # Authenticate session = get_session_with_mfa(profile) if not session: print(f"✗ Skipping profile {profile} due to authentication failure") continue # Get account info account_id, account_name = get_account_info(session) print(f"Account: {account_name} ({account_id})") # Fetch data security_groups, sg_rules = fetch_security_groups(session, account_id, account_name) ec2_instances = fetch_ec2_instances(session, account_id, account_name) all_security_groups.extend(security_groups) all_ec2_instances.extend(ec2_instances) all_sg_rules.extend(sg_rules) # Import to database if all_security_groups or all_ec2_instances: print(f"\n{'=' * 60}") print("Importing to database...") print('=' * 60) import_to_database(db_path, all_security_groups, all_ec2_instances, all_sg_rules, append=append_mode and len(selected_profiles) > 1) print(f"\n✓ Successfully imported data from {len(selected_profiles)} profile(s)") print(f" Database: {db_path}") print(f" Total Security Groups: {len(all_security_groups)}") print(f" Total EC2 Instances: {len(all_ec2_instances)}") print(f" Total SG Rules: {len(all_sg_rules)}") else: print("\n✗ No data imported") if __name__ == "__main__": main()