U
    vPh\                     @  s~  d dl mZ d dlZd dlmZmZmZmZmZ d dl	Z	d dl
mZ d dlmZ d dlmZ d dlmZ d dlmZmZmZmZ d d	lmZ d d
lmZmZ d dlmZ d dlmZm Z m!Z! d dl"m#Z$ d dl"m%Z& d dl'm(Z( erd dl)m*Z*m+Z+ d dl,m-Z- n<e!dej.e d\Z*Z/e!dej.e d\Z-Z/e!dej.e d\Z+Z/dddgZ0G dd deZ1G dd de1Z2G dd de1Z3dS )    )annotationsN)TYPE_CHECKINGAnyCallableIterableSequence)	Optimizer)
DataLoader)
IgniteInfo)
MetaTensor)IterationEventsdefault_make_latentdefault_metric_cmp_fndefault_prepare_batch)Workflow)InfererSimpleInferer)	Transform)GanKeysmin_versionoptional_import)
CommonKeys)EngineStatsKeys)pytorch_after)Engine	EventEnum)Metriczignite.enginer   zignite.metricsr   r   TrainerSupervisedTrainer
GanTrainerc                      s.   e Zd ZdZdd fddZdd Z  ZS )r   zH
    Base class for all kinds of trainers, inherits from Workflow.

    None)returnc                   s&   | j rtjj  nd| _t   dS )z
        Execute training based on Ignite Engine.
        If call this function multiple times, it will continuously run from the previous state.

        N)amptorchcuda
GradScalerscalersuperrun)self	__class__ J/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/engines/trainer.pyr(   1   s    zTrainer.runc              
   G  sV   t j| jjt j| jjt j| jjt j| jj	t j
| jji}|D ]}t| j|d||< q:|S )a  
        Get the statistics information of the training process.
        Default to return the `rank`, `current_epoch`, `current_iteration`, `total_epochs`, `total_iterations`.

        Args:
            vars: except for the default stats, other variables name in the `self.state` to return,
                will use the variable name as the key and the state content as the value.
                if the variable doesn't exist, default value is `None`.

        N)ESKeysRANKstaterankCURRENT_EPOCHepochCURRENT_ITERATION	iterationTOTAL_EPOCHS
max_epochsTOTAL_ITERATIONSepoch_lengthgetattr)r)   varsstatskr,   r,   r-   	get_stats:   s         zTrainer.get_stats)__name__
__module____qualname____doc__r(   r>   __classcell__r,   r,   r*   r-   r   +   s   	c                      s   e Zd ZdZddedddddeddddddddddfddddd	d
ddd
dddddd
dddddddddddd fddZd dddddZ  ZS )r   ao  
    Standard supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``.

    Args:
        device: an object representing the device on which to run.
        max_epochs: the total epoch number for trainer to run.
        train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
        network: network to train in the trainer, should be regular PyTorch `torch.nn.Module`.
        optimizer: the optimizer associated to the network, should be regular PyTorch optimizer from `torch.optim`
            or its subclass.
        loss_function: the loss function associated to the optimizer, should be regular PyTorch loss,
            which inherit from `torch.nn.modules.loss`.
        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_train_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision training, default is False.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
        compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
            `torch.Tensor` before forward pass,  then converted back afterward with copied meta information.
        compile_kwargs: dict of the args for `torch.compile()` API, for more details:
            https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
    NFTstr | torch.deviceintzIterable | DataLoadertorch.nn.Moduler   r   
int | Nonebool#Callable[[Engine, Any], Any] | NoneInferer | NoneTransform | Nonedict[str, Metric] | NoneSequence | Nonez.list[str | EventEnum | type[EventEnum]] | Nonedict | Noner    )devicer7   train_data_loadernetwork	optimizerloss_functionr9   non_blockingprepare_batchiteration_updateinfererpostprocessingkey_train_metricadditional_metricsmetric_cmp_fntrain_handlersr"   event_namesevent_to_attr	decollateoptim_set_to_none	to_kwargs
amp_kwargscompilecompile_kwargsr!   c                   s   t  j||||||	|
|||||||||||d |rhtddr^|d krJi n|}tj|f|}n
td || _|| _|| _|| _	|d krt
 n|| _|| _d S )N)rO   r7   data_loaderr9   rT   rU   rV   rX   
key_metricrZ   r[   handlersr"   r]   r^   r_   ra   rb         zeNetwork compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done)r'   __init__r   r#   rc   warningswarnrQ   rR   rS   r   rW   r`   )r)   rO   r7   rP   rQ   rR   rS   r9   rT   rU   rV   rW   rX   rY   rZ   r[   r\   r"   r]   r^   r_   r`   ra   rb   rc   rd   r*   r,   r-   rj      sB    
zSupervisedTrainer.__init__zdict[str, torch.Tensor]dictengine	batchdatar!   c           	   	     s  |dkrt dj|jjjfj}t|dkrH|\d i n|\ | jrd\}}}}tt	rt
d  jj  }}tt	r jj  }}tjtjij_ fdd}j  jjjd	 jrnjdk	rntjjjf j |  W 5 Q R X jjjtj    !t"j# j$j j%  n.|  jjtj    !t"j# j$  | jr
|dk	rt	||d
jjtj< t	jjtj& ||d
jjtj&< |dk	r
t	||d
jjtj< !t"j' jjS )a  
        Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.
            - LOSS: loss value computed by loss function.

        Args:
            engine: `SupervisedTrainer` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        Nz.Must provide batch data for current iteration.rh   r,   )NNNNzgWill convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.c                     s`   j jf jjtj< tj 	jjtj 
 jjtj< tj d S )N)rW   rQ   r0   outputKeysPRED
fire_eventr   FORWARD_COMPLETEDrS   meanLOSSLOSS_COMPLETEDr,   argsro   inputskwargstargetsr,   r-   _compute_pred_loss   s     $z8SupervisedTrainer._iteration.<locals>._compute_pred_lossset_to_none)metaapplied_operations)(
ValueErrorrU   r0   rO   rT   ra   lenrc   
isinstancer   rk   rl   	as_tensorr   r   rr   IMAGELABELrq   rQ   trainrR   	zero_gradr`   r"   r&   r#   r$   autocastrb   scalerw   backwardrt   r   BACKWARD_COMPLETEDstepupdaters   MODEL_COMPLETED)	r)   ro   rp   batchinputs_metatargets_metainputs_applied_operationstargets_applied_operationsr~   r,   ry   r-   
_iteration   st    






    
  zSupervisedTrainer._iteration)	r?   r@   rA   rB   r   r   rj   r   rC   r,   r,   r*   r-   r   Q   s,   ?F?c                      s   e Zd ZdZddddddeedddddedddddfddd	d
ddd
dddddddddddddddddddddd fddZd dddddZ  Z	S )r   av  
    Generative adversarial network training based on Goodfellow et al. 2014 https://arxiv.org/abs/1406.266,
    inherits from ``Trainer`` and ``Workflow``.

    Training Loop: for each batch of data size `m`
        1. Generate `m` fakes from random latent codes.
        2. Update discriminator with these fakes and current batch reals, repeated d_train_steps times.
        3. If g_update_latents, generate `m` fakes from new random latent codes.
        4. Update generator with these fakes using discriminator feedback.

    Args:
        device: an object representing the device on which to run.
        max_epochs: the total epoch number for engine to run.
        train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
        g_network: generator (G) network architecture.
        g_optimizer: G optimizer function.
        g_loss_function: G loss function for optimizer.
        d_network: discriminator (D) network architecture.
        d_optimizer: D optimizer function.
        d_loss_function: D loss function for optimizer.
        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
        g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
        d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
        d_train_steps: number of times to update D with real data minibatch. Defaults to ``1``.
        latent_shape: size of G input latent code. Defaults to ``64``.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        d_prepare_batch: callback function to prepare batchdata for D inferer.
            Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        g_prepare_batch: callback function to create batch of latent input for G inferer.
            Defaults to return random latents. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_train_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.

    Nri   @   FTrD   rE   r	   rF   r   r   rG   rJ   rH   rI   rK   rL   rM   rN   )rO   r7   rP   	g_networkg_optimizerg_loss_function	d_networkd_optimizerd_loss_functionr9   	g_inferer	d_infererd_train_stepslatent_shaperT   d_prepare_batchg_prepare_batchg_update_latentsrV   rX   rY   rZ   r[   r\   r_   r`   ra   rb   c                   s   t |tstdt j||||
|||||||||||d || _|| _|| _|d kr\t n|| _	|| _
|| _|	| _|d krt n|| _|| _|| _|| _|| _|| _d S )Nz-train_data_loader must be PyTorch DataLoader.)rO   r7   re   r9   rT   rU   rV   rf   rZ   r[   rg   rX   r_   ra   rb   )r   r	   r   r'   rj   r   r   r   r   r   r   r   r   r   r   r   r   r   r`   )r)   rO   r7   rP   r   r   r   r   r   r   r9   r   r   r   r   rT   r   r   r   rV   rX   rY   rZ   r[   r\   r_   r`   ra   rb   r*   r,   r-   rj   ]  s@    
zGanTrainer.__init__zdict | Sequencez,dict[str, torch.Tensor | int | float | bool]rn   c                 C  sF  |dkrt d|j||jj|jf|j}|jj}|jf ||j	|jj|jd|j}|
||j}td}t|jD ]>}|jj|jd |||}	|	  |j  ||	 7 }qx|jr|jf ||j	|jj|jd|j}|
||j}|jj|jd ||}
|
  |j  tj|tj|tj|tj|
 tj| iS )a  
        Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine.

        Args:
            engine: `GanTrainer` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: must provide batch data for current iteration.

        Nz.must provide batch data for current iteration.)num_latentslatent_sizerO   rT   ri   r   ) r   rU   r0   rO   rT   ra   re   
batch_sizer   r   r   r   r#   zerosranger   r   r   r`   r   r   r   itemr   r   r   r   REALSFAKESLATENTSGLOSSDLOSS)r)   ro   rp   d_inputr   Zg_inputZg_outputZd_total_loss_Zdlossg_lossr,   r,   r-   r     sZ    



     zGanTrainer._iteration)
r?   r@   rA   rB   r   r   r   rj   r   rC   r,   r,   r*   r-   r     s,   HJB)4
__future__r   rk   typingr   r   r   r   r   r#   Ztorch.optim.optimizerr   torch.utils.datar	   monai.configr
   
monai.datar   monai.engines.utilsr   r   r   r   monai.engines.workflowr   monai.inferersr   r   monai.transformsr   monai.utilsr   r   r   monai.utils.enumsr   rr   r   r.   monai.utils.moduler   ignite.enginer   r   ignite.metricsr   OPT_IMPORT_VERSIONr   __all__r   r   r   r,   r,   r,   r-   <module>   s4   
& O