TensorFlow 调用预训练好的模型—— Python 实现

栏目: Python · 发布时间: 6年前

内容简介:获取更多精彩,请关注「seniusen」!
  • TensorFlow 预训练好的模型被保存为以下四个文件
TensorFlow 调用预训练好的模型—— Python 实现
  • data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如下所示,为简单起见只保留了一个模型。
model_checkpoint_path: "/home/senius/python/c_python/test/model-40"
all_model_checkpoint_paths: "/home/senius/python/c_python/test/model-40"
复制代码

2. 导入模型图、参数值和相关变量

import tensorflow as tf
import numpy as np

sess = tf.Session()
X = None # input
yhat = None # output

def load_model():
    """
        Loading the pre-trained model and parameters.
    """
    global X, yhat
    modelpath = r'/home/senius/python/c_python/test/'
    saver = tf.train.import_meta_graph(modelpath + 'model-40.meta')
    saver.restore(sess, tf.train.latest_checkpoint(modelpath))
    graph = tf.get_default_graph()
    X = graph.get_tensor_by_name("X:0")
    yhat = graph.get_tensor_by_name("tanh:0")
    print('Successfully load the pre-trained model!')

复制代码
  • 通过 saver.restore 我们可以得到预训练的所有参数值,然后再通过 graph.get_tensor_by_name 得到模型的输入张量和我们想要的输出张量。

3. 运行前向传播过程得到预测值

def predict(txtdata):
    """
        Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3).
        Test a single example.
        Arg:
                txtdata: Array in C.
        Returns:
            Three coordinates of a face normal.
    """
    global X, yhat

    data = np.array(txtdata)
    data = data.reshape(-1, 41, 41, 41, 3)
    output = sess.run(yhat, feed_dict={X: data})  # (-1, 3)
    output = output.reshape(-1, 1)
    ret = output.tolist()
    return ret

复制代码
  • 通过 feed_dict 喂入测试数据,然后 run 输出的张量我们就可以得到预测值。

4. 测试

load_model()
testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32)
testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3)
testdata = testdata[0:2, ...] # the first two examples
txtdata = testdata.tolist()
output = predict(txtdata)
print(output)
#  [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523], 
# [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]]
复制代码
  • 本例输入是一个三维网格模型处理后的 [41, 41, 41, 3] 的数据,输出一个表面法向量坐标 (x, y, z)。

获取更多精彩,请关注「seniusen」!

TensorFlow 调用预训练好的模型—— Python 实现

以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

The Elements of Statistical Learning

The Elements of Statistical Learning

Trevor Hastie、Robert Tibshirani、Jerome Friedman / Springer / 2009-10-1 / GBP 62.99

During the past decade there has been an explosion in computation and information technology. With it have come vast amounts of data in a variety of fields such as medicine, biology, finance, and mark......一起来看看 《The Elements of Statistical Learning》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具