U
    PhVk                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZmZ d dl	Z
d dlZd dlZd dlmZ d dlmZmZ d dlmZmZmZmZ d dlmZ d d	lmZmZmZmZ d d
lmZm Z m!Z!m"Z" dgZ#e$ddd Z%dd Z&G dd deej'Z(dS )    )annotationsN)deepcopy)AnySequence)NdarrayTensor)MetaObjget_track_meta)affine_to_spacingdecollate_batchlist_data_collateremove_extra_metadata)look_up_option)LazyAttrMetaKeysPostFix	SpaceKeys)convert_data_typeconvert_to_dst_typeconvert_to_numpyconvert_to_tensor
MetaTensorc                 C  sH   t tdrDt | drDt tj| jrDtttj| jtrDttj| jS d S )Nreturn_types__name__)hasattrtorchr   r   
isinstancegetattrtype)func r   K/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/data/meta_tensor.py_get_named_tuple_like_type!   s    r!   c              	   C  sH   t | ttttjtjtjtj	fpFt | t
pDt | toDtdd | D  S )Nc                 s  s   | ]}t |tV  qd S )N)r   r   .0xr   r   r    	<genexpr>/   s     z*_not_requiring_metadata.<locals>.<genexpr>)r   intstrbytesr   Sizedtypedevicenpndarrayr   r   any)retr   r   r    _not_requiring_metadata-   s     $r0   c                      s  e Zd ZdZedQdddd dddZdRdddd	d fd
dZedddddZedd Z	edSdd fddZ
edd Zdd Zdd ZeejfddddZddd d!Zejddfd"d#ZdTd%d&d'd(Zed)d* Zejd	dd+d*Zejdfd,d-d.d/d0ZdUd1d2Zeddd3d4Zejd5d	d6d7d4Zed8d9 Zd:d; Zd<d= Zd>d? Z dVd@dAZ!dBdC Z"edWd5dd%dEd,dFdGdHZ#dIdJ Z$dKdL Z%dMdN Z&d	ddOdPZ'  Z(S )Xr   a	  
    Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.

    Metadata is stored in the form of a dictionary. Nested, an affine matrix will be
    stored. This should be in the form of `torch.Tensor`.

    Behavior should be the same as `torch.Tensor` aside from the extended
    meta functionality.

    Copying of information:

        * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
          first instance of `MetaTensor` if `a.is_batch` is False
          (For batched data, the metadata will be shallow copied for efficiency purposes).

    Example:
        .. code-block:: python

            import torch
            from monai.data import MetaTensor

            t = torch.tensor([1,2,3])
            affine = torch.as_tensor([[2,0,0,0],
                                      [0,2,0,0],
                                      [0,0,2,0],
                                      [0,0,0,1]], dtype=torch.float64)
            meta = {"some": "info"}
            m = MetaTensor(t, affine=affine, meta=meta)
            m2 = m + m
            assert isinstance(m2, MetaTensor)
            assert m2.meta["some"] == "info"
            assert torch.all(m2.affine == affine)

    Notes:
        - Requires pytorch 1.9 or newer for full compatibility.
        - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
          not work if `im` is of type `MetaTensor`. This can be resolved with
          `torch.jit.trace(net, im.as_tensor())`.
        - For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported.
        - For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor.
          see: https://github.com/pytorch/pytorch/issues/54457
        - A warning will be raised if in the constructor `affine` is not `None` and
          `meta` already contains the key `affine`.
        - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.
        - With a batch of data, `batch[0]` will return the 0th image
          with the 0th metadata. When the batch dimension is non-singleton, e.g.,
          `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the
          last example) of the metadata will be returned, and `is_batch` will return `True`.
        - When creating a batch with this class, use `monai.data.DataLoader` as opposed
          to `torch.utils.data.DataLoader`, as this will take care of collating the
          metadata properly.
    Nztorch.Tensor | Nonezdict | Nonezlist | None)affinemetaapplied_operationsreturnc                 O  s:   |r| dd | dd dni }tj|f||| S )Nr+   r*   r+   r*   )popr   	as_tensoras_subclass)clsr$   r1   r2   r3   argskwargs_kwargsr   r   r    __new__i   s    
"zMetaTensor.__new__Nonec                   s   t    |dk	r|| _nt|tr0t|j| _|dk	rVtj| jkrNt	
d || _n&tj| jkrr| jtj | _n
|  | _|dk	r|| _n
t | _t|tjrt|ts| |  tj| jkrtj| jtj< dS )a  
        Args:
            x: initial array for the MetaTensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.
            affine: optional 4x4 array.
            meta: dictionary of metadata.
            applied_operations: list of previously applied operations on the MetaTensor,
                the list is typically maintained by `monai.transforms.TraceableTransform`.
                See also: :py:class:`monai.transforms.TraceableTransform`
            _args: additional args (currently not in use in this constructor).
            _kwargs: additional kwargs (currently not in use in this constructor).

        Note:
            If a `meta` dictionary is given, use it. Else, if `meta` exists in the input tensor `x`, use it.
            Else, use the default value. Similar for the affine, except this could come from
            four places, priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.

        NzRSetting affine, but the applied meta contains an affine. This will be overwritten.)super__init__r2   r   r   r   __dict__r   AFFINEwarningswarnr1   get_default_affiner3   get_default_applied_operationsr   Tensorr   copy_meta_fromSPACEr   RAS)selfr$   r1   r2   r3   _argsr<   	__class__r   r    r@   v   s&    





zMetaTensor.__init__r   )retsr4   c           
   	   C  s   g }d}t dd t|| D }t| D ]l\}}t|tsBnNt sR| }n>t|| }	||_	|j
|	| d |rt||||||}|| q.t| trt|S |S )a  
        Update the metadata from the output of `MetaTensor.__torch_function__`.

        The output of `torch.Tensor.__torch_function__` could be a single object or a
        sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
        list of not already, and then we loop across each element, processing metadata
        as necessary. For each element, if not of type `MetaTensor`, then nothing to do.

        Args:
            rets: the output from `torch.Tensor.__torch_function__`, which has been
                converted to a list in `MetaTensor.__torch_function__` if it wasn't
                already a `Sequence`.
            func: the torch function that was applied. Examples might be `torch.squeeze`
                or `torch.Tensor.__add__`. We need this since the metadata need to be
                treated differently if a batch of data is considered. For example,
                slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
                dimension of a batch of data should return a ith tensor with the ith
                metadata.
            args: positional arguments that were passed to `func`.
            kwargs: keyword arguments that were passed to `func`.

        Returns:
            A sequence with the same number of elements as `rets`. For each element, if
            the input type was not `MetaTensor`, then no modifications will have been
            made. If global parameters have been set to false (e.g.,
            `not get_track_meta()`), then any `MetaTensor` will be converted to
            `torch.Tensor`. Else, metadata will be propagated as necessary (see
            :py:func:`MetaTensor._copy_meta`).
        Nc                 s  s   | ]}t |d r|jV  qdS )is_batchN)r   rP   r"   r   r   r    r%      s     
 z)MetaTensor.update_meta.<locals>.<genexpr>)	copy_attr)r.   r   flatten_meta_objsvalues	enumerater   r   r   r7   rP   rH   _handle_batchedappendtuple)
rO   r   r:   r;   outmetasrP   idxr/   Z	meta_argsr   r   r    update_meta   s    

zMetaTensor.update_metac                 C  s  |t jjkr|dks2t|dk s2t|d dk r6|S t|d trP|d d n|d }|tdddtdfkszt|t jr~|S t|d dd}|| }	t|	t	r|	rzt
|	}	W q ttttfk
r }
 ztd|
W 5 d}
~
X Y qX nt|	trd|	_t|	dr|	j |_n|t jjkrt|dkr8|d }nd	|krL|d	 }nd}|dkr|dkrtt|d dd}t|| dr|| j |_d|_|S )
z/utility function to handle batched MetaTensors.r         NF)detachzInconsistent batched metadata dicts when slicing a batch of MetaTensors, please consider converting it into a torch Tensor using `x.as_tensor()` or a numpy array using `x.array`.rA   dim)r   rG   __getitem__lenr   r   sliceEllipsisr
   listr   	TypeError
ValueErrorRuntimeError
IndexErrorr   rP   r   rA   copyunbind)r9   r/   rZ   rY   r   r:   r;   	batch_idxZ	dec_batchZret_metaer_   r   r   r    rU      sD    $""





zMetaTensor._handle_batchedr   r   )r4   c           	        s   |dkri }t  ||||}t|r*|S t|dk	rt|t|rt||||}t|jD ]$}|| j	|| _	|| j
|| _
q^|S t|ts|g}d}nd}t||||}|r|d S |S )zWraps all torch functions.NTFr   )r?   __torch_function__r0   r!   r   r   r[   rangen_fieldsr2   r3   r   )	r9   r   typesr:   r;   r/   Z	out_itemsrZ   unpackrM   r   r    rm     s"    
zMetaTensor.__torch_function__c                 C  s,   t | ttjttfr(t| tjddd S | S )NF)output_typewrap_sequencer   )	r   r   r   rG   rW   rd   r   r,   r-   )r$   r   r   r    _convert2  s    zMetaTensor._convertc                 C  s\   z|j dstW S W n tk
r.   t Y S X tttj|}dd | D }|||S )zQfor numpy Interoperability, so that we can compute ``np.sum(MetaTensor([1.0]))``.numpyc                 S  s   i | ]\}}|t |qS r   r   rt   r#   kvr   r   r    
<dictcomp>@  s      z1MetaTensor.__array_function__.<locals>.<dictcomp>)	
__module__
startswithNotImplementedAttributeErrorrd   mapr   rt   items)rK   r   rp   r:   r;   rL   r<   r   r   r    __array_function__8  s    

zMetaTensor.__array_function__c                 O  s   zt |jdstW S W n tk
r2   t Y S X |dkr@tS ttj|}dd | D }d|krjtS zt	||||W S  tk
r   t Y S X dS )z
        For numpy interoperability, so that we can compute ``MetaTensor([1.0]) >= np.asarray([1.0])``.
        This is for pytorch > 1.8.
        ru   __call__c                 S  s   i | ]\}}|t |qS r   rv   rw   r   r   r    rz   P  s      z.MetaTensor.__array_ufunc__.<locals>.<dictcomp>rX   N)
r   r{   r|   r}   r~   r   r   rt   r   r   )rK   ufuncmethodinputsr;   Z_inputsr<   r   r   r    __array_ufunc__C  s    

zMetaTensor.__array_ufunc__ztorch.Tensorc                 C  s   t jdt d| dS )N   cpur5   )r   eyer+   r*   r   r   r    rE   X  s    zMetaTensor.get_default_affinec                 C  s   |  tjS )z
        Return the `MetaTensor` as a `torch.Tensor`.
        It is OS dependent as to whether this will be a deep copy or not.
        )r8   r   rG   rK   r   r   r    r7   \  s    zMetaTensor.as_tensorc                 O  s   t | |||ddd S )a  
        Returns a new array in `output_type`, the array shares the same underlying storage when the output is a
        numpy array. Changes to self tensor will be reflected in the ndarray and vice versa.

        Args:
            output_type: output type, see also: :py:func:`monai.utils.convert_data_type`.
            dtype: dtype of output data. Converted to correct library type (e.g.,
                `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
                If left blank, it remains unchanged.
            device: if the output is a `torch.Tensor`, select device (if `None`, unchanged).
            _args: currently unused parameters.
            _kwargs: currently unused parameters.
        T)rr   r*   r+   rs   r   )r   )rK   rr   r*   r+   rL   r<   r   r   r    	get_arrayc  s    zMetaTensor.get_arrayFboolnon_blockingc                 O  sB   t |ddd}z| j||dW S  tk
r<   || _|  Y S X dS )a  
        Copies the elements from src into self tensor and returns self.
        The src tensor must be broadcastable with the self tensor.
        It may be of a different data type or reside on a different device.

        See also: `https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html`

        Args:
            src: the source tensor to copy from.
            non_blocking: if True and this copy is between CPU and GPU, the copy may occur
                asynchronously with respect to the host. For other cases, this argument has no effect.
            _args: currently unused parameters.
            _kwargs:  currently unused parameters.
        FT)
track_metars   r   N)r   copy_rg   data)rK   srcr   rL   r<   	convertedr   r   r    	set_arrays  s    zMetaTensor.set_arrayc                 C  s   |   S )a  
        Returns a numpy array of ``self``. The array and ``self`` shares the same underlying storage if self is on cpu.
        Changes to ``self`` (it's a subclass of torch.Tensor) will be reflected in the ndarray and vice versa.
        If ``self`` is not on cpu, the call will move the array to cpu and then the storage is not shared.

        :getter: see also: :py:func:`MetaTensor.get_array()`
        :setter: see also: :py:func:`MetaTensor.set_array()`
        )r   r   r   r   r    array  s    
zMetaTensor.arrayc                 C  s   |  | dS )z+A default setter using ``self.set_array()``N)r   )rK   r   r   r   r    r     s    r'   dict)keyr4   c                 C  sJ   |t jtjfkr td| d|| j||dt|| jt|| j	iS )a  
        Get the object as a dictionary for backwards compatibility.
        This method does not make a deep copy of the objects.

        Args:
            key: Base key to store main data. The key for the metadata will be determined using `PostFix`.
            output_type: `torch.Tensor` or `np.ndarray` for the main data.
            dtype: dtype of output data. Converted to correct library type (e.g.,
                `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
                If left blank, it remains unchanged.

        Return:
            A dictionary consisting of three keys, the main data (stored under `key`) and the metadata.
        z4output_type must be torch.Tensor or np.ndarray, got .)rr   r*   )
r   rG   r,   r-   rf   r   r   r2   
transformsr3   )rK   r   rr   r*   r   r   r    as_dict  s       zMetaTensor.as_dictc                 O  s   t |tr,|dd^}}|s"|n|d }nt|dd}t|dddhdd}|dkr\tj}n|d	krltj}nd
}| j	|||dS )a  
        Cast to ``dtype``, sharing data whenever possible.

        Args:
            dtype: dtypes such as np.float32, torch.float, "np.float32", float.
            device: the device if `dtype` is a torch data type.
            _args: additional args (currently unused).
            _kwargs: additional kwargs (currently unused).

        Returns:
            data array instance
        r   r]   r   r{   r   ru   r,   )default)ru   r,   N)rr   r*   r+   )
r   r'   splitr   r   r   rG   r,   r-   r   )rK   r*   r+   rL   r<   mod_strout_typer   r   r    astype  s    
zMetaTensor.astypec                 C  s   | j tj|  S )zAGet the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``)r2   getr   rB   rE   r   r   r   r    r1     s    zMetaTensor.affiner   )dr4   c                 C  s$   t j|t dt jd| jtj< dS )zSet the affine.r   r5   N)r   r7   r+   float64r2   r   rB   )rK   r   r   r   r    r1     s    c                 C  s    | j rdd | jD S t| jS )zGet the spacingc                 S  s   g | ]}t |qS r   )r	   )r#   ar   r   r    
<listcomp>  s     z%MetaTensor.pixdim.<locals>.<listcomp>)rP   r1   r	   r   r   r   r    pixdim  s    zMetaTensor.pixdimc                 C  sH   d}| j r| j d tjd}|dkrDtt| jdd dd S |S )z
        Get the currently expected spatial shape as if all the pending operations are executed.
        For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned.
        NT)rs   r]   )pending_operationsr   r   SHAPErW   r   shapetolist)rK   resr   r   r    peek_pending_shape  s    zMetaTensor.peek_pending_shapec                 C  s   | j }t|d }|dkr,td| d | jD ]T}t|tjt	j
d}|d krVq2t||d }tjj||}tjjj||}q2|S )Nr]   )r\      z)Only 2d and 3d affine are supported, got zd input.r   r   )r1   ra   rC   rD   r   r   r   r   rB   r   r   r   monair   utilsto_affine_ndr   lazyZcombine_transforms)rK   r   rpZnext_matrixr   r   r    peek_pending_affine  s    
zMetaTensor.peek_pending_affinec                 C  sB   | j r| j d tjd n| j}|d kr,dS ttdt|d S )Nr   r]   )r   r   r   rB   r1   r&   maxra   )rK   r   r   r   r    peek_pending_rank  s     zMetaTensor.peek_pending_rankc                 C  s   t | |  j||||dS )z
        must be defined for deepcopy to work

        See:
            - https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty
        )sizer*   r+   requires_grad)r   r7   	new_empty)rK   r   r*   r+   r   r   r   r    r     s    zMetaTensor.new_emptyc                 K  s$   t |  jf |}t| j|_|S )z
        Returns a copy of the MetaTensor instance.

        Args:
            kwargs: additional keyword arguments to `torch.clone`.

        See also: https://pytorch.org/docs/stable/generated/torch.clone.html
        )r   r7   cloner   rA   )rK   r;   Znew_instr   r   r    r     s    	zMetaTensor.cloner   z
str | None)imr2   simple_keyspatternsepc                 C  s   t | t o|dk	d}t|ts$|S |dkr0i }|rZtj|krRt |tj |tj< t| |dk	rxtjj	||dd|}|dkri }||_
tj|kr|tj |_n
t |_|S )aX  
        Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
        convert that to `torch.Tensor`, too. Remove any superfluous metadata.

        Args:
            im: Input image (`np.ndarray` or `torch.Tensor`)
            meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor.
            simple_keys: whether to keep only a simple subset of metadata keys.
            pattern: combined with `sep`, a regular expression used to match and prune keys
                in the metadata (nested dictionary), default to None, no key deletion.
            sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
                default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
                e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.

        Returns:
            By default, a `MetaTensor` is returned.
            However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
        N)r   T)keysr   Zuse_re)r   r   r   r   r   rB   r   r   r   ZDeleteItemsdr2   r1   rE   )r   r2   r   r   r   imgr   r   r    ensure_torch_and_prune_meta  s$    



z&MetaTensor.ensure_torch_and_prune_metac                 C  s   d|     S )z
        Prints a representation of the tensor.
        Prepends "meta" to ``torch.Tensor.__repr__``.
        Use ``print_verbose`` for associated metadata.
        r2   )r7   __repr__r   r   r   r    r   G  s    zMetaTensor.__repr__c                 C  s   dt |   S )z
        Prints a representation of the tensor.
        Prepends "meta" to ``torch.Tensor.__str__``.
        Use ``print_verbose`` for associated metadata.
        r2   )r'   r7   r   r   r   r    __str__O  s    zMetaTensor.__str__c                 C  s   |   |S )zO
        returns the output of pytorch tensor's ``__format__`` method.
        )r7   
__format__)rK   format_specr   r   r    r   W  s    zMetaTensor.__format__c                 C  s$   t |  | jdk	r t | j  dS )zVerbose print with meta data.N)printr2   r   r   r   r   r    print_verbose]  s    
zMetaTensor.print_verbose)NNN)NNN)r   N)F)N)NNF)FNr   ))r   r{   __qualname____doc__staticmethodr=   r@   r[   classmethodrU   rm   rt   r   r   r   r   rE   r7   r,   r-   r   r   propertyr   setterrG   r   r   r1   r   r   r   r   r   r   r   r   r   r   r   __classcell__r   r   rM   r    r   3   sf   5      76
0




     2))
__future__r   	functoolsrC   ri   r   typingr   r   ru   r,   r   r   monai.config.type_definitionsr   monai.data.meta_objr   r   monai.data.utilsr	   r
   r   r   monai.utilsr   monai.utils.enumsr   r   r   r   Zmonai.utils.type_conversionr   r   r   r   __all__	lru_cacher!   r0   rG   r   r   r   r   r    <module>   s$   
