# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
from torch import nn
from torch.nn import functional as F

from monai.config.deviceconfig import USE_COMPILED
from monai.networks.layers.spatial_transforms import grid_pull
from monai.networks.utils import meshgrid_ij
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import

_C, _ = optional_import("monai._C")

__all__ = ["Warp", "DVF2DDF"]


class Warp(nn.Module):
    """
    Warp an image with given dense displacement field (DDF).
    """

    def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value, jitter=False):
        """
        For pytorch native APIs, the possible values are:

            - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``.
            - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``

        See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

        For MONAI C++/CUDA extensions, the possible values are:

            - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``, 0, 1, ...
            - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``, 0, 1, ...

        See also: :py:class:`monai.networks.layers.grid_pull`

        - jitter: bool, default=False
            Define reference grid on non-integer values
            Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration
            based on mutual information. Image and Vision Computing, 19:33-44, 2001.
        """
        super().__init__()
        # resolves _interp_mode for different methods

        if USE_COMPILED:
            if mode in (inter.value for inter in GridSampleMode):
                mode = GridSampleMode(mode)
                if mode == GridSampleMode.BILINEAR:
                    mode = 1
                elif mode == GridSampleMode.NEAREST:
                    mode = 0
                elif mode == GridSampleMode.BICUBIC:
                    mode = 3
                else:
                    mode = 1  # default to linear
            self._interp_mode = mode
        else:
            warnings.warn("monai.networks.blocks.Warp: Using PyTorch native grid_sample.")
            self._interp_mode = GridSampleMode(mode).value

        # resolves _padding_mode for different methods
        if USE_COMPILED:
            if padding_mode in (pad.value for pad in GridSamplePadMode):
                padding_mode = GridSamplePadMode(padding_mode)
                if padding_mode == GridSamplePadMode.ZEROS:
                    padding_mode = 7
                elif padding_mode == GridSamplePadMode.BORDER:
                    padding_mode = 0
                elif padding_mode == GridSamplePadMode.REFLECTION:
                    padding_mode = 1
                else:
                    padding_mode = 0  # default to nearest
            self._padding_mode = padding_mode
        else:
            self._padding_mode = GridSamplePadMode(padding_mode).value

        self.ref_grid = None
        self.jitter = jitter

    def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int = 0) -> torch.Tensor:
        if (
            self.ref_grid is not None
            and self.ref_grid.shape[0] == ddf.shape[0]
            and self.ref_grid.shape[1:] == ddf.shape[2:]
        ):
            return self.ref_grid  # type: ignore
        mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]
        grid = torch.stack(meshgrid_ij(*mesh_points), dim=0)  # (spatial_dims, ...)
        grid = torch.stack([grid] * ddf.shape[0], dim=0)  # (batch, spatial_dims, ...)
        self.ref_grid = grid.to(ddf)
        if jitter:
            # Define reference grid on non-integer values
            with torch.random.fork_rng(enabled=seed):
                torch.random.manual_seed(seed)
                grid += torch.rand_like(grid)
        self.ref_grid.requires_grad = False
        return self.ref_grid

    def forward(self, image: torch.Tensor, ddf: torch.Tensor):
        """
        Args:
            image: Tensor in shape (batch, num_channels, H, W[, D])
            ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])

        Returns:
            warped_image in the same shape as image (batch, num_channels, H, W[, D])
        """
        spatial_dims = len(image.shape) - 2
        if spatial_dims not in (2, 3):
            raise NotImplementedError(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.")
        ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:])
        if ddf.shape != ddf_shape:
            raise ValueError(
                f"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, "
                f"Got {ddf.shape} instead."
            )
        grid = self.get_reference_grid(ddf, jitter=self.jitter) + ddf
        grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1])  # (batch, ..., spatial_dims)

        if not USE_COMPILED:  # pytorch native grid_sample
            for i, dim in enumerate(grid.shape[1:-1]):
                grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
            index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))
            grid = grid[..., index_ordering]  # z, y, x -> x, y, z
            return F.grid_sample(
                image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
            )

        # using csrc resampling
        return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)


class DVF2DDF(nn.Module):
    """
    Layer calculates a dense displacement field (DDF) from a dense velocity field (DVF)
    with scaling and squaring.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)

    """

    def __init__(
        self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value
    ):
        super().__init__()
        if num_steps <= 0:
            raise ValueError(f"expecting positive num_steps, got {num_steps}")
        self.num_steps = num_steps
        self.warp_layer = Warp(mode=mode, padding_mode=padding_mode)

    def forward(self, dvf: torch.Tensor) -> torch.Tensor:
        """
        Args:
            dvf: dvf to be transformed, in shape (batch, ``spatial_dims``, H, W[,D])

        Returns:
            a dense displacement field
        """
        ddf: torch.Tensor = dvf / (2**self.num_steps)
        for _ in range(self.num_steps):
            ddf = ddf + self.warp_layer(image=ddf, ddf=ddf)
        return ddf
