Keras中的MultiStepLR

举报
悲恋花丶无心之人 发表于 2021/02/03 01:14:15 2021/02/03
【摘要】 Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的: 1.代码 from tensorflow.python.keras.callbacks import Callbackfrom tensorflow.python.keras import backend as Kimport numpy as npimport argpar...

Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:

1.代码


  
  1. from tensorflow.python.keras.callbacks import Callback
  2. from tensorflow.python.keras import backend as K
  3. import numpy as np
  4. import argparse
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
  7. parser.add_argument('--lr_decay_factor', type=float, default=0.1)
  8. args, _ = parser.parse_known_args()
  9. def get_lr_scheduler(args):
  10. lr_scheduler = MultiStepLR(args=args)
  11. return lr_scheduler
  12. class MultiStepLR(Callback):
  13. """Learning rate scheduler.
  14. Arguments:
  15. args: parser_setting
  16. verbose: int. 0: quiet, 1: update messages.
  17. """
  18. def __init__(self, args, verbose=0):
  19. super(MultiStepLR, self).__init__()
  20. self.args = args
  21. self.steps = args.lr_decay_epochs
  22. self.factor = args.lr_decay_factor
  23. self.verbose = verbose
  24. def on_epoch_begin(self, epoch, logs=None):
  25. if not hasattr(self.model.optimizer, 'lr'):
  26. raise ValueError('Optimizer must have a "lr" attribute.')
  27. lr = self.schedule(epoch)
  28. if not isinstance(lr, (float, np.float32, np.float64)):
  29. raise ValueError('The output of the "schedule" function '
  30. 'should be float.')
  31. K.set_value(self.model.optimizer.lr, lr)
  32. print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
  33. if self.verbose > 0:
  34. print('\nEpoch %05d: MultiStepLR reducing learning '
  35. 'rate to %s.' % (epoch + 1, lr))
  36. def schedule(self, epoch):
  37. lr = K.get_value(self.model.optimizer.lr)
  38. for i in range(len(self.steps)):
  39. if epoch == self.steps[i]:
  40. lr = lr * self.factor
  41. return lr

2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)


  
  1. callbacks = []
  2. lr_scheduler = get_lr_scheduler(args=args)
  3. callbacks.append(lr_scheduler)
  4. ...
  5. model.fit_generator(train_generator,
  6. steps_per_epoch=train_generator.samples // args.batch_size,
  7. validation_data=test_generator,
  8. validation_steps=test_generator.samples // args.batch_size,
  9. workers=args.num_workers,
  10. callbacks=callbacks, # 你的callbacks, 包含了lr_scheduler
  11. epochs=args.epochs,
  12. )

大家可以拿去用~

文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。

原文链接:nickhuang1996.blog.csdn.net/article/details/103645204

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。