o
    iW                     @  sR   d dl mZ d dlZd dlmZ d dlmZ d dlmZmZ G dd dej	Z
dS )    )annotationsN)Tensor)sensitivity_map_expandsensitivity_map_reducec                      s8   e Zd ZdZdd fddZdddZdddZ  ZS )VarNetBlockaQ  
    A variational block based on Sriram et. al., "End-to-end variational networks for accelerated MRI reconstruction".
    It applies data consistency and refinement to the intermediate kspace and combines those results.

    Modified and adopted from: https://github.com/facebookresearch/fastMRI

    Args:
        refinement_model: the model used for refinement (typically a U-Net but can be any deep learning model
            that performs well when the input and output are in image domain (e.g., a convolutional network).
        spatial_dims: is 2 for 2D data and is 3 for 3D data
       refinement_model	nn.Modulespatial_dimsintc                   sT   t    || _|| _ttd| _dd t	|d D }| 
dt| d S )N   c                 S  s   g | ]}d qS )r    ).0_r   r   w/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/reconstruction/networks/blocks/varnetblock.py
<listcomp>(   s    z(VarNetBlock.__init__.<locals>.<listcomp>   zeros)super__init__modelr
   nn	Parametertorchones	dc_weightrangeregister_bufferr   )selfr   r
   buffer_shape	__class__r   r   r   "   s   
zVarNetBlock.__init__xr   
ref_kspacemaskreturnc                 C  s   t ||| | j| j S )a!  
        Applies data consistency to input x. Suppose x is an intermediate estimate of the kspace and ref_kspace
        is the reference under-sampled measurement. This function returns mask * (x - ref_kspace). View this as the
        residual between the original under-sampled kspace and the estimate given by the network.

        Args:
            x: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the
                coil dimension. 3D data will have the shape (B,C,H,W,D,2).
            ref_kspace: original under-sampled kspace with the same shape as x.
            mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.

        Returns:
            Output of DC block with the same shape as x
        )r   wherer   r   )r   r"   r#   r$   r   r   r   soft_dc+   s   zVarNetBlock.soft_dccurrent_kspace	sens_mapsc                 C  s@   |  |||}t| t||| jd|| jd}|| | }|S )a  
        Args:
            current_kspace: Predicted kspace from the previous block. It's a 2D kspace (B,C,H,W,2)
                with the last dimension being 2 (for real/imaginary parts) and C denoting the
                coil dimension. 3D data will have the shape (B,C,H,W,D,2).
            ref_kspace: reference kspace for applying data consistency (is the under-sampled kspace in MRI reconstruction).
                Its shape is the same as current_kspace.
            mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.
            sens_maps: coil sensitivity maps with the same shape as current_kspace

        Returns:
            Output of VarNetBlock with the same shape as current_kspace
        )r
   )r'   r   r   r   r
   )r   r(   r#   r$   r)   Zdc_outZrefinement_outoutputr   r   r   forward<   s   zVarNetBlock.forward)r   )r   r	   r
   r   )r"   r   r#   r   r$   r   r%   r   )
r(   r   r#   r   r$   r   r)   r   r%   r   )__name__
__module____qualname____doc__r   r'   r+   __classcell__r   r   r    r   r      s
    
	r   )
__future__r   r   torch.nnr   r   Z-monai.apps.reconstruction.networks.nets.utilsr   r   Moduler   r   r   r   r   <module>   s   