Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 189 additions & 54 deletions smplx/body_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,20 @@
#
# Contact: ps-license@tuebingen.mpg.de

from typing import Optional, Dict, Union
import os
import os.path as osp

import pickle
from collections import namedtuple
from typing import Dict, Optional, Union

import numpy as np

import torch
import torch.nn as nn

from .lbs import (
lbs, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords, blend_shapes)

from .lbs import blend_shapes, find_dynamic_lmk_idx_and_bcoords, lbs, vertices2landmarks
from .utils import Array, FLAMEOutput, MANOOutput, SMPLHOutput, SMPLOutput, SMPLXOutput, Struct, Tensor, find_joint_kin_chain, to_np, to_tensor
from .vertex_ids import vertex_ids as VERTEX_IDS
from .utils import (
Struct, to_np, to_tensor, Tensor, Array,
SMPLOutput,
SMPLHOutput,
SMPLXOutput,
MANOOutput,
FLAMEOutput,
find_joint_kin_chain)
from .vertex_joint_selector import VertexJointSelector
from collections import namedtuple

TensorOutput = namedtuple('TensorOutput',
['vertices', 'joints', 'betas', 'expression', 'global_orient', 'body_pose', 'left_hand_pose',
Expand Down Expand Up @@ -357,24 +346,50 @@ def forward(
'''
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else
self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas

bs = -1
for v in [betas, body_pose, global_orient, transl]:
if v is not None:
bs = max(bs, v.shape[0])
bs = self.batch_size if bs < 0 else bs # when no input was given
should_expand_bs = bs > 1 and self.batch_size == 1
should_cut_default_bs = bs < self.batch_size
assert not (should_expand_bs and should_cut_default_bs)

if global_orient is None:
if should_expand_bs:
global_orient = self.global_orient.expand(bs, *self.global_orient.shape[1:])
elif should_cut_default_bs:
global_orient = self.global_orient[:bs]
else:
global_orient = self.global_orient

if body_pose is None:
if should_expand_bs:
body_pose = self.body_pose.expand(bs, *self.body_pose.shape[1:])
elif should_cut_default_bs:
body_pose = self.body_pose[:bs]
else:
body_pose = self.body_pose

if betas is None:
betas = self.betas
if should_expand_bs and betas.shape[0] == 1:
betas = betas.expand(bs, *betas.shape[1:])
elif should_cut_default_bs and betas.shape[0] > bs:
betas = betas[:bs]

apply_trans = transl is not None or hasattr(self, 'transl')
if transl is None and hasattr(self, 'transl'):
transl = self.transl
if should_expand_bs:
transl = self.transl.expand(bs, *self.transl.shape[1:])
elif should_cut_default_bs:
transl = self.transl[:bs]
else:
transl = self.transl

full_pose = torch.cat([global_orient, body_pose], dim=1)

batch_size = max(betas.shape[0], global_orient.shape[0],
body_pose.shape[0])

if betas.shape[0] != batch_size:
num_repeats = int(batch_size / betas.shape[0])
betas = betas.expand(num_repeats, -1)

vertices, joints = lbs(betas, full_pose, self.v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
Expand All @@ -391,6 +406,7 @@ def forward(

output = SMPLOutput(vertices=vertices if return_verts else None,
global_orient=global_orient,
transl=transl,
body_pose=body_pose,
joints=joints,
betas=betas,
Expand Down Expand Up @@ -497,6 +513,7 @@ def forward(

output = SMPLOutput(vertices=vertices if return_verts else None,
global_orient=global_orient,
transl=transl,
body_pose=body_pose,
joints=joints,
betas=betas,
Expand Down Expand Up @@ -706,23 +723,65 @@ def forward(
pose2rot: bool = True,
**kwargs
) -> SMPLHOutput:
'''
'''

bs = -1
for v in [betas, global_orient, body_pose, left_hand_pose, right_hand_pose, transl]:
if v is not None:
bs = max(bs, v.shape[0])
bs = self.batch_size if bs < 0 else bs # when no input was given
should_expand_bs = bs > 1 and self.batch_size == 1
should_cut_default_bs = bs < self.batch_size
assert not (should_expand_bs and should_cut_default_bs)

# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else
self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas
left_hand_pose = (left_hand_pose if left_hand_pose is not None else
self.left_hand_pose)
right_hand_pose = (right_hand_pose if right_hand_pose is not None else
self.right_hand_pose)
if global_orient is None:
if should_expand_bs:
global_orient = self.global_orient.expand(bs, *self.global_orient.shape[1:])
elif should_cut_default_bs:
global_orient = self.global_orient[:bs]
else:
global_orient = self.global_orient

if body_pose is None:
if should_expand_bs:
body_pose = self.body_pose.expand(bs, *self.body_pose.shape[1:])
elif should_cut_default_bs:
body_pose = self.body_pose[:bs]
else:
body_pose = self.body_pose

if betas is None:
if should_expand_bs:
betas = self.betas.expand(bs, *self.betas.shape[1:])
elif should_cut_default_bs:
betas = self.betas[:bs]
else:
betas = self.betas

if left_hand_pose is None:
if should_expand_bs:
left_hand_pose = self.left_hand_pose.expand(bs, *self.left_hand_pose.shape[1:])
elif should_cut_default_bs:
left_hand_pose = self.left_hand_pose[:bs]
else:
left_hand_pose = self.left_hand_pose

if right_hand_pose is None:
if should_expand_bs:
right_hand_pose = self.right_hand_pose.expand(bs, *self.right_hand_pose.shape[1:])
elif should_cut_default_bs:
right_hand_pose = self.right_hand_pose[:bs]
else:
right_hand_pose = self.right_hand_pose

apply_trans = transl is not None or hasattr(self, 'transl')
if transl is None:
if hasattr(self, 'transl'):
if transl is None and hasattr(self, 'transl'):
if should_expand_bs:
transl = self.transl.expand(bs, *self.transl.shape[1:])
elif should_cut_default_bs:
transl = self.transl[:bs]
else:
transl = self.transl

if self.use_pca:
Expand Down Expand Up @@ -754,6 +813,7 @@ def forward(
joints=joints,
betas=betas,
global_orient=global_orient,
transl=transl,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
Expand Down Expand Up @@ -883,6 +943,7 @@ def forward(
joints=joints,
betas=betas,
global_orient=global_orient,
transl=transl,
body_pose=body_pose,
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
Expand Down Expand Up @@ -1188,26 +1249,96 @@ def forward(
output: ModelOutput
A named tuple of type `ModelOutput`
'''
bs = -1
for v in [betas, global_orient, body_pose, left_hand_pose, right_hand_pose, transl, expression, jaw_pose, leye_pose, reye_pose]:
if v is not None:
bs = max(bs, v.shape[0])
bs = self.batch_size if bs < 0 else bs # when no input was given
should_expand_bs = bs > 1 and self.batch_size == 1
should_cut_default_bs = bs < self.batch_size
assert not (should_expand_bs and should_cut_default_bs)

# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else
self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas
if global_orient is None:
if should_expand_bs:
global_orient = self.global_orient.expand(bs, *self.global_orient.shape[1:])
elif should_cut_default_bs:
global_orient = self.global_orient[:bs]
else:
global_orient = self.global_orient

left_hand_pose = (left_hand_pose if left_hand_pose is not None else
self.left_hand_pose)
right_hand_pose = (right_hand_pose if right_hand_pose is not None else
self.right_hand_pose)
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
leye_pose = leye_pose if leye_pose is not None else self.leye_pose
reye_pose = reye_pose if reye_pose is not None else self.reye_pose
expression = expression if expression is not None else self.expression
if body_pose is None:
if should_expand_bs:
body_pose = self.body_pose.expand(bs, *self.body_pose.shape[1:])
elif should_cut_default_bs:
body_pose = self.body_pose[:bs]
else:
body_pose = self.body_pose

if betas is None:
if should_expand_bs:
betas = self.betas.expand(bs, *self.betas.shape[1:])
elif should_cut_default_bs:
betas = self.betas[:bs]
else:
betas = self.betas

if left_hand_pose is None:
if should_expand_bs:
left_hand_pose = self.left_hand_pose.expand(bs, *self.left_hand_pose.shape[1:])
elif should_cut_default_bs:
left_hand_pose = self.left_hand_pose[:bs]
else:
left_hand_pose = self.left_hand_pose

if right_hand_pose is None:
if should_expand_bs:
right_hand_pose = self.right_hand_pose.expand(bs, *self.right_hand_pose.shape[1:])
elif should_cut_default_bs:
right_hand_pose = self.right_hand_pose[:bs]
else:
right_hand_pose = self.right_hand_pose

if expression is None:
if should_expand_bs:
expression = self.expression.expand(bs, *self.expression.shape[1:])
elif should_cut_default_bs:
expression = self.expression[:bs]
else:
expression = self.expression

if jaw_pose is None:
if should_expand_bs:
jaw_pose = self.jaw_pose.expand(bs, *self.jaw_pose.shape[1:])
elif should_cut_default_bs:
jaw_pose = self.jaw_pose[:bs]
else:
jaw_pose = self.jaw_pose

if leye_pose is None:
if should_expand_bs:
leye_pose = self.leye_pose.expand(bs, *self.leye_pose.shape[1:])
elif should_cut_default_bs:
leye_pose = self.leye_pose[:bs]
else:
leye_pose = self.leye_pose

if reye_pose is None:
if should_expand_bs:
reye_pose = self.reye_pose.expand(bs, *self.reye_pose.shape[1:])
elif should_cut_default_bs:
reye_pose = self.reye_pose[:bs]
else:
reye_pose = self.reye_pose

apply_trans = transl is not None or hasattr(self, 'transl')
if transl is None:
if hasattr(self, 'transl'):
if transl is None and hasattr(self, 'transl'):
if should_expand_bs:
transl = self.transl.expand(bs, *self.transl.shape[1:])
elif should_cut_default_bs:
transl = self.transl[:bs]
else:
transl = self.transl

if self.use_pca:
Expand Down Expand Up @@ -1496,7 +1627,7 @@ def forward(
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
jaw_pose=jaw_pose,
transl=transl if transl != None else Tensor(0),
transl=transl if transl is not None else Tensor(0),
full_pose=full_pose if return_full_pose else Tensor(0))

return output
Expand Down Expand Up @@ -1707,6 +1838,7 @@ def forward(
joints=joints if return_verts else None,
betas=betas,
global_orient=global_orient,
transl=transl,
hand_pose=hand_pose,
full_pose=full_pose if return_full_pose else None)

Expand Down Expand Up @@ -1773,6 +1905,7 @@ def forward(
joints=joints if return_verts else None,
betas=betas,
global_orient=global_orient,
transl=transl,
hand_pose=hand_pose,
full_pose=full_pose if return_full_pose else None)

Expand Down Expand Up @@ -2139,6 +2272,7 @@ def forward(
betas=betas,
expression=expression,
global_orient=global_orient,
transl=transl,
neck_pose=neck_pose,
jaw_pose=jaw_pose,
full_pose=full_pose if return_full_pose else None)
Expand Down Expand Up @@ -2288,6 +2422,7 @@ def forward(
betas=betas,
expression=expression,
global_orient=global_orient,
transl=transl,
neck_pose=neck_pose,
jaw_pose=jaw_pose,
full_pose=full_pose if return_full_pose else None)
Expand Down