o
    i{<                     @  sV  d dl mZ d dlZd dlmZmZmZmZ d dlm	Z	m
Z
 d dlZd dlmZ d dlmZ d dlmZ d dlmZmZmZ d dlmZ d d	lmZmZmZmZmZ d
dlm Z  edej!ed\Z"Z#edej!ed\Z$Z#e	r}d dl%m&Z&m'Z' d dl(m)Z) n$edej!eddd\Z&Z#edej!eddd\Z)Z#edej!eddd\Z'Z#G dd de&Z*dS )    )annotationsN)CallableIterableSequenceSized)TYPE_CHECKINGAny)
DataLoader)DistributedSampler)IterationEventsdefault_metric_cmp_fndefault_prepare_batch)
Decollated)
IgniteInfoensure_tuple	is_scalarmin_versionoptional_import   )engine_apply_transformzignite.engineStateEvents)Engine	EventEnum)Metricr   	decorator)as_typezignite.metricsr   r   c                      s   e Zd ZdZddeddddedddddddfd> fd%d&Zd'd( Zd?d*d+Zd@dAd/d0Z	dBd2d3Z
dC fd4d5ZdDd:d;Zd<d= Z  ZS )EWorkflowa  
    Workflow defines the core work process inheriting from Ignite engine.
    All trainer, validator and evaluator share this same workflow as base class,
    because they all can be treated as same Ignite engine loops.
    It initializes all the sharable data in Ignite engine.state.
    And attach additional processing logics to Ignite engine based on Event-Handler mechanism.

    Users should consider inheriting from `trainer` or `evaluator` to develop more trainers or evaluators.

    Args:
        device: an object representing the device on which to run.
        max_epochs: the total epoch number for engine to run, validator and evaluator have only 1 epoch.
        data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
        epoch_length: number of iterations for one epoch, default to `len(data_loader)`.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
            from `engine.state.batch` for every iteration, for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
        iteration_update: the callable function for every iteration, expect to accept `engine`
            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
            if not provided, use `self._iteration()` instead. for more details please refer to:
            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
        postprocessing: execute additional transformation for the model output data.
            Typically, several Tensor based transforms composed by `Compose`.
        key_metric: compute metric when every iteration completed, and save average value to
            engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the
            checkpoint into files.
        additional_metrics: more Ignite metrics that also attach to Ignite Engine.
        metric_cmp_fn: function to compare current key metric with previous best key metric value,
            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
        handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
            CheckpointHandler, StatsHandler, etc.
        amp: whether to enable auto-mixed-precision training or inference, default is False.
        event_names: additional custom ignite events that will register to the engine.
            new events can be a list of str or `ignite.engine.events.EventEnum`.
        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
            #ignite.engine.engine.Engine.register_events.
        decollate: whether to decollate the batch-first data to a list of data after model computation,
            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
            default to `True`.
        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
            `device`, `non_blocking`.
        amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
            https://pytorch.org/docs/stable/amp.html#torch.autocast.

    Raises:
        TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.
        TypeError: When ``key_metric`` is not a ``Optional[dict]``.
        TypeError: When ``additional_metrics`` is not a ``Optional[dict]``.

    NFTdevicetorch.device | str
max_epochsintdata_loaderIterable | DataLoaderepoch_length
int | Nonenon_blockingboolprepare_batchr   iteration_update#Callable[[Engine, Any], Any] | NonepostprocessingCallable | None
key_metricdict[str, Metric] | Noneadditional_metricsmetric_cmp_fnhandlersSequence | Noneampevent_names.list[str | EventEnum | type[EventEnum]] | Noneevent_to_attrdict | None	decollate	to_kwargs
amp_kwargsreturnNonec                   s  t  |d u r
| jn| t|tr*t|dd  t tr*| tj	d fdd}|d u rCt|t
rCzt|}W n	 tyB   Y nw tt rPt rPt ndddd||d d i i d t|tjse|d u rg|nt|d d	d	d
| _|| _|| _|| _|| _|| _|d u ri n|| _|d u ri n|| _d | _|d u rtg}nt|tstd|tg7 }|D ]"}t|t t!fr| j"||d qt#|t!r| j"|d|i qtd|r| $  |d ur| %| |	d ur| &|	|
 |d ur| '| d S d S )Nsamplerenginer   r;   r<   c                   s     | jj d S N)	set_epochstateepoch)r>   r=    X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/engines/workflow.pyset_sampler_epoch   s   z,Workflow.__init__.<locals>.set_sampler_epochr   )rankseed	iterationrB   r    r$   outputbatchmetricsZmetric_details
dataloaderr   key_metric_namebest_metricbest_metric_epochz6`event_names` must be a list of strings or EventEnums.)r6   r6   r>   r   r;   r<   )(super__init__
_iteration
isinstancer	   getattrr
   onr   ZEPOCH_STARTEDr   len	TypeErrorr   distis_availableis_initializedget_ranktorchr   rA   r"   r&   r(   r0   r3   r9   r:   scalerr   list
ValueErrorstrr   Zregister_events
issubclass_register_decollate_register_postprocessing_register_metrics_register_handlers)selfr   r    r"   r$   r&   r(   r)   r+   r-   r/   r0   r1   r3   r4   r6   r8   r9   r:   rF   name	__class__rC   rE   rT   g   sr   


 



zWorkflow.__init__c                 C  s   |  tjddd}dS )	zv
        Register the decollate operation for batch data, will execute after model forward and loss forward.

        r>   r   r;   r<   c                 S  sX   t d dd}t| jjttfr|| jj| j_t| jjttfr*|| jj| j_d S d S )NT)keysdetach)r   rV   rA   rL   ra   dictrK   )r>   	transformrD   rD   rE   _decollate_data   s   z5Workflow._register_decollate.<locals>._decollate_dataNrR   rX   r   MODEL_COMPLETED)ri   rq   rD   rD   rE   re      s   
zWorkflow._register_decollate	posttransc                   s   |  tjd fdd}dS )	zz
        Register the postprocessing logic to the engine, will execute them as a chain when iteration completed.

        r>   r   r;   r<   c                   s   t | jjtrt | jjts!t| jj| jj d\| j_| j_d S tt| jj| jjD ]\}\}}t|| \| jj|< | jj|< q,d S )N)rL   rK   rp   )rV   rA   rL   ra   rK   r   	enumeratezip)r>   ibort   rD   rE   _run_postprocessing   s   ""z>Workflow._register_postprocessing.<locals>._run_postprocessingNrR   rr   )ri   rt   r{   rD   rz   rE   rf      s   
z!Workflow._register_postprocessingk_metricro   add_metricsc                   s   t |tstdt|j dt| d  j_t|}|dur=t	|dkr=t |ts8tdt|j d|
| | D ]
\}}| | qA tjd fd
d}dS )zi
        Register the key metric and additional metrics to the engine, supports ignite Metrics.

        z+`key_metric` must be None or a dict but is .r   Nz1Additional metrics must be None or a dict but is r>   r   r;   r<   c                   s   | j j}|d ur?| j j| }t|std d S | j jdks' || j jrA j	
d| d|  || j _| j j| j _d S d S d S )NzKey metric is not a scalar value, skip the metric comparison with the current best metric.Please set other metrics as the key metric, or change the `reduction` mode to 'mean'.rG   zGot new best metric of z: )rA   rO   rM   r   warningswarnrQ   r0   rP   loggerinforB   )r>   rO   Zcurrent_val_metricri   rD   rE   _compare_metrics   s    	z4Workflow._register_metrics.<locals>._compare_metrics)r>   r   r;   r<   )rV   ro   rZ   type__name__ra   rm   rA   rO   rY   updateitemsattachrX   r   ZEPOCH_COMPLETED)ri   r|   r}   rM   rj   metricr   rD   r   rE   rg      s   



zWorkflow._register_metricsr   c                 C  s    t |}|D ]}||  qdS )zc
        Register the handlers to the engine, supports ignite Handlers with `attach` API.

        N)r   r   )ri   r1   Z	handlers_handlerrD   rD   rE   rh     s   zWorkflow._register_handlersc                   s4   | j jdkrtd dS t j| j| j jd dS )zT
        Execute training, validation or evaluation based on Ignite Engine.
        r   z`dataloader` is empty or the specified `epoch_length` is 0, skip the `run`. If running distributed training, the program may hang in `all-gather`, `all-reduce`, etc. because not all the ranks run the same computation logic.N)datar    )rA   r$   r   r   rS   runr"   r    r   rk   rD   rE   r     s   zWorkflow.runr>   r   	batchdatadict[str, torch.Tensor]c                 C  s   t d| jj d)a  
        Abstract callback function for the processing logic of 1 iteration in Ignite Engine.
        Need subclass to implement different logics, like SupervisedTrainer/Evaluator, GANTrainer, etc.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

        Raises:
            NotImplementedError: When the subclass does not override this method.

        z	Subclass z must implement this method.)NotImplementedErrorrl   r   )ri   r>   r   rD   rD   rE   rU     s   zWorkflow._iterationc                   s    fdd|D S )a!  
        Get the statistics information of the workflow process.

        Args:
            vars: variables name in the `self.state`, will use the variable name as the key
                and the state content as the value. if the variable doesn't exist, default value is `None`.

        c                   s   i | ]
}|t  j|d qS r?   )rW   rA   ).0kr   rD   rE   
<dictcomp>4  s    z&Workflow.get_stats.<locals>.<dictcomp>rD   )ri   varsrD   r   rE   	get_stats+  s   	zWorkflow.get_stats)&r   r   r    r!   r"   r#   r$   r%   r&   r'   r(   r   r)   r*   r+   r,   r-   r.   r/   r.   r0   r   r1   r2   r3   r'   r4   r5   r6   r7   r8   r'   r9   r7   r:   r7   r;   r<   )rt   r   r;   r<   r?   )r|   ro   r}   r7   r;   r<   )r1   r   r;   r<   )r;   r<   )r>   r   r   r   r;   ro   )r   
__module____qualname____doc__r   r   rT   re   rf   rg   rh   r   rU   r   __classcell__rD   rD   rk   rE   r   /   s2    <]

#	
r   )+
__future__r   r   collections.abcr   r   r   r   typingr   r   r_   torch.distributeddistributedr[   torch.utils.datar	   torch.utils.data.distributedr
   monai.engines.utilsr   r   r   monai.transformsr   monai.utilsr   r   r   r   r   utilsr   OPT_IMPORT_VERSIONr   _r   ignite.enginer   r   ignite.metricsr   r   rD   rD   rD   rE   <module>   s6   


