Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Design and Implementation Analysis of BLIP Multimodal Pretraining Model

Tech 1

Most existing approaches fall into two categories: encoder-only or encoder-decoder architectures. Encoder-only models struggle with generation tasks like image captioning, while encoder-decoder variants have not been effectively applied to image-text retrieval. Data-wise, prevailing methods such as CLIP rely on large-scale noisy web-mined pairs, which yield suboptimal results due to irrelevant textual descriptions.

Architecture

BLIP introduces a Multimodal Encoder-Decoder Mixer (MED) to jointly handle comprehension and generation objectives across subtasks.

  • Unimodal encoders: Separate encoders process images and text. Text encoder follows BERT, prepending a [CLS] token to summarize sentence-level semantics. Image encoder uses ViT with a [CLS] token representing global visual context.
  • Image-conditioned text encoder: Inserts a cross-attention (CA) block between self-attention (SA) and feed-forward layers in each transformer block, injecting visual signals. A special [Encode] token is appended to text inputs; its embedding serves as the multimodal representation.
  • Image-conditioned text decoder: Replaces bidirectional SA with causal SA to enforce autoregressive generation. Uses a [Decode] token to mark sequence start; an end-of-sequence marker terminates decoding.

Training Objcetives

Image-Text Contrastive Loss (ITC)

ITC learns discriminative unimodal embeddings before fusion. It maximizes mutual information between correct image-text pairs via contrastive learning using softened targets.

with torch.no_grad():
    temp_val.data.clamp_(0.001, 0.5)

img_vis = img_encoder(img_tensor)              # [B, N, D]
img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
proj_img = F.normalize(img_proj(img_vis[:, 0, :]), dim=-1)   # [B, d]

txt_data = txt_tokenizer(caption, padding='max_length', truncation=True,
                         max_length=30, return_tensors="pt").to(img_tensor.device)
txt_enc = txt_encoder(txt_data.input_ids, attention_mask=txt_data.attention_mask,
                       return_dict=True, mode='text')
proj_txt = F.normalize(txt_proj(txt_enc.last_hidden_state[:, 0, :]), dim=-1)

with torch.no_grad():
    _update_momentum()
    img_vis_m = img_encoder_m(img_tensor)
    proj_img_m = F.normalize(img_proj_m(img_vis_m[:, 0, :]), dim=-1)
    all_img_feat = torch.cat([proj_img_m.t(), stored_img_queue.detach()], dim=1)

    txt_enc_m = txt_encoder_m(txt_data.input_ids, attention_mask=txt_data.attention_mask,
                              return_dict=True, mode='text')
    proj_txt_m = F.normalize(txt_proj_m(txt_enc_m.last_hidden_state[:, 0, :]), dim=-1)
    all_txt_feat = torch.cat([proj_txt_m.t(), stored_txt_queue.detach()], dim=1)

    sim_img2txt_m = proj_img_m @ all_txt_feat / temp_val
    sim_txt2img_m = proj_txt_m @ all_img_feat / temp_val

    target_mat = torch.zeros_like(sim_img2txt_m)
    target_mat.fill_diagonal_(1)

    tgt_img2txt = alpha * F.softmax(sim_img2txt_m, dim=1) + (1 - alpha) * target_mat
    tgt_txt2img = alpha * F.softmax(sim_txt2img_m, dim=1) + (1 - alpha) * target_mat

sim_img2txt = proj_img @ all_txt_feat / temp_val
loss_i2t = -torch.sum(F.log_softmax(sim_img2txt, dim=1) * tgt_img2txt, dim=1).mean()

sim_txt2img = proj_txt @ all_img_feat / temp_val
loss_t2i = -torch.sum(F.log_softmax(sim_txt2img, dim=1) * tgt_txt2img, dim=1).mean()

loss_itc = (loss_i2t + loss_t2i) / 2
_enqueue_features(proj_img_m, proj_txt_m)

Image-Text Matching Loss (ITM)

ITM trains a binary classifier over multimodal embeddings to predict whether a pair matches. Hard negatives are mined from ITC similarity distributions.

etxt_ids = txt_data.input_ids.clone()
etxt_ids[:, 0] = txt_tokenizer.enc_token_id

pos_out = txt_encoder(etxt_ids, attention_mask=txt_data.attention_mask,
                      encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
                      return_dict=True)

with torch.no_grad():
    w_txt2img = F.softmax(sim_txt2img[:, :bs], dim=1) + 1e-4
    w_txt2img.fill_diagonal_(0)
    w_img2txt = F.softmax(sim_img2txt[:, :bs], dim=1) + 1e-4
    w_img2txt.fill_diagonal_(0)

neg_img_list = [img_vis[torch.multinomial(w_txt2img[b], 1)] for b in range(bs)]
neg_img_batch = torch.stack(neg_img_list, dim=0)

neg_txt_ids, neg_txt_mask = [], []
for b in range(bs):
    idx = torch.multinomial(w_img2txt[b], 1)
    neg_txt_ids.append(etxt_ids[idx])
    neg_txt_mask.append(txt_data.attention_mask[idx])
neg_txt_ids = torch.stack(neg_txt_ids, dim=0)
neg_txt_mask = torch.stack(neg_txt_mask, dim=0)

comb_txt_ids = torch.cat([etxt_ids, neg_txt_ids], dim=0)
comb_txt_mask = torch.cat([txt_data.attention_mask, neg_txt_mask], dim=0)
comb_img = torch.cat([neg_img_batch, img_vis], dim=0)
comb_img_mask = torch.cat([img_mask, img_mask], dim=0)

neg_out = txt_encoder(comb_txt_ids, attention_mask=comb_txt_mask,
                      encoder_hidden_states=comb_img, encoder_attention_mask=comb_img_mask,
                      return_dict=True)

cat_emb = torch.cat([pos_out.last_hidden_state[:, 0, :],
                     neg_out.last_hidden_state[:, 0, :]], dim=0)
pred_match = itm_classifier(cat_emb)

gt_label = torch.cat([torch.ones(bs, dtype=torch.long),
                      torch.zeros(2 * bs, dtype=torch.long)], dim=0).to(img_tensor.device)
loss_itm = F.cross_entropy(pred_match, gt_label)

Language Modeling Loss (LM)

LM trains the decoder to generate coherent captions given an image, optimizing autoregressive likelihood via cross-entropy.

dec_in = txt_data.input_ids.clone()
dec_in[:, 0] = txt_tokenizer.bos_token_id
dec_trg = dec_in.masked_fill(dec_in == txt_tokenizer.pad_token_id, -100)

dec_out = txt_decoder(dec_in, attention_mask=txt_data.attention_mask,
                      encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
                      labels=dec_trg, return_dict=True)
loss_lm = dec_out.loss

Downstream Applications

CapFilt

Generates cleaner training data by synthesizing captions for web images and filtering out mismatched pairs. Both components fine-tune MED separately on COCO: captioner via LM, filter via ITC and ITM. Filtered synthetic and human-labeled pairs form an improved pretraining corpus.

Feature Extraction

def extract(mode, img_tensor, caption=None):
    if mode == 'image':
        return img_encoder(img_tensor)
    elif mode == 'text':
        txt_out = txt_encoder(txt_tokenizer(caption, return_tensors='pt').input_ids.to(img_tensor.device),
                              return_dict=True, mode='text')
        return txt_out.last_hidden_state
    elif mode == 'multimodal':
        img_vis = img_encoder(img_tensor)
        img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
        cap_ids = txt_tokenizer(caption, return_tensors='pt').input_ids.to(img_tensor.device)
        cap_ids[:, 0] = txt_tokenizer.enc_token_id
        out = txt_encoder(cap_ids, attention_mask=cap_ids.ne(0),
                          encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
                          return_dict=True)
        return out.last_hidden_state

Image-Text Matching Inference

def match(img_tensor, caption, head_type):
    img_vis = img_encoder(img_tensor)
    img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
    txt_data = txt_tokenizer(caption, padding='max_length', truncation=True,
                             max_length=35, return_tensors='pt').to(img_tensor.device)
    if head_type == 'itm':
        txt_data.input_ids[:, 0] = txt_tokenizer.enc_token_id
        enc_out = txt_encoder(txt_data.input_ids, attention_mask=txt_data.attention_mask,
                              encoder_hidden_states=img_vis, encoder_attention_mask=img_mask,
                              return_dict=True)
        score = itm_classifier(enc_out.last_hidden_state[:, 0, :])
        return torch.softmax(score, dim=1)[:, 1]
    else:
        txt_out = txt_encoder(txt_data.input_ids, attention_mask=txt_data.attention_mask,
                              return_dict=True, mode='text')
        img_feat = F.normalize(img_proj(img_vis[:, 0, :]), dim=-1)
        txt_feat = F.normalize(txt_proj(txt_out.last_hidden_state[:, 0, :]), dim=-1)
        return (img_feat @ txt_feat.t()).item()

Image Captioning Generation

def generate_caption(img_tensor, use_sampling=False, num_beams=3, max_len=30, min_len=10, top_p=0.9, rep_penalty=1.0):
    img_vis = img_encoder(img_tensor)
    img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
    kwargs = {"encoder_hidden_states": img_vis, "encoder_attention_mask": img_mask}
    prompt_ids = txt_tokenizer(['a picture of'] * img_tensor.size(0), return_tensors='pt').input_ids.to(img_tensor.device)
    prompt_ids[:, 0] = txt_tokenizer.bos_token_id
    prompt_ids = prompt_ids[:, :-1]
    if use_sampling:
        seq = txt_decoder.generate(input_ids=prompt_ids, max_length=max_len, min_length=min_len,
                                   do_sample=True, top_p=top_p, num_return_sequences=1,
                                   eos_token_id=txt_tokenizer.sep_token_id,
                                   pad_token_id=txt_tokenizer.pad_token_id,
                                   repetition_penalty=1.1, **kwargs)
    else:
        seq = txt_decoder.generate(input_ids=prompt_ids, max_length=max_len, min_length=min_len,
                                   num_beams=num_beams, eos_token_id=txt_tokenizer.sep_token_id,
                                   pad_token_id=txt_tokenizer.pad_token_id,
                                   repetition_penalty=rep_penalty, **kwargs)
    return [txt_tokenizer.decode(s, skip_special_tokens=True)[len('a picture of'):] for s in seq]

Visual Question Answering

def vqa_forward(img_tensor, question, answer=None, train=False, infer_mode='generate'):
    img_vis = img_encoder(img_tensor)
    img_mask = torch.ones(img_vis.shape[:-1], dtype=torch.long, device=img_tensor.device)
    q_ids = txt_tokenizer(question, padding='longest', truncation=True, max_length=35,
                          return_tensors='pt').input_ids.to(img_tensor.device)
    q_ids[:, 0] = txt_tokenizer.enc_token_id
    if train:
        ans_ids = txt_tokenizer(answer, padding='longest', return_tensors='pt').input_ids.to(img_tensor.device)
        ans_ids[:, 0] = txt_tokenizer.bos_token_id
        ans_trg = ans_ids.masked_fill(ans_ids == txt_tokenizer.pad_token_id, -100)
        q_out = txt_encoder(q_ids, attention_mask=q_ids.ne(0), encoder_hidden_states=img_vis,
                            encoder_attention_mask=img_mask, return_dict=True)
        # repeat per answer count, omitted here for brevity
        ans_out = txt_decoder(ans_ids, attention_mask=ans_ids.ne(0),
                              encoder_hidden_states=question_states,
                              encoder_attention_mask=question_att_mask,
                              labels=ans_trg, return_dict=True, reduction='none')
        loss = (ans_weights * ans_out.loss).sum() / img_tensor.size(0)
        return loss
    else:
        q_out = txt_encoder(q_ids, attention_mask=q_ids.ne(0), encoder_hidden_states=img_vis,
                            encoder_attention_mask=img_mask, return_dict=True)
        if infer_mode == 'generate':
            beams = 3
            q_states = q_out.last_hidden_state.repeat_interleave(beams, dim=0)
            q_att = torch.ones(q_states.shape[:-1], dtype=torch.long, device=q_states.device)
            gen_kwargs = {"encoder_hidden_states": q_states, "encoder_attention_mask": q_att}
            bos = torch.full((img_tensor.size(0), 1), txt_tokenizer.bos_token_id, device=img_tensor.device)
            seq = txt_decoder.generate(input_ids=bos, max_length=10, min_length=1,
                                       num_beams=beams, eos_token_id=txt_tokenizer.sep_token_id,
                                       pad_token_id=txt_tokenizer.pad_token_id, **gen_kwargs)
            return [txt_tokenizer.decode(s, skip_special_tokens=True) for s in seq]

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.