"""
Real-time Training Monitor
Provides live dashboard for monitoring AI-AntiBot training
"""
import json
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any
from collections import deque
import logging

from ..training.trainer import TurnstileTrainer
from ..models.detector import BotDetector

logger = logging.getLogger(__name__)

class TrainingMonitor:
    """Real-time monitoring and dashboard for AI-AntiBot training"""
    
    def __init__(self):
        self.trainer = TurnstileTrainer()
        self.detector = BotDetector()
        self.real_time_feed = deque(maxlen=1000)  # Keep last 1000 events
        self.performance_metrics = deque(maxlen=100)  # Keep last 100 performance points
        self.hourly_stats = {}
        
    def add_real_time_event(self, event_type: str, data: Dict[str, Any]):
        """Add an event to the real-time feed"""
        event = {
            'timestamp': datetime.now().isoformat(),
            'event_type': event_type,
            'data': data,
            'id': int(time.time() * 1000)  # Unique ID for frontend
        }
        self.real_time_feed.append(event)
        
        # Update hourly stats
        hour_key = datetime.now().strftime('%Y-%m-%d %H')
        if hour_key not in self.hourly_stats:
            self.hourly_stats[hour_key] = {
                'total_interactions': 0,
                'human_count': 0,
                'bot_count': 0,
                'training_events': 0
            }
        
        if event_type == 'turnstile_interaction':
            self.hourly_stats[hour_key]['total_interactions'] += 1
            if data.get('turnstile_passed'):
                self.hourly_stats[hour_key]['human_count'] += 1
            else:
                self.hourly_stats[hour_key]['bot_count'] += 1
        elif event_type == 'model_training':
            self.hourly_stats[hour_key]['training_events'] += 1
    
    def get_dashboard_data(self) -> Dict[str, Any]:
        """Get comprehensive dashboard data"""
        training_stats = self.trainer.get_training_stats()
        model_stats = self.detector.get_model_stats()
        
        dashboard_data = {
            'overview': {
                'total_training_samples': training_stats.get('total_samples', 0),
                'human_samples': training_stats.get('human_samples', 0),
                'bot_samples': training_stats.get('bot_samples', 0),
                'model_accuracy': model_stats.get('accuracy', 0),
                'model_loaded': model_stats.get('model_loaded', False),
                'training_active': training_stats.get('training_active', False),
                'last_training_time': training_stats.get('last_training_time'),
                'class_balance': training_stats.get('class_balance', 0)
            },
            'real_time_feed': list(self.real_time_feed)[-50:],  # Last 50 events
            'performance_metrics': {
                'accuracy': model_stats.get('accuracy', 0),
                'precision': model_stats.get('precision', 0),
                'recall': model_stats.get('recall', 0),
                'f1_score': model_stats.get('f1_score', 0)
            },
            'hourly_stats': self._get_recent_hourly_stats(),
            'feature_importance': self.detector.get_feature_importance(),
            'training_progress': self._get_training_progress(),
            'system_status': self._get_system_status()
        }
        
        return dashboard_data
    
    def get_live_feed(self, since_id: int = 0) -> List[Dict[str, Any]]:
        """Get live feed events since a specific ID"""
        return [event for event in self.real_time_feed if event['id'] > since_id]
    
    def _get_recent_hourly_stats(self) -> List[Dict[str, Any]]:
        """Get hourly statistics for the last 24 hours"""
        current_time = datetime.now()
        stats = []
        
        for i in range(24):
            hour_time = current_time - timedelta(hours=i)
            hour_key = hour_time.strftime('%Y-%m-%d %H')
            
            if hour_key in self.hourly_stats:
                stat = self.hourly_stats[hour_key].copy()
                stat['hour'] = hour_time.strftime('%H:00')
                stat['date'] = hour_time.strftime('%Y-%m-%d')
                stats.append(stat)
            else:
                stats.append({
                    'hour': hour_time.strftime('%H:00'),
                    'date': hour_time.strftime('%Y-%m-%d'),
                    'total_interactions': 0,
                    'human_count': 0,
                    'bot_count': 0,
                    'training_events': 0
                })
        
        return list(reversed(stats))
    
    def _get_training_progress(self) -> Dict[str, Any]:
        """Get training progress information"""
        stats = self.trainer.get_training_stats()
        
        # Calculate progress towards next training threshold
        current_samples = stats.get('total_samples', 0)
        threshold = 20  # Training threshold
        
        progress = {
            'current_samples': current_samples,
            'threshold': threshold,
            'progress_percentage': min(100, (current_samples / threshold) * 100),
            'samples_needed': max(0, threshold - current_samples),
            'ready_for_training': current_samples >= threshold
        }
        
        return progress
    
    def _get_system_status(self) -> Dict[str, Any]:
        """Get overall system status"""
        training_stats = self.trainer.get_training_stats()
        model_stats = self.detector.get_model_stats()
        
        # Determine system health
        health_score = 0
        issues = []
        
        # Check model status
        if model_stats.get('model_loaded', False):
            health_score += 25
        else:
            issues.append("Model not loaded")
        
        # Check training data
        if training_stats.get('total_samples', 0) > 100:
            health_score += 25
        else:
            issues.append("Insufficient training data")
        
        # Check model accuracy
        accuracy = model_stats.get('accuracy', 0)
        if accuracy > 0.9:
            health_score += 25
        elif accuracy > 0.8:
            health_score += 15
        elif accuracy > 0.7:
            health_score += 10
        else:
            issues.append("Low model accuracy")
        
        # Check class balance
        class_balance = training_stats.get('class_balance', 0)
        if class_balance > 0.3:
            health_score += 25
        elif class_balance > 0.2:
            health_score += 15
        else:
            issues.append("Imbalanced training data")
        
        # Determine status
        if health_score >= 90:
            status = "excellent"
        elif health_score >= 70:
            status = "good"
        elif health_score >= 50:
            status = "fair"
        else:
            status = "poor"
        
        return {
            'health_score': health_score,
            'status': status,
            'issues': issues,
            'last_updated': datetime.now().isoformat()
        }
    
    def record_turnstile_interaction(self, request, client_ip: str, turnstile_passed: bool, 
                                   session_data: Dict = None, extra_context: Dict = None):
        """Record a Turnstile interaction for training"""
        # Record in trainer
        result = self.trainer.record_interaction(request, client_ip, turnstile_passed, 
                                               session_data, extra_context)
        
        # Add to real-time feed
        self.add_real_time_event('turnstile_interaction', {
            'ip_hash': result.get('log_entry', {}).get('ip_hash', 'unknown'),
            'turnstile_passed': turnstile_passed,
            'label': 'human' if turnstile_passed else 'bot',
            'total_samples': result.get('total_samples', 0),
            'session_duration': session_data.get('session_duration', 0) if session_data else 0,
            'user_agent_length': len(request.headers.get('User-Agent', '')),
            'extra_context': extra_context or {}
        })
        
        return result
    
    def record_ai_prediction(self, request, client_ip: str, prediction: bool, 
                           confidence: float, details: Dict = None):
        """Record an AI prediction (when AI is in autonomous mode)"""
        self.add_real_time_event('ai_prediction', {
            'ip_hash': client_ip[:8] + '...',
            'prediction': prediction,
            'confidence': confidence,
            'label': 'bot' if prediction else 'human',
            'details': details or {}
        })
    
    def record_model_training(self, training_result: Dict[str, Any]):
        """Record a model training event"""
        self.add_real_time_event('model_training', training_result)
    
    def get_training_visualization_data(self) -> Dict[str, Any]:
        """Get data for training visualization charts"""
        stats = self.trainer.get_training_stats()
        
        # Sample distribution
        sample_distribution = {
            'human': stats.get('human_samples', 0),
            'bot': stats.get('bot_samples', 0)
        }
        
        # Training timeline (mock data for now - would be populated from actual training logs)
        training_timeline = [
            {'timestamp': '2024-01-01 10:00', 'accuracy': 0.65, 'samples': 100},
            {'timestamp': '2024-01-01 11:00', 'accuracy': 0.72, 'samples': 250},
            {'timestamp': '2024-01-01 12:00', 'accuracy': 0.78, 'samples': 400},
            {'timestamp': '2024-01-01 13:00', 'accuracy': 0.85, 'samples': 650},
            {'timestamp': '2024-01-01 14:00', 'accuracy': 0.89, 'samples': 850}
        ]
        
        # Feature importance (top 10)
        feature_importance = self.detector.get_feature_importance()
        top_features = dict(list(feature_importance.items())[:10])
        
        return {
            'sample_distribution': sample_distribution,
            'training_timeline': training_timeline,
            'feature_importance': top_features,
            'current_accuracy': stats.get('accuracy', 0),
            'total_samples': stats.get('total_samples', 0)
        }