o
    iBq                     @  s  d dl mZ d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dl	Z	d dl
mZ d dlmZ d dlmZ d dlmZ d dlmZ d dlZd dlmZ d d	lmZ d d
lmZmZ d dlmZmZmZm Z m!Z!m"Z" d dl#m$Z$ d dl%m&Z& d dl'm(Z(m)Z)m*Z* d dl+m,Z, d dl-m.Z. ee/dZ0e.1 Z2ddgZ3G dd deZ4de2 dZ5e6dde6dde6dde6dddZ7d'd"d#Z8d$d% Z9G d&d deZ:dS )(    )annotationsN)deepcopy)Path)TemporaryDirectory)Any)urlparse)download_and_extract)
get_logger)AlgoAlgoGen)_prepare_cmd_bcprun_prepare_cmd_default_prepare_cmd_torchrun_run_cmd_bcprun_run_cmd_torchrunalgo_to_pickle)ConfigParser)PathLike)ensure_tuplelook_up_optionrun_cmd)AlgoKeys)MONAIEnvVars)module_name
BundleAlgo	BundleGenc                   @  s   e Zd ZdZdCddZdDdEddZdFddZdGddZdHddZdIddZ	dJd#d$Z
dKd'd(ZdLdMd-d.ZdNdOd2d3Z	)dPdQd5d6Zd7d8 Zd9d: ZdLdRd?d@ZdAdB Zd)S )Sr   a5  
    An algorithm represented by a set of bundle configurations and scripts.

    ``BundleAlgo.cfg`` is a ``monai.bundle.ConfigParser`` instance.

    .. code-block:: python

        from monai.apps.auto3dseg import BundleAlgo

        data_stats_yaml = "../datastats.yaml"
        algo = BundleAlgo(template_path="../algorithm_templates")
        algo.set_data_stats(data_stats_yaml)
        # algo.set_data_src("../data_src.json")
        algo.export_to_disk(".", algo_name="segresnet2d_1")

    This class creates MONAI bundles from a directory of 'bundle template'. Different from the regular MONAI bundle
    format, the bundle template may contain placeholders that must be filled using ``fill_template_config`` during
    ``export_to_disk``. Then created bundle keeps the same file structure as the template.

    template_pathr   c                 C  s   || _ d| _d| _d| _d| _d| _d| _d| _i | _d	dd t
tj D ttj ttjddtjdd	tjd
dd| _dS )a0  
        Create an Algo instance based on the predefined Algo template.

        Args:
            template_path: path to a folder that contains the algorithm templates.
                Please check https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates

         N,c                 S  s   g | ]}t |qS  )str).0xr   r   a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/auto3dseg/bundle_gen.py
<listcomp>a   s    z'BundleAlgo.__init__.<locals>.<listcomp>	NUM_NODES   MN_START_METHODbcprun
CMD_PREFIX)CUDA_VISIBLE_DEVICES	n_devicesr%   r'   r)   )r   data_stats_filesdata_list_filemlflow_tracking_urimlflow_experiment_nameoutput_pathnamebest_metricfill_recordsjoinrangetorchcudadevice_countintosenvirongetdevice_setting)selfr   r   r   r#   __init__K   s   
zBundleAlgo.__init__Fr   skip_bundlegenbool	skip_infor    returntuple[bool, str]c                 C  s   ||fS )a  
        Analyse the data analysis report and check if the algorithm needs to be skipped.
        This function is overriden within algo.
        Args:
            skip_bundlegen: skip generating bundles for this algo if true.
            skip_info: info to print when skipped.
        r   )r>   r@   rB   r   r   r#   pre_check_skip_algoh   s   zBundleAlgo.pre_check_skip_algor,   Nonec                 C  
   || _ dS )z
        Set the data analysis report (generated by DataAnalyzer).

        Args:
            data_stats_files: path to the datastats yaml file
        N)r,   )r>   r,   r   r   r#   set_data_statsr      
zBundleAlgo.set_data_statsdata_src_cfgc                 C  rG   )aN  
        Set the data source configuration file

        Args:
            data_src_cfg: path to a configuration file (yaml) that contains datalist, dataroot, and other params.
                The config will be in a form of {"modality": "ct", "datalist": "path_to_json_datalist", "dataroot":
                "path_dir_data"}
        N)r-   )r>   rJ   r   r   r#   set_data_source{      
	zBundleAlgo.set_data_sourcer.   
str | Nonec                 C  rG   aI  
        Set the tracking URI for MLflow server

        Args:
            mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
                the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
                the value is None.
        Nr.   r>   r.   r   r   r#   set_mlflow_tracking_uri   rL   z"BundleAlgo.set_mlflow_tracking_urir/   c                 C  rG   z
        Set the experiment name for MLflow server

        Args:
            mlflow_experiment_name: a string to specify the experiment name for MLflow server.
        Nr/   r>   r/   r   r   r#   set_mlflow_experiment_name   rI   z%BundleAlgo.set_mlflow_experiment_namedata_stats_filename	algo_pathkwargsr   dictc                 K  s   i S )a  
        The configuration files defined when constructing this Algo instance might not have a complete training
        and validation pipelines. Some configuration components and hyperparameters of the pipelines depend on the
        training data and other factors. This API is provided to allow the creation of fully functioning config files.
        Return the records of filling template config: {"<config name>": {"<placeholder key>": value, ...}, ...}.

        Args:
            data_stats_filename: filename of the data stats report (generated by DataAnalyzer)

        Notes:
            Template filling is optional. The user can construct a set of pre-filled configs without replacing values
            by using the data analysis results. It is also intended to be re-implemented in subclasses of BundleAlgo
            if the user wants their own way of auto-configured template filling.
        r   )r>   rV   rW   rX   r   r   r#   fill_template_config   s   zBundleAlgo.fill_template_configr0   	algo_namec                 K  s   | ddr4tj||| _tj| jdd tj| jr#t| j t	tjt
| j| j| j nt
| j| _| ddrM| j| j| jfi || _td| j  dS )a  
        Fill the configuration templates, write the bundle (configs + scripts) to folder `output_path/algo_name`.

        Args:
            output_path: Path to export the 'scripts' and 'configs' directories.
            algo_name: the identifier of the algorithm (usually contains the name and extra info like fold ID).
            kwargs: other parameters, including: "copy_dirs=True/False" means whether to copy the template as output
                instead of inplace operation, "fill_template=True/False" means whether to fill the placeholders
                in the template. other parameters are for `fill_template_config` function.

        Z	copy_dirsT)exist_okZfill_templatez
Generated:N)popr:   pathr4   r0   makedirsisdirshutilrmtreecopytreer    r   r1   rZ   r,   r3   loggerinfo)r>   r0   r[   rX   r   r   r#   export_to_disk   s   "zBundleAlgo.export_to_diskNtrain_paramsNone | dicttuple[str, str]c              
   C  sb  |du ri }t |}tj| jdd}tj| jd}g }tj|rEtt|D ]}|ds6|drD|	t
tj||  q*t| jd dkrzt| jd	 d
g W n typ } zt| jd	  d|d}~ww t| df| jd  |d|dfS t| jd dkrt| dfd|i|dfS t| df| jd  |d|dfS )z:
        Create the command to execute training.

        Nscriptsztrain.pyconfigsyamljsonr%   r&   r'   r(   zI is not supported yet.Try modify BundleAlgo._create_cmd for your cluster.z runr)   )
cmd_prefixconfig_filer   r+   ro   )r   r:   r^   r4   r0   r`   sortedlistdirendswithappendr   as_posixr9   r=   r   
ValueErrorNotImplementedErrorr   r   r   )r>   rg   paramsZtrain_py
config_dirZconfig_filesfileerrr   r   r#   _create_cmd   sX   
	
zBundleAlgo._create_cmdcmddevices_infosubprocess.CompletedProcessc              
   C  s   |rt d| d tj }t| jd |d< tdd|}t	| jd dkrYzt
| jd d	g W n tyK } zt| jd  d
|d}~ww t|| jd | jd dS t	| jd dkrnt|d| jd |ddS t| d|ddS )zP
        Execute the training command with target devices information.

        zinput devices_info z is deprecated and ignored.r*   z^\s*\w+=.*?\s+r   r%   r&   r'   r(   zF is not supported yet.Try modify BundleAlgo._run_cmd for your cluster.Nr+   )npT)nnodesZnproc_per_nodeenvcheck)run_cmd_verboser   r   )warningswarnr:   r;   copyr    r=   resubr9   r   ru   rv   r   r   r   split)r>   r|   r}   Z
ps_environrz   r   r   r#   _run_cmd   s,   
zBundleAlgo._run_cmdr=   c                 C  sn   |dur| j | tt| j d d| j d< |dur+d|v r+td |d | |\}}| 	|S )a  
        Load the run function in the training script of each model. Training parameter is predefined by the
        algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.

        Args:
            train_params:  training parameters
            device_setting: device related settings, should follow the device_setting in auto_runner.set_device_info.
                'CUDA_VISIBLE_DEVICES' should be a string e.g. '0,1,2,3'
        Nr*   r   r+   z5CUDA_VISIBLE_DEVICES is deprecated from train_params!)
r=   updatelenr    r   r   r   r]   r{   r   )r>   rg   r=   r|   Z_unused_returnr   r   r#   train  s   


zBundleAlgo.trainc                 O  sR   t j| jdd}t }|| |jd| jd}tt j|d}|d d S )zU
        Returns validation scores of the model trained by the current Algo.
        rk   zhyper_parameters.yaml	ckpt_path)defaultzprogress.yamlZbest_avg_dice_score)r:   r^   r4   r0   r   Zread_configget_parsed_contentload_config_file)r>   argsrX   Zconfig_yamlparserr   Z	dict_filer   r   r#   	get_score.  s   
zBundleAlgo.get_scorec                   s   t j| jdd}t j|st| dt j| jd  fddt  D }tj	d|}tj
|}|tjd< |j| |j|g|R i |S )a  
        Load the InferClass from the infer.py. The InferClass should be defined in the template under the path of
        `"scripts/infer.py"`. It is required to define the "InferClass" (name is fixed) with two functions at least
        (``__init__`` and ``infer``). The init class has an override kwargs that can be used to override parameters in
        the run-time optionally.

        Examples:

        .. code-block:: python

            class InferClass
                def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **override):
                    # read configs from config_file (sequence)
                    # set up transforms
                    # set up model
                    # set up other hyper parameters
                    return

                @torch.no_grad()
                def infer(self, image_file):
                    # infer the model and save the results to output
                    return output

        rj   zinfer.pyz% is not found, please check the path.rk   c                   s   g | ]	}t j |qS r   )r:   r^   r4   r!   frx   r   r#   r$   Y  s    z*BundleAlgo.get_inferer.<locals>.<listcomp>
InferClass)r:   r^   r4   r0   isfileru   rq   	importlibutilspec_from_file_locationmodule_from_specsysmodulesloaderexec_moduler   )r>   r   rX   Zinfer_pyZconfigs_pathspecZinfer_classr   r   r#   get_inferer;  s   
zBundleAlgo.get_infererpredict_fileslistpredict_paramsdict | Nonec                   s:   |du ri nt |}| jdi |  fddt|D S )aL  
        Use the trained model to predict the outputs with a given input image.

        Args:
            predict_files: a list of paths to files to run inference on ["path_to_image_1", "path_to_image_2"]
            predict_params: a dict to override the parameters in the bundle config (including the files to predict).

        Nc                   s   g | ]}  |qS r   )inferr   infererr   r#   r$   l  s    z&BundleAlgo.predict.<locals>.<listcomp>r   )r   r   r   )r>   r   r   rw   r   r   r#   predicta  s   	zBundleAlgo.predictc                 C     | j S )zCReturns the algo output paths to find the algo scripts and configs.)r0   r>   r   r   r#   get_output_pathn     zBundleAlgo.get_output_path)r   r   )Fr   )r@   rA   rB   r    rC   rD   )r,   r    rC   rF   )rJ   r    rC   rF   )r.   rM   rC   rF   )r/   rM   rC   rF   )rV   r    rW   r    rX   r   rC   rY   )r0   r    r[   r    rX   r   rC   rF   )N)rg   rh   rC   ri   )r   )r|   r    r}   r    rC   r~   )NN)rg   rh   r=   rh   rC   r~   )r   r   r   r   rC   r   )__name__
__module____qualname____doc__r?   rE   rH   rK   rQ   rU   rZ   rf   r{   r   r   r   r   r   r   r   r   r   r#   r   5   s$    



	


	
4 &zYhttps://github.com/Project-MONAI/research-contributions/releases/download/algo_templates/z.tar.gzz(segresnet2d.scripts.algo.Segresnet2dAlgo)_target_zdints.scripts.algo.DintsAlgoz$swinunetr.scripts.algo.SwinunetrAlgoz$segresnet.scripts.algo.SegresnetAlgo)Zsegresnet2ddintsZ	swinunetr	segresneturlr    at_pathrC   dict[str, dict[str, str]]c           
      C  s   t j|}t }t j|jd}d}t|D ]J}zt| |t j|d W n: t	y` } z.d|  d|d  d| d}||d k rMt
| t| n	|  t||W Y d	}~qd	}~ww  |  tt}|D ]}	|||	 d
< ql|S )z
    Downloads the algorithm templates release archive, and extracts it into a parent directory of the at_path folder.
    Returns a dictionary of the algorithm templates.
    zalgo_templates.tar.gz   )r   filepath
output_dirzDownload and extract of z failed, attempt r&   /.Nr   )r:   r^   abspathr   r4   r1   r5   r   dirname	Exceptionr   r   timesleepcleanupru   r   default_algos)
r   r   Zzip_download_dirZalgo_compressed_fileZdownload_attemptsiemsg	algos_allr1   r   r   r#   _download_algos_url  s.   

	r   c              	   C  s   t j| } t j|}| |kr!t j|rt| t| | i }t |D ],}t jt j| |ddrTt	| d|
  d|d||< td| d||   q(|s^td|  |S )	zl
    Copies the algorithm templates folder to at_path.
    Returns a dictionary of algorithm templates.
    rj   zalgo.pyz.scripts.algo.r
   )r   r   zCopying template: z -- zUnable to find any algos in )r:   r^   r   existsra   rb   rc   rq   r4   rY   
capitalizerd   re   ru   )folderr   r   r1   r   r   r#   _copy_algos_folder  s   
 r   c                   @  s   e Zd ZdZ							d2d3ddZd4ddZdd Zdd Zdd Zdd Z	dd Z
dd Zd d! Zd5d#d$Z		%	&		'd6d7d0d1ZdS )8r   a  
    This class generates a set of bundles according to the cross-validation folds, each of them can run independently.

    Args:
        algo_path: the directory path to save the algorithm templates. Default is the current working dir.
        algos: If dictionary, it outlines the algorithm to use. If a list or a string, defines a subset of names of
            the algorithms to use, e.g. ('segresnet', 'dints') out of the full set of algorithm templates provided
            by templates_path_or_url. Defaults to None - to use all available algorithms.
        templates_path_or_url: the folder with the algorithm templates or a url. If None provided, the default template
            zip url will be downloaded and extracted into the algo_path. The current default options are released at:
            https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg.
        data_stats_filename: the path to the data stats file (generated by DataAnalyzer).
        data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
                           {"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}.
        mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of
            the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if
            the value is None.
        mlfow_experiment_name: a string to specify the experiment name for MLflow server.
    .. code-block:: bash

        python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml"
    r   NrW   r    algosdict | list | str | Nonetemplates_path_or_urlrM   rV   data_src_cfg_namer.   r/   c                   s   d u st  tttfrq|d u rt}tjtj|d}tj	|r2t
d|  t||d}	n t|jdv rHt
d|  t||d}	n
t| j d|  d uro fdd	|	 D  t d
krntd|	 n|	 g | _t  trt  D ]?\}
}|dd}t|d
kr|tjvrtj| zt| }|
|_| j| W q ty } zd}t||d }~ww ntd|| _|| _|| _ || _!g | _"d S )NZalgorithm_templateszBundleGen from directory )r   r   )httphttpszBundleGen from )r   r   z) received invalid templates_path_or_url: c                   s"   i | ]\}}|t  v r||qS r   )r   )r!   kvr   r   r#   
<dictcomp>  s   " z&BundleGen.__init__.<locals>.<dictcomp>r   z!Unable to find provided algos in r   r   u  Please make sure the folder structure of an Algo Template follows
                        [algo_name]
                        ├── configs
                        │   ├── hyper_parameters.yaml  # automatically generated yaml from a set of ``template_configs``
                        └── scripts
                            ├── test.py
                            ├── __init__.py
                            └── validate.py
                    z$Unexpected error algos is not a dict)#
isinstancer   tupler    default_algo_zipr:   r^   r4   r   r`   rd   re   r   r   schemer   ru   	__class__itemsr   r   rY   rp   r<   r   rs   r   r   r1   RuntimeErrorrV   r   r.   r/   history)r>   rW   r   r   rV   r   r.   r/   r   r   r[   Zalgo_paramsr   Zonealgor   r   r   r   r#   r?     sN   


	
zBundleGen.__init__rC   rF   c                 C  rG   )zs
        Set the data stats filename

        Args:
            data_stats_filename: filename of datastats
        NrV   )r>   rV   r   r   r#   rH     rI   zBundleGen.set_data_statsc                 C  r   )z"Get the filename of the data statsr   r   r   r   r#   get_data_stats  r   zBundleGen.get_data_statsc                 C  rG   )zy
        Set the data source filename

        Args:
            data_src_cfg_name: filename of data_source file
        Nr   )r>   r   r   r   r#   set_data_src   rI   zBundleGen.set_data_srcc                 C  r   )zGet the data source filenamer   r   r   r   r#   get_data_src)  r   zBundleGen.get_data_srcc                 C  rG   rN   rO   rP   r   r   r#   rQ   -  rL   z!BundleGen.set_mlflow_tracking_uric                 C  rG   rR   rS   rT   r   r   r#   rU   8  rI   z$BundleGen.set_mlflow_experiment_namec                 C  r   )z&Get the tracking URI for MLflow serverrO   r   r   r   r#   get_mlflow_tracking_uriA  r   z!BundleGen.get_mlflow_tracking_uric                 C  r   )z)Get the experiment name for MLflow serverrS   r   r   r   r#   get_mlflow_experiment_nameE  r   z$BundleGen.get_mlflow_experiment_namer   c                 C  r   )zEGet the history of the bundleAlgo object with their names/identifiers)r   r   r   r   r#   get_historyI  r   zBundleGen.get_history   FToutput_foldernum_foldr9   gpu_customizationrA   gpu_customization_specsdict[str, Any] | None
allow_skipc              	   C  s  t t|}| jD ]v}t|D ]o}|  }	|  }
|  }|  }t|}|	|	 |
|
 || || |j d| }|rV| \}}|rVt| d|  q|rc|j|||d|d n|j|||d t||jd | jtj|tj|i qq	dS )a0  
        Generate the bundle scripts/configs for each bundleAlgo

        Args:
            output_folder: the output folder to save each algorithm.
            num_fold: the number of cross validation fold.
            gpu_customization: the switch to determine automatically customize/optimize bundle script/config
                parameters for each bundleAlgo based on gpus. Custom parameters are obtained through dummy
                training to simulate the actual model training process and hyperparameter optimization (HPO)
                experiments.
            gpu_customization_specs: the dictionary to enable users overwrite the HPO settings. user can
                overwrite part of variables as follows or all of them. The structure is as follows.
            allow_skip: a switch to determine if some Algo in the default templates can be skipped based on the
                analysis on the dataset from Auto3DSeg DataAnalyzer.

                .. code-block:: python

                    gpu_customization_specs = {
                        'ALGO': {
                            'num_trials': 6,
                            'range_num_images_per_batch': [1, 20],
                            'range_num_sw_batch_size': [1, 20]
                        }
                    }

            ALGO: the name of algorithm. It could be one of algorithm names (e.g., 'dints') or 'universal' which
                would apply changes to all algorithms. Possible options are

                - {``"universal"``, ``"dints"``, ``"segresnet"``, ``"segresnet2d"``, ``"swinunetr"``}.

            num_trials: the number of HPO trials/experiments to run.
            range_num_images_per_batch: the range of number of images per mini-batch.
            range_num_sw_batch_size: the range of batch size in sliding-window inferer.
        _z is skipped! T)foldr   r   )r   )r   N)r   r5   r   r   r   r   r   r   r   rH   rK   rQ   rU   r1   rE   rd   re   rf   r   r   r   rs   r   IDALGO)r>   r   r   r   r   r   Zfold_idxalgoZf_idZ
data_statsrJ   r.   r/   Zgen_algor1   r@   rB   r   r   r#   generateM  sB   *




zBundleGen.generate)r   NNNNNN)rW   r    r   r   r   rM   rV   rM   r   rM   r.   rM   r/   rM   )rV   r    rC   rF   )rC   r   )r   r   FNT)r   r    r   r9   r   rA   r   r   r   rA   rC   rF   )r   r   r   r   r?   rH   r   r   r   rQ   rU   r   r   r   r   r   r   r   r#   r     s2    
A			
)r   r    r   r    rC   r   );
__future__r   r   r:   r   ra   
subprocessr   r   r   r   r   pathlibr   tempfiler   typingr   urllib.parser   r6   Z
monai.appsr   monai.apps.utilsr	   Zmonai.auto3dseg.algo_genr
   r   monai.auto3dseg.utilsr   r   r   r   r   r   Zmonai.bundle.config_parserr   monai.configr   monai.utilsr   r   r   monai.utils.enumsr   monai.utils.miscr   r   rd   	algo_hashZ	ALGO_HASH__all__r   r   rY   r   r   r   r   r   r   r   r#   <module>   sP    
  
B
!