o
    /iQ                     @  s  d dl mZ d dlZd dlmZmZmZmZ d dlm	Z	 d dl
Zd dlZd dlm  mZ d dlmZ d dlmZmZmZ d dlmZmZmZmZmZmZmZmZm Z  e ddd	\Z!Z"d
Z#dgZ$dej%dej%dddddddddfdDd4dZ&d5d6 Z'd7d8 Z(dEd>d?Z)d@dA Z*dFdBdCZ+dS )G    )annotationsN)CallableIterableMappingSequence)Any)
MetaTensor)compute_importance_mapdense_patch_slicesget_valid_patch_size)		BlendModePytorchPadModeconvert_data_typeconvert_to_dst_typeensure_tupleensure_tuple_repfall_back_tuplelook_up_optionoptional_importtqdm)nameznearest-exactsliding_window_inferenceg      ?g      ?g        Finputstorch.Tensor | MetaTensorroi_sizeSequence[int] | intsw_batch_sizeint	predictorNCallable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]]overlapSequence[float] | floatmodeBlendMode | strsigma_scalepadding_modePytorchPadMode | strcvalfloat	sw_devicetorch.device | str | Nonedeviceprogressboolroi_weight_maptorch.Tensor | None
process_fnCallable | Nonebuffer_steps
int | None
buffer_dim
with_coordargsr   kwargsreturnAtorch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]c           Q        s	  |duo|dk}t jd }|r1|| k s||kr)td|  d| d| d|dk r1||7 }t||}|D ]}|dk sB|dkrJtd	| dq8j}j^}}|
pXj}
|	p]j}	|d
d d}ttrttg j	dd}t
tjddd ttfddt|D }g }tt jd ddD ]}t|d  j|  d}|d }|||| g qt|rtj|t|t|d durtj |t|t|d t|||}t||| dt | } |sd}!td| |}"n-t||||\}#}$}"tj d}!}%|$d|# D ]}&|&d |%k r&d}! n|&d }%qt|}'|'kr>|dur>|}(n>zt|'})t|)|||	|d}(t |(j|kr[|s[|(d }(W n ty{ }* zt d|) d| d| d|
 d	|*d}*~*ww t
|(tj|	|dd }(g g g ddf\}+},}-}.}/|rt!|"n|"D ]:}0t|0t"|0| |r|$|. d n| }1fdd|1D }2|dkrt#fdd|2D $|	}3 durt# fdd|2D $|	}4|4|d
< n|2d  $|	}3 dur |2d  $|	}4|4|d
< |r||3|2g|R i |}5n||3g|R i |}5t%|5\}6}7|r1||7|3|(\}7}8n|(}8t |8j|kr?|8d }8|8j$||	d }8|r|$|. dd \}9}:|-st|7d jd }t&|};|:|9 |;|< tj'd|g|;||	d!g}-t(|7d |2D ].\}<}=|=|d  j)|9 }>t*|>|>|  |=|d < t*dd|=d< |-d |=  |<|8 7  < q{|/t |27 }/|/|$|. d k rqnt&|7}-tt |-D ]}?|-|? j}@|@d |@dd }A}Bd}C|s|Bkrd"d t(|BD }Ctj+|8|Bt,d#}8t |+|?krh||Ag}D|D|Crd$d t(||CD nt&|7 }D|!rtj-ntj'}E|+.|E|D||
d  |,.tj'ddg|Ddd  ||
d  |8$|
}FD ]'}G|CdurStd%d t(|G|CD }G|,d t*dt*dg|GR   |F7  < q@|rt*dgt j }Ht*|9|:|H|d < |.|# }It*|I|Id |Hd< |!r|+d |H j/|-d |!d& q|+d |H  |-d j$|
d'7  < q|-|?  |89  < |-|? $|
|-|?< t0|2|C|+|? |-|?  qg }-|r|.d7 }.q|!rtj1 2  tt |+D ]}?|+|?  |,d  < qt|rb|3d(|i t4|+D ]Z\}?}Jd)d t(|Jjdd D }Kg }Lt|D ]0}M||M d }Nt*t5t6||Md  |K|N  t5t6||Md  |N  |K|N  }O|L7d|O q |Jt*dt*dg|LR  |+|?< qt8|+|6}P|durwt9|P||
d'd }P|PS t9|P|
d'd }P|PS )*a  
    Sliding window inference on `inputs` with `predictor`.

    The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
    Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
    e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
    could be ([128,64,256], [64,32,128]).
    In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
    an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
    so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).

    When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
    To maintain the same spatial sizes, the output image will be cropped to the original input size.

    Args:
        inputs: input image to be processed (assuming NCHW[D])
        roi_size: the spatial window size for inferences.
            When its components have None or non-positives, the corresponding inputs dimension will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        predictor: given input tensor ``patch_data`` in shape NCHW[D],
            The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
            with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
            where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
            N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
            the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
            In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
            to ensure the scaled output ROI sizes are still integers.
            If the `predictor`'s input and output spatial sizes are different,
            we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
        overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a `tqdm` progress bar.
        roi_weight_map: pre-computed (non-negative) weight map for each ROI.
            If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
        process_fn: process inference output and adjust the importance map per window
        buffer_steps: the number of sliding window iterations along the ``buffer_dim``
            to be buffered on ``sw_device`` before writing to ``device``.
            (Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.)
            default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size,
            (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
        buffer_dim: the spatial dimension along which the buffers are created.
            0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
        with_coord: whether to pass the window coordinates to ``predictor``. Default is False.
            If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``.
        args: optional args to be passed to ``predictor``.
        kwargs: optional keyword args to be passed to ``predictor``.

    Note:
        - input must be channel-first and have a batch dim, supports N-D sliding window.

    Nr      zbuffer_dim must be in [z, z], got .   z"overlap must be >= 0 and < 1, got 	conditionF)	copy_attrT)wrap_sequencec                 3  s"    | ]}t  | | V  qd S N)max.0i)image_size_r    V/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/inferers/utils.py	<genexpr>        z+sliding_window_inference.<locals>.<genexpr>r   )padr#   value)return_slice)r#   r%   r,   dtype)NNzpatch size z, mode=z, sigma_scale=z	, device=z^
Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'.)r,   rN   c                   s:   g | ]}t |  |  d  t dgt|    qS )r=   N)slicelist)rD   idx)num_winslicesrG   rH   
<listcomp>   s    ,z,sliding_window_inference.<locals>.<listcomp>c                      g | ]} | qS rG   rG   rD   Z	win_slice)r   rG   rH   rT          c                   rU   rG   rG   rV   )r>   rG   rH   rT      rW   )rN   r,   )sizerN   r,   c                 S  s   g | ]
\}}|t | qS rG   )r)   )rD   Zout_w_iZin_w_irG   rG   rH   rT         )r#   c                 S  s   g | ]
\}}t || qS rG   )r   )rD   _i_zrG   rG   rH   rT     rY   c                 s  s2    | ]\}}t t|j| t|j| V  qd S rA   )rO   r   startstop)rD   _siZz_srG   rG   rH   rI     s   0 )non_blocking)r,   pad_sizec                 S  s   g | ]\}}|| qS rG   rG   )rD   Z_shape_dZ_roi_size_drG   rG   rH   rT   :  s    ):lenshape
ValueErrorr   rN   r,   pop
isinstancer   copy_meta_fromr   torchTensorr   tuplerangerB   extendanyFrK   r   r   _get_scan_intervalr
   _create_buffered_slicescudais_availabler   r   r	   	ExceptionRuntimeErrorr   mincatto_flatten_structrP   zeroszipr\   rO   interpolate_nearest_modeemptyappendcopy__compute_coordscurrent_streamsynchronizeupdate	enumerater   roundinsert_pack_structr   )Qr   r   r   r   r!   r#   r%   r&   r(   r*   r,   r-   r/   r1   r3   r5   r6   r7   r8   bufferednum_spatial_dimsocompute_dtype
batch_size_Z	temp_meta
image_sizer`   kdiffhalfscan_intervalZtotal_slicesr_   windows_rangen_per_batchb_slices_ssxZvalid_patch_sizeZimportance_map_Zvalid_p_sizeeZoutput_image_listZcount_map_listZsw_device_bufferb_sb_iZslice_gZslice_rangeZunravel_sliceZwin_dataZwin_conditionZseg_prob_out	dict_keysZ	seg_tupleZw_tc_startc_endsp_sizepsoffsetssb_shapeZseg_chnsZ	seg_shapez_scaleoutput_shape
new_tensorZw_t_Z__sZo_sliceZimg_boutput_iZ
zoom_scaleZfinal_slicingspsi	slice_dimZfinal_outputrG   )r>   rF   r   rR   r   rS   rH   r   *   sN  _







&



&&

*$
""

c                   s`  t | }|t j|dd|df dd }dd |D } |dd|f }t j|dddf ddd\}}}t | }dg|ddtt|t| d	 |d	 k r\	|d	  td
   fddt
|D }	g }
t|	D ]/\}}||dkr|	|d
  jt|  nddf }||jd
 t|  d
f }|
	|j||f qutj|	 }	|  |
|	fS )zrearrange slices for bufferingNr   	mergesort)kindc                 S  s   g | ]}t d d |D qS )c                 s  s"    | ]}t |d  |d V  qdS )r   r=   N)rO   )rD   crG   rG   rH   rI   R  rJ   z5_create_buffered_slices.<locals>.<listcomp>.<genexpr>)ri   rC   rG   rG   rH   rT   R  s    z+_create_buffered_slices.<locals>.<listcomp>T)return_countsreturn_indexr   r=   c              	     sJ   g | ]!}t  D ]}t |d   |  |d   |d   qqS )r   r=   )rj   )rD   brE   r   r   r   rG   rH   rT   [  s    .)npasarrayargsortuniquecumsumtolistrt   ra   r   r}   rj   r   r]   	itertoolschain)rS   r   r   r5   r3   Z	slices_npr   Z_b_lensZb_endsr   r   _s_rZs_sZs_erG   r   rH   ro   N  s(   
 ""*
ro   c              	   C  s   t | |D ]9\}}t|}|r6tdt|D ]}tt|| j||d   t|| j||d   ||< q||  |7  < qdS )zKsliding window batch spatial scaling indexing for multi-resolution outputs.r;   N)ry   rP   rj   ra   rO   r   r\   r]   )coordsr   outpatchoriginal_idxr   Zidx_zmaxisrG   rG   rH   r   i  s   0
r   r   Sequence[int]r   Sequence[float]tuple[int, ...]c                 C  s   t | |krtdt |  d| dt ||kr&tdt | d| dg }tt||D ]+\}}|| | | krE|t||  q/t|| d|  }||dkrW|nd q/t|S )z
    Compute scan interval according to the image size, roi size and overlap.
    Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
    use 1 instead to make sure sliding window works.

    zlen(image_size) z different from spatial dims r<   zlen(roi_size) r=   r   )ra   rc   ry   rj   r}   r   ri   )r   r   r   r!   r   rE   r   intervalrG   rG   rH   rn   u  s   	rn   c                   sb   d }t  tjr f}||fS t  tr)t  }t fdd|D }||fS t }||fS )Nc                 3  s    | ]} | V  qd S rA   rG   )rD   r   seg_outrG   rH   rI     s    z"_flatten_struct.<locals>.<genexpr>)re   rg   rh   r   sortedkeysri   r   )r   r   Z	seg_probsrG   r   rH   rw     s   
rw   c                 C  s@   |d urt t|| S t| ttfrt| dkr| d S t| S )Nr=   r   )dictry   re   rP   ri   ra   r   )r   r   rG   rG   rH   r     s
   r   )(r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r"   r&   r'   r(   r)   r*   r+   r,   r+   r-   r.   r/   r0   r1   r2   r3   r4   r5   r   r6   r.   r7   r   r8   r   r9   r:   )
r   r   r   r   r   r   r!   r   r9   r   rA   ),
__future__r   r   collections.abcr   r   r   r   typingr   numpyr   rg   torch.nn.functionalnn
functionalrm   monai.data.meta_tensorr   monai.data.utilsr	   r
   r   monai.utilsr   r   r   r   r   r   r   r   r   r   r   r{   __all__CONSTANTr   ro   r   rn   rw   r   rG   rG   rG   rH   <module>   sD   ,  &
