1package readline
2
3import (
4	"bufio"
5	"bytes"
6	"encoding/binary"
7	"fmt"
8	"io"
9	"net"
10	"os"
11	"sync"
12	"sync/atomic"
13)
14
15type MsgType int16
16
17const (
18	T_DATA = MsgType(iota)
19	T_WIDTH
20	T_WIDTH_REPORT
21	T_ISTTY_REPORT
22	T_RAW
23	T_ERAW // exit raw
24	T_EOF
25)
26
27type RemoteSvr struct {
28	eof           int32
29	closed        int32
30	width         int32
31	reciveChan    chan struct{}
32	writeChan     chan *writeCtx
33	conn          net.Conn
34	isTerminal    bool
35	funcWidthChan func()
36	stopChan      chan struct{}
37
38	dataBufM sync.Mutex
39	dataBuf  bytes.Buffer
40}
41
42type writeReply struct {
43	n   int
44	err error
45}
46
47type writeCtx struct {
48	msg   *Message
49	reply chan *writeReply
50}
51
52func newWriteCtx(msg *Message) *writeCtx {
53	return &writeCtx{
54		msg:   msg,
55		reply: make(chan *writeReply),
56	}
57}
58
59func NewRemoteSvr(conn net.Conn) (*RemoteSvr, error) {
60	rs := &RemoteSvr{
61		width:      -1,
62		conn:       conn,
63		writeChan:  make(chan *writeCtx),
64		reciveChan: make(chan struct{}),
65		stopChan:   make(chan struct{}),
66	}
67	buf := bufio.NewReader(rs.conn)
68
69	if err := rs.init(buf); err != nil {
70		return nil, err
71	}
72
73	go rs.readLoop(buf)
74	go rs.writeLoop()
75	return rs, nil
76}
77
78func (r *RemoteSvr) init(buf *bufio.Reader) error {
79	m, err := ReadMessage(buf)
80	if err != nil {
81		return err
82	}
83	// receive isTerminal
84	if m.Type != T_ISTTY_REPORT {
85		return fmt.Errorf("unexpected init message")
86	}
87	r.GotIsTerminal(m.Data)
88
89	// receive width
90	m, err = ReadMessage(buf)
91	if err != nil {
92		return err
93	}
94	if m.Type != T_WIDTH_REPORT {
95		return fmt.Errorf("unexpected init message")
96	}
97	r.GotReportWidth(m.Data)
98
99	return nil
100}
101
102func (r *RemoteSvr) HandleConfig(cfg *Config) {
103	cfg.Stderr = r
104	cfg.Stdout = r
105	cfg.Stdin = r
106	cfg.FuncExitRaw = r.ExitRawMode
107	cfg.FuncIsTerminal = r.IsTerminal
108	cfg.FuncMakeRaw = r.EnterRawMode
109	cfg.FuncExitRaw = r.ExitRawMode
110	cfg.FuncGetWidth = r.GetWidth
111	cfg.FuncOnWidthChanged = func(f func()) {
112		r.funcWidthChan = f
113	}
114}
115
116func (r *RemoteSvr) IsTerminal() bool {
117	return r.isTerminal
118}
119
120func (r *RemoteSvr) checkEOF() error {
121	if atomic.LoadInt32(&r.eof) == 1 {
122		return io.EOF
123	}
124	return nil
125}
126
127func (r *RemoteSvr) Read(b []byte) (int, error) {
128	r.dataBufM.Lock()
129	n, err := r.dataBuf.Read(b)
130	r.dataBufM.Unlock()
131	if n == 0 {
132		if err := r.checkEOF(); err != nil {
133			return 0, err
134		}
135	}
136
137	if n == 0 && err == io.EOF {
138		<-r.reciveChan
139		r.dataBufM.Lock()
140		n, err = r.dataBuf.Read(b)
141		r.dataBufM.Unlock()
142	}
143	if n == 0 {
144		if err := r.checkEOF(); err != nil {
145			return 0, err
146		}
147	}
148
149	return n, err
150}
151
152func (r *RemoteSvr) writeMsg(m *Message) error {
153	ctx := newWriteCtx(m)
154	r.writeChan <- ctx
155	reply := <-ctx.reply
156	return reply.err
157}
158
159func (r *RemoteSvr) Write(b []byte) (int, error) {
160	ctx := newWriteCtx(NewMessage(T_DATA, b))
161	r.writeChan <- ctx
162	reply := <-ctx.reply
163	return reply.n, reply.err
164}
165
166func (r *RemoteSvr) EnterRawMode() error {
167	return r.writeMsg(NewMessage(T_RAW, nil))
168}
169
170func (r *RemoteSvr) ExitRawMode() error {
171	return r.writeMsg(NewMessage(T_ERAW, nil))
172}
173
174func (r *RemoteSvr) writeLoop() {
175	defer r.Close()
176
177loop:
178	for {
179		select {
180		case ctx, ok := <-r.writeChan:
181			if !ok {
182				break
183			}
184			n, err := ctx.msg.WriteTo(r.conn)
185			ctx.reply <- &writeReply{n, err}
186		case <-r.stopChan:
187			break loop
188		}
189	}
190}
191
192func (r *RemoteSvr) Close() error {
193	if atomic.CompareAndSwapInt32(&r.closed, 0, 1) {
194		close(r.stopChan)
195		r.conn.Close()
196	}
197	return nil
198}
199
200func (r *RemoteSvr) readLoop(buf *bufio.Reader) {
201	defer r.Close()
202	for {
203		m, err := ReadMessage(buf)
204		if err != nil {
205			break
206		}
207		switch m.Type {
208		case T_EOF:
209			atomic.StoreInt32(&r.eof, 1)
210			select {
211			case r.reciveChan <- struct{}{}:
212			default:
213			}
214		case T_DATA:
215			r.dataBufM.Lock()
216			r.dataBuf.Write(m.Data)
217			r.dataBufM.Unlock()
218			select {
219			case r.reciveChan <- struct{}{}:
220			default:
221			}
222		case T_WIDTH_REPORT:
223			r.GotReportWidth(m.Data)
224		case T_ISTTY_REPORT:
225			r.GotIsTerminal(m.Data)
226		}
227	}
228}
229
230func (r *RemoteSvr) GotIsTerminal(data []byte) {
231	if binary.BigEndian.Uint16(data) == 0 {
232		r.isTerminal = false
233	} else {
234		r.isTerminal = true
235	}
236}
237
238func (r *RemoteSvr) GotReportWidth(data []byte) {
239	atomic.StoreInt32(&r.width, int32(binary.BigEndian.Uint16(data)))
240	if r.funcWidthChan != nil {
241		r.funcWidthChan()
242	}
243}
244
245func (r *RemoteSvr) GetWidth() int {
246	return int(atomic.LoadInt32(&r.width))
247}
248
249// -----------------------------------------------------------------------------
250
251type Message struct {
252	Type MsgType
253	Data []byte
254}
255
256func ReadMessage(r io.Reader) (*Message, error) {
257	m := new(Message)
258	var length int32
259	if err := binary.Read(r, binary.BigEndian, &length); err != nil {
260		return nil, err
261	}
262	if err := binary.Read(r, binary.BigEndian, &m.Type); err != nil {
263		return nil, err
264	}
265	m.Data = make([]byte, int(length)-2)
266	if _, err := io.ReadFull(r, m.Data); err != nil {
267		return nil, err
268	}
269	return m, nil
270}
271
272func NewMessage(t MsgType, data []byte) *Message {
273	return &Message{t, data}
274}
275
276func (m *Message) WriteTo(w io.Writer) (int, error) {
277	buf := bytes.NewBuffer(make([]byte, 0, len(m.Data)+2+4))
278	binary.Write(buf, binary.BigEndian, int32(len(m.Data)+2))
279	binary.Write(buf, binary.BigEndian, m.Type)
280	buf.Write(m.Data)
281	n, err := buf.WriteTo(w)
282	return int(n), err
283}
284
285// -----------------------------------------------------------------------------
286
287type RemoteCli struct {
288	conn        net.Conn
289	raw         RawMode
290	receiveChan chan struct{}
291	inited      int32
292	isTerminal  *bool
293
294	data  bytes.Buffer
295	dataM sync.Mutex
296}
297
298func NewRemoteCli(conn net.Conn) (*RemoteCli, error) {
299	r := &RemoteCli{
300		conn:        conn,
301		receiveChan: make(chan struct{}),
302	}
303	return r, nil
304}
305
306func (r *RemoteCli) MarkIsTerminal(is bool) {
307	r.isTerminal = &is
308}
309
310func (r *RemoteCli) init() error {
311	if !atomic.CompareAndSwapInt32(&r.inited, 0, 1) {
312		return nil
313	}
314
315	if err := r.reportIsTerminal(); err != nil {
316		return err
317	}
318
319	if err := r.reportWidth(); err != nil {
320		return err
321	}
322
323	// register sig for width changed
324	DefaultOnWidthChanged(func() {
325		r.reportWidth()
326	})
327	return nil
328}
329
330func (r *RemoteCli) writeMsg(m *Message) error {
331	r.dataM.Lock()
332	_, err := m.WriteTo(r.conn)
333	r.dataM.Unlock()
334	return err
335}
336
337func (r *RemoteCli) Write(b []byte) (int, error) {
338	m := NewMessage(T_DATA, b)
339	r.dataM.Lock()
340	_, err := m.WriteTo(r.conn)
341	r.dataM.Unlock()
342	return len(b), err
343}
344
345func (r *RemoteCli) reportWidth() error {
346	screenWidth := GetScreenWidth()
347	data := make([]byte, 2)
348	binary.BigEndian.PutUint16(data, uint16(screenWidth))
349	msg := NewMessage(T_WIDTH_REPORT, data)
350
351	if err := r.writeMsg(msg); err != nil {
352		return err
353	}
354	return nil
355}
356
357func (r *RemoteCli) reportIsTerminal() error {
358	var isTerminal bool
359	if r.isTerminal != nil {
360		isTerminal = *r.isTerminal
361	} else {
362		isTerminal = DefaultIsTerminal()
363	}
364	data := make([]byte, 2)
365	if isTerminal {
366		binary.BigEndian.PutUint16(data, 1)
367	} else {
368		binary.BigEndian.PutUint16(data, 0)
369	}
370	msg := NewMessage(T_ISTTY_REPORT, data)
371	if err := r.writeMsg(msg); err != nil {
372		return err
373	}
374	return nil
375}
376
377func (r *RemoteCli) readLoop() {
378	buf := bufio.NewReader(r.conn)
379	for {
380		msg, err := ReadMessage(buf)
381		if err != nil {
382			break
383		}
384		switch msg.Type {
385		case T_ERAW:
386			r.raw.Exit()
387		case T_RAW:
388			r.raw.Enter()
389		case T_DATA:
390			os.Stdout.Write(msg.Data)
391		}
392	}
393}
394
395func (r *RemoteCli) ServeBy(source io.Reader) error {
396	if err := r.init(); err != nil {
397		return err
398	}
399
400	go func() {
401		defer r.Close()
402		for {
403			n, _ := io.Copy(r, source)
404			if n == 0 {
405				break
406			}
407		}
408	}()
409	defer r.raw.Exit()
410	r.readLoop()
411	return nil
412}
413
414func (r *RemoteCli) Close() {
415	r.writeMsg(NewMessage(T_EOF, nil))
416}
417
418func (r *RemoteCli) Serve() error {
419	return r.ServeBy(os.Stdin)
420}
421
422func ListenRemote(n, addr string, cfg *Config, h func(*Instance), onListen ...func(net.Listener) error) error {
423	ln, err := net.Listen(n, addr)
424	if err != nil {
425		return err
426	}
427	if len(onListen) > 0 {
428		if err := onListen[0](ln); err != nil {
429			return err
430		}
431	}
432	for {
433		conn, err := ln.Accept()
434		if err != nil {
435			break
436		}
437		go func() {
438			defer conn.Close()
439			rl, err := HandleConn(*cfg, conn)
440			if err != nil {
441				return
442			}
443			h(rl)
444		}()
445	}
446	return nil
447}
448
449func HandleConn(cfg Config, conn net.Conn) (*Instance, error) {
450	r, err := NewRemoteSvr(conn)
451	if err != nil {
452		return nil, err
453	}
454	r.HandleConfig(&cfg)
455
456	rl, err := NewEx(&cfg)
457	if err != nil {
458		return nil, err
459	}
460	return rl, nil
461}
462
463func DialRemote(n, addr string) error {
464	conn, err := net.Dial(n, addr)
465	if err != nil {
466		return err
467	}
468	defer conn.Close()
469
470	cli, err := NewRemoteCli(conn)
471	if err != nil {
472		return err
473	}
474	return cli.Serve()
475}
476