"""
Moteur de séquençage MIDI multi-pistes.
"""

import time
import threading
from dataclasses import dataclass, field
from typing import List, Optional, Callable
from enum import Enum


class NoteEventType(Enum):
    """Type d'événement MIDI."""
    NOTE_ON = "note_on"
    NOTE_OFF = "note_off"


@dataclass
class MidiEvent:
    """Représente un événement MIDI dans le séquenceur."""
    timestamp: float  # Position en beats (pulses / PPQN)
    note: int  # Numéro MIDI de la note (0-127)
    velocity: int  # Vélocité (0-127)
    event_type: NoteEventType
    duration: Optional[float] = None  # Durée en beats (calculée pour NOTE_ON)
    
    def __lt__(self, other):
        """Permet de trier les événements par timestamp."""
        return self.timestamp < other.timestamp


@dataclass
class Track:
    """Une piste du séquenceur contenant des événements MIDI."""
    name: str
    events: List[MidiEvent] = field(default_factory=list)
    instrument_type: str = "sf2"  # "synth" ou "soundfont"
    instrument_category: str = "Piano"
    instrument_name: str = "Grand Piano"
    bank: int = 0
    preset: int = 0
    volume: float = 0.8
    solo: bool = False
    mute: bool = False
    armed: bool = False  # Piste armée pour enregistrement
    
    def add_event(self, event: MidiEvent):
        """Ajoute un événement et maintient le tri chronologique."""
        self.events.append(event)
        self.events.sort()
    
    
    def remove_events_in_range(self, start_time: float, end_time: float, note: int = None):
        """Supprime tous les événements dans une plage temporelle (pour l'écrasement).
        
        Args:
            start_time: Début de la plage
            end_time: Fin de la plage
            note: Si spécifié, supprime uniquement les événements de cette note (pitch)
        """
        if note is not None:
            # Supprimer uniquement les notes de la même hauteur
            self.events = [e for e in self.events 
                          if not (start_time <= e.timestamp <= end_time and e.note == note)]
        else:
            # Supprimer toutes les notes dans la plage
            self.events = [e for e in self.events if not (start_time <= e.timestamp <= end_time)]
    
    def get_events_in_range(self, start_time: float, end_time: float) -> List[MidiEvent]:
        """Récupère tous les événements dans une plage temporelle."""
        return [e for e in self.events if start_time <= e.timestamp <= end_time]
    
    def clear(self):
        """Efface tous les événements de la piste."""
        self.events.clear()


class TransportState(Enum):
    """État du transport du séquenceur."""
    STOPPED = "stopped"
    PLAYING = "playing"
    RECORDING = "recording"
    PAUSED = "paused"


class SequencerEngine:
    """
    Moteur de séquençage multi-pistes.
    Gère l'enregistrement, la lecture et l'édition de séquences MIDI.
    """
    
    def __init__(self, bpm: int = 120, ppqn: int = 96, sample_rate: int = 44100):
        """
        Initialise le moteur de séquenceur.
        
        Args:
            bpm: Tempo en battements par minute
            ppqn: Pulses Per Quarter Note (résolution temporelle)
            sample_rate: Taux d'échantillonnage pour synchronisation audio
        """
        self.bpm = bpm
        self.ppqn = ppqn
        self.sample_rate = sample_rate
        
        # Pistes
        self.tracks: List[Track] = []
        self._create_default_tracks()
        
        # État du transport
        self.state = TransportState.STOPPED
        self.current_position = 0.0  # Position en beats
        self.loop_enabled = False
        self.loop_start = 0.0
        self.loop_end = 16.0  # 4 mesures par défaut
        
        # Enregistrement
        self.recording_start_time = 0.0
        self.recording_notes_on = {}  # {note: (timestamp, velocity)} pour calculer durée
        
        # Métronome
        self.metronome_enabled = True
        self.count_in_bars = 1  # Nombre de mesures de count-in
        
        # Quantification
        self.quantize_enabled = False
        self.quantize_value = 16  # Quantifier sur 16ème notes
        
        # Thread de lecture
        self.playback_thread = None
        self.playback_running = False
        
        # Callbacks pour envoyer les notes aux moteurs de synthèse
        self.note_on_callback: Optional[Callable] = None
        self.note_off_callback: Optional[Callable] = None
        self.position_callback: Optional[Callable] = None  # Callback pour mise à jour UI
    
    def _create_default_tracks(self):
        """Crée les pistes par défaut."""
        for i in range(8):
            self.tracks.append(Track(
                name=f"Piste {i+1}",
                instrument_type="sf2",
                instrument_name="Grand Piano"
            ))
    
    def add_track(self, name: str = None) -> Track:
        """Ajoute une nouvelle piste."""
        if name is None:
            name = f"Piste {len(self.tracks) + 1}"
        track = Track(name=name)
        self.tracks.append(track)
        return track
    
    def remove_track(self, index: int):
        """Supprime une piste."""
        if 0 <= index < len(self.tracks):
            self.tracks.pop(index)
    
    def get_armed_track(self) -> Optional[Track]:
        """Retourne la piste armée pour l'enregistrement."""
        for track in self.tracks:
            if track.armed:
                return track
        return None

    def get_armed_track_index(self) -> int:
        """Retourne l'index de la piste armée, ou -1 si aucune."""
        for i, track in enumerate(self.tracks):
            if track.armed:
                return i
        return -1
    
    def set_bpm(self, bpm: int):
        """Définit le tempo."""
        self.bpm = max(20, min(300, bpm))
    
    def beats_to_seconds(self, beats: float) -> float:
        """Convertit des beats en secondes."""
        return (beats * 60.0) / self.bpm
    
    def seconds_to_beats(self, seconds: float) -> float:
        """Convertit des secondes en beats."""
        return (seconds * self.bpm) / 60.0
    
    def quantize(self, beat_position: float) -> float:
        """Quantifie une position temporelle."""
        if not self.quantize_enabled:
            return beat_position
        
        # Calculer la grille de quantification
        grid_size = 4.0 / self.quantize_value  # 4 beats = 1 mesure
        return round(beat_position / grid_size) * grid_size
    
    # ===== ENREGISTREMENT =====
    
    def start_recording(self):
        """Démarre l'enregistrement sur la piste armée."""
        armed_track = self.get_armed_track()
        if not armed_track:
            print("⚠️ Aucune piste armée pour l'enregistrement. Pistes:")
            for i, t in enumerate(self.tracks):
                print(f"  - [{i}] {t.name}: armed={t.armed}")
            return False
        
        self.state = TransportState.RECORDING
        self.recording_start_time = self.current_position
        self.recording_notes_on.clear()
        
        # Démarrer le playback pour monitoring
        self._start_playback_thread()
        
        print(f"🔴 Enregistrement démarré sur {armed_track.name}")
        return True
    
    def record_note_on(self, note: int, velocity: int):
        """Enregistre un événement note on pendant l'enregistrement."""
        armed_idx = self.get_armed_track_index()
        
        # Envoyer immédiatement au synthé pour monitoring (même si pas RECORDING, pour le live)
        if self.note_on_callback:
            self.note_on_callback(note, velocity, armed_idx if armed_idx >= 0 else 0)
            
        if self.state != TransportState.RECORDING:
            return
        
        current_time = self.current_position
        quantized_time = self.quantize(current_time)
        
        # Stocker timestamp ET vélocité pour calculer la durée au note_off
        self.recording_notes_on[note] = (quantized_time, velocity)
        
        print(f"📝 Note ON: {note} vel:{velocity} @{quantized_time:.2f} [Track {armed_idx}]")
    
    def record_note_off(self, note: int):
        """Enregistre un événement note off pendant l'enregistrement."""
        armed_idx = self.get_armed_track_index()
        
        # Toujours envoyer au synthé pour le monitoring
        if self.note_off_callback:
            self.note_off_callback(note, armed_idx if armed_idx >= 0 else 0)
        
        if self.state != TransportState.RECORDING or note not in self.recording_notes_on:
            return
        
        armed_track = self.get_armed_track()
        if not armed_track:
            return
        
        # Récupérer le timestamp ET la vélocité du note_on
        note_on_time, velocity = self.recording_notes_on.pop(note)
        note_off_time = self.quantize(self.current_position)
        duration = max(0.1, note_off_time - note_on_time)  # Durée minimale
        
        # Supprimer SEULEMENT les notes de la même hauteur qui chevauchent (mode écrasement)
        armed_track.remove_events_in_range(note_on_time, note_off_time, note=note)
        
        # Créer et ajouter l'événement avec la vraie vélocité
        event = MidiEvent(
            timestamp=note_on_time,
            note=note,
            velocity=velocity,  # Vélocité capturée au note_on
            event_type=NoteEventType.NOTE_ON,
            duration=duration
        )
        armed_track.add_event(event)
        
        print(f"✓ Note enregistrée: {note} vel:{velocity} @{note_on_time:.2f} dur:{duration:.2f}")
    
    def stop_recording(self):
        """Arrête l'enregistrement."""
        if self.state == TransportState.RECORDING:
            # Terminer toutes les notes en cours
            armed_track = self.get_armed_track()
            if armed_track:
                for note, (timestamp, velocity) in self.recording_notes_on.items():
                    duration = max(0.1, self.current_position - timestamp)
                    event = MidiEvent(
                        timestamp=timestamp,
                        note=note,
                        velocity=velocity,
                        event_type=NoteEventType.NOTE_ON,
                        duration=duration
                    )
                    armed_track.add_event(event)
            
            self.recording_notes_on.clear()
            if armed_track:
                print(f"⏹️ Enregistrement arrêté - {len(armed_track.events)} événements")
        
        self.stop()
    
    # ===== LECTURE =====
    
    def play(self):
        """Démarre la lecture."""
        if self.state != TransportState.PLAYING:
            self.state = TransportState.PLAYING
            self._start_playback_thread()
            print(f"▶️ Lecture démarrée à {self.current_position:.2f} beats")
    
    def stop(self):
        """Arrête la lecture/enregistrement."""
        self.state = TransportState.STOPPED
        self._stop_playback_thread()
        
        # Arrêter toutes les notes en cours
        if self.note_off_callback:
            for i in range(128):
                self.note_off_callback(i)
        
        print("⏹️ Arrêt")
    
    def pause(self):
        """Met en pause."""
        if self.state == TransportState.PLAYING:
            self.state = TransportState.PAUSED
            self._stop_playback_thread()
            print(f"⏸️ Pause à {self.current_position:.2f} beats")
    
    def set_position(self, beats: float):
        """Définit la position de lecture."""
        self.current_position = max(0.0, beats)
        if self.position_callback:
            self.position_callback(self.current_position)
    
    def _start_playback_thread(self):
        """Démarre le thread de lecture."""
        if not self.playback_running:
            self.playback_running = True
            self.playback_thread = threading.Thread(target=self._playback_loop, daemon=True)
            self.playback_thread.start()
    
    def _stop_playback_thread(self):
        """Arrête le thread de lecture."""
        self.playback_running = False
        if self.playback_thread:
            self.playback_thread.join(timeout=1.0)
            self.playback_thread = None
    
    def _playback_loop(self):
        """Boucle de lecture principale (thread séparé)."""
        last_time = time.time()
        
        while self.playback_running:
            current_time = time.time()
            delta_time = current_time - last_time
            last_time = current_time
            
            # Avancer la position
            delta_beats = self.seconds_to_beats(delta_time)
            self.current_position += delta_beats
            
            # Boucle si activée
            if self.loop_enabled and self.current_position >= self.loop_end:
                self.current_position = self.loop_start
            
            # Générer les événements MIDI pour toutes les pistes
            if self.state == TransportState.PLAYING or self.state == TransportState.RECORDING:
                self._process_events(self.current_position - delta_beats, self.current_position)
            
            # Callback de position pour UI
            if self.position_callback:
                self.position_callback(self.current_position)
            
            # Dormir un peu pour ne pas surcharger le CPU
            time.sleep(0.005)  # 5ms
    
    def _process_events(self, start_beat: float, end_beat: float):
        """Traite les événements MIDI dans la plage temporelle donnée."""
        for track_idx, track in enumerate(self.tracks):
            # Ignorer les pistes mutées ou en solo (si d'autres sont en solo)
            if track.mute:
                continue
            
            solo_tracks = [t for t in self.tracks if t.solo]
            if solo_tracks and not track.solo:
                continue
            
            # Récupérer les événements dans la plage
            events = track.get_events_in_range(start_beat, end_beat)
            
            for event in events:
                if event.event_type == NoteEventType.NOTE_ON:
                    # Note on
                    if self.note_on_callback:
                        velocity = int(event.velocity * track.volume)  # Appliquer le volume
                        self.note_on_callback(event.note, velocity, track_idx)
                    
                    # Programmer le note off après la durée
                    if event.duration and self.note_off_callback:
                        # Calculer quand arrêter la note
                        note_off_time = event.timestamp + event.duration
                        # Créer un timer pour le note_off
                        import threading
                        delay = self.beats_to_seconds(event.duration)
                        timer = threading.Timer(delay, lambda n=event.note, idx=track_idx: self.note_off_callback(n, idx) if self.note_off_callback else None)
                        timer.daemon = True
                        timer.start()
    
    def all_notes_off(self):
        """Arrête toutes les notes (bouton PANIC)."""
        if self.note_off_callback:
            for i in range(128):
                self.note_off_callback(i)
