您正在使用IE低版浏览器,为了您的雷峰网账号安全和更好的产品体验,强烈建议使用更快更安全的浏览器
此为临时链接,仅用于文章预览,将在时失效
人工智能 正文
发私信给AI研习社
发送

1

TensorFlow极速入门

本文作者: AI研习社 2017-02-11 18:25
导语:目前,深度学习已经广泛应用于各个领域,很多童鞋想要一探究竟,这里抛砖引玉的介绍下最火的深度学习开源框架tensorflow。

雷锋网按:本文原载于Qunar技术沙龙,原作者已授权雷锋网发布。作者孟晓龙,2016年加入Qunar,目前在去哪儿网机票事业部担任算法工程师。热衷于深度学习技术的探索,对新事物有着强烈的好奇心。

一、前言

目前,深度学习已经广泛应用于各个领域,比如图像识别,图形定位与检测,语音识别,机器翻译等等,对于这个神奇的领域,很多童鞋想要一探究竟,这里抛砖引玉的简单介绍下最火的深度学习开源框架 tensorflow。本教程不是 cookbook,所以不会将所有的东西都事无巨细的讲到,所有的示例都将使用 python。

那么本篇教程会讲到什么?首先是一些基础概念,包括计算图,graph 与 session,基础数据结构,Variable,placeholder 与 feed_dict 以及使用它们时需要注意的点。最后给出了在 tensorflow 中建立一个机器学习模型步骤,并用一个手写数字识别的例子进行演示。

1、tensorflow是什么?

tensorflow 是 google 开源的机器学习工具,在2015年11月其实现正式开源,开源协议Apache 2.0。

下图是 query 词频时序图,从中可以看出 tensorflow 的火爆程度。

TensorFlow极速入门

2、 why tensorflow?

Tensorflow 拥有易用的 python 接口,而且可以部署在一台或多台 cpu , gpu 上,兼容多个平台,包括但不限于 安卓/windows/linux 等等平台上,而且拥有 tensorboard这种可视化工具,可以使用 checkpoint 进行实验管理,得益于图计算,它可以进行自动微分计算,拥有庞大的社区,而且很多优秀的项目已经使用 tensorflow 进行开发了。

3、 易用的tensorflow工具

如果不想去研究 tensorflow 繁杂的API,仅想快速的实现些什么,可以使用其他高层工具。比如 tf.contrib.learn,tf.contrib.slim,Keras 等,它们都提供了高层封装。这里是 tflearn 的样例集(github链接  https://github.com/tflearn/tflearn/tree/master/examples)。

4、 tensorflow安装

目前 tensorflow 的安装已经十分方便,有兴趣可以参考官方文档 (https://www.tensorflow.org/get_started/os_setup)。

二、 tensorflow基础

实际上编写tensorflow可以总结为两步.

(1)组装一个graph;

(2)使用session去执行graph中的operation。

因此我们从 graph 与 session 说起。

1、 graph与session

(1)计算图

Tensorflow 是基于计算图的框架,因此理解 graph 与 session 显得尤为重要。不过在讲解 graph 与 session 之前首先介绍下什么是计算图。假设我们有这样一个需要计算的表达式。该表达式包括了两个加法与一个乘法,为了更好讲述引入中间变量c与d。由此该表达式可以表示为:

TensorFlow极速入门

当需要计算e时就需要计算c与d,而计算c就需要计算a与b,计算d需要计算b。这样就形成了依赖关系。这种有向无环图就叫做计算图,因为对于图中的每一个节点其微分都很容易得出,因此应用链式法则求得一个复杂的表达式的导数就成为可能,所以它会应用在类似tensorflow这种需要应用反向传播算法的框架中。

(2)概念说明

下面是 graph , session , operation , tensor 四个概念的简介。

Tensor:类型化的多维数组,图的边;

Operation:执行计算的单元,图的节点;

Graph:一张有边与点的图,其表示了需要进行计算的任务;

Session:称之为会话的上下文,用于执行图。

Graph仅仅定义了所有 operation 与 tensor 流向,没有进行任何计算。而session根据 graph 的定义分配资源,计算 operation,得出结果。既然是图就会有点与边,在图计算中 operation 就是点而 tensor 就是边。Operation 可以是加减乘除等数学运算,也可以是各种各样的优化算法。每个 operation 都会有零个或多个输入,零个或多个输出。 tensor 就是其输入与输出,其可以表示一维二维多维向量或者常量。而且除了Variables指向的 tensor 外所有的 tensor 在流入下一个节点后都不再保存。

(3)举例

下面首先定义一个图(其实没有必要,tensorflow会默认定义一个),并做一些计算。

import  tensorflow as tf

graph  = tf.Graph()

with  graph.as_default():

    foo = tf.Variable(3,name='foo')

    bar = tf.Variable(2,name='bar')

    result = foo + bar

    initialize =  tf.global_variables_initializer()

这段代码,首先会载入tensorflow,定义一个graph类,并在这张图上定义了foo与bar的两个变量,最后对这个值求和,并初始化所有变量。其中,Variable是定义变量并赋予初值。让我们看下result(下方代码)。后面是输出,可以看到并没有输出实际的结果,由此可见在定义图的时候其实没有进行任何实际的计算。

print(result)  #Tensor("add:0", shape=(), dtype=int32)

下面定义一个session,并进行真正的计算。

with  tf.Session(graph=graph) as sess:

    sess.run(initialize)

    res = sess.run(result)

   print(res)  # 5

这段代码中,定义了session,并在session中执行了真正的初始化,并且求得result的值并打印出来。可以看到,在session中产生了真正的计算,得出值为5。

下图是该graph在tensorboard中的显示。这张图整体是一个graph,其中foo,bar,add这些节点都是operation,而foo和bar与add连接边的就是tensor。当session运行result时,实际就是求得add这个operation流出的tensor值,那么add的所有上游节点都会进行计算,如果图中有非add上游节点(本例中没有)那么该节点将不会进行计算,这也是图计算的优势之一。

TensorFlow极速入门

2、数据结构

Tensorflow的数据结构有着rank,shape,data types的概念,下面来分别讲解。

(1)rank

Rank一般是指数据的维度,其与线性代数中的rank不是一个概念。其常用rank举例如下。

TensorFlow极速入门

(2)shape

Shape指tensor每个维度数据的个数,可以用python的list/tuple表示。下图表示了rank,shape的关系。

TensorFlow极速入门

(3)data type

Data type,是指单个数据的类型。常用DT_FLOAT,也就是32位的浮点数。下图表示了所有的types。

TensorFlow极速入门

3、 Variables

(1)介绍

当训练模型时,需要使用Variables保存与更新参数。Variables会保存在内存当中,所有tensor一旦拥有Variables的指向就不会在session中丢失。其必须明确的初始化而且可以通过Saver保存到磁盘上。Variables可以通过Variables初始化。

weights  = tf.Variable(tf.random_normal([784, 200], stddev=0.35),name="weights")

biases  = tf.Variable(tf.zeros([200]), name="biases")

其中,tf.random_normal是随机生成一个正态分布的tensor,其shape是第一个参数,stddev是其标准差。tf.zeros是生成一个全零的tensor。之后将这个tensor的值赋值给Variable。

(2)初始化

实际在其初始化过程中做了很多的操作,比如初始化空间,赋初值(等价于tf.assign),并把Variable添加到graph中等操作。注意在计算前需要初始化所有的Variable。一般会在定义graph时定义global_variables_initializer,其会在session运算时初始化所有变量。

直接调用global_variables_initializer会初始化所有的Variable,如果仅想初始化部分Variable可以调用tf.variables_initializer。

Init_ab  = tf.variables_initializer([a,b],name=”init_ab”)

Variables可以通过eval显示其值,也可以通过assign进行赋值。Variables支持很多数学运算,具体可以参照官方文档 (https://www.tensorflow.org/api_docs/python/math_ops/)。

(3)Variables与constant的区别

值得注意的是Variables与constant的区别。Constant一般是常量,可以被赋值给Variables,constant保存在graph中,如果graph重复载入那么constant也会重复载入,其非常浪费资源,如非必要尽量不使用其保存大量数据。而Variables在每个session中都是单独保存的,甚至可以单独存在一个参数服务器上。可以通过代码观察到constant实际是保存在graph中,具体如下。

const  = tf.constant(1.0,name="constant")

print(tf.get_default_graph().as_graph_def())

这里第二行是打印出图的定义,其输出如下。

node {

  name: "constant"

  op: "Const"

  attr {

    key: "dtype"

    value {

      type: DT_FLOAT

    }

  }

  attr {

    key: "value"

    value {

      tensor {

        dtype: DT_FLOAT

        tensor_shape {

        }

        float_val: 1.0

      }

    }

  }

}

versions {

  producer: 17

}

(4)命名

另外一个值得注意的地方是尽量每一个变量都明确的命名,这样易于管理命令空间,而且在导入模型的时候不会造成不同模型之间的命名冲突,这样就可以在一张graph中容纳很多个模型。

4、 placeholders与feed_dict

当我们定义一张graph时,有时候并不知道需要计算的值,比如模型的输入数据,其只有在训练与预测时才会有值。这时就需要placeholder与feed_dict的帮助。

定义一个placeholder,可以使用tf.placeholder(dtype,shape=None,name=None)函数。

foo =  tf.placeholder(tf.int32,shape=[1],name='foo')

bar = tf.constant(2,name='bar')

result = foo + bar

with tf.Session() as sess:

    print(sess.run(result))

在上面的代码中,会抛出错误(InvalidArgumentError),因为计算result需要foo的具体值,而在代码中并没有给出。这时候需要将实际值赋给foo。最后一行修改如下:

print(sess.run(result,{foo:[3]}))

其中最后的dict就是一个feed_dict,一般会使用python读入一些值后传入,当使用minbatch的情况下,每次输入的值都不同。

三、mnist识别实例

介绍了一些tensorflow基础后,我们用一个完整的例子将这些串起来。

首先,需要下载数据集,mnist数据可以在Yann LeCun's website( http://yann.lecun.com/exdb/mnist/ )下载到,也可以通过如下两行代码得到。

from  tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/",  one_hot=True)

该数据集中一共有55000个样本,其中50000用于训练,5000用于验证。每个样本分为X与y两部分,其中X如下图所示,是28*28的图像,在使用时需要拉伸成784维的向量。

TensorFlow极速入门

整体的X可以表示为。

TensorFlow极速入门


y为X真实的类别,其数据可以看做如下图的形式。因此,问题可以看成一个10分类的问题。

TensorFlow极速入门

而本次演示所使用的模型为逻辑回归,其可以表示为

TensorFlow极速入门

用图形可以表示为下图,具体原理这里不再阐述,更多细节参考 该链接 (http://tech.meituan.com/intro_to_logistic_regression.html)。

TensorFlow极速入门

那么 let's coding。

当使用tensorflow进行graph构建时,大体可以分为五部分:

   1、 为 输入X与 输出y 定义placeholder;

    2、定义权重W;

    3、定义模型结构;

    4、定义损失函数;

    5、定义优化算法。

首先导入需要的包,定义X与y的placeholder以及 W,b 的 Variables。其中None表示任意维度,一般是min-batch的 batch size。而 W 定义是 shape 为784,10,rank为2的Variable,b是shape为10,rank为1的Variable。

import tensorflow as tf

x = tf.placeholder(tf.float32,  [None, 784])

y_ = tf.placeholder(tf.float32,  [None, 10])

W = tf.Variable(tf.zeros([784,  10]))

b = tf.Variable(tf.zeros([10]))

之后是定义模型。x与W矩阵乘法后与b求和,经过softmax得到y。

y = tf.nn.softmax(tf.matmul(x,  W) + b)

求逻辑回归的损失函数,这里使用了cross entropy,其公式可以表示为:

TensorFlow极速入门

这里的 cross entropy 取了均值。定义了学习步长为0.5,使用了梯度下降算法(GradientDescentOptimizer)最小化损失函数。不要忘记初始化 Variables。

cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

train_step =  tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

init =  tf.global_variables_initializer()

最后,我们的 graph 至此定义完毕,下面就可以进行真正的计算,包括初始化变量,输入数据,并计算损失函数与利用优化算法更新参数。

with tf.Session() as sess:

    sess.run(init)

    for i in range(1000):

        batch_xs, batch_ys =  mnist.train.next_batch(100)

        sess.run(train_step, feed_dict={x:  batch_xs, y_: batch_ys})

其中,迭代了1000次,每次输入了100个样本。mnist.train.next_batch 就是生成下一个 batch 的数据,这里知道它在干什么就可以。那么训练结果如何呢,需要进行评估。这里使用单纯的正确率,正确率是用取最大值索引是否相等的方式,因为正确的 label 最大值为1,而预测的 label 最大值为最大概率。

correct_prediction =  tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,  tf.float32))

print(sess.run(accuracy,  feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

至此,我们开发了一个简单的手写数字识别模型。

总结

总结全文,我们首先介绍了 graph 与 session,并解释了基础数据结构,讲解了一些Variable需要注意的地方并介绍了 placeholders 与 feed_dict 。最终以一个手写数字识别的实例将这些点串起来,希望可以给想要入门的你一丢丢的帮助。雷锋网

雷峰网版权文章,未经授权禁止转载。详情见转载须知

TensorFlow极速入门

分享:
相关文章

编辑

聚焦数据科学,连接 AI 开发者。更多精彩内容,请访问:yanxishe.com
当月热门文章
最新文章
请填写申请人资料
姓名
电话
邮箱
微信号
作品链接
个人简介
为了您的账户安全,请验证邮箱
您的邮箱还未验证,完成可获20积分哟!
请验证您的邮箱
立即验证
完善账号信息
您的账号已经绑定,现在您可以设置密码以方便用邮箱登录
立即设置 以后再说
Baidu
map