1/*
2 * Copyright (c) 2015, Yawning Angel <yawning at schwanenlied dot me>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  * Redistributions of source code must retain the above copyright notice,
9 *    this list of conditions and the following disclaimer.
10 *
11 *  * Redistributions in binary form must reproduce the above copyright notice,
12 *    this list of conditions and the following disclaimer in the documentation
13 *    and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
19 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
20 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
21 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
22 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
23 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
24 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
25 * POSSIBILITY OF SUCH DAMAGE.
26 */
27
28package scramblesuit
29
30import (
31	"bytes"
32	"encoding/base32"
33	"encoding/json"
34	"errors"
35	"fmt"
36	"hash"
37	"io/ioutil"
38	"net"
39	"os"
40	"path"
41	"strconv"
42	"sync"
43	"time"
44
45	"gitlab.com/yawning/obfs4.git/common/csrand"
46)
47
48const (
49	ticketFile = "scramblesuit_tickets.json"
50
51	ticketKeyLength = 32
52	ticketLength    = 112
53	ticketLifetime  = 60 * 60 * 24 * 7
54
55	ticketMinPadLength = 0
56	ticketMaxPadLength = 1388
57)
58
59var (
60	errInvalidTicket = errors.New("scramblesuit: invalid serialized ticket")
61)
62
63type ssTicketStore struct {
64	sync.Mutex
65
66	filePath string
67	store    map[string]*ssTicket
68}
69
70type ssTicket struct {
71	key      [ticketKeyLength]byte
72	ticket   [ticketLength]byte
73	issuedAt int64
74}
75
76type ssTicketJSON struct {
77	KeyTicket string `json:"key-ticket"`
78	IssuedAt  int64  `json:"issuedAt"`
79}
80
81func (t *ssTicket) isValid() bool {
82	return t.issuedAt+ticketLifetime > time.Now().Unix()
83}
84
85func newTicket(raw []byte) (*ssTicket, error) {
86	if len(raw) != ticketKeyLength+ticketLength {
87		return nil, errInvalidTicket
88	}
89	t := &ssTicket{issuedAt: time.Now().Unix()}
90	copy(t.key[:], raw[0:])
91	copy(t.ticket[:], raw[ticketKeyLength:])
92	return t, nil
93}
94
95func (s *ssTicketStore) storeTicket(addr net.Addr, rawT []byte) {
96	t, err := newTicket(rawT)
97	if err != nil {
98		// Silently ignore ticket store failures.
99		return
100	}
101
102	s.Lock()
103	defer s.Unlock()
104
105	// Add the ticket to the map, and checkpoint to disk.  Serialization errors
106	// are ignored because the handshake code will just use UniformDH if a
107	// ticket is not available.
108	s.store[addr.String()] = t
109	_ = s.serialize()
110}
111
112func (s *ssTicketStore) getTicket(addr net.Addr) (*ssTicket, error) {
113	aStr := addr.String()
114
115	s.Lock()
116	defer s.Unlock()
117
118	t, ok := s.store[aStr]
119	if ok && t != nil {
120		// Tickets are one use only, so remove tickets from the map, and
121		// checkpoint the map to disk.
122		delete(s.store, aStr)
123		err := s.serialize()
124		if !t.isValid() {
125			// Expired ticket, ignore it.
126			return nil, err
127		}
128		return t, err
129	}
130
131	// No ticket was found, that's fine.
132	return nil, nil
133}
134
135func (s *ssTicketStore) serialize() error {
136	encMap := make(map[string]*ssTicketJSON)
137	for k, v := range s.store {
138		kt := make([]byte, 0, ticketKeyLength+ticketLength)
139		kt = append(kt, v.key[:]...)
140		kt = append(kt, v.ticket[:]...)
141		ktStr := base32.StdEncoding.EncodeToString(kt)
142		jsonObj := &ssTicketJSON{KeyTicket: ktStr, IssuedAt: v.issuedAt}
143		encMap[k] = jsonObj
144	}
145	jsonStr, err := json.Marshal(encMap)
146	if err != nil {
147		return err
148	}
149	return ioutil.WriteFile(s.filePath, jsonStr, 0600)
150}
151
152func loadTicketStore(stateDir string) (*ssTicketStore, error) {
153	fPath := path.Join(stateDir, ticketFile)
154	s := &ssTicketStore{filePath: fPath}
155	s.store = make(map[string]*ssTicket)
156
157	f, err := ioutil.ReadFile(fPath)
158	if err != nil {
159		// No ticket store is fine.
160		if os.IsNotExist(err) {
161			return s, nil
162		}
163
164		// But a file read error is not.
165		return nil, err
166	}
167
168	encMap := make(map[string]*ssTicketJSON)
169	if err = json.Unmarshal(f, &encMap); err != nil {
170		return nil, fmt.Errorf("failed to load ticket store '%s': '%s'", fPath, err)
171	}
172	for k, v := range encMap {
173		raw, err := base32.StdEncoding.DecodeString(v.KeyTicket)
174		if err != nil || len(raw) != ticketKeyLength+ticketLength {
175			// Just silently skip corrupted tickets.
176			continue
177		}
178		t := &ssTicket{issuedAt: v.IssuedAt}
179		if !t.isValid() {
180			// Just ignore expired tickets.
181			continue
182		}
183		copy(t.key[:], raw[0:])
184		copy(t.ticket[:], raw[ticketKeyLength:])
185		s.store[k] = t
186	}
187	return s, nil
188}
189
190type ssTicketClientHandshake struct {
191	mac    hash.Hash
192	ticket *ssTicket
193	padLen int
194}
195
196func (hs *ssTicketClientHandshake) generateHandshake() ([]byte, error) {
197	var buf bytes.Buffer
198	hs.mac.Reset()
199
200	// The client handshake is T | P | M | MAC(T | P | M | E)
201	_, _ = hs.mac.Write(hs.ticket.ticket[:])
202	m := hs.mac.Sum(nil)[:macLength]
203	p, err := makePad(hs.padLen)
204	if err != nil {
205		return nil, err
206	}
207
208	// Write T, P, M.
209	buf.Write(hs.ticket.ticket[:])
210	buf.Write(p)
211	buf.Write(m)
212
213	// Calculate and write the MAC.
214	e := []byte(strconv.FormatInt(getEpochHour(), 10))
215	_, _ = hs.mac.Write(p)
216	_, _ = hs.mac.Write(m)
217	_, _ = hs.mac.Write(e)
218	buf.Write(hs.mac.Sum(nil)[:macLength])
219
220	hs.mac.Reset()
221	return buf.Bytes(), nil
222}
223
224func newTicketClientHandshake(mac hash.Hash, ticket *ssTicket) *ssTicketClientHandshake {
225	hs := &ssTicketClientHandshake{mac: mac, ticket: ticket}
226	hs.padLen = csrand.IntRange(ticketMinPadLength, ticketMaxPadLength)
227	return hs
228}
229