U
    |Ph                     @  s   d dl mZ d dlZd dlZd dlmZ d dlZd dlm	Z	m
Z
 d dlmZ d dlmZmZmZ d dlmZ ede
jed	\ZZerd d
lmZ nede
jed\ZZG dd dZdS )    )annotationsN)TYPE_CHECKING)	DtypeLike
IgniteInfo)FolderLayout)ProbMapKeysmin_versionoptional_import)
CommonKeyszignite.engineEvents)Enginer   c                   @  sv   e Zd ZdZdddejdfdddddd	d
ddZdd	dddZdd	dddZdd	dddZ	dd	dddZ
dS )ProbMapProducera;  
    Event handler triggered on completing every iteration to calculate and save the probability map.
    This handler use metadata from MetaTensor to create the probability map. This can be simply achieved by using
    `monai.data.SlidingPatchWSIDataset` or `monai.data.MaskedPatchWSIDataset` as the dataset.

    z./ predNstrr   z
str | NoneNone)
output_diroutput_postfixprob_keydtypenamereturnc                 C  sZ   t ||ddddd| _t|| _|| _|| _|| _i | _i | _	d| _
d| _t | _dS )a  
        Args:
            output_dir: output directory to save probability maps.
            output_postfix: a string appended to all output file names.
            prob_key: the key associated to the probability output of the model
            dtype: the data type in which the probability map is stored. Default np.float64.
            name: identifier of logging.logger to use, defaulting to `engine.logger`.

        z.npyFTr   )r   postfix	extensionparentmakedirsdata_root_dirr   N)r   folder_layoutlogging	getLoggerlogger_namer   r   prob_mapcounternum_done_images
num_images	threadingLocklock)selfr   r   r   r   r    r*   T/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/handlers/probability_maps.py__init__(   s"    	zProbMapProducer.__init__r   )enginer   c                 C  s   |j jj}t|| _|D ]:}|tj }|tj | j|< t	j
|tj | jd| j|< q| jdkrf|j| _|| tjs|tj|  || jtjs|tj| j dS )zg
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        )r   N)data_loaderdataset
image_datalenr%   r   NAMECOUNTr#   npzerosSIZEr   r"   r!   r    has_event_handlerr   ITERATION_COMPLETEDadd_event_handlerfinalize	COMPLETED)r)   r-   r0   sampler   r*   r*   r+   attachL   s    



zProbMapProducer.attachc              
   C  s   t |jjtrt |jjts$td|jjtj jt	j
 }|jjtj jt	j }|jj| j }t|||D ]X\}}}|| j| t|< | j0 | j|  d8  < | j| dkr| | W 5 Q R X qjdS )z
        This method assumes self.batch_transform will extract metadata from the input batch.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        z@engine.state.batch and engine.state.output must be dictionaries.   r   N)
isinstancestatebatchdictoutput
ValueErrorr
   IMAGEmetar   r2   LOCATIONr   zipr"   tupler(   r#   save_prob_map)r)   r-   nameslocsprobsr   locprobr*   r*   r+   __call__b   s    zProbMapProducer.__call__)r   r   c              	   C  sd   | j |}t|| j|  |  jd7  _| jd| d| j d| j d | j|= | j	|= dS )z
        This method save the probability map for an image, when its inference is finished,
        and delete that probability map from memory.

        Args:
            name: the name of image to be saved.
        r>   zInference of 'z' is done [/z]!N)
r   filenamer4   saver"   r$   r    infor%   r#   )r)   r   	file_pathr*   r*   r+   rJ   u   s    $zProbMapProducer.save_prob_mapc                 C  s"   | j d| j d| j d d S )NzProbability map is created for rQ   z images!)r    rT   r$   r%   )r)   r-   r*   r*   r+   r:      s    zProbMapProducer.finalize)__name__
__module____qualname____doc__r4   float64r,   r=   rP   rJ   r:   r*   r*   r*   r+   r       s   	$r   )
__future__r   r   r&   typingr   numpyr4   monai.configr   r   monai.data.folder_layoutr   monai.utilsr   r   r	   monai.utils.enumsr
   OPT_IMPORT_VERSIONr   _ignite.enginer   r   r*   r*   r*   r+   <module>   s   