U
    Ph[I                     @  s  d dl mZ d dlmZ d dlmZ d dlZd dlZd dl	m
Z
mZ d dlmZ d dlmZmZ dd	d
dddddddddddddddddddddd d!d"d#d$d%d&gZd'd'd(d)d&Zdkd-d'd.d/d0dZd'd1d1d'd2d3d	Zd4d
 Zd'd'd5d6dZdld'd7d.d8d9d:dZdmd'd'd;d<dZd-d-d5d=dZdnd-d7d-d?d@dZd'd'd(dAdZd'd'd5dBdZd'dCdDdZd'dCdEdZd'd'd(dFdZd'd1d'dGdHdZ d'd'd'd/dIdZ!dodJdKd'dLdMdZ"dpd'd'd5dNdZ#d'd'd(dOdZ$dqd-d'd-dPdQdZ%drd'dKd7d'dRdSdZ&d'd'd(dTdZ'edUZ(dVdWd(dXdZ)dYdKd-dZd[dZ*dsd-dKd.d-d]d^dZ+d-d-d(d_d Z,d-d-d(d`daZ-dtd-dbd-dZdcd!Z.dud-dbd-dZddd$Z/dvd-dbd-dZded#Z0dwd-dbd-dZdfd"Z1dxd-dbd.d-dgdhd%Z2dyd-dbd-dZdidjZ3dS )z    )annotations)Sequence)TypeVarN)NdarrayOrTensorNdarrayTensor)is_module_ver_at_least)convert_data_typeconvert_to_dst_typeallclosemoveaxisin1dclip
percentilewhereargwhereargsortnonzerofloor_divideunravel_indexunravel_indicesravel	any_np_ptmaximumconcatenatecumsumisfinitesearchsortedrepeatisnanascontiguousarraystackmodeuniquemaxminmedianmeanstdsoftplusr   )xreturnc                 C  s0   t | tjrtt| | S tt| | S )zstable softplus through `np.logaddexp` with equivalent implementation for torch.

    Args:
        x: array/tensor.

    Returns:
        Softplus of the input.
    )
isinstancenpndarray	logaddexp
zeros_liketorchr)    r2   e/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/utils_pytorch_numpy_unification.pyr(   ;   s    	h㈵>:0yE>Fr   bool)abr*   c                 C  sF   t || dd^}}t| tjr2tj| ||||dS tj| ||||dS )z7`np.allclose` with equivalent implementation for torch.T)wrap_sequence)rtolatol	equal_nan)r	   r+   r,   r-   r
   r0   )r7   r8   r:   r;   r<   _r2   r2   r3   r
   I   s    zint | Sequence[int])r)   srcdstr*   c                 C  s(   t | tjrt| ||S t| ||S )z `moveaxis` for pytorch and numpy)r+   r0   Tensormovedimr,   r   )r)   r>   r?   r2   r2   r3   r   Q   s    c                 C  s<   t | tjrt| |S | d tj|| jdkddS )z3`np.in1d` with equivalent implementation for torch.).Ndevice)	r+   r,   r-   r   r0   tensorrC   anyview)r)   yr2   r2   r3   r   X   s    )r7   r*   c                 C  s.   t | tjrt| ||}nt| ||}|S )z3`np.clip` with equivalent implementation for torch.)r+   r,   r-   r   r0   clamp)r7   a_mina_maxresultr2   r2   r3   r   _   s    z
int | NonezNdarrayOrTensor | float | int)r)   dimkeepdimr*   c                 K  s   t |tjddd }|dk |dkB  r8td| dt| tjs^t| tjrt| dkrt | tjdd }tj	||f||d	|}t
|| d }n$t
|d
 | d }tj| |||d}|S )a(  `np.percentile` with equivalent implementation for torch.

    Pytorch uses `quantile`. For more details please refer to:
    https://pytorch.org/docs/stable/generated/torch.quantile.html.
    https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.

    Args:
        x: input data.
        q: percentile to compute (should in range 0 <= q <= 100).
        dim: the dim along which the percentiles are computed. default is to compute the percentile
            along a flattened version of the array.
        keepdim: whether the output data has dim retained or not.
        kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
            https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.

    Returns:
        Resulting value (scalar)
    T)output_typer9   r   d   z*q values must be in [0, 100], got values: .i@B )rO   )axiskeepdimsg      Y@)rM   rN   )r   r,   r-   rF   
ValueErrorr+   r0   r@   numelr   r	   quantile)r)   qrM   rN   kwargsZq_np_xrL   r2   r2   r3   r   i   s    &)	conditionr*   c                 C  sz   t | tjr0|dk	r$t| ||}qvt| }nF|dk	rltj|| jd}tj|| j|jd}t| ||}n
t| }|S )zA
    Note that `torch.where` may convert y.dtype to x.dtype.
    NrB   )rC   dtype)r+   r,   r-   r   r0   	as_tensorrC   r[   )rZ   r)   rH   rL   r2   r2   r3   r      s    
c                 C  s    t | tjrt| S t| S )a  `np.argwhere` with equivalent implementation for torch.

    Args:
        a: input data.

    Returns:
        Indices of elements that are non-zero. Indices are grouped by element.
        This array will have shape (N, a.ndim) where N is the number of non-zero items.
    )r+   r,   r-   r   r0   )r7   r2   r2   r3   r      s    

rD   )r7   rR   r*   c                 C  s(   t | tjrtj| |dS tj| |dS )z`np.argsort` with equivalent implementation for torch.

    Args:
        a: the array/tensor to sort.
        axis: axis along which to sort.

    Returns:
        Array/Tensor of indices that sort a along the specified axis.
    rR   )rM   )r+   r,   r-   r   r0   )r7   rR   r2   r2   r3   r      s    
c                 C  s(   t | tjrt| d S t|  S )z`np.nonzero` with equivalent implementation for torch.

    Args:
        x: array/tensor.

    Returns:
        Index unravelled for given shape
    r   )r+   r,   r-   r   r0   flattenr1   r2   r2   r3   r      s    	c                 C  s>   t | tjr2ttdr&tj| |ddS t| |S t| |S )aD  `np.floor_divide` with equivalent implementation for torch.

    As of pt1.8, use `torch.div(..., rounding_mode="floor")`, and
    before that, use `torch.floor_divide`.

    Args:
        a: first array/tensor
        b: scalar to divide by

    Returns:
        Element-wise floor division between two arrays/tensors.
    )      r   floor)rounding_mode)r+   r0   r@   r   divr   r,   r7   r8   r2   r2   r3   r      s
    
)r*   c                 C  s\   t | tjrJg }t|D ]}|| |  t| |} qt|ddd S tt	| |S )z`np.unravel_index` with equivalent implementation for torch.

    Args:
        idx: index to unravel.
        shape: shape of array/tensor.

    Returns:
        Index unravelled for given shape
    NrD   )
r+   r0   r@   reversedappendr   r    r,   asarrayr   )idxshapecoordrM   r2   r2   r3   r      s    
c                   s2   t | d tjrtjntj}| fdd| D S )zComputing unravel coordinates from indices.

    Args:
        idx: a sequence of indices to unravel.
        shape: shape of array/tensor.

    Returns:
        Stacked indices unravelled for given shape
    r   c                   s   g | ]}t | qS r2   )r   ).0iri   r2   r3   
<listcomp>   s     z#unravel_indices.<locals>.<listcomp>)r+   r0   r@   r    r,   )rh   ri   Z	lib_stackr2   rm   r3   r      s    
c                 C  s4   t | tjr*ttdr|  S |   S t| S )z`np.ravel` with equivalent implementation for torch.

    Args:
        x: array/tensor to ravel.

    Returns:
        Return a contiguous flattened array/tensor.
    r   )r+   r0   r@   hasattrr   r^   
contiguousr,   r1   r2   r2   r3   r     s
    	
)r)   rR   r*   c              	   C  sp   t | tjrt| |S t |ts(|gn|}|D ]:}zt| |} W q0 tk
rh   t|  |} Y q0X q0| S )a  `np.any` with equivalent implementation for torch.

    For pytorch, convert to boolean for compatibility with older versions.

    Args:
        x: input array/tensor.
        axis: axis to perform `any` over.

    Returns:
        Return a contiguous flattened array/tensor.
    )r+   r,   r-   rF   r   r0   RuntimeErrorr6   )r)   rR   axr2   r2   r3   r     s    c                 C  s0   t | tjr$t |tjr$t| |S t| |S )z`np.maximum` with equivalent implementation for torch.

    Args:
        a: first array/tensor.
        b: second array/tensor.

    Returns:
        Element-wise maximum between two arrays/tensors.
    )r+   r0   r@   r   r,   rd   r2   r2   r3   r   +  s    
zSequence[NdarrayOrTensor]int)to_catrR   r*   c                 C  s.   t | d tjrt| ||S tj| ||dS )zH`np.concatenate` with equivalent implementation for torch (`torch.cat`).r   )rM   out)r+   r,   r-   r   r0   cat)rt   rR   ru   r2   r2   r3   r   :  s    c                 K  sN   t | tjrt| |S |dkr8tj| dd df|S tj| fd|i|S )aH  
    `np.cumsum` with equivalent implementation for torch.

    Args:
        a: input data to compute cumsum.
        axis: expected axis to compute cumsum.
        kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:
            https://pytorch.org/docs/stable/generated/torch.cumsum.html.

    Nr   rM   )r+   r,   r-   r   r0   )r7   rR   rX   r2   r2   r3   r   A  s
    c                 C  s    t | tjst| S t| S )z7`np.isfinite` with equivalent implementation for torch.)r+   r0   r@   r,   r   r1   r2   r2   r3   r   T  s    
)r7   vr*   c                 K  s@   |rdnd}t | tjr(t| |||S tj| |fd|i|S )ay  
    `np.searchsorted` with equivalent implementation for torch.

    Args:
        a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.
        v: containing the search values.
        right: if False, return the first suitable location that is found, if True, return the last such index.
        sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.
        kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:
            https://pytorch.org/docs/stable/generated/torch.searchsorted.html.

    rightleft)r+   r,   r-   r   r0   )r7   rw   rx   sorterrX   sider2   r2   r3   r   [  s    )r7   repeatsrR   r*   c                 K  s2   t | tjrt| ||S tj| |fd|i|S )a  
    `np.repeat` with equivalent implementation for torch (`repeat_interleave`).

    Args:
        a: input data to repeat.
        repeats: number of repetitions for each element, repeats is broadcast to fit the shape of the given axis.
        axis: axis along which to repeat values.
        kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:
            https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.

    rM   )r+   r,   r-   r   r0   repeat_interleave)r7   r|   rR   rX   r2   r2   r3   r   n  s    c                 C  s    t | tjrt| S t| S )z^`np.isnan` with equivalent implementation for torch.

    Args:
        x: array/tensor.

    )r+   r,   r-   r   r0   r1   r2   r2   r3   r     s    
TzNdarrayTensor | TzNdarrayOrTensor | Tc                 K  s@   t | tjr$| jdkr| S t| S t | tjr<| jf |S | S )a-  `np.ascontiguousarray` with equivalent implementation for torch (`contiguous`).

    Args:
        x: array/tensor.
        kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
            https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.

    r   )r+   r,   r-   ndimr   r0   r@   rp   r)   rX   r2   r2   r3   r     s    	

zSequence[NdarrayTensor])r)   rM   r*   c                 C  s(   t | d tjrt| |S t| |S )z`np.stack` with equivalent implementation for torch.

    Args:
        x: array/tensor.
        dim: dimension along which to perform the stack (referred to as `axis` by numpy).
    r   )r+   r,   r-   r    r0   )r)   rM   r2   r2   r3   r      s    T)r)   rM   to_longr*   c                 C  sB   |r
t jnd}t| t j|d^}}t ||j}t|| ^}}|S )z`torch.mode` with equivalent implementation for numpy.

    Args:
        x: array/tensor.
        dim: dimension along which to perform `mode` (referred to as `axis` by numpy).
        to_long: convert input to long before performing mode.
    N)r[   )r0   int64r   r@   r!   valuesr	   )r)   rM   r   r[   Zx_tr=   o_tor2   r2   r3   r!     s
    c                 K  s,   t | tjtfrtj| f|S tj| f|S )za`torch.unique` with equivalent implementation for numpy.

    Args:
        x: array/tensor.
    )r+   r,   r-   listr"   r0   r   r2   r2   r3   r"     s    c                 C  sD   t | tjr ttdr t| S t | tjr8tj| S tj| S )ze`torch.linalg.inv` with equivalent implementation for numpy.

    Args:
        x: array/tensor.
    inverse)r+   r0   r@   ro   r   linalginvr,   r1   r2   r2   r3   
linalg_inv  s    
r   zint | tuple | Nonec                 K  sv   |dkr6t | tjtfr&tj| f|ntj| f|}n<t | tjtfr^tj| fd|i|}ntj| t|f|}|S )z`torch.max` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the maximum of x.

    NrR   )r+   r,   r-   r   r#   r0   rs   r)   rM   rX   retr2   r2   r3   r#     s    .c                 K  sv   |dkr6t | tjtfr&tj| f|ntj| f|}n<t | tjtfr^tj| fd|i|}ntj| t|f|}|S )z`torch.mean` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the mean of x
    NrR   )r+   r,   r-   r   r&   r0   rs   r   r2   r2   r3   r&     s    .c                 K  sv   |dkr6t | tjtfr&tj| f|ntj| f|}n<t | tjtfr^tj| fd|i|}ntj| t|f|}|S )z`torch.median` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns
        the median of x.
    NrR   )r+   r,   r-   r   r%   r0   rs   r   r2   r2   r3   r%     s    .c                 K  sv   |dkr6t | tjtfr&tj| f|ntj| f|}n<t | tjtfr^tj| fd|i|}ntj| t|f|}|S )z`torch.min` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the minimum of x.
    NrR   )r+   r,   r-   r   r$   r0   rs   r   r2   r2   r3   r$     s    .)r)   rM   unbiasedr*   c                 C  sf   |dkr0t | tjtfr"t| n
t| |}n2t | tjtfrPtj| |d}nt| t||}|S )z`torch.std` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the standard deviation of x.
    Nr]   )r+   r,   r-   r   r'   r0   rs   )r)   rM   r   r   r2   r2   r3   r'   (  s    (c                 K  sv   |dkr6t | tjtfr&tj| f|ntj| f|}n<t | tjtfr^tj| fd|i|}ntj| t|f|}|S )z`torch.sum` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the sum of x.
    NrR   )r+   r,   r-   r   sumr0   rs   r   r2   r2   r3   r   >  s    .r   )r4   r5   F)NF)NN)rD   )r   N)N)FN)N)rD   T)N)N)N)N)NF)N)4
__future__r   collections.abcr   typingr   numpyr,   r0   monai.config.type_definitionsr   r   monai.utils.miscr   Zmonai.utils.type_conversionr   r	   __all__r(   r
   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r~   r   r    r!   r"   r   r#   r&   r%   r$   r'   r   r2   r2   r2   r3   <module>   s   #   #	