U
    PhE                     @  s  d Z ddlmZ ddlmZ ddlmZ ddlZddl	Z	ddl
mZmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZmZ ddlmZ ddlmZ ddl m!Z!m"Z" ddddgZ#G dd deZ$G dd deZ%G dd deZ&G dd deZ'dS )zbTransforms using a smooth spatial field generated by interpolating from smaller randomized fields.    )annotations)Sequence)AnyN)grid_sampleinterpolate)NdarrayOrTensor)get_track_meta)meshgrid_ij)RandomizableRandomizableTransform)moveaxis)GridSampleModeGridSamplePadModeInterpolateMode)TransformBackends)look_up_option)convert_to_dst_typeconvert_to_tensorSmoothFieldRandSmoothFieldAdjustContrastRandSmoothFieldAdjustIntensityRandSmoothDeformc                   @  s   e Zd ZdZejgZddddddejddf	ddd	d	d	dd
dddd
ddZ	d!dddddZ
d
ddddZdddddZd"dddd ZdS )#r   a  
    Generate a smooth field array by defining a smaller randomized field and then reinterpolating to the desired size.

    This exploits interpolation to create a smoothly varying field used for other applications. An initial randomized
    field is defined with `rand_size` dimensions with `pad` number of values padding it along each dimension using
    `pad_val` as the value. If `spatial_size` is given this is interpolated to that size, otherwise if None the random
    array is produced uninterpolated. The output is always a Pytorch tensor allocated on the specified device.

    Args:
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with `pad_val`
        pad_val: value with which to pad field edges
        low: low value for randomized field
        high: high value for randomized field
        channels: number of channels of final output
        spatial_size: final output size of the array, None to produce original uninterpolated field
        mode: interpolation mode for resizing the field
        align_corners: if True align the corners when upsampling field
        device: Pytorch device to define field on
    r   g            ?   NSequence[int]intfloatzSequence[int] | Nonestrbool | Nonetorch.device | None
	rand_sizepadpad_vallowhighchannelsspatial_sizemodealign_cornersdevicec                   s   t | _| _| _| _| _| _|	 _|
 _d  _	d  _
||krPtdt  fdd jD  _tjd jf j  jd|  _ jf j  _ jdkrtd nt j j }dtd f|ft j   _ | d S )NzFValue for `low` must be less than `high` otherwise field will be zerosc                 3  s   | ]}| j d   V  qdS )   N)r"   ).0rsself X/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/smooth_field/array.py	<genexpr>W   s     z'SmoothField.__init__.<locals>.<genexpr>r   )r*   r   )tupler!   r"   r$   r%   r&   r(   r)   r*   r'   spatial_zoom
ValueErrortotal_rand_sizetorchonesfield
crand_sizeslicelenrand_slicesset_spatial_size)r/   r!   r"   r#   r$   r%   r&   r'   r(   r)   r*   Z	pad_slicer0   r.   r1   __init__;   s$    
""zSmoothField.__init__
Any | NoneNonedatareturnc                 C  s(   t | j| j| j| j| j| j< d S N)	r7   
from_numpyRuniformr$   r%   r:   r9   r=   r/   rC   r0   r0   r1   	randomizeb   s    zSmoothField.randomize)r'   rD   c                 C  sB   |dkrd| _ d| _n(t|| _ tdd t| j | jD | _dS )a
  
        Set the `spatial_size` and `spatial_zoom` attributes used for interpolating the field to the given
        dimension, or not interpolate at all if None.

        Args:
            spatial_size: new size to interpolate to, or None to not interpolate
        Nc                 s  s   | ]\}}|| V  qd S rE   r0   )r,   sfr0   r0   r1   r2   r   s     z/SmoothField.set_spatial_size.<locals>.<genexpr>)r'   r4   r3   zipr6   )r/   r'   r0   r0   r1   r>   e   s
    
zSmoothField.set_spatial_sizer(   rD   c                 C  s
   || _ d S rE   )r(   r/   r(   r0   r0   r1   set_modet   s    zSmoothField.set_modeFztorch.Tensor)rD   c           	      C  s   |r|    | j }| jd k	rt|| jt| jt| jdd}|	 }|
 }| j	 }| j
 }|d| || }||| |}|S )NF)inputscale_factorr(   r)   recompute_scale_factorr   )rJ   r9   cloner4   r   r   r(   r   r)   minmaxsqueezediv_mul_add_)	r/   rJ   r9   Zresized_fieldminamaxaminvmaxvZ
norm_fieldr0   r0   r1   __call__w   s$    




zSmoothField.__call__)N)F)__name__
__module____qualname____doc__r   TORCHbackendr   AREAr?   rJ   r>   rP   r_   r0   r0   r0   r1   r   #   s    "'c                
      s   e Zd ZdZejgZdejddddfddddd	d
ddd fddZ	d#ddd d fddZ
d$ddd fddZdddddZd%dddd d!d"Z  ZS )&r   a  
    Randomly adjust the contrast of input images by calculating a randomized smooth field for each invocation.

    This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the
    edges of the input volume of that width will be mostly unchanged. Contrast is changed by raising input
    values by the power of the smooth field so the range of values given by `gamma` should be chosen with this
    in mind. For example, a minimum value of 0 in `gamma` will produce white areas so this should be avoided.
    After the contrast is adjusted the values of the result are rescaled to the range of the original input.

    Args:
        spatial_size: size of input array's spatial dimensions
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with 1
        mode: interpolation mode to use when upsampling
        align_corners: if True align the corners when upsampling field
        prob: probability transform is applied
        gamma: (min, max) range for exponential field
        device: Pytorch device to define field on
    r   N皙?)      ?g      @r   r   r   r   r   Sequence[float] | floatr   r'   r!   r"   r(   r)   probgammar*   c	           	        sz   t  | t|ttfr&d|f| _n&t|dkr:tdt|t	|f| _t
||d| jd | jd d||||d
| _d S Nrh   r+   z7Argument `gamma` should be a number or pair of numbers.r   r   r    superr?   
isinstancer   r   rl   r<   r5   rU   rV   r   sfield	r/   r'   r!   r"   r(   r)   rk   rl   r*   	__class__r0   r1   r?      s$    z&RandSmoothFieldAdjustContrast.__init__
int | Nonenp.random.RandomState | NoneseedstaterD   c                   s    t  || | j|| | S rE   ro   set_random_staterq   r/   rx   ry   rs   r0   r1   r{      s    z.RandSmoothFieldAdjustContrast.set_random_stater@   rA   rB   c                   s    t  d  | jr| j  d S rE   ro   rJ   _do_transformrq   rI   rs   r0   r1   rJ      s    z'RandSmoothFieldAdjustContrast.randomizerN   c                 C  s   | j | d S rE   rq   rP   rO   r0   r0   r1   rP      s    z&RandSmoothFieldAdjustContrast.set_modeTr   boolimgrJ   rD   c           
      C  sz   t |t d}|r|   | js$|S | }| }|| }|  }t||^}}|| |d  }|| }|| | }	|	S ){
        Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous.
        
track_metag|=)r   r   rJ   r~   rU   rV   rq   r   )
r/   r   rJ   img_minimg_maxZimg_rngr9   rfield_outr0   r0   r1   r_      s    z&RandSmoothFieldAdjustContrast.__call__)NN)N)Tr`   ra   rb   rc   r   rd   re   r   rf   r?   r{   rJ   rP   r_   __classcell__r0   r0   rs   r1   r      s   "#   c                
      s   e Zd ZdZejgZdejddddfddddd	d
ddd fddZ	d#ddd d fddZ
d$ddd fddZdddddZd%dddd d!d"Z  ZS )&r   a+  
    Randomly adjust the intensity of input images by calculating a randomized smooth field for each invocation.

    This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the
    edges of the input volume of that width will be mostly unchanged. Intensity is changed by multiplying the
    inputs by the smooth field, so the values of `gamma` should be chosen with this in mind. The default values
    of `(0.1, 1.0)` are sensible in that values will not be zeroed out by the field nor multiplied greater than
    the original value range.

    Args:
        spatial_size: size of input array
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with 1
        mode: interpolation mode to use when upsampling
        align_corners: if True align the corners when upsampling field
        prob: probability transform is applied
        gamma: (min, max) range of intensity multipliers
        device: Pytorch device to define field on
    r   Nrg   )rg   r   r   r   r   r   r   ri   r   rj   c	           	        sz   t  | t|ttfr&d|f| _n&t|dkr:tdt|t	|f| _t
||d| jd | jd d||||d
| _d S rm   rn   rr   rs   r0   r1   r?     s$    z'RandSmoothFieldAdjustIntensity.__init__ru   rv   rw   c                   s    t  || | j|| | S rE   rz   r|   rs   r0   r1   r{   0  s    z/RandSmoothFieldAdjustIntensity.set_random_stater@   rA   rB   c                   s    t  d  | jr| j  d S rE   r}   rI   rs   r0   r1   rJ   7  s    z(RandSmoothFieldAdjustIntensity.randomizerN   c                 C  s   | j | d S rE   r   rO   r0   r0   r1   rP   =  s    z'RandSmoothFieldAdjustIntensity.set_modeTr   r   r   c                 C  sF   t |t d}|r|   | js$|S |  }t||^}}|| }|S )r   r   )r   r   rJ   r~   rq   r   )r/   r   rJ   r9   r   r   r   r0   r0   r1   r_   @  s    z'RandSmoothFieldAdjustIntensity.__call__)NN)N)Tr   r0   r0   rs   r1   r      s   "#   c                      s   e Zd ZdZejgZdejddde	j
ejejddf
dddd	d
ddd	d	d
dd fddZd'dddd fddZd(ddd fddZd	ddddZd	dddd Zd)d"d#dd"d$d%d&Z  ZS )*r   a`  
    Deform an image using a random smooth field and Pytorch's grid_sample.

    The amount of deformation is given by `def_range` in fractions of the size of the image. The size of each dimension
    of the input image is always defined as 2 regardless of actual image voxel dimensions, that is the coordinates in
    every dimension range from -1 to 1. A value of 0.1 means pixels/voxels can be moved by up to 5% of the image's size.

    Args:
        spatial_size: input array size to which deformation grid is interpolated
        rand_size: size of the randomized field to start from
        pad: number of pixels/voxels along the edges of the field to pad with 0
        field_mode: interpolation mode to use when upsampling the deformation field
        align_corners: if True align the corners when upsampling field
        prob: probability transform is applied
        def_range: value of the deformation range in image size fractions, single min/max value  or min/max pair
        grid_dtype: type for the deformation grid calculated from the field
        grid_mode: interpolation mode used for sampling input using deformation grid
        grid_padding_mode: padding mode used for sampling input using deformation grid
        grid_align_corners: if True align the corners when sampling the deformation grid
        device: Pytorch device to define field on
    r   Nrg   r   Fr   r   r   r   r   ri   r   )r'   r!   r"   
field_moder)   rk   	def_range	grid_modegrid_padding_modegrid_align_cornersr*   c                   s   t  | || _|	| _|| _|| _|| _|
| _t|t	t
frL| |f| _n&t|dkr`tdt|t|f| _t|||| jd | jd t||||d	| _|d k	rt|n| jjjdd  }dd |D }t| }t|d| j| j| _d S )Nr+   z;Argument `def_range` should be a number or pair of numbers.r   r   )	r'   r!   r"   r$   r%   r&   r(   r)   r*   c                 S  s   g | ]}t d d|qS )r   )r7   linspace)r,   dr0   r0   r1   
<listcomp>  s     z-RandSmoothDeform.__init__.<locals>.<listcomp>)ro   r?   
grid_dtyper   r   r*   r   r   rp   r   r   r<   r5   rU   rV   r   rq   r3   r9   shaper	   r7   stack	unsqueezetogrid)r/   r'   r!   r"   r   r)   rk   r   r   r   r   r   r*   Z
grid_spaceZgrid_rangesr   rs   r0   r1   r?   o  s6    "zRandSmoothDeform.__init__ru   rv   r
   rw   c                   s    t  || | j|| | S rE   rz   r|   rs   r0   r1   r{     s    z!RandSmoothDeform.set_random_stater@   rA   rB   c                   s    t  d  | jr| j  d S rE   r}   rI   rs   r0   r1   rJ     s    zRandSmoothDeform.randomizerN   c                 C  s   | j | d S rE   r   rO   r0   r0   r1   set_field_mode  s    zRandSmoothDeform.set_field_modec                 C  s
   || _ d S rE   )r   rO   r0   r0   r1   set_grid_mode  s    zRandSmoothDeform.set_grid_modeTr   r   )r   rJ   r*   rD   c           
      C  s   t |t d}|r|   | js$|S |d k	r0|n| j}|  }| j|| j }t	|dd}|dt
t|jd d ddf }t |d  tj|}t||t| jt| jt| jtd}t|d|^}}	|S )Nr   r   r   .)rQ   r   r(   r)   padding_moder   )r   r   rJ   r~   r*   rq   r   r   r   r   listranger   r7   float32r   r   r   r   r   r   r   r   rW   )
r/   r   rJ   r*   r9   Zdgridimg_tr   out_tr   r0   r0   r1   r_     s(    "

zRandSmoothDeform.__call__)NN)N)TN)r`   ra   rb   rc   r   rd   re   r   rf   r7   r   r   NEARESTr   BORDERr?   r{   rJ   r   r   r_   r   r0   r0   rs   r1   r   V  s(   (3   )(rc   
__future__r   collections.abcr   typingr   numpynpr7   torch.nn.functionalr   r   monai.config.type_definitionsr   monai.data.meta_objr   monai.networks.utilsr	   monai.transforms.transformr
   r   0monai.transforms.utils_pytorch_numpy_unificationr   monai.utilsr   r   r   monai.utils.enumsr   monai.utils.moduler   monai.utils.type_conversionr   r   __all__r   r   r   r   r0   r0   r0   r1   <module>   s(   oe_