import sys
sys.path.insert(0,"../")
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import Activations, AsDiscrete, Compose
from monai.utils.enums import MetricReduction
from monai.networks.nets import BasicUNetPlusPlus
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import os
from abc import ABC, abstractmethod

# model = SwinUNETR(
#         in_channels=1,
#         out_channels=2,
#         feature_size=48,
#         spatial_dims=3,
#     )


class Basicunetplusplus(BasicUNetPlusPlus):
    def __init__(self, in_channels=1, classes=2,dropout=0,dropout_p=0.0):
        # super(Swinunetr, self).__init__()
        super(Basicunetplusplus, self).__init__(
        in_channels=in_channels,
        out_channels=classes,
        dropout_p=dropout_p,
        dropout=dropout,
      )
        self.best_loss = 1000000
        # self.classes = classes
        # self.in_channels = in_channels
        # self.swinunetr = SwinUNETR(
        #     img_size=img_size,
        #     in_channels=in_channels,
        #     out_channels=classes,
        #     feature_size=48,
        #     spatial_dims=3,
        # )
    def forward(self, x):
        return super(Basicunetplusplus, self).forward(x)
    
    @property
    def device(self):
        return next(self.parameters()).device

    def restore_checkpoint(self, ckpt_file, optimizer=None):
        r"""
        Restores checkpoint from a pth file and restores optimizer state.

        Args:
            ckpt_file (str): A PyTorch pth file containing model weights.
            optimizer (Optimizer): A vanilla optimizer to have its state restored from.

        Returns:
            int: Global step variable where the model was last checkpointed.
        """
        if not ckpt_file:
            raise ValueError("No checkpoint file to be restored.")

        try:
            ckpt_dict = torch.load(ckpt_file)
        except RuntimeError:
            ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
        # Restore model weights
        self.load_state_dict(ckpt_dict['model_state_dict'])

        # Restore optimizer status if existing. Evaluation doesn't need this
        # TODO return optimizer?????
        if optimizer:
            optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])

        # Return global step
        return ckpt_dict['epoch']

    def save_checkpoint(self,
                        directory,
                        epoch, loss,
                        optimizer=None,
                        name=None):
        r"""
        Saves checkpoint at a certain global step during training. Optimizer state
        is also saved together.

        Args:
            directory (str): Path to save checkpoint to.
            epoch (int): The training. epoch
            optimizer (Optimizer): Optimizer state to be saved concurrently.
            name (str): The name to save the checkpoint file as.

        Returns:
            None
        """
        # Create directory to save to
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Build checkpoint dict to save.
        ckpt_dict = {
            'model_state_dict':
                self.state_dict(),
            'optimizer_state_dict':
                optimizer.state_dict() if optimizer is not None else None,
            'epoch':
                epoch
        }

        # Save the file with specific name
        if name is None:
            name = "{}_{}_epoch.pth".format(
                os.path.basename(directory),  # netD or netG
                'last')

        torch.save(ckpt_dict, os.path.join(directory, name))
        if self.best_loss > loss:
            self.best_loss = loss
            name = "{}_BEST.pth".format(
                os.path.basename(directory))
            torch.save(ckpt_dict, os.path.join(directory, name))
        return name

    def count_params(self):
        r"""
        Computes the number of parameters in this model.

        Args: None

        Returns:
            int: Total number of weight parameters for this model.
            int: Total number of trainable parameters for this model.

        """
        num_total_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(p.numel() for p in self.parameters()
                                   if p.requires_grad)

        return num_total_params, num_trainable_params

    def inference(self, input_tensor):
        # self.eval()
        with torch.no_grad():
            output = self.forward(input_tensor)[0]
            if isinstance(output, tuple):
                output = output[0]
            return output

    def test(self,device='cpu'):
        input_tensor = torch.rand(1, self.in_channels, 32, 32, 32)
        ideal_out = torch.rand(1, self.classes, 32, 32, 32)
        out = self.forward(input_tensor)
        assert ideal_out.shape == out.shape
        summary(self.to(torch.device(device)), (self.in_channels, 32, 32, 32),device=device)
        # import torchsummaryX
        # torchsummaryX.summary(self, input_tensor.to(device))
        print("Vnet test is complete")
        





import torch
from torchviz import make_dot
# from your_module import BasicUNetPlusPlus  # 替换成你的模块导入路径

if __name__ == "__main__":
    model = BasicUNetPlusPlus(spatial_dims=3, in_channels=1, out_channels=2)
    model.eval()

    # 2) 构造一个符合 forward 要求的虚拟输入
    x = torch.randn(1, 1, 128, 128, 128)      # 主输入
    # 如果 forward 需要第二个分支，可以传 None：
    y = model(x)   # 如果模型 forward 直接返回 [output], 则 y[0] 是 Tensor

    # 3) 生成计算图
    dot = make_dot(y[0], params=dict(model.named_parameters()))

    # 4) 保存为 PDF/PNG
    dot.format = 'png'
    dot.directory = './'
    dot.render('unetpp_graph', cleanup=True)
    print("已生成：unetpp_graph.png")