import argparse
import copy
import os
import random
from datetime import datetime
from random import sample

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader

from data_generators.data_loader import SegmentationDataSet
from medzoo_imports import Trainer, create_model
from monai.losses.dice import DiceLoss

fits_path = ""
label_path = ""

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
# print current gpu device's available memory
print(torch.cuda.memory_summary(device=None, abbreviated=False))


def main(
    batch_size,
    shuffle,
    num_workers,
    dims,
    overlaps,
    root,
    random_seed,
    train_size,
    model,
    opt,
    lr,
    inChannels,
    classes,
    log_dir,
    dataset_name,
    terminal_show_freq,
    nEpochs,
    cuda,
    scale,
    subsample,
    k_folds,
    pretrained,
    retrain,
    load_data_loc,
    jobid,
    feature_extraction=False,
    augmentation=False,
    amp=False,
):
    """Create training and validation datasets

    Args:
        batch_size (int): Batch size
        shuffle (bool): Whether or not to shuffle the train/val split
        num_workers (int): The number of workers to use
        dims (list): The dimensions of the subcubes
        overlaps (list): The dimensions of the overlap  of subcubes
        root (str): The root directory of the data
        random_seed (int): Random Seed
        train_size (float): Ratio of training to validation split
        model (str): The 3D segmentation model to use
        opt (str): The type of optimizer
        lr (float): The learning rate
        inChannels (int): The desired modalities/channels that you want to use
        classes (int): The number of classes
        log_dir (str): The directory to output the logs
        dataset_name (str): The name of the dataset
        terminal_show_freq (int): How often it shows the output
        nEpochs (int): The number of epochs
        scale (str): Loud or soft - S-N ratio
        subsample (int): Size of subset
        k_folds (int): The number of folds for cross validation
        pretrained (str): The location of the pretrained model
        amp (bool): Whether to use automatic mixed precision

    Returns:
        The training and validation data loaders
    """
    torch.autograd.set_detect_anomaly(True)
    now = datetime.now()  # current date and time
    date_str = now.strftime("%d%m%Y_%H%M%S")
    # input and target files
    model_name = model
    model, optimizer = create_model(args)
    if args.cuda:
        model = model.cuda()
        print("Model transferred in GPU.....")
    if pretrained:
        checkpoint = torch.load(pretrained)
        feature_extractor_keys = [
            key
            for key in checkpoint["model_state_dict"].keys()
            if "final" not in key
            # if  'upcat' not in key and 'final' not in key
        ]

        print("\nFeature extractor state dict keys (excluding 'upcat' and 'final'):")
        for key in feature_extractor_keys:
            print(key)
        feature_extractor_state_dict = {
            k: v
            for k, v in checkpoint["model_state_dict"].items()
            if k in feature_extractor_keys
        }

        # model.load_state_dict(checkpoint['model_state_dict'])
        model.load_state_dict(feature_extractor_state_dict, strict=False)
        start_epoch = 1
        save = "/home/cl/sdc2/HISourceFinder-master-l/%s_saved_models_%s_%s_%s/" % (
            jobid,
            date_str,
            scale,
            subsample,
        )
        if not os.path.exists(save):
            os.mkdir(save)
    else:
        if retrain:
            checkpoint = torch.load(retrain)
            model.load_state_dict(checkpoint["model_state_dict"])
            if feature_extraction:
                for param in model.parameters():
                    param.requires_grad = False
                model.conv1 = nn.Conv3d(
                    model.in_channels, model.num_features, kernel_size=5, padding=2
                )
                weight_decay = 0.001
                optimizer = optim.AdamW(
                    filter(lambda p: p.requires_grad, model.parameters()),
                    lr=lr,
                    weight_decay=weight_decay,
                )
            else:
                optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = 1
        save = "./%s_saved_models_%s_%s_%s/" % (
            jobid,
            date_str,
            scale,
            subsample,
        )
        if not os.path.exists(save):
            os.mkdir(save)
  
    inputs = [
        os.path.join(fits_path, x)
        for x in sorted(os.listdir(fits_path))
        if ".fits" in x
    ]
    targets = [
        os.path.join(label_path, x)
        for x in sorted(os.listdir(label_path))
        if ".fits" in x
    ]
    
    dataset_full = SegmentationDataSet(
        inputs=inputs,
        targets=targets,
        dims=dims,
        overlaps=overlaps,
        arr_shape=(128, 128, 128),
        root=root,
        mode="full",
        save_name=save,
    )
    # dataset validation
    dataset_test = copy.deepcopy(dataset_full)
    # dataset_test.list = [i for i in dataset_full.list if "_6.fits" in i[0][0]]
    suffixes = [f"_{r}_{c}." for r in range(6, 13) for c in range(6, 13)]
    print(len(suffixes))
    dataset_test.list = [
        i for i in dataset_full.list if any(x in i[0][0] for x in suffixes)
    ]
    if subsample < 7:
        # test_cubes = [i.split("/")[-1] for i in inputs if "_6.fits" in i]
        test_cubes = [i.split("/")[-1] for i in inputs if any(x in i for x in suffixes)]
        test_cubes = sample(test_cubes, subsample)
        dataset_test.list = [
            j for j in dataset_test.list if j[0][0].split("/")[-1] in test_cubes
        ]
    print(len(dataset_test.list))
    # print(dataset_test.list)
    params = {"batch_size": 1, "shuffle": shuffle, "num_workers": num_workers}
    # dataloader test
    dataloader_test = DataLoader(dataset=dataset_test, **params)
    print(len(dataset_test.list))

    cubes = [i.split("/")[-1] for i in inputs if not any(x in i for x in suffixes)]
    if subsample < 7:
        cubes = sample(cubes, subsample)
    # For fold results
    dataset_train_val = SegmentationDataSet(
        inputs=inputs,
        targets=targets,
        dims=dims,
        overlaps=overlaps,
        root=root,
        arr_shape=(128, 128, 128),
        mode="train_val",
        save_name=save,
        augmentation=augmentation,
    )
    print("--------------------------------")
    file_list = []
    train_list, val_list = [], []
    # for k in range(k_folds):
    # print('FOLD %s'%k)
    print("--------------------------------")
    args.save = (
        save + "fold" + "_checkpoints/" + model_name + "_",
        dataset_name + "_" + date_str,
    )[0]
    if load_data_loc == "":
        for cube in cubes:
            cube_a = os.path.join(fits_path, cube)
            file_list += [i for i in dataset_train_val.list if cube_a in i[0][0]]
        random.shuffle(file_list)
        print(len(file_list))
        num_val = int(len(file_list) * (1 - train_size))
        num_train = int(len(file_list) * train_size)
        print(num_train, num_val)
        train_list += file_list[:num_train]
        val_list += file_list[num_train : num_train + num_val]
        if not os.path.exists(save + "fold_" + "_checkpoints"):
            os.mkdir(save + "fold_" + "_checkpoints")

    dataset_train = copy.deepcopy(dataset_train_val)
    dataset_train.list = train_list
    # dataset validation
    dataset_valid = copy.deepcopy(dataset_train_val)
    dataset_valid.list = val_list

    params = {"batch_size": batch_size, "shuffle": shuffle, "num_workers": num_workers}
    dataloader_training = DataLoader(dataset=dataset_train, **params)
    # dataloader validation
    dataloader_validation = DataLoader(dataset=dataset_valid, **params)
    print(
        dataloader_training.__len__(),
        dataloader_validation.__len__(),
        dataloader_test.__len__(),
    )
    criterion = DiceLoss(to_onehot_y=True, softmax=True, smooth_nr=1e-5, smooth_dr=1e-5)

    trainer = Trainer(
        args,
        model,
        criterion,
        optimizer,
        train_data_loader=dataloader_training,
        valid_data_loader=dataloader_validation,
        lr_scheduler=None,
        start_epoch=start_epoch,
    )
    print("START TRAINING...")

    trainer.training()

    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Train model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        nargs="?",
        const="default",
        default=1,
        help="Batch size",
    )
    parser.add_argument(
        "--shuffle",
        type=bool,
        nargs="?",
        const="default",
        default=True,
        help="Whether or not to shuffle the train/val split",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        nargs="?",
        const="default",
        default=0,
        help="The number of workers to use",
    )
    parser.add_argument(
        "--dims",
        type=list,
        nargs="?",
        const="default",
        default=[128, 128, 128],
        help="The dimensions of the subcubes",
    )
    parser.add_argument(
        "--overlaps",
        type=list,
        nargs="?",
        const="default",
        default=[0, 0, 0],
        help="The dimensions of the overlap of subcubes",
    )
    parser.add_argument(
        "--root",
        type=str,
        nargs="?",
        const="default",
        default="/home/cl/sdc2/dataset",
        help="The root directory of the data",
    )
    parser.add_argument(
        "--random_seed",
        type=int,
        nargs="?",
        const="default",
        default=42,
        help="Random Seed",
    )
    parser.add_argument(
        "--train_size",
        type=float,
        nargs="?",
        const="default",
        default=0.8,
        help="Ratio of training to validation split",
    )
    parser.add_argument(
        "--model",
        type=str,
        nargs="?",
        const="default",
        default="VNET",
        help="The 3D segmentation model to use",
    )
    parser.add_argument(
        "--opt",
        type=str,
        nargs="?",
        const="default",
        default="adam",
        help="The type of optimizer",
    )
    parser.add_argument(
        "--lr",
        type=float,
        nargs="?",
        const="default",
        default=1e-4,
        help="The learning rate",
    )
    parser.add_argument(
        "--inChannels",
        type=int,
        nargs="?",
        const="default",
        default=1,
        help="The desired modalities/channels that you want to use",
    )
    parser.add_argument(
        "--classes",
        type=int,
        nargs="?",
        const="default",
        default=2,
        help="The number of classes",
    )
    parser.add_argument(
        "--log_dir",
        type=str,
        nargs="?",
        const="default",
        default="/home/cl/sdc2/HISourceFinder-master-l",
        help="The directory to output the logs",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        nargs="?",
        const="default",
        default="hi_source",
        help="The name of the dataset",
    )
    parser.add_argument(
        "--terminal_show_freq",
        type=int,
        nargs="?",
        const="default",
        default=500,
        help="Show when to print progress",
    )
    parser.add_argument(
        "--nEpochs",
        type=int,
        nargs="?",
        const="default",
        default=50,
        help="The number of epochs",
    )
    parser.add_argument(
        "--scale",
        type=str,
        nargs="?",
        const="default",
        default="",
        help="The scale of inserted galaxies to noise",
    )
    parser.add_argument(
        "--subsample",
        type=int,
        nargs="?",
        const="default",
        default=10,
        help="The size of subset to train on",
    )
    parser.add_argument(
        "--cuda",
        type=bool,
        nargs="?",
        const="default",
        default=True,
        help="Memory allocation",
    )
    parser.add_argument(
        "--k_folds",
        type=int,
        nargs="?",
        const="default",
        default=5,
        help="Number of folds for k folds cross-validations",
    )
    parser.add_argument(
        "--pretrained",
        type=str,
        nargs="?",
        const="default",
        default=None,
        help="The location of the pretrained model",
    )
    parser.add_argument(
        "--retrain",
        type=str,
        nargs="?",
        const="default",
        default=None,
        help="The location of the pretrained model to re-train with",
    )
    parser.add_argument(
        "--load_data_loc",
        type=str,
        nargs="?",
        const="default",
        default="",
        help="The location of the data windows",
    )
    parser.add_argument(
        "--jobid",
        type=str,
        nargs="?",
        const="default",
        default="y6-12x6-12_attention_trans",
        help="The job ID",
    )
    args = parser.parse_args()

    main(
        args.batch_size,
        args.shuffle,
        args.num_workers,
        args.dims,
        args.overlaps,
        args.root,
        args.random_seed,
        args.train_size,
        args.model,
        args.opt,
        args.lr,
        args.inChannels,
        args.classes,
        args.log_dir,
        args.dataset_name,
        args.terminal_show_freq,
        args.nEpochs,
        args.cuda,
        args.scale,
        args.subsample,
        args.k_folds,
        args.pretrained,
        args.retrain,
        args.load_data_loc,
        args.jobid,
    )
