Implementing the Human Motion Diffusion Model for Text-to-Motion Generation
Dependencies and Setup
Install the required Python package:
pip install spacy
Dataset Preparation
Clone the HumanML3D dataset repository from GitHub. Move the HumanML3D directory into you're project's dataset folder. Extract the texts.zip archive into a texts subdirectory within the dataset folder.
Online Demo
A hosted API for the model is available on Replicate.
Inference Script for Batch Generation
The following script loads a pre-trained model and generates motion sequences from text prompts.
import os
import numpy as np
import torch
import shutil
from utils.fixseed import fixseed
from utils.parser_util import generate_args
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from utils import dist_util
from model.cfg_sampler import ClassifierFreeSampleModel
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric
import data_loaders.humanml.utils.paramUtil as paramUtil
from data_loaders.humanml.utils.plot_script import plot_3d_motion
from data_loaders.tensors import collate
def generate_motion_batch():
args = generate_args()
fixseed(args.seed)
# Configuration
args.text_prompt = ''
args.input_text = 'assets/example_text_prompts.txt'
args.model_path = 'humanml_trans_enc_512/model000200000.pt'
output_dir = args.output_dir
dataset_name = args.dataset
max_seq_len = 196 if dataset_name in ['kit', 'humanml'] else 60
frame_rate = 12.5 if dataset_name == 'kit' else 20
num_frames = min(max_seq_len, int(args.motion_length * frame_rate))
dist_util.setup_dist(args.device)
# Determine output directory
if output_dir == '':
model_dir = os.path.dirname(args.model_path)
iter_tag = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
output_dir = os.path.join(model_dir, f'samples_{iter_tag}_seed{args.seed}')
if args.text_prompt:
output_dir += '_' + args.text_prompt.replace(' ', '_').replace('.', '')
elif args.input_text:
fname = os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '')
output_dir += f'_{fname}'
# Load text prompts
prompt_list = []
if args.text_prompt:
prompt_list = [args.text_prompt]
args.num_samples = 1
elif args.input_text:
with open(args.input_text, 'r') as f:
prompt_list = [line.strip() for line in f.readlines()]
args.num_samples = len(prompt_list)
elif args.action_name:
prompt_list = [args.action_name]
args.num_samples = 1
elif args.action_file:
with open(args.action_file, 'r') as f:
prompt_list = [line.strip() for line in f.readlines()]
args.num_samples = len(prompt_list)
assert args.num_samples <= args.batch_size, 'Reduce num_samples or increase batch_size'
args.batch_size = args.num_samples
print('Loading dataset...')
data_loader = get_dataset_loader(name=dataset_name,
batch_size=args.batch_size,
num_frames=max_seq_len,
split='test',
hml_mode='text_only')
if dataset_name in ['kit', 'humanml']:
data_loader.dataset.t2m_dataset.fixed_length = num_frames
total_samples = args.num_samples * args.num_repetitions
print('Creating model and diffusion...')
network, diffusion_process = create_model_and_diffusion(args, data_loader)
print(f'Loading checkpoint from {args.model_path}...')
checkpoint = torch.load(args.model_path, map_location='cpu')
load_model_wo_clip(network, checkpoint)
if args.guidance_param != 1:
network = ClassifierFreeSampleModel(network)
network.to(dist_util.dev())
network.eval()
# Prepare model inputs
if not any([args.input_text, args.text_prompt, args.action_file, args.action_name]):
data_iter = iter(data_loader)
_, model_inputs = next(data_iter)
else:
base_args = [{'inp': torch.zeros(num_frames), 'tokens': None, 'lengths': num_frames}] * args.num_samples
is_text_to_motion = any([args.input_text, args.text_prompt])
if is_text_to_motion:
collate_args = [dict(arg, text=txt) for arg, txt in zip(base_args, prompt_list)]
else:
action_data = data_loader.dataset.action_name_to_action(prompt_list)
collate_args = [dict(arg, action=act, action_text=txt) for arg, act, txt in zip(base_args, action_data, prompt_list)]
_, model_inputs = collate(collate_args)
all_generated_motions = []
all_seq_lengths = []
all_prompts_used = []
for repetition in range(args.num_repetitions):
print(f'Generating batch #{repetition}')
if args.guidance_param != 1:
model_inputs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
sampling_function = diffusion_process.p_sample_loop
generated_sample = sampling_function(
network,
(args.batch_size, network.njoints, network.nfeats, max_seq_len),
clip_denoised=False,
model_kwargs=model_inputs,
skip_timesteps=0,
init_image=None,
progress=True,
dump_steps=None,
noise=None,
const_noise=False,
)
# Convert representation to 3D coordinates
if network.data_rep == 'hml_vec':
num_joints = 22 if generated_sample.shape[1] == 263 else 21
generated_sample = data_loader.dataset.t2m_dataset.inv_transform(generated_sample.cpu().permute(0, 2, 3, 1)).float()
generated_sample = recover_from_ric(generated_sample, num_joints)
generated_sample = generated_sample.view(-1, *generated_sample.shape[2:]).permute(0, 2, 3, 1)
pose_representation = 'xyz' if network.data_rep in ['xyz', 'hml_vec'] else network.data_rep
mask_for_conversion = None if pose_representation == 'xyz' else model_inputs['y']['mask'].reshape(args.batch_size, num_frames).bool()
generated_sample = network.rot2xyz(x=generated_sample, mask=mask_for_conversion, pose_rep=pose_representation, glob=True, translation=True,
jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
get_rotations_back=False)
if args.unconstrained:
all_prompts_used += ['unconstrained'] * args.num_samples
else:
key_for_text = 'text' if 'text' in model_inputs['y'] else 'action_text'
all_prompts_used += model_inputs['y'][key_for_text]
all_generated_motions.append(generated_sample.cpu().numpy())
all_seq_lengths.append(model_inputs['y']['lengths'].cpu().numpy())
print(f'Generated {len(all_generated_motions) * args.batch_size} samples')
# Concatenate and trim results
all_generated_motions = np.concatenate(all_generated_motions, axis=0)[:total_samples]
all_prompts_used = all_prompts_used[:total_samples]
all_seq_lengths = np.concatenate(all_seq_lengths, axis=0)[:total_samples]
# Save outputs
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
np_output_path = os.path.join(output_dir, 'results.npy')
print(f'Saving results to {np_output_path}')
np.save(np_output_path,
{'motion': all_generated_motions, 'text': all_prompts_used, 'lengths': all_seq_lengths,
'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions})
with open(np_output_path.replace('.npy', '.txt'), 'w') as f:
f.write('\n'.join(all_prompts_used))
with open(np_output_path.replace('.npy', '_len.txt'), 'w') as f:
f.write('\n'.join([str(l) for l in all_seq_lengths]))
# Visualize motions
print(f'Saving visualizations to {output_dir}')
skeleton_structure = paramUtil.kit_kinematic_chain if dataset_name == 'kit' else paramUtil.t2m_kinematic_chain
for sample_idx in range(args.num_samples):
for rep_idx in range(args.num_repetitions):
caption_text = all_prompts_used[rep_idx * args.batch_size + sample_idx]
seq_len = all_seq_lengths[rep_idx * args.batch_size + sample_idx]
motion_data = all_generated_motions[rep_idx * args.batch_size + sample_idx].transpose(2, 0, 1)[:seq_len]
save_filename = f'sample{sample_idx:02d}_rep{rep_idx:02d}.mp4'
save_path = os.path.join(output_dir, save_filename)
plot_3d_motion(save_path, skeleton_structure, motion_data, dataset=dataset_name, title=caption_text, fps=frame_rate)
abs_output_path = os.path.abspath(output_dir)
print(f'[Complete] Results saved at {abs_output_path}')
if __name__ == '__main__':
generate_motion_batch()
Mesh Rendering from Stick Figures
This script converts generated stick figure animations into 3D mesh representations.
import argparse
import os
import shutil
from tqdm import tqdm
from visualize import vis_utils
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str, required=True, help='Path to input stick figure MP4 file.')
parser.add_argument('--cuda', type=bool, default=True)
parser.add_argument('--device', type=int, default=0)
config = parser.parse_args()
assert config.input_path.endswith('.mp4')
base_name = os.path.basename(config.input_path).replace('.mp4', '').replace('sample', '').replace('rep', '')
sample_index, rep_index = [int(x) for x in base_name.split('_')]
npy_file = os.path.join(os.path.dirname(config.input_path), 'results.npy')
output_npy = config.input_path.replace('.mp4', '_smpl_params.npy')
output_dir = config.input_path.replace('.mp4', '_obj')
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
converter = vis_utils.npy2obj(npy_file, sample_index, rep_index,
device=config.device, cuda=config.cuda)
print(f'Saving OBJ files to {os.path.abspath(output_dir)}')
for frame in tqdm(range(converter.real_num_frames)):
converter.save_obj(os.path.join(output_dir, f'frame{frame:03d}.obj'), frame)
print(f'Saving SMPL parameters to {os.path.abspath(output_npy)}')
converter.save_npy(output_npy)
Modified Visualization for Windows Systems
The original visualization code may not display properly on Windows. Here is an adapted version.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from textwrap import wrap
def visualize_motion_sequence(output_file, skeleton, joint_positions, title_text, dataset_name, fig_dim=(3, 3), frame_rate=120, axis_scale=3):
fig = plt.figure(figsize=fig_dim)
ax = fig.add_subplot(111, projection='3d')
plt.tight_layout()
title_text = '\n'.join(wrap(title_text, 20))
color_palette = ["#DD5A37", "#D69E00", "#B75A39", "#FF6D00", "#DDB50E"]
def initialize_plot():
ax.set_xlim3d([-axis_scale / 2, axis_scale / 2])
ax.set_ylim3d([0, axis_scale])
ax.set_zlim3d([-axis_scale / 3., axis_scale * 2 / 3.])
fig.suptitle(title_text, fontsize=10)
ax.grid(False)
return fig,
def update_frame(frame_idx):
# Clear previous frame drawings
while len(ax.lines) > 0:
ax.lines[0].remove()
while len(ax.collections) > 0:
ax.collections[0].remove()
ax.view_init(elev=120, azim=-90)
ax.dist = 7.5
# Draw skeleton for current frame
for chain_idx, (bone_chain, color) in enumerate(zip(skeleton, color_palette)):
line_width = 4.0 if chain_idx < 5 else 2.0
ax.plot3D(joint_data[frame_idx, bone_chain, 0],
joint_data[frame_idx, bone_chain, 1],
joint_data[frame_idx, bone_chain, 2],
linewidth=line_width, color=color)
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
return fig,
joint_data = joint_positions.copy().reshape(len(joint_positions), -1, 3)
total_frames = joint_data.shape[0]
anim = FuncAnimation(fig, update_frame, frames=total_frames, interval=1000 / frame_rate, repeat=False, init_func=initialize_plot)
anim.save(output_file, fps=frame_rate)
plt.close()
# Example usage
if __name__ == '__main__':
skeleton_def = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
loaded_data = np.load('motion_data.npz', allow_pickle=True)
joint_pos = loaded_data['joints_3d'].item()['data']
joint_pos /= 20 # Scale adjustment
visualize_motion_sequence('output_animation.mp4', skeleton_def, joint_pos, 'A walking motion', 'humanml')