1// Copyright 2017 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package adapter
16
17import (
18	"context"
19
20	"google.golang.org/grpc"
21	"google.golang.org/grpc/metadata"
22)
23
24// chanServerStream implements grpc.ServerStream with a chanStream
25type chanServerStream struct {
26	headerc  chan<- metadata.MD
27	trailerc chan<- metadata.MD
28	grpc.Stream
29
30	headers []metadata.MD
31}
32
33func (ss *chanServerStream) SendHeader(md metadata.MD) error {
34	if ss.headerc == nil {
35		return errAlreadySentHeader
36	}
37	outmd := make(map[string][]string)
38	for _, h := range append(ss.headers, md) {
39		for k, v := range h {
40			outmd[k] = v
41		}
42	}
43	select {
44	case ss.headerc <- outmd:
45		ss.headerc = nil
46		ss.headers = nil
47		return nil
48	case <-ss.Context().Done():
49	}
50	return ss.Context().Err()
51}
52
53func (ss *chanServerStream) SetHeader(md metadata.MD) error {
54	if ss.headerc == nil {
55		return errAlreadySentHeader
56	}
57	ss.headers = append(ss.headers, md)
58	return nil
59}
60
61func (ss *chanServerStream) SetTrailer(md metadata.MD) {
62	ss.trailerc <- md
63}
64
65// chanClientStream implements grpc.ClientStream with a chanStream
66type chanClientStream struct {
67	headerc  <-chan metadata.MD
68	trailerc <-chan metadata.MD
69	*chanStream
70}
71
72func (cs *chanClientStream) Header() (metadata.MD, error) {
73	select {
74	case md := <-cs.headerc:
75		return md, nil
76	case <-cs.Context().Done():
77	}
78	return nil, cs.Context().Err()
79}
80
81func (cs *chanClientStream) Trailer() metadata.MD {
82	select {
83	case md := <-cs.trailerc:
84		return md
85	case <-cs.Context().Done():
86		return nil
87	}
88}
89
90func (cs *chanClientStream) CloseSend() error {
91	close(cs.chanStream.sendc)
92	return nil
93}
94
95// chanStream implements grpc.Stream using channels
96type chanStream struct {
97	recvc  <-chan interface{}
98	sendc  chan<- interface{}
99	ctx    context.Context
100	cancel context.CancelFunc
101}
102
103func (s *chanStream) Context() context.Context { return s.ctx }
104
105func (s *chanStream) SendMsg(m interface{}) error {
106	select {
107	case s.sendc <- m:
108		if err, ok := m.(error); ok {
109			return err
110		}
111		return nil
112	case <-s.ctx.Done():
113	}
114	return s.ctx.Err()
115}
116
117func (s *chanStream) RecvMsg(m interface{}) error {
118	v := m.(*interface{})
119	for {
120		select {
121		case msg, ok := <-s.recvc:
122			if !ok {
123				return grpc.ErrClientConnClosing
124			}
125			if err, ok := msg.(error); ok {
126				return err
127			}
128			*v = msg
129			return nil
130		case <-s.ctx.Done():
131		}
132		if len(s.recvc) == 0 {
133			// prioritize any pending recv messages over canceled context
134			break
135		}
136	}
137	return s.ctx.Err()
138}
139
140func newPipeStream(ctx context.Context, ssHandler func(chanServerStream) error) chanClientStream {
141	// ch1 is buffered so server can send error on close
142	ch1, ch2 := make(chan interface{}, 1), make(chan interface{})
143	headerc, trailerc := make(chan metadata.MD, 1), make(chan metadata.MD, 1)
144
145	cctx, ccancel := context.WithCancel(ctx)
146	cli := &chanStream{recvc: ch1, sendc: ch2, ctx: cctx, cancel: ccancel}
147	cs := chanClientStream{headerc, trailerc, cli}
148
149	sctx, scancel := context.WithCancel(ctx)
150	srv := &chanStream{recvc: ch2, sendc: ch1, ctx: sctx, cancel: scancel}
151	ss := chanServerStream{headerc, trailerc, srv, nil}
152
153	go func() {
154		if err := ssHandler(ss); err != nil {
155			select {
156			case srv.sendc <- err:
157			case <-sctx.Done():
158			case <-cctx.Done():
159			}
160		}
161		scancel()
162		ccancel()
163	}()
164	return cs
165}
166