"""
Turnstile Training Module
Handles learning from Cloudflare Turnstile decisions
"""
import json
import pickle
import os
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from threading import Thread, Event
import time

from ..config import AIAntiBotConfig
from ..models.detector import BotDetector
from ..features.extractor import FeatureExtractor

logger = logging.getLogger(__name__)

class TurnstileTrainer:
    """Handles training the AI-AntiBot model from Cloudflare Turnstile data"""
    
    def __init__(self):
        self.detector = BotDetector()
        self.feature_extractor = FeatureExtractor()
        self.training_data = []
        self.training_labels = []
        self.live_training_log = []
        self.training_thread = None
        self.stop_training = Event()
        
        # Load existing training data
        self.load_training_data()
    
    def load_training_data(self):
        """Load existing training data from disk"""
        try:
            training_data_path = AIAntiBotConfig.get_training_data_path()
            data_file = os.path.join(training_data_path, 'training_data.pkl')
            labels_file = os.path.join(training_data_path, 'training_labels.pkl')
            
            if os.path.exists(data_file) and os.path.exists(labels_file):
                with open(data_file, 'rb') as f:
                    self.training_data = pickle.load(f)
                with open(labels_file, 'rb') as f:
                    self.training_labels = pickle.load(f)
                
                logger.info(f"Loaded {len(self.training_data)} training samples")
            else:
                logger.info("No existing training data found")
                
        except Exception as e:
            logger.error(f"Failed to load training data: {e}")
    
    def save_training_data(self):
        """Save training data to disk"""
        try:
            training_data_path = AIAntiBotConfig.get_training_data_path()
            os.makedirs(training_data_path, exist_ok=True)
            
            data_file = os.path.join(training_data_path, 'training_data.pkl')
            labels_file = os.path.join(training_data_path, 'training_labels.pkl')
            
            with open(data_file, 'wb') as f:
                pickle.dump(self.training_data, f)
            with open(labels_file, 'wb') as f:
                pickle.dump(self.training_labels, f)
            
            logger.info(f"Saved {len(self.training_data)} training samples")
            
        except Exception as e:
            logger.error(f"Failed to save training data: {e}")
    
    def record_interaction(self, request, client_ip: str, turnstile_passed: bool, 
                          session_data: Dict = None, extra_context: Dict = None) -> Dict[str, Any]:
        """Record a user interaction with Turnstile verdict for training"""
        try:
            # Extract features from the interaction
            features = self.feature_extractor.extract_request_features(request, client_ip, session_data)
            feature_vector = self.feature_extractor.get_feature_vector(features)
            
            # Label: 0 = human (passed Turnstile), 1 = bot (failed Turnstile)
            label = 0 if turnstile_passed else 1
            
            # Add to training data
            self.training_data.append(feature_vector)
            self.training_labels.append(label)
            
            # Create training log entry
            log_entry = {
                'timestamp': datetime.now().isoformat(),
                'ip_hash': features.get('ip_hash', 'unknown'),
                'turnstile_passed': turnstile_passed,
                'label': 'human' if label == 0 else 'bot',
                'features_count': len(features),
                'session_duration': session_data.get('session_duration', 0) if session_data else 0,
                'user_agent_length': features.get('user_agent_length', 0),
                'total_interactions': features.get('total_interactions', 0),
                'extra_context': extra_context or {}
            }
            
            # Add to live training log (keep last 1000 entries)
            self.live_training_log.append(log_entry)
            if len(self.live_training_log) > 1000:
                self.live_training_log.pop(0)
            
            # Save training data every 10 samples for real-time operation
            if len(self.training_data) % 10 == 0:
                self.save_training_data()
            
            # Real-time training - trigger immediately after each sample
            if len(self.training_data) >= AIAntiBotConfig.TRAINING_THRESHOLD:
                if not self.training_thread or not self.training_thread.is_alive():
                    self.start_background_training()
            
            logger.debug(f"Recorded training sample: {client_ip} -> {label} (total: {len(self.training_data)})")
            
            return {
                'success': True,
                'total_samples': len(self.training_data),
                'log_entry': log_entry
            }
            
        except Exception as e:
            logger.error(f"Failed to record interaction: {e}")
            return {'success': False, 'error': str(e)}
    
    def start_background_training(self):
        """Start background training thread"""
        if self.training_thread and self.training_thread.is_alive():
            return
        
        self.stop_training.clear()
        self.training_thread = Thread(target=self._background_training_loop)
        self.training_thread.daemon = True
        self.training_thread.start()
        
        logger.info("Started background training thread")
    
    def stop_background_training(self):
        """Stop background training thread"""
        self.stop_training.set()
        if self.training_thread:
            self.training_thread.join(timeout=30)
    
    def _background_training_loop(self):
        """Background training loop"""
        while not self.stop_training.is_set():
            try:
                # Check if we have enough data for training
                if len(self.training_data) >= AIAntiBotConfig.TRAINING_THRESHOLD:
                    # Train the model
                    result = self.detector.train_model(self.training_data, self.training_labels)
                    
                    if result['success']:
                        logger.info(f"Background training completed: {result['stats']}")
                        
                        # Add training result to live log
                        self.live_training_log.append({
                            'timestamp': datetime.now().isoformat(),
                            'event': 'model_training',
                            'success': True,
                            'stats': result['stats'],
                            'training_samples': result['training_samples']
                        })
                    else:
                        logger.error(f"Background training failed: {result.get('error', 'Unknown error')}")
                
                # For real-time training, wait shorter intervals
                self.stop_training.wait(5)  # Wait 5 seconds between training cycles
                
            except Exception as e:
                logger.error(f"Background training error: {e}")
                self.stop_training.wait(300)  # Wait 5 minutes on error
    
    def get_training_stats(self) -> Dict[str, Any]:
        """Get current training statistics"""
        stats = {
            'total_samples': len(self.training_data),
            'total_labels': len(self.training_labels),
            'human_samples': sum(1 for label in self.training_labels if label == 0),
            'bot_samples': sum(1 for label in self.training_labels if label == 1),
            'training_active': self.training_thread and self.training_thread.is_alive(),
            'last_training_time': None
        }
        
        # Add model stats
        model_stats = self.detector.get_model_stats()
        stats.update(model_stats)
        
        # Calculate class distribution
        if stats['total_samples'] > 0:
            stats['human_ratio'] = stats['human_samples'] / stats['total_samples']
            stats['bot_ratio'] = stats['bot_samples'] / stats['total_samples']
            stats['class_balance'] = min(stats['human_ratio'], stats['bot_ratio']) / max(stats['human_ratio'], stats['bot_ratio'])
        
        return stats
    
    def get_live_training_log(self, limit: int = 100) -> List[Dict[str, Any]]:
        """Get recent training log entries"""
        return self.live_training_log[-limit:] if self.live_training_log else []
    
    def force_training(self) -> Dict[str, Any]:
        """Force immediate training with current data"""
        try:
            if len(self.training_data) < 10:
                return {'success': False, 'error': 'Insufficient training data'}
            
            result = self.detector.train_model(self.training_data, self.training_labels)
            
            if result['success']:
                # Add training result to live log
                self.live_training_log.append({
                    'timestamp': datetime.now().isoformat(),
                    'event': 'forced_training',
                    'success': True,
                    'stats': result['stats'],
                    'training_samples': result['training_samples']
                })
            
            return result
            
        except Exception as e:
            logger.error(f"Forced training failed: {e}")
            return {'success': False, 'error': str(e)}
    
    def clear_training_data(self):
        """Clear all training data (for testing purposes)"""
        self.training_data = []
        self.training_labels = []
        self.live_training_log = []
        self.save_training_data()
        logger.info("Cleared all training data")
    
    def export_training_data(self, format: str = 'json') -> str:
        """Export training data for analysis"""
        try:
            if format == 'json':
                data = {
                    'training_data': [arr.tolist() for arr in self.training_data],
                    'training_labels': self.training_labels,
                    'export_timestamp': datetime.now().isoformat(),
                    'total_samples': len(self.training_data)
                }
                return json.dumps(data, indent=2)
            
            elif format == 'csv':
                import csv
                import io
                
                output = io.StringIO()
                writer = csv.writer(output)
                
                # Write header
                feature_names = ['feature_' + str(i) for i in range(len(self.training_data[0]))]
                writer.writerow(feature_names + ['label'])
                
                # Write data
                for i, features in enumerate(self.training_data):
                    row = features.tolist() + [self.training_labels[i]]
                    writer.writerow(row)
                
                return output.getvalue()
            
            else:
                raise ValueError(f"Unsupported format: {format}")
                
        except Exception as e:
            logger.error(f"Failed to export training data: {e}")
            return ""