{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "import glob\n",
    "import math\n",
    "import os\n",
    "import time\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from astropy.io import fits\n",
    "from matplotlib.font_manager import FontProperties\n",
    "from scipy.signal import convolve2d\n",
    "from sklearn.decomposition import PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_PsfBasicSample(directory_path):\n",
    "    \"\"\"\n",
    "    load fits \n",
    "    reference:\n",
    "    - directory_path\n",
    "    return:\n",
    "    - array\n",
    "    \"\"\"\n",
    "    data_list = []\n",
    "\n",
    "    fits_files = glob.glob(os.path.join(directory_path, '*.fits'))\n",
    "\n",
    "    for fits_filename in fits_files:\n",
    "\n",
    "        hdul = fits.open(fits_filename)\n",
    "                                          \n",
    "        \n",
    "        data = hdul[0].data\n",
    "        data = (data-np.min(data))/(np.max(data)-np.min(data))\n",
    "        # data = data / np.sum(data)\n",
    "        data = center_psf(data)\n",
    "        # Check the shape\n",
    "        # if data.shape[0] == data.shape[1]:  \n",
    "        #     data_list.append(data)\n",
    "        data_list.append(data)\n",
    "        \n",
    "        hdul.close()\n",
    "\n",
    "    arrays = [np.array(lst) for lst in data_list]\n",
    "    matrix = np.stack(arrays, axis= 0)\n",
    "    \n",
    "    return matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_fits_data_from_folders(base_folder):\n",
    "    \"\"\"\n",
    "    Reads FITS data from multiple folders and stores it in a dictionary.\n",
    "    \n",
    "    Parameters:\n",
    "    - base_folder: The base folder path containing multiple subfolders.\n",
    "    \n",
    "    Returns:\n",
    "    - psf_data_dict: A dictionary with subfolder names as keys and PSF data as values.\n",
    "    \"\"\"\n",
    "    psf_data_dict = {}\n",
    "    \n",
    "    # Get all subfolders starting with 'S'\n",
    "    subfolders = glob.glob(os.path.join(base_folder, 'S*'))\n",
    "    xlsx_dict = {}\n",
    "    \n",
    "    for folder_path in subfolders:\n",
    "        print(f\"Processing folder: {folder_path}\")\n",
    "        folder_name = os.path.basename(folder_path)\n",
    "        xlsx_file = glob.glob(os.path.join(folder_path, '*.xlsx'))\n",
    "\n",
    "        # Path to the PSF folder in the current subfolder\n",
    "        psf_folder_path = os.path.join(folder_path, 'psf')\n",
    "        \n",
    "        # Get all FITS files in the PSF folder\n",
    "        fits_files = glob.glob(os.path.join(psf_folder_path, '*.fit'))\n",
    "        \n",
    "        psf_data_list = []\n",
    "        for fits_filename in fits_files:\n",
    "            hdul = fits.open(fits_filename)\n",
    "            data = hdul[0].data\n",
    "            data = data / np.sum(data)  # Normalize data\n",
    "            # Extract x and y coordinates from the filename\n",
    "            file_name = os.path.basename(fits_filename)\n",
    "            file_parts = file_name.split('_')\n",
    "            x_coordinate = float(file_parts[0])\n",
    "            y_coordinate = float(file_parts[1])\n",
    "            if data.shape[0] == data.shape[1]:  # Process square images only\n",
    "                psf_data_list.append((data, x_coordinate, y_coordinate))\n",
    "            hdul.close()\n",
    "        \n",
    "        psf_data_dict[folder_name] = psf_data_list, xlsx_file\n",
    "\n",
    "    return psf_data_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_process_mul(psf_xy_dict):\n",
    "    \"\"\"\n",
    "    Processes multiple folders of FITS data and stores them in a dictionary.\n",
    "    \n",
    "    Parameters:\n",
    "    - psf_xy_dict: Dictionary where keys are folder names and values are lists of PSF data and xlsx paths.\n",
    "    \n",
    "    Returns:\n",
    "    - psf_data_dict: Dictionary with folder names as keys and corresponding PSF data as values.\n",
    "    - coordinates_dict: Dictionary with folder names as keys and their coordinates as values.\n",
    "    \"\"\"\n",
    "    psf_data_dict = {}\n",
    "    coordinates_dict = {}\n",
    "\n",
    "    for folder_name, (psf_xy_data, xlsx_path) in psf_xy_dict.items():\n",
    "        psf_data = []\n",
    "        X_Cord = []\n",
    "        Y_Cord = []\n",
    "        num_psf = len(psf_xy_data)\n",
    "        max_psf_count = int(np.floor(np.sqrt(num_psf)))\n",
    "        psf_length = max_psf_count ** 2\n",
    "        \n",
    "        for i in range(psf_length):\n",
    "            if i >= num_psf:  # Exit loop if index exceeds data length\n",
    "                break\n",
    "            psf, x_cord, y_cord = psf_xy_data[i]\n",
    "            psf = center_psf(psf)\n",
    "            \n",
    "            # Save data as FITS files with the format 'sample_folder_name_i.fit'\n",
    "            sample_folder_name = f\"{folder_name}_{i}.fit\"\n",
    "            sample_folder_path = os.path.join(f'process_data\\Sample', sample_folder_name)\n",
    "            hdu = fits.PrimaryHDU(psf)\n",
    "            hdu.writeto(sample_folder_path, overwrite=True)\n",
    "\n",
    "            psf_data.append(psf)\n",
    "            X_Cord.append(x_cord)\n",
    "            Y_Cord.append(y_cord)\n",
    "\n",
    "        psf_data = np.array(psf_data)\n",
    "        psf_data_dict[folder_name] = psf_data, xlsx_path\n",
    "        coordinates_dict[folder_name] = list(zip(X_Cord, Y_Cord))\n",
    "    \n",
    "    return psf_data_dict, coordinates_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def center_psf(psf_image):\n",
    "    \"\"\"\n",
    "    Centers the PSF image in the middle of the image.\n",
    "    \n",
    "    Parameters:\n",
    "    - psf_image: A NumPy array containing the PSF image.\n",
    "    \n",
    "    Returns:\n",
    "    - A PSF image that has been centered.\n",
    "    \"\"\"\n",
    "    # Find the index of the maximum value (center of PSF)\n",
    "    max_index = np.unravel_index(np.argmax(psf_image), psf_image.shape)\n",
    "    \n",
    "    # Calculate the offset to center the PSF\n",
    "    offset_x = (psf_image.shape[1] - 1) / 2 - max_index[1]\n",
    "    offset_y = (psf_image.shape[0] - 1) / 2 - max_index[0]\n",
    "\n",
    "    # Shift the PSF to the center\n",
    "    centered_psf = np.roll(psf_image, int(offset_y), axis=0)\n",
    "    centered_psf = np.roll(centered_psf, int(offset_x), axis=1)\n",
    "\n",
    "    return centered_psf\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pca_if(resized_arrays):\n",
    "    \"\"\"\n",
    "    Computes the cumulative explained variance ratio using PCA.\n",
    "    \n",
    "    Parameters:\n",
    "    - resized_arrays: A NumPy array containing all resized galaxy images.\n",
    "    \n",
    "    Returns:\n",
    "    - A NumPy array representing the cumulative explained variance for the first 100 principal components.\n",
    "    \"\"\"\n",
    "    # Reshape the data\n",
    "    data_new = resized_arrays.reshape(resized_arrays.shape[0], -1).T\n",
    "    pca = PCA() \n",
    "    pca.fit(data_new)\n",
    "    \n",
    "    pca_info = pca.explained_variance_ratio_\n",
    "    cumulative_explained_variance = np.cumsum(pca_info)\n",
    "    \n",
    "    # Plot the cumulative explained variance\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    plt.plot(range(1, len(cumulative_explained_variance) + 1), cumulative_explained_variance, label='Cumulative Explained Variance')\n",
    "    \n",
    "    # Mark the point for x=100\n",
    "    x_value = 100\n",
    "    y_value = cumulative_explained_variance[x_value - 1]  \n",
    "    plt.scatter(x_value, y_value, color='blue')\n",
    "    plt.annotate(f'({x_value}, {y_value:.4f})', \n",
    "                 xy=(x_value, y_value), \n",
    "                 xytext=(x_value + 5, y_value - 0.05),\n",
    "                 fontsize=14,\n",
    "                 arrowprops=dict(facecolor='black', arrowstyle='->'))\n",
    "    \n",
    "    plt.xlabel('Number of Components', fontsize=22, fontfamily='Times New Roman')\n",
    "    plt.ylabel('Cumulative Explained Variance', fontsize=22, fontfamily='Times New Roman')\n",
    "    plt.xticks(fontsize=22, fontfamily='Times New Roman')\n",
    "    plt.yticks(fontsize=22, fontfamily='Times New Roman')\n",
    "    plt.legend(prop=FontProperties(family='Times New Roman', size=22))\n",
    "    plt.grid(True)\n",
    "    plt.savefig(r'Data\\0930_Compare_Out_Path\\cumulative_explained_variance.png', bbox_inches='tight', pad_inches=0.0, dpi=1000)\n",
    "    plt.show()\n",
    "\n",
    "    # Return the cumulative variance for the first 100 components\n",
    "    return cumulative_explained_variance[:100]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def PCA_(psf_basis_sample, n_components):\n",
    "    \"\"\"\n",
    "    Perform PCA to reduce the dimensionality of the PSF basis samples.\n",
    "    \n",
    "    Parameters:\n",
    "    - psf_basis_sample: A NumPy array containing all the PSF basis samples.\n",
    "    - n_components: The number of principal components to retain.\n",
    "    \n",
    "    Returns:\n",
    "    - A NumPy array representing the reduced PSF basis samples.\n",
    "    \"\"\"\n",
    "    data_new = psf_basis_sample.reshape(psf_basis_sample.shape[0], -1).T\n",
    "    pca = PCA(n_components, svd_solver='full')  # Set the number of components\n",
    "    pca_components = pca.fit_transform(data_new)\n",
    "    pca_components = abs(pca_components)\n",
    "    reduced_data_final_res = pca_components.reshape(\n",
    "        psf_basis_sample.shape[1], psf_basis_sample.shape[1], pca_components.shape[1]).transpose(2, 0, 1)\n",
    "    \n",
    "    normalized_psf_basic = []   \n",
    "    for i in range(reduced_data_final_res.shape[0]):\n",
    "        normalized_psf_basic_sig = reduced_data_final_res[i] / np.sum(reduced_data_final_res[i])\n",
    "        normalized_psf_basic.append(normalized_psf_basic_sig)\n",
    "    \n",
    "    normalized_psf_basic = np.array(normalized_psf_basic)\n",
    "    return normalized_psf_basic\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def coeff_matrix(img_size, resized_arrays, normalized_psf_basic):\n",
    "    \"\"\"\n",
    "    Compute the coefficient matrix.\n",
    "    \n",
    "    Parameters:\n",
    "    - resized_arrays: A NumPy array containing all resized galaxy images.\n",
    "    - normalized_psf_basic: A NumPy array containing all the normalized PSF basis images.\n",
    "    \n",
    "    Returns:\n",
    "    - A NumPy array representing the coefficient matrix.\n",
    "    \"\"\"\n",
    "    X = np.stack([matrix.flatten() for matrix in normalized_psf_basic]).T\n",
    "    y = []\n",
    "    \n",
    "    coefficient_matrices = []\n",
    "    \n",
    "    for i in range(resized_arrays.shape[0]):\n",
    "        # Flatten each PSF and output the matrix as a 1D array\n",
    "        y = resized_arrays[i, :, :].flatten()\n",
    "        # Solve for the coefficients using least squares\n",
    "        coefficients, _, _, _ = np.linalg.lstsq(X, y, rcond=None)\n",
    "        coefficient_matrices.append(coefficients)\n",
    "\n",
    "    # Get the coefficient matrix\n",
    "    coeff_matrix = np.array(coefficient_matrices).reshape(resized_arrays.shape[0], normalized_psf_basic.shape[0])\n",
    "        \n",
    "    print(np.sum(coeff_matrix[0, :]))\n",
    "    print(np.sum(coeff_matrix))\n",
    "    # Reshape the coefficient matrix and interpolate to match the image size\n",
    "    sq = int(np.sqrt(coeff_matrix.shape[0]))\n",
    "    coeff_matrix = np.array(coeff_matrix).reshape(\n",
    "        sq, sq, normalized_psf_basic.shape[0])\n",
    "    \n",
    "    kk = torch.from_numpy(coeff_matrix).permute(2, 0, 1).unsqueeze(0)\n",
    "    print(np.shape(kk))\n",
    "    # Resize the coefficients using bilinear interpolation\n",
    "    zer = F.interpolate(kk, size=img_size,\n",
    "                        mode='bilinear', align_corners=False)\n",
    "    print(np.shape(zer))\n",
    "    coef = zer.squeeze(0).permute(1, 2, 0)\n",
    "    \n",
    "    # The interpolated coefficient matrix\n",
    "    output_array = np.array(coef)\n",
    "    return output_array\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Gen_Psf(coefficients, psf_basic, x_cord, y_cord):\n",
    "    \"\"\"\n",
    "    Generate a PSF image.\n",
    "    \n",
    "    Parameters:\n",
    "    - coefficients: A NumPy array containing all the coefficients.\n",
    "    - psf_basic: A NumPy array containing all the basic PSFs.\n",
    "    - x_cord: The x coordinate of the PSF.\n",
    "    - y_cord: The y coordinate of the PSF.\n",
    "    \n",
    "    Returns:\n",
    "    - A NumPy array representing the PSF image.\n",
    "    \"\"\"\n",
    "    x_cord = round(x_cord)\n",
    "    y_cord = round(y_cord)\n",
    "    coefficient = np.array(coefficients[x_cord, y_cord, :])\n",
    "    \n",
    "    result_psf = np.zeros_like(psf_basic[0])\n",
    "    # Multiply corresponding PSF and coefficients and sum them\n",
    "    for i in range(psf_basic.shape[0]):\n",
    "        result_psf += psf_basic[i].astype(float) * coefficient[i].astype(float) \n",
    "    \n",
    "    return result_psf\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_noise_around_psf(psf_image, kernel_size=15, sigma=5):\n",
    "    \"\"\"\n",
    "    Remove noise from the PSF image.\n",
    "    \n",
    "    Parameters:\n",
    "    - psf_image: A NumPy array containing the PSF image.\n",
    "    - kernel_size: The size of the Gaussian filter kernel.\n",
    "    - sigma: The standard deviation of the Gaussian filter.\n",
    "    \n",
    "    Returns:\n",
    "    - A PSF image with noise removed.\n",
    "    \"\"\"\n",
    "\n",
    "    # Apply Gaussian blur\n",
    "    blurred_image = cv2.GaussianBlur(psf_image, (kernel_size, kernel_size), sigma)\n",
    "    \n",
    "    # Get the center coordinates of the PSF\n",
    "    center_x, center_y = psf_image.shape[1] // 2, psf_image.shape[0] // 2\n",
    "    central_region = blurred_image[center_y - 25:center_y + 25, center_x - 25:center_x + 25]\n",
    "\n",
    "    # Calculate the mean value of the central region\n",
    "    mean_value = np.mean(central_region)\n",
    "    denoised_image = psf_image - (mean_value - np.min(psf_image))\n",
    "    denoised_image[denoised_image < 0] = 0\n",
    "\n",
    "    return denoised_image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_psf_image(coordinates, coefficients, psf_basis, img_size):\n",
    "    # Create an empty large image\n",
    "    image = np.zeros(img_size)\n",
    "    \n",
    "    # For each coordinate, generate the corresponding PSF and place it at the corresponding position in the large image\n",
    "    for x_cord, y_cord in coordinates:\n",
    "        x_cord = int(round(x_cord))\n",
    "        y_cord = int(round(y_cord))\n",
    "        psf = Gen_Psf(coefficients, psf_basis, x_cord, y_cord)\n",
    "        \n",
    "        # Remove noise from the PSF\n",
    "        psf = remove_noise_around_psf(psf)\n",
    "        \n",
    "        # Save the PSF to a folder with the filename x_cord_y_cord\n",
    "        sample_folder_name = f\"{x_cord}_{y_cord}.fit\"\n",
    "        sample_folder_path = os.path.join(f'process_data\\Simulate', sample_folder_name)\n",
    "        hdu = fits.PrimaryHDU(psf)\n",
    "        hdu.writeto(sample_folder_path, overwrite=True)\n",
    "        \n",
    "        # Randomly generate brightness and size\n",
    "        brightness = np.random.uniform(6000, 20000)\n",
    "        size = 15\n",
    "        std = 0.8\n",
    "\n",
    "        # Adjust brightness\n",
    "        psf = psf * brightness\n",
    "\n",
    "        # Calculate the center position of the PSF\n",
    "        center_x, center_y = psf.shape[0] // 2, psf.shape[1] // 2\n",
    "        \n",
    "        # Ensure the size of the convolved PSF matches the remaining space in the large image at that position\n",
    "        convolved_psf = psf[max(0, center_x - x_cord):min(psf.shape[0], image.shape[0] - x_cord + center_x), \n",
    "                            max(0, center_y - y_cord):min(psf.shape[1], image.shape[1] - y_cord + center_y)]\n",
    "        \n",
    "        # Place the convolved PSF at the corresponding position in the large image\n",
    "        start_x = max(0, x_cord - center_x)\n",
    "        start_y = max(0, y_cord - center_y)\n",
    "        image[start_x:start_x+convolved_psf.shape[0], start_y:start_y+convolved_psf.shape[1]] = convolved_psf\n",
    "\n",
    "    return image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_gen_psf(output_folder_path, coefficients_dict, psf_basis, img_size):\n",
    "    # Get all files in the folder\n",
    "    for file_name, (coefficients, xlsx_path) in coefficients_dict.items():\n",
    "        # Read the file and generate the path from the list\n",
    "        xlsx_path = xlsx_path[0]\n",
    "        df = pd.read_excel(xlsx_path)\n",
    "        \n",
    "        # Get the file name\n",
    "        file_name = os.path.splitext(os.path.basename(xlsx_path))[0]\n",
    "        \n",
    "        # Convert the first and second columns into a list of coordinates\n",
    "        coordinates = list(zip(df.iloc[:, 0], df.iloc[:, 1]))\n",
    "        \n",
    "        # Generate the PSF image\n",
    "        T0 = time.time()\n",
    "        # Removed std, size parameters\n",
    "        psf_img = generate_psf_image(coordinates, coefficients, psf_basis, img_size)\n",
    "        \n",
    "        # Add random noise to the image\n",
    "        noise = np.random.uniform(20, 40, psf_img.shape)\n",
    "        psf_img += noise\n",
    "        T1 = time.time()\n",
    "        print('Time taken to generate a single simulated image: %s seconds' % (T1 - T0))\n",
    "        \n",
    "        # Create the FITS file\n",
    "        hdu = fits.PrimaryHDU(psf_img)\n",
    "        hdul = fits.HDUList([hdu])\n",
    "        \n",
    "        # Create a new folder to store the star maps, named with the current date\n",
    "        today = datetime.date.today().strftime('%Y%m%d')\n",
    "        # Add the parameter values after the date\n",
    "        folder_name = today\n",
    "        new_folder_path = os.path.join(output_folder_path, folder_name)\n",
    "        os.makedirs(new_folder_path, exist_ok=True)\n",
    "        \n",
    "        # Get the output file path\n",
    "        output_file_path = os.path.join(new_folder_path, file_name + '.fit')\n",
    "        \n",
    "        # Save the FITS file\n",
    "        hdul.writeto(output_file_path, overwrite=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    # ---------- Load PSF Data ----------\n",
    "    \n",
    "    # --- GWAC-datafolder ---\n",
    "    # Ensure the data has consistent size; includes image centering code\n",
    "    psf_basis_sample_folder_path = r'psf_basic_sample'\n",
    "    base_folder_path = r'gird_data'\n",
    "    sim_img_output_path = r'output_folder'\n",
    "    img_size = (4136, 4196)\n",
    "    \n",
    "    # ---------- Load Grid Data ----------\n",
    "    print(\"Loading PSF Data...\")\n",
    "    psf_data_list = get_fits_data_from_folders(base_folder_path)\n",
    "    print(f\"Loaded {len(psf_data_list)} PSF images.\")\n",
    "\n",
    "    # ---------- Process Data ----------\n",
    "    print(\"Processing PSF Data...\")\n",
    "    psf_grid_data_m, coordinate_m = data_process_mul(psf_data_list)\n",
    "    \n",
    "    # ---------- Load Basis Sample Data ----------\n",
    "    print(\"Loading PSF Basis Samples...\")\n",
    "    psf_basis_sample = get_PsfBasicSample(psf_basis_sample_folder_path)\n",
    "    \n",
    "    # ---------- PCA Dimensionality Reduction ----------\n",
    "    n_components = 100\n",
    "    print(\"Performing PCA...\")\n",
    "    # pca_if(psf_basis_sample)\n",
    "    psf_basis = PCA_(psf_basis_sample, n_components)\n",
    "    \n",
    "    # ---------- Compute Coefficients ----------\n",
    "    print(\"Computing Coefficients...\")\n",
    "    coeff_matrix_dict = {}\n",
    "    for foldername, (psf_data, xlsx_path) in psf_grid_data_m.items():\n",
    "        coefficients = coeff_matrix(img_size, psf_data, psf_basis)\n",
    "        coeff_matrix_dict[foldername] = coefficients, xlsx_path\n",
    "\n",
    "    # ---------- Generate Simulated Figures ----------\n",
    "    print(\"Generating Simulated PSFs...\")\n",
    "    process_gen_psf(sim_img_output_path, coeff_matrix_dict, psf_basis, img_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading PSF Data...\n",
      "Processing folder: gird_data\\SN2023ixf_ML_200115_G044\n",
      "Processing folder: gird_data\\SN2023ixf_ML_200117_G024\n",
      "Processing folder: gird_data\\SN2023ixf_ML_200123_G044\n",
      "Processing folder: gird_data\\SN2023ixf_ML_200125_G024\n",
      "Loaded 4 PSF images.\n",
      "Processing PSF Data...\n",
      "Loading PSF Basis Samples...\n",
      "Performing PCA...\n",
      "Computing Coefficients...\n",
      "0.9825328150815279\n",
      "199.14730788791672\n",
      "torch.Size([1, 100, 14, 14])\n",
      "torch.Size([1, 100, 4136, 4196])\n",
      "0.9775865927596932\n",
      "199.49113196979056\n",
      "torch.Size([1, 100, 14, 14])\n",
      "torch.Size([1, 100, 4136, 4196])\n",
      "1.0088870186669263\n",
      "122.64755859982456\n",
      "torch.Size([1, 100, 11, 11])\n",
      "torch.Size([1, 100, 4136, 4196])\n",
      "1.027526895663851\n",
      "81.68010544353821\n",
      "torch.Size([1, 100, 9, 9])\n",
      "torch.Size([1, 100, 4136, 4196])\n",
      "Generating Simulated PSFs...\n",
      "Time taken to generate a single simulated image: 1.7903130054473877 seconds\n",
      "Time taken to generate a single simulated image: 1.857999563217163 seconds\n",
      "Time taken to generate a single simulated image: 2.0200347900390625 seconds\n",
      "Time taken to generate a single simulated image: 1.7674732208251953 seconds\n"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
