U
    PhG                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dlm	Z	 d dl
mZ d dlmZmZ d dlmZ G d	d
 d
ejZdS )    )annotations)SequenceN)Tensor)root_sum_of_squares_t)ComplexUnet)$reshape_batch_channel_to_channel_dimreshape_channel_to_batch_dim)ifftn_centered_tc                      s~   e Zd ZdZddddddfdd	difdd
dddf	dddddddddd	 fddZdddddZddddddZ  ZS ) CoilSensitivityModela|  
    This class uses a convolutional model to learn coil sensitivity maps for multi-coil MRI reconstruction.
    The convolutional model is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet` by default
    but can be specified by the user as well. Learning is done on the center of the under-sampled
    kspace (that region is fully sampled).

    The data being a (complex) 2-channel tensor is a requirement for using this model.

    Modified and adopted from: https://github.com/facebookresearch/fastMRI

    Args:
        spatial_dims: number of spatial dimensions.
        features: six integers as numbers of features. denotes number of channels in each layer.
        act: activation type and arguments. Defaults to LeakyReLU.
        norm: feature normalization type and arguments. Defaults to instance norm.
        bias: whether to have a bias term in convolution blocks. Defaults to True.
        dropout: dropout ratio. Defaults to 0.0.
        upsample: upsampling mode, available options are
            ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
        coil_dim: coil dimension in the data
        conv_net: the learning model used to estimate the coil sensitivity maps. default
            is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet`. The only
            requirement on the model is to have 2 as input and output number of channels.
       )    r   @         r   	LeakyReLUg?T)negative_slopeinplaceinstanceaffineg        deconv   NintzSequence[int]zstr | tupleboolzfloat | tuplestrznn.Module | None)	spatial_dimsfeaturesactnormbiasdropoutupsamplecoil_dimconv_netc
              	     s|   t    |	d kr,t|||||||d| _n@dd |	 D }
|
d d dkrftd|
d d  d|	| _|| _|| _d S )	N)r   r   r   r   r   r   r    c                 S  s   g | ]
}|j qS  shape).0pr#   r#   s/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py
<listcomp>Q   s     z1CoilSensitivityModel.__init__.<locals>.<listcomp>r   r   r   z!in_channels should be 2 but it's .)super__init__r   r"   
parameters
ValueErrorr   r!   )selfr   r   r   r   r   r   r    r!   r"   params	__class__r#   r(   r,   7   s"    

zCoilSensitivityModel.__init__r   ztuple[int, int])maskreturnc                 C  sV   |j d d  }}|d|ddf r.|d7 }q|d|ddf rJ|d8 }q.|d |fS )a  
        Extracts the size of the fully-sampled part of the kspace. Note that when a kspace
        is under-sampled, a part of its center is fully sampled. This part is called the Auto
        Calibration Region (ACR). ACR is used for sensitivity map computation.

        Args:
            mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension

        Returns:
            A tuple containing
                (1) left index of the region
                (2) right index of the region

        Note:
            Suppose the mask is of shape (1,1,20,1). If this function returns 8,12 as left and right
                indices, then it means that the fully-sampled center region has size 4 starting from 8 to 12.
        r   .Nr   r$   )r/   r3   leftrightr#   r#   r(   get_fully_sampled_regionX   s    

z-CoilSensitivityModel.get_fully_sampled_region)masked_kspacer3   r4   c           	      C  s   |  |\}}|| }t|}|jd | d d }|d||| ddf |d||| ddf< t|| jdd}t|\}}| |}t||}|t	|| j
d| j
 }|S )	a  
        Args:
            masked_kspace: the under-sampled kspace (which is the input measurement). Its shape
                is (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.
            mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.

        Returns:
            predicted coil sensitivity maps with shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.
        r5   r   r   .NT)r   
is_complex)spatial_dim)r8   torch
zeros_liker%   r	   r   r   r"   r   r   r!   	unsqueeze)	r/   r9   r3   r6   r7   Znum_low_freqsxstartbr#   r#   r(   forwards   s    

0

zCoilSensitivityModel.forward)__name__
__module____qualname____doc__r,   r8   rB   __classcell__r#   r#   r1   r(   r
      s   
$!r
   )
__future__r   collections.abcr   r<   torch.nnnnr   Z#monai.apps.reconstruction.mri_utilsr   Z4monai.apps.reconstruction.networks.nets.complex_unetr   -monai.apps.reconstruction.networks.nets.utilsr   r   !monai.networks.blocks.fft_utils_tr	   Moduler
   r#   r#   r#   r(   <module>   s   