o
    i)                     @  s   d dl mZ d dlZd dlmZ d dlZd dlZd dlm	Z	m
Z
 d dlmZ d dlmZ d dlmZ d dlmZ g d	ZdddZdddZG dd deZG dd deZG dd deZdS )    )annotationsN)Sequence)	DtypeLikeKeysCollection)MapLabelValue)MapTransform)$keep_components_with_positive_points)look_up_option)VistaPreTransformdVistaPostTransformdRelabeldlabels_dictdict | Nonereturndictc                 C  s"   i }| durdd |   D }|S )z#get the label name to index mappingNc                 S  s   i | ]\}}|  t|qS  )lowerint).0kvr   r   _/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/vista3d/transforms.py
<dictcomp>!   s    z._get_name_to_index_mapping.<locals>.<dictcomp>)items)r   name_to_index_mappingr   r   r   _get_name_to_index_mapping   s   r   r   label_promptlist | Nonec                 C  s   |durWt |trWg }|D ]}t |tr&| s&| | vr&t| | | < q|D ]+}t |ttfrO|t |trI| | | rFt|ndnt| q)|| q)|S |S )zconvert the label name to indexNr   )	
isinstanceliststrisdigitr   lenr   appendget)r   r   Zconverted_label_promptlr   r   r   _convert_name_to_index%   s   0r&   c                      s0   e Zd Z				dd fddZdd Z  ZS )r
   F               u   Nkeysr   allow_missing_keysboolspecial_indexSequence[int]r   r   subclassr   Nonec                   s(   t  || || _|| _t|| _dS )a  
        Pre-transform for Vista3d.

        It performs two functionalities:

        1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels),
           convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive).

        2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key].
           e.g. "lung" label is converted to ["left lung", "right lung"].

        The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B,
        where each element is an int value of length [B, N].

        Args:
            keys: keys of the corresponding items to be transformed.
            special_index: the index that defines the special class.
            subclass: a dictionary that maps a label prompt to its subclasses.
            allow_missing_keys: don't raise exception if key is missing.
        N)super__init__r1   r3   r   r   )selfr.   r/   r1   r   r3   	__class__r   r   r6   :   s   zVistaPreTransformd.__init__c                 C  s  | dd }| dd }t| j|}zm| jd urO|d urOg }ttt| j }tt	|D ]}|| |v rC|
| jt||   q.|||  q.||d< |d urz|d ur}|d | jv rst|}d||dk< d||dk< | }||d< W |S W |S W |S  ty   td Y |S w )Nr   point_labelsr            zDVistaPreTransformd failed to transform label prompt or point labels.)r$   r&   r   r3   r   mapr   r.   ranger"   extendr    r#   r1   nparraytolist	Exceptionwarningswarn)r7   datar   r:   Z_label_promptZsubclass_keysir   r   r   __call__[   s8   

zVistaPreTransformd.__call__)Fr'   NN)r.   r   r/   r0   r1   r2   r   r   r3   r   r   r4   __name__
__module____qualname__r6   rI   __classcell__r   r   r8   r   r
   9   s    !r
   c                      s(   e Zd Zdd fdd	Zd
d Z  ZS )r   Fr.   r   r/   r0   r   r4   c                   s   t  || dS )as  
        Post-transform for Vista3d. It converts the model output logits into final segmentation masks.
        If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...],
        else the indexes will be [0, label_prompt[0], label_prompt[1], ...].
        If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove
        regions that does not contain positive points.

        Args:
            keys: keys of the corresponding items to be transformed.
            dataset_transforms: a dictionary specifies the transform for corresponding dataset:
                key: dataset name, value: list of data transforms.
            dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
            allow_missing_keys: don't raise exception if key is missing.

        N)r5   r6   )r7   r.   r/   r8   r   r   r6   |   s   zVistaPostTransformd.__init__c           
      C  sL  | j D ]}||v r|| }|jd }|j}|dddu r=|dddur=t|d|d||d|dd }d||dk < |dkrbtj|dkdd	d
}|	dd
 d }d||< nd||dk< d|v r|d dur|d7 }|d |}td|d D ]}|d }	||d  |j|||	k< qd||dk< |||< q|S )z)data["label_prompt"] should not contain 0r   r   Npointsr:   )point_coordsr:   g        r=   T)dimkeepdimg      ?g      ?)r.   shapedevicer$   r   	unsqueezetotorchallargmaxfloatr?   dtype)
r7   rG   r.   predZ
object_numrT   Zis_bkr   rH   fracr   r   r   rI      s:   

 
zVistaPostTransformd.__call__)F)r.   r   r/   r0   r   r4   rJ   r   r   r8   r   r   {   s    r   c                      s0   e Zd Zejddfd fddZdd Z  ZS )r   dataset_nameFr.   r   label_mappings dict[str, list[tuple[int, int]]]r[   r   dataset_keyr    r/   r0   r   r4   c                   sX   t  || i | _|| _| D ]\}}tdd |D dd |D |d| j|< qdS )aS  
        Remap the voxel labels in the input data dictionary based on the specified mapping.

        This list of local -> global label mappings will be applied to each input `data[keys]`.
        if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used.
        if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed.

        Args:
            keys: keys of the corresponding items to be transformed.
            label_mappings: a dictionary specifies how local dataset class indices are mapped to the
                global class indices. The dictionary keys are dataset names and the values are lists of
                list of (local label, global label) pairs. This list of local -> global label mappings
                will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`,
                label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None,
                no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform.
            dtype: convert the output data to dtype, default to float32.
            dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
            allow_missing_keys: don't raise exception if key is missing.

        c                 S     g | ]}t |d  qS )r   r   r   pairr   r   r   
<listcomp>       z%Relabeld.__init__.<locals>.<listcomp>c                 S  rb   )r=   rc   rd   r   r   r   rf      rg   )orig_labelstarget_labelsr[   N)r5   r6   mappersra   r   r   )r7   r.   r_   r[   ra   r/   namemappingr8   r   r   r6      s   zRelabeld.__init__c                 C  sV   t |}|| jd}t|| jd d}|d u r|S | |D ]
}||| ||< q|S )Ndefault)rm   )r   r$   ra   r	   rj   key_iterator)r7   rG   dr^   _mkeyr   r   r   rI      s   zRelabeld.__call__)r.   r   r_   r`   r[   r   ra   r    r/   r0   r   r4   )rK   rL   rM   rA   int16r6   rI   rN   r   r   r8   r   r      s    &r   )r   r   r   r   )r   r   r   r   r   r   )
__future__r   rE   collections.abcr   numpyrA   rW   monai.configr   r   monai.transformsr   monai.transforms.transformr   monai.transforms.utilsr   monai.utilsr	   __all__r   r&   r
   r   r   r   r   r   r   <module>   s    

B6