import argparse
import copy
import os
from datetime import datetime
from os import listdir
from random import sample
import numpy as np
import pandas as pd
import scipy.ndimage as ndi
import scipy.ndimage as ndimage
import torch
import torch.nn as nn
import torch.optim as optim
from astropy.io import fits
from scipy.ndimage import label
from torch.utils.data import DataLoader
from tqdm import tqdm
from data_generators.data_loader import SegmentationDataSet
from medzoo_imports import create_model
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from typing import Iterable, Union
import time


def mc_inference(model, input_tensor, num_samples=10):
    """
    使用 MC-Dropout 估计输出不确定性：
      - model:   已经替换好 MC-Dropout 的模型
      - input:   输入张量，shape (B, C, X, Y, Z...)
      - num_samples: 前向次数
    返回:
      - mean:   shape (B, out_channels, ...)
      - var:    shape (B, out_channels, ...)
    """
    model.train()  
    outputs = []

    for _ in range(num_samples):

        with torch.no_grad():
            out = model.forward(input_tensor)[0]
            if isinstance(out, tuple):
                out = out[0]
                out = torch.sigmoid(out)
        outputs.append(out.unsqueeze(0)) 

    outputs = torch.cat(outputs, dim=0)

    mean = outputs.mean(dim=0)  # (B, out_ch, ...)
    var = outputs.var(dim=0)  # (B, out_ch, ...)

    return mean, var

    
    
def main(
    input_dir,
    out_fits_dir,
    batch_size,
    dims,
    checkpointdir,
    jobid,
    shuffle=False,
    num_workers=0,
):
    now = datetime.now()  # current date and time
    date_str = now.strftime("%d%m%Y_%H%M%S")
    params = {"batch_size": batch_size, "shuffle": shuffle, "num_workers": num_workers}
    model, optimizer = create_model(args)
    checkpoint = torch.load(checkpointdir)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    if os.path.exists(out_fits_dir):
        print("输出目录已存在:", out_fits_dir)
    else:
        os.makedirs(out_fits_dir)
        print("创建输出目录:", out_fits_dir)
    

    device_ids = [0]                     
    main_device = torch.device(f"cuda:{device_ids[0]}")
    model = model.to(main_device)            

  
    save = "/home/dell461/cl/sdc2/HISourceFinder-master-l/%s_saved_models_%s/" % (
        jobid,
        date_str,
    )
    inputs = [
        os.path.join(input_dir, x)
        for x in sorted(os.listdir(input_dir))
        if ".fits" in x
    ]
    dataset_full = SegmentationDataSet(
        inputs=inputs,
        dims=dims,
        overlaps=[0,0,0],
        arr_shape=dims,
        mode="full",
        save_name=save,
    )
    # dataset validation
    dataset_test = copy.deepcopy(dataset_full)
    dataset_test.list = copy.deepcopy(dataset_full.list)
    print(len(dataset_full.list))
    print("Test dataset size: ", len(dataset_test.list))
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"])
    print(len(dataset_test.list))
    model.train()
    with torch.no_grad():
        for batch_idx, input_tuple in enumerate(dataloader_test):
            print(dataset_test.list[batch_idx])
            input_tensor, cube_files = input_tuple
           
            file_paths = cube_files[0][0]
            input_tensor = input_tensor.to(main_device)
            start_time = time.time()
            out_cube, var_out = mc_inference(model, input_tensor, num_samples=10)
            end_time = time.time()
            print(f"Inference time: {end_time - start_time:.4f} seconds")
    
            for i in range(params["batch_size"]):
                
                var_np = var_out[i].cpu().squeeze()[1].numpy()
                out_np = out_cube[i].cpu().squeeze()[1].numpy()
                print(out_np.max())
                print(batch_idx)
                # target_np = np.moveaxis(target_np, 2, 0)
                var_np = np.moveaxis(var_np, 2, 0)
                mean_np = np.moveaxis(out_np, 2, 0)
                file_paths = cube_files[0][i]
                file_name = os.path.basename(file_paths)
                out_mean_fits_path = file_name.replace(".fits", "_mean.fits")
                out_var_fits_path = file_name.replace(".fits", "_var.fits")
                print("Writing to: ", out_mean_fits_path)
                fits.writeto(
                    os.path.join(out_fits_dir, out_mean_fits_path),
                    mean_np.astype(np.float32),
                    overwrite=True,
                )
                fits.writeto(
                    os.path.join(out_fits_dir, out_var_fits_path),
                    var_np.astype(np.float32),
                    overwrite=True,
                )
    
    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Testmodel",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--input_dir",
        type=str,
        nargs="?",
        const="default",
        default=f"",
        help="The input directory of the data",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        nargs="?",
        const="default",
        default="./data/output/",
        help="The output directory for the results",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        nargs="?",
        const="default",
        default=1,
        help="Batch size",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        nargs="?",
        const="default",
        default=0,
        help="The number of workers to use",
    )
    parser.add_argument(
        "--dims",
        type=list,
        nargs="?",
        const="default",
        default=[128, 128, 128],
        help="The dimensions of the subcubes",
    )
    parser.add_argument(
        "--opt",
        type=str,
        nargs="?",
        const="default",
        default="adam",
        help="The type of optimizer",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        nargs="?",
        const="default",
        default="hi_source",
        help="The name of the dataset",
    )
    parser.add_argument(
        "--checkponts_dir",
        type=str,
        nargs="?",
        const="default",
        default=f"",
        help="权重路径",
    )
    parser.add_argument(
        "--jobid",
        type=str,
        nargs="?",
        const="default",
        default="unetplusplus",
        help="The job ID",
    )
    args = parser.parse_args()

    main(
        args.input_dir,
        args.output_dir,
        args.batch_size,
        args.dims,
        args.checkponts_dir,
        args.jobid,
    )

