1// +build !windows
2// +build !wasm
3// +build !confonly
4
5package domainsocket
6
7import (
8	"context"
9	gotls "crypto/tls"
10	"os"
11	"strings"
12
13	"golang.org/x/sys/unix"
14
15	"github.com/v2fly/v2ray-core/v4/common"
16	"github.com/v2fly/v2ray-core/v4/common/net"
17	"github.com/v2fly/v2ray-core/v4/transport/internet"
18	"github.com/v2fly/v2ray-core/v4/transport/internet/tls"
19)
20
21type Listener struct {
22	addr      *net.UnixAddr
23	ln        net.Listener
24	tlsConfig *gotls.Config
25	config    *Config
26	addConn   internet.ConnHandler
27	locker    *fileLocker
28}
29
30func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
31	settings := streamSettings.ProtocolSettings.(*Config)
32	addr, err := settings.GetUnixAddr()
33	if err != nil {
34		return nil, err
35	}
36
37	unixListener, err := net.ListenUnix("unix", addr)
38	if err != nil {
39		return nil, newError("failed to listen domain socket").Base(err).AtWarning()
40	}
41
42	ln := &Listener{
43		addr:    addr,
44		ln:      unixListener,
45		config:  settings,
46		addConn: handler,
47	}
48
49	if !settings.Abstract {
50		ln.locker = &fileLocker{
51			path: settings.Path + ".lock",
52		}
53		if err := ln.locker.Acquire(); err != nil {
54			unixListener.Close()
55			return nil, err
56		}
57	}
58
59	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
60		ln.tlsConfig = config.GetTLSConfig()
61	}
62
63	go ln.run()
64
65	return ln, nil
66}
67
68func (ln *Listener) Addr() net.Addr {
69	return ln.addr
70}
71
72func (ln *Listener) Close() error {
73	if ln.locker != nil {
74		ln.locker.Release()
75	}
76	return ln.ln.Close()
77}
78
79func (ln *Listener) run() {
80	for {
81		conn, err := ln.ln.Accept()
82		if err != nil {
83			if strings.Contains(err.Error(), "closed") {
84				break
85			}
86			newError("failed to accepted raw connections").Base(err).AtWarning().WriteToLog()
87			continue
88		}
89
90		if ln.tlsConfig != nil {
91			conn = tls.Server(conn, ln.tlsConfig)
92		}
93
94		ln.addConn(internet.Connection(conn))
95	}
96}
97
98type fileLocker struct {
99	path string
100	file *os.File
101}
102
103func (fl *fileLocker) Acquire() error {
104	f, err := os.Create(fl.path)
105	if err != nil {
106		return err
107	}
108	if err := unix.Flock(int(f.Fd()), unix.LOCK_EX); err != nil {
109		f.Close()
110		return newError("failed to lock file: ", fl.path).Base(err)
111	}
112	fl.file = f
113	return nil
114}
115
116func (fl *fileLocker) Release() {
117	if err := unix.Flock(int(fl.file.Fd()), unix.LOCK_UN); err != nil {
118		newError("failed to unlock file: ", fl.path).Base(err).WriteToLog()
119	}
120	if err := fl.file.Close(); err != nil {
121		newError("failed to close file: ", fl.path).Base(err).WriteToLog()
122	}
123	if err := os.Remove(fl.path); err != nil {
124		newError("failed to remove file: ", fl.path).Base(err).WriteToLog()
125	}
126}
127
128func init() {
129	common.Must(internet.RegisterTransportListener(protocolName, Listen))
130}
131