import argparse
import os
import numpy
import random
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.backends.cudnn as cudnn
from utils.dataset import DATASET
from torch.autograd import Variable
from utils.ImagePool import ImagePool
from model.Generator import Generator
from utils.fitsFun import DATASET_fits
from utils.fitsFun import LoadSaveFits

parser = argparse.ArgumentParser(description='train pix2pix model')
parser.add_argument('--batchSize', type=int, default=100, help='with batchSize=1 equivalent to instance normalization.')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=150, help='number of iterations to train for')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay in network D, default=1e-4')
parser.add_argument('--cuda',default = True, action='store_true', help='enables cuda')
parser.add_argument('--outf', default='checkpoints/', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--dataPath_train',default='./dataset/train/',help='image data')
parser.add_argument('--dataPath_val',default='./dataset/val/',help='image data')
parser.add_argument('--loadSize', type=int, default=500, help='scale image to this size')
parser.add_argument('--fineSize', type=int, default=500, help='random crop image to this size')
parser.add_argument('--flip', type=int, default=1, help='1 for flipping image randomly, 0 for not')
parser.add_argument('--input_nc', type=int, default=1, help='channel number of input image')
parser.add_argument('--output_nc', type=int, default=1, help='channel number of output image')
parser.add_argument('--kernelSize', type=int, default=3, help='random crop kernel to this size')
parser.add_argument('--G_AB', default='', help='path to pre-trained G_AB')
parser.add_argument('--save_step', type=int, default=10, help='save interval')
parser.add_argument('--log_step', type=int, default=10, help='log interval')
parser.add_argument('--loss_type', default='mse', help='GAN loss type, bce|mse default is negative likelihood loss')
parser.add_argument('--poolSize', type=int, default=50, help='size of buffer in lsGAN, poolSize=0 indicates not using history')
parser.add_argument("--gpus", default="0", type=str, help="gpu ids (default: 0)")
opt = parser.parse_args()
print(opt)

try:
    os.makedirs(opt.outf)
except OSError:
    pass

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    print("=> use  gpu id: '{}'".format(opt.gpus))
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

opt.seed = random.randint(1, 10000)
print("Random Seed: ", opt.seed)
torch.manual_seed(opt.seed)
cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

##########      dataset fits  #############
dataset_train = DATASET_fits(opt.dataPath_train,opt.fineSize)
loader_train= torch.utils.data.DataLoader(dataset=dataset_train,
                                       batch_size=opt.batchSize,
                                       shuffle=True,
                                       num_workers=0)
loadertrain = iter(loader_train)

dataset_val = DATASET_fits(opt.dataPath_val,opt.fineSize)
loader_val= torch.utils.data.DataLoader(dataset=dataset_val,
                                       batch_size=opt.batchSize,
                                       shuffle=True,
                                       num_workers=0)
loaderval = iter(loader_val)

ABPool = ImagePool(opt.poolSize)

############   MODEL   ###########
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

ndf = opt.ndf
ngf = opt.ngf

G_AB = Generator(opt.input_nc, opt.output_nc, opt.ngf)

if(opt.G_AB != ''):
    print('Warning! Loading pre-trained weights.')
    G_AB.load_state_dict(torch.load(opt.G_AB))
else:
    G_AB.apply(weights_init)
    
if(opt.cuda):
    G_AB.cuda()


###########   LOSS & OPTIMIZER   ##########
criterionMSE = nn.L1Loss()
if(opt.loss_type == 'bce'):
    criterion = nn.BCELoss()
else:
    criterion = nn.MSELoss()

optimizerG = torch.optim.Adam(G_AB.parameters(),lr=opt.lr, betas=(opt.beta1, 0.999))

############   GLOBAL VARIABLES   ###########
input_nc = opt.input_nc
output_nc = opt.output_nc
fineSize = opt.fineSize
batchSize = opt.batchSize
kernelSize = opt.kernelSize
loss_least = float('inf')
train_loss = []
val_loss = []

real_A = torch.FloatTensor(opt.batchSize, input_nc, fineSize, fineSize)
AB = torch.FloatTensor(opt.batchSize, input_nc, fineSize, fineSize)


real_A = Variable(real_A)
AB = Variable(AB)

if(opt.cuda):
    real_A = real_A.cuda()
    AB = AB.cuda()
    criterion.cuda()
    criterionMSE.cuda()


###########   Val    ###########
def val(niter):
    loaderval = iter(loader_val)
    imgA = next(loaderval)
    real_A.data.resize_(imgA[:,0:1,:,:].size()).copy_(imgA[:,0:1,:,:])
    AB = G_AB(real_A)
    errG = criterionMSE(AB,real_A)
    errG = torch.Tensor.cpu(errG.data)
    errG = errG.data.numpy()
    LoadSaveFits.save_fit(real_A.data,'realA_%03d_'%niter,'./out_picture/out_image_train/')  
    LoadSaveFits.save_fit(AB.data,'AB_%03d_'%niter,'./out_picture/out_image_train/')
    return errG

###########   Training   ###########
G_AB.train()
for iteration in range(1,opt.niter+1):
    try:
        imgA = next(loadertrain)
    except StopIteration:
        loadertrain = iter(loader_train)
        imgA = next(loadertrain)

    real_A.resize_(imgA[:,0:1,:,:].size()).copy_(imgA[:,0:1,:,:])

    G_AB.zero_grad()
    AB = G_AB(real_A)

    ###########    Loss     ############
    errG = criterionMSE(AB,real_A)

    loss = torch.Tensor.cpu(errG.data)
    loss = loss.data.numpy()
    train_loss.append((loss.data).tolist())

    errG.backward()
    optimizerG.step()
    
    ###########  Visualize  ############
    loss = val(iteration)
    val_loss.append(loss)
    plt.plot(train_loss)
    plt.plot(val_loss)
    plt.savefig("./loss.jpg")
    
    ###########   Logging   ############
    if(iteration % opt.log_step):
        print('[%d/%d] Loss_MAE: %.4f '% (iteration, opt.niter,errG.data))

    if loss < loss_least:
        loss_least = loss
        torch.save(G_AB.state_dict(), '{}/G_AB_latest.pth'.format(opt.outf))
    




