1// Package gracenet provides a family of Listen functions that either open a 2// fresh connection or provide an inherited connection from when the process 3// was started. The behave like their counterparts in the net package, but 4// transparently provide support for graceful restarts without dropping 5// connections. This is provided in a systemd socket activation compatible form 6// to allow using socket activation. 7// 8// BUG: Doesn't handle closing of listeners. 9package gracenet 10 11import ( 12 "fmt" 13 "net" 14 "os" 15 "os/exec" 16 "strconv" 17 "strings" 18 "sync" 19) 20 21const ( 22 // Used to indicate a graceful restart in the new process. 23 envCountKey = "LISTEN_FDS" 24 envCountKeyPrefix = envCountKey + "=" 25) 26 27// In order to keep the working directory the same as when we started we record 28// it at startup. 29var originalWD, _ = os.Getwd() 30 31// Net provides the family of Listen functions and maintains the associated 32// state. Typically you will have only once instance of Net per application. 33type Net struct { 34 inherited []net.Listener 35 active []net.Listener 36 mutex sync.Mutex 37 inheritOnce sync.Once 38 39 // used in tests to override the default behavior of starting from fd 3. 40 fdStart int 41} 42 43func (n *Net) inherit() error { 44 var retErr error 45 n.inheritOnce.Do(func() { 46 n.mutex.Lock() 47 defer n.mutex.Unlock() 48 countStr := os.Getenv(envCountKey) 49 if countStr == "" { 50 return 51 } 52 count, err := strconv.Atoi(countStr) 53 if err != nil { 54 retErr = fmt.Errorf("found invalid count value: %s=%s", envCountKey, countStr) 55 return 56 } 57 58 // In tests this may be overridden. 59 fdStart := n.fdStart 60 if fdStart == 0 { 61 // In normal operations if we are inheriting, the listeners will begin at 62 // fd 3. 63 fdStart = 3 64 } 65 66 for i := fdStart; i < fdStart+count; i++ { 67 file := os.NewFile(uintptr(i), "listener") 68 l, err := net.FileListener(file) 69 if err != nil { 70 file.Close() 71 retErr = fmt.Errorf("error inheriting socket fd %d: %s", i, err) 72 return 73 } 74 if err := file.Close(); err != nil { 75 retErr = fmt.Errorf("error closing inherited socket fd %d: %s", i, err) 76 return 77 } 78 n.inherited = append(n.inherited, l) 79 } 80 }) 81 return retErr 82} 83 84// Listen announces on the local network address laddr. The network net must be 85// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It 86// returns an inherited net.Listener for the matching network and address, or 87// creates a new one using net.Listen. 88func (n *Net) Listen(nett, laddr string) (net.Listener, error) { 89 switch nett { 90 default: 91 return nil, net.UnknownNetworkError(nett) 92 case "tcp", "tcp4", "tcp6": 93 addr, err := net.ResolveTCPAddr(nett, laddr) 94 if err != nil { 95 return nil, err 96 } 97 return n.ListenTCP(nett, addr) 98 case "unix", "unixpacket", "invalid_unix_net_for_test": 99 addr, err := net.ResolveUnixAddr(nett, laddr) 100 if err != nil { 101 return nil, err 102 } 103 return n.ListenUnix(nett, addr) 104 } 105} 106 107// ListenTCP announces on the local network address laddr. The network net must 108// be: "tcp", "tcp4" or "tcp6". It returns an inherited net.Listener for the 109// matching network and address, or creates a new one using net.ListenTCP. 110func (n *Net) ListenTCP(nett string, laddr *net.TCPAddr) (*net.TCPListener, error) { 111 if err := n.inherit(); err != nil { 112 return nil, err 113 } 114 115 n.mutex.Lock() 116 defer n.mutex.Unlock() 117 118 // look for an inherited listener 119 for i, l := range n.inherited { 120 if l == nil { // we nil used inherited listeners 121 continue 122 } 123 if isSameAddr(l.Addr(), laddr) { 124 n.inherited[i] = nil 125 n.active = append(n.active, l) 126 return l.(*net.TCPListener), nil 127 } 128 } 129 130 // make a fresh listener 131 l, err := net.ListenTCP(nett, laddr) 132 if err != nil { 133 return nil, err 134 } 135 n.active = append(n.active, l) 136 return l, nil 137} 138 139// ListenUnix announces on the local network address laddr. The network net 140// must be a: "unix" or "unixpacket". It returns an inherited net.Listener for 141// the matching network and address, or creates a new one using net.ListenUnix. 142func (n *Net) ListenUnix(nett string, laddr *net.UnixAddr) (*net.UnixListener, error) { 143 if err := n.inherit(); err != nil { 144 return nil, err 145 } 146 147 n.mutex.Lock() 148 defer n.mutex.Unlock() 149 150 // look for an inherited listener 151 for i, l := range n.inherited { 152 if l == nil { // we nil used inherited listeners 153 continue 154 } 155 if isSameAddr(l.Addr(), laddr) { 156 n.inherited[i] = nil 157 n.active = append(n.active, l) 158 return l.(*net.UnixListener), nil 159 } 160 } 161 162 // make a fresh listener 163 l, err := net.ListenUnix(nett, laddr) 164 if err != nil { 165 return nil, err 166 } 167 n.active = append(n.active, l) 168 return l, nil 169} 170 171// activeListeners returns a snapshot copy of the active listeners. 172func (n *Net) activeListeners() ([]net.Listener, error) { 173 n.mutex.Lock() 174 defer n.mutex.Unlock() 175 ls := make([]net.Listener, len(n.active)) 176 copy(ls, n.active) 177 return ls, nil 178} 179 180func isSameAddr(a1, a2 net.Addr) bool { 181 if a1.Network() != a2.Network() { 182 return false 183 } 184 a1s := a1.String() 185 a2s := a2.String() 186 if a1s == a2s { 187 return true 188 } 189 190 // This allows for ipv6 vs ipv4 local addresses to compare as equal. This 191 // scenario is common when listening on localhost. 192 const ipv6prefix = "[::]" 193 a1s = strings.TrimPrefix(a1s, ipv6prefix) 194 a2s = strings.TrimPrefix(a2s, ipv6prefix) 195 const ipv4prefix = "0.0.0.0" 196 a1s = strings.TrimPrefix(a1s, ipv4prefix) 197 a2s = strings.TrimPrefix(a2s, ipv4prefix) 198 return a1s == a2s 199} 200 201// StartProcess starts a new process passing it the active listeners. It 202// doesn't fork, but starts a new process using the same environment and 203// arguments as when it was originally started. This allows for a newly 204// deployed binary to be started. It returns the pid of the newly started 205// process when successful. 206func (n *Net) StartProcess() (int, error) { 207 listeners, err := n.activeListeners() 208 if err != nil { 209 return 0, err 210 } 211 212 // Extract the fds from the listeners. 213 files := make([]*os.File, len(listeners)) 214 for i, l := range listeners { 215 files[i], err = l.(filer).File() 216 if err != nil { 217 return 0, err 218 } 219 defer files[i].Close() 220 } 221 222 // Use the original binary location. This works with symlinks such that if 223 // the file it points to has been changed we will use the updated symlink. 224 argv0, err := exec.LookPath(os.Args[0]) 225 if err != nil { 226 return 0, err 227 } 228 229 // Pass on the environment and replace the old count key with the new one. 230 var env []string 231 for _, v := range os.Environ() { 232 if !strings.HasPrefix(v, envCountKeyPrefix) { 233 env = append(env, v) 234 } 235 } 236 env = append(env, fmt.Sprintf("%s%d", envCountKeyPrefix, len(listeners))) 237 238 allFiles := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, files...) 239 process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{ 240 Dir: originalWD, 241 Env: env, 242 Files: allFiles, 243 }) 244 if err != nil { 245 return 0, err 246 } 247 return process.Pid, nil 248} 249 250type filer interface { 251 File() (*os.File, error) 252} 253