1package forwarding 2 3import ( 4 "context" 5 "fmt" 6 "os" 7 "sync" 8 "time" 9 10 "github.com/pkg/errors" 11 12 "github.com/golang/protobuf/ptypes" 13 14 "github.com/mutagen-io/mutagen/pkg/encoding" 15 "github.com/mutagen-io/mutagen/pkg/logging" 16 "github.com/mutagen-io/mutagen/pkg/mutagen" 17 "github.com/mutagen-io/mutagen/pkg/prompting" 18 "github.com/mutagen-io/mutagen/pkg/state" 19 "github.com/mutagen-io/mutagen/pkg/url" 20) 21 22const ( 23 // autoReconnectInterval is the period of time to wait before attempting an 24 // automatic reconnect after disconnection or a failed reconnect. 25 autoReconnectInterval = 15 * time.Second 26) 27 28// controller manages and executes a single session. 29type controller struct { 30 // logger is the controller logger. 31 logger *logging.Logger 32 // sessionPath is the path to the serialized session. 33 sessionPath string 34 // stateLock guards and tracks changes to the session member's Paused field 35 // and the state member. 36 stateLock *state.TrackingLock 37 // session encodes the associated session metadata. It is considered static 38 // and safe for concurrent access except for its Paused field, for which the 39 // stateLock member should be held. It should be saved to disk any time it 40 // is modified. 41 session *Session 42 // mergedSourceConfiguration is the source-specific configuration object 43 // (computed from the core configuration and source-specific overrides). It 44 // is considered static and safe for concurrent access. It is a derived 45 // field and not saved to disk. 46 mergedSourceConfiguration *Configuration 47 // mergedDestinationConfiguration is the destination-specific configuration 48 // object (computed from the core configuration and destination-specific 49 // overrides). It is considered static and safe for concurrent access. It is 50 // a derived field and not saved to disk. 51 mergedDestinationConfiguration *Configuration 52 // state represents the current forwarding state. 53 state *State 54 // lifecycleLock guards setting of the disabled, cancel, flushRequests, and 55 // done members. Access to these members is allowed for the forwarding loop 56 // without holding the lock. Any code wishing to set these members should 57 // first acquire the lock, then cancel the forwarding loop, and wait for it 58 // to complete before making any such changes. 59 lifecycleLock sync.Mutex 60 // disabled indicates that no more changes to the forwarding loop lifecycle 61 // are allowed (i.e. no more forwarding loops can be started for this 62 // controller). This is used by terminate and shutdown. It should only be 63 // set to true once any existing forwarding loop has been stopped. 64 disabled bool 65 // cancel cancels the forwarding loop execution context. It should be nil if 66 // and only if there is no forwarding loop running. 67 cancel context.CancelFunc 68 // done will be closed by the current forwarding loop when it exits. 69 done chan struct{} 70} 71 72// newSession creates a new session and corresponding controller. 73func newSession( 74 ctx context.Context, 75 logger *logging.Logger, 76 tracker *state.Tracker, 77 identifier string, 78 source, destination *url.URL, 79 configuration, configurationSource, configurationDestination *Configuration, 80 name string, 81 labels map[string]string, 82 paused bool, 83 prompter string, 84) (*controller, error) { 85 // Update status. 86 prompting.Message(prompter, "Creating session...") 87 88 // Set the session version. 89 version := Version_Version1 90 91 // Compute the creation time and convert it to Protocol Buffers format. 92 creationTime := time.Now() 93 creationTimeProto, err := ptypes.TimestampProto(creationTime) 94 if err != nil { 95 return nil, errors.Wrap(err, "unable to convert creation time format") 96 } 97 98 // Compute merged endpoint configurations. 99 mergedSourceConfiguration := MergeConfigurations(configuration, configurationSource) 100 mergedDestinationConfiguration := MergeConfigurations(configuration, configurationDestination) 101 102 // If the session isn't being created paused, then try to connect to any 103 // endpoints not using the tunnel protocol. The tunnel protocol is the one 104 // case where we want to allow asynchronous connectivity (since it doesn't 105 // require user input but also isn't guaranteed to connect immediately). If 106 // we connect to endpoints here and don't hand them off to the runloop 107 // below, then defer their shutdown. 108 var sourceEndpoint, destinationEndpoint Endpoint 109 defer func() { 110 if sourceEndpoint != nil { 111 sourceEndpoint.Shutdown() 112 sourceEndpoint = nil 113 } 114 if destinationEndpoint != nil { 115 destinationEndpoint.Shutdown() 116 destinationEndpoint = nil 117 } 118 }() 119 if !paused && source.Protocol != url.Protocol_Tunnel { 120 logger.Info("Connecting to source endpoint") 121 sourceEndpoint, err = connect( 122 ctx, 123 logger.Sublogger("source"), 124 source, 125 prompter, 126 identifier, 127 version, 128 mergedSourceConfiguration, 129 true, 130 ) 131 if err != nil { 132 logger.Info("Source connection failure:", err) 133 return nil, errors.Wrap(err, "unable to connect to source") 134 } 135 } 136 if !paused && destination.Protocol != url.Protocol_Tunnel { 137 logger.Info("Connecting to destination endpoint") 138 destinationEndpoint, err = connect( 139 ctx, 140 logger.Sublogger("destination"), 141 destination, 142 prompter, 143 identifier, 144 version, 145 mergedDestinationConfiguration, 146 false, 147 ) 148 if err != nil { 149 logger.Info("Destination connection failure:", err) 150 return nil, errors.Wrap(err, "unable to connect to destination") 151 } 152 } 153 154 // Create the session. 155 session := &Session{ 156 Identifier: identifier, 157 Version: version, 158 CreationTime: creationTimeProto, 159 CreatingVersionMajor: mutagen.VersionMajor, 160 CreatingVersionMinor: mutagen.VersionMinor, 161 CreatingVersionPatch: mutagen.VersionPatch, 162 Source: source, 163 Destination: destination, 164 Configuration: configuration, 165 ConfigurationSource: configurationSource, 166 ConfigurationDestination: configurationDestination, 167 Name: name, 168 Labels: labels, 169 Paused: paused, 170 } 171 172 // Compute the session path. 173 sessionPath, err := pathForSession(session.Identifier) 174 if err != nil { 175 return nil, errors.Wrap(err, "unable to compute session path") 176 } 177 178 // Save the session to disk. 179 if err := encoding.MarshalAndSaveProtobuf(sessionPath, session); err != nil { 180 return nil, errors.Wrap(err, "unable to save session") 181 } 182 183 // Create the controller. 184 controller := &controller{ 185 logger: logger, 186 sessionPath: sessionPath, 187 stateLock: state.NewTrackingLock(tracker), 188 session: session, 189 mergedSourceConfiguration: mergedSourceConfiguration, 190 mergedDestinationConfiguration: mergedDestinationConfiguration, 191 state: &State{ 192 Session: session, 193 }, 194 } 195 196 // If the session isn't being created pre-paused, then start a forwarding 197 // loop and mark the endpoints as handed off to that loop so that we don't 198 // defer their shutdown. 199 if !paused { 200 logger.Info("Starting forwarding loop") 201 ctx, cancel := context.WithCancel(context.Background()) 202 controller.cancel = cancel 203 controller.done = make(chan struct{}) 204 go controller.run(ctx, sourceEndpoint, destinationEndpoint) 205 sourceEndpoint = nil 206 destinationEndpoint = nil 207 } 208 209 // Success. 210 logger.Info("Session initialized") 211 return controller, nil 212} 213 214// loadSession loads an existing session and creates a corresponding controller. 215func loadSession(logger *logging.Logger, tracker *state.Tracker, identifier string) (*controller, error) { 216 // Compute the session path. 217 sessionPath, err := pathForSession(identifier) 218 if err != nil { 219 return nil, errors.Wrap(err, "unable to compute session path") 220 } 221 222 // Load and validate the session. 223 session := &Session{} 224 if err := encoding.LoadAndUnmarshalProtobuf(sessionPath, session); err != nil { 225 return nil, errors.Wrap(err, "unable to load session configuration") 226 } 227 if err := session.EnsureValid(); err != nil { 228 return nil, errors.Wrap(err, "invalid session found on disk") 229 } 230 231 // Create the controller. 232 controller := &controller{ 233 logger: logger, 234 sessionPath: sessionPath, 235 stateLock: state.NewTrackingLock(tracker), 236 session: session, 237 mergedSourceConfiguration: MergeConfigurations( 238 session.Configuration, 239 session.ConfigurationSource, 240 ), 241 mergedDestinationConfiguration: MergeConfigurations( 242 session.Configuration, 243 session.ConfigurationDestination, 244 ), 245 state: &State{ 246 Session: session, 247 }, 248 } 249 250 // If the session isn't marked as paused, start a forwarding loop. 251 if !session.Paused { 252 ctx, cancel := context.WithCancel(context.Background()) 253 controller.cancel = cancel 254 controller.done = make(chan struct{}) 255 go controller.run(ctx, nil, nil) 256 } 257 258 // Success. 259 logger.Info("Session loaded") 260 return controller, nil 261} 262 263// currentState creates a snapshot of the current session state. 264func (c *controller) currentState() *State { 265 // Lock the session state and defer its release. It's very important that we 266 // unlock without a notification here, otherwise we'd trigger an infinite 267 // cycle of list/notify. 268 c.stateLock.Lock() 269 defer c.stateLock.UnlockWithoutNotify() 270 271 // Perform a (pseudo) deep copy of the state. 272 return c.state.Copy() 273} 274 275// resume attempts to reconnect and resume the session if it isn't currently 276// connected and forwarding. 277func (c *controller) resume(ctx context.Context, prompter string) error { 278 // Update status. 279 prompting.Message(prompter, fmt.Sprintf("Resuming session %s...", c.session.Identifier)) 280 281 // Lock the controller's lifecycle and defer its release. 282 c.lifecycleLock.Lock() 283 defer c.lifecycleLock.Unlock() 284 285 // Don't allow any resume operations if the controller is disabled. 286 if c.disabled { 287 return errors.New("controller disabled") 288 } 289 290 // Check if there's an existing forwarding loop (i.e. if the session is 291 // unpaused). 292 if c.cancel != nil { 293 // If there is an existing forwarding loop, check if it's already in a 294 // state that's considered "forwarding". 295 c.stateLock.Lock() 296 forwarding := c.state.Status >= Status_ForwardingConnections 297 c.stateLock.UnlockWithoutNotify() 298 299 // If we're already forwarding, then there's nothing we need to do. We 300 // don't even need to mark the session as unpaused because it can't be 301 // marked as paused if an existing forwarding loop is running (we 302 // enforce this invariant as part of the controller's logic). 303 if forwarding { 304 return nil 305 } 306 307 // Otherwise, cancel the existing forwarding loop and wait for it to 308 // finish. 309 // 310 // There's something of an efficiency race condition here, because the 311 // existing loop might succeed in connecting between the time we check 312 // and the time we cancel it. That could happen if an auto-reconnect 313 // succeeds or even if the loop was already passed connections and it's 314 // just hasn't updated its status yet. But the only danger here is 315 // basically wasting those connections, and the window is very small. 316 c.cancel() 317 <-c.done 318 319 // Nil out any lifecycle state. 320 c.cancel = nil 321 c.done = nil 322 } 323 324 // Mark the session as unpaused and save it to disk. 325 c.stateLock.Lock() 326 c.session.Paused = false 327 saveErr := encoding.MarshalAndSaveProtobuf(c.sessionPath, c.session) 328 c.stateLock.Unlock() 329 330 // Attempt to connect to source. 331 c.stateLock.Lock() 332 c.state.Status = Status_ConnectingSource 333 c.stateLock.Unlock() 334 source, sourceConnectErr := connect( 335 ctx, 336 c.logger.Sublogger("source"), 337 c.session.Source, 338 prompter, 339 c.session.Identifier, 340 c.session.Version, 341 c.mergedSourceConfiguration, 342 true, 343 ) 344 c.stateLock.Lock() 345 c.state.SourceConnected = (source != nil) 346 c.stateLock.Unlock() 347 348 // Attempt to connect to destination. 349 c.stateLock.Lock() 350 c.state.Status = Status_ConnectingDestination 351 c.stateLock.Unlock() 352 destination, destinationConnectErr := connect( 353 ctx, 354 c.logger.Sublogger("destination"), 355 c.session.Destination, 356 prompter, 357 c.session.Identifier, 358 c.session.Version, 359 c.mergedDestinationConfiguration, 360 false, 361 ) 362 c.stateLock.Lock() 363 c.state.DestinationConnected = (destination != nil) 364 c.stateLock.Unlock() 365 366 // Start the forwarding loop with what we have. Source or destination may 367 // have failed to connect (and be nil), but in any case that'll just make 368 // the run loop keep trying to connect. 369 ctx, cancel := context.WithCancel(context.Background()) 370 c.cancel = cancel 371 c.done = make(chan struct{}) 372 go c.run(ctx, source, destination) 373 374 // Report any errors. Since we always want to start a forwarding loop, even 375 // on partial or complete failure (since it might be able to auto-reconnect 376 // on its own), we wait until the end to report errors. 377 if saveErr != nil { 378 return errors.Wrap(saveErr, "unable to save session") 379 } else if sourceConnectErr != nil { 380 return errors.Wrap(sourceConnectErr, "unable to connect to source") 381 } else if destinationConnectErr != nil { 382 return errors.Wrap(destinationConnectErr, "unable to connect to destination") 383 } 384 385 // Success. 386 return nil 387} 388 389// controllerHaltMode represents the behavior to use when halting a session. 390type controllerHaltMode uint8 391 392const ( 393 // controllerHaltModePause indicates that a session should be halted and 394 // marked as paused. 395 controllerHaltModePause controllerHaltMode = iota 396 // controllerHaltModeShutdown indicates that a session should be halted. 397 controllerHaltModeShutdown 398 // controllerHaltModeShutdown indicates that a session should be halted and 399 // then deleted. 400 controllerHaltModeTerminate 401) 402 403// description returns a human-readable description of a halt mode. 404func (m controllerHaltMode) description() string { 405 switch m { 406 case controllerHaltModePause: 407 return "Pausing" 408 case controllerHaltModeShutdown: 409 return "Shutting down" 410 case controllerHaltModeTerminate: 411 return "Terminating" 412 default: 413 panic("unhandled halt mode") 414 } 415} 416 417// halt halts the session with the specified behavior. 418func (c *controller) halt(_ context.Context, mode controllerHaltMode, prompter string) error { 419 // Update status. 420 prompting.Message(prompter, fmt.Sprintf("%s session %s...", mode.description(), c.session.Identifier)) 421 422 // Lock the controller's lifecycle and defer its release. 423 c.lifecycleLock.Lock() 424 defer c.lifecycleLock.Unlock() 425 426 // Don't allow any additional halt operations if the controller is disabled, 427 // because either this session is being terminated or the service is 428 // shutting down, and in either case there is no point in halting. 429 if c.disabled { 430 return errors.New("controller disabled") 431 } 432 433 // Kill any existing forwarding loop. 434 if c.cancel != nil { 435 // Cancel the forwarding loop and wait for it to finish. 436 c.cancel() 437 <-c.done 438 439 // Nil out any lifecycle state. 440 c.cancel = nil 441 c.done = nil 442 } 443 444 // Handle based on the halt mode. 445 if mode == controllerHaltModePause { 446 // Mark the session as paused and save it. 447 c.stateLock.Lock() 448 c.session.Paused = true 449 saveErr := encoding.MarshalAndSaveProtobuf(c.sessionPath, c.session) 450 c.stateLock.Unlock() 451 if saveErr != nil { 452 return errors.Wrap(saveErr, "unable to save session") 453 } 454 } else if mode == controllerHaltModeShutdown { 455 // Disable the controller. 456 c.disabled = true 457 } else if mode == controllerHaltModeTerminate { 458 // Disable the controller. 459 c.disabled = true 460 461 // Wipe the session information from disk. 462 sessionRemoveErr := os.Remove(c.sessionPath) 463 if sessionRemoveErr != nil { 464 return errors.Wrap(sessionRemoveErr, "unable to remove session from disk") 465 } 466 } else { 467 panic("invalid halt mode specified") 468 } 469 470 // Success. 471 return nil 472} 473 474// run is the main runloop for the controller, managing connectivity and 475// forwarding. 476func (c *controller) run(ctx context.Context, source, destination Endpoint) { 477 // Defer resource and state cleanup. 478 defer func() { 479 // Shutdown any endpoints. These might be non-nil if the runloop was 480 // cancelled while partially connected rather than after forwarding 481 // failure. 482 if source != nil { 483 source.Shutdown() 484 } 485 if destination != nil { 486 destination.Shutdown() 487 } 488 489 // Reset the state. 490 c.stateLock.Lock() 491 c.state = &State{ 492 Session: c.session, 493 } 494 c.stateLock.Unlock() 495 496 // Signal completion. 497 close(c.done) 498 }() 499 500 // Track the last time that forwarding failed. 501 var lastForwardingFailureTime time.Time 502 503 // Loop until cancelled. 504 for { 505 // Loop until we're connected to both endpoints. We do a non-blocking 506 // check for cancellation on each reconnect error so that we don't waste 507 // resources by trying another connect when the context has been 508 // cancelled (it'll be wasteful). This is better than sentinel errors. 509 for { 510 // Ensure that source is connected. 511 var sourceConnectErr error 512 if source == nil { 513 c.stateLock.Lock() 514 c.state.Status = Status_ConnectingSource 515 c.stateLock.Unlock() 516 source, sourceConnectErr = connect( 517 ctx, 518 c.logger.Sublogger("source"), 519 c.session.Source, 520 "", 521 c.session.Identifier, 522 c.session.Version, 523 c.mergedSourceConfiguration, 524 true, 525 ) 526 } 527 c.stateLock.Lock() 528 c.state.SourceConnected = (source != nil) 529 if sourceConnectErr != nil { 530 c.state.LastError = errors.Wrap(sourceConnectErr, "unable to connect to source").Error() 531 } 532 c.stateLock.Unlock() 533 534 // Check for cancellation to avoid a spurious connection to 535 // destination in case cancellation occurred while connecting to 536 // source. 537 select { 538 case <-ctx.Done(): 539 return 540 default: 541 } 542 543 // Ensure that destination is connected. 544 var destinationConnectErr error 545 if destination == nil { 546 c.stateLock.Lock() 547 c.state.Status = Status_ConnectingDestination 548 c.stateLock.Unlock() 549 destination, destinationConnectErr = connect( 550 ctx, 551 c.logger.Sublogger("destination"), 552 c.session.Destination, 553 "", 554 c.session.Identifier, 555 c.session.Version, 556 c.mergedDestinationConfiguration, 557 false, 558 ) 559 } 560 c.stateLock.Lock() 561 c.state.DestinationConnected = (destination != nil) 562 if destinationConnectErr != nil { 563 c.state.LastError = errors.Wrap(destinationConnectErr, "unable to connect to destination").Error() 564 } 565 c.stateLock.Unlock() 566 567 // If both endpoints are connected, we're done. We perform this 568 // check here (rather than in the loop condition) because if we did 569 // it in the loop condition we'd still need a check here to avoid a 570 // sleep every time (even if already successfully connected). 571 if source != nil && destination != nil { 572 break 573 } 574 575 // If we failed to connect, wait and then retry. Watch for 576 // cancellation in the mean time. 577 select { 578 case <-ctx.Done(): 579 return 580 case <-time.After(autoReconnectInterval): 581 } 582 } 583 584 // Grab transport error channels for each endpoint. 585 sourceTransportErrors := source.TransportErrors() 586 destinationTransportErrors := destination.TransportErrors() 587 588 // Create a cancellable subcontext that we can use to manage shutdown. 589 shutdownCtx, forceShutdown := context.WithCancel(ctx) 590 591 // Create a Goroutine that will shut down (and unblock) endpoints. This 592 // is the only way to unblock forwarding on cancellation. 593 shutdownComplete := make(chan struct{}) 594 go func() { 595 <-shutdownCtx.Done() 596 source.Shutdown() 597 destination.Shutdown() 598 close(shutdownComplete) 599 }() 600 601 // Perform forwarding in a background Goroutine and monitor for errors. 602 forwardingErrors := make(chan error, 1) 603 go func() { 604 forwardingErrors <- c.forward(source, destination) 605 }() 606 607 // Wait for cancellation, an error from forwarding, or an error from 608 // either transport. 609 var sessionErr error 610 var forwardingErrorReceived bool 611 select { 612 case <-ctx.Done(): 613 sessionErr = errors.New("session cancelled") 614 case sessionErr = <-forwardingErrors: 615 forwardingErrorReceived = true 616 case err := <-sourceTransportErrors: 617 sessionErr = fmt.Errorf("source transport failure: %w", err) 618 case err := <-destinationTransportErrors: 619 sessionErr = fmt.Errorf("destination transport failure: %w", err) 620 } 621 622 // Force shutdown, which may have already occurred due to cancellation. 623 forceShutdown() 624 625 // Wait for shutdown to complete. 626 <-shutdownComplete 627 628 // If the forwarding loop wasn't what unblocked our wait, then wait for 629 // it to return a result so that we know it has exited. This isn't 630 // strictly necessary with our current design, but it's cleaner and more 631 // robust. 632 if !forwardingErrorReceived { 633 <-forwardingErrors 634 } 635 636 // Nil out endpoints to update our state. 637 source = nil 638 destination = nil 639 640 // Reset the forwarding state, but propagate the error that caused 641 // failure. 642 c.stateLock.Lock() 643 c.state = &State{ 644 Session: c.session, 645 LastError: sessionErr.Error(), 646 } 647 c.stateLock.Unlock() 648 649 // When forwarding fails, we generally want to restart it as quickly as 650 // possible. Thus, if it's been longer than our usual waiting period 651 // since forwarding failed last, simply try to reconnect immediately 652 // (though still check for cancellation). If it's been less than our 653 // usual waiting period since forwarding failed last, then something is 654 // probably wrong, so wait for our usual waiting period (while checking 655 // and monitoring for cancellation). 656 now := time.Now() 657 if now.Sub(lastForwardingFailureTime) >= autoReconnectInterval { 658 select { 659 case <-ctx.Done(): 660 return 661 default: 662 } 663 } else { 664 select { 665 case <-ctx.Done(): 666 return 667 case <-time.After(autoReconnectInterval): 668 } 669 } 670 lastForwardingFailureTime = now 671 } 672} 673 674// forward is the main forwarding loop for the controller. 675func (c *controller) forward(source, destination Endpoint) error { 676 // Create a context that we can use to regulate the lifecycle of forwarding 677 // Goroutines and defer its cancellation. 678 ctx, cancel := context.WithCancel(context.Background()) 679 defer cancel() 680 681 // Clear any error state and update the status to forwarding. While we're at 682 // it, capture a pointer to the state instance that all forwarding 683 // Goroutines spawned by this loop will update. This state instance will be 684 // replaced once this loop returns, so those background Goroutines can 685 // continue to safely update it without any risk of updating a future loop's 686 // state object. The only penalty is that both state objects will share the 687 // same lock, but that's a negligible overhead. 688 var state *State 689 c.stateLock.Lock() 690 c.state.LastError = "" 691 c.state.Status = Status_ForwardingConnections 692 state = c.state 693 c.stateLock.Unlock() 694 695 // Accept and forward connections until there's an error. 696 for { 697 // Accept a connection from the source. 698 incoming, err := source.Open() 699 if err != nil { 700 return errors.Wrap(err, "unable to accept connection") 701 } 702 703 // Open the outgoing connection to which we should forward. 704 outgoing, err := destination.Open() 705 if err != nil { 706 incoming.Close() 707 return errors.Wrap(err, "unable to open forwarding connection") 708 } 709 710 // Increment the open and total connection counts. 711 c.stateLock.Lock() 712 state.OpenConnections++ 713 state.TotalConnections++ 714 c.stateLock.Unlock() 715 716 // Perform forwarding and update state in a background Goroutine. 717 go func() { 718 // Perform forwarding. 719 ForwardAndClose(ctx, incoming, outgoing) 720 721 // Decrement open connection counts. 722 c.stateLock.Lock() 723 state.OpenConnections-- 724 c.stateLock.Unlock() 725 }() 726 } 727} 728