o
    i                     @  sr   d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	m
Z
 d dlmZ d dlmZ G dd	 d	ejZdS )
    )annotationsN)Tensor)complex_abs_t)root_sum_of_squares_t)VarNetBlock)ifftn_centered_tc                      s2   e Zd ZdZ		dd fd
dZdddZ  ZS )VariationalNetworkModelac  
    The end-to-end variational network (or simply e2e-VarNet) based on Sriram et. al., "End-to-end variational
    networks for accelerated MRI reconstruction".
    It comprises several cascades each consisting of refinement and data consistency steps. The network takes in
    the under-sampled kspace and estimates the ground-truth reconstruction.

    Modified and adopted from: https://github.com/facebookresearch/fastMRI

    Args:
        coil_sensitivity_model: A convolutional model for learning coil sensitivity maps. An example is
            :py:class:`monai.apps.reconstruction.networks.nets.coil_sensitivity_model.CoilSensitivityModel`.
        refinement_model: A convolutional network used in the refinement step of e2e-VarNet. An example
            is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet.ComplexUnet`.
        num_cascades: Number of cascades. Each cascade is a
            :py:class:`monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock` which consists of
            refinement and data consistency steps.
        spatial_dims: number of spatial dimensions.
          coil_sensitivity_model	nn.Modulerefinement_modelnum_cascadesintspatial_dimsc                   s8   t    || _t fddt|D | _|| _d S )Nc                   s   g | ]	}t t qS  )r   copydeepcopy).0ir   r   p/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/reconstruction/networks/nets/varnet.py
<listcomp>6   s    z4VariationalNetworkModel.__init__.<locals>.<listcomp>)super__init__r   nn
ModuleListrangecascadesr   )selfr   r   r   r   	__class__r   r   r   -   s   

z VariationalNetworkModel.__init__masked_kspacer   maskreturnc                 C  sL   |  ||}| }| jD ]	}|||||}qttt|| jddd}|S )ad  
        Args:
            masked_kspace: The under-sampled kspace. It's a 2D kspace (B,C,H,W,2)
                with the last dimension being 2 (for real/imaginary parts) and C denoting the
                coil dimension. 3D data will have the shape (B,C,H,W,D,2).
            mask: The under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.

        Returns:
            The reconstructed image which is the root sum of squares (rss) of the absolute value
                of the inverse fourier of the predicted kspace (note that rss combines coil images into one image).
        )r      )spatial_dim)r   cloner   r   r   r   r   )r   r"   r#   Zsensitivity_mapsZkspace_predcascadeZoutput_imager   r   r   forward9   s   
zVariationalNetworkModel.forward)r	   r
   )r   r   r   r   r   r   r   r   )r"   r   r#   r   r$   r   )__name__
__module____qualname____doc__r   r)   __classcell__r   r   r    r   r      s    r   )
__future__r   r   torch.nnr   torchr   Z'monai.apps.reconstruction.complex_utilsr   Z#monai.apps.reconstruction.mri_utilsr   Z5monai.apps.reconstruction.networks.blocks.varnetblockr   !monai.networks.blocks.fft_utils_tr   Moduler   r   r   r   r   <module>   s   