tensorflow keras 查找中间tensor并构建局部子图

栏目: 数据库 · 发布时间: 7年前

内容简介:在Mask_RCNN项目的示例项目此方法可以读取层的输出,对于输出多于1个tensor的,可以指定get_layer("rpn_class").output[0:2]等确定。但是对于自定义层的中间变量,就没办法获得了,因此需要使用方法二。

在Mask_RCNN项目的示例项目 nucleus 中,stepbystep步骤里面,需要对网络模型的中间变量进行提取和可视化,常见方式有两种:

通过 get_layer方法:

outputs = [
    ("rpn_class", model.keras_model.get_layer("rpn_class").output),
    ("proposals", model.keras_model.get_layer("ROI").output)
    ]

此方法可以读取层的输出,对于输出多于1个tensor的,可以指定get_layer("rpn_class").output[0:2]等确定。

但是对于自定义层的中间变量,就没办法获得了,因此需要使用方法二。

通过 tensor.op.inputs 逐层向上查找

定义一个迭代函数,不断查找

def find_in_tensor(tensor,name,index=0):
    index += 1
    if index >20:
        return
    tensor_parent = tensor.op.inputs
    for each_ptensor in tensor_parent:
        #print(each_ptensor.name)
        if bool(re.fullmatch(name, each_ptensor.name)):
            print('find it!')
            return each_ptensor
        result = find_in_tensor(each_ptensor,name,index)
        if result is not None:
            return result

接着获得某层的输出,调用迭代函数,找到该tensor

pillar = model.keras_model.get_layer("ROI").output
nms_rois = find_in_tensor(pillar,'ROI_3/rpn_non_max_suppression/NonMaxSuppressionV2:0')
outputs.append(('NonMaxSuppression',nms_rois))

最后,调用kf.fuction构建局部图,并运行:

submodel = model.keras_model
outputs = OrderedDict(outputs)
if submodel.uses_learning_phase and not isinstance(K.learning_phase(), int):
    inputs += [K.learning_phase()]
kf = K.function(submodel.inputs, list(outputs.values()))
in_p,ou_p = next(train_generator)
output_all = kf(in_p)

此时打印outputs可以看到类似如下:

OrderedDict([('rpn_class',<tf.Tensor 'rpn_class_3/concat:0' shape=(?, ?, 2) dtype=float32>),
             ('proposals',<tf.Tensor 'ROI_3/packed_2:0' shape=(1, ?, ?) dtype=float32>),
             ('fpn_p2',<tf.Tensor 'fpn_p2_3/BiasAdd:0' shape=(?, 192, 192, 256) dtype=float32>),
             ('fpn_p3',<tf.Tensor 'fpn_p3_3/BiasAdd:0' shape=(?, 96, 96, 256) dtype=float32>),
             ('fpn_p4',<tf.Tensor 'fpn_p4_3/BiasAdd:0' shape=(?, 48, 48, 256) dtype=float32>),
             ('fpn_p6',<tf.Tensor 'fpn_p6_3/MaxPool:0' shape=(?, 12, 12, 256) dtype=float32>),
             ('NonMaxSuppression',<tf.Tensor 'ROI_3/rpn_non_max_suppression/NonMaxSuppressionV2:0' shape=(?,) dtype=int32>)])

大功告成~


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

More Eric Meyer on CSS (Voices That Matter)

More Eric Meyer on CSS (Voices That Matter)

Eric A. Meyer / New Riders Press / 2004-04-08 / USD 45.00

Ready to commit to using more CSS on your sites? If you are a hands-on learner who has been toying with CSS and want to experiment with real-world projects that will enable you to see how CSS......一起来看看 《More Eric Meyer on CSS (Voices That Matter)》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

MD5 加密
MD5 加密

MD5 加密工具

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具