"""Monolithic AnimeGANv2 file generated to remove dependencies."""
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- FROM utils/common.py ---
import gc
import os
import urllib.request
import cv2
from tqdm import tqdm

HTTP_PREFIXES = [
    'http',
    'data:image/jpeg',
]


RELEASED_WEIGHTS = {
    "hayao:v1": (
        "v1",
        "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
    ),
    "hayao": (
        "v1",
        "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
    ),
    "shinkai:v1": (
        "v1",
        "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
    ),
    "shinkai": (
        "v1",
        "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
    ),
    
    ## VER 2 ##
    "hayao:v2": (
        # Dataset trained on Google Landmark micro as training real photo
        "v2",
        "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_gldv2_Hayao.pt"
    ),
    "shinkai:v2": (
        # Dataset trained on Google Landmark micro as training real photo
        "v2",
        "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_gldv2_Shinkai.pt"
    ),
    ## Face portrait
    "arcane:v2": (
        "v2",
        "https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_ffhq_Arcane_210624_e350.pt"
    )
}

def is_image_file(path):
    _, ext = os.path.splitext(path)
    return ext.lower() in (".png", ".jpg", ".jpeg", ".webp")

def is_video_file(path):
    # https://moviepy-tburrows13.readthedocs.io/en/improve-docs/ref/VideoClip/VideoFileClip.html
    _, ext = os.path.splitext(path)
    return ext.lower() in (".mp4", ".mov", ".ogv", ".avi", ".mpeg")


def read_image(path):
    """
    Read image from given path
    """

    if any(path.startswith(p) for p in HTTP_PREFIXES):
        urllib.request.urlretrieve(path, "temp.jpg")
        path = "temp.jpg"

    img = cv2.imread(path)
    if img.shape[-1] == 4:
        # 4 channels image
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def save_checkpoint(model, path, optimizer=None, epoch=None):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
    }
    if optimizer is  not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    torch.save(checkpoint, path)

def maybe_remove_module(state_dict):
    # Remove added module ins state_dict in ddp training
    # https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
    new_state_dict = {}
    module_str = 'module.'
    for k, v in state_dict.items():

        if k.startswith(module_str):
            k = k[len(module_str):]
        new_state_dict[k] = v
    return new_state_dict


def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
    state_dict, path = load_state_dict(path, map_location)
    model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
    model.load_state_dict(
        model_state_dict,
        strict=True
    )
    if 'optimizer_state_dict' in state_dict:
        if optimizer is not None:
            optimizer.load_state_dict(state_dict['optimizer_state_dict'])
        if strip_optimizer:
            del state_dict["optimizer_state_dict"]
            torch.save(state_dict, path)
            print(f"Optimizer stripped and saved to {path}")

    epoch = state_dict.get('epoch', 0)
    return epoch


def load_state_dict(weight, map_location) -> dict:
    if weight.lower() in RELEASED_WEIGHTS:
        weight = _download_weight(weight.lower())

    if map_location is None:
        # auto select
        map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
    state_dict = torch.load(weight, map_location=map_location)

    return state_dict, weight


def initialize_weights(net):
    for m in net.modules():
        try:
            if isinstance(m, nn.Conv2d):
                # m.weight.data.normal_(0, 0.02)
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.zero_()
            elif isinstance(m, nn.ConvTranspose2d):
                # m.weight.data.normal_(0, 0.02)
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                # m.weight.data.normal_(0, 0.02)
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        except Exception as e:
            # print(f'SKip layer {m}, {e}')
            pass


def set_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class DownloadProgressBar(tqdm):
    '''
    https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
    '''
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)


def _download_weight(weight):
    '''
    Download weight and save to local file
    '''
    os.makedirs('.cache', exist_ok=True)
    url = RELEASED_WEIGHTS[weight][1]
    filename = os.path.basename(url)
    save_path = f'.cache/{filename}'

    if os.path.isfile(save_path):
        return save_path

    desc = f'Downloading {url} to {save_path}'
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
        urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)

    return save_path



# --- FROM models/conv_blocks.py ---


class DownConv(nn.Module):

    def __init__(self, channels, bias=False):
        super(DownConv, self).__init__()

        self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias)
        self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias)

    def forward(self, x):
        out1 = self.conv1(x)
        out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
        out2 = self.conv2(out2)

        return out1 + out2


class UpConv(nn.Module):
    def __init__(self, channels, bias=False):
        super(UpConv, self).__init__()

        self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias)

    def forward(self, x):
        out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
        out = self.conv(out)
        return out


class UpConvLNormLReLU(nn.Module):
    """Upsample Conv block with Layer Norm and Leaky ReLU"""
    def __init__(self, in_channels, out_channels, norm_type="instance", bias=False):
        super(UpConvLNormLReLU, self).__init__()

        self.conv_block = ConvBlock(
            in_channels,
            out_channels,
            kernel_size=3,
            norm_type=norm_type,
            bias=bias,
        )

    def forward(self, x):
        out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
        out = self.conv_block(out)
        return out

class SeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, bias=False):
        super(SeparableConv2D, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3,
            stride=stride, padding=1, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels,
            kernel_size=1, stride=1, bias=bias)
        # self.pad = 
        self.ins_norm1 = nn.InstanceNorm2d(in_channels)
        self.activation1 = nn.LeakyReLU(0.2, True)
        self.ins_norm2 = nn.InstanceNorm2d(out_channels)
        self.activation2 = nn.LeakyReLU(0.2, True)

        initialize_weights(self)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.ins_norm1(out)
        out = self.activation1(out)

        out = self.pointwise(out)
        out = self.ins_norm2(out)

        return self.activation2(out)


class ConvBlock(nn.Module):
    """Stack of Conv2D + Norm + LeakyReLU"""
    def __init__(
        self,
        channels,
        out_channels,
        kernel_size=3,
        stride=1,
        groups=1,
        padding=1,
        bias=False,
        norm_type="instance"
    ):
        super(ConvBlock, self).__init__()

        # if kernel_size == 3 and stride == 1:
        #     self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
        # elif kernel_size == 7 and stride == 1:
        #     self.pad = nn.ReflectionPad2d((3, 3, 3, 3))
        # elif stride == 2:
        #     self.pad = nn.ReflectionPad2d((0, 1, 1, 0))
        # else:
        #     self.pad = None
        
        self.pad = nn.ReflectionPad2d(padding)
        self.conv = nn.Conv2d(
            channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            groups=groups,
            padding=0,
            bias=bias
        )
        self.ins_norm = get_norm(norm_type, out_channels)
        self.activation = nn.LeakyReLU(0.2, True)

        # initialize_weights(self)

    def forward(self, x):
        if self.pad is not None:
            x = self.pad(x)
        out = self.conv(x)
        out = self.ins_norm(out)
        out = self.activation(out)
        return out


class InvertedResBlock(nn.Module):
    def __init__(
        self,
        channels=256,
        out_channels=256,
        expand_ratio=2,
        norm_type="instance",
    ):
        super(InvertedResBlock, self).__init__()
        bottleneck_dim = round(expand_ratio * channels)
        self.conv_block = ConvBlock(
            channels,
            bottleneck_dim,
            kernel_size=1,
            padding=0,
            norm_type=norm_type,
            bias=False
        )
        self.conv_block2 = ConvBlock(
            bottleneck_dim,
            bottleneck_dim,
            groups=bottleneck_dim,
            norm_type=norm_type,
            bias=True
        )
        self.conv = nn.Conv2d(
            bottleneck_dim,
            out_channels,
            kernel_size=1,
            padding=0,
            bias=False
        )
        self.norm = get_norm(norm_type, out_channels)

    def forward(self, x):
        out = self.conv_block(x)
        out = self.conv_block2(out)
        # out = self.activation(out)
        out = self.conv(out)
        out = self.norm(out)

        if out.shape[1] != x.shape[1]:
            # Only concate if same shape
            return out
        return out + x


# --- FROM models/anime_gan_v2.py ---



class GeneratorV2(nn.Module):
    def __init__(self, dataset=''):
        super(GeneratorV2, self).__init__()
        self.name = f'{self.__class__.__name__}_{dataset}'

        self.conv_block1 = nn.Sequential(
            ConvBlock(3, 32, kernel_size=7, stride=1, padding=3, norm_type="layer"),
            ConvBlock(32, 64, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
            ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
        )

        self.conv_block2 = nn.Sequential(
            ConvBlock(64, 128, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
            ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
        )

        self.res_blocks = nn.Sequential(
            ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
            InvertedResBlock(128, 256, expand_ratio=2, norm_type="layer"),
            InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
            InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
            InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
            ConvBlock(256, 128, kernel_size=3, stride=1, norm_type="layer"),
        )

        self.conv_block3 = nn.Sequential(
            # UpConvLNormLReLU(128, 128, norm_type="layer"),
            ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
            ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
        )

        self.conv_block4 = nn.Sequential(
            # UpConvLNormLReLU(128, 64, norm_type="layer"),
            ConvBlock(128, 64, kernel_size=3, stride=1, norm_type="layer"),
            ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
            ConvBlock(64, 32, kernel_size=7, padding=3, stride=1, norm_type="layer"),
        )

        self.decode_blocks = nn.Sequential(
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
            nn.Tanh(),
        )

        initialize_weights(self)

    def forward(self, x):
        out = self.conv_block1(x)
        out = self.conv_block2(out)
        out = self.res_blocks(out)
        out = F.interpolate(out, scale_factor=2, mode="bilinear")
        out = self.conv_block3(out)
        out = F.interpolate(out, scale_factor=2, mode="bilinear")
        out = self.conv_block4(out)
        img = self.decode_blocks(out)

        return img

