Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Saving and Restoring TensorFlow Models: Checkpoints, MetaGraphs, and SavedModel

Tech 1

Saving and restoring in TensorFlow depends on the API level you use. In graph-mode TensorFlow 1.x, checkpoints capture Variable values; graph structure can be reloaded from a MetaGraph. In TensorFlow 2.x/Keras, the recommended format is SavedModel (or H5 for Keras-only models), and object-based checkpoints via tf.train.Checkpoint are preferred over name-based ones.

TF 1.x (graph mode): checkpoints with tf.train.Saver

Minimal save/restore with the same graph definition:

import tensorflow as tf

# Graph construction
X = tf.placeholder(tf.float32, shape=[None, 3], name='X')
Y = tf.placeholder(tf.float32, shape=[None, 1], name='Y')
W = tf.Variable(tf.random_normal([3, 1]), name='W')
C = tf.Variable(tf.zeros([1]), name='C')
Y_pred = tf.add(tf.matmul(X, W), C, name='Y_pred')

loss = tf.reduce_mean(tf.square(Y_pred - Y))
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

saver = tf.train.Saver()  # saves all variables by default
ckpt_dir = '/tmp/linreg_ckpts'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(200):
        xb = [[1, 2, 3]]; yb = [[7]]
        sess.run(train_op, feed_dict={X: xb, Y: yb})
        if (step + 1) % 50 == 0:
            saver.save(sess, ckpt_dir + '/model.ckpt', global_step=step+1)

Restore into an identical graph and run inference:

import tensorflow as tf

# Same graph structure as during save
X = tf.placeholder(tf.float32, shape=[None, 3], name='X')
W = tf.Variable(tf.random_normal([3, 1]), name='W')
C = tf.Variable(tf.zeros([1]), name='C')
Y_pred = tf.add(tf.matmul(X, W), C, name='Y_pred')

saver = tf.train.Saver()
ckpt_dir = '/tmp/linreg_ckpts'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    latest = tf.train.latest_checkpoint(ckpt_dir)
    if latest:
        saver.restore(sess, latest)
    preds = sess.run('Y_pred:0', feed_dict={'X:0': [[5, 0, -1]]})

Loading graph structure from a .meta file

Saver writes a MetaGraph (.meta) that captures ops, collections, and savers. Import it, then restore variables:

import tensorflow as tf

meta_path = '/tmp/my_great_model.meta'
ckpt_path = '/tmp/my_great_model'

saver = tf.train.import_meta_graph(meta_path)
with tf.Session() as sess:
    saver.restore(sess, ckpt_path)
    g = tf.get_default_graph()
    # Fetch tensors/ops by name
    x_t = g.get_tensor_by_name('X:0')
    yhat_t = g.get_tensor_by_name('Y_pred:0')
    out = sess.run(yhat_t, feed_dict={x_t: [[0, 1, 2]]})

Inspecting checkpoints and remapping variables

When loading a checkpoint from a different model/namespace, inspect the variable names:

python -m tensorflow.python.tools.inspect_checkpoint \
  --file_name=/path/to/model.ckpt

Map checkpoint variable names to your in-graph Variabels:

import tensorflow as tf

# Build your graph/variables with explicit names
v_a = tf.get_variable('encoder/W', shape=[128, 64])
v_b = tf.get_variable('encoder/b', shape=[64])

# Map checkpoint names -> graph vars
mapping = {
    'enc_layer/weights': v_a,
    'enc_layer/biases': v_b,
}

loader = tf.train.Saver(var_list=mapping)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    loader.restore(sess, '/path/to/foreign_model.ckpt')

Fast in-memory rollback (no disk IO)

Capture variable states in memory and restore later (useful for early stopping rollbacks):

import tensorflow as tf

g = tf.get_default_graph()
vars_ = g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [g.get_operation_by_name(v.op.name + '/Assign') for v in vars_]
feed_slots = [op.inputs[1] for op in assign_ops]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # ... train ...
    snapshot = sess.run(vars_)
    # ... train more ...
    # rollback
    sess.run(assign_ops, feed_dict={slot: val for slot, val in zip(feed_slots, snapshot)})

SavedModel (TF 1.x): full-graph export with signtaures

Programmatic SavedModel export with explicit input/output signatures:

import tensorflow as tf

x = tf.placeholder(tf.float32, shape=[None], name='x')
w = tf.Variable(3.0, name='w')
b = tf.Variable(1.0, name='b')
y = tf.add(w * x, b, name='y')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    export_dir = './saved_model_v1'

    bldr = tf.saved_model.builder.SavedModelBuilder(export_dir)
    sig = tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'x': tf.saved_model.utils.build_tensor_info(x)},
        outputs={'y': tf.saved_model.utils.build_tensor_info(y)},
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
    )
    bldr.add_meta_graph_and_variables(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig
        },
    )
    bldr.save()

Load the SavedModel and run predictions:

import tensorflow as tf

export_dir = './saved_model_v1'
sess = tf.Session()
meta = tf.saved_model.loader.load(
    sess, [tf.saved_model.tag_constants.SERVING], export_dir)
sig = meta.signature_def
x_name = sig['serving_default'].inputs['x'].name
y_name = sig['serving_default'].outputs['y'].name
x_t = sess.graph.get_tensor_by_name(x_name)
y_t = sess.graph.get_tensor_by_name(y_name)
print(sess.run(y_t, feed_dict={x_t: [2.0, 4.0]}))

Simple export in TF 1.x with tf.saved_model.simple_save:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.placeholder(tf.float32, shape=[None, 2], name='a')
    k = tf.Variable([[2.0], [0.5]], name='k')
    b = tf.Variable([0.1], name='b')
    z = tf.add(tf.matmul(a, k), b, name='z')
    sess.run(tf.global_variables_initializer())
    tf.saved_model.simple_save(
        sess, './simple_sm', inputs={'a': a}, outputs={'z': z})

TF 1.x notes and tips

  • Prefix a relative path with "./" when restoring from the current working directory (e.g., saver.restore(sess, './model.ckpt')).
  • Initialize variables before restore in some setup; Saver will overwrite initialized values with checkpoint values.
  • Ensure variable/op names are unique and deterministic if you plan to import/export MetaGraphs.

TF 2.x and Keras: SavedModel / H5 and checkpoints

Save and load entire Keras models (architecture + weights + optimizer state):

import tensorflow as tf
from tensorflow import keras

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(32,)),
    keras.layers.Dense(10, activation='softmax'),
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# Train ... then save
model.save('./models/my_model')         # SavedModel directory (recommended)
model.save('./models/my_model.h5')      # H5 file (Keras format)

# Load
m1 = keras.models.load_model('./models/my_model')
m2 = keras.models.load_model('./models/my_model.h5')

Weights-only save/restore:

model.save_weights('./ckpts/ckpt')
# ... later or in a new process
reconstructed = keras.models.clone_model(model)
reconstructed.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
reconstructed.load_weights('./ckpts/ckpt')

Checkpoint callback during training:

cp_path = './train_ckpts/weights.{epoch:03d}.ckpt'
cb = tf.keras.callbacks.ModelCheckpoint(cp_path, save_weights_only=True, period=5)
model.fit(ds, epochs=50, callbacks=[cb])
# Restore most recent
latest = tf.train.latest_checkpoint('./train_ckpts')
model.load_weights(latest)

Custom objects when saving/loading H5 models:

import tensorflow as tf
from tensorflow import keras

@tf.function
def my_metric(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))

inp = keras.Input(shape=(10,))
out = keras.layers.Lambda(lambda t: tf.square(t))(inp)
net = keras.Model(inp, out)
net.compile(optimizer='adam', loss='mse', metrics=[my_metric])
net.save('./models/custom.h5')

restored = keras.models.load_model('./models/custom.h5',
                                   custom_objects={'my_metric': my_metric})

Object-based checkpoints with tf.train.Checkpoint (TF 2.x)

Object-based checkpoints track Python object dependencies and are resilient to code refactors:

import tensorflow as tf

class Toy(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.d1 = tf.keras.layers.Dense(8)
    def call(self, x):
        return self.d1(x)

model = Toy()
opt = tf.keras.optimizers.Adam(1e-3)
ckpt = tf.train.Checkpoint(optimizer=opt, model=model)
manager = tf.train.CheckpointManager(ckpt, './ob_ckpts', max_to_keep=3)

@tf.function
def train_step(x):
    with tf.GradientTape() as tape:
        y = model(x)
        loss = tf.reduce_mean(tf.square(y))
    grads = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))

# Attempt restore, then train and save
ckpt.restore(tf.train.latest_checkpoint('./ob_ckpts'))
for _ in range(100):
    train_step(tf.random.normal([32, 4]))
manager.save()

TF 1.x: naming and restoring by varible name

Give variables stable names to fetch by name after restoring a MetaGraph:

import tensorflow as tf

v = tf.get_variable('weights_1', shape=[5, 5])
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './byname/model')

# Later
saver = tf.train.import_meta_graph('./byname/model.meta')
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('./byname'))
    val = sess.run('weights_1:0')

MonitoredTrainingSession (TF 1.x)

MonitoredTrainingSession will handle checkpointing automatically when given a directory:

import tensorflow as tf

X = tf.placeholder(tf.float32, [None, 4])
Y = tf.layers.dense(X, 2)
loss = tf.reduce_mean(tf.square(Y))
train = tf.train.AdamOptimizer(1e-3).minimize(loss)

with tf.train.MonitoredTrainingSession(checkpoint_dir='./mchk') as mon_sess:
    while not mon_sess.should_stop():
        mon_sess.run(train, feed_dict={X: [[0, 1, 2, 3]]})

Common pitfalls

  • For TF 1.x, saving captures variable values; you must rebuild the same graph or import the MetaGraph before restoring.
  • Use tf.train.latest_checkpoint(dir) to pick the newest checkpoint in a directory.
  • Ensure unique, explicit names for variables/ops if exporting/importing graphs.
  • For cross-graph restores, supply var_list mapping to tf.train.Saver.
  • Prefer SavedModel for serving/export; perfer tf.train.Checkpoint for TF 2.x training workflows.

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.