import os
import sys

import torch
from torchsummary import summary
from torchviz import make_dot

from monai.networks.nets.basic_unetplusplus_modified import (
    BasicUNetPlusPlusKernelModified,
)
# from monai.networks.nets.basic_unetplusplus_modified_trans import BasicUNetPlusPlusTrans

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from monai.networks.nets.basic_unetplusplus import BasicUNetPlusPlus

# model = SwinUNETR(
#         in_channels=1,
#         out_channels=2,
#         feature_size=48,
#         spatial_dims=3,
#     )


# class Basicunetplusplus(BasicUNetPlusPlus):
# class Basicunetplusplus(BasicUNetPlusPlusKernelModified):
class Basicunetplusplus(BasicUNetPlusPlus):
    def __init__(self, in_channels=1, classes=2, dropout=0.0):
        # super(Swinunetr, self).__init__()
        super(Basicunetplusplus, self).__init__(
            in_channels=in_channels,
            out_channels=classes,
            dropout=dropout,
        )
        self.best_loss = 1000000
        self.classes = classes
        self.in_channels = in_channels
        # self.swinunetr = SwinUNETR(
        #     img_size=img_size,
        #     in_channels=in_channels,
        #     out_channels=classes,
        #     feature_size=48,
        #     spatial_dims=3,
        # )

    def forward(self, x):
        return super(Basicunetplusplus, self).forward(x)

    @property
    def device(self):
        return next(self.parameters()).device

    def restore_checkpoint(self, ckpt_file, optimizer=None):
        r"""
        Restores checkpoint from a pth file and restores optimizer state.

        Args:
            ckpt_file (str): A PyTorch pth file containing model weights.
            optimizer (Optimizer): A vanilla optimizer to have its state restored from.

        Returns:
            int: Global step variable where the model was last checkpointed.
        """
        if not ckpt_file:
            raise ValueError("No checkpoint file to be restored.")

        try:
            ckpt_dict = torch.load(ckpt_file)
        except RuntimeError:
            ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
        # Restore model weights
        self.load_state_dict(ckpt_dict["model_state_dict"])

        # Restore optimizer status if existing. Evaluation doesn't need this
        # TODO return optimizer?????
        if optimizer:
            optimizer.load_state_dict(ckpt_dict["optimizer_state_dict"])

        # Return global step
        return ckpt_dict["epoch"]

    def save_checkpoint(self, directory, epoch, loss, optimizer=None, name=None):
        r"""
        Saves checkpoint at a certain global step during training. Optimizer state
        is also saved together.

        Args:
            directory (str): Path to save checkpoint to.
            epoch (int): The training. epoch
            optimizer (Optimizer): Optimizer state to be saved concurrently.
            name (str): The name to save the checkpoint file as.

        Returns:
            None
        """
        # Create directory to save to
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Build checkpoint dict to save.
        ckpt_dict = {
            "model_state_dict": self.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
            if optimizer is not None
            else None,
            "epoch": epoch,
        }

        # Save the file with specific name
        if name is None:
            name = "{}_{}_epoch.pth".format(
                os.path.basename(directory),  # netD or netG
                "last",
            )

        torch.save(ckpt_dict, os.path.join(directory, name))
        if self.best_loss > loss:
            self.best_loss = loss
            name = "{}_BEST.pth".format(os.path.basename(directory))
            torch.save(ckpt_dict, os.path.join(directory, name))
        return name

    def count_params(self):
        r"""
        Computes the number of parameters in this model.

        Args: None

        Returns:
            int: Total number of weight parameters for this model.
            int: Total number of trainable parameters for this model.

        """
        num_total_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad
        )

        return num_total_params, num_trainable_params

    def inference(self, input_tensor):
        # self.eval()
        with torch.no_grad():
            output = self.forward(input_tensor)[0]
            if isinstance(output, tuple):
                output = output[0]
            return output

    def test(self, device="cpu"):
        input_tensor = torch.rand(1, self.in_channels, 32, 32, 32)
        ideal_out = torch.rand(1, self.classes, 32, 32, 32)
        out = self.forward(input_tensor)[0]
        assert ideal_out.shape == out.shape
        summary(
            self.to(torch.device(device)), (self.in_channels, 32, 32, 32), device=device
        )
        # import torchsummaryX
        # torchsummaryX.summary(self, input_tensor.to(device))
        print("Vnet test is complete")


# from your_module import BasicUNetPlusPlus  # 替换成你的模块导入路径

if __name__ == "__main__":
    model = BasicUNetPlusPlus(spatial_dims=3, in_channels=1, out_channels=2)
    model.eval()

    # 2) 构造一个符合 forward 要求的虚拟输入
    x = torch.randn(1, 1, 128, 128, 128)  # 主输入
    # 如果 forward 需要第二个分支，可以传 None：
    y = model(x)  # 如果模型 forward 直接返回 [output], 则 y[0] 是 Tensor

    # 3) 生成计算图
    dot = make_dot(y[0], params=dict(model.named_parameters()))

    # 4) 保存为 PDF/PNG
    dot.format = "png"
    dot.directory = "./"
    dot.render("unetpp_graph", cleanup=True)
    print("已生成：unetpp_graph.png")
