1package forwarding
2
3import (
4	"context"
5	"fmt"
6	"os"
7	"sync"
8	"time"
9
10	"github.com/pkg/errors"
11
12	"github.com/golang/protobuf/ptypes"
13
14	"github.com/mutagen-io/mutagen/pkg/encoding"
15	"github.com/mutagen-io/mutagen/pkg/logging"
16	"github.com/mutagen-io/mutagen/pkg/mutagen"
17	"github.com/mutagen-io/mutagen/pkg/prompting"
18	"github.com/mutagen-io/mutagen/pkg/state"
19	"github.com/mutagen-io/mutagen/pkg/url"
20)
21
22const (
23	// autoReconnectInterval is the period of time to wait before attempting an
24	// automatic reconnect after disconnection or a failed reconnect.
25	autoReconnectInterval = 15 * time.Second
26)
27
28// controller manages and executes a single session.
29type controller struct {
30	// logger is the controller logger.
31	logger *logging.Logger
32	// sessionPath is the path to the serialized session.
33	sessionPath string
34	// stateLock guards and tracks changes to the session member's Paused field
35	// and the state member.
36	stateLock *state.TrackingLock
37	// session encodes the associated session metadata. It is considered static
38	// and safe for concurrent access except for its Paused field, for which the
39	// stateLock member should be held. It should be saved to disk any time it
40	// is modified.
41	session *Session
42	// mergedSourceConfiguration is the source-specific configuration object
43	// (computed from the core configuration and source-specific overrides). It
44	// is considered static and safe for concurrent access. It is a derived
45	// field and not saved to disk.
46	mergedSourceConfiguration *Configuration
47	// mergedDestinationConfiguration is the destination-specific configuration
48	// object (computed from the core configuration and destination-specific
49	// overrides). It is considered static and safe for concurrent access. It is
50	// a derived field and not saved to disk.
51	mergedDestinationConfiguration *Configuration
52	// state represents the current forwarding state.
53	state *State
54	// lifecycleLock guards setting of the disabled, cancel, flushRequests, and
55	// done members. Access to these members is allowed for the forwarding loop
56	// without holding the lock. Any code wishing to set these members should
57	// first acquire the lock, then cancel the forwarding loop, and wait for it
58	// to complete before making any such changes.
59	lifecycleLock sync.Mutex
60	// disabled indicates that no more changes to the forwarding loop lifecycle
61	// are allowed (i.e. no more forwarding loops can be started for this
62	// controller). This is used by terminate and shutdown. It should only be
63	// set to true once any existing forwarding loop has been stopped.
64	disabled bool
65	// cancel cancels the forwarding loop execution context. It should be nil if
66	// and only if there is no forwarding loop running.
67	cancel context.CancelFunc
68	// done will be closed by the current forwarding loop when it exits.
69	done chan struct{}
70}
71
72// newSession creates a new session and corresponding controller.
73func newSession(
74	ctx context.Context,
75	logger *logging.Logger,
76	tracker *state.Tracker,
77	identifier string,
78	source, destination *url.URL,
79	configuration, configurationSource, configurationDestination *Configuration,
80	name string,
81	labels map[string]string,
82	paused bool,
83	prompter string,
84) (*controller, error) {
85	// Update status.
86	prompting.Message(prompter, "Creating session...")
87
88	// Set the session version.
89	version := Version_Version1
90
91	// Compute the creation time and convert it to Protocol Buffers format.
92	creationTime := time.Now()
93	creationTimeProto, err := ptypes.TimestampProto(creationTime)
94	if err != nil {
95		return nil, errors.Wrap(err, "unable to convert creation time format")
96	}
97
98	// Compute merged endpoint configurations.
99	mergedSourceConfiguration := MergeConfigurations(configuration, configurationSource)
100	mergedDestinationConfiguration := MergeConfigurations(configuration, configurationDestination)
101
102	// If the session isn't being created paused, then try to connect to any
103	// endpoints not using the tunnel protocol. The tunnel protocol is the one
104	// case where we want to allow asynchronous connectivity (since it doesn't
105	// require user input but also isn't guaranteed to connect immediately). If
106	// we connect to endpoints here and don't hand them off to the runloop
107	// below, then defer their shutdown.
108	var sourceEndpoint, destinationEndpoint Endpoint
109	defer func() {
110		if sourceEndpoint != nil {
111			sourceEndpoint.Shutdown()
112			sourceEndpoint = nil
113		}
114		if destinationEndpoint != nil {
115			destinationEndpoint.Shutdown()
116			destinationEndpoint = nil
117		}
118	}()
119	if !paused && source.Protocol != url.Protocol_Tunnel {
120		logger.Info("Connecting to source endpoint")
121		sourceEndpoint, err = connect(
122			ctx,
123			logger.Sublogger("source"),
124			source,
125			prompter,
126			identifier,
127			version,
128			mergedSourceConfiguration,
129			true,
130		)
131		if err != nil {
132			logger.Info("Source connection failure:", err)
133			return nil, errors.Wrap(err, "unable to connect to source")
134		}
135	}
136	if !paused && destination.Protocol != url.Protocol_Tunnel {
137		logger.Info("Connecting to destination endpoint")
138		destinationEndpoint, err = connect(
139			ctx,
140			logger.Sublogger("destination"),
141			destination,
142			prompter,
143			identifier,
144			version,
145			mergedDestinationConfiguration,
146			false,
147		)
148		if err != nil {
149			logger.Info("Destination connection failure:", err)
150			return nil, errors.Wrap(err, "unable to connect to destination")
151		}
152	}
153
154	// Create the session.
155	session := &Session{
156		Identifier:               identifier,
157		Version:                  version,
158		CreationTime:             creationTimeProto,
159		CreatingVersionMajor:     mutagen.VersionMajor,
160		CreatingVersionMinor:     mutagen.VersionMinor,
161		CreatingVersionPatch:     mutagen.VersionPatch,
162		Source:                   source,
163		Destination:              destination,
164		Configuration:            configuration,
165		ConfigurationSource:      configurationSource,
166		ConfigurationDestination: configurationDestination,
167		Name:                     name,
168		Labels:                   labels,
169		Paused:                   paused,
170	}
171
172	// Compute the session path.
173	sessionPath, err := pathForSession(session.Identifier)
174	if err != nil {
175		return nil, errors.Wrap(err, "unable to compute session path")
176	}
177
178	// Save the session to disk.
179	if err := encoding.MarshalAndSaveProtobuf(sessionPath, session); err != nil {
180		return nil, errors.Wrap(err, "unable to save session")
181	}
182
183	// Create the controller.
184	controller := &controller{
185		logger:                         logger,
186		sessionPath:                    sessionPath,
187		stateLock:                      state.NewTrackingLock(tracker),
188		session:                        session,
189		mergedSourceConfiguration:      mergedSourceConfiguration,
190		mergedDestinationConfiguration: mergedDestinationConfiguration,
191		state: &State{
192			Session: session,
193		},
194	}
195
196	// If the session isn't being created pre-paused, then start a forwarding
197	// loop and mark the endpoints as handed off to that loop so that we don't
198	// defer their shutdown.
199	if !paused {
200		logger.Info("Starting forwarding loop")
201		ctx, cancel := context.WithCancel(context.Background())
202		controller.cancel = cancel
203		controller.done = make(chan struct{})
204		go controller.run(ctx, sourceEndpoint, destinationEndpoint)
205		sourceEndpoint = nil
206		destinationEndpoint = nil
207	}
208
209	// Success.
210	logger.Info("Session initialized")
211	return controller, nil
212}
213
214// loadSession loads an existing session and creates a corresponding controller.
215func loadSession(logger *logging.Logger, tracker *state.Tracker, identifier string) (*controller, error) {
216	// Compute the session path.
217	sessionPath, err := pathForSession(identifier)
218	if err != nil {
219		return nil, errors.Wrap(err, "unable to compute session path")
220	}
221
222	// Load and validate the session.
223	session := &Session{}
224	if err := encoding.LoadAndUnmarshalProtobuf(sessionPath, session); err != nil {
225		return nil, errors.Wrap(err, "unable to load session configuration")
226	}
227	if err := session.EnsureValid(); err != nil {
228		return nil, errors.Wrap(err, "invalid session found on disk")
229	}
230
231	// Create the controller.
232	controller := &controller{
233		logger:      logger,
234		sessionPath: sessionPath,
235		stateLock:   state.NewTrackingLock(tracker),
236		session:     session,
237		mergedSourceConfiguration: MergeConfigurations(
238			session.Configuration,
239			session.ConfigurationSource,
240		),
241		mergedDestinationConfiguration: MergeConfigurations(
242			session.Configuration,
243			session.ConfigurationDestination,
244		),
245		state: &State{
246			Session: session,
247		},
248	}
249
250	// If the session isn't marked as paused, start a forwarding loop.
251	if !session.Paused {
252		ctx, cancel := context.WithCancel(context.Background())
253		controller.cancel = cancel
254		controller.done = make(chan struct{})
255		go controller.run(ctx, nil, nil)
256	}
257
258	// Success.
259	logger.Info("Session loaded")
260	return controller, nil
261}
262
263// currentState creates a snapshot of the current session state.
264func (c *controller) currentState() *State {
265	// Lock the session state and defer its release. It's very important that we
266	// unlock without a notification here, otherwise we'd trigger an infinite
267	// cycle of list/notify.
268	c.stateLock.Lock()
269	defer c.stateLock.UnlockWithoutNotify()
270
271	// Perform a (pseudo) deep copy of the state.
272	return c.state.Copy()
273}
274
275// resume attempts to reconnect and resume the session if it isn't currently
276// connected and forwarding.
277func (c *controller) resume(ctx context.Context, prompter string) error {
278	// Update status.
279	prompting.Message(prompter, fmt.Sprintf("Resuming session %s...", c.session.Identifier))
280
281	// Lock the controller's lifecycle and defer its release.
282	c.lifecycleLock.Lock()
283	defer c.lifecycleLock.Unlock()
284
285	// Don't allow any resume operations if the controller is disabled.
286	if c.disabled {
287		return errors.New("controller disabled")
288	}
289
290	// Check if there's an existing forwarding loop (i.e. if the session is
291	// unpaused).
292	if c.cancel != nil {
293		// If there is an existing forwarding loop, check if it's already in a
294		// state that's considered "forwarding".
295		c.stateLock.Lock()
296		forwarding := c.state.Status >= Status_ForwardingConnections
297		c.stateLock.UnlockWithoutNotify()
298
299		// If we're already forwarding, then there's nothing we need to do. We
300		// don't even need to mark the session as unpaused because it can't be
301		// marked as paused if an existing forwarding loop is running (we
302		// enforce this invariant as part of the controller's logic).
303		if forwarding {
304			return nil
305		}
306
307		// Otherwise, cancel the existing forwarding loop and wait for it to
308		// finish.
309		//
310		// There's something of an efficiency race condition here, because the
311		// existing loop might succeed in connecting between the time we check
312		// and the time we cancel it. That could happen if an auto-reconnect
313		// succeeds or even if the loop was already passed connections and it's
314		// just hasn't updated its status yet. But the only danger here is
315		// basically wasting those connections, and the window is very small.
316		c.cancel()
317		<-c.done
318
319		// Nil out any lifecycle state.
320		c.cancel = nil
321		c.done = nil
322	}
323
324	// Mark the session as unpaused and save it to disk.
325	c.stateLock.Lock()
326	c.session.Paused = false
327	saveErr := encoding.MarshalAndSaveProtobuf(c.sessionPath, c.session)
328	c.stateLock.Unlock()
329
330	// Attempt to connect to source.
331	c.stateLock.Lock()
332	c.state.Status = Status_ConnectingSource
333	c.stateLock.Unlock()
334	source, sourceConnectErr := connect(
335		ctx,
336		c.logger.Sublogger("source"),
337		c.session.Source,
338		prompter,
339		c.session.Identifier,
340		c.session.Version,
341		c.mergedSourceConfiguration,
342		true,
343	)
344	c.stateLock.Lock()
345	c.state.SourceConnected = (source != nil)
346	c.stateLock.Unlock()
347
348	// Attempt to connect to destination.
349	c.stateLock.Lock()
350	c.state.Status = Status_ConnectingDestination
351	c.stateLock.Unlock()
352	destination, destinationConnectErr := connect(
353		ctx,
354		c.logger.Sublogger("destination"),
355		c.session.Destination,
356		prompter,
357		c.session.Identifier,
358		c.session.Version,
359		c.mergedDestinationConfiguration,
360		false,
361	)
362	c.stateLock.Lock()
363	c.state.DestinationConnected = (destination != nil)
364	c.stateLock.Unlock()
365
366	// Start the forwarding loop with what we have. Source or destination may
367	// have failed to connect (and be nil), but in any case that'll just make
368	// the run loop keep trying to connect.
369	ctx, cancel := context.WithCancel(context.Background())
370	c.cancel = cancel
371	c.done = make(chan struct{})
372	go c.run(ctx, source, destination)
373
374	// Report any errors. Since we always want to start a forwarding loop, even
375	// on partial or complete failure (since it might be able to auto-reconnect
376	// on its own), we wait until the end to report errors.
377	if saveErr != nil {
378		return errors.Wrap(saveErr, "unable to save session")
379	} else if sourceConnectErr != nil {
380		return errors.Wrap(sourceConnectErr, "unable to connect to source")
381	} else if destinationConnectErr != nil {
382		return errors.Wrap(destinationConnectErr, "unable to connect to destination")
383	}
384
385	// Success.
386	return nil
387}
388
389// controllerHaltMode represents the behavior to use when halting a session.
390type controllerHaltMode uint8
391
392const (
393	// controllerHaltModePause indicates that a session should be halted and
394	// marked as paused.
395	controllerHaltModePause controllerHaltMode = iota
396	// controllerHaltModeShutdown indicates that a session should be halted.
397	controllerHaltModeShutdown
398	// controllerHaltModeShutdown indicates that a session should be halted and
399	// then deleted.
400	controllerHaltModeTerminate
401)
402
403// description returns a human-readable description of a halt mode.
404func (m controllerHaltMode) description() string {
405	switch m {
406	case controllerHaltModePause:
407		return "Pausing"
408	case controllerHaltModeShutdown:
409		return "Shutting down"
410	case controllerHaltModeTerminate:
411		return "Terminating"
412	default:
413		panic("unhandled halt mode")
414	}
415}
416
417// halt halts the session with the specified behavior.
418func (c *controller) halt(_ context.Context, mode controllerHaltMode, prompter string) error {
419	// Update status.
420	prompting.Message(prompter, fmt.Sprintf("%s session %s...", mode.description(), c.session.Identifier))
421
422	// Lock the controller's lifecycle and defer its release.
423	c.lifecycleLock.Lock()
424	defer c.lifecycleLock.Unlock()
425
426	// Don't allow any additional halt operations if the controller is disabled,
427	// because either this session is being terminated or the service is
428	// shutting down, and in either case there is no point in halting.
429	if c.disabled {
430		return errors.New("controller disabled")
431	}
432
433	// Kill any existing forwarding loop.
434	if c.cancel != nil {
435		// Cancel the forwarding loop and wait for it to finish.
436		c.cancel()
437		<-c.done
438
439		// Nil out any lifecycle state.
440		c.cancel = nil
441		c.done = nil
442	}
443
444	// Handle based on the halt mode.
445	if mode == controllerHaltModePause {
446		// Mark the session as paused and save it.
447		c.stateLock.Lock()
448		c.session.Paused = true
449		saveErr := encoding.MarshalAndSaveProtobuf(c.sessionPath, c.session)
450		c.stateLock.Unlock()
451		if saveErr != nil {
452			return errors.Wrap(saveErr, "unable to save session")
453		}
454	} else if mode == controllerHaltModeShutdown {
455		// Disable the controller.
456		c.disabled = true
457	} else if mode == controllerHaltModeTerminate {
458		// Disable the controller.
459		c.disabled = true
460
461		// Wipe the session information from disk.
462		sessionRemoveErr := os.Remove(c.sessionPath)
463		if sessionRemoveErr != nil {
464			return errors.Wrap(sessionRemoveErr, "unable to remove session from disk")
465		}
466	} else {
467		panic("invalid halt mode specified")
468	}
469
470	// Success.
471	return nil
472}
473
474// run is the main runloop for the controller, managing connectivity and
475// forwarding.
476func (c *controller) run(ctx context.Context, source, destination Endpoint) {
477	// Defer resource and state cleanup.
478	defer func() {
479		// Shutdown any endpoints. These might be non-nil if the runloop was
480		// cancelled while partially connected rather than after forwarding
481		// failure.
482		if source != nil {
483			source.Shutdown()
484		}
485		if destination != nil {
486			destination.Shutdown()
487		}
488
489		// Reset the state.
490		c.stateLock.Lock()
491		c.state = &State{
492			Session: c.session,
493		}
494		c.stateLock.Unlock()
495
496		// Signal completion.
497		close(c.done)
498	}()
499
500	// Track the last time that forwarding failed.
501	var lastForwardingFailureTime time.Time
502
503	// Loop until cancelled.
504	for {
505		// Loop until we're connected to both endpoints. We do a non-blocking
506		// check for cancellation on each reconnect error so that we don't waste
507		// resources by trying another connect when the context has been
508		// cancelled (it'll be wasteful). This is better than sentinel errors.
509		for {
510			// Ensure that source is connected.
511			var sourceConnectErr error
512			if source == nil {
513				c.stateLock.Lock()
514				c.state.Status = Status_ConnectingSource
515				c.stateLock.Unlock()
516				source, sourceConnectErr = connect(
517					ctx,
518					c.logger.Sublogger("source"),
519					c.session.Source,
520					"",
521					c.session.Identifier,
522					c.session.Version,
523					c.mergedSourceConfiguration,
524					true,
525				)
526			}
527			c.stateLock.Lock()
528			c.state.SourceConnected = (source != nil)
529			if sourceConnectErr != nil {
530				c.state.LastError = errors.Wrap(sourceConnectErr, "unable to connect to source").Error()
531			}
532			c.stateLock.Unlock()
533
534			// Check for cancellation to avoid a spurious connection to
535			// destination in case cancellation occurred while connecting to
536			// source.
537			select {
538			case <-ctx.Done():
539				return
540			default:
541			}
542
543			// Ensure that destination is connected.
544			var destinationConnectErr error
545			if destination == nil {
546				c.stateLock.Lock()
547				c.state.Status = Status_ConnectingDestination
548				c.stateLock.Unlock()
549				destination, destinationConnectErr = connect(
550					ctx,
551					c.logger.Sublogger("destination"),
552					c.session.Destination,
553					"",
554					c.session.Identifier,
555					c.session.Version,
556					c.mergedDestinationConfiguration,
557					false,
558				)
559			}
560			c.stateLock.Lock()
561			c.state.DestinationConnected = (destination != nil)
562			if destinationConnectErr != nil {
563				c.state.LastError = errors.Wrap(destinationConnectErr, "unable to connect to destination").Error()
564			}
565			c.stateLock.Unlock()
566
567			// If both endpoints are connected, we're done. We perform this
568			// check here (rather than in the loop condition) because if we did
569			// it in the loop condition we'd still need a check here to avoid a
570			// sleep every time (even if already successfully connected).
571			if source != nil && destination != nil {
572				break
573			}
574
575			// If we failed to connect, wait and then retry. Watch for
576			// cancellation in the mean time.
577			select {
578			case <-ctx.Done():
579				return
580			case <-time.After(autoReconnectInterval):
581			}
582		}
583
584		// Grab transport error channels for each endpoint.
585		sourceTransportErrors := source.TransportErrors()
586		destinationTransportErrors := destination.TransportErrors()
587
588		// Create a cancellable subcontext that we can use to manage shutdown.
589		shutdownCtx, forceShutdown := context.WithCancel(ctx)
590
591		// Create a Goroutine that will shut down (and unblock) endpoints. This
592		// is the only way to unblock forwarding on cancellation.
593		shutdownComplete := make(chan struct{})
594		go func() {
595			<-shutdownCtx.Done()
596			source.Shutdown()
597			destination.Shutdown()
598			close(shutdownComplete)
599		}()
600
601		// Perform forwarding in a background Goroutine and monitor for errors.
602		forwardingErrors := make(chan error, 1)
603		go func() {
604			forwardingErrors <- c.forward(source, destination)
605		}()
606
607		// Wait for cancellation, an error from forwarding, or an error from
608		// either transport.
609		var sessionErr error
610		var forwardingErrorReceived bool
611		select {
612		case <-ctx.Done():
613			sessionErr = errors.New("session cancelled")
614		case sessionErr = <-forwardingErrors:
615			forwardingErrorReceived = true
616		case err := <-sourceTransportErrors:
617			sessionErr = fmt.Errorf("source transport failure: %w", err)
618		case err := <-destinationTransportErrors:
619			sessionErr = fmt.Errorf("destination transport failure: %w", err)
620		}
621
622		// Force shutdown, which may have already occurred due to cancellation.
623		forceShutdown()
624
625		// Wait for shutdown to complete.
626		<-shutdownComplete
627
628		// If the forwarding loop wasn't what unblocked our wait, then wait for
629		// it to return a result so that we know it has exited. This isn't
630		// strictly necessary with our current design, but it's cleaner and more
631		// robust.
632		if !forwardingErrorReceived {
633			<-forwardingErrors
634		}
635
636		// Nil out endpoints to update our state.
637		source = nil
638		destination = nil
639
640		// Reset the forwarding state, but propagate the error that caused
641		// failure.
642		c.stateLock.Lock()
643		c.state = &State{
644			Session:   c.session,
645			LastError: sessionErr.Error(),
646		}
647		c.stateLock.Unlock()
648
649		// When forwarding fails, we generally want to restart it as quickly as
650		// possible. Thus, if it's been longer than our usual waiting period
651		// since forwarding failed last, simply try to reconnect immediately
652		// (though still check for cancellation). If it's been less than our
653		// usual waiting period since forwarding failed last, then something is
654		// probably wrong, so wait for our usual waiting period (while checking
655		// and monitoring for cancellation).
656		now := time.Now()
657		if now.Sub(lastForwardingFailureTime) >= autoReconnectInterval {
658			select {
659			case <-ctx.Done():
660				return
661			default:
662			}
663		} else {
664			select {
665			case <-ctx.Done():
666				return
667			case <-time.After(autoReconnectInterval):
668			}
669		}
670		lastForwardingFailureTime = now
671	}
672}
673
674// forward is the main forwarding loop for the controller.
675func (c *controller) forward(source, destination Endpoint) error {
676	// Create a context that we can use to regulate the lifecycle of forwarding
677	// Goroutines and defer its cancellation.
678	ctx, cancel := context.WithCancel(context.Background())
679	defer cancel()
680
681	// Clear any error state and update the status to forwarding. While we're at
682	// it, capture a pointer to the state instance that all forwarding
683	// Goroutines spawned by this loop will update. This state instance will be
684	// replaced once this loop returns, so those background Goroutines can
685	// continue to safely update it without any risk of updating a future loop's
686	// state object. The only penalty is that both state objects will share the
687	// same lock, but that's a negligible overhead.
688	var state *State
689	c.stateLock.Lock()
690	c.state.LastError = ""
691	c.state.Status = Status_ForwardingConnections
692	state = c.state
693	c.stateLock.Unlock()
694
695	// Accept and forward connections until there's an error.
696	for {
697		// Accept a connection from the source.
698		incoming, err := source.Open()
699		if err != nil {
700			return errors.Wrap(err, "unable to accept connection")
701		}
702
703		// Open the outgoing connection to which we should forward.
704		outgoing, err := destination.Open()
705		if err != nil {
706			incoming.Close()
707			return errors.Wrap(err, "unable to open forwarding connection")
708		}
709
710		// Increment the open and total connection counts.
711		c.stateLock.Lock()
712		state.OpenConnections++
713		state.TotalConnections++
714		c.stateLock.Unlock()
715
716		// Perform forwarding and update state in a background Goroutine.
717		go func() {
718			// Perform forwarding.
719			ForwardAndClose(ctx, incoming, outgoing)
720
721			// Decrement open connection counts.
722			c.stateLock.Lock()
723			state.OpenConnections--
724			c.stateLock.Unlock()
725		}()
726	}
727}
728