U
    PhG                      @  s   d dl mZ d dlZd dlmZ dddddddZdd	d	dd
ddZdd	ddddZdd	ddddZddddddddZddddddddZ	dS )    )annotationsN)Tensorr   int)xshift	shift_dimreturnc                 C  s\   ||  | }|dkr| S | |d|  || }| ||  || |}tj||f|dS )a{  
    Similar to roll but for only one dim.

    Args:
        x: input data (k-space or image) that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        shift: the amount of shift along each of shift_dims dimension
        shift_dim: the dimension over which the shift is applied

    Returns:
        1d-shifted version of x

    Note:
        This function is called when fftshift and ifftshift are not available in the running pytorch version
    r   )dim)sizenarrowtorchcat)r   r   r   leftright r   V/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/fft_utils_t.pyroll_1d   s    r   z	list[int])r   r   
shift_dimsr   c                 C  sR   t |t |kr.tdt | dt | dt||D ]\}}t| ||} q8| S )a  
    Similar to np.roll but applies to PyTorch Tensors

    Args:
        x: input data (k-space or image) that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        shift: the amount of shift along each of shift_dims dimensions
        shift_dims: dimensions over which the shift is applied

    Returns:
        shifted version of x

    Note:
        This function is called when fftshift and ifftshift are not available in the running pytorch version
    z$len(shift) != len(shift_dims), got fz and f.)len
ValueErrorzipr   )r   r   r   sdr   r   r   roll-   s
    r   )r   r   r   c                 C  s>   dgt | }t|D ]\}}| j| d ||< qt| ||S )a9  
    Similar to np.fft.fftshift but applies to PyTorch Tensors

    Args:
        x: input data (k-space or image) that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        shift_dims: dimensions over which the shift is applied

    Returns:
        fft-shifted version of x

    Note:
        This function is called when fftshift is not available in the running pytorch version
    r      r   	enumerateshaper   r   r   r   iZdim_numr   r   r   fftshiftE   s    r!   c                 C  sB   dgt | }t|D ]\}}| j| d d ||< qt| ||S )a<  
    Similar to np.fft.ifftshift but applies to PyTorch Tensors

    Args:
        x: input data (k-space or image) that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        shift_dims: dimensions over which the shift is applied

    Returns:
        ifft-shifted version of x

    Note:
        This function is called when ifftshift is not available in the running pytorch version
    r      r   r   r   r   r   r   	ifftshift[   s    r#   Tbool)kspspatial_dims
is_complexr   c                 C  s   t t| d}|rL| jd dkr8td| jd  dt t| d d}t t| d}t| |}|rttjjt	||dd}nttjj||dd}t
||}|S )	aC  
    Pytorch-based ifft for spatial_dims-dim signals. "centered" means this function automatically takes care
    of the required ifft and fft shifts.
    This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift

    Args:
        ksp: k-space data that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)
        is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels)

    Returns:
        "out" which is the output image (inverse fourier of ksp)

    Example:

        .. code-block:: python

            import torch
            ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts
            # output1 and output2 will be identical
            output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm="ortho")
            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )

            output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)
    r   r   zksp.shape[-1] is not 2 ().r"   orthor	   norm)listranger   r   r#   r   view_as_realfftifftnview_as_complexr!   )r%   r&   r'   r   dimsr   outr   r   r   ifftn_centered_tq   s    
 
r5   )imr&   r'   r   c                 C  s   t t| d}|rL| jd dkr8td| jd  dt t| d d}t t| d}t| |}|rttjjt	||dd}nttjj||dd}t
||}|S )	a,  
    Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care
    of the required ifft and fft shifts.
    This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift

    Args:
        im: image that can be
            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or
            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.
        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)
        is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels)

    Returns:
        "out" which is the output kspace (fourier of im)

    Example:

        .. code-block:: python

            import torch
            im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts
            # output1 and output2 will be identical
            output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho")
            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )

            output2 = fftn_centered(im, spatial_dims=2, is_complex=True)
    r   r(   r   zimg.shape[-1] is not 2 (r)   r"   r*   r+   )r-   r.   r   r   r#   r   r/   r0   fftnr2   r!   )r6   r&   r'   r   r3   r   r4   r   r   r   fftn_centered_t   s    
 
r8   )T)T)

__future__r   r   r   r   r   r!   r#   r5   r8   r   r   r   r   <module>   s   0