o
    i=                     @  s  d dl mZ d dlmZmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZ d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZmZ d dlmZ d dlmZ G dd deZ G dd deeZ!G dd de!Z"G dd deeZ#G dd deZ$dS )    )annotations)HashableMappingSequenceN)ndarray)Tensor)EquispacedKspaceMaskRandomKspaceMask)	DtypeLikeKeysCollection)NdarrayOrTensor)InvertibleTransform)SpatialCrop)NormalizeIntensity)MapTransformRandomizableTransform)FastMRIKeys)convert_to_tensorc                   @  &   e Zd ZdZddddZdddZdS )ExtractDataKeyFromMetaKeyday  
    Moves keys from meta to data. It is useful when a dataset of paired samples
    is loaded and certain keys should be moved from meta to data.

    Args:
        keys: keys to be transferred from meta to data
        meta_key: the meta key where all the meta-data is stored
        allow_missing_keys: don't raise exception if key is missing

    Example:
        When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
        but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
        In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
    Fkeysr   meta_keystrallow_missing_keysboolreturnNonec                 C     t | || || _d S N)r   __init__r   )selfr   r   r    r!   q/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/reconstruction/transforms/dictionary.pyr   /      
z#ExtractDataKeyFromMetaKeyd.__init__data"Mapping[Hashable, NdarrayOrTensor]dict[Hashable, Tensor]c                 C  sZ   t |}| jD ]#}||| j v r|| j | ||< q| js*td| d| jj dq|S )
        Args:
            data: is a dictionary containing (key,value) pairs from the
                loaded dataset

        Returns:
            the new data dictionary
        zKey `z` of transform `z=` was missing in the meta data and allow_missing_keys==False.)dictr   r   r   KeyError	__class____name__r    r$   dkeyr!   r!   r"   __call__3   s   	
z#ExtractDataKeyFromMetaKeyd.__call__NF)r   r   r   r   r   r   r   r   r$   r%   r   r&   r+   
__module____qualname____doc__r   r/   r!   r!   r!   r"   r      s    r   c                      sH   e Zd ZdZejZ			dd ddZ	d!d" fddZd#ddZ  Z	S )$RandomKspaceMaskda}  
    Dictionary-based wrapper of :py:class:`monai.apps.reconstruction.transforms.array.RandomKspacemask`.
    Other mask transforms can inherit from this class, for example:
    :py:class:`monai.apps.reconstruction.transforms.dictionary.EquispacedKspaceMaskd`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        center_fractions: Fraction of low-frequency columns to be retained.
            If multiple values are provided, then one of these numbers is
            chosen uniformly each time.
        accelerations: Amount of under-sampling. This should have the
            same length as center_fractions. If multiple values are provided,
            then one of these is chosen uniformly each time.
        spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data; it's
            also 2 for pseudo-3D datasets like the fastMRI dataset).
            The last spatial dim is selected for sampling. For the fastMRI
            dataset, k-space has the form (...,num_slices,num_coils,H,W)
            and sampling is done along W. For a general 3D data with the
            shape (...,num_coils,H,W,D), sampling is done along D.
        is_complex: if True, then the last dimension will be reserved
            for real/imaginary parts.
        allow_missing_keys: don't raise exception if key is missing.
       TFr   r   center_fractionsSequence[float]accelerationsspatial_dimsint
is_complexr   r   r   r   c                 C  $   t | || t||||d| _d S N)r8   r:   r;   r=   )r   r   r	   maskerr    r   r8   r:   r;   r=   r   r!   r!   r"   r   d      	zRandomKspaceMaskd.__init__Nseed
int | Nonestatenp.random.RandomState | Nonec                       t  || | j|| | S r   superset_random_stater@   r    rC   rE   r*   r!   r"   rJ   u      z"RandomKspaceMaskd.set_random_stater$   r%   r&   c                 C  sL   t |}| |D ]}| || \||d < ||d < | jj|tj< q	|S )r'   _maskedZ_masked_ifft)r(   key_iteratorr@   maskr   MASKr,   r!   r!   r"   r/   |   s
   	"zRandomKspaceMaskd.__call__r7   TFr   r   r8   r9   r:   r9   r;   r<   r=   r   r   r   r   r   NN)rC   rD   rE   rF   r   r6   r1   )
r+   r3   r4   r5   r	   backendr   rJ   r/   __classcell__r!   r!   rL   r"   r6   H   s    r6   c                      s>   e Zd ZdZejZ			ddddZ	dd fddZ  ZS )EquispacedKspaceMaskda  
    Dictionary-based wrapper of
    :py:class:`monai.apps.reconstruction.transforms.array.EquispacedKspaceMask`.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        center_fractions: Fraction of low-frequency columns to be retained.
            If multiple values are provided, then one of these numbers is
            chosen uniformly each time.
        accelerations: Amount of under-sampling. This should have the same
            length as center_fractions. If multiple values are provided,
            then one of these is chosen uniformly each time.
        spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data;
            it's also 2 for  pseudo-3D datasets like the fastMRI dataset).
            The last spatial dim is selected for sampling. For the fastMRI
            dataset, k-space has the form (...,num_slices,num_coils,H,W)
            and sampling is done along W. For a general 3D data with the shape
            (...,num_coils,H,W,D), sampling is done along D.
        is_complex: if True, then the last dimension will be reserved
            for real/imaginary parts.
        allow_missing_keys: don't raise exception if key is missing.
    r7   TFr   r   r8   r9   r:   r;   r<   r=   r   r   r   r   c                 C  r>   r?   )r   r   r   r@   rA   r!   r!   r"   r      rB   zEquispacedKspaceMaskd.__init__NrC   rD   rE   rF   c                   rG   r   rH   rK   rL   r!   r"   rJ      rM   z&EquispacedKspaceMaskd.set_random_staterR   rS   rT   )rC   rD   rE   rF   r   rW   )	r+   r3   r4   r5   r   rU   r   rJ   rV   r!   r!   rL   r"   rW      s    rW   c                   @  r   )ReferenceBasedSpatialCropda  
    Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`.
    This is similar to :py:class:`monai.transforms.SpatialCropd` which is a
    general purpose cropper to produce sub-volume region of interest (ROI).
    Their difference is that this transform does cropping according to a reference image.

    If a dimension of the expected ROI size is larger than the input image size, will not crop that dimension.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: :py:class:`monai.transforms.compose.MapTransform`
        ref_key: key of the item to be used to crop items of "keys"
        allow_missing_keys: don't raise exception if key is missing.

    Example:
        In an image reconstruction task, let keys=["image"] and ref_key=["target"].
        Also, let data be the data dictionary. Then, ReferenceBasedSpatialCropd
        center-crops data["image"] based on the spatial size of data["target"] by
        calling :py:class:`monai.transforms.SpatialCrop`.
    Fr   r   ref_keyr   r   r   r   r   c                 C  r   r   )r   r   rY   )r    r   rY   r   r!   r!   r"   r      r#   z#ReferenceBasedSpatialCropd.__init__r$   Mapping[Hashable, Tensor]r&   c                 C  st   t |}|| j jdd }| |D ]$}|| }tdd |jdd D }t||d}t||| ||< q|S )a  
        This transform can support to crop ND spatial (channel-first) data.
        It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D
        data point where C is the number of slices)

        Args:
            data: is a dictionary containing (key,value) pairs from
                the loaded dataset

        Returns:
            the new data dictionary
           Nc                 s  s    | ]}|d  V  qdS )r7   Nr!   ).0ir!   r!   r"   	<genexpr>   s    z6ReferenceBasedSpatialCropd.__call__.<locals>.<genexpr>)
roi_centerroi_size)r(   rY   shaperO   tupler   r   )r    r$   r-   r`   r.   imager_   cropperr!   r!   r"   r/      s   z#ReferenceBasedSpatialCropd.__call__Nr0   )r   r   rY   r   r   r   r   r   )r$   rZ   r   r&   r2   r!   r!   r!   r"   rX      s    rX   c                      sB   e Zd ZdZejZddddejdfd fddZdddZ	  Z
S )!ReferenceBasedNormalizeIntensitydao  
    Dictionary-based wrapper of
    :py:class:`monai.transforms.NormalizeIntensity`.
    This is similar to :py:class:`monai.transforms.NormalizeIntensityd`
    and can normalize non-zero values or the entire image. The difference
    is that this transform does normalization according to a reference image.

    Args:
        keys: keys of the corresponding items to be transformed.
            See also: monai.transforms.MapTransform
        ref_key: key of the item to be used to normalize items of "keys"
        subtrahend: the amount to subtract by (usually the mean)
        divisor: the amount to divide by (usually the standard deviation)
        nonzero: whether only normalize non-zero values.
        channel_wise: if True, calculate on each channel separately,
            otherwise, calculate on the entire image directly. default
            to False.
        dtype: output data type, if None, same as input image. defaults
            to float32.
        allow_missing_keys: don't raise exception if key is missing.

    Example:
        In an image reconstruction task, let keys=["image", "target"] and ref_key=["image"].
        Also, let data be the data dictionary. Then, ReferenceBasedNormalizeIntensityd
        normalizes data["target"] and data["image"] based on the mean-std of data["image"] by
        calling :py:class:`monai.transforms.NormalizeIntensity`.
    NFr   r   rY   r   
subtrahendNdarrayOrTensor | Nonedivisornonzeror   channel_wisedtyper
   r   r   r   c	           	        s*   t  || t|||||| _|| _d S r   )rI   r   r   default_normalizerrY   )	r    r   rY   rf   rh   ri   rj   rk   r   rL   r!   r"   r     s   
z*ReferenceBasedNormalizeIntensityd.__init__r$   r%   dict[Hashable, NdarrayOrTensor]c                 C  s^  t |}| jjr9| jjdu rtdd || j D }n| jj}| jjdu r4tdd || j D }nS| jj}nN| jjdu r[t|| j t	rO|| j 
 }n|| j  
  }n| jj}| jjdu rt|| j t	ru|| j  }n|| j  jdd }n| jj}t||| jj| jj| jj}||d< ||d< | |D ]
}||| ||< q|S )	a  
        This transform can support to normalize ND spatial (channel-first) data.
        It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D
        data point where C is the number of slices)

        Args:
            data: is a dictionary containing (key,value) pairs from
                the loaded dataset

        Returns:
            the new data dictionary
        Nc                 S  s.   g | ]}t |tr| n|   qS r!   )
isinstancer   meanfloatitemr\   valr!   r!   r"   
<listcomp>:  s   . z>ReferenceBasedNormalizeIntensityd.__call__.<locals>.<listcomp>c                 S  s2   g | ]}t |tr| n	| jd d qS )Funbiased)rn   r   stdrp   rq   rr   r!   r!   r"   rt   D  s    $Fru   ro   rw   )r(   rl   rj   rf   nparrayrY   rh   rn   r   ro   rp   rq   rw   r   ri   rk   rO   )r    r$   r-   rf   rh   
normalizerr.   r!   r!   r"   r/   $  sH   
	z*ReferenceBasedNormalizeIntensityd.__call__)r   r   rY   r   rf   rg   rh   rg   ri   r   rj   r   rk   r
   r   r   r   r   )r$   r%   r   rm   )r+   r3   r4   r5   r   rU   rx   float32r   r/   rV   r!   r!   rL   r"   re      s    re   )%
__future__r   collections.abcr   r   r   numpyrx   r   torchr   Z*monai.apps.reconstruction.transforms.arrayr   r	   monai.configr
   r   monai.config.type_definitionsr   monai.transformsr   monai.transforms.croppad.arrayr    monai.transforms.intensity.arrayr   monai.transforms.transformr   r   monai.utilsr   monai.utils.type_conversionr   r   r6   rW   rX   re   r!   r!   r!   r"   <module>   s&   )E45