Saving and Restoring TensorFlow Models: Checkpoints, MetaGraphs, and SavedModel
Saving and restoring in TensorFlow depends on the API level you use. In graph-mode TensorFlow 1.x, checkpoints capture Variable values; graph structure can be reloaded from a MetaGraph. In TensorFlow 2.x/Keras, the recommended format is SavedModel (or H5 for Keras-only models), and object-based checkpoints via tf.train.Checkpoint are preferred over name-based ones.
TF 1.x (graph mode): checkpoints with tf.train.Saver
Minimal save/restore with the same graph definition:
import tensorflow as tf
# Graph construction
X = tf.placeholder(tf.float32, shape=[None, 3], name='X')
Y = tf.placeholder(tf.float32, shape=[None, 1], name='Y')
W = tf.Variable(tf.random_normal([3, 1]), name='W')
C = tf.Variable(tf.zeros([1]), name='C')
Y_pred = tf.add(tf.matmul(X, W), C, name='Y_pred')
loss = tf.reduce_mean(tf.square(Y_pred - Y))
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
saver = tf.train.Saver() # saves all variables by default
ckpt_dir = '/tmp/linreg_ckpts'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(200):
xb = [[1, 2, 3]]; yb = [[7]]
sess.run(train_op, feed_dict={X: xb, Y: yb})
if (step + 1) % 50 == 0:
saver.save(sess, ckpt_dir + '/model.ckpt', global_step=step+1)
Restore into an identical graph and run inference:
import tensorflow as tf
# Same graph structure as during save
X = tf.placeholder(tf.float32, shape=[None, 3], name='X')
W = tf.Variable(tf.random_normal([3, 1]), name='W')
C = tf.Variable(tf.zeros([1]), name='C')
Y_pred = tf.add(tf.matmul(X, W), C, name='Y_pred')
saver = tf.train.Saver()
ckpt_dir = '/tmp/linreg_ckpts'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
latest = tf.train.latest_checkpoint(ckpt_dir)
if latest:
saver.restore(sess, latest)
preds = sess.run('Y_pred:0', feed_dict={'X:0': [[5, 0, -1]]})
Loading graph structure from a .meta file
Saver writes a MetaGraph (.meta) that captures ops, collections, and savers. Import it, then restore variables:
import tensorflow as tf
meta_path = '/tmp/my_great_model.meta'
ckpt_path = '/tmp/my_great_model'
saver = tf.train.import_meta_graph(meta_path)
with tf.Session() as sess:
saver.restore(sess, ckpt_path)
g = tf.get_default_graph()
# Fetch tensors/ops by name
x_t = g.get_tensor_by_name('X:0')
yhat_t = g.get_tensor_by_name('Y_pred:0')
out = sess.run(yhat_t, feed_dict={x_t: [[0, 1, 2]]})
Inspecting checkpoints and remapping variables
When loading a checkpoint from a different model/namespace, inspect the variable names:
python -m tensorflow.python.tools.inspect_checkpoint \
--file_name=/path/to/model.ckpt
Map checkpoint variable names to your in-graph Variabels:
import tensorflow as tf
# Build your graph/variables with explicit names
v_a = tf.get_variable('encoder/W', shape=[128, 64])
v_b = tf.get_variable('encoder/b', shape=[64])
# Map checkpoint names -> graph vars
mapping = {
'enc_layer/weights': v_a,
'enc_layer/biases': v_b,
}
loader = tf.train.Saver(var_list=mapping)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
loader.restore(sess, '/path/to/foreign_model.ckpt')
Fast in-memory rollback (no disk IO)
Capture variable states in memory and restore later (useful for early stopping rollbacks):
import tensorflow as tf
g = tf.get_default_graph()
vars_ = g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [g.get_operation_by_name(v.op.name + '/Assign') for v in vars_]
feed_slots = [op.inputs[1] for op in assign_ops]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# ... train ...
snapshot = sess.run(vars_)
# ... train more ...
# rollback
sess.run(assign_ops, feed_dict={slot: val for slot, val in zip(feed_slots, snapshot)})
SavedModel (TF 1.x): full-graph export with signtaures
Programmatic SavedModel export with explicit input/output signatures:
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[None], name='x')
w = tf.Variable(3.0, name='w')
b = tf.Variable(1.0, name='b')
y = tf.add(w * x, b, name='y')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
export_dir = './saved_model_v1'
bldr = tf.saved_model.builder.SavedModelBuilder(export_dir)
sig = tf.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tf.saved_model.utils.build_tensor_info(x)},
outputs={'y': tf.saved_model.utils.build_tensor_info(y)},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
)
bldr.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig
},
)
bldr.save()
Load the SavedModel and run predictions:
import tensorflow as tf
export_dir = './saved_model_v1'
sess = tf.Session()
meta = tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], export_dir)
sig = meta.signature_def
x_name = sig['serving_default'].inputs['x'].name
y_name = sig['serving_default'].outputs['y'].name
x_t = sess.graph.get_tensor_by_name(x_name)
y_t = sess.graph.get_tensor_by_name(y_name)
print(sess.run(y_t, feed_dict={x_t: [2.0, 4.0]}))
Simple export in TF 1.x with tf.saved_model.simple_save:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
a = tf.placeholder(tf.float32, shape=[None, 2], name='a')
k = tf.Variable([[2.0], [0.5]], name='k')
b = tf.Variable([0.1], name='b')
z = tf.add(tf.matmul(a, k), b, name='z')
sess.run(tf.global_variables_initializer())
tf.saved_model.simple_save(
sess, './simple_sm', inputs={'a': a}, outputs={'z': z})
TF 1.x notes and tips
- Prefix a relative path with "./" when restoring from the current working directory (e.g., saver.restore(sess, './model.ckpt')).
- Initialize variables before restore in some setup; Saver will overwrite initialized values with checkpoint values.
- Ensure variable/op names are unique and deterministic if you plan to import/export MetaGraphs.
TF 2.x and Keras: SavedModel / H5 and checkpoints
Save and load entire Keras models (architecture + weights + optimizer state):
import tensorflow as tf
from tensorflow import keras
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_shape=(32,)),
keras.layers.Dense(10, activation='softmax'),
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
# Train ... then save
model.save('./models/my_model') # SavedModel directory (recommended)
model.save('./models/my_model.h5') # H5 file (Keras format)
# Load
m1 = keras.models.load_model('./models/my_model')
m2 = keras.models.load_model('./models/my_model.h5')
Weights-only save/restore:
model.save_weights('./ckpts/ckpt')
# ... later or in a new process
reconstructed = keras.models.clone_model(model)
reconstructed.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
reconstructed.load_weights('./ckpts/ckpt')
Checkpoint callback during training:
cp_path = './train_ckpts/weights.{epoch:03d}.ckpt'
cb = tf.keras.callbacks.ModelCheckpoint(cp_path, save_weights_only=True, period=5)
model.fit(ds, epochs=50, callbacks=[cb])
# Restore most recent
latest = tf.train.latest_checkpoint('./train_ckpts')
model.load_weights(latest)
Custom objects when saving/loading H5 models:
import tensorflow as tf
from tensorflow import keras
@tf.function
def my_metric(y_true, y_pred):
return tf.reduce_mean(tf.abs(y_true - y_pred))
inp = keras.Input(shape=(10,))
out = keras.layers.Lambda(lambda t: tf.square(t))(inp)
net = keras.Model(inp, out)
net.compile(optimizer='adam', loss='mse', metrics=[my_metric])
net.save('./models/custom.h5')
restored = keras.models.load_model('./models/custom.h5',
custom_objects={'my_metric': my_metric})
Object-based checkpoints with tf.train.Checkpoint (TF 2.x)
Object-based checkpoints track Python object dependencies and are resilient to code refactors:
import tensorflow as tf
class Toy(tf.keras.Model):
def __init__(self):
super().__init__()
self.d1 = tf.keras.layers.Dense(8)
def call(self, x):
return self.d1(x)
model = Toy()
opt = tf.keras.optimizers.Adam(1e-3)
ckpt = tf.train.Checkpoint(optimizer=opt, model=model)
manager = tf.train.CheckpointManager(ckpt, './ob_ckpts', max_to_keep=3)
@tf.function
def train_step(x):
with tf.GradientTape() as tape:
y = model(x)
loss = tf.reduce_mean(tf.square(y))
grads = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))
# Attempt restore, then train and save
ckpt.restore(tf.train.latest_checkpoint('./ob_ckpts'))
for _ in range(100):
train_step(tf.random.normal([32, 4]))
manager.save()
TF 1.x: naming and restoring by varible name
Give variables stable names to fetch by name after restoring a MetaGraph:
import tensorflow as tf
v = tf.get_variable('weights_1', shape=[5, 5])
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, './byname/model')
# Later
saver = tf.train.import_meta_graph('./byname/model.meta')
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('./byname'))
val = sess.run('weights_1:0')
MonitoredTrainingSession (TF 1.x)
MonitoredTrainingSession will handle checkpointing automatically when given a directory:
import tensorflow as tf
X = tf.placeholder(tf.float32, [None, 4])
Y = tf.layers.dense(X, 2)
loss = tf.reduce_mean(tf.square(Y))
train = tf.train.AdamOptimizer(1e-3).minimize(loss)
with tf.train.MonitoredTrainingSession(checkpoint_dir='./mchk') as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train, feed_dict={X: [[0, 1, 2, 3]]})
Common pitfalls
- For TF 1.x, saving captures variable values; you must rebuild the same graph or import the MetaGraph before restoring.
- Use tf.train.latest_checkpoint(dir) to pick the newest checkpoint in a directory.
- Ensure unique, explicit names for variables/ops if exporting/importing graphs.
- For cross-graph restores, supply var_list mapping to tf.train.Saver.
- Prefer SavedModel for serving/export; perfer tf.train.Checkpoint for TF 2.x training workflows.