U
    vPh6i                     @  s  d dl mZ d dlZd dlmZmZmZmZmZ d dl	Z	d dl
mZ d dlmZmZ d dlmZ d dlmZmZmZ d dlmZ d d	lmZmZ d d
lmZmZ d dlmZ d dlm Z m!Z!m"Z"m#Z# d dl$m%Z& d dl$m'Z( d dl)m*Z*m+Z+ erd dl,m-Z-m.Z. d dl/m0Z0 n<e#dej1e"d\Z-Z2e#dej1e"d\Z0Z2e#dej1e"d\Z.Z2dddgZ3G dd deZ4G dd de4Z5G dd de4Z6dS )    )annotationsN)TYPE_CHECKINGAnyCallableIterableSequence)
DataLoader)
IgniteInfoKeysCollection)
MetaTensor)IterationEventsdefault_metric_cmp_fndefault_prepare_batch)Workflow)InfererSimpleInferer)	eval_mode
train_mode)	Transform)ForwardModeensure_tuplemin_versionoptional_import)
CommonKeys)EngineStatsKeys)look_up_optionpytorch_after)Engine	EventEnum)Metriczignite.enginer   zignite.metricsr   r   	EvaluatorSupervisedEvaluatorEnsembleEvaluatorc                      s   e Zd ZdZddeddddeddejdddddfddddd	d
dddd	dddddddddd fddZdddd fddZ	dd Z
  ZS )r    aE  
    Base class for all kinds of evaluators, inherits from Workflow.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        mode: model forward mode during evaluation, should be 'eval' or 'train',
            which maps to `model.eval()` or `model.train()`, default to 'eval'.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.

    NFTztorch.device | strIterable | DataLoader
int | Noneboolr   #Callable[[Engine, Any], Any] | NoneTransform | Nonedict[str, Metric] | NoneSequence | NoneForwardMode | str.list[str | EventEnum | type[EventEnum]] | Nonedict | NoneNone)deviceval_data_loaderepoch_lengthnon_blockingprepare_batchiteration_updatepostprocessingkey_val_metricadditional_metricsmetric_cmp_fnval_handlersampmodeevent_namesevent_to_attr	decollate	to_kwargs
amp_kwargsreturnc                   sr   t  j|d||||||||	|
|||||||d t|t}|tjkrLt| _n"|tjkr^t| _nt	d| dd S )N   )r.   
max_epochsdata_loaderr0   r1   r2   r3   r4   
key_metricr6   r7   handlersr9   r;   r<   r=   r>   r?   zunsupported mode: z, should be 'eval' or 'train'.)
super__init__r   r   EVALr   r:   TRAINr   
ValueError)selfr.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   	__class__ L/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/engines/evaluator.pyrG   Y   s4    


zEvaluator.__init__rA   int)global_epochr@   c                   s0   t |d| j_|d | j_d| j_t   dS )z
        Execute validation/evaluation based on Ignite Engine.

        Args:
            global_epoch: the overall epoch if during a training. evaluator engine can get it from trainer.

        rA   r   N)maxstaterB   epoch	iterationrF   run)rK   rQ   rL   rN   rO   rV      s    	zEvaluator.runc                 G  sB   t j| jjt j| jjt j| jji}|D ]}t| j|d||< q&|S )a  
        Get the statistics information of the validation process.
        Default to return the `rank`, `best_validation_epoch` and `best_validation_metric`.

        Args:
            vars: except for the default stats, other variables name in the `self.state` to return,
                will use the variable name as the key and the state content as the value.
                if the variable doesn't exist, default value is `None`.

        N)	ESKeysRANKrS   rankBEST_VALIDATION_EPOCHZbest_metric_epochBEST_VALIDATION_METRICbest_metricgetattr)rK   varsstatskrN   rN   rO   	get_stats   s       zEvaluator.get_stats)rA   )__name__
__module____qualname____doc__r   r   r   rH   rG   rV   ra   __classcell__rN   rN   rL   rO   r    +   s(   181c                      s   e Zd ZdZddedddddeddejdddddddfddddd	d
dddddd
dd	dddd	ddd	ddd fddZd dddddZ	  Z
S )r!   a  
    Standard supervised evaluation method with image and label(optional), inherits from evaluator and Workflow.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
        network: network to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        mode: model forward mode during evaluation, should be 'eval' or 'train',
            which maps to `model.eval()` or `model.train()`, default to 'eval'.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
        compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
            `torch.Tensor` before forward pass,  then converted back afterward with copied meta information.
        compile_kwargs: dict of the args for `torch.compile()` API, for more details:
            https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.

    NFTtorch.devicer#   ztorch.nn.Moduler$   r%   r   r&   Inferer | Noner'   r(   r)   r*   r+   r,   r-   )r.   r/   networkr0   r1   r2   r3   infererr4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   compilecompile_kwargsr@   c                   s   t  j|||||||	|
||||||||||d |rhtddr^|d krJi n|}tj|f|}n
td || _|| _|d krt n|| _	d S )Nr.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?      rA   zeNetwork compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done)
rF   rG   r   torchrk   warningswarnri   r   rj   )rK   r.   r/   ri   r0   r1   r2   r3   rj   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   rk   rl   rL   rN   rO   rG      s<    
zSupervisedEvaluator.__init__dict[str, torch.Tensor]dictengine	batchdatar@   c              
   C  s  |dkrt d|j||jj|jf|j}t|dkrH|\}}d}i }n|\}}}}| jrd\}}	}
}t|t	rt
d | |j|j  }}}
t|t	r| |j|j  }}	}tj|tj|i|j_||jn |jrtjjjf |j& |j||jf|||jjtj< W 5 Q R X n |j||jf|||jjtj< W 5 Q R X | jr|dk	rt	|||
d|jjtj< t	|jjtj ||
d|jjtj< |	dk	rt	||	|d|jjtj< |tj |tj |jjS )a  
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: `SupervisedEvaluator` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        N.Must provide batch data for current iteration.rn   rN   )NNNNzgWill convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.)metaapplied_operations) rJ   r2   rS   r.   r1   r>   lenrk   
isinstancer   rp   rq   	as_tensorrx   ry   KeysIMAGELABELoutputr:   ri   r9   ro   cudaautocastr?   rj   PRED
fire_eventr   FORWARD_COMPLETEDMODEL_COMPLETED)rK   ru   rv   batchinputstargetsargskwargsZinputs_metaZtargets_metaZinputs_applied_operationsZtargets_applied_operationsrN   rN   rO   
_iteration  sd    



,*
    
  zSupervisedEvaluator._iterationrb   rc   rd   re   r   r   r   rH   rG   r   rf   rN   rN   rL   rO   r!      s,   8@9c                      s   e Zd ZdZdddedddddeddejdddddfddddd	d
ddddddddd
dddd
dddd fddZd dddddZ	  Z
S )r"   a  
    Ensemble evaluation for multiple models, inherits from evaluator and Workflow.
    It accepts a list of models for inference and outputs a list of predictions for further operations.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`.
        pred_keys: the keys to store every prediction data.
            the length must exactly match the number of networks.
            if None, use "pred_{index}" as key corresponding to N networks, index from `0` to `N-1`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        mode: model forward mode during evaluation, should be 'eval' or 'train',
            which maps to `model.eval()` or `model.train()`, default to 'eval'.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.

    NFTrg   r#   zSequence[torch.nn.Module]zKeysCollection | Noner$   r%   r   r&   rh   r'   r(   r)   r*   r+   r,   r-   )r.   r/   networks	pred_keysr0   r1   r2   r3   rj   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   r@   c                   s   t  j|||||||
|||||||||||d t|| _|d krZdd tt| jD nt|| _t| jt| jkrtd|	d krt n|	| _	d S )Nrm   c                 S  s   g | ]}t j d | qS )_)r}   r   ).0irN   rN   rO   
<listcomp>  s     z.EnsembleEvaluator.__init__.<locals>.<listcomp>z?length of `pred_keys` must be same as the length of `networks`.)
rF   rG   r   r   rangerz   r   rJ   r   rj   )rK   r.   r/   r   r   r0   r1   r2   r3   rj   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   rL   rN   rO   rG     s4    
&zEnsembleEvaluator.__init__rr   rs   rt   c           
      C  sF  |dkrt d|j||jj|jf|j}t|dkrH|\}}d}i }n|\}}}}tj|tj	|i|j_
t|jD ]\}}	||	 |jrtjjjf |j< t|jj
tr|jj
|j| |j||	f||i W 5 Q R X n8t|jj
tr|jj
|j| |j||	f||i W 5 Q R X qr|tj |tj |jj
S )a#  
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - pred_keys[0]: prediction result of network 0.
            - pred_keys[1]: prediction result of network 1.
            - ... ...
            - pred_keys[N]: prediction result of network N.

        Args:
            engine: `EnsembleEvaluator` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        Nrw   rn   rN   )rJ   r2   rS   r.   r1   r>   rz   r}   r~   r   r   	enumerater   r:   r9   ro   r   r   r?   r{   rs   updater   rj   r   r   r   r   )
rK   ru   rv   r   r   r   r   r   idxri   rN   rN   rO   r     s0    zEnsembleEvaluator._iterationr   rN   rN   rL   rO   r"   b  s*   8>5)7
__future__r   rp   typingr   r   r   r   r   ro   torch.utils.datar   monai.configr	   r
   
monai.datar   Zmonai.engines.utilsr   r   r   Zmonai.engines.workflowr   Zmonai.inferersr   r   monai.networks.utilsr   r   monai.transformsr   monai.utilsr   r   r   r   monai.utils.enumsr   r}   r   rW   monai.utils.moduler   r   ignite.enginer   r   Zignite.metricsr   OPT_IMPORT_VERSIONr   __all__r    r!   r"   rN   rN   rN   rO   <module>   s6   
  6