o
    1i                      @  sd  d dl mZ d dlZd dlZd dlmZmZ d dlmZm	Z	 d dl
Z
d dlmZ d dlmZ d dlmZ d dlmZ d dlmZmZmZmZmZ d d	lmZmZmZ d d
lmZm Z  d dl!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' d dl(m)Z) d dl*m+Z+m,Z, d dl-m.Z.m/Z/ d dl0m1Z1 ee2Z3d&ddZ4dd Z5d'ddZ6G dd  d e Z7e/d!d"e.d#G d$d% d%ee7Z8dS )(    )annotationsN)MappingMutableMapping)Anycast)DataAnalyzer)
get_logger)SegSummarizer)BundleWorkflowConfigComponent
ConfigItemConfigParserConfigWorkflow)SupervisedEvaluatorSupervisedTrainerTrainer)
ClientAlgoClientAlgoStats)
ExtraItemsFiltersTypeFlPhaseFlStatistics	ModelType
WeightType)ExchangeObject)copy_model_stateget_state_dict)min_versionrequire_pkg)DataStatsKeysglobal_weightsr   local_var_dictr   returntuple[MutableMapping, int]c                 C  s   |   }d}|D ]6}||v r>| | }ztt||| j}|||< |d7 }W q ty= } z	td| d|d}~ww q||fS )zAHelper function to convert global weights to local weights formatr      zConvert weight from z failed.N)keystorchreshape	as_tensorshape	Exception
ValueError)r    r!   Z
model_keysn_convertedvar_nameweightse r0   \/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/fl/client/monai_algo.pyconvert_global_weights%   s   
r2   c                 C  s   | d u rt d|d u rt di }d}| D ]+}||vrq||  | |   ||< |d7 }tt|| rAt d| dq|dkrJtd|S )Nz>Cannot compute weight differences if `global_weights` is None!z>Cannot compute weight differences if `local_var_dict` is None!r   r$   zWeights for z became NaN...zNo weight differences computed!)r+   cpur&   anyisnanRuntimeError)r    r!   Zweight_diffZn_diffnamer0   r0   r1   compute_weight_diff8   s"   r8   parserr   Nonec                 C  s<   d| v r| d D ]}t |rd|d v rd|d< qd S d S )Nzvalidate#handlersCheckpointLoader_target_T
_disabled_)r   is_instantiable)r9   hr0   r0   r1   disable_ckpt_loadersM   s   
r@   c                   @  sZ   e Zd ZdZ					d d!ddZd"ddZd"d#ddZd"ddZedd Z	dd Z
dS )$MonaiAlgoStatsa7  
    Implementation of ``ClientAlgoStats`` to allow federated learning with MONAI bundle configurations.

    Args:
        bundle_root: directory path of the bundle.
        config_train_filename: bundle training config path relative to bundle_root. Can be a list of files;
            defaults to "configs/train.json". only useful when `workflow` is None.
        config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
        data_stats_transform_list: transforms to apply for the data stats result.
        histogram_only: whether to only compute histograms. Defaults to False.
        workflow: the bundle workflow to execute, usually it's training, evaluation or inference.
            if None, will create an `ConfigWorkflow` internally based on `config_train_filename`.
    configs/train.jsonNFbundle_rootstrconfig_train_filenamestr | list | Noneconfig_filters_filenamedata_stats_transform_listlist | Nonehistogram_onlyboolworkflowBundleWorkflow | Nonec                 C  s   t | _ || _|| _|| _d| _d| _|| _|| _d | _|d ur5t	|t
s(td| d u r2td|| _d | _d| _d | _tj| _d | _d S )Ntrainevalz.workflow must be a subclass of BundleWorkflow.z"workflow doesn't specify the type. )loggerrC   rE   rG   train_data_keyeval_data_keyrH   rJ   rL   
isinstancer
   r+   get_workflow_typeclient_nameapp_rootpost_statistics_filtersr   IDLEphasedataset_root)selfrC   rE   rG   rH   rJ   rL   r0   r0   r1   __init__d   s(   	

zMonaiAlgoStats.__init__c                 C  s  |du ri }| tjd| _| tjd}| jd| j d | tjd| _t	j
| j| j| _| jdu rF| | j}t|d|dd| _| j  | j| j_| j  | | j}t }t|dkrv|| |jtjtdtjd	| _| jd
| j d dS )  
        Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.

        Args:
            extra: Dict with additional information that should be provided by FL system,
                i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
                You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.

        NnonameInitializing  ...rP   rN   config_file	meta_filelogging_fileworkflow_typer   defaultInitialized .)getr   CLIENT_NAMErV   LOGGING_FILErQ   infoAPP_ROOTrW   ospathjoinrC   rL   _add_config_filesrE   r   
initializerG   r   lenread_configget_parsed_contentr   POST_STATISTICS_FILTERSr   rX   )r\   extrare   config_train_filesconfig_filter_filesfilter_parserr0   r0   r1   rt      s.   





zMonaiAlgoStats.initializery   dict | Noner"   r   c              	   C  sp  |du rt d| jjrtj| _| jd| jj  tj	|vr$t d|tj	 }tj
|vr2t d|tj
 }i }| j| jj| j||tj| jdd\}}|rX|| j|i d}d}| jjdurx| j| jj| j||tj| jdd\}}n| jd	 |r|| j|i |r|r| ||g||}	|tj|	i t|d
}
| jdur| jD ]}||
|}
q|
S t d)aX  
        Returns summary statistics about the local data.

        Args:
            extra: Dict with additional information that can be provided by the FL system.
                    Both FlStatistics.HIST_BINS and FlStatistics.HIST_RANGE must be provided.

        Returns:
            stats: ExchangeObject with summary statistics.

        Nz`extra` has to be setzComputing statistics on z1FlStatistics.NUM_OF_BINS not specified in `extra`z0FlStatistics.HIST_RANGE not specified in `extra`ztrain_data_stats.yaml)datadata_key	hist_bins
hist_rangeoutput_pathzeval_data_stats.yamlz0the datalist doesn't contain validation section.)
statisticszdata_root not set!)r+   rL   dataset_dirr   GET_DATA_STATSrZ   rQ   rn   r   	HIST_BINS
HIST_RANGE_get_data_key_statstrain_dataset_datarR   rp   rq   rr   rW   updateval_dataset_datarS   warning_compute_total_stats
TOTAL_DATAr   rX   )r\   ry   r   r   Z
stats_dictZtrain_summary_statsZtrain_case_statsZeval_summary_statsZeval_case_statstotal_summary_statsstats_filterr0   r0   r1   get_data_stats   sZ   








zMonaiAlgoStats.get_data_statsc           
      C  s   t ||i| jj|||| jd}| j| j d| d |j| j|d}|t	j
 }tj|t	j tjt|tjt|t| i}	|	|fS )N)datalistdatarootr   r   r   rJ   z compute data statistics on z...)transform_listkey)r   rL   r   rJ   rQ   rn   rV   get_all_case_statsrH   r   BY_CASEr   
DATA_STATSSUMMARY
DATA_COUNTru   
FAIL_COUNT)
r\   r~   r   r   r   r   analyzerZ	all_statsZ
case_statssummary_statsr0   r0   r1   r      s    	

z"MonaiAlgoStats._get_data_key_statsc                 C  sR   g }| D ]}||7 }qt dddd||d}||}tj|tjt|tjdi}|S )NimagelabelT)averagedo_ccpr   r   r   )r	   	summarizer   r   r   ru   r   )Zcase_stats_listsr   r   Ztotal_case_statsZcase_stats_list
summarizerr   r   r0   r0   r1   r     s   


z#MonaiAlgoStats._compute_total_statsc                 C  s   g }|rJt |tr|tj| j| |S t |tr>|D ]}t |tr0|tj| j| qtdt	| d| |S tdt	| d| |S )Nz/Expected config file to be of type str but got z: z8Expected config files to be of type str or list but got )
rT   rD   appendrp   rq   rr   rC   listr+   type)r\   config_filesfilesfiler0   r0   r1   rs   $  s   


z MonaiAlgoStats._add_config_files)rB   NNFN)rC   rD   rE   rF   rG   rF   rH   rI   rJ   rK   rL   rM   N)ry   r}   r"   r   )__name__
__module____qualname____doc__r]   rt   r   r   staticmethodr   rs   r0   r0   r0   r1   rA   U   s    
(
M
rA   ignitez0.4.10)pkg_nameversionversion_checkerc                   @  s   e Zd ZdZ														
		d:d;d#d$Zd<d%d&Zd<d=d,d-Zd<d.d/Zd<d>d0d1Zd<d2d3Z	d<d?d4d5Z
d6d7 Zd8d9 ZdS )@	MonaiAlgoa
  
    Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations.

    Args:
        bundle_root: directory path of the bundle.
        local_epochs: number of local epochs to execute during each round of local training; defaults to 1.
        send_weight_diff: whether to send weight differences rather than full weights; defaults to `True`.
        config_train_filename: bundle training config path relative to bundle_root. can be a list of files.
            defaults to "configs/train.json". only useful when `train_workflow` is None.
        train_kwargs: other args of the `ConfigWorkflow` of train, except for `config_file`, `meta_file`,
            `logging_file`, `workflow_type`. only useful when `train_workflow` is None.
        config_evaluate_filename: bundle evaluation config path relative to bundle_root. can be a list of files.
            if "default", ["configs/train.json", "configs/evaluate.json"] will be used.
            this arg is only useful when `eval_workflow` is None.
        eval_kwargs: other args of the `ConfigWorkflow` of evaluation, except for `config_file`, `meta_file`,
            `logging_file`, `workflow_type`. only useful when `eval_workflow` is None.
        config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
        disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`.
        best_model_filepath: location of best model checkpoint; defaults "models/model.pt" relative to `bundle_root`.
        final_model_filepath: location of final model checkpoint; defaults "models/model_final.pt" relative to `bundle_root`.
        save_dict_key: If a model checkpoint contains several state dicts,
            the one defined by `save_dict_key` will be returned by `get_weights`; defaults to "model".
            If all state dicts should be returned, set `save_dict_key` to None.
        data_stats_transform_list: transforms to apply for the data stats result.
        eval_workflow_name: the workflow name corresponding to the "config_evaluate_filename", default to "train"
            as the default "config_evaluate_filename" overrides the train workflow config.
            this arg is only useful when `eval_workflow` is None.
        train_workflow: the bundle workflow to execute training, if None, will create a `ConfigWorkflow` internally
            based on `config_train_filename` and `train_kwargs`.
        eval_workflow: the bundle workflow to execute evaluation, if None, will create a `ConfigWorkflow` internally
            based on `config_evaluate_filename`, `eval_kwargs`, `eval_workflow_name`.

    r$   TrB   Nrh   models/model.ptmodels/model_final.ptmodelrN   rC   rD   local_epochsintsend_weight_diffrK   rE   rF   train_kwargsr}   config_evaluate_filenameeval_kwargsrG   disable_ckpt_loadingbest_model_filepath
str | Nonefinal_model_filepathsave_dict_keyrH   rI   eval_workflow_nametrain_workflowrM   eval_workflowc                 C  sJ  t | _ || _|| _|| _|| _|d u ri n|| _|dkr ddg}|| _|d u r)i n|| _|| _|	| _	t
j|
t
j|i| _|| _|| _|| _d | _d | _|d uret|trY| dkrbtdtj d|| _|d ur{t|trt| d u rxtd|| _d | _d| _d | _d | _d | _d | _d | _d | _d	| _ d | _!t"j#| _$d | _%d | _&d S )
Nrh   rB   zconfigs/evaluate.jsonrN   z6train workflow must be BundleWorkflow and set type in rj   z3train workflow must be BundleWorkflow and set type.rP   r   )'rQ   rC   r   r   rE   r   r   r   rG   r   r   
BEST_MODELFINAL_MODELmodel_filepathsr   rH   r   r   r   rT   r
   rU   r+   supported_train_typestats_senderrW   r|   trainer	evaluatorpre_filterspost_weight_filterspost_evaluate_filtersiter_of_start_timer    r   rY   rZ   rV   r[   )r\   rC   r   r   rE   r   r   r   rG   r   r   r   r   rH   r   r   r   r0   r0   r1   r]   Z  sR   
zMonaiAlgo.__init__c                 C  s  |    |du r
i }|tjd| _|tjd}td}| j	d| j d |tj
d| _tj| j| j| _| jdu ri| jduri| | j}d| jvr[| j d| | jd< td|d|d	d
| j| _| jdur| j  | j| j_| j| j_| jrt| jtrt| jjd | j  | jj| _t| jtstdt| j d| j du r| j!dur| | j!}d| j"vr| j d| | j"d< td|d|| j#d
| j"| _ | j dur| j   | j| j _| jrt| j trt| j jd | j   | j j$| _$t| j$t%stdt| j$ d| | j&}t' | _(t)|dkr*| j(*| |tj+| j,| _,| j,durG| j,-| j | j,-| j$ | j(j.t/j0t1dt/j0d| _2| j(j.t/j3t1dt/j3d| _4| j(j.t/j5t1dt/j5d| _6| j(j.t/j7t1dt/j7d| _8| j	d| j d dS )r^   Nr_   z%Y%m%d_%H%M%Sr`   ra   rP   run_name_rN   rb   )r9   z,trainer must be SupervisedTrainer, but got: rj   z0evaluator must be SupervisedEvaluator, but got: r   rg   ri   r0   )9_set_cuda_devicerk   r   rl   rV   rm   timestrftimerQ   rn   ro   rW   rp   rq   rr   rC   r   rE   rs   r   r   rt   r   
max_epochsr   rT   r@   r9   r   r   r+   r   r   r   r   r   r   r   rG   r   r|   ru   rv   STATS_SENDERr   attachrw   r   PRE_FILTERSr   r   POST_WEIGHT_FILTERSr   POST_EVALUATE_FILTERSr   rx   rX   )r\   ry   re   	timestamprz   Zconfig_eval_filesr{   r0   r0   r1   rt     s   













zMonaiAlgo.initializer~   r   ry   r"   r:   c                 C  s2  |    |du r
i }t|tstdt| | jdu r!td| jdur1| jD ]}|||}q)tj| _	| j
d| j d t| jj}ttt|j|d\| _}| |j|| | jjj| j | jj_| jjj| _ttt| j| jjd\}}}t|dkr| j
d	 | j
d
| j d | j  dS )z
        Train on client's local data.

        Args:
            data: `ExchangeObject` containing the current global model weights.
            extra: Dict with additional information that can be provided by the FL system.

        N0expected data to be ExchangeObject but received z self.trainer should not be None.Load  weights...r    r!   srcdstr   No weights loaded!Start z training...) r   rT   r   r+   r   r   r   r   TRAINrZ   rQ   rn   rV   r   networkr2   r   dictr.   r    _check_convertedstateepochr   r   	iterationr   r   r   ru   r   run)r\   r~   ry   r   r!   r,   r   updated_keysr0   r0   r1   rN     s0   




zMonaiAlgo.trainc           
      C  s  |    |du r
i }tj| _tj|v r}|tj}t|ts't	dt
| || jv rrtj| jtt| j| }tj|sGt	d| tj|ddd}t|tr_| j|v r_|| j}tj}i }| jd| d| d	 nXt	d
| d| j | jrt| jj}| D ]
}||  ||< qtj}| j  }| jj!j"| j# |t$j%< | j&rt'| j(|d}tj)}| jd n| jd nd}d}t }t|tst	d| t*|d||d}| j+dur| j+D ]}	|	||}q|S )av  
        Returns the current weights of the model.

        Args:
            extra: Dict with additional information that can be provided by the FL system.

        Returns:
            return_weights: `ExchangeObject` containing current weights (default)
                or load requested model type from disk (`ModelType.BEST_MODEL` or `ModelType.FINAL_MODEL`).

        NzEExpected requested model type to be of type `ModelType` but received z#No best model checkpoint exists at r3   T)map_locationweights_onlyz
Returning z checkpoint weights from rj   zRequested model type z% not specified in `model_filepaths`: r   z%Returning current weight differences.zReturning current weights.zstats is not a dict, )r.   optimweight_typer   ),r   r   GET_WEIGHTSrZ   r   
MODEL_TYPErk   rT   r   r+   r   r   rp   rq   rr   rC   r   rD   isfiler&   loadr   r   r   WEIGHTSrQ   rn   r   r   r   r%   r3   	get_statsr   r   r   r   NUM_EXECUTED_ITERATIONSr   r8   r    WEIGHT_DIFFr   r   )
r\   ry   
model_typeZ
model_pathr.   Z
weigh_typer   kZreturn_weightsr   r0   r0   r1   get_weights#  sd   






zMonaiAlgo.get_weightsc           
      C  sZ  |    |du r
i }t|tstdt| | jdu r!td| jdur1| jD ]}|||}q)tj| _	| j
d| j d t| jj}ttt|j|d\}}| |j|| t|| jjd\}}}t|dkrq| j
d	 | j
d
| j d t| jtr| j| jjjd  n| j  t| jjjd}	| jdur| jD ]}||	|}	q|	S )aK  
        Evaluate on client's local data.

        Args:
            data: `ExchangeObject` containing the current global model weights.
            extra: Dict with additional information that can be provided by the FL system.

        Returns:
            return_metrics: `ExchangeObject` containing evaluation metrics.

        Nr   z"self.evaluator should not be None.r   r   r   r   r   r   r   z evaluating...r$   )metrics)r   rT   r   r+   r   r   r   r   EVALUATErZ   rQ   rn   rV   r   r   r2   r   r   r.   r   r   ru   r   r   r   r   r   r   r   r   )
r\   r~   ry   r   r!   r    r,   r   r   Zreturn_metricsr0   r0   r1   evaluater  s:   







zMonaiAlgo.evaluatec                 C  s~   | j d| j d| j d t| jtr%| j d| j d | j  t| jtr=| j d| j d | j  dS dS )z
        Abort the training or evaluation.
        Args:
            extra: Dict with additional information that can be provided by the FL system.
        z	Aborting  during  phase. trainer... evaluator...N)	rQ   rn   rV   rZ   rT   r   r   Z	interruptr   r\   ry   r0   r0   r1   abort  s   
zMonaiAlgo.abortc                 C  s   | j d| j d| j d t| jtr%| j d| j d | j  t| jtr;| j d| j d | j  | j	durE| j	
  | jdurQ| j
  dS dS )z
        Finalize the training or evaluation.
        Args:
            extra: Dict with additional information that can be provided by the FL system.
        zTerminating r  r  r  r  N)rQ   rn   rV   rZ   rT   r   r   	terminater   r   finalizer   r  r0   r0   r1   r	    s   




zMonaiAlgo.finalizec                 C  s@   |dkrt dt|  | jd| dt| d d S )Nr   z;No global weights converted! Received weight dict keys are z
Converted z global variables to match z local variables.)r6   r   r%   rQ   rn   ru   )r\   r    r!   r,   r0   r0   r1   r     s   zMonaiAlgo._check_convertedc                 C  s.   t  rttjd | _tj| j d S d S )N
LOCAL_RANK)	distis_initializedr   rp   environrankr&   cuda
set_device)r\   r0   r0   r1   r     s   zMonaiAlgo._set_cuda_device)r$   TrB   Nrh   NNTr   r   r   NrN   NN) rC   rD   r   r   r   rK   rE   rF   r   r}   r   rF   r   r}   rG   rF   r   rK   r   r   r   r   r   r   rH   rI   r   rD   r   rM   r   rM   r   )r~   r   ry   r}   r"   r:   )r~   r   ry   r}   r"   r   )ry   r}   r"   r:   )r   r   r   r   r]   rt   rN   r   r  r  r	  r   r   r0   r0   r0   r1   r   6  s4    %
A`
(O
0
r   )r    r   r!   r   r"   r#   )r9   r   r"   r:   )9
__future__r   rp   r   collections.abcr   r   typingr   r   r&   torch.distributeddistributedr  "monai.apps.auto3dseg.data_analyzerr   monai.apps.utilsr   monai.auto3dsegr	   monai.bundler
   r   r   r   r   monai.enginesr   r   r   Zmonai.fl.clientr   r   monai.fl.utils.constantsr   r   r   r   r   r   monai.fl.utils.exchange_objectr   monai.networks.utilsr   r   monai.utilsr   r   monai.utils.enumsr   r   rQ   r2   r8   r@   rA   r   r0   r0   r0   r1   <module>   s4    

 b