1package tcp
2
3import (
4	"bytes"
5	"fmt"
6	"net"
7	"strings"
8	"sync"
9	"time"
10
11	log "github.com/schollz/logger"
12	"github.com/schollz/pake/v3"
13
14	"github.com/schollz/croc/v9/src/comm"
15	"github.com/schollz/croc/v9/src/crypt"
16	"github.com/schollz/croc/v9/src/models"
17)
18
19type server struct {
20	host       string
21	port       string
22	debugLevel string
23	banner     string
24	password   string
25	rooms      roomMap
26}
27
28type roomInfo struct {
29	first  *comm.Comm
30	second *comm.Comm
31	opened time.Time
32	full   bool
33}
34
35type roomMap struct {
36	rooms map[string]roomInfo
37	sync.Mutex
38}
39
40const pingRoom = "pinglkasjdlfjsaldjf"
41
42var timeToRoomDeletion = 10 * time.Minute
43
44// Run starts a tcp listener, run async
45func Run(debugLevel, host, port, password string, banner ...string) (err error) {
46	s := new(server)
47	s.host = host
48	s.port = port
49	s.password = password
50	s.debugLevel = debugLevel
51	if len(banner) > 0 {
52		s.banner = banner[0]
53	}
54	return s.start()
55}
56
57func (s *server) start() (err error) {
58	log.SetLevel(s.debugLevel)
59	log.Debugf("starting with password '%s'", s.password)
60	s.rooms.Lock()
61	s.rooms.rooms = make(map[string]roomInfo)
62	s.rooms.Unlock()
63
64	// delete old rooms
65	go func() {
66		for {
67			time.Sleep(timeToRoomDeletion)
68			var roomsToDelete []string
69			s.rooms.Lock()
70			for room := range s.rooms.rooms {
71				if time.Since(s.rooms.rooms[room].opened) > 3*time.Hour {
72					roomsToDelete = append(roomsToDelete, room)
73				}
74			}
75			s.rooms.Unlock()
76
77			for _, room := range roomsToDelete {
78				s.deleteRoom(room)
79			}
80		}
81	}()
82
83	err = s.run()
84	if err != nil {
85		log.Error(err)
86	}
87	return
88}
89
90func (s *server) run() (err error) {
91	network := "tcp"
92	addr := net.JoinHostPort(s.host, s.port)
93	if s.host != "" {
94		ip := net.ParseIP(s.host)
95		if ip == nil {
96			tcpIP, err := net.ResolveIPAddr("ip", s.host)
97			if err != nil {
98				return err
99			}
100			ip = tcpIP.IP
101		}
102		addr = net.JoinHostPort(ip.String(), s.port)
103		if s.host != "" {
104			if ip.To4() != nil {
105				network = "tcp4"
106			} else {
107				network = "tcp6"
108			}
109		}
110	}
111	addr = strings.Replace(addr, "127.0.0.1", "0.0.0.0", 1)
112	log.Infof("starting TCP server on " + addr)
113	server, err := net.Listen(network, addr)
114	if err != nil {
115		return fmt.Errorf("error listening on %s: %w", addr, err)
116	}
117	defer server.Close()
118	// spawn a new goroutine whenever a client connects
119	for {
120		connection, err := server.Accept()
121		if err != nil {
122			return fmt.Errorf("problem accepting connection: %w", err)
123		}
124		log.Debugf("client %s connected", connection.RemoteAddr().String())
125		go func(port string, connection net.Conn) {
126			c := comm.New(connection)
127			room, errCommunication := s.clientCommunication(port, c)
128			log.Debugf("room: %+v", room)
129			log.Debugf("err: %+v", errCommunication)
130			if errCommunication != nil {
131				log.Debugf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error())
132				connection.Close()
133				return
134			}
135			if room == pingRoom {
136				log.Debugf("got ping")
137				connection.Close()
138				return
139			}
140			for {
141				// check connection
142				log.Debugf("checking connection of room %s for %+v", room, c)
143				deleteIt := false
144				s.rooms.Lock()
145				if _, ok := s.rooms.rooms[room]; !ok {
146					log.Debug("room is gone")
147					s.rooms.Unlock()
148					return
149				}
150				log.Debugf("room: %+v", s.rooms.rooms[room])
151				if s.rooms.rooms[room].first != nil && s.rooms.rooms[room].second != nil {
152					log.Debug("rooms ready")
153					s.rooms.Unlock()
154					break
155				} else {
156					if s.rooms.rooms[room].first != nil {
157						errSend := s.rooms.rooms[room].first.Send([]byte{1})
158						if errSend != nil {
159							log.Debug(errSend)
160							deleteIt = true
161						}
162					}
163				}
164				s.rooms.Unlock()
165				if deleteIt {
166					s.deleteRoom(room)
167					break
168				}
169				time.Sleep(1 * time.Second)
170			}
171		}(s.port, connection)
172	}
173}
174
175var weakKey = []byte{1, 2, 3}
176
177func (s *server) clientCommunication(port string, c *comm.Comm) (room string, err error) {
178	// establish secure password with PAKE for communication with relay
179	B, err := pake.InitCurve(weakKey, 1, "siec")
180	if err != nil {
181		return
182	}
183	Abytes, err := c.Receive()
184	if err != nil {
185		return
186	}
187	log.Debugf("Abytes: %s", Abytes)
188	if bytes.Equal(Abytes, []byte("ping")) {
189		room = pingRoom
190		log.Debug("sending back pong")
191		c.Send([]byte("pong"))
192		return
193	}
194	err = B.Update(Abytes)
195	if err != nil {
196		return
197	}
198	err = c.Send(B.Bytes())
199	if err != nil {
200		return
201	}
202	strongKey, err := B.SessionKey()
203	if err != nil {
204		return
205	}
206	log.Debugf("strongkey: %x", strongKey)
207
208	// receive salt
209	salt, err := c.Receive()
210	if err != nil {
211		return
212	}
213	strongKeyForEncryption, _, err := crypt.New(strongKey, salt)
214	if err != nil {
215		return
216	}
217
218	log.Debugf("waiting for password")
219	passwordBytesEnc, err := c.Receive()
220	if err != nil {
221		return
222	}
223	passwordBytes, err := crypt.Decrypt(passwordBytesEnc, strongKeyForEncryption)
224	if err != nil {
225		return
226	}
227	if strings.TrimSpace(string(passwordBytes)) != s.password {
228		err = fmt.Errorf("bad password")
229		enc, _ := crypt.Decrypt([]byte(err.Error()), strongKeyForEncryption)
230		if err := c.Send(enc); err != nil {
231			return "", fmt.Errorf("send error: %w", err)
232		}
233		return
234	}
235
236	// send ok to tell client they are connected
237	banner := s.banner
238	if len(banner) == 0 {
239		banner = "ok"
240	}
241	log.Debugf("sending '%s'", banner)
242	bSend, err := crypt.Encrypt([]byte(banner+"|||"+c.Connection().RemoteAddr().String()), strongKeyForEncryption)
243	if err != nil {
244		return
245	}
246	err = c.Send(bSend)
247	if err != nil {
248		return
249	}
250
251	// wait for client to tell me which room they want
252	log.Debug("waiting for answer")
253	enc, err := c.Receive()
254	if err != nil {
255		return
256	}
257	roomBytes, err := crypt.Decrypt(enc, strongKeyForEncryption)
258	if err != nil {
259		return
260	}
261	room = string(roomBytes)
262
263	s.rooms.Lock()
264	// create the room if it is new
265	if _, ok := s.rooms.rooms[room]; !ok {
266		s.rooms.rooms[room] = roomInfo{
267			first:  c,
268			opened: time.Now(),
269		}
270		s.rooms.Unlock()
271		// tell the client that they got the room
272
273		bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption)
274		if err != nil {
275			return
276		}
277		err = c.Send(bSend)
278		if err != nil {
279			log.Error(err)
280			s.deleteRoom(room)
281			return
282		}
283		log.Debugf("room %s has 1", room)
284		return
285	}
286	if s.rooms.rooms[room].full {
287		s.rooms.Unlock()
288		bSend, err = crypt.Encrypt([]byte("room full"), strongKeyForEncryption)
289		if err != nil {
290			return
291		}
292		err = c.Send(bSend)
293		if err != nil {
294			log.Error(err)
295			return
296		}
297		return
298	}
299	log.Debugf("room %s has 2", room)
300	s.rooms.rooms[room] = roomInfo{
301		first:  s.rooms.rooms[room].first,
302		second: c,
303		opened: s.rooms.rooms[room].opened,
304		full:   true,
305	}
306	otherConnection := s.rooms.rooms[room].first
307	s.rooms.Unlock()
308
309	// second connection is the sender, time to staple connections
310	var wg sync.WaitGroup
311	wg.Add(1)
312
313	// start piping
314	go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) {
315		log.Debug("starting pipes")
316		pipe(com1.Connection(), com2.Connection())
317		wg.Done()
318		log.Debug("done piping")
319	}(otherConnection, c, &wg)
320
321	// tell the sender everything is ready
322	bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption)
323	if err != nil {
324		return
325	}
326	err = c.Send(bSend)
327	if err != nil {
328		s.deleteRoom(room)
329		return
330	}
331	wg.Wait()
332
333	// delete room
334	s.deleteRoom(room)
335	return
336}
337
338func (s *server) deleteRoom(room string) {
339	s.rooms.Lock()
340	defer s.rooms.Unlock()
341	if _, ok := s.rooms.rooms[room]; !ok {
342		return
343	}
344	log.Debugf("deleting room: %s", room)
345	if s.rooms.rooms[room].first != nil {
346		s.rooms.rooms[room].first.Close()
347	}
348	if s.rooms.rooms[room].second != nil {
349		s.rooms.rooms[room].second.Close()
350	}
351	s.rooms.rooms[room] = roomInfo{first: nil, second: nil}
352	delete(s.rooms.rooms, room)
353
354}
355
356// chanFromConn creates a channel from a Conn object, and sends everything it
357//  Read()s from the socket to the channel.
358func chanFromConn(conn net.Conn) chan []byte {
359	c := make(chan []byte, 1)
360	if err := conn.SetReadDeadline(time.Now().Add(3 * time.Hour)); err != nil {
361		log.Warnf("can't set read deadline: %v", err)
362	}
363
364	go func() {
365		b := make([]byte, models.TCP_BUFFER_SIZE)
366		for {
367			n, err := conn.Read(b)
368			if n > 0 {
369				res := make([]byte, n)
370				// Copy the buffer so it doesn't get changed while read by the recipient.
371				copy(res, b[:n])
372				c <- res
373			}
374			if err != nil {
375				log.Debug(err)
376				c <- nil
377				break
378			}
379		}
380		log.Debug("exiting")
381	}()
382
383	return c
384}
385
386// pipe creates a full-duplex pipe between the two sockets and
387// transfers data from one to the other.
388func pipe(conn1 net.Conn, conn2 net.Conn) {
389	chan1 := chanFromConn(conn1)
390	chan2 := chanFromConn(conn2)
391
392	for {
393		select {
394		case b1 := <-chan1:
395			if b1 == nil {
396				return
397			}
398			if _, err := conn2.Write(b1); err != nil {
399				log.Errorf("write error on channel 1: %v", err)
400			}
401
402		case b2 := <-chan2:
403			if b2 == nil {
404				return
405			}
406			if _, err := conn1.Write(b2); err != nil {
407				log.Errorf("write error on channel 2: %v", err)
408			}
409		}
410	}
411}
412
413func PingServer(address string) (err error) {
414	log.Debugf("pinging %s", address)
415	c, err := comm.NewConnection(address, 300*time.Millisecond)
416	if err != nil {
417		log.Debug(err)
418		return
419	}
420	err = c.Send([]byte("ping"))
421	if err != nil {
422		log.Debug(err)
423		return
424	}
425	b, err := c.Receive()
426	if err != nil {
427		log.Debug(err)
428		return
429	}
430	if bytes.Equal(b, []byte("pong")) {
431		return nil
432	}
433	return fmt.Errorf("no pong")
434}
435
436// ConnectToTCPServer will initiate a new connection
437// to the specified address, room with optional time limit
438func ConnectToTCPServer(address, password, room string, timelimit ...time.Duration) (c *comm.Comm, banner string, ipaddr string, err error) {
439	if len(timelimit) > 0 {
440		c, err = comm.NewConnection(address, timelimit[0])
441	} else {
442		c, err = comm.NewConnection(address)
443	}
444	if err != nil {
445		log.Debug(err)
446		return
447	}
448
449	// get PAKE connection with server to establish strong key to transfer info
450	A, err := pake.InitCurve(weakKey, 0, "siec")
451	if err != nil {
452		log.Debug(err)
453		return
454	}
455	err = c.Send(A.Bytes())
456	if err != nil {
457		log.Debug(err)
458		return
459	}
460	Bbytes, err := c.Receive()
461	if err != nil {
462		log.Debug(err)
463		return
464	}
465	err = A.Update(Bbytes)
466	if err != nil {
467		log.Debug(err)
468		return
469	}
470	strongKey, err := A.SessionKey()
471	if err != nil {
472		log.Debug(err)
473		return
474	}
475	log.Debugf("strong key: %x", strongKey)
476
477	strongKeyForEncryption, salt, err := crypt.New(strongKey, nil)
478	if err != nil {
479		log.Debug(err)
480		return
481	}
482	// send salt
483	err = c.Send(salt)
484	if err != nil {
485		log.Debug(err)
486		return
487	}
488
489	log.Debug("sending password")
490	bSend, err := crypt.Encrypt([]byte(password), strongKeyForEncryption)
491	if err != nil {
492		log.Debug(err)
493		return
494	}
495	err = c.Send(bSend)
496	if err != nil {
497		log.Debug(err)
498		return
499	}
500	log.Debug("waiting for first ok")
501	enc, err := c.Receive()
502	if err != nil {
503		log.Debug(err)
504		return
505	}
506	data, err := crypt.Decrypt(enc, strongKeyForEncryption)
507	if err != nil {
508		log.Debug(err)
509		return
510	}
511	if !strings.Contains(string(data), "|||") {
512		err = fmt.Errorf("bad response: %s", string(data))
513		log.Debug(err)
514		return
515	}
516	banner = strings.Split(string(data), "|||")[0]
517	ipaddr = strings.Split(string(data), "|||")[1]
518	log.Debug("sending room")
519	bSend, err = crypt.Encrypt([]byte(room), strongKeyForEncryption)
520	if err != nil {
521		log.Debug(err)
522		return
523	}
524	err = c.Send(bSend)
525	if err != nil {
526		log.Debug(err)
527		return
528	}
529	log.Debug("waiting for room confirmation")
530	enc, err = c.Receive()
531	if err != nil {
532		log.Debug(err)
533		return
534	}
535	data, err = crypt.Decrypt(enc, strongKeyForEncryption)
536	if err != nil {
537		log.Debug(err)
538		return
539	}
540	if !bytes.Equal(data, []byte("ok")) {
541		err = fmt.Errorf("got bad response: %s", data)
542		log.Debug(err)
543		return
544	}
545	log.Debug("all set")
546	return
547}
548