Fading Coder

One Final Commit for the Last Sprint

Home > Notes > Content

Implementing the Human Motion Diffusion Model for Text-to-Motion Generation

Notes 1

Dependencies and Setup

Install the required Python package:

pip install spacy

Dataset Preparation

Clone the HumanML3D dataset repository from GitHub. Move the HumanML3D directory into you're project's dataset folder. Extract the texts.zip archive into a texts subdirectory within the dataset folder.

Online Demo

A hosted API for the model is available on Replicate.

Inference Script for Batch Generation

The following script loads a pre-trained model and generates motion sequences from text prompts.

import os
import numpy as np
import torch
import shutil
from utils.fixseed import fixseed
from utils.parser_util import generate_args
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from utils import dist_util
from model.cfg_sampler import ClassifierFreeSampleModel
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric
import data_loaders.humanml.utils.paramUtil as paramUtil
from data_loaders.humanml.utils.plot_script import plot_3d_motion
from data_loaders.tensors import collate

def generate_motion_batch():
    args = generate_args()
    fixseed(args.seed)
    
    # Configuration
    args.text_prompt = ''
    args.input_text = 'assets/example_text_prompts.txt'
    args.model_path = 'humanml_trans_enc_512/model000200000.pt'
    output_dir = args.output_dir
    
    dataset_name = args.dataset
    max_seq_len = 196 if dataset_name in ['kit', 'humanml'] else 60
    frame_rate = 12.5 if dataset_name == 'kit' else 20
    num_frames = min(max_seq_len, int(args.motion_length * frame_rate))
    
    dist_util.setup_dist(args.device)
    
    # Determine output directory
    if output_dir == '':
        model_dir = os.path.dirname(args.model_path)
        iter_tag = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
        output_dir = os.path.join(model_dir, f'samples_{iter_tag}_seed{args.seed}')
        if args.text_prompt:
            output_dir += '_' + args.text_prompt.replace(' ', '_').replace('.', '')
        elif args.input_text:
            fname = os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '')
            output_dir += f'_{fname}'
    
    # Load text prompts
    prompt_list = []
    if args.text_prompt:
        prompt_list = [args.text_prompt]
        args.num_samples = 1
    elif args.input_text:
        with open(args.input_text, 'r') as f:
            prompt_list = [line.strip() for line in f.readlines()]
        args.num_samples = len(prompt_list)
    elif args.action_name:
        prompt_list = [args.action_name]
        args.num_samples = 1
    elif args.action_file:
        with open(args.action_file, 'r') as f:
            prompt_list = [line.strip() for line in f.readlines()]
        args.num_samples = len(prompt_list)
    
    assert args.num_samples <= args.batch_size, 'Reduce num_samples or increase batch_size'
    args.batch_size = args.num_samples
    
    print('Loading dataset...')
    data_loader = get_dataset_loader(name=dataset_name,
                                     batch_size=args.batch_size,
                                     num_frames=max_seq_len,
                                     split='test',
                                     hml_mode='text_only')
    if dataset_name in ['kit', 'humanml']:
        data_loader.dataset.t2m_dataset.fixed_length = num_frames
    
    total_samples = args.num_samples * args.num_repetitions
    
    print('Creating model and diffusion...')
    network, diffusion_process = create_model_and_diffusion(args, data_loader)
    
    print(f'Loading checkpoint from {args.model_path}...')
    checkpoint = torch.load(args.model_path, map_location='cpu')
    load_model_wo_clip(network, checkpoint)
    
    if args.guidance_param != 1:
        network = ClassifierFreeSampleModel(network)
    network.to(dist_util.dev())
    network.eval()
    
    # Prepare model inputs
    if not any([args.input_text, args.text_prompt, args.action_file, args.action_name]):
        data_iter = iter(data_loader)
        _, model_inputs = next(data_iter)
    else:
        base_args = [{'inp': torch.zeros(num_frames), 'tokens': None, 'lengths': num_frames}] * args.num_samples
        is_text_to_motion = any([args.input_text, args.text_prompt])
        if is_text_to_motion:
            collate_args = [dict(arg, text=txt) for arg, txt in zip(base_args, prompt_list)]
        else:
            action_data = data_loader.dataset.action_name_to_action(prompt_list)
            collate_args = [dict(arg, action=act, action_text=txt) for arg, act, txt in zip(base_args, action_data, prompt_list)]
        _, model_inputs = collate(collate_args)
    
    all_generated_motions = []
    all_seq_lengths = []
    all_prompts_used = []
    
    for repetition in range(args.num_repetitions):
        print(f'Generating batch #{repetition}')
        
        if args.guidance_param != 1:
            model_inputs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
        
        sampling_function = diffusion_process.p_sample_loop
        
        generated_sample = sampling_function(
            network,
            (args.batch_size, network.njoints, network.nfeats, max_seq_len),
            clip_denoised=False,
            model_kwargs=model_inputs,
            skip_timesteps=0,
            init_image=None,
            progress=True,
            dump_steps=None,
            noise=None,
            const_noise=False,
        )
        
        # Convert representation to 3D coordinates
        if network.data_rep == 'hml_vec':
            num_joints = 22 if generated_sample.shape[1] == 263 else 21
            generated_sample = data_loader.dataset.t2m_dataset.inv_transform(generated_sample.cpu().permute(0, 2, 3, 1)).float()
            generated_sample = recover_from_ric(generated_sample, num_joints)
            generated_sample = generated_sample.view(-1, *generated_sample.shape[2:]).permute(0, 2, 3, 1)
        
        pose_representation = 'xyz' if network.data_rep in ['xyz', 'hml_vec'] else network.data_rep
        mask_for_conversion = None if pose_representation == 'xyz' else model_inputs['y']['mask'].reshape(args.batch_size, num_frames).bool()
        generated_sample = network.rot2xyz(x=generated_sample, mask=mask_for_conversion, pose_rep=pose_representation, glob=True, translation=True,
                                           jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
                                           get_rotations_back=False)
        
        if args.unconstrained:
            all_prompts_used += ['unconstrained'] * args.num_samples
        else:
            key_for_text = 'text' if 'text' in model_inputs['y'] else 'action_text'
            all_prompts_used += model_inputs['y'][key_for_text]
        
        all_generated_motions.append(generated_sample.cpu().numpy())
        all_seq_lengths.append(model_inputs['y']['lengths'].cpu().numpy())
        
        print(f'Generated {len(all_generated_motions) * args.batch_size} samples')
    
    # Concatenate and trim results
    all_generated_motions = np.concatenate(all_generated_motions, axis=0)[:total_samples]
    all_prompts_used = all_prompts_used[:total_samples]
    all_seq_lengths = np.concatenate(all_seq_lengths, axis=0)[:total_samples]
    
    # Save outputs
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir)
    
    np_output_path = os.path.join(output_dir, 'results.npy')
    print(f'Saving results to {np_output_path}')
    np.save(np_output_path,
            {'motion': all_generated_motions, 'text': all_prompts_used, 'lengths': all_seq_lengths,
             'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions})
    
    with open(np_output_path.replace('.npy', '.txt'), 'w') as f:
        f.write('\n'.join(all_prompts_used))
    with open(np_output_path.replace('.npy', '_len.txt'), 'w') as f:
        f.write('\n'.join([str(l) for l in all_seq_lengths]))
    
    # Visualize motions
    print(f'Saving visualizations to {output_dir}')
    skeleton_structure = paramUtil.kit_kinematic_chain if dataset_name == 'kit' else paramUtil.t2m_kinematic_chain
    
    for sample_idx in range(args.num_samples):
        for rep_idx in range(args.num_repetitions):
            caption_text = all_prompts_used[rep_idx * args.batch_size + sample_idx]
            seq_len = all_seq_lengths[rep_idx * args.batch_size + sample_idx]
            motion_data = all_generated_motions[rep_idx * args.batch_size + sample_idx].transpose(2, 0, 1)[:seq_len]
            save_filename = f'sample{sample_idx:02d}_rep{rep_idx:02d}.mp4'
            save_path = os.path.join(output_dir, save_filename)
            plot_3d_motion(save_path, skeleton_structure, motion_data, dataset=dataset_name, title=caption_text, fps=frame_rate)
    
    abs_output_path = os.path.abspath(output_dir)
    print(f'[Complete] Results saved at {abs_output_path}')

if __name__ == '__main__':
    generate_motion_batch()

Mesh Rendering from Stick Figures

This script converts generated stick figure animations into 3D mesh representations.

import argparse
import os
import shutil
from tqdm import tqdm
from visualize import vis_utils

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', type=str, required=True, help='Path to input stick figure MP4 file.')
    parser.add_argument('--cuda', type=bool, default=True)
    parser.add_argument('--device', type=int, default=0)
    config = parser.parse_args()
    
    assert config.input_path.endswith('.mp4')
    base_name = os.path.basename(config.input_path).replace('.mp4', '').replace('sample', '').replace('rep', '')
    sample_index, rep_index = [int(x) for x in base_name.split('_')]
    
    npy_file = os.path.join(os.path.dirname(config.input_path), 'results.npy')
    output_npy = config.input_path.replace('.mp4', '_smpl_params.npy')
    output_dir = config.input_path.replace('.mp4', '_obj')
    
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir)
    
    converter = vis_utils.npy2obj(npy_file, sample_index, rep_index,
                                  device=config.device, cuda=config.cuda)
    
    print(f'Saving OBJ files to {os.path.abspath(output_dir)}')
    for frame in tqdm(range(converter.real_num_frames)):
        converter.save_obj(os.path.join(output_dir, f'frame{frame:03d}.obj'), frame)
    
    print(f'Saving SMPL parameters to {os.path.abspath(output_npy)}')
    converter.save_npy(output_npy)

Modified Visualization for Windows Systems

The original visualization code may not display properly on Windows. Here is an adapted version.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from textwrap import wrap

def visualize_motion_sequence(output_file, skeleton, joint_positions, title_text, dataset_name, fig_dim=(3, 3), frame_rate=120, axis_scale=3):
    fig = plt.figure(figsize=fig_dim)
    ax = fig.add_subplot(111, projection='3d')
    plt.tight_layout()
    
    title_text = '\n'.join(wrap(title_text, 20))
    
    color_palette = ["#DD5A37", "#D69E00", "#B75A39", "#FF6D00", "#DDB50E"]
    
    def initialize_plot():
        ax.set_xlim3d([-axis_scale / 2, axis_scale / 2])
        ax.set_ylim3d([0, axis_scale])
        ax.set_zlim3d([-axis_scale / 3., axis_scale * 2 / 3.])
        fig.suptitle(title_text, fontsize=10)
        ax.grid(False)
        return fig,
    
    def update_frame(frame_idx):
        # Clear previous frame drawings
        while len(ax.lines) > 0:
            ax.lines[0].remove()
        while len(ax.collections) > 0:
            ax.collections[0].remove()
        
        ax.view_init(elev=120, azim=-90)
        ax.dist = 7.5
        
        # Draw skeleton for current frame
        for chain_idx, (bone_chain, color) in enumerate(zip(skeleton, color_palette)):
            line_width = 4.0 if chain_idx < 5 else 2.0
            ax.plot3D(joint_data[frame_idx, bone_chain, 0],
                      joint_data[frame_idx, bone_chain, 1],
                      joint_data[frame_idx, bone_chain, 2],
                      linewidth=line_width, color=color)
        
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_zticklabels([])
        
        return fig,
    
    joint_data = joint_positions.copy().reshape(len(joint_positions), -1, 3)
    total_frames = joint_data.shape[0]
    
    anim = FuncAnimation(fig, update_frame, frames=total_frames, interval=1000 / frame_rate, repeat=False, init_func=initialize_plot)
    anim.save(output_file, fps=frame_rate)
    plt.close()

# Example usage
if __name__ == '__main__':
    skeleton_def = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
    loaded_data = np.load('motion_data.npz', allow_pickle=True)
    joint_pos = loaded_data['joints_3d'].item()['data']
    joint_pos /= 20  # Scale adjustment
    
    visualize_motion_sequence('output_animation.mp4', skeleton_def, joint_pos, 'A walking motion', 'humanml')

Related Articles

Designing Alertmanager Templates for Prometheus Notifications

How to craft Alertmanager templates to format alert messages, improving clarity and presentation. Alertmanager uses Go’s text/template engine with additional helper functions. Alerting rules referenc...

Deploying a Maven Web Application to Tomcat 9 Using the Tomcat Manager

Tomcat 9 does not provide a dedicated Maven plugin. The Tomcat Manager interface, however, is backward-compatible, so the Tomcat 7 Maven Plugin can be used to deploy to Tomcat 9. This guide shows two...

Skipping Errors in MySQL Asynchronous Replication

When a replica halts because the SQL thread encounters an error, you can resume replication by skipping the problematic event(s). Two common approaches are available. Methods to Skip Errors 1) Skip a...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.