
import argparse
import os
import sys
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, script_dir)
import random
import torch
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
# from utils.dataset import DATASET 
from model.Generator import Generator
from utils.fitsFun import DATASET_fits
# from Tools import get_mask
from torch.autograd import Variable
import csv
import shutil
import math

def init_quality_assessment():
    
    global opt, real_A, criterionMSE, G_AB, device
    parser = argparse.ArgumentParser(description='get quality score')
    parser.add_argument('--batchSize', type=int, default=4, help='with batchSize=1 equivalent to instance normalization.')
    parser.add_argument('--ngf', type=int, default=64)
    parser.add_argument('--ndf', type=int, default=64)
    parser.add_argument('--patch_size', type=int, default=50)
    parser.add_argument('--cuda', default=True, action='store_true', help='enables cuda')
    parser.add_argument('--manualSeed', type=int, help='manual seed')
    parser.add_argument('--dataPath', default='datapath',type=str, help='path to training images') # data
    parser.add_argument('--loadSize', type=int, default=500, help='scale image to this size')
    parser.add_argument('--fineSize', type=int, default=500, help='random crop image to this size')
    parser.add_argument('--flip', type=int, default=0, help='1 for flipping image randomly, 0 for not')
    parser.add_argument('--input_nc', type=int, default=1, help='channel number of input image')
    parser.add_argument('--output_nc', type=int, default=1, help='channel number of output image')
    parser.add_argument('--G_AB', default=script_dir +'/checkpoints/G_AB_latest.pth', help='path to pre-trained G_AB')
    parser.add_argument('--imgNum', type=int, default=2, help='image number')
    parser.add_argument("--gpus", default="1", type=str, help="gpu ids (default: 1)")
    opt = parser.parse_args()

    if opt.manualSeed is None:     
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    if opt.cuda:
        print("=> use gpu id: '{}'".format(opt.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
        if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")


    ###########  MODEL   ###########
    ndf = opt.ndf
    ngf = opt.ngf
    G_AB = Generator(opt.input_nc, opt.output_nc, opt.ngf)

    from collections import OrderedDict
    if(opt.G_AB != ''):
        # print('Warning! Loading pre-trained weights.')
        G_AB.load_state_dict(torch.load(opt.G_AB, map_location=torch.device('cpu') ,weights_only=True))
    else:
        print('ERROR! G_AB must be provided!')

    if(opt.cuda):
        G_AB.cuda()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ###########   GLOBAL VARIABLES   ###########
    input_nc = opt.input_nc
    output_nc = opt.output_nc
    fineSize = opt.fineSize

    real_A = torch.FloatTensor(opt.batchSize, input_nc, fineSize, fineSize)
    real_A = Variable(real_A)

    if(opt.cuda):
        real_A = real_A.cuda()

    criterionMSE = nn.L1Loss().to(device) 

def norm2(img, z):
        for i in range(z):
            img[i] = (img[i] - np.min(img[i])) / (np.max(img[i]) - np.min(img[i]))  # normalization
            img[i] -= np.mean(img[i])  # take the mean
            img[i] /= np.std(img[i])  # standardization
            img[i] = np.array(img[i], dtype='float32')
        return img

def get_quality_score(img_data):
    """
    Computes a quality score for the input image based on the mean squared error (MSE) of its patches.

    Parameters:
    ----------
    img_data : numpy array
        The input image data as a numpy array.

    Returns:
    --------
    quality_score : float
        A score between 0 and 1 that represents the quality of the input image, 
        where higher values indicate better quality.
    """
    img = torch.from_numpy(img_data)
    img = torch.unsqueeze(img, dim=0)
    img = img.numpy()
    z, h, w = img.shape
    img = norm2(img, z)
    img = torch.from_numpy(img)
    img = torch.unsqueeze(img, dim=0)
    loss_list = []
    imgA = img.to(device)
    real_A.resize_(imgA[:,:,:,:].size()).copy_(imgA[:,: ,:,:])
    b, c, image_h, image_w = np.shape(real_A)

    num_y = image_h // opt.fineSize
    num_x = image_w // opt.fineSize
    for x in range(num_x):
        for y in range(num_y):
            xy = real_A[:, :, opt.fineSize * y:opt.fineSize * (y + 1), opt.fineSize * x:opt.fineSize * (x + 1)]
            AB = G_AB(xy)
            errMSE = criterionMSE(AB, xy)
            loss_list.append(errMSE.item())
    loss = np.mean(loss_list)  
    quality_score = 1 / (1 + np.exp(loss))
    return quality_score


