import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image
import os
import ssl

# Fix SSL issue for internal torch hub download sometimes
ssl._create_default_https_context = ssl._create_unverified_context

def cartoonize_torch(image_path, output_path, model_name='paprika'):
    print(f"Loading model {model_name} from TorchHub...")
    device = 'cpu'
    
    try:
        model = torch.hub.load('bryandlee/animegan2-pytorch:main', 'generator', pretrained=model_name, device=device)
        model.eval()
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    print(f"Processing {image_path}...")
    try:
        img = Image.open(image_path).convert("RGB")
        # Resize for speed and memory on CPU
        w, h = img.size
        # Resize to max 512 on long side to keep it fast
        ratio = 512 / max(w, h)
        new_w, new_h = int(w * ratio), int(h * ratio)
        # Ensure dimensions are multiples of 32 for some architectures
        new_w = new_w - (new_w % 32)
        new_h = new_h - (new_h % 32)
        
        img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
        
        input_tensor = to_tensor(img).unsqueeze(0) * 2 - 1
        
        with torch.no_grad():
            output_tensor = model(input_tensor)
            
        output_img = to_pil_image((output_tensor[0] * 0.5 + 0.5).clip(0, 1))
        output_img.save(output_path)
        print(f"Saved to {output_path}")
        
    except Exception as e:
        print(f"Error processing image: {e}")

if __name__ == "__main__":
    input_file = "photos/IMG-20230813-WA0028.jpg"
    output_file = "output/test_ai_paprika.jpg"
    
    if os.path.exists(input_file):
        cartoonize_torch(input_file, output_file, model_name='paprika')
    else:
        print(f"File not found: {input_file}")
