pytorch安装及LeNet案例
10月10日 望北海投稿 (1)安装pytorchpytorch官网安装网址:https:pytorch。orggetstartedlocally,可选择cuda或者pip安装,只要在命令行上输入安装官网上的指令即可。
pytorch官网安装torchvision
condainstalltorchvision
3,查看pytorch子模块的具体含义
网址:https:pytorch。orgdocsstableindex。html
(2)pytorchdemo
1,pytorchTensor的通道排序:〔batch,channel,height,width〕
(3)案例:搭建LeNet网络
代码详解
LeNet网络文件结构
其中,model是网络模型文件、train是训练文件、predict是预测文件
网络模型model。py文件
importtorch。nnasnn
importtorch。nn。functionalasF
classLeNet(nn。Module):
definit(self):
super(LeNet,self)。init()
self。conv1nn。Conv2d(3,16,5)3:特征层深度16:卷积核个数Filter:55
self。pool1nn。MaxPool2d(2,2)
self。conv2nn。Conv2d(16,32,5)
self。pool2nn。MaxPool2d(2,2)
self。fc1nn。Linear(3255,120)
self。fc2nn。Linear(120,84)
self。fc3nn。Linear(84,10)
defforward(self,x):
xF。relu(self。conv1(x))input(3,32,32)按照公式:N(WF2P)S1(32520)1128output(16,28,28)
xself。pool1(x)output(16,14,14)
xF。relu(self。conv2(x))output(32,10,10)
xself。pool2(x)output(32,5,5)
xx。view(1,3255)output(3255)
xF。relu(self。fc1(x))output(120)
xF。relu(self。fc2(x))output(84)
xself。fc3(x)output(10)
returnx
训练文件:
importmatplotlib。pyplotasplt
importnumpyasnp
importtorch
importtorchvision
importtorch。nnasnn
frommodelimportLeNet
importtorch。optimasoptim
importtorchvision。transformsastransforms
defmain():
transformtransforms。Compose(
〔transforms。ToTensor(),作用:将PILImageornumpy数据转换成tensor数据,即将(HWC)转换成(CHW)
transforms。Normalize((0。5,0。5,0。5),(0。5,0。5,0。5))〕)作用:标准化input〔channel〕(input〔channel〕mean〔channel〕)std〔channel〕
trainsettorchvision。datasets。CIFAR10(root。data,trainTrue,
downloadFalse,transformtransform)download为true,会自动将cifar数据集下载至本地的data文件夹中
trainloadertorch。utils。data。DataLoader(trainset,batchsize36,
shuffleTrue,numworkers0)
valsettorchvision。datasets。CIFAR10(root。data,trainFalse,
downloadFalse,transformtransform)
valloadertorch。utils。data。DataLoader(valset,batchsize5000,
shuffleFalse,numworkers0)
valdataiteriter(valloader)
valimage,vallabelnext(valdataiter)
classes(plane,car,bird,cat,
deer,dog,frog,horse,ship,truck)
显示单张图片
defimshow(img):
imgimg20。5反标准化
npimgimg。numpy()
plt。imshow(np。transpose(npimg,(1,2,0)))
printlabels
print(。join(5sclasses〔vallabel〔j〕〕forjinrange(4)))
showimages
figplt。figure(figsize(7,5))设置画布
foridxinrange(4):
axfig。addsubplot(2,2,idx1,xticks〔〕,yticks〔〕)
imshow(valimage〔idx〕)展示第idx图片
ax。settitle(classes〔vallabel〔idx〕〕)为图片设置标签
plt。show()
netLeNet()
lossfunctionnn。CrossEntropyLoss()
optimizeroptim。Adam(net。parameters(),lr0。001)
forepochinrange(5):loopoverthedatasetmultipletimes
runningloss0。0
forstep,datainenumerate(trainloader,start0):
dataisalistof〔inputs,labels〕
inputs,labelsdata
zerotheparametergradients
optimizer。zerograd()
forwardbackwardoptimize
outputsnet(inputs)
losslossfunction(outputs,labels)
loss。backward()
optimizer。step()
printstatistics
runninglossloss。item()
ifstep500499:printevery500minibatches
withtorch。nograd():
outputsnet(valimage)〔batch,10〕
predictytorch。max(outputs,dim1)〔1〕
accuracytorch。eq(predicty,vallabel)。sum()。item()vallabel。size(0)
print(〔d,5d〕trainloss:。3ftestaccuracy:。3f
(epoch1,step1,runningloss500,accuracy))
runningloss0。0
print(FinishedTraining)
savepath。Lenet。pth
torch。save(net。statedict(),savepath)
ifnamemain:
main()
predict。py预测文件
importtorch
importtorchvision。transformsastransforms
fromPILimportImage
frommodelimportLeNet
defmain():
transformtransforms。Compose(
〔transforms。Resize((32,32)),将图片尺寸变为3232
transforms。ToTensor(),
transforms。Normalize((0。5,0。5,0。5),(0。5,0。5,0。5))〕)
classes(plane,car,bird,cat,
deer,dog,frog,horse,ship,truck)
netLeNet()
net。loadstatedict(torch。load(Lenet。pth))
imImage。open(img。png)
imtransform(im)〔C,H,W〕
imtorch。unsqueeze(im,dim0)〔N,C,H,W〕
withtorch。nograd():
outputsnet(im)
predicttorch。max(outputs,dim1)〔1〕。numpy()
print(classes〔int(predict)〕)
ifnamemain:
main()
想深入学习的可点击如下链接
投诉 评论