Некоторые оптимизаторы не включают свои имена в конфиги.
Вот полный пример того, как получить конфигурации и как реконструировать (т.е. клонировать) оптимизатор из их конфигураций (включая скорость обучения).
import keras.optimizers as opt
def get_opt_config(optimizer):
"""
Extract Optimizer Configs from an instance of
keras Optimizer
:param optimizer: instance of keras Optimizer.
:return: dict of optimizer configs.
"""
if not isinstance(optimizer, opt.Optimizer):
raise TypeError('optimizer should be instance of '
'keras.optimizers.Optimizer '
'Got {}.'.format(type(optimizer)))
opt_config = optimizer.get_config()
if 'name' not in opt_config.keys():
_name = str(optimizer.__class__).split('.')[-1] \
.replace('\'', '').replace('>', '')
opt_config.update({'name': _name})
return opt_config
def clone_opt(opt_config):
"""
Clone keras optimizer from its configurations.
:param opt_config: dict, keras optimizer configs.
:return: instance of keras optimizer.
"""
if not isinstance(opt_config, dict):
raise TypeError('opt_config must be a dict. '
'Got {}'.format(type(opt_config)))
if 'name' not in opt_config.keys():
raise ValueError('could not find the name of optimizer in opt_config')
name = opt_config.get('name')
params = {k: opt_config[k] for k in opt_config.keys() if k != 'name'}
if name.upper() == 'ADAM':
return opt.Adam(**params)
if name.upper() == 'NADAM':
return opt.Nadam(**params)
if name.upper() == 'ADAMAX':
return opt.Adamax(**params)
if name.upper() == 'ADADELTA':
return opt.Adadelta(**params)
if name.upper() == 'ADAGRAD':
return opt.Adagrad(**params)
if name.upper() == 'RMSPROP':
return opt.RMSprop()
if name.upper() == 'SGD':
return opt.SGD(**params)
raise ValueError('Unknown optimizer name. Available are: '
'(\'adam\',\'sgd\', \'rmsprop\', \'adagrad\', '
'\'adadelta\', \'adamax\', \'nadam\'). '
'Got {}.'.format(name))
Контрольная работа
if __name__ == '__main__':
rmsprop = opt.RMSprop()
configs = get_opt_config(rmsprop)
print(configs)
cloned_rmsprop = clone_opt(configs)
print(cloned_rmsprop)
print(cloned_rmsprop.get_config())
Выходы
{'lr': 0.0010000000474974513, 'rho': 0.8999999761581421, 'decay': 0.0, 'epsilon': 1e-07, 'name': 'RMSprop'}
<keras.optimizers.RMSprop object at 0x7f96370a9358>
{'lr': 0.0010000000474974513, 'rho': 0.8999999761581421, 'decay': 0.0, 'epsilon': 1e-07}
person
Yahya
schedule
09.10.2019