U
    Ph                     @  sr  d dl mZ d dlZd dlZd dlmZmZ d dlmZm	Z	 d dl
Z
d dlmZ d dlmZ d dlmZ d dlmZ d dlmZmZmZmZmZ d d	lmZmZmZ d d
lmZm Z  d dl!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' d dl(m)Z) d dl*m+Z+m,Z, d dl-m.Z.m/Z/ d dl0m1Z1 ee2Z3ddddddZ4dd Z5dddddZ6G dd de Z7e/dd e.d!G d"d# d#ee7Z8dS )$    )annotationsN)MappingMutableMapping)Anycast)DataAnalyzer)
get_logger)SegSummarizer)BundleWorkflowConfigComponent
ConfigItemConfigParserConfigWorkflow)SupervisedEvaluatorSupervisedTrainerTrainer)
ClientAlgoClientAlgoStats)
ExtraItemsFiltersTypeFlPhaseFlStatistics	ModelType
WeightType)ExchangeObject)copy_model_stateget_state_dict)min_versionrequire_pkg)DataStatsKeysr   r   ztuple[MutableMapping, int])global_weightslocal_var_dictreturnc                 C  s   |   }d}|D ]v}||kr| | }z,tt||| j}|||< |d7 }W q tk
r } ztd| d|W 5 d}~X Y qX q||fS )zAHelper function to convert global weights to local weights formatr      zConvert weight from z failed.N)keystorchreshape	as_tensorshape	Exception
ValueError)r    r!   Z
model_keysn_convertedvar_nameweightse r/   O/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/fl/client/monai_algo.pyconvert_global_weights%   s    &r1   c                 C  s   | d krt d|d kr t di }d}| D ]V}||kr:q,||  | |   ||< |d7 }tt|| r,t d| dq,|dkrtd|S )Nz>Cannot compute weight differences if `global_weights` is None!z>Cannot compute weight differences if `local_var_dict` is None!r   r#   zWeights for z became NaN...zNo weight differences computed!)r*   cpur%   anyisnanRuntimeError)r    r!   Zweight_diffZn_diffnamer/   r/   r0   compute_weight_diff8   s     r7   r   None)parserr"   c                 C  s8   d| kr4| d D ]"}t |rd|d krd|d< qd S )Nzvalidate#handlersCheckpointLoader_target_T
_disabled_)r   is_instantiable)r9   hr/   r/   r0   disable_ckpt_loadersM   s
    
r?   c                   @  sd   e Zd ZdZddddddd	d
ddZdddZddddddZdddZedd Z	dd Z
dS )MonaiAlgoStatsa7  
    Implementation of ``ClientAlgoStats`` to allow federated learning with MONAI bundle configurations.

    Args:
        bundle_root: directory path of the bundle.
        config_train_filename: bundle training config path relative to bundle_root. Can be a list of files;
            defaults to "configs/train.json". only useful when `workflow` is None.
        config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
        data_stats_transform_list: transforms to apply for the data stats result.
        histogram_only: whether to only compute histograms. Defaults to False.
        workflow: the bundle workflow to execute, usually it's training, evaluation or inference.
            if None, will create an `ConfigWorkflow` internally based on `config_train_filename`.
    configs/train.jsonNFstrstr | list | Nonelist | NoneboolBundleWorkflow | None)bundle_rootconfig_train_filenameconfig_filters_filenamedata_stats_transform_listhistogram_onlyworkflowc                 C  s   t | _ || _|| _|| _d| _d| _|| _|| _d | _|d k	rjt	|t
sPtd| d krdtd|| _d | _d| _d | _tj| _d | _d S )Ntrainevalz.workflow must be a subclass of BundleWorkflow.z"workflow doesn't specify the type. )loggerrG   rH   rI   train_data_keyeval_data_keyrJ   rK   rL   
isinstancer
   r*   get_workflow_typeclient_nameapp_rootpost_statistics_filtersr   IDLEphasedataset_root)selfrG   rH   rI   rJ   rK   rL   r/   r/   r0   __init__d   s(    	
zMonaiAlgoStats.__init__c                 C  s  |dkri }| tjd| _| tjd}| jd| j d | tjd| _t	j
| j| j| _| jdkr| | j}t|d|dd| _| j  | j| j_| j  | | j}t }t|dkr|| |jtjtdtjd	| _| jd
| j d dS )  
        Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.

        Args:
            extra: Dict with additional information that should be provided by FL system,
                i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
                You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.

        NnonameInitializing  ...rO   rM   config_file	meta_filelogging_fileworkflow_typer   defaultInitialized .)getr   CLIENT_NAMErU   LOGGING_FILErP   infoAPP_ROOTrV   ospathjoinrG   rL   _add_config_filesrH   r   
initializerI   r   lenread_configget_parsed_contentr   POST_STATISTICS_FILTERSr   rW   )r[   extrard   config_train_filesconfig_filter_filesfilter_parserr/   r/   r0   rs      s6    

   



 
zMonaiAlgoStats.initializedict | Noner   rx   r"   c              	   C  s  |dkrt d| jjrxtj| _| jd| jj  tj	|krLt dn
|tj	 }tj
|krjt dn
|tj
 }i }| j| jj| j||tj| jdd\}}|r|| j|i d}d}| jjdk	r| j| jj| j||tj| jdd\}}n| jd	 |r|| j|i |rF|rF| ||g||}	|tj|	i t|d
}
| jdk	rt| jD ]}||
|}
qb|
S t ddS )aX  
        Returns summary statistics about the local data.

        Args:
            extra: Dict with additional information that can be provided by the FL system.
                    Both FlStatistics.HIST_BINS and FlStatistics.HIST_RANGE must be provided.

        Returns:
            stats: ExchangeObject with summary statistics.

        Nz`extra` has to be setzComputing statistics on z1FlStatistics.NUM_OF_BINS not specified in `extra`z0FlStatistics.HIST_RANGE not specified in `extra`ztrain_data_stats.yaml)datadata_key	hist_bins
hist_rangeoutput_pathzeval_data_stats.yamlz0the datalist doesn't contain validation section.)
statisticszdata_root not set!)r*   rL   dataset_dirr   GET_DATA_STATSrY   rP   rm   r   	HIST_BINS
HIST_RANGE_get_data_key_statstrain_dataset_datarQ   ro   rp   rq   rV   updateval_dataset_datarR   warning_compute_total_stats
TOTAL_DATAr   rW   )r[   rx   r   r   Z
stats_dictZtrain_summary_statsZtrain_case_statsZeval_summary_statsZeval_case_statstotal_summary_statsstats_filterr/   r/   r0   get_data_stats   s^    







  

zMonaiAlgoStats.get_data_statsc           
      C  s   t ||i| jj|||| jd}| j| j d| d |j| j|d}|t	j
 }tj|t	j tjt|tjt|t| i}	|	|fS )N)datalistdatarootr   r   r   rK   z compute data statistics on z...)transform_listkey)r   rL   r   rK   rP   rm   rU   get_all_case_statsrJ   r   BY_CASEr   
DATA_STATSSUMMARY
DATA_COUNTrt   
FAIL_COUNT)
r[   r~   r   r   r   r   analyzerZ	all_statsZ
case_statssummary_statsr/   r/   r0   r      s&    	
   z"MonaiAlgoStats._get_data_key_statsc                 C  sR   g }| D ]}||7 }qt dddd||d}||}tj|tjt|tjdi}|S )NimagelabelT)averagedo_ccpr   r   r   )r	   	summarizer   r   r   rt   r   )Zcase_stats_listsr   r   Ztotal_case_statsZcase_stats_list
summarizerr   r   r/   r/   r0   r     s(    
     
   z#MonaiAlgoStats._compute_total_statsc                 C  s   g }|rt |tr*|tj| j| nht |trz|D ]>}t |tr^|tj| j| q8tdt	| d| q8ntdt	| d| |S )Nz/Expected config file to be of type str but got z: z8Expected config files to be of type str or list but got )
rS   rB   appendro   rp   rq   rG   listr*   type)r[   config_filesfilesfiler/   r/   r0   rr   $  s    


z MonaiAlgoStats._add_config_files)rA   NNFN)N)N)N)__name__
__module____qualname____doc__r\   rs   r   r   staticmethodr   rr   r/   r/   r/   r0   r@   U   s        
(M

r@   ignitez0.4.10)pkg_nameversionversion_checkerc                   @  s   e Zd ZdZd*dddddddddddddddddddZd+ddZd,ddddddZd-ddZd.dddddd Zd/d!d"Z	d0ddd#d$d%Z
d&d' Zd(d) ZdS )1	MonaiAlgoa
  
    Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations.

    Args:
        bundle_root: directory path of the bundle.
        local_epochs: number of local epochs to execute during each round of local training; defaults to 1.
        send_weight_diff: whether to send weight differences rather than full weights; defaults to `True`.
        config_train_filename: bundle training config path relative to bundle_root. can be a list of files.
            defaults to "configs/train.json". only useful when `train_workflow` is None.
        train_kwargs: other args of the `ConfigWorkflow` of train, except for `config_file`, `meta_file`,
            `logging_file`, `workflow_type`. only useful when `train_workflow` is None.
        config_evaluate_filename: bundle evaluation config path relative to bundle_root. can be a list of files.
            if "default", ["configs/train.json", "configs/evaluate.json"] will be used.
            this arg is only useful when `eval_workflow` is None.
        eval_kwargs: other args of the `ConfigWorkflow` of evaluation, except for `config_file`, `meta_file`,
            `logging_file`, `workflow_type`. only useful when `eval_workflow` is None.
        config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
        disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`.
        best_model_filepath: location of best model checkpoint; defaults "models/model.pt" relative to `bundle_root`.
        final_model_filepath: location of final model checkpoint; defaults "models/model_final.pt" relative to `bundle_root`.
        save_dict_key: If a model checkpoint contains several state dicts,
            the one defined by `save_dict_key` will be returned by `get_weights`; defaults to "model".
            If all state dicts should be returned, set `save_dict_key` to None.
        data_stats_transform_list: transforms to apply for the data stats result.
        eval_workflow_name: the workflow name corresponding to the "config_evaluate_filename", default to "train"
            as the default "config_evaluate_filename" overrides the train workflow config.
            this arg is only useful when `eval_workflow` is None.
        train_workflow: the bundle workflow to execute training, if None, will create a `ConfigWorkflow` internally
            based on `config_train_filename` and `train_kwargs`.
        eval_workflow: the bundle workflow to execute evaluation, if None, will create a `ConfigWorkflow` internally
            based on `config_evaluate_filename`, `eval_kwargs`, `eval_workflow_name`.

    r#   TrA   Nrg   models/model.ptmodels/model_final.ptmodelrM   rB   intrE   rC   r|   z
str | NonerD   rF   )rG   local_epochssend_weight_diffrH   train_kwargsconfig_evaluate_filenameeval_kwargsrI   disable_ckpt_loadingbest_model_filepathfinal_model_filepathsave_dict_keyrJ   eval_workflow_nametrain_workfloweval_workflowc                 C  sJ  t | _ || _|| _|| _|| _|d kr*i n|| _|dkr@ddg}|| _|d krRi n|| _|| _|	| _	t
j|
t
j|i| _|| _|| _|| _d | _d | _|d k	rt|tr| dkrtdtj d|| _|d k	rt|tr| d krtd|| _d | _d| _d | _d | _d | _d | _d | _d | _d	| _ d | _!t"j#| _$d | _%d | _&d S )
Nrg   rA   zconfigs/evaluate.jsonrM   z6train workflow must be BundleWorkflow and set type in ri   z3train workflow must be BundleWorkflow and set type.rO   r   )'rP   rG   r   r   rH   r   r   r   rI   r   r   
BEST_MODELFINAL_MODELmodel_filepathsr   rJ   r   r   r   rS   r
   rT   r*   supported_train_typestats_senderrV   r{   trainer	evaluatorpre_filterspost_weight_filterspost_evaluate_filtersiter_of_start_timer    r   rX   rY   rU   rZ   )r[   rG   r   r   rH   r   r   r   rI   r   r   r   r   rJ   r   r   r   r/   r/   r0   r\   Z  sR    zMonaiAlgo.__init__c                 C  s*  |    |dkri }|tjd| _|tjd}td}| j	d| j d |tj
d| _tj| j| j| _| jdkr| jdk	r| | j}d| jkr| j d| | jd< tf |d|d	d
| j| _| jdk	rX| j  | j| j_| j| j_| jr t| jtr t| jjd | j  | jj| _t| jtsXtdt| j d| j dkr| j!dk	r| | j!}d| j"kr| j d| | j"d< tf |d|| j#d
| j"| _ | j dk	r8| j   | j| j _| jr t| j tr t| j jd | j   | j j$| _$t| j$t%s8tdt| j$ d| | j&}t' | _(t)|dkrf| j(*| |tj+| j,| _,| j,dk	r| j,-| j | j,-| j$ | j(j.t/j0t1dt/j0d| _2| j(j.t/j3t1dt/j3d| _4| j(j.t/j5t1dt/j5d| _6| j(j.t/j7t1dt/j7d| _8| j	d| j d dS )r]   Nr^   z%Y%m%d_%H%M%Sr_   r`   rO   run_name_rM   ra   )r9   z,trainer must be SupervisedTrainer, but got: ri   z0evaluator must be SupervisedEvaluator, but got: r   rf   rh   )9_set_cuda_devicerj   r   rk   rU   rl   timestrftimerP   rm   rn   rV   ro   rp   rq   rG   r   rH   rr   r   r   rs   r   
max_epochsr   rS   r?   r9   r   r   r*   r   r   r   r   r   r   r   rI   r   r{   rt   ru   STATS_SENDERr   attachrv   r   PRE_FILTERSr   r   POST_WEIGHT_FILTERSr   POST_EVALUATE_FILTERSr   rw   rW   )r[   rx   rd   	timestampry   Zconfig_eval_filesrz   r/   r/   r0   rs     s    











 
 
 
 
zMonaiAlgo.initializer   r8   )r~   rx   r"   c                 C  s4  |    |dkri }t|ts0tdt| | jdkrBtd| jdk	rb| jD ]}|||}qRtj| _	| j
d| j d t| jj}ttt|j|d\| _}| |j|| | jjj| j | jj_| jjj| _ttt| j| jjd\}}}t|dkr| j
d	 | j
d
| j d | j  dS )z
        Train on client's local data.

        Args:
            data: `ExchangeObject` containing the current global model weights.
            extra: Dict with additional information that can be provided by the FL system.

        N0expected data to be ExchangeObject but received z self.trainer should not be None.Load  weights...r    r!   srcdstr   No weights loaded!Start z training...) r   rS   r   r*   r   r   r   r   TRAINrY   rP   rm   rU   r   networkr1   r   dictr-   r    _check_convertedstateepochr   r   	iterationr   r   r   rt   r   run)r[   r~   rx   r   r!   r+   r   updated_keysr/   r/   r0   rM     s2    





 zMonaiAlgo.trainc           
      C  s  |    |dkri }tj| _tj|kr|tj}t|tsNt	dt
| || jkrtj| jtt| j| }tj|st	d| tj|dd}t|tr| j|kr|| j}tj}i }| jd| d| d nt	d	| d
| j n| jrt| jj}| D ]}||  ||< qtj}| j  }| jj!j"| j# |t$j%< | j&r~t'| j(|d}tj)}| jd n| jd nd}d}t }t|tst	d| t*|d||d}| j+dk	r| j+D ]}	|	||}q|S )av  
        Returns the current weights of the model.

        Args:
            extra: Dict with additional information that can be provided by the FL system.

        Returns:
            return_weights: `ExchangeObject` containing current weights (default)
                or load requested model type from disk (`ModelType.BEST_MODEL` or `ModelType.FINAL_MODEL`).

        NzEExpected requested model type to be of type `ModelType` but received z#No best model checkpoint exists at r2   )map_locationz
Returning z checkpoint weights from ri   zRequested model type z% not specified in `model_filepaths`: r   z%Returning current weight differences.zReturning current weights.zstats is not a dict, )r-   optimweight_typer   ),r   r   GET_WEIGHTSrY   r   
MODEL_TYPErj   rS   r   r*   r   r   ro   rp   rq   rG   r   rB   isfiler%   loadr   r   r   WEIGHTSrP   rm   r   r   r   r$   r2   	get_statsr   r   r   r   NUM_EXECUTED_ITERATIONSr   r7   r    WEIGHT_DIFFr   r   )
r[   rx   
model_typeZ
model_pathr-   Z
weigh_typer   kZreturn_weightsr   r/   r/   r0   get_weights#  sd    




zMonaiAlgo.get_weightsc           
      C  s`  |    |dkri }t|ts0tdt| | jdkrBtd| jdk	rb| jD ]}|||}qRtj| _	| j
d| j d t| jj}ttt|j|d\}}| |j|| t|| jjd\}}}t|dkr| j
d	 | j
d
| j d t| jtr| j| jjjd  n
| j  t| jjjd}	| jdk	r\| jD ]}||	|}	qJ|	S )aK  
        Evaluate on client's local data.

        Args:
            data: `ExchangeObject` containing the current global model weights.
            extra: Dict with additional information that can be provided by the FL system.

        Returns:
            return_metrics: `ExchangeObject` containing evaluation metrics.

        Nr   z"self.evaluator should not be None.r   r   r   r   r   r   r   z evaluating...r#   )metrics)r   rS   r   r*   r   r   r   r   EVALUATErY   rP   rm   rU   r   r   r1   r   r   r-   r   r   rt   r   r   r   r   r   r   r   r   )
r[   r~   rx   r   r!   r    r+   r   r   Zreturn_metricsr/   r/   r0   evaluater  s<    




 


zMonaiAlgo.evaluatec                 C  sz   | j d| j d| j d t| jtrJ| j d| j d | j  t| jtrv| j d| j d | j  dS )z
        Abort the training or evaluation.
        Args:
            extra: Dict with additional information that can be provided by the FL system.
        z	Aborting  during  phase. trainer... evaluator...N)	rP   rm   rU   rY   rS   r   r   Z	interruptr   r[   rx   r/   r/   r0   abort  s    
zMonaiAlgo.abortr}   c                 C  s   | j d| j d| j d t| jtrJ| j d| j d | j  t| jtrv| j d| j d | j  | j	dk	r| j	
  | jdk	r| j
  dS )z
        Finalize the training or evaluation.
        Args:
            extra: Dict with additional information that can be provided by the FL system.
        zTerminating r   r   r  r  N)rP   rm   rU   rY   rS   r   r   	terminater   r   finalizer   r  r/   r/   r0   r    s    




zMonaiAlgo.finalizec                 C  sB   |dkr t dt|  n| jd| dt| d d S )Nr   z;No global weights converted! Received weight dict keys are z
Converted z global variables to match z local variables.)r5   r   r$   rP   rm   rt   )r[   r    r!   r+   r/   r/   r0   r     s    zMonaiAlgo._check_convertedc                 C  s*   t  r&ttjd | _tj| j d S )N
LOCAL_RANK)	distis_initializedr   ro   environrankr%   cuda
set_device)r[   r/   r/   r0   r     s    zMonaiAlgo._set_cuda_device)r#   TrA   Nrg   NNTr   r   r   NrM   NN)N)N)N)N)N)N)r   r   r   r   r\   rs   rM   r   r   r  r  r   r   r/   r/   r/   r0   r   6  s2   %               .A
`(
O0

r   )9
__future__r   ro   r   collections.abcr   r   typingr   r   r%   torch.distributeddistributedr  "monai.apps.auto3dseg.data_analyzerr   monai.apps.utilsr   monai.auto3dsegr	   monai.bundler
   r   r   r   r   monai.enginesr   r   r   Zmonai.fl.clientr   r   monai.fl.utils.constantsr   r   r   r   r   r   monai.fl.utils.exchange_objectr   monai.networks.utilsr   r   monai.utilsr   r   monai.utils.enumsr   r   rP   r1   r7   r?   r@   r   r/   r/   r/   r0   <module>   s2     b