# Description Codes for "Inflate and Shrink:Enriching and Reducing Interactions for Fast Text-Image Retrieval" # Implementation of key methods > APIs is provided by [PaddlePaddle](https://www.paddlepaddle.org.cn/) ``` import paddle class EncoderText(object): def __init__(self, opt): super(EncoderText, self).__init__() model_path = 'bert/' # There are five optional modes: "fast", "base", "inflate", "shrinktobase", "shrinktofast" self.MODE = opt.MODE self.MARGIN = opt.margin self.BASE_LENGTH = 32 self.fc = paddle.fluid.layers.fc(2048, 768) self.norm = paddle.fluid.layers.layer_norm(768, eps=1e-5) self.encoder = Bert.from_pretrained(model_path) if self.MODE == 'inflate' or self.MODE == 'shrinktobase': self.CODE_LENGTH = 32 self.code_vision_embeddings = paddle.fluid.layers.embedding(self.CODE_LENGTH, 768) self.code_text_embeddings = paddle.fluid.layers.embedding(self.CODE_LENGTH, 768) if self.MODE == 'shrinktobase': self.encoder2 = Bert.from_pretrained(model_path) self.TEACHER_LENGTH = 32 self.STUDENT_LENGTH = 32 self.TEATURE_TEMPERATURE = 16 self.STUDENT_TEMPERATURE = 16 self.CODE_LENGTH = 32 self.TEACHER_LENGTH = self.TEACHER_LENGTH + self.CODE_LENGTH elif self.MODE == 'shrinktofast': self.encoder2 = Bert.from_pretrained(model_path) self.encoder3 = Bert.from_pretrained(model_path) self.TEACHER_LENGTH = 32 self.STUDENT_LENGTH = 1 self.TEATURE_TEMPERATURE = 2 self.STUDENT_TEMPERATURE = 16 def forward(self, input_ids, token_type_ids, text_mask, vision_feat, vision_mask, IS_TEST=False): text_feat = self.encoder.embeddings(input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids.long().squeeze()) head_mask = [None]*20 vision_feat = self.fc(vision_feat) vision_feat = self.norm(vision_feat) BS = text_feat.size(0) # Prepare input for inflated model if self.MODE == 'inflate' or self.MODE == 'shrinktobase': code_vision_ids = paddle.arange(self.CODE_LENGTH, dtype=paddle.long) code_text_ids = paddle.arange(self.CODE_LENGTH, dtype=paddle.long) code_vision_mask = paddle.ones(BS,self.CODE_LENGTH).long() code_text_mask = paddle.ones(BS,self.CODE_LENGTH).long() code_vision_ids = code_vision_ids.unsqueeze(0).expand(BS, self.CODE_LENGTH) code_text_ids = code_text_ids.unsqueeze(0).expand(BS, self.CODE_LENGTH) code_vision_emb = self.code_vision_embeddings(code_vision_ids) code_vision_emb = self.norm(code_vision_emb) code_text_emb = self.code_text_embeddings(code_text_ids) code_text_emb = self.norm(code_text_emb) if self.MODE == 'inflate': text_mask = paddle.concat([text_mask.long(), code_text_mask], axis=1) vision_mask = paddle.concat([vision_mask.long(), code_vision_mask], axis=1) vision_feat = paddle.concat([vision_feat, code_vision_emb], axis=1) text_feat = paddle.concat([text_feat, code_text_emb], axis=1) elif self.MODE == 'shrinktobase': text_mask_plus = paddle.concat([text_mask.long(), code_text_mask], axis=1) vision_mask_plus = paddle.concat([vision_mask.long(), code_vision_mask], axis=1) vision_feat_plus = paddle.concat([vision_feat, code_vision_emb], axis=1) text_output_plus = paddle.concat([text_feat, code_text_emb], axis=1) extended_attention_mask_text = text_mask.squeeze()[:, None, None, :] extended_attention_mask_text = (1.0 - extended_attention_mask_text) * -10000.0 extended_attention_mask_vision = vision_mask.squeeze()[:, None, None, :] extended_attention_mask_vision = (1.0 - extended_attention_mask_vision) * -10000.0 # Encoding for base, fast, inflated model if self.MODE == 'base' or self.MODE == 'fast' or self.MODE == 'inflate': extended_attention_mask_text = text_mask.squeeze()[:, None, None, :] extended_attention_mask_text = (1.0 - extended_attention_mask_text) * -10000.0 extended_attention_mask_vision = vision_mask.squeeze()[:, None, None, :] extended_attention_mask_vision = (1.0 - extended_attention_mask_vision) * -10000.0 text_feature_teacher = self.encoder.encoder(text_feat,extended_attention_mask_text,head_mask) vision_feature_teacher = self.encoder.encoder(vision_feat,extended_attention_mask_vision,head_mask) text_feature_teacher = text_feature_teacher[0] vision_feature_teacher = vision_feature_teacher[0] vision_feature_teacher = paddle.nn.functional.normalize(vision_feature_teacher, p=2, axis=2) text_feature_teacher = paddle.nn.functional.normalize(text_feature_teacher, p=2, axis=2) if self.MODE == 'base': FRAG_LENGTH = self.BASE_LENGTH elif self.MODE == 'fast': FRAG_LENGTH = 1 elif self.MODE == 'inflate': FRAG_LENGTH = self.BASE_LENGTH + self.CODE_LENGTH vision_feature_teacher = vision_feature_teacher[:, :FRAG_LENGTH] text_feature_teacher = text_feature_teacher[:, :FRAG_LENGTH] if IS_TEST == False: text_feature_teacher = paddle.reshape(text_feature_teacher.unsqueeze(0).expand(BS, -1, -1, -1), [BS * BS, FRAG_LENGTH, -1]) vision_feature_teacher = paddle.reshape(vision_feature_teacher.unsqueeze(1).expand(-1, BS, -1, -1), [BS * BS, FRAG_LENGTH, -1]) # Encoding for shiriking if self.MODE == 'shrinktobase': extended_attention_mask_text = text_mask.squeeze()[:, None, None, :] extended_attention_mask_text = (1.0 - extended_attention_mask_text) * -10000.0 extended_attention_mask_vision = vision_mask.squeeze()[:, None, None, :] extended_attention_mask_vision = (1.0 - extended_attention_mask_vision) * -10000.0 extended_attention_mask_text_plus = text_mask_plus.squeeze()[:, None, None, :] extended_attention_mask_text_plus = (1.0 - extended_attention_mask_text_plus) * -10000.0 extended_attention_mask_vision_plus = vision_mask_plus.squeeze()[:, None, None, :] extended_attention_mask_vision_plus = (1.0 - extended_attention_mask_vision_plus) * -10000.0 text_feature_teacher = self.encoder.encoder(text_output_plus,extended_attention_mask_text_plus,head_mask) vision_feature_teacher = self.encoder.encoder(vision_feat_plus,extended_attention_mask_vision_plus,head_mask) text_feature_teacher = text_feature_teacher[0] vision_feature_teacher = vision_feature_teacher[0] vision_feature_teacher = paddle.nn.functional.normalize(vision_feature_teacher, p=2, axis=2) text_feature_teacher = paddle.nn.functional.normalize(text_feature_teacher, p=2, axis=2) FRAG_LENGTH = self.TEACHER_LENGTH vision_feature_teacher = vision_feature_teacher[:, :FRAG_LENGTH] text_feature_teacher = text_feature_teacher[:, :FRAG_LENGTH] text_feature_student = self.encoder2.encoder(text_feat,extended_attention_mask_text,head_mask) vision_feature_student = self.encoder2.encoder(vision_feat,extended_attention_mask_vision,head_mask) text_feature_student = text_feature_student[0] vision_feature_student = vision_feature_student[0] vision_feature_student = paddle.nn.functional.normalize(vision_feature_student, p=2, axis=2) text_feature_student = paddle.nn.functional.normalize(text_feature_student, p=2, axis=2) FRAG_LENGTH_STUDENT = self.STUDENT_LENGTH vision_feature_student = vision_feature_student[:, :FRAG_LENGTH_STUDENT] text_feature_student = text_feature_student[:, :FRAG_LENGTH_STUDENT] if IS_TEST == False: text_feature_teacher = paddle.reshape(text_feature_teacher.unsqueeze(0).expand(BS,-1,-1,-1), [BS * BS, FRAG_LENGTH, -1]) vision_feature_teacher = paddle.reshape(vision_feature_teacher.unsqueeze(1).expand(-1,BS,-1,-1), [BS * BS, FRAG_LENGTH, -1]) text_feature_student = paddle.reshape(text_feature_student.unsqueeze(0).expand(BS,-1,-1,-1), [BS * BS, FRAG_LENGTH_STUDENT, 1]) vision_feature_student = paddle.reshape(vision_feature_student.unsqueeze(1).expand(-1,BS,-1,-1), [BS * BS, FRAG_LENGTH_STUDENT, -1]) # Encoding for shiriking if self.MODE == 'shrinktofast': extended_attention_mask_text = text_mask.squeeze()[:, None, None, :] extended_attention_mask_text = (1.0 - extended_attention_mask_text) * -10000.0 extended_attention_mask_vision = vision_mask.squeeze()[:, None, None, :] extended_attention_mask_vision = (1.0 - extended_attention_mask_vision) * -10000.0 text_feature_teacher = self.encoder2.encoder(text_feat,extended_attention_mask_text, head_mask) vision_feature_teacher = self.encoder2.encoder(vision_feat,extended_attention_mask_vision, head_mask) text_feature_teacher = text_feature_teacher[0] vision_feature_teacher = vision_feature_teacher[0] vision_feature_teacher = paddle.nn.functional.normalize(vision_feature_teacher, p=2, axis=2) text_feature_teacher = paddle.nn.functional.normalize(text_feature_teacher, p=2, axis=2) FRAG_LENGTH = self.TEACHER_LENGTH vision_feature_teacher = vision_feature_teacher[:, :FRAG_LENGTH] text_feature_teacher = text_feature_teacher[:, :FRAG_LENGTH] text_feature_student = self.encoder3.encoder(text_feat, extended_attention_mask_text, head_mask) vision_feature_student = self.encoder3.encoder(vision_feat, extended_attention_mask_vision, head_mask) text_feature_student = text_feature_student[0] vision_feature_student = vision_feature_student[0] vision_feature_student = paddle.nn.functional.normalize(vision_feature_student, p=2, axis=2) text_feature_student = paddle.nn.functional.normalize(text_feature_student, p=2, axis=2) FRAG_LENGTH_STUDENT = self.STUDENT_LENGTH vision_feature_student = vision_feature_student[:, :FRAG_LENGTH_STUDENT] text_feature_student = text_feature_student[:, :FRAG_LENGTH_STUDENT] if IS_TEST == False: text_feature_teacher = paddle.reshape(text_feature_teacher.unsqueeze(0).expand(BS,-1,-1,-1), [BS * BS, FRAG_LENGTH, -1]) vision_feature_teacher = paddle.reshape(vision_feature_teacher.unsqueeze(1).expand(-1,BS,-1,-1), [BS * BS, FRAG_LENGTH, -1]) text_feature_student = paddle.reshape(text_feature_student.unsqueeze(0).expand(BS,-1,-1,-1), [BS * BS, FRAG_LENGTH_STUDENT, -1]) vision_feature_student = paddle.reshape(vision_feature_student.unsqueeze(1).expand(-1,BS,-1,-1), [BS * BS, FRAG_LENGTH_STUDENT, -1]) # Scoring and compuing Loss if IS_TEST == False: scores = paddle.reshape(paddle.bmm(text_feature_teacher, paddle.transpose(vision_feature_teacher, [0, 2, 1])), [BS, BS, FRAG_LENGTH, FRAG_LENGTH]) scores_teacher = scores.max(axis=3)[0].sum(2) if self.MODE == 'shrinktobase' or self.MODE == 'shrinktofast': scores = paddle.reshape(paddle.bmm(text_feature_student, paddle.transpose(vision_feature_student, [0, 2, 1])), [BS, BS, FRAG_LENGTH_STUDENT, FRAG_LENGTH_STUDENT]) scores_student = scores.max(axis=3)[0].sum(2) return self.distillation_loss(scores_teacher, scores_student) return self.retrieval_loss(scores_teacher, self.MARGIN, BS) else: scores = paddle.reshape(paddle.bmm(text_feature_teacher, paddle.transpose(vision_feature_teacher, [0, 2, 1])), [BS, FRAG_LENGTH, FRAG_LENGTH]) scores_teacher = scores.max(axis=2)[0].sum(1) if self.MODE == 'shrinktobase' or self.MODE == 'shrinktofast': scores = paddle.reshape(paddle.bmm(text_feature_student, paddle.transpose(vision_feature_student, [0, 2, 1])), [BS, FRAG_LENGTH_STUDENT, FRAG_LENGTH_STUDENT]) scores_student = scores.max(axis=2)[0].sum(1) return scores_student return scores_teacher def retrieval_loss(self, scores, margin, BS): scores_c = paddle.reshape(scores, [BS, BS]) diagonal = paddle.reshape(paddle.diag(scores_c), [scores_c.size(0), 1]) d1 = diagonal.expand_as(scores_c) d2 = diagonal.t().expand_as(scores_c) cost_s = (margin + scores_c - d1).clamp(min=0) cost_im = (margin + scores_c - d2).clamp(min=0) eps = 1e-5 cost_s = cost_s.pow(8).sum(1).add(eps).sqrt().sqrt().sqrt() cost_im = cost_im.pow(8).sum(0).add(eps).sqrt().sqrt().sqrt() cost_c = cost_s.sum() + cost_im.sum() return cost_c def distillation_loss(self, scores_gt, scores_pred): scale1 = self.TEATURE_TEMPERATURE scale2 = self.STUDENT_TEMPERATURE cost_gt1 = paddle.nn.Softmax(paddle.multiply(scores_gt, scale1),axis=1).detach() cost_pre1 = paddle.nn.Softmax(paddle.multiply(scores_pred, scale2),axis=1) cost_gt2 = paddle.nn.Softmax(paddle.multiply(scores_gt, scale1),axis=0).detach() cost_pre2 = paddle.nn.Softmax(paddle.multiply(scores_pred, scale2),axis=0) return -0.5 * paddle.multiply(cost_gt1, cost_pre1.log()).sum() - 0.5 * paddle.multiply(cost_gt2, cost_pre2.log()).sum() ```