diff --git a/smplx/body_models.py b/smplx/body_models.py index f21afec..354a522 100644 --- a/smplx/body_models.py +++ b/smplx/body_models.py @@ -14,31 +14,20 @@ # # Contact: ps-license@tuebingen.mpg.de -from typing import Optional, Dict, Union import os import os.path as osp - import pickle +from collections import namedtuple +from typing import Dict, Optional, Union import numpy as np - import torch import torch.nn as nn -from .lbs import ( - lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords, blend_shapes) - +from .lbs import blend_shapes, find_dynamic_lmk_idx_and_bcoords, lbs, vertices2landmarks +from .utils import Array, FLAMEOutput, MANOOutput, SMPLHOutput, SMPLOutput, SMPLXOutput, Struct, Tensor, find_joint_kin_chain, to_np, to_tensor from .vertex_ids import vertex_ids as VERTEX_IDS -from .utils import ( - Struct, to_np, to_tensor, Tensor, Array, - SMPLOutput, - SMPLHOutput, - SMPLXOutput, - MANOOutput, - FLAMEOutput, - find_joint_kin_chain) from .vertex_joint_selector import VertexJointSelector -from collections import namedtuple TensorOutput = namedtuple('TensorOutput', ['vertices', 'joints', 'betas', 'expression', 'global_orient', 'body_pose', 'left_hand_pose', @@ -357,24 +346,50 @@ def forward( ''' # If no shape and pose parameters are passed along, then use the # ones from the module - global_orient = (global_orient if global_orient is not None else - self.global_orient) - body_pose = body_pose if body_pose is not None else self.body_pose - betas = betas if betas is not None else self.betas + + bs = -1 + for v in [betas, body_pose, global_orient, transl]: + if v is not None: + bs = max(bs, v.shape[0]) + bs = self.batch_size if bs < 0 else bs # when no input was given + should_expand_bs = bs > 1 and self.batch_size == 1 + should_cut_default_bs = bs < self.batch_size + assert not (should_expand_bs and should_cut_default_bs) + + if global_orient is None: + if should_expand_bs: + global_orient = self.global_orient.expand(bs, *self.global_orient.shape[1:]) + elif should_cut_default_bs: + global_orient = self.global_orient[:bs] + else: + global_orient = self.global_orient + + if body_pose is None: + if should_expand_bs: + body_pose = self.body_pose.expand(bs, *self.body_pose.shape[1:]) + elif should_cut_default_bs: + body_pose = self.body_pose[:bs] + else: + body_pose = self.body_pose + + if betas is None: + betas = self.betas + if should_expand_bs and betas.shape[0] == 1: + betas = betas.expand(bs, *betas.shape[1:]) + elif should_cut_default_bs and betas.shape[0] > bs: + betas = betas[:bs] apply_trans = transl is not None or hasattr(self, 'transl') if transl is None and hasattr(self, 'transl'): - transl = self.transl + if should_expand_bs: + transl = self.transl.expand(bs, *self.transl.shape[1:]) + elif should_cut_default_bs: + transl = self.transl[:bs] + else: + transl = self.transl full_pose = torch.cat([global_orient, body_pose], dim=1) - batch_size = max(betas.shape[0], global_orient.shape[0], - body_pose.shape[0]) - - if betas.shape[0] != batch_size: - num_repeats = int(batch_size / betas.shape[0]) - betas = betas.expand(num_repeats, -1) - vertices, joints = lbs(betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, @@ -391,6 +406,7 @@ def forward( output = SMPLOutput(vertices=vertices if return_verts else None, global_orient=global_orient, + transl=transl, body_pose=body_pose, joints=joints, betas=betas, @@ -497,6 +513,7 @@ def forward( output = SMPLOutput(vertices=vertices if return_verts else None, global_orient=global_orient, + transl=transl, body_pose=body_pose, joints=joints, betas=betas, @@ -706,23 +723,65 @@ def forward( pose2rot: bool = True, **kwargs ) -> SMPLHOutput: - ''' - ''' + + bs = -1 + for v in [betas, global_orient, body_pose, left_hand_pose, right_hand_pose, transl]: + if v is not None: + bs = max(bs, v.shape[0]) + bs = self.batch_size if bs < 0 else bs # when no input was given + should_expand_bs = bs > 1 and self.batch_size == 1 + should_cut_default_bs = bs < self.batch_size + assert not (should_expand_bs and should_cut_default_bs) # If no shape and pose parameters are passed along, then use the # ones from the module - global_orient = (global_orient if global_orient is not None else - self.global_orient) - body_pose = body_pose if body_pose is not None else self.body_pose - betas = betas if betas is not None else self.betas - left_hand_pose = (left_hand_pose if left_hand_pose is not None else - self.left_hand_pose) - right_hand_pose = (right_hand_pose if right_hand_pose is not None else - self.right_hand_pose) + if global_orient is None: + if should_expand_bs: + global_orient = self.global_orient.expand(bs, *self.global_orient.shape[1:]) + elif should_cut_default_bs: + global_orient = self.global_orient[:bs] + else: + global_orient = self.global_orient + + if body_pose is None: + if should_expand_bs: + body_pose = self.body_pose.expand(bs, *self.body_pose.shape[1:]) + elif should_cut_default_bs: + body_pose = self.body_pose[:bs] + else: + body_pose = self.body_pose + + if betas is None: + if should_expand_bs: + betas = self.betas.expand(bs, *self.betas.shape[1:]) + elif should_cut_default_bs: + betas = self.betas[:bs] + else: + betas = self.betas + + if left_hand_pose is None: + if should_expand_bs: + left_hand_pose = self.left_hand_pose.expand(bs, *self.left_hand_pose.shape[1:]) + elif should_cut_default_bs: + left_hand_pose = self.left_hand_pose[:bs] + else: + left_hand_pose = self.left_hand_pose + + if right_hand_pose is None: + if should_expand_bs: + right_hand_pose = self.right_hand_pose.expand(bs, *self.right_hand_pose.shape[1:]) + elif should_cut_default_bs: + right_hand_pose = self.right_hand_pose[:bs] + else: + right_hand_pose = self.right_hand_pose apply_trans = transl is not None or hasattr(self, 'transl') - if transl is None: - if hasattr(self, 'transl'): + if transl is None and hasattr(self, 'transl'): + if should_expand_bs: + transl = self.transl.expand(bs, *self.transl.shape[1:]) + elif should_cut_default_bs: + transl = self.transl[:bs] + else: transl = self.transl if self.use_pca: @@ -754,6 +813,7 @@ def forward( joints=joints, betas=betas, global_orient=global_orient, + transl=transl, body_pose=body_pose, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, @@ -883,6 +943,7 @@ def forward( joints=joints, betas=betas, global_orient=global_orient, + transl=transl, body_pose=body_pose, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, @@ -1188,26 +1249,96 @@ def forward( output: ModelOutput A named tuple of type `ModelOutput` ''' + bs = -1 + for v in [betas, global_orient, body_pose, left_hand_pose, right_hand_pose, transl, expression, jaw_pose, leye_pose, reye_pose]: + if v is not None: + bs = max(bs, v.shape[0]) + bs = self.batch_size if bs < 0 else bs # when no input was given + should_expand_bs = bs > 1 and self.batch_size == 1 + should_cut_default_bs = bs < self.batch_size + assert not (should_expand_bs and should_cut_default_bs) # If no shape and pose parameters are passed along, then use the # ones from the module - global_orient = (global_orient if global_orient is not None else - self.global_orient) - body_pose = body_pose if body_pose is not None else self.body_pose - betas = betas if betas is not None else self.betas + if global_orient is None: + if should_expand_bs: + global_orient = self.global_orient.expand(bs, *self.global_orient.shape[1:]) + elif should_cut_default_bs: + global_orient = self.global_orient[:bs] + else: + global_orient = self.global_orient - left_hand_pose = (left_hand_pose if left_hand_pose is not None else - self.left_hand_pose) - right_hand_pose = (right_hand_pose if right_hand_pose is not None else - self.right_hand_pose) - jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose - leye_pose = leye_pose if leye_pose is not None else self.leye_pose - reye_pose = reye_pose if reye_pose is not None else self.reye_pose - expression = expression if expression is not None else self.expression + if body_pose is None: + if should_expand_bs: + body_pose = self.body_pose.expand(bs, *self.body_pose.shape[1:]) + elif should_cut_default_bs: + body_pose = self.body_pose[:bs] + else: + body_pose = self.body_pose + + if betas is None: + if should_expand_bs: + betas = self.betas.expand(bs, *self.betas.shape[1:]) + elif should_cut_default_bs: + betas = self.betas[:bs] + else: + betas = self.betas + + if left_hand_pose is None: + if should_expand_bs: + left_hand_pose = self.left_hand_pose.expand(bs, *self.left_hand_pose.shape[1:]) + elif should_cut_default_bs: + left_hand_pose = self.left_hand_pose[:bs] + else: + left_hand_pose = self.left_hand_pose + + if right_hand_pose is None: + if should_expand_bs: + right_hand_pose = self.right_hand_pose.expand(bs, *self.right_hand_pose.shape[1:]) + elif should_cut_default_bs: + right_hand_pose = self.right_hand_pose[:bs] + else: + right_hand_pose = self.right_hand_pose + + if expression is None: + if should_expand_bs: + expression = self.expression.expand(bs, *self.expression.shape[1:]) + elif should_cut_default_bs: + expression = self.expression[:bs] + else: + expression = self.expression + + if jaw_pose is None: + if should_expand_bs: + jaw_pose = self.jaw_pose.expand(bs, *self.jaw_pose.shape[1:]) + elif should_cut_default_bs: + jaw_pose = self.jaw_pose[:bs] + else: + jaw_pose = self.jaw_pose + + if leye_pose is None: + if should_expand_bs: + leye_pose = self.leye_pose.expand(bs, *self.leye_pose.shape[1:]) + elif should_cut_default_bs: + leye_pose = self.leye_pose[:bs] + else: + leye_pose = self.leye_pose + + if reye_pose is None: + if should_expand_bs: + reye_pose = self.reye_pose.expand(bs, *self.reye_pose.shape[1:]) + elif should_cut_default_bs: + reye_pose = self.reye_pose[:bs] + else: + reye_pose = self.reye_pose apply_trans = transl is not None or hasattr(self, 'transl') - if transl is None: - if hasattr(self, 'transl'): + if transl is None and hasattr(self, 'transl'): + if should_expand_bs: + transl = self.transl.expand(bs, *self.transl.shape[1:]) + elif should_cut_default_bs: + transl = self.transl[:bs] + else: transl = self.transl if self.use_pca: @@ -1496,7 +1627,7 @@ def forward( left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, - transl=transl if transl != None else Tensor(0), + transl=transl if transl is not None else Tensor(0), full_pose=full_pose if return_full_pose else Tensor(0)) return output @@ -1707,6 +1838,7 @@ def forward( joints=joints if return_verts else None, betas=betas, global_orient=global_orient, + transl=transl, hand_pose=hand_pose, full_pose=full_pose if return_full_pose else None) @@ -1773,6 +1905,7 @@ def forward( joints=joints if return_verts else None, betas=betas, global_orient=global_orient, + transl=transl, hand_pose=hand_pose, full_pose=full_pose if return_full_pose else None) @@ -2139,6 +2272,7 @@ def forward( betas=betas, expression=expression, global_orient=global_orient, + transl=transl, neck_pose=neck_pose, jaw_pose=jaw_pose, full_pose=full_pose if return_full_pose else None) @@ -2288,6 +2422,7 @@ def forward( betas=betas, expression=expression, global_orient=global_orient, + transl=transl, neck_pose=neck_pose, jaw_pose=jaw_pose, full_pose=full_pose if return_full_pose else None)