# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

# from monai.networks.nets.basic_unet import Down, TwoConv, UpCat
from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.upsample import UpSample
from monai.networks.layers.factories import Conv, Pool
from monai.utils.misc import ensure_tuple_rep

__all__ = [
    "BasicUnetPlusPlus",
    "BasicunetPlusPlus",
    "basicunetplusplus",
    "BasicUNetPlusPlusKernelModified",
]


class Attention_block(nn.Module):
    """
    Attention Block
    """

    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv3d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(F_int),
        )

        self.W_x = nn.Sequential(
            nn.Conv3d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(F_int),
        )

        self.psi = nn.Sequential(
            nn.Conv3d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(1),
            nn.Sigmoid(),
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi


class MCDropout3d(nn.Dropout3d):
    """MC Dropout：无论 model.train()/eval() 都启用随机失活"""

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.dropout3d(input, p=self.p, training=True, inplace=self.inplace)


class TwoConv(nn.Sequential):
    """two convolutions."""

    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        out_chns: int,
        act: str | tuple,
        norm: str | tuple,
        bias: bool,
        dropout: float | tuple = 0.0,
        kernel_size: Sequence[int] = (3, 3, 3),
        padding: Optional[tuple] = (1, 1, 3),
        stride: Sequence[int] | int = 1,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

        """
        super().__init__()

        conv_0 = Convolution(
            spatial_dims,
            in_chns,
            out_chns,
            act=act,
            norm=norm,
            dropout=dropout,
            bias=bias,
            kernel_size=kernel_size,
            padding=padding,
            strides=stride,
        )
        conv_1 = Convolution(
            spatial_dims,
            out_chns,
            out_chns,
            act=act,
            norm=norm,
            dropout=dropout,
            bias=bias,
            kernel_size=kernel_size,
            padding=padding,
        )
        self.add_module("conv_0", conv_0)
        self.add_module("conv_1", conv_1)


class Down(nn.Sequential):
    """maxpooling downsampling and two convolutions."""

    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        out_chns: int,
        act: str | tuple,
        norm: str | tuple,
        bias: bool,
        dropout: float | tuple = 0.0,
        kernel_size: Sequence[int] = (3, 3, 3),
        padding: Optional[tuple] = (1, 1, 3),
        downsample_mode: str = "maxpool",
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.
            downsample_mode: 'maxpool' or 'strideconv'. Defaults to 'maxpool'.

        """
        super().__init__()
        if downsample_mode == "maxpool":
            max_pooling = Pool["MAX", spatial_dims](kernel_size=(2, 2, 1))
            self.add_module("max_pooling", max_pooling)
            conv_stride = 1
        elif downsample_mode == "strideconv":
            conv_stride = (2, 2, 1)
        else:
            raise ValueError(f"Unsupported downsample mode: {downsample_mode}")

        convs = TwoConv(
            spatial_dims,
            in_chns,
            out_chns,
            act,
            norm,
            bias,
            dropout,
            kernel_size=kernel_size,
            padding=padding,
            stride=conv_stride,
        )

        self.add_module("convs", convs)


class _ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        super().__init__(
            nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=3,
                padding=dilation,
                dilation=dilation,
                bias=False,
            ),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
        )


class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, atrous_rates):
        super().__init__()
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv3d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(),
            )
        )

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(_ASPPConv(in_channels, out_channels, rate))

        modules.append(
            nn.Sequential(
                nn.AdaptiveAvgPool3d(1),
                nn.Conv3d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(),
            )
        )

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv3d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5),
        )

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(
                F.interpolate(
                    conv(x), size=x.shape[2:], mode="trilinear", align_corners=False
                )
            )
        res = torch.cat(res, dim=1)
        return self.project(res)


class UpCat(nn.Module):
    """upsampling, concatenation with the encoder feature map, two convolutions"""

    def __init__(
        self,
        spatial_dims: int,
        in_chns: int,
        cat_chns: int,
        out_chns: int,
        act: str | tuple,
        norm: str | tuple,
        bias: bool,
        dropout: float | tuple = 0.0,
        upsample: str = "deconv",
        pre_conv: nn.Module | str | None = "default",
        interp_mode: str = "linear",
        align_corners: bool | None = True,
        halves: bool = True,
        is_pad: bool = True,
        kernel_size: Sequence[int] = (3, 3, 3),
        padding: Optional[tuple] = (1, 1, 3),
        attention: bool = False,
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels to be upsampled.
            cat_chns: number of channels from the encoder.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
            pre_conv: a conv block applied before upsampling.
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
            align_corners: set the align_corners parameter for upsample. Defaults to True.
                Only used in the "nontrainable" mode.
            halves: whether to halve the number of channels during upsampling.
                This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.
            is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True.
            attention: whether to use attention gate. Defaults to False.

        """
        super().__init__()
        if upsample == "nontrainable" and pre_conv is None:
            up_chns = in_chns
        else:
            up_chns = in_chns // 2 if halves else in_chns
        self.upsample = UpSample(
            spatial_dims,
            in_chns,
            up_chns,
            (2, 2, 1),
            mode=upsample,
            pre_conv=pre_conv,
            interp_mode=interp_mode,
            align_corners=align_corners,
            kernel_size=kernel_size,
        )
        self.convs = TwoConv(
            spatial_dims,
            cat_chns + up_chns,
            out_chns,
            act,
            norm,
            bias,
            dropout,
            kernel_size=kernel_size,
            padding=padding,
        )
        self.is_pad = is_pad
        self.use_attention = attention
        if self.use_attention:
            self.attention_gate = Attention_block(
                F_g=up_chns, F_l=cat_chns, F_int=cat_chns // 2
            )

    def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
        """

        Args:
            x: features to be upsampled.
            x_e: optional features from the encoder, if None, this branch is not in use.
        """
        x_0 = self.upsample(x)

        if x_e is not None:
            x_e_non_none = x_e
            if self.use_attention:
                x_e_non_none = self.attention_gate(g=x_0, x=x_e_non_none)
            if self.is_pad:
                # handling spatial shapes due to the 2x maxpooling with odd edge lengths.
                dimensions = len(x.shape) - 2
                sp = [0] * (dimensions * 2)
                for i in range(dimensions):
                    if x_e_non_none.shape[-i - 1] != x_0.shape[-i - 1]:
                        sp[i * 2 + 1] = 1
                x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
            x = self.convs(
                torch.cat([x_e_non_none, x_0], dim=1)
            )  # input channels: (cat_chns + up_chns)
        else:
            x = self.convs(x_0)

        return x


class BasicUNetPlusPlusKernelModified(nn.Module):
    def __init__(
        self,
        spatial_dims: int = 3,
        in_channels: int = 1,
        out_channels: int = 2,
        features: Sequence[int] = (32, 32, 64, 128, 256, 32),
        deep_supervision: bool = False,
        act: str | tuple = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
        norm: str | tuple = ("instance", {"affine": True}),
        bias: bool = True,
        dropout: float | tuple = 0.0,
        upsample: str = "deconv",
        dropout_p: float = 0.0,
    ):
        """
        A UNet++ implementation with 1D/2D/3D supports.

        Based on:

            Zhou et al. "UNet++: A Nested U-Net Architecture for Medical Image
            Segmentation". 4th Deep Learning in Medical Image Analysis (DLMIA)
            Workshop, DOI: https://doi.org/10.48550/arXiv.1807.10165


        Args:
            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            deep_supervision: whether to prune the network at inference time. Defaults to False. If true, returns a list,
                whose elements correspond to outputs at different nodes.
            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            bias: whether to have a bias term in convolution blocks. Defaults to True.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        Examples::

            # for spatial 2D
            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))

            # for spatial 2D, with deep supervision enabled
            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), deep_supervision=True)

            # for spatial 2D, with group norm
            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

            # for spatial 3D
            >>> net = BasicUNetPlusPlus(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))

        See Also
            - :py:class:`monai.networks.nets.BasicUNet`
            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

        """
        super().__init__()

        self.deep_supervision = deep_supervision
        self.dropout_p = dropout_p
        fea = ensure_tuple_rep(features, 6)
        print(f"BasicUNetPlusPlus features: {fea}.")

        self.conv_0_0 = TwoConv(
            spatial_dims,
            in_channels,
            fea[0],
            act,
            norm,
            bias,
            dropout,
            kernel_size=(3, 3, 7),
        )
        self.conv_1_0 = Down(
            spatial_dims,
            fea[0],
            fea[1],
            act,
            norm,
            bias,
            dropout,
            kernel_size=(3, 3, 7),
        )
        self.conv_2_0 = Down(
            spatial_dims,
            fea[1],
            fea[2],
            act,
            norm,
            bias,
            dropout,
            kernel_size=(3, 3, 7),
        )
        self.conv_3_0 = Down(
            spatial_dims,
            fea[2],
            fea[3],
            act,
            norm,
            bias,
            dropout,
            kernel_size=(3, 3, 7),
            downsample_mode="strideconv",
        )
        self.aspp = ASPP(
            in_channels=fea[3], out_channels=fea[4], atrous_rates=[6, 12, 18]
        )

        self.upcat_0_1 = UpCat(
            spatial_dims,
            fea[1],
            fea[0],
            fea[0],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )
        self.upcat_1_1 = UpCat(
            spatial_dims,
            fea[2],
            fea[1],
            fea[1],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )
        self.upcat_2_1 = UpCat(
            spatial_dims,
            fea[3],
            fea[2],
            fea[2],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )
        self.upcat_3_1 = UpCat(
            spatial_dims,
            fea[4],
            fea[3],
            fea[3],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )

        self.upcat_0_2 = UpCat(
            spatial_dims,
            fea[1],
            fea[0] * 2,
            fea[0],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )
        self.upcat_1_2 = UpCat(
            spatial_dims,
            fea[2],
            fea[1] * 2,
            fea[1],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )
        self.upcat_2_2 = UpCat(
            spatial_dims,
            fea[3],
            fea[2] * 2,
            fea[2],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )

        self.upcat_0_3 = UpCat(
            spatial_dims,
            fea[1],
            fea[0] * 3,
            fea[0],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )
        self.upcat_1_3 = UpCat(
            spatial_dims,
            fea[2],
            fea[1] * 3,
            fea[1],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )

        self.upcat_0_4 = UpCat(
            spatial_dims,
            fea[1],
            fea[0] * 4,
            fea[5],
            act,
            norm,
            bias,
            dropout,
            upsample,
            halves=False,
            kernel_size=(3, 3, 7),
            attention=True,
        )
        self.mc_dropout_out = (
            MCDropout3d(self.dropout_p)
            if self.dropout_p and self.dropout_p > 0
            else nn.Identity()
        )
        self.final_conv_0_1 = Conv["conv", spatial_dims](
            fea[0], out_channels, kernel_size=1
        )
        self.final_conv_0_2 = Conv["conv", spatial_dims](
            fea[0], out_channels, kernel_size=1
        )
        self.final_conv_0_3 = Conv["conv", spatial_dims](
            fea[0], out_channels, kernel_size=1
        )
        self.final_conv_0_4 = Conv["conv", spatial_dims](
            fea[5], out_channels, kernel_size=1
        )

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: input should have spatially N dimensions
                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `dimensions`.
                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
                even edge lengths.

        Returns:
            A torch Tensor of "raw" predictions in shape
            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.
        """
        x_0_0 = self.conv_0_0(x)
        x_1_0 = self.conv_1_0(x_0_0)
        x_0_1 = self.upcat_0_1(x_1_0, x_0_0)

        x_2_0 = self.conv_2_0(x_1_0)
        x_1_1 = self.upcat_1_1(x_2_0, x_1_0)
        x_0_2 = self.upcat_0_2(x_1_1, torch.cat([x_0_0, x_0_1], dim=1))

        x_3_0 = self.conv_3_0(x_2_0)
        x_2_1 = self.upcat_2_1(x_3_0, x_2_0)
        x_1_2 = self.upcat_1_2(x_2_1, torch.cat([x_1_0, x_1_1], dim=1))
        x_0_3 = self.upcat_0_3(x_1_2, torch.cat([x_0_0, x_0_1, x_0_2], dim=1))

        x_4_0 = self.aspp(x_3_0)
        x_3_1 = self.upcat_3_1(x_4_0, x_3_0)
        x_2_2 = self.upcat_2_2(x_3_1, torch.cat([x_2_0, x_2_1], dim=1))
        x_1_3 = self.upcat_1_3(x_2_2, torch.cat([x_1_0, x_1_1, x_1_2], dim=1))
        x_0_4 = self.upcat_0_4(x_1_3, torch.cat([x_0_0, x_0_1, x_0_2, x_0_3], dim=1))

        # output_0_1 = self.final_conv_0_1(x_0_1)
        # output_0_2 = self.final_conv_0_2(x_0_2)
        # output_0_3 = self.final_conv_0_3(x_0_3)
        x_0_4 = self.mc_dropout_out(x_0_4)
        output_0_4 = self.final_conv_0_4(x_0_4)

        # if self.deep_supervision:
        #     output = [output_0_1, output_0_2, output_0_3, output_0_4]
        # else:
        output = [output_0_4]

        return output


BasicUnetPlusPlus = BasicunetPlusPlus = basicunetplusplus = (
    BasicUNetPlusPlusKernelModified
)
