o
    #il                     @  s$  d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
Zd dlZd dlZd dlmZ d dlmZmZ d dlmZmZmZmZ d d	lmZ d d
lmZmZmZmZ d dlm Z m!Z!m"Z"m#Z# dgZ$e%ddd Z&dd Z'G dd deej(Z)e*ej+drej+,ee)eeg dS dS )    )annotationsN)Sequence)deepcopy)Any)NdarrayTensor)MetaObjget_track_meta)affine_to_spacingdecollate_batchlist_data_collateremove_extra_metadata)look_up_option)LazyAttrMetaKeysPostFix	SpaceKeys)convert_data_typeconvert_to_dst_typeconvert_to_numpyconvert_to_tensor
MetaTensorc                 C  sH   t tdr"t | dr"t tj| jr"tttj| jtr"ttj| jS d S )Nreturn_types__name__)hasattrtorchr   r   
isinstancegetattrtype)func r   X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/data/meta_tensor.py_get_named_tuple_like_type"   s   r!   c              	   C  sH   t | ttttjtjtjtj	fp#t | t
p"t | to"tdd | D  S )Nc                 s  s    | ]}t |tV  qd S N)r   r   .0xr   r   r    	<genexpr>0   s    z*_not_requiring_metadata.<locals>.<genexpr>)r   intstrbytesr   Sizedtypedevicenpndarrayr   r   any)retr   r   r    _not_requiring_metadata.   s    $r1   c                      s  e Zd ZdZe			dVdWd
dZ			dVdX fddZedYddZedd Z	edZd[ fddZ
edd Zdd Zdd Zeejfd\d d!Zd\d"d#Zejddfd$d%Zd]d^d)d*Zed+d, Zejd_d-d,Zejdfd`d1d2Zdad3d4Zed\d5d6Zejdbd9d6Zed:d; Zd<d= Zd>d? Zd@dA Z dcdBdCZ!dDdE Z"e	FdddedLdMZ#dNdO Z$dPdQ Z%dRdS Z&d_dTdUZ'  Z(S )fr   a	  
    Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.

    Metadata is stored in the form of a dictionary. Nested, an affine matrix will be
    stored. This should be in the form of `torch.Tensor`.

    Behavior should be the same as `torch.Tensor` aside from the extended
    meta functionality.

    Copying of information:

        * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
          first instance of `MetaTensor` if `a.is_batch` is False
          (For batched data, the metadata will be shallow copied for efficiency purposes).

    Example:
        .. code-block:: python

            import torch
            from monai.data import MetaTensor

            t = torch.tensor([1,2,3])
            affine = torch.as_tensor([[2,0,0,0],
                                      [0,2,0,0],
                                      [0,0,2,0],
                                      [0,0,0,1]], dtype=torch.float64)
            meta = {"some": "info"}
            m = MetaTensor(t, affine=affine, meta=meta)
            m2 = m + m
            assert isinstance(m2, MetaTensor)
            assert m2.meta["some"] == "info"
            assert torch.all(m2.affine == affine)

    Notes:
        - Requires pytorch 1.9 or newer for full compatibility.
        - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
          not work if `im` is of type `MetaTensor`. This can be resolved with
          `torch.jit.trace(net, im.as_tensor())`.
        - For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported.
        - For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor.
          see: https://github.com/pytorch/pytorch/issues/54457
        - A warning will be raised if in the constructor `affine` is not `None` and
          `meta` already contains the key `affine`.
        - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.
        - With a batch of data, `batch[0]` will return the 0th image
          with the 0th metadata. When the batch dimension is non-singleton, e.g.,
          `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the
          last example) of the metadata will be returned, and `is_batch` will return `True`.
        - When creating a batch with this class, use `monai.data.DataLoader` as opposed
          to `torch.utils.data.DataLoader`, as this will take care of collating the
          metadata properly.
    Naffinetorch.Tensor | Nonemetadict | Noneapplied_operationslist | Nonereturnc                 O  s@   |r| dd | dd dni }tj|g|R i || S )Nr,   r+   r,   r+   )popr   	as_tensoras_subclass)clsr%   r2   r4   r6   argskwargs_kwargsr   r   r    __new__j   s   "
zMetaTensor.__new__Nonec                   s   t    |dur|| _nt|trt|j| _|dur+tj| jv r't	
d || _ntj| jv r9| jtj | _n|  | _|durF|| _nt | _t|tjr[t|ts[| |  tj| jvrjtj| jtj< dS dS )a  
        Args:
            x: initial array for the MetaTensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.
            affine: optional 4x4 array.
            meta: dictionary of metadata.
            applied_operations: list of previously applied operations on the MetaTensor,
                the list is typically maintained by `monai.transforms.TraceableTransform`.
                See also: :py:class:`monai.transforms.TraceableTransform`
            _args: additional args (currently not in use in this constructor).
            _kwargs: additional kwargs (currently not in use in this constructor).

        Note:
            If a `meta` dictionary is given, use it. Else, if `meta` exists in the input tensor `x`, use it.
            Else, use the default value. Similar for the affine, except this could come from
            four places, priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.

        NzRSetting affine, but the applied meta contains an affine. This will be overwritten.)super__init__r4   r   r   r   __dict__r   AFFINEwarningswarnr2   get_default_affiner6   get_default_applied_operationsr   Tensorr   copy_meta_fromSPACEr   RAS)selfr%   r2   r4   r6   _argsr@   	__class__r   r    rD   w   s(   





zMetaTensor.__init__retsr   c           
   	   C  s   g }d}t dd t|| D }t| D ]6\}}t|ts!n't s)| }nt|| }	||_	|j
|	| d |rHt||||||}|| qt| trWt|S |S )a  
        Update the metadata from the output of `MetaTensor.__torch_function__`.

        The output of `torch.Tensor.__torch_function__` could be a single object or a
        sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a
        list of not already, and then we loop across each element, processing metadata
        as necessary. For each element, if not of type `MetaTensor`, then nothing to do.

        Args:
            rets: the output from `torch.Tensor.__torch_function__`, which has been
                converted to a list in `MetaTensor.__torch_function__` if it wasn't
                already a `Sequence`.
            func: the torch function that was applied. Examples might be `torch.squeeze`
                or `torch.Tensor.__add__`. We need this since the metadata need to be
                treated differently if a batch of data is considered. For example,
                slicing (`torch.Tensor.__getitem__`) the ith element of the 0th
                dimension of a batch of data should return a ith tensor with the ith
                metadata.
            args: positional arguments that were passed to `func`.
            kwargs: keyword arguments that were passed to `func`.

        Returns:
            A sequence with the same number of elements as `rets`. For each element, if
            the input type was not `MetaTensor`, then no modifications will have been
            made. If global parameters have been set to false (e.g.,
            `not get_track_meta()`), then any `MetaTensor` will be converted to
            `torch.Tensor`. Else, metadata will be propagated as necessary (see
            :py:func:`MetaTensor._copy_meta`).
        Nc                 s  s     | ]}t |d r|jV  qdS )is_batchN)r   rT   r#   r   r   r    r&      s    z)MetaTensor.update_meta.<locals>.<genexpr>)	copy_attr)r/   r   flatten_meta_objsvalues	enumerater   r   r   r;   rT   rL   _handle_batchedappendtuple)
rS   r   r>   r?   outmetasrT   idxr0   Z	meta_argsr   r   r    update_meta   s   

zMetaTensor.update_metac                 C  s  |t jjkr|dkst|dk st|d dk r|S t|d tr'|d d n|d }|tdddtdfv s<t|t jr>|S t|d dd}|| }	t|	t	rm|	rmzt
|	}	W n ttttfyl }
 ztd|
d}
~
ww t|	trud|	_t|	dr|	j |_|S |t jjkrt|dkr|d }nd	|v r|d	 }nd}|dkr|du rt|d dd}t|| dr|| j |_d|_|S )
z/utility function to handle batched MetaTensors.r         NF)detachzInconsistent batched metadata dicts when slicing a batch of MetaTensors, please consider converting it into a torch Tensor using `x.as_tensor()` or a numpy array using `x.array`.rE   dim)r   rK   __getitem__lenr   r   sliceEllipsisr
   listr   	TypeError
ValueErrorRuntimeError
IndexErrorr   rT   r   rE   copyunbind)r=   r0   r^   r]   r   r>   r?   	batch_idxZ	dec_batchZret_metaerc   r   r   r    rY      sJ   $""



zMetaTensor._handle_batchedr   r   c           	        s   |du ri }t  ||||}t|r|S t|durDt|t|rDt||||}t|jD ]}|| j	|| _	|| j
|| _
q/|S t|tsO|g}d}nd}t||||}|r_|d S |S )zWraps all torch functions.NTFr   )rC   __torch_function__r1   r!   r   r   r_   rangen_fieldsr4   r6   r   )	r=   r   typesr>   r?   r0   Z	out_itemsr^   unpackrQ   r   r    rq     s"   
zMetaTensor.__torch_function__c                 C  s,   t | ttjttfrt| tjddd S | S )NF)output_typewrap_sequencer   )	r   r   r   rK   r[   rh   r   r-   r.   )r%   r   r   r    _convert3  s   zMetaTensor._convertc                 C  s^   z|j ds
tW S W n ty   t Y S w tttj|}dd | D }||i |S )zQfor numpy Interoperability, so that we can compute ``np.sum(MetaTensor([1.0]))``.numpyc                 S     i | ]
\}}|t |qS r   r   rx   r$   kvr   r   r    
<dictcomp>A      z1MetaTensor.__array_function__.<locals>.<dictcomp>)	
__module__
startswithNotImplementedAttributeErrorrh   mapr   rx   items)rO   r   rt   r>   r?   rP   r@   r   r   r    __array_function__9  s   zMetaTensor.__array_function__c                 O  s   zt |jdstW S W n ty   t Y S w |dkrtS ttj|}dd | D }d|v r4tS zt	|||i |W S  tyJ   t Y S w )z
        For numpy interoperability, so that we can compute ``MetaTensor([1.0]) >= np.asarray([1.0])``.
        This is for pytorch > 1.8.
        ry   __call__c                 S  rz   r   r{   r|   r   r   r    r   Q  r   z.MetaTensor.__array_ufunc__.<locals>.<dictcomp>r\   )
r   r   r   r   r   r   r   rx   r   r   )rO   ufuncmethodinputsr?   Z_inputsr@   r   r   r    __array_ufunc__D  s$   zMetaTensor.__array_ufunc__torch.Tensorc                 C  s   t jdt d| dS )N   cpur9   )r   eyer,   r+   r   r   r    rI   Y  s   zMetaTensor.get_default_affinec                 C  s   |  tjS )z
        Return the `MetaTensor` as a `torch.Tensor`.
        It is OS dependent as to whether this will be a deep copy or not.
        )r<   r   rK   rO   r   r   r    r;   ]  s   zMetaTensor.as_tensorc                 O  s   t | |||ddd S )a  
        Returns a new array in `output_type`, the array shares the same underlying storage when the output is a
        numpy array. Changes to self tensor will be reflected in the ndarray and vice versa.

        Args:
            output_type: output type, see also: :py:func:`monai.utils.convert_data_type`.
            dtype: dtype of output data. Converted to correct library type (e.g.,
                `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
                If left blank, it remains unchanged.
            device: if the output is a `torch.Tensor`, select device (if `None`, unchanged).
            _args: currently unused parameters.
            _kwargs: currently unused parameters.
        T)rv   r+   r,   rw   r   )r   )rO   rv   r+   r,   rP   r@   r   r   r    	get_arrayd  s   zMetaTensor.get_arrayFnon_blockingboolc                 O  s<   t |ddd}z| j||dW S  ty   || _|  Y S w )a  
        Copies the elements from src into self tensor and returns self.
        The src tensor must be broadcastable with the self tensor.
        It may be of a different data type or reside on a different device.

        See also: `https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html`

        Args:
            src: the source tensor to copy from.
            non_blocking: if True and this copy is between CPU and GPU, the copy may occur
                asynchronously with respect to the host. For other cases, this argument has no effect.
            _args: currently unused parameters.
            _kwargs:  currently unused parameters.
        FT)
track_metarw   )r   )r   copy_rk   data)rO   srcr   rP   r@   	convertedr   r   r    	set_arrayt  s   zMetaTensor.set_arrayc                 C  s   |   S )a  
        Returns a numpy array of ``self``. The array and ``self`` shares the same underlying storage if self is on cpu.
        Changes to ``self`` (it's a subclass of torch.Tensor) will be reflected in the ndarray and vice versa.
        If ``self`` is not on cpu, the call will move the array to cpu and then the storage is not shared.

        :getter: see also: :py:func:`MetaTensor.get_array()`
        :setter: see also: :py:func:`MetaTensor.set_array()`
        )r   r   r   r   r    array  s   
zMetaTensor.arrayc                 C  s   |  | dS )z+A default setter using ``self.set_array()``N)r   )rO   r   r   r   r    r     s   keyr(   dictc                 C  sJ   |t jtjfvrtd| d|| j||dt|| jt|| j	iS )a  
        Get the object as a dictionary for backwards compatibility.
        This method does not make a deep copy of the objects.

        Args:
            key: Base key to store main data. The key for the metadata will be determined using `PostFix`.
            output_type: `torch.Tensor` or `np.ndarray` for the main data.
            dtype: dtype of output data. Converted to correct library type (e.g.,
                `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
                If left blank, it remains unchanged.

        Return:
            A dictionary consisting of three keys, the main data (stored under `key`) and the metadata.
        z4output_type must be torch.Tensor or np.ndarray, got .)rv   r+   )
r   rK   r-   r.   rj   r   r   r4   
transformsr6   )rO   r   rv   r+   r   r   r    as_dict  s   zMetaTensor.as_dictc                 O  s~   t |tr|dd^}}|s|n|d }nt|dd}t|h ddd}|dkr-tj}n
|d	v r5tj}nd
}| j	|||dS )a  
        Cast to ``dtype``, sharing data whenever possible.

        Args:
            dtype: dtypes such as np.float32, torch.float, "np.float32", float.
            device: the device if `dtype` is a torch data type.
            _args: additional args (currently unused).
            _kwargs: additional kwargs (currently unused).

        Returns:
            data array instance
        r   ra   r   r   r   >   ry   r   r-   ry   )default)ry   r-   N)rv   r+   r,   )
r   r(   splitr   r   r   rK   r-   r.   r   )rO   r+   r,   rP   r@   mod_strout_typer   r   r    astype  s   
zMetaTensor.astypec                 C  s   | j tj|  S )zAGet the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``)r4   getr   rF   rI   r   r   r   r    r2     s   zMetaTensor.affinedr   c                 C  s$   t j|t dt jd| jtj< dS )zSet the affine.r   r9   N)r   r;   r,   float64r4   r   rF   )rO   r   r   r   r    r2     s   $c                 C  s    | j rdd | jD S t| jS )zGet the spacingc                 S  s   g | ]}t |qS r   )r	   )r$   ar   r   r    
<listcomp>  s    z%MetaTensor.pixdim.<locals>.<listcomp>)rT   r2   r	   r   r   r   r    pixdim  s   
zMetaTensor.pixdimc                 C  sH   d}| j r| j d tjd}|du r"tt| jdd dd S |S )z
        Get the currently expected spatial shape as if all the pending operations are executed.
        For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned.
        NT)rw   ra   )pending_operationsr   r   SHAPEr[   r   shapetolist)rO   resr   r   r    peek_pending_shape  s   *zMetaTensor.peek_pending_shapec                 C  s   | j }t|d }|dvrtd| d | jD ]*}t|tjt	j
d}|d u r+qt||d }tjj||}tjjj||}q|S )Nra   )r`      z)Only 2d and 3d affine are supported, got zd input.r   r   )r2   re   rG   rH   r   r   r   r   rF   r   r   r   monair   utilsto_affine_ndr   lazyZcombine_transforms)rO   r   rpZnext_matrixr   r   r    peek_pending_affine  s   
zMetaTensor.peek_pending_affinec                 C  sB   | j r| j d tjd n| j}|d u rdS ttdt|d S )Nr   ra   )r   r   r   rF   r2   r'   maxre   )rO   r   r   r   r    peek_pending_rank  s    "zMetaTensor.peek_pending_rankc                 C  s   t | |  j||||dS )z
        must be defined for deepcopy to work

        See:
            - https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty
        )sizer+   r,   requires_grad)r   r;   	new_empty)rO   r   r+   r,   r   r   r   r    r     s   zMetaTensor.new_emptyc                 K  s(   t |  jdi |}t| j|_|S )z
        Returns a copy of the MetaTensor instance.

        Args:
            kwargs: additional keyword arguments to `torch.clone`.

        See also: https://pytorch.org/docs/stable/generated/torch.clone.html
        Nr   )r   r;   cloner   rE   )rO   r?   Znew_instr   r   r    r     s   	zMetaTensor.cloner   imsimple_keyspattern
str | Nonesepc                 C  s   t | t o|dud}t|ts|S |du ri }|r-tj|v r)t |tj |tj< t| |dur<tjj	||dd|}|du rBi }||_
tj|v rR|tj |_|S t |_|S )aX  
        Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
        convert that to `torch.Tensor`, too. Remove any superfluous metadata.

        Args:
            im: Input image (`np.ndarray` or `torch.Tensor`)
            meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor.
            simple_keys: whether to keep only a simple subset of metadata keys.
            pattern: combined with `sep`, a regular expression used to match and prune keys
                in the metadata (nested dictionary), default to None, no key deletion.
            sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
                default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
                e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.

        Returns:
            By default, a `MetaTensor` is returned.
            However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
        N)r   T)keysr   Zuse_re)r   r   r   r   r   rF   r   r   r   ZDeleteItemsdr4   r2   rI   )r   r4   r   r   r   imgr   r   r    ensure_torch_and_prune_meta  s&   



z&MetaTensor.ensure_torch_and_prune_metac                 C  s   d|     S )z
        Prints a representation of the tensor.
        Prepends "meta" to ``torch.Tensor.__repr__``.
        Use ``print_verbose`` for associated metadata.
        r4   )r;   __repr__r   r   r   r    r   G     zMetaTensor.__repr__c                 C  s   dt |   S )z
        Prints a representation of the tensor.
        Prepends "meta" to ``torch.Tensor.__str__``.
        Use ``print_verbose`` for associated metadata.
        r4   )r(   r;   r   r   r   r    __str__O  r   zMetaTensor.__str__c                 C  s   |   |S )zO
        returns the output of pytorch tensor's ``__format__`` method.
        )r;   
__format__)rO   format_specr   r   r    r   W  s   zMetaTensor.__format__c                 C  s(   t |  | jdurt | j  dS dS )zVerbose print with meta data.N)printr4   r   r   r   r   r    print_verbose]  s   
zMetaTensor.print_verbose)NNN)r2   r3   r4   r5   r6   r7   r8   r   )r2   r3   r4   r5   r6   r7   r8   rB   )rS   r   r8   r   )r   N)r8   r   )r8   r   )F)r   r   )r8   rB   )r   r(   r8   r   r"   )r   r   r8   rB   )NNF)FNr   )
r   r   r4   r5   r   r   r   r   r   r(   ))r   r   __qualname____doc__staticmethodrA   rD   r_   classmethodrY   rq   rx   r   r   r   r   rI   r;   r-   r.   r   r   propertyr   setterrK   r   r   r2   r   r   r   r   r   r   r   r   r   r   r   __classcell__r   r   rQ   r    r   4   sd    576
0





1add_safe_globals)-
__future__r   	functoolsrG   collections.abcr   rm   r   typingr   ry   r-   r   r   monai.config.type_definitionsr   monai.data.meta_objr   r   monai.data.utilsr	   r
   r   r   monai.utilsr   monai.utils.enumsr   r   r   r   Zmonai.utils.type_conversionr   r   r   r   __all__	lru_cacher!   r1   rK   r   r   serializationr   r   r   r   r    <module>   s6   
    5