# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math

import torch
import torch.nn as nn

from monai.networks.blocks import Convolution
from monai.networks.layers.factories import Act, Conv, Norm, Pool, split_args


class ChannelSELayer(nn.Module):
    """
    Re-implementation of the Squeeze-and-Excitation block based on:
    "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        r: int = 2,
        acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}),
        acti_type_2: tuple[str, dict] | str = "sigmoid",
        add_residual: bool = False,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to ``("relu", {"inplace": True})``.
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".

        Raises:
            ValueError: When ``r`` is nonpositive or larger than ``in_channels``.

        See also:

            :py:class:`monai.networks.layers.Act`

        """
        super().__init__()

        self.add_residual = add_residual

        pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims]
        self.avg_pool = pool_type(1)  # spatial size (1, 1, ...)

        channels = int(in_channels // r)
        if channels <= 0:
            raise ValueError(f"r must be positive and smaller than in_channels, got r={r} in_channels={in_channels}.")

        act_1, act_1_args = split_args(acti_type_1)
        act_2, act_2_args = split_args(acti_type_2)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, channels, bias=True),
            Act[act_1](**act_1_args),
            nn.Linear(channels, in_channels, bias=True),
            Act[act_2](**act_2_args),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
        """
        b, c = x.shape[:2]
        y: torch.Tensor = self.avg_pool(x).view(b, c)
        y = self.fc(y).view([b, c] + [1] * (x.ndim - 2))
        result = x * y

        # Residual connection is moved here instead of providing an override of forward in ResidualSELayer since
        # Torchscript has an issue with using super().
        if self.add_residual:
            result += x

        return result


class ResidualSELayer(ChannelSELayer):
    """
    A "squeeze-and-excitation"-like layer with a residual connection::

        --+-- SE --o--
          |        |
          +--------+
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        r: int = 2,
        acti_type_1: tuple[str, dict] | str = "leakyrelu",
        acti_type_2: tuple[str, dict] | str = "relu",
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: defaults to "leakyrelu".
            acti_type_2: defaults to "relu".

        See also:
            :py:class:`monai.networks.blocks.ChannelSELayer`
        """
        super().__init__(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            r=r,
            acti_type_1=acti_type_1,
            acti_type_2=acti_type_2,
            add_residual=True,
        )


class SEBlock(nn.Module):
    """
    Residual module enhanced with Squeeze-and-Excitation::

        ----+- conv1 --  conv2 -- conv3 -- SE -o---
            |                                  |
            +---(channel project if needed)----+

    Re-implementation of the SE-Resnet block based on:
    "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        n_chns_1: int,
        n_chns_2: int,
        n_chns_3: int,
        conv_param_1: dict | None = None,
        conv_param_2: dict | None = None,
        conv_param_3: dict | None = None,
        project: Convolution | None = None,
        r: int = 2,
        acti_type_1: tuple[str, dict] | str = ("relu", {"inplace": True}),
        acti_type_2: tuple[str, dict] | str = "sigmoid",
        acti_type_final: tuple[str, dict] | str | None = ("relu", {"inplace": True}),
    ):
        """
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            n_chns_1: number of output channels in the 1st convolution.
            n_chns_2: number of output channels in the 2nd convolution.
            n_chns_3: number of output channels in the 3rd convolution.
            conv_param_1: additional parameters to the 1st convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_2: additional parameters to the 2nd convolution.
                Defaults to ``{"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_3: additional parameters to the 3rd convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": None}``
            project: in the case of residual chns and output chns doesn't match, a project
                (Conv) layer/block is used to adjust the number of chns. In SENET, it is
                consisted with a Conv layer as well as a Norm layer.
                Defaults to None (chns are matchable) or a Conv layer with kernel size 1.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to "relu".
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".
            acti_type_final: activation type of the end of the block. Defaults to "relu".

        See also:

            :py:class:`monai.networks.blocks.ChannelSELayer`

        """
        super().__init__()

        if not conv_param_1:
            conv_param_1 = {"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}
        self.conv1 = Convolution(
            spatial_dims=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1
        )

        if not conv_param_2:
            conv_param_2 = {"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}
        self.conv2 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2)

        if not conv_param_3:
            conv_param_3 = {"kernel_size": 1, "norm": Norm.BATCH, "act": None}
        self.conv3 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3)

        self.se_layer = ChannelSELayer(
            spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2
        )

        if project is None and in_channels != n_chns_3:
            self.project = Conv[Conv.CONV, spatial_dims](in_channels, n_chns_3, kernel_size=1)
        elif project is None:
            self.project = nn.Identity()
        else:
            self.project = project

        if acti_type_final is not None:
            act_final, act_final_args = split_args(acti_type_final)
            self.act = Act[act_final](**act_final_args)
        else:
            self.act = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
        """
        residual = self.project(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.se_layer(x)
        x += residual
        x = self.act(x)
        return x


class SEBottleneck(SEBlock):
    """
    Bottleneck for SENet154.
    """

    expansion = 4

    def __init__(
        self,
        spatial_dims: int,
        inplanes: int,
        planes: int,
        groups: int,
        reduction: int,
        stride: int = 1,
        downsample: Convolution | None = None,
    ) -> None:
        conv_param_1 = {
            "strides": 1,
            "kernel_size": 1,
            "act": ("relu", {"inplace": True}),
            "norm": Norm.BATCH,
            "bias": False,
        }
        conv_param_2 = {
            "strides": stride,
            "kernel_size": 3,
            "act": ("relu", {"inplace": True}),
            "norm": Norm.BATCH,
            "groups": groups,
            "bias": False,
        }
        conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}

        super().__init__(
            spatial_dims=spatial_dims,
            in_channels=inplanes,
            n_chns_1=planes * 2,
            n_chns_2=planes * 4,
            n_chns_3=planes * 4,
            conv_param_1=conv_param_1,
            conv_param_2=conv_param_2,
            conv_param_3=conv_param_3,
            project=downsample,
            r=reduction,
        )


class SEResNetBottleneck(SEBlock):
    """
    ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
    implementation and uses `strides=stride` in `conv1` and not in `conv2`
    (the latter is used in the torchvision implementation of ResNet).
    """

    expansion = 4

    def __init__(
        self,
        spatial_dims: int,
        inplanes: int,
        planes: int,
        groups: int,
        reduction: int,
        stride: int = 1,
        downsample: Convolution | None = None,
    ) -> None:
        conv_param_1 = {
            "strides": stride,
            "kernel_size": 1,
            "act": ("relu", {"inplace": True}),
            "norm": Norm.BATCH,
            "bias": False,
        }
        conv_param_2 = {
            "strides": 1,
            "kernel_size": 3,
            "act": ("relu", {"inplace": True}),
            "norm": Norm.BATCH,
            "groups": groups,
            "bias": False,
        }
        conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}

        super().__init__(
            spatial_dims=spatial_dims,
            in_channels=inplanes,
            n_chns_1=planes,
            n_chns_2=planes,
            n_chns_3=planes * 4,
            conv_param_1=conv_param_1,
            conv_param_2=conv_param_2,
            conv_param_3=conv_param_3,
            project=downsample,
            r=reduction,
        )


class SEResNeXtBottleneck(SEBlock):
    """
    ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
    """

    expansion = 4

    def __init__(
        self,
        spatial_dims: int,
        inplanes: int,
        planes: int,
        groups: int,
        reduction: int,
        stride: int = 1,
        downsample: Convolution | None = None,
        base_width: int = 4,
    ) -> None:
        conv_param_1 = {
            "strides": 1,
            "kernel_size": 1,
            "act": ("relu", {"inplace": True}),
            "norm": Norm.BATCH,
            "bias": False,
        }
        conv_param_2 = {
            "strides": stride,
            "kernel_size": 3,
            "act": ("relu", {"inplace": True}),
            "norm": Norm.BATCH,
            "groups": groups,
            "bias": False,
        }
        conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}
        width = math.floor(planes * (base_width / 64)) * groups

        super().__init__(
            spatial_dims=spatial_dims,
            in_channels=inplanes,
            n_chns_1=width,
            n_chns_2=width,
            n_chns_3=planes * 4,
            conv_param_1=conv_param_1,
            conv_param_2=conv_param_2,
            conv_param_3=conv_param_3,
            project=downsample,
            r=reduction,
        )
