深度学习Matlab工具箱代码详解概览

最近研究了几天深度学习的Matlab工具箱代码,发现作者给出的源码中注释实在是少得可怜,为了方便大家阅读,特对代码进行了注释,与大家分享。

  在阅读Matlab工具箱代码之前,建议大家阅读几篇CNN方面的两篇经典材料.

(1)《Notes on Convolutional Neural Networks》,这篇文章是与Matlab工具箱代码配套的文献,不过文献中在下采样层也有两种训练参数,在工具箱中的下采样层并没有可训练参数,直接进行下采样操作。

(2)《CNN学习-薛开宇》,这是与《Notes on Convolutional Neural Networks》内容及其相似的一份中文PPT资料,对卷积神经网络的介绍也是通俗易懂。

(3)深度学习的Matlab工具箱Github下载地址:https://github.com/rasmusbergpalm/DeepLearnToolbox

接下来给出一个工具箱中CNN程序在Mnist数据库上的示例程序:

  1. %%=========================================================================

  2. % 主要功能:在mnist数据库上做实验,验证工具箱的有效性

  3. % 算法流程:1)载入训练样本和测试样本

  4. % 2)设置CNN参数,并进行训练

  5. % 3)进行检测cnntest()

  6. % 注意事项:1)由于直接将所有测试样本输入会导致内存溢出,故采用一次只测试一个训练样本的测试方法

  7. %%=========================================================================

  8. %%

  9. %%%%%%%%%%%%%%%%%%%%加载数据集%%%%%%%%%%%%%%%%%%%%

  10. load mnist_uint8;

  11. train_x = double(reshape(train_x',28,28,60000))/255;

  12. test_x = double(reshape(test_x',28,28,10000))/255;

  13. train_y = double(train_y');

  14. test_y = double(test_y');

  15.  

  16. %%

  17. %%=========================================================================

  18. %%%%%%%%%%%%%%%%%%%%设置卷积神经网络参数%%%%%%%%%%%%%%%%%%%%

  19. % 主要功能:训练一个6c-2s-12c-2s形式的卷积神经网络,预期性能如下:

  20. % 1)迭代一次需要200秒左右,错误率大约为11%

  21. % 2)迭代一百次后错误率大约为1.2%

  22. % 算法流程:1)构建神经网络并进行训练,以CNN结构体的形式保存

  23. % 2)用已知的训练样本进行测试

  24. % 注意事项:1)之前在测试的时候提示内存溢出,后来莫名其妙的又不溢出了,估计到了系统的内存临界值

  25. %%=========================================================================

  26. rand('state',0)

  27. cnn.layers = {

  28. struct('type', 'i') %输入层

  29. struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5) %卷积层

  30. struct('type', 's', 'scale', 2) %下采样层

  31. struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) %卷积层

  32. struct('type', 's', 'scale', 2) %下采样层

  33. };

  34. cnn = cnnsetup(cnn, train_x, train_y);

  35. opts.alpha = 1;

  36. opts.batchsize = 50;

  37. opts.numepochs = 5;

  38. cnn = cnntrain(cnn, train_x, train_y, opts);

  39. save CNN_5 cnn;

  40.  

  41. load CNN_5;

  42. [er, bad] = cnntest(cnn, test_x, test_y);

  43. figure; plot(cnn.rL);

  44. assert(er<0.12, 'Too big error');

免责声明:信息仅供参考,不构成投资及交易建议。投资者据此操作,风险自担。
如果觉得文章对你有用,请随意赞赏收藏
相关推荐
相关下载
登录后评论
Copyright © 2019 宽客在线