Design and Implementation Analysis of BLIP Multimodal Pretraining Model
Most existing approaches fall into two categories: encoder-only or encoder-decoder architectures. Encoder-only models struggle with generation tasks like image captioning, while encoder-decoder variants have not been effectively applied to image-text retrieval. Data-wise, prevailing methods such as CLIP rely on large-scale noisy web-mined pairs, which yield suboptimal results due to irrelevant textual descriptions.
Architecture
BLIP introduces a Multimodal Encoder-Decoder Mixer (MED) to jointly handle comprehension and generation objectives across subtasks.
- Unimodal encoders: Separate encoders process images and text. Text encoder follows BERT, prepending a
[CLS]token to summarize sentence-level semantics. Image encoder uses ViT with a[CLS]token representing global visual context. - Image-conditioned text encoder: Inserts a cross-attention (CA) block between self-attention (SA) and feed-forward layers in each transformer block, injecting visual signals. A special
[Encode]token is appended to text inputs; its embedding serves as the multimodal representation. - Image-conditioned text decoder: Replaces bidirectional SA with causal SA to enforce autoregressive generation. Uses a
[Decode]token to mark sequence start; an end-of-sequence marker terminates decoding.
Training Objcetives
Image-Text Contrastive Loss (ITC)
ITC learns discriminative unimodal embeddings before fusion. It maximizes mutual information between correct image-text pairs via contrastive learning using softened targets.
with torch.no_grad():
temp_val.data.clamp_(0.001, 0.5)
img_vis = img_encoder(img_tensor) # [B, N, D]
img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
proj_img = F.normalize(img_proj(img_vis[:, 0, :]), dim=-1) # [B, d]
txt_data = txt_tokenizer(caption, padding='max_length', truncation=True,
max_length=30, return_tensors="pt").to(img_tensor.device)
txt_enc = txt_encoder(txt_data.input_ids, attention_mask=txt_data.attention_mask,
return_dict=True, mode='text')
proj_txt = F.normalize(txt_proj(txt_enc.last_hidden_state[:, 0, :]), dim=-1)
with torch.no_grad():
_update_momentum()
img_vis_m = img_encoder_m(img_tensor)
proj_img_m = F.normalize(img_proj_m(img_vis_m[:, 0, :]), dim=-1)
all_img_feat = torch.cat([proj_img_m.t(), stored_img_queue.detach()], dim=1)
txt_enc_m = txt_encoder_m(txt_data.input_ids, attention_mask=txt_data.attention_mask,
return_dict=True, mode='text')
proj_txt_m = F.normalize(txt_proj_m(txt_enc_m.last_hidden_state[:, 0, :]), dim=-1)
all_txt_feat = torch.cat([proj_txt_m.t(), stored_txt_queue.detach()], dim=1)
sim_img2txt_m = proj_img_m @ all_txt_feat / temp_val
sim_txt2img_m = proj_txt_m @ all_img_feat / temp_val
target_mat = torch.zeros_like(sim_img2txt_m)
target_mat.fill_diagonal_(1)
tgt_img2txt = alpha * F.softmax(sim_img2txt_m, dim=1) + (1 - alpha) * target_mat
tgt_txt2img = alpha * F.softmax(sim_txt2img_m, dim=1) + (1 - alpha) * target_mat
sim_img2txt = proj_img @ all_txt_feat / temp_val
loss_i2t = -torch.sum(F.log_softmax(sim_img2txt, dim=1) * tgt_img2txt, dim=1).mean()
sim_txt2img = proj_txt @ all_img_feat / temp_val
loss_t2i = -torch.sum(F.log_softmax(sim_txt2img, dim=1) * tgt_txt2img, dim=1).mean()
loss_itc = (loss_i2t + loss_t2i) / 2
_enqueue_features(proj_img_m, proj_txt_m)
Image-Text Matching Loss (ITM)
ITM trains a binary classifier over multimodal embeddings to predict whether a pair matches. Hard negatives are mined from ITC similarity distributions.
etxt_ids = txt_data.input_ids.clone()
etxt_ids[:, 0] = txt_tokenizer.enc_token_id
pos_out = txt_encoder(etxt_ids, attention_mask=txt_data.attention_mask,
encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
return_dict=True)
with torch.no_grad():
w_txt2img = F.softmax(sim_txt2img[:, :bs], dim=1) + 1e-4
w_txt2img.fill_diagonal_(0)
w_img2txt = F.softmax(sim_img2txt[:, :bs], dim=1) + 1e-4
w_img2txt.fill_diagonal_(0)
neg_img_list = [img_vis[torch.multinomial(w_txt2img[b], 1)] for b in range(bs)]
neg_img_batch = torch.stack(neg_img_list, dim=0)
neg_txt_ids, neg_txt_mask = [], []
for b in range(bs):
idx = torch.multinomial(w_img2txt[b], 1)
neg_txt_ids.append(etxt_ids[idx])
neg_txt_mask.append(txt_data.attention_mask[idx])
neg_txt_ids = torch.stack(neg_txt_ids, dim=0)
neg_txt_mask = torch.stack(neg_txt_mask, dim=0)
comb_txt_ids = torch.cat([etxt_ids, neg_txt_ids], dim=0)
comb_txt_mask = torch.cat([txt_data.attention_mask, neg_txt_mask], dim=0)
comb_img = torch.cat([neg_img_batch, img_vis], dim=0)
comb_img_mask = torch.cat([img_mask, img_mask], dim=0)
neg_out = txt_encoder(comb_txt_ids, attention_mask=comb_txt_mask,
encoder_hidden_states=comb_img, encoder_attention_mask=comb_img_mask,
return_dict=True)
cat_emb = torch.cat([pos_out.last_hidden_state[:, 0, :],
neg_out.last_hidden_state[:, 0, :]], dim=0)
pred_match = itm_classifier(cat_emb)
gt_label = torch.cat([torch.ones(bs, dtype=torch.long),
torch.zeros(2 * bs, dtype=torch.long)], dim=0).to(img_tensor.device)
loss_itm = F.cross_entropy(pred_match, gt_label)
Language Modeling Loss (LM)
LM trains the decoder to generate coherent captions given an image, optimizing autoregressive likelihood via cross-entropy.
dec_in = txt_data.input_ids.clone()
dec_in[:, 0] = txt_tokenizer.bos_token_id
dec_trg = dec_in.masked_fill(dec_in == txt_tokenizer.pad_token_id, -100)
dec_out = txt_decoder(dec_in, attention_mask=txt_data.attention_mask,
encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
labels=dec_trg, return_dict=True)
loss_lm = dec_out.loss
Downstream Applications
CapFilt
Generates cleaner training data by synthesizing captions for web images and filtering out mismatched pairs. Both components fine-tune MED separately on COCO: captioner via LM, filter via ITC and ITM. Filtered synthetic and human-labeled pairs form an improved pretraining corpus.
Feature Extraction
def extract(mode, img_tensor, caption=None):
if mode == 'image':
return img_encoder(img_tensor)
elif mode == 'text':
txt_out = txt_encoder(txt_tokenizer(caption, return_tensors='pt').input_ids.to(img_tensor.device),
return_dict=True, mode='text')
return txt_out.last_hidden_state
elif mode == 'multimodal':
img_vis = img_encoder(img_tensor)
img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
cap_ids = txt_tokenizer(caption, return_tensors='pt').input_ids.to(img_tensor.device)
cap_ids[:, 0] = txt_tokenizer.enc_token_id
out = txt_encoder(cap_ids, attention_mask=cap_ids.ne(0),
encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
return_dict=True)
return out.last_hidden_state
Image-Text Matching Inference
def match(img_tensor, caption, head_type):
img_vis = img_encoder(img_tensor)
img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
txt_data = txt_tokenizer(caption, padding='max_length', truncation=True,
max_length=35, return_tensors='pt').to(img_tensor.device)
if head_type == 'itm':
txt_data.input_ids[:, 0] = txt_tokenizer.enc_token_id
enc_out = txt_encoder(txt_data.input_ids, attention_mask=txt_data.attention_mask,
encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
return_dict=True)
score = itm_classifier(enc_out.last_hidden_state[:, 0, :])
return torch.softmax(score, dim=1)[:, 1]
else:
txt_out = txt_encoder(txt_data.input_ids, attention_mask=txt_data.attention_mask,
return_dict=True, mode='text')
img_feat = F.normalize(img_proj(img_vis[:, 0, :]), dim=-1)
txt_feat = F.normalize(txt_proj(txt_out.last_hidden_state[:, 0, :]), dim=-1)
return (img_feat @ txt_feat.t()).item()
Image Captioning Generation
def generate_caption(img_tensor, use_sampling=False, num_beams=3, max_len=30, min_len=10, top_p=0.9, rep_penalty=1.0):
img_vis = img_encoder(img_tensor)
img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
kwargs = {"encoder_hidden_states": img_vis, "encoder_attention_mask": img_mask}
prompt_ids = txt_tokenizer(['a picture of'] * img_tensor.size(0), return_tensors='pt').input_ids.to(img_tensor.device)
prompt_ids[:, 0] = txt_tokenizer.bos_token_id
prompt_ids = prompt_ids[:, :-1]
if use_sampling:
seq = txt_decoder.generate(input_ids=prompt_ids, max_length=max_len, min_length=min_len,
do_sample=True, top_p=top_p, num_return_sequences=1,
eos_token_id=txt_tokenizer.sep_token_id,
pad_token_id=txt_tokenizer.pad_token_id,
repetition_penalty=1.1, **kwargs)
else:
seq = txt_decoder.generate(input_ids=prompt_ids, max_length=max_len, min_length=min_len,
num_beams=num_beams, eos_token_id=txt_tokenizer.sep_token_id,
pad_token_id=txt_tokenizer.pad_token_id,
repetition_penalty=rep_penalty, **kwargs)
return [txt_tokenizer.decode(s, skip_special_tokens=True)[len('a picture of'):] for s in seq]
Visual Question Answering
def vqa_forward(img_tensor, question, answer=None, train=False, infer_mode='generate'):
img_vis = img_encoder(img_tensor)
img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
q_ids = txt_tokenizer(question, padding='longest', truncation=True, max_length=35,
return_tensors='pt').input_ids.to(img_tensor.device)
q_ids[:, 0] = txt_tokenizer.enc_token_id
if train:
ans_ids = txt_tokenizer(answer, padding='longest', return_tensors='pt').input_ids.to(img_tensor.device)
ans_ids[:, 0] = txt_tokenizer.bos_token_id
ans_trg = ans_ids.masked_fill(ans_ids == txt_tokenizer.pad_token_id, -100)
q_out = txt_encoder(q_ids, attention_mask=q_ids.ne(0), encoder_hidden_states=img_vis,
encoder_attention_mask=img_mask, return_dict=True)
# repeat per answer count, omitted here for brevity
ans_out = txt_decoder(ans_ids, attention_mask=ans_ids.ne(0),
encoder_hidden_states=question_states,
encoder_attention_mask=question_att_mask,
labels=ans_trg, return_dict=True, reduction='none')
loss = (ans_weights * ans_out.loss).sum() / img_tensor.size(0)
return loss
else:
q_out = txt_encoder(q_ids, attention_mask=q_ids.ne(0), encoder_hidden_states=img_vis,
encoder_attention_mask=img_mask, return_dict=True)
if infer_mode == 'generate':
beams = 3
q_states = q_out.last_hidden_state.repeat_interleave(beams, dim=0)
q_att = torch.ones(q_states.shape[:-1], dtype=torch.long, device=q_states.device)
gen_kwargs = {"encoder_hidden_states": q_states, "encoder_attention_mask": q_att}
bos = torch.full((img_tensor.size(0), 1), txt_tokenizer.bos_token_id, device=img_tensor.device)
seq = txt_decoder.generate(input_ids=bos, max_length=10, min_length=1,
num_beams=beams, eos_token_id=txt_tokenizer.sep_token_id,
pad_token_id=txt_tokenizer.pad_token_id, **gen_kwargs)
return [txt_tokenizer.decode(s, skip_special_tokens=True) for s in seq]