-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmake_checkpoint.py
94 lines (75 loc) · 3.54 KB
/
make_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# make_checkpoint.py -- convert a .pb to a TF session dump
#
# Copyright (C) 2018, Hiromu Yakura <[email protected]>.
# Copyright (C) 2017, Nicholas Carlini <[email protected]>.
#
# This program is licenced under the BSD 2-Clause licence,
# contained in the LICENCE file in this directory.
from __future__ import print_function
import argparse
import os
import sys
from tensorflow.core.framework.graph_pb2 import *
import numpy as np
import tensorflow as tf
sys.path.append(os.path.join(os.path.dirname(__file__), 'DeepSpeech'))
from util.audio import audiofile_to_input_vector
from util.text import ctc_label_dense_to_sparse
# Okay, so this is ugly. We don't want DeepSpeech to crash
# when we haven't built the language model.
# So we're just going to monkeypatch TF and make it a no-op.
# Sue me.
tf.load_op_library = lambda x: x
import DeepSpeech
def make_checkpoint(model_path, audio_path, save_path):
graph_def = GraphDef()
loaded = graph_def.ParseFromString(open(model_path, 'rb').read())
with tf.Graph().as_default() as graph:
new_input = tf.placeholder(tf.float32, [None, None, None],
name='new_input')
# Load the saved .pb into the current graph to let us grab
# access to the weights.
logits, = tf.import_graph_def(
graph_def,
input_map={'input_node:0': new_input},
return_elements=['logits:0'],
name='newname',
op_dict=None,
producer_op_list=None
)
# Now let's dump these weights into a new copy of the network.
with tf.Session(graph=graph) as sess:
# Sample sentence, to make sure we've done it right
mfcc = audiofile_to_input_vector(audio_path, 26, 9)
# Okay, so this is ugly again.
# We just want it to not crash.
tf.app.flags.FLAGS.alphabet_config_path = \
os.path.join(os.path.dirname(__file__), 'DeepSpeech/data/alphabet.txt')
DeepSpeech.initialize_globals()
logits2 = DeepSpeech.BiRNN(new_input, [len(mfcc)], [0]*10)
# Here's where all the work happens. Copy the variables
# over from the .pb to the session object.
for var in tf.global_variables():
sess.run(var.assign(sess.run('newname/'+var.name)))
# Test to make sure we did it right.
res = (sess.run(logits, {new_input: [mfcc],
'newname/input_lengths:0': [len(mfcc)]}).flatten())
res2 = (sess.run(logits2, {new_input: [mfcc]})).flatten()
print('This value should be small', np.sum(np.abs(res - res2)))
# And finally save the constructed session.
saver = tf.train.Saver()
saver.save(sess, save_path)
def main():
get_path = lambda x: os.path.join(os.path.dirname(__file__), x)
parser = argparse.ArgumentParser(description=None)
parser.add_argument('--model', type=str, default=get_path('models/output_graph.pb'),
help='Input TensorFlow graph file taken from DeepSpeech')
parser.add_argument('--audio', type=str,
default=get_path('DeepSpeech/data/smoke_test/LDC93S1.wav'),
help='Sample audio file for testing')
parser.add_argument('--out', type=str, default=get_path('models/session_dump'),
help='Path for saving the constructed session')
args = parser.parse_args()
make_checkpoint(args.model, args.audio, args.out)
if __name__ == '__main__':
main()