from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import Activations, AsDiscrete, Compose
from monai.utils.enums import MetricReduction
import torch
from monai.networks.nets import DynUNet
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import os
from abc import ABC, abstractmethod

# model = SwinUNETR(
#         in_channels=1,
#         out_channels=2,
#         feature_size=48,
#         spatial_dims=3,
#     )
model = DynUNet(
    spatial_dims=3,                 # 3D卷积
    in_channels=1,                  # 输入通道数，例如CT影像是灰度图，in_channels=1
    out_channels=2,                 # 输出通道数，例如2分类（前景/背景）
    kernel_size=[3, 3, 3, 3, 3],     # 每个阶段卷积核大小，可以是每层不同
    strides=[1, 2, 2, 2, 2],         # 下采样的步长，比如一开始不缩小，后面每次缩小一半
    upsample_kernel_size=[2, 2, 2, 2], # 上采样步长，通常跟 strides[1:] 保持一致
    filters=[32, 64, 128, 256, 512], # 每层的通道数，逐渐加深
    dropout=0.1,                    # Dropout率（可选）
    norm_name=("INSTANCE", {"affine": True}),  # 使用InstanceNorm
    act_name=("leakyrelu", {"inplace": True, "negative_slope": 0.01}), # 激活函数
    deep_supervision=True,           # 开启深监督，帮助中间层学习
    deep_supr_num=1,                 # 深监督输出的数量
    res_block=True,                  # 使用残差连接（推荐开启）
    trans_bias=False                 # 转置卷积的bias，一般设False
)

class Dynunet(DynUNet):
    def __init__(self,in_channels=1, classes=2):
        # super(Swinunetr, self).__init__()
        super(Dynunet, self).__init__(
             spatial_dims=3,
        in_channels=in_channels,
        out_channels=classes,
        kernel_size=[3, 3, 3, 3, 3],     # 每个阶段卷积核大小，可以是每层不同
        strides=[1, 2, 2, 2, 2],         # 下采样的步长，比如一开始不缩小，后面每次缩小一半
        upsample_kernel_size=[2, 2, 2, 2], # 上采样步长，通常跟 strides[1:] 保持一致
        filters=[32, 64, 128, 256, 512], # 每层的通道数，逐渐加深
        dropout=0.1,    
      )
        self.best_loss = 1000000
        # self.classes = classes
        # self.in_channels = in_channels
        # self.swinunetr = SwinUNETR(
        #     img_size=img_size,
        #     in_channels=in_channels,
        #     out_channels=classes,
        #     feature_size=48,
        #     spatial_dims=3,
        # )
    def forward(self, x):
        return super(Dynunet, self).forward(x)
    
    @property
    def device(self):
        return next(self.parameters()).device

    def restore_checkpoint(self, ckpt_file, optimizer=None):
        r"""
        Restores checkpoint from a pth file and restores optimizer state.

        Args:
            ckpt_file (str): A PyTorch pth file containing model weights.
            optimizer (Optimizer): A vanilla optimizer to have its state restored from.

        Returns:
            int: Global step variable where the model was last checkpointed.
        """
        if not ckpt_file:
            raise ValueError("No checkpoint file to be restored.")

        try:
            ckpt_dict = torch.load(ckpt_file)
        except RuntimeError:
            ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
        # Restore model weights
        self.load_state_dict(ckpt_dict['model_state_dict'])

        # Restore optimizer status if existing. Evaluation doesn't need this
        # TODO return optimizer?????
        if optimizer:
            optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])

        # Return global step
        return ckpt_dict['epoch']

    def save_checkpoint(self,
                        directory,
                        epoch, loss,
                        optimizer=None,
                        name=None):
        r"""
        Saves checkpoint at a certain global step during training. Optimizer state
        is also saved together.

        Args:
            directory (str): Path to save checkpoint to.
            epoch (int): The training. epoch
            optimizer (Optimizer): Optimizer state to be saved concurrently.
            name (str): The name to save the checkpoint file as.

        Returns:
            None
        """
        # Create directory to save to
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Build checkpoint dict to save.
        ckpt_dict = {
            'model_state_dict':
                self.state_dict(),
            'optimizer_state_dict':
                optimizer.state_dict() if optimizer is not None else None,
            'epoch':
                epoch
        }

        # Save the file with specific name
        if name is None:
            name = "{}_{}_epoch.pth".format(
                os.path.basename(directory),  # netD or netG
                'last')

        torch.save(ckpt_dict, os.path.join(directory, name))
        if self.best_loss > loss:
            self.best_loss = loss
            name = "{}_BEST.pth".format(
                os.path.basename(directory))
            torch.save(ckpt_dict, os.path.join(directory, name))
        return name

    def count_params(self):
        r"""
        Computes the number of parameters in this model.

        Args: None

        Returns:
            int: Total number of weight parameters for this model.
            int: Total number of trainable parameters for this model.

        """
        num_total_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(p.numel() for p in self.parameters()
                                   if p.requires_grad)

        return num_total_params, num_trainable_params

    def inference(self, input_tensor):
        self.eval()
        with torch.no_grad():
            output = self.forward(input_tensor)
            if isinstance(output, tuple):
                output = output[0]
            return output.cpu().detach()

    def test(self,device='cpu'):
        input_tensor = torch.rand(1, self.in_channels, 32, 32, 32)
        ideal_out = torch.rand(1, self.classes, 32, 32, 32)
        out = self.forward(input_tensor)
        assert ideal_out.shape == out.shape
        summary(self.to(torch.device(device)), (self.in_channels, 32, 32, 32),device=device)
        # import torchsummaryX
        # torchsummaryX.summary(self, input_tensor.to(device))
        print("Vnet test is complete")