U
    Ph%                     @  sX   d dl mZ d dlmZmZ d dlZd dlmZmZ dgZ	dddd	d
dddddZ
dS )    )annotations)CallableSequenceN)ensure_tupleensure_tuple_repgenerate_param_groupsTztorch.nn.ModulezSequence[Callable]zSequence[str]zSequence[float]boolz
list[dict])networklayer_matchesmatch_types	lr_valuesinclude_othersreturnc                   s   t |}t|t|}t|t|}fdd}fdd}g }g  t|||D ]n\}}	}
|	 dkrp||}n&|	 dkr||}ntd|	 d|| |
d	  d
d | D  qP|r|dt fdd	 i |S )a  
    Utility function to generate parameter groups with different LR values for optimizer.
    The output parameter groups have the same order as `layer_match` functions.

    Args:
        network: source network to generate parameter groups from.
        layer_matches: a list of callable functions to select or filter out network layer groups,
            for "select" type, the input will be the `network`, for "filter" type,
            the input will be every item of `network.named_parameters()`.
            for "select", the parameters will be
            `select_func(network).parameters()`.
            for "filter", the parameters will be
            `(x[1] for x in filter(f, network.named_parameters()))`
        match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,
            can be "select" or "filter".
        lr_values: a list of LR values corresponding to the `layer_matches` functions.
        include_others: whether to include the rest layers as the last group, default to True.

    It's mainly used to set different LR values for different network elements, for example:

    .. code-block:: python

        net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])
        print(net)  # print out network components to select expected items
        print(net.named_parameters())  # print out all the named parameters to filter out expected items
        params = generate_param_groups(
            network=net,
            layer_matches=[lambda x: x.model[0], lambda x: "2.0.conv" in x[0]],
            match_types=["select", "filter"],
            lr_values=[1e-2, 1e-3],
        )
        # the groups will be a list of dictionaries:
        # [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01},
        #  {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001},
        #  {'params': <filter object at 0x7f9088fd0da0>}]
        optimizer = torch.optim.Adam(params, 1e-4)

    c                   s    fdd}|S )Nc                     s      S N)
parameters fr	   r   K/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/optimizers/utils.py_selectJ   s    z;generate_param_groups.<locals>._get_select.<locals>._selectr   )r   r   r	   r   r   _get_selectH   s    z*generate_param_groups.<locals>._get_selectc                   s    fdd}|S )Nc                     s   dd t   D S )Nc                 s  s   | ]}|d  V  qdS )   Nr   .0xr   r   r   	<genexpr>S   s     zNgenerate_param_groups.<locals>._get_filter.<locals>._filter.<locals>.<genexpr>)filternamed_parametersr   r   r   r   _filterQ   s    z;generate_param_groups.<locals>._get_filter.<locals>._filterr   )r   r    r   r   r   _get_filterO   s    z*generate_param_groups.<locals>._get_filterselectr   zunsupported layer match type: .)paramslrc                 S  s   g | ]}t |qS r   idr   r   r   r   
<listcomp>b   s     z)generate_param_groups.<locals>.<listcomp>r$   c                   s   t |  kS r   r&   )p)_layersr   r   <lambda>e       z'generate_param_groups.<locals>.<lambda>)
r   r   lenziplower
ValueErrorappendextendr   r   )r	   r
   r   r   r   r   r!   r$   functyr%   layer_paramsr   )r*   r	   r   r      s$    -

 )T)
__future__r   collections.abcr   r   torchmonai.utilsr   r   __all__r   r   r   r   r   <module>   s    