U
    |Ph                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlZd dl	m
Z
 d dlmZmZ d dlmZ d dlmZmZmZmZ ed	e
jed
\ZZede
jed\ZZerd dlmZ nede
jed\ZZG dd dZdS )    )annotationsN)Callable)TYPE_CHECKING)
IgniteInfo)CSVSaverdecollate_batch)ImageMetaKey)evenly_divisible_all_gathermin_versionoptional_importstring_list_all_gatherignitedistributedzignite.engineEvents)Enginer   c                   @  s   e Zd ZdZdddddd dd d	d
d	f	ddddddddddd
ddZdddddZdddddZdddddZdddddZd	S ) ClassificationSaverz
    Event handler triggered on completing every iteration to save the classification predictions as CSV file.
    If running in distributed data parallel, only saves CSV file in the specified rank.

    z./zpredictions.csv,Tc                 C  s   | S N xr   r   X/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/handlers/classification_saver.py<lambda>/       zClassificationSaver.<lambda>c                 C  s   | S r   r   r   r   r   r   r   0   r   Nr   strboolr   z
str | NoneintzCSVSaver | NoneNone)

output_dirfilename	delimiter	overwritebatch_transformoutput_transformname	save_ranksaverreturnc
           
      C  sR   || _ || _|| _|| _|| _|| _|| _|	| _t	|| _
|| _g | _g | _dS )a  
        Args:
            output_dir: if `saver=None`, output CSV file directory.
            filename: if `saver=None`, name of the saved CSV file name.
            delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
                to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
            overwrite: if `saver=None`, whether to overwriting existing file content, if True,
                will clear the file before saving. otherwise, will append new content to the file.
            batch_transform: a callable that is used to extract the `meta_data` dictionary of
                the input images from `ignite.engine.state.batch`. the purpose is to get the input
                filenames from the `meta_data` and store with classification results together.
                `engine.state` and `batch_transform` inherit from the ignite concept:
                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
            output_transform: a callable that is used to extract the model prediction data from
                `ignite.engine.state.output`. the first dimension of its output will be treated as
                the batch dimension. each item in the batch will be saved individually.
                `engine.state` and `output_transform` inherit from the ignite concept:
                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
            name: identifier of logging.logger to use, defaulting to `engine.logger`.
            save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation,
                default to 0.
            saver: the saver instance to save classification results, if None, create a CSVSaver internally.
                the saver must provide `save_batch(batch_data, meta_data)` and `finalize()` APIs.

        N)r%   r   r   r    r!   r"   r#   r&   logging	getLoggerlogger_name_outputs
_filenames)
selfr   r   r    r!   r"   r#   r$   r%   r&   r   r   r   __init__)   s    'zClassificationSaver.__init__r   )enginer'   c                 C  sr   | j dkr|j| _|| jtjs2|tj| j || tjsN|tj|  || jtj	sn|tj	| j dS )zg
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)
r+   r*   Zhas_event_handler_startedr   EPOCH_STARTEDadd_event_handlerITERATION_COMPLETED	_finalizeEPOCH_COMPLETED)r.   r0   r   r   r   attach^   s    
zClassificationSaver.attach)_enginer'   c                 C  s   g | _ g | _dS )zs
        Initialize internal buffers.

        Args:
            _engine: Ignite Engine, unused argument.

        N)r,   r-   )r.   r8   r   r   r   r1   l   s    zClassificationSaver._startedc                 C  s|   |  |jj}t|tr t|}| |jj}t||D ]>\}}| j	
|tj  t|tjrj| }| j
| q8dS )z
        This method assumes self.batch_transform will extract metadata from the input batch.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)r"   statebatch
isinstancedictr   r#   outputzipr-   appendgetKeyFILENAME_OR_OBJtorchTensordetachr,   )r.   r0   	meta_dataZengine_outputmor   r   r   __call__w   s    
zClassificationSaver.__call__c                 C  s   t  }| j|krtdtj| jdd}| j}|dkrLt|dd}t	|}t
|dkr^d}n:t
|t
|krtdt
| d	t
| d
 tj|i}t  | jkr| jpt| j| j| j| jd}||| |  dS )z
        All gather classification results from ranks and save to CSV file.

        Args:
            _engine: Ignite Engine, unused argument.
        z<target save rank is greater than the distributed group size.r   )dim   T)concatNzfilenames length: z doesn't match outputs length: .)r   r   r!   r    )idistget_world_sizer%   
ValueErrorrC   stackr,   r-   r	   r   lenwarningswarnrA   rB   get_rankr&   r   r   r   r!   r    
save_batchfinalize)r.   r8   wsoutputs	filenames	meta_dictr&   r   r   r   r5      s,    
 
   zClassificationSaver._finalize)	__name__
__module____qualname____doc__r/   r7   r1   rI   r5   r   r   r   r   r   "   s   "5r   )
__future__r   r(   rS   collections.abcr   typingr   rC   monai.configr   
monai.datar   r   monai.utilsr   rA   r	   r
   r   r   OPT_IMPORT_VERSIONrN   _r   ignite.enginer   r   r   r   r   r   <module>   s   