aegisai / backend / services / database_service.py
database_service.py
Raw
"""
Database Service - SQLite operations with async support
"""

import sqlite3
import json
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime
from pathlib import Path

from config.settings import settings

logger = logging.getLogger(__name__)


class DatabaseService:
    """Thread-safe database operations for incident tracking"""
    
    def __init__(self, db_path: Optional[Path] = None):
        self.db_path = db_path or settings.DB_PATH
        self._ensure_database()
    
    def _ensure_database(self):
        """Create database and tables if they don't exist"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # Incidents table
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS incidents (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp TEXT NOT NULL,
                incident_type TEXT NOT NULL,
                severity TEXT NOT NULL,
                confidence REAL NOT NULL,
                reasoning TEXT NOT NULL,
                subjects TEXT,
                evidence_path TEXT,
                response_plan TEXT,
                status TEXT DEFAULT 'active',
                created_at TEXT DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Actions table
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS actions (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                incident_id INTEGER,
                action_type TEXT NOT NULL,
                action_data TEXT,
                status TEXT DEFAULT 'pending',
                executed_at TEXT,
                created_at TEXT DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (incident_id) REFERENCES incidents(id)
            )
        """)
        
        # System stats table
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS system_stats (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                component TEXT NOT NULL,
                metric TEXT NOT NULL,
                value REAL NOT NULL,
                timestamp TEXT DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        # Indexes for performance
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_incidents_created 
            ON incidents(created_at DESC)
        """)
        
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_incidents_severity 
            ON incidents(severity)
        """)
        
        conn.commit()
        conn.close()
        logger.info(f"Database initialized: {self.db_path}")
    
    def save_incident(self, incident_data: Dict[str, Any]) -> int:
        """
        Save incident to database
        
        Args:
            incident_data: Incident details
            
        Returns:
            Incident ID
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                INSERT INTO incidents (
                    timestamp, incident_type, severity, confidence,
                    reasoning, subjects, evidence_path, response_plan, created_at
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                incident_data.get('timestamp', datetime.now().isoformat()),
                incident_data.get('type', 'unknown'),
                incident_data.get('severity', 'low'),
                incident_data.get('confidence', 0),
                incident_data.get('reasoning', ''),
                json.dumps(incident_data.get('subjects', [])),
                incident_data.get('evidence_path', ''),
                json.dumps(incident_data.get('response_plan', [])),
                incident_data.get('created_at', datetime.now().isoformat())  # <-- fix: allows injecting old timestamp
            ))
            
            incident_id = cursor.lastrowid
            conn.commit()
            
            logger.info(f"Saved incident #{incident_id}")
            return incident_id
            
        except Exception as e:
            logger.error(f"Failed to save incident: {e}")
            conn.rollback()
            return -1
        finally:
            conn.close()
    
    def save_action(
        self, 
        incident_id: int, 
        action_type: str, 
        action_data: Dict[str, Any]
    ) -> int:
        """Save executed action"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                INSERT INTO actions (
                    incident_id, action_type, action_data, 
                    status, executed_at
                ) VALUES (?, ?, ?, ?, ?)
            """, (
                incident_id,
                action_type,
                json.dumps(action_data),
                action_data.get('status', 'completed'),
                datetime.now().isoformat()
            ))
            
            action_id = cursor.lastrowid
            conn.commit()
            
            return action_id
            
        except Exception as e:
            logger.error(f"Failed to save action: {e}")
            conn.rollback()
            return -1
        finally:
            conn.close()
    
    def get_recent_incidents(
        self, 
        limit: int = 50,
        severity: Optional[str] = None
    ) -> List[Dict]:
        """Retrieve recent incidents with optional severity filter"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            query = """
                SELECT 
                    id, timestamp, incident_type, severity, confidence,
                    reasoning, subjects, evidence_path, status, created_at
                FROM incidents
            """
            
            params = []
            
            if severity:
                query += " WHERE severity = ?"
                params.append(severity)
            
            query += " ORDER BY timestamp DESC LIMIT ?"
            params.append(limit)
            
            cursor.execute(query, params)
            rows = cursor.fetchall()
            
            incidents = []
            for row in rows:
                incidents.append({
                    'id': row[0],
                    'timestamp': row[1],
                    'type': row[2],
                    'severity': row[3],
                    'confidence': row[4],
                    'reasoning': row[5],
                    'subjects': json.loads(row[6]) if row[6] else [],
                    'evidence_path': row[7],
                    'status': row[8],
                    'created_at': row[9]
                })
            
            return incidents
            
        except Exception as e:
            logger.error(f"Failed to retrieve incidents: {e}")
            return []
        finally:
            conn.close()
    
    def get_incident_by_id(self, incident_id: int) -> Optional[Dict]:
        """Get single incident by ID"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                SELECT * FROM incidents WHERE id = ?
            """, (incident_id,))
            
            row = cursor.fetchone()
            
            if row:
                return {
                    'id': row[0],
                    'timestamp': row[1],
                    'type': row[2],
                    'severity': row[3],
                    'confidence': row[4],
                    'reasoning': row[5],
                    'subjects': json.loads(row[6]) if row[6] else [],
                    'evidence_path': row[7],
                    'response_plan': json.loads(row[8]) if row[8] else [],
                    'status': row[9],
                    'created_at': row[10]
                }
            
            return None
            
        except Exception as e:
            logger.error(f"Failed to get incident: {e}")
            return None
        finally:
            conn.close()
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get system statistics"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            # Total incidents
            cursor.execute("SELECT COUNT(*) FROM incidents")
            total_incidents = cursor.fetchone()[0]
            
            # Active incidents
            cursor.execute(
                "SELECT COUNT(*) FROM incidents WHERE status = 'active'"
            )
            active_incidents = cursor.fetchone()[0]
            
            # Severity breakdown
            cursor.execute("""
                SELECT severity, COUNT(*) 
                FROM incidents 
                GROUP BY severity
            """)
            severity_breakdown = dict(cursor.fetchall())
            
            # Recent incidents (last 24h)
            cursor.execute("""
                SELECT COUNT(*) 
                FROM incidents 
                WHERE datetime(created_at) > datetime('now', '-1 day')
            """)
            recent_incidents = cursor.fetchone()[0]
            
            return {
                'total_incidents': total_incidents,
                'active_incidents': active_incidents,
                'severity_breakdown': severity_breakdown,
                'recent_24h': recent_incidents
            }
            
        except Exception as e:
            logger.error(f"Failed to get statistics: {e}")
            return {}
        finally:
            conn.close()
    
    def update_incident_status(self, incident_id: int, status: str) -> bool:
        """Update incident status"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cursor.execute("""
                UPDATE incidents 
                SET status = ? 
                WHERE id = ?
            """, (status, incident_id))
            
            conn.commit()
            return cursor.rowcount > 0
            
        except Exception as e:
            logger.error(f"Failed to update incident: {e}")
            conn.rollback()
            return False
        finally:
            conn.close()
    
    def cleanup_old_incidents(self, days: int = 30) -> int:
        """Delete incidents older than specified days"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        try:
            cutoff_modifier = f'-{days} days'
            cursor.execute("""
                DELETE FROM incidents 
                WHERE datetime(created_at) < datetime('now', ?)
            """, (cutoff_modifier,))
            
            deleted_count = cursor.rowcount
            conn.commit()
            
            logger.info(f"Cleaned up {deleted_count} old incidents")
            return deleted_count
            
        except Exception as e:
            logger.error(f"Cleanup failed: {e}")
            conn.rollback()
            return 0
        finally:
            conn.close()


# Singleton instance
db_service = DatabaseService()