o
    -i%                     @  sL   d dl mZ d dlmZmZ d dlZd dlmZmZ dgZ		ddddZ
dS )    )annotations)CallableSequenceN)ensure_tupleensure_tuple_repgenerate_param_groupsTnetworktorch.nn.Modulelayer_matchesSequence[Callable]match_typesSequence[str]	lr_valuesSequence[float]include_othersboolreturn
list[dict]c                   s   t |}t|t|}t|t|}fdd}fdd}g }g  t|||D ]7\}}	}
|	 dkr8||}n|	 dkrC||}ntd|	 d|| |
d	  d
d | D  q(|rr|dt fdd	 i |S )a  
    Utility function to generate parameter groups with different LR values for optimizer.
    The output parameter groups have the same order as `layer_match` functions.

    Args:
        network: source network to generate parameter groups from.
        layer_matches: a list of callable functions to select or filter out network layer groups,
            for "select" type, the input will be the `network`, for "filter" type,
            the input will be every item of `network.named_parameters()`.
            for "select", the parameters will be
            `select_func(network).parameters()`.
            for "filter", the parameters will be
            `(x[1] for x in filter(f, network.named_parameters()))`
        match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,
            can be "select" or "filter".
        lr_values: a list of LR values corresponding to the `layer_matches` functions.
        include_others: whether to include the rest layers as the last group, default to True.

    It's mainly used to set different LR values for different network elements, for example:

    .. code-block:: python

        net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])
        print(net)  # print out network components to select expected items
        print(net.named_parameters())  # print out all the named parameters to filter out expected items
        params = generate_param_groups(
            network=net,
            layer_matches=[lambda x: x.model[0], lambda x: "2.0.conv" in x[0]],
            match_types=["select", "filter"],
            lr_values=[1e-2, 1e-3],
        )
        # the groups will be a list of dictionaries:
        # [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01},
        #  {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001},
        #  {'params': <filter object at 0x7f9088fd0da0>}]
        optimizer = torch.optim.Adam(params, 1e-4)

    c                       fdd}|S )Nc                     s      S N)
parameters fr   r   X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/optimizers/utils.py_selectJ   s   z;generate_param_groups.<locals>._get_select.<locals>._selectr   )r   r   r   r   r   _get_selectH   s   z*generate_param_groups.<locals>._get_selectc                   r   )Nc                     s   dd t   D S )Nc                 s  s    | ]}|d  V  qdS )   Nr   .0xr   r   r   	<genexpr>S   s    zNgenerate_param_groups.<locals>._get_filter.<locals>._filter.<locals>.<genexpr>)filternamed_parametersr   r   r   r   _filterQ   s   z;generate_param_groups.<locals>._get_filter.<locals>._filterr   )r   r&   r   r   r   _get_filterO   s   z*generate_param_groups.<locals>._get_filterselectr$   zunsupported layer match type: .)paramslrc                 S  s   g | ]}t |qS r   idr    r   r   r   
<listcomp>b   s    z)generate_param_groups.<locals>.<listcomp>r*   c                   s   t |  vS r   r,   )p)_layersr   r   <lambda>e   s    z'generate_param_groups.<locals>.<lambda>)
r   r   lenziplower
ValueErrorappendextendr$   r   )r   r
   r   r   r   r   r'   r*   functyr+   layer_paramsr   )r0   r   r   r      s$   -

 )T)r   r	   r
   r   r   r   r   r   r   r   r   r   )
__future__r   collections.abcr   r   torchmonai.utilsr   r   __all__r   r   r   r   r   <module>   s   