o
    i`1                     @  s   d Z ddlmZ ddlZddlZddlZddlmZ	 ddl
mZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZmZmZmZmZ g dZdd Zd.ddZd.ddZ ej!fd/ddZ"d0d d!Z#ej!d"fd1d(d)Z$d2d,d-Z%dS )3zA
A collection of "functional" transforms for spatial operations.
    )annotationsN)pad)NdarrayTensor)get_track_meta)
MetaTensor)to_affine_nd)TraceableTransform)convert_pad_modecreate_translate)PytorchPadModeconvert_to_dst_typeconvert_to_numpyconvert_to_tensorensure_tuple)pad_ndpad_func	crop_funccrop_or_pad_ndc                 C  s8   | du s| dv rt jS | dv rt jS | dv rt jS t jS )zSget the most similar mode of `pad` from ``padding_mode`` of the spatial resampling.N)zerosconstantzgrid-constant)
reflectionreflectmirrorzgrid-mirror)wrapz	grid-wrap)r   CONSTANTREFLECTCIRCULAR	REPLICATE)padding_mode r   e/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/transforms/croppad/functional.py_convert_pt_pad_mode"   s   r!   imgr   	pad_widthlist[tuple[int, int]]modestrreturnc              	   K  s   t | tjr#| jrtd| j d| j d| d |  	 
 }nt| }t||dj}|dkr>d|v r>|d|d< tj||fd	|i|}t|| d
d S )NzPadding: moving img z from cuda to cpu for dtype=z mode=.dstr%   r   valueconstant_valuesr%   r*   r   )
isinstancetorchTensoris_cudawarningswarnshapedtypedetachcpunumpynpasarrayr	   r+   popr   r   )r"   r#   r%   kwargsZimg_npr   r   r    _np_pad-   s   "
r=   c                 K  s   t | }t||dj}|dkr d|v r | }|d|d< n|}dd |dd  D d d d }t|d	|fd
|i|d	}t	|| dd	 S )Nr)   r   r,   r+   c                 S  s$   g | ]}|d d d D ]}|qqS )Nr   ).0sublistvalr   r   r    
<listcomp>C      $ z_pt_pad.<locals>.<listcomp>   r>   r   r%   r-   )
r/   	as_tensorr	   r+   copyr;   pad_pt	unsqueezesqueezer   )r"   r#   r%   r<   Zimg_pt_kwargsZpt_pad_widthr   r   r    _pt_pad;   s   
 "rK   to_padc                   s  |dv rt | f||d|S z!t }|dv r%| jtjtjtjtjhvr%t}|| f||d|W S  tt	t
fy   zBt tsKt fdddD r\t | f||d|W  Y d  S t| j d| d| d| d| j dt| tjrw| jnd  d  ww )	a  
    Pad `img` for a given an amount of padding in each dimension.

    `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
    in which case `np.pad` will be used.

    Args:
        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
        to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
            default to `self.to_pad`.
        mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.
    >   emptymaximummeanlinear_rampminimum	symmetricmedian)r#   r%   >   r   circularr   r   	replicateedgec                 3  s    | ]	}|t  v V  qd S )N)r&   )r?   kerrr   r    	<genexpr>l   s    
zpad_nd.<locals>.<genexpr>)	supportedzunexpected keywordZimplementedr+   N )r=   r5   r/   int16int64booluint8rK   
ValueError	TypeErrorRuntimeErrorr.   NotImplementedErroranyr4   r0   device)r"   rL   r%   r<   _padr   rX   r    r   I   s0   ":r   torch.Tensorspatial_sizetuple[int, ...]c                 K  s  t | jd }tt|t|dd }tt ||}ttjdd |D ddi}|	t |df}t
|t|t|d	d f }|jdd
|jdd
}	}
dgtd	gddf\}}}}t|	|
| jdd	 D ]S\}}}|p~|dk p~||d k|p|dkp||d k }}||dkrdnt| ||d k rdnt|| d fg7 }|ttt|dt|d |d d  g7 }qn|rt|}t| |fd|i|} |r| | } | S )a  
    Crop or pad using the translation matrix and spatial size. The translation coefficients are rounded
    to the nearest integers. For a more generic implementation, please see :py:class:`monai.transforms.SpatialResample`.

    Args:
        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
        translation_mat: the translation matrix to be applied to the image. A translation matrix generated by,
            for example, :py:func:`monai.transforms.utils.create_translate`. The translation coefficients are rounded
            to the nearest integers.
        spatial_size: the spatial size of the output image.
        mode: the padding mode.
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
    rD   T)wrap_sequencec                 S  s   g | ]}d |d  gqS )g      ?r   )r?   xr   r   r    rB      s    z"crop_or_pad_nd.<locals>.<listcomp>indexingijr>   N)axisr   r   Fr   r%   )lenr4   r9   roundr   r   rF   r:   meshgridreshapefloorconcatenate	ones_likeminmaxslicezipintr!   r   )r"   Ztranslation_matri   r%   r<   ndim	matrix_npccZsrc_ccZ	src_startZsrc_endrL   Zto_cropdo_padZdo_cropsesp_moder   r   r    r   u   s$    & 2<0r   Ftuple[tuple[int, int]]transform_infodictlazyr_   c              	   K  s  || d}t | tr|  n| jdd }t | tr|  nd}t| }	|	rgdd |D }
t|
t| jk rH|
dgt| jt|
  7 }
dd |
dd D }t	||}d	d t
||
dd D }n|}tjt|d td
tjd}tj| ||||||d}tt | tr|  n| t d}|rt |tr||S |S |	rt||
|fi |n|}t|t d}t |tr||S |S )a7  
    Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
    to ``lazy`` (default ``False``).

    `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
    in which case `np.pad` will be used.

    Args:
        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
        to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
            note that it including channel dimension.
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
        mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        lazy: a flag indicating whether the operation should be performed in a lazy fashion or not.
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.
    )paddedr%   rD   N   c                 S  s$   g | ]}t |d  t |d fqS r   rD   )r|   )r?   pr   r   r    rB      rC   zpad_func.<locals>.<listcomp>rp   c                 S  s   g | ]}|d   qS )r   r   )r?   r   r   r   r    rB      s    c                 S  s    g | ]\}\}}|| | qS r   r   )r?   dr   r   r   r   r    rB      s     r7   )rf   r5   sp_sizeaffine
extra_info	orig_sizer   r   
track_meta)r.   r   peek_pending_shaper4   peek_pending_rankr9   r:   re   rq   r
   r{   r/   eyer|   rf   float64r   track_transform_metar   rE   r   copy_meta_fromr   )r"   rL   r   r%   r   r<   r   img_sizespatial_rankr   Zto_pad_listto_shiftxformr4   	meta_infooutr   r   r    r      s8    
  	r   slicestuple[slice, ...]c              	   C  sZ  t | tr	|  n| jdd }t | tr|  nd}tdd t|dd |D }d| 	 i}g }t
t|dd D ]!\}	}
|
jdur\||
jdk rW||	 |
j n|
j q@|d q@dd t|dd |D }tj| |t||||||d	}tt | tr|  n| t d
}|rt |tr||S |S || }t |tr||S |S )aI  
    Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according
    to ``lazy`` (default ``False``).

    Args:
        img: data to be transformed, assuming `img` is channel-first and cropping doesn't apply to the channel dim.
        slices: the crop slices computed based on specified `center & size` or `start & end` or `slices`.
        lazy: a flag indicating whether the operation should be performed in a lazy fashion or not.
        transform_info: a dictionary with the relevant information pertaining to an applied transform.
    rD   Nr   c                 S  s0   g | ]\}}| |d  || |d  gqS r   indicesr?   r   or   r   r    rB      s   0 zcrop_func.<locals>.<listcomp>croppedr   c                 S  s,   g | ]\}}| |d  | |d  qS )rD   r   r   r   r   r   r    rB      s   , r   r   )r.   r   r   r4   r   r9   r:   r{   flattentolist	enumerater   startappendr   r   r
   r   rE   r   r   )r"   r   r   r   r   r   r   r   r   ir   r4   r   r   r   r   r    r      s0    "
& 	r   )r"   r   r#   r$   r%   r&   r'   r   )r"   r   rL   r$   r%   r&   r'   r   )r"   rh   ri   rj   r%   r&   )r"   rh   rL   r   r   r   r%   r&   r   r_   r'   rh   )
r"   rh   r   r   r   r_   r   r   r'   rh   )&__doc__
__future__r   r2   r8   r9   r/   torch.nn.functionalr   rG   monai.config.type_definitionsr   monai.data.meta_objr   monai.data.meta_tensorr   monai.data.utilsr   monai.transforms.inverser   monai.transforms.utilsr	   r
   monai.utilsr   r   r   r   r   __all__r!   r=   rK   r   r   r   r   r   r   r   r   r    <module>   s0   



,&>