1""" Symbolic versions of the DICOM orientation mathemeatics.
2
3Notes on the SPM orientation machinery.
4
5There are symbolic versions of the code in ``spm_dicom_convert``,
6``write_volume`` subfunction, around line 509 in the version I have (SPM8, late
72009 vintage).
8"""
9
10import numpy as np
11
12import sympy
13from sympy import Matrix, Symbol, symbols, zeros, ones, eye
14
15
16# The code below is general (independent of SPMs code)
17def numbered_matrix(nrows, ncols, symbol_prefix):
18    return Matrix(nrows, ncols, lambda i, j: Symbol(
19        symbol_prefix + '_{%d%d}' % (i + 1, j + 1)))
20
21
22def numbered_vector(nrows, symbol_prefix):
23    return Matrix(nrows, 1, lambda i, j: Symbol(
24        symbol_prefix + '_{%d}' % (i + 1)))
25
26
27# premultiplication matrix to go from 0 based to 1 based indexing
28one_based = eye(4)
29one_based[:3, 3] = (1, 1, 1)
30# premult for swapping row and column indices
31row_col_swap = eye(4)
32row_col_swap[:, 0] = eye(4)[:, 1]
33row_col_swap[:, 1] = eye(4)[:, 0]
34
35# various worming matrices
36orient_pat = numbered_matrix(3, 2, 'F')
37orient_cross = numbered_vector(3, 'n')
38missing_r_col = numbered_vector(3, 'k')
39pos_pat_0 = numbered_vector(3, 'T^1')
40pos_pat_N = numbered_vector(3, 'T^N')
41pixel_spacing = symbols((r'\Delta{r}', r'\Delta{c}'))
42NZ = Symbol('N')
43slice_spacing = Symbol(r'\Delta{s}')
44
45R3 = orient_pat * np.diag(pixel_spacing)
46R = zeros(4, 2)
47R[:3, :] = R3
48
49# The following is specific to the SPM algorithm.
50x1 = ones(4, 1)
51y1 = ones(4, 1)
52y1[:3, :] = pos_pat_0
53
54to_inv = zeros(4, 4)
55to_inv[:, 0] = x1
56to_inv[:, 1] = symbols('a b c d')
57to_inv[0, 2] = 1
58to_inv[1, 3] = 1
59inv_lhs = zeros(4, 4)
60inv_lhs[:, 0] = y1
61inv_lhs[:, 1] = symbols('e f g h')
62inv_lhs[:, 2:] = R
63
64
65def spm_full_matrix(x2, y2):
66    rhs = to_inv[:, :]
67    rhs[:, 1] = x2
68    lhs = inv_lhs[:, :]
69    lhs[:, 1] = y2
70    return lhs * rhs.inv()
71
72
73# single slice case
74orient = zeros(3, 3)
75orient[:3, :2] = orient_pat
76orient[:, 2] = orient_cross
77x2_ss = Matrix((0, 0, 1, 0))
78y2_ss = zeros(4, 1)
79y2_ss[:3, :] = orient * Matrix((0, 0, slice_spacing))
80A_ss = spm_full_matrix(x2_ss, y2_ss)
81
82# many slice case
83x2_ms = Matrix((1, 1, NZ, 1))
84y2_ms = ones(4, 1)
85y2_ms[:3, :] = pos_pat_N
86A_ms = spm_full_matrix(x2_ms, y2_ms)
87
88# End of SPM algorithm
89
90# Rather simpler derivation from DICOM affine formulae - see
91# dicom_orientation.rst
92
93# single slice case
94single_aff = eye(4)
95rot = orient
96rot_scale = rot * np.diag(pixel_spacing[:] + (slice_spacing,))
97single_aff[:3, :3] = rot_scale
98single_aff[:3, 3] = pos_pat_0
99
100# For multi-slice case, we have the start and the end slice position
101# patient.  This gives us the third column of the affine, because,
102# ``pat_pos_N = aff * [[0,0,ZN-1,1]].T
103multi_aff = eye(4)
104multi_aff[:3, :2] = R3
105trans_z_N = Matrix((0, 0, NZ - 1, 1))
106multi_aff[:3, 2] = missing_r_col
107multi_aff[:3, 3] = pos_pat_0
108est_pos_pat_N = multi_aff * trans_z_N
109eqns = tuple(est_pos_pat_N[:3, 0] - pos_pat_N)
110solved = sympy.solve(eqns, tuple(missing_r_col))
111multi_aff_solved = multi_aff[:, :]
112multi_aff_solved[:3, 2] = multi_aff_solved[:3, 2].subs(solved)
113
114# Check that SPM gave us the same result
115A_ms_0based = A_ms * one_based
116A_ms_0based.simplify()
117A_ss_0based = A_ss * one_based
118A_ss_0based.simplify()
119assert single_aff == A_ss_0based
120assert multi_aff_solved == A_ms_0based
121
122# Now, trying to work out Z from slice affines
123A_i = single_aff
124nz_trans = eye(4)
125NZT = Symbol('d')
126nz_trans[2, 3] = NZT
127A_j = A_i * nz_trans
128IPP_i = A_i[:3, 3]
129IPP_j = A_j[:3, 3]
130
131# SPM does it with the inner product of the vectors
132spm_z = IPP_j.T * orient_cross
133spm_z.simplify()
134
135# We can also do it with a sum and division, but then we'd get undefined
136# behavior when orient_cross sums to zero.
137ipp_sum_div = sum(IPP_j) / sum(orient_cross)
138ipp_sum_div = sympy.simplify(ipp_sum_div)
139
140
141# Dump out the formulae here to latex for the RST docs
142def my_latex(expr):
143    S = sympy.latex(expr)
144    return S[1:-1]
145
146
147print('Latex stuff')
148print('   R = ' + my_latex(to_inv))
149print('   ')
150print('   L = ' + my_latex(inv_lhs))
151print()
152print('   0B = ' + my_latex(one_based))
153print()
154print('   ' + my_latex(solved))
155print()
156print('   A_{multi} = ' + my_latex(multi_aff_solved))
157print('   ')
158print('   A_{single} = ' + my_latex(single_aff))
159print()
160print(r'   \left(\begin{smallmatrix}T^N\\1\end{smallmatrix}\right) = A ' + my_latex(trans_z_N))
161print()
162print('   A_j = A_{single} ' + my_latex(nz_trans))
163print()
164print('   T^j = ' + my_latex(IPP_j))
165print()
166print(r'   T^j \cdot \mathbf{c} = ' + my_latex(spm_z))
167