논문 제목 : Class-aware Information for Logit-based Knowledge Distillation
컨퍼런스 : ??
저자 : Shuoxi Zhang, Hanpeng Liu
대학 : School of Computer Science and Technology Wuhan
초록
- 지금까지의 logit-based distillation은 instance level에서 다루었다면, 논문은 다른 의미적인 정보들을 간과했던 것들을 관찰해보려고함
- 논문은 간과점 문제를 다루기 위해 Class aware Logit KD(CLKD)를 제안함. 이는 instance-level과 class-level을 동시에 logit distillation하기 위한 모듈임
- CLKD는 distillation performance를 높이면서 Teacher로부터의 상당한 의미적인 정보를 흉내낼 수 있음
- 또한, 논문은 Class Correlation Loss라는 것을 제안하여 Teacher의 내재된 class-level 상관관계를 Student에 주입시키도록 할 것임
서론
- KD의 접근으로 logit-based / feature-based로 쪼개져나옴
- 이때 핵심 아이디어는 vanilla KD를 기반으로 다져지는데, 이는 logit-based 에 속함.
- soft label 확률 분포는 ground-truth 레이블 단독으로 쓰는 것보다 많은 정보를 포함하고 있음
- 이는 student model이 distillation 이후에도 성능이 유지되게끔 도움을 줌.
- 현존하는 연구들(23, 41)은 logit 기반의 distillation의 약함을 위해 logit으로부터의 정보 추출의 개선방안들을 제안함
- 23 : Seyed Iman Mirzadeh, Mehrdad Farajtabar, Ang Li, Nir Levine, Akihiro Matsukawa, and Hassan Ghasemzadeh. Improved knowledge distillation via teacher assistant. In AAAI, pages 5191–5198, 2020.
- 41 : Borui Zhao, Quan Cui, Renjie Song, Yiyu Qiu, and Jiajun Liang. Decoupled knowledge distillation. In CVPR, pages 11953–11962, 2022.
- 본 논문에서는 위와 같은 현존하는 연구들은 오직 instance-level 정보만 사용함을 인지함
- 그 instance-level은 특정 이미지에 대한 확률 분포들이 나오게 되는데, 이를 inter-instance knowledge 라고 부른다고 함
- 위의 고찰을 토대로, 논문은 간단하나 효과적인 logit distillation method를 제안. 이는 class-aware logit KD(CLKD)라고 부르며, inter-instance와 intra-instance knowledge를 짬뽕시켰음
- 그리고 Class representation에 의해 클래스 상관관계 정보가 교정이 됨
- CLKD 그림
- inter-instance knowledge 추출을 위해 logit matrix를 역전시킴
- 이는 intra-instance 정보의 보완적인 내용(?) 이라카네
- class correlation loss를 통해 Teacher의 class correlation을 Student가 따라할 수 있도록함
- 근데 그림을 보면, Teacher와 Student의 예측 확률의 차이 폭이 너무 큼. 따라서, 논문은 그 차이 폭을 줄이기 위해 normalized metric 디자인함
- inter-instance knowledge 추출을 위해 logit matrix를 역전시킴
- 기존 logit-based 관련 연구들 중 Chen 저자의 논문은 teacher의 분류기를 도입해서 Student와 Teacher간 표현력 차이를 좁히려고 했으나, 이는 Classifier 사용에 대한 비용 및 지식 증류의 원칙을 위배하는 사항이 됨에 따라 그렇게 효과적이진 않음
- 본 논문의 핵심 아이디어는 이전 연구인 Contrastive Clustering(CC)와 연관있다고 함. CC는 cluster 표현으로 소개되었으며, mode collapse를 피하기 위한 instance-wise contrastive learning이다.
본론
- vanilla kd에 대해 소개를 한 뒤, Teacher Knowledge를 효과적으로 파악하기 위해, 출력 행렬을 역전(transpose) 시켜 class-instance 의존성을 포착하려고 함.
- Teacher와 Student의 출력 표현의 크기차를 위해 Normalized Mean-Sqaured Error(NMSE)를 제안함.
- Correlation Congruence의 아이디어를 따와서 class correlation information을 파악하기 위한 전략을 제안함.
- 논문은 또한, feature-based distillation에 적용할 수 있었다고 함
Vanilla Knowledge Distillation
- vanilla KD는 soft logit 확률 분포를 전이시키는 과정으로 loss를 아래와 같이 표현함
$$ \mathcal{L} = (1-\alpha)\mathcal{L}{CE}(\sigma(z_S, y))+ \alpha\mathcal{L}{KD}(z_S, z_T), $$
- vanilla KD는 temperature-scaled KL divergence로 softhend categorical probability 분포를 흉내내게끔 하기 위해 아래와 같이 KD loss를 표현함
$$ \mathcal{L}{KD} = \gamma^2\mathcal{L}{KL}(\sigma(z_S/\gamma), \sigma(z_T/\gamma)) $$
- temperature 상수가 1이면, 그냥 softmax가 됨. 상수가 커지면 커질수록 더 분포도는 softer해짐.
- vanilla KD는 하나의 이미지에 대한 Teacher의 예측은 instance-level class 유사성 지식을 담고 있다고 함
- 본 논문에서는 vanilla KD가 오직 instance level만 이용하고 있음에 따라 완전히 이용해먹을 수 있다는 것을 알아챔
- 이를 위해 논문은 2가지 관점으로부터 output 행렬을 학습하는 목표를 세움
- instance-wise 확률 분포
- 더 높은 의미적인 수준에서의 지식 전이를 위한 범주형적인 표현
Motivation
- instance(data)의 mini-batch B 가 주어졌을때 logit matrix B x C ( B : 배치 데이터 개수, C : 범주형 개수) 가 만들어진다고 하자, 여기서 행 벡터는 instance-level의 범주형 확률 예측값이다. vanilla KD는 Teacher의 행에 대한 logit matrix를 통해 범주형 확률 정보를 포착한다.
- 하지만, logit matrix의 열벡터는 instance 사이의 class-level 정보를 포함하고 있음.
- 현존하는 logit 방법론들은 하나의 이미지에 대한 관계성을 표현하는 반면, 다른 이미지 사이의 의미적인 의존성(semantic dependency)를 간과함
- 이를 기반으로 논문은 instance-wise level을 넘어서 logit matrix를 이용하는 새로운 KD 방법을 제안함
- 특히, 논문은 instance-wise 정보와 class-wise 정보를 이용해서 Teacher의 지식 정보를 더 포괄적이게 하고자함
Class Distillation
- mini-batch 데이터가 주어졌을 때 softmax 전의 logit matrix를 따로 저장함. 그런 다음 각 logit matrix의 열 벡터는 클래스에 일치하는 표현으로 간주함.
- 각 열 벡터는 instance와 class 사이의 유사성을 표현하며, 이를 class-instance similiarity knowledge 라고함
- 이를 구현하기 위해, matrix transposition을 수행하며, 이를 통해 inter-instance 정보는 Student가 Teacher의 비슷한 클래스 표현을 생성할 수 있도록 강요받게끔 전이된다.
- 이때의 Loss Function은 아래와 같이 표현됨
$$ \mathcal{L}{KD} = \mathcal{L}{ins} + \beta\mathcal{L}_{cla}, $$
$$ \mathcal{L}{ins} = \mathcal{L}{NMSE}(Z_s, Z_t), $$
$$ \mathcal{L}{cla} = \mathcal{L}{NMSE}(norm^T(Z_s), norm^T(Z_T)), $$
$$ \mathcal{L}_{NMSE}(p, z) = \left\lVert \frac{{p}}{\left\lVert p \right\rVert_2} - \frac{{z}}{\left\lVert z \right\rVert_2} \right\rVert^2_2$$
- L_ins는 instance-wise distillation loss이고, L_cla는 논문에서 제안하는 class-wise distillation loss임. 계수 beta는 두 loss 간 trade-off임
- p와 z는 서로 다른 encodings, distribution을 이야기하는 건데, 여기서는 teacher와 student의 logit을 이야기하는 것임.
- 다음으로 transposition 하기전의 normalization(L2 norm)은 서로 다른 행의 크기 차에 의한 클래스 표현력 손상 완화를 위해 사용함
- Kim 저자는 logit matching 전략으로 logit 사이의 MSE를 활용했는데 이는 KL loss에 비해 좋았다고 함
- 다만, 불가피한 모델 크기에 의한 차이 및 Student와 Teacher 사이의 파라미터 수, 출력의 norm은 비슷한 크기를 공유하기에는 어려움이 있으며, 결국 간단한 logit matching에 국한되어 사용됨
- 따라서, MSE를 계산하기 전에 logit을 정규화함
Class Correlation Loss
- logit transposition으로 class representation을 인코딩했는데, 논문에서의 class representation은 다른 클래스 사이의 상관관계를 담고 있음.
- 그럼 이를 어떻게 Student의 클래스 상관관계를 학습시킬 수 있을까?
- Correlation Congruence에 영감을 받아 instance correlation 포착 방법론을 제안함.
$$ \mathcal{B}(Z) = \frac{1}{C-1}\sum^C_{j=1} (Z._j-\bar{Z})^T (Z._j-\bar{Z}), $$
- 위 수식에서 \bar(Z)는 colum벡터의 평균을 의미한다. 그리고 Z.j는 j번째 column 벡터를 의미한다. 위와 같은 수식은 student의 클래스 상관관계가 Teacher를 닮도록 강요함
- Class Correlation(CC) Loss는 결국 아래와 같이 표현하게 됨. 이는 클래스 상관관계 차이를 계산함
$$ \mathcal{L}_{CC}(S, T) = \frac{1}{C^2}\left\lVert\mathcal{B}(Z_S) - \mathcal{B}(Z_T)\right\lVert_2^2 $$
- Class Correlation Loss까지 표현했을때, 결국 총 Loss function은 아래와 같이 표현됨
$$ \mathcal{L} = \lambda\mathcal{L}{CE} + \mu\mathcal{L}{KD} + \nu\mathcal{L}_{CC} $$
- 이때, 붙는 상수들은 총 합이 1이 되며 이는 각 term에 대한 importance를 제어하기 위함임.
'논문 리뷰' 카테고리의 다른 글
Distilling Knowledge via Knowledge Review 논문 리뷰 (2) | 2024.02.29 |
---|---|
Decoupled Knowledge Distillation 논문 리뷰 (0) | 2024.02.23 |
GCN paper review (0) | 2023.06.02 |
GCN paper review 전 공부 - semi-supervised learning (0) | 2023.05.30 |
MobileNet paper review (0) | 2023.02.08 |