import customtkinter as ctk
from tkinter import filedialog, messagebox
from tkinterdnd2 import TkinterDnD, DND_FILES
from PIL import Image, ImageTk
import os
import threading
import json
import shutil

# Configuration du thème
ctk.set_appearance_mode("Dark")
ctk.set_default_color_theme("blue")

class CartoonApp(ctk.CTk, TkinterDnD.DnDWrapper):
    def __init__(self):
        super().__init__()
        self.TkdndVersion = TkinterDnD._require(self)

        # Configuration de la fenêtre principale
        self.title("Cartoonify-Me")
        self.geometry("1100x700")
        
        # Configuration de la grille (Layout)
        self.grid_columnconfigure(1, weight=1)
        self.grid_rowconfigure(0, weight=1)

        # Variables d'état
        self.original_image_path = None
        self.processed_image = None
        self.current_image = None  # PIL Image object for display
        
        # Config models custom (Must be done before create_sidebar)
        self.custom_models_dir = os.path.join("models", "custom")
        if not os.path.exists(self.custom_models_dir):
            os.makedirs(self.custom_models_dir)
        
        # Persistence de la configuration
        self.config_file = "config.json"
        self.load_config()
        self.protocol("WM_DELETE_WINDOW", self.on_closing)

        self.create_sidebar()
        self.create_main_view()

        # Initialisation de l'état du slider
        self.change_style_event(self.style_var.get())
        
        # Restore slider value if saved (change_style_event might have reset it)
        if hasattr(self, 'saved_slider'):
             self.slider_intensity.set(self.saved_slider)

        # Drag & Drop Configuration
        self.drop_target_register(DND_FILES)
        self.dnd_bind('<<Drop>>', self.drop_event)

    def drop_event(self, event):
        """Gère le fichier déposé sur la fenêtre."""
        file_path = event.data
        # Clean path formatting often messed up by DnD in Windows (braces {})
        if file_path.startswith('{') and file_path.endswith('}'):
            file_path = file_path[1:-1]
            
        if os.path.isfile(file_path):
            valid_exts = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
            if file_path.lower().endswith(valid_exts):
                self.original_image_path = file_path
                self.display_image(file_path, self.label_image_orig)
                # Reset result
                self.label_image_proc.configure(image=None, text="Le résultat apparaîtra ici")
                self.label_image_hd.configure(image=None, text="")
                self.btn_save.configure(state="disabled")
                self.btn_video.configure(state="disabled")
            else:
                messagebox.showwarning("Format invalide", "Veuillez déposer une image (JPG, PNG).")

    def get_style_list(self):
        """Récupère la liste des styles de base + les modèles customs."""
        base_styles = ["Comics (OpenCV)", "Crayon (Sketch)", "Pixel Art", "Manga (K-ON)", "Ghibli (Paprika)", "Shinkai"]
        custom_styles = []
        if os.path.exists(self.custom_models_dir):
            for f in os.listdir(self.custom_models_dir):
                if f.endswith(".pt"):
                    # Nom du style = nom du fichier sans extension, préfixé par [Custom]
                    custom_styles.append(f"[Custom] {os.path.splitext(f)[0]}")
        return base_styles + custom_styles

    def create_sidebar(self):
        """Crée la barre latérale avec les contrôles."""
        self.sidebar_frame = ctk.CTkFrame(self, width=200, corner_radius=0)
        self.sidebar_frame.grid(row=0, column=0, sticky="nsew")
        self.sidebar_frame.grid_rowconfigure(6, weight=1) # Spacer

        # Titre
        self.logo_label = ctk.CTkLabel(self.sidebar_frame, text="Cartoonify-Me", font=ctk.CTkFont(size=20, weight="bold"))
        self.logo_label.grid(row=0, column=0, padx=20, pady=(20, 10))

        # Bouton Importer
        self.btn_import = ctk.CTkButton(self.sidebar_frame, text="Importer une photo", command=self.load_image)
        self.btn_import.grid(row=1, column=0, padx=20, pady=10)

        # Choix du Style
        self.label_style = ctk.CTkLabel(self.sidebar_frame, text="Choisir un style :", anchor="w")
        self.label_style.grid(row=2, column=0, padx=20, pady=(10, 0))
        
        # Determine initial style (check if saved style is still valid)
        initial_style = getattr(self, 'saved_style', "Comics (OpenCV)")
        available_styles = self.get_style_list()
        if initial_style not in available_styles:
             initial_style = "Comics (OpenCV)"
             
        self.style_var = ctk.StringVar(value=initial_style)
        self.option_style = ctk.CTkOptionMenu(self.sidebar_frame, 
                                            values=available_styles, 
                                            variable=self.style_var,
                                            command=self.change_style_event)
        self.option_style.grid(row=3, column=0, padx=20, pady=10)

        # Bouton Import Modèle
        self.btn_import_model = ctk.CTkButton(self.sidebar_frame, text="Importer un Style (IA)...", command=self.import_custom_model, fg_color="gray", hover_color="gray30")
        self.btn_import_model = ctk.CTkButton(self.sidebar_frame, text="Importer un Style (IA)...", command=self.import_custom_model, fg_color="gray", hover_color="gray30")
        self.btn_import_model.grid(row=13, column=0, padx=20, pady=10)

        # Bouton Batch
        self.btn_batch = ctk.CTkButton(self.sidebar_frame, text="Traitement par Lot (Dossier)", command=self.select_batch_folder, fg_color="#3B8ED0", hover_color="#36719F")
        self.btn_batch.grid(row=10, column=0, padx=20, pady=10)

        # Slider Polyvalent
        self.slider_frame = ctk.CTkFrame(self.sidebar_frame, fg_color="transparent")
        self.slider_frame.grid(row=4, column=0, padx=20, pady=(0, 10), sticky="ew")
        
        self.label_intensity = ctk.CTkLabel(self.slider_frame, text="Intensité :", anchor="w")
        self.label_intensity.pack(fill="x")
        self.label_intensity.pack(fill="x")
        self.slider_intensity = ctk.CTkSlider(self.slider_frame, from_=0.0, to=1.0, number_of_steps=20, command=self.on_slider_change)
        self.slider_intensity.pack(fill="x")
        self.slider_intensity.set(0.5)
        
        # Switch Mode HD
        self.switch_hd = ctk.CTkSwitch(self.sidebar_frame, text="Mode HD (Plus lent)", onvalue=True, offvalue=False)
        self.switch_hd.grid(row=9, column=0, padx=20, pady=10)

        # Bouton Transformer
        self.btn_process = ctk.CTkButton(self.sidebar_frame, text="Transformer !", command=self.start_processing, fg_color="green", hover_color="darkgreen")
        self.btn_process.grid(row=5, column=0, padx=20, pady=20)

        # Barre de progression
        self.progressbar = ctk.CTkProgressBar(self.sidebar_frame, mode="indeterminate")
        self.progressbar.grid(row=6, column=0, padx=20, pady=10)
        self.progressbar.set(0)

        # Bouton Exporter
        self.btn_save = ctk.CTkButton(self.sidebar_frame, text="Sauvegarder", command=self.save_image, state="disabled")
        self.btn_save.grid(row=7, column=0, padx=20, pady=(20, 10))

        # Bouton Vidéo
        self.btn_video = ctk.CTkButton(self.sidebar_frame, text="Créer Vidéo (MP4)", command=self.create_video_action, state="disabled", fg_color="#E59400", hover_color="#B57600")
        self.btn_video.grid(row=8, column=0, padx=20, pady=(0, 20))

    def create_main_view(self):
        """Crée la zone principale d'affichage des images."""
        self.main_frame = ctk.CTkFrame(self, corner_radius=10)
        self.main_frame.grid(row=0, column=1, padx=20, pady=20, sticky="nsew")
        
        # Configuration grille interne (2 colonnes)
        self.main_frame.grid_columnconfigure(0, weight=1)
        self.main_frame.grid_columnconfigure(1, weight=1)
        self.main_frame.grid_rowconfigure(0, weight=1)

        # Cadre Image Originale
        self.frame_original = ctk.CTkFrame(self.main_frame, fg_color="transparent")
        self.frame_original.grid(row=0, column=0, padx=10, pady=10, sticky="nsew")
        
        self.label_title_orig = ctk.CTkLabel(self.frame_original, text="Originale", font=("Arial", 14, "bold"))
        self.label_title_orig.pack(pady=5)
        
        self.label_image_orig = ctk.CTkLabel(self.frame_original, text="Aucune image chargée")
        self.label_image_orig.pack(expand=True, fill="both")

        # Cadre Image Transformée (Standard / HD)
        self.frame_processed = ctk.CTkFrame(self.main_frame, fg_color="transparent")
        self.frame_processed.grid(row=0, column=1, padx=10, pady=10, sticky="nsew")
        
        self.label_title_proc = ctk.CTkLabel(self.frame_processed, text="Résultat", font=("Arial", 14, "bold"))
        self.label_title_proc.pack(pady=5)
        
        self.label_image_proc = ctk.CTkLabel(self.frame_processed, text="Le résultat apparaîtra ici")
        self.label_image_proc.pack(expand=True, fill="both")

        # Cadre Image HD (3ème colonne, cachée par défaut)
        self.frame_hd = ctk.CTkFrame(self.main_frame, fg_color="transparent")
        # Ne pas grider tout de suite
        
        self.label_title_hd = ctk.CTkLabel(self.frame_hd, text="Résultat HD", font=("Arial", 14, "bold"))
        self.label_title_hd.pack(pady=5)
        
        self.label_image_hd = ctk.CTkLabel(self.frame_hd, text="")
        self.label_image_hd.pack(expand=True, fill="both")
        
    def show_hd_view(self):
        """Active la vue 3 colonnes."""
        self.main_frame.grid_columnconfigure(2, weight=1)
        self.frame_hd.grid(row=0, column=2, padx=10, pady=10, sticky="nsew")
        self.label_title_proc.configure(text="Standard")
        self.label_title_hd.configure(text="Résultat HD")
        
    def hide_hd_view(self):
        """Désactive la vue 3 colonnes (retour à 2)."""
        self.frame_hd.grid_forget()
        self.main_frame.grid_columnconfigure(2, weight=0)
        self.label_title_proc.configure(text="Résultat")

    def load_image(self):
        file_path = filedialog.askopenfilename(filetypes=[("Images", "*.jpg;*.jpeg;*.png")])
        if file_path:
            self.original_image_path = file_path
            self.display_image(file_path, self.label_image_orig)
            # Reset result
            self.label_image_proc.configure(image=None, text="Le résultat apparaîtra ici")
            self.label_image_hd.configure(image=None, text="")
            self.btn_save.configure(state="disabled")

        self.main_frame.bind("<Configure>", self.on_resize)
        
    def on_resize(self, event):
        """Redimensionne les images quand la fenêtre change de taille."""
        if not hasattr(self, 'resize_timer'):
             self.resize_timer = None
             
        if self.resize_timer:
             self.after_cancel(self.resize_timer)
             
        # Debounce to avoid lag
        self.resize_timer = self.after(100, self.update_images_size)
        
    def update_images_size(self):
        """Recalcule la taille des images affichées."""
        if not hasattr(self, 'main_frame'): return
        
        # Determine available width per column
        total_w = self.main_frame.winfo_width()
        total_h = self.main_frame.winfo_height()
        
        cols = 3 if self.switch_hd.get() else 2
        
        # Margins approx (20px padding * cols)
        avail_w = (total_w / cols) - 40
        avail_h = total_h - 80 # Minus title/margins
        
        if avail_w < 100 or avail_h < 100: return # Too small

        # Update each label if it has an image
        for lbl in [self.label_image_orig, self.label_image_proc, self.label_image_hd]:
             if hasattr(lbl, 'pil_image') and lbl.pil_image:
                  self.display_image(lbl.pil_image, lbl, max_size=(avail_w, avail_h))

    def display_image(self, image_source, label_widget, max_size=None):
        """Affiche une image dans un label en la redimensionnant proprement."""
        try:
            if isinstance(image_source, str):
                img = Image.open(image_source)
            else:
                img = image_source
                
            # Store ORIGINAL PIL image for later resizing
            label_widget.pil_image = img
                
            # Calcul dynamique selon la taille de fenêtre (on prend de la marge)
            if max_size:
                 max_w, max_h = max_size
            else:
                 # Default fallback if called before resize event
                 max_w = 600
                 max_h = 600
            
            # Utilisation de thumbnail qui préserve le ratio
            img_copy = img.copy()
            img_copy.thumbnail((max_w, max_h), Image.Resampling.LANCZOS)
            
            ctk_img = ctk.CTkImage(light_image=img_copy, dark_image=img_copy, size=img_copy.size)
            label_widget.configure(image=ctk_img, text="")
            label_widget.image = ctk_img # keep reference
        except Exception as e:
            # messagebox.showerror("Erreur", f"Impossible d'ouvrir l'image : {e}")
            print(f"Error display image: {e}")

    def start_processing(self):
        if not self.original_image_path:
            messagebox.showwarning("Attention", "Veuillez d'abord importer une image.")
            return

        self.progressbar.start()
        self.btn_process.configure(state="disabled", fg_color="red", hover_color="darkred")
        
        # Threading pour ne pas bloquer l'interface
        threading.Thread(target=self.process_image_thread, daemon=True).start()

    def process_image_thread(self):
        """Logique de traitement réelle connectée au moteur."""
        import time
        from engine import CartoonEngine
        
        # Initialiser le moteur si pas encore fait (ou le faire dans __init__)
        if not hasattr(self, 'engine'):
            self.engine = CartoonEngine()
            
        # Créer un dossier temp s'il n'existe pas
        output_dir = "output"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        filename = os.path.basename(self.original_image_path)
        output_path = os.path.join(output_dir, f"cartoon_{filename}")
        
        style = self.style_var.get()
        print(f"Style sélectionné: {style}")
        is_hd = self.switch_hd.get()
        
        try:
            # Liste des tâches : [(Label, HD_Mode, OutputAttr, RawOutputID)]
            # Si HD actif : on fait Standard (False) PUIS HD (True)
            # Si HD inactif : on fait juste Standard (False)
            
            tasks = []
            if is_hd:
                self.after(0, self.show_hd_view)
                std_out = os.path.join(output_dir, f"cartoon_std_{filename}")
                hd_out = os.path.join(output_dir, f"cartoon_hd_{filename}")
                tasks.append({"hd": False, "out": std_out, "display": self.label_image_proc, "raw_prefix": "std"})
                tasks.append({"hd": True, "out": hd_out, "display": self.label_image_hd, "raw_prefix": "hd"})
            else:
                self.after(0, self.hide_hd_view)
                std_out = output_path # Keep original naming for simple mode
                tasks.append({"hd": False, "out": std_out, "display": self.label_image_proc, "raw_prefix": "std"})

            # Reset state for slider
            self.current_outputs = {} # Format: {'std': {'final': path, 'raw': path}, 'hd': ...}

            for task in tasks:
                style_key = style
                out_path = task["out"]
                
                # Logic Processing
                if style_key == "Comics (OpenCV)":
                    intensity = self.slider_intensity.get()
                    self.engine.process_opencv(self.original_image_path, out_path, intensity=intensity)
                elif style_key == "Crayon (Sketch)":
                    self.engine.process_sketch(self.original_image_path, out_path)
                elif style_key == "Pixel Art":
                    val = self.slider_intensity.get()
                    # Si HD, on pourrait adapter la taille des blocs ? Pour l'instant on garde proportionnel
                    block_size = int(5 + val * 25)
                    self.engine.process_pixel_art(self.original_image_path, out_path, block_size=block_size)
                elif style_key.startswith("[Custom] "):
                     style_name = style_key.replace("[Custom] ", "")
                     model_path = os.path.join(self.custom_models_dir, f"{style_name}.pt")
                     self.engine.process_torch(self.original_image_path, out_path, model_path, hd_mode=task["hd"])
                else:
                    # Styles Torch
                    if "Paprika" in style_key or "Ghibli" in style_key:
                         base, ext = os.path.splitext(out_path)
                         raw_path = f"{base}_raw{ext}"
                         
                         self.engine.process_torch(self.original_image_path, raw_path, style_key, orange_correction=None, hd_mode=task["hd"])
                         
                         # Apply Color Correction
                         current_strength = self.slider_intensity.get()
                         self.engine.reduce_orange_tint(raw_path, out_path, strength=current_strength)
                         
                         # Save state for slider
                         self.current_outputs[task["raw_prefix"]] = {"raw": raw_path, "final": out_path}
                    else:
                         self.engine.process_torch(self.original_image_path, out_path, style_key, hd_mode=task["hd"])

                # Update UI incrementally
                img_res = Image.open(out_path)
                img_res.load()
                # Use after to update in main thread
                self.after(0, lambda img=img_res, lbl=task["display"]: self.display_image(img, lbl))
                
            # Set processed_image to the best available (HD if present, else Std)
            # Set processed_image to the best available (HD if present, else Std)
            if is_hd:
                 self.processed_image = Image.open(tasks[1]["out"])
                 self.current_final_output = tasks[1]["out"] # For Video
            else:
                 self.processed_image = Image.open(tasks[0]["out"])
                 self.current_final_output = tasks[0]["out"] # For Video

        except Exception as e:
            print(f"Erreur de traitement: {e}")
            messagebox.showerror("Erreur", f"Une erreur est survenue :\n{e}")
            self.processed_image = None
            
        self.after(0, self.finish_processing)
        
    def on_slider_change(self, value):
        """Appelé quand le slider bouge. Met à jour TOUS les outputs disponibles (std et hd)."""
        # print(f"Slider value: {value}") # Debug
        try:
            style = self.style_var.get()
            if ("Paprika" in style or "Ghibli" in style) and hasattr(self, 'current_outputs') and self.current_outputs:
                # Iterate over stored outputs (std, hd)
                for key, data in self.current_outputs.items():
                    raw = data["raw"]
                    final = data["final"]
                    
                    if os.path.exists(raw):
                         # Apply tint
                         self.engine.reduce_orange_tint(raw, final, strength=value)
                         
                         # Refresh Display
                         # Find which label to update based on key
                         if key == "std":
                              # If HD mode is on, std is in middle (label_image_proc), if not, it's also label_image_proc
                              # Basically always label_image_proc for 'std' in our logic above
                              self.update_processed_image_display(final, self.label_image_proc)
                         elif key == "hd":
                              self.update_processed_image_display(final, self.label_image_hd)
                              
        except Exception as e:
            print(f"Error realtime update: {e}")

    def update_processed_image_display(self, image_path, label_widget=None):
        """Met à jour l'image affichée sans bloquer (helper)."""
        if label_widget is None: label_widget = self.label_image_proc
        
        try:
            # Force reload from disk to ensure we get the latest pixels
            img = Image.open(image_path)
            # Force load data
            img.load() 
            
            # Update main processed_image ref if it's the main view
            if label_widget == self.label_image_proc and not self.switch_hd.get():
                 self.processed_image = img
            elif label_widget == self.label_image_hd:
                 self.processed_image = img
            
            self.display_image(img, label_widget)
            label_widget.configure(text="")
            # Force update
            self.update_idletasks()
        except Exception as e:
            print(f"Display update failed: {e}")

    def finish_processing(self):
        self.progressbar.stop()
        self.progressbar.set(1) # Full
        self.btn_process.configure(state="normal", fg_color="green", hover_color="darkgreen")
        
        if self.processed_image:
            self.display_image(self.processed_image, self.label_image_proc)
            self.btn_save.configure(state="normal")

            self.btn_save.configure(state="normal")
            self.btn_video.configure(state="normal")

    def create_video_action(self):
        """Génère une vidéo MP4 de la transformation."""
        if not self.processed_image or not hasattr(self, 'current_final_output'):
            # Fallback if current_final_output not set (e.g. simple OpenCV)
            # We need to find where the output is.
            # Actually process threads set self.current_final_output only for Torch styles in my updated code?
            # Let's verify.
            # If not set, we can't do video easily unless we track it better.
            messagebox.showinfo("Info", "La vidéo nécessite une image générée récente.")
            return

        file_path = filedialog.asksaveasfilename(defaultextension=".mp4", filetypes=[("MP4 Video", "*.mp4")])
        if file_path:
            try:
                self.btn_video.configure(state="disabled", text="Génération...")
                self.update_idletasks()
                
                # Determine which output to use (Standard or HD)
                # If HD is on, we prefer HD result for video? Or maybe Standard is safer for size?
                # User probably wants high quality.
                
                # Source path
                # Ideally we stored paths. Let's look at self.current_outputs or rely on logic.
                target_img_path = None
                
                if hasattr(self, 'current_outputs') and self.current_outputs:
                     # Check HD first
                     if 'hd' in self.current_outputs:
                          target_img_path = self.current_outputs['hd']['final']
                     elif 'std' in self.current_outputs:
                          target_img_path = self.current_outputs['std']['final']
                
                # If still none (e.g. OpenCV styles), we need to capture that path in process_thread
                # Let's assume we re-run verify or fix process_thread to always set current_final_output?
                # For now let's rely on what we have.
                if not target_img_path and hasattr(self, 'current_final_output'):
                     target_img_path = self.current_final_output
                     
                if not target_img_path:
                     # Last resort: re-save processed_image? No quality loss.
                     # But we need Original too.
                     pass
                
                if not target_img_path or not os.path.exists(target_img_path):
                     messagebox.showerror("Erreur", "Image source introuvable pour la vidéo.")
                     self.btn_video.configure(state="normal", text="Créer Vidéo (MP4)")
                     return

                self.engine.create_reveal_video(self.original_image_path, target_img_path, file_path)
                
                messagebox.showinfo("Succès", f"Vidéo générée avec succès !\n{file_path}")
                os.startfile(file_path)
                
            except Exception as e:
                messagebox.showerror("Erreur", f"Echec vidéo : {e}")
            finally:
                self.btn_video.configure(state="normal", text="Créer Vidéo (MP4)")

    def save_image(self):
        if self.processed_image:
            file_path = filedialog.asksaveasfilename(defaultextension=".jpg", filetypes=[("JPEG", "*.jpg"), ("PNG", "*.png")])
            if file_path:
                self.processed_image.save(file_path)
                messagebox.showinfo("Succès", "Image sauvegardée avec succès !")

    def change_style_event(self, new_style):
        if new_style == "Comics (OpenCV)":
            self.slider_frame.grid()
            self.label_intensity.configure(text="Intensité Lissage :")
            self.slider_intensity.set(0.5)  # Default value
        elif new_style == "Pixel Art":
            self.slider_frame.grid()
            self.label_intensity.configure(text="Taille des Pixels :")
            self.slider_intensity.set(0.5)  # Default value
        elif "Paprika" in new_style or "Ghibli" in new_style:
            self.slider_frame.grid()
            self.label_intensity.configure(text="Correction Orange :")
            self.slider_intensity.set(0.85)  # Default correction strength
        else:
            self.slider_frame.grid_remove()

    def load_config(self):
        if os.path.exists(self.config_file):
            try:
                with open(self.config_file, 'r') as f:
                    config = json.load(f)
                    w = config.get("width", 1100)
                    h = config.get("height", 700)
                    x = config.get("x", 100)
                    y = config.get("y", 100)
                    self.geometry(f"{w}x{h}+{x}+{y}")
                    
                    self.saved_style = config.get("style", "Comics (OpenCV)")
                    self.saved_slider = config.get("slider", 0.5)
            except Exception as e:
                print(f"Erreur chargement config: {e}")

    def save_config(self):
        try:
            config = {
                "width": self.winfo_width(),
                "height": self.winfo_height(),
                "x": self.winfo_x(),
                "y": self.winfo_y(),
                "style": self.style_var.get(),
                "slider": self.slider_intensity.get()
            }
            with open(self.config_file, 'w') as f:
                json.dump(config, f)
        except Exception as e:
            print(f"Erreur sauvegarde config: {e}")

    def on_closing(self):
        self.save_config()
        self.destroy()

    def import_custom_model(self):
        file_path = filedialog.askopenfilename(filetypes=[("PyTorch Model", "*.pt")])
        if file_path:
            try:
                # Copy file to custom models dir
                filename = os.path.basename(file_path)
                dest = os.path.join(self.custom_models_dir, filename)
                shutil.copy2(file_path, dest)
                
                # Update menu
                new_list = self.get_style_list()
                self.option_style.configure(values=new_list)
                
                # Select newly added style
                style_name = f"[Custom] {os.path.splitext(filename)[0]}"
                self.style_var.set(style_name)
                self.change_style_event(style_name)
                
                messagebox.showinfo("Succès", f"Style '{style_name}' importé avec succès !")
                
            except Exception as e:
                messagebox.showerror("Erreur", f"Echec de l'import : {e}")

            except Exception as e:
                messagebox.showerror("Erreur", f"Echec de l'import : {e}")

    def select_batch_folder(self):
        folder_path = filedialog.askdirectory(title="Sélectionner le dossier contenant les images")
        if folder_path:
            self.start_batch_processing(folder_path)

    def start_batch_processing(self, input_folder):
        style = self.style_var.get()
        # Confirm action
        if not messagebox.askyesno("Confirmation", f"Traiter toutes les images de :\n{input_folder}\n\nAvec le style : {style} ?"):
            return

        self.progressbar.configure(mode="determinate")
        self.progressbar.set(0)
        self.btn_process.configure(state="disabled")
        self.btn_batch.configure(state="disabled")
        self.btn_import.configure(state="disabled")
        
        threading.Thread(target=self.process_batch_thread, args=(input_folder,), daemon=True).start()

    def process_batch_thread(self, input_folder):
        import time
        from engine import CartoonEngine
        
        if not hasattr(self, 'engine'):
            self.engine = CartoonEngine()
            
        # Create output folder
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        output_dir = os.path.join(input_folder, f"cartoon_batch_{timestamp}")
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        # List valid images
        valid_exts = ('.jpg', '.jpeg', '.png', '.bmp')
        files = [f for f in os.listdir(input_folder) if f.lower().endswith(valid_exts)]
        total_files = len(files)
        
        if total_files == 0:
            messagebox.showinfo("Info", "Aucune image trouvée dans ce dossier.")
            self.after(0, self.finish_batch)
            return
            
        print(f"Batch processing {total_files} images...")
        
        style = self.style_var.get()
        hd_mode = self.switch_hd.get()
        intensity = self.slider_intensity.get() 
        
        for i, filename in enumerate(files):
            input_path = os.path.join(input_folder, filename)
            output_path = os.path.join(output_dir, f"cartoon_{filename}")
            
            try:
                # Update progress
                progress = (i / total_files)
                self.after(0, lambda p=progress: self.progressbar.set(p))
                
                # Logic copied from process_image_thread (simplified)
                if style == "Comics (OpenCV)":
                    self.engine.process_opencv(input_path, output_path, intensity=intensity)
                elif style == "Crayon (Sketch)":
                    self.engine.process_sketch(input_path, output_path)
                elif style == "Pixel Art":
                    block_size = int(5 + intensity * 25)
                    self.engine.process_pixel_art(input_path, output_path, block_size=block_size)
                elif style.startswith("[Custom] "):
                     style_name = style.replace("[Custom] ", "")
                     model_path = os.path.join(self.custom_models_dir, f"{style_name}.pt")
                     self.engine.process_torch(input_path, output_path, model_path, hd_mode=hd_mode)
                else:
                    # Built-in Torch styles
                    if "Paprika" in style or "Ghibli" in style:
                        # For batch, we do the full pipeline directly
                        base, ext = os.path.splitext(output_path)
                        raw_path = f"{base}_raw{ext}"
                        
                        self.engine.process_torch(input_path, raw_path, style, orange_correction=None, hd_mode=hd_mode)
                        self.engine.reduce_orange_tint(raw_path, output_path, strength=intensity)
                        # Clean up raw if desired? Let's keep it simple and maybe delete raw
                        if os.path.exists(raw_path):
                             os.remove(raw_path) 
                    else:
                        self.engine.process_torch(input_path, output_path, style, hd_mode=hd_mode)
                        
            except Exception as e:
                print(f"Error processing {filename}: {e}")
                
        self.after(0, lambda: self.finish_batch(output_dir, total_files))

    def finish_batch(self, output_dir=None, count=0):
        self.progressbar.set(1)
        self.btn_process.configure(state="normal")
        self.btn_batch.configure(state="normal")
        self.btn_import.configure(state="normal")
        self.progressbar.configure(mode="indeterminate")
        self.progressbar.start() # Reset to indeterminate animation or just stop
        self.progressbar.stop()
        
        if output_dir:
            messagebox.showinfo("Terminé", f"Traitement par lot terminé !\n{count} images traitées.\nSauvegardées dans :\n{output_dir}")
            # Open folder
            os.startfile(output_dir)

if __name__ == "__main__":
    app = CartoonApp()
    app.mainloop()
