o
    ie                     @  s  d dl mZ d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dl	m
Z
mZmZmZmZ d dlmZmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZ d d
lmZmZ d dlZej dddkZ!ej dddkZ"ej dddkZ#dZ$g dZ%		djdkddZ&dld#d$Z'	%dmdnd+d,Z(dod0d1Z)d2d3 Z*dpdqd8d9Z+dpdrd;d<Z,G d=d> d>e-Z.G d?d@ d@e/Z0d4e+d4e$ddAdfdsdLdMZ1d4e+dfdtdQdRZ2dudTdUZ3e4ddVdW Z5dvd[d\Z6dwd]d^Z7dwd_d`Z8e4ddxdydfdgZ9e4ddxdzdhdiZ:dS ){    )annotationsN)Callable
CollectionHashableIterableMapping)partialwraps)import_module)walk_packages)locate)match)FunctionType
ModuleType)AnycastZMONAI_EVAL_EXPR10ZMONAI_DEBUG_CONFIGZMONAI_ALLOW_MISSING_REFERENCEz{})InvalidPyTorchVersionErrorOptionalImportErrorexact_versiondamerau_levenshtein_distancelook_up_optionmin_versionoptional_importrequire_pkginstantiateget_full_type_nameget_package_versionget_torch_version_tupleversion_leqversion_geqpytorch_after
no_defaultTopt_strr   	supportedCollection | enum.EnumMetadefaultr   print_all_optionsboolreturnc           
      C  s  t | tstdt|  d|  dt | tr|  } t |tjr@t | tr3| dd |D v r3|| S t | tjr?| |v r?| S nt |t	rM| |v rM||  S t |t
rX| |v rX| S |dkr^|S t |tjrldd |D }n|durtt|nt }|std	| di }|  } |D ]}t| | }|d
kr|||< q|rd| dnd}|rt||jd}	td|  d|	 dd|  d | td|  d| )a3  
    Look up the option in the supported collection and return the matched item.
    Raise a value error possibly with a guess of the closest match.

    Args:
        opt_str: The option string or Enum to look up.
        supported: The collection of supported options, it can be list, tuple, set, dict, or Enum.
        default: If it is given, this method will return `default` when `opt_str` is not found,
            instead of raising a `ValueError`. Otherwise, it defaults to `"no_default"`,
            so that the method may raise a `ValueError`.
        print_all_options: whether to print all available options when `opt_str` is not found. Defaults to True

    Examples:

    .. code-block:: python

        from enum import Enum
        from monai.utils import look_up_option
        class Color(Enum):
            RED = "red"
            BLUE = "blue"
        look_up_option("red", Color)  # <Color.RED: 'red'>
        look_up_option(Color.RED, Color)  # <Color.RED: 'red'>
        look_up_option("read", Color)
        # ValueError: By 'read', did you mean 'red'?
        # 'read' is not a valid option.
        # Available options are {'blue', 'red'}.
        look_up_option("red", {"red", "blue"})  # "red"

    Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/utilities/util_common.py#L249
    zUnrecognized option type: :.c                 S     h | ]}|j qS  value.0itemr.   r.   T/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/utils/module.py	<setcomp>g       z!look_up_option.<locals>.<setcomp>r#   c                 S  r-   r.   r/   r1   r.   r.   r4   r5   y   r6   NzNo options available:    zAvailable options are z.
 )keyzBy 'z', did you mean 'z'?
'z' is not a valid value.
zUnsupported option 'z', )
isinstancer   
ValueErrortypestrstripenumEnumMetaEnumr   r   setr   minget)
r$   r%   r'   r(   Zset_to_checkZ
edit_distsr9   Z	edit_distZsupported_msgZguess_at_spellingr.   r.   r4   r   =   sN   
%

r   s1r>   s2intc           
   	   C  sN  | |krdS t | }t |}| s|S |s|S dd td|d D }td|d D ]
}|d |d|f< q)t| D ]d\}}t|D ][\}}||krJdnd}	t||d |f d |||d f d ||d |d f |	 |||f< |r|r|||d  kr| |d  |krt|||f ||d |d f |	 |||f< q@q8||d |d f S )u   
    Calculates the Damerau–Levenshtein distance between two strings for spelling correction.
    https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance
    r   c                 S  s   i | ]	}|d f|d qS )   r.   )r2   ir.   r.   r4   
<dictcomp>   s    z0damerau_levenshtein_distance.<locals>.<dictcomp>rI   rJ      )lenrange	enumeraterD   )
rF   rG   Zstring_1_lengthZstring_2_lengthdjrK   Zs1iZs2jcostr.   r.   r4   r      s,   :(.r   (.*[tT]est.*)|(_.*)basemodr   load_allexclude_pattern"tuple[list[ModuleType], list[str]]c                 C  s   g }g }t | j| jd |jdD ]Y\}}}|s|ri|tjvrit||du rizt|}||}	|	rA|	j	rA|	j	}
|

| || W q tyK   Y q tyh } zd}t|| d| |j|d}~ww q||fS )z
    Traverse the source of the module structure starting with module `basemod`, loading all packages plus all files if
    `load_all` is True, excluding anything whose name matches `exclude_pattern`.
    r,   )prefixonerrorNz
Multiple versions of MONAI may have been installed?
Please see the installation guide: https://docs.monai.io/en/stable/installation.html

)r   __path____name__appendsysmodulesr   r
   	find_specloaderexec_moduler   ImportErrorr=   with_traceback__traceback__)rU   rV   rW   
submodulesZerr_modimporternameZis_pkgmodmod_specrb   emsgr.   r.   r4   load_submodules   s0    



 rn   __path__modekwargsc                 K  sl  ddl m} t| trt| n| }|du rtd|  dt||}zg|dds+tr:t	
d| d	| d
 t  t|sMt	
d| d| d |W S ||jkrZ|di |W S ||jkrm|rjt|fi |W S |W S ||jkrt	
d| d	| d
 tj|fi |W S W n! ty } ztd|  dd|  d|j d|d}~ww t	
d|  d |S )a  
    Create an object instance or call a callable object from a class or function represented by ``_path``.
    `kwargs` will be part of the input arguments to the class constructor or function.
    The target component must be a class or a function, if not, return the component directly.

    Args:
        __path: if a string is provided, it's interpreted as the full path of the target class or function component.
            If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode="default"``.
            For ``__mode="callable"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided,
            as ``functools.partial(__path, **kwargs)`` for future invoking.

        __mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``:

            - ``"default"``: returns ``component(**kwargs)``
            - ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
            - ``"debug"``: returns ``pdb.runcall(component, **kwargs)``

        kwargs: keyword arguments to the callable represented by ``__path``.

    r   )CompInitModeNz'Cannot locate class or function path: 'z'.Z_debug_Fz

pdb: instantiating component=z, mode=zV
See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html
z
Component z is not callable when mode=r,   z!Failed to instantiate component 'z' with keywords: ,z
 set '_mode_=z' to enter the debugging mode.zKComponent to instantiate must represent a valid class or function, but got r.   )monai.utils.enumsrr   r;   r>   r   ModuleNotFoundErrorr   pop	run_debugwarningswarn
breakpointcallableDEFAULTZCALLABLEr   DEBUGpdbruncall	ExceptionRuntimeErrorjoinkeys)ro   rp   rq   rr   	componentmrl   r.   r.   r4   r      sH   



r   c                 C  s.   | j }|du s|tjj kr| jS |d | j S )zG
    Utility to get the full path name of a class or object type.

    Nr,   )
__module__r>   	__class__r]   )typeobjmoduler.   r.   r4   r     s   r   r8   
the_modulemin_version_str_argsc                 G  s\   |rt | ds	dS tdd | jddd D }tdd |ddd D }||kS )	z
    Convert version strings into tuples of int and compare them.

    Returns True if the module's version is greater or equal to the 'min_version'.
    When min_version_str is not provided, it always returns True.
    __version__Tc                 s      | ]}t |V  qd S NrH   r2   xr.   r.   r4   	<genexpr>      zmin_version.<locals>.<genexpr>r,   NrM   c                 s  r   r   r   r   r.   r.   r4   r     r   )hasattrtupler   split)r   r   r   Zmod_versionrequiredr.   r.   r4   r     s
   " r   version_strc                 G  s,   t | dst|  d dS t| j|kS )zF
    Returns True if the module's __version__ matches version_str
    r   z5 has no attribute __version__ in exact_version check.F)r   rx   ry   r)   r   )r   r   r   r.   r.   r4   r      s   
r   c                      s    e Zd ZdZ fddZ  ZS )r   zo
    Raised when called function or method requires a more recent
    PyTorch version than that installed.
    c                   s    | d| d}t  | d S )Nz requires PyTorch version z	 or later)super__init__)selfZrequired_versionri   messager   r.   r4   r   0  s   z#InvalidPyTorchVersionError.__init__)r]   r   __qualname____doc__r   __classcell__r.   r.   r   r4   r   *  s    r   c                   @  s   e Zd ZdZdS )r   z<
    Could not import APIs from an optional dependency.
    N)r]   r   r   r   r.   r.   r.   r4   r   5  s    r   Fr   versionversion_checkerCallable[..., bool]ri   
descriptorversion_argsallow_namespace_pkgas_typetuple[Any, bool]c              
     s^  dd}|rd|  d| }	nd|  }	z$t | }
t| }|s0t|dddu o+t|d}|r0t|r7t||}W n tyP } z|j| }W Y d}~nd}~ww |r^||
| |r^|dfS |sj||
| rj|dfS ||	|rdu rd	|  d
| d|j d7 |rd| d7 G fddd} dkr| dfS G  fddd|}|dfS )a	  
    Imports an optional module specified by `module` string.
    Any importing related exceptions will be stored, and exceptions raise lazily
    when attempting to use the failed-to-import module.

    Args:
        module: name of the module to be imported.
        version: version string used by the version_checker.
        version_checker: a callable to check the module version, Defaults to monai.utils.min_version.
        name: a non-module attribute (such as method/class) to import from the imported module.
        descriptor: a format string for the final error message when using a not imported module.
        version_args: additional parameters to the version checker.
        allow_namespace_pkg: whether importing a namespace package is allowed. Defaults to False.
        as_type: there are cases where the optionally imported object is used as
            a base class, or a decorator, the exceptions should raise accordingly. The current supported values
            are "default" (call once to raise), "decorator" (call the constructor and the second call to raise),
            and anything else will return a lazy class that can be used as a base class (call the constructor to raise).

    Returns:
        The imported module and a boolean flag indicating whether the import is successful.

    Examples::

        >>> torch, flag = optional_import('torch', '1.1')
        >>> print(torch, flag)
        <module 'torch' from 'python/lib/python3.6/site-packages/torch/__init__.py'> True

        >>> the_module, flag = optional_import('unknown_module')
        >>> print(flag)
        False
        >>> the_module.method  # trying to access a module which is not imported
        OptionalImportError: import unknown_module (No module named 'unknown_module').

        >>> torch, flag = optional_import('torch', '42', exact_version)
        >>> torch.nn  # trying to access a module for which there isn't a proper version imported
        OptionalImportError: import torch (requires version '42' by 'exact_version').

        >>> conv, flag = optional_import('torch.nn.functional', '1.0', name='conv1d')
        >>> print(conv)
        <built-in method conv1d of type object at 0x11a49eac0>

        >>> conv, flag = optional_import('torch.nn.functional', '42', name='conv1d')
        >>> conv()  # trying to use a function from the not successfully imported module (due to unmatched version)
        OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version').
    Nr8   zfrom z import zimport __file__r\   Tz (requires ' z' by 'z')z ()c                      s:   e Zd Z fddZdd Zdd Zdd Zd	d
 ZdS )z#optional_import.<locals>._LazyRaisec                   s<     dd d }d u rt || _d S t || _d S )Nr,   zG

For details about installing the optional dependencies, please visit:z^
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies)r   
_exceptionre   )r   r   _kwargsZ_default_msgrm   tbr.   r4   r     s   z,optional_import.<locals>._LazyRaise.__init__c                 S     | j za
            Raises:
                OptionalImportError: When you call this method.
            r   )r   ri   r.   r.   r4   __getattr__     z/optional_import.<locals>._LazyRaise.__getattr__c                 _  r   r   r   )r   r   r   r.   r.   r4   __call__  r   z,optional_import.<locals>._LazyRaise.__call__c                 S  r   r   r   )r   r3   r.   r.   r4   __getitem__     z/optional_import.<locals>._LazyRaise.__getitem__c                 S  r   r   r   )r   r.   r.   r4   __iter__  r   z,optional_import.<locals>._LazyRaise.__iter__N)r]   r   r   r   r   r   r   r   r.   r   r.   r4   
_LazyRaise  s    r   r'   Fc                      s   e Zd Z fddZ  ZS )z!optional_import.<locals>._LazyClsc                   s   t    ds| jd S )N	decorator)r   r   
startswithr   )r   r   rq   )r   r   r.   r4   r     s   

z*optional_import.<locals>._LazyCls.__init__)r]   r   r   r   r   r.   )r   r   r4   _LazyCls  s    r   )	
__import__r
   getattrr   AssertionErrorr   rf   formatr]   )r   r   r   ri   r   r   r   r   Zexception_strZ
actual_cmdpkgr   is_namespaceZimport_exceptionr   r   r.   )r   rm   r   r4   r   ;  sD   8


!
r   pkg_nameraise_errorr   c                   s    fdd}|S )a  
    Decorator function to check the required package installation.

    Args:
        pkg_name: required package name, like: "itk", "nibabel", etc.
        version: required version string used by the version_checker.
        version_checker: a callable to check the module version, defaults to `monai.utils.min_version`.
        raise_error: if True, raise `OptionalImportError` error if the required package is not installed
            or the version doesn't match requirement, if False, print the error in a warning.

    c                   sF   t | t}|r	| n| j t  fdd}|r|S || _| S )Nc                    sF   t d\}}|sd d}rt|t|  | i |S )N)r   r   r   zrequired package `z<` is not installed or the version doesn't match requirement.)r   r   rx   ry   )argsrq   _haserr_msg)call_objr   r   r   r   r.   r4   _wrapper  s   
z1require_pkg.<locals>._decorator.<locals>._wrapper)r;   r   r   r	   )objis_funcr   r   r   r   r   )r   r4   
_decorator  s   
zrequire_pkg.<locals>._decoratorr.   )r   r   r   r   r   r.   r   r4   r     s   r   !NOT INSTALLED or UNKNOWN VERSION.c                 C  s$   t | \}}|rt|dr|jS |S )zN
    Try to load package and get version. If not found, return `default`.
    r   )r   r   r   )Zdep_namer'   depZhas_depr.   r.   r4   r     s   r   c                   C  s"   t dd tjddd D S )zT
    Returns:
        tuple of ints represents the pytorch major/minor version.
    c                 s  r   r   r   r   r.   r.   r4   r     r   z*get_torch_version_tuple.<locals>.<genexpr>r,   NrM   )r   torchr   r   r.   r.   r.   r4   r     s   "r   lhsrhs/tuple[Iterable[int | str], Iterable[int | str]]c                 C  sR   ddd}|  ddd	 } | ddd	 }t||  d
}t|| d
}||fS )z$
    Parse the version strings.
    valr>   r*   	int | strc                 S  sN   |   } ztd| }|d ur| d } t| W S | W S  ty&   |  Y S w )Nz	(\d+)(.*)r   )r?   r   groupsrH   r<   )r   r   r.   r.   r4   	_try_cast  s   

z%parse_version_strs.<locals>._try_cast+rJ   r   r,   N)r   r>   r*   r   )r   map)r   r   r   lhs_rhs_r.   r.   r4   parse_version_strs  s   
r   c                 C  s   t | t |} }td\}}|r+ztt|| ||kW S  |jy*   Y dS w t| |\}}t||D ] \}}||krWt|t	rOt|t	rO||k   S | | k   S q7dS )a  
    Returns True if version `lhs` is earlier or equal to `rhs`.

    Args:
        lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.
        rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.

    packaging.versionT
r>   r   r   r)   VersionInvalidVersionr   zipr;   rH   r   r   Zpkginghas_verr   r   lrr.   r.   r4   r      s    
r    c                 C  s   t | t |} }td\}}|r+ztt|| ||kW S  |jy*   Y dS w t| |\}}t||D ] \}}||krWt|t	rOt|t	rO||k  S | | k  S q7dS )a  
    Returns True if version `lhs` is later or equal to `rhs`.

    Args:
        lhs: version name to compare with `rhs`, return True if later or equal to `rhs`.
        rhs: version name to compare with `lhs`, return True if earlier or equal to `lhs`.

    r   Tr   r   r.   r.   r4   r!   0  s    	r!   majorminorpatchcurrent_ver_string
str | Nonec              
   C  s  zX|du rt jdd}|r|ntj}tddd\}}|r0|d|  | | f|| kW S | dd	d
 dd}t|dk rN|dg7 }t|dk sC|dd \}}	}
W n t	t
tfyk   t \}}	d}
Y nw t|t|	f}t| t|f}||kr||kS d|
  v pd|
  v }d
}ztd|
 }|rt| }W n t	tt
fy   d}Y nw t|}||kr||kS |rdS dS )aJ  
    Compute whether the current pytorch version is after or equal to the specified version.
    The current system pytorch version is determined by `torch.__version__` or
    via system environment variable `PYTORCH_VER`.

    Args:
        major: major version number to be compared with
        minor: minor version number to be compared with
        patch: patch version number to be compared with
        current_ver_string: if None, `torch.__version__` will be used.

    Returns:
        True if the current pytorch version is greater than or equal to the specified version.
    NZPYTORCH_VERr8   r   parseri   r,   r   rJ   r   r7   r   arcz\d+TF)osenvironrE   r   r   r   r   r   rN   AttributeErrorr<   	TypeErrorr   rH   lowerresearchgroup)r   r   r   r   Z_env_varverr   partsc_majorc_minorZc_patchc_mnmnis_prereleaseZc_pZp_regr.   r.   r4   r"   L  sH   &

r"   c                 C  s  |du r2t j }td\}}|sdS |sdS |  |d}||\}}| d| }|  tddd	\}	}
|
rL|	d|  | f|	| kS | 	d
dd 	dd}t
|dk rj|dg7 }t
|dk s_|dd \}}t|t|f}t| t|f}||kS )a  
    Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
    The current system GPU CUDA compute capability is determined by the first GPU in the system.
    The compared version is a string in the form of "major.minor".

    Args:
        major: major version number to be compared with.
        minor: minor version number to be compared with. Defaults to 0.
        current_ver_string: if None, the current system GPU CUDA compute capability will be used.

    Returns:
        True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
    NpynvmlTFr   r,   r   r   r   r   rJ   rM   r   )r   cudais_availabler   nvmlInitnvmlDeviceGetHandleByIndexZ"nvmlDeviceGetCudaComputeCapabilityZnvmlShutdownr   r   rN   rH   )r   r   r   Zcuda_availabler   Z
has_pynvmlhandleZmajor_cZminor_cr   r   r   r   r   r   r   r.   r.   r4   compute_capabilities_after  s.   

 
r  )r#   T)
r$   r   r%   r&   r'   r   r(   r)   r*   r   )rF   r>   rG   r>   r*   rH   )TrT   )rU   r   rV   r)   rW   r>   r*   rX   )ro   r>   rp   r>   rq   r   r*   r   )r8   )r   r   r   r>   r   r   r*   r)   )r   r   r   r>   r   r   r*   r)   )r   r>   r   r>   r   r   ri   r>   r   r>   r   r   r   r)   r   r>   r*   r   )
r   r>   r   r>   r   r   r   r)   r*   r   )r   )r   r>   r   r>   r*   r   )r   r>   r   r>   r*   r)   )r   N)
r   rH   r   rH   r   rH   r   r   r*   r)   )r   rH   r   rH   r   r   r*   r)   );
__future__r   r@   	functoolsr   r~   r   r_   rx   collections.abcr   r   r   r   r   r   r	   	importlibr
   pkgutilr   pydocr   r   typesr   r   typingr   r   r   r   rE   run_evalrw   allow_missing_referenceZOPTIONAL_IMPORT_MSG_FMT__all__r   r   rn   r   r   r   r   r   r   rd   r   r   r   r   	lru_cacher   r   r    r!   r"   r  r.   r.   r.   r4   <module>   sn   
S
 9
 
'




2