U
    Ph=                     @  s  d dl mZ d dlmZmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZ d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZmZ d dlmZ d dlmZ G dd deZ G dd deeZ!G dd de!Z"G dd deeZ#G dd deZ$dS )    )annotations)HashableMappingSequenceN)ndarray)Tensor)EquispacedKspaceMaskRandomKspaceMask)	DtypeLikeKeysCollection)NdarrayOrTensor)InvertibleTransform)SpatialCrop)NormalizeIntensity)MapTransformRandomizableTransform)FastMRIKeys)convert_to_tensorc                   @  s6   e Zd ZdZdddddddd	Zd
ddddZdS )ExtractDataKeyFromMetaKeyday  
    Moves keys from meta to data. It is useful when a dataset of paired samples
    is loaded and certain keys should be moved from meta to data.

    Args:
        keys: keys to be transferred from meta to data
        meta_key: the meta key where all the meta-data is stored
        allow_missing_keys: don't raise exception if key is missing

    Example:
        When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
        but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
        In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
    Fr   strboolNone)keysmeta_keyallow_missing_keysreturnc                 C  s   t | || || _d S N)r   __init__r   )selfr   r   r    r   d/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/reconstruction/transforms/dictionary.pyr   /   s    z#ExtractDataKeyFromMetaKeyd.__init__"Mapping[Hashable, NdarrayOrTensor]dict[Hashable, Tensor]datar   c                 C  sZ   t |}| jD ]F}||| j kr4|| j | ||< q| jstd| d| jj dq|S )
        Args:
            data: is a dictionary containing (key,value) pairs from the
                loaded dataset

        Returns:
            the new data dictionary
        zKey `z` of transform `z=` was missing in the meta data and allow_missing_keys==False.)dictr   r   r   KeyError	__class____name__r   r$   dkeyr   r   r    __call__3   s    	
z#ExtractDataKeyFromMetaKeyd.__call__N)Fr)   
__module____qualname____doc__r   r-   r   r   r   r    r      s   r   c                	      s^   e Zd ZdZejZdddddddd	d
ddZdddd d fddZdddddZ  Z	S )RandomKspaceMaskda}  
    Dictionary-based wrapper of :py:class:`monai.apps.reconstruction.transforms.array.RandomKspacemask`.
    Other mask transforms can inherit from this class, for example:
    :py:class:`monai.apps.reconstruction.transforms.dictionary.EquispacedKspaceMaskd`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        center_fractions: Fraction of low-frequency columns to be retained.
            If multiple values are provided, then one of these numbers is
            chosen uniformly each time.
        accelerations: Amount of under-sampling. This should have the
            same length as center_fractions. If multiple values are provided,
            then one of these is chosen uniformly each time.
        spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data; it's
            also 2 for pseudo-3D datasets like the fastMRI dataset).
            The last spatial dim is selected for sampling. For the fastMRI
            dataset, k-space has the form (...,num_slices,num_coils,H,W)
            and sampling is done along W. For a general 3D data with the
            shape (...,num_coils,H,W,D), sampling is done along D.
        is_complex: if True, then the last dimension will be reserved
            for real/imaginary parts.
        allow_missing_keys: don't raise exception if key is missing.
       TFr   Sequence[float]intr   r   r   center_fractionsaccelerationsspatial_dims
is_complexr   r   c                 C  s$   t | || t||||d| _d S N)r7   r8   r9   r:   )r   r   r	   maskerr   r   r7   r8   r9   r:   r   r   r   r    r   d   s    	zRandomKspaceMaskd.__init__N
int | Nonenp.random.RandomState | Noneseedstater   c                   s    t  || | j|| | S r   superset_random_stater<   r   rA   rB   r(   r   r    rE   u   s    z"RandomKspaceMaskd.set_random_stater!   r"   r#   c                 C  sL   t |}| |D ]4}| || \||d < ||d < | jj|tj< q|S )r%   _maskedZ_masked_ifft)r&   key_iteratorr<   maskr   MASKr*   r   r   r    r-   |   s
    	"zRandomKspaceMaskd.__call__)r3   TF)NN)
r)   r/   r0   r1   r	   backendr   rE   r-   __classcell__r   r   rG   r    r2   H   s         r2   c                	      sN   e Zd ZdZejZdddddddd	d
ddZdddd d fddZ  ZS )EquispacedKspaceMaskda  
    Dictionary-based wrapper of
    :py:class:`monai.apps.reconstruction.transforms.array.EquispacedKspaceMask`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        center_fractions: Fraction of low-frequency columns to be retained.
            If multiple values are provided, then one of these numbers is
            chosen uniformly each time.
        accelerations: Amount of under-sampling. This should have the same
            length as center_fractions. If multiple values are provided,
            then one of these is chosen uniformly each time.
        spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data;
            it's also 2 for  pseudo-3D datasets like the fastMRI dataset).
            The last spatial dim is selected for sampling. For the fastMRI
            dataset, k-space has the form (...,num_slices,num_coils,H,W)
            and sampling is done along W. For a general 3D data with the shape
            (...,num_coils,H,W,D), sampling is done along D.
        is_complex: if True, then the last dimension will be reserved
            for real/imaginary parts.
        allow_missing_keys: don't raise exception if key is missing.
    r3   TFr   r4   r5   r   r   r6   c                 C  s$   t | || t||||d| _d S r;   )r   r   r   r<   r=   r   r   r    r      s    	zEquispacedKspaceMaskd.__init__Nr>   r?   r@   c                   s    t  || | j|| | S r   rC   rF   rG   r   r    rE      s    z&EquispacedKspaceMaskd.set_random_state)r3   TF)NN)	r)   r/   r0   r1   r   rL   r   rE   rM   r   r   rG   r    rN      s         rN   c                   @  s6   e Zd ZdZdddddddd	Zd
ddddZdS )ReferenceBasedSpatialCropda  
    Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`.
    This is similar to :py:class:`monai.transforms.SpatialCropd` which is a
    general purpose cropper to produce sub-volume region of interest (ROI).
    Their difference is that this transform does cropping according to a reference image.

    If a dimension of the expected ROI size is larger than the input image size, will not crop that dimension.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: :py:class:`monai.transforms.compose.MapTransform`
        ref_key: key of the item to be used to crop items of "keys"
        allow_missing_keys: don't raise exception if key is missing.

    Example:
        In an image reconstruction task, let keys=["image"] and ref_key=["target"].
        Also, let data be the data dictionary. Then, ReferenceBasedSpatialCropd
        center-crops data["image"] based on the spatial size of data["target"] by
        calling :py:class:`monai.transforms.SpatialCrop`.
    Fr   r   r   r   )r   ref_keyr   r   c                 C  s   t | || || _d S r   )r   r   rP   )r   r   rP   r   r   r   r    r      s    z#ReferenceBasedSpatialCropd.__init__zMapping[Hashable, Tensor]r"   r#   c                 C  st   t |}|| j jdd }| |D ]H}|| }tdd |jdd D }t||d}t||| ||< q&|S )a  
        This transform can support to crop ND spatial (channel-first) data.
        It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D
        data point where C is the number of slices)

        Args:
            data: is a dictionary containing (key,value) pairs from
                the loaded dataset

        Returns:
            the new data dictionary
           Nc                 s  s   | ]}|d  V  qdS )r3   Nr   ).0ir   r   r    	<genexpr>   s     z6ReferenceBasedSpatialCropd.__call__.<locals>.<genexpr>)
roi_centerroi_size)r&   rP   shaperI   tupler   r   )r   r$   r+   rV   r,   imagerU   cropperr   r   r    r-      s    z#ReferenceBasedSpatialCropd.__call__N)Fr.   r   r   r   r    rO      s   rO   c                      s\   e Zd ZdZejZddddejdfddddddddd	d
	 fddZdddddZ	  Z
S )!ReferenceBasedNormalizeIntensitydao  
    Dictionary-based wrapper of
    :py:class:`monai.transforms.NormalizeIntensity`.
    This is similar to :py:class:`monai.transforms.NormalizeIntensityd`
    and can normalize non-zero values or the entire image. The difference
    is that this transform does normalization according to a reference image.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        ref_key: key of the item to be used to normalize items of "keys"
        subtrahend: the amount to subtract by (usually the mean)
        divisor: the amount to divide by (usually the standard deviation)
        nonzero: whether only normalize non-zero values.
        channel_wise: if True, calculate on each channel separately,
            otherwise, calculate on the entire image directly. default
            to False.
        dtype: output data type, if None, same as input image. defaults
            to float32.
        allow_missing_keys: don't raise exception if key is missing.

    Example:
        In an image reconstruction task, let keys=["image", "target"] and ref_key=["image"].
        Also, let data be the data dictionary. Then, ReferenceBasedNormalizeIntensityd
        normalizes data["target"] and data["image"] based on the mean-std of data["image"] by
        calling :py:class:`monai.transforms.NormalizeIntensity`.
    NFr   r   zNdarrayOrTensor | Noner   r
   r   )	r   rP   
subtrahenddivisornonzerochannel_wisedtyper   r   c	           	        s*   t  || t|||||| _|| _d S r   )rD   r   r   default_normalizerrP   )	r   r   rP   r\   r]   r^   r_   r`   r   rG   r   r    r     s    z*ReferenceBasedNormalizeIntensityd.__init__r!   zdict[Hashable, NdarrayOrTensor]r#   c                 C  sb  t |}| jjrr| jjdkr8tdd || j D }n| jj}| jjdkrhtdd || j D }n| jj}n| jjdkrt|| j t	r|| j 
 }q|| j  
  }n| jj}| jjdkrt|| j t	r|| j  }n|| j  jdd }n| jj}t||| jj| jj| jj}||d< ||d< | |D ]}||| ||< qF|S )	a  
        This transform can support to normalize ND spatial (channel-first) data.
        It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D
        data point where C is the number of slices)

        Args:
            data: is a dictionary containing (key,value) pairs from
                the loaded dataset

        Returns:
            the new data dictionary
        Nc                 S  s.   g | ]&}t |tr| n|   qS r   )
isinstancer   meanfloatitemrR   valr   r   r    
<listcomp>:  s     z>ReferenceBasedNormalizeIntensityd.__call__.<locals>.<listcomp>c                 S  s2   g | ]*}t |tr| n| jd d qS )Funbiased)rb   r   stdrd   re   rf   r   r   r    rh   D  s   Fri   rc   rk   )r&   ra   r_   r\   nparrayrP   r]   rb   r   rc   rd   re   rk   r   r^   r`   rI   )r   r$   r+   r\   r]   
normalizerr,   r   r   r    r-   $  sH    
	z*ReferenceBasedNormalizeIntensityd.__call__)r)   r/   r0   r1   r   rL   rl   float32r   r-   rM   r   r   rG   r    r[      s   $r[   )%
__future__r   collections.abcr   r   r   numpyrl   r   torchr   Z*monai.apps.reconstruction.transforms.arrayr   r	   monai.configr
   r   monai.config.type_definitionsr   monai.transformsr   Zmonai.transforms.croppad.arrayr   Z monai.transforms.intensity.arrayr   monai.transforms.transformr   r   monai.utilsr   monai.utils.type_conversionr   r   r2   rN   rO   r[   r   r   r   r    <module>   s$   )E45