保存和恢复模型

栏目: 编程工具 · 发布时间: 6年前

内容简介:上述代码将创建一个 TensorFlow 检查点文件集合,这些文件在每个周期结束时更新:创建一个未经训练的全新模型。仅通过权重恢复模型时,您必须有一个与原始模型架构相同的模型。由于模型架构相同,因此我们可以分享权重(尽管是不同的模型实例)。现在,重新构建一个未经训练的全新模型,并用测试集对其进行评估。未训练模型的表现有很大的偶然性(准确率约为 10%):

模型进度可在训练期间和之后保存。这意味着,您可以从上次暂停的地方继续训练模型,避免训练时间过长。此外,可以保存意味着您可以分享模型,而他人可以对您的工作成果进行再创作。发布研究模型和相关技术时,大部分机器学习从业者会分享以下内容:

  • 用于创建模型的代码,以及
  • 模型的训练权重或参数

分享此类数据有助于他人了解模型的工作原理并尝试使用新数据自行尝试模型。

注意:请谨慎使用不可信的代码 - TensorFlow 模型就是代码。有关详情,请参阅安全地使用 TensorFlow。

选项

您可以通过多种不同的方法保存 TensorFlow 模型,具体取决于您使用的 API。本指南使用的是 tf.keras,它是一种用于在 TensorFlow 中构建和训练模型的高阶 API。要了解其他方法,请参阅 TensorFlow 保存和恢复指南或在 Eager 中保存。

设置

安装和导入

安装并导入 TensorFlow 和依赖项:

In [1]:

!pip install -q h5py pyyaml
复制代码

获取示例数据集

我们将使用 MNIST 数据集训练模型,以演示如何保存权重。要加快演示运行速度,请仅使用前 1000 个样本:

In [2]:

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras

tf.__version__
复制代码

Out[2]:

'1.13.1'复制代码

In [3]:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
复制代码

定义模型

我们来构建一个简单的模型,以演示如何保存和加载权重。

In [4]:

# Returns a short sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation=tf.nn.softmax)
  ])

  model.compile(optimizer=tf.keras.optimizers.Adam(),
                loss=tf.keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy'])

  return model

# Create a basic model instance
model = create_model()
model.summary()
复制代码
WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\keras\layers\core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
复制代码

在训练期间保存检查点

主要用例是,在训练期间或训练结束时自动保存检查点。这样一来,您便可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断。

tf.keras.callbacks.ModelCheckpoint 是执行此任务的回调。该回调需要几个参数来配置检查点。

检查点回调用法

训练模型,并将 ModelCheckpoint 回调传递给该模型:

In [5]:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

model = create_model()

model.fit(train_images, train_labels,  epochs = 10,
          validation_data = (test_images,test_labels),
          callbacks = [cp_callback])  # pass callback to training
复制代码
Train on 1000 samples, validate on 1000 samples
Epoch 1/10
 864/1000 [========================>.....] - ETA: 0s - loss: 1.2590 - acc: 0.6354
Epoch 00001: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\keras\engine\network.py:1436: update_checkpoint_state (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.train.CheckpointManager to manage checkpoints rather than manually editing the Checkpoint proto.
1000/1000 [==============================] - 1s 791us/sample - loss: 1.1675 - acc: 0.6650 - val_loss: 0.7683 - val_acc: 0.7550
Epoch 2/10
 896/1000 [=========================>....] - ETA: 0s - loss: 0.4623 - acc: 0.8750
Epoch 00002: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 351us/sample - loss: 0.4515 - acc: 0.8750 - val_loss: 0.5316 - val_acc: 0.8340
Epoch 3/10
 800/1000 [=======================>......] - ETA: 0s - loss: 0.2790 - acc: 0.9287
Epoch 00003: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 358us/sample - loss: 0.2834 - acc: 0.9270 - val_loss: 0.4607 - val_acc: 0.8520
Epoch 4/10
 928/1000 [==========================>...] - ETA: 0s - loss: 0.2077 - acc: 0.9515
Epoch 00004: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 339us/sample - loss: 0.2046 - acc: 0.9530 - val_loss: 0.4370 - val_acc: 0.8540
Epoch 5/10
 896/1000 [=========================>....] - ETA: 0s - loss: 0.1578 - acc: 0.9710
Epoch 00005: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 350us/sample - loss: 0.1526 - acc: 0.9720 - val_loss: 0.4047 - val_acc: 0.8670
Epoch 6/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.1055 - acc: 0.9815
Epoch 00006: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 380us/sample - loss: 0.1062 - acc: 0.9830 - val_loss: 0.4201 - val_acc: 0.8560
Epoch 7/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0826 - acc: 0.9850
Epoch 00007: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 351us/sample - loss: 0.0824 - acc: 0.9850 - val_loss: 0.4168 - val_acc: 0.8660
Epoch 8/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0662 - acc: 0.9919
Epoch 00008: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 357us/sample - loss: 0.0655 - acc: 0.9910 - val_loss: 0.4021 - val_acc: 0.8700
Epoch 9/10
 864/1000 [========================>.....] - ETA: 0s - loss: 0.0495 - acc: 0.9954
Epoch 00009: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 358us/sample - loss: 0.0491 - acc: 0.9950 - val_loss: 0.4168 - val_acc: 0.8640
Epoch 10/10
 896/1000 [=========================>....] - ETA: 0s - loss: 0.0401 - acc: 1.0000
Epoch 00010: saving model to training_1/cp.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 354us/sample - loss: 0.0397 - acc: 1.0000 - val_loss: 0.4091 - val_acc: 0.8770
复制代码

Out[5]:

<tensorflow.python.keras.callbacks.History at 0x1403b5f8>复制代码

上述代码将创建一个 TensorFlow 检查点文件集合,这些文件在每个周期结束时更新:

In [7]:

!dir {checkpoint_dir}
复制代码
驱动器 C 中的卷没有标签。
 卷的序列号是 CE2F-63AD

 C:\Users\Administrator\JupyterProject\training_1 的目录

2019/04/28  11:23    <DIR>          .
2019/04/28  11:23    <DIR>          ..
2019/04/28  11:23                71 checkpoint
2019/04/28  11:23         1,631,508 cp.ckpt.data-00000-of-00001
2019/04/28  11:23               648 cp.ckpt.index
               3 个文件      1,632,227 字节
               2 个目录 23,484,948,480 可用字节
复制代码

创建一个未经训练的全新模型。仅通过权重恢复模型时,您必须有一个与原始模型架构相同的模型。由于模型架构相同,因此我们可以分享权重(尽管是不同的模型实例)。

现在,重新构建一个未经训练的全新模型,并用测试集对其进行评估。未训练模型的表现有很大的偶然性(准确率约为 10%):

In [8]:

model = create_model()

loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
复制代码
1000/1000 [==============================] - 0s 81us/sample - loss: 2.3694 - acc: 0.0610
Untrained model, accuracy:  6.10%
复制代码

然后从检查点加载权重,并重新评估:

In [9]:

model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
复制代码
1000/1000 [==============================] - 0s 46us/sample - loss: 0.4091 - acc: 0.8770
Restored model, accuracy: 87.70%
复制代码

检查点回调选项

该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。

训练一个新模型,每隔 5 个周期保存一次检查点并设置唯一名称:

In [10]:

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)
复制代码
Epoch 00005: saving model to training_2/cp-0005.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00010: saving model to training_2/cp-0010.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00015: saving model to training_2/cp-0015.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00020: saving model to training_2/cp-0020.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00025: saving model to training_2/cp-0025.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00030: saving model to training_2/cp-0030.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00035: saving model to training_2/cp-0035.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00040: saving model to training_2/cp-0040.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00045: saving model to training_2/cp-0045.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

Epoch 00050: saving model to training_2/cp-0050.ckpt
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
复制代码

Out[10]:

<tensorflow.python.keras.callbacks.History at 0x16054e48>复制代码

现在,看一下生成的检查点并选择最新的检查点:

In [11]:

! dir {checkpoint_dir}
复制代码
驱动器 C 中的卷没有标签。
 卷的序列号是 CE2F-63AD

 C:\Users\Administrator\JupyterProject\training_2 的目录

2019/04/28  11:24    <DIR>          .
2019/04/28  11:24    <DIR>          ..
2019/04/28  11:24                81 checkpoint
2019/04/28  11:24         1,631,508 cp-0005.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0005.ckpt.index
2019/04/28  11:24         1,631,508 cp-0010.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0010.ckpt.index
2019/04/28  11:24         1,631,508 cp-0015.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0015.ckpt.index
2019/04/28  11:24         1,631,508 cp-0020.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0020.ckpt.index
2019/04/28  11:24         1,631,508 cp-0025.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0025.ckpt.index
2019/04/28  11:24         1,631,508 cp-0030.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0030.ckpt.index
2019/04/28  11:24         1,631,508 cp-0035.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0035.ckpt.index
2019/04/28  11:24         1,631,508 cp-0040.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0040.ckpt.index
2019/04/28  11:24         1,631,508 cp-0045.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0045.ckpt.index
2019/04/28  11:24         1,631,508 cp-0050.ckpt.data-00000-of-00001
2019/04/28  11:24               648 cp-0050.ckpt.index
              21 个文件     16,321,641 字节
               2 个目录 23,468,404,736 可用字节
复制代码

In [13]:

latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
复制代码

Out[13]:

'training_2\\cp-0050.ckpt'复制代码

注意:默认的 TensorFlow 格式仅保存最近的 5 个检查点。

要进行测试,请重置模型并加载最新的检查点:

In [14]:

model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
复制代码
1000/1000 [==============================] - 0s 86us/sample - loss: 0.4830 - acc: 0.8770
Restored model, accuracy: 87.70%
复制代码

这些是什么文件?

上述代码将权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。检查点包括:

包含模型权重的一个或多个分片。

如果您仅在一台机器上训练模型,则您将有 1 个后缀为 .data-00000-of-00001 的分片

手动保存权重

在上文中,您了解了如何将权重加载到模型中。

手动保存权重的方法同样也很简单,只需使用 Model.save_weights 方法即可。

In [15]:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
复制代码
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000018D9D080>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.
1000/1000 [==============================] - 0s 88us/sample - loss: 0.4830 - acc: 0.8770
Restored model, accuracy: 87.70%
复制代码

保存整个模型

整个模型可以保存到一个文件中,其中包含权重值、模型配置乃至优化器配置。这样,您就可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。

在 Keras 中保存完全可正常使用的模型非常有用,您可以在 TensorFlow.js 中加载它们,然后在网络浏览器中训练和运行它们。

Keras 使用 HDF5 标准提供基本的保存格式。对于我们来说,可将保存的模型视为一个二进制 blob。

In [16]:

model = create_model()

model.fit(train_images, train_labels, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')
复制代码
Epoch 1/5
1000/1000 [==============================] - 0s 322us/sample - loss: 1.1511 - acc: 0.6830
Epoch 2/5
1000/1000 [==============================] - 0s 235us/sample - loss: 0.4189 - acc: 0.8840s - loss: 0.4545 - acc: 0.8
Epoch 3/5
1000/1000 [==============================] - 0s 235us/sample - loss: 0.2864 - acc: 0.9230
Epoch 4/5
1000/1000 [==============================] - 0s 233us/sample - loss: 0.2147 - acc: 0.9410
Epoch 5/5
1000/1000 [==============================] - 0s 224us/sample - loss: 0.1642 - acc: 0.9660
复制代码

现在,从该文件重新创建模型:

In [17]:

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
复制代码
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_14 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_7 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
复制代码

检查其准确率:

In [18]:

loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
复制代码
1000/1000 [==============================] - 0s 99us/sample - loss: 0.4258 - acc: 0.8530
Restored model, accuracy: 85.30%
复制代码

此技巧可保存以下所有内容:

  • 权重值
  • 模型配置(架构)
  • 优化器配置

Keras 通过检查架构来保存模型。目前,它无法保存 TensorFlow 优化器(来自 tf.train)。使用此类优化器时,您需要在加载模型后对其进行重新编译,使优化器的状态变松散。

后续学习计划¶

这些就是使用 tf.keras 保存和加载模型的快速指南。

  • tf.keras 指南详细介绍了如何使用 tf.keras 保存和加载模型。

  • 请参阅在 Eager 中保存,了解如何在 Eager Execution 期间保存模型。

  • 保存和恢复指南介绍了有关 TensorFlow 保存的低阶详细信息。


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Head First Web Design

Head First Web Design

Ethan Watrall、Jeff Siarto / O’Reilly Media, Inc. / 2009-01-02 / USD 49.99

Want to know how to make your pages look beautiful, communicate your message effectively, guide visitors through your website with ease, and get everything approved by the accessibility and usability ......一起来看看 《Head First Web Design》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

MD5 加密
MD5 加密

MD5 加密工具

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具