U
    Ph"                     @  sz   d dl mZ d dlmZ d dlmZmZmZmZm	Z	m
Z
mZmZmZmZ d dlmZ d dlmZ dgZG dd deZdS )	    )annotations)Any)
AnalyzerFgImageStatsFgImageStatsSummFilenameStatsImageHistogramImageHistogramSumm
ImageStatsImageStatsSumm
LabelStatsLabelStatsSumm)Compose)DataStatsKeysSegSummarizerc                
      sX   e Zd ZdZddddddd	dd
d fddZddd
dddZdddddZ  ZS )r   a
  
    SegSummarizer serializes the operations for data analysis in Auto3Dseg pipeline. It loads
    two types of analyzer functions and execute differently. The first type of analyzer is
    CaseAnalyzer which is similar to traditional monai transforms. It can be composed with other
    transforms to process the data dict which has image/label keys. The second type of analyzer
    is SummaryAnalyzer which works only on a list of dictionary. Each dictionary is the output
    of the case analyzers on a single dataset.

    Args:
        image_key: a string that user specify for the image. The DataAnalyzer will look it up in the
            datalist to locate the image files of the dataset.
        label_key: a string that user specify for the label. The DataAnalyzer will look it up in the
            datalist to locate the label files of the dataset. If label_key is None, the DataAnalyzer
            will skip looking for labels and all label-related operations.
        do_ccp: apply the connected component algorithm to process the labels/images.
        hist_bins: list of positive integers (one for each channel) for setting the number of bins used to
            compute the histogram. Defaults to [100].
        hist_range: list of lists of two floats (one for each channel) setting the intensity range to
            compute the histogram. Defaults to [-500, 500].
        histogram_only: whether to only compute histograms. Defaults to False.

    Examples:
        .. code-block:: python

            # imports

            summarizer = SegSummarizer("image", "label")
            transform_list = [
                LoadImaged(keys=keys),
                EnsureChannelFirstd(keys=keys),  # this creates label to be (1,H,W,D)
                ToDeviced(keys=keys, device=device, non_blocking=True),
                Orientationd(keys=keys, axcodes="RAS"),
                EnsureTyped(keys=keys, data_type="tensor"),
                Lambdad(keys="label", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
                SqueezeDimd(keys=["label"], dim=0),
                summarizer,
            ]
            ...
            # skip some steps to set up data loader
            dataset = data.DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
            transform = Compose(transform_list)
            stats = []
            for batch_data in dataset:
                d = transform(batch_data[0])
                stats.append(d)
            report = summarizer.summarize(stats)
    TNFstrz
str | Noneboolzlist[int] | int | Nonezlist | NoneNone)	image_key	label_keyaveragedo_ccp	hist_bins
hist_rangehistogram_onlyreturnc                   s   || _ || _|d krdgn|| _|d kr0ddgn|| _|| _g | _t   | t	|t
jd  | t	|t
jd  | js| t|t|d |d krd S | t||t|d | t|||dt||d | jdkr| t|||dt  d S )	Nd   ii  )r   )r   )r   r   r   )r   r   r   )r   r   r   r   r   summary_analyzerssuper__init__add_analyzerr   r   BY_CASE_IMAGE_PATHBY_CASE_LABEL_PATHr
   r   r   r   r   r   r   r	   )selfr   r   r   r   r   r   r   	__class__ S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/auto3dseg/seg_summarizer.pyr   S   s.    

 

 zSegSummarizer.__init__r   zAnalyzer | None)case_analyzersummary_analyzerr   c                 C  s(   |  j |f7  _ |dk	r$| j| dS )a  
        Add new analyzers to the engine so that the callable and summarize functions will
        utilize the new analyzers for stats computations.

        Args:
            case_analyzer: analyzer that works on each data.
            summary_analyzer: analyzer that works on list of stats dict (output from case_analyzers).

        Examples:

            .. code-block:: python

                from monai.auto3dseg import Analyzer
                from monai.auto3dseg.utils import concat_val_to_np
                from monai.auto3dseg.analyzer_engine import SegSummarizer

                class UserAnalyzer(Analyzer):
                    def __init__(self, image_key="image", stats_name="user_stats"):
                        self.image_key = image_key
                        report_format = {"ndims": None}
                        super().__init__(stats_name, report_format)

                    def __call__(self, data):
                        d = dict(data)
                        report = deepcopy(self.get_report_format())
                        report["ndims"] = d[self.image_key].ndim
                        d[self.stats_name] = report
                        return d

                class UserSummaryAnalyzer(Analyzer):
                    def __init__(stats_name="user_stats"):
                        report_format = {"ndims": None}
                        super().__init__(stats_name, report_format)
                        self.update_ops("ndims", SampleOperations())

                    def __call__(self, data):
                        report = deepcopy(self.get_report_format())
                        v_np = concat_val_to_np(data, [self.stats_name, "ndims"])
                        report["ndims"] = self.ops["ndims"].evaluate(v_np)
                        return report

                summarizer = SegSummarizer()
                summarizer.add_analyzer(UserAnalyzer, UserSummaryAnalyzer)

        N)
transformsr   append)r#   r(   r)   r&   r&   r'   r    {   s    .zSegSummarizer.add_analyzerz
list[dict]zdict[str, dict])datar   c                 C  s   t |tst| j di }t|dkr.|S t |d tsXt| j dt|d  | jD ] }t|r^|	|j
||i q^|S )a  
        Summarize the input list of data and generates a report ready for json/yaml export.

        Args:
            data: a list of data dicts.

        Returns:
            a dict that summarizes the stats across data samples.

        Examples:
            stats_summary:
                image_foreground_stats:
                    intensity: {...}
                image_stats:
                    channels: {...}
                    cropped_shape: {...}
                    ...
                label_stats:
                    image_intensity: {...}
                    label:
                    - image_intensity: {...}
                    - image_intensity: {...}
                    - image_intensity: {...}
                    - image_intensity: {...}
        z4 summarize function needs input to be a list of dictr   z6 summarize function needs a list of dict. Now we have )
isinstancelist
ValueErrorr%   lendicttyper   callableupdate
stats_name)r#   r,   reportanalyzerr&   r&   r'   	summarize   s    

zSegSummarizer.summarize)TTNNF)__name__
__module____qualname____doc__r   r    r8   __classcell__r&   r&   r$   r'   r   "   s   4     "(2N)
__future__r   typingr   Zmonai.auto3dseg.analyzerr   r   r   r   r   r	   r
   r   r   r   monai.transformsr   monai.utils.enumsr   __all__r   r&   r&   r&   r'   <module>   s   0