U
    Phj                     @  s  d dl mZ d dlZd dlmZmZ d dlmZmZ d dl	m
Z
 d dlmZmZ d dlmZ d dlZd dlZd dlmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZ d dlm Z m!Z!m"Z"m#Z#m$Z$ d dl%m&Z& d dl'm(Z( d dl)m*Z*m+Z+m,Z, d dl-m.Z. d dl/m0Z0 d dl1m2Z2m3Z3 d dl4m5Z5m6Z6 e6ddd\Z7Z8ee9dZ:G dd deZ;G dd de;Z<G dd de;Z=G dd dZ>G dd  d Z?dS )!    )annotationsN)ABCabstractmethod)MappingSequence)deepcopy)Anycast)warn)
BundleAlgo)get_name_from_algo_idimport_bundle_algo_history)
get_logger)concat_val_to_np)_prepare_cmd_bcprun_prepare_cmd_torchrun_run_cmd_bcprun_run_cmd_torchrundatafold_read)ConfigParser)partition_dataset)MeanEnsemble	SaveImageVoteEnsemble)
RankFilter)AlgoKeys) check_kwargs_exist_in_class_init
prob2class)look_up_optionoptional_importtqdm)name)module_namec                   @  s   e Zd ZdZdd Zdd Zdd Zdd	 Zd dddddddZd!ddZ	dddddddZ
d"dddddZedd ZdS )#AlgoEnsemblez,
    The base class of Ensemble methods
    c                 C  s   g | _ d| _g | _g | _d S )Nmean)algosmodeinfer_filesalgo_ensembleself r+   Z/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/auto3dseg/ensemble_builder.py__init__6   s    zAlgoEnsemble.__init__c                 C  s   t || _dS )z0
        Register model in the ensemble
        N)r   r%   )r*   infer_algosr+   r+   r,   	set_algos<   s    zAlgoEnsemble.set_algosc                 C  s&   | j D ]}||tj kr|  S qdS )zn
        Get a model by identifier.

        Args:
            identifier: the name of the bundleAlgo
        N)r%   r   ID)r*   
identifieralgor+   r+   r,   get_algoB   s    
zAlgoEnsemble.get_algoc                 C  s   | j S )z
        Get the algo ensemble after ranking or a empty list if ranking was not started.

        Returns:
            A list of Algo
        )r(   r)   r+   r+   r,   get_algo_ensembleM   s    zAlgoEnsemble.get_algo_ensembletestingstrz
str | listNone)datarootdata_list_or_pathdata_keyreturnc                 C  s   g | _ t|tr|| _ ndt|trtt|}||krLt||d|d\| _ }q|t| dr`| jdkr|t	
d| d ntddS )	z
        Set the files to perform model inference.

        Args:
            dataroot: the path of the files
            data_list_or_path: the data source file path
        )datalistbasedirfoldkeyrankr   z#Datalist file has no testing key - z$. No data for inference is specifiedzUnsupported parameter typeN)r'   
isinstancelistr6   r   load_config_filer   hasattrrA   loggerinfo
ValueError)r*   r8   r9   r:   r=   _r+   r+   r,   set_infer_filesV   s    	


zAlgoEnsemble.set_infer_filesFc                   s   t dd |D r dd |D }| jdkrLt |}tttj|dd dS | jd	kr fd
d|D } rvt |S t|d jd d|S dS )a  
        ensemble the results using either "mean" or "vote" method

        Args:
            preds: a list of probability prediction in Tensor-Like format.
            sigmoid: use the sigmoid function to threshold probability one-hot map,
                otherwise argmax is used. Defaults to False

        Returns:
            a tensor which is the ensembled prediction.
        c                 s  s   | ]}|j  V  qd S N)is_cuda.0pr+   r+   r,   	<genexpr>z   s     z-AlgoEnsemble.ensemble_pred.<locals>.<genexpr>c                 S  s   g | ]}|  qS r+   cpurM   r+   r+   r,   
<listcomp>{   s     z.AlgoEnsemble.ensemble_pred.<locals>.<listcomp>r$   r   Tdimkeepdimsigmoidvotec                   s   g | ]}t |d d dqS )r   TrT   )r   rM   rW   r+   r,   rS      s     )num_classesN)	anyr&   r   r   r	   torchTensorr   shape)r*   predsrW   probclassesr+   rY   r,   ensemble_predm   s    



zAlgoEnsemble.ensemble_preddict)algo_spec_paramparam	algo_namer;   c                 C  s@   t |}t |}| D ]"\}}| | kr|| q|S )a  
        Apply the model-specific params to the prediction params based on the name of the Algo.

        Args:
            algo_spec_param: a dict that has structure of {"<name of algo>": "<pred_params for that algo>"}.
            param: the prediction params to override.
            algo_name: name of the Algo

        Returns:
            param after being updated with the model-specific param
        )r   itemslowerupdate)r*   rd   re   rf   Z_param_to_override_paramkvr+   r+   r,   _apply_algo_specific_param   s    z'AlgoEnsemble._apply_algo_specific_paramNdict | NonerC   )
pred_paramr;   c              	   C  s  |dkri nt |}| j}d|kr,|d}d|krF|d}|| }d|krj|d}t|ddgd| _|dd	}d
|krt|d
  }|di }g }	tr|r|dddkrt	t
|ddnt	|D ]\}
}g }| jD ]H}t|tj }|tj }| |||}|j|g|d}||d  qd
|krz| j||d}W n. tk
rv   | jdd |D |d}Y nX ||}t|drd|j kr|jd }ntd d}ntd | j||d}|	| q|	S )aG  
        Use the ensembled model to predict result.

        Args:
            pred_param: prediction parameter dictionary. The key has two groups: the first one will be consumed
                in this function, and the second group will be passed to the `InferClass` to override the
                parameters of the class functions.
                The first group contains:

                    - ``"infer_files"``: file paths to the images to read in a list.
                    - ``"files_slices"``: a value type of `slice`. The files_slices will slice the ``"infer_files"`` and
                      only make prediction on the infer_files[file_slices].
                    - ``"mode"``: ensemble mode. Currently "mean" and "vote" (majority voting) schemes are supported.
                    - ``"image_save_func"``: a dictionary used to instantiate the ``SaveImage`` transform. When specified,
                      the ensemble prediction will save the prediction files, instead of keeping the files in the memory.
                      Example: `{"_target_": "SaveImage", "output_dir": "./"}`
                    - ``"sigmoid"``: use the sigmoid function (e.g. x > 0.5) to convert the prediction probability map
                      to the label class prediction, otherwise argmax(x) is used.
                    - ``"algo_spec_params"``: a dictionary to add pred_params that are specific to a model.
                      The dict has a format of {"<name of algo>": "<pred_params for that algo>"}.

                The parameters in the second group is defined in the ``config`` of each Algo templates. Please check:
                https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates

        Returns:
            A list of tensors or file paths, depending on whether ``"image_save_func"`` is set.
        Nr'   Zfiles_slicesr&   r$   rX   	supportedrW   Fimage_save_funcalgo_spec_paramsrA   r   zEnsembling (rank 0)...)desc)predict_filespredict_paramsrY   c                 S  s   g | ]}| d qS rQ   )to)rN   rI   r+   r+   r,   rS      s     z)AlgoEnsemble.__call__.<locals>.<listcomp>metasaved_tozImage save path not returned.z\Prediction returned in list instead of disk, provide image_save_func to avoid out of memory.)r   r'   popr   r&   r   get_parsed_contenthas_tqdmget	enumerater    r(   r   r   r0   ALGOrm   predictappendrb   BaseExceptionrE   rx   keysr
   )r*   ro   re   filesslicesr&   rW   Z	img_saverrs   outputsrI   filer_   r2   Zinfer_algo_nameZinfer_instancerj   predZensemble_predsresr+   r+   r,   __call__   sX    





zAlgoEnsemble.__call__c                 O  s   t d S rK   )NotImplementedError)r*   argskwargsr+   r+   r,   collect_algos   s    zAlgoEnsemble.collect_algos)r5   )F)N)__name__
__module____qualname____doc__r-   r/   r3   r4   rJ   rb   rm   r   r   r   r+   r+   r+   r,   r#   1   s   	
Pr#   c                      sB   e Zd ZdZddd fddZdd Zddd
dddZ  ZS )AlgoEnsembleBestNz
    Ensemble method that select N model out of all using the models' best_metric scores

    Args:
        n_best: number of models to pick for ensemble (N).
       intn_bestc                   s   t    || _d S rK   )superr-   r   )r*   r   	__class__r+   r,   r-      s    
zAlgoEnsembleBestN.__init__c                 C  s   t | jtjg}t| S )z'
        Sort the best_metrics
        )r   r%   r   SCOREnpargsorttolist)r*   scoresr+   r+   r,   
sort_score   s    zAlgoEnsembleBestN.sort_scorer<   r7   )r   r;   c                   s    dkr| j  |  t k rNtdt d  dt d t  fddtD }t|dd	}t| j| _|D ]}|t| jk r| j	| qd
S )zQ
        Rank the algos by finding the top N (n_best) validation scores.
        r   zFound z% available algos (pre-defined n_best=z). All z will be used.c                   s$   g | ]\}}|t   k r|qS r+   )len)rN   irr   ranksr+   r,   rS     s      z3AlgoEnsembleBestN.collect_algos.<locals>.<listcomp>T)reverseN)
r   r   r   r
   r~   sortedr   r%   r(   rz   )r*   r   indicesidxr+   r   r,   r     s    $zAlgoEnsembleBestN.collect_algos)r   )r<   )r   r   r   r   r-   r   r   __classcell__r+   r+   r   r,   r      s   r   c                      s6   e Zd ZdZddd fddZddd	d
Z  ZS )AlgoEnsembleBestByFoldz
    Ensemble method that select the best models that are the tops in each fold.

    Args:
        n_fold: number of cross-validation folds used in training
    r   r   n_foldc                   s   t    || _d S rK   )r   r-   r   )r*   r   r   r+   r,   r-   #  s    
zAlgoEnsembleBestByFold.__init__r7   r;   c                 C  s   g | _ t| jD ]}d}d}| jD ]~}|tj dd }zt|}W n4 tk
rz } ztd| d|W 5 d}~X Y nX ||kr"|tj	 |kr"|}|tj	 }q"| j 
| qdS )zX
        Rank the algos by finding the best model in each cross-validation fold
        g      NrI      zmodel identifier z is not number.)r(   ranger   r%   r   r0   splitr   rH   r   r   )r*   f_idx
best_scoreZ
best_modelr2   r1   Zalgo_iderrr+   r+   r,   r   '  s    
$z$AlgoEnsembleBestByFold.collect_algos)r   )r   r   r   r   r-   r   r   r+   r+   r   r,   r     s   r   c                   @  sT   e Zd ZdZddddddZddd	d
ddddZdddddddZdd ZdS )AlgoEnsembleBuildera  
    Build ensemble workflow from configs and arguments.

    Args:
        history: a collection of trained bundleAlgo algorithms.
        data_src_cfg_name: filename of the data source.

    Examples:

        .. code-block:: python

            builder = AlgoEnsembleBuilder(history, data_src_cfg)
            builder.set_ensemble_method(BundleAlgoEnsembleBestN(3))
            ensemble = builder.get_ensemble()

    NzSequence[dict[str, Any]]z
str | None)historydata_src_cfg_namec           	      C  s   g | _ |  tdd| _|d k	r:tjt|r:| j| |D ]z}|tj	 }|tj
 }| }|j}tj|dd}tj|st|j d tj|st| d | ||| q>d S )NF)globalsscriptszinfer.pyz+ is not a directory. Please check the path.z% is not found. Please check the path.)r.   r   data_src_cfgospathexistsr6   read_configr   r0   r   	get_scoreoutput_pathjoinisdirr
   isfileadd_inferer)	r*   r   r   	algo_dictr!   gen_algobest_metric	algo_pathZ
infer_pathr+   r+   r,   r-   O  s     

zAlgoEnsembleBuilder.__init__r6   r   zfloat | Noner7   )r1   r   r   r;   c                 C  s6   |dkrt dtj|tj|tj|i}| j| dS )z
        Add model inferer to the builder.

        Args:
            identifier: name of the bundleAlgo.
            gen_algo: a trained BundleAlgo model object.
            best_metric: the best metric in validation of the trained model.
        Nz+Feature to re-validate is to be implemented)rH   r   r0   r   r   r.   r   )r*   r1   r   r   r2   r+   r+   r,   r   i  s    
zAlgoEnsembleBuilder.add_infererr#   r   )ensembler   r   r;   c                 O  s:   | | j |j|| || jd | jd  || _dS )zj
        Set the ensemble method.

        Args:
            ensemble: the AlgoEnsemble to build.
        r8   r=   N)r/   r.   r   rJ   r   r   )r*   r   r   r   r+   r+   r,   set_ensemble_methody  s    z'AlgoEnsembleBuilder.set_ensemble_methodc                 C  s   | j S )zGet the ensemble)r   r)   r+   r+   r,   get_ensemble  s    z AlgoEnsembleBuilder.get_ensemble)N)N)r   r   r   r   r-   r   r   r   r+   r+   r+   r,   r   =  s
   r   c                	   @  s   e Zd ZdZd#dddddd	d
dddZd$dd	d
dddZdd Zd	d
dddZd%dd
dddZdd Z	d&dd
dddZ
d
d d!d"ZdS )'EnsembleRunnera  
    The Runner for ensembler. It ensembles predictions and saves them to the disk with a support of using multi-GPU.

    Args:
        data_src_cfg_name: filename of the data source.
        work_dir: working directory to save the intermediate and final results. Default is `./work_dir`.
        num_fold: number of fold. Default is 5.
        ensemble_method_name: method to ensemble predictions from different model. Default is AlgoEnsembleBestByFold.
                              Supported methods: ["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"].
        mgpu: if using multi-gpu. Default is True.
        kwargs: additional image writing, ensembling parameters and prediction parameters for the ensemble inference.
              - for image saving, please check the supported parameters in SaveImage transform.
              - for prediction parameters, please check the supported parameters in the ``AlgoEnsemble`` callables.
              - for ensemble parameters, please check the documentation of the selected AlgoEnsemble callable.

    Example:

        .. code-block:: python

            ensemble_runner = EnsembleRunner(data_src_cfg_name,
                                             work_dir,
                                             ensemble_method_name,
                                             mgpu=device_setting['n_devices']>1,
                                             **kwargs,
                                             **pred_params)
            ensemble_runner.run(device_setting)

    
./work_dirr   r   Tr6   r   boolr   r7   )r   work_dirnum_foldensemble_method_namemgpur   r;   c                 K  s   || _ || _|| _|| _|| _t|| _d| _d| _d	dd t
tj D tj ttjddtjddtjd	d
d| _d S )Nr   r   ,c                 S  s   g | ]}t |qS r+   )r6   )rN   xr+   r+   r,   rS     s     z+EnsembleRunner.__init__.<locals>.<listcomp>	NUM_NODESMN_START_METHODbcprun
CMD_PREFIX )CUDA_VISIBLE_DEVICES	n_devicesr   r   r   )r   r   r   r   r   r   r   rA   
world_sizer   r   r\   cudadevice_countr   r   environr}   device_setting)r*   r   r   r   r   r   r   r+   r+   r,   r-     s    	
zEnsembleRunner.__init__)r   r   r;   c                 K  sf   t |ddgd| _| jdkr6|dd}t|d| _n,| jdkrPt| jd| _ntd| j d	d
S )a  
        Set the bundle ensemble method

        Args:
            ensemble_method_name: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
                and "AlgoEnsembleBestByFold".
            kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
                ``AlgoEnsembleBestN`` is supported.

        r   r   rp   r      r   r   zEnsemble method z is not implemented.N)r   r   rz   r   ensemble_methodr   r   r   )r*   r   r   r   r+   r+   r,   r     s     

z"EnsembleRunner.set_ensemble_methodc           	      K  s  | dd}|dkr6tj| jd}td| d tj|sbtj|dd td| d	 t	
| j}|d
d}d|| dd| dd| dddd| d|| ddd	}tt|\}}|r|| n*t|D ] }||kr||| |i q|S )a  
        Pop the kwargs used to define ImageSave class for the ensemble output.

        Args:
            kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
                transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .

        Returns:
            save_image: a dictionary that can be used to instantiate a SaveImage class in ConfigParser.
        
output_dirNZensemble_outputz!The output_dir is not specified. z+ will be used to save ensemble predictions.T)exist_okz
Directory z( is created to save ensemble predictionsr8   r   r   output_postfixr   output_dtypez	$np.uint8resampleFdata_root_dirseparate_folder)	_target_r   r   r   r   	print_logsavepath_in_metadictr   r   )rz   r   r   r   r   rF   rG   r   makedirsr   rD   r   r}   r   r   ri   rC   )	r*   r   r   Z
input_yamlr   
save_imageZare_all_args_save_image
extra_argsr   r+   r+   r,   '_pop_kwargs_to_get_image_save_transform  s4    




z6EnsembleRunner._pop_kwargs_to_get_image_save_transform)r   r;   c                 K  s2   t t|\}}|r | j| nt| ddS )a$  
        Set the ensemble output transform.

        Args:
            kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
                transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .

        z are not supported in monai.transforms.SaveImage,Check https://docs.monai.io/en/stable/transforms.html#saveimage for more information.N)r   r   r   ri   rH   )r*   r   are_all_args_presentr   r+   r+   r,   set_image_save_transform
  s    	z'EnsembleRunner.set_image_save_transform)r   r;   c                 C  s    |dkrt d| || _dS )z
        Set the number of cross validation folds for all algos.

        Args:
            num_fold: a positive integer to define the number of folds.
        r   zEnum_fold is expected to be an integer greater than zero. Now it gets N)rH   r   )r*   r   r+   r+   r,   set_num_fold  s    zEnsembleRunner.set_num_foldc                 C  s  | j r4tjddd t | _t | _tt	  | j
| jd | j| jf| j | jf | j}t| jdd}dd |D }|rtd	d
d |D  d dd |D }t|dkrtd| j dt|| j}|| j | | _| jj}t|| jk rBt|dkr td d S | jt|k r<|| j gng }nt|d| jdd| j }|| j_| j| jd< || jd< td | j D ]}t|tj  q|d }td| d | j| jd | j rt   d S )Nncclzenv://)backendinit_method)r   F)only_trainedc                 S  s   g | ]}|t j s|qS r+   r   
IS_TRAINEDrN   hr+   r+   r,   rS   6  s     
 z+EnsembleRunner.ensemble.<locals>.<listcomp>zEnsembling step will skip c                 S  s   g | ]}|t j qS r+   )r   r0   r   r+   r+   r,   rS   9  s     zJ untrained algos.Generally it means these algos did not complete training.c                 S  s   g | ]}|t j r|qS r+   r   r   r+   r+   r,   rS   <  s     
 r   z&Could not find the trained results in z8. Possibly the required training step was not completed.z=No testing files for inference is provided. Ensembler ending.)datashufflenum_partitionseven_divisiblerA   rr   z4Auto3Dseg picked the following networks to ensemble:r   z7Auto3Dseg ensemble prediction outputs will be saved in .)ro   )!r   distinit_process_groupget_world_sizer   get_rankrA   rF   	addFilterr   r   r   r   r   r   r   r   r   warningr   rH   r   r   r   r   Z	ensemblerr'   rG   r   r4   r   r0   destroy_process_group)r*   r   r   Zhistory_untrainedbuilderr'   r2   r   r+   r+   r,   r   (  s^    



"   

zEnsembleRunner.ensembleNrn   )r   r;   c                 C  s>   |dk	r2| j | tt| j d d| j d< |   dS )a  
        Load the run function in the training script of each model. Training parameter is predefined by the
        algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.

        Args:
            device_setting: device related settings, should follow the device_setting in auto_runner.set_device_info.
                'CUDA_VISIBLE_DEVICES' should be a string e.g. '0,1,2,3'
        Nr   r   r   )r   ri   r   r6   r   _create_cmd)r*   r   r+   r+   r,   run`  s    
zEnsembleRunner.runr   c              	   C  sx  t | jd dkr:t | jd dkr:td |   d S d| j d| j d| j d| j d		}| j	rt
| j	tr| j	 D ]\}}|d
| d| 7 }q|tj }t| jd |d< t | jd dkr8| jd dkrt| jd  dtd| jd  d td| | jd  d}t|| jd | jd d n<td| jd  d td| }t|d| jd |dd d S )Nr   r   r   zEnsembling using single GPU!zQmonai.apps.auto3dseg EnsembleRunner ensemble                 --data_src_cfg_name z                 --work_dir z                 --num_fold z(                 --ensemble_method_name z                 --mgpu Truez --=r   r   r   zN is not supported yet. Try modify EnsembleRunner._create_cmd for your cluster.zEnsembling on z nodes!z-m r   )
cmd_prefix)nrO   zEnsembling using z GPU!T)nnodesnproc_per_nodeenvcheck)r   r   rF   rG   r   r   r   r   r   r   rB   r   rg   r   r   copyr6   r   r   r   r   r   )r*   Zbase_cmdrk   rl   
ps_environcmdr+   r+   r,   r	  o  sH    $

    zEnsembleRunner._create_cmd)r   r   r   T)r   )r   )N)r   r   r   r   r-   r   r   r   r   r   r
  r	  r+   r+   r+   r,   r     s        08r   )@
__future__r   r   abcr   r   collections.abcr   r   r  r   typingr   r	   warningsr
   numpyr   r\   torch.distributeddistributedr  monai.apps.auto3dseg.bundle_genr   monai.apps.auto3dseg.utilsr   r   monai.apps.utilsr   monai.auto3dsegr   monai.auto3dseg.utilsr   r   r   r   r   monai.bundler   
monai.datar   monai.transformsr   r   r   monai.utilsr   monai.utils.enumsr   monai.utils.miscr   r   monai.utils.moduler   r   r    r|   r   rF   r#   r   r   r   r   r+   r+   r+   r,   <module>   s:   
 ?,"P