持续学习的方法
链接
best incremental learning github
持续学习(Continual Learning)最新综述
背景与挑战
持续学习(Continual Learning, CL)旨在使模型在不断接收新任务的过程中,不遗忘旧任务的知识,同时能够高效地适应新任务。这一领域的关键挑战包括:
- 灾难性遗忘:模型在学习新任务时会覆盖旧任务的知识。
- 任务间冲突:新旧任务目标函数的梯度冲突,导致模型难以协调优化。
- 存储限制:部分场景无法存储历史数据,增加了设计的复杂性。
为了解决这些问题,研究者提出了多种技术方案,包括特征回放、提示学习(Prompting)、正则化方法以及生成模型。
模型分类与最新进展
以下是最新的持续学习模型分类及其特点,包含特征回放、生成式模型、正则化方法等不同范式。
1. 特征回放(Feature Replay)
通过保存历史任务的特征或中间表示,避免直接存储原始数据:
Experience Replay (ER)
使用历史数据的特征与新任务共同训练,简单高效。ER-ACE
优化了损失函数,避免回放数据对当前任务的干扰。Dark Experience Replay (DER) & DER++
保存网络输出的表征,用于提高特征一致性,特别是 DER++ 进一步改进了模型泛化性能。Meta-Experience Replay (MER)
基于元学习的方法优化特征回放的策略。iCaRL (Incremental Classifier and Representation Learning)
保存每个类别的特征均值,结合最近邻分类器完成任务。LiDER 系列
基于 DER++ 和 ER-ACE,优化了特征回放的选择策略。
2. 提示学习(Prompt Learning)
通过为不同任务设计动态提示,减少灾难性遗忘:
DualPrompt
结合提示学习与特征回放,在任务切换时动态调整提示。L2P (Learning to Prompt)
基于 Transformer,自动生成适应新任务的提示。CODA-Prompt
解构任务并动态生成提示,无需保存历史数据。STAR-Prompt
提出了两阶段动态提示,分别用于特征学习和分类任务。
3. 生成式模型(Generative Replay)
使用生成模型生成历史任务数据进行训练:
Continual Generative Training for Incremental Prompt-Learning (CGIL)
基于生成模型对任务特征进行增量优化。Hindsight Anchor Learning (HAL)
使用生成式方法生成任务锚点,降低任务间冲突。
4. 正则化方法(Regularization-Based Methods)
通过约束参数更新,减少灾难性遗忘:
EWC (Elastic Weight Consolidation)
通过加权惩罚约束关键参数的更新。Synaptic Intelligence (SI)
动态调整正则项的权重,以适应新任务。Function Distance Regularization (FDR)
通过约束特征空间距离,减少特征漂移。
5. 其他方法
GDumb
简单地保存少量数据用于训练,作为强基线。X-DER (eXtended-DER)
扩展了 DER 方法,结合正则化和生成式回放技术。CLIP & AttriCLIP
在多模态任务中,结合预训练模型实现持续学习。
持续学习未来方向
- 统一框架设计:整合特征回放、提示学习和生成式模型的优点。
- 小样本持续学习:在数据有限的情况下提高模型性能。
- 可解释性增强:研究参数更新和知识保存的可视化与分析。
- 多模态持续学习:扩展到图像、文本、音频等多模态数据场景。
总结
持续学习领域的研究正在快速发展,各类方法在减少灾难性遗忘、提高新任务适应性方面各有特点。结合特征回放与提示学习的技术,例如 DualPrompt 和 LiDER 系列,表现出强大的泛化能力。未来的研究将继续朝着更加高效、通用的方向发展。
Function Distance Regularization (FDR) 模型综述
Function Distance Regularization (FDR) 是持续学习中的一个关键方向,通过约束任务间特征空间的距离,减少新旧任务特征分布的漂移。以下是一些典型模型及其论文:
1. Hindsight Anchor Learning (HAL)
- 特点:
HAL 为每个任务设置“锚点”(anchor),保持新旧任务特征分布一致性,用固定参考减少特征漂移。 - 论文名称:
Hindsight Anchor Learning for Continual Learning - 技术细节:
计算当前任务特征与锚点的距离损失,优化特征空间。
2. Gradient Episodic Memory (GEM)
- 特点:
使用梯度投影技术,在优化新任务时,确保特征更新不会破坏旧任务。 - 论文名称:
Gradient Episodic Memory for Continual Learning - 技术细节:
约束新任务梯度方向,使其与旧任务梯度变化一致,从而优化特征分布。
3. Meta-Experience Replay (MER)
- 特点:
结合经验回放和元学习,对齐任务间特征空间的变化,减少遗忘。 - 论文名称:
Meta-Experience Replay for Continual Learning - 技术细节:
动态调整样本权重,优化新旧任务特征距离。
4. Synaptic Intelligence (SI)
- 特点:
使用重要性权重约束任务特征更新,对旧任务重要的特征赋予更高正则化权重。 - 论文名称:
Continual Learning Through Synaptic Intelligence - 技术细节:
计算参数更新的重要性,减少对旧任务关键特征的干扰。
5. Regular Polytope Classifier (RPC)
- 特点:
强调特征点在几何空间中的一致性,利用正则化约束新旧任务特征在多边形空间中的投影。 - 论文名称:
Regular Polytope Networks for Incremental Learning - 技术细节:
通过几何对齐优化特征分布,减少新旧任务分类界限偏移。
以上模型通过不同方式优化特征分布,减少新任务学习对旧任务的干扰,是 Function Distance Regularization 的典型实现。
三种增量学习的区别
模型明确知道当前评分属于某一个特定区间的含义
定义
当模型明确知道当前评分属于某一个特定区间时,表示在训练或测试过程中,每次输入数据时,模型会接收到一个“任务标记”或“区间信息”,明确告知模型当前任务的范围。
特点
明确任务边界
- 模型知道当前任务对应的评分区间,例如 “0-20”。
- 模型只需处理这个范围内的数据,其他区间的数据对当前任务无影响。
任务切换
- 如果切换到 “20-40” 区间,则任务标记会更新。
- 模型能够根据任务标记进行专门的优化,无需处理全局关系。
简化学习
- 模型不需要自行判断数据属于哪个区间,任务标记提供了显式的指导。
- 减少了任务间的歧义和干扰。
典型应用
- 任务增量学习(Task-Incremental Learning, TIL)
- 通过任务标记显式区分任务。
- 每个任务相对独立,模型在不同任务间的干扰较少。
- 任务增量学习(Task-Incremental Learning, TIL)
域增量学习
定义
- 域增量学习(Domain-Incremental Learning, DIL)中,模型接收不同领域(如不同数据分布)的数据流,并试图学习适应这些变化的能力。
- 但模型无法通过显式任务标记识别当前数据来自哪个领域。
特点
- 模型必须对不同领域的特征分布自适应,而不依赖于任务提示。
- 比如动作质量评估中,评分数据可能来自不同相机的拍摄,域增量学习要求模型跨相机(领域)泛化。
对比
- 与任务增量不同,域增量学习不提供显式标记,强调跨领域一致性。
- 与类增量相似,需自主应对数据分布变化,但着重领域变化而非类别增加。
示例
在动作质量评估(AQA)任务中:
- 任务增量学习:每个评分区间(如
0-20
、20-40
)是独立任务,提供区间标记。 - 域增量学习:评分区间相同,但数据分布因拍摄设备、场景等变化,模型需跨领域泛化。
- 类增量学习:评分范围内新增类别(如从动作类型 A 增加到类型 B),需对新类别进行学习。
总结
学习类别 | 特点 | 应用场景 |
---|---|---|
类增量学习 | 增加新类别,模型需学习新知识,不忘旧知识 | 分类任务,如新增动作类型 |
任务增量学习 | 明确任务边界,任务间独立 | 评分区间划分,明确每段评分区间 |
域增量学习 | 数据分布变化,需跨领域适应 | 同评分范围内,不同设备、场景或环境的数据分布差异 |
Mammoth框架
配置解析
parse_args解析配置:
add_initial_args
:
--dataset |
add_configuration_args
:
--dataset_config: |
type=field_with_aliases({‘default’: [‘base’, ‘default’], ‘best’: [‘best’]}):别名连接
load_configs
:加载配置文件
get_model_class,返回数据集类型,dataset_name 就是datasets/{name}.py。
get_default_args_for_dataset(args.dataset):加载数据集的参数,在数据集中定义。
如果模型参数中没有dataset_config,就使用数据集中的参数。
可以指定参数名称,在数据集文件夹下,否则使用default.yaml。
最后返回数据集,模型,基本参数。
add_dynamic_parsable_args(parser, args.dataset, backbone)
:动态加载参数。
如果参数是字典类型,当做参数解析出来。
add_management_args(parser)
:添加管理参数
add_experiment_args(parser)
: 添加实验参数
update_cli_defaults(parser, config)
:配置文件加载参数
get_dataset获取数据类:
extend_args(args, dataset
:加载数据类的额外参数
backbone = get_backbone(args):
基础参数要在命令行中输入:
sys.argv = ['CAQA', |
datasets
class_rg
修改dataload,改为自己的数据加载器。