import PIL.Image as Image
from astropy.io import fits
import torchvision.transforms as transforms
import numpy as np
import torch
import os
from skimage import transform
import torch.utils.data as Data
from skimage import exposure
import random

def cut(image, width, height):

    left = 0
    while np.all(image[:, :, left] == 0):
        left += 1

    # 寻找右边界
    right = width - 1
    while np.all(image[:, :, right] == 0):
        right -= 1

    # 寻找顶部边界
    top = 0
    while np.all(image[:, top, :] == 0):
        top += 1

    # 寻找底部边界
    bottom = height - 1
    while np.all(image[:, bottom, :] == 0):
        bottom -= 1

    # 裁剪图像，去除黑边部分
    cropped_image = image[:, 7:249, 7:249]
    cropped_image = transform.resize(cropped_image, (1, 256, 256), anti_aliasing=True)
    
    return cropped_image
    
    
    
class LoadSaveFits:
    def __init__(self, path, img, name):
        self.path = path
        self.img = img
        self.name = name

    def norm(img):
        img = (img - np.min(img)) / (np.max(img) - np.min(img))  # normalization
        img -= np.mean(img)  # take the mean
        img /= np.std(img)  # standardization
        img = np.array(img, dtype='float32')
        return img

    def norm2(img, z):
        for i in range(z):
            img[i] = (img[i] - np.min(img[i])) / (np.max(img[i]) - np.min(img[i]))  # normalization
            img[i] -= np.mean(img[i])  # take the mean
            img[i] /= np.std(img[i])  # standardization
            img[i] = np.array(img[i], dtype='float32')
        return img

    def read_fits(path):
        hdu = fits.open(path)
        img = hdu[0].data
        img = np.array(img, dtype=np.float32)
        # img = transform.resize(img, (400, 400), anti_aliasing=True)
        hdu.close()
        
        return img

    def save_fit_cpu(img, name, path):
        if os.path.exists(path + name + '.fit'):
            os.remove(path + name + '.fit')
        grey = fits.PrimaryHDU(img)
        greyHDU = fits.HDUList([grey])
        greyHDU.writeto(path + name + '.fit')

    def save_fit(img, name, path):
        if torch.cuda.is_available():
            img = torch.Tensor.cpu(img)
            img = img.data.numpy()
            IMG = img[ 0, :, :]
        else:
            img = np.array(img)
        if os.path.exists(path + name + '.fit'):
            os.remove(path + name + '.fit')
        grey = fits.PrimaryHDU(img)
        greyHDU = fits.HDUList([grey])
        greyHDU.writeto(path + name + '.fit')
        
    def save_jpg(img, name, path):
        # 注意反归一化操作
        # img = img.add_(1).mul_(0.5)
        img = (img-img.min())/(img.max()-img.min())
        ####################### plt 保存数据 #################################
        img = transforms.ToPILImage()(img.float())
        save_path = path + name + ".jpg"
        img.save(save_path)

# load data of cycleGAN
class DATASET_fits():
    def __init__(self, dataPath='', fineSize=512):
        super(DATASET_fits, self).__init__()
        # list all images into a list
        all_files = os.listdir(dataPath)
        self.list = [filename for filename in all_files if filename.lower().endswith('.fit')]
        # self.list.sort()
        self.list = sorted(self.list)
        self.dataPath = dataPath
        self.fineSize = fineSize

        # self.list.sort()
        # # print(self.list)
        # self.dataPath = dataPath
        # self.fineSize = fineSize

    def __getitem__(self, index):
        path = os.path.join(self.dataPath, self.list[index])
        # _, file_extension = os.path.splitext(self.list[index])
        # if file_extension.lower() != '.fit':
        #     # 如果不是.fit文件，返回空tensor
        #     return self.__getitem__((index + 1) % len(self.list))
        try:
            img = LoadSaveFits.read_fits(path)
            img = torch.from_numpy(img)
            img = torch.unsqueeze(img, dim=0)
            img = img.numpy()
            z, h, w = img.shape
            print(img.shape)
            number_rot = random.randint(0, 3)
            #for zz in range(z):
                #img[zz, :, :] = np.rot90(img[zz, :, :], number_rot)
            #img = img[:, int((h / 2 - self.fineSize / 2)):int((h / 2 + self.fineSize / 2)),
                #int((w / 2 - self.fineSize / 2)):int((w / 2 + self.fineSize / 2))]   
            #img = cut(img, w, h)
            img = LoadSaveFits.norm2(img, z)
            img = torch.from_numpy(img)
        
            return img

        except Exception as e:
            print(f"Error processing file {path}: {e}")
            return torch.zeros((1, 256, 256), dtype=torch.float32)

    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return len(self.list)











