U
    Phq,                     @  s  d Z ddlmZ ddlZddlmZ ddlmZ ddl	m
Z
mZ ddlmZmZ ddd	d
dZddd	ddZddd	ddZddddddZddd	ddZd-ddddddZdddddd Zd!d"d#d$d%Zd.ddddd'd(d)Zd/ddddd*d+d,ZdS )0zW
This script contains utility functions for developing new networks/blocks in PyTorch.
    )annotationsN)Tensor)
functional)complex_conj_tcomplex_mul_t)fftn_centered_tifftn_centered_tr   )xreturnc                 C  s   | j d dkr$td| j d  dt| j dkrh| j \}}}}}| ddddd	 |d| ||S t| j d
kr| j \}}}}}}| ddddd	d |d| |||S td| j  dS )aC  
    Swaps the complex dimension with the channel dimension so that the network treats real/imaginary
    parts as two separate channels.

    Args:
        x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data

    Returns:
        output of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data
       z'last dim must be 2, but x.shape[-1] is .   r               Lonly 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape N)shape
ValueErrorlenpermute
contiguousviewr	   bchwtwod r!   b/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/reconstruction/networks/nets/utils.pyreshape_complex_to_channel_dim   s    &*r#   c                 C  s   | j d d dkr(td| j d  dt| j dkrl| j \}}}}|d }| |d|||dddddS t| j dkr| j \}}}}}|d }| |d||||ddddddS td	| j  d
S )a,  
    Swaps the complex dimension with the channel dimension so that the network output has 2 as its last dimension

    Args:
        x: input of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data

    Returns:
        output of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data
    r   r   r   z&channel dimension should be even but (z	) is odd.r   r   r   zLonly 2D (B,C*2,H,W) and 3D (B,C*2,H,W,D) data are supported but x has shape N)r   r   r   r   r   )r	   r   c2r   r   r   r    r!   r!   r"   #reshape_channel_complex_to_last_dim4   s    
 $r%   ztuple[Tensor, int]c                 C  s   t | jdkr<| j\}}}}}|  || d||||fS t | jdkr|| j\}}}}}}|  || d|||||fS td| j dS )z
    Combines batch and channel dimensions.

    Args:
        x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data

    Returns:
        A tuple containing:
            (1) output of shape (B*C,1,...)
            (2) batch size
    r   r   r   r   N)r   r   r   r   r   r   r!   r!   r"   reshape_channel_to_batch_dimO   s     r&   int)r	   
batch_sizer
   c           	      C  s   t | jdkr8| j\}}}}}|| }| |||||S t | jdkrt| j\}}}}}}|| }| ||||||S td| j dS )z
    Detaches batch and channel dimensions.

    Args:
        x: input of shape (B*C,1,H,W,2) for 2D data or (B*C,1,H,W,D,2) for 3D data
        batch_size: batch size

    Returns:
        output of shape (B,C,...)
    r   r   zPonly 2D (B*C,1,H,W,2) and 3D (B*C,1,H,W,D,2) data are supported but x has shape N)r   r   r   r   )	r	   r(   bconer   r   r   r   r    r!   r!   r"   $reshape_batch_channel_to_channel_dimh   s    r+   ztuple[Tensor, Tensor, Tensor]c                 C  s  t | jdkr| j\}}}}|  |d|d | | } | jdd|dddd|d|d dd ||dd}| jddd|dddd|d|d dd ||dd}| ||||} | | | ||fS t | jdkr| j\}}}}}|  |d|d | | | } | jdd|ddddd|d|d ddd ||ddd}| jddd|ddddd|d|d ddd ||ddd}| |||||} | | | ||fS td| j d	S )
a  
    Performs layer mean-std normalization for complex data. Normalization is done for each batch member
    along each part (part refers to real and imaginary parts), separately.

    Args:
        x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data

    Returns:
        A tuple containing
            (1) normalized output of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
            (2) mean
            (3) std
    r   r   )dimr   F)r,   unbiasedr   Honly 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape N)r   r   r   r   meanexpandstdr   )r	   r   r   r   r   r/   r1   r    r!   r!   r"   complex_normalize   sH    <>"B              r2      zVtuple[Tensor, tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int]])r	   kr
   c                 C  sX  t | jdkr| j\}}}}|d |d B d }|d |d B d }t|| d }t|| d }	t| ||	 } d}
d}|	|||||
f}nt | jdkr@| j\}}}}}|d |d B d }|d |d B d }|d |d B d }
t|| d }t|| d }	t|
| d }t| || |	 } |	|||||
f}ntd| j | |fS )ah  
    Pad input to feed into the network (torch script compatible)

    Args:
        x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
        k: padding factor. each padded dimension will be divisible by k.

    Returns:
        A tuple containing
            (1) padded input
            (2) pad sizes (in order to reverse padding if needed)

    Example:
        .. code-block:: python

            import torch

            # 2D data
            x = torch.ones([3,2,50,70])
            x_pad,padding_sizes = divisible_pad_t(x, k=16)
            # the following line should print (3, 2, 64, 80)
            print(x_pad.shape)

            # 3D data
            x = torch.ones([3,2,50,70,80])
            x_pad,padding_sizes = divisible_pad_t(x, k=16)
            # the following line should print (3, 2, 64, 80, 80)
            print(x_pad.shape)

    r   r   r   r   )r   r   r   r.   )r   r   
floor_ceilFpadr   )r	   r4   r   r   r   r   w_multh_multw_padh_padd_multd_pad	pad_sizesr    r!   r!   r"   divisible_pad_t   s,    !r?   zGtuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int])r	   r>   r
   c                 C  s   |\}}}}}}t | jdkrL| d|d ||d  |d ||d  f S t | jdkr| d|d ||d  |d ||d  |d ||d  f S td| j dS )z
    De-pad network output to match its original shape

    Args:
        x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
        pad_sizes: padding values

    Returns:
        de-padded input
    r   .r   r   r   r.   N)r   r   r   )r	   r>   r;   r:   r=   r9   r8   r<   r!   r!   r"   inverse_divisible_pad_t   s    .@r@   floatztuple[int, int])nr
   c                 C  s   t | t | fS )z
    Returns floor and ceil of the input

    Args:
        n: input number

    Returns:
        A tuple containing:
            (1) floor(n)
            (2) ceil(n)
    )mathfloorceil)rB   r!   r!   r"   r5      s    r5   r   )kspace	sens_mapsspatial_dimsr
   c                 C  s&   t | |dd}t|t|jdddS )a  
    Reduces coil measurements to a corresponding image based on the given sens_maps. Let's say there
    are C coil measurements inside kspace, then this function multiplies the conjugate of each coil sensitivity map with the
    corresponding coil image. The result of this process will be C images. Summing those images together gives the
    resulting "reduced image."

    Args:
        kspace: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the
            coil dimension. 3D data will have the shape (B,C,H,W,D,2).
        sens_maps: sensitivity maps of the same shape as input x.
        spatial_dims: is 2 for 2D data and is 3 for 3D data

    Returns:
        reduction of x to (B,1,H,W,2) for 2D data or (B,1,H,W,D,2) for 3D data.
    TrH   
is_complexr   )r,   keepdim)r   r   r   sum)rF   rG   rH   imgr!   r!   r"   sensitivity_map_reduce  s    rN   )rM   rG   rH   r
   c                 C  s   t t| ||ddS )an  
    Expands an image to its corresponding coil images based on the given sens_maps. Let's say there
    are C coils. This function multiples image img with each coil sensitivity map in sens_maps and stacks
    the resulting C coil images along the channel dimension which is reserved for coils.

    Args:
        img: 2D image (B,1,H,W,2) with the last dimension being 2 (for real/imaginary parts). 3D data will have
            the shape (B,1,H,W,D,2).
        sens_maps: Sensitivity maps for combining coil images. The shape is (B,C,H,W,2) for 2D data
            or (B,C,H,W,D,2) for 3D data (C denotes the coil dimension).
        spatial_dims: is 2 for 2D data and is 3 for 3D data

    Returns:
        Expansion of x to (B,C,H,W,2) for 2D data and (B,C,H,W,D,2) for 3D data. The output is transferred
            to the frequency domain to yield coil measurements.
    TrI   )r   r   )rM   rG   rH   r!   r!   r"   sensitivity_map_expand#  s    rO   )r3   )r   )r   )__doc__
__future__r   rC   torchr   torch.nnr   r6   Z'monai.apps.reconstruction.complex_utilsr   r   Z!monai.networks.blocks.fft_utils_tr   r   r#   r%   r&   r+   r2   r?   r@   r5   rN   rO   r!   r!   r!   r"   <module>   s"   ) >