U
    PhH                  	   @  sD  d dl mZ d dlZd dlZd dlZd dlZd dlZd dlmZ d dl	m
Z
 d dlmZmZ d dlZd dlZd dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZmZ d dlmZm Z m!Z! dddddddddg	Z"e dde\Z#Z$e d\Z%Z&dddddZ'ddddddZ(dTdd d!d"d#dZ)dUd%d&d'd'd(dd)d*dZ*dVd%d&d+d d(d,d-d.dZ+dWd0d1d2d1d3d4d5dZ,d6d6d d7d8dZ-dXd9d:d(d1d;d<dZ.dYd1d:d(d(d=d>dZ/d?d1d@dAdBZ0d6d1dCdDdEZ1dZd1dFd(d1dGdHdIZ2d1d(d1dJdKdLZ3d[d1dFd(d1dGdMdNZ4d1d(dOdJdPdQZ5d1d(dOdJdRdSZ6dS )\    )annotationsN)deepcopy)Number)Anycast)Algo)ConfigParser)
ID_SEP_KEY)PathLike
MetaTensor)CropForegroundToCupy)min_versionoptional_importrun_cmdget_foreground_imageget_foreground_labelget_label_ccpconcat_val_to_npconcat_multikeys_to_dictdatafold_readverify_report_formatalgo_to_picklealgo_from_picklezskimage.measurez0.14.2cupyr   z
np.ndarray)imagereturnc                 C  s$   t dd dd}|| }ttj|S )ay  
    Get a foreground image by removing all-zero rectangles on the edges of the image
    Note for the developer: update select_fn if the foreground is defined differently.

    Args:
        image: ndarray image to segment.

    Returns:
        ndarray of foreground image by removing all-zero edges.

    Notes:
        the size of the output is smaller than the input.
    c                 S  s   | dkS )Nr    )xr   r   J/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/auto3dseg/utils.py<lambda>A       z&get_foreground_image.<locals>.<lambda>T)	select_fnallow_smaller)r   r   npndarray)r   copperimage_foregroundr   r   r    r   2   s    )r   labelr   c                 C  s   t | |dk }|S )a  
    Get foreground image pixel values and mask out the non-labeled area.

    Args
        image: ndarray image to segment.
        label: ndarray the image input and annotated with class IDs.

    Returns:
        1D array of foreground image with label > 0
    r   r   )r   r)   Zlabel_foregroundr   r   r    r   F   s    Tboolztuple[list[Any], int])
mask_indexuse_gpur   c                   s|  t d\}}g }| jjdkrtr|r|rt |  }|j|}t	|t
| }|D ]\}t||k}	tj|	dd tj|	dd   fddtt D }
||
 qZt|}~~~~	~t   ntrltj| j  ddd	\}}td
|d
 D ]^}t||k}	tj|	dd tj|	dd   fddtt D }
||
 q
ntd||fS )a4  
    Find all connected components and their bounding shape. Backend can be cuPy/cuCIM or Numpy
    depending on the hardware.

    Args:
        mask_index: a binary mask.
        use_gpu: a switch to use GPU/CUDA or not. If GPU is unavailable, CPU will be used
            regardless of this setting.

    zcucim.skimagecudar   )axisc                   s    g | ]} | |  d  qS    r   .0iZcomp_idx_maxZcomp_idx_minr   r    
<listcomp>l   s     z!get_label_ccp.<locals>.<listcomp>T)
background
return_numr0   c                   s    g | ]} | |  d  qS r/   r   r1   r4   r   r    r5   y   s     zVCannot find one of the following required dependencies: {cuPy+cuCIM} or {scikit-image})r   devicetypehas_cpr   shortmeasurer)   cpuniquenonzeroargwheremintolistmaxrangelenappendget_default_memory_poolfree_all_blockshas_measure
measure_npdatacpunumpyr%   RuntimeError)r+   r,   skimage	has_cucim
shape_listZ	mask_cupyZlabeledvalsZncompZcomp_idxZ
bbox_shapencomponentsr   r4   r    r   V   s2    
Fz
list[dict]zlist[str | int]zbool | Noner   )	data_list
fixed_keysraggedallow_missingkwargsr   c                 K  sF  g }| D ]}t |}t|D ]\}}	t|	||< q|t|}
|
dkrj|rZ|d qt| dqt|
t	r|t
|
 qt|
tjtfr||
   qt|
t
jr||
 qt|
tr|t
|
 qt|
j dq|rdd |D }t|dkrt
dgS |r2t
j|f|S t
j|gf|S dS )a  
    Get the nested value in a list of dictionary that shares the same structure.

    Args:
       data_list: a list of dictionary {key1: {key2: np.ndarray}}.
       fixed_keys: a list of keys that records to path to the value in the dict elements.
       ragged: if True, numbers can be in list of lists or ragged format so concat mode needs change.
       allow_missing: if True, it will return a None if the value cannot be found.

    Returns:
        nd.array of concatenated array.

    Nz  is not nested in the dictionaryz concat is not supported.c                 S  s   g | ]}|d k	r|qS )Nr   )r2   r   r   r   r    r5      s      z$concat_val_to_np.<locals>.<listcomp>r   )r   	enumeratestrgetr	   joinrG   AttributeError
isinstancelistr%   arraytorchTensorr   rM   rN   r&   r   NotImplementedError	__class__rF   concatenate)rU   rV   rW   rX   rY   Znp_listrL   parserr3   keyvalr   r   r    r      s4    

z	list[str]zdict[str, np.ndarray])rU   rV   keyszero_insertrY   r   c           	      K  sD   i }|D ]6}|rd|gn|g}t | || f|}|||i q|S )a  
    Get the nested value in a list of dictionary that shares the same structure iteratively on all keys.
    It returns a dictionary with keys with the found values in nd.ndarray.

    Args:
        data_list: a list of dictionary {key1: {key2: np.ndarray}}.
        fixed_keys: a list of keys that records to path to the value in the dict elements.
        keys: a list of string keys that will be iterated to generate a dict output.
        zero_insert: insert a zero in the list so that it can find the value in element 0 before getting the keys
        flatten: if True, numbers are flattened before concat.

    Returns:
        a dict with keys - nd.array of concatenated array pair.
    r   )r   update)	rU   rV   rj   rk   rY   ret_dictrh   Zaddonri   r   r   r    r      s    trainingz
str | dictr[   intztuple[list, list])datalistbasedirfoldrh   r   c                   s   t | trt| }n| }t|| }|D ]|}| D ]n\}}t || trh fdd|| D ||< q6t || tr6t|| dkrtj	
 || n|| ||< q6q*g }	g }
|D ].}d|kr|d |kr|
| q|	| q|	|
fS )a  
    Read a list of data dictionary `datalist`

    Args:
        datalist: the name of a JSON file listing the data, or a dictionary.
        basedir: directory of image files.
        fold: which fold to use (0..1 if in training set).
        key: usually 'training' , but can try 'validation' or 'testing' to get the list data without labels (used in challenges).

    Returns:
        A tuple of two arrays (training, validation).
    c                   s   g | ]}t j |qS r   )ospathr]   )r2   ivrq   r   r    r5      s     z!datafold_read.<locals>.<listcomp>r   rr   )r_   r[   r   load_config_filer   itemsr`   rF   rs   rt   r]   rG   )rp   rq   rr   rh   	json_dataZ	dict_datadk_trri   r   rv   r    r      s"    
2dict)reportreport_formatr   c                 C  s   |  D ]z\}}|| kr dS | | }t|trt|trt|dkrNtdt|dkr|t|dkr|t|d |d   S  dS qdS )z
    Compares the report and the report_format that has only keys.

    Args:
        report: dict that has real values.
        report_format: dict that only has keys and list-nested value.
    Fr0   z%list length in report_format is not 1r   T)rx   r_   r`   rF   UserWarningr   )r   r   Zk_fmtZv_fmtvr   r   r    r      s    r   zPathLike | None)algotemplate_pathalgo_meta_datar   c           	   	   K  st   t | t|d}tj|  d}| D ]\}}|||i q.t |}t	|d}|
| W 5 Q R X |S )a  
    Export the Algo object to pickle file.

    Args:
        algo: Algo-like object.
        template_path: a str path that is needed to be added to the sys.path to instantiate the class.
        algo_meta_data: additional keyword to save into the dictionary, for example, model training info
            such as acc/best_metrics

    Returns:
        filename of the pickled Algo object
    )
algo_bytesr   zalgo_object.pklwb)pickledumpsr[   rs   rt   r]   get_output_pathrx   rl   openwrite)	r   r   r   rL   pkl_filenamer{   r   
data_bytesf_pir   r   r    r     s    
)r   r   rY   r   c                 K  sh  t | d}| }W 5 Q R X t|}t|tsDtd|j dd|kr\td| d|d}|dd}g }t	j
t|r|t	j
t| |t	j
t	j
t|d	 t	j
t|r|t	j
| |t	j
t	j
|d	 t	j
| }	t	j
|	d	d
}
t	j
|
r6|t	j
|
 t|dkrVt|}d|_nt|D ]\}}z"tj
| t|}W  qW nh tk
r } zHtd| d tj
  |t|d krtd|  d| |W 5 d}~X Y nX q^||_t	j
|	t	j
| kr:t|  d|	 d |	|_i }| D ]\}}|||i qF||fS )a  
    Import the Algo object from a pickle file.

    Args:
        pkl_filename: the name of the pickle file.
        template_path: a folder containing files to instantiate the Algo. Besides the `template_path`,
        this function will also attempt to use the `template_path` saved in the pickle file and a directory
        named `algorithm_templates` in the parent folder of the folder containing the pickle file.

    Returns:
        algo: the Algo object saved in the pickle file.
        algo_meta_data: additional keyword saved in the pickle file, for example, acc/best_metrics.

    Raises:
        ValueError if the pkl_filename does not contain a dict, or the dict does not contain `algo_bytes`.
        ModuleNotFoundError if it is unable to instantiate the Algo class.

    rbzthe data object is z. Dict is expected.r   zkey [algo_bytes] not found in z. Unable to instantiate.r   Nz..algorithm_templatesr   zFolder z; doesn't contain the Algo templates for Algo instantiation.r0   zFailed to instantiate z with z5 is changed. Now override the Algo output_path with: .)r   readr   loadsr_   r~   
ValueErrorre   poprs   rt   isdirr[   rG   abspathr]   dirnamerF   r   rZ   sysModuleNotFoundErrorloggingdebugr   output_pathrx   rl   )r   r   rY   r   r   rL   r   Zalgo_template_pathZtemplate_paths_candidatesZpkl_dirZalgo_template_path_fuzzyr   r3   pnot_found_errr   r{   r   r   r   r    r   ,  sZ    


 


r`   )argsr   c                 C  s    d dd | D }d| dS )z
    Convert a list of arguments to a string that can be used in python-fire.

    Args:
        args: the list of arguments.

    Returns:
        the string that can be used in python-fire.
    ,c                 S  s   g | ]}t |qS r   )r[   )r2   argr   r   r    r5     s     z/list_to_python_fire_arg_str.<locals>.<listcomp>')r]   )r   args_strr   r   r    list_to_python_fire_arg_strz  s    
r   )paramsr   c                 C  sT   d}|   D ]B\}}t|tr(tdnt|tr:t|}|d| d| 7 }q|S )z;convert `params` into '--key_1=value_1 --key_2=value_2 ...' zNested dict is not supported.z --=)rx   r_   r~   r   r`   r   )r   Zcmd_mod_optr{   r   r   r   r    check_and_set_optional_args  s    


r   z
str | None)cmd
cmd_prefixrY   r   c                 K  s:   |  }|rd|krd}|ds*|d7 }||  t| S )a  
    Prepare the command for subprocess to run the script with the given arguments.

    Args:
        cmd: the command or script to run in the distributed job.
        cmd_prefix: the command prefix to run the script, e.g., "python", "python -m", "python3", "/opt/conda/bin/python3.8 ".
        kwargs: the keyword arguments to be passed to the script.

    Returns:
        the command to run with ``subprocess``.

    Examples:
        To prepare a subprocess command
        "python train.py run -k --config 'a,b'", the function can be called as
        - _prepare_cmd_default("train.py run -k", config=['a','b'])
        - _prepare_cmd_default("train.py run -k --config 'a,b'")

    Nonepython )copyendswithr   )r   r   rY   r   r   r   r    _prepare_cmd_default  s    
r   )r   rY   r   c                 K  s   |  }| t| S )a  
    Prepare the command for multi-gpu/multi-node job execution using torchrun.

    Args:
        cmd: the command or script to run in the distributed job.
        kwargs: the keyword arguments to be passed to the script.

    Returns:
        the command to append to ``torchrun``

    Examples:
        For command "torchrun --nnodes=1 --nproc_per_node=8 train.py run -k --config 'a,b'",
        it only prepares command after the torchrun arguments, i.e., "train.py run -k --config 'a,b'".
        The function can be called as
        - _prepare_cmd_torchrun("train.py run -k", config=['a','b'])
        - _prepare_cmd_torchrun("train.py run -k --config 'a,b'")
    )r   r   )r   rY   r   r   r   r    _prepare_cmd_torchrun  s    r   c                 K  s   t | fd|i|S )a  
    Prepare the command for distributed job running using bcprun.

    Args:
        script: the script to run in the distributed job.
        cmd_prefix: the command prefix to run the script, e.g., "python".
        kwargs: the keyword arguments to be passed to the script.

    Returns:
        The command to run the script in the distributed job.

    Examples:
        For command "bcprun -n 2 -p 8 -c python train.py run -k --config 'a,b'",
        it only prepares command after the bcprun arguments, i.e., "train.py run -k --config 'a,b'".
        the function can be called as
        - _prepare_cmd_bcprun("train.py run -k", config=['a','b'], n=2, p=8)
        - _prepare_cmd_bcprun("train.py run -k --config 'a,b'", n=2, p=8)
    r   )r   )r   r   rY   r   r   r    _prepare_cmd_bcprun  s    r   zsubprocess.CompletedProcessc                 K  sx   |  }|  }dg}ddg}|D ]8}||kr>td| d|d| t||g7 }q"||7 }t|fddi|S )	a  
    Run the command with torchrun.

    Args:
        cmd: the command to run. Typically it is prepared by ``_prepare_cmd_torchrun``.
        kwargs: the keyword arguments to be passed to the ``torchrun``.

    Return:
        the return code of the subprocess command.
    Ztorchrunnnodesnproc_per_nodeMissing required argument z for torchrun.z--run_cmd_verboseT)r   splitr   r[   r   r   )r   rY   r   cmd_listZtorchrun_listrequired_argsr   r   r   r    _run_cmd_torchrun  s    r   c                 K  sv   |  }dg}ddg}|D ]8}||kr6td| d|d| t||g7 }q|d| g t|fdd	i|S )
a  
    Run the command with bcprun.

    Args:
        cmd: the command to run. Typically it is prepared by ``_prepare_cmd_bcprun``.
        kwargs: the keyword arguments to be passed to the ``bcprun``.

    Returns:
        the return code of the subprocess command.
    bcprunnr   r   z for bcprun.-z-cr   T)r   r   r[   r   extendr   )r   rY   r   r   r   r   r   r   r    _run_cmd_bcprun  s    r   )T)FF)T)r   rn   )N)N)N)N)7
__future__r   r   rs   r   
subprocessr   r   r   numbersr   typingr   r   rN   r%   rb   Zmonai.auto3dsegr   monai.bundle.config_parserr   monai.bundle.utilsr	   monai.configr
   monai.data.meta_tensorr   monai.transformsr   r   monai.utilsr   r   r   __all__rK   rJ   r>   r;   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    <module>   s`   .  : 'N