import torch
import numpy as np
import os
from skimage import io
import torch.utils.data as data
from os import listdir
from os.path import join
# from utils import is_image_file
import os
from PIL import Image
import random
import cv2
from torchvision import transforms as T
import matplotlib.pyplot as plt
import torchvision.utils as vutils


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])


def default_loader(path):
    im = Image.open(path).convert('RGB')
    # print(im)
    # plt.show()
    return im


def save_jpg(img, name, path):
    # 注意反归一化操作
    # img = img.add_(1).mul_(0.5)
    img = (img-img.min())/(img.max()-img.min())
    ####################### plt 保存数据 #################################
    img = T.ToPILImage()(img.float())
    save_path = path + name + ".jpg"
    img.save(save_path)

    ############ 另外两种 也可行  #####################################


# You should build custom dataset as below.
class DATASET(data.Dataset):
    def __init__(self, dataPath='', loadSize=72, fineSize=64, flip=1):
        super(DATASET, self).__init__()
        # list all images into a list
        self.list = [x for x in listdir(dataPath) if is_image_file(x)]
        self.dataPath = dataPath
        self.loadSize = loadSize
        self.fineSize = fineSize
        self.flip = flip

    def __getitem__(self, index):

        # # 检查文件扩展名是否为.fit
        # _, file_extension = os.path.splitext(self.list[index])
        # if file_extension.lower() != '.fit':
        #     # 如果不是.fit文件，跳过处理
        #     return None
        
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        path = os.path.join(self.dataPath, self.list[index])
        img = default_loader(path)  # 256x256
        # print(img.size)

        # 2. seperate image A and B; Scale; Random Crop; to Tensor
        w, h = img.size
        # print(img.size)

        if (h != self.loadSize):
            img = img.resize((self.loadSize, self.loadSize), Image.BILINEAR)
        if (self.loadSize != self.fineSize):
            x1 = random.randint(0, self.loadSize - self.fineSize)
            y1 = random.randint(0, self.loadSize - self.fineSize)
            img = img.crop((x1, y1, x1 + self.fineSize, y1 + self.fineSize))

        if (self.flip == 1):
            if random.random() < 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)

        img = T.ToTensor()(img)  # 3 x 172 X 172
        #print(img.shape)

        ########################## 这一步的问题 #############################
        img = img.mul_(2).add_(-1)  # 将数据 [0,1] ----> [-1,1]

        # 3. Return a data pair (e.g. image and label).
        return img

    def __len__(self):
        # You should change 0 to the total size of your dataset.

        return len(self.list)


class LoadSaveJpg:
    def __init__(self, path, img, name):
        self.path = path
        self.img = img
        self.name = name

    def save_jpg(img, name, path):
        save_jpg(img[0, :, :, :], name, path)


if __name__ == "__main__":

    ##########      dataset jpg  #############
    path = "../DATA1/"
    datasetA = DATASET(path, 424, 424)
    loader_A = torch.utils.data.DataLoader(dataset=datasetA,
                                           batch_size=2,
                                           shuffle=False,
                                           num_workers=0)
    loaderA = iter(loader_A)

    while True:
        try:
            x = next(loaderA)
            im = x[0, :, :, :]

            # ##############   采用plt 显示图片  #############################

            # 反归一化操作，tensor --》 numpy
            # im = im.add_(1).mul_(0.5)
            # img = im.numpy()*255
            # # 转换维度
            # im = im.permute(1, 2, 0)
            #
            # # 显示数据
            # print(im.max())
            # plt.imshow(im)
            # plt.show()

            ############## 采用 PIL 显示数据 #################################

            # # 除T.tensor()所进行的归一化操作之外的 其他归一化操作的反归一化
            # im = im.add_(1).mul_(0.5)
            #
            # # 将像素值 [0,1]， 维度为 [C, H, W] 的 tensor  --- 》 numpy 对象 [0,255] [H,w,c] ---> Image 对象
            # img = T.ToPILImage()(im)
            #
            # # 显示数据
            # img.show()

            ################# 使用vutils保存数据 ############################
            # 反归一化

            im = im.add_(1).mul_(0.5)
            vutils.save_image(im.cpu().float(), os.path.join("./1.jpg"))

        except StopIteration:
            # 遇到StopIteration就退出循环
            break