预训练图像分类模型预测
此为 datawhale 的公开教程
教程地址:github
1. 调用 pytorch 中 model 加载模型
model = models.resnet18(pretrained=True)
model = model.eval()
model = model.to(device)
Notes:
- model.eval() 通常在对模型进行验证时需要设置的。此设置的目的是:在模型中有BatchNormal以及Dropout层时,取消BN和Dropout层的效果,以达到所有数据都进行测试的效果。
- BatchNormal 以及 Dropout 均为正则化手段,一定程度上可以处理过拟合的情况
- model.to(device) 是将模型转移到指定的设备上
2. 图像预处理以及测试图片
from torchvision import transforms
from PIL import Image
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
img_pil = Image.open(img_path)
Notes:
- tansformams 可以对图像进行一系列处理包括,设置图片大小,将图片对象转为 Tensor 对象等
- 使用 Image 打开图片
3. 调用摄像头获取图像(视频)
import cv2
import time# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)# 打开cap
cap.open(0)# 无限循环,直到break被触发
while cap.isOpened():# 获取画面success, frame = cap.read()if not success:print('Error')break## !!!处理帧函数frame = process_frame(frame)# 展示处理后的三通道图像cv2.imshow('my_window',frame)if cv2.waitKey(1) in [ord('q'),27]: # 按键盘上的q或esc退出(在英文输入法下)break# 关闭摄像头
cap.release()# 关闭图像窗口
cv2.destroyAllWindows()
Notes:
- 使用 cv2 库调用摄像头