# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transforms using a smooth spatial field generated by interpolating from smaller randomized fields."""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np
import torch
from torch.nn.functional import grid_sample, interpolate

from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.networks.utils import meshgrid_ij
from monai.transforms.transform import Randomizable, RandomizableTransform
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode
from monai.utils.enums import TransformBackends
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

__all__ = ["SmoothField", "RandSmoothFieldAdjustContrast", "RandSmoothFieldAdjustIntensity", "RandSmoothDeform"]


class SmoothField(Randomizable):
    """
    Generate a smooth field array by defining a smaller randomized field and then reinterpolating to the desired size.

    This exploits interpolation to create a smoothly varying field used for other applications. An initial randomized
    field is defined with `rand_size` dimensions with `pad` number of values padding it along each dimension using
    `pad_val` as the value. If `spatial_size` is given this is interpolated to that size, otherwise if None the random
    array is produced uninterpolated. The output is always a Pytorch tensor allocated on the specified device.

    Args:
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with `pad_val`
        pad_val: value with which to pad field edges
        low: low value for randomized field
        high: high value for randomized field
        channels: number of channels of final output
        spatial_size: final output size of the array, None to produce original uninterpolated field
        mode: interpolation mode for resizing the field
        align_corners: if True align the corners when upsampling field
        device: Pytorch device to define field on
    """

    backend = [TransformBackends.TORCH]

    def __init__(
        self,
        rand_size: Sequence[int],
        pad: int = 0,
        pad_val: float = 0,
        low: float = -1.0,
        high: float = 1.0,
        channels: int = 1,
        spatial_size: Sequence[int] | None = None,
        mode: str = InterpolateMode.AREA,
        align_corners: bool | None = None,
        device: torch.device | None = None,
    ):
        self.rand_size = tuple(rand_size)
        self.pad = pad
        self.low = low
        self.high = high
        self.channels = channels
        self.mode = mode
        self.align_corners = align_corners
        self.device = device

        self.spatial_size: Sequence[int] | None = None
        self.spatial_zoom: Sequence[float] | None = None

        if low >= high:
            raise ValueError("Value for `low` must be less than `high` otherwise field will be zeros")

        self.total_rand_size = tuple(rs + self.pad * 2 for rs in self.rand_size)

        self.field = torch.ones((1, self.channels) + self.total_rand_size, device=self.device) * pad_val

        self.crand_size = (self.channels,) + self.rand_size

        pad_slice = slice(None) if self.pad == 0 else slice(self.pad, -self.pad)
        self.rand_slices = (0, slice(None)) + (pad_slice,) * len(self.rand_size)

        self.set_spatial_size(spatial_size)

    def randomize(self, data: Any | None = None) -> None:
        self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size))  # type: ignore[index]

    def set_spatial_size(self, spatial_size: Sequence[int] | None) -> None:
        """
        Set the `spatial_size` and `spatial_zoom` attributes used for interpolating the field to the given
        dimension, or not interpolate at all if None.

        Args:
            spatial_size: new size to interpolate to, or None to not interpolate
        """
        if spatial_size is None:
            self.spatial_size = None
            self.spatial_zoom = None
        else:
            self.spatial_size = tuple(spatial_size)
            self.spatial_zoom = tuple(s / f for s, f in zip(self.spatial_size, self.total_rand_size))

    def set_mode(self, mode: str) -> None:
        self.mode = mode

    def __call__(self, randomize=False) -> torch.Tensor:
        if randomize:
            self.randomize()

        field = self.field.clone()

        if self.spatial_zoom is not None:
            resized_field = interpolate(
                input=field,
                scale_factor=self.spatial_zoom,
                mode=look_up_option(self.mode, InterpolateMode),
                align_corners=self.align_corners,
                recompute_scale_factor=False,
            )

            mina = resized_field.min()
            maxa = resized_field.max()
            minv = self.field.min()
            maxv = self.field.max()

            # faster than rescale_array, this uses in-place operations and doesn't perform unneeded range checks
            norm_field = (resized_field.squeeze(0) - mina).div_(maxa - mina)
            field = norm_field.mul_(maxv - minv).add_(minv)

        return field


class RandSmoothFieldAdjustContrast(RandomizableTransform):
    """
    Randomly adjust the contrast of input images by calculating a randomized smooth field for each invocation.

    This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the
    edges of the input volume of that width will be mostly unchanged. Contrast is changed by raising input
    values by the power of the smooth field so the range of values given by `gamma` should be chosen with this
    in mind. For example, a minimum value of 0 in `gamma` will produce white areas so this should be avoided.
    After the contrast is adjusted the values of the result are rescaled to the range of the original input.

    Args:
        spatial_size: size of input array's spatial dimensions
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with 1
        mode: interpolation mode to use when upsampling
        align_corners: if True align the corners when upsampling field
        prob: probability transform is applied
        gamma: (min, max) range for exponential field
        device: Pytorch device to define field on
    """

    backend = [TransformBackends.TORCH]

    def __init__(
        self,
        spatial_size: Sequence[int],
        rand_size: Sequence[int],
        pad: int = 0,
        mode: str = InterpolateMode.AREA,
        align_corners: bool | None = None,
        prob: float = 0.1,
        gamma: Sequence[float] | float = (0.5, 4.5),
        device: torch.device | None = None,
    ):
        super().__init__(prob)

        if isinstance(gamma, (int, float)):
            self.gamma = (0.5, gamma)
        else:
            if len(gamma) != 2:
                raise ValueError("Argument `gamma` should be a number or pair of numbers.")

            self.gamma = (min(gamma), max(gamma))

        self.sfield = SmoothField(
            rand_size=rand_size,
            pad=pad,
            pad_val=1,
            low=self.gamma[0],
            high=self.gamma[1],
            channels=1,
            spatial_size=spatial_size,
            mode=mode,
            align_corners=align_corners,
            device=device,
        )

    def set_random_state(
        self, seed: int | None = None, state: np.random.RandomState | None = None
    ) -> RandSmoothFieldAdjustContrast:
        super().set_random_state(seed, state)
        self.sfield.set_random_state(seed, state)
        return self

    def randomize(self, data: Any | None = None) -> None:
        super().randomize(None)

        if self._do_transform:
            self.sfield.randomize()

    def set_mode(self, mode: str) -> None:
        self.sfield.set_mode(mode)

    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
        """
        Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous.
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())
        if randomize:
            self.randomize()

        if not self._do_transform:
            return img

        img_min = img.min()
        img_max = img.max()
        img_rng = img_max - img_min

        field = self.sfield()
        rfield, *_ = convert_to_dst_type(field, img)

        # everything below here is to be computed using the destination type (numpy, tensor, etc.)

        img = (img - img_min) / (img_rng + 1e-10)  # rescale to unit values
        img = img**rfield  # contrast is changed by raising image data to a power, in this case the field

        out = (img * img_rng) + img_min  # rescale back to the original image value range

        return out


class RandSmoothFieldAdjustIntensity(RandomizableTransform):
    """
    Randomly adjust the intensity of input images by calculating a randomized smooth field for each invocation.

    This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the
    edges of the input volume of that width will be mostly unchanged. Intensity is changed by multiplying the
    inputs by the smooth field, so the values of `gamma` should be chosen with this in mind. The default values
    of `(0.1, 1.0)` are sensible in that values will not be zeroed out by the field nor multiplied greater than
    the original value range.

    Args:
        spatial_size: size of input array
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with 1
        mode: interpolation mode to use when upsampling
        align_corners: if True align the corners when upsampling field
        prob: probability transform is applied
        gamma: (min, max) range of intensity multipliers
        device: Pytorch device to define field on
    """

    backend = [TransformBackends.TORCH]

    def __init__(
        self,
        spatial_size: Sequence[int],
        rand_size: Sequence[int],
        pad: int = 0,
        mode: str = InterpolateMode.AREA,
        align_corners: bool | None = None,
        prob: float = 0.1,
        gamma: Sequence[float] | float = (0.1, 1.0),
        device: torch.device | None = None,
    ):
        super().__init__(prob)

        if isinstance(gamma, (int, float)):
            self.gamma = (0.5, gamma)
        else:
            if len(gamma) != 2:
                raise ValueError("Argument `gamma` should be a number or pair of numbers.")

            self.gamma = (min(gamma), max(gamma))

        self.sfield = SmoothField(
            rand_size=rand_size,
            pad=pad,
            pad_val=1,
            low=self.gamma[0],
            high=self.gamma[1],
            channels=1,
            spatial_size=spatial_size,
            mode=mode,
            align_corners=align_corners,
            device=device,
        )

    def set_random_state(
        self, seed: int | None = None, state: np.random.RandomState | None = None
    ) -> RandSmoothFieldAdjustIntensity:
        super().set_random_state(seed, state)
        self.sfield.set_random_state(seed, state)
        return self

    def randomize(self, data: Any | None = None) -> None:
        super().randomize(None)

        if self._do_transform:
            self.sfield.randomize()

    def set_mode(self, mode: str) -> None:
        self.sfield.set_mode(mode)

    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
        """
        Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous.
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())

        if randomize:
            self.randomize()

        if not self._do_transform:
            return img

        field = self.sfield()
        rfield, *_ = convert_to_dst_type(field, img)

        # everything below here is to be computed using the destination type (numpy, tensor, etc.)

        out = img * rfield

        return out


class RandSmoothDeform(RandomizableTransform):
    """
    Deform an image using a random smooth field and Pytorch's grid_sample.

    The amount of deformation is given by `def_range` in fractions of the size of the image. The size of each dimension
    of the input image is always defined as 2 regardless of actual image voxel dimensions, that is the coordinates in
    every dimension range from -1 to 1. A value of 0.1 means pixels/voxels can be moved by up to 5% of the image's size.

    Args:
        spatial_size: input array size to which deformation grid is interpolated
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with 0
        field_mode: interpolation mode to use when upsampling the deformation field
        align_corners: if True align the corners when upsampling field
        prob: probability transform is applied
        def_range: value of the deformation range in image size fractions, single min/max value  or min/max pair
        grid_dtype: type for the deformation grid calculated from the field
        grid_mode: interpolation mode used for sampling input using deformation grid
        grid_padding_mode: padding mode used for sampling input using deformation grid
        grid_align_corners: if True align the corners when sampling the deformation grid
        device: Pytorch device to define field on
    """

    backend = [TransformBackends.TORCH]

    def __init__(
        self,
        spatial_size: Sequence[int],
        rand_size: Sequence[int],
        pad: int = 0,
        field_mode: str = InterpolateMode.AREA,
        align_corners: bool | None = None,
        prob: float = 0.1,
        def_range: Sequence[float] | float = 1.0,
        grid_dtype=torch.float32,
        grid_mode: str = GridSampleMode.NEAREST,
        grid_padding_mode: str = GridSamplePadMode.BORDER,
        grid_align_corners: bool | None = False,
        device: torch.device | None = None,
    ):
        super().__init__(prob)

        self.grid_dtype = grid_dtype
        self.grid_mode = grid_mode
        self.def_range = def_range
        self.device = device
        self.grid_align_corners = grid_align_corners
        self.grid_padding_mode = grid_padding_mode

        if isinstance(def_range, (int, float)):
            self.def_range = (-def_range, def_range)
        else:
            if len(def_range) != 2:
                raise ValueError("Argument `def_range` should be a number or pair of numbers.")

            self.def_range = (min(def_range), max(def_range))

        self.sfield = SmoothField(
            spatial_size=spatial_size,
            rand_size=rand_size,
            pad=pad,
            low=self.def_range[0],
            high=self.def_range[1],
            channels=len(rand_size),
            mode=field_mode,
            align_corners=align_corners,
            device=device,
        )

        grid_space = tuple(spatial_size) if spatial_size is not None else self.sfield.field.shape[2:]
        grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space]

        grid = meshgrid_ij(*grid_ranges)

        self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype)

    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable:
        super().set_random_state(seed, state)
        self.sfield.set_random_state(seed, state)
        return self

    def randomize(self, data: Any | None = None) -> None:
        super().randomize(None)

        if self._do_transform:
            self.sfield.randomize()

    def set_field_mode(self, mode: str) -> None:
        self.sfield.set_mode(mode)

    def set_grid_mode(self, mode: str) -> None:
        self.grid_mode = mode

    def __call__(
        self, img: NdarrayOrTensor, randomize: bool = True, device: torch.device | None = None
    ) -> NdarrayOrTensor:
        img = convert_to_tensor(img, track_meta=get_track_meta())
        if randomize:
            self.randomize()

        if not self._do_transform:
            return img

        device = device if device is not None else self.device

        field = self.sfield()

        dgrid = self.grid + field.to(self.grid_dtype)
        dgrid = moveaxis(dgrid, 1, -1)  # type: ignore
        dgrid = dgrid[..., list(range(dgrid.shape[-1] - 1, -1, -1))]  # invert order of coordinates

        img_t = convert_to_tensor(img[None], torch.float32, device)

        out = grid_sample(
            input=img_t,
            grid=dgrid,
            mode=look_up_option(self.grid_mode, GridSampleMode),
            align_corners=self.grid_align_corners,
            padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode),
        )

        out_t, *_ = convert_to_dst_type(out.squeeze(0), img)

        return out_t
