o
    i                     @  s  d dl mZ d dlZd dlmZmZ d dlmZmZm	Z	 d dl
Z
d dlmZ d dlmZ d dlmZ d dlmZmZmZmZ d d	lmZ d d
lmZmZ d dlmZ d dlmZmZm Z m!Z!m"Z"m#Z# d dl$m%Z& d dl$m'Z( erd dl)m*Z*m+Z+ d dl,m-Z- ne#de!j.e"d\Z*Z/e#de!j.e"d\Z-Z/e#de!j.e"d\Z+Z/g dZ0G dd deZ1G dd de1Z2G dd de1Z3G dd de1Z4dS )    )annotationsN)IterableSequence)TYPE_CHECKINGAnyCallable)	Optimizer)
DataLoader)
MetaTensor)IterationEventsdefault_make_latentdefault_metric_cmp_fndefault_prepare_batch)Workflow)InfererSimpleInferer)	Transform)AdversarialIterationEventsAdversarialKeysGanKeys
IgniteInfomin_versionoptional_import)
CommonKeys)EngineStatsKeys)Engine	EventEnum)Metriczignite.enginer   zignite.metricsr   r   )TrainerSupervisedTrainer
GanTrainerAdversarialTrainerc                      s*   e Zd ZdZd fddZdd Z  ZS )	r   zH
    Base class for all kinds of trainers, inherits from Workflow.

    returnNonec                   s&   | j r	tjj  nd| _t   dS )z
        Execute training based on Ignite Engine.
        If call this function multiple times, it will continuously run from the previous state.

        N)amptorchcuda
GradScalerscalersuperrun)self	__class__ W/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/engines/trainer.pyr*   0   s   zTrainer.runc              
   G  sV   t j| jjt j| jjt j| jjt j| jj	t j
| jji}|D ]}t| j|d||< q|S )a  
        Get the statistics information of the training process.
        Default to return the `rank`, `current_epoch`, `current_iteration`, `total_epochs`, `total_iterations`.

        Args:
            vars: except for the default stats, other variables name in the `self.state` to return,
                will use the variable name as the key and the state content as the value.
                if the variable doesn't exist, default value is `None`.

        N)ESKeysRANKstaterankCURRENT_EPOCHepochCURRENT_ITERATION	iterationTOTAL_EPOCHS
max_epochsTOTAL_ITERATIONSepoch_lengthgetattr)r+   varsstatskr.   r.   r/   	get_stats9   s   




zTrainer.get_statsr"   r#   )__name__
__module____qualname____doc__r*   r@   __classcell__r.   r.   r,   r/   r   *   s    	r   c                      sT   e Zd ZdZddedddddeddddddddddfd7 fd/d0Zd8d5d6Z  ZS )9r   ac  
    Standard supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``.

    Args:
        device: an object representing the device on which to run.
        max_epochs: the total epoch number for trainer to run.
        train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
        network: network to train in the trainer, should be regular PyTorch `torch.nn.Module`.
        optimizer: the optimizer associated to the network, should be regular PyTorch optimizer from `torch.optim`
            or its subclass.
        loss_function: the loss function associated to the optimizer, should be regular PyTorch loss,
            which inherit from `torch.nn.modules.loss`.
        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_train_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision training, default is False.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.autocast.
        compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
            `torch.Tensor` before forward pass,  then converted back afterward with copied meta information.
        compile_kwargs: dict of the args for `torch.compile()` API, for more details:
            https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.
    NFTdevicestr | torch.devicer9   inttrain_data_loaderIterable | DataLoadernetworktorch.nn.Module	optimizerr   loss_functionr   r;   
int | Nonenon_blockingboolprepare_batchiteration_update#Callable[[Engine, Any], Any] | NoneinfererInferer | NonepostprocessingTransform | Nonekey_train_metricdict[str, Metric] | Noneadditional_metricsmetric_cmp_fntrain_handlersSequence | Noner$   event_names.list[str | EventEnum | type[EventEnum]] | Noneevent_to_attrdict | None	decollateoptim_set_to_none	to_kwargs
amp_kwargscompilecompile_kwargsr"   r#   c                   s   t  jdi d|d|d|d|d|d|	d|
d|d	|d
|d|d|d|d|d|d|d|d| |rP|d u rEi n|}tj|fi |}|| _|| _|| _|| _|d u rct n|| _|| _	d S NrG   r9   data_loaderr;   rQ   rS   rT   rX   
key_metricr\   r]   handlersr$   r`   rb   rd   rf   rg   r.   )
r)   __init__r%   rh   rL   rN   rO   r   rV   re   )r+   rG   r9   rJ   rL   rN   rO   r;   rQ   rS   rT   rV   rX   rZ   r\   r]   r^   r$   r`   rb   rd   re   rf   rg   rh   ri   r,   r.   r/   rn      s\   	

zSupervisedTrainer.__init__engine	batchdatadict[str, torch.Tensor]dictc           	        s0  |du rt dj|jjjfi j}t|dkr&|\d i n|\ | jr\d\}}}}tt	rKt
d  jj}}tt	r\ jj}}tjtjij_ fdd}j  jjjd	 jrjdurtjdi j |  W d   n1 sw   Y  jjjtj    t!j" j#j j$  n|  jjtj    t!j" j#  | jr|durt	||djjtj< t	jjtj% ||djjtj%< |durt	||djjtj<  t!j& jjS )a  
        Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device.
            - PRED: prediction result of model.
            - LOSS: loss value computed by loss function.

        Args:
            engine: `SupervisedTrainer` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: When ``batchdata`` is None.

        N.Must provide batch data for current iteration.   r.   )NNNNzgWill convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.c                     sf   j jg R i jjtj< tj 	jjtj 
 jjtj< tj d S N)rV   rL   r2   outputKeysPRED
fire_eventr   FORWARD_COMPLETEDrO   meanLOSSLOSS_COMPLETEDr.   argsro   inputskwargstargetsr.   r/   _compute_pred_loss   s   &$z8SupervisedTrainer._iteration.<locals>._compute_pred_lossset_to_noner&   )metaapplied_operationsr&   )'
ValueErrorrS   r2   rG   rQ   rf   lenrh   
isinstancer
   warningswarn	as_tensorr   r   rw   IMAGELABELrv   rL   trainrN   	zero_gradre   r$   r(   r%   autocastrg   scaler|   backwardry   r   BACKWARD_COMPLETEDstepupdaterx   MODEL_COMPLETED)	r+   ro   rp   batchinputs_metatargets_metainputs_applied_operationstargets_applied_operationsr   r.   r~   r/   
_iteration   sj   






zSupervisedTrainer._iteration)4rG   rH   r9   rI   rJ   rK   rL   rM   rN   r   rO   r   r;   rP   rQ   rR   rS   r   rT   rU   rV   rW   rX   rY   rZ   r[   r\   r[   r]   r   r^   r_   r$   rR   r`   ra   rb   rc   rd   rR   re   rR   rf   rc   rg   rc   rh   rR   ri   rc   r"   r#   )ro   r   rp   rq   r"   rr   )	rB   rC   rD   rE   r   r   rn   r   rF   r.   r.   r,   r/   r   P   s.    ?:r   c                      sT   e Zd ZdZddddddeedddddedddddfd: fd1d2Zd;d8d9Z  Z	S )<r    aj  
    Generative adversarial network training based on Goodfellow et al. 2014 https://arxiv.org/abs/1406.266,
    inherits from ``Trainer`` and ``Workflow``.

    Training Loop: for each batch of data size `m`
        1. Generate `m` fakes from random latent codes.
        2. Update discriminator with these fakes and current batch reals, repeated d_train_steps times.
        3. If g_update_latents, generate `m` fakes from new random latent codes.
        4. Update generator with these fakes using discriminator feedback.

    Args:
        device: an object representing the device on which to run.
        max_epochs: the total epoch number for engine to run.
        train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
        g_network: generator (G) network architecture.
        g_optimizer: G optimizer function.
        g_loss_function: G loss function for optimizer.
        d_network: discriminator (D) network architecture.
        d_optimizer: D optimizer function.
        d_loss_function: D loss function for optimizer.
        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
        g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
        d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
        d_train_steps: number of times to update D with real data minibatch. Defaults to ``1``.
        latent_shape: size of G input latent code. Defaults to ``64``.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        d_prepare_batch: callback function to prepare batchdata for D inferer.
            Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        g_prepare_batch: callback function to create batch of latent input for G inferer.
            Defaults to return random latents. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_train_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.autocast.

    N   @   FTrG   rH   r9   rI   rJ   r	   	g_networkrM   g_optimizerr   g_loss_functionr   	d_networkd_optimizerd_loss_functionr;   rP   	g_infererrW   	d_infererd_train_stepslatent_shaperQ   rR   d_prepare_batchg_prepare_batchg_update_latentsrT   rU   rX   rY   rZ   r[   r\   r]   r^   r_   rd   re   rf   rc   rg   c                   s   t |ts	tdt j||||
|||||||||||d || _|| _|| _|d u r.t n|| _	|| _
|| _|	| _|d u rAt n|| _|| _|| _|| _|| _|| _d S )Nz-train_data_loader must be PyTorch DataLoader.)rG   r9   rk   r;   rQ   rS   rT   rl   r\   r]   rm   rX   rd   rf   rg   )r   r	   r   r)   rn   r   r   r   r   r   r   r   r   r   r   r   r   r   re   )r+   rG   r9   rJ   r   r   r   r   r   r   r;   r   r   r   r   rQ   r   r   r   rT   rX   rZ   r\   r]   r^   rd   re   rf   rg   r,   r.   r/   rn   W  s@   

zGanTrainer.__init__ro   rp   dict | Sequencer"   ,dict[str, torch.Tensor | int | float | bool]c                 C  sJ  |du rt d|j||jj|jfi |j}|jj}|jd||j	|jj|jd|j}|
||j}td}t|jD ]}|jj|jd |||}	|	  |j  ||	 7 }q>|jrs|jd||j	|jj|jd|j}|
||j}|jj|jd ||}
|
  |j  tj|tj|tj|tj|
 tj| iS )a  
        Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine.

        Args:
            engine: `GanTrainer` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: must provide batch data for current iteration.

        Nz.must provide batch data for current iteration.)num_latentslatent_sizerG   rQ   r   r   r.   ) r   rS   r2   rG   rQ   rf   rk   
batch_sizer   r   r   r   r%   zerosranger   r   r   re   r   r   r   itemr   r   r   r   REALSFAKESLATENTSGLOSSDLOSS)r+   ro   rp   d_inputr   Zg_inputZg_outputZd_total_loss_Zdlossg_lossr.   r.   r/   r     sP   





zGanTrainer._iteration)8rG   rH   r9   rI   rJ   r	   r   rM   r   r   r   r   r   rM   r   r   r   r   r;   rP   r   rW   r   rW   r   rI   r   rI   rQ   rR   r   r   r   r   r   rR   rT   rU   rX   rY   rZ   r[   r\   r[   r]   r   r^   r_   rd   rR   re   rR   rf   rc   rg   rc   )ro   r    rp   r   r"   r   )
rB   rC   rD   rE   r   r   r   rn   r   rF   r.   r.   r,   r/   r      s.    HBr    c                      s\   e Zd ZdZddeddddddeddddddddfd< fd0d1Zd=d4d5Zd>d:d;Z  Z	S )?r!   a  
    Standard supervised training workflow for adversarial loss enabled neural networks.

    Args:
        device: an object representing the device on which to run.
        max_epochs: the total epoch number for engine to run.
        train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
        g_network: ''generator'' (G) network architecture.
        g_optimizer: G optimizer function.
        g_loss_function: G loss function for adversarial training.
        recon_loss_function: G loss function for reconstructions.
        d_network: discriminator (D) network architecture.
        d_optimizer: D optimizer function.
        d_loss_function: D loss function for adversarial training..
        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to
            the host. For other cases, this argument has no effect.
        prepare_batch: function to parse image and label for current iteration.
        iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input
            parameters. if not provided, use `self._iteration()` instead.
        g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
        d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
        postprocessing: execute additional transformation for the model output data. Typically, several Tensor based
            transforms composed by `Compose`. Defaults to None
        key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics
            when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args
            (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and
            `best_metric_epoch` with current metric and epoch, default to `greater than`.
        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision training, default is False.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation, recommend
            `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`.
        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.autocast.
    NFTrG   torch.device | strr9   rI   rJ   rK   r   rM   r   r   r   r   recon_loss_functionr   r   r   r;   rP   rQ   rR   rS   rT   Callable | Noner   rW   r   rX   rY   rZ   r[   r\   r]   r^   r_   r$   r`   ra   rb   rc   rd   re   rf   rg   c                   s*  t  jdi d|d|d|d|d|d|d|d|d	|d
|d|d|d|d|d|d|d|d| | jt  || j_|| j_|| j_|| j_|| j_	|	| j_
|
| j_|d u ret n|| _|d u rot n|| _| jr{tjj nd | j_| jrtjj nd | j_|| _|   d S rj   )r)   rn   register_eventsr   r2   r   r   r   r   r   r   r   r   r   r   r$   r%   r&   r'   g_scalerd_scalerre   _complete_state_dict_user_keys)r+   rG   r9   rJ   r   r   r   r   r   r   r   r;   rQ   rS   rT   r   r   rX   rZ   r\   r]   r^   r$   r`   rb   rd   re   rf   rg   r,   r.   r/   rn     sf   	

zAdversarialTrainer.__init__r"   r#   c                 C  s   | j g d t| jjdd}t|r| j d t| jjdd}t|r,| j d t| jjdd}t|r@| j d dS dS )a1  
        This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for
        checkpoint saving.

        Follows the example found at:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict
        )r   r   r   r   r   r   
state_dictNr   r   r   )	Z_state_dict_user_keysextendr<   r2   r   callableappendr   r   )r+   Zg_loss_state_dictZd_loss_state_dictZrecon_loss_state_dictr.   r.   r/   r   O  s   z1AdversarialTrainer._complete_state_dict_user_keysro   rp   rq   r   c                   s  |du rt dj|jjjfi j}t|dkr&|\d i n|\ tjtj	t
jij_d fdd}jj  jjjjd	 jrjjdurtjdi j |  W d   n1 ssw   Y  jjt
j jjt
j  jjtj< jjjjtj   tj jjjj jj   n|  jjt
j jjt
j    tj jj  tj! d fdd}jj"  jj"jjd	 jr;jj#dur;tjdi j |  W d   n	1 sw   Y  jj#jjt
j$   tj% jj#jj& jj#   jjS |  jjt
j$   jj&  jjS )a  
        Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.
        Return below items in a dictionary:
            - IMAGE: image Tensor data for model input, already moved to device.
            - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised
                Learning this is equal to IMAGE.
            - PRED: prediction result of model.
            - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up).
            - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE.
            - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED.
            - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images.
            - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images.
            - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function.
            - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the
                discriminator loss for the fake images. That is backpropagated through the generator only.
            - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the
                discriminator loss for the real images and the fake images. That is backpropagated through the
                discriminator only.

        Args:
            engine: `AdversarialTrainer` to execute operation for an iteration.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            ValueError: must provide batch data for current iteration.

        Nrs   rt   r.   r"   r#   c                     s   j jjg R i jjtj< jjtj jjtj< t	j
 jjjtj   jjg R i jjtj< t	j jjjtj  jjtj< t	j jjjtj  jjtj< t	j d S ru   )r   r2   r   rv   r   r   rw   rx   ry   r   GENERATOR_FORWARD_COMPLETEDr   float
contiguousr   FAKE_LOGITS)GENERATOR_DISCRIMINATOR_FORWARD_COMPLETEDr   r{   RECONSTRUCTION_LOSSRECONSTRUCTION_LOSS_COMPLETEDr   GENERATOR_LOSSGENERATOR_LOSS_COMPLETEDr.   r~   r.   r/   _compute_generator_loss  s:   z>AdversarialTrainer._iteration.<locals>._compute_generator_lossr   r&   c                     s   j jjtj   jjg R i jjtj< 	t
j j jjtj   jjg R i jjtj< 	t
j jjjtj jjtj  jjtj< 	t
j d S ru   )r   r2   rv   r   r   r   detachr   REAL_LOGITSry   r   %DISCRIMINATOR_REALS_FORWARD_COMPLETEDr   r   %DISCRIMINATOR_FAKES_FORWARD_COMPLETEDr   r{   DISCRIMINATOR_LOSSDISCRIMINATOR_LOSS_COMPLETEDr.   )r   ro   r   r.   r/   _compute_discriminator_loss  s0   zBAdversarialTrainer._iteration.<locals>._compute_discriminator_lossrA   r   )'r   rS   r2   rG   rQ   rf   r   rw   r   r   r   r   rv   r   r   r   r   re   r$   r   r%   r   rg   r   r   r|   r   r   ry   r   GENERATOR_BACKWARD_COMPLETEDr   r   GENERATOR_MODEL_COMPLETEDr   r   r    DISCRIMINATOR_BACKWARD_COMPLETEDr   )r+   ro   rp   r   r   r   r.   r~   r/   r   g  s`   zAdversarialTrainer._iteration)8rG   r   r9   rI   rJ   rK   r   rM   r   r   r   r   r   r   r   rM   r   r   r   r   r;   rP   rQ   rR   rS   r   rT   r   r   rW   r   rW   rX   rY   rZ   r[   r\   r[   r]   r   r^   r_   r$   rR   r`   ra   rb   rc   rd   rR   re   rR   rf   rc   rg   rc   rA   )ro   r!   rp   rq   r"   r   )
rB   rC   rD   rE   r   r   rn   r   r   rF   r.   r.   r,   r/   r!     s.    <
Hr!   )5
__future__r   r   collections.abcr   r   typingr   r   r   r%   Ztorch.optim.optimizerr   torch.utils.datar	   
monai.datar
   monai.engines.utilsr   r   r   r   monai.engines.workflowr   monai.inferersr   r   monai.transformsr   monai.utilsr   r   r   r   r   r   monai.utils.enumsr   rw   r   r0   ignite.enginer   r   ignite.metricsr   OPT_IMPORT_VERSIONr   __all__r   r   r    r!   r.   r.   r.   r/   <module>   s8    & J >