1// Copyright (C) 2015 Audrius Butkevicius and Contributors (see the CONTRIBUTORS file).
2
3package protocol
4
5import (
6	"errors"
7	"fmt"
8	"io"
9)
10
11const (
12	magic        = 0x9E79BC40
13	ProtocolName = "bep-relay"
14)
15
16var (
17	ResponseSuccess           = Response{0, "success"}
18	ResponseNotFound          = Response{1, "not found"}
19	ResponseAlreadyConnected  = Response{2, "already connected"}
20	ResponseUnexpectedMessage = Response{100, "unexpected message"}
21)
22
23func WriteMessage(w io.Writer, message interface{}) error {
24	header := header{
25		magic: magic,
26	}
27
28	var payload []byte
29	var err error
30
31	switch msg := message.(type) {
32	case Ping:
33		payload, err = msg.MarshalXDR()
34		header.messageType = messageTypePing
35	case Pong:
36		payload, err = msg.MarshalXDR()
37		header.messageType = messageTypePong
38	case JoinRelayRequest:
39		payload, err = msg.MarshalXDR()
40		header.messageType = messageTypeJoinRelayRequest
41	case JoinSessionRequest:
42		payload, err = msg.MarshalXDR()
43		header.messageType = messageTypeJoinSessionRequest
44	case Response:
45		payload, err = msg.MarshalXDR()
46		header.messageType = messageTypeResponse
47	case ConnectRequest:
48		payload, err = msg.MarshalXDR()
49		header.messageType = messageTypeConnectRequest
50	case SessionInvitation:
51		payload, err = msg.MarshalXDR()
52		header.messageType = messageTypeSessionInvitation
53	case RelayFull:
54		payload, err = msg.MarshalXDR()
55		header.messageType = messageTypeRelayFull
56	default:
57		err = errors.New("unknown message type")
58	}
59
60	if err != nil {
61		return err
62	}
63
64	header.messageLength = int32(len(payload))
65
66	headerpayload, err := header.MarshalXDR()
67	if err != nil {
68		return err
69	}
70
71	_, err = w.Write(append(headerpayload, payload...))
72	return err
73}
74
75func ReadMessage(r io.Reader) (interface{}, error) {
76	var header header
77
78	buf := make([]byte, header.XDRSize())
79	if _, err := io.ReadFull(r, buf); err != nil {
80		return nil, err
81	}
82
83	if err := header.UnmarshalXDR(buf); err != nil {
84		return nil, err
85	}
86
87	if header.magic != magic {
88		return nil, errors.New("magic mismatch")
89	}
90	if header.messageLength < 0 || header.messageLength > 1024 {
91		return nil, fmt.Errorf("bad length (%d)", header.messageLength)
92	}
93
94	buf = make([]byte, int(header.messageLength))
95	if _, err := io.ReadFull(r, buf); err != nil {
96		return nil, err
97	}
98
99	switch header.messageType {
100	case messageTypePing:
101		var msg Ping
102		err := msg.UnmarshalXDR(buf)
103		return msg, err
104	case messageTypePong:
105		var msg Pong
106		err := msg.UnmarshalXDR(buf)
107		return msg, err
108	case messageTypeJoinRelayRequest:
109		var msg JoinRelayRequest
110		err := msg.UnmarshalXDR(buf)
111		return msg, err
112	case messageTypeJoinSessionRequest:
113		var msg JoinSessionRequest
114		err := msg.UnmarshalXDR(buf)
115		return msg, err
116	case messageTypeResponse:
117		var msg Response
118		err := msg.UnmarshalXDR(buf)
119		return msg, err
120	case messageTypeConnectRequest:
121		var msg ConnectRequest
122		err := msg.UnmarshalXDR(buf)
123		return msg, err
124	case messageTypeSessionInvitation:
125		var msg SessionInvitation
126		err := msg.UnmarshalXDR(buf)
127		return msg, err
128	case messageTypeRelayFull:
129		var msg RelayFull
130		err := msg.UnmarshalXDR(buf)
131		return msg, err
132	}
133
134	return nil, errors.New("unknown message type")
135}
136