MNIST机器学习入门

作者:杨润炜
日期:2016/10/16 21:30

这段时间一直在了解机器学习方面的知识,因为涉及高等数学(主要是微积分)、概率论、统计学、线性代数等数学知识,不得不回归到大学的课本上,这是一个周期很长的学习过程。经过这段时间的沉淀,总算能看懂一些机器学习方面的理论知识,便在极客学院的机器学习教程进行机器学习方面的研究,主要是利用tensorflow这个工具来开发机器学习的应用。下面是我对入门教程的一些源码,算是做个笔记,也可以分享给和我一样正在入门的小伙伴。
这是根据教程写下来的代码,其中还包括了模型的保存,方便接下来做模型的实际应用。

minist.py
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  4. print("Download Done!")
  5. x = tf.placeholder(tf.float32, [None, 784])
  6. # paras
  7. W = tf.Variable(tf.zeros([784, 10]))
  8. b = tf.Variable(tf.zeros([10]))
  9. y = tf.nn.softmax(tf.matmul(x, W) + b)
  10. y_ = tf.placeholder(tf.float32, [None, 10])
  11. # loss func
  12. cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
  13. train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
  14. # init
  15. init = tf.initialize_all_variables()
  16. sess = tf.Session()
  17. sess.run(init)
  18. # train
  19. for i in range(1000):
  20. batch_xs, batch_ys = mnist.train.next_batch(100)
  21. sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  22. correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
  23. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
  24. print("Accuarcy on Test-dataset: ", sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
  25. # save model
  26. saver = tf.train.Saver()
  27. save_path = saver.save(sess, "./model/minist_softmax.ckpt")
  28. print("Model saved in file: ", save_path)

运行上述代码便可得到模型。由于需要下载数据集,所以第一次运行会比较慢,需要耐心等待。

下面是应用模型对实际手写数字图片做识别的源码。你需要用ps做一些黑底白字的手写数字图片资源。

test_minist.py
  1. # -*- coding: UTF-8 -*-
  2. from PIL import Image
  3. from numpy import *
  4. import tensorflow as tf
  5. import sys
  6. if len(sys.argv) < 2 :
  7. print('argv must at least 2. you give '+str(len(sys.argv)))
  8. sys.exit()
  9. filename = sys.argv[1]
  10. im=Image.open(filename)
  11. img = array(im.resize((28, 28), Image.ANTIALIAS).convert("L"))
  12. # data = transpose(img.ravel())
  13. data = img.reshape([1, 784])
  14. # print(data)
  15. # xData = tf.Variable(data, name="x")
  16. x = tf.placeholder(tf.float32, [None, 784])
  17. W = tf.Variable(tf.zeros([784, 10]))
  18. b = tf.Variable(tf.zeros([10]))
  19. y = tf.nn.softmax(tf.matmul(x, W) + b)
  20. # y = tf.add(b, tf.matmul(x, W))
  21. saver = tf.train.Saver()
  22. init_op = tf.initialize_all_variables()
  23. with tf.Session() as sess:
  24. sess.run(init_op)
  25. save_path = "./model/minist_softmax.ckpt"
  26. saver.restore(sess, save_path)
  27. # print("Model restored.")
  28. predictions = sess.run(y, feed_dict={x: data})
  29. print(predictions[0]);
  30. # print(tf.arg_max(predictions[0], 1))

注:上述代码需要两个python库,Pillow和numpy,执行如下命令来安装。

  1. pip install Pillow numpy

接下来可以通过命令指定需要测试的图片,如:

  1. python test_minist.py ./img/test_1.jpg

结果是一个数组,值为1的索引就代表识别出来的数字了。

感谢您的阅读!
如果看完后有任何疑问,欢迎拍砖。
欢迎转载,转载请注明出处:http://www.yangrunwei.com/a/76.html
邮箱:glowrypauky@gmail.com
QQ: 892413924

上一篇:智能时代
下一篇:Git基本架构