模型:
FredZhang7/google-safesearch-mini-v2
Google Safesearch Mini V2在训练过程中采用了不同的方法,使用了InceptionResNetV2架构以及大约340万张从互联网随机获得的图片数据集,其中一些图片是通过数据增强生成的。训练和验证数据来自Google Images、Reddit、Kaggle和Imgur,由公司、Google SafeSearch和版主对其进行了安全或不安全的分类。
在将模型训练了5个轮次并在训练集和验证集上进行了评估以确定预测概率低于0.90的图片后,对策划的数据集进行了必要的修正,并在额外训练了8个轮次。接下来,我对模型进行了各种可能难以分类的情况进行了测试,并观察到它将棕色猫的皮毛误认为人体皮肤。为了提高准确性,我使用了来自Kaggle的15个附加数据集对模型进行了微调,然后在最后一个轮次中使用训练和测试数据的组合进行了训练。这使得训练和验证数据上的准确率达到了97%。
Safesearch过滤器不仅是社交媒体的良好工具,还可以用于过滤数据集。与稳定扩散安全检查器相比,该模型具有重要的优势-用户可以节省1.0 GB的RAM和磁盘空间。
pip install --upgrade torchvision
import torch, os
from torchvision import transforms
from PIL import Image
import urllib.request
import timm
image_path = "https://www.allaboutcats.ca/wp-content/uploads/sites/235/2022/03/shutterstock_320462102-2500-e1647917149997.jpg"
device = "cuda"
def preprocess_image(image_path):
# Define image pre-processing transforms
transform = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if image_path.startswith('http://') or image_path.startswith('https://'):
import requests
from io import BytesIO
response = requests.get(image_path)
img = Image.open(BytesIO(response.content)).convert('RGB')
else:
img = Image.open(image_path).convert('RGB')
img = transform(img).unsqueeze(0)
img = img.cuda() if device.lower() == "cuda" else img.cpu()
return img
def eval():
model = timm.create_model("hf_hub:FredZhang7/google-safesearch-mini-v2", pretrained=True)
model.to(device)
img = preprocess_image(image_path)
with torch.no_grad():
out = model(img)
_, predicted = torch.max(out.data, 1)
classes = {
0: 'nsfw_gore',
1: 'nsfw_suggestive',
2: 'safe'
}
print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')
if __name__ == '__main__':
eval()