o
    i3A                     @  s  d dl mZ d dlZd dlmZ d dlmZ d dlmZm	Z	 d dl
mZ d dlmZ d dlmZ d d	lmZmZmZmZ d d
lmZ d dlmZ d dlmZ d dlmZ ed\ZZed\ZZ ee!dZ"g dZ#G dd deZ$G dd de$Z%G dd de$Z&dS )    )annotationsN)abstractmethod)deepcopy)Anycast)warn)
BundleAlgo)
get_logger)AlgoAlgoGenalgo_from_picklealgo_to_pickle)ConfigParser)PathLike)optional_import)AlgoKeysnnioptuna)module_name)HPOGenNNIGen	OptunaGenc                   @  s@   e Zd ZdZedd Zedd Zedd Zedd	 Zd
S )r   a>  
    The base class for hyperparameter optimization (HPO) interfaces to generate algos in the Auto3Dseg pipeline.
    The auto-generated algos are saved at their ``output_path`` on the disk. The files in the ``output_path``
    may contain scripts that define the algo, configuration files, and pickle files that save the internal states
    of the algo before/after the training. Compared to the BundleGen class, HPOGen generates Algo on-the-fly, so
    training and algo generation may be executed alternatively and take a long time to finish the generation process.

    c                 C     t )z Get the hyperparameter from HPO.NotImplementedErrorself r   ^/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/auto3dseg/hpo_gen.pyget_hyperparameters-      zHPOGen.get_hyperparametersc                 O  r   )zHUpdate Algo parameters according to the hyperparameters to be evaluated.r   r   argskwargsr   r   r   update_params2   r    zHPOGen.update_paramsc                 O  r   )z$Report the evaluated results to HPO.r   r!   r   r   r   	set_score7   r    zHPOGen.set_scorec                 O  r   )zDInterface for launch the training given the fetched hyperparameters.r   r!   r   r   r   run_algo<   r    zHPOGen.run_algoN)	__name__
__module____qualname____doc__r   r   r$   r%   r&   r   r   r   r   r   #   s    	


r   c                   @  sf   e Zd ZdZd"d#ddZd	d
 Zdd Zdd Zd$ddZdd Z	d%d&ddZ
dd Zd'd(d d!ZdS ))r   u  
    Generate algorithms for the NNI to automate hyperparameter tuning. The module has two major interfaces:
    ``__init__`` which prints out how to set up the NNI, and a trialCommand function ``run_algo`` for the NNI library to
    start the trial of the algo. More about trialCommand function can be found in ``trail code`` section in NNI webpage
    https://nni.readthedocs.io/en/latest/tutorials/hpo_quickstart_pytorch/main.html .

    Args:
        algo: an Algo object (e.g. BundleAlgo) with defined methods: ``get_output_path`` and train
            and supports saving to and loading from pickle files via ``algo_from_pickle`` and ``algo_to_pickle``.
        params: a set of parameter to override the algo if override is supported by Algo subclass.

    Examples::

        The experiment will keep generating new folders to save the model checkpoints, scripts, and configs if available.
        ├── algorithm_templates
        │   └── unet
        ├── unet_0
        │   ├── algo_object.pkl
        │   ├── configs
        │   └── scripts
        ├── unet_0_learning_rate_0.01
        │   ├── algo_object.pkl
        │   ├── configs
        │   ├── model_fold0
        │   └── scripts
        └── unet_0_learning_rate_0.1
            ├── algo_object.pkl
            ├── configs
            ├── model_fold0
            └── scripts

        .. code-block:: python
            # Bundle Algorithms are already generated by BundleGen in work_dir
            import_bundle_algo_history(work_dir, only_trained=False)
            algo_dict = self.history[0]  # pick the first algorithm
            algo_name = algo_dict[AlgoKeys.ID]
            onealgo = algo_dict[AlgoKeys.ALGO]
            nni_gen = NNIGen(algo=onealgo)
            nni_gen.print_bundle_algo_instruction()

    Notes:
        The NNIGen will prepare the algorithms in a folder and suggest a command to replace trialCommand in the experiment
        config. However, NNIGen will not trigger NNI. User needs to write their NNI experiment configs, and then run the
        NNI command manually.
    NalgoAlgo | Noneparamsdict | Nonec                 C  s   |  d| _ d| _|d urRt|trC|d u r|| _n-t|| _tj|	 d }tj
|	 }|ddi | jj||fi | n|| _t| j| jjd| _d S d S N 	_overridefill_with_datastatsFtemplate_path)hintobj_filename
isinstancer   r+   r   ospathbasenameget_output_pathdirnameupdateexport_to_diskr   r4   r   r+   r-   nameoutput_folderr   r   r   __init__q   s   

zNNIGen.__init__c                 C     | j S )z5Return the filename of the dumped pickle algo object.r6   r   r   r   r   get_obj_filename      zNNIGen.get_obj_filenamec                 C  s   d}t d t d t d t | d| j d t d t d t d	tt| jj d
 t d| j  d t d t | d t d dS )zH
        Print how to write the trial commands for Bundle Algo.
        z/python -m monai.apps.auto3dseg NNIGen run_algo z============================================================================================================================================z#If NNI will run in your local env: zB1. Add the following line to the trialCommand in your NNI config:  z {result_dir}z--------------------------------------------------------------------------------------------------------------------------------------------z!If NNI will run in a remote env: z'1. Copy the algorithm_templates folder z+ to remote {remote_algorithm_templates_dir}z2. Copy the older z( to the remote machine {remote_algo_dir}zDThen add the following line to the trialCommand in your NNI config: z@ {remote_algo_dir} {result_dir} {remote_algorithm_templates_dir}N)loggerinfor6   r   r   r+   r4   r;   )r   r5   r   r   r   print_bundle_algo_instruction   s   





z$NNIGen.print_bundle_algo_instructionc                 C  s   t rt S td i S )zK
        Get parameter for next round of training from NNI server.
        ?NNI is not detected. The code will continue to run without NNI.)has_nnir   Zget_next_parameterr   r   r   r   r   r      s   zNNIGen.get_hyperparametersdictreturnNonec                 C  
   || _ dS )z
        Translate the parameter from monai bundle to meet NNI requirements.

        Args:
            params: a dict of parameters.
        Nr-   r   r-   r   r   r   r$         
zNNIGen.update_paramsc                 C     d dd | j D pdS )
        Get the identifier of the current experiment. In the format of listing the searching parameter name and values
        connected by underscore in the file name.
        r0   c                 s  $    | ]\}}d | d | V  qdS _Nr   .0kvr   r   r   	<genexpr>      " z%NNIGen.get_task_id.<locals>.<genexpr>_Nonejoinr-   itemsr   r   r   r   get_task_id      zNNIGen.get_task_id.rA   strc                 C  s   |   }tj| j }tj||| }tj|d| _t| jt	r2| jj
||| |dd dS t| j| t| dS )
        Generate the record for each Algo. If it is a BundleAlgo, it will generate the config files.

        Args:
            output_folder: the directory nni will save the results to.
        algo_object.pklF)bundle_rootr2   Nrc   r8   r9   r:   r+   r;   ra   r6   r7   r   r>   r   export_config_filer-   rH   rI   r   rA   Ztask_idZtask_prefix
write_pathr   r   r   generate   s   
zNNIGen.generatec                 C  s   t r	t| dS td dS )z/
        Report the acc to NNI server.
        rK   N)rL   r   Zreport_final_resultr   r   accr   r   r   r%      s   zNNIGen.set_scorer6   r4   PathLike | Nonec                 C     t j|st| dt||d\| _}|  }| | | | | j	| j
 | j }ttj|i}t| jfd| jji| | | dS am  
        The python interface for NNI to run.

        Args:
            obj_filename: the pickle-exported Algo object.
            output_folder: the root path of the algorithms templates.
            template_path: the algorithm_template. It must contain algo.py in the follow path:
                ``{algorithm_templates_dir}/{network}/scripts/algo.py``
        z is not foundr3   r4   Nr8   r9   isfile
ValueErrorr   r+   r   r$   rn   trainr-   	get_scorerf   r   SCOREr   r4   r%   r   r6   rA   r4   algo_meta_datar-   rp   r   r   r   r&      s   



zNNIGen.run_algoNN)r+   r,   r-   r.   r-   rM   rN   rO   re   rA   rf   rN   rO   re   Nr6   rf   rA   rf   r4   rq   rN   rO   )r'   r(   r)   r*   rB   rE   rJ   r   r$   rc   rn   r%   r&   r   r   r   r   r   B   s    .
			r   c                   @  st   e Zd ZdZd&d'd	d
Zdd Zdd Zdd Zdd Z	d(d)ddZ	d*ddZ
d d! Zd+d,d"d#Zd(d-d$d%ZdS ).r   u  
    Generate algorithms for the Optuna to automate hyperparameter tuning. Please refer to NNI and Optuna
    (https://optuna.readthedocs.io/en/stable/) for more information. Optuna has different running scheme
    compared to NNI. The hyperparameter samples come from a trial object (trial.suggest...) created by Optuna,
    so OptunaGen needs to accept this trial object as input. Meanwhile, Optuna calls OptunaGen,
    thus OptunaGen.__call__() should return the accuracy. Use functools.partial to wrap OptunaGen
    for addition input arguments.

    Args:
        algo: an Algo object (e.g. BundleAlgo). The object must at least define two methods: get_output_path and train
            and supports saving to and loading from pickle files via ``algo_from_pickle`` and ``algo_to_pickle``.
        params: a set of parameter to override the algo if override is supported by Algo subclass.

    Examples::

        The experiment will keep generating new folders to save the model checkpoints, scripts, and configs if available.
        ├── algorithm_templates
        │   └── unet
        ├── unet_0
        │   ├── algo_object.pkl
        │   ├── configs
        │   └── scripts
        ├── unet_0_learning_rate_0.01
        │   ├── algo_object.pkl
        │   ├── configs
        │   ├── model_fold0
        │   └── scripts
        └── unet_0_learning_rate_0.1
            ├── algo_object.pkl
            ├── configs
            ├── model_fold0
            └── scripts

    Notes:
        Different from NNI and NNIGen, OptunaGen and Optuna can be ran within the Python process.

    Nr+   r,   r-   r.   rN   rO   c                 C  s   |  d| _ |d urOt|tr@|d u r|| _n-t|| _tj| d }tj	| }|
ddi | jj||fi | n|| _t| j| jjd| _ d S d S r/   )r6   r7   r   r+   r   r8   r9   r:   r;   r<   r=   r>   r   r4   r?   r   r   r   rB     s   

zOptunaGen.__init__c                 C  rC   )z(Return the dumped pickle object of algo.rD   r   r   r   r   rE   -  rF   zOptunaGen.get_obj_filenamec                 C  s.   t rtd d| jdddiS td i S )z
        Get parameter for next round of training from optuna trial object.
        This function requires user rewrite during usage for different search space.
        z2Please rewrite this code by creating a child classZlearning_rateg-C6?g?zEOptuna is not detected. The code will continue to run without Optuna.)
has_optunarH   rI   trialZsuggest_floatr   r   r   r   r   r   1  s
   
zOptunaGen.get_hyperparametersc                 C  rP   )zSet the accuracy scoreN)rp   ro   r   r   r   r%   =     
zOptunaGen.set_scorec                 C  rP   )zSet the Optuna trialN)r   )r   r   r   r   r   	set_trialA  r   zOptunaGen.set_trialre   r   r   r6   rf   rA   r4   rq   c                 C  s   |  | | ||| | jS )a  
        Callable that Optuna will use to optimize the hyper-parameters

        Args:
            obj_filename: the pickle-exported Algo object.
            output_folder: the root path of the algorithms templates.
            template_path: the algorithm_template. It must contain algo.py in the follow path:
                ``{algorithm_templates_dir}/{network}/scripts/algo.py``
        )r   r&   rp   )r   r   r6   rA   r4   r   r   r   __call__E  s   
zOptunaGen.__call__rM   c                 C  rP   )zu
        Translate the parameter from monai bundle.

        Args:
            params: a dict of parameters.
        NrQ   rR   r   r   r   r$   U  rS   zOptunaGen.update_paramsc                 C  rT   )rU   r0   c                 s  rV   rW   r   rY   r   r   r   r]   c  r^   z(OptunaGen.get_task_id.<locals>.<genexpr>r_   r`   r   r   r   r   rc   ^  rd   zOptunaGen.get_task_idc                 C  s~   |   }tj| j }tj||| }tj|d| _t| jt	r1| jj
||| dd dS t| j| t| dS )rg   rh   F)r2   Nrj   rl   r   r   r   rn   e  s   zOptunaGen.generatec                 C  rr   rs   rt   rz   r   r   r   r&   w  s   



zOptunaGen.run_algor|   )r+   r,   r-   r.   rN   rO   r   )
r   r   r6   rf   rA   rf   r4   rq   rN   r   r}   r~   r   r   )r'   r(   r)   r*   rB   rE   r   r%   r   r   r$   rc   rn   r&   r   r   r   r   r      s    &
	r   )'
__future__r   r8   abcr   copyr   typingr   r   warningsr   monai.apps.auto3dseg.bundle_genr   monai.apps.utilsr	   monai.auto3dsegr
   r   r   r   monai.bundle.config_parserr   monai.configr   monai.utilsr   monai.utils.enumsr   r   rL   r   r   r'   rH   __all__r   r   r   r   r   r   r   <module>   s*   
 1