U
    Ph#                     @  s   d dl mZ d dlmZmZmZ d dlZd dlZd dl	m
  mZ d dlmZ d dlmZ d dlmZmZmZ dgZG dd deZdS )	    )annotations)AnyCallableSequenceN)SlidingWindowInferer)sliding_window_inference)	BlendModePytorchPadModelook_up_optionSlidingWindowHoVerNetInfererc                      s|   e Zd ZdZddejdejdddddddfdd	d
dddd
dddddddd fddZdd ZddddddddZ	  Z
S )r   a  
    Sliding window method for HoVerNet model inference,
    with `sw_batch_size` windows for every model.forward().
    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.

    Args:
        roi_size: the window size to execute SlidingWindow evaluation.
            If it has non-positive components, the corresponding `inputs` size will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        overlap: Amount of overlap between scans.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a tqdm progress bar.
        cache_roi_weight_map: whether to pre-compute the ROI weight map.
        cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)
            when input image volume is larger than this threshold (in pixels/voxels).
            Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu.
        extra_input_padding: the amount of padding for the input image, which is a tuple of even number of pads.
            Refer to to the `pad` argument of `torch.nn.functional.pad` for more details.

    Note:
        ``sw_batch_size`` denotes the max number of windows per network inference iteration,
        not the batch size of inputs.

       g      ?g      ?g        NFzSequence[int] | intintfloatzBlendMode | strzSequence[float] | floatzPytorchPadMode | strztorch.device | str | Noneboolz
int | Noneztuple[int] | NoneNone)roi_sizesw_batch_sizeoverlapmodesigma_scalepadding_modecval	sw_devicedeviceprogresscache_roi_weight_map
cpu_threshextra_input_paddingreturnc                   s.   t  j|||||||||	|
||d || _d S )N)r   r   r   r   r   r   r   r   r   r   r   r   )super__init__r   )selfr   r   r   r   r   r   r   r   r   r   r   r   r   	__class__ Z/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/pathology/inferers/inferer.pyr    L   s    z%SlidingWindowHoVerNetInferer.__init__c                   s   |j dd  }|d j dd  }g g }t||D ]H\}}t|| d}	|	d }
|	|
 }|
|g |t|
||  q2tj||j|j	d}|| ||< t
 fdd|D }||fS )N   r   )dtyper   c                 3  s(   | ] }t j|t j jd V  qdS )padr   valueN)Fr)   tupler   r   ).0Zseg_probr!   Zwindow_pad_sizer$   r%   	<genexpr>}   s   z>SlidingWindowHoVerNetInferer.process_output.<locals>.<genexpr>)shapezipmaxextendappendslicetorchzerosr'   r   r,   )r!   Zseg_prob_tupleZwindow_dataimportance_map_window_shape	seg_shapeZwindow_pad_slicesZwindow_sZoutput_s	pad_widthZ
pad_half_1Z
pad_half_2importance_mapr$   r.   r%   process_outputl   s     z+SlidingWindowHoVerNetInferer.process_outputztorch.TensorzNCallable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]]r   zAtorch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor])inputsnetworkargskwargsr   c                   s  | j }|dkr4| jdk	r4|jdd  | jkr4d}| jrr|jdd }t|}tj|t| jt	| j
t| jd}t|| j| j|| j| j| j| j
| j| j|| j| j| j| j| jdf||}| jrg  t| jd }	t|	D ]>}
t| j|
d  |||
 d  | j|
d   } d| qtt|j|	 D ]} dtd q0t|trv| D ]\}}|  ||< q\n^t|ttfrt | fdd	|D }n2t|t!j"t#j$fr|  }nt%d
t | d|S )ag  

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.

        Nr&   cpur(   Fr   r   c                   s   g | ]}|  qS r$   r$   )r-   resZextra_slicingr$   r%   
<listcomp>   s     z9SlidingWindowHoVerNetInferer.__call__.<locals>.<listcomp>zThe output [zC] should be either dict, list, tuple, torch.Tensor, or numpy array.)&r   r   r0   numelr   lenr+   r)   r,   r
   r   r	   r   r   r   r   r   r   r   r   r   roi_weight_mapr=   buffer_steps
buffer_dimranger5   insert
isinstancedictitemslisttyper6   Tensornpndarray
ValueError)r!   r>   r?   r@   rA   r   Zimage_size_originalnum_spatial_dimsresultsZnum_padded_dimssp	slice_dim_kvr$   rD   r%   __call__   sr    *

z%SlidingWindowHoVerNetInferer.__call__)__name__
__module____qualname____doc__r   CONSTANTr	   r    r=   r]   __classcell__r$   r$   r"   r%   r      s    3. )
__future__r   typingr   r   r   numpyrS   r6   torch.nn.functionalnn
functionalr+   monai.inferersr   monai.inferers.utilsr   monai.utilsr   r	   r
   __all__r   r$   r$   r$   r%   <module>   s   