877 lines
32 KiB
Python
Executable file
877 lines
32 KiB
Python
Executable file
#!/usr/bin/env python3
|
||
"""
|
||
Flask web application for exploring AWS EC2 and Security Group exports
|
||
"""
|
||
|
||
from flask import Flask, render_template, request, jsonify, Response, stream_with_context
|
||
import sqlite3
|
||
import os
|
||
import re
|
||
import atexit
|
||
import signal
|
||
import sys
|
||
import boto3
|
||
import configparser
|
||
from pathlib import Path
|
||
import json
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import threading
|
||
import queue
|
||
|
||
app = Flask(__name__)
|
||
|
||
DB_PATH = os.path.join(os.path.dirname(__file__), 'data', 'aws_export.db')
|
||
data_imported = False
|
||
|
||
# Cache for AWS session credentials (valid for 1 hour)
|
||
session_cache = {} # {profile: {'credentials': {...}, 'region': ..., 'timestamp': ...}}
|
||
|
||
|
||
def regexp(pattern, value):
|
||
"""Custom REGEXP function for SQLite"""
|
||
if value is None:
|
||
return False
|
||
try:
|
||
return re.search(pattern, value, re.IGNORECASE) is not None
|
||
except re.error:
|
||
return False
|
||
|
||
|
||
def get_db():
|
||
"""Get database connection"""
|
||
# Ensure data directory exists
|
||
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
||
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.row_factory = sqlite3.Row
|
||
conn.create_function("REGEXP", 2, regexp)
|
||
|
||
# Create tables if they don't exist
|
||
cursor = conn.cursor()
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS security_groups (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
account_id TEXT,
|
||
account_name TEXT,
|
||
group_id TEXT UNIQUE,
|
||
group_name TEXT,
|
||
tag_name TEXT,
|
||
tag_wave TEXT,
|
||
tag_git_repo TEXT,
|
||
tag_git_org TEXT,
|
||
tag_git_file TEXT,
|
||
tags_json TEXT,
|
||
ingress_rule_count INTEGER,
|
||
egress_rule_count INTEGER
|
||
)
|
||
""")
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS ec2_instances (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
account_id TEXT,
|
||
account_name TEXT,
|
||
tag_name TEXT,
|
||
instance_id TEXT UNIQUE,
|
||
state TEXT,
|
||
private_ip_address TEXT,
|
||
security_groups_id_list TEXT,
|
||
security_groups_name_list TEXT,
|
||
tag_git_repo TEXT,
|
||
tag_git_org TEXT,
|
||
tag_git_file TEXT,
|
||
tags_json TEXT
|
||
)
|
||
""")
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS sg_rules (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
group_id TEXT,
|
||
direction TEXT,
|
||
protocol TEXT,
|
||
port_range TEXT,
|
||
source_type TEXT,
|
||
source TEXT,
|
||
description TEXT
|
||
)
|
||
""")
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS refresh_timestamps (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
account_id TEXT,
|
||
account_name TEXT,
|
||
last_refresh TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
UNIQUE(account_id)
|
||
)
|
||
""")
|
||
conn.commit()
|
||
|
||
return conn
|
||
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""Import page - always shown first"""
|
||
global data_imported
|
||
# If data already imported, redirect to explorer
|
||
if data_imported and os.path.exists(DB_PATH):
|
||
return render_template('index.html')
|
||
return render_template('import.html')
|
||
|
||
|
||
@app.route('/explorer')
|
||
def explorer():
|
||
"""Main explorer interface"""
|
||
# Always show explorer, will display empty state if no data
|
||
return render_template('index.html')
|
||
|
||
|
||
@app.route('/api/profiles')
|
||
def get_profiles():
|
||
"""Get list of AWS profiles"""
|
||
try:
|
||
config_path = Path.home() / '.aws' / 'config'
|
||
|
||
if not config_path.exists():
|
||
return jsonify({'error': f'AWS config file not found at {config_path}'}), 404
|
||
|
||
config = configparser.ConfigParser()
|
||
config.read(config_path)
|
||
|
||
profiles = []
|
||
for section in config.sections():
|
||
profile_name = None
|
||
if section.startswith('profile '):
|
||
profile_name = section.replace('profile ', '')
|
||
elif section == 'default':
|
||
profile_name = 'default'
|
||
|
||
if profile_name:
|
||
# Check if profile has MFA configured
|
||
has_mfa = config.has_option(section, 'mfa_serial')
|
||
profiles.append({
|
||
'name': profile_name,
|
||
'has_mfa': has_mfa
|
||
})
|
||
|
||
# Sort profiles alphabetically, but keep 'default' at the top
|
||
profiles.sort(key=lambda x: ('0' if x['name'] == 'default' else '1' + x['name'].lower()))
|
||
|
||
return jsonify({'profiles': profiles})
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
def send_progress(message, status='info'):
|
||
"""Send progress update via Server-Sent Events"""
|
||
return f"data: {json.dumps({'message': message, 'status': status})}\n\n"
|
||
|
||
|
||
def get_account_info_inline(session):
|
||
"""Get AWS account ID and alias (inline version)"""
|
||
sts = session.client('sts')
|
||
identity = sts.get_caller_identity()
|
||
account_id = identity['Account']
|
||
|
||
try:
|
||
iam = session.client('iam')
|
||
aliases = iam.list_account_aliases()
|
||
account_name = aliases['AccountAliases'][0] if aliases['AccountAliases'] else account_id
|
||
except:
|
||
account_name = account_id
|
||
|
||
return account_id, account_name
|
||
|
||
|
||
def import_profile(profile, mfa_code, progress_queue):
|
||
"""Import data from a single AWS profile (runs in thread)"""
|
||
try:
|
||
from import_from_aws import fetch_security_groups, fetch_ec2_instances
|
||
|
||
progress_queue.put(('info', f"[{profile}] Starting authentication..."))
|
||
|
||
# Read AWS config to get MFA serial
|
||
config_path = Path.home() / '.aws' / 'config'
|
||
config = configparser.ConfigParser()
|
||
config.read(config_path)
|
||
|
||
section_name = f'profile {profile}' if profile != 'default' else 'default'
|
||
mfa_serial = None
|
||
region = None
|
||
source_profile = None
|
||
role_arn = None
|
||
|
||
if section_name in config:
|
||
mfa_serial = config[section_name].get('mfa_serial')
|
||
region = config[section_name].get('region', 'us-east-1')
|
||
source_profile = config[section_name].get('source_profile')
|
||
role_arn = config[section_name].get('role_arn')
|
||
|
||
# Debug output
|
||
progress_queue.put(('info', f"[{profile}] Config: region={region}, mfa_serial={bool(mfa_serial)}, source_profile={source_profile}, role_arn={role_arn}"))
|
||
|
||
# Read base credentials from ~/.aws/credentials
|
||
creds_path = Path.home() / '.aws' / 'credentials'
|
||
creds_config = configparser.ConfigParser()
|
||
creds_config.read(creds_path)
|
||
|
||
# Determine which credentials section to use
|
||
# Priority: source_profile > profile name > default
|
||
if source_profile and source_profile in creds_config:
|
||
cred_section = source_profile
|
||
elif profile in creds_config:
|
||
cred_section = profile
|
||
elif 'default' in creds_config:
|
||
cred_section = 'default'
|
||
else:
|
||
progress_queue.put(('error', f"✗ [{profile}] Credentials not found in ~/.aws/credentials"))
|
||
return None
|
||
|
||
if cred_section not in creds_config:
|
||
progress_queue.put(('error', f"✗ [{profile}] Credentials not found in ~/.aws/credentials"))
|
||
return None
|
||
|
||
base_access_key = creds_config[cred_section].get('aws_access_key_id')
|
||
base_secret_key = creds_config[cred_section].get('aws_secret_access_key')
|
||
|
||
if not base_access_key or not base_secret_key:
|
||
progress_queue.put(('error', f"✗ [{profile}] Invalid credentials in ~/.aws/credentials"))
|
||
return None
|
||
|
||
# If MFA is configured and we have a code, use it
|
||
if mfa_serial and mfa_code:
|
||
progress_queue.put(('info', f"[{profile}] Using MFA authentication..."))
|
||
|
||
# Create STS client with base credentials (no session)
|
||
sts = boto3.client(
|
||
'sts',
|
||
aws_access_key_id=base_access_key,
|
||
aws_secret_access_key=base_secret_key,
|
||
region_name=region or 'us-east-1'
|
||
)
|
||
|
||
try:
|
||
# Get temporary credentials with MFA
|
||
response = sts.get_session_token(
|
||
DurationSeconds=3600,
|
||
SerialNumber=mfa_serial,
|
||
TokenCode=mfa_code
|
||
)
|
||
|
||
credentials = response['Credentials']
|
||
progress_queue.put(('success', f"✓ [{profile}] MFA authentication successful"))
|
||
|
||
# If there's a role to assume, assume it
|
||
if role_arn:
|
||
progress_queue.put(('info', f"[{profile}] Assuming role {role_arn}..."))
|
||
|
||
# Create STS client with MFA session credentials
|
||
sts_with_mfa = boto3.client(
|
||
'sts',
|
||
aws_access_key_id=credentials['AccessKeyId'],
|
||
aws_secret_access_key=credentials['SecretAccessKey'],
|
||
aws_session_token=credentials['SessionToken'],
|
||
region_name=region or 'us-east-1'
|
||
)
|
||
|
||
try:
|
||
# Assume the role
|
||
role_response = sts_with_mfa.assume_role(
|
||
RoleArn=role_arn,
|
||
RoleSessionName=f"{profile}-session"
|
||
)
|
||
|
||
role_credentials = role_response['Credentials']
|
||
session = boto3.Session(
|
||
aws_access_key_id=role_credentials['AccessKeyId'],
|
||
aws_secret_access_key=role_credentials['SecretAccessKey'],
|
||
aws_session_token=role_credentials['SessionToken'],
|
||
region_name=region or 'us-east-1'
|
||
)
|
||
progress_queue.put(('success', f"✓ [{profile}] Role assumption successful"))
|
||
except Exception as role_error:
|
||
progress_queue.put(('error', f"✗ [{profile}] Role assumption failed - {str(role_error)}"))
|
||
return None
|
||
else:
|
||
# No role to assume, use MFA session directly
|
||
session = boto3.Session(
|
||
aws_access_key_id=credentials['AccessKeyId'],
|
||
aws_secret_access_key=credentials['SecretAccessKey'],
|
||
aws_session_token=credentials['SessionToken'],
|
||
region_name=region or 'us-east-1'
|
||
)
|
||
|
||
except Exception as mfa_error:
|
||
progress_queue.put(('error', f"✗ [{profile}] MFA authentication failed - {str(mfa_error)}"))
|
||
return None
|
||
else:
|
||
# No MFA configured or no code provided
|
||
if mfa_serial and not mfa_code:
|
||
progress_queue.put(('error', f"✗ [{profile}] MFA code required but not provided"))
|
||
return None
|
||
|
||
progress_queue.put(('info', f"[{profile}] Using direct authentication (no MFA)..."))
|
||
|
||
# If there's a role to assume (without MFA)
|
||
if role_arn:
|
||
progress_queue.put(('info', f"[{profile}] Assuming role {role_arn}..."))
|
||
|
||
sts = boto3.client(
|
||
'sts',
|
||
aws_access_key_id=base_access_key,
|
||
aws_secret_access_key=base_secret_key,
|
||
region_name=region or 'us-east-1'
|
||
)
|
||
|
||
try:
|
||
role_response = sts.assume_role(
|
||
RoleArn=role_arn,
|
||
RoleSessionName=f"{profile}-session"
|
||
)
|
||
|
||
role_credentials = role_response['Credentials']
|
||
session = boto3.Session(
|
||
aws_access_key_id=role_credentials['AccessKeyId'],
|
||
aws_secret_access_key=role_credentials['SecretAccessKey'],
|
||
aws_session_token=role_credentials['SessionToken'],
|
||
region_name=region or 'us-east-1'
|
||
)
|
||
progress_queue.put(('success', f"✓ [{profile}] Role assumption successful"))
|
||
except Exception as role_error:
|
||
progress_queue.put(('error', f"✗ [{profile}] Role assumption failed - {str(role_error)}"))
|
||
return None
|
||
else:
|
||
# No role, use base credentials directly
|
||
session = boto3.Session(
|
||
aws_access_key_id=base_access_key,
|
||
aws_secret_access_key=base_secret_key,
|
||
region_name=region or 'us-east-1'
|
||
)
|
||
|
||
# Verify it works
|
||
try:
|
||
sts = session.client('sts')
|
||
sts.get_caller_identity()
|
||
progress_queue.put(('success', f"✓ [{profile}] Authentication successful"))
|
||
except Exception as e:
|
||
progress_queue.put(('error', f"✗ [{profile}] Authentication failed - {str(e)}"))
|
||
return None
|
||
|
||
# Get account info
|
||
account_id, account_name = get_account_info_inline(session)
|
||
progress_queue.put(('info', f" [{profile}] Account: {account_name} ({account_id})"))
|
||
|
||
# Cache the session credentials for reuse (valid for 1 hour)
|
||
global session_cache
|
||
session_cache[profile] = {
|
||
'session': session,
|
||
'region': region,
|
||
'timestamp': time.time(),
|
||
'account_id': account_id,
|
||
'account_name': account_name
|
||
}
|
||
|
||
# Fetch data
|
||
progress_queue.put(('info', f" [{profile}] Fetching security groups..."))
|
||
security_groups, sg_rules = fetch_security_groups(session, account_id, account_name)
|
||
progress_queue.put(('success', f" ✓ [{profile}] Found {len(security_groups)} security groups with {len(sg_rules)} rules"))
|
||
|
||
progress_queue.put(('info', f" [{profile}] Fetching EC2 instances..."))
|
||
ec2_instances = fetch_ec2_instances(session, account_id, account_name)
|
||
progress_queue.put(('success', f" ✓ [{profile}] Found {len(ec2_instances)} EC2 instances"))
|
||
|
||
return {
|
||
'profile': profile,
|
||
'security_groups': security_groups,
|
||
'ec2_instances': ec2_instances,
|
||
'sg_rules': sg_rules
|
||
}
|
||
|
||
except Exception as e:
|
||
progress_queue.put(('error', f"✗ [{profile}] Error - {str(e)}"))
|
||
return None
|
||
|
||
|
||
@app.route('/api/import', methods=['POST'])
|
||
def import_data():
|
||
"""Import data from AWS with parallel execution and streaming progress"""
|
||
global data_imported
|
||
|
||
data = request.json
|
||
selected_profiles = data.get('profiles', [])
|
||
mfa_codes = data.get('mfa_codes', {})
|
||
|
||
def generate():
|
||
try:
|
||
from import_from_aws import import_to_database
|
||
|
||
yield send_progress(f"Starting parallel import from {len(selected_profiles)} profile(s)...", 'info')
|
||
|
||
# Create a queue for progress messages from threads
|
||
progress_queue = queue.Queue()
|
||
|
||
# Submit all profiles for parallel execution
|
||
with ThreadPoolExecutor(max_workers=len(selected_profiles)) as executor:
|
||
# Submit all import tasks
|
||
futures = {}
|
||
for profile in selected_profiles:
|
||
mfa_code = mfa_codes.get(profile, '')
|
||
future = executor.submit(import_profile, profile, mfa_code, progress_queue)
|
||
futures[future] = profile
|
||
|
||
# Process results as they complete and drain progress queue
|
||
all_security_groups = []
|
||
all_ec2_instances = []
|
||
all_sg_rules = []
|
||
completed = 0
|
||
|
||
while completed < len(selected_profiles):
|
||
# Check for progress messages
|
||
while not progress_queue.empty():
|
||
status, message = progress_queue.get()
|
||
yield send_progress(message, status)
|
||
|
||
# Check for completed futures
|
||
for future in as_completed(futures, timeout=0.1):
|
||
if future in futures:
|
||
result = future.result()
|
||
completed += 1
|
||
|
||
if result:
|
||
all_security_groups.extend(result['security_groups'])
|
||
all_ec2_instances.extend(result['ec2_instances'])
|
||
all_sg_rules.extend(result['sg_rules'])
|
||
|
||
del futures[future]
|
||
break
|
||
|
||
time.sleep(0.1) # Small delay to prevent busy waiting
|
||
|
||
# Drain any remaining progress messages
|
||
while not progress_queue.empty():
|
||
status, message = progress_queue.get()
|
||
yield send_progress(message, status)
|
||
|
||
# Import to database
|
||
if all_security_groups or all_ec2_instances:
|
||
yield send_progress("Importing to database...", 'info')
|
||
import_to_database(DB_PATH, all_security_groups, all_ec2_instances, all_sg_rules, append=False)
|
||
|
||
yield send_progress(f"✓ Import complete!", 'success')
|
||
yield send_progress(f" Total Security Groups: {len(all_security_groups)}", 'success')
|
||
yield send_progress(f" Total EC2 Instances: {len(all_ec2_instances)}", 'success')
|
||
yield send_progress(f" Total SG Rules: {len(all_sg_rules)}", 'success')
|
||
|
||
data_imported = True
|
||
|
||
yield send_progress("Redirecting to explorer...", 'complete')
|
||
else:
|
||
yield send_progress("✗ No data imported", 'error')
|
||
|
||
except Exception as e:
|
||
yield send_progress(f"✗ Import failed: {str(e)}", 'error')
|
||
|
||
return Response(stream_with_context(generate()), mimetype='text/event-stream')
|
||
|
||
|
||
@app.route('/api/import-profile', methods=['POST'])
|
||
def import_single_profile():
|
||
"""Import data from a single AWS profile with streaming progress"""
|
||
global data_imported
|
||
|
||
data = request.json
|
||
profile = data.get('profile')
|
||
mfa_code = data.get('mfa_code', '')
|
||
|
||
def generate():
|
||
try:
|
||
from import_from_aws import import_to_database
|
||
|
||
yield send_progress(f"Starting import from {profile}...", 'info')
|
||
|
||
# Create a queue for progress messages
|
||
progress_queue = queue.Queue()
|
||
|
||
# Import the profile
|
||
result = import_profile(profile, mfa_code, progress_queue)
|
||
|
||
# Drain progress messages
|
||
while not progress_queue.empty():
|
||
status, message = progress_queue.get()
|
||
yield send_progress(message, status)
|
||
|
||
# Import to database
|
||
if result:
|
||
yield send_progress("Importing to database...", 'info')
|
||
import_to_database(
|
||
DB_PATH,
|
||
result['security_groups'],
|
||
result['ec2_instances'],
|
||
result['sg_rules'],
|
||
append=True # Append mode for individual imports
|
||
)
|
||
|
||
yield send_progress(f"✓ Import complete for {profile}!", 'success')
|
||
yield send_progress(f" Security Groups: {len(result['security_groups'])}", 'success')
|
||
yield send_progress(f" EC2 Instances: {len(result['ec2_instances'])}", 'success')
|
||
yield send_progress(f" SG Rules: {len(result['sg_rules'])}", 'success')
|
||
|
||
data_imported = True
|
||
yield send_progress("Done", 'complete')
|
||
else:
|
||
yield send_progress(f"✗ Import failed for {profile}", 'error')
|
||
|
||
except Exception as e:
|
||
yield send_progress(f"✗ Import failed: {str(e)}", 'error')
|
||
|
||
return Response(stream_with_context(generate()), mimetype='text/event-stream')
|
||
|
||
|
||
@app.route('/api/refresh-cached', methods=['POST'])
|
||
def refresh_cached():
|
||
"""Refresh data using cached AWS sessions (if still valid)"""
|
||
global session_cache, data_imported
|
||
|
||
if not session_cache:
|
||
return jsonify({'error': 'No cached sessions', 'redirect': True})
|
||
|
||
def generate():
|
||
try:
|
||
from import_from_aws import fetch_security_groups, fetch_ec2_instances, import_to_database
|
||
|
||
# Check if cached sessions are still valid (< 1 hour old)
|
||
current_time = time.time()
|
||
valid_profiles = []
|
||
|
||
for profile, cache_data in session_cache.items():
|
||
age_minutes = (current_time - cache_data['timestamp']) / 60
|
||
if age_minutes < 55: # Use 55 minutes to be safe
|
||
valid_profiles.append(profile)
|
||
else:
|
||
yield send_progress(f"[{profile}] Session expired ({age_minutes:.1f} min old)", 'error')
|
||
|
||
if not valid_profiles:
|
||
yield send_progress("All sessions expired. Please re-authenticate.", 'error')
|
||
yield send_progress("REDIRECT", 'redirect')
|
||
return
|
||
|
||
yield send_progress(f"Refreshing data from {len(valid_profiles)} cached session(s)...", 'info')
|
||
|
||
all_security_groups = []
|
||
all_ec2_instances = []
|
||
all_sg_rules = []
|
||
|
||
for profile in valid_profiles:
|
||
cache_data = session_cache[profile]
|
||
session = cache_data['session']
|
||
account_id = cache_data['account_id']
|
||
account_name = cache_data['account_name']
|
||
|
||
try:
|
||
yield send_progress(f"[{profile}] Fetching security groups...", 'info')
|
||
security_groups, sg_rules = fetch_security_groups(session, account_id, account_name)
|
||
yield send_progress(f"✓ [{profile}] Found {len(security_groups)} security groups", 'success')
|
||
|
||
yield send_progress(f"[{profile}] Fetching EC2 instances...", 'info')
|
||
ec2_instances = fetch_ec2_instances(session, account_id, account_name)
|
||
yield send_progress(f"✓ [{profile}] Found {len(ec2_instances)} EC2 instances", 'success')
|
||
|
||
all_security_groups.extend(security_groups)
|
||
all_ec2_instances.extend(ec2_instances)
|
||
all_sg_rules.extend(sg_rules)
|
||
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
if 'ExpiredToken' in error_msg or 'InvalidToken' in error_msg:
|
||
yield send_progress(f"✗ [{profile}] Session expired", 'error')
|
||
yield send_progress("REDIRECT", 'redirect')
|
||
return
|
||
else:
|
||
yield send_progress(f"✗ [{profile}] Error: {error_msg}", 'error')
|
||
|
||
# Import to database
|
||
if all_security_groups or all_ec2_instances:
|
||
yield send_progress("Updating database...", 'info')
|
||
import_to_database(DB_PATH, all_security_groups, all_ec2_instances, all_sg_rules, append=False)
|
||
|
||
yield send_progress(f"✓ Refresh complete!", 'success')
|
||
yield send_progress(f" Total Security Groups: {len(all_security_groups)}", 'success')
|
||
yield send_progress(f" Total EC2 Instances: {len(all_ec2_instances)}", 'success')
|
||
|
||
data_imported = True
|
||
yield send_progress("COMPLETE", 'complete')
|
||
else:
|
||
yield send_progress("✗ No data refreshed", 'error')
|
||
|
||
except Exception as e:
|
||
yield send_progress(f"✗ Refresh failed: {str(e)}", 'error')
|
||
|
||
return Response(stream_with_context(generate()), mimetype='text/event-stream')
|
||
|
||
|
||
@app.route('/api/refresh', methods=['POST'])
|
||
def refresh_data():
|
||
"""Refresh data from AWS - reuses existing MFA session if valid"""
|
||
return import_data()
|
||
|
||
|
||
@app.route('/api/tags')
|
||
def get_tags():
|
||
"""Get all available tag values for filtering"""
|
||
conn = get_db()
|
||
|
||
# Get distinct tag_wave values
|
||
waves = conn.execute("""
|
||
SELECT DISTINCT tag_wave FROM security_groups
|
||
WHERE tag_wave IS NOT NULL AND tag_wave != ''
|
||
ORDER BY tag_wave
|
||
""").fetchall()
|
||
|
||
# Get distinct tag_git_repo values from both tables
|
||
repos = conn.execute("""
|
||
SELECT DISTINCT tag_git_repo FROM security_groups
|
||
WHERE tag_git_repo IS NOT NULL AND tag_git_repo != ''
|
||
UNION
|
||
SELECT DISTINCT tag_git_repo FROM ec2_instances
|
||
WHERE tag_git_repo IS NOT NULL AND tag_git_repo != ''
|
||
ORDER BY tag_git_repo
|
||
""").fetchall()
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
'waves': [w['tag_wave'] for w in waves],
|
||
'repos': [r['tag_git_repo'] for r in repos]
|
||
})
|
||
|
||
|
||
@app.route('/api/search')
|
||
def search():
|
||
"""Search for EC2 instances or security groups"""
|
||
query = request.args.get('q', '').strip()
|
||
search_type = request.args.get('type', 'all')
|
||
use_regex = request.args.get('regex', 'false').lower() == 'true'
|
||
filter_wave = request.args.get('wave', '').strip()
|
||
filter_repo = request.args.get('repo', '').strip()
|
||
|
||
conn = get_db()
|
||
results = []
|
||
|
||
try:
|
||
if search_type in ['all', 'sg']:
|
||
# Build WHERE clause with tag filters
|
||
where_clauses = []
|
||
params = []
|
||
|
||
if query:
|
||
if use_regex:
|
||
try:
|
||
re.compile(query)
|
||
except re.error as e:
|
||
conn.close()
|
||
return jsonify({'error': f'Invalid regex pattern: {str(e)}', 'results': []})
|
||
where_clauses.append("(group_id REGEXP ? OR group_name REGEXP ? OR tag_name REGEXP ?)")
|
||
params.extend([query, query, query])
|
||
else:
|
||
where_clauses.append("(group_id LIKE ? OR group_name LIKE ? OR tag_name LIKE ?)")
|
||
params.extend([f'%{query}%', f'%{query}%', f'%{query}%'])
|
||
|
||
if filter_wave:
|
||
where_clauses.append("tag_wave = ?")
|
||
params.append(filter_wave)
|
||
|
||
if filter_repo:
|
||
where_clauses.append("tag_git_repo = ?")
|
||
params.append(filter_repo)
|
||
|
||
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||
|
||
sg_results = conn.execute(f"""
|
||
SELECT 'sg' as type, group_id as id, group_name as name, tag_name,
|
||
account_name, account_id, tag_wave, tag_git_repo, tag_git_org, tag_git_file,
|
||
ingress_rule_count
|
||
FROM security_groups
|
||
WHERE {where_sql}
|
||
ORDER BY tag_name, group_name
|
||
LIMIT 500
|
||
""", params).fetchall()
|
||
|
||
for row in sg_results:
|
||
results.append(dict(row))
|
||
|
||
if search_type in ['all', 'ec2']:
|
||
# Build WHERE clause with tag filters
|
||
where_clauses = []
|
||
params = []
|
||
|
||
if query:
|
||
if use_regex:
|
||
where_clauses.append("(instance_id REGEXP ? OR tag_name REGEXP ? OR private_ip_address REGEXP ?)")
|
||
params.extend([query, query, query])
|
||
else:
|
||
where_clauses.append("(instance_id LIKE ? OR tag_name LIKE ? OR private_ip_address LIKE ?)")
|
||
params.extend([f'%{query}%', f'%{query}%', f'%{query}%'])
|
||
|
||
if filter_repo:
|
||
where_clauses.append("tag_git_repo = ?")
|
||
params.append(filter_repo)
|
||
|
||
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||
|
||
ec2_results = conn.execute(f"""
|
||
SELECT 'ec2' as type, instance_id as id, tag_name as name, tag_name,
|
||
account_name, account_id, state, private_ip_address,
|
||
security_groups_id_list, security_groups_name_list, tag_git_repo,
|
||
tag_git_org, tag_git_file
|
||
FROM ec2_instances
|
||
WHERE {where_sql}
|
||
ORDER BY tag_name
|
||
LIMIT 500
|
||
""", params).fetchall()
|
||
|
||
for row in ec2_results:
|
||
results.append(dict(row))
|
||
|
||
except Exception as e:
|
||
conn.close()
|
||
return jsonify({'error': f'Search error: {str(e)}', 'results': []})
|
||
|
||
conn.close()
|
||
return jsonify({'results': results})
|
||
|
||
|
||
@app.route('/api/ec2/<instance_id>')
|
||
def get_ec2_details(instance_id):
|
||
"""Get detailed information about an EC2 instance and its security groups"""
|
||
conn = get_db()
|
||
|
||
ec2 = conn.execute("""
|
||
SELECT * FROM ec2_instances WHERE instance_id = ?
|
||
""", (instance_id,)).fetchone()
|
||
|
||
if not ec2:
|
||
conn.close()
|
||
return jsonify({'error': 'EC2 instance not found'}), 404
|
||
|
||
ec2_dict = dict(ec2)
|
||
|
||
sg_ids = ec2_dict['security_groups_id_list'].split(';') if ec2_dict['security_groups_id_list'] else []
|
||
|
||
security_groups = []
|
||
for sg_id in sg_ids:
|
||
if sg_id:
|
||
sg = conn.execute("""
|
||
SELECT * FROM security_groups WHERE group_id = ?
|
||
""", (sg_id,)).fetchone()
|
||
if sg:
|
||
security_groups.append(dict(sg))
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
'ec2': ec2_dict,
|
||
'security_groups': security_groups
|
||
})
|
||
|
||
|
||
@app.route('/api/sg/<group_id>')
|
||
def get_sg_details(group_id):
|
||
"""Get detailed information about a security group and attached EC2 instances"""
|
||
conn = get_db()
|
||
|
||
sg = conn.execute("""
|
||
SELECT * FROM security_groups WHERE group_id = ?
|
||
""", (group_id,)).fetchone()
|
||
|
||
if not sg:
|
||
conn.close()
|
||
return jsonify({'error': 'Security group not found'}), 404
|
||
|
||
sg_dict = dict(sg)
|
||
|
||
ec2_instances = conn.execute("""
|
||
SELECT * FROM ec2_instances
|
||
WHERE security_groups_id_list LIKE ?
|
||
""", (f'%{group_id}%',)).fetchall()
|
||
|
||
ec2_list = [dict(row) for row in ec2_instances]
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
'security_group': sg_dict,
|
||
'ec2_instances': ec2_list
|
||
})
|
||
|
||
|
||
@app.route('/api/sg/<group_id>/rules')
|
||
def get_sg_rules(group_id):
|
||
"""Get all rules for a security group"""
|
||
conn = get_db()
|
||
|
||
ingress_rules = conn.execute("""
|
||
SELECT * FROM sg_rules
|
||
WHERE group_id = ? AND direction = 'ingress'
|
||
ORDER BY protocol, port_range, source
|
||
""", (group_id,)).fetchall()
|
||
|
||
egress_rules = conn.execute("""
|
||
SELECT * FROM sg_rules
|
||
WHERE group_id = ? AND direction = 'egress'
|
||
ORDER BY protocol, port_range, source
|
||
""", (group_id,)).fetchall()
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
'ingress': [dict(row) for row in ingress_rules],
|
||
'egress': [dict(row) for row in egress_rules]
|
||
})
|
||
|
||
|
||
@app.route('/api/stats')
|
||
def get_stats():
|
||
"""Get database statistics"""
|
||
conn = get_db()
|
||
|
||
sg_count = conn.execute("SELECT COUNT(*) as count FROM security_groups").fetchone()['count']
|
||
ec2_count = conn.execute("SELECT COUNT(*) as count FROM ec2_instances").fetchone()['count']
|
||
|
||
accounts = conn.execute("""
|
||
SELECT DISTINCT account_name FROM security_groups
|
||
UNION
|
||
SELECT DISTINCT account_name FROM ec2_instances
|
||
ORDER BY account_name
|
||
""").fetchall()
|
||
|
||
# Get refresh timestamps
|
||
refresh_times = conn.execute("""
|
||
SELECT account_name, last_refresh
|
||
FROM refresh_timestamps
|
||
ORDER BY last_refresh DESC
|
||
""").fetchall()
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
'security_groups': sg_count,
|
||
'ec2_instances': ec2_count,
|
||
'accounts': [a['account_name'] for a in accounts],
|
||
'refresh_timestamps': [{'account': r['account_name'], 'timestamp': r['last_refresh']} for r in refresh_times]
|
||
})
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# Get debug mode from environment variable
|
||
debug_mode = os.getenv('DEBUG', 'false').lower() in ('true', '1', 'yes')
|
||
|
||
print("\n" + "="*60)
|
||
print("🔭 SGO: Security Groups (and Instances) Observatory")
|
||
print("="*60)
|
||
print(f"\nℹ️ Database location: {DB_PATH}")
|
||
print("ℹ️ Database is persistent - data will be preserved between runs")
|
||
print("ℹ️ Access the application at: http://localhost:5000")
|
||
print(f"ℹ️ Debug mode: {'enabled' if debug_mode else 'disabled'}")
|
||
print("\n" + "="*60 + "\n")
|
||
|
||
app.run(host='0.0.0.0', port=5000, debug=debug_mode)
|