o
    i1h                     @  s  d dl mZ d dlZd dlmZmZ d dlmZmZm	Z	 d dl
Z
d dlmZ d dlmZ d dlmZ d dlmZmZmZ d d	lmZ d d
lmZmZ d dlmZmZ d dlmZ d dlm Z m!Z!m"Z"m#Z#m$Z$ d dl%m&Z' d dl%m(Z) d dl*m+Z+ erd dl,m-Z-m.Z. d dl/m0Z0 ne$de!j1e#d\Z-Z2e$de!j1e#d\Z0Z2e$de!j1e#d\Z.Z2g dZ3G dd deZ4G dd de4Z5G dd de4Z6dS )    )annotationsN)IterableSequence)TYPE_CHECKINGAnyCallable)
DataLoader)KeysCollection)
MetaTensor)IterationEventsdefault_metric_cmp_fndefault_prepare_batch)Workflow)InfererSimpleInferer)	eval_mode
train_mode)	Transform)ForwardMode
IgniteInfoensure_tuplemin_versionoptional_import)
CommonKeys)EngineStatsKeys)look_up_option)Engine	EventEnum)Metriczignite.enginer   zignite.metricsr   r   )	EvaluatorSupervisedEvaluatorEnsembleEvaluatorc                      s^   e Zd ZdZddeddddeddejdddddfd. fd%d&Zd/d0 fd*d+Z	d,d- Z
  ZS )1r   a9  
    Base class for all kinds of evaluators, inherits from Workflow.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        mode: model forward mode during evaluation, should be 'eval' or 'train',
            which maps to `model.eval()` or `model.train()`, default to 'eval'.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.autocast.

    NFTdevicetorch.device | strval_data_loaderIterable | DataLoaderepoch_length
int | Nonenon_blockingboolprepare_batchr   iteration_update#Callable[[Engine, Any], Any] | NonepostprocessingTransform | Nonekey_val_metricdict[str, Metric] | Noneadditional_metricsmetric_cmp_fnval_handlersSequence | NoneampmodeForwardMode | strevent_names.list[str | EventEnum | type[EventEnum]] | Noneevent_to_attrdict | None	decollate	to_kwargs
amp_kwargsreturnNonec                   s   t  jdi d|ddd|d|d|d|d|d	|d
|d|	d|
d|d|d|d|d|d|d| t|t}|tjkrLt| _d S |tjkrVt| _d S t	d| d)Nr"   
max_epochs   data_loaderr&   r(   r*   r+   r-   
key_metricr1   r2   handlersr5   r8   r:   r<   r=   r>   zunsupported mode: z, should be 'eval' or 'train'. )
super__init__r   r   EVALr   r6   TRAINr   
ValueError)selfr"   r$   r&   r(   r*   r+   r-   r/   r1   r2   r3   r5   r6   r8   r:   r<   r=   r>   	__class__rF   Y/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/engines/evaluator.pyrH   Z   sV   	





zEvaluator.__init__rB   global_epochintc                   s0   t |d| j_|d | j_d| j_t   dS )z
        Execute validation/evaluation based on Ignite Engine.

        Args:
            global_epoch: the overall epoch if during a training. evaluator engine can get it from trainer.

        rB   r   N)maxstaterA   epoch	iterationrG   run)rL   rP   rM   rF   rO   rV      s   	zEvaluator.runc                 G  sB   t j| jjt j| jjt j| jji}|D ]}t| j|d||< q|S )a  
        Get the statistics information of the validation process.
        Default to return the `rank`, `best_validation_epoch` and `best_validation_metric`.

        Args:
            vars: except for the default stats, other variables name in the `self.state` to return,
                will use the variable name as the key and the state content as the value.
                if the variable doesn't exist, default value is `None`.

        N)	ESKeysRANKrS   rankBEST_VALIDATION_EPOCHZbest_metric_epochBEST_VALIDATION_METRICbest_metricgetattr)rL   varsstatskrF   rF   rO   	get_stats   s   


zEvaluator.get_stats)&r"   r#   r$   r%   r&   r'   r(   r)   r*   r   r+   r,   r-   r.   r/   r0   r1   r0   r2   r   r3   r4   r5   r)   r6   r7   r8   r9   r:   r;   r<   r)   r=   r;   r>   r;   r?   r@   )rB   )rP   rQ   r?   r@   )__name__
__module____qualname____doc__r   r   r   rI   rH   rV   ra   __classcell__rF   rF   rM   rO   r   ,   s*    11r   c                      sV   e Zd ZdZddedddddeddejdddddddfd3 fd+d,Zd4d1d2Z	  Z
S )5r    a  
    Standard supervised evaluation method with image and label(optional), inherits from evaluator and Workflow.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
        network: network to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        mode: model forward mode during evaluation, should be 'eval' or 'train',
            which maps to `model.eval()` or `model.train()`, default to 'eval'.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.autocast.
        compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
            `torch.Tensor` before forward pass,  then converted back afterward with copied meta information.
        compile_kwargs: dict of the args for `torch.compile()` API, for more details:
            https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.

    NFTr"   torch.devicer$   r%   networktorch.nn.Moduler&   r'   r(   r)   r*   r   r+   r,   infererInferer | Noner-   r.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   compilecompile_kwargsr?   r@   c                   s   t  jdi d|d|d|d|d|d|d|	d|
d	|d
|d|d|d|d|d|d|d|d| |rP|d u rEi n|}tj|fi |}|| _|| _|d u r`t | _d S || _d S )Nr"   r$   r&   r(   r*   r+   r-   r/   r1   r2   r3   r5   r6   r8   r:   r<   r=   r>   rF   )rG   rH   torchrl   rh   r   rj   )rL   r"   r$   rh   r&   r(   r*   r+   rj   r-   r/   r1   r2   r3   r5   r6   r8   r:   r<   r=   r>   rl   rm   rM   rF   rO   rH      sV   	
zSupervisedEvaluator.__init__engine	batchdatadict[str, torch.Tensor]dictc              	   C  s  |du rt d|j||jj|jfi |j}t|dkr&|\}}d}i }n|\}}}}| jr\d\}}	}
}t|t	rKt
d | |j|j}}}
t|t	r\| |j|j}}	}tj|tj|i|j_||jK |jrtjd	i |j |j||jg|R i ||jjtj< W d   n1 sw   Y  n|j||jg|R i ||jjtj< W d   n1 sw   Y  | jr|durt	|||
d|jjtj< t	|jjtj ||
d|jjtj< |	durt	||	|d|jjtj< |tj |tj |jjS )
a  
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.

        Args:
            engine: `SupervisedEvaluator` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        N.Must provide batch data for current iteration.   rF   )NNNNzgWill convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.cuda)metaapplied_operationsru   )rK   r*   rS   r"   r(   r=   lenrl   
isinstancer
   warningswarn	as_tensorrv   rw   KeysIMAGELABELoutputr6   rh   r5   rn   autocastr>   rj   PRED
fire_eventr   FORWARD_COMPLETEDMODEL_COMPLETED)rL   ro   rp   batchinputstargetsargskwargsZinputs_metaZtargets_metaZinputs_applied_operationsZtargets_applied_operationsrF   rF   rO   
_iteration  s`   



(&zSupervisedEvaluator._iteration).r"   rg   r$   r%   rh   ri   r&   r'   r(   r)   r*   r   r+   r,   rj   rk   r-   r.   r/   r0   r1   r0   r2   r   r3   r4   r5   r)   r6   r7   r8   r9   r:   r;   r<   r)   r=   r;   r>   r;   rl   r)   rm   r;   r?   r@   )ro   r    rp   rq   r?   rr   rb   rc   rd   re   r   r   r   rI   rH   r   rf   rF   rF   rM   rO   r       s.    84r    c                      sT   e Zd ZdZdddedddddeddejdddddfd3 fd+d,Zd4d1d2Z	  Z
S )5r!   as  
    Ensemble evaluation for multiple models, inherits from evaluator and Workflow.
    It accepts a list of models for inference and outputs a list of predictions for further operations.

    Args:
        device: an object representing the device on which to run.
        val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
        networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`.
        pred_keys: the keys to store every prediction data.
            the length must exactly match the number of networks.
            if None, use "pred_{index}" as key corresponding to N networks, index from `0` to `N-1`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_val_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision evaluation, default is False.
        mode: model forward mode during evaluation, should be 'eval' or 'train',
            which maps to `model.eval()` or `model.train()`, default to 'eval'.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.autocast.

    NFTr"   rg   r$   r%   networksSequence[torch.nn.Module]	pred_keysKeysCollection | Noner&   r'   r(   r)   r*   r   r+   r,   rj   rk   r-   r.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   r@   c                   s   t  jdi d|d|d|d|d|d|d|
d|d	|d
|d|d|d|d|d|d|d|d| t|| _|d u rRdd tt| jD nt|| _t| jt| jkretd|	d u rot | _	d S |	| _	d S )Nr"   r$   r&   r(   r*   r+   r-   r/   r1   r2   r3   r5   r6   r8   r:   r<   r=   r>   c                 S  s   g | ]
}t j d | qS )_)r~   r   ).0irF   rF   rO   
<listcomp>  s    z.EnsembleEvaluator.__init__.<locals>.<listcomp>z?length of `pred_keys` must be same as the length of `networks`.rF   )
rG   rH   r   r   rangery   r   rK   r   rj   )rL   r"   r$   r   r   r&   r(   r*   r+   rj   r-   r/   r1   r2   r3   r5   r6   r8   r:   r<   r=   r>   rM   rF   rO   rH     sV   	

&zEnsembleEvaluator.__init__ro   rp   rq   rr   c           
   
   C  s|  |du rt d|j||jj|jfi |j}t|dkr&|\}}d}i }n|\}}}}tj|tj	|i|j_
t|jD ]r\}}	||	a |jrtjdi |j& t|jj
trp|jj
|j| |j||	g|R i |i W d   n1 szw   Y  nt|jj
tr|jj
|j| |j||	g|R i |i W d   n1 sw   Y  q;|tj |tj |jj
S )a#  
        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - pred_keys[0]: prediction result of network 0.
            - pred_keys[1]: prediction result of network 1.
            - ... ...
            - pred_keys[N]: prediction result of network N.

        Args:
            engine: `EnsembleEvaluator` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        Nrs   rt   rF   ru   rx   )rK   r*   rS   r"   r(   r=   ry   r~   r   r   r   	enumerater   r6   r5   rn   r   r>   rz   rr   updater   rj   r   r   r   r   )
rL   ro   rp   r   r   r   r   r   idxrh   rF   rF   rO   r     s<   ""zEnsembleEvaluator._iteration),r"   rg   r$   r%   r   r   r   r   r&   r'   r(   r)   r*   r   r+   r,   rj   rk   r-   r.   r/   r0   r1   r0   r2   r   r3   r4   r5   r)   r6   r7   r8   r9   r:   r;   r<   r)   r=   r;   r>   r;   r?   r@   )ro   r!   rp   rq   r?   rr   r   rF   rF   rM   rO   r!   ^  s,    85r!   )7
__future__r   r{   collections.abcr   r   typingr   r   r   rn   torch.utils.datar   monai.configr	   
monai.datar
   Zmonai.engines.utilsr   r   r   Zmonai.engines.workflowr   Zmonai.inferersr   r   monai.networks.utilsr   r   monai.transformsr   monai.utilsr   r   r   r   r   monai.utils.enumsr   r~   r   rW   monai.utils.moduler   ignite.enginer   r   Zignite.metricsr   OPT_IMPORT_VERSIONr   __all__r   r    r!   rF   rF   rF   rO   <module>   s:     1