import pygame
import threading
import time
import os
import wave
import struct
import copy
import json
import shutil

try:
    import numpy as np
    import pygame.sndarray
    HAS_NUMPY = True
except ImportError:
    HAS_NUMPY = False

class AudioController:
    def __init__(self, model):
        self.model = model
        try:
            pygame.mixer.init(frequency=44100, size=-16, channels=2, buffer=512)
            pygame.mixer.set_num_channels(32) # Augmenter le nombre de canaux
        except pygame.error:
            print("Erreur: Impossible d'initialiser pygame.mixer")
        
        self.channels = [pygame.mixer.Channel(i) for i in range(32)] # Pré-allouer 32 canaux
        self.samples = []
        self.original_samples = [] # Cache pour pitch resampling
        self.pitch_cache = {}
        self.is_playing = False
        self.current_step = 0
        self.thread = None
        # Ajustement du chemin pour pointer vers le dossier parent de src
        # __file__ est dans PyBeat/src/audio.py -> parent = PyBeat/src -> parent = PyBeat
        self.base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        names = self._load_kit(model.current_kit)
        self.model.set_tracks(names)
        self.lock = threading.Lock() # Lock pour protéger l'accès au mixer

    def _get_config_path(self, kit_name):
        return os.path.join(self.base_path, "assets", "kits", kit_name, "kit_config.json")

# ... (skip to update_volumes)

    def update_volumes(self):
        """Applique les volumes du modèle aux canaux de mixage."""
        with self.lock:
            for i in range(min(self.model.num_tracks, len(self.channels))):
                vol = self.model.track_volumes[i] * self.model.master_volume
                if vol < 0.005: vol = 0.0
                
                # Apply Pan
                pan = 0.0
                if i < len(self.model.track_pan):
                    pan = self.model.track_pan[i]
                
                left_vol = vol
                right_vol = vol
                
                if pan > 0.0:
                    left_vol *= (1.0 - pan)
                elif pan < 0.0:
                    right_vol *= (1.0 + pan)
                    
                try:
                    self.channels[i].set_volume(left_vol, right_vol)
                except Exception:
                    pass
        return os.path.join(self.base_path, "assets", "kits", kit_name, "kit_config.json")

    def _save_kit_config(self, kit_name, tracks_data):
        try:
            cfg_path = self._get_config_path(kit_name)
            with open(cfg_path, 'w') as f:
                json.dump({"tracks": tracks_data}, f, indent=4)
        except Exception as e:
            print(f"Erreur saving kit config: {e}")

    def _load_kit(self, kit_name):
        kit_path = os.path.join(self.base_path, "assets", "kits", kit_name)
        cfg_path = self._get_config_path(kit_name)
        
        temp_samples = []
        names = []
        tracks_data = [] # List of dicts {name, file}
        
        if os.path.exists(cfg_path):
            # Load from config
            try:
                with open(cfg_path, 'r') as f:
                    data = json.load(f)
                    tracks_data = data.get("tracks", [])
            except Exception as e:
                print(f"Error loading kit config: {e}")
        
        # If no config or empty, fallback to scan
        if not tracks_data and os.path.exists(kit_path):
            files = sorted([f for f in os.listdir(kit_path) if f.lower().endswith(".wav")])
            for f in files:
                name = os.path.splitext(f)[0].replace("_", " ").title()
                tracks_data.append({"name": name, "file": f})
            
            # Create default config if it was missing and we found files
            if files:
                self._save_kit_config(kit_name, tracks_data)
        
        # Now load audio based on tracks_data
        for track in tracks_data:
            f = track["file"]
            full_path = os.path.join(kit_path, f)
            try:
                temp_samples.append(pygame.mixer.Sound(full_path))
                names.append(track["name"])
            except Exception as e:
                print(f"Erreur lors du chargement de {full_path} : {e}")
        
        if not temp_samples:
            print(f"Warning: No samples found in {kit_path}")
            names = ["Silence"]
            temp_samples = [None]
            
        # Mise à jour atomique des échantillons
        # Mise à jour atomique des échantillons
        self.samples = temp_samples
        self.original_samples = list(temp_samples) # Copie de sauvegarde
        self.pitch_cache = {} # Reset cache
        return names

    def update_kit(self, kit_name):
        names = self._load_kit(kit_name)
        return names

    def add_track_to_kit(self, kit_name, source_file, track_name):
        kit_path = os.path.join(self.base_path, "assets", "kits", kit_name)
        filename = os.path.basename(source_file)
        dest_path = os.path.join(kit_path, filename)
        
        # Copy file
        if not os.path.exists(dest_path):
            shutil.copy(source_file, dest_path)
            
        # Update Config
        cfg_path = self._get_config_path(kit_name)
        data = {"tracks": []}
        if os.path.exists(cfg_path):
            with open(cfg_path, 'r') as f:
                data = json.load(f)
        
        data["tracks"].append({"name": track_name, "file": filename})
        self._save_kit_config(kit_name, data["tracks"])
        
        return self.update_kit(kit_name)

    def remove_track_from_kit(self, kit_name, index):
        cfg_path = self._get_config_path(kit_name)
        if os.path.exists(cfg_path):
            with open(cfg_path, 'r') as f:
                data = json.load(f)
            
            if 0 <= index < len(data["tracks"]):
                del data["tracks"][index]
                self._save_kit_config(kit_name, data["tracks"])
                
        return self.update_kit(kit_name)

    def rename_track_in_kit(self, kit_name, index, new_name):
        cfg_path = self._get_config_path(kit_name)
        if os.path.exists(cfg_path):
            with open(cfg_path, 'r') as f:
                data = json.load(f)
            
            if 0 <= index < len(data["tracks"]):
                data["tracks"][index]["name"] = new_name
                self._save_kit_config(kit_name, data["tracks"])
                
        return self.update_kit(kit_name)

    def replace_sample(self, track_idx, file_path):
        """Remplace le sample d'une piste spécifique."""
        if not os.path.exists(file_path):
            return False, "Fichier introuvable"
            
        try:
            sound = pygame.mixer.Sound(file_path)
            # Mettre à jour la liste des samples
            if track_idx < len(self.samples):
                self.samples[track_idx] = sound
                if track_idx < len(self.original_samples):
                    self.original_samples[track_idx] = sound # Update original too
                # Clear pitch cache for this track
                keys_to_del = [k for k in self.pitch_cache.keys() if k[0] == track_idx]
                for k in keys_to_del: 
                    del self.pitch_cache[k]
            else:
                # Si l'index dépasse (rare mais possible si kit change), on ne fait rien ou on extend
                pass
                
            # Mettre à jour le nom dans le modèle (optionnel, ou on garde le nom original ?)
            # Mieux vaut mettre à jour pour que l'utilisateur sache
            new_name = os.path.splitext(os.path.basename(file_path))[0].title()
            # On retourne le nouveau nom pour l'UI
            return True, new_name
        except Exception as e:
            print(f"Erreur chargement sample: {e}")
            return False, str(e)

    def start_playback(self):
        if not self.is_playing:
            self.is_playing = True
            self.thread = threading.Thread(target=self._run_loop, daemon=True)
            self.thread.start()

    def stop_playback(self):
        self.is_playing = False
        self.current_step = 0

    def update_volumes(self):
        """Applique les volumes du modèle aux canaux de mixage."""
        for i in range(min(self.model.num_tracks, len(self.channels))):
            vol = self.model.track_volumes[i] * self.model.master_volume
            if vol < 0.005: vol = 0.0
            
            # Apply Pan
            pan = 0.0
            if i < len(self.model.track_pan):
                pan = self.model.track_pan[i]
            
            left_vol = vol
            right_vol = vol
            
            if pan > 0.0:
                left_vol *= (1.0 - pan)
            elif pan < 0.0:
                right_vol *= (1.0 + pan)
                
            self.channels[i].set_volume(left_vol, right_vol)

    def export_wav(self, filename, loops=1):
        """Exporte la séquence actuelle en fichier WAV."""
        print(f"Exportation vers {filename} (Loops: {loops})...")
        
        # 1. Charger les données brutes des samples du kit actuel
        kit_path = os.path.join(self.base_path, "assets", "kits", self.model.current_kit)
        sample_data = [] # Liste de (data, framerate)
        
        if os.path.exists(kit_path):
            files = sorted([f for f in os.listdir(kit_path) if f.lower().endswith(".wav")])
            for f in files:
                full_path = os.path.join(kit_path, f)
                try:
                    with wave.open(full_path, 'rb') as w:
                        frames = w.readframes(w.getnframes())
                        params = w.getparams()
                        # On convertit tout en liste d'entiers pour le mixage (suppose 16-bit mono/stereo)
                        # Pour simplifier on assume 44100Hz 16-bit
                        # Convertir bytes en liste d'int
                        fmt = f"<{len(frames)//2}h"
                        data = struct.unpack(fmt, frames)
                        sample_data.append(data)
                except Exception as e:
                    print(f"Erreur lecture Wav {f}: {e}")
                    sample_data.append(None)
        
        if not sample_data:
            return False

        # 2. Préparer le buffer de mixage
        effective_steps = self.model.get_effective_steps()
        bpm = self.model.bpm
        step_duration_sec = 60.0 / bpm / 4.0 # 16th note
        total_duration_sec = step_duration_sec * effective_steps * loops
        sr = 44100
        total_samples = int(total_duration_sec * sr)
        
        # Buffer stéréo (L, R) int32 pour éviter le clipping lors de l'addition
        mix_buffer = [0] * (total_samples * 2) # *2 pour stéréo
        
        # 3. Mixer
        # On itère sur chaque boucle
        for loop in range(loops):
            loop_offset_sec = loop * (step_duration_sec * effective_steps)
            
            for step in range(effective_steps):
                # Temps absolu du début de ce pas dans le fichier final
                start_sample_idx = int((loop_offset_sec + (step * step_duration_sec)) * sr)
                
                # Vérifier quels sons jouer à ce pas
                is_solo_mode = any(self.model.solo_states)
                
                for r in range(min(self.model.num_tracks, len(sample_data))):
                    should_play = (not self.model.mute_states[r]) and \
                                  ((not is_solo_mode) or self.model.solo_states[r])
                                  
                    if self.model.grid[r][step] and should_play:
                        sound = sample_data[r]
                        if sound:
                            vol = self.model.track_volumes[r] * self.model.master_volume
                            
                            # Additionner au mix
                            for i, sample_val in enumerate(sound):
                                target_idx = (start_sample_idx * 2) + i
                                
                                if target_idx < len(mix_buffer):
                                    val_scaled = int(sample_val * vol)
                                    if target_idx % 2 == 0: # L
                                        mix_buffer[target_idx] += val_scaled     # L
                                        if target_idx + 1 < len(mix_buffer):
                                            mix_buffer[target_idx+1] += val_scaled # R
        
        # 4. Écrire le fichier final
        try:
            with wave.open(filename, 'w') as out:
                out.setnchannels(2)
                out.setsampwidth(2)
                out.setframerate(sr)
                
                # Clipper et convertir en bytes
                final_data = []
                for val in mix_buffer:
                    clipped = max(-32768, min(32767, val))
                    final_data.append(int(clipped))
                
                out.writeframes(struct.pack(f"<{len(final_data)}h", *final_data))
            print("Export terminé.")
            return True
        except Exception as e:
            print(f"Erreur export: {e}")
            return False

    def _run_loop(self):
        while self.is_playing:
            if self.model.is_loading:
                time.sleep(0.1)
                continue
                
            effective_steps = self.model.get_effective_steps()
            # Protection contre division par zéro
            bpm = max(1, self.model.bpm)
            step_duration = 60.0 / bpm / 4.0

            # Application du SWING
            # Pas pairs : Plus longs (1 + swing)
            # Pas impairs : Plus courts (1 - swing)
            if self.current_step % 2 == 0:
                actual_duration = step_duration * (1.0 + self.model.swing)
            else:
                actual_duration = step_duration * (1.0 - self.model.swing)
            
            with self.model.grid_lock:
                # Vérifier si on doit switcher de pattern (SONG MODE logic)
                # On switch au début de la mesure (step 0), AVANT de jouer
                if self.current_step == 0 and self.model.song_mode and self.model.playlist:
                    self.model.song_index = (self.model.song_index + 1) % len(self.model.playlist)
                    item = self.model.playlist[self.model.song_index]
                    
                    # Support legacy (list of grids) vs new (dict)
                    if isinstance(item, dict):
                        self.model.grid = copy.deepcopy(item['grid'])
                        self.model.current_style_name = item.get('name', 'Unknown')
                    else:
                        self.model.grid = copy.deepcopy(item)
                    
                    self.model.grid_has_changed = True

                # Vérifier mode Solo global
                is_solo_mode = any(self.model.solo_states)
                
                # Jouer le pas actuel
                for r in range(self.model.num_tracks):
                    if r < len(self.model.grid) and self.current_step < len(self.model.grid[r]):
                        # Logique Mute/Solo
                        should_play = (not self.model.mute_states[r]) and \
                                      ((not is_solo_mode) or self.model.solo_states[r])
                        
                        if self.model.grid[r][self.current_step] and should_play:
                            if r < len(self.channels) and r < len(self.samples) and self.samples[r]:
                                # 1. Velocity
                                val = self.model.grid[r][self.current_step]
                                if val is True: val = 1.0
                                if val is False: val = 0.0
                                
                                if val > 0.0:
                                    # 2. Volume de base
                                    base_vol = self.model.track_volumes[r] * self.model.master_volume * val
                                    if base_vol < 0.005: base_vol = 0.0
                                    
                                    # 3. Pan
                                    pan = 0.0
                                    if r < len(self.model.track_pan):
                                        pan = self.model.track_pan[r]
                                    
                                    left_vol = base_vol
                                    right_vol = base_vol
                                    
                                    if pan > 0.0:
                                        left_vol *= (1.0 - pan)
                                    elif pan < 0.0:
                                        right_vol *= (1.0 + pan)
                                    
                                    # 4. Pitch processing
                                    current_sound = self.samples[r]
                                    pitch = 1.0
                                    
                                    # Check Per-Step Pitch Override (V9)
                                    step_pitch = 0.0
                                    if r < len(self.model.grid_pitch) and self.current_step < len(self.model.grid_pitch[r]):
                                        step_pitch = self.model.grid_pitch[r][self.current_step]
                                        
                                    if step_pitch > 0.0:
                                        pitch = step_pitch
                                    elif r < len(self.model.track_pitch):
                                        pitch = self.model.track_pitch[r]
                                        
                                    # Si le pitch a changé et qu'on a Numpy
                                    if abs(pitch - 1.0) > 0.05 and HAS_NUMPY and r < len(self.original_samples) and self.original_samples[r]:
                                        # Check cache
                                        cache_key = (r, round(pitch, 2)) # Round to avoid float cache miss
                                        if cache_key in self.pitch_cache:
                                            current_sound = self.pitch_cache[cache_key]
                                        else:
                                            # Generate
                                            try:
                                                snd_array = pygame.sndarray.array(self.original_samples[r])
                                                # Resample: New length = Old length / pitch
                                                new_len = int(len(snd_array) / pitch)
                                                # Indices array
                                                indices = np.linspace(0, len(snd_array) - 1, new_len)
                                                
                                                # Resample Left & Right separately if stereo, otherwise just 1D
                                                # Pygame sounds are often stereo (N, 2)
                                                if len(snd_array.shape) > 1 and snd_array.shape[1] > 1:
                                                    left = np.interp(indices, np.arange(len(snd_array)), snd_array[:, 0])
                                                    right = np.interp(indices, np.arange(len(snd_array)), snd_array[:, 1])
                                                    resampled = np.column_stack((left, right)).astype(np.int16)
                                                else:
                                                    resampled = np.interp(indices, np.arange(len(snd_array)), snd_array).astype(np.int16)
                                                    
                                                current_sound = pygame.sndarray.make_sound(resampled)
                                                self.pitch_cache[cache_key] = current_sound
                                            except Exception as e:
                                                print(f"Pitch error track {r}: {e}")
                                                
                                    # 5. Play
                                    with self.lock:
                                        try:
                                            chan = self.channels[r]
                                            # DEBUG
                                            # if r >= 0: print(f"Trk{r} Vol={self.model.track_volumes[r]:.2f} Base={base_vol:.2f}")
                                                 
                                            chan.play(current_sound)
                                            chan.set_volume(left_vol, right_vol)
                                        except Exception as e:
                                            # print(f"Audio Play Error: {e}")
                                            pass
                
            # Passer au pas suivant ou boucler
            self.current_step = (self.current_step + 1) % effective_steps
            
            time.sleep(max(0.005, actual_duration))
