1// Copyright 2020 The go-ethereum Authors
2// This file is part of the go-ethereum library.
3//
4// The go-ethereum library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Lesser General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// The go-ethereum library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU Lesser General Public License for more details.
13//
14// You should have received a copy of the GNU Lesser General Public License
15// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16
17package ethtest
18
19import (
20	"crypto/ecdsa"
21	"fmt"
22
23	"github.com/ethereum/go-ethereum/eth/protocols/eth"
24	"github.com/ethereum/go-ethereum/p2p"
25	"github.com/ethereum/go-ethereum/p2p/rlpx"
26	"github.com/ethereum/go-ethereum/rlp"
27)
28
29type Message interface {
30	Code() int
31}
32
33type Error struct {
34	err error
35}
36
37func (e *Error) Unwrap() error  { return e.err }
38func (e *Error) Error() string  { return e.err.Error() }
39func (e *Error) Code() int      { return -1 }
40func (e *Error) String() string { return e.Error() }
41
42func errorf(format string, args ...interface{}) *Error {
43	return &Error{fmt.Errorf(format, args...)}
44}
45
46// Hello is the RLP structure of the protocol handshake.
47type Hello struct {
48	Version    uint64
49	Name       string
50	Caps       []p2p.Cap
51	ListenPort uint64
52	ID         []byte // secp256k1 public key
53
54	// Ignore additional fields (for forward compatibility).
55	Rest []rlp.RawValue `rlp:"tail"`
56}
57
58func (h Hello) Code() int { return 0x00 }
59
60// Disconnect is the RLP structure for a disconnect message.
61type Disconnect struct {
62	Reason p2p.DiscReason
63}
64
65func (d Disconnect) Code() int { return 0x01 }
66
67type Ping struct{}
68
69func (p Ping) Code() int { return 0x02 }
70
71type Pong struct{}
72
73func (p Pong) Code() int { return 0x03 }
74
75// Status is the network packet for the status message for eth/64 and later.
76type Status eth.StatusPacket
77
78func (s Status) Code() int { return 16 }
79
80// NewBlockHashes is the network packet for the block announcements.
81type NewBlockHashes eth.NewBlockHashesPacket
82
83func (nbh NewBlockHashes) Code() int { return 17 }
84
85type Transactions eth.TransactionsPacket
86
87func (t Transactions) Code() int { return 18 }
88
89// GetBlockHeaders represents a block header query.
90type GetBlockHeaders eth.GetBlockHeadersPacket
91
92func (g GetBlockHeaders) Code() int { return 19 }
93
94type BlockHeaders eth.BlockHeadersPacket
95
96func (bh BlockHeaders) Code() int { return 20 }
97
98// GetBlockBodies represents a GetBlockBodies request
99type GetBlockBodies eth.GetBlockBodiesPacket
100
101func (gbb GetBlockBodies) Code() int { return 21 }
102
103// BlockBodies is the network packet for block content distribution.
104type BlockBodies eth.BlockBodiesPacket
105
106func (bb BlockBodies) Code() int { return 22 }
107
108// NewBlock is the network packet for the block propagation message.
109type NewBlock eth.NewBlockPacket
110
111func (nb NewBlock) Code() int { return 23 }
112
113// NewPooledTransactionHashes is the network packet for the tx hash propagation message.
114type NewPooledTransactionHashes eth.NewPooledTransactionHashesPacket
115
116func (nb NewPooledTransactionHashes) Code() int { return 24 }
117
118type GetPooledTransactions eth.GetPooledTransactionsPacket
119
120func (gpt GetPooledTransactions) Code() int { return 25 }
121
122type PooledTransactions eth.PooledTransactionsPacket
123
124func (pt PooledTransactions) Code() int { return 26 }
125
126// Conn represents an individual connection with a peer
127type Conn struct {
128	*rlpx.Conn
129	ourKey                 *ecdsa.PrivateKey
130	negotiatedProtoVersion uint
131	ourHighestProtoVersion uint
132	caps                   []p2p.Cap
133}
134
135// Read reads an eth packet from the connection.
136func (c *Conn) Read() Message {
137	code, rawData, _, err := c.Conn.Read()
138	if err != nil {
139		return errorf("could not read from connection: %v", err)
140	}
141
142	var msg Message
143	switch int(code) {
144	case (Hello{}).Code():
145		msg = new(Hello)
146	case (Ping{}).Code():
147		msg = new(Ping)
148	case (Pong{}).Code():
149		msg = new(Pong)
150	case (Disconnect{}).Code():
151		msg = new(Disconnect)
152	case (Status{}).Code():
153		msg = new(Status)
154	case (GetBlockHeaders{}).Code():
155		msg = new(GetBlockHeaders)
156	case (BlockHeaders{}).Code():
157		msg = new(BlockHeaders)
158	case (GetBlockBodies{}).Code():
159		msg = new(GetBlockBodies)
160	case (BlockBodies{}).Code():
161		msg = new(BlockBodies)
162	case (NewBlock{}).Code():
163		msg = new(NewBlock)
164	case (NewBlockHashes{}).Code():
165		msg = new(NewBlockHashes)
166	case (Transactions{}).Code():
167		msg = new(Transactions)
168	case (NewPooledTransactionHashes{}).Code():
169		msg = new(NewPooledTransactionHashes)
170	case (GetPooledTransactions{}.Code()):
171		msg = new(GetPooledTransactions)
172	case (PooledTransactions{}.Code()):
173		msg = new(PooledTransactions)
174	default:
175		return errorf("invalid message code: %d", code)
176	}
177	// if message is devp2p, decode here
178	if err := rlp.DecodeBytes(rawData, msg); err != nil {
179		return errorf("could not rlp decode message: %v", err)
180	}
181	return msg
182}
183
184// Read66 reads an eth66 packet from the connection.
185func (c *Conn) Read66() (uint64, Message) {
186	code, rawData, _, err := c.Conn.Read()
187	if err != nil {
188		return 0, errorf("could not read from connection: %v", err)
189	}
190
191	var msg Message
192	switch int(code) {
193	case (Hello{}).Code():
194		msg = new(Hello)
195	case (Ping{}).Code():
196		msg = new(Ping)
197	case (Pong{}).Code():
198		msg = new(Pong)
199	case (Disconnect{}).Code():
200		msg = new(Disconnect)
201	case (Status{}).Code():
202		msg = new(Status)
203	case (GetBlockHeaders{}).Code():
204		ethMsg := new(eth.GetBlockHeadersPacket66)
205		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
206			return 0, errorf("could not rlp decode message: %v", err)
207		}
208		return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket)
209	case (BlockHeaders{}).Code():
210		ethMsg := new(eth.BlockHeadersPacket66)
211		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
212			return 0, errorf("could not rlp decode message: %v", err)
213		}
214		return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket)
215	case (GetBlockBodies{}).Code():
216		ethMsg := new(eth.GetBlockBodiesPacket66)
217		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
218			return 0, errorf("could not rlp decode message: %v", err)
219		}
220		return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket)
221	case (BlockBodies{}).Code():
222		ethMsg := new(eth.BlockBodiesPacket66)
223		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
224			return 0, errorf("could not rlp decode message: %v", err)
225		}
226		return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket)
227	case (NewBlock{}).Code():
228		msg = new(NewBlock)
229	case (NewBlockHashes{}).Code():
230		msg = new(NewBlockHashes)
231	case (Transactions{}).Code():
232		msg = new(Transactions)
233	case (NewPooledTransactionHashes{}).Code():
234		msg = new(NewPooledTransactionHashes)
235	case (GetPooledTransactions{}.Code()):
236		ethMsg := new(eth.GetPooledTransactionsPacket66)
237		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
238			return 0, errorf("could not rlp decode message: %v", err)
239		}
240		return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket)
241	case (PooledTransactions{}.Code()):
242		ethMsg := new(eth.PooledTransactionsPacket66)
243		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
244			return 0, errorf("could not rlp decode message: %v", err)
245		}
246		return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket)
247	default:
248		msg = errorf("invalid message code: %d", code)
249	}
250
251	if msg != nil {
252		if err := rlp.DecodeBytes(rawData, msg); err != nil {
253			return 0, errorf("could not rlp decode message: %v", err)
254		}
255		return 0, msg
256	}
257	return 0, errorf("invalid message: %s", string(rawData))
258}
259
260// Write writes a eth packet to the connection.
261func (c *Conn) Write(msg Message) error {
262	// check if message is eth protocol message
263	var (
264		payload []byte
265		err     error
266	)
267	payload, err = rlp.EncodeToBytes(msg)
268	if err != nil {
269		return err
270	}
271	_, err = c.Conn.Write(uint64(msg.Code()), payload)
272	return err
273}
274
275// Write66 writes an eth66 packet to the connection.
276func (c *Conn) Write66(req eth.Packet, code int) error {
277	payload, err := rlp.EncodeToBytes(req)
278	if err != nil {
279		return err
280	}
281	_, err = c.Conn.Write(uint64(code), payload)
282	return err
283}
284