o
    iI                     @  s  d dl mZ d dlmZ d dlmZ d dlZd dlZd dl	m
Z
mZ d dlmZmZ g dZdkddZdldmddZdnddZdd ZdoddZ	dpdqd#d$Zdrdsd&d'Zdtd(d)Zdudvd,d-Zdkd.d/Zdod0d1Zdwd2d3Zdwd4d5Zdkd6d7Zdxd8d9Zdyd:d;Zdzd{d?d@Z d|dodAdBZ!dkdCdDZ"d}d~dFdGZ#d|ddIdJZ$dkdKdLZ%edMZ&ddPdQZ'ddSdTZ(dddWdXZ)ddYdZZ*dd[d\Z+d|dd^d_Z,d|dd`daZ-d|ddbdcZ.d|ddddeZ/dpddgdhZ0d|ddidjZ1dS )    )annotations)Sequence)TypeVarN)NdarrayOrTensorNdarrayTensor)convert_data_typeconvert_to_dst_type)allclosemoveaxisin1dclip
percentilewhereargwhereargsortnonzerofloor_divideunravel_indexunravel_indicesravel	any_np_ptmaximumconcatenatecumsumisfinitesearchsortedrepeatisnanascontiguousarraystackmodeuniquemaxminmedianmeanstdsoftplusxr   returnc                 C  s0   t | tjrtt| | S tt| | S )zstable softplus through `np.logaddexp` with equivalent implementation for torch.

    Args:
        x: array/tensor.

    Returns:
        Softplus of the input.
    )
isinstancenpndarray	logaddexp
zeros_liketorchr(    r1   r/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/transforms/utils_pytorch_numpy_unification.pyr'   :   s   	r'   h㈵>:0yE>Far   bboolc                 C  sF   t || dd^}}t| tjrtj| ||||dS tj| ||||dS )z7`np.allclose` with equivalent implementation for torch.T)wrap_sequence)rtolatol	equal_nan)r   r*   r+   r,   r	   r/   )r5   r6   r9   r:   r;   _r1   r1   r2   r	   H   s   r	   srcint | Sequence[int]dstc                 C  s(   t | tjrt| ||S t| ||S )z `moveaxis` for pytorch and numpy)r*   r/   Tensormovedimr+   r
   )r(   r=   r?   r1   r1   r2   r
   P   s   r
   c                 C  s<   t | tjrt| |S | d tj|| jdkddS )z3`np.in1d` with equivalent implementation for torch.).Ndevice)	r*   r+   r,   isinr/   tensorrC   anyview)r(   yr1   r1   r2   r   W   s   $r   c                 C  s0   t | tjrt| ||}|S t| ||}|S )z3`np.clip` with equivalent implementation for torch.)r*   r+   r,   r   r/   clamp)r5   a_mina_maxresultr1   r1   r2   r   ^   s
   r   dim
int | NonekeepdimNdarrayOrTensor | float | intc                 K  s   t |tjddd }|dk |dkB  rtd| dt| tjs/t| tjrNt| dkrNt | tjdd }tj	||f||d	|}t
|| d }|S t
|d
 | d }tj| |||d}|S )a(  `np.percentile` with equivalent implementation for torch.

    Pytorch uses `quantile`. For more details please refer to:
    https://pytorch.org/docs/stable/generated/torch.quantile.html.
    https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.

    Args:
        x: input data.
        q: percentile to compute (should in range 0 <= q <= 100).
        dim: the dim along which the percentiles are computed. default is to compute the percentile
            along a flattened version of the array.
        keepdim: whether the output data has dim retained or not.
        kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
            https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.

    Returns:
        Resulting value (scalar)
    T)output_typer8   r   d   z*q values must be in [0, 100], got values: .i@B )rR   )axiskeepdimsg      Y@)rN   rP   )r   r+   r,   rG   
ValueErrorr*   r/   r@   numelr   r   quantile)r(   qrN   rP   kwargsZq_np_xrM   r1   r1   r2   r   h   s   &r   	conditionc                 C  s   t | tjr|durt| ||}|S t| }|S |dur9tj|| jd}tj|| j|jd}t| ||}|S t| }|S )zA
    Note that `torch.where` may convert y.dtype to x.dtype.
    NrB   )rC   dtype)r*   r+   r,   r   r/   	as_tensorrC   r^   )r]   r(   rI   rM   r1   r1   r2   r      s   


r   c                 C      t | tjrt| S t| S )a  `np.argwhere` with equivalent implementation for torch.

    Args:
        a: input data.

    Returns:
        Indices of elements that are non-zero. Indices are grouped by element.
        This array will have shape (N, a.ndim) where N is the number of non-zero items.
    )r*   r+   r,   r   r/   )r5   r1   r1   r2   r      s   


r   rD   rU   c                 C  s(   t | tjrtj| |dS tj| |dS )z`np.argsort` with equivalent implementation for torch.

    Args:
        a: the array/tensor to sort.
        axis: axis along which to sort.

    Returns:
        Array/Tensor of indices that sort a along the specified axis.
    rU   )rN   )r*   r+   r,   r   r/   )r5   rU   r1   r1   r2   r      s   
r   c                 C  s(   t | tjrt| d S t|  S )z`np.nonzero` with equivalent implementation for torch.

    Args:
        x: array/tensor.

    Returns:
        Index unravelled for given shape
    r   )r*   r+   r,   r   r/   flattenr0   r1   r1   r2   r      s   	r   c                 C  s*   t | tjrt| |S tt| |S )aD  `np.floor_divide` with equivalent implementation for torch.

    As of pt1.8, use `torch.div(..., rounding_mode="floor")`, and
    before that, use `torch.floor_divide`.

    Args:
        a: first array/tensor
        b: scalar to divide by

    Returns:
        Element-wise floor division between two arrays/tensors.
    )r*   r/   r@   r   r+   asarrayr5   r6   r1   r1   r2   r      s   r   c                 C  s\   t | tjr%g }t|D ]}|| |  t| |} qt|ddd S tt	| |S )z`np.unravel_index` with equivalent implementation for torch.

    Args:
        idx: index to unravel.
        shape: shape of array/tensor.

    Returns:
        Index unravelled for given shape
    NrD   )
r*   r/   r@   reversedappendr   r   r+   rc   r   )idxshapecoordrN   r1   r1   r2   r      s   
r   c                   s2   t | d tjrtjntj}| fdd| D S )zComputing unravel coordinates from indices.

    Args:
        idx: a sequence of indices to unravel.
        shape: shape of array/tensor.

    Returns:
        Stacked indices unravelled for given shape
    r   c                   s   g | ]}t | qS r1   )r   ).0irh   r1   r2   
<listcomp>   s    z#unravel_indices.<locals>.<listcomp>)r*   r/   r@   r   r+   )rg   rh   Z	lib_stackr1   rl   r2   r      s   
r   c                 C  s4   t | tjrttdr|  S |   S t| S )z`np.ravel` with equivalent implementation for torch.

    Args:
        x: array/tensor to ravel.

    Returns:
        Return a contiguous flattened array/tensor.
    r   )r*   r/   r@   hasattrr   rb   
contiguousr+   r0   r1   r1   r2   r      s
   	

r   c              	   C  sl   t | tjrt| |S t |ts|gn|}|D ]}zt| |} W q ty3   t|  |} Y qw | S )a  `np.any` with equivalent implementation for torch.

    For pytorch, convert to boolean for compatibility with older versions.

    Args:
        x: input array/tensor.
        axis: axis to perform `any` over.

    Returns:
        Return a contiguous flattened array/tensor.
    )r*   r+   r,   rG   r   r/   RuntimeErrorr7   )r(   rU   axr1   r1   r2   r     s   r   c                 C  s0   t | tjrt |tjrt| |S t| |S )z`np.maximum` with equivalent implementation for torch.

    Args:
        a: first array/tensor.
        b: second array/tensor.

    Returns:
        Element-wise maximum between two arrays/tensors.
    )r*   r/   r@   r   r+   rd   r1   r1   r2   r   )  s   
r   to_catSequence[NdarrayOrTensor]intc                 C  s.   t | d tjrt| ||S tj| ||dS )zH`np.concatenate` with equivalent implementation for torch (`torch.cat`).r   )rN   out)r*   r+   r,   r   r/   cat)rr   rU   ru   r1   r1   r2   r   8  s   r   c                 K  sR   t | tjrt| |S |du rtj| dd dfi |S tj| fd|i|S )aH  
    `np.cumsum` with equivalent implementation for torch.

    Args:
        a: input data to compute cumsum.
        axis: expected axis to compute cumsum.
        kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:
            https://pytorch.org/docs/stable/generated/torch.cumsum.html.

    Nr   rN   )r*   r+   r,   r   r/   )r5   rU   r[   r1   r1   r2   r   ?  s
   r   c                 C  s    t | tjst| S t| S )z7`np.isfinite` with equivalent implementation for torch.)r*   r/   r@   r+   r   r0   r1   r1   r2   r   R  s   

r   vc                 K  s@   |rdnd}t | tjrt| |||S tj| |fd|i|S )ay  
    `np.searchsorted` with equivalent implementation for torch.

    Args:
        a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.
        v: containing the search values.
        right: if False, return the first suitable location that is found, if True, return the last such index.
        sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.
        kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:
            https://pytorch.org/docs/stable/generated/torch.searchsorted.html.

    rightleft)r*   r+   r,   r   r/   )r5   rw   rx   sorterr[   sider1   r1   r2   r   Y  s   r   repeatsc                 K  s2   t | tjrt| ||S tj| |fd|i|S )a  
    `np.repeat` with equivalent implementation for torch (`repeat_interleave`).

    Args:
        a: input data to repeat.
        repeats: number of repetitions for each element, repeats is broadcast to fit the shape of the given axis.
        axis: axis along which to repeat values.
        kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:
            https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.

    rN   )r*   r+   r,   r   r/   repeat_interleave)r5   r|   rU   r[   r1   r1   r2   r   l  s   r   c                 C  r`   )z^`np.isnan` with equivalent implementation for torch.

    Args:
        x: array/tensor.

    )r*   r+   r,   r   r/   r0   r1   r1   r2   r   }  s   

r   TNdarrayTensor | TNdarrayOrTensor | Tc                 K  sD   t | tjr| jdkr| S t| S t | tjr | jdi |S | S )a-  `np.ascontiguousarray` with equivalent implementation for torch (`contiguous`).

    Args:
        x: array/tensor.
        kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
            https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.

    r   Nr1   )r*   r+   r,   ndimr   r/   r@   ro   r(   r[   r1   r1   r2   r     s   	

r   Sequence[NdarrayTensor]c                 C  s(   t | d tjrt| |S t| |S )z`np.stack` with equivalent implementation for torch.

    Args:
        x: array/tensor.
        dim: dimension along which to perform the stack (referred to as `axis` by numpy).
    r   )r*   r+   r,   r   r/   )r(   rN   r1   r1   r2   r     s   r   Tto_longc                 C  sB   |rt jnd}t| t j|d^}}t ||j}t|| ^}}|S )z`torch.mode` with equivalent implementation for numpy.

    Args:
        x: array/tensor.
        dim: dimension along which to perform `mode` (referred to as `axis` by numpy).
        to_long: convert input to long before performing mode.
    N)r^   )r/   int64r   r@   r    valuesr   )r(   rN   r   r^   Zx_tr<   Zo_tor1   r1   r2   r      s
   r    c                 K  s4   t | tjtfrtj| fi |S tj| fi |S )za`torch.unique` with equivalent implementation for numpy.

    Args:
        x: array/tensor.
    )r*   r+   r,   listr!   r/   r   r1   r1   r2   r!     s   4r!   c                 C  sD   t | tjrttdrt| S t | tjrtj| S tj| S )ze`torch.linalg.inv` with equivalent implementation for numpy.

    Args:
        x: array/tensor.
    inverse)r*   r/   r@   rn   r   linalginvr+   r0   r1   r1   r2   
linalg_inv  s   
$r   int | tuple | Nonec                 K     |du rt | tjtfrtj| fi |ntj| fi |}n t | tjtfr3tj| fd|i|}ntj| t|fi |}t |trH|d S |S )z`torch.max` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the maximum of x.

    NrU   r   )r*   r+   r,   r   r"   r/   rt   tupler(   rN   r[   retr1   r1   r2   r"     s   6r"   c                 K     |du r"t | tjtfrtj| fi |}|S tj| fi |}|S t | tjtfr7tj| fd|i|}|S tj| t|fi |}|S )z`torch.mean` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the mean of x
    NrU   )r*   r+   r,   r   r%   r/   rt   r   r1   r1   r2   r%        "r%   c                 K  r   )z`torch.median` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns
        the median of x.
    NrU   )r*   r+   r,   r   r$   r/   rt   r   r1   r1   r2   r$     r   r$   c                 K  r   )z`torch.min` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the minimum of x.
    NrU   r   )r*   r+   r,   r   r#   r/   rt   r   r   r1   r1   r2   r#     s   6r#   unbiasedc                 C  sn   |du rt | tjtfrt| }|S t| |}|S t | tjtfr,tj| |d}|S t| t||}|S )z`torch.std` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the standard deviation of x.
    Nra   )r*   r+   r,   r   r&   r/   rt   )r(   rN   r   r   r1   r1   r2   r&   &  s   r&   c                 K  r   )z`torch.sum` with equivalent implementation for numpy

    Args:
        x: array/tensor.

    Returns:
        the sum of x.
    NrU   )r*   r+   r,   r   sumr/   rt   r   r1   r1   r2   r   <  r   r   )r(   r   r)   r   )r3   r4   F)r5   r   r6   r   r)   r7   )r(   r   r=   r>   r?   r>   r)   r   )r5   r   r)   r   )NF)r(   r   rN   rO   rP   r7   r)   rQ   )NN)r]   r   r)   r   )r5   r   r)   r   )rD   )r5   r   rU   rO   r)   r   )r)   r   )r(   r   rU   r>   r)   r   )r5   r   r6   r   r)   r   )r   N)rr   rs   rU   rt   r)   r   )N)FN)r5   r   rw   r   r)   r   )r5   r   r|   rt   rU   rO   r)   r   )r(   r   r)   r   )r(   r   rN   rt   r)   r   )rD   T)r(   r   rN   rt   r   r7   r)   r   )r(   r   r)   r   )r(   r   rN   r   r)   r   )r(   r   rN   r   r   r7   r)   r   )2
__future__r   collections.abcr   typingr   numpyr+   r/   monai.config.type_definitionsr   r   Zmonai.utils.type_conversionr   r   __all__r'   r	   r
   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r~   r   r   r    r!   r   r"   r%   r$   r#   r&   r   r1   r1   r1   r2   <module>   sV   
#

#













	