import torch
import torch.nn as nn
import os
from abc import ABC, abstractmethod
from torchsummary import summary
# -------------------------------------------------------------
# Helper blocks
# -------------------------------------------------------------

def _group_norm(channels: int) -> int:
    """Choose reasonable number of groups for GroupNorm."""
    if channels % 8 == 0:
        return 8
    if channels % 4 == 0:
        return 4
    return 1  # fallback → InstanceNorm‑like


class ConvGNReLU(nn.Module):
    """3D Conv → GroupNorm → ReLU → (optional) Dropout3d"""

    def __init__(self, in_c: int, out_c: int, k: int = 3, stride: int = 1, drop: float = 0.0):
        super().__init__()
        pad = k // 2
        self.conv = nn.Conv3d(in_c, out_c, k, stride=stride, padding=pad, bias=False)
        self.gn = nn.GroupNorm(_group_norm(out_c), out_c)
        self.act = nn.ReLU(inplace=True)
        self.do = nn.Dropout3d(drop) if drop > 0 else nn.Identity()

    def forward(self, x):
        return self.do(self.act(self.gn(self.conv(x))))


class Down(nn.Module):
    def __init__(self, in_c: int, out_c: int, drop: float = 0.0):
        super().__init__()
        self.op = ConvGNReLU(in_c, out_c, stride=2, drop=drop)

    def forward(self, x):
        return self.op(x)


class Up(nn.Module):
    def __init__(self, in_c: int, out_c: int):
        super().__init__()
        self.deconv = nn.ConvTranspose3d(in_c, out_c, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.deconv(x))


# -------------------------------------------------------------
# Attention modules
# -------------------------------------------------------------
class PosAttn(nn.Module):
    def __init__(self, in_c: int, inter_c: int):
        super().__init__()
        self.q = nn.Conv3d(in_c, inter_c, 1)
        self.k = nn.Conv3d(in_c, inter_c, 1)
        self.v = nn.Conv3d(in_c, inter_c, 1)
        self.proj = nn.Conv3d(inter_c, in_c, 1)
        self.act = nn.Sigmoid()

    def forward(self, x):
        b, _, d, h, w = x.shape
        q = self.q(x).view(b, -1, d * h * w)
        k = self.k(x).view(b, -1, d * h * w)
        v = self.v(x).view(b, -1, d * h * w)
        attn = self.act(q * k)
        out = (attn * v).view(b, -1, d, h, w)
        return self.proj(out) + x


class ChaAttn(nn.Module):
    def __init__(self):
        super().__init__()
        self.act = nn.Sigmoid()

    def forward(self, x):
        b, c, d, h, w = x.shape
        flat = x.view(b, c, -1)
        energy = torch.bmm(flat, flat.transpose(1, 2))
        attn = self.act(energy)
        out = torch.bmm(attn, flat).view(b, c, d, h, w)
        return out + x


class DualAttn(nn.Module):
    def __init__(self, in_c: int, ratio: int = 2):
        super().__init__()
        inter_c = max(1, in_c // ratio)
        self.pos = PosAttn(in_c, inter_c)
        self.cha = ChaAttn()
        self.fuse = nn.Conv3d(in_c, in_c, 1)

    def forward(self, x):
        return self.fuse(self.pos(x) + self.cha(x))


# -------------------------------------------------------------
# Utility – crop centre region of src to match tgt spatial dims
# -------------------------------------------------------------

def crop(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
    if src.shape[2:] == tgt.shape[2:]:
        return src
    dz = (src.size(2) - tgt.size(2)) // 2
    dy = (src.size(3) - tgt.size(3)) // 2
    dx = (src.size(4) - tgt.size(4)) // 2
    return src[:, :, dz:dz + tgt.size(2), dy:dy + tgt.size(3), dx:dx + tgt.size(4)]


class BaseModel(nn.Module, ABC):
    r"""
    BaseModel with basic functionalities for checkpointing and restoration.
    """

    def __init__(self):
        super().__init__()
        self.best_loss = 1000000

    @abstractmethod
    def forward(self, x):
        pass

    @abstractmethod
    def test(self):
        """
        To be implemented by the subclass so that
        models can perform a forward propagation
        :return:
        """
        pass

    @property
    def device(self):
        return next(self.parameters()).device

    def restore_checkpoint(self, ckpt_file, optimizer=None):
        r"""
        Restores checkpoint from a pth file and restores optimizer state.

        Args:
            ckpt_file (str): A PyTorch pth file containing model weights.
            optimizer (Optimizer): A vanilla optimizer to have its state restored from.

        Returns:
            int: Global step variable where the model was last checkpointed.
        """
        if not ckpt_file:
            raise ValueError("No checkpoint file to be restored.")

        try:
            ckpt_dict = torch.load(ckpt_file)
        except RuntimeError:
            ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
        # Restore model weights
        self.load_state_dict(ckpt_dict['model_state_dict'])

        # Restore optimizer status if existing. Evaluation doesn't need this
        # TODO return optimizer?????
        if optimizer:
            optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])

        # Return global step
        return ckpt_dict['epoch']

    def save_checkpoint(self,
                        directory,
                        epoch, loss,
                        optimizer=None,
                        name=None):
        r"""
        Saves checkpoint at a certain global step during training. Optimizer state
        is also saved together.

        Args:
            directory (str): Path to save checkpoint to.
            epoch (int): The training. epoch
            optimizer (Optimizer): Optimizer state to be saved concurrently.
            name (str): The name to save the checkpoint file as.

        Returns:
            None
        """
        # Create directory to save to
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Build checkpoint dict to save.
        ckpt_dict = {
            'model_state_dict':
                self.state_dict(),
            'optimizer_state_dict':
                optimizer.state_dict() if optimizer is not None else None,
            'epoch':
                epoch
        }

        # Save the file with specific name
        if name is None:
            name = "{}_{}_epoch.pth".format(
                os.path.basename(directory),  # netD or netG
                'last')

        torch.save(ckpt_dict, os.path.join(directory, name))
        if self.best_loss > loss:
            self.best_loss = loss
            name = "{}_BEST.pth".format(
                os.path.basename(directory))
            torch.save(ckpt_dict, os.path.join(directory, name))
        return name

    def count_params(self):
        r"""
        Computes the number of parameters in this model.

        Args: None

        Returns:
            int: Total number of weight parameters for this model.
            int: Total number of trainable parameters for this model.

        """
        num_total_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(p.numel() for p in self.parameters()
                                   if p.requires_grad)

        return num_total_params, num_trainable_params

    def inference(self, input_tensor):
        self.eval()
        with torch.no_grad():
            output = self.forward(input_tensor)
            if isinstance(output, tuple):
                output = output[0]
            return output.cpu().detach()


def passthrough(x, **kwargs):
    return x


def ELUCons(elu, nchan):
    if elu:
        return nn.ELU(inplace=True)
    else:
        return nn.PReLU(nchan)

# -------------------------------------------------------------
# Dual‑Attention V‑Net 3D (clean, residual‑safe)
# -------------------------------------------------------------
class DualAttentionVNet3D(BaseModel):
    def __init__(self, in_ch: int = 1, n_cls: int = 1, base: int = 16, drop: float = 0.0):
        super().__init__()
        f = base  # shorthand

        # Encoder
        self.e0 = nn.Sequential(
            ConvGNReLU(in_ch, f, drop=drop),
            ConvGNReLU(f, f, drop=drop)
        )
        self.d1 = Down(f, 2 * f, drop)
        self.e1 = nn.Sequential(
            ConvGNReLU(2 * f, 2 * f, drop=drop),
            ConvGNReLU(2 * f, 2 * f, drop=drop)
        )
        self.d2 = Down(2 * f, 4 * f, drop)
        self.e2 = nn.Sequential(
            ConvGNReLU(4 * f, 4 * f, drop=drop),
            ConvGNReLU(4 * f, 4 * f, drop=drop),
            ConvGNReLU(4 * f, 4 * f, drop=drop)
        )
        self.d3 = Down(4 * f, 8 * f, drop)
        self.e3 = nn.Sequential(
            ConvGNReLU(8 * f, 8 * f, drop=drop),
            ConvGNReLU(8 * f, 8 * f, drop=drop),
            ConvGNReLU(8 * f, 8 * f, drop=drop)
        )
        self.d4 = Down(8 * f, 16 * f, drop)
        self.e4 = nn.Sequential(
            ConvGNReLU(16 * f, 16 * f, drop=drop),
            ConvGNReLU(16 * f, 16 * f, drop=drop),
            ConvGNReLU(16 * f, 16 * f, drop=drop)
        )

        # Decoder
        self.up3 = Up(16 * f, 8 * f)
        self.att3 = DualAttn(8 * f)
        self.dec3 = nn.Sequential(
            ConvGNReLU(16 * f, 8 * f, drop=drop),
            ConvGNReLU(8 * f, 8 * f, drop=drop),
            ConvGNReLU(8 * f, 8 * f, drop=drop)
        )

        self.up2 = Up(8 * f, 4 * f)
        self.att2 = DualAttn(4 * f)
        self.dec2 = nn.Sequential(
            ConvGNReLU(8 * f, 4 * f, drop=drop),
            ConvGNReLU(4 * f, 4 * f, drop=drop),
            ConvGNReLU(4 * f, 4 * f, drop=drop)
        )

        self.up1 = Up(4 * f, 2 * f)
        self.att1 = DualAttn(2 * f)
        self.dec1 = nn.Sequential(
            ConvGNReLU(4 * f, 2 * f, drop=drop),
            ConvGNReLU(2 * f, 2 * f, drop=drop),
            ConvGNReLU(2 * f, 2 * f, drop=drop)
        )

        self.up0 = Up(2 * f, f)
        self.att0 = DualAttn(f)
        self.dec0 = nn.Sequential(
            ConvGNReLU(2 * f, f, drop=drop),
            ConvGNReLU(f, f, drop=drop),
            ConvGNReLU(f, f, drop=drop)
        )

        self.outc = nn.Conv3d(f, n_cls, 1)
        self.act = nn.Sigmoid() if n_cls == 1 else nn.Softmax(dim=1)

    # ---------------------------------------------------------
    # Forward (residual shapes ensured)
    # ---------------------------------------------------------
    def forward(self, x):
        # ---------- Encode ----------
        e0 = self.e0(x)              # (B,f)
        e1 = self.e1(self.d1(e0))    # (B,2f)
        e2 = self.e2(self.d2(e1))    # (B,4f)
        e3 = self.e3(self.d3(e2))    # (B,8f)
        e4 = self.e4(self.d4(e3))    # (B,16f)

        # ---------- Decode ----------
        u3 = self.up3(e4)                           # (B,8f)
        c3 = torch.cat([crop(self.att3(e3), u3), u3], dim=1)  # (B,16f)
        d3 = self.dec3(c3) + u3                     # residual add → (B,8f)

        u2 = self.up2(d3)                           # (B,4f)
        c2 = torch.cat([crop(self.att2(e2), u2), u2], dim=1)  # (B,8f)
        d2 = self.dec2(c2) + u2                     # (B,4f)

        u1 = self.up1(d2)                           # (B,2f)
        c1 = torch.cat([crop(self.att1(e1), u1), u1], dim=1)  # (B,4f)
        d1 = self.dec1(c1) + u1                     # (B,2f)

        u0 = self.up0(d1)                           # (B,f)
        c0 = torch.cat([crop(self.att0(e0), u0), u0], dim=1)  # (B,2f)
        d0 = self.dec0(c0) + u0                     # (B,f)

        return self.act(self.outc(d0))
    def test(self,device='cpu'):
        input_tensor = torch.rand(1, self.in_channels, 32, 32, 32)
        ideal_out = torch.rand(1, self.classes, 32, 32, 32)
        out = self.forward(input_tensor)
        assert ideal_out.shape == out.shape
        summary(self.to(torch.device(device)), (self.in_channels, 32, 32, 32),device=device)
        # import torchsummaryX
        # torchsummaryX.summary(self, input_tensor.to(device))
        print("Vnet test is complete")

# -------------------------------------------------------------
# Quick check
# -------------------------------------------------------------
if __name__ == "__main__":
    net = DualAttentionVNet3D()
