-
Notifications
You must be signed in to change notification settings - Fork 1
Description
1. 概要(基本アイデア)
従来法:特徴セットからglobal contextを取得するためにAttentionを使って色々提案している
提案法:global context = 特徴セットの低ランク表現では?→行列分解をNNの中に入れ込んで学習
- 行列分解(MD)をNNの中で実行
- MDの最適化処理を,RNNで出てくる計算グラフのようにとらえて,BPTTで一貫学習
- 加えて,t(最適化の繰り返し回数)-> \infの時勾配消失が発生するので,回避するためにOneStep Gradientを提案
2. 新規性
- 行列分解によってglobal contextを取得する
- NN + 行列分解の処理をBPTTによって一貫学習可能にする
- 勾配消失を避ける手法を提案する
3. 手法詳細
行列分解(低ランク表現)によるglobal contextの取得
Forward
- 入力:特徴行列 X \in R^{d\times n}
- 出力:D C, D\in R{d\times r}, C \in R^{r\times n}
- 行列分解:D,C = argmin_{D,C} L(X, DC) + R1(D) + R2(C)
L: 損失関数(大体再構成誤差),R1,R2はそれぞれの行列に対する制約項
例えば,
- D,Cに非負制約を加えたら,NMF
- (C)_i = e_iとしたら,Vector Quantization
Vector Quantization, NMFの場合はAlgorithm1,2 にある繰り返し処理で最適化できる
Backward
Algorithm1,2 の繰り返し処理によって計算グラフが生成されている.
BPTTを使うと一貫学習できる.
一方,t->\infとすると,勾配消失やランク落ち行列の逆行列が出てくる.(Proposition1,2,3周辺)
本来の勾配(式12)を線形近似(式13, One Step Gradient)して学習する.
Hamburger
提案法全体の名前はHamburger.
名前の由来は,行列分解(Ham)の前後に線形変換(burgers)をつけるから.
ということで,全体の処理は下記(入力:XはCNNなどによって得られる特徴行列)
- Z1 = W1 X (線形変換, lower burger)
- Z2 = DC (Z1の低ランク表現, ham)
- Z3 = W2 Z2 (線形変換,upper burger)
- 最終出力= X + BN(Z3) (バッチ正規化+スキップコネクション)
4. 結果
使ったNNの構成
- ResNet 50で特徴抽出
- 3x3conv + BN + ReLUで2048チャネルを512チャネルに削減
- Hamburger
Sec 3.1 & 3.3
PASCAL VOCを使ったsemantic segmentationで下記を検証
- d=1024, r=8程度で良い結果
- MDの繰り返し回数は6あれば十分
- Vector QuantizationよりNMFの方がいい
- Attention系と比べて,パラメータ,GPU Load, GPU time全て低い
Sec 3.4
PASCAL VOC, PASCAL Contextでsemantic segmentation
結果従来用より良い結果(mIoU)
Sec 3.5
画像生成でもいい結果(FID)
5. 論文,コード等へのリンク
論文:https://openreview.net/forum?id=1FvkSpWosOl
コード:https://github.com/Gsunshine/Enjoy-Hamburger
6. 感想,コメント
行列分解の入れ込み方がすごい
OneStep Gradientは汎用性が高そうかつ実装が容易なので使えそう
名前がHamburger
7. bibtex
@inproceedings{
geng2021is,
title={Is Attention Better Than Matrix Decomposition?},
author={Zhengyang Geng and Meng-Hao Guo and Hongxu Chen and Xia Li and Ke Wei and Zhouchen Lin},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=1FvkSpWosOl}
}