U
    PhT1                     @  s>  d Z ddlmZ ddlZddlZddlZddlmZ	 ddl
mZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZmZmZmZmZ ddddgZdd ZdddddddZdddddddZ ej!fdddddddZ"ddddddZ#ej!d fdd!d"dd#dd$d%dZ$dd&d#d"dd'd(dZ%dS ))zA
A collection of "functional" transforms for spatial operations.
    )annotationsN)pad)NdarrayTensor)get_track_meta)
MetaTensor)to_affine_nd)TraceableTransform)convert_pad_modecreate_translate)PytorchPadModeconvert_to_dst_typeconvert_to_numpyconvert_to_tensorensure_tuplepad_ndpad_func	crop_funccrop_or_pad_ndc                 C  s8   | dks| dkrt jS | dkr$t jS | dkr2t jS t jS )zSget the most similar mode of `pad` from ``padding_mode`` of the spatial resampling.N)zerosconstantzgrid-constant)
reflectionreflectmirrorzgrid-mirror)wrapz	grid-wrap)r   CONSTANTREFLECTCIRCULAR	REPLICATE)padding_mode r   X/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/croppad/functional.py_convert_pt_pad_mode"   s    r!   r   zlist[tuple[int, int]]str)img	pad_widthmodereturnc              	   K  s   t | tjrF| jr4td| j d| j d| d |  	 
 }n| }t||dj}|dkrvd|krv|d|d< tj||fd	|i|}t|| d
d S )NzPadding: moving img z from cuda to cpu for dtype=z mode=.dstr%   r   valueconstant_valuesr%   r)   r   )
isinstancetorchTensoris_cudawarningswarnshapedtypedetachcpunumpyr	   r*   popnpr   r   )r#   r$   r%   kwargsZimg_npr   r   r    _np_pad-   s    "r;   c                 K  s   t | }t||dj}|dkr@d|kr@| }|d|d< n|}dd |dd  D d d d }t|d	|fd
|i|d	}t	|| dd	 S )Nr(   r   r+   r*   c                 S  s$   g | ]}|d d d D ]}|qqS )Nr   ).0sublistvalr   r   r    
<listcomp>C   s       z_pt_pad.<locals>.<listcomp>   r<   r   r%   r,   )
r.   	as_tensorr	   r*   copyr8   pad_pt	unsqueezesqueezer   )r#   r$   r%   r:   Zimg_pt_kwargsZpt_pad_widthr   r   r    _pt_pad;   s    
 "rH   )r#   to_padr%   r&   c                   s  |dkrt | f||d|S zBt }|dkrJ| jtjtjtjtjhkrJt}|| f||d|W S  tt	t
fk
r
   zt tst fdddD rt | f||d| W Y HS t| j d| d| d| d| j dt| tjr| jnd  W 5 d  X Y nX dS )	a  
    Pad `img` for a given an amount of padding in each dimension.

    `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
    in which case `np.pad` will be used.

    Args:
        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
        to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
            default to `self.to_pad`.
        mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.
    >   median	symmetriclinear_rampminimummaximummeanempty)r$   r%   >   r   	replicater   r   edgecircularc                 3  s   | ]}|t  kV  qd S )N)r"   )r=   kerrr   r    	<genexpr>l   s    zpad_nd.<locals>.<genexpr>)	supportedzunexpected keywordimplementedr*    N)r;   r4   r.   int16int64booluint8rH   
ValueError	TypeErrorRuntimeErrorr-   NotImplementedErroranyr3   r/   device)r#   rI   r%   r:   _padr   rU   r    r   I   s,    :ztorch.Tensorztuple[int, ...])r#   spatial_sizer%   c                 K  s  t | jd }tt|t|dd }tt ||}ttjdd |D ddi}|	t |df}t
|t|t|d	d f }|jdd
|jdd
 }	}
dgtd	gddf\}}}}t|	|
| jdd	 D ]\}}}|p|dk p||d k|p|dkp||d k  }}||dkr,dnt| ||d k rFdnt|| d fg7 }|ttt|dt|d |d d  g7 }q|rt|}t| |fd|i|} |r| | } | S )a  
    Crop or pad using the translation matrix and spatial size. The translation coefficients are rounded
    to the nearest integers. For a more generic implementation, please see :py:class:`monai.transforms.SpatialResample`.

    Args:
        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
        translation_mat: the translation matrix to be applied to the image. A translation matrix generated by,
            for example, :py:func:`monai.transforms.utils.create_translate`. The translation coefficients are rounded
            to the nearest integers.
        spatial_size: the spatial size of the output image.
        mode: the padding mode.
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
    rA   T)wrap_sequencec                 S  s   g | ]}d |d  gqS )g      ?r   )r=   xr   r   r    r@      s     z"crop_or_pad_nd.<locals>.<listcomp>indexingijr<   N)axisr   r   Fr   r%   )lenr3   r9   roundr   r   rC   asarraymeshgridreshapefloorconcatenate	ones_likeminmaxslicezipintr!   r   )r#   Ztranslation_matrf   r%   r:   ndim	matrix_npccZsrc_ccZ	src_startZsrc_endrI   Zto_cropdo_padZdo_cropsesp_moder   r   r    r   u   s$     & 6@0Fztuple[tuple[int, int]]dictr]   )r#   rI   transform_infor%   lazyr&   c              	   K  s  || d}t | tr|  n| jdd }t | tr>|  nd}t| }	|	rdd |D }
t|
t| jk r|
dgt| jt|
  7 }
dd |
dd D }t	||}d	d t
||
dd D }n$|}tjt|d td
tjd}tj| ||||||d}tt | tr |  n| t d}|rLt |trH||S |S |	rbt||
|f|n|}t|t d}t |tr||S |S )a7  
    Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
    to ``lazy`` (default ``False``).

    `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
    in which case `np.pad` will be used.

    Args:
        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
        to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
            note that it including channel dimension.
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
        mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        lazy: a flag indicating whether the operation should be performed in a lazy fashion or not.
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.
    )paddedr%   rA   N   c                 S  s$   g | ]}t |d  t |d fqS r   rA   )ry   )r=   pr   r   r    r@      s     zpad_func.<locals>.<listcomp>rl   c                 S  s   g | ]}|d   qS )r   r   )r=   r~   r   r   r    r@      s     c                 S  s    g | ]\}\}}|| | qS r   r   )r=   dr~   r   r   r   r    r@      s    
 r6   )rd   r4   sp_sizeaffine
extra_info	orig_sizer   r   
track_meta)r-   r   peek_pending_shaper3   peek_pending_rankr9   ro   rc   rm   r
   rx   r.   eyery   rd   float64r   track_transform_metar   rB   r   copy_meta_fromr   )r#   rI   r   r%   r   r:   r   img_sizespatial_rankr}   Zto_pad_listto_shiftxformr3   	meta_infooutr   r   r    r      s8     
 	"ztuple[slice, ...])r#   slicesr   r   r&   c              	   C  sb  t | tr|  n| jdd }t | tr2|  nd}tdd t|dd |D }d| 	 i}g }t
t|dd D ]B\}	}
|
jdk	r||
jdk r||	 |
j n|
j q|d qdd t|dd |D }tj| |t||||||d	}tt | tr|  n| t d
}|r@t |tr<||S |S || }t |tr^||S |S )aI  
    Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according
    to ``lazy`` (default ``False``).

    Args:
        img: data to be transformed, assuming `img` is channel-first and cropping doesn't apply to the channel dim.
        slices: the crop slices computed based on specified `center & size` or `start & end` or `slices`.
        lazy: a flag indicating whether the operation should be performed in a lazy fashion or not.
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
    rA   Nr   c                 S  s0   g | ](\}}| |d  || |d  gqS r   indicesr=   r~   or   r   r    r@      s     zcrop_func.<locals>.<listcomp>croppedr   c                 S  s,   g | ]$\}}| |d  | |d  qS )rA   r   r   r   r   r   r    r@      s     r   r   )r-   r   r   r3   r   r9   ro   rx   flattentolist	enumerater   startappendr   r   r
   r   rB   r   r   )r#   r   r   r   r   r   r   r   r   ir~   r3   r   r   r   r   r    r      s0     "
&	")&__doc__
__future__r   r1   r7   r9   r.   torch.nn.functionalr   rD   monai.config.type_definitionsr   monai.data.meta_objr   monai.data.meta_tensorr   monai.data.utilsr   monai.transforms.inverser   monai.transforms.utilsr	   r
   monai.utilsr   r   r   r   r   __all__r!   r;   rH   r   r   r   r   r   r   r   r   r    <module>   s.   ,&>