o
    i#                     @  s   d dl mZ d dlmZ d dlmZmZ d dlZd dl	Z	d dl
m  mZ d dlmZ d dlmZ d dlmZmZmZ dgZG d	d deZdS )
    )annotations)Sequence)AnyCallableN)SlidingWindowInferer)sliding_window_inference)	BlendModePytorchPadModelook_up_optionSlidingWindowHoVerNetInfererc                      sR   e Zd ZdZddejdejdddddddfd/ fd!d"Zd#d$ Zd0d-d.Z	  Z
S )1r   a  
    Sliding window method for HoVerNet model inference,
    with `sw_batch_size` windows for every model.forward().
    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.

    Args:
        roi_size: the window size to execute SlidingWindow evaluation.
            If it has non-positive components, the corresponding `inputs` size will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        overlap: Amount of overlap between scans.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a tqdm progress bar.
        cache_roi_weight_map: whether to pre-compute the ROI weight map.
        cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)
            when input image volume is larger than this threshold (in pixels/voxels).
            Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu.
        extra_input_padding: the amount of padding for the input image, which is a tuple of even number of pads.
            Refer to to the `pad` argument of `torch.nn.functional.pad` for more details.

    Note:
        ``sw_batch_size`` denotes the max number of windows per network inference iteration,
        not the batch size of inputs.

       g      ?g      ?g        NFroi_sizeSequence[int] | intsw_batch_sizeintoverlapfloatmodeBlendMode | strsigma_scaleSequence[float] | floatpadding_modePytorchPadMode | strcval	sw_devicetorch.device | str | Nonedeviceprogressboolcache_roi_weight_map
cpu_thresh
int | Noneextra_input_paddingtuple[int] | NonereturnNonec                   s.   t  j|||||||||	|
||d || _d S )N)r   r   r   r   r   r   r   r   r   r   r   r    )super__init__r"   )selfr   r   r   r   r   r   r   r   r   r   r   r    r"   	__class__ g/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/pathology/inferers/inferer.pyr'   M   s   
z%SlidingWindowHoVerNetInferer.__init__c                   s   |j dd  }|d j dd  }g g }t||D ]$\}}t|| d}	|	d }
|	|
 }|
|g |t|
||  qtj||j|j	d}|| ||< t
 fdd|D }||fS )N   r   )dtyper   c                 3  s*    | ]}t j|t j jd V  qdS )padr   valueN)Fr0   tupler   r   ).0Zseg_probr(   Zwindow_pad_sizer+   r,   	<genexpr>~   s
    
z>SlidingWindowHoVerNetInferer.process_output.<locals>.<genexpr>)shapezipmaxextendappendslicetorchzerosr.   r   r3   )r(   Zseg_prob_tupleZwindow_dataimportance_map_window_shape	seg_shapeZwindow_pad_slicesZwindow_sZoutput_s	pad_widthZ
pad_half_1Z
pad_half_2importance_mapr+   r5   r,   process_outputm   s    z+SlidingWindowHoVerNetInferer.process_outputinputstorch.TensornetworkNCallable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]]argsr   kwargsAtorch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]c                   s  | j }|du r| jdur|jdd  | jkrd}| jr9|jdd }t|}tj|t| jt	| j
t| jd}t|| j| j|| j| j| j| j
| j| j|| j| j| j| j| jdg|R i |}| jrg  t| jd }	t|	D ]}
t| j|
d  |||
 d  | j|
d   } d| qqtt|j|	 D ]
} dtd qt|tr| D ]
\}}|  ||< q|S t|ttfrt | fdd	|D }|S t|t!j"t#j$fr|  }|S t%d
t | d|S )ag  

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.

        Nr-   cpur/   Fr   r   c                   s   g | ]}|  qS r+   r+   )r4   resZextra_slicingr+   r,   
<listcomp>   s    z9SlidingWindowHoVerNetInferer.__call__.<locals>.<listcomp>zThe output [zC] should be either dict, list, tuple, torch.Tensor, or numpy array.)&r   r    r7   numelr"   lenr2   r0   r3   r
   r   r	   r   r   r   r   r   r   r   r   r   roi_weight_maprD   buffer_steps
buffer_dimranger<   insert
isinstancedictitemslisttyper=   Tensornpndarray
ValueError)r(   rE   rG   rI   rJ   r   Zimage_size_originalnum_spatial_dimsresultsZnum_padded_dimssp	slice_dim_kvr+   rN   r,   __call__   sx   *


z%SlidingWindowHoVerNetInferer.__call__)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   )
rE   rF   rG   rH   rI   r   rJ   r   r$   rK   )__name__
__module____qualname____doc__r   CONSTANTr	   r'   rD   rg   __classcell__r+   r+   r)   r,   r      s"    3 )
__future__r   collections.abcr   typingr   r   numpyr]   r=   torch.nn.functionalnn
functionalr2   monai.inferersr   monai.inferers.utilsr   monai.utilsr   r	   r
   __all__r   r+   r+   r+   r,   <module>   s   