1import json
2import math
3import random
4import numpy as np
5from os import listdir
6from os.path import isfile, join
7
8from scipy.spatial.transform import Rotation as R
9from scipy.spatial.transform import Slerp
10
11
12class MotionCaptureDataMultiClip(object):
13
14    def __init__(self):
15        self.Reset()
16
17    def Reset(self):
18        self._motion_data = {}
19        self._num_frames_min = 9999
20        self._num_frames_max = -1
21        self._num_clips = -1
22        self._names = []
23
24    def Load(self, path):
25        if isfile(path):
26            files = [path]
27        else:
28            files = [join(path, f) for f in listdir(path) if isfile(join(path, f))]
29            files.sort()
30        for i in range(len(files)):
31            with open(files[i], 'r') as f:
32                self._names.append(files[i].split('/')[-1])
33                self._motion_data[i] = json.load(f)
34                t = len(self._motion_data[i]['Frames'])
35                self._num_frames_min = min(t, self._num_frames_min)
36                self._num_frames_max = max(t, self._num_frames_max)
37        self._num_clips = len(self._motion_data)
38        self._downsample = True
39        if self._downsample:
40            self.downsampleClips()
41            self._num_frames_max = self._num_frames_min
42        else:
43            self.upsampleClips()
44            self._num_frames_min = self._num_frames_max
45
46    def getNumFrames(self):
47        if self._downsample:
48            return self._num_frames_min
49        else:
50            return self._num_frames_max
51
52    def getKeyFrameDuration(self, id=0):
53        return self._motion_data[id]['Frames'][0][0]
54
55    def getCycleTime(self):
56        keyFrameDuration = self.getKeyFrameDuration()
57        cycleTime = keyFrameDuration * (self.getNumFrames() - 1)
58        return cycleTime
59
60    def calcCycleCount(self, simTime, cycleTime):
61        phases = simTime / cycleTime
62        count = math.floor(phases)
63        return count
64
65    def computeCycleOffset(self, id=0):
66        lastFrame = self.getNumFrames() - 1
67        frameData = self._motion_data[id]['Frames'][0]
68        frameDataNext = self._motion_data[id]['Frames'][lastFrame]
69
70        basePosStart = [frameData[1], frameData[2], frameData[3]]
71        basePosEnd = [frameDataNext[1], frameDataNext[2], frameDataNext[3]]
72        self._cycleOffset = [
73            basePosEnd[0] - basePosStart[0], basePosEnd[1] - basePosStart[1],
74            basePosEnd[2] - basePosStart[2]
75        ]
76        return self._cycleOffset
77
78    def getNumClips(self):
79        return self._num_clips
80
81    def downsampleClips(self):
82        for i in range(self._num_clips):
83            n_frames = len(self._motion_data[i]['Frames'])
84            if n_frames != self._num_frames_min:
85                sample = random.sample(range(n_frames), self._num_frames_min)
86                sample.sort()
87                downsampled = np.array(self._motion_data[i]['Frames'])[sample]
88                self._motion_data[i]['Frames'] = downsampled.tolist()
89                #s = json.dumps(self._motion_data[i])
90                #with open("output/{}".format(self._names[i]), 'w') as f:
91                #    f.writelines(s)
92
93    def upsampleClips(self):
94        print("Max number of frames: ", self._num_frames_max)
95        for i in range(self._num_clips):
96            #print("Uspsampling clip number: ", i)
97            keyframe_duration = self.getKeyFrameDuration(i)
98            old_times = np.arange(0, len(self._motion_data[i]['Frames']) * keyframe_duration, keyframe_duration)
99            while len(old_times) < self._num_frames_max:
100                new_times, new_vals = self.slerpSingleClip(self._motion_data[i]['Frames'], old_times)
101                #print("Number of final frames: ", len(new_vals))
102                self._motion_data[i]['Frames'] = new_vals
103                old_times = new_times
104            #s = json.dumps(self._motion_data[i])
105            #with open("output/{}".format(self._names[i]), 'w') as f:
106            #    f.writelines(s)
107
108
109    def slerpSingleClip(self, clip, key_times):
110        #print("Number of initial frames: ", len(key_times))
111        org_clip = self.quatlist_to_quatlists(clip)
112        org_clip = np.asarray(org_clip)
113        t = org_clip[:, 0]
114        root_pos = org_clip[:, 1]
115        key_rots = org_clip[:, 2]
116        n_frames = len(key_rots)
117        assert len(key_times) == n_frames
118        needed_frames = self._num_frames_max - n_frames
119        #print("Needed frames: ", needed_frames)
120        inter_times = self.calc_inter_times(key_times)
121        inter_times = sorted(random.sample(inter_times, min(len(inter_times), needed_frames)))
122        #print("Number of frames to interpolate: ", len(inter_times))
123        #print("Number of rots: ", len(key_rots[0]))
124        inter_joint = []
125        for i in range(len(key_rots[0])):
126            quats = [rot[i] for rot in key_rots]
127            if len(quats[0]) == 4:
128                joint = R.from_quat(quats)
129                slerp = Slerp(key_times, joint)
130                interp_rots = slerp(inter_times)
131                interp_rots = interp_rots.as_quat().tolist()
132            else:
133                interp_rots = []
134                for tim in range(len(inter_times)):
135                    lb = key_times.tolist().index(max([st for st in key_times if st < inter_times[tim]]))
136                    ub = lb + 1
137                    #print(lb, ub)
138                    new_rot = (quats[ub][0] + quats[lb][0])/2
139                    interp_rots.append([new_rot])
140            inter_joint.append(interp_rots)
141        inter_joint = np.array(inter_joint).T
142        #print("Shape of interpolated joints: ", inter_joint.shape)
143        old_dict = dict(zip(key_times, key_rots))
144        new_dict = dict(zip(inter_times, inter_joint))
145        old_root_pos = dict(zip(key_times, root_pos))
146        inter_root_pos = self.calc_inter_root_pos(root_pos, key_times, inter_times)
147        new_root_pos = dict(zip(inter_times, inter_root_pos))
148        new_dict = {**old_dict, **new_dict}
149        new_rp_dict = {**old_root_pos, **new_root_pos}
150        ord_keys = sorted(new_dict.keys())
151        ord_rots = [new_dict[k] for k in ord_keys]
152        ord_root_pos = [new_rp_dict[k] for k in ord_keys]
153        new_clip = self.quatlists_to_quatlist(t, ord_root_pos, ord_rots)
154
155        return np.array(ord_keys), new_clip
156
157    def quatlist_to_quatlists(self, clip):
158        new_clips = []
159        for c in clip:
160            t = c[0]
161            root_pos = c[1:4]
162            root_rotation = c[4:8]
163            chest_rotation = c[8:12]
164            neck_rotation = c[12:16]
165            right_hip_rotation = c[16:20]
166            right_knee_rotation = c[20]
167            right_ankle_rotation = c[21:25]
168            right_shoulder_rotation = c[25:29]
169            right_elbow_rotation = c[29]
170            left_hip_rotation = c[30:34]
171            left_knee_rotation = c[34]
172            left_ankle_rotation = c[35:39]
173            left_shoulder_rotation = c[39:43]
174            left_elbow_rotation = c[43]
175            d = [
176                t,
177                root_pos,
178                [
179                    self.deepmimic_to_scipy_quaternion(root_rotation),
180                    self.deepmimic_to_scipy_quaternion(chest_rotation),
181                    self.deepmimic_to_scipy_quaternion(neck_rotation),
182                    self.deepmimic_to_scipy_quaternion(right_hip_rotation),
183                    [right_knee_rotation],
184                    self.deepmimic_to_scipy_quaternion(right_ankle_rotation),
185                    self.deepmimic_to_scipy_quaternion(right_shoulder_rotation),
186                    [right_elbow_rotation],
187                    self.deepmimic_to_scipy_quaternion(left_hip_rotation),
188                    [left_knee_rotation],
189                    self.deepmimic_to_scipy_quaternion(left_ankle_rotation),
190                    self.deepmimic_to_scipy_quaternion(left_shoulder_rotation),
191                    [left_elbow_rotation]
192                ]
193            ]
194            new_clips.append(d)
195        return new_clips
196
197    def deepmimic_to_scipy_quaternion(self, quat):
198        return quat[1:] + [quat[0]]
199
200    def scipy_to_deepmimic_quaternion(self, quat):
201        return [quat[-1]] + quat[:-1]
202
203    def calc_inter_times(self, times, method="intermediate"):
204        if method == "intermediate":
205            inter_times = []
206            for i in range(1, len(times)):
207                it = (times[i] - times[i-1])/2 + times[i-1]
208                inter_times.append(it)
209            return inter_times
210
211    def calc_inter_root_pos(self, root_pos, times, inter_times):
212        inter_root_pos = []
213        all_times = sorted([*times.tolist(), *inter_times])
214        for i in range(len(inter_times)):
215            low_index = times.tolist().index(all_times[all_times.index(inter_times[i]) - 1])
216            up_index = times.tolist().index(all_times[all_times.index(inter_times[i]) + 1])
217            assert low_index == up_index - 1
218            inter_root_pos.append(((np.array(root_pos[up_index]) + np.array(root_pos[low_index]))/2).tolist())
219        return inter_root_pos
220
221    def quatlists_to_quatlist(self, t, ord_root_pos, ord_rots):
222        delta_t = t[0]
223        new_quats = self.merge_quaternions(ord_rots)
224        a = []
225        for i in range(len(ord_root_pos)):
226            rot = new_quats[i]
227
228            a.append([
229                delta_t,
230                *ord_root_pos[i],
231                *rot
232            ])
233        return a
234
235    def merge_quaternions(self, rotations):
236        quats = []
237        for rot in rotations:
238            rots = []
239            for r in rot:
240                if len(r) == 4:
241                    r = self.scipy_to_deepmimic_quaternion(r)
242                rots += [el for el in r]
243            quats.append(rots)
244        return quats