U
    Phz                     @  s  d Z ddlmZ ddlZddlZddlmZ ddlZddl	Z	ddl
Z
ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlm Z m!Z!m"Z"m#Z# ddl$m%Z% ddl&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z-m.Z.m/Z/ e/d\Z0Z1e/d\Z2Z3e/d\Z4Z3e/d\Z5Z3ddddddddgZ6d)ddZ7dd d!dZ8dd d"dZ9d#d Z:d$d Z;d%d Z<d&d Z=d'd Z>d(d Z?dS )*zA
A collection of "functional" transforms for spatial operations.
    )annotationsN)Enum)USE_COMPILED)NdarrayOrTensor)get_track_meta)
MetaTensor)
AFFINE_TOLcompute_shape_offsetto_affine_nd)AffineTransform)ResizeWithPadOrCrop)GaussianSmooth)TraceableTransform)create_rotatecreate_translateresolves_modesscale_affine)allclose)	LazyAttr	TraceKeysconvert_to_dst_typeconvert_to_numpyconvert_to_tensorensure_tupleensure_tuple_repfall_back_tupleoptional_importnibabelcupyzcupyx.scipy.ndimagezscipy.ndimagespatial_resampleorientationflipresizerotatezoomrotate90affine_funcc                 C  s&   t t| tr|  n| ||t ddS )zgcreate a metatensor with fresh metadata if track_meta is True otherwise convert img into a torch tensorT)dtypedevice
track_metawrap_sequence)r   
isinstancer   	as_tensorr   )imgr'   r(    r.   X/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/spatial/functional.py_maybe_new_metatensor8   s    r0   ztorch.Tensor)returnc	               
   C  sx  t | tr|  n| jdd }	t | tr2|  ntd}
t| t d} t	t
| jd |
jd d d}t |tr||dkr|dk	rt	t
t|d}t||
tj}
|dk	rt||n|
}t||
d }t |tjstdt| t|	d| }t |tr|dkr|}n$|dkr<|dkr<t||
|\}}ttt|d| |d	d
 }t|dd t |tr~|jn|t |tr|jn||dk	r|ntj|
d}z:t|
}t|}|dk rt|d ntj||}W nD tjj t!fk
r, } ztd| d| d|W 5 d}~X Y nX tt||j| j"tjd}t#|
|t$drft#||pt#|tt
|t$dot#||}t%j&| ||r|sdn|||	||d}|rt'| }t |tr|(|S |S |rt'| tj)d}t |tr|(|S |S t | tr| * n| } t+| j}|d |d|d  ||d d   }}}|rpdg| }| ,|} | |} t |tst-rt.|dd |D }|t||d  }t/j0j1||dd||d}|2d || ||d} W 5 Q R X nDt3||\}}}}t4d|||dd}|| 5d|| |d6d} |rP|f||}| ,|} t'| tj)d}t |trt|(|S |S )a  
    Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be resampled, assuming `img` is channel-first.
        dst_affine: target affine matrix, if None, use the input affine matrix, effectively no resampling.
        spatial_size: output spatial size, if the component is ``-1``, use the corresponding input spatial size.
        mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
            Interpolation mode to calculate output values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
            and the value represents the order of the spline interpolation.
            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
        padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
            Padding mode for outside grid values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            When `mode` is an integer, using numpy/cupy backends, this argument accepts
            {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
        align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        dtype_pt: data `dtype` for resampling computation.
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
       N   )datar)   r      z)dst_affine should be a torch.Tensor, got c                 S  s   | dkS )Nr   r.   )xr.   r.   r/   <lambda>t       z"spatial_resample.<locals>.<lambda>   )r'   modepadding_modealign_corners
src_affine   zsrc affine is not invertible z, .)r(   r'   )atolsp_sizeaffine
extra_info	orig_sizetransform_infolazyr'   c                 S  s   g | ]}t |d  d qS r2   r?   float.0dr.   r.   r/   
<listcomp>   s     z$spatial_resample.<locals>.<listcomp>T)rD   spatial_size
normalized
image_onlyr'   r=   F)r;   r<   rR   r;   r<   r=   reverse_indexing)thetarQ   )7r+   r   peek_pending_shapeshapepeek_pending_affinetorcheyer   r   minlenintr   r
   tofloat64r   Tensor
ValueErrortypetensorr	   r   strr   valuer   NONEr   nplinalgsolveLinAlgErrorRuntimeErrorr(   r   r   r   track_transform_metar0   copy_meta_fromfloat32r,   listreshaper   r   monai
transformsAffinetrace_transformr   r   	unsqueezesqueeze) r-   
dst_affinerQ   r;   r<   r=   dtype_ptrH   rG   original_spatial_shaper>   spatial_rankZin_spatial_size_rE   _s_dxformeZaffine_unchanged	meta_infooutim_sizechnsZ
in_sp_sizeZadditional_dimsZxform_shape	dst_xformZaffine_xform_m_p
full_shaper.   r.   r/   r   C   s      **"	
,


     
c              	   C  s~  t | tr|  n| jdd }tj||}t| t d} |dddf  d7  < t	
t	ddgg|g}dd t|dddf D }t	t|d }t	|dddf |dt|< d|i}	t|dd	}
|
d
d |D  }
tj| |
||	|||d}t| }|r&t |tr"||S |S |r:tj||d}t	|t	t|jksd|| }t |trz||S |S )a  
    Functional implementation of changing the input image's orientation into the specified based on `spatial_ornt`.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be changed, assuming `img` is channel-first.
        original_affine: original affine of the input image.
        spatial_ornt: orientations of the spatial axes,
            see also https://nipy.org/nibabel/reference/nibabel.orientations.html
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
    r2   Nr)   r   c                 S  s   g | ]\}}|d kr|qS )r6   r.   )rN   axr!   r.   r.   r/   rP      s      zorientation.<locals>.<listcomp>original_affineTr*   c                 S  s   g | ]}|d kr|d qS )r   r2   r.   rN   ir.   r.   r/   rP      s      rB   )dims)r+   r   rW   rX   niborientationsinv_ornt_affr   r   rh   concatenatearray	enumeratearanger]   argsortr   r   rm   r0   rn   rZ   r!   allpermutetolist)r-   r   spatial_orntrH   rG   spatial_shaper   axesfull_transposerE   Zshape_npr   r   r.   r.   r/   r       s8     "	c                 C  s&  t | tr|  n| jdd }t|dd }d|i}tjj	| j
|}t | tr\|  ntjdtjd}tjt|d tjd}|D ]:}	|	d }
||
|
f d ||
 d  ||
|
f< ||
df< qtj| |||||d	}t| }|r t |tr||S |S t||}t |tr"||S |S )
a  
    Functional implementation of flip.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be changed, assuming `img` is channel-first.
        sp_axes: spatial axes along which to flip over.
            If None, will flip over all of the axes of the input array.
            If axis is negative it counts from the last to the first axis.
            If axis is a tuple of ints, flipping is performed on all of the axes
            specified in the tuple.
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
    r2   NTr   r         @rI   r6   )rC   rD   rE   rG   rH   )r+   r   rW   rX   r   r   rr   rs   utilsmap_spatial_axesndimpeek_pending_rankrZ   rd   doubler[   r^   r   rm   r0   rn   r!   )r-   Zsp_axesrH   rG   rC   rE   r   rankr   axisspr   r   r.   r.   r/   r!      s,     "0     c
              	   C  sD  t | t d} t| tr |  n| jdd }
||dk	r<|ntjt|dd t	|
| d}t
j| |t|
|||
|	|d}|r|r|rtd t| }t|tr||S |S tt|
|krt| tjd}t|tr||S |S t| }t ||d	d
}|rtdd t||jdd D rttt|jdd t|}|dkrztt|j|d d  }nBtt|t	|}tt	|D ]"}|| t|| dk ||< qt |d}t ||d	d}t!|t	|jd d\}}}}tj"j#j$|%d|||d}t&|'d|tjd^}}t|tr@||S |S )a  
    Functional implementation of resize.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be changed, assuming `img` is channel-first.
        out_size: expected shape of spatial dimensions after resize operation.
        mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``,
            ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
            The interpolation mode.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        align_corners: This only has an effect when mode is
            'linear', 'bilinear', 'bicubic' or 'trilinear'.
        dtype: data type for resampling computation. If None, use the data type of input data.
        input_ndim: number of spatial dimensions.
        anti_aliasing: whether to apply a Gaussian filter to smooth the image prior
            to downsampling. It is crucial to filter when downsampling
            the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
        anti_aliasing_sigma: {float, tuple of floats}, optional
            Standard deviation for Gaussian filtering used when anti-aliasing.
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
    r   r2   Nr:   )r;   r=   r'   new_dimrB   z5anti-aliasing is not compatible with lazy evaluation.rI   F)r'   r)   c                 s  s   | ]\}}||k V  qd S Nr.   )rN   r7   yr.   r.   r/   	<genexpr>B  s     zresize.<locals>.<genexpr>r?   )sigmatorch_interpolate_spatial_ndr   )inputsizer;   r=   )(r   r   r+   r   rW   rX   r   rg   re   r]   r   rm   r   warningswarnr0   rn   tupler   rZ   ro   anyzipdivra   rp   maximumzerosr   r   ranger^   r   r   nn
functionalinterpolaterv   r   rw   )r-   out_sizer;   r=   r'   
input_ndimanti_aliasinganti_aliasing_sigmarH   rG   rF   rE   r   r   img_factorsr   Zanti_aliasing_filterr|   r   resizedr.   r.   r/   r"     sX     
	
*&
" 
   c	              	   C  s(  t | tr|  n| jdd }	t|	}
|
dkr@td|
 dt||
dkrPdnd}t|
|}|dkrt	tj
dd	 |	D d
dit|	df}|ddddf | }tj	|jddd td}ntj	|td}t|
t|	d d  }t|
tj	|tdd  d  }|| | }||||dk	r6|ntjt|dd d}tj| ||||	||d}t| }|rt |tr||S |S t||\}}}}td|||dd}||}t||^}}||d|tdd |D d}| d}t||t j!d^}}t |tr$||S |S )a  
    Functional implementation of rotate.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be changed, assuming `img` is channel-first.
        angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D.
        output_shape: output shape of the rotated data.
        mode: {``"bilinear"``, ``"nearest"``}
            Interpolation mode to calculate output values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
            Padding mode for outside grid values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        dtype: data type for resampling computation.
            If None, use the data type of input data. To be compatible with other modules,
            the output data type is always ``float32``.
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.

    r2   N)r?   r5   zUnsupported image dimension: z, available options are [2, 3].r?   r5   c                 S  s   g | ]}d |fqS )r   r.   )rN   dimr.   r.   r/   rP   v  s     zrotate.<locals>.<listcomp>indexingijr6   )r   g      ?rI   r:   )rot_matr;   r<   r=   r'   rB   FTrT   r   c                 s  s   | ]}t |V  qd S r   r^   r   r.   r.   r/   r     s     zrotate.<locals>.<genexpr>)rQ   dstr'   )"r+   r   rW   rX   r]   rb   r   r   rh   asarraymeshgridrq   ptpr^   r   r   r   r   rg   re   r   rm   r0   rn   r   r   r_   r   rv   r   rL   rw   rZ   ro   )r-   angleoutput_shaper;   r<   r=   r'   rH   rG   im_shaper   _angle	transformcornersshiftshift_1rE   r   r   r|   r   r   r   img_ttransform_toutputr.   r.   r/   r#   V  s\     
."	    
"c	              	   C  s>  t | tr|  n| jdd }	dd t|	|D }
t|	|
}||dk	rL|ntjt|dd di d}|rt	
|
|	 }|r|rt|	|d}d	|_tg tt|
d d
}|tjt|
tj|i ||}t |tr| }|jd |d< ||d< dd |	D }
tj| |
|||	||d}t| }|rHt |trD||S |S ||}t|t|jd d\}}}}tjjjd	| dt|||d!d}t"||tj#d^}}t |tr||}t	
|
|jdd  }|rt|jdd |d}||}t$ r:|r:|j%& }d	|j%d d d< ||j%d d d< |S )a  
    Functional implementation of zoom.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be changed, assuming `img` is channel-first.
        scale_factor: The zoom factor along the spatial axes.
            If a float, zoom is the same for each spatial axis.
            If a sequence, zoom should contain one value for each spatial axis.
        keep_size: Whether keep original size (padding/slicing if needed).
        mode: {``"bilinear"``, ``"nearest"``}
            Interpolation mode to calculate output values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
            Padding mode for outside grid values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        dtype: data type for resampling computation.
            If None, use the data type of input data. To be compatible with other modules,
            the output data type is always ``float32``.
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.

    r2   Nc                 S  s&   g | ]\}}t tt|| qS r.   )r^   mathfloorrL   )rN   r   zr.   r.   r/   rP     s     zzoom.<locals>.<listcomp>r:   F)r;   r=   r'   
do_padcroppadcrop)rQ   r;   T)rD   r6   r   r   c                 S  s   g | ]}t |qS r.   r   r   r.   r.   r/   rP     s     rB   r   r   )recompute_scale_factorr   scale_factorr;   r=   r   rE   )'r+   r   rW   rX   r   r   r   rg   re   rh   r   r   rH   rZ   r[   r]   push_pending_operationr   SHAPErp   AFFINErY   pending_operationsr   rm   r0   rn   r_   r   r   r   r   rv   rw   r   ro   r   applied_operationspop)r-   r   	keep_sizer;   r<   r=   r'   rH   rG   r   output_sizer   rE   Zdo_pad_cropZ	_pad_cropZ_tmp_imgZlazy_croppedr   r   r   r|   r   Zzoomedpadcrop_xformr.   r.   r/   r$     sr     

	


c              	   C  s  dd |D |d}t | tr&|  n| jdd }t|}|dkrx|d d |d d  }}	||	 ||  ||< ||	< t | tr|  ntjdtjd	}
t	|
t
| }}t|t|d
d |D }t	|d t	|d  dkrdnd}|dkrt|t||tj d g}nFdddht| }dddg}|tj d || d < t|t||}t|D ]}|| }q`t|t|dd |D | }tj| ||||||d}t| }|rt |tr||S |S t|||}t |tr||S |S )ai  
    Functional implementation of rotate90.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be changed, assuming `img` is channel-first.
        axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
            If axis is negative it counts from the last to the first axis.
        k: number of times to rotate by 90 degrees.
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
    c                 S  s   g | ]}|d  qS )r2   r.   rM   r.   r.   r/   rP     s     zrotate90.<locals>.<listcomp>)r   kr2   N)r2   r5   r   r   rI   c                 S  s   g | ]}t |d   d qS rJ   rK   rM   r.   r.   r/   rP     s     )r6   r?   g      g      ?r?   r5   c                 S  s   g | ]}t |d  d qS rJ   rK   rM   r.   r.   r/   rP     s     rB   )r+   r   rW   rX   rp   r   rZ   rd   r   r^   r]   r
   r   r   rh   pisetr   r   r   rm   r0   rn   rot90)r-   r   r   rH   rG   rE   Z	ori_shapeZsp_shapeZa_0a_1r   rZsp_rr   sr   idxr   r|   r   r   r.   r.   r/   r%     sB     "$

	c              	   C  s  t | tr|  n| jdd }t | tr2|  ntjdtjd}|||||jd}t	j
j||||}tj| |||||
|	d}|	rt| }t |tr||n|}|r|S ||fS |r|| |||d}t|}nt| tj|jd}t |tr||n|}|r|S ||fS )	ai  
    Functional implementation of affine.
    This function operates eagerly or lazily according to
    ``lazy`` (default ``False``).

    Args:
        img: data to be changed, assuming `img` is channel-first.
        affine: the affine transformation to be applied, it can be a 3x3 or 4x4 matrix. This should be defined
            for the voxel space spatial centers (``float(size - 1)/2``).
        grid: used in non-lazy mode to pre-compute the grid to do the resampling.
        resampler: the resampler function, see also: :py:class:`monai.transforms.Resample`.
        sp_size: output image spatial size.
        mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
            Interpolation mode to calculate output values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
            and the value represents the order of the spline interpolation.
            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
        padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
            Padding mode for outside grid values.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            When `mode` is an integer, using numpy/cupy backends, this argument accepts
            {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
        do_resampling: whether to do the resampling, this is a flag for the use case of updating metadata but
            skipping the actual (potentially heavy) resampling operation.
        image_only: if True return only the image volume, otherwise return (image, affine).
        lazy: a flag that indicates whether the operation should be performed lazily or not
        transform_info: a dictionary with the relevant information pertaining to an applied transform.

    r2   Nr   rI   )rD   r;   r<   do_resamplingr=   rB   )r-   gridr;   r<   )r'   r(   )r+   r   rW   rX   r   rZ   rd   r   r=   rr   rs   rt   compute_w_affiner   rm   r0   rn   ro   r(   )r-   rD   r   	resamplerrC   r;   r<   r   rS   rH   rG   img_sizer   rE   r   r   r.   r.   r/   r&   #  s8    $ "	
)NN)@__doc__
__future__r   r   r   enumr   numpyrh   rZ   rr   monai.configr   monai.config.type_definitionsr   monai.data.meta_objr   monai.data.meta_tensorr   monai.data.utilsr   r	   r
   monai.networks.layersr   monai.transforms.croppad.arrayr   Z monai.transforms.intensity.arrayr   monai.transforms.inverser   monai.transforms.utilsr   r   r   r   0monai.transforms.utils_pytorch_numpy_unificationr   monai.utilsr   r   r   r   r   r   r   r   r   r   has_nibr   r|   cupy_ndinp_ndi__all__r0   r   r    r!   r"   r#   r$   r%   r&   r.   r.   r.   r/   <module>   sB   ,
w.$JGT2