o
    i                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlZd dl	m
Z
mZ d dlmZ d dlmZ d dlmZmZmZmZ ed	ejed
\ZZedejed\ZZer[d dlmZ n
edejed\ZZG dd dZdS )    )annotationsN)Callable)TYPE_CHECKING)CSVSaverdecollate_batch)
IgniteInfo)ImageMetaKey)evenly_divisible_all_gathermin_versionoptional_importstring_list_all_gatherignitedistributedzignite.engineEvents)Enginer   c                	   @  s^   e Zd ZdZdddddd dd d	d
d	f	d)ddZd*d d!Zd+d#d$Zd*d%d&Zd+d'd(Zd	S ),ClassificationSaverz
    Event handler triggered on completing every iteration to save the classification predictions as CSV file.
    If running in distributed data parallel, only saves CSV file in the specified rank.

    z./zpredictions.csv,Tc                 C     | S N xr   r   e/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/handlers/classification_saver.py<lambda>/       zClassificationSaver.<lambda>c                 C  r   r   r   r   r   r   r   r   0   r   Nr   
output_dirstrfilename	delimiter	overwriteboolbatch_transformr   output_transformname
str | None	save_rankintsaverCSVSaver | NonereturnNonec
           
      C  sR   || _ || _|| _|| _|| _|| _|| _|	| _t	|| _
|| _g | _g | _dS )a  
        Args:
            output_dir: if `saver=None`, output CSV file directory.
            filename: if `saver=None`, name of the saved CSV file name.
            delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
                to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
            overwrite: if `saver=None`, whether to overwriting existing file content, if True,
                will clear the file before saving. otherwise, will append new content to the file.
            batch_transform: a callable that is used to extract the `meta_data` dictionary of
                the input images from `ignite.engine.state.batch`. the purpose is to get the input
                filenames from the `meta_data` and store with classification results together.
                `engine.state` and `batch_transform` inherit from the ignite concept:
                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
            output_transform: a callable that is used to extract the model prediction data from
                `ignite.engine.state.output`. the first dimension of its output will be treated as
                the batch dimension. each item in the batch will be saved individually.
                `engine.state` and `output_transform` inherit from the ignite concept:
                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
            name: identifier of logging.logger to use, defaulting to `engine.logger`.
            save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation,
                default to 0.
            saver: the saver instance to save classification results, if None, create a CSVSaver internally.
                the saver must provide `save_batch(batch_data, meta_data)` and `finalize()` APIs.

        N)r%   r   r   r   r   r!   r"   r'   logging	getLoggerlogger_name_outputs
_filenames)
selfr   r   r   r   r!   r"   r#   r%   r'   r   r   r   __init__)   s   '
zClassificationSaver.__init__enginer   c                 C  sv   | j du r	|j| _|| jtjs|tj| j || tjs'|tj|  || jtj	s9|tj	| j dS dS )zg
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)
r.   r-   Zhas_event_handler_startedr   EPOCH_STARTEDadd_event_handlerITERATION_COMPLETED	_finalizeEPOCH_COMPLETED)r1   r3   r   r   r   attach^   s   
zClassificationSaver.attach_enginec                 C  s   g | _ g | _dS )zs
        Initialize internal buffers.

        Args:
            _engine: Ignite Engine, unused argument.

        N)r/   r0   )r1   r;   r   r   r   r4   l   s   
zClassificationSaver._startedc                 C  s|   |  |jj}t|trt|}| |jj}t||D ]\}}| j	
|tj  t|tjr5| }| j
| qdS )z
        This method assumes self.batch_transform will extract metadata from the input batch.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)r!   statebatch
isinstancedictr   r"   outputzipr0   appendgetKeyFILENAME_OR_OBJtorchTensordetachr/   )r1   r3   	meta_dataZengine_outputmor   r   r   __call__w   s   
zClassificationSaver.__call__c                 C  s   t  }| j|krtdtj| jdd}| j}|dkr&t|dd}t	|}t
|dkr/d}nt
|t
|krGtdt
| d	t
| d
 tj|i}t  | jkrn| jpat| j| j| j| jd}||| |  dS dS )z
        All gather classification results from ranks and save to CSV file.

        Args:
            _engine: Ignite Engine, unused argument.
        z<target save rank is greater than the distributed group size.r   )dim   T)concatNzfilenames length: z doesn't match outputs length: .)r   r   r   r   )idistget_world_sizer%   
ValueErrorrF   stackr/   r0   r	   r   lenwarningswarnrD   rE   get_rankr'   r   r   r   r   r   
save_batchfinalize)r1   r;   wsoutputs	filenames	meta_dictr'   r   r   r   r8      s(   
 
zClassificationSaver._finalize)r   r   r   r   r   r   r   r    r!   r   r"   r   r#   r$   r%   r&   r'   r(   r)   r*   )r3   r   r)   r*   )r;   r   r)   r*   )	__name__
__module____qualname____doc__r2   r:   r4   rL   r8   r   r   r   r   r   "   s     
5

r   )
__future__r   r+   rV   collections.abcr   typingr   rF   
monai.datar   r   monai.utilsr   r   rD   r	   r
   r   r   OPT_IMPORT_VERSIONrQ   _r   ignite.enginer   r   r   r   r   r   <module>   s    