o
    i?O                     @  s   d dl mZ d dlZd dlZd dlmZmZ d dlmZ d dl	m
Z
 d dlZd dlmZ d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZmZmZmZmZmZm Z  d dl!m"Z" ddgZ#G dd deZ$G dd de$eZ%dS )    )annotationsN)HashableMapping)contextmanager)Any)
transforms)MetaObjget_track_meta)
MetaTensor)to_affine_nd)InvertibleTrait)	Transform)LazyAttrMetaKeys	TraceKeysTraceStatusKeysconvert_to_dst_typeconvert_to_numpyconvert_to_tensor)MONAIEnvVarsTraceableTransformInvertibleTransformc                   @  s   e Zd ZdZdd Zdd Zed.dd	Zejd/dd	Ze	d0d1ddZ
e	dd Zd2ddZdd Ze							d3d4ddZd5d"d#Zd6d7d'd(Zd8d9d)d*Zed:d,d-ZdS );r   a.  
    Maintains a stack of applied transforms to data.

    Data can be one of two types:
        1. A `MetaTensor` (this is the preferred data type).
        2. A dictionary of data containing arrays/tensors and auxiliary metadata. In
            this case, a key must be supplied (this dictionary-based approach is deprecated).

    If `data` is of type `MetaTensor`, then the applied transform will be added to ``data.applied_operations``.

    If `data` is a dictionary, then one of two things can happen:
        1. If data[key] is a `MetaTensor`, the applied transform will be added to ``data[key].applied_operations``.
        2. Else, the applied transform will be appended to an adjacent list using
            `trace_key`. If, for example, the key is `image`, then the transform
            will be appended to `image_transforms` (this dictionary-based approach is deprecated).

    Hopefully it is clear that there are three total possibilities:
        1. data is `MetaTensor`
        2. data is dictionary, data[key] is `MetaTensor`
        3. data is dictionary, data[key] is not `MetaTensor` (this is a deprecated approach).

    The ``__call__`` method of this transform class must be implemented so
    that the transformation information is stored during the data transformation.

    The information in the stack of applied transforms must be compatible with the
    default collate, by only storing strings, numbers and arrays.

    `tracing` could be enabled by assigning to `self.tracing` or setting
    `MONAI_TRACE_TRANSFORM` when initializing the class.
    c                 C  s8   t | ds
t | _t | jdst dk| j_dS dS )zRCreate a `_tracing` instance member to store the thread-local tracing state value._tracingvalue0N)hasattr	threadinglocalr   r   trace_transformr   self r!   Z/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/transforms/inverse.py_init_trace_threadlocalJ   s
   

z*TraceableTransform._init_trace_threadlocalc                   sN   t t di } fddt dg D }|dd t|dkr#|S ||fS )zbWhen pickling, remove the `_tracing` member from the output, if present, since it's not picklable.__dict__c                   s   i | ]}|t  |qS r!   )getattr).0kr   r!   r"   
<dictcomp>X       z3TraceableTransform.__getstate__.<locals>.<dictcomp>	__slots__r   Nr   )dictr%   poplen)r    _dictZ_slotsr!   r   r"   __getstate__U   s   zTraceableTransform.__getstate__returnboolc                 C  s   |    t| jjS )z~
        Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`.
        )r#   r1   r   r   r   r!   r!   r"   tracing\   s   zTraceableTransform.tracingvalc                 C  s   |    || j_dS )z-Sets the thread-local tracing state to `val`.N)r#   r   r   )r    r3   r!   r!   r"   r2   d   s   Nkeyr   c                 C  s   | du rt j S |  t j S )z1The key to store the stack of applied transforms.N)r   
KEY_SUFFIX)r4   r!   r!   r"   	trace_keyj   s   zTraceableTransform.trace_keyc                   C  s   t jt jt jt jfS )z9The keys to store necessary info of an applied transform.)r   
CLASS_NAMEIDTRACINGDO_TRANSFORMr!   r!   r!   r"   transform_info_keysq   s   z&TraceableTransform.transform_info_keysr+   c                 C  s8   | j jt| | jt| dr| jndf}tt|  |S )zg
        Return a dictionary with the relevant information pertaining to an applied transform.
        _do_transformT)		__class____name__idr2   r   r<   r+   zipr;   )r    valsr!   r!   r"   get_transform_infov   s   z%TraceableTransform.get_transform_infoc                 O  s4  | dd}|  }| tjd}|pi }|dd}|rgt rgt|trg|sC|r0| j|ddni }| j	|| tj
|d}	||	S |rT|j }| }
|| n|i }}
| j	||d|
d}	||	S ||d< d|v r~t|d tr~|d | n||d< tj|g|R i |}	t|tr||	S |S )	a  
        Push to a stack of applied transforms of ``data``.

        Args:
            data: dictionary of data or `MetaTensor`.
            args: additional positional arguments to track_transform_meta.
            kwargs: additional keyword arguments to track_transform_meta,
                set ``replace=True`` (default False) to rewrite the last transform infor in
                applied_operation/pending_operation based on ``self.get_transform_info()``.
        lazyFTreplace)check)	orig_size
extra_info)transform_inforC   rG   rH   )getrB   r   r:   r,   r	   
isinstancer
   pop_transformpush_transform	ORIG_SIZEcopy_meta_frompending_operationscopyupdater+   r   track_transform_meta)r    dataargskwargsZ	lazy_evalrH   Zdo_transformrD   Zxformmeta_objextrar!   r!   r"   rL      s.   



z!TraceableTransform.push_transformFrG   dict | NonerF   tuple | Nonec	              
   C  s  |dur|| n|}	t  }
t|	tr|
j|	|
j d |r&t s&td |s~|dur~t|	tr~|		 }t
||tjdd }z|tt|d |tjd }W n tyn } z|jdkri|	jrbd}nd	}t|| d}~ww t|td
tjd|
jtj< t r|r|tjst|trt|tst|}t|	tr|	|
n|	||< |S |
S | }|dur||tj< nt|	tr|	 |tj< nt|	dr|	j dd |tj< ||tj!< |dur|"t#j$d |"t#jd ||tj%< |rJ|du rt#j$|vr|tjg |t#j$< n||t#j$< t&t'|t#j$ dd( |t#j$< |du r0t#j|vr/t) |t#j< n||t#j< t|t#j td
d|t#j< |
*| nV|
j+rt|tr[|tj,dnd}d| ddd |
j+D  }|durw|d| 7 }|
j+d }|tj-t }|t.j/t0 }|1| ||t.j/< ||tj-< |
2| t|trt|tst|}t|	tr|	|
||< |S t34|}||vrg ||< || 1| |S |
S )am  
        Update a stack of applied/pending transforms metadata of ``data``.

        Args:
            data: dictionary of data or `MetaTensor`.
            key: if data is a dictionary, data[key] will be modified.
            sp_size: the expected output spatial size when the transform is applied.
                it can be tensor or numpy, but will be converted to a list of integers.
            affine: the affine representation of the (spatial) transform in the image space.
                When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``.
            extra_info: if desired, any extra information pertaining to the applied
                transform can be stored in this dictionary. These are often needed for
                computing the inverse transformation.
            orig_size: sometimes during the inverse it is useful to know what the size
                of the original image was, in which case it can be supplied here.
            transform_info: info from self.get_transform_info().
            lazy: whether to push the transform to pending_operations or applied_operations.

        Returns:

            For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with
            updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata.
        N)keyszUmetadata is not tracked, please call 'set_track_meta(True)' if doing lazy evaluation.)dtyper         zHTransform applied to batched tensor, should be applied to instances onlyz\Mismatch affine matrix, ensured that the batch dimension is not included in the calculation.cpu)devicer[   shapeT)wrap_sequence)r_    z
Transform z; has been applied to a MetaTensor with pending operations: c                 S  s   g | ]}| tjqS r!   )rI   r   r7   )r&   xr!   r!   r"   
<listcomp>  r)   z;TraceableTransform.track_transform_meta.<locals>.<listcomp>z	 for key )5r   rJ   r
   rN   r$   rZ   r	   warningswarnpeek_pending_affiner   torchfloat64r   r-   RuntimeErrorndimis_batchr   r_   metar   AFFINErI   r   r9   r   r+   rP   rM   peek_pending_shaper   r`   LAZYr,   r   SHAPE
EXTRA_INFOtupler   tolistget_default_affinepush_pending_operationrO   r7   STATUSESr   PENDING_DURING_APPLYlistappendpush_applied_operationr   r6   )clsrS   r4   Zsp_sizeaffinerG   rF   rH   rC   Zdata_tZout_objZorig_affineemsginfoZtransform_nameZpendstatusesmessagesZx_kr!   r!   r"   rR      s   #


 

	







 









z'TraceableTransform.track_transform_meta	transformr   Nonec              
   C  s   | tjd}|t| krdS |tjkrdS | tjd}| tji  d}|r.t| t	j
 dv r=|| jjkr=dS td| jj d| d| dt|  d		)
z&Check transforms are of same instance.rb   Nrg   )spawnNzError z8 getting the most recently applied invertible transform  z != .)rI   r   r8   r?   NONEr7   rs   rf   rg   ri   multiprocessingget_start_methodr=   r>   rk   )r    r   Zxform_idZ
xform_namewarning_msgr!   r!   r"   check_transforms_match+  s(   

z)TraceableTransform.check_transforms_matchTrE   r,   c                 C  s   | j stdt|tr|j}n,t|tr2||v r&t|| tr&|| j}n|| |t }n
t	dt
| d|sNt	dt
| d| d| d|rW| |d  |r^|dS |d S )	aL  
        Get most recent matching transform for the current class from the sequence of applied operations.

        Args:
            data: dictionary of data or `MetaTensor`.
            key: if data is a dictionary, data[key] will be modified.
            check: if true, check that `self` is the same type as the most recently-applied transform.
            pop: if true, remove the transform as it is returned.

        Returns:
            Dictionary of most recently applied transform

        Raises:
            - RuntimeError: data is neither `MetaTensor` nor dictionary
        zCTransform Tracing must be enabled to get the most recent transform.z8`data` should be either `MetaTensor` or dictionary, got r   zItem of type z (key: z, pop: z ) has empty 'applied_operations're   )r2   rk   rJ   r
   applied_operationsr   rI   r6   get_default_applied_operations
ValueErrortyper   r,   )r    rS   r4   rE   r,   all_transformsr!   r!   r"   get_most_recent_transform?  s   

 z,TraceableTransform.get_most_recent_transformc                 C  s   | j |||ddS )a  
        Return and pop the most recent transform.

        Args:
            data: dictionary of data or `MetaTensor`
            key: if data is a dictionary, data[key] will be modified
            check: if true, check that `self` is the same type as the most recently-applied transform.

        Returns:
            Dictionary of most recently applied transform

        Raises:
            - RuntimeError: data is neither `MetaTensor` nor dictionary
        T)r,   )r   )r    rS   r4   rE   r!   r!   r"   rK   c  s   z TraceableTransform.pop_transformto_tracec                 c  s    | j }|| _ dV  || _ dS )zITemporarily set the tracing status of a transform with a context manager.N)r2   )r    r   prevr!   r!   r"   r   t  s
   
z"TraceableTransform.trace_transform)r0   r1   )r3   r1   )N)r4   r   )r0   r+   )NNNNNNF)r4   r   rG   rX   rF   rY   )r   r   r0   r   )NTF)r4   r   rE   r1   r,   r1   )NT)r4   r   rE   r1   )r   r1   )r>   
__module____qualname____doc__r#   r/   propertyr2   setterstaticmethodr6   r;   rB   rL   classmethodrR   r   r   rK   r   r   r!   r!   r!   r"   r   *   s:    

% 
$c                   @  s"   e Zd ZdZdd Zd
ddZd	S )r   a:  Classes for invertible transforms.

    This class exists so that an ``invert`` method can be implemented. This allows, for
    example, images to be cropped, rotated, padded, etc., during training and inference,
    and after be returned to their original size before saving to file for comparison in
    an external viewer.

    When the ``inverse`` method is called:

        - the inverse is called on each key individually, which allows for
          different parameters being passed to each label (e.g., different
          interpolation for image and label).

        - the inverse transforms are applied in a last-in-first-out order. As
          the inverse is applied, its entry is removed from the list detailing
          the applied transformations. That is to say that during the forward
          pass, the list of applied transforms grows, and then during the
          inverse it shrinks back down to an empty list.

    We currently check that the ``id()`` of the transform is the same in the forward and
    inverse directions. This is a useful check to ensure that the inverses are being
    processed in the correct order.

    Note to developers: When converting a transform to an invertible transform, you need to:

        #. Inherit from this class.
        #. In ``__call__``, add a call to ``push_transform``.
        #. Any extra information that might be needed for the inverse can be included with the
           dictionary ``extra_info``. This dictionary should have the same keys regardless of
           whether ``do_transform`` was `True` or `False` and can only contain objects that are
           accepted in pytorch data loader's collate function (e.g., `None` is not allowed).
        #. Implement an ``inverse`` method. Make sure that after performing the inverse,
           ``pop_transform`` is called.

    c                 C  sd   t |trt | tjs|S t|}| |D ]}tj|}||vs&|| s'qtj||dd}q|S )z
        This function is to be called before every `self.inverse(data)`,
        update each MetaTensor `data[key]` using `data[key_transforms]` and `data[key_meta_dict]`,
        for MetaTensor backward compatibility 0.9.0.
        F)t)rJ   r+   r   MapTransformZkey_iteratorr   r6   sync_meta_info)r    rS   dr'   Ztransform_keyr!   r!   r"   inverse_update  s   z"InvertibleTransform.inverse_updaterS   r   r0   c                 C  s   t d| jj d)z
        Inverse of ``__call__``.

        Raises:
            NotImplementedError: When the subclass does not override this method.

        z	Subclass z must implement this method.)NotImplementedErrorr=   r>   )r    rS   r!   r!   r"   inverse  s   zInvertibleTransform.inverseN)rS   r   r0   r   )r>   r   r   r   r   r   r!   r!   r!   r"   r   }  s    $)&
__future__r   r   rf   collections.abcr   r   
contextlibr   typingr   ri   monair   monai.data.meta_objr   r	   monai.data.meta_tensorr
   monai.data.utilsr   monai.transforms.traitsr   monai.transforms.transformr   monai.utilsr   r   r   r   r   r   r   monai.utils.miscr   __all__r   r   r!   r!   r!   r"   <module>   s(   $	  U