轻松学Pytorch-使用ResNet50实现图像分类

栏目: IT技术 · 发布时间: 3年前

Hello大家好,这篇文章给大家详细介绍一下pytorch中最重要的组件torchvision,它包含了常见的数据集、模型架构与预训练模型权重文件、常见图像变换、计算机视觉任务训练。可以是说是pytorch中非常有用的模型迁移学习神器。本文将会介绍如何使用torchvison的预训练模型ResNet50实现图像分类。

模型

Torchvision.models包里面包含了常见的各种基础模型架构,主要包括:

AlexNet VGG ResNet SqueezeNet DenseNet Inception v3 GoogLeNet ShuffleNet v2 MobileNet v2 ResNeXt Wide ResNet MNASNet

这里我选择了ResNet50,基于ImageNet训练的基础网络来实现图像分类, 网络模型下载与加载如下:

model = torchvision.models.resnet50(pretrained=True).eval().cuda()

tf = transforms.Compose([

transforms.Resize(256),

transforms.CenterCrop(224),

transforms.ToTensor(),

transforms.Normalize(

mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225]

)])

使用模型实现图像分类

这里首先需要加载ImageNet的分类标签,目的是最后显示分类的文本标签时候使用。然后对输入图像完成预处理,使用ResNet50模型实现分类预测,对预测结果解析之后,显示标签文本,完整的代码演示如下:

 1with open('imagenet_classes.txt') as f:
2 labels = [line.strip() for line in f.readlines()]
3
4src = cv.imread("D:/images/space_shuttle.jpg") # aeroplane.jpg
5image = cv.resize(src, (224, 224))
6image = np.float32(image) / 255.0
7image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9image = image.transpose((2, 0, 1))
10input_x = torch.from_numpy(image).unsqueeze(0)
11print(input_x.size())
12pred = model(input_x.cuda())
13pred_index = torch.argmax(pred, 1).cpu().detach().numpy()
14print(pred_index)
15print("current predict class name : %s"%labels[pred_index[0]])
16cv.putText(src, labels[pred_index[0]], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
17cv.imshow("input", src)
18cv.waitKey(0)
19cv.destroyAllWindows()

运行结果如下:

轻松学Pytorch-使用ResNet50实现图像分类

转ONNX支持

在torchvision中的模型基本上都可以转换为ONNX格式,而且被OpenCV DNN模块所支持,所以,很方便的可以对torchvision自带的模型转为ONNX,实现OpenCV DNN的调用,首先转为ONNX模型,直接使用torch.onnx.export即可转换(还不知道怎么转,快点看前面的例子)。转换之后使用OpenCV DNN调用的代码如下:

 1with open('imagenet_classes.txt') as f:
2 labels = [line.strip() for line in f.readlines()]
3net = cv.dnn.readNetFromONNX("resnet.onnx")
4src = cv.imread("D:/images/messi.jpg") # aeroplane.jpg
5image = cv.resize(src, (224, 224))
6image = np.float32(image) / 255.0
7image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9blob = cv.dnn.blobFromImage(image, 1.0, (224, 224), (0, 0, 0), False)
10net.setInput(blob)
11probs = net.forward()
12index = np.argmax(probs)
13cv.putText(src, labels[index], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
14cv.imshow("input", src)
15cv.waitKey(0)
16cv.destroyAllWindows()

运行结果见上图,这里就不再贴了。

✄------------------------------------------------

看到这里,说明你喜欢这篇文章,请点击「 在看 」或顺手「 转发 」「 点赞 」。

欢迎微信搜索「 panchuangxx 」,添加小编 磐小小仙 微信,每日朋友圈更新一篇高质量推文(无广告),为您提供更多精彩内容。

▼       扫描二维码添加小编   ▼    ▼  

轻松学Pytorch-使用ResNet50实现图像分类


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

大思维:集体智慧如何改变我们的世界

大思维:集体智慧如何改变我们的世界

杰夫·摩根 / 郭莉玲、尹玮琦、徐强 / 中信出版集团股份有限公司 / 2018-8-1 / CNY 65.00

智能时代,我们如何与机器互联,利用技术来让我们变得更聪明?为什么智能技术不会自动导致智能结果呢?线上线下群体如何协作?社会、政府或管理系统如何解决复杂的问题?本书从哲学、计算机科学和生物学等领域收集见解,揭示了如何引导组织和社会充分利用人脑和数字技术进行大规模思考,从而提高整个集体的智力水平,以解决我们时代的巨大挑战。是英国社会创新之父的洞见之作,解析企业、群体、社会如何明智决策、协作进化。一起来看看 《大思维:集体智慧如何改变我们的世界》 这本书的介绍吧!

URL 编码/解码
URL 编码/解码

URL 编码/解码

SHA 加密
SHA 加密

SHA 加密工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具