IMG2TEXT-Part2. OFA, CLIP Interrogator and ViT

Whale Learning
21 min readMay 14, 2023

Continuing from Part 1, we are going to look into the CLIP Interrogator, OFA model, and ViT model and ensemble them. Most of the codes are from the notebook and I modified and added explanations.

OFA

Imports

# Using the pre compiled wheel since we don't have internet on submission
!pip install -q /kaggle/input/stable-diffusion-data/transformers-4.18.0.dev0-py3-none-any.whl

The wheel file that contains a specific version of the Transformers library is installed.

import os
import sys
import glob
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import OFATokenizer, OFAModel
from transformers.models.ofa.generate import sequence_generator

import gc

It imports necessary libraries such as NumPy, Pandas, Matplotlib, PIL, Torch, and the OFATokenizer and OFAModel classes from the Transformers library. It also imports sequence_generator from the generated module of OFA.

Config

We will perform image captioning with pre-trained OFA on the test images.

CKPT_DIR = "/kaggle/input/stable-diffusion-data/OFA-large-caption/"
IMAGE_DIR = "/kaggle/input/stable-diffusion-image-to-prompts/images"

# Print the names of files in CKPT_DIR folder
print(f'Files in CKPT_DIR folder:')
for filename in os.listdir(CKPT_DIR):
print(filename)

# Print the names of files in IMAGE_DIR folder
print(f'\nFiles in IMAGE_DIR folder:')
for filename in os.listdir(IMAGE_DIR):
print(filename)


'''
Files in CKPT_DIR folder:
.git
.gitattributes
README.md
config.json
merges.txt
vocab.json
pytorch_model.bin

Files in IMAGE_DIR folder:
c98f79f71.png
92e911621.png
a4e1c55a9.png
f27825b2c.png
d8edf2e40.png
20057f34d.png
227ef0887.png
'''

OFA-large-caption is the large version of the OFA model fine-tuned for image captioning. OFA is a pretrained model that unifies modalities and tasks to a simple sequence-to-sequence learning framework. The package includes four files, namely config.json, vocab.json, merge.txt, and pytorch_model.bin, which respectively contain model configuration, tokenizer information, model weights, and merge information. The package is designed to address potential mismatches between Fairseq and transformers.

Loading the pre-trained OFA model

We use the OFA-large-caption model to generate captions for an image. The mean and std variables are used to normalize the input image to a standard format. The patch_resize_transform function uses various transformations to resize and normalize the input image. The tokenizer and model variables load the pre-trained OFA model and tokenizer from the specified checkpoint directory. Finally, the inputs variable tokenizes the given input text using the loaded tokenizer, and the model generates the image caption for the given input image and text using beam search.

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
#mean, std = [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]
resolution = 480
patch_resize_transform = transforms.Compose([
lambda image: image.convert("RGB"),
transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])

tokenizer = OFATokenizer.from_pretrained(CKPT_DIR)
model = OFAModel.from_pretrained(CKPT_DIR, use_cache=False).cuda()
txt = " what does the image describe?"
inputs = tokenizer([txt], return_tensors="pt").input_ids

Model EDA

I tried captioning the photos on my Instagram. The for loop iterates through each of the 4 images. In each iteration, the image is opened, resized to 500x500 pixels, transformed using a patch resize transform, and passed to the image captioning model to generate a caption.

# Load 4 PNG images from /content/drive/MyDrive/23_diffusion/data/myimages
sample_images = glob.glob('/content/drive/MyDrive/23_diffusion/data/myimages/*.png')[:4]

# Create a subplot with a square layout to show the 4 images and their captions
fig, ax = plt.subplots(2, 2, figsize=(8, 8))

# Loop through each image and generate a caption using the model
for i, impath in enumerate(sample_images):
# Open the image and resize it to 500x500
image = Image.open(impath).resize((500, 500))
# Transform the image and run it through the model to generate a caption
image_t = patch_resize_transform(image).cuda().unsqueeze(0)
out = model.generate(inputs.cuda(), patch_images=image_t.cuda(), num_beams=5, no_repeat_ngram_size=2)
out_captions = tokenizer.batch_decode(out, skip_special_tokens=True)
# Show the image and caption in the corresponding subplot
row = i // 2
col = i % 2
ax[row, col].imshow(image)
ax[row, col].text(0.5, -0.1, out_captions[0], horizontalalignment='center', verticalalignment='center', transform=ax[row, col].transAxes)
ax[row, col].set_xticks([])
ax[row, col].set_yticks([])

# Adjust the layout of the subplots
plt.tight_layout()Here’s a brief explanation of generating captions for an input image.

A patch resize transformation is applied to the input image, moves it to the GPU using the “cuda()” method, and adds a batch dimension using the “unsqueeze” method. The “unsqueeze(0)” method adds an extra dimension to the image tensor to create a batch of size 1. The resulting tensor is stored in the “image_t” variable.

out = model.generate(inputs.cuda(), patch_images=image_t.cuda(), num_beams=5, no_repeat_ngram_size=2)

  • “inputs.cuda()”: The initial input sequence for the image captioning model, moved to the GPU.
  • “patch_images=image_t.cuda()”: The transformed input image tensor, also moved to the GPU.
  • “num_beams=5”: The number of beams to use during decoding, which determines how many possible sequences the model considers at each step.
  • “no_repeat_ngram_size=2”: The maximum size of n-grams that cannot be repeated in the generated sequence.

The generated sequence of tokens is decoded into a text caption using the tokenizer’s “batch_decode” method. It takes the following arguments:

  • “out”: The generated sequence of tokens.
  • “skip_special_tokens=True”: Whether to remove special tokens (e.g., start and end of sequence tokens) from the decoded text. The resulting text caption is stored in the “out_captions” variable.

Inference

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('../input/stable-diffusion-image-to-prompts/')
st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')

The st_model variable creates an instance of the SentenceTransformer class, which is initialized with a pre-trained model called all-MiniLM-L6-v2.

class ImageGen(Dataset):
def __init__(self, root, batch_size=32):
self.root = root
self.im_paths = os.listdir(self.root)
self.batch_size = batch_size
self.sz = len(self.im_paths)
self.genlen = self.sz//self.batch_size + int(self.sz%self.batch_size > 0)

def __getitem__(self, index):
if index >= self.genlen:
raise IndexError("Out of bounds")

l, r = index*self.batch_size, min(self.sz, (index+1)*self.batch_size)

f_paths = [os.path.join(self.root, self.im_paths[i]) for i in range(l,r)]
f_ids = [self.im_paths[i][:-4] for i in range(l,r)]
ims = [Image.open(f_path) for f_path in f_paths]
ims = [patch_resize_transform(im).cuda().unsqueeze(0) for im in ims]
ims = torch.cat(ims)

return ims, f_ids

def __len__(self):
return self.genlen

ImageGen is a PyTorch dataset that generates batches of images. The __init__ method takes the root directory where the images are stored and the batch size as inputs. It initializes some variables such as self.im_paths, which is a list of all the image file names in the directory, self.batch_size, and self.sz, which is the total number of images. self.genlen is the total number of batches that can be generated given the batch size.

The __getitem__ method takes an index as input and returns a batch of images and their respective IDs. The index is used to determine which batch to return. First, the method computes the range of image file names that correspond to the batch.

f_ids is a list comprehension that creates a list of image file names without their extensions. self.im_paths is a list of all the image file names in the directory. self.im_paths[i] retrieves the i-th image file name, and [:-4] is a slice notation that returns all the characters of the string except for the last four.

It then opens each image file using Image.open, resizes, and transform the images. A list of PyTorch tensors is created by applying patch_resize_transform() and cuda() functions on each image in the ims list, and then adding a new dimension to the tensor using unsqueeze(0). The unsqueeze(0) function adds a new dimension at the beginning of the tensor, effectively converting a 3-dimensional tensor (representing an image with height, width, and channels) to a 4-dimensional tensor (with an additional dimension for the batch size). This is necessary because most deep learning models expect input data in batches. torch.cat is concatenating all PyTorch tensors in the ims list into a single tensor along the first (batch) dimension. Finally, it returns the tensor of images and their IDs.

The __len__ method returns the total number of batches that can be generated using this dataset.

from gensim.parsing.preprocessing import remove_stopwords

sub_ids = []
sub_embeds = []

imgen = ImageGen(IMAGE_DIR, BATCH_SIZE)

for b in imgen:
for j in range(len(b[1])):
sub_ids.extend([f"{b[1][j]}_{i}" for i in range(384)])

img_batch = b[0]
out = model.generate(inputs.repeat(len(img_batch), 1).cuda(), patch_images=img_batch, num_beams=5, no_repeat_ngram_size=2)
out_captions = tokenizer.batch_decode(out, skip_special_tokens=True)
out_captions = [remove_stopwords(text) for text in out_captions]

embeddings = st_model.encode(out_captions).flatten()
sub_embeds.extend(embeddings)

The code loops over each batch, and for each image in the batch, it generates a unique ID by concatenating the original image ID with an index ranging from 0 to 383. sub_ids list is extended by adding a list of 384 unique IDs generated from the image batch b[1][j] and the range of numbers from 0 to 383.

The number 384 is determined by the fact that each input image is divided into patches of size 16x16, and each patch generates a corresponding caption. Since each input image contains 24 patches (6x4), the model generates 24 captions for each input image. Therefore, for each generated caption, a unique ID is created by concatenating the image ID (b[1][j]) with the index of the patch (i), resulting in a total of 384 unique IDs (24 captions per patch x 16 patches per image)

Therefore, for each generated caption, a unique ID is created by concatenating the image ID b[1][j] with the index of the patch i. For example, let's say we have an input image with an ID of image_001. The loop will iterate through each patch and create a unique ID for each generated caption. So, for the first patch, the IDs will be image_001_0, image_001_1, image_001_2, and so on until image_001_23. Then, the loop will move on to the next image and repeat the process, creating a total of 384 unique IDs for all generated captions.

Generating captions for each image patch in the current batch of images, inputs.repeat(len(img_batch), 1) repeats the input sequence inputs (which should be the start token) for each image in the batch. The resulting tensor is of shape (batch_size, sequence_length). Setting num_beams=5 means that the model considers the top 5 candidate sequences at each step and continues to build all of them until they are all complete.

After generating the captions for the input images, the next step is to preprocess them before generating embeddings. tokenizer.batch_decode is used to decode the generated captions into human-readable text by removing the special tokens. Then, the remove_stopwords function is applied to each caption text to remove any stopwords, which are common words that do not carry much meaning (e.g. "the", "a", "and").

Once the captions have been preprocessed, they are passed through a pre-trained sentence-transformers model (st_model) to generate 384-dimensional embeddings for each caption. The encode method of the st_model is used for this purpose. The resulting embeddings are flattened to a 1-dimensional vector and added to the sub_embeds list using the extend method.

The sub_embeds list will eventually contain embeddings for all of the generated captions, which can be used for downstream tasks such as clustering or retrieval.

embeddings1 = np.array(sub_embeds)
embeddings1.shape


'''
Output: (2688,)
'''


embeddings1_list = embeddings1[:100].tolist()
print('The Fitst Embeddings:\n',embeddings1_list)

'''
The Fitst Embeddings:
[-0.0561382882297039, 0.05179676041007042, -0.010111154057085514, 0.009372996166348457, -0.08668708056211472, 0.02042912133038044, 0.05539735034108162, 0.07150783389806747, 0.037339936941862106, -0.046097010374069214, 0.012080357410013676, -0.018545761704444885, 0.022009504958987236, -0.008211500942707062, 0.009951111860573292, -0.045629262924194336, 0.0670464038848877, 0.018947778269648552, 0.11232531070709229, -0.025967668741941452, 0.0038218721747398376, 0.08071696013212204, 0.07032617181539536, 0.0011246695648878813, -0.02383158914744854, -0.08036167919635773, -0.009118749760091305, 0.0202063899487257, 0.10231679677963257, 0.005335730500519276, 0.05832022801041603, 0.009123396128416061, -0.02153872326016426, 0.022793728858232498, -0.05844773352146149, -0.10858351737260818, 0.05635203793644905, 0.03066965751349926, -0.0220758318901062, 0.02577693574130535, 0.007247749250382185, -0.033000536262989044, -0.028591414913535118, -0.027001265436410904, 0.0009703710093162954, 0.08752570301294327, 0.0021870543714612722, -0.011224618181586266, -0.04847567901015282, 0.00015174485452007502, 0.041111256927251816, 0.017190856859087944, 0.09749161452054977, 0.012434051372110844, 0.03300405666232109, 0.00894326064735651, 0.03134206682443619, -0.03647271543741226, 0.023139864206314087, 0.0797228217124939, -0.022575175389647484, -0.02784871496260166, 0.00744879525154829, 0.043413709849119186, 0.004357294179499149, 0.002728877356275916, -0.08391834795475006, 0.08548334985971451, 0.013969373889267445, 0.058525536209344864, -0.02816159464418888, 0.03193134069442749, 0.06430048495531082, -0.03445425257086754, -0.01881301961839199, -0.03835446760058403, -0.009387059137225151, -0.008853904902935028, -0.001368737081065774, 0.04366125538945198, -0.04213671386241913, -0.006765689235180616, -0.039187368005514145, 0.0076749068684875965, -0.06753271818161011, -0.07330788671970367, 0.012115951627492905, 0.05949762836098671, -0.01170273870229721, -0.055939510464668274, -0.06842418015003204, 0.022529324516654015, -0.03895615413784981, 0.03687118738889694, -0.01273308601230383, -0.058583661913871765, -0.030613863840699196, 0.028653480112552643, -0.05504542216658592, 0.015734845772385597]
'''

The sub_embeds list contains the embeddings for each of the generated captions so the total number of embeddings is 384 x 7 = 2688.

del model, tokenizer, st_model
torch.cuda.empty_cache()
gc.collect()

In order to free up memory, we can delete variables and models that are no longer needed, clear the GPU memory cache, and run the garbage collector.

CLIP Interrogator

Install & Import all dependencies

# Install clip_interrogator
wheels_path = "/kaggle/input/clip-interrogator-wheels-x"
clip_interrogator_whl_path = f"{wheels_path}/clip_interrogator-0.4.3-py3-none-any.whl"

clip_interrogator from a wheel file is installed.

!pip install --no-index --find-links $wheels_path $clip_interrogator_whl_path -q

The !pip install command is used to install the package, and the --no-index flag is used to prevent pip from searching PyPI (Python Package Index) for the package. Instead, pip will only look for the package in the specified directory using the --find-links flag.

!pip list | grep transformers

The grep command is a Unix utility that searches for a specified pattern in a file or input stream, and outputs any lines that contain the pattern. The command !pip list | grep transformers will show a list of all installed Python packages that have "transformers" in their name.

#import inspect
#import importlib

from blip.models import blip
from clip_interrogator import clip_interrogator

blip.models.blip and clip_interrogator.clip_interrogator are imported.

# replace tokenizer path to prevent downloading
blip_path = inspect.getfile(blip)

fin = open(blip_path, "rt")
data = fin.read()
data = data.replace(
"BertTokenizer.from_pretrained('bert-base-uncased')",
"BertTokenizer.from_pretrained('/kaggle/input/clip-interrogator-models-x/bert-base-uncased')"
)
fin.close()

fin = open(blip_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(blip)

A tokenizer path was replaced to prevent downloading and fix a bug in the clip_interrogator module.

# fix clip_interrogator bug
clip_interrogator_path = inspect.getfile(clip_interrogator.Interrogator)

fin = open(clip_interrogator_path, "rt")
data = fin.read()
data = data.replace(
'open_clip.get_tokenizer(clip_model_name)',
'open_clip.get_tokenizer(config.clip_model_name.split("/", 2)[0])'
)
fin.close()

fin = open(clip_interrogator_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(clip_interrogator)

clip_interrogator.Interrogator module and replaces a line of code that causes a bug in the module.

import open_clip

sys.path.append('../input/sentence-transformers-222/sentence-transformers')

#from sentence_transformers import SentenceTransformer, models

comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')

We’re importing the open_clip library and defining a Path object comp_path.

Set configs

class CFG:
device = "cuda"
seed = 42
embedding_length = 384
sentence_model_path = "/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2"
blip_model_path = "/kaggle/input/clip-interrogator-models-x/model_large_caption.pth"
ci_clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
clip_model_name = "ViT-H-14"
clip_model_path = "/kaggle/input/clip-interrogator-models-x/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
cache_path = "/kaggle/input/clip-interrogator-models-x"

CFG defines several class-level variables with different values. Here is a description of each variable:

  • seed: an integer indicating the random seed used for reproducibility, which is set to 42.
  • embedding_length: an integer indicating the length of the embeddings used in the models, which is set to 384.
  • sentence_model_path: the file path to the SentenceTransformer model used for text embeddings
  • blip_model_path: the file path to the BLIP model used for image embedding
  • ci_clip_model_name: the name of the OpenAI CLIP model used for image-text retrieval, which is set to "ViT-H-14/laion2b_s32b_b79k".
  • clip_model_name: the name of the OpenAI CLIP model used for image-text retrieval, which is set to "ViT-H-14".
  • clip_model_path: the file path to the OpenAI CLIP model used for image-text retrieval.
  • cache_path: file path to the directory used for caching the downloaded models, a cache is a component that stores data so that future requests for that data can be served faster.

Build index from images

#  read the sample submission file with the index set to 'imgId_eId'
df_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')

# get the list of all images in the 'images' directory
images = os.listdir(comp_path / 'images')

# extract image ids from the list of images
imgIds = [i.split('.')[0] for i in images]

# create a list of numbers from 0 to CFG.embedding_length
eIds = list(range(CFG.embedding_length))

# create a list of 'imgId_eId' pairs for each image id and embedding id
imgId_eId = [
'_'.join(map(str, i)) for i in zip(
np.repeat(imgIds, CFG.embedding_length),
np.tile(range(CFG.embedding_length), len(imgIds))
)
]

# check that the list of 'imgId_eId' pairs matches the index in the sample submission file
assert sorted(imgId_eId) == sorted(df_submission.index)

The index column is set as ‘imgId_eId’. It then gets the list of image filenames in the images folder and extracts the image IDs by removing the file extensions. It creates a list of embedding IDs as integers from 0 to ‘CFG.embedding_length’ value. Then it generates a list of ‘imgId_eId’ by combining image IDs and embedding IDs using the ‘_join’ method and ‘np.repeat’ and ‘np.tile’ functions.

np.repeat() is a NumPy function that repeats each element of an input array a specified number of times along a given axis. For example, np.repeat([1, 2, 3], 2) would return the array [1, 1, 2, 2, 3, 3].

np.tile() is a NumPy function that constructs a new array by repeating a given array a specified number of times along a specified axis. For example, np.tile([1, 2], (2, 2)) would return the 2D array [[1, 2, 1, 2], [1, 2, 1, 2]].

Prepare CLIP interrogator tool

Define CLIP interrogator config

model_config = clip_interrogator.Config(clip_model_name=CFG.ci_clip_model_name)
model_config.cache_path = CFG.cache_path

The model_config variable is set as an instance of the Config class from the clip_interrogator module. It is initialized with the clip_model_name attribute set to the value of CFG.ci_clip_model_name, which is set to ViT-H-14/laion2b_s32b_b79k. Then, the cache_path attribute of model_config is set to the value of CFG.cache_path.

Define BLIP model

configs_path = os.path.join(os.path.dirname(os.path.dirname(blip_path)), 'configs')

med_config = os.path.join(configs_path, 'med_config.json')

blip_model = blip.blip_decoder(
pretrained=CFG.blip_model_path,
image_size=model_config.blip_image_eval_size,
vit=model_config.blip_model_type,
med_config=med_config
)
blip_model.eval()
blip_model = blip_model.to(model_config.device)
model_config.blip_model = blip_model

A pre-trained BLiP model is loaded from the specified path and set to evaluation mode and moved to the device specified in the configuration. The configuration also includes the path to a JSON configuration file for a Medium model. med_config.json is a configuration file that contains hyperparameters and settings for the BLIP model which is then assigned to the blip_model attribute of the model_config object.

Define CLIP model

clip_model = open_clip.create_model(CFG.clip_model_name, precision='fp16' if model_config.device == 'cuda' else 'fp32')
open_clip.load_checkpoint(clip_model, CFG.clip_model_path)
clip_model.to(model_config.device).eval()
model_config.clip_model = clip_model

CLIP model is called using the create_model() function from the open_clip module, with the specified CFG.clip_model_name and precision ('fp16' if the model_config.device is 'cuda', and 'fp32' otherwise). The load_checkpoint() function is then used to load the model weights from the file specified in CFG.clip_model_path. Finally, the model is moved to the device specified in model_config.device and put in evaluation mode using to() and eval() functions respectively. The resulting CLIP model is stored in model_config.clip_model.

clip_preprocess = open_clip.image_transform(
clip_model.visual.image_size,
is_train = False,
mean = getattr(clip_model.visual, 'image_mean', None),
std = getattr(clip_model.visual, 'image_std', None),
)
model_config.clip_preprocess = clip_preprocess

We initialize the clip_preprocess variable, which is a data preprocessing pipeline for images to be used with the CLIP model. The open_clip.image_transform() function is called with the following arguments:

  • clip_model.visual.image_size: the size of the input images expected by the CLIP model
  • is_train=False: specifies that this preprocessing pipeline is not meant for training data
  • mean=getattr(clip_model.visual, 'image_mean', None): specifies the mean pixel values to subtract from the input image. If clip_model.visual has an attribute called image_mean, it will be used. Otherwise, None is passed as the default value.
  • std=getattr(clip_model.visual, 'image_std', None): specifies the standard deviation to divide the input image by. If clip_model.visual has an attribute called image_std, it will be used. Otherwise, None is passed as the default value.

The resulting preprocessing pipeline is then stored in the model_config.clip_preprocess variable.

getattr() is a built-in Python function that returns the value of a named attribute of an object. It takes two arguments: the object and the name of the attribute. If the named attribute does not exist, getattr() can also return a default value, specified as a third argument.

ci = clip_interrogator.Interrogator(model_config)

The ci object is an instance of this class that will be used to perform various operations on the image and caption embeddings. It takes the model_config object, which contains the configuration for the CLIP model and other related models and paths, as input.

Define interrogate function

Get labels embeddings

A different approach was used instead of the original CLIP Interrogator to find the similarity between the image and text label. Rather than using matrix multiplication between image features and text embeddings, cosine similarity was used instead as it was found to be faster and produced similar results.

cos = torch.nn.CosineSimilarity(dim=1)

mediums_features_array = torch.stack([torch.from_numpy(t) for t in ci.mediums.embeds]).to(ci.device)
movements_features_array = torch.stack([torch.from_numpy(t) for t in ci.movements.embeds]).to(ci.device)
flavors_features_array = torch.stack([torch.from_numpy(t) for t in ci.flavors.embeds]).to(ci.device)

Here, a PyTorch cosine similarity function is defined. Three arrays of features extracted from the CLIP model for the mediums, movements, and flavors are stacked together and converted to a PyTorch tensor, and then moved to the device being used by the model (CPU or GPU) using the to method.

The “medium” and “movement” are the output of ranking models that predict the style of the artwork and the art movement, respectively, based on the image features. The “flavors” are the output of the flavor ranking model, which predicts the associated adjectives that describe the artwork.

The rank method is used in the previous code to find the closest label for each category (medium, movement, flavor) for a given image. The method uses the cosine similarity between the image features and the pre-computed label features to find the most similar label. By ranking the labels based on their similarity to the image features, the method selects the top-ranked label as the most appropriate label for the given image.

Create main interrogation function

def interrogate(image: Image) -> str:
caption = ci.generate_caption(image)
image_features = ci.image_to_features(image)

medium = [ci.mediums.labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
movement = [ci.movements.labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
flaves = ", ".join([ci.flavors.labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])

if caption.startswith(medium):
prompt = f"{caption}, {movement}, in the style of {flaves}"
else:
prompt = f"{caption}, {medium}, {movement}, in the style of {flaves}"

return clip_interrogator._truncate_to_fit(prompt, ci.tokenize)

interrogate takes an image and returns a string. Here's what the code does:

  1. Generates a caption for the image using the generate_caption method of the clip_interrogator object.
  2. Converts the image to a feature vector using the image_to_features method of the clip_interrogator object.
  3. Computes the cosine similarity between the image feature vector and the feature vectors of the mediums, movements, and flavors using the cos function defined earlier and the mediums_features_array, movements_features_array, and flavors_features_array arrays respectively.
  4. Retrieves the label with the highest similarity score from each category using the topk method and index the labels attribute of the corresponding category object (ci.mediums.labels, ci.movements.labels, ci.flavors.labels) to get the name of the category.
  5. Joins the top three flavor names with a comma using the join method.
  6. Constructs a string prompt by concatenating the caption, medium, movement, and flavor information, with the format "caption, medium, movement, in the style of flavor1, flavor2, flavor3", or "caption, movement, in the style of flavor1, flavor2, flavor3" if the caption starts with the medium.
  7. Returns the prompt, truncated to fit a certain length using the _truncate_to_fit method of the clip_interrogator object.

Extract prompt from images

prompts = []

images_path = "../input/stable-diffusion-image-to-prompts/images/"

for image_name in images:
img = Image.open(images_path + image_name).convert("RGB")

generated = interrogate(img)

prompts.append(generated)

The interrogate() function is applied to generate prompts for each image. The resulting prompts are then appended to a list called prompts.

def add_text_limiters(text: str) -> str:
return " ".join([
word + "\\n" if i % 15 == 0 else word
for i, word in enumerate(text.split(" "), start=1)
])

This function takes a string text and adds newline characters every 15 words.

def plot_image(image: np.ndarray, original_prompt: str, generated_prompt: str) -> None:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.annotate(
"Original prompt:\\n" + add_text_limiters(original_prompt) + "\\n\\nGenerated prompt:\\n" + add_text_limiters(generated_prompt),
xy=(1.05, 0.5), xycoords='axes fraction', ha='left', va='center',
fontsize=16, rotation=0, color="#104a6e"
)

The plot_image function takes an image and two prompts as input and plots the image along with the two prompts as annotations. The add_text_limiters function is used to add line breaks to the prompts, making them easier to read in the plot.

original_prompts_df = pd.read_csv("/kaggle/input/stable-diffusion-image-to-prompts/prompts.csv")

for image_name, prompt in zip(images, prompts):
img = Image.open(images_path + image_name).convert("RGB")
original_prompt = original_prompts_df[
original_prompts_df.imgId == image_name.split(".")[0]
].prompt.iloc[0]
plot_image(img, original_prompt, prompt)

The original prompts for each image in the dataset are iterated through each image in the images list and its corresponding generated prompt in the prompts list. For each image, the original prompt is extracted from the CSV file using the image ID as the key. The plot_image function is then called with the original prompt and generated prompt as inputs, along with the image itself.

 clip_interrogator_prompts = []

images_path = '/kaggle/input/cxr-samples/'

image_files = glob.glob(images_path + '*.PNG')

for image_file in image_files:
img = Image.open(image_file).convert("RGB")
generated = interrogate(img)
clip_interrogator_prompts.append(generated)

# Create a DataFrame from the generated prompts
prompts_df = pd.DataFrame({'Image Path': image_files, 'Caption': clip_interrogator_prompts})

# Save the DataFrame as a CSV file
prompts_df.to_csv('/kaggle/working/clip_int_prompts.csv', index=False)

# Print the prompts
print("List of CLIP Interrogator Prompts:")
for prompt in clip_interrogator_prompts:
print(prompt)

Let me compare this with the case of chest X-rays.

def plot_image(image: np.ndarray, ofa_prompts: str, clip_int_prompts: str) -> None:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.annotate(
"OFA prompts:\n" + add_text_limiters(ofa_prompts) + "\n\nCLIP Interrogator:\n" + add_text_limiters(clip_int_prompts),
xy=(1.05, 0.5), xycoords='axes fraction', ha='left', va='center',
fontsize=16, rotation=0, color="black" # Set the font color to black
)
plt.axis('off') # Turn off the axes

# Read the prompts from the ofa_prompts_df.csv file
ofa_prompts_df = pd.read_csv("/kaggle/working/ofa_prompts_df.csv")
clip_int_prompts_df = pd.read_csv('/kaggle/working/clip_int_prompts.csv')

# Select four images from the /kaggle/input/cxr-samples/*.PNG directory
sample_images = glob.glob('/kaggle/input/cxr-samples/*.PNG')

for image_path in sample_images:
image_name = os.path.basename(image_path)
prompt = clip_int_prompts_df[clip_int_prompts_df['Image Path'].str.contains(image_name)].iloc[0]['Caption']

img = Image.open(image_path).convert("RGB")
ofa_prompt = ofa_prompts_df[ofa_prompts_df['Image Path'].str.contains(image_name)].iloc[0]['Caption']

# Resize the image to a square
size = max(img.size)
new_img = Image.new("RGB", (size, size))
new_img.paste(img, ((size - img.size[0]) // 2, (size - img.size[1]) // 2))

plot_image(new_img, ofa_prompt, prompt)
plt.show() # Display the plot

As you can see, the prompts from OFA are more detailed.

Create a sample submission

st_model = SentenceTransformer(CFG.sentence_model_path)
embeddings2 = st_model.encode(prompts).flatten()

SentenceTransformer library is used to encode the generated prompts into a numerical representation. The encode() method takes a list of sentences and returns a matrix where each row corresponds to the embedding of one sentence. The resulting embeddings will have the same dimensionality for all prompts.

Release GPU resource

del ci
del blip_model, clip_model
del st_model
torch.cuda.empty_cache()
gc.collect()

Ensemble with ViT

Config for Ensemble

The ratios are typically chosen based on the performance of each individual model on a validation set. Models that perform well on the validation set are assigned a higher weight in the ensemble, while models that perform poorly are assigned a lower weight or excluded altogether. By adjusting the ratios, it is possible to balance the contribution of each model in the ensemble and obtain the best possible overall performance.

Since we will use three models, we need to set ratios of each model.

ratio_ViT_B_16          = 0.26 #0.25
ratio_CLIP_Interrogator = 0.15
ratio_OFA = 0.09

In this case, the sum of the ratios is 0.5, which means that the ensemble prediction is a weighted combination of the predictions from each model, with the weights determined by the corresponding ratios. The specific values of the ratios depend on the performance of each model and the desired balance of their contributions to the final prediction, and they do not necessarily need to add up to 1.

Ensemble OFA and CLIP Interrogator

embeddings12 = ratio_OFA * embeddings1 + ratio_CLIP_Interrogator * embeddings2

The embeddings1 and embeddings2 are concatenated and then combined using ratios ratio_OFA and ratio_CLIP_Interrogator respectively.

del embeddings1
del embeddings2
gc.collect()

ViT-B-16

Imports

#import numpy as np
#import pandas as pd
#from pathlib import Path
#from PIL import Image
from tqdm.notebook import tqdm
#import torch
from torch.utils.data import Dataset, DataLoader
#from torchvision import transforms
import timm
from sklearn.preprocessing import normalize

Configuration

class CFG:
model_path = '/kaggle/input/saved-vit-model/vit_large_patch16_384_1_64_0.0001_0.6564.pth'
model_name = 'vit_large_patch16_384'
model_path1 = '/kaggle/input/mystable-diffusion-vit-baseline-train/vit_base_patch16_224.pth'
model_name1 = 'vit_base_patch16_224'
input_size1 = 224
input_size = 384
batch_size = 64

It sets the paths and names for the two models vit_large_patch16_384 and vit_base_patch16_224 and the corresponding input sizes and batch sizes for the models.

Dataset

class DiffusionTestDataset(Dataset):
def __init__(self, images, transform):
self.images = images
self.transform = transform

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
image = Image.open(self.images[idx])
image = self.transform(image)
return image

This is a class definition for a PyTorch dataset that is used to test the Diffusion model. The dataset takes in a list of image file paths and a transform function as inputs.

Prediction

def predict(
images,
model_path,
model_name,
input_size,
batch_size
):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.Resize(input_size),
transforms.RandomHorizontalFlip(p=0.5),
#transforms.RandomRotation(degrees=10),

#transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
dataset = DiffusionTestDataset(images, transform)
dataloader = DataLoader(
dataset=dataset,
shuffle=False,
batch_size=batch_size,
pin_memory=True,
num_workers=2,
drop_last=False
)

model = timm.create_model(
model_name,
pretrained=False,
num_classes=384
)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

tta_preds = None
for _ in range(2):
preds = []
for X in tqdm(dataloader, leave=False):
X = X.to(device)

with torch.no_grad():
X_out = model(X).cpu().numpy()
# L2 normalize -- Start
X_out = X_out / ( np.abs(X_out).max(axis=-1, keepdims=True) + 0.0000001) # To avoid to overflow at normalize()
X_out = normalize( X_out )
# L2 normalize -- End
preds.append(X_out)

if tta_preds is None:
tta_preds = np.vstack(preds).flatten()
else:
tta_preds += np.vstack(preds).flatten()

return tta_preds / 2

The predict function takes in several arguments:

  • images: a list of file paths to images that will be predicted on.
  • model_path: the path to a PyTorch model's saved weights.
  • model_name: the name of a pre-trained model that will be loaded using timm.
  • input_size: the size of the input images that will be resized.
  • batch_size: the batch size to use for inference.

The function first sets the device to either CPU or GPU based on availability, then applies several transformations to the input images including resizing, random horizontal flipping, and normalization. The images are then loaded into a PyTorch dataset and dataloader.

The function then loads a pre-trained PyTorch model using timm and the given model_name, and loads the saved weights at the given model_path. The model is then moved to the appropriate device and set to evaluation mode.

The function then applies test-time augmentation (TTA) by running two passes over the dataset with a set of horizontal flips. During each pass, the model makes predictions on batches of images and stores the predictions. After both passes, the predictions are averaged and returned as the final output of the function.

Test-time augmentation (TTA)

TTA is a technique used during the testing phase of a model to improve its accuracy by creating new predictions from multiple augmentations of the same image.

In this case, random horizontal flipping is applied to the input images during the test time augmentation (TTA) step. Specifically, the DiffusionTestDataset class uses transforms.RandomHorizontalFlip(p=0.5) in the transform method, which randomly flips the input image horizontally with a probability of 0.5. This is repeated twice in the predict function to increase the robustness of the predictions.

images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
imgIds = [i.stem for i in images]
EMBEDDING_LENGTH = 384
imgId_eId = [
'_'.join(map(str, i)) for i in zip(
np.repeat(imgIds, EMBEDDING_LENGTH),
np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

embeddings3 = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)
embeddings4 = predict(images, CFG.model_path1, CFG.model_name1, CFG.input_size1, CFG.batch_size)

A list of image paths is defined and their respective embeddings are extracted using two different models, which are loaded from pre-trained weights (vit_large_patch16_384 and vit_base_patch16_224). The embeddings are flattened and concatenated into a single feature vector for each image. The resulting feature vectors are used as inputs to the final diffusion model, which generates text prompts.

Final Ensemble and Submission

embeddings = embeddings12 + ratio_ViT_B_16 * embeddings3 + embeddings4*0.1

The final embeddings for each image are calculated using the previously calculated embeddings (embeddings1, embeddings2) and the embeddings obtained from two different ViT models (embeddings3 and embeddings4).

submission = pd.DataFrame(
index=imgId_eId,
data=embeddings,
columns=['val']
).rename_axis('imgId_eId')

submission.to_csv('submission.csv')

Finally, the embeddings are stored in a pandas dataframe and saved as a csv file named submission.csv. The index of the dataframe consists of all possible combinations of image ids and embedding indices (imgId_eId) and the column val contains the corresponding embedding values.

Thank you for reading!

--

--