1package otr3
2
3import "bytes"
4
5const tlvHeaderLength = 4
6
7const (
8	tlvTypePadding           = uint16(0x00)
9	tlvTypeDisconnected      = uint16(0x01)
10	tlvTypeSMP1              = uint16(0x02)
11	tlvTypeSMP2              = uint16(0x03)
12	tlvTypeSMP3              = uint16(0x04)
13	tlvTypeSMP4              = uint16(0x05)
14	tlvTypeSMPAbort          = uint16(0x06)
15	tlvTypeSMP1WithQuestion  = uint16(0x07)
16	tlvTypeExtraSymmetricKey = uint16(0x08)
17)
18
19type tlvHandler func(*Conversation, tlv, dataMessageExtra) (*tlv, error)
20
21var tlvHandlers = make([]tlvHandler, 9)
22
23func initTLVHandlers() {
24	tlvHandlers[tlvTypePadding] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
25		return c.processPaddingTLV(t, x)
26	}
27	tlvHandlers[tlvTypeDisconnected] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
28		return c.processDisconnectedTLV(t, x)
29	}
30	tlvHandlers[tlvTypeSMP1] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
31		return c.processSMPTLV(t, x)
32	}
33	tlvHandlers[tlvTypeSMP2] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
34		return c.processSMPTLV(t, x)
35	}
36	tlvHandlers[tlvTypeSMP3] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
37		return c.processSMPTLV(t, x)
38	}
39	tlvHandlers[tlvTypeSMP4] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
40		return c.processSMPTLV(t, x)
41	}
42	tlvHandlers[tlvTypeSMPAbort] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
43		return c.processSMPTLV(t, x)
44	}
45	tlvHandlers[tlvTypeSMP1WithQuestion] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
46		return c.processSMPTLV(t, x)
47	}
48	tlvHandlers[tlvTypeExtraSymmetricKey] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) {
49		return c.processExtraSymmetricKeyTLV(t, x)
50	}
51}
52
53func messageHandlerForTLV(t tlv) (tlvHandler, error) {
54	if t.tlvType >= uint16(len(tlvHandlers)) {
55		return nil, newOtrError("unexpected TLV type")
56	}
57	return tlvHandlers[t.tlvType], nil
58}
59
60type tlv struct {
61	tlvType   uint16
62	tlvLength uint16
63	tlvValue  []byte
64}
65
66func (c tlv) serialize() []byte {
67	out := appendShort([]byte{}, c.tlvType)
68	out = appendShort(out, c.tlvLength)
69	return append(out, c.tlvValue...)
70}
71
72func (c *tlv) deserialize(tlvsBytes []byte) error {
73	var ok bool
74	tlvsBytes, c.tlvType, ok = extractShort(tlvsBytes)
75	if !ok {
76		return newOtrError("wrong tlv type")
77	}
78	tlvsBytes, c.tlvLength, ok = extractShort(tlvsBytes)
79	if !ok {
80		return newOtrError("wrong tlv length")
81	}
82	if len(tlvsBytes) < int(c.tlvLength) {
83		return newOtrError("wrong tlv value")
84	}
85	c.tlvValue = tlvsBytes[:int(c.tlvLength)]
86	return nil
87}
88
89func (c tlv) isSMPMessage() bool {
90	return c.tlvType >= tlvTypeSMP1 && c.tlvType <= tlvTypeSMP1WithQuestion
91}
92
93func (c tlv) smpMessage() (smpMessage, bool) {
94	switch c.tlvType {
95	case tlvTypeSMP1:
96		return toSmpMessage1(c)
97	case tlvTypeSMP1WithQuestion:
98		return toSmpMessage1Q(c)
99	case tlvTypeSMP2:
100		return toSmpMessage2(c)
101	case tlvTypeSMP3:
102		return toSmpMessage3(c)
103	case tlvTypeSMP4:
104		return toSmpMessage4(c)
105	case tlvTypeSMPAbort:
106		return toSmpMessageAbort(c)
107	}
108
109	return nil, false
110}
111
112func toSmpMessage1(t tlv) (msg smp1Message, ok bool) {
113	_, mpis, ok := extractMPIs(t.tlvValue)
114	if !ok || len(mpis) < 6 {
115		return msg, false
116	}
117	msg.g2a = mpis[0]
118	msg.c2 = mpis[1]
119	msg.d2 = mpis[2]
120	msg.g3a = mpis[3]
121	msg.c3 = mpis[4]
122	msg.d3 = mpis[5]
123	return msg, true
124}
125
126func toSmpMessage1Q(t tlv) (msg smp1Message, ok bool) {
127	nulPos := bytes.IndexByte(t.tlvValue, 0)
128	if nulPos == -1 {
129		return msg, false
130	}
131	question := string(t.tlvValue[:nulPos])
132	t.tlvValue = t.tlvValue[(nulPos + 1):]
133	msg, ok = toSmpMessage1(t)
134	msg.hasQuestion = true
135	msg.question = question
136	return msg, ok
137}
138
139func toSmpMessage2(t tlv) (msg smp2Message, ok bool) {
140	_, mpis, ok := extractMPIs(t.tlvValue)
141	if !ok || len(mpis) < 11 {
142		return msg, false
143	}
144	msg.g2b = mpis[0]
145	msg.c2 = mpis[1]
146	msg.d2 = mpis[2]
147	msg.g3b = mpis[3]
148	msg.c3 = mpis[4]
149	msg.d3 = mpis[5]
150	msg.pb = mpis[6]
151	msg.qb = mpis[7]
152	msg.cp = mpis[8]
153	msg.d5 = mpis[9]
154	msg.d6 = mpis[10]
155	return msg, true
156}
157
158func toSmpMessage3(t tlv) (msg smp3Message, ok bool) {
159	_, mpis, ok := extractMPIs(t.tlvValue)
160	if !ok || len(mpis) < 8 {
161		return msg, false
162	}
163	msg.pa = mpis[0]
164	msg.qa = mpis[1]
165	msg.cp = mpis[2]
166	msg.d5 = mpis[3]
167	msg.d6 = mpis[4]
168	msg.ra = mpis[5]
169	msg.cr = mpis[6]
170	msg.d7 = mpis[7]
171	return msg, true
172}
173
174func toSmpMessage4(t tlv) (msg smp4Message, ok bool) {
175	_, mpis, ok := extractMPIs(t.tlvValue)
176	if !ok || len(mpis) < 3 {
177		return msg, false
178	}
179	msg.rb = mpis[0]
180	msg.cr = mpis[1]
181	msg.d7 = mpis[2]
182	return msg, true
183}
184
185func toSmpMessageAbort(t tlv) (msg smpMessageAbort, ok bool) {
186	return msg, true
187}
188