1/*
2Copyright 2014 SAP SE
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package protocol
18
19import (
20	"log"
21	"net"
22
23	"github.com/SAP/go-hdb/internal/bufio"
24)
25
26type dir bool
27
28const (
29	maxBinarySize = 128
30)
31
32type fragment interface {
33	read(rd *bufio.Reader) error
34	write(wr *bufio.Writer) error
35}
36
37func (d dir) String() string {
38	if d {
39		return "->"
40	}
41	return "<-"
42}
43
44// A Sniffer is a simple proxy for logging hdb protocol requests and responses.
45type Sniffer struct {
46	conn   net.Conn
47	dbAddr string
48	dbConn net.Conn
49
50	//client
51	clRd *bufio.Reader
52	clWr *bufio.Writer
53	//database
54	dbRd *bufio.Reader
55	dbWr *bufio.Writer
56
57	mh *messageHeader
58	sh *segmentHeader
59	ph *partHeader
60
61	buf []byte
62}
63
64// NewSniffer creates a new sniffer instance. The conn parameter is the net.Conn connection, where the Sniffer
65// is listening for hdb protocol calls. The dbAddr is the hdb host port address in "host:port" format.
66func NewSniffer(conn net.Conn, dbAddr string) (*Sniffer, error) {
67	s := &Sniffer{
68		conn:   conn,
69		dbAddr: dbAddr,
70		clRd:   bufio.NewReader(conn),
71		clWr:   bufio.NewWriter(conn),
72		mh:     &messageHeader{},
73		sh:     &segmentHeader{},
74		ph:     &partHeader{},
75		buf:    make([]byte, 0),
76	}
77
78	dbConn, err := net.Dial("tcp", s.dbAddr)
79	if err != nil {
80		return nil, err
81	}
82
83	s.dbRd = bufio.NewReader(dbConn)
84	s.dbWr = bufio.NewWriter(dbConn)
85	s.dbConn = dbConn
86	return s, nil
87}
88
89func (s *Sniffer) getBuffer(size int) []byte {
90	if cap(s.buf) < size {
91		s.buf = make([]byte, size)
92	}
93	return s.buf[:size]
94}
95
96// Go starts the protocol request and response logging.
97func (s *Sniffer) Go() {
98	defer s.dbConn.Close()
99	defer s.conn.Close()
100
101	req := newInitRequest()
102	if err := s.streamFragment(dir(true), s.clRd, s.dbWr, req); err != nil {
103		return
104	}
105
106	rep := newInitReply()
107	if err := s.streamFragment(dir(false), s.dbRd, s.clWr, rep); err != nil {
108		return
109	}
110
111	for {
112		//up stream
113		if err := s.stream(dir(true), s.clRd, s.dbWr); err != nil {
114			return
115		}
116		//down stream
117		if err := s.stream(dir(false), s.dbRd, s.clWr); err != nil {
118			return
119		}
120	}
121}
122
123func (s *Sniffer) stream(d dir, from *bufio.Reader, to *bufio.Writer) error {
124
125	if err := s.streamFragment(d, from, to, s.mh); err != nil {
126		return err
127	}
128
129	size := int(s.mh.varPartLength)
130
131	for i := 0; i < int(s.mh.noOfSegm); i++ {
132
133		if err := s.streamFragment(d, from, to, s.sh); err != nil {
134			return err
135		}
136
137		size -= int(s.sh.segmentLength)
138
139		for j := 0; j < int(s.sh.noOfParts); j++ {
140
141			if err := s.streamFragment(d, from, to, s.ph); err != nil {
142				return err
143			}
144
145			// protocol error workaraound
146			padding := (size == 0) || (j != (int(s.sh.noOfParts) - 1))
147
148			if err := s.streamPart(d, from, to, s.ph, padding); err != nil {
149				return err
150			}
151		}
152	}
153	return to.Flush()
154}
155
156func (s *Sniffer) streamPart(d dir, from *bufio.Reader, to *bufio.Writer, ph *partHeader, padding bool) error {
157
158	switch ph.partKind {
159
160	default:
161		return s.streamBinary(d, from, to, int(ph.bufferLength), padding)
162	}
163}
164
165func (s *Sniffer) streamBinary(d dir, from *bufio.Reader, to *bufio.Writer, size int, padding bool) error {
166	var b []byte
167
168	//protocol error workaraound
169	if padding {
170		pad := padBytes(size)
171		b = s.getBuffer(size + pad)
172	} else {
173		b = s.getBuffer(size)
174	}
175
176	from.ReadFull(b)
177	err := from.GetError()
178	if err != nil {
179		log.Print(err)
180		return err
181	}
182
183	if size > maxBinarySize {
184		log.Printf("%s %v", d, b[:maxBinarySize])
185	} else {
186		log.Printf("%s %v", d, b[:size])
187	}
188	to.Write(b)
189	return nil
190}
191
192func (s *Sniffer) streamFragment(d dir, from *bufio.Reader, to *bufio.Writer, f fragment) error {
193	if err := f.read(from); err != nil {
194		log.Print(err)
195		return err
196	}
197	log.Printf("%s %s", d, f)
198	if err := f.write(to); err != nil {
199		log.Print(err)
200		return err
201	}
202	return nil
203}
204