import cv2
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image
import os
import ssl

# Fix SSL context for internal downloads if needed
ssl._create_default_https_context = ssl._create_unverified_context

class CartoonEngine:
    def __init__(self):
        self.device = 'cpu' # Force CPU as requested
        self.models = {} # Cache for loaded models
        
        # Configure local cache for TorchHub models
        project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        models_dir = os.path.join(project_root, 'models')
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
        torch.hub.set_dir(models_dir)
        print(f"TorchHub cache directory set to: {models_dir}")

    def _load_torch_model(self, style_key):
        """Charge et cache le modèle TorchHub si nécessaire."""
        if style_key in self.models:
            return self.models[style_key]
        
        print(f"Loading model for {style_key}...")
        try:
            # Check if style_key is a path to a custom model
            if style_key.endswith('.pt') and os.path.exists(style_key):
                print(f"Loading custom model from {style_key}")
                model = torch.hub.load(
                    'bryandlee/animegan2-pytorch:main', 
                    'generator', 
                    pretrained=None, 
                    device=self.device
                )
                
                try:
                    # Load state dict
                    state_dict = torch.load(style_key, map_location=self.device)
                    
                    # Handle state dicts nested in 'generator', 'model', or 'model_state_dict' keys
                    if 'generator' in state_dict:
                        state_dict = state_dict['generator']
                    elif 'model' in state_dict:
                        state_dict = state_dict['model']
                    elif 'model_state_dict' in state_dict:
                        state_dict = state_dict['model_state_dict']
                    
                    # Clean keys: try removing typical prefixes
                    model_keys = set(model.state_dict().keys())
                    new_state_dict = {}
                    
                    # Debug: print first key of file
                    file_keys = list(state_dict.keys())
                    if file_keys:
                        print(f"File first key: {file_keys[0]}")
                        
                    for k, v in state_dict.items():
                        # Try exact match
                        if k in model_keys:
                            new_state_dict[k] = v
                            continue
                            
                        # Try removing prefixes
                        for prefix in ["module.", "generator.", "model.", "g_ema."]:
                            if k.startswith(prefix):
                                k_clean = k[len(prefix):]
                                if k_clean in model_keys:
                                    new_state_dict[k_clean] = v
                                    break
                    
                    # 1. First attempt: Standard Architecture (bryandlee)
                    if len(new_state_dict) > 0:
                        missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
                        print(f"Keys matched (Standard): {len(new_state_dict)}/{len(model_keys)}")
                    else:
                        raise RuntimeError("No keys matched standard architecture")

                except Exception as e:
                    print(f"Standard load failed ({e}), trying Legacy (ptran1203)...")
                    try:
                        # 2. Fallback: Legacy Architecture (ptran1203)
                        # Dynamic import to avoid crash if file missing
                        import sys
                        sys.path.append(os.path.join(os.path.dirname(__file__)))
                        from anime_gan_v2 import GeneratorV2
                        
                        legacy_model = GeneratorV2().to(self.device)
                        
                        # Reload state dict fresh
                        state_dict = torch.load(style_key, map_location=self.device)
                        if 'generator' in state_dict: state_dict = state_dict['generator']
                        elif 'model' in state_dict: state_dict = state_dict['model']
                        elif 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict']

                        # Simple Key Cleaning for Legacy
                        legacy_keys = set(legacy_model.state_dict().keys())
                        clean_dict = {}
                        for k, v in state_dict.items():
                            k_cl = k.replace("module.", "")
                            if k_cl in legacy_keys:
                                clean_dict[k_cl] = v
                        
                        if len(clean_dict) == 0:
                            raise RuntimeError("No keys matched Legacy architecture either.")
                            
                        legacy_model.load_state_dict(clean_dict, strict=False)
                        legacy_model.eval()
                        print(f"Loaded using Legacy Architecture. Keys matched: {len(clean_dict)}")
                        
                        self.models[style_key] = legacy_model
                        return legacy_model
                        
                    except Exception as e2:
                        print(f"Legacy load failed too: {e2}")
                        raise e # Raise original or new error
                
                model.eval()
                self.models[style_key] = model
                return model

            # Standard TorchHub models
            weight_map = {
                "Manga (K-ON)": "celeba_distill",
                "Ghibli (Paprika)": "paprika",
                "Shinkai": "face_paint_512_v2" 
            }
            
            pretrained_name = weight_map.get(style_key, "paprika")
            
            model = torch.hub.load(
                'bryandlee/animegan2-pytorch:main', 
                'generator', 
                pretrained=pretrained_name, 
                device=self.device
            )
            model.eval()
            self.models[style_key] = model
            return model
        except Exception as e:
            print(f"Error loading model: {e}")
            return None

    def process_opencv(self, input_path, output_path, intensity=0.5):
        """
        Effet Comics/BD via OpenCV pur.
        intensity: float entre 0.0 et 1.0. Contrôle la force du lissage.
        """
        img = cv2.imread(input_path)
        if img is None:
            raise FileNotFoundError(f"Impossible de lire: {input_path}")

        # Map intensity (0.0-1.0) to bilateral filter parameters
        # sigmaColor: 10 (low) to 150 (high)
        # sigmaSpace: 10 (low) to 150 (high)
        sigma = int(10 + intensity * 140) 

        # 1. Edges (Contours)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray = cv2.medianBlur(gray, 5)
        edges = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 9)

        # 2. Color (Lissage)
        # d=9 is standard, sigma controls the "flatness"
        color = cv2.bilateralFilter(img, 9, sigma, sigma)

        # 3. Combine
        cartoon = cv2.bitwise_and(color, color, mask=edges)

        cv2.imwrite(output_path, cartoon)
        return output_path

    def process_sketch(self, input_path, output_path):
        """Effet Crayon (Sketch) N&B via OpenCV."""
        img = cv2.imread(input_path)
        if img is None:
            raise FileNotFoundError(f"Impossible de lire: {input_path}")
            
        # 1. Convert to Gray
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        # 2. Invert
        inv = cv2.bitwise_not(gray)
        
        # 3. Gaussian Blur
        blur = cv2.GaussianBlur(inv, (21, 21), 0)
        
        # 4. Color Dodge blend (Gray / (255 - Blur))
        def dodge(image, mask):
            return cv2.divide(image, 255 - mask, scale=256)
            
        sketch = dodge(gray, blur)
        
        cv2.imwrite(output_path, sketch)
        return output_path

    def process_pixel_art(self, input_path, output_path, block_size=10):
        """Effet Pixel Art par downscaling/upscaling."""
        img = Image.open(input_path)
        
        # Calculate new small size
        w, h = img.size
        small_w = w // block_size
        small_h = h // block_size
        
        # Resize down (losses detail)
        img_small = img.resize((small_w, small_h), Image.Resampling.BILINEAR)
        
        # Resize up (Visual pixelation)
        img_pixel = img_small.resize((w, h), Image.Resampling.NEAREST)
        
        img_pixel.save(output_path)
        return output_path

    def reduce_orange_tint(self, image_path, output_path, strength=0.5):
        """Réduit la dominante orange dans une image.
        
        Args:
            image_path: Chemin de l'image source
            output_path: Chemin de sortie
            strength: Intensité de la correction (0.0 = aucune, 1.0 = maximale)
        """
        import cv2
        import numpy as np
        
        # Lire l'image
        img = cv2.imread(image_path)
        if img is None:
            raise FileNotFoundError(f"Impossible de lire: {image_path}")
        
        # Convertir en HSV
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
        
        # Orange/Rouges: plage étendue de 0 à 45 ET 170-180 pour attraper les rouges vifs/roses du visage
        h, s, v = cv2.split(hsv)
        
        # Créer un masque pour les tons oranges/rouges (0-45)
        # OpenCV Hue est 0-180. Rouges=0-10, Oranges=10-25, Jaunes=25-35
        mask1 = np.logical_and(h >= 0, h <= 50).astype(np.float32)
        
        # Ajouter les rouges/roses en fin de spectre (170-180)
        mask2 = (h >= 170).astype(np.float32)
        
        orange_mask = np.maximum(mask1, mask2)
        
        # Réduction plus agressive de la saturation (jusqu'à 70% en moins)
        s = s * (1 - orange_mask * strength * 0.7)
        
        # Shift de teinte : On pousse légèrement vers le jaune
        # Pour le masque 1 (0-50), on ajoute pour aller vers le jaune
        h = h + mask1 * strength * 5
        
        # Pour le masque 2 (170-180), on veut aussi aller vers le jaune (donc wrapper vers 0 puis monter)
        # C'est compliqué à faire proprement sans artefact de wrapping.
        h = np.clip(h, 0, 180)
        
        # Recombiner
        hsv_fixed = cv2.merge([h, s, v])
        result = cv2.cvtColor(hsv_fixed.astype(np.uint8), cv2.COLOR_HSV2BGR)
        
        cv2.imwrite(output_path, result)
        return output_path

    def create_reveal_video(self, original_path, cartoon_path, output_video_path, duration=3.0, fps=30):
        """Crée une animation vidéo (Reveal) : Originale -> Cartoon.
        Use 'mp4v' codec accessible by default on most systems without ffmpeg.
        """
        import cv2
        import numpy as np
        
        img_orig = cv2.imread(original_path)
        img_cart = cv2.imread(cartoon_path)
        
        if img_orig is None or img_cart is None:
            raise FileNotFoundError("Images sources introuvables pour la vidéo.")
            
        # Ensure same size (Cartoon might be slightly different due to 32px align if not fixed, but we fixed it)
        # Just in case, resize cartoon to match original
        h, w = img_orig.shape[:2]
        img_cart = cv2.resize(img_cart, (w, h))
        
        total_frames = int(duration * fps)
        
        # Codec setup
        fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
        out = cv2.VideoWriter(output_video_path, fourcc, fps, (w, h))
        
        print(f"Generating video: {w}x{h} @ {fps}fps, {total_frames} frames")
        
        for i in range(total_frames):
            # Progress 0.0 -> 1.0
            p = i / total_frames
            
            # Curtain position (x coordinate)
            split_x = int(p * w)
            
            # Combine: Left = Cartoon (Revealed), Right = Original
            # Create frame copy
            frame = img_orig.copy()
            
            if split_x > 0:
                frame[:, :split_x] = img_cart[:, :split_x]
                
            # Draw lines/separator
            if 0 < split_x < w:
                cv2.line(frame, (split_x, 0), (split_x, h), (255, 255, 255), 2)
                
            out.write(frame)
            
        # Hold final frame
        for _ in range(int(fps * 1)): # 1 sec hold
            out.write(img_cart)
            
        out.release()
        return output_video_path
    
    def process_torch(self, input_path, output_path, style, orange_correction=None, hd_mode=False):
        """Effet Anime via TorchHub.
        
        Args:
            orange_correction: Si fourni, intensité de la correction orange (0.0-1.0)
            hd_mode: Si True, utilise une résolution plus élevée pour le traitement
        """
        model = self._load_torch_model(style)
        if not model:
            raise RuntimeError(f"Impossible de charger le modèle pour {style}")

        img = Image.open(input_path).convert("RGB")
        w, h = img.size
        
        # Optimization: Resize if too big (keep short side max 800 approx for CPU speed)
        # But user wants high quality. Let's limit slightly less aggressively than test script.
        # Max 1024px standard, 2560px for HD Mode
        max_dim = 2560 if hd_mode else 1024
        
        ratio = max_dim / max(w, h)
        if ratio < 1:
            new_w, new_h = int(w * ratio), int(h * ratio)
        else:
            new_w, new_h = w, h
            
        # Store exact target dimensions to restore ratio later
        target_w, target_h = new_w, new_h
            
        # Align to 32 for UNet architectures (required for inference)
        align_w = new_w - (new_w % 32)
        align_h = new_h - (new_h % 32)
        
        img_resized = img.resize((align_w, align_h), Image.Resampling.LANCZOS)
        
        input_tensor = to_tensor(img_resized).unsqueeze(0) * 2 - 1
        
        with torch.no_grad():
            output_tensor = model(input_tensor)
            
        output_img = to_pil_image((output_tensor[0] * 0.5 + 0.5).clip(0, 1))
        
        # RESTORE EXACT RATIO: Resize back to target dimensions (undo 32-alignment distortion)
        if output_img.size != (target_w, target_h):
            output_img = output_img.resize((target_w, target_h), Image.Resampling.LANCZOS)
            
        output_img.save(output_path)
        
        # Post-traitement: correction couleur pour Paprika (si demandé)
        if orange_correction is not None and orange_correction > 0:
            print(f"Applying color correction (strength={orange_correction:.2f})...")
            base, ext = os.path.splitext(output_path)
            temp_output = f"{base}_temp{ext}"
            os.rename(output_path, temp_output)
            self.reduce_orange_tint(temp_output, output_path, strength=orange_correction)
            os.remove(temp_output)
        
        return output_path
