o
    ik                    @  s  d Z ddlmZ ddlmZ ddlmZmZmZ ddl	m
Z
 ddlmZ ddlmZ ddlZddlZdd	lmZ dd
lmZmZ ddlmZ ddlmZ ddlmZmZ ddlmZm Z m!Z!m"Z" ddl#m$Z$m%Z% ddl&m'Z'm(Z(m)Z)m*Z*m+Z+ ddl,m-Z-m.Z.m/Z/ ddl0m1Z1 ddl2m3Z3m4Z4m5Z5m6Z6 ddl7m8Z8m9Z9 ddl:m;Z;m<Z<m=Z=m>Z> e9dde8\Z?Z@g dZAG dd de$ZBG dd de$ZCG dd de%ZDG dd  d e$ZEG d!d" d"e%ZFG d#d$ d$e$ZGG d%d& d&e%ZHG d'd( d(e%ZIG d)d* d*e$ZJG d+d, d,e$ZKG d-d. d.e$ZLG d/d0 d0e%ZMG d1d2 d2e%ZNG d3d4 d4e%ZOG d5d6 d6e%ZPG d7d8 d8e%ZQG d9d: d:e$ZRG d;d< d<e%ZSG d=d> d>e%ZTG d?d@ d@e%ZUG dAdB dBe%ZVG dCdD dDe%ZWG dEdF dFe%ZXG dGdH dHe$ZYG dIdJ dJe%ZZG dKdL dLe$Z[G dMdN dNe$Z\G dOdP dPe%e'Z]G dQdR dRe$Z^G dSdT dTe%e'Z_G dUdV dVe$e'Z`G dWdX dXe$ZaG dYdZ dZeaZbG d[d\ d\eaZcG d]d^ d^e%ZdG d_d` d`e$ZeG dadb dbe$ZfG dcdd dde%ZgG dedf dfe%ZhG dgdh dhe%ZidS )iz@
A collection of "vanilla" transforms for intensity adjustment.
    )annotations)abstractmethod)CallableIterableSequence)partial)Any)warnN)	DtypeLike)NdarrayOrTensorNdarrayTensor)get_track_meta)UltrasoundConfidenceMap)get_random_patchget_valid_patch_size)GaussianFilterHilbertTransformMedianFilterSavitzkyGolayFilter)RandomizableTransform	Transform)Fourierequalize_histis_positiverescale_array	soft_clip)clip
percentilewhere)TransformBackends)ensure_tupleensure_tuple_repensure_tuple_sizefall_back_tuple)min_versionoptional_import)convert_data_typeconvert_to_dst_typeconvert_to_tensorget_equivalent_dtypeskimagez0.19.0)(RandGaussianNoiseRandRicianNoiseShiftIntensityRandShiftIntensityStdShiftIntensityRandStdShiftIntensityRandBiasFieldScaleIntensityRandScaleIntensityScaleIntensityFixedMeanRandScaleIntensityFixedMeanNormalizeIntensityThresholdIntensityScaleIntensityRangeClipIntensityPercentilesAdjustContrastRandAdjustContrastScaleIntensityRangePercentilesMaskIntensityDetectEnvelopeSavitzkyGolaySmoothMedianSmoothGaussianSmoothRandGaussianSmoothGaussianSharpenRandGaussianSharpenRandHistogramShift
GibbsNoiseRandGibbsNoiseKSpaceSpikeNoiseRandKSpaceSpikeNoiseRandCoarseTransformRandCoarseDropoutRandCoarseShuffleHistogramNormalizeIntensityRemapRandIntensityRemapForegroundMaskComputeHoVerMaps UltrasoundConfidenceMapTransformc                      sT   e Zd ZdZejejgZdddej	dfdddZ
dd fddZddddZ  ZS )r+   u  
    Add Gaussian noise to image.

    Args:
        prob: Probability to add Gaussian noise.
        mean: Mean or “centre” of the distribution.
        std: Standard deviation (spread) of distribution.
        dtype: output data type, if None, same as input image. defaults to float32.
        sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.

    皙?        Tprobfloatmeanstddtyper
   
sample_stdboolreturnNonec                 C  s.   t | | || _|| _|| _d | _|| _d S N)r   __init__rW   rX   rY   noiserZ   )selfrU   rW   rX   rY   rZ    rb   b/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/transforms/intensity/array.pyr_   e   s   
zRandGaussianNoise.__init__Nimgr   float | Nonec                   sl   t  d  | jsd S | jr| jd| jn| j}| jj|d u r#| jn|||j	d}t
|| jd^| _}d S )Nr   sizerY   )super	randomize_do_transformrZ   RuniformrX   normalrW   shaper&   rY   r`   )ra   rd   rW   rX   r`   _	__class__rb   rc   rj   t   s   "zRandGaussianNoise.randomizerj   c                 C  st   t |t d}|r| j||du r| jn|d | js|S | jdu r%tdt|| jd^}}t	| j|^}}|| S )/
        Apply the transform to `img`.
        
track_metaN)rd   rW   -please call the `randomize()` function first.rh   )
r(   r   rj   rW   rk   r`   RuntimeErrorr&   rY   r'   )ra   rd   rW   rj   rp   r`   rb   rb   rc   __call__}   s   
zRandGaussianNoise.__call__)rU   rV   rW   rV   rX   rV   rY   r
   rZ   r[   r\   r]   r^   )rd   r   rW   re   r\   r]   NT)rd   r   rW   re   rj   r[   r\   r   __name__
__module____qualname____doc__r   TORCHNUMPYbackendnpfloat32r_   rj   rx   __classcell__rb   rb   rq   rc   r+   V   s    	r+   c                      sV   e Zd ZdZejejgZddddddej	fdddZ
dddZdd  fddZ  ZS )!r,   a  
    Add Rician noise to image.
    Rician noise in MRI is the result of performing a magnitude operation on complex
    data with Gaussian noise of the same variance in both channels, as described in
    `Noise in Magnitude Magnetic Resonance Images <https://doi.org/10.1002/cmr.a.20124>`_.
    This transform is adapted from `DIPY <https://github.com/dipy/dipy>`_.
    See also: `The rician distribution of noisy mri data <https://doi.org/10.1002/mrm.1910340618>`_.

    Args:
        prob: Probability to add Rician noise.
        mean: Mean or "centre" of the Gaussian distributions sampled to make up
            the Rician noise.
        std: Standard deviation (spread) of the Gaussian distributions sampled
            to make up the Rician noise.
        channel_wise: If True, treats each channel of the image separately.
        relative: If True, the spread of the sampled Gaussian distributions will
            be std times the standard deviation of the image or channel's intensity
            histogram.
        sample_std: If True, sample the spread of the Gaussian distributions
            uniformly from 0 to std.
        dtype: output data type, if None, same as input image. defaults to float32.

    rS   rT         ?FTrU   rV   rW   Sequence[float] | floatrX   channel_wiser[   relativerZ   rY   r
   r\   r]   c                 C  sB   t | | || _|| _|| _|| _|| _|| _|| _|  |  d S r^   )	r   r_   rU   rW   rX   r   r   rZ   rY   )ra   rU   rW   rX   r   r   rZ   rY   rb   rb   rc   r_      s   
zRandRicianNoise.__init__rd   r   c           	      C  s   t |jtj}|j}| jr| jd|n|}| jj|||dj	|dd| _
| jj|||dj	|dd| _t|tjrYtj| j
|jd}tj| j|jd}t|| d |d  S t|| j
 d | jd  S )Nr   rf   Fcopydevice   )r)   rY   r   ndarrayro   rZ   rl   rm   rn   astypeZ_noise1Z_noise2
isinstancetorchTensortensorr   sqrt)	ra   rd   rW   rX   Zdtype_npZim_shape_stdn1n2rb   rb   rc   
_add_noise   s   zRandRicianNoise._add_noiserj   c                   s<  t |t | jd}|rt d | js|S | jrMt| jt	|}t| j
t	|}t|D ]\}}| j||| | jrB|| |
  n|| d||< q-|S t| jttfs`tdt| j dt| j
ttfsstdt| j
 d| jr| j
|
   n| j
}t|ttfstdt| d| j|| j|d}|S )rs   ru   rY   N)rW   rX   z;If channel_wise is False, mean must be a float or int, got .z:If channel_wise is False, std must be a float or int, got z'std must be a float or int number, got )r(   r   rY   ri   rj   rk   r   r!   rW   lenrX   	enumerater   r   r   intrV   rw   typeitem)ra   rd   rj   _meanr   idrX   rq   rb   rc   rx      s(   4
zRandRicianNoise.__call__)rU   rV   rW   r   rX   r   r   r[   r   r[   rZ   r[   rY   r
   r\   r]   )rd   r   rW   rV   rX   rV   Trd   r   rj   r[   r\   r   )r{   r|   r}   r~   r   r   r   r   r   r   r_   r   rx   r   rb   rb   rq   rc   r,      s    
r,   c                   @  s4   e Zd ZdZejejgZddd	d
ZddddZ	dS )r-   aq  
    Shift intensity uniformly for the entire image with specified `offset`.

    Args:
        offset: offset value to shift the intensity of image.
        safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
            E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
    FoffsetrV   safer[   r\   r]   c                 C     || _ || _d S r^   )r   r   )ra   r   r   rb   rb   rc   r_         
zShiftIntensity.__init__Nrd   r   re   c                 C  sB   t |t d}|du r| jn|}|| }t||j| jd^}}|S )rs   rt   N)datarY   r   )r(   r   r   r&   rY   r   )ra   rd   r   outrp   rb   rb   rc   rx      s
   zShiftIntensity.__call__)F)r   rV   r   r[   r\   r]   r^   )rd   r   r   re   r\   r   
r{   r|   r}   r~   r   r   r   r   r_   rx   rb   rb   rb   rc   r-      s
    	r-   c                      sJ   e Zd ZdZejejgZ	ddddZdd fddZ	d d!ddZ
  ZS )"r.   z?
    Randomly shift intensity with randomly picked offset.
    FrS   offsetstuple[float, float] | floatr   r[   rU   rV   r   r\   r]   c                 C  s   t | | t|ttfrt| |t| |f| _nt|dkr)t	d| dt|t|f| _| jd | _
|| _t| j
|| _dS )a  
        Args:
            offsets: offset range to randomly shift.
                if single number, offset value is picked from (-offsets, offsets).
            safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
                E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
            prob: probability of shift.
            channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen.
                Please ensure that the first dimension represents the channel of the image if True.
        r   z3offsets should be a number or pair of numbers, got r   r   N)r   r_   r   r   rV   minmaxr   r   
ValueError_offsetr   r-   _shifter)ra   r   r   rU   r   rb   rb   rc   r_     s   zRandShiftIntensity.__init__Nr   
Any | Nonec                   `   t  d   jsd S  jr fddt|jd D  _d S  jj j	d  j	d d _d S )Nc                   (   g | ]} j j jd   jd dqS r      lowhigh)rl   rm   r   .0rp   ra   rb   rc   
<listcomp>)     ( z0RandShiftIntensity.randomize.<locals>.<listcomp>r   r   r   )
ri   rj   rk   r   rangero   r   rl   rm   r   ra   r   rq   r   rc   rj   $     ""zRandShiftIntensity.randomizeTrd   r   factorre   rj   c           	      C  s   t |t d}|r| | | js|S | jrAg }t|D ]\}}| ||du r,| j| n| j| | }|| qt	
|}|S | ||du rK| jn| j| }|S )a
  
        Apply the transform to `img`.

        Args:
            img: input image to shift intensity.
            factor: a factor to multiply the random offset, then shift.
                can be some image specific value at runtime, like: max(img), etc.

        rt   N)r(   r   rj   rk   r   r   r   r   appendr   stack)	ra   rd   r   rj   r   r   r   out_channelretrb   rb   rc   rx   -  s   

(
 zRandShiftIntensity.__call__)FrS   F)
r   r   r   r[   rU   rV   r   r[   r\   r]   r^   r   r   r\   r]   ry   )rd   r   r   re   rj   r[   r\   r   )r{   r|   r}   r~   r   r   r   r   r_   rj   rx   r   rb   rb   rq   rc   r.     s    	r.   c                   @  sD   e Zd ZdZejejgZddej	fdddZ
dddZdddZdS )r/   a  
    Shift intensity for the image with a factor and the standard deviation of the image
    by: ``v = v + factor * std(v)``.
    This transform can focus on only non-zero values or the entire image,
    and can also calculate the std on each channel separately.

    Args:
        factor: factor shift by ``v = v + factor * std(v)``.
        nonzero: whether only count non-zero values.
        channel_wise: if True, calculate on each channel separately. Please ensure
            that the first dimension represents the channel of the image if True.
        dtype: output data type, if None, same as input image. defaults to float32.
    Fr   rV   nonzeror[   r   rY   r
   r\   r]   c                 C     || _ || _|| _|| _d S r^   r   r   r   rY   )ra   r   r   r   rY   rb   rb   rc   r_   [  s   
zStdShiftIntensity.__init__rd   r   c                 C  sx   t |tjrtj}ttjdd}ntj}tj}| jr|dkn||jt	d}|
 r:| j|||  }|| | ||< |S )NFunbiasedr   rh   )r   r   r   onesr   rX   r   r   ro   r[   anyr   )ra   rd   r   rX   slicesr   rb   rb   rc   	_stdshiftc  s   zStdShiftIntensity._stdshiftc                 C  sJ   t |t | jd}| jrt|D ]\}}| |||< q|S | |}|S )rs   r   )r(   r   rY   r   r   r   )ra   rd   r   r   rb   rb   rc   rx   s  s   
zStdShiftIntensity.__call__N)
r   rV   r   r[   r   r[   rY   r
   r\   r]   rd   r   r\   r   )r{   r|   r}   r~   r   r   r   r   r   r   r_   r   rx   rb   rb   rb   rc   r/   J  s    
r/   c                      sR   e Zd ZdZejejgZdddej	fdddZ
dd fddZdd ddZ  ZS )!r0   z
    Shift intensity for the image with a factor and the standard deviation of the image
    by: ``v = v + factor * std(v)`` where the `factor` is randomly picked.
    rS   Ffactorsr   rU   rV   r   r[   r   rY   r
   r\   r]   c                 C  s   t | | t|ttfrt| |t| |f| _nt|dkr)t	d| dt|t|f| _| jd | _
|| _|| _|| _dS )a  
        Args:
            factors: if tuple, the randomly picked range is (min(factors), max(factors)).
                If single number, the range is (-factors, factors).
            prob: probability of std shift.
            nonzero: whether only count non-zero values.
            channel_wise: if True, calculate on each channel separately.
            dtype: output data type, if None, same as input image. defaults to float32.

        r   3factors should be a number or pair of numbers, got r   r   N)r   r_   r   r   rV   r   r   r   r   r   r   r   r   rY   )ra   r   rU   r   r   rY   rb   rb   rc   r_     s   
zRandStdShiftIntensity.__init__Nr   r   c                   8   t  d  | jsd S | jj| jd | jd d| _d S Nr   r   r   ri   rj   rk   rl   rm   r   r   r   rq   rb   rc   rj        "zRandStdShiftIntensity.randomizeTrd   r   rj   c                 C  sJ   t |t | jd}|r|   | js|S t| j| j| j| jd}||dS )rs   r   r   rd   )	r(   r   rY   rj   rk   r/   r   r   r   )ra   rd   rj   Zshifterrb   rb   rc   rx     s   
zRandStdShiftIntensity.__call__)r   r   rU   rV   r   r[   r   r[   rY   r
   r\   r]   r^   r   r   r   rz   rb   rb   rq   rc   r0     s    r0   c                   @  s>   e Zd ZdZejejgZddddej	fdddZ
dddZdS )r2   z
    Scale the intensity of input image to the given value range (minv, maxv).
    If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.
    rT   r   NFminvre   maxvr   r   r[   rY   r
   r\   r]   c                 C  "   || _ || _|| _|| _|| _dS )a  
        Args:
            minv: minimum value of output data.
            maxv: maximum value of output data.
            factor: factor scale by ``v = v * (1 + factor)``. In order to use
                this parameter, please set both `minv` and `maxv` into None.
            channel_wise: if True, scale on each channel separately. Please ensure
                that the first dimension represents the channel of the image if True.
            dtype: output data type, if None, same as input image. defaults to float32.
        N)r   r   r   r   rY   )ra   r   r   r   r   rY   rb   rb   rc   r_     s
   
zScaleIntensity.__init__rd   r   c                   s   t |t d}t |dd} jdus jdur5 jr) fdd|D }t|}nt| j j jd}n j	durA|d j	  n|}t
|| jpK|jdd	 }|S )
z
        Apply the transform to `img`.

        Raises:
            ValueError: When ``self.minv=None`` or ``self.maxv=None`` and ``self.factor=None``. Incompatible values.

        rt   FNc                   s"   g | ]}t | j j jd qS )rh   )r   r   r   rY   r   r   r   rb   rc   r        " z+ScaleIntensity.__call__.<locals>.<listcomp>rh   r   dstrY   r   )r(   r   r   r   r   r   r   r   rY   r   r'   )ra   rd   img_tr   r   rb   r   rc   rx     s   zScaleIntensity.__call__)r   re   r   re   r   re   r   r[   rY   r
   r\   r]   r   r{   r|   r}   r~   r   r   r   r   r   r   r_   rx   rb   rb   rb   rc   r2     s    r2   c                   @  s@   e Zd ZdZejejgZddddej	fdddZ
ddddZdS )r4   z
    Scale the intensity of input image by ``v = v * (1 + factor)``, then shift the output so that the output image has the
    same mean as the input.
    r   FTr   rV   preserve_ranger[   
fixed_meanr   rY   r
   r\   r]   c                 C  r   )a  
        Args:
            factor: factor scale by ``v = v * (1 + factor)``.
            preserve_range: clips the output array/tensor to the range of the input array/tensor
            fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
                to ensure that the output has the same mean as the input.
            channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
                on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
                channel of the image if True.
            dtype: output data type, if None, same as input image. defaults to float32.
        N)r   r   r   r   rY   )ra   r   r   r   r   rY   rb   rb   rc   r_     s
   
z ScaleIntensityFixedMean.__init__Nrd   r   c                 C  s$  |dur|n| j }t|t d}t|dd}| jrWg }|D ]3}| jr*| }| }| jr5| }|| }|d|  }	| jrB|	| }	| jrKt	|	||}	|
|	 qt|}
n,| jrb| }| }| jrm| }|| }|d|  }
| jrz|
| }
| jrt	|
||}
t|
|| jp|jdd }
|
S )z
        Apply the transform to `img`.
        Args:
            img: the input tensor/array
            factor: factor scale by ``v = v * (1 + factor)``

        Nrt   Fr   r   r   )r   r(   r   r   r   r   r   r   rW   r   r   r   r   r'   rY   )ra   rd   r   r   r   r   Zclip_minZclip_maxmnr   r   rb   rb   rc   rx     s@   	z ScaleIntensityFixedMean.__call__)r   rV   r   r[   r   r[   r   r[   rY   r
   r\   r]   r^   r   r   rb   rb   rb   rc   r4     s    r4   c                      sN   e Zd ZdZejZddddejfdddZdd fddZ	d d!ddZ
  ZS )"r5   a  
    Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
    is randomly picked. Subtract the mean intensity before scaling with `factor`, then add the same value after scaling
    to ensure that the output has the same mean as the input.
    rS   r   TFrU   rV   r   r   r   r[   r   rY   r
   r\   r]   c                 C  s   t | | t|ttfrt| |t| |f| _nt|dkr%t	dt|t|f| _| jd | _
|| _|| _|| _t| j
| j| j| jd| _dS )aQ  
        Args:
            factors: factor range to randomly scale by ``v = v * (1 + factor)``.
                if single number, factor value is picked from (-factors, factors).
            preserve_range: clips the output array/tensor to the range of the input array/tensor
            fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
                to ensure that the output has the same mean as the input.
            channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
            on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
            channel of the image if True.
            dtype: output data type, if None, same as input image. defaults to float32.

        r   z.factors should be a number or pair of numbers.r   )r   r   r   rY   N)r   r_   r   r   rV   r   r   r   r   r   r   r   r   rY   r4   scaler)ra   rU   r   r   r   rY   rb   rb   rc   r_   V  s   z$RandScaleIntensityFixedMean.__init__Nr   r   c                   r   r   r   r   rq   rb   rc   rj   {  r   z%RandScaleIntensityFixedMean.randomizerd   r   rj   c                 C  s@   t |t d}|r|   | jst|| jdd S | || jS )rs   rt   rh   r   )r(   r   rj   rk   r&   rY   r   r   ra   rd   rj   rb   rb   rc   rx     s   z$RandScaleIntensityFixedMean.__call__)rU   rV   r   r   r   r[   r   r[   rY   r
   r\   r]   r^   r   r   r   )r{   r|   r}   r~   r4   r   r   r   r_   rj   rx   r   rb   rb   rq   rc   r5   M  s    %r5   c                      sJ   e Zd ZdZejZddejfdddZdd fddZ	ddddZ
  ZS ) r3   z|
    Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
    is randomly picked.
    rS   Fr   r   rU   rV   r   r[   rY   r
   r\   r]   c                 C  s   t | | t|ttfrt| |t| |f| _nt|dkr)t	d| dt|t|f| _| jd | _
|| _|| _dS )a  
        Args:
            factors: factor range to randomly scale by ``v = v * (1 + factor)``.
                if single number, factor value is picked from (-factors, factors).
            prob: probability of scale.
            channel_wise: if True, scale on each channel separately. Please ensure
                that the first dimension represents the channel of the image if True.
            dtype: output data type, if None, same as input image. defaults to float32.

        r   r   r   r   N)r   r_   r   r   rV   r   r   r   r   r   r   r   rY   )ra   r   rU   r   rY   rb   rb   rc   r_     s   
zRandScaleIntensity.__init__Nr   r   c                   r   )Nc                   r   r   )rl   rm   r   r   r   rb   rc   r     r   z0RandScaleIntensity.randomize.<locals>.<listcomp>r   r   r   )
ri   rj   rk   r   r   ro   r   rl   rm   r   r   rq   r   rc   rj     r   zRandScaleIntensity.randomizeTrd   r   rj   c                 C  s   t |t d}|r| | | jst|| jdd S | jrBg }t|D ]\}}tdd| j	| | jd|}|
| q#t|}|S tdd| j	| jd|}|S )rs   rt   rh   r   N)r   r   r   rY   )r(   r   rj   rk   r&   rY   r   r   r2   r   r   r   r   )ra   rd   rj   r   r   r   r   r   rb   rb   rc   rx     s   

zRandScaleIntensity.__call__)
r   r   rU   rV   r   r[   rY   r
   r\   r]   r^   r   r   r   )r{   r|   r}   r~   r2   r   r   r   r_   rj   rx   r   rb   rb   rq   rc   r3     s    	r3   c                      sV   e Zd ZdZejgZddejdfd!ddZ	d"ddZ
d# fddZd$d%dd Z  ZS )&r1   a  
    Random bias field augmentation for MR images.
    The bias field is considered as a linear combination of smoothly varying basis (polynomial)
    functions, as described in `Automated Model-Based Tissue Classification of MR Images of the Brain
    <https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=811270>`_.
    This implementation adapted from `NiftyNet
    <https://github.com/NifTK/NiftyNet>`_.
    Referred to `Longitudinal segmentation of age-related white matter hyperintensities
    <https://www.sciencedirect.com/science/article/pii/S1361841517300257?via%3Dihub>`_.

    Args:
        degree: degree of freedom of the polynomials. The value should be no less than 1.
            Defaults to 3.
        coeff_range: range of the random coefficients. Defaults to (0.0, 0.1).
        dtype: output data type, if None, same as input image. defaults to float32.
        prob: probability to do random bias field.

       )rT   rS   rS   degreer   coeff_rangetuple[float, float]rY   r
   rU   rV   r\   r]   c                 C  sB   t | | |dk rtd| d|| _|| _|| _dg| _d S )Nr   z%degree should be no less than 1, got r   r   )r   r_   r   r   r   rY   _coeff)ra   r   r   rY   rU   rb   rb   rc   r_     s   zRandBiasField.__init__spatial_shapeSequence[int]coeffSequence[float]c                 C  s>  t |}t|d f| }dd |D }|dkr/||t|d < tjj|d |d |S |dkrg dg}t|d D ]"}t|d | D ]}	t|d | |	 D ]
}
|||	|
g qTqHq>t |dkrm|dd }t	|}|||dddf |dddf |dddf f< tjj
|d |d |d |S td	)
zC
        products of polynomials as bias field estimations
        r   c                 S  s    g | ]}t jd d|t jdqS )      r   rh   )r   linspacer   )r   dimrb   rb   rc   r     s     z8RandBiasField._generate_random_field.<locals>.<listcomp>r   r   r   )r   r   r   Nzonly supports 2D or 3D fields)r   r   zerostril_indices
polynomiallegendre	leggrid2dr   r   r   	leggrid3dNotImplementedError)ra   r   r   r   rankZ	coeff_matcoordsptsr   jkZnp_ptsrb   rb   rc   _generate_random_field  s(   

2 z$RandBiasField._generate_random_fieldimg_sizec                   sb   t  d   jsd S tt fddtdt|d D } jj	g  j
|R    _d S )Nc                   s   g | ]	} j | | qS rb   )r   r   r   r   rb   rc   r     s    z+RandBiasField.randomize.<locals>.<listcomp>r   )ri   rj   rk   r   r   prodr   r   rl   rm   r   tolistr   )ra   r  Zn_coeffrq   r   rc   rj     s
   *"zRandBiasField.randomizeTrd   r   rj   r[   c                   s   t |t d}|r j|jdd d  js|S |j^}tj fddt|D dd}t|tj	^}}|t
| }t|| jpF|jd	^}}|S )
rs   rt   r   N)r  c                   s    g | ]} j  j jd qS ))r   r   r   )r   r   r   r   ra   r   rb   rc   r   $  s    z*RandBiasField.__call__.<locals>.<listcomp>r   axissrcr   rY   )r(   r   rj   ro   rk   r   r   r   r&   r   expr'   rY   )ra   rd   rj   num_channelsZ_bias_fieldsimg_nprp   r   rb   r  rc   rx     s    
zRandBiasField.__call__)
r   r   r   r   rY   r
   rU   rV   r\   r]   )r   r   r   r   r   r   r  r   r\   r]   r   r   )r{   r|   r}   r~   r   r   r   r   r   r_   r   rj   rx   r   rb   rb   rq   rc   r1     s    
r1   c                   @  sb   e Zd ZdZejejgZddddej	fdddZ
edd Zedd ZddddZdddZdS )r6   a  
    Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
    Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided.
    This transform can normalize only non-zero values or entire image, and can also calculate
    mean and std on each channel separately.
    When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
    be the number of image channels if they are not None.
    If the input is not of floating point type, it will be converted to float32

    Args:
        subtrahend: the amount to subtract by (usually the mean).
        divisor: the amount to divide by (usually the standard deviation).
        nonzero: whether only normalize non-zero values.
        channel_wise: if True, calculate on each channel separately, otherwise, calculate on
            the entire image directly. default to False.
        dtype: output data type, if None, same as input image. defaults to float32.
    NF
subtrahend!Sequence | NdarrayOrTensor | Nonedivisorr   r[   r   rY   r
   r\   r]   c                 C  "   || _ || _|| _|| _|| _d S r^   )r  r  r   r   rY   )ra   r  r  r   r   rY   rb   rb   rc   r_   E  
   
zNormalizeIntensity.__init__c                 C  s<   t | tjrt| S t|  } |  dkr|  S | S )Nr   )r   r   r   rW   r   rV   numelr   xrb   rb   rc   r   S  s   
zNormalizeIntensity._meanc                 C  s@   t | tjrt| S tj|  dd} |  dkr|  S | S )NFr   r   )r   r   r   rX   r   rV   r  r   r  rb   rb   rc   r   Z  s   
zNormalizeIntensity._stdrd   r   c           	      C  s  t |tjd^}}| jr|dk}|| }| s|S nd }|}|d ur%|n| |}t|tjtj	frBt
||^}}|d urB|| }|d urH|n| |}t|rY|dkrXd}nt|tjtj	frwt
||^}}|d urq|| }d||dk< |d ur|| | ||< |S || | }|S )Nrh   r   rT   r   )r&   r   r   r   r   r   r   r   r   r   r'   r   isscalar)	ra   rd   subdivrp   r   Z
masked_img_sub_divrb   rb   rc   
_normalizea  s:   
zNormalizeIntensity._normalizec           	      C  s"  t |t d}| jp|j}t|}| jr}| jdur.t| j|kr.td| dt| j d| jdurHt| j|krHtd| dt| j d|jjsUt	|t
jd^}}t|D ]"\}}| j|| jdurj| j| nd| jduru| j| ndd||< qYn	| || j| j}t|||dd	 }|S )
zw
        Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
        rt   Nzimg has z channels, but subtrahend has z components.z channels, but divisor has rh   )r  r  r   )r(   r   rY   r   r   r  r   r  is_floating_pointr&   r   r   r   r  r'   )	ra   rd   r   rY   Zimg_lenrp   r   r   r   rb   rb   rc   rx     s(   zNormalizeIntensity.__call__)r  r  r  r  r   r[   r   r[   rY   r
   r\   r]   )NNr   )r{   r|   r}   r~   r   r   r   r   r   r   r_   staticmethodr   r   r  rx   rb   rb   rb   rc   r6   0  s    

"r6   c                   @  s2   e Zd ZdZejejgZddddZdddZ	dS )r7   a  
    Filter the intensity values of whole image to below threshold or above threshold.
    And fill the remaining parts of the image to the `cval` value.

    Args:
        threshold: the threshold to filter intensity values.
        above: filter values above the threshold or below the threshold, default is True.
        cval: value to fill the remaining parts of the image, default is 0.
    TrT   	thresholdrV   abover[   cvalr\   r]   c                 C  >   t |ttfstdt| d| d|| _|| _|| _d S )Nz-threshold must be a float or int number, got  r   )r   r   rV   r   r   r  r  r   )ra   r  r  r   rb   rb   rc   r_     
   
zThresholdIntensity.__init__rd   r   c                 C  sL   t |t d}| jr|| jkn|| jk }t||| j}t||jd^}}|S )rs   rt   rh   )r(   r   r  r  r   r   r&   rY   )ra   rd   maskresrp   rb   rb   rc   rx     s
   zThresholdIntensity.__call__N)TrT   )r  rV   r  r[   r   rV   r\   r]   r   r   rb   rb   rb   rc   r7     s
    
r7   c                   @  s<   e Zd ZdZejejgZdddej	fdddZ
dddZdS )r8   a  
    Apply specific intensity scaling to the whole numpy array.
    Scaling from [a_min, a_max] to [b_min, b_max] with clip option.

    When `b_min` or `b_max` are `None`, `scaled_array * (b_max - b_min) + b_min` will be skipped.
    If `clip=True`, when `b_min`/`b_max` is None, the clipping is not performed on the corresponding edge.

    Args:
        a_min: intensity original range min.
        a_max: intensity original range max.
        b_min: intensity target range min.
        b_max: intensity target range max.
        clip: whether to perform clip after scaling.
        dtype: output data type, if None, same as input image. defaults to float32.
    NFa_minrV   a_maxb_minre   b_maxr   r[   rY   r
   r\   r]   c                 C  s(   || _ || _|| _|| _|| _|| _d S r^   r&  r'  r(  r)  r   rY   )ra   r&  r'  r(  r)  r   rY   rb   rb   rc   r_     s   	
zScaleIntensityRange.__init__rd   r   c                 C  s   t |t d}| jp|j}| j| j dkr,tdt | jdu r$|| j S || j | j S || j | j| j  }| jdurL| jdurL|| j| j  | j }| j	rWt	|| j| j}t
||dd }|S )rs   rt   rT   zDivide by zero (a_min == a_max)Nrh   r   )r(   r   rY   r'  r&  r	   Warningr(  r)  r   r&   )ra   rd   rY   r   rb   rb   rc   rx     s   


zScaleIntensityRange.__call__)r&  rV   r'  rV   r(  re   r)  re   r   r[   rY   r
   r\   r]   r   r   rb   rb   rb   rc   r8     s    r8   c                   @  sF   e Zd ZdZejejgZdddej	fdddZ
dddZdddZdS )r9   a  
    Apply clip based on the intensity distribution of input image.
    If `sharpness_factor` is provided, the intensity values will be soft clipped according to
    f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
    From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291

    Soft clipping preserves the order of the values and maintains the gradient everywhere.
    For example:

    .. code-block:: python
        :emphasize-lines: 11, 22

        image = torch.Tensor(
            [[[1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5]]])

        # Hard clipping from lower and upper image intensity percentiles
        hard_clipper = ClipIntensityPercentiles(30, 70)
        print(hard_clipper(image))
        metatensor([[[2., 2., 3., 4., 4.],
                [2., 2., 3., 4., 4.],
                [2., 2., 3., 4., 4.],
                [2., 2., 3., 4., 4.],
                [2., 2., 3., 4., 4.],
                [2., 2., 3., 4., 4.]]])


        # Soft clipping from lower and upper image intensity percentiles
        soft_clipper = ClipIntensityPercentiles(30, 70, 10.)
        print(soft_clipper(image))
        metatensor([[[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000]]])

    See Also:

        - :py:class:`monai.transforms.ScaleIntensityRangePercentiles`
    NFlowerre   uppersharpness_factorr   r[   return_clipping_valuesrY   r
   r\   r]   c                 C  s   |du r|du rt d|dur|dk s|dkrt d|dur,|dk s(|dkr,t d|dur<|dur<||k r<t d|durH|dkrHt d|| _|| _|| _|| _|rYg | _|| _|| _dS )	a  
        Args:
            lower: lower intensity percentile. In the case of hard clipping, None will have the same effect as 0 by
                not clipping the lowest input values. However, in the case of soft clipping, None and zero will have
                two different effects: None will not apply clipping to low values, whereas zero will still transform
                the lower values according to the soft clipping transformation. Please check for more details:
                https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291.
            upper: upper intensity percentile.  The same as for lower, but this time with the highest values. If we
                are looking to perform soft clipping, if None then there will be no effect on this side whereas if set
                to 100, the values will be passed via the corresponding clipping equation.
            sharpness_factor: if not None, the intensity values will be soft clipped according to
                f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)).
                defaults to None.
            channel_wise: if True, compute intensity percentile and normalize every channel separately.
                default to False.
            return_clipping_values: whether to return the calculated percentiles in tensor meta information.
                If soft clipping and requested percentile is None, return None as the corresponding clipping
                values in meta information. Clipping values are stored in a list with each element corresponding
                to a channel if channel_wise is set to True. defaults to False.
            dtype: output data type, if None, same as input image. defaults to float32.
        Nz+lower or upper percentiles must be providedrT         Y@)Percentiles must be in the range [0, 100]z,upper must be greater than or equal to lowerr   z'sharpness_factor must be greater than 0)r   r,  r-  r.  r   clipping_valuesr/  rY   )ra   r,  r-  r.  r   r/  rY   rb   rb   rc   r_   )  s$   
z!ClipIntensityPercentiles.__init__rd   r   c                 C  s  | j d ur*| jd urt|| jnd }| jd urt|| jnd }t|| j ||| j}n&| jd ur5t|| jnt|d}| jd urEt|| jnt|d}t|||}| jry| j	|d u r\|n
t
|dre| n||d u rl|n
t
|dru| n|f t|dd}|S )Nr   d   r   Frt   )r.  r,  r   r-  r   rY   r   r/  r2  r   hasattrr   r(   )ra   rd   Zlower_percentileZupper_percentilerb   rb   rc   _clip[  s&   
  zClipIntensityPercentiles._clipc                   sl   t |t d}t |dd} jrt fdd|D }n j|d}t||dd } jr4 j|j	d< |S )	rs   rt   Fc                      g | ]} j |d qS r   )r5  r   r   rb   rc   r   ~      z5ClipIntensityPercentiles.__call__.<locals>.<listcomp>r   r   r   r2  )
r(   r   r   r   r   r5  r'   r/  r2  metara   rd   r   rb   r   rc   rx   w  s   z!ClipIntensityPercentiles.__call__)r,  re   r-  re   r.  re   r   r[   r/  r[   rY   r
   r\   r]   r   )r{   r|   r}   r~   r   r   r   r   r   r   r_   r5  rx   rb   rb   rb   rc   r9     s    .
2r9   c                   @  s4   e Zd ZdZejejgZddd
dZddddZ	dS )r:   aN  
    Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::

        x = ((x - min) / intensity_range) ^ gamma * intensity_range + min

    Args:
        gamma: gamma value to adjust the contrast as function.
        invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
            values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
            from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
            function.
        retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
            ensure that the output intensity distribution has the same mean and standard deviation as the intensity
            distribution of the input. This behaviour is mimicked from `nnU-Net
            <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
            function.
    FgammarV   invert_imager[   retain_statsr\   r]   c                 C  r!  )Nz)gamma must be a float or int number, got r"  r   )r   r   rV   r   r   r<  r=  r>  )ra   r<  r=  r>  rb   rb   rc   r_     r#  zAdjustContrast.__init__Nrd   r   c           	      C  s   t |t d}|dur|n| j}| jr| }| jr!| }| }d}| }| | }|| t	||  | | | }| jrT||  }|| d  }|| | }| jrZ| }|S )zn
        Apply the transform to `img`.
        gamma: gamma value to adjust the contrast as function.
        rt   NgHz>g:0yE>)
r(   r   r<  r=  r>  rW   rX   r   r   rV   )	ra   rd   r<  r   sdepsilonimg_minZ	img_ranger   rb   rb   rc   rx     s$    zAdjustContrast.__call__)FF)r<  rV   r=  r[   r>  r[   r\   r]   r^   r   r   rb   rb   rb   rc   r:     s
    r:   c                      sJ   e Zd ZdZejZ				ddddZdd fddZdd ddZ  Z	S )!r;   a  
    Randomly changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:

        x = ((x - min) / intensity_range) ^ gamma * intensity_range + min

    Args:
        prob: Probability of adjustment.
        gamma: Range of gamma values.
            If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).
        invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
            values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
            from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
            function.
        retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
            ensure that the output intensity distribution has the same mean and standard deviation as the intensity
            distribution of the input. This behaviour is mimicked from `nnU-Net
            <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
            function.
    rS         ?g      @FrU   rV   r<  r   r=  r[   r>  r\   r]   c                 C  s   t | | t|ttfr|dkrtd| d|f| _nt|dkr(tdt|t	|f| _d| _
|| _|| _t| j
| j| jd| _d S )NrC  zWif gamma is a number, must greater than 0.5 and value is picked from (0.5, gamma), got r   z,gamma should be a number or pair of numbers.r   )r=  r>  )r   r_   r   r   rV   r   r<  r   r   r   gamma_valuer=  r>  r:   adjust_contrast)ra   rU   r<  r=  r>  rb   rb   rc   r_     s    zRandAdjustContrast.__init__Nr   r   c                   r   r   )ri   rj   rk   rl   rm   r<  rD  r   rq   rb   rc   rj     r   zRandAdjustContrast.randomizeTrd   r   rj   c                 C  sD   t |t d}|r|   | js|S | jdu rtd| || jS )rs   rt   Nz?gamma_value is not set, please call `randomize` function first.)r(   r   rj   rk   rD  rw   rE  r   rb   rb   rc   rx     s   
zRandAdjustContrast.__call__)rS   rB  FF)
rU   rV   r<  r   r=  r[   r>  r[   r\   r]   r^   r   r   r   )
r{   r|   r}   r~   r:   r   r_   rj   rx   r   rb   rb   rq   rc   r;     s    r;   c                   @  s@   e Zd ZdZejZdddejfdddZdddZ	dddZ
dS )r<   a	  
    Apply range scaling to a numpy array based on the intensity distribution of the input.

    By default this transform will scale from [lower_intensity_percentile, upper_intensity_percentile] to
    `[b_min, b_max]`, where {lower,upper}_intensity_percentile are the intensity values at the corresponding
    percentiles of ``img``.

    The ``relative`` parameter can also be set to scale from [lower_intensity_percentile, upper_intensity_percentile]
    to the lower and upper percentiles of the output range [b_min, b_max].

    For example:

    .. code-block:: python
        :emphasize-lines: 11, 22

        image = torch.Tensor(
            [[[1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5],
              [1, 2, 3, 4, 5]]])

        # Scale from lower and upper image intensity percentiles
        # to output range [b_min, b_max]
        scaler = ScaleIntensityRangePercentiles(10, 90, 0, 200, False, False)
        print(scaler(image))
        metatensor([[[  0.,  50., 100., 150., 200.],
             [  0.,  50., 100., 150., 200.],
             [  0.,  50., 100., 150., 200.],
             [  0.,  50., 100., 150., 200.],
             [  0.,  50., 100., 150., 200.],
             [  0.,  50., 100., 150., 200.]]])


        # Scale from lower and upper image intensity percentiles
        # to lower and upper percentiles of the output range [b_min, b_max]
        rel_scaler = ScaleIntensityRangePercentiles(10, 90, 0, 200, False, True)
        print(rel_scaler(image))
        metatensor([[[ 20.,  60., 100., 140., 180.],
             [ 20.,  60., 100., 140., 180.],
             [ 20.,  60., 100., 140., 180.],
             [ 20.,  60., 100., 140., 180.],
             [ 20.,  60., 100., 140., 180.],
             [ 20.,  60., 100., 140., 180.]]])

    See Also:

        - :py:class:`monai.transforms.ScaleIntensityRange`

    Args:
        lower: lower intensity percentile.
        upper: upper intensity percentile.
        b_min: intensity target range min.
        b_max: intensity target range max.
        clip: whether to perform clip after scaling.
        relative: whether to scale to the corresponding percentiles of [b_min, b_max].
        channel_wise: if True, compute intensity percentile and normalize every channel separately.
            default to False.
        dtype: output data type, if None, same as input image. defaults to float32.
    Fr,  rV   r-  r(  re   r)  r   r[   r   r   rY   r
   r\   r]   c	           	      C  sd   |dk s|dkrt d|dk s|dkrt d|| _|| _|| _|| _|| _|| _|| _|| _d S )NrT   r0  r1  )	r   r,  r-  r(  r)  r   r   r   rY   )	ra   r,  r-  r(  r)  r   r   r   rY   rb   rb   rc   r_   U  s   
z'ScaleIntensityRangePercentiles.__init__rd   r   c                 C  s   t || j}t || j}| j}| j}| jr?| jd u s| jd u r#td| j| j | jd  | j }| j| j | jd  | j }t||||| j| j	d}||}t
|dd}|S )Nz6If it is relative, b_min and b_max should not be None.r0  r*  Frt   )r   r,  r-  r(  r)  r   r   r8   r   rY   r(   )ra   rd   r&  r'  r(  r)  scalarrb   rb   rc   r  m  s   z)ScaleIntensityRangePercentiles._normalizec                   sZ   t |t d}t |dd} jrt fdd|D }n j|d}t|| jdd S )rs   rt   Fc                   r6  r7  )r  r   r   rb   rc   r     r8  z;ScaleIntensityRangePercentiles.__call__.<locals>.<listcomp>r   r   r   )r(   r   r   r   r   r  r'   rY   r;  rb   r   rc   rx     s   z'ScaleIntensityRangePercentiles.__call__N)r,  rV   r-  rV   r(  re   r)  re   r   r[   r   r[   r   r[   rY   r
   r\   r]   r   )r{   r|   r}   r~   r8   r   r   r   r_   r  rx   rb   rb   rb   rc   r<     s    >
r<   c                   @  s8   e Zd ZdZejejgZdefdd	d
Z	ddddZ
dS )r=   aP  
    Mask the intensity values of input image with the specified mask data.
    Mask data must have the same spatial size as the input image, and all
    the intensity values of input image corresponding to the selected values
    in the mask data will keep the original value, others will be set to `0`.

    Args:
        mask_data: if `mask_data` is single channel, apply to every channel
            of input image. if multiple channels, the number of channels must
            match the input data. the intensity values of input image corresponding
            to the selected values in the mask data will keep the original value,
            others will be set to `0`. if None, must specify the `mask_data` at runtime.
        select_fn: function to select valid values of the `mask_data`, default is
            to select `values > 0`.

    N	mask_dataNdarrayOrTensor | None	select_fnr   r\   r]   c                 C  r   r^   )rG  rI  )ra   rG  rI  rb   rb   rc   r_     r   zMaskIntensity.__init__rd   r   c                 C  s   t |t d}|du r| jn|}|du rtdt||d^}}| |}|jd dkrG|jd |jd krGtd|jd  d|jd  d	t|| |d
d S )a&  
        Args:
            mask_data: if mask data is single channel, apply to every channel
                of input image. if multiple channels, the channel number must
                match input data. mask_data will be converted to `bool` values
                by `mask_data > 0` before applying transform to input image.

        Raises:
            - ValueError: When both ``mask_data`` and ``self.mask_data`` are None.
            - ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel.

        rt   NzImust provide the mask_data when initializing the transform or at runtime.r	  r   r   r   zZWhen mask_data is not single channel, mask_data channels must match img, got img channels=z mask_data channels=r   r9  )r(   r   rG  r   r'   rI  ro   )ra   rd   rG  Z
mask_data_rp   rb   rb   rc   rx     s   
"zMaskIntensity.__call__)rG  rH  rI  r   r\   r]   r^   )rd   r   rG  rH  r\   r   )r{   r|   r}   r~   r   r   r   r   r   r_   rx   rb   rb   rb   rc   r=     s
    r=   c                   @  s.   e Zd ZdZejgZddd
dZdddZdS )r?   aQ  
    Smooth the input data along the given axis using a Savitzky-Golay filter.

    Args:
        window_length: Length of the filter window, must be a positive odd integer.
        order: Order of the polynomial to fit to each window, must be less than ``window_length``.
        axis: Optional axis along which to apply the filter kernel. Default 1 (first spatial dimension).
        mode: Optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
            or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.
    r   r   window_lengthr   orderr  modestrc                 C  s8   |dk rt d|| _|| _|| _|| _td| _d S )Nr   axis must be zero or positive.rT   )r   rK  rL  r  rM  r   r   r   )ra   rK  rL  r  rM  rb   rb   rc   r_     s   zSavitzkyGolaySmooth.__init__rd   r   r\   c                 C  s`   t |t d}t |dd| _t| j| j| jd | j}|| jd	d}t
||d^}}|S )z
        Args:
            img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].

        Returns:
            array containing smoothed result.

        rt   Fr   r   r9  )r(   r   r   r   rK  rL  r  rM  	unsqueezesqueezer'   )ra   rd   Zsavgol_filterZsmoothedr   rp   rb   rb   rc   rx     s   	zSavitzkyGolaySmooth.__call__N)r   r   )rK  r   rL  r   r  r   rM  rN  r   	r{   r|   r}   r~   r   r   r   r_   rx   rb   rb   rb   rc   r?     s
    
r?   c                   @  s.   e Zd ZdZejgZddd
dZdddZdS )r>   aU  
    Find the envelope of the input data along the requested axis using a Hilbert transform.

    Args:
        axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension.
        n: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension
        ``axis``.

    r   Nr  r   n
int | Noner\   r]   c                 C  s    |dk rt d|| _|| _d S )Nr   rO  )r   r  rS  )ra   r  rS  rb   rb   rc   r_     s   
zDetectEnvelope.__init__rd   r   c                 C  s\   t |t d}t|tj^}}t| jd | j}||d	d
 }t||d^}}|S )z

        Args:
            img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].

        Returns:
            np.ndarray containing envelope of data in img along the specified axis.

        rt   r   r   rJ  )r(   r   r&   r   r   r   r  rS  rP  rQ  absr'   )ra   rd   r   rp   Zhilbert_transformr   rb   rb   rc   rx     s   
zDetectEnvelope.__call__)r   N)r  r   rS  rT  r\   r]   )rd   r   rR  rb   rb   rb   rc   r>     s
    
r>   c                   @  s.   e Zd ZdZejgZddddZdddZdS )r@   a  
    Apply median filter to the input data based on specified `radius` parameter.
    A default value `radius=1` is provided for reference.

    See also: :py:func:`monai.networks.layers.median_filter`

    Args:
        radius: if a list of values, must match the count of spatial dimensions of input data,
            and apply every value in the list to 1 spatial dimension. if only 1 value provided,
            use it for all spatial dimensions.
    r   radiusSequence[int] | intr\   r]   c                 C  s
   || _ d S r^   )rV  )ra   rV  rb   rb   rc   r_   )  s   
zMedianSmooth.__init__rd   r   c           	      C  sf   t |t d}t|tjtjd^}}|jd }t| j|}t	||d}||}t
|||jd^}}|S )Nrt   rh   r   )spatial_dimsr   )r(   r   r&   r   r   rV   ndimr!   rV  r   r'   rY   )	ra   rd   r   rp   rX  rZmedian_filter_instanceout_tr   rb   rb   rc   rx   ,  s   
zMedianSmooth.__call__N)r   )rV  rW  r\   r]   rd   r   r\   r   rR  rb   rb   rb   rc   r@     s
    r@   c                   @  s.   e Zd ZdZejgZddd
dZdddZdS )rA   a:  
    Apply Gaussian smooth to the input data based on specified `sigma` parameter.
    A default value `sigma=1.0` is provided for reference.

    Args:
        sigma: if a list of values, must match the count of spatial dimensions of input data,
            and apply every value in the list to 1 spatial dimension. if only 1 value provided,
            use it for all spatial dimensions.
        approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
            see also :py:meth:`monai.networks.layers.GaussianFilter`.

    r   erfsigmar   approxrN  r\   r]   c                 C  r   r^   r^  r_  )ra   r^  r_  rb   rb   rc   r_   G  r   zGaussianSmooth.__init__rd   r   c                   s   t |t d}t|tjtjd^ }t| jtr# fdd| jD }n	tj	| j j
d}t jd || jd}| dd}t|||jd	^}}|S )
Nrt   rh   c                   s   g | ]
}t j| jd qS )r   )r   	as_tensorr   )r   sr   rb   rc   r   P      z+GaussianSmooth.__call__.<locals>.<listcomp>r   r   r_  r   r   )r(   r   r&   r   r   rV   r   r^  r   ra  r   r   rY  r_  rP  rQ  r'   rY   )ra   rd   rp   r^  gaussian_filterr[  r   rb   rc  rc   rx   K  s   zGaussianSmooth.__call__N)r   r]  )r^  r   r_  rN  r\   r]   r\  rR  rb   rb   rb   rc   rA   7  s
    rA   c                      sL   e Zd ZdZejZ					ddddZdd  fddZd!d"ddZ  Z	S )#rB   aD  
    Apply Gaussian smooth to the input data based on randomly selected `sigma` parameters.

    Args:
        sigma_x: randomly select sigma value for the first spatial dimension.
        sigma_y: randomly select sigma value for the second spatial dimension if have.
        sigma_z: randomly select sigma value for the third spatial dimension if have.
        prob: probability of Gaussian smooth.
        approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
            see also :py:meth:`monai.networks.layers.GaussianFilter`.

    g      ?g      ?rS   r]  sigma_xr   sigma_ysigma_zrU   rV   r_  rN  r\   r]   c                 C  sL   t | | || _|| _|| _|| _| jd | _| jd | _| jd | _d S )Nr   )	r   r_   rh  ri  rj  r_  r  yz)ra   rh  ri  rj  rU   r_  rb   rb   rc   r_   j  s   zRandGaussianSmooth.__init__Nr   r   c                   st   t  d  | jsd S | jj| jd | jd d| _| jj| jd | jd d| _| jj| j	d | j	d d| _
d S r   )ri   rj   rk   rl   rm   rh  r  ri  rk  rj  rl  r   rq   rb   rc   rj   |  s   "zRandGaussianSmooth.randomizeTrd   r   rj   r[   c                 C  sT   t |t d}|r|   | js|S t| j| j| jf|jd d}t	|| j
d|S )Nrt   r   valsr   r`  )r(   r   rj   rk   r"   r  rk  rl  rY  rA   r_  )ra   rd   rj   r^  rb   rb   rc   rx     s   zRandGaussianSmooth.__call__)rg  rg  rg  rS   r]  )rh  r   ri  r   rj  r   rU   rV   r_  rN  r\   r]   r^   r   r   r   )
r{   r|   r}   r~   rA   r   r_   rj   rx   r   rb   rb   rq   rc   rB   Z  s    rB   c                   @  s6   e Zd ZdZejgZ				ddddZdddZdS )rC   a&  
    Sharpen images using the Gaussian Blur filter.
    Referring to: http://scipy-lectures.org/advanced/image_processing/auto_examples/plot_sharpen.html.
    The algorithm is shown as below

    .. code-block:: python

        blurred_f = gaussian_filter(img, sigma1)
        filter_blurred_f = gaussian_filter(blurred_f, sigma2)
        img = blurred_f + alpha * (blurred_f - filter_blurred_f)

    A set of default values `sigma1=3.0`, `sigma2=1.0` and `alpha=30.0` is provide for reference.

    Args:
        sigma1: sigma parameter for the first gaussian kernel. if a list of values, must match the count
            of spatial dimensions of input data, and apply every value in the list to 1 spatial dimension.
            if only 1 value provided, use it for all spatial dimensions.
        sigma2: sigma parameter for the second gaussian kernel. if a list of values, must match the count
            of spatial dimensions of input data, and apply every value in the list to 1 spatial dimension.
            if only 1 value provided, use it for all spatial dimensions.
        alpha: weight parameter to compute the final result.
        approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
            see also :py:meth:`monai.networks.layers.GaussianFilter`.

          @r         >@r]  sigma1r   sigma2alpharV   r_  rN  r\   r]   c                 C  r   r^   rq  rr  rs  r_  )ra   rq  rr  rs  r_  rb   rb   rc   r_     s   
zGaussianSharpen.__init__rd   r   c           	        s   t |t d}t|tjtjd^ } fddjjfD \}}| d}||}|j	||   
d}t|||jd^}}|S )Nrt   rh   c                 3  s.    | ]}t  jd  |jd jV  qdS )r   re  N)r   rY  r_  tor   )r   r^  r   ra   rb   rc   	<genexpr>  s
    
z+GaussianSharpen.__call__.<locals>.<genexpr>r   r   )r(   r   r&   r   r   r   rq  rr  rP  rs  rQ  r'   rY   )	ra   rd   rp   Zgf1Zgf2Z	blurred_fZfilter_blurred_fr[  r   rb   rv  rc   rx     s   

zGaussianSharpen.__call__N)ro  r   rp  r]  )
rq  r   rr  r   rs  rV   r_  rN  r\   r]   r\  rR  rb   rb   rb   rc   rC     s    rC   c                      sT   e Zd ZdZejZ									d$d%ddZd&d' fddZd(d)d"d#Z  Z	S )*rD   a  
    Sharpen images using the Gaussian Blur filter based on randomly selected `sigma1`, `sigma2` and `alpha`.
    The algorithm is :py:class:`monai.transforms.GaussianSharpen`.

    Args:
        sigma1_x: randomly select sigma value for the first spatial dimension of first gaussian kernel.
        sigma1_y: randomly select sigma value for the second spatial dimension(if have) of first gaussian kernel.
        sigma1_z: randomly select sigma value for the third spatial dimension(if have) of first gaussian kernel.
        sigma2_x: randomly select sigma value for the first spatial dimension of second gaussian kernel.
            if only 1 value `X` provided, it must be smaller than `sigma1_x` and randomly select from [X, sigma1_x].
        sigma2_y: randomly select sigma value for the second spatial dimension(if have) of second gaussian kernel.
            if only 1 value `Y` provided, it must be smaller than `sigma1_y` and randomly select from [Y, sigma1_y].
        sigma2_z: randomly select sigma value for the third spatial dimension(if have) of second gaussian kernel.
            if only 1 value `Z` provided, it must be smaller than `sigma1_z` and randomly select from [Z, sigma1_z].
        alpha: randomly select weight parameter to compute the final result.
        approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
            see also :py:meth:`monai.networks.layers.GaussianFilter`.
        prob: probability of Gaussian sharpen.

    rC  r   rC  g      $@rp  r]  rS   sigma1_xr   sigma1_ysigma1_zsigma2_xr   sigma2_ysigma2_zrs  r_  rN  rU   rV   r\   r]   c
           
      C  sj   t | |	 || _|| _|| _|| _|| _|| _|| _|| _	d | _
d | _d | _d | _d | _d | _d | _d S r^   )r   r_   rz  r{  r|  r}  r~  r  rs  r_  x1y1z1x2y2z2a)
ra   rz  r{  r|  r}  r~  r  rs  r_  rU   rb   rb   rc   r_     s    
zRandGaussianSharpen.__init__Nr   r   c                   s:  t  d  | jsd S | jj| jd | jd d| _| jj| jd | jd d| _| jj| j	d | j	d d| _
t| jtsD| j| jfn| j}t| jtsS| j| jfn| j}t| jtsb| j| j
fn| j}| jj|d |d d| _| jj|d |d d| _| jj|d |d d| _| jj| jd | jd d| _d S r   )ri   rj   rk   rl   rm   rz  r  r{  r  r|  r  r   r}  r   r~  r  r  r  r  rs  r  )ra   r   r}  r~  r  rq   rb   rc   rj     s   "zRandGaussianSharpen.randomizeTrd   r   rj   r[   c                 C  s   t |t d}|r|   | js|S | jd u s&| jd u s&| jd u s&| jd u r*tdt	| j
| j| jf|jd d}t	| j| j| jf|jd d}t||| j| jd|S )Nrt   rv   r   rm  rt  )r(   r   rj   rk   r  r  r  r  rw   r"   r  r  r  rY  rC   r_  )ra   rd   rj   rq  rr  rb   rb   rc   rx     s   (zRandGaussianSharpen.__call__)	rx  rx  rx  rC  rC  rC  ry  r]  rS   )rz  r   r{  r   r|  r   r}  r   r~  r   r  r   rs  r   r_  rN  rU   rV   r\   r]   r^   r   r   r   )
r{   r|   r}   r~   rC   r   r_   rj   rx   r   rb   rb   rq   rc   rD     s    rD   c                      sR   e Zd ZdZejejgZddd
dZdddZ	d d! fddZ
d"d#ddZ  ZS )$rE   a  
    Apply random nonlinear transform to the image's intensity histogram.

    Args:
        num_control_points: number of control points governing the nonlinear intensity mapping.
            a smaller number of control points allows for larger intensity shifts. if two values provided, number of
            control points selecting from range (min_value, max_value).
        prob: probability of histogram shift.
    
   rS   num_control_pointstuple[int, int] | intrU   rV   r\   r]   c                 C  sx   t | | t|tr|dkrtd||f| _nt|dkr#tdt|dkr-tdt|t|f| _|  |  d S )Nr   z7num_control_points should be greater than or equal to 3z:num_control points should be a number or a pair of numbers)	r   r_   r   r   r   r  r   r   r   )ra   r  rU   rb   rb   rc   r_   (  s   
zRandHistogramShift.__init__r  r   xpfpc           	      C  s   t |tjrtnt}t |tjrt|||S |dd  |d d  |dd  |d d   }|d d ||d d   }||d|dd }||dt	|d }|| |d ||  |j
}|d |||d k < |d |||d k< |S )Nr   r   )r   r   r   r   r   interpsearchsortedreshaper   r   ro   )	ra   r  r  r  nsmbindicesfrb   rb   rc   r  8  s   0"zRandHistogramShift.interpNr   r   c                   s   t  d  | jsd S | j| jd | jd d }tdd|| _t	| j| _
td|d D ]}| j| j
|d  | j
|d  | j
|< q0d S )Nr   r   )ri   rj   rk   rl   randintr  r   r   reference_control_pointsr   floating_control_pointsr   rm   )ra   r   Znum_control_pointr   rq   rb   rc   rj   I  s   zRandHistogramShift.randomizeTrd   rj   r[   c                 C  s   t |t d}|r|   | js|S | jd u s| jd u r tdt |dd}| | }}||kr=t	d| d |S t
| j|d^}}t
| j|d^}}|||  | }	|||  | }
| ||	|
}t
||dd S )Nrt   rv   Fz(The image's intensity is a single value zD. The original image is simply returned, no histogram shift is done.r9  r   )r(   r   rj   rk   r  r  rw   r   r   r	   r'   r  )ra   rd   rj   r   rA  Zimg_maxr  rp   ypZreference_control_points_scaledZfloating_control_points_scaledrb   rb   rc   rx   U  s(   
zRandHistogramShift.__call__)r  rS   )r  r  rU   rV   r\   r]   )r  r   r  r   r  r   r\   r   r^   r   r   r   )r{   r|   r}   r~   r   r   r   r   r_   r  rj   rx   r   rb   rb   rq   rc   rE     s    

rE   c                   @  s<   e Zd ZdZejejgZddddZdddZ	dddZ
dS )rF   a  
    The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
    are one of the common type of type artifacts appearing in MRI scans.

    The transform is applied to all the channels in the data.

    For general information on Gibbs artifacts, please refer to:

    `An Image-based Approach to Understanding the Physics of MR Artifacts
    <https://pubs.rsna.org/doi/full/10.1148/rg.313105115>`_.

    `The AAPM/RSNA Physics Tutorial for Residents
    <https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949>`_

    Args:
        alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes
            values in the interval [0,1] with alpha = 0 acting as the identity mapping.
    rS   rs  rV   r\   r]   c                 C  s"   |dks|dk rt d|| _d S )Nr   r   z.alpha must take values in the interval [0, 1].)r   rs  )ra   rs  rb   rb   rc   r_     s   
zGibbsNoise.__init__rd   r   c                 C  sf   t |t d}t |dd}t|jdd  }| ||}| |}| ||}t|||jd^}}|S )Nrt   Fr   r   )	r(   r   r   ro   shift_fourier_apply_maskinv_shift_fourierr'   rY   )ra   rd   r   n_dimsr   r   rp   rb   rb   rc   rx     s   
zGibbsNoise.__call__r   c                 C  s   |j dd }d| j t| td d }t|d d }tjtdd |D  }dd t||D }tt	|}||k}tj
|d |j d	 d	d
}t|tjr`t|tj|jd^}}	|| }
|
S )zBuilds and applies a mask on the spatial dimensions.

        Args:
            k: k-space version of the image.
        Returns:
            masked version of the k-space image.
        r   Nr          @c                 s  s    | ]}t d |V  qdS r   N)slicer   r   rb   rb   rc   rw    s    z)GibbsNoise._apply_mask.<locals>.<genexpr>c                 S  s   g | ]
\}}|| d  qS )r   rb   )r   coordcrb   rb   rc   r     rd  z*GibbsNoise._apply_mask.<locals>.<listcomp>r   r  r   )ro   rs  r   r   r   arrayogridtuplezipsumrepeatr   r   r   r&   r   )ra   r   ro   rZ  centerr   Zcoords_from_center_sqZdist_from_centerr$  rp   Zk_maskedrb   rb   rc   r    s   "zGibbsNoise._apply_maskN)rS   )rs  rV   r\   r]   r   )r   r   r\   r   )r{   r|   r}   r~   r   r   r   r   r_   rx   r  rb   rb   rb   rc   rF   o  s    
rF   c                      s@   e Zd ZdZejZddd
dZd fddZddddZ  Z	S )rG   a  
    Naturalistic image augmentation via Gibbs artifacts. The transform
    randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
    are one of the common type of type artifacts appearing in MRI scans.

    The transform is applied to all the channels in the data.

    For general information on Gibbs artifacts, please refer to:
    https://pubs.rsna.org/doi/full/10.1148/rg.313105115
    https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949


    Args:
        prob (float): probability of applying the transform.
        alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
            values in the interval [0,1] with alpha = 0 acting as the identity mapping.
            If a length-2 list is given as [a,b] then the value of alpha will be
            sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1.
            If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].
    rS   rT   r   rU   rV   rs  float | Sequence[float]r\   r]   c                 C  s   t |tr	d|f}t|}t|dkrtd|d dks#|d dk r'td|d |d kr3td|| _d| _tj| |d d S )	Nr   r   zalpha length must be 2.r   z-alpha must take values in the interval [0, 1]z!When alpha = [a,b] we need a < b.r   rU   )	r   rV   r    r   r   rs  sampled_alphar   r_   )ra   rU   rs  rb   rb   rc   r_     s   
zRandGibbsNoise.__init__r   r   c                   s6   t  d | jsdS | j| jd | jd | _dS )zr
        (1) Set random variable to apply the transform.
        (2) Get alpha from uniform distribution.
        Nr   r   )ri   rj   rk   rl   rm   rs  r  r   rq   rb   rc   rj     s    zRandGibbsNoise.randomizeTrd   r   rj   r[   c                 C  s4   t |t d}|r| d  | js|S t| j|S )Nrt   )r(   r   rj   rk   rF   r  r   rb   rb   rc   rx     s   
zRandGibbsNoise.__call__)rS   r  )rU   rV   rs  r  r\   r]   )r   r   r\   r]   r   rd   r   rj   r[   )
r{   r|   r}   r~   rF   r   r_   rj   rx   r   rb   rb   rq   rc   rG     s    
rG   c                   @  sF   e Zd ZdZejejgZddddZdddZ	dddZ
dddZdS )rH   a  
    Apply localized spikes in `k`-space at the given locations and intensities.
    Spike (Herringbone) artifact is a type of data acquisition artifact which
    may occur during MRI scans.

    For general information on spike artifacts, please refer to:

    `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging
    <https://pubmed.ncbi.nlm.nih.gov/16009826>`_.

    `Body MRI artifacts in clinical practice: A physicist's and radiologist's
    perspective <https://doi.org/10.1002/jmri.24288>`_.

    Args:
        loc: spatial location for the spikes. For
            images with 3D spatial dimensions, the user can provide (C, X, Y, Z)
            to fix which channel C is affected, or (X, Y, Z) to place the same
            spike in all channels. For 2D cases, the user can provide (C, X, Y)
            or (X, Y).
        k_intensity: value for the log-intensity of the
            `k`-space version of the image. If one location is passed to ``loc`` or the
            channel is not specified, then this argument should receive a float. If
            ``loc`` is given a sequence of locations, then this argument should
            receive a sequence of intensities. This value should be tested as it is
            data-dependent. The default values are the 2.5 the mean of the
            log-intensity for each channel.

    Example:
        When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))``
        will place a spike at `[3, 60, 64, 32]` with `log-intensity = 13`, and
        one spike per channel located respectively at `[: , 64, 60, 32]`
        with `log-intensity = 14`.
    Nloctuple | Sequence[tuple]k_intensitySequence[float] | float | Nonec                 C  s   t || _|| _t|tr$t|d tstdt|t|kr$tdt| jd tr:|d ur<t| jts>tdd S d S d S )Nr   zZIf a sequence is passed to k_intensity, then a sequence of locations must be passed to loczJThere must be one intensity_factor value for each tuple of indices in loc.)r    r  r  r   r   r   r   )ra   r  r  rb   rb   rc   r_     s   

$zKSpaceSpikeNoise.__init__rd   r   r\   c                 C  s  t |t d}| | t|jdk rtdt| jd tr1t|jdkr1t| jdkr1tdt| jd t	rNt|jdkrNt
tt| jdkrNtdt|jdd	 }| ||}t|tjretnt}|||d
 }||}| j}|d	u rt|j|tt| ddd }t| jd t	rt| jt|D ]\}}	| |||	 qn| || j| |||d|  }t| |||d^}}
|S )zX
        Args:
            img: image with dimensions (C, H, W) or (C, H, W, D)
        rt   r   z Image needs a channel direction.r      r   zCInput images of dimension 4 need location tuple to be length 3 or 4r   N绽|=r        @y              ?r9  )r(   r   _check_indicesr   ro   rw   r   r  r   r   r   mapr  r   r   r   logrU  angler  r  rW   r   r  r    
_set_spiker
  r'   r  )ra   rd   r  r   liblog_absphaser  idxvalrp   rb   rb   rc   rx   -  s0   
,2
"zKSpaceSpikeNoise.__call__r]   c                   s   t | j}t|d ts|g}tt|D ] t|  t|jk r-dgt |   | < qtt|jD ] |j  t fdd|D krStd  d| j dq5dS )zHelper method to check consistency of self.loc and input image.

        Raises assertion error if any index in loc is out of bounds.r   c                 3  s    | ]}|  V  qd S r^   rb   )r   r  r   rb   rc   rw  c  s    z2KSpaceSpikeNoise._check_indices.<locals>.<genexpr>zThe index value at position z of one of the tuples in loc = z$ is out of bounds for current image.N)	listr  r   r   r   r   ro   r   r   )ra   rd   r  rb   r  rc   r  V  s   
 zKSpaceSpikeNoise._check_indicesr   r  r  r  r   c                 C  s   t |jt |krt|tr||d  n|||< dS t |jdkr9t |dkr9||dd|d |d |d f< dS t |jdkrUt |dkrW||dd|d |d f< dS dS dS )z
        Helper function to introduce a given intensity at given location.

        Args:
            k: intensity array to alter.
            idx: index of location where to apply change.
            val: value of intensity to write in.
        r   r  r   Nr   r   )r   ro   r   r   )ra   r   r  r  rb   rb   rc   r  h  s   	"$zKSpaceSpikeNoise._set_spiker^   )r  r  r  r  r   )r\   r]   )r   r   r  r  r  r   )r{   r|   r}   r~   r   r   r   r   r_   rx   r  r  rb   rb   rb   rc   rH     s    "

)rH   c                      s^   e Zd ZdZejZ			dd fddZddddZd  fddZd!ddZ	d"ddZ
  ZS )#rI   aF  
    Naturalistic data augmentation via spike artifacts. The transform applies
    localized spikes in `k`-space, and it is the random version of
    :py:class:`monai.transforms.KSpaceSpikeNoise`.

    Spike (Herringbone) artifact is a type of data acquisition artifact which
    may occur during MRI scans. For general information on spike artifacts,
    please refer to:

    `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging
    <https://pubmed.ncbi.nlm.nih.gov/16009826>`_.

    `Body MRI artifacts in clinical practice: A physicist's and radiologist's
    perspective <https://doi.org/10.1002/jmri.24288>`_.

    Args:
        prob: probability of applying the transform, either on all
            channels at once, or channel-wise if ``channel_wise = True``.
        intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b)
            uniformly for all channels. Or pass sequence of intervals
            ((a0, b0), (a1, b1), ...) to sample for each respective channel.
            In the second case, the number of 2-tuples must match the number of channels.
            Default ranges is `(0.95x, 1.10x)` where `x` is the mean
            log-intensity for each channel.
        channel_wise: treat each channel independently. True by
            default.

    Example:
        To apply `k`-space spikes randomly with probability 0.5, and
        log-intensity sampled from the interval [11, 12] for each channel
        independently, one uses
        ``RandKSpaceSpikeNoise(prob=0.5, intensity_range=(11, 12), channel_wise=True)``
    rS   NTrU   rV   intensity_range(Sequence[Sequence[float] | float] | Noner   r[   c                   sJ   || _ || _g | _g | _|d urt|d tr|stdt | d S )Nr   zSWhen channel_wise = False, intensity_range should be a 2-tuple (low, high) or None.)	r  r   sampled_k_intensitysampled_locsr   r   r   ri   r_   )ra   rU   r  r   rq   rb   rc   r_     s   zRandKSpaceSpikeNoise.__init__rd   r   rj   c                 C  s   | j durt| j d trt| j |jd krtdt|t d}g | _g | _	|r5| 
|}| || | js:|S t| j	| j|S )z
        Apply transform to `img`. Assumes data is in channel-first form.

        Args:
            img: image with dimensions (C, H, W) or (C, H, W, D)
        Nr   ziIf intensity_range is a sequence of sequences, then there must be one (low, high) tuple for each channel.rt   )r  r   r   r   ro   rw   r(   r   r  r  _make_sequencerj   rk   rH   )ra   rd   rj   r  rb   rb   rc   rx     s    
	
zRandKSpaceSpikeNoise.__call__Sequence[Sequence[float]]r\   r]   c                   s  t  d  jsdS  jr?t|D ]*\}} j|ft fdd|jD    j	 j
|| d || d  qdS t fdd|jdd D fddt|jd D  _t|d trq fd	d|D  _	dS  j
|d |d gt|  _	dS )
a  
        Helper method to sample both the location and intensity of the spikes.
        When not working channel wise (channel_wise=False) it use the random
        variable ``self._do_transform`` to decide whether to sample a location
        and intensity.

        When working channel wise, the method randomly samples a location and
        intensity for each channel depending on ``self._do_transform``.
        Nc                 3      | ]
} j d |V  qdS r  rl   r  r  r   rb   rc   rw        z1RandKSpaceSpikeNoise.randomize.<locals>.<genexpr>r   r   c                 3  r  r  r  r  r   rb   rc   rw    r  c                   s   g | ]}|f  qS rb   rb   r  )spatialrb   rc   r     s    z2RandKSpaceSpikeNoise.randomize.<locals>.<listcomp>c                   s"   g | ]} j |d  |d qS )r   r   )rl   rm   )r   pr   rb   rc   r     r   )ri   rj   rk   r   r   r  r   r  ro   r  rl   rm   r   r   r   r   )ra   rd   r  r   chanrq   )ra   r  rc   rj     s   
&( &zRandKSpaceSpikeNoise.randomizer  c                 C  sD   | j du r
| |S t| j d tst| j f|jd  S t| j S )zZ
        Formats the sequence of intensities ranges to Sequence[Sequence[float]].
        Nr   )r  _set_default_ranger   r   r    ro   )ra   r  rb   rb   rc   r    s
   


z#RandKSpaceSpikeNoise._make_sequencec                 C  s   t |jdd }| ||}t|tjrtnt}|||d }|	|t
t| dd }t|tjr<|d}t
dd |D S )	zr
        Sets default intensity ranges to be sampled.

        Args:
            img: image to transform.
        r   Nr  r   r  cpuc                 s  s     | ]}|d  |d fV  qdS )gffffff?g?Nrb   r  rb   rb   rc   rw  	  s    z:RandKSpaceSpikeNoise._set_default_range.<locals>.<genexpr>)r   ro   r  r   r   r   r   r  absoluterW   r  r   ru  )ra   rd   r  r   modr  Zshifted_meansrb   rb   rc   r    s   
z'RandKSpaceSpikeNoise._set_default_range)rS   NT)rU   rV   r  r  r   r[   r   r  )rd   r   r  r  r\   r]   )r  r   r\   r  )rd   r   r\   r  )r{   r|   r}   r~   rH   r   r_   rx   rj   r  r  r   rb   rb   rq   rc   rI   y  s    "
rI   c                      sV   e Zd ZdZejgZ			d d!ddZd" fddZe	d#ddZ
d$d%ddZ  ZS )&rJ   a  
    Randomly select coarse regions in the image, then execute transform operations for the regions.
    It's the base class of all kinds of region transforms.
    Refer to papers: https://arxiv.org/abs/1708.04552

    Args:
        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to
            randomly select the expected number of regions.
        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg
            as the minimum spatial size to randomly select size for every region.
            if some components of the `spatial_size` are non-positive values, the transform will use the
            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        max_holes: if not None, define the maximum number to randomly select the expected number of regions.
        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
            if some components of the `max_spatial_size` are non-positive values, the transform will use the
            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        prob: probability of applying the transform.

    NrS   holesr   spatial_sizerW  	max_holesrT  max_spatial_sizeSequence[int] | int | NonerU   rV   r\   r]   c                 C  s>   t | | |dk rtd|| _|| _|| _|| _g | _d S )Nr   z'number of holes must be greater than 0.)r   r_   r   r  r  r  r  hole_coords)ra   r  r  r  r  rU   rb   rb   rc   r_   	  s   
zRandCoarseTransform.__init__r  r   c                   s   t  d  jsd S tj|g _jd u rjn
j	jjd }t
|D ]3}jd urItj| t fddt
t|D t|}jtd ft||j  q+d S )Nr   c                 3  s,    | ]}j j|  | d  dV  qdS )r   r   Nr  r  max_sizera   rg   rb   rc   rw  8	  s   * z0RandCoarseTransform.randomize.<locals>.<genexpr>)ri   rj   rk   r#   r  r  r  r  rl   r  r   r  r  r   r   r   r  r   )ra   r  Z	num_holesrp   
valid_sizerq   r  rc   rj   .	  s   &
"
"zRandCoarseTransform.randomizerd   
np.ndarrayc                 C  s   t d| jj d)zV
        Transform the randomly selected `self.hole_coords` in input images.

        z	Subclass z must implement this method.)r   rr   r{   ra   rd   rb   rb   rc   _transform_holes<	  s   z$RandCoarseTransform._transform_holesTr   rj   r[   c                 C  s`   t |t d}|r| |jdd   | js|S t|tj^}}| j|d}t	||d^}}|S )Nrt   r   r   rJ  )
r(   r   rj   ro   rk   r&   r   r   r  r'   )ra   rd   rj   r  rp   r   r   rb   rb   rc   rx   D	  s   zRandCoarseTransform.__call__)NNrS   )r  r   r  rW  r  rT  r  r  rU   rV   r\   r]   r  )rd   r  r\   r  r   r   )r{   r|   r}   r~   r   r   r   r_   rj   r   r  rx   r   rb   rb   rq   rc   rJ   	  s    rJ   c                      s8   e Zd ZdZ					dd fddZdddZ  ZS )rK   a  
    Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value.
    Or keep the rectangular regions and fill in the other areas with specified value.
    Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379
    And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/
    #albumentations.augmentations.transforms.CoarseDropout.

    Args:
        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to
            randomly select the expected number of regions.
        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg
            as the minimum spatial size to randomly select size for every region.
            if some components of the `spatial_size` are non-positive values, the transform will use the
            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and
            dropout the outside and fill value. default to `True`.
        fill_value: target value to fill the dropout regions, if providing a number, will use it as constant
            value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select
            value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max`
            value of input image then randomly select value to fill, default to None.
        max_holes: if not None, define the maximum number to randomly select the expected number of regions.
        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
            if some components of the `max_spatial_size` are non-positive values, the transform will use the
            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        prob: probability of applying the transform.

    TNrS   r  r   r  rW  dropout_holesr[   
fill_value"tuple[float, float] | float | Noner  rT  r  r  rU   rV   r\   r]   c                   sH   t  j|||||d || _t|ttfrt|dkrtd|| _d S )N)r  r  r  r  rU   r   zEfill value should contain 2 numbers if providing the `min` and `max`.)	ri   r_   r  r   r  r  r   r   r  )ra   r  r  r  r  r  r  rU   rq   rb   rc   r_   q	  s   


zRandCoarseDropout.__init__rd   r  c                 C  s   | j du r| | fn| j }| jr;| jD ] }t|ttfr2| jj	|d |d || j
d||< q|||< q|}|S t|ttfrW| jj	|d |d |j
dj|jdd}nt||}| jD ]}|| ||< q`|S )z
        Fill the randomly selected `self.hole_coords` in input images.
        Please note that we usually only use `self.R` in `randomize()` method, here is a special case.

        Nr   r   rf   Fr   )r  r   r   r  r  r   r  r  rl   rm   ro   r   rY   r   	full_like)ra   rd   r  hr   rb   rb   rc   r  	  s    
&
*
z"RandCoarseDropout._transform_holes)TNNNrS   )r  r   r  rW  r  r[   r  r  r  rT  r  r  rU   rV   r\   r]   rd   r  )r{   r|   r}   r~   r_   r  r   rb   rb   rq   rc   rK   R	  s    "rK   c                   @  s   e Zd ZdZdddZdS )rL   a  
    Randomly select regions in the image, then shuffle the pixels within every region.
    It shuffles every channel separately.
    Refer to paper:
    Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017).
    https://arxiv.org/abs/1707.07103

    Args:
        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to
            randomly select the expected number of regions.
        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg
            as the minimum spatial size to randomly select size for every region.
            if some components of the `spatial_size` are non-positive values, the transform will use the
            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        max_holes: if not None, define the maximum number to randomly select the expected number of regions.
        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
            if some components of the `max_spatial_size` are non-positive values, the transform will use the
            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        prob: probability of applying the transform.

    rd   r  c                 C  sN   | j D ]!}t|| D ]\}}| }| j| ||j|| |< qq|S )z
        Shuffle the content of randomly selected `self.hole_coords` in input images.
        Please note that we usually only use `self.R` in `randomize()` method, here is a special case.

        )r  r   flattenrl   shuffler  ro   )ra   rd   r  r   r  Zpatch_channelrb   rb   rc   r  	  s   
z"RandCoarseShuffle._transform_holesNr  )r{   r|   r}   r~   r  rb   rb   rb   rc   rL   	  s    rL   c                   @  s<   e Zd ZdZejgZddddejfdddZ	ddddZ
dS )rM   a?  
    Apply the histogram normalization to input image.
    Refer to: https://github.com/facebookresearch/CovidPrognosis/blob/master/covidprognosis/data/transforms.py#L83.

    Args:
        num_bins: number of the bins to use in histogram, default to `256`. for more details:
            https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.
        min: the min value to normalize input image, default to `0`.
        max: the max value to normalize input image, default to `255`.
        mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.
            only points at which `mask==True` are used for the equalization.
            can also provide the mask along with img at runtime.
        dtype: data type of the output, if None, same as input image. default to `float32`.

       r      Nnum_binsr   r   r   r$  rH  rY   r
   r\   r]   c                 C  r  r^   )r  r   r   r$  rY   )ra   r  r   r   r$  rY   rb   rb   rc   r_   	  r  zHistogramNormalize.__init__rd   r   c                 C  s   t |t d}t|tj^}}|d ur|n| j}d }|d ur&t|tj^}}t||| j| j| j	d}t
||| jp:|jd^}}|S )Nrt   )rd   r$  r  r   r   r  )r(   r   r&   r   r   r$  r   r  r   r   r'   rY   )ra   rd   r$  r  rp   mask_npr   r   rb   rb   rc   rx   	  s   zHistogramNormalize.__call__)r  r   r   r   r   r   r$  rH  rY   r
   r\   r]   r^   rd   r   r$  rH  r\   r   )r{   r|   r}   r~   r   r   r   r   r   r_   rx   rb   rb   rb   rc   rM   	  s    rM   c                      s.   e Zd ZdZdd fdd	ZdddZ  ZS )rN   a  
    Transform for intensity remapping of images. The intensity at each
    pixel is replaced by a new values coming from an intensity remappping
    curve.

    The remapping curve is created by uniformly sampling values from the
    possible intensities for the input image and then adding a linear
    component. The curve is the rescaled to the input image intensity range.

    Intended to be used as a means to data augmentation via:
    :py:class:`monai.transforms.RandIntensityRemap`.

    Implementation is described in the work:
    `Intensity augmentation for domain transfer of whole breast segmentation
    in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.

    Args:
        kernel_size: window size for averaging operation for the remapping
            curve.
        slope: slope of the linear component. Easiest to leave default value
            and tune the kernel_size parameter instead.
       ffffff?kernel_sizer   sloperV   c                   s   t    || _|| _d S r^   )ri   r_   r  r  )ra   r  r  rq   rb   rc   r_   
  s   

zIntensityRemap.__init__rd   torch.Tensorr\   c                 C  s   t |t d}t |dd}t| }t| j|t|d | j	 }tj
j| j	dd|d }tt|t| }|| j| 7 }||  | |   |  |  }t|t|}t|| |d^}}|S )8
        Args:
            img: image to remap.
        rt   Fr   )strider   r9  )r(   r   r   uniquer  
from_numpyrl   choicer   r  nn	AvgPool1drP  rQ  aranger  r   r   	bucketizer   r'   )ra   rd   img_Zvals_to_sampler`   gridZ	index_imgrp   rb   rb   rc   rx   
  s   " ,zIntensityRemap.__call__)r  r  )r  r   r  rV   rd   r  r\   r  r{   r|   r}   r~   r_   rx   r   rb   rb   rq   rc   rN   	  s    rN   c                      s.   e Zd ZdZddddZd fddZ  ZS )rO   a  
    Transform for intensity remapping of images. The intensity at each
    pixel is replaced by a new values coming from an intensity remappping
    curve.

    The remapping curve is created by uniformly sampling values from the
    possible intensities for the input image and then adding a linear
    component. The curve is the rescaled to the input image intensity range.

    Implementation is described in the work:
    `Intensity augmentation for domain transfer of whole breast segmentation
    in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.

    Args:
        prob: probability of applying the transform.
        kernel_size: window size for averaging operation for the remapping
            curve.
        slope: slope of the linear component. Easiest to leave default value
            and tune the kernel_size parameter instead.
        channel_wise: set to True to treat each channel independently.
    rS   r  r  TrU   rV   r  r   r  r   r[   c                 C  s$   t j| |d || _|| _|| _d S )Nr  )r   r_   r  r  r   )ra   rU   r  r  r   rb   rb   rc   r_   B
  s   
zRandIntensityRemap.__init__rd   r  r\   c                   sr   t  d t t d jr7jr&t fddtt	 D   S t
jjj jg   S )r  Nrt   c                   s2   g | ]}t jjj jg | qS rb   )rN   r  rl   r  r  r  rd   ra   rb   rc   r   R
  s    $z/RandIntensityRemap.__call__.<locals>.<listcomp>)ri   rj   r(   r   rk   r   r   r   r   r   rN   r  rl   r  r  r  rq   r  rc   rx   H
  s   
	"zRandIntensityRemap.__call__)rS   r  r  T)rU   rV   r  r   r  rV   r   r[   r  r  rb   rb   rq   rc   rO   +
  s    rO   c                   @  sH   e Zd ZdZejejgZ			ddddZdd Z	dd Z
dddZdS )rP   a  
    Creates a binary mask that defines the foreground based on thresholds in RGB or HSV color space.
    This transform receives an RGB (or grayscale) image where by default it is assumed that the foreground has
    low values (dark) while the background has high values (white). Otherwise, set `invert` argument to `True`.

    Args:
        threshold: an int or a float number that defines the threshold that values less than that are foreground.
            It also can be a callable that receives each dimension of the image and calculate the threshold,
            or a string that defines such callable from `skimage.filter.threshold_...`. For the list of available
            threshold functions, please refer to https://scikit-image.org/docs/stable/api/skimage.filters.html
            Moreover, a dictionary can be passed that defines such thresholds for each channel, like
            {"R": 100, "G": "otsu", "B": skimage.filter.threshold_mean}
        hsv_threshold: similar to threshold but HSV color space ("H", "S", and "V").
            Unlike RBG, in HSV, value greater than `hsv_threshold` are considered foreground.
        invert: invert the intensity range of the input image, so that the dtype maximum is now the dtype minimum,
            and vice-versa.

    otsuNFr  #dict | Callable | str | float | inthsv_threshold*dict | Callable | str | float | int | Noneinvertr[   r\   r]   c                 C  s   i | _ |d ur0t|tr| D ]\}}| ||  qn| |d | |d | |d |d ur]t|trK| D ]\}}| ||  q=n| |d | |d | |d dd | j  D | _ | j  td	r{t	d
| j  d|| _
d S )Nrl   GBHSVc                 S  s   i | ]\}}|d ur||qS r^   rb   )r   r   vrb   rb   rc   
<dictcomp>
  s    z+ForegroundMask.__init__.<locals>.<dictcomp>ZRGBHSVzBThreshold for at least one channel of RGB or HSV needs to be set. z is provided.)
thresholdsr   dictitems_set_thresholdr-  keys
isdisjointsetr   r  )ra   r  r   r  rM  thrb   rb   rc   r_   s
  s.   


zForegroundMask.__init__c                 C  sr   t |r|| j|< d S t|trttjd|  | j|< d S t|tt	fr/t|| j|< d S t
dt| d)N
threshold_zB`threshold` should be either a callable, string, or float number, z was given.)callabler
  r   rN  getattrr*   filtersr,  rV   r   r   r   )ra   r  rM  rb   rb   rc   r  
  s   
zForegroundMask._set_thresholdc                 C  s    | j |}t|r||S |S r^   )r
  getr  )ra   imagerM  r  rb   rb   rc   _get_threshold
  s   zForegroundMask._get_thresholdr  r   c                 C  s.  t |t d}t|tj^}}| jrtj|}g }| j	 
tdsLt|d d }t|dD ]\}}| ||}|rFt|||k}q2|| | j	 
tdstjj|dd}	t|d d }
t|	dD ]\}}| ||}|rt|
||k}
ql||
 t|jdd}t||dd S )	Nrt   RGBr   HSVr   )channel_axisr  rJ  )r(   r   r&   r   r   r  r*   utilr
  r  r  r  
zeros_liker  r  
logical_orr   colorZrgb2hsvr   allr'   )ra   r  Zimg_rgbrp   ZforegroundsZrgb_foregroundrd   rM  r  Zimg_hsvZhsv_foregroundr$  rb   rb   rc   rx   
  s0   

zForegroundMask.__call__)r  NF)r  r  r   r  r  r[   r\   r]   )r  r   )r{   r|   r}   r~   r   r   r   r   r_   r  r  rx   rb   rb   rb   rc   rP   ]
  s    rP   c                      s.   e Zd ZdZdd fddZdddZ  ZS )rQ   a  Compute horizontal and vertical maps from an instance mask
    It generates normalized horizontal and vertical distances to the center of mass of each region.
    Input data with the size of [1xHxW[xD]], which channel dim will temporarily removed for calculating coordinates.

    Args:
        dtype: the data type of output Tensor. Defaults to `"float32"`.

    Return:
        A torch.Tensor with the size of [2xHxW[xD]], which is stack horizontal and vertical maps

    r   rY   r
   r\   r]   c                   s   t    || _d S r^   )ri   r_   rY   )ra   rY   rq   rb   rc   r_   
  s   

zComputeHoVerMaps.__init__r$  r   c           	      C  s(  t |tjd }|j| jdd}|j| jdd}|d}tj|D ]b}|j	d d df |j
d  }|j	d d df |j
d  }||dk   t|   < ||dk  t|  < ||dk   t|   < ||dk  t|  < ||||jk< ||||jk< q#tt||gt d}|S )Nr   Tr   r   rt   )r&   r   r   r   rY   rQ  r*   measureregionpropsr   centroidaminamaxlabelr(   concatenater   )	ra   r$  Zinstance_maskh_mapZv_mapregionZv_disth_distZhv_mapsrb   rb   rc   rx   
  s   
zComputeHoVerMaps.__call__)r   )rY   r
   r\   r]   )r$  r   r  rb   rb   rq   rc   rQ   
  s    rQ   c                   @  s8   e Zd ZdZ									ddddZddddZdS )rR   a2  Compute confidence map from an ultrasound image.
    This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005.
    It generates a confidence map by setting source and sink points in the image and computing the probability
    for random walks to reach the source for each pixel.

    The official code is available at:
    https://campar.in.tum.de/Main/AthanasiosKaramalisCode

    Args:
        alpha (float, optional): Alpha parameter. Defaults to 2.0.
        beta (float, optional): Beta parameter. Defaults to 90.0.
        gamma (float, optional): Gamma parameter. Defaults to 0.05.
        mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
        sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when
            calling the transform. Can be one of 'all', 'mid', 'min', 'mask'.
        use_cg (bool, optional): Use Conjugate Gradient method for solving the linear system. Defaults to False.
        cg_tol (float, optional): Tolerance for the Conjugate Gradient method. Defaults to 1e-6.
            Will be used only if `use_cg` is True.
        cg_maxiter (int, optional): Maximum number of iterations for the Conjugate Gradient method. Defaults to 200.
            Will be used only if `use_cg` is True.
    r       V@皙?r  r   Fư>   rs  rV   betar<  cg_tol
cg_maxiterr   c	           	   	   C  s   || _ || _|| _|| _|| _|| _|| _|| _| jdvr&td| j d| jdvr4td| j dt	| j | j| j| j| j| j| j| j| _
d S )N)r  ZRFzUnknown mode: z#. Supported modes are 'B' and 'RF'.)r   midr   r$  zUnknown sink mode: z5. Supported modes are 'all', 'mid', 'min' and 'mask'.)rs  r/  r<  rM  	sink_modeuse_cgr0  r1  r   r   _compute_conf_map)	ra   rs  r/  r<  rM  r3  r4  r0  r1  rb   rb   rc   r_   
  s"   

 
z)UltrasoundConfidenceMapTransform.__init__Nrd   r   r$  rH  r\   c                 C  s   | j dkr|du rtd|jd dkrtdt|t d}t|tj^}}|d }d}|durFt|tj	t d}t|tj^}}|d }t
|jd	krTtj|dd
}|durb|j|jkrbtd| ||}t|tju rtt|}|S )a+  Compute confidence map from an ultrasound image.

        Args:
            img (ndarray or Tensor): Ultrasound image of shape [1, H, W] or [1, D, H, W]. If the image has channels,
                they will be averaged before computing the confidence map.
            mask (ndarray or Tensor, optional): Mask of shape [1, H, W]. Defaults to None. Must be
                provided when sink mode is 'mask'. The non-zero values of the mask are used as sink points.

        Returns:
            ndarray or Tensor: Confidence map of shape [1, H, W].
        r$  Nz1A mask must be provided when sink mode is 'mask'.r   r   z<The correct shape of the image is [1, H, W] or [1, D, H, W].rt   )rY   ru   r   r  z/The mask must have the same shape as the image.)r3  r   ro   r(   r   r&   r   r   r   r[   r   rW   r5  r   r   r  )ra   rd   r$  Z_imgr  rp   r  Zconf_maprb   rb   rc   rx     s(   
z)UltrasoundConfidenceMapTransform.__call__)r  r+  r,  r  r   Fr-  r.  )
rs  rV   r/  rV   r<  rV   r0  rV   r1  r   r^   r  )r{   r|   r}   r~   r_   rx   rb   rb   rb   rc   rR   
  s     rR   )jr~   
__future__r   abcr   collections.abcr   r   r   	functoolsr   typingr   warningsr	   numpyr   r   monai.configr
   monai.config.type_definitionsr   r   monai.data.meta_objr   Z$monai.data.ultrasound_confidence_mapr   monai.data.utilsr   r   monai.networks.layersr   r   r   r   monai.transforms.transformr   r   monai.transforms.utilsr   r   r   r   r   0monai.transforms.utils_pytorch_numpy_unificationr   r   r   monai.utils.enumsr   monai.utils.miscr    r!   r"   r#   monai.utils.moduler$   r%   monai.utils.type_conversionr&   r'   r(   r)   r*   rp   __all__r+   r,   r-   r.   r/   r0   r2   r4   r5   r3   r1   r6   r7   r8   r9   r:   r;   r<   r=   r?   r>   r@   rA   rB   rC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rb   rb   rb   rc   <module>   s   
,9YE6=7YBD]p9 ?Lz6-)#68STL>  NK(/72a)