논문 리뷰

Class-aware Information for Logit-based Knowledge Distillation 논문 리뷰

HyunMaru 2024. 3. 8. 02:43

논문 제목 : Class-aware Information for Logit-based Knowledge Distillation

컨퍼런스 : ??

저자 : Shuoxi Zhang, Hanpeng Liu

대학 : School of Computer Science and Technology Wuhan

초록

  • 지금까지의 logit-based distillation은 instance level에서 다루었다면, 논문은 다른 의미적인 정보들을 간과했던 것들을 관찰해보려고함
  • 논문은 간과점 문제를 다루기 위해 Class aware Logit KD(CLKD)를 제안함. 이는 instance-level과 class-level을 동시에 logit distillation하기 위한 모듈임
  • CLKD는 distillation performance를 높이면서 Teacher로부터의 상당한 의미적인 정보를 흉내낼 수 있음
  • 또한, 논문은 Class Correlation Loss라는 것을 제안하여 Teacher의 내재된 class-level 상관관계를 Student에 주입시키도록 할 것임

서론

  • KD의 접근으로 logit-based / feature-based로 쪼개져나옴
  • 이때 핵심 아이디어는 vanilla KD를 기반으로 다져지는데, 이는 logit-based 에 속함.
    • soft label 확률 분포는 ground-truth 레이블 단독으로 쓰는 것보다 많은 정보를 포함하고 있음
    • 이는 student model이 distillation 이후에도 성능이 유지되게끔 도움을 줌.
  • 현존하는 연구들(23, 41)은 logit 기반의 distillation의 약함을 위해 logit으로부터의 정보 추출의 개선방안들을 제안함
    • 23 : Seyed Iman Mirzadeh, Mehrdad Farajtabar, Ang Li, Nir Levine, Akihiro Matsukawa, and Hassan Ghasemzadeh. Improved knowledge distillation via teacher assistant. In AAAI, pages 5191–5198, 2020.
    • 41 : Borui Zhao, Quan Cui, Renjie Song, Yiyu Qiu, and Jiajun Liang. Decoupled knowledge distillation. In CVPR, pages 11953–11962, 2022.
  • 본 논문에서는 위와 같은 현존하는 연구들은 오직 instance-level 정보만 사용함을 인지함
  • 그 instance-level은 특정 이미지에 대한 확률 분포들이 나오게 되는데, 이를 inter-instance knowledge 라고 부른다고 함
  • 위의 고찰을 토대로, 논문은 간단하나 효과적인 logit distillation method를 제안. 이는 class-aware logit KD(CLKD)라고 부르며, inter-instance와 intra-instance knowledge를 짬뽕시켰음
  • 그리고 Class representation에 의해 클래스 상관관계 정보가 교정이 됨
  • CLKD 그림
    • inter-instance knowledge 추출을 위해 logit matrix를 역전시킴
      • 이는 intra-instance 정보의 보완적인 내용(?) 이라카네
    • class correlation loss를 통해 Teacher의 class correlation을 Student가 따라할 수 있도록함
    • 근데 그림을 보면, Teacher와 Student의 예측 확률의 차이 폭이 너무 큼. 따라서, 논문은 그 차이 폭을 줄이기 위해 normalized metric 디자인함
  • 기존 logit-based 관련 연구들 중 Chen 저자의 논문은 teacher의 분류기를 도입해서 Student와 Teacher간 표현력 차이를 좁히려고 했으나, 이는 Classifier 사용에 대한 비용 및 지식 증류의 원칙을 위배하는 사항이 됨에 따라 그렇게 효과적이진 않음
  • 본 논문의 핵심 아이디어는 이전 연구인 Contrastive Clustering(CC)와 연관있다고 함. CC는 cluster 표현으로 소개되었으며, mode collapse를 피하기 위한 instance-wise contrastive learning이다.

Proposed Method

본론


  • vanilla kd에 대해 소개를 한 뒤, Teacher Knowledge를 효과적으로 파악하기 위해, 출력 행렬을 역전(transpose) 시켜 class-instance 의존성을 포착하려고 함.
  • Teacher와 Student의 출력 표현의 크기차를 위해 Normalized Mean-Sqaured Error(NMSE)를 제안함.
  • Correlation Congruence의 아이디어를 따와서 class correlation information을 파악하기 위한 전략을 제안함.
  • 논문은 또한, feature-based distillation에 적용할 수 있었다고 함

Vanilla Knowledge Distillation

  • vanilla KD는 soft logit 확률 분포를 전이시키는 과정으로 loss를 아래와 같이 표현함

$$ \mathcal{L} = (1-\alpha)\mathcal{L}{CE}(\sigma(z_S, y))+ \alpha\mathcal{L}{KD}(z_S, z_T), $$

  • vanilla KD는 temperature-scaled KL divergence로 softhend categorical probability 분포를 흉내내게끔 하기 위해 아래와 같이 KD loss를 표현함

$$ \mathcal{L}{KD} = \gamma^2\mathcal{L}{KL}(\sigma(z_S/\gamma), \sigma(z_T/\gamma)) $$

  • temperature 상수가 1이면, 그냥 softmax가 됨. 상수가 커지면 커질수록 더 분포도는 softer해짐.
  • vanilla KD는 하나의 이미지에 대한 Teacher의 예측은 instance-level class 유사성 지식을 담고 있다고 함
  • 본 논문에서는 vanilla KD가 오직 instance level만 이용하고 있음에 따라 완전히 이용해먹을 수 있다는 것을 알아챔
  • 이를 위해 논문은 2가지 관점으로부터 output 행렬을 학습하는 목표를 세움
    1. instance-wise 확률 분포
    2. 더 높은 의미적인 수준에서의 지식 전이를 위한 범주형적인 표현

Motivation

  • instance(data)의 mini-batch B 가 주어졌을때 logit matrix B x C ( B : 배치 데이터 개수, C : 범주형 개수) 가 만들어진다고 하자, 여기서 행 벡터는 instance-level의 범주형 확률 예측값이다. vanilla KD는 Teacher의 행에 대한 logit matrix를 통해 범주형 확률 정보를 포착한다.
  • 하지만, logit matrix의 열벡터는 instance 사이의 class-level 정보를 포함하고 있음.
  • 현존하는 logit 방법론들은 하나의 이미지에 대한 관계성을 표현하는 반면, 다른 이미지 사이의 의미적인 의존성(semantic dependency)를 간과함
  • 이를 기반으로 논문은 instance-wise level을 넘어서 logit matrix를 이용하는 새로운 KD 방법을 제안함
  • 특히, 논문은 instance-wise 정보와 class-wise 정보를 이용해서 Teacher의 지식 정보를 더 포괄적이게 하고자함

Class Distillation

  • mini-batch 데이터가 주어졌을 때 softmax 전의 logit matrix를 따로 저장함. 그런 다음 각 logit matrix의 열 벡터는 클래스에 일치하는 표현으로 간주함.
  • 각 열 벡터는 instance와 class 사이의 유사성을 표현하며, 이를 class-instance similiarity knowledge 라고함
  • 이를 구현하기 위해, matrix transposition을 수행하며, 이를 통해 inter-instance 정보는 Student가 Teacher의 비슷한 클래스 표현을 생성할 수 있도록 강요받게끔 전이된다.
  • 이때의 Loss Function은 아래와 같이 표현됨

$$ \mathcal{L}{KD} = \mathcal{L}{ins} + \beta\mathcal{L}_{cla}, $$

$$ \mathcal{L}{ins} = \mathcal{L}{NMSE}(Z_s, Z_t), $$

$$ \mathcal{L}{cla} = \mathcal{L}{NMSE}(norm^T(Z_s), norm^T(Z_T)), $$

$$ \mathcal{L}_{NMSE}(p, z) = \left\lVert \frac{{p}}{\left\lVert p \right\rVert_2} - \frac{{z}}{\left\lVert z \right\rVert_2} \right\rVert^2_2$$

  • L_ins는 instance-wise distillation loss이고, L_cla는 논문에서 제안하는 class-wise distillation loss임. 계수 beta는 두 loss 간 trade-off임
  • p와 z는 서로 다른 encodings, distribution을 이야기하는 건데, 여기서는 teacher와 student의 logit을 이야기하는 것임.
  • 다음으로 transposition 하기전의 normalization(L2 norm)은 서로 다른 행의 크기 차에 의한 클래스 표현력 손상 완화를 위해 사용함
  • Kim 저자는 logit matching 전략으로 logit 사이의 MSE를 활용했는데 이는 KL loss에 비해 좋았다고 함
    • 다만, 불가피한 모델 크기에 의한 차이 및 Student와 Teacher 사이의 파라미터 수, 출력의 norm은 비슷한 크기를 공유하기에는 어려움이 있으며, 결국 간단한 logit matching에 국한되어 사용됨
  • 따라서, MSE를 계산하기 전에 logit을 정규화함

Class Correlation Loss

  • logit transposition으로 class representation을 인코딩했는데, 논문에서의 class representation은 다른 클래스 사이의 상관관계를 담고 있음.
  • 그럼 이를 어떻게 Student의 클래스 상관관계를 학습시킬 수 있을까?
  • Correlation Congruence에 영감을 받아 instance correlation 포착 방법론을 제안함.

Class Correlation Matrix

$$ \mathcal{B}(Z) = \frac{1}{C-1}\sum^C_{j=1} (Z._j-\bar{Z})^T (Z._j-\bar{Z}), $$

  • 위 수식에서 \bar(Z)는 colum벡터의 평균을 의미한다. 그리고 Z.j는 j번째 column 벡터를 의미한다. 위와 같은 수식은 student의 클래스 상관관계가 Teacher를 닮도록 강요함
  • Class Correlation(CC) Loss는 결국 아래와 같이 표현하게 됨. 이는 클래스 상관관계 차이를 계산함

$$ \mathcal{L}_{CC}(S, T) = \frac{1}{C^2}\left\lVert\mathcal{B}(Z_S) - \mathcal{B}(Z_T)\right\lVert_2^2 $$

  • Class Correlation Loss까지 표현했을때, 결국 총 Loss function은 아래와 같이 표현됨

$$ \mathcal{L} = \lambda\mathcal{L}{CE} + \mu\mathcal{L}{KD} + \nu\mathcal{L}_{CC} $$

  • 이때, 붙는 상수들은 총 합이 1이 되며 이는 각 term에 대한 importance를 제어하기 위함임.