英文

全面注意力网络(Holistic Attention Network,HAN)

HAN模型在DIV2K上进行了预训练(800个图像训练,增强到4000个图像,100个图像验证),用于2x、3x和4x图像超分辨率。该模型由Niu等人在2020年的论文 Single Image Super-Resolution via a Holistic Attention Network 中提出,并于 this repository 首次发布。

图像超分辨率的目标是从单个低分辨率(LR)图像恢复高分辨率(HR)图像。下图显示了真实图像(HR)、双三次上采样和模型上采样。

模型描述

信息特征在单图像超分辨率任务中起着关键作用。通道注意力已被证明对于保留每层中的信息丰富特征是有效的。然而,通道注意力将每个卷积层视为单独的处理过程,忽略了不同层之间的相关性。为了解决这个问题,我们提出了一种新的全面注意力网络(Holistic Attention Network,HAN),它包括层注意力模块(Layer Attention Module,LAM)和通道-空间注意力模块(Channel-Spatial Attention Module,CSAM),以建模层、通道和位置之间的整体相互依赖关系。具体而言,所提出的LAM通过考虑层之间的相关性自适应地强调层次特征。同时,CSAM学习每个通道的所有位置的置信度,以选择性地捕获更多信息丰富的特征。大量实验证明,所提出的HAN与最先进的单图像超分辨率方法相比表现优越。

使用目的和限制

您可以使用预训练模型将图像放大2x、3x和4x。您还可以使用训练器在自己的数据集上训练模型。

如何使用

您可以使用 super_image 库使用该模型:

pip install super-image

以下是如何使用预训练模型放大图像的示例:

from super_image import HanModel, ImageLoader
from PIL import Image
import requests

url = 'https://paperswithcode.com/media/datasets/Set5-0000002728-07a9793f_zA3bDjj.jpg'
image = Image.open(requests.get(url, stream=True).raw)

model = HanModel.from_pretrained('eugenesiow/han', scale=2)      # scale 2, 3 and 4 models available
inputs = ImageLoader.load_image(image)
preds = model(inputs)

ImageLoader.save_image(preds, './scaled_2x.png')                        # save the output 2x scaled image to `./scaled_2x.png`
ImageLoader.save_compare(inputs, preds, './scaled_2x_compare.png')      # save an output comparing the super-image with a bicubic scaling

训练数据

2x、3x和4x图像超分辨率的模型是在 DIV2K 上进行预训练的,该数据集包含800张高质量(2K分辨率)的训练图像,通过增强生成了4000张图像,并使用了100张验证图像(图像编号为801到900)。

训练过程

预处理

我们遵循 Wang et al. 的预处理和训练方法。通过使用双三次插值作为调整大小方法,将高分辨率(HR)图像的大小缩小2倍、3倍和4倍,创建低分辨率(LR)图像。在训练过程中,使用LR输入的大小为64×64的RGB块以及它们对应的HR块。在预处理阶段,对训练集进行数据增强,从原始图像的四个角和中心创建五个图像。

我们需要使用huggingface datasets 库来下载数据:

pip install datasets

以下代码获取数据并预处理/增强数据。

from datasets import load_dataset
from super_image.data import EvalDataset, TrainDataset, augment_five_crop

augmented_dataset = load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='train')\
    .map(augment_five_crop, batched=True, desc="Augmenting Dataset")                                # download and augment the data with the five_crop method
train_dataset = TrainDataset(augmented_dataset)                                                     # prepare the train dataset for loading PyTorch DataLoader
eval_dataset = EvalDataset(load_dataset('eugenesiow/Div2k', 'bicubic_x4', split='validation'))      # prepare the eval dataset for the PyTorch DataLoader

预训练

该模型是在GPU上训练的。下面提供了训练代码:

from super_image import Trainer, TrainingArguments, HanModel, HanConfig

training_args = TrainingArguments(
    output_dir='./results',                 # output directory
    num_train_epochs=1000,                  # total number of training epochs
)

config = HanConfig(
    scale=4,                                # train a model to upscale 4x
)
model = HanModel(config)

trainer = Trainer(
    model=model,                         # the instantiated model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=eval_dataset            # evaluation dataset
)

trainer.train()

评估结果

评估指标包括 PSNR SSIM

评估数据集包括:

以下结果列以PSNR/SSIM表示,与双三次基准进行比较:

|数据集 |比例 |双三次 |HAN ||--- |--- |--- |--- ||Set5 |2x |33.64/0.9292 |**** ||Set5 |3x |30.39/0.8678 |**** ||Set5 |4x |28.42/0.8101 | 31.21/0.8778 ||Set14 |2x |30.22/0.8683 |**** ||Set14 |3x |27.53/0.7737 |**** ||Set14 |4x |25.99/0.7023 | 28.18/0.7712 ||BSD100 |2x |29.55/0.8425 |**** ||BSD100 |3x |27.20/0.7382 |**** ||BSD100 |4x |25.96/0.6672 | 28.09/0.7533 ||Urban100 |2x |26.66/0.8408 |**** ||Urban100 |3x | |**** ||Urban100 |4x |23.14/0.6573 | 25.1/0.7497 |

您可以在预训练模型上找到一个笔记本,以便轻松进行评估:

BibTeX条目和引文信息

@misc{niu2020single,
      title={Single Image Super-Resolution via a Holistic Attention Network}, 
      author={Ben Niu and Weilei Wen and Wenqi Ren and Xiangde Zhang and Lianping Yang and Shuzhen Wang and Kaihao Zhang and Xiaochun Cao and Haifeng Shen},
      year={2020},
      eprint={2008.08767},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
}