CLIP模型微调:实现卡通图像与笑话字幕匹配

2024年09月13日 由 alex 发表 60 0

像CLIP这样的多模型模型通过连接图像和易于理解、生成和解析的文本描述打开了新的AI用例。然而,像CLIP这样的现成模型可能不能代表特定领域中通常遇到的数据,这种情况下可能需要通过微调来将模型适应该领域。


本文介绍了如何在《纽约客》杂志的卡通图片和与之对应的笑话字幕上微调CLIP模型。该工作基于https://www.capcon.dev/,这是一个与《纽约客》卡通比赛相关的各种任务的数据集。其中一个任务是从一系列可能的字幕中选择适当的字幕来描述一张卡通图片。让我们看看如何为这个任务微调CLIP。


数据

数据托管并公开可用于gs://datachain-demo/newyorker_caption_contest,其中包含两个部分:

  1. images:一个包含JPEG文件的文件夹,每个文件表示一张卡通图片。
  2. new_yorker_meta.parquet:一个包含有关图片的元数据的Parquet文件,包括图片的多个选项字幕和正确的字幕选择。


为了处理这些数据,我们将使用开源库datachain,该库有助于将这样的非结构化数据转换成更结构化的格式(免责声明:我参与了datachain的开发)。


首先,我们从数据源读取图像和元数据,然后根据文件名(在元数据中作为列名)将它们进行联接:


from datachain import C, DataChain
from datachain.sql.functions import path


img_dc = DataChain.from_storage("gs://datachain-demo/newyorker_caption_contest/images", type="image", anon=True)
meta_dc = DataChain.from_parquet("gs://datachain-demo/newyorker_caption_contest/new_yorker_meta.parquet")
dc = img_dc.mutate(filename=path.name(C("file.path"))).merge(meta_dc, on="filename")


该代码首先从目录中的图像创建了一个名为img_dc的数据集,存储了每个文件的基本信息,我们稍后将使用这些信息来读取图像。然后,它从元数据的parquet文件创建了一个名为meta_dc的数据集。最后,它基于图像文件名将这两个数据集进行了合并。img_dc包含一个名为file.path的列,其中包含文件的完整路径,而img_dc.mutate(filename=path.name(C("file.path")))只提取了该路径的最后一部分,它与meta_dc中的filename列的内容相匹配。合并后的dc数据集同时包含文件信息和每张图像的元数据。


我们可以通过过滤和收集数据来查看数据的示例,就像这样:


sample = dc.filter(C("file.path").endswith("/371.jpeg")).limit(1)filter(C("file.path").endswith("/371.jpeg")).limit(1)
sample_results = list(sample.collect("file", "caption_choices", "label"))


这将数据限制为以/371.jpeg结尾的图像,并只集"file"、"caption_choices"和"label"三列。结果输出包括一个ImageFile(见下文),一个可能的字幕列表,并且有一个标签表示正确字幕的选项字母。由于每个图像有多行不同的字幕选择,你可能会得到稍微不同的结果。


[(ImageFile(source='gs://datachain-demo', path='newyorker_caption_contest/images/371.jpeg', size=25555, version='1719848719616822', etag='CLaWgOCXhocDEAE=', is_latest=True, last_modified=datetime.datetime(2024, 7, 1, 15, 45, 19, 669000, tzinfo=datetime.timezone.utc), location=None, vtype=''),'gs://datachain-demo', path='newyorker_caption_contest/images/371.jpeg', size=25555, version='1719848719616822', etag='CLaWgOCXhocDEAE=', is_latest=True, last_modified=datetime.datetime(2024, 7, 1, 15, 45, 19, 669000, tzinfo=datetime.timezone.utc), location=None, vtype=''),
  ["I feel like we've gotten a little soft, Lex.",
   "Hold on, the Senate Committee on Women's Health is getting out.",
   "I know a specialist, but he's in prison.",
   'Six rounds. Nine lives. You do the math.',
   'Growth has exceeded our projections.'],
  'D')]


我们可以使用ImageFile对象的read()方法获取图像本身。


应用基础CLIP模型

我们可以将CLIP应用于这些数据,预测每个字幕的可能性。这类似于CLIP的基本架构,它使用对比学习来接收一张图像并从一批文本字幕中找出最可能的字幕(反之亦然)。在训练过程中,CLIP将图像-文本配对作为输入,每个图像都映射到其相应的文本字幕。对于每个批次,CLIP计算每个图像与批次中每个文本的余弦相似度,这样它不仅拥有匹配项的相似性,还有每个不匹配的图像-文本配对的相似性(见下图)。然后,将其视为一个分类问题,其中匹配项被视为正确标签,不匹配项被视为错误标签。在推理过程中,通过将图像和一批字幕输入,可以将其用作零样本预测器,CLIP将返回每个字幕的概率。


2


3


对于卡通数据集,我们可以将示例图像和字幕选项输入,以获得每个选项是正确匹配的概率。以下是代码示例:


import clip
import torch


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = example[0].read()
image = preprocess(image).unsqueeze(0).to(device)
text = clip.tokenize(example[1]).to(device)
logits_per_image, logits_per_text = model(image, text)
logits_per_image.softmax(dim=1)[0]


首先,我们将ViT-B/32预训练模型和图像预处理器加载到设备上。然后,我们将图像转换成预期的张量输入,并对文本字幕进行标记化以进行相同的处理。接下来,我们对这些转换后的输入运行模型,以获得图像与每个文本的逻辑相似度分数,并最终通过softmax函数获得每个文本字幕的相对概率。


输出结果显示,CLIP可以自信地预测出这个示例的正确字幕,因为字幕D(第四个字幕)的概率为0.9844(如果你自己尝试,可能在示例中有不同的字幕选项,可能会得到不同的结果):


tensor([0.0047, 0.0013, 0.0029, 0.9844, 0.0067], grad_fn=<SelectBackward0>)0.0047, 0.0013, 0.0029, 0.9844, 0.0067], grad_fn=<SelectBackward0>)


创建训练数据集

现在我们知道如何应用CLIP来预测字幕,我们可以构建一个用于微调模型的训练数据集。让我们为随机选择的10张图片获取相似度分数(你可以增加到更大的规模,但我们在这里将保持较小的规模,以便在笔记本CPU上快速进行跟随)。以下是代码示例:


from datachain.torch import clip_similarity_scores


train_dc = dc.shuffle().limit(10).save("newyorker_caption_contest_train")
train_dc = train_dc.map(
    func=lambda img_file, txt: clip_similarity_scores(img_file.read(), txt, model, preprocess, clip.tokenize, prob=True)[0],
    params=["file", "caption_choices"],
    output={"scores": list[float]}
)


首先,我们从数据集中随机选择并保存10张图像。然后,我们使用map()方法对每条记录应用一个函数,并将结果保存为一个新列。我们使用名为clip_similarity_scores的实用函数,它在一行中执行前面部分的步骤,以获取字幕的概率。map()函数的输入由params=["file", "caption_choices"]定义,输出列由output={"scores": list[float]}定义。


对于训练,我们还需要正确字幕的真实标签,因此我们再次使用map()函数计算每条记录的正确字幕的索引,以及该字幕的CLIP概率,这样我们就可以看到基准CLIP的性能如何:


import string


def label_ind(label):
    return string.ascii_uppercase.index(label)
def label_prob(scores, label_ind):
    return scores[label_ind]
train_dc = (
    train_dc.map(label_ind, params=["label"], output={"label_ind": int})
    .map(label_prob, params=["scores", "label_ind"], output={"label_prob": float})
)
train_dc = train_dc.save()


我们可以运行train_dc.avg("label_prob")来获取训练样本中正确字幕的平均概率。平均值将取决于训练数据集中的随机样本,但你应该会看到比上述示例图像低得多的值,因此似乎对于基准CLIP来说,其他图像不太容易正确预测。


微调

为了微调CLIP,我们需要创建一个train()函数来循环遍历训练数据并更新模型:


def train(loader, model, optimizer, epochs=5):
    if device == "cuda":
        model = model.float()
    loss_func = torch.nn.CrossEntropyLoss()


    for epoch in range(epochs):
        total_loss = 0
        for images, texts, labels in loader:
            optimizer.zero_grad()
            batch_loss = 0
            for image, text, label in zip(images, texts, labels):
                image = image.to(device).unsqueeze(0)
                text = text.to(device)
                label = label.to(device).unsqueeze(0)
                logits_per_image, logits_per_text = model(image, text)
                batch_loss += loss_func(logits_per_image, label)
            batch_loss.backward()
            optimizer.step()
            batch_loss = batch_loss.item()
            total_loss += batch_loss
        print(f"loss for epoch {epoch}: {total_loss}")


对于每个图像和文本字幕的配对,该函数计算logit相似度分数,使用正确的标签索引应用损失函数,并执行反向传播来更新模型。


这与基本的CLIP工作方式非常相似,只有一个区别。基本的CLIP期望每个批次包含图像-文本配对,其中每个图像都有一个对应的文本,CLIP必须从批次中的其他样本中获取错误的文本进行对比学习(见上图)。使用卡通数据集时,每个图像不仅已经有了对应的正确文本字幕,还有多个不正确的文本字幕。因此,与依赖于批次中的其他样本进行对比学习不同,上述函数仅依赖于为该图像提供的文本字幕选择。


为了将训练数据输入到此函数中,我们需要生成一个PyTorch数据集和数据加载器,并将加载器与优化器一起传递给train()函数:


from torch.utils.data import DataLoader


ds = train_dc.select("file", "caption_choices", "label_ind").to_pytorch(
    transform=preprocess,
    tokenizer=clip.tokenize,
)
loader = DataLoader(ds, batch_size=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train(loader, model, optimizer)


上述代码选择了训练所需的列("file"、"caption_choices"、"label_ind"),然后使用CLIP预处理器和分词器调用to_pytorch(),它返回一个带有预处理的图像张量、标记化的文本和标签索引的PyTorch IterableDataset。接下来,代码创建了一个PyTorch DataLoader和优化器,并将它们传递给train()函数开始训练。


由于我们使用的是一个小数据集,我们可以快速观察到模型适应示例并且损失显著减少:


loss for epoch 0: 5.243085099384018for epoch 0: 5.243085099384018
loss for epoch 1: 6.937912189641793e-05
loss for epoch 2: 0.0006402461804100312
loss for epoch 3: 0.0009484810252615716
loss for epoch 4: 0.00019728825191123178


这应该引起我们对过拟合的警觉,但对于这个练习来说,能够看到train()函数正在按照我们的预期进行学习:从训练数据集中学习正确的字幕。我们可以通过使用微调后的模型计算训练数据中每张图片的正确字幕的预测概率来确认:


train_dc = train_dc.map(map(
    func=lambda img_file, txt: clip_similarity_scores(img_file.read(), txt, model, preprocess, clip.tokenize, prob=True)[0],
    params=["file", "caption_choices"],
    output={"scores_fine_tune": list[float]}
)


train_dc = train_dc.map(label_prob, params=["scores_fine_tune", "label_ind"], output={"label_prob_fine_tune": float})


上述代码与之前用于在微调之前计算概率的代码相同。运行train_dc.avg("label_prob_fine_tune")将输出一个平均预测概率大于0.99的值,看起来微调工作如预期所料。


这是一个人为构建的示例,但希望能够让你了解如何对CLIP进行微调。要以更稳健的方式解决预测正确字幕的任务,你需要选择一个更大的样本,并针对未在训练过程中看到的图像和文本的保留样本进行评估。在尝试时,你可能会发现CLIP在泛化到字幕预测问题上表现不佳,这并不太令人意外,因为CLIP是为了理解图像的内容而不是理解笑话而构建的。CLIP依赖于相对简单的文本编码器,可能值得尝试不同的文本编码器来解决该任务。这超出了微调和本文的范围,但现在你已经知道如何训练CLIP,你可以尝试这个想法,或者根据自己的想法将CLIP适应于多模态的用例中。

文章来源:https://medium.com/@dave_101/you-do-the-math-fine-tuning-multimodal-models-clip-to-match-cartoon-images-to-joke-captions-2f03393e0b80
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消