Keras中的MultiStepLR
【摘要】 Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:
1.代码
from tensorflow.python.keras.callbacks import Callbackfrom tensorflow.python.keras import backend as Kimport numpy as npimport argpar...
Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:
1.代码
-
from tensorflow.python.keras.callbacks import Callback
-
from tensorflow.python.keras import backend as K
-
import numpy as np
-
import argparse
-
-
-
parser = argparse.ArgumentParser()
-
parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
-
parser.add_argument('--lr_decay_factor', type=float, default=0.1)
-
args, _ = parser.parse_known_args()
-
-
-
def get_lr_scheduler(args):
-
lr_scheduler = MultiStepLR(args=args)
-
return lr_scheduler
-
-
-
class MultiStepLR(Callback):
-
"""Learning rate scheduler.
-
-
Arguments:
-
args: parser_setting
-
verbose: int. 0: quiet, 1: update messages.
-
"""
-
-
def __init__(self, args, verbose=0):
-
super(MultiStepLR, self).__init__()
-
self.args = args
-
self.steps = args.lr_decay_epochs
-
self.factor = args.lr_decay_factor
-
self.verbose = verbose
-
-
def on_epoch_begin(self, epoch, logs=None):
-
if not hasattr(self.model.optimizer, 'lr'):
-
raise ValueError('Optimizer must have a "lr" attribute.')
-
lr = self.schedule(epoch)
-
if not isinstance(lr, (float, np.float32, np.float64)):
-
raise ValueError('The output of the "schedule" function '
-
'should be float.')
-
K.set_value(self.model.optimizer.lr, lr)
-
print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
-
if self.verbose > 0:
-
print('\nEpoch %05d: MultiStepLR reducing learning '
-
'rate to %s.' % (epoch + 1, lr))
-
-
def schedule(self, epoch):
-
lr = K.get_value(self.model.optimizer.lr)
-
for i in range(len(self.steps)):
-
if epoch == self.steps[i]:
-
lr = lr * self.factor
-
-
return lr
2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)
-
callbacks = []
-
lr_scheduler = get_lr_scheduler(args=args)
-
callbacks.append(lr_scheduler)
-
-
...
-
model.fit_generator(train_generator,
-
steps_per_epoch=train_generator.samples // args.batch_size,
-
validation_data=test_generator,
-
validation_steps=test_generator.samples // args.batch_size,
-
workers=args.num_workers,
-
callbacks=callbacks, # 你的callbacks, 包含了lr_scheduler
-
epochs=args.epochs,
-
)
大家可以拿去用~
文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。
原文链接:nickhuang1996.blog.csdn.net/article/details/103645204
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)