1package dlna
2
3import (
4	"crypto/md5"
5	"encoding/xml"
6	"errors"
7	"fmt"
8	"io"
9	"log"
10	"net"
11	"net/http"
12	"net/http/httptest"
13	"net/http/httputil"
14	"os"
15	"regexp"
16	"strconv"
17	"strings"
18
19	"github.com/anacrolix/dms/soap"
20	"github.com/anacrolix/dms/upnp"
21	"github.com/rclone/rclone/fs"
22)
23
24// Return a default "friendly name" for the server.
25func makeDefaultFriendlyName() string {
26	hostName, err := os.Hostname()
27	if err != nil {
28		hostName = ""
29	} else {
30		hostName = " (" + hostName + ")"
31	}
32	return "rclone" + hostName
33}
34
35func makeDeviceUUID(unique string) string {
36	h := md5.New()
37	if _, err := io.WriteString(h, unique); err != nil {
38		log.Panicf("makeDeviceUUID write failed: %s", err)
39	}
40	buf := h.Sum(nil)
41	return upnp.FormatUUID(buf)
42}
43
44// Get all available active network interfaces.
45func listInterfaces() []net.Interface {
46	ifs, err := net.Interfaces()
47	if err != nil {
48		log.Printf("list network interfaces: %v", err)
49		return []net.Interface{}
50	}
51
52	var active []net.Interface
53	for _, intf := range ifs {
54		if intf.Flags&net.FlagUp != 0 && intf.Flags&net.FlagMulticast != 0 && intf.MTU > 0 {
55			active = append(active, intf)
56		}
57	}
58	return active
59}
60
61func didlLite(chardata string) string {
62	return `<DIDL-Lite` +
63		` xmlns:dc="http://purl.org/dc/elements/1.1/"` +
64		` xmlns:upnp="urn:schemas-upnp-org:metadata-1-0/upnp/"` +
65		` xmlns="urn:schemas-upnp-org:metadata-1-0/DIDL-Lite/"` +
66		` xmlns:dlna="urn:schemas-dlna-org:metadata-1-0/">` +
67		chardata +
68		`</DIDL-Lite>`
69}
70
71func mustMarshalXML(value interface{}) []byte {
72	ret, err := xml.MarshalIndent(value, "", "  ")
73	if err != nil {
74		log.Panicf("mustMarshalXML failed to marshal %v: %s", value, err)
75	}
76	return ret
77}
78
79// Marshal SOAP response arguments into a response XML snippet.
80func marshalSOAPResponse(sa upnp.SoapAction, args map[string]string) []byte {
81	soapArgs := make([]soap.Arg, 0, len(args))
82	for argName, value := range args {
83		soapArgs = append(soapArgs, soap.Arg{
84			XMLName: xml.Name{Local: argName},
85			Value:   value,
86		})
87	}
88	return []byte(fmt.Sprintf(`<u:%[1]sResponse xmlns:u="%[2]s">%[3]s</u:%[1]sResponse>`,
89		sa.Action, sa.ServiceURN.String(), mustMarshalXML(soapArgs)))
90}
91
92var serviceURNRegexp = regexp.MustCompile(`:service:(\w+):(\d+)$`)
93
94func parseServiceType(s string) (ret upnp.ServiceURN, err error) {
95	matches := serviceURNRegexp.FindStringSubmatch(s)
96	if matches == nil {
97		err = errors.New(s)
98		return
99	}
100	if len(matches) != 3 {
101		log.Panicf("Invalid serviceURNRegexp ?")
102	}
103	ret.Type = matches[1]
104	ret.Version, err = strconv.ParseUint(matches[2], 0, 0)
105	return
106}
107
108func parseActionHTTPHeader(s string) (ret upnp.SoapAction, err error) {
109	if s[0] != '"' || s[len(s)-1] != '"' {
110		return
111	}
112	s = s[1 : len(s)-1]
113	hashIndex := strings.LastIndex(s, "#")
114	if hashIndex == -1 {
115		return
116	}
117	ret.Action = s[hashIndex+1:]
118	ret.ServiceURN, err = parseServiceType(s[:hashIndex])
119	return
120}
121
122type loggingResponseWriter struct {
123	http.ResponseWriter
124	request   *http.Request
125	committed bool
126}
127
128func (lrw *loggingResponseWriter) logRequest(code int, err interface{}) {
129	// Choose appropriate log level based on response status code.
130	var level fs.LogLevel
131	if code < 400 && err == nil {
132		level = fs.LogLevelInfo
133	} else {
134		level = fs.LogLevelError
135	}
136
137	if err == nil {
138		err = ""
139	}
140
141	fs.LogPrintf(level, lrw.request.URL, "%s %s %d %s %s",
142		lrw.request.RemoteAddr, lrw.request.Method, code,
143		lrw.request.Header.Get("SOAPACTION"), err)
144}
145
146func (lrw *loggingResponseWriter) WriteHeader(code int) {
147	lrw.committed = true
148	lrw.logRequest(code, nil)
149	lrw.ResponseWriter.WriteHeader(code)
150}
151
152// HTTP handler that logs requests and any errors or panics.
153func logging(next http.Handler) http.Handler {
154	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
155		lrw := &loggingResponseWriter{ResponseWriter: w, request: r}
156		defer func() {
157			err := recover()
158			if err != nil {
159				if !lrw.committed {
160					lrw.logRequest(http.StatusInternalServerError, err)
161					http.Error(w, fmt.Sprint(err), http.StatusInternalServerError)
162				} else {
163					// Too late to send the error to client, but at least log it.
164					fs.Errorf(r.URL.Path, "Recovered panic: %v", err)
165				}
166			}
167		}()
168		next.ServeHTTP(lrw, r)
169	})
170}
171
172// HTTP handler that logs complete request and response bodies for debugging.
173// Error recovery and general request logging are left to logging().
174func traceLogging(next http.Handler) http.Handler {
175	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
176		dump, err := httputil.DumpRequest(r, true)
177		if err != nil {
178			serveError(nil, w, "error dumping request", err)
179			return
180		}
181		fs.Debugf(nil, "%s", dump)
182
183		recorder := httptest.NewRecorder()
184		next.ServeHTTP(recorder, r)
185
186		dump, err = httputil.DumpResponse(recorder.Result(), true)
187		if err != nil {
188			// log the error but ignore it
189			fs.Errorf(nil, "error dumping response: %v", err)
190		} else {
191			fs.Debugf(nil, "%s", dump)
192		}
193
194		// copy from recorder to the real response writer
195		for k, v := range recorder.Header() {
196			w.Header()[k] = v
197		}
198		w.WriteHeader(recorder.Code)
199		_, err = recorder.Body.WriteTo(w)
200		if err != nil {
201			// Network error
202			fs.Debugf(nil, "Error writing response: %v", err)
203		}
204	})
205}
206
207// HTTP handler that sets headers.
208func withHeader(name string, value string, next http.Handler) http.Handler {
209	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
210		w.Header().Set(name, value)
211		next.ServeHTTP(w, r)
212	})
213}
214
215// serveError returns an http.StatusInternalServerError and logs the error
216func serveError(what interface{}, w http.ResponseWriter, text string, err error) {
217	err = fs.CountError(err)
218	fs.Errorf(what, "%s: %v", text, err)
219	http.Error(w, text+".", http.StatusInternalServerError)
220}
221
222// Splits a path into (root, ext) such that root + ext == path, and ext is empty
223// or begins with a period.  Extended version of path.Ext().
224func splitExt(path string) (string, string) {
225	for i := len(path) - 1; i >= 0 && path[i] != '/'; i-- {
226		if path[i] == '.' {
227			return path[:i], path[i:]
228		}
229	}
230	return path, ""
231}
232