U
    Ph                     @  sR   d dl mZ d dlZd dlmZ d dlZd dlmZ d dl	m
Z
 G dd dZdS )    )annotationsN)Any)NdarrayOrTensorc                   @  sd   e Zd ZdZddddZddddZdd	d
dddZdd	d
dddZdddddddZdS )CumulativeAveragea  
    A utility class to keep track of average values. For example during training/validation loop,
    we need to accumulate the per-batch metrics and calculate the final average value for the whole dataset.
    When training in multi-gpu environment, with DistributedDataParallel, it will average across the processes.

    Example:

    .. code-block:: python

        from monai.metrics import CumulativeAverage

        run_avg = CumulativeAverage()
        batch_size = 8
        for i in range(len(train_set)):
            ...
            val = calc_metric(x,y) #some metric value
            run_avg.append(val, count=batch_size)

        val_avg = run_avg.aggregate() #average value

    None)returnc                 C  s   |    d S )N)resetself r   U/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/metrics/cumulative_average.py__init__.   s    zCumulativeAverage.__init__c                 C  s@   d| _ tjdtjd| _tjdtjd| _t o8t | _	dS )z"
        Reset all  stats
        Nr   dtype)
valtorchtensorfloatsumcountdistis_availableis_initializedis_distributedr	   r   r   r   r   1   s    zCumulativeAverage.resetTboolr   )to_numpyr   c                 C  sX   | j dkrdS | j  }d|t| < | jrD|t  }t| |rT| 	 }|S )z
        returns the most recent value (averaged across processes)

        Args:
            to_numpy: whether to convert to numpy array. Defaults to True
        Nr   )
r   cloner   isfiniter   r   get_world_size
all_reducecpunumpy)r
   r   r   r   r   r   get_current:   s    


zCumulativeAverage.get_currentc                 C  s~   | j dkrdS | j}| j}| jrT|j| j dd}|j| j dd}t| t| t|dk|| |}|rz|	 
 }|S )z
        returns the total average value (averaged across processes)

        Args:
            to_numpy: whether to convert to numpy array. Defaults to True
        Nr   T)copy)r   r   r   r   tor   r   r   wherer    r!   )r
   r   r   r   r   r   r   r   	aggregateP   s    


zCumulativeAverage.aggregate   r   z
Any | None)r   r   r   c                 C  s   t j|t jd| _| jjr*| j  | _t j|t jdd}|jdkrn|j| jjkrnt	d| d| j
  || j
  }t |}t |std| d|  t ||t |}t ||t |}| j| | _| j| | _d	S )
a  
        Append with a new value, and an optional count. Any data type is supported that is convertable
            with torch.as_tensor() e.g. number, list, numpy array, or Tensor.

        Args:
            val: value (e.g. number, list, numpy array or Tensor) to keep track of
            count: count (e.g. number, list, numpy array or Tensor), to update the contribution count

        For example:
            # a simple constant tracking
            avg = CumulativeAverage()
            avg.append(0.6)
            avg.append(0.8)
            print(avg.aggregate()) #prints 0.7

            # an array tracking, e.g. metrics from 3 classes
            avg= CumulativeAverage()
            avg.append([0.2, 0.4, 0.4])
            avg.append([0.4, 0.6, 0.4])
            print(avg.aggregate()) #prints [0.3, 0.5. 0.4]

            # different contributions / counts
            avg= CumulativeAverage()
            avg.append(1, count=4) #avg metric 1 coming from a batch of 4
            avg.append(2, count=6) #avg metric 2 coming from a batch of 6
            print(avg.aggregate()) #prints 1.6 == (1*4 +2*6)/(4+6)

            # different contributions / counts
            avg= CumulativeAverage()
            avg.append([0.5, 0.5, 0], count=[1, 1, 0]) # last elements count is zero to ignore it
            avg.append([0.5, 0.5, 0.5], count=[1, 1, 1]) #
            print(avg.aggregate()) #prints [0.5, 0.5, 0,5] == ([0.5, 0.5, 0] + [0.5, 0.5, 0.5]) / ([1, 1, 0] + [1, 1, 1])

        r   r    )r   devicer   zCCount shape must match val shape, unless count is a single number: z val z!non-finite inputs received: val: z	, count: N)r   	as_tensorr   r   requires_graddetachr   ndimshape
ValueErrorr    r   allwarningswarnr%   
zeros_liker   r   )r
   r   r   Znfinr   r   r   appendi   s     #

zCumulativeAverage.appendN)T)T)r'   )	__name__
__module____qualname____doc__r   r   r"   r&   r3   r   r   r   r   r      s   	r   )
__future__r   r0   typingr   r   torch.distributeddistributedr   monai.configr   r   r   r   r   r   <module>   s   