1package steam
2
3import (
4	"bytes"
5	"compress/gzip"
6	"crypto/rand"
7	"encoding/binary"
8	"fmt"
9	"hash/crc32"
10	"io/ioutil"
11	"net"
12	"sync"
13	"sync/atomic"
14	"time"
15
16	"github.com/Philipp15b/go-steam/cryptoutil"
17	"github.com/Philipp15b/go-steam/netutil"
18	. "github.com/Philipp15b/go-steam/protocol"
19	. "github.com/Philipp15b/go-steam/protocol/protobuf"
20	. "github.com/Philipp15b/go-steam/protocol/steamlang"
21	. "github.com/Philipp15b/go-steam/steamid"
22)
23
24// Represents a client to the Steam network.
25// Always poll events from the channel returned by Events() or receiving messages will stop.
26// All access, unless otherwise noted, should be threadsafe.
27//
28// When a FatalErrorEvent is emitted, the connection is automatically closed. The same client can be used to reconnect.
29// Other errors don't have any effect.
30type Client struct {
31	// these need to be 64 bit aligned for sync/atomic on 32bit
32	sessionId    int32
33	_            uint32
34	steamId      uint64
35	currentJobId uint64
36
37	Auth          *Auth
38	Social        *Social
39	Web           *Web
40	Notifications *Notifications
41	Trading       *Trading
42	GC            *GameCoordinator
43
44	events        chan interface{}
45	handlers      []PacketHandler
46	handlersMutex sync.RWMutex
47
48	tempSessionKey []byte
49
50	ConnectionTimeout time.Duration
51
52	mutex     sync.RWMutex // guarding conn and writeChan
53	conn      connection
54	writeChan chan IMsg
55	writeBuf  *bytes.Buffer
56	heartbeat *time.Ticker
57}
58
59type PacketHandler interface {
60	HandlePacket(*Packet)
61}
62
63func NewClient() *Client {
64	client := &Client{
65		events:   make(chan interface{}, 3),
66		writeBuf: new(bytes.Buffer),
67	}
68	client.Auth = &Auth{client: client}
69	client.RegisterPacketHandler(client.Auth)
70	client.Social = newSocial(client)
71	client.RegisterPacketHandler(client.Social)
72	client.Web = &Web{client: client}
73	client.RegisterPacketHandler(client.Web)
74	client.Notifications = newNotifications(client)
75	client.RegisterPacketHandler(client.Notifications)
76	client.Trading = &Trading{client: client}
77	client.RegisterPacketHandler(client.Trading)
78	client.GC = newGC(client)
79	client.RegisterPacketHandler(client.GC)
80	return client
81}
82
83// Get the event channel. By convention all events are pointers, except for errors.
84// It is never closed.
85func (c *Client) Events() <-chan interface{} {
86	return c.events
87}
88
89func (c *Client) Emit(event interface{}) {
90	c.events <- event
91}
92
93// Emits a FatalErrorEvent formatted with fmt.Errorf and disconnects.
94func (c *Client) Fatalf(format string, a ...interface{}) {
95	c.Emit(FatalErrorEvent(fmt.Errorf(format, a...)))
96	c.Disconnect()
97}
98
99// Emits an error formatted with fmt.Errorf.
100func (c *Client) Errorf(format string, a ...interface{}) {
101	c.Emit(fmt.Errorf(format, a...))
102}
103
104// Registers a PacketHandler that receives all incoming packets.
105func (c *Client) RegisterPacketHandler(handler PacketHandler) {
106	c.handlersMutex.Lock()
107	defer c.handlersMutex.Unlock()
108	c.handlers = append(c.handlers, handler)
109}
110
111func (c *Client) GetNextJobId() JobId {
112	return JobId(atomic.AddUint64(&c.currentJobId, 1))
113}
114
115func (c *Client) SteamId() SteamId {
116	return SteamId(atomic.LoadUint64(&c.steamId))
117}
118
119func (c *Client) SessionId() int32 {
120	return atomic.LoadInt32(&c.sessionId)
121}
122
123func (c *Client) Connected() bool {
124	c.mutex.RLock()
125	defer c.mutex.RUnlock()
126	return c.conn != nil
127}
128
129// Connects to a random Steam server and returns its address.
130// If this client is already connected, it is disconnected first.
131// This method tries to use an address from the Steam Directory and falls
132// back to the built-in server list if the Steam Directory can't be reached.
133// If you want to connect to a specific server, use `ConnectTo`.
134func (c *Client) Connect() *netutil.PortAddr {
135	var server *netutil.PortAddr
136
137	// try to initialize the directory cache
138	if !steamDirectoryCache.IsInitialized() {
139		_ = steamDirectoryCache.Initialize()
140	}
141	if steamDirectoryCache.IsInitialized() {
142		server = steamDirectoryCache.GetRandomCM()
143	} else {
144		server = GetRandomCM()
145	}
146
147	c.ConnectTo(server)
148	return server
149}
150
151// Connects to a specific server.
152// You may want to use one of the `GetRandom*CM()` functions in this package.
153// If this client is already connected, it is disconnected first.
154func (c *Client) ConnectTo(addr *netutil.PortAddr) {
155	c.ConnectToBind(addr, nil)
156}
157
158// Connects to a specific server, and binds to a specified local IP
159// If this client is already connected, it is disconnected first.
160func (c *Client) ConnectToBind(addr *netutil.PortAddr, local *net.TCPAddr) {
161	c.Disconnect()
162
163	conn, err := dialTCP(local, addr.ToTCPAddr())
164	if err != nil {
165		c.Fatalf("Connect failed: %v", err)
166		return
167	}
168	c.conn = conn
169	c.writeChan = make(chan IMsg, 5)
170
171	go c.readLoop()
172	go c.writeLoop()
173}
174
175func (c *Client) Disconnect() {
176	c.mutex.Lock()
177	defer c.mutex.Unlock()
178
179	if c.conn == nil {
180		return
181	}
182
183	c.conn.Close()
184	c.conn = nil
185	if c.heartbeat != nil {
186		c.heartbeat.Stop()
187	}
188	close(c.writeChan)
189	c.Emit(&DisconnectedEvent{})
190
191}
192
193// Adds a message to the send queue. Modifications to the given message after
194// writing are not allowed (possible race conditions).
195//
196// Writes to this client when not connected are ignored.
197func (c *Client) Write(msg IMsg) {
198	if cm, ok := msg.(IClientMsg); ok {
199		cm.SetSessionId(c.SessionId())
200		cm.SetSteamId(c.SteamId())
201	}
202	c.mutex.RLock()
203	defer c.mutex.RUnlock()
204	if c.conn == nil {
205		return
206	}
207	c.writeChan <- msg
208}
209
210func (c *Client) readLoop() {
211	for {
212		// This *should* be atomic on most platforms, but the Go spec doesn't guarantee it
213		c.mutex.RLock()
214		conn := c.conn
215		c.mutex.RUnlock()
216		if conn == nil {
217			return
218		}
219		packet, err := conn.Read()
220
221		if err != nil {
222			c.Fatalf("Error reading from the connection: %v", err)
223			return
224		}
225		c.handlePacket(packet)
226	}
227}
228
229func (c *Client) writeLoop() {
230	for {
231		c.mutex.RLock()
232		conn := c.conn
233		c.mutex.RUnlock()
234		if conn == nil {
235			return
236		}
237
238		msg, ok := <-c.writeChan
239		if !ok {
240			return
241		}
242
243		err := msg.Serialize(c.writeBuf)
244		if err != nil {
245			c.writeBuf.Reset()
246			c.Fatalf("Error serializing message %v: %v", msg, err)
247			return
248		}
249
250		err = conn.Write(c.writeBuf.Bytes())
251
252		c.writeBuf.Reset()
253
254		if err != nil {
255			c.Fatalf("Error writing message %v: %v", msg, err)
256			return
257		}
258	}
259}
260
261func (c *Client) heartbeatLoop(seconds time.Duration) {
262	if c.heartbeat != nil {
263		c.heartbeat.Stop()
264	}
265	c.heartbeat = time.NewTicker(seconds * time.Second)
266	for {
267		_, ok := <-c.heartbeat.C
268		if !ok {
269			break
270		}
271		c.Write(NewClientMsgProtobuf(EMsg_ClientHeartBeat, new(CMsgClientHeartBeat)))
272	}
273	c.heartbeat = nil
274}
275
276func (c *Client) handlePacket(packet *Packet) {
277	switch packet.EMsg {
278	case EMsg_ChannelEncryptRequest:
279		c.handleChannelEncryptRequest(packet)
280	case EMsg_ChannelEncryptResult:
281		c.handleChannelEncryptResult(packet)
282	case EMsg_Multi:
283		c.handleMulti(packet)
284	case EMsg_ClientCMList:
285		c.handleClientCMList(packet)
286	}
287
288	c.handlersMutex.RLock()
289	defer c.handlersMutex.RUnlock()
290	for _, handler := range c.handlers {
291		handler.HandlePacket(packet)
292	}
293}
294
295func (c *Client) handleChannelEncryptRequest(packet *Packet) {
296	body := NewMsgChannelEncryptRequest()
297	packet.ReadMsg(body)
298
299	if body.Universe != EUniverse_Public {
300		c.Fatalf("Invalid univserse %v!", body.Universe)
301	}
302
303	c.tempSessionKey = make([]byte, 32)
304	rand.Read(c.tempSessionKey)
305	encryptedKey := cryptoutil.RSAEncrypt(GetPublicKey(EUniverse_Public), c.tempSessionKey)
306
307	payload := new(bytes.Buffer)
308	payload.Write(encryptedKey)
309	binary.Write(payload, binary.LittleEndian, crc32.ChecksumIEEE(encryptedKey))
310	payload.WriteByte(0)
311	payload.WriteByte(0)
312	payload.WriteByte(0)
313	payload.WriteByte(0)
314
315	c.Write(NewMsg(NewMsgChannelEncryptResponse(), payload.Bytes()))
316}
317
318func (c *Client) handleChannelEncryptResult(packet *Packet) {
319	body := NewMsgChannelEncryptResult()
320	packet.ReadMsg(body)
321
322	if body.Result != EResult_OK {
323		c.Fatalf("Encryption failed: %v", body.Result)
324		return
325	}
326	c.conn.SetEncryptionKey(c.tempSessionKey)
327	c.tempSessionKey = nil
328
329	c.Emit(&ConnectedEvent{})
330}
331
332func (c *Client) handleMulti(packet *Packet) {
333	body := new(CMsgMulti)
334	packet.ReadProtoMsg(body)
335
336	payload := body.GetMessageBody()
337
338	if body.GetSizeUnzipped() > 0 {
339		r, err := gzip.NewReader(bytes.NewReader(payload))
340		if err != nil {
341			c.Errorf("handleMulti: Error while decompressing: %v", err)
342			return
343		}
344
345		payload, err = ioutil.ReadAll(r)
346		if err != nil {
347			c.Errorf("handleMulti: Error while decompressing: %v", err)
348			return
349		}
350	}
351
352	pr := bytes.NewReader(payload)
353	for pr.Len() > 0 {
354		var length uint32
355		binary.Read(pr, binary.LittleEndian, &length)
356		packetData := make([]byte, length)
357		pr.Read(packetData)
358		p, err := NewPacket(packetData)
359		if err != nil {
360			c.Errorf("Error reading packet in Multi msg %v: %v", packet, err)
361			continue
362		}
363		c.handlePacket(p)
364	}
365}
366
367func (c *Client) handleClientCMList(packet *Packet) {
368	body := new(CMsgClientCMList)
369	packet.ReadProtoMsg(body)
370
371	l := make([]*netutil.PortAddr, 0)
372	for i, ip := range body.GetCmAddresses() {
373		l = append(l, &netutil.PortAddr{
374			readIp(ip),
375			uint16(body.GetCmPorts()[i]),
376		})
377	}
378
379	c.Emit(&ClientCMListEvent{l})
380}
381
382func readIp(ip uint32) net.IP {
383	r := make(net.IP, 4)
384	r[3] = byte(ip)
385	r[2] = byte(ip >> 8)
386	r[1] = byte(ip >> 16)
387	r[0] = byte(ip >> 24)
388	return r
389}
390