KNOT: Knowledge Distillation Using Optimal Transport for Solving NLP Tasks

Rishabh Bhardwaj, Tushar Vaidya, Soujanya Poria


Abstract
We propose a new approach, Knowledge Distillation using Optimal Transport (KNOT), to distill the natural language semantic knowledge from multiple teacher networks to a student network. KNOT aims to train a (global) student model by learning to minimize the optimal transport cost of its assigned probability distribution over the labels to the weighted sum of probabilities predicted by the (local) teacher models, under the constraints that the student model does not have access to teacher models’ parameters or training data. To evaluate the quality of knowledge transfer, we introduce a new metric, Semantic Distance (SD), that measures semantic closeness between the predicted and ground truth label distributions. The proposed method shows improvements in the global model’s SD performance over the baseline across three NLP tasks while performing on par with Entropy-based distillation on standard accuracy and F1 metrics. The implementation pertaining to this work is publicly available at https://github.com/declare-lab/KNOT.
Anthology ID:
2022.coling-1.425
Volume:
Proceedings of the 29th International Conference on Computational Linguistics
Month:
October
Year:
2022
Address:
Gyeongju, Republic of Korea
Venue:
COLING
SIG:
Publisher:
International Committee on Computational Linguistics
Note:
Pages:
4801–4820
Language:
URL:
https://aclanthology.org/2022.coling-1.425
DOI:
Bibkey:
Cite (ACL):
Rishabh Bhardwaj, Tushar Vaidya, and Soujanya Poria. 2022. KNOT: Knowledge Distillation Using Optimal Transport for Solving NLP Tasks. In Proceedings of the 29th International Conference on Computational Linguistics, pages 4801–4820, Gyeongju, Republic of Korea. International Committee on Computational Linguistics.
Cite (Informal):
KNOT: Knowledge Distillation Using Optimal Transport for Solving NLP Tasks (Bhardwaj et al., COLING 2022)
Copy Citation:
PDF:
https://aclanthology.org/2022.coling-1.425.pdf
Code
 declare-lab/knot +  additional community code
Data
MELDSNLI