SGO/import_from_aws.py
Eduardo Figueroa 6886c8871c
Initial Commit
2025-11-20 12:03:30 -08:00

548 lines
19 KiB
Python
Executable file

#!/usr/bin/env python3
"""
Import AWS EC2 and Security Group data directly from AWS accounts using boto3
Supports MFA/OTP authentication
"""
import boto3
import sqlite3
import os
import sys
import configparser
from pathlib import Path
from getpass import getpass
def get_aws_profiles():
"""Read available AWS profiles from ~/.aws/config"""
config_path = Path.home() / '.aws' / 'config'
if not config_path.exists():
print(f"Error: AWS config file not found at {config_path}")
return []
config = configparser.ConfigParser()
config.read(config_path)
profiles = []
for section in config.sections():
if section.startswith('profile '):
profile_name = section.replace('profile ', '')
profiles.append(profile_name)
elif section == 'default':
profiles.append('default')
# Sort profiles alphabetically, but keep 'default' at the top
profiles.sort(key=lambda x: ('0' if x == 'default' else '1' + x.lower()))
return profiles
def get_session_with_mfa(profile_name):
"""
Create a boto3 session with MFA authentication
"""
print(f"\nAuthenticating with profile: {profile_name}")
# Create initial session
session = boto3.Session(profile_name=profile_name)
sts = session.client('sts')
try:
# Try to get caller identity (will fail if MFA is required)
identity = sts.get_caller_identity()
print(f"✓ Authenticated as: {identity['Arn']}")
return session
except Exception as e:
# Check if MFA is required
if 'MultiFactorAuthentication' in str(e) or 'MFA' in str(e):
print("MFA/OTP required for this profile")
# Get MFA device ARN from config or prompt
config_path = Path.home() / '.aws' / 'config'
config = configparser.ConfigParser()
config.read(config_path)
section_name = f'profile {profile_name}' if profile_name != 'default' else 'default'
mfa_serial = None
if section_name in config:
mfa_serial = config[section_name].get('mfa_serial')
if not mfa_serial:
print("\nMFA device ARN not found in config.")
print("Enter MFA device ARN (e.g., arn:aws:iam::123456789012:mfa/username):")
mfa_serial = input("MFA ARN: ").strip()
else:
print(f"Using MFA device: {mfa_serial}")
# Get OTP token
token_code = getpass("Enter MFA token code: ")
# Get temporary credentials
try:
response = sts.get_session_token(
DurationSeconds=3600, # 1 hour
SerialNumber=mfa_serial,
TokenCode=token_code
)
credentials = response['Credentials']
# Create new session with temporary credentials
session = boto3.Session(
aws_access_key_id=credentials['AccessKeyId'],
aws_secret_access_key=credentials['SecretAccessKey'],
aws_session_token=credentials['SessionToken']
)
print("✓ MFA authentication successful")
return session
except Exception as mfa_error:
print(f"Error: MFA authentication failed: {mfa_error}")
return None
else:
print(f"Error: Authentication failed: {e}")
return None
def get_account_info(session):
"""Get AWS account ID and alias"""
sts = session.client('sts')
identity = sts.get_caller_identity()
account_id = identity['Account']
# Try to get account alias
try:
iam = session.client('iam')
aliases = iam.list_account_aliases()
account_name = aliases['AccountAliases'][0] if aliases['AccountAliases'] else account_id
except:
account_name = account_id
return account_id, account_name
def fetch_security_groups(session, account_id, account_name):
"""Fetch all security groups from AWS"""
ec2 = session.client('ec2')
print("Fetching security groups...")
paginator = ec2.get_paginator('describe_security_groups')
security_groups = []
sg_rules = []
for page in paginator.paginate():
for sg in page['SecurityGroups']:
# Extract tags
tags = {tag['Key']: tag['Value'] for tag in sg.get('Tags', [])}
# Parse rules first to get accurate counts
ingress_rules = []
egress_rules = []
for rule in sg.get('IpPermissions', []):
ingress_rules.extend(parse_sg_rule(sg['GroupId'], 'ingress', rule))
for rule in sg.get('IpPermissionsEgress', []):
egress_rules.extend(parse_sg_rule(sg['GroupId'], 'egress', rule))
sg_data = {
'account_id': account_id,
'account_name': account_name,
'group_id': sg['GroupId'],
'group_name': sg['GroupName'],
'tag_name': tags.get('Name', ''),
'tag_wave': tags.get('ucsb:dept:INFR:wave', 'none'),
'tag_git_repo': tags.get('git_repo', 'none'),
'tag_git_org': tags.get('git_org', ''),
'tag_git_file': tags.get('git_file', ''),
'tags_json': tags,
'ingress_rule_count': len(ingress_rules),
'egress_rule_count': len(egress_rules)
}
security_groups.append(sg_data)
# Add parsed rules to the list
sg_rules.extend(ingress_rules)
sg_rules.extend(egress_rules)
print(f"✓ Found {len(security_groups)} security groups with {len(sg_rules)} rules")
return security_groups, sg_rules
def parse_sg_rule(group_id, direction, rule):
"""Parse a security group rule into individual entries"""
rules = []
protocol = rule.get('IpProtocol', '-1')
from_port = rule.get('FromPort', '')
to_port = rule.get('ToPort', '')
# Normalize protocol
if protocol == '-1':
protocol_str = 'All'
port_range = 'All'
elif protocol == '6':
protocol_str = 'TCP'
port_range = f"{from_port}-{to_port}" if from_port != to_port else str(from_port)
elif protocol == '17':
protocol_str = 'UDP'
port_range = f"{from_port}-{to_port}" if from_port != to_port else str(from_port)
elif protocol == '1':
protocol_str = 'ICMP'
port_range = 'N/A'
else:
protocol_str = protocol
port_range = f"{from_port}-{to_port}" if from_port and to_port else 'N/A'
# Parse IP ranges
for ip_range in rule.get('IpRanges', []):
rules.append({
'group_id': group_id,
'direction': direction,
'protocol': protocol_str,
'port_range': port_range,
'source_type': 'CIDR',
'source': ip_range['CidrIp'],
'description': ip_range.get('Description', '')
})
# Parse IPv6 ranges
for ip_range in rule.get('Ipv6Ranges', []):
rules.append({
'group_id': group_id,
'direction': direction,
'protocol': protocol_str,
'port_range': port_range,
'source_type': 'CIDR',
'source': ip_range['CidrIpv6'],
'description': ip_range.get('Description', '')
})
# Parse security group references
for sg_ref in rule.get('UserIdGroupPairs', []):
source = sg_ref.get('GroupId', '')
if sg_ref.get('GroupName'):
source += f" ({sg_ref['GroupName']})"
rules.append({
'group_id': group_id,
'direction': direction,
'protocol': protocol_str,
'port_range': port_range,
'source_type': 'Security Group',
'source': source,
'description': sg_ref.get('Description', '')
})
# Parse prefix lists
for prefix in rule.get('PrefixListIds', []):
rules.append({
'group_id': group_id,
'direction': direction,
'protocol': protocol_str,
'port_range': port_range,
'source_type': 'Prefix List',
'source': prefix['PrefixListId'],
'description': prefix.get('Description', '')
})
return rules
def fetch_ec2_instances(session, account_id, account_name):
"""Fetch all EC2 instances from AWS"""
ec2 = session.client('ec2')
print("Fetching EC2 instances...")
paginator = ec2.get_paginator('describe_instances')
instances = []
for page in paginator.paginate():
for reservation in page['Reservations']:
for instance in reservation['Instances']:
# Extract tags
tags = {tag['Key']: tag['Value'] for tag in instance.get('Tags', [])}
# Extract security groups
sg_ids = [sg['GroupId'] for sg in instance.get('SecurityGroups', [])]
sg_names = [sg['GroupName'] for sg in instance.get('SecurityGroups', [])]
instance_data = {
'account_id': account_id,
'account_name': account_name,
'tag_name': tags.get('Name', ''),
'instance_id': instance['InstanceId'],
'state': instance['State']['Name'],
'private_ip_address': instance.get('PrivateIpAddress', ''),
'security_groups_id_list': ';'.join(sg_ids),
'security_groups_name_list': ';'.join(sg_names),
'tag_git_repo': tags.get('git_repo', 'none'),
'tag_git_org': tags.get('git_org', ''),
'tag_git_file': tags.get('git_file', ''),
'tags_json': tags
}
instances.append(instance_data)
print(f"✓ Found {len(instances)} EC2 instances")
return instances
def get_db(db_path):
"""Get database connection and create schema if needed"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create tables if they don't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS security_groups (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
account_name TEXT,
group_id TEXT UNIQUE,
group_name TEXT,
tag_name TEXT,
tag_wave TEXT,
tag_git_repo TEXT,
tag_git_org TEXT,
tag_git_file TEXT,
tags_json TEXT,
ingress_rule_count INTEGER,
egress_rule_count INTEGER
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS ec2_instances (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
account_name TEXT,
tag_name TEXT,
instance_id TEXT UNIQUE,
state TEXT,
private_ip_address TEXT,
security_groups_id_list TEXT,
security_groups_name_list TEXT,
tag_git_repo TEXT,
tag_git_org TEXT,
tag_git_file TEXT,
tags_json TEXT
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS sg_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
group_id TEXT,
direction TEXT,
protocol TEXT,
port_range TEXT,
source_type TEXT,
source TEXT,
description TEXT,
FOREIGN KEY (group_id) REFERENCES security_groups(group_id)
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS refresh_timestamps (
id INTEGER PRIMARY KEY AUTOINCREMENT,
account_id TEXT,
account_name TEXT,
last_refresh TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(account_id)
)
""")
# Create indexes
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_group_id ON security_groups(group_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_account_name ON security_groups(account_name)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ec2_instance_id ON ec2_instances(instance_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ec2_account_name ON ec2_instances(account_name)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_rules_group_id ON sg_rules(group_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_sg_rules_direction ON sg_rules(direction)")
conn.commit()
return conn
def import_to_database(db_path, security_groups, ec2_instances, sg_rules=None, append=False):
"""Import data into SQLite database"""
import json
from datetime import datetime
conn = get_db(db_path)
cursor = conn.cursor()
if not append:
# Clear existing data (but keep refresh_timestamps)
print("Clearing existing data...")
cursor.execute("DELETE FROM security_groups")
cursor.execute("DELETE FROM ec2_instances")
cursor.execute("DELETE FROM sg_rules")
# Import security groups
print(f"Importing {len(security_groups)} security groups...")
for sg in security_groups:
cursor.execute("""
INSERT OR REPLACE INTO security_groups
(account_id, account_name, group_id, group_name, tag_name, tag_wave, tag_git_repo,
tag_git_org, tag_git_file, tags_json, ingress_rule_count, egress_rule_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
sg['account_id'], sg['account_name'], sg['group_id'], sg['group_name'],
sg['tag_name'], sg['tag_wave'], sg['tag_git_repo'],
sg.get('tag_git_org', ''), sg.get('tag_git_file', ''),
json.dumps(sg.get('tags_json', {})),
sg['ingress_rule_count'], sg.get('egress_rule_count', 0)
))
# Import EC2 instances
print(f"Importing {len(ec2_instances)} EC2 instances...")
for instance in ec2_instances:
cursor.execute("""
INSERT OR REPLACE INTO ec2_instances
(account_id, account_name, tag_name, instance_id, state, private_ip_address,
security_groups_id_list, security_groups_name_list, tag_git_repo,
tag_git_org, tag_git_file, tags_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
instance['account_id'], instance['account_name'], instance['tag_name'],
instance['instance_id'], instance['state'], instance['private_ip_address'],
instance['security_groups_id_list'], instance['security_groups_name_list'],
instance['tag_git_repo'], instance.get('tag_git_org', ''),
instance.get('tag_git_file', ''), json.dumps(instance.get('tags_json', {}))
))
# Import security group rules
if sg_rules:
print(f"Importing {len(sg_rules)} security group rules...")
# If appending, delete existing rules for these security groups to avoid duplicates
if append:
unique_group_ids = set(rule['group_id'] for rule in sg_rules)
for group_id in unique_group_ids:
cursor.execute("DELETE FROM sg_rules WHERE group_id = ?", (group_id,))
for rule in sg_rules:
cursor.execute("""
INSERT INTO sg_rules
(group_id, direction, protocol, port_range, source_type, source, description)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
rule['group_id'], rule['direction'], rule['protocol'], rule['port_range'],
rule['source_type'], rule['source'], rule['description']
))
# Update refresh timestamps for all accounts
print("Updating refresh timestamps...")
accounts = set()
for sg in security_groups:
accounts.add((sg['account_id'], sg['account_name']))
for instance in ec2_instances:
accounts.add((instance['account_id'], instance['account_name']))
for account_id, account_name in accounts:
cursor.execute("""
INSERT INTO refresh_timestamps (account_id, account_name, last_refresh)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(account_id) DO UPDATE SET
last_refresh = CURRENT_TIMESTAMP,
account_name = excluded.account_name
""", (account_id, account_name))
conn.commit()
conn.close()
print("✓ Import complete")
def main():
# Database path
db_path = os.path.join(os.path.dirname(__file__), 'data', 'aws_export.db')
os.makedirs(os.path.dirname(db_path), exist_ok=True)
print("=" * 60)
print("AWS Direct Import Tool")
print("=" * 60)
# Get available profiles
profiles = get_aws_profiles()
if not profiles:
print("No AWS profiles found in ~/.aws/config")
sys.exit(1)
print("\nAvailable AWS profiles:")
for i, profile in enumerate(profiles, 1):
print(f" {i}. {profile}")
# Let user select profile(s)
print("\nEnter profile number(s) to import (comma-separated, or 'all'):")
selection = input("Selection: ").strip()
if selection.lower() == 'all':
selected_profiles = profiles
else:
try:
indices = [int(x.strip()) - 1 for x in selection.split(',')]
selected_profiles = [profiles[i] for i in indices]
except (ValueError, IndexError):
print("Invalid selection")
sys.exit(1)
# Ask if should append or replace
append_mode = False
if len(selected_profiles) > 1:
append_choice = input("\nAppend to existing data? (y/N): ").strip().lower()
append_mode = append_choice == 'y'
# Process each profile
all_security_groups = []
all_ec2_instances = []
all_sg_rules = []
for i, profile in enumerate(selected_profiles):
print(f"\n{'=' * 60}")
print(f"Processing profile {i+1}/{len(selected_profiles)}: {profile}")
print('=' * 60)
# Authenticate
session = get_session_with_mfa(profile)
if not session:
print(f"✗ Skipping profile {profile} due to authentication failure")
continue
# Get account info
account_id, account_name = get_account_info(session)
print(f"Account: {account_name} ({account_id})")
# Fetch data
security_groups, sg_rules = fetch_security_groups(session, account_id, account_name)
ec2_instances = fetch_ec2_instances(session, account_id, account_name)
all_security_groups.extend(security_groups)
all_ec2_instances.extend(ec2_instances)
all_sg_rules.extend(sg_rules)
# Import to database
if all_security_groups or all_ec2_instances:
print(f"\n{'=' * 60}")
print("Importing to database...")
print('=' * 60)
import_to_database(db_path, all_security_groups, all_ec2_instances, all_sg_rules,
append=append_mode and len(selected_profiles) > 1)
print(f"\n✓ Successfully imported data from {len(selected_profiles)} profile(s)")
print(f" Database: {db_path}")
print(f" Total Security Groups: {len(all_security_groups)}")
print(f" Total EC2 Instances: {len(all_ec2_instances)}")
print(f" Total SG Rules: {len(all_sg_rules)}")
else:
print("\n✗ No data imported")
if __name__ == "__main__":
main()