o
    ij                     @  s  d dl mZ d dlZd dlmZmZ d dlmZmZ d dl	m
Z
 d dlmZmZ d dlmZ d dlZd dlZd dlmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZ d dlm Z m!Z!m"Z"m#Z#m$Z$ d dl%m&Z& d dl'm(Z( d dl)m*Z*m+Z+m,Z, d dl-m.Z. d dl/m0Z0 d dl1m2Z2m3Z3 d dl4m5Z5m6Z6 e6ddd\Z7Z8ee9dZ:G dd deZ;G dd de;Z<G dd de;Z=G dd dZ>G dd  d Z?dS )!    )annotationsN)ABCabstractmethod)MappingSequence)deepcopy)Anycast)warn)
BundleAlgo)get_name_from_algo_idimport_bundle_algo_history)
get_logger)concat_val_to_np)_prepare_cmd_bcprun_prepare_cmd_torchrun_run_cmd_bcprun_run_cmd_torchrundatafold_read)ConfigParser)partition_dataset)MeanEnsemble	SaveImageVoteEnsemble)
RankFilter)AlgoKeys) check_kwargs_exist_in_class_init
prob2class)look_up_optionoptional_importtqdm)name)module_namec                   @  sh   e Zd ZdZdd Zdd Zdd Zdd	 Zd%d&ddZd'ddZ	d(ddZ
d)d*d!d"Zed#d$ ZdS )+AlgoEnsemblez,
    The base class of Ensemble methods
    c                 C  s   g | _ d| _g | _g | _d S )Nmean)algosmodeinfer_filesalgo_ensembleself r+   g/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/auto3dseg/ensemble_builder.py__init__6   s   
zAlgoEnsemble.__init__c                 C  s   t || _dS )z0
        Register model in the ensemble
        N)r   r%   )r*   infer_algosr+   r+   r,   	set_algos<   s   zAlgoEnsemble.set_algosc                 C  s&   | j D ]}||tj kr|  S qdS )zn
        Get a model by identifier.

        Args:
            identifier: the name of the bundleAlgo
        N)r%   r   ID)r*   
identifieralgor+   r+   r,   get_algoB   s
   
zAlgoEnsemble.get_algoc                 C     | j S )z
        Get the algo ensemble after ranking or a empty list if ranking was not started.

        Returns:
            A list of Algo
        )r(   r)   r+   r+   r,   get_algo_ensembleM   s   zAlgoEnsemble.get_algo_ensembletestingdatarootstrdata_list_or_path
str | listdata_keyreturnNonec                 C  s   g | _ t|tr|| _ dS t|tr?t|}||v r(t||d|d\| _ }dS t| dr2| jdkr=t	
d| d dS dS td)	z
        Set the files to perform model inference.

        Args:
            dataroot: the path of the files
            data_list_or_path: the data source file path
        )datalistbasedirfoldkeyrankr   z#Datalist file has no testing key - z$. No data for inference is specifiedzUnsupported parameter typeN)r'   
isinstancelistr8   r   load_config_filer   hasattrrC   loggerinfo
ValueError)r*   r7   r9   r;   r?   _r+   r+   r,   set_infer_filesV   s   	



zAlgoEnsemble.set_infer_filesFc                   s   t dd |D rdd |D }| jdkr&t |}tttj|dd dS | jd	krG fd
d|D } r;t |S t|d jd d|S dS )a  
        ensemble the results using either "mean" or "vote" method

        Args:
            preds: a list of probability prediction in Tensor-Like format.
            sigmoid: use the sigmoid function to threshold probability one-hot map,
                otherwise argmax is used. Defaults to False

        Returns:
            a tensor which is the ensembled prediction.
        c                 s  s    | ]}|j  V  qd S N)is_cuda.0pr+   r+   r,   	<genexpr>z   s    z-AlgoEnsemble.ensemble_pred.<locals>.<genexpr>c                 S  s   g | ]}|  qS r+   cpurO   r+   r+   r,   
<listcomp>{       z.AlgoEnsemble.ensemble_pred.<locals>.<listcomp>r$   r   Tdimkeepdimsigmoidvotec                   s   g | ]
}t |d d dqS )r   TrW   )r   rO   rZ   r+   r,   rU      s    )num_classesN)	anyr&   r   r   r	   torchTensorr   shape)r*   predsrZ   probclassesr+   r\   r,   ensemble_predm   s   



zAlgoEnsemble.ensemble_predalgo_spec_paramdictparam	algo_namec                 C  s@   t |}t |}| D ]\}}| | kr|| q|S )a  
        Apply the model-specific params to the prediction params based on the name of the Algo.

        Args:
            algo_spec_param: a dict that has structure of {"<name of algo>": "<pred_params for that algo>"}.
            param: the prediction params to override.
            algo_name: name of the Algo

        Returns:
            param after being updated with the model-specific param
        )r   itemslowerupdate)r*   rf   rh   ri   Z_param_to_override_paramkvr+   r+   r,   _apply_algo_specific_param   s   
z'AlgoEnsemble._apply_algo_specific_paramN
pred_paramdict | NonerE   c              	   C  s  |du ri nt |}| j}d|v r|d}d|v r#|d}|| }d|v r5|d}t|ddgd| _|dd	}d
|v rGt|d
  }|di }g }	trc|rc|dddkrct	t
|ddnt	|D ]~\}
}g }| jD ]$}t|tj }|tj }| |||}|j|g|d}||d  qpd
|v rz	| j||d}W n ty   | jdd |D |d}Y nw ||}t|drd|j v r|jd }ntd d}ntd | j||d}|	| qg|	S )aG  
        Use the ensembled model to predict result.

        Args:
            pred_param: prediction parameter dictionary. The key has two groups: the first one will be consumed
                in this function, and the second group will be passed to the `InferClass` to override the
                parameters of the class functions.
                The first group contains:

                    - ``"infer_files"``: file paths to the images to read in a list.
                    - ``"files_slices"``: a value type of `slice`. The files_slices will slice the ``"infer_files"`` and
                      only make prediction on the infer_files[file_slices].
                    - ``"mode"``: ensemble mode. Currently "mean" and "vote" (majority voting) schemes are supported.
                    - ``"image_save_func"``: a dictionary used to instantiate the ``SaveImage`` transform. When specified,
                      the ensemble prediction will save the prediction files, instead of keeping the files in the memory.
                      Example: `{"_target_": "SaveImage", "output_dir": "./"}`
                    - ``"sigmoid"``: use the sigmoid function (e.g. x > 0.5) to convert the prediction probability map
                      to the label class prediction, otherwise argmax(x) is used.
                    - ``"algo_spec_params"``: a dictionary to add pred_params that are specific to a model.
                      The dict has a format of {"<name of algo>": "<pred_params for that algo>"}.

                The parameters in the second group is defined in the ``config`` of each Algo templates. Please check:
                https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates

        Returns:
            A list of tensors or file paths, depending on whether ``"image_save_func"`` is set.
        Nr'   Zfiles_slicesr&   r$   r[   	supportedrZ   Fimage_save_funcalgo_spec_paramsrC   r   zEnsembling (rank 0)...)desc)predict_filespredict_paramsr\   c                 S  s   g | ]}| d qS rS   )to)rP   rK   r+   r+   r,   rU          z)AlgoEnsemble.__call__.<locals>.<listcomp>metasaved_tozImage save path not returned.z\Prediction returned in list instead of disk, provide image_save_func to avoid out of memory.)r   r'   popr   r&   r   get_parsed_contenthas_tqdmget	enumerater    r(   r   r   r0   ALGOrp   predictappendre   BaseExceptionrG   r|   keysr
   )r*   rq   rh   filesslicesr&   rZ   Z	img_saverrv   outputsrK   filerb   r2   Zinfer_algo_nameZinfer_instancerm   predZensemble_predsresr+   r+   r,   __call__   sZ   





zAlgoEnsemble.__call__c                 O  s   t rM   )NotImplementedError)r*   argskwargsr+   r+   r,   collect_algos   s   zAlgoEnsemble.collect_algos)r6   )r7   r8   r9   r:   r;   r8   r<   r=   )F)rf   rg   rh   rg   ri   r8   r<   rg   rM   )rq   rr   r<   rE   )__name__
__module____qualname____doc__r-   r/   r3   r5   rL   re   rp   r   r   r   r+   r+   r+   r,   r#   1   s    	

Pr#   c                      s8   e Zd ZdZdd fddZdd ZddddZ  ZS )AlgoEnsembleBestNz
    Ensemble method that select N model out of all using the models' best_metric scores

    Args:
        n_best: number of models to pick for ensemble (N).
       n_bestintc                      t    || _d S rM   )superr-   r   )r*   r   	__class__r+   r,   r-         

zAlgoEnsembleBestN.__init__c                 C  s   t | jtjg}t| S )z'
        Sort the best_metrics
        )r   r%   r   SCOREnpargsorttolist)r*   scoresr+   r+   r,   
sort_score   s   zAlgoEnsembleBestN.sort_scorer>   r<   r=   c                   s    dkr| j  |  t k r'tdt d  dt d t  fddtD }t|dd	}t| j| _|D ]}|t| jk rP| j	| qAd
S )zQ
        Rank the algos by finding the top N (n_best) validation scores.
        r   zFound z% available algos (pre-defined n_best=z). All z will be used.c                   s$   g | ]\}}|t   k r|qS r+   )len)rP   irr   ranksr+   r,   rU     s   $ z3AlgoEnsembleBestN.collect_algos.<locals>.<listcomp>T)reverseN)
r   r   r   r
   r   sortedr   r%   r(   r~   )r*   r   indicesidxr+   r   r,   r     s   $zAlgoEnsembleBestN.collect_algosr   )r   r   )r>   )r   r   r<   r=   )r   r   r   r   r-   r   r   __classcell__r+   r+   r   r,   r      s
    r   c                      s.   e Zd ZdZdd fddZdd	d
Z  ZS )AlgoEnsembleBestByFoldz
    Ensemble method that select the best models that are the tops in each fold.

    Args:
        n_fold: number of cross-validation folds used in training
    r   n_foldr   c                   r   rM   )r   r-   r   )r*   r   r   r+   r,   r-   #  r   zAlgoEnsembleBestByFold.__init__r<   r=   c                 C  s   g | _ t| jD ]J}d}d}| jD ]:}|tj dd }zt|}W n ty8 } z	td| d|d}~ww ||krK|tj	 |krK|}|tj	 }q| j 
| qdS )zX
        Rank the algos by finding the best model in each cross-validation fold
        g      NrK      zmodel identifier z is not number.)r(   ranger   r%   r   r0   splitr   rJ   r   r   )r*   f_idx
best_scoreZ
best_modelr2   r1   Zalgo_iderrr+   r+   r,   r   '  s$   

z$AlgoEnsembleBestByFold.collect_algosr   )r   r   r<   r=   )r   r   r   r   r-   r   r   r+   r+   r   r,   r     s    r   c                   @  s:   e Zd ZdZddddZddddZdddZdd ZdS ) AlgoEnsembleBuildera  
    Build ensemble workflow from configs and arguments.

    Args:
        history: a collection of trained bundleAlgo algorithms.
        data_src_cfg_name: filename of the data source.

    Examples:

        .. code-block:: python

            builder = AlgoEnsembleBuilder(history, data_src_cfg)
            builder.set_ensemble_method(BundleAlgoEnsembleBestN(3))
            ensemble = builder.get_ensemble()

    NhistorySequence[dict[str, Any]]data_src_cfg_name
str | Nonec           	      C  s   g | _ |  tdd| _|d urtjt|r| j| |D ]=}|tj	 }|tj
 }| }|j}tj|dd}tj|sHt|j d tj|sUt| d | ||| qd S )NF)globalsscriptszinfer.pyz+ is not a directory. Please check the path.z% is not found. Please check the path.)r.   r   data_src_cfgospathexistsr8   read_configr   r0   r   	get_scoreoutput_pathjoinisdirr
   isfileadd_inferer)	r*   r   r   	algo_dictr!   gen_algobest_metric	algo_pathZ
infer_pathr+   r+   r,   r-   O  s"   

zAlgoEnsembleBuilder.__init__r1   r8   r   r   r   float | Noner<   r=   c                 C  s6   |du rt dtj|tj|tj|i}| j| dS )z
        Add model inferer to the builder.

        Args:
            identifier: name of the bundleAlgo.
            gen_algo: a trained BundleAlgo model object.
            best_metric: the best metric in validation of the trained model.
        Nz+Feature to re-validate is to be implemented)rJ   r   r0   r   r   r.   r   )r*   r1   r   r   r2   r+   r+   r,   r   i  s   
zAlgoEnsembleBuilder.add_infererensembler#   r   r   r   c                 O  s>   | | j |j|i | || jd | jd  || _dS )zj
        Set the ensemble method.

        Args:
            ensemble: the AlgoEnsemble to build.
        r7   r?   N)r/   r.   r   rL   r   r   )r*   r   r   r   r+   r+   r,   set_ensemble_methody  s   
z'AlgoEnsembleBuilder.set_ensemble_methodc                 C  r4   )zGet the ensemble)r   r)   r+   r+   r,   get_ensemble  s   z AlgoEnsembleBuilder.get_ensemblerM   )r   r   r   r   )r1   r8   r   r   r   r   r<   r=   )r   r#   r   r   r   r   r<   r=   )r   r   r   r   r-   r   r   r   r+   r+   r+   r,   r   =  s    
r   c                   @  sl   e Zd ZdZ				d%d&ddZd'd(ddZdd Zd)ddZd*d+ddZdd Z	d,d-d!d"Z
d.d#d$ZdS )/EnsembleRunnera  
    The Runner for ensembler. It ensembles predictions and saves them to the disk with a support of using multi-GPU.

    Args:
        data_src_cfg_name: filename of the data source.
        work_dir: working directory to save the intermediate and final results. Default is `./work_dir`.
        num_fold: number of fold. Default is 5.
        ensemble_method_name: method to ensemble predictions from different model. Default is AlgoEnsembleBestByFold.
                              Supported methods: ["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"].
        mgpu: if using multi-gpu. Default is True.
        kwargs: additional image writing, ensembling parameters and prediction parameters for the ensemble inference.
              - for image saving, please check the supported parameters in SaveImage transform.
              - for prediction parameters, please check the supported parameters in the ``AlgoEnsemble`` callables.
              - for ensemble parameters, please check the documentation of the selected AlgoEnsemble callable.

    Example:

        .. code-block:: python

            ensemble_runner = EnsembleRunner(data_src_cfg_name,
                                             work_dir,
                                             ensemble_method_name,
                                             mgpu=device_setting['n_devices']>1,
                                             **kwargs,
                                             **pred_params)
            ensemble_runner.run(device_setting)

    
./work_dirr   r   Tr   r8   work_dirnum_foldr   ensemble_method_namemgpuboolr   r   r<   r=   c                 K  s   || _ || _|| _|| _|| _t|| _d| _d| _d	dd t
tj D tj ttjddtjddtjd	d
d| _d S )Nr   r   ,c                 S  s   g | ]}t |qS r+   )r8   )rP   xr+   r+   r,   rU     rV   z+EnsembleRunner.__init__.<locals>.<listcomp>	NUM_NODESMN_START_METHODbcprun
CMD_PREFIX )CUDA_VISIBLE_DEVICES	n_devicesr   r   r   )r   r   r   r   r   r   r   rC   
world_sizer   r   r_   cudadevice_countr   r   environr   device_setting)r*   r   r   r   r   r   r   r+   r+   r,   r-     s   	
zEnsembleRunner.__init__c                 K  sf   t |ddgd| _| jdkr|dd}t|d| _d
S | jdkr*t| jd| _d
S td| j d	)a  
        Set the bundle ensemble method

        Args:
            ensemble_method_name: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
                and "AlgoEnsembleBestByFold".
            kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
                ``AlgoEnsembleBestN`` is supported.

        r   r   rs   r      )r   )r   zEnsemble method z is not implemented.N)r   r   r~   r   ensemble_methodr   r   r   )r*   r   r   r   r+   r+   r,   r     s   

z"EnsembleRunner.set_ensemble_methodc           	      K  s  | dd}|du rtj| jd}td| d tj|s1tj|dd td| d	 t	
| j}|d
d}d|| dd| dd| dddd| d|| ddd	}tt|\}}|rm|| |S t|D ]}||vr||| |i qq|S )a  
        Pop the kwargs used to define ImageSave class for the ensemble output.

        Args:
            kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
                transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .

        Returns:
            save_image: a dictionary that can be used to instantiate a SaveImage class in ConfigParser.
        
output_dirNZensemble_outputz!The output_dir is not specified. z+ will be used to save ensemble predictions.T)exist_okz
Directory z( is created to save ensemble predictionsr7   r   r   output_postfixr   output_dtypez	$np.uint8resampleFdata_root_dirseparate_folder)	_target_r   r   r   r   	print_logsavepath_in_metadictr   r   )r~   r   r   r   r   rH   rI   r   makedirsr   rF   r   r   r   r   rl   rE   )	r*   r   r   Z
input_yamlr   
save_imageZare_all_args_save_image
extra_argsr   r+   r+   r,   '_pop_kwargs_to_get_image_save_transform  s8   





z6EnsembleRunner._pop_kwargs_to_get_image_save_transformc                 K  s0   t t|\}}|r| j| dS t| d)a$  
        Set the ensemble output transform.

        Args:
            kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
                transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .

        z are not supported in monai.transforms.SaveImage,Check https://docs.monai.io/en/stable/transforms.html#saveimage for more information.N)r   r   r   rl   rJ   )r*   r   are_all_args_presentr   r+   r+   r,   set_image_save_transform
  s   	z'EnsembleRunner.set_image_save_transformc                 C  s    |dkrt d| || _dS )z
        Set the number of cross validation folds for all algos.

        Args:
            num_fold: a positive integer to define the number of folds.
        r   zEnum_fold is expected to be an integer greater than zero. Now it gets N)rJ   r   )r*   r   r+   r+   r,   set_num_fold  s   
zEnsembleRunner.set_num_foldc                 C  s  | j rtjddd t | _t | _tt	  | j
| jd | j| jfi | j | jdi | j}t| jdd}dd |D }|rZtd	d
d |D  d dd |D }t|dkritd| j dt|| j}|| j | | _| jj}t|| jk rt|dkrtd d S | jt|k r|| j gng }nt|d| jdd| j }|| j_| j| jd< || jd< td | j D ]
}t|tj  q|d }td| d | j| jd | j rt   d S d S )Nncclzenv://)backendinit_method)r   F)only_trainedc                 S  s   g | ]	}|t j s|qS r+   r   
IS_TRAINEDrP   hr+   r+   r,   rU   6      z+EnsembleRunner.ensemble.<locals>.<listcomp>zEnsembling step will skip c                 S  s   g | ]}|t j qS r+   )r   r0   r  r+   r+   r,   rU   9  r{   zJ untrained algos.Generally it means these algos did not complete training.c                 S  s   g | ]	}|t j r|qS r+   r   r  r+   r+   r,   rU   <  r  r   z&Could not find the trained results in z8. Possibly the required training step was not completed.z=No testing files for inference is provided. Ensembler ending.)datashufflenum_partitionseven_divisiblerC   ru   z4Auto3Dseg picked the following networks to ensemble:r   z7Auto3Dseg ensemble prediction outputs will be saved in .)rq   r+   )!r   distinit_process_groupget_world_sizer   get_rankrC   rH   	addFilterr   r   r   r   r   r   r   r   r   warningr   rJ   r   r   r   r   Z	ensemblerr'   rI   r   r5   r   r0   destroy_process_group)r*   r   r   Zhistory_untrainedbuilderr'   r2   r   r+   r+   r,   r   (  sZ   



 


zEnsembleRunner.ensembleNr   rr   c                 C  s>   |dur| j | tt| j d d| j d< |   dS )a  
        Load the run function in the training script of each model. Training parameter is predefined by the
        algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.

        Args:
            device_setting: device related settings, should follow the device_setting in auto_runner.set_device_info.
                'CUDA_VISIBLE_DEVICES' should be a string e.g. '0,1,2,3'
        Nr   r   r   )r   rl   r   r8   r   _create_cmd)r*   r   r+   r+   r,   run`  s   
zEnsembleRunner.runc              	   C  sx  t | jd dkrt | jd dkrtd |   d S d| j d| j d| j d| j d		}| j	rMt
| j	trM| j	 D ]\}}|d
| d| 7 }q>tj }t| jd |d< t | jd dkr| jd dkrut| jd  dtd| jd  d td| | jd  d}t|| jd | jd d d S td| jd  d td| }t|d| jd |dd d S )Nr   r   r   zEnsembling using single GPU!zQmonai.apps.auto3dseg EnsembleRunner ensemble                 --data_src_cfg_name z                 --work_dir z                 --num_fold z(                 --ensemble_method_name z                 --mgpu Truez --=r   r   r   zN is not supported yet. Try modify EnsembleRunner._create_cmd for your cluster.zEnsembling on z nodes!z-m r   )
cmd_prefix)nrQ   zEnsembling using z GPU!T)nnodesnproc_per_nodeenvcheck)r   r   rH   rI   r   r   r   r   r   r   rD   r   rj   r   r   copyr8   r   r   r   r   r   )r*   Zbase_cmdrn   ro   
ps_environcmdr+   r+   r,   r  o  sB   $

zEnsembleRunner._create_cmd)r   r   r   T)r   r8   r   r8   r   r   r   r8   r   r   r   r   r<   r=   )r   )r   r8   r   r   r<   r=   )r   r   r<   r=   r   )r   r   r<   r=   rM   )r   rr   r<   r=   r   )r   r   r   r   r-   r   r   r   r   r   r  r  r+   r+   r+   r,   r     s     
08r   )@
__future__r   r   abcr   r   collections.abcr   r   r  r   typingr   r	   warningsr
   numpyr   r_   torch.distributeddistributedr
  monai.apps.auto3dseg.bundle_genr   monai.apps.auto3dseg.utilsr   r   monai.apps.utilsr   monai.auto3dsegr   monai.auto3dseg.utilsr   r   r   r   r   monai.bundler   
monai.datar   monai.transformsr   r   r   monai.utilsr   monai.utils.enumsr   monai.utils.miscr   r   monai.utils.moduler   r   r    r   r   rH   r#   r   r   r   r   r+   r+   r+   r,   <module>   s<   
 ?,"P