1package otr3 2 3import "bytes" 4 5const tlvHeaderLength = 4 6 7const ( 8 tlvTypePadding = uint16(0x00) 9 tlvTypeDisconnected = uint16(0x01) 10 tlvTypeSMP1 = uint16(0x02) 11 tlvTypeSMP2 = uint16(0x03) 12 tlvTypeSMP3 = uint16(0x04) 13 tlvTypeSMP4 = uint16(0x05) 14 tlvTypeSMPAbort = uint16(0x06) 15 tlvTypeSMP1WithQuestion = uint16(0x07) 16 tlvTypeExtraSymmetricKey = uint16(0x08) 17) 18 19type tlvHandler func(*Conversation, tlv, dataMessageExtra) (*tlv, error) 20 21var tlvHandlers = make([]tlvHandler, 9) 22 23func initTLVHandlers() { 24 tlvHandlers[tlvTypePadding] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 25 return c.processPaddingTLV(t, x) 26 } 27 tlvHandlers[tlvTypeDisconnected] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 28 return c.processDisconnectedTLV(t, x) 29 } 30 tlvHandlers[tlvTypeSMP1] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 31 return c.processSMPTLV(t, x) 32 } 33 tlvHandlers[tlvTypeSMP2] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 34 return c.processSMPTLV(t, x) 35 } 36 tlvHandlers[tlvTypeSMP3] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 37 return c.processSMPTLV(t, x) 38 } 39 tlvHandlers[tlvTypeSMP4] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 40 return c.processSMPTLV(t, x) 41 } 42 tlvHandlers[tlvTypeSMPAbort] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 43 return c.processSMPTLV(t, x) 44 } 45 tlvHandlers[tlvTypeSMP1WithQuestion] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 46 return c.processSMPTLV(t, x) 47 } 48 tlvHandlers[tlvTypeExtraSymmetricKey] = func(c *Conversation, t tlv, x dataMessageExtra) (*tlv, error) { 49 return c.processExtraSymmetricKeyTLV(t, x) 50 } 51} 52 53func messageHandlerForTLV(t tlv) (tlvHandler, error) { 54 if t.tlvType >= uint16(len(tlvHandlers)) { 55 return nil, newOtrError("unexpected TLV type") 56 } 57 return tlvHandlers[t.tlvType], nil 58} 59 60type tlv struct { 61 tlvType uint16 62 tlvLength uint16 63 tlvValue []byte 64} 65 66func (c tlv) serialize() []byte { 67 out := appendShort([]byte{}, c.tlvType) 68 out = appendShort(out, c.tlvLength) 69 return append(out, c.tlvValue...) 70} 71 72func (c *tlv) deserialize(tlvsBytes []byte) error { 73 var ok bool 74 tlvsBytes, c.tlvType, ok = extractShort(tlvsBytes) 75 if !ok { 76 return newOtrError("wrong tlv type") 77 } 78 tlvsBytes, c.tlvLength, ok = extractShort(tlvsBytes) 79 if !ok { 80 return newOtrError("wrong tlv length") 81 } 82 if len(tlvsBytes) < int(c.tlvLength) { 83 return newOtrError("wrong tlv value") 84 } 85 c.tlvValue = tlvsBytes[:int(c.tlvLength)] 86 return nil 87} 88 89func (c tlv) isSMPMessage() bool { 90 return c.tlvType >= tlvTypeSMP1 && c.tlvType <= tlvTypeSMP1WithQuestion 91} 92 93func (c tlv) smpMessage() (smpMessage, bool) { 94 switch c.tlvType { 95 case tlvTypeSMP1: 96 return toSmpMessage1(c) 97 case tlvTypeSMP1WithQuestion: 98 return toSmpMessage1Q(c) 99 case tlvTypeSMP2: 100 return toSmpMessage2(c) 101 case tlvTypeSMP3: 102 return toSmpMessage3(c) 103 case tlvTypeSMP4: 104 return toSmpMessage4(c) 105 case tlvTypeSMPAbort: 106 return toSmpMessageAbort(c) 107 } 108 109 return nil, false 110} 111 112func toSmpMessage1(t tlv) (msg smp1Message, ok bool) { 113 _, mpis, ok := extractMPIs(t.tlvValue) 114 if !ok || len(mpis) < 6 { 115 return msg, false 116 } 117 msg.g2a = mpis[0] 118 msg.c2 = mpis[1] 119 msg.d2 = mpis[2] 120 msg.g3a = mpis[3] 121 msg.c3 = mpis[4] 122 msg.d3 = mpis[5] 123 return msg, true 124} 125 126func toSmpMessage1Q(t tlv) (msg smp1Message, ok bool) { 127 nulPos := bytes.IndexByte(t.tlvValue, 0) 128 if nulPos == -1 { 129 return msg, false 130 } 131 question := string(t.tlvValue[:nulPos]) 132 t.tlvValue = t.tlvValue[(nulPos + 1):] 133 msg, ok = toSmpMessage1(t) 134 msg.hasQuestion = true 135 msg.question = question 136 return msg, ok 137} 138 139func toSmpMessage2(t tlv) (msg smp2Message, ok bool) { 140 _, mpis, ok := extractMPIs(t.tlvValue) 141 if !ok || len(mpis) < 11 { 142 return msg, false 143 } 144 msg.g2b = mpis[0] 145 msg.c2 = mpis[1] 146 msg.d2 = mpis[2] 147 msg.g3b = mpis[3] 148 msg.c3 = mpis[4] 149 msg.d3 = mpis[5] 150 msg.pb = mpis[6] 151 msg.qb = mpis[7] 152 msg.cp = mpis[8] 153 msg.d5 = mpis[9] 154 msg.d6 = mpis[10] 155 return msg, true 156} 157 158func toSmpMessage3(t tlv) (msg smp3Message, ok bool) { 159 _, mpis, ok := extractMPIs(t.tlvValue) 160 if !ok || len(mpis) < 8 { 161 return msg, false 162 } 163 msg.pa = mpis[0] 164 msg.qa = mpis[1] 165 msg.cp = mpis[2] 166 msg.d5 = mpis[3] 167 msg.d6 = mpis[4] 168 msg.ra = mpis[5] 169 msg.cr = mpis[6] 170 msg.d7 = mpis[7] 171 return msg, true 172} 173 174func toSmpMessage4(t tlv) (msg smp4Message, ok bool) { 175 _, mpis, ok := extractMPIs(t.tlvValue) 176 if !ok || len(mpis) < 3 { 177 return msg, false 178 } 179 msg.rb = mpis[0] 180 msg.cr = mpis[1] 181 msg.d7 = mpis[2] 182 return msg, true 183} 184 185func toSmpMessageAbort(t tlv) (msg smpMessageAbort, ok bool) { 186 return msg, true 187} 188