Estimator:一种可极大地简化机器学习编程的高阶 TensorFlow API

栏目: 编程工具 · 发布时间: 5年前

今天的内容要向大家介绍 Estimator (https://tensorflow.google.cn/api_docs/python/tf/estimator?hl=zh-CN) - 一种可极大地简化机器学习编程的高阶 TensorFlow API Estimator 会封装下列操作:

  • 训练

  • 评估

  • 预测

  • 导出以供使用

您可以使用我们提供的预创建的 Estimator,也可以编写自定义 Estimator。所有 Estimator(无论是预创建的还是自定义)都是基于 tf.estimator.Estimator  类的类 (https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator?hl=zh-CN)

注意:TensorFlow 还包含一个已弃用的 Estimator 类  tf.contrib.learn.Estimator (https://tensorflow.google.cn/api_docs/python/tf/contrib/learn/Estimator?hl=zh-CN) ,您不应该使用此类。

Estimator 的优势

Estimator 具有下列优势:

  • 您可以在本地主机上或分布式多服务器环境中运行基于 Estimator 的模型,而无需更改模型。此外,您可以在 CPU、GPU 或 TPU 上运行基于 Estimator 的模型,而无需重新编码模型。

  • Estimator 简化了在模型开发者之间共享实现的过程。

  • 您可以使用高级直观代码开发先进的模型。简言之,采用 Estimator 创建模型通常比采用低阶 TensorFlow API 更简单。

  • Estimator 本身在 tf.layers 之上构建而成 (https://tensorflow.google.cn/api_docs/python/tf/layers?hl=zh-CN) ,可以简化自定义过程。

  • Estimator 会为您构建图。

  • Estimator 提供安全的分布式训练循环,可以控制如何以及何时:

    • 构建图

    • 初始化变量

    • 开始排队

    • 处理异常

    • 创建检查点文件并从故障中恢复

    • 保存 TensorBoard 的摘要

使用 Estimator 编写应用时,您必须将数据输入管道从模型中分离出来。这种分离简化了不同数据集的实验流程。

预创建的 Estimator

借助预创建的 Estimator,您能够在比基本 TensorFlow API 高级很多的概念层面上进行操作。由于 Estimator 会为您处理所有 “ 管道工作 ”,因此您不必再为创建计算图或会话而操心。也就是说,预创建的 Estimator 会为您创建和管理 Graph (https://tensorflow.google.cn/api_docs/python/tf/Graph?hl=zh-CN) 和  Session (https://tensorflow.google.cn/api_docs/python/tf/Session?hl=zh-CN) 对象。此外,借助预创建的 Estimator,您只需稍微更改下代码,就可以尝试不同的模型架构。例如, DNNClassifier 是一个预创建的 Estimator 类 (https://tensorflow.google.cn/api_docs/python/tf/estimator/DNNClassifier?hl=zh-CN) ,它根据密集的前馈神经网络训练分类模型。

预创建的 Estimator 程序的结构

依赖预创建的 Estimator 的 TensorFlow 程序通常包含下列四个步骤:

编写一个或多个数据集导入函数。

例如,您可以创建一个函数来导入训练集,并创建另一个函数来导入测试集。每个数据集导入函数都必须返回两个对象 :

  • 一个字典,其中键是特征名称,值是包含相应特征数据的张量(或 SparseTensor)

  • 一个包含一个或多个标签的张量

例如,以下代码展示了输入函数的基本框架:

def input_fn(dataset):

...  # manipulate dataset, extracting the feature dict and the label

return feature_dict, label

(要了解完整的详细信息,请参阅 导入数据 (https://tensorflow.google.cn/guide/datasets?hl=zh-CN) 。)

定义特征列。

每个 tf.feature_column (https://tensorflow.google.cn/api_docs/python/tf/feature_column?hl=zh-CN) 都标识了特征名称、特征类型和任何输入预处理操作。例如,以下代码段创建了三个存储整数或浮点数据的特征列。前两个特征列仅标识了特征的名称和类型。第三个特征列还指定了一个 lambda,该程序将调用此 lambda 来调节原始数据:

# Define three numeric feature columns.

population = tf.feature_column.numeric_column('population')

crime_rate = tf.feature_column.numeric_column('crime_rate')

median_education = tf.feature_column.numeric_column('median_education',

normalizer_fn=lambda x: x - global_education_mean)

实例化相关的预创建的 Estimator。

例如,下面是对名为 LinearClassifier 的预创建 Estimator 进行实例化的示例代码:

# Instantiate an estimator, passing the feature columns.

estimator = tf.estimator.LinearClassifier(

feature_columns=[population, crime_rate, median_education],

)

调用训练、评估或推理方法。

例如,所有 Estimator 都提供训练模型的 train 方法。

# my_training_set is the function created in Step 1

estimator.train(input_fn=my_training_set, steps=2000)

预创建的 Estimator 的优势

预创建的 Estimator 会编码最佳做法,从而具有下列优势:

  • 确定计算图不同部分的运行位置以及在单台机器或多台机器上实现策略的最佳做法。

  • 事件(汇总)编写和普遍有用的汇总的最佳做法。

如果您不使用预创建的 Estimator,则必须自行实现上述功能。

自定义 Estimator

每个 Estimator(无论是预创建还是自定义)的核心都是其模型函数,这是一种为训练、评估和预测构建图的方法。如果您使用预创建的 Estimator,则有人已经实现了模型函数。如果您使用自定义 Estimator,则必须自行编写模型函数。 随附文档 介绍了如何编写模型函数 https://tensorflow.google.cn/guide/custom_estimators?hl=zh-CN)

推荐的工作流程

我们推荐以下工作流程:

  1. 假设存在合适的预创建的 Estimator,使用它构建第一个模型并使用其结果确定基准。

  2. 使用此预创建的 Estimator 构建和测试整体管道,包括数据的完整性和可靠性。

  3. 如果存在其他合适的预创建的 Estimator,则运行实验来确定哪个预创建的 Estimator 效果最好。

  4. 可以通过构建自定义 Estimator 进一步改进模型。

从 Keras 模型创建 Estimator

您可以将现有的 Keras 模型转换为 Estimator。这样做之后,Keras 模型就可以利用 Estimator 的优势,例如分布式训练。调用

tf.keras.estimator.model_to_estimator (https://tensorflow.google.cn/api_docs/python/tf/keras/estimator/model_to_estimator?hl=zh-CN) ,如下例所示:

# Instantiate a Keras inception v3 model.

keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)

# Compile model with the optimizer, loss, and metrics you'd like to train with.

keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),

loss='categorical_crossentropy',

metric='accuracy')

# Create an Estimator from the compiled Keras model. Note the initial model

# state of the keras model is preserved in the created Estimator.

est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)

# Treat the derived Estimator as you would with any other Estimator.

# First, recover the input name(s) of Keras model, so we can use them as the

# feature column name(s) of the Estimator input function:

keras_inception_v3.input_names  # print out: ['input_1']

# Once we have the input name(s), we can create the input function, for example,

# for input(s) in the format of numpy ndarray:

train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(

x={"input_1": train_data},

y=train_labels,

num_epochs=1,

shuffle=False)

# To train, we call Estimator's train function:

est_inception_v3.train(input_fn=train_input_fn, steps=2000)

请注意,Keras Estimator 的特征列名称和标签来自经过编译的对应 Keras 模型。例如,上面的 train_input_fn 的输入键名称可以从 keras_inception_v3.input_names 获得;同样,预测的输出名称可以从 keras_inception_v3.output_names 获得。

要了解详情,请参阅 tf.keras.estimator.model_to_estimator  的文档 (https://tensorflow.google.cn/api_docs/python/tf/keras/estimator/model_to_estimator?hl=zh-CN)

Estimator:一种可极大地简化机器学习编程的高阶 TensorFlow API


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

那些让文案绝望的文案

那些让文案绝望的文案

小马宋 / 北京联合出版公司 / 2015-10 / 45

什么文案60年前就在使用互联网思维? 什么文案让一辆小车在崇尚大车的国度畅销不衰? 什么文案让做文案的人产生“既生瑜何生亮”的绝望? 没错,它是甲壳虫。 远在上世纪五六十年代,这些文案让这辆不起眼的小车畅销不衰。 它的文案风趣而又言之凿凿,它的文案机智而又无可辩驳。 它充满自黑精神,善于借势时事热点,懂得乖巧卖萌,也是天生的段子手。 为了让国内读者一睹这一......一起来看看 《那些让文案绝望的文案》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具