1package middleware
2
3import (
4	"net/http"
5	"net/http/httptest"
6	"testing"
7
8	"github.com/go-chi/chi"
9)
10
11func TestXRealIP(t *testing.T) {
12	req, _ := http.NewRequest("GET", "/", nil)
13	req.Header.Add("X-Real-IP", "100.100.100.100")
14	w := httptest.NewRecorder()
15
16	r := chi.NewRouter()
17	r.Use(RealIP)
18
19	realIP := ""
20	r.Get("/", func(w http.ResponseWriter, r *http.Request) {
21		realIP = r.RemoteAddr
22		w.Write([]byte("Hello World"))
23	})
24	r.ServeHTTP(w, req)
25
26	if w.Code != 200 {
27		t.Fatal("Response Code should be 200")
28	}
29
30	if realIP != "100.100.100.100" {
31		t.Fatal("Test get real IP error.")
32	}
33}
34
35func TestXForwardForIP(t *testing.T) {
36	req, _ := http.NewRequest("GET", "/", nil)
37	req.Header.Add("X-Forwarded-For", "100.100.100.100")
38	w := httptest.NewRecorder()
39
40	r := chi.NewRouter()
41	r.Use(RealIP)
42
43	realIP := ""
44	r.Get("/", func(w http.ResponseWriter, r *http.Request) {
45		realIP = r.RemoteAddr
46		w.Write([]byte("Hello World"))
47	})
48	r.ServeHTTP(w, req)
49
50	if w.Code != 200 {
51		t.Fatal("Response Code should be 200")
52	}
53
54	if realIP != "100.100.100.100" {
55		t.Fatal("Test get real IP error.")
56	}
57}
58