"""
Bot Detection Model
Main ML model for detecting bots using learned patterns from Cloudflare Turnstile
"""
import pickle
import numpy as np
import os
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, Tuple, Optional
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import joblib

from ..config import AIAntiBotConfig
from ..features.extractor import FeatureExtractor

logger = logging.getLogger(__name__)

class BotDetector:
    """Main bot detection model using ensemble learning"""
    
    def __init__(self):
        self.model = None
        self.scaler = None
        self.feature_extractor = FeatureExtractor()
        self.model_version = AIAntiBotConfig.MODEL_VERSION
        self.last_training_time = None
        self.training_stats = {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0,
            'total_samples': 0,
            'last_updated': None
        }
        
        # Load existing model if available
        self.load_model()
    
    def load_model(self) -> bool:
        """Load trained model and scaler from disk"""
        try:
            model_path = AIAntiBotConfig.get_model_path()
            scaler_path = AIAntiBotConfig.get_scaler_path()
            
            if os.path.exists(model_path) and os.path.exists(scaler_path):
                self.model = joblib.load(model_path)
                self.scaler = joblib.load(scaler_path)
                
                # Load training stats if available
                stats_path = model_path.replace('.pkl', '_stats.pkl')
                if os.path.exists(stats_path):
                    with open(stats_path, 'rb') as f:
                        self.training_stats = pickle.load(f)
                
                logger.info(f"Loaded AI-AntiBot model from {model_path}")
                return True
            else:
                logger.info("No existing model found, will train new model")
                return False
                
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            return False
    
    def save_model(self) -> bool:
        """Save trained model and scaler to disk"""
        try:
            # Ensure directory exists
            os.makedirs(os.path.dirname(AIAntiBotConfig.get_model_path()), exist_ok=True)
            
            model_path = AIAntiBotConfig.get_model_path()
            scaler_path = AIAntiBotConfig.get_scaler_path()
            
            joblib.dump(self.model, model_path)
            joblib.dump(self.scaler, scaler_path)
            
            # Save training stats
            stats_path = model_path.replace('.pkl', '_stats.pkl')
            with open(stats_path, 'wb') as f:
                pickle.dump(self.training_stats, f)
            
            logger.info(f"Saved AI-AntiBot model to {model_path}")
            return True
            
        except Exception as e:
            logger.error(f"Failed to save model: {e}")
            return False
    
    def train_model(self, training_data: list, labels: list) -> Dict[str, Any]:
        """Train the bot detection model"""
        try:
            if len(training_data) < AIAntiBotConfig.TRAINING_THRESHOLD:
                logger.warning(f"Insufficient training data: {len(training_data)} < {AIAntiBotConfig.TRAINING_THRESHOLD}")
                return {'success': False, 'error': 'Insufficient training data'}
            
            # Convert to numpy arrays
            X = np.array(training_data)
            y = np.array(labels)
            
            # Split data
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=42, stratify=y
            )
            
            # Scale features
            self.scaler = StandardScaler()
            X_train_scaled = self.scaler.fit_transform(X_train)
            X_test_scaled = self.scaler.transform(X_test)
            
            # Train ensemble model
            self.model = self._create_ensemble_model()
            self.model.fit(X_train_scaled, y_train)
            
            # Evaluate model
            y_pred = self.model.predict(X_test_scaled)
            
            # Calculate metrics
            accuracy = accuracy_score(y_test, y_pred)
            precision = precision_score(y_test, y_pred, average='weighted')
            recall = recall_score(y_test, y_pred, average='weighted')
            f1 = f1_score(y_test, y_pred, average='weighted')
            
            # Update training stats
            self.training_stats = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1_score': f1,
                'total_samples': len(training_data),
                'last_updated': datetime.now().isoformat()
            }
            
            self.last_training_time = datetime.now()
            
            # Save model
            self.save_model()
            
            logger.info(f"Model trained successfully - Accuracy: {accuracy:.3f}, F1: {f1:.3f}")
            
            return {
                'success': True,
                'stats': self.training_stats,
                'training_samples': len(training_data),
                'test_samples': len(X_test)
            }
            
        except Exception as e:
            logger.error(f"Model training failed: {e}")
            return {'success': False, 'error': str(e)}
    
    def _create_ensemble_model(self):
        """Create ensemble model combining multiple algorithms"""
        from sklearn.ensemble import VotingClassifier
        
        # Define base models
        rf_model = RandomForestClassifier(
            n_estimators=100,
            max_depth=10,
            min_samples_split=5,
            min_samples_leaf=2,
            random_state=42
        )
        
        gb_model = GradientBoostingClassifier(
            n_estimators=100,
            learning_rate=0.1,
            max_depth=6,
            random_state=42
        )
        
        # Create voting classifier
        ensemble = VotingClassifier(
            estimators=[
                ('rf', rf_model),
                ('gb', gb_model)
            ],
            voting='soft'
        )
        
        return ensemble
    
    def predict(self, request, client_ip: str, session_data: Dict = None, 
                deployment=None, mode='learn') -> Tuple[bool, float, Dict]:
        """Predict if request is from a bot"""
        try:
            if not self.model or not self.scaler:
                logger.warning("Model not trained, cannot make prediction")
                return False, 0.0, {'error': 'Model not trained', 'mode': mode}
            
            # Extract features
            features = self.feature_extractor.extract_request_features(request, client_ip, session_data)
            feature_vector = self.feature_extractor.get_feature_vector(features)
            
            # Scale features
            feature_vector_scaled = self.scaler.transform(feature_vector.reshape(1, -1))
            
            # Make prediction
            prediction = self.model.predict(feature_vector_scaled)[0]
            probability = self.model.predict_proba(feature_vector_scaled)[0]
            
            # Get confidence (probability of predicted class)
            confidence = probability[1] if prediction == 1 else probability[0]
            
            is_bot = bool(prediction)
            
            # Log decision if in decision-making mode
            if mode == 'decide' and deployment:
                self._log_ai_decision(request, client_ip, deployment, is_bot, confidence, features)
            
            prediction_details = {
                'prediction': is_bot,
                'confidence': confidence,
                'probability_human': probability[0],
                'probability_bot': probability[1],
                'features_extracted': len(features),
                'model_version': self.model_version,
                'mode': mode
            }
            
            logger.debug(f"Prediction for {client_ip}: bot={is_bot}, confidence={confidence:.3f}, mode={mode}")
            
            return is_bot, confidence, prediction_details
            
        except Exception as e:
            logger.error(f"Prediction failed: {e}")
            return False, 0.0, {'error': str(e), 'mode': mode}
    
    def _log_ai_decision(self, request, client_ip: str, deployment, is_bot: bool, 
                        confidence: float, features: Dict):
        """Log AI decision to IP Management page"""
        try:
            from app import db
            from models import AIDecisionLog
            
            # Determine decision type
            if is_bot and confidence >= 0.9:
                decision_type = 'block'
            elif is_bot and confidence >= 0.7:
                decision_type = 'suspicious'
            else:
                decision_type = 'allow'
            
            # Create AI decision log
            decision_log = AIDecisionLog(
                ip_address=client_ip,
                deployed_app_id=deployment.id if deployment else None,
                decision_type=decision_type,
                confidence=confidence,
                ai_reason=f"AI-AntiBot decision: {decision_type} (confidence: {confidence:.3f})",
                user_agent=request.headers.get('User-Agent', ''),
            )
            decision_log.set_feature_data(features)
            
            db.session.add(decision_log)
            db.session.commit()
            
            logger.info(f"AI decision logged: {decision_type} for {client_ip} with confidence {confidence:.3f}")
            
        except Exception as e:
            logger.error(f"Failed to log AI decision: {e}")
    
    def get_model_stats(self) -> Dict[str, Any]:
        """Get current model statistics"""
        stats = self.training_stats.copy()
        
        # Add model status
        stats.update({
            'model_loaded': self.model is not None,
            'scaler_loaded': self.scaler is not None,
            'model_version': self.model_version,
            'last_training_time': self.last_training_time.isoformat() if self.last_training_time else None,
            'needs_retraining': self._needs_retraining(),
            'feature_count': len(self.feature_extractor.get_feature_vector({}))
        })
        
        return stats
    
    def _needs_retraining(self) -> bool:
        """Check if model needs retraining"""
        if not self.last_training_time:
            return True
        
        time_since_training = datetime.now() - self.last_training_time
        return time_since_training > AIAntiBotConfig.RETRAINING_INTERVAL
    
    def get_feature_importance(self) -> Dict[str, float]:
        """Get feature importance from the trained model"""
        if not self.model:
            return {}
        
        try:
            # Get feature names
            feature_names = [
                'request_method', 'has_referrer', 'content_length', 'timestamp',
                'time_since_first_request', 'time_since_last_request', 'total_requests',
                'recent_requests', 'avg_request_interval', 'session_duration',
                'page_load_time', 'interaction_delay', 'form_fill_time',
                'header_count', 'has_accept_language', 'has_accept_encoding',
                'has_connection', 'has_cache_control', 'has_dnt', 'common_headers_ratio',
                'language_count', 'has_quality_values', 'user_agent_length',
                'has_user_agent', 'ua_length', 'has_bot_keywords', 'bot_keyword_count',
                'has_legitimate_browser', 'ua_entropy', 'ua_complexity',
                'mouse_movements', 'keyboard_events', 'scroll_events', 'click_events',
                'screen_width', 'screen_height', 'color_depth', 'timezone_offset',
                'total_interactions', 'has_human_behavior', 'interaction_diversity',
                'is_private_ip', 'ip_version', 'has_x_forwarded_for', 'has_x_real_ip',
                'proxy_chain_length', 'has_proxy_chain'
            ]
            
            # Handle different model types
            if hasattr(self.model, 'named_estimators_'):
                # VotingClassifier - get importance from Random Forest component
                rf_model = self.model.named_estimators_['rf']
                importance_values = rf_model.feature_importances_
            elif hasattr(self.model, 'feature_importances_'):
                # Direct RandomForestClassifier
                importance_values = self.model.feature_importances_
            else:
                # Model doesn't support feature importance
                return {}
            
            # Create feature importance dictionary
            feature_importance = dict(zip(feature_names, importance_values))
            
            # Sort by importance
            sorted_features = sorted(feature_importance.items(), key=lambda x: x[1], reverse=True)
            
            return dict(sorted_features)
            
        except Exception as e:
            logger.error(f"Failed to get feature importance: {e}")
            return {}