1package core
2
3import (
4	"crypto/sha1"
5	"encoding/json"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"net/url"
10	"os"
11	"path"
12	"path/filepath"
13	"runtime"
14	"strconv"
15	"strings"
16	"sync"
17	"sync/atomic"
18	"time"
19
20	"github.com/asaskevich/EventBus"
21	"github.com/remeh/sizedwaitgroup"
22)
23
24type Stats struct {
25	StartedAt            time.Time `json:"startedAt"`
26	FinishedAt           time.Time `json:"finishedAt"`
27	PortOpen             uint32    `json:"portOpen"`
28	PortClosed           uint32    `json:"portClosed"`
29	RequestSuccessful    uint32    `json:"requestSuccessful"`
30	RequestFailed        uint32    `json:"requestFailed"`
31	ResponseCode2xx      uint32    `json:"responseCode2xx"`
32	ResponseCode3xx      uint32    `json:"responseCode3xx"`
33	ResponseCode4xx      uint32    `json:"responseCode4xx"`
34	ResponseCode5xx      uint32    `json:"responseCode5xx"`
35	ScreenshotSuccessful uint32    `json:"screenshotSuccessful"`
36	ScreenshotFailed     uint32    `json:"screenshotFailed"`
37}
38
39func (s *Stats) Duration() time.Duration {
40	return s.FinishedAt.Sub(s.StartedAt)
41}
42
43func (s *Stats) IncrementPortOpen() {
44	atomic.AddUint32(&s.PortOpen, 1)
45}
46
47func (s *Stats) IncrementPortClosed() {
48	atomic.AddUint32(&s.PortClosed, 1)
49}
50
51func (s *Stats) IncrementRequestSuccessful() {
52	atomic.AddUint32(&s.RequestSuccessful, 1)
53}
54
55func (s *Stats) IncrementRequestFailed() {
56	atomic.AddUint32(&s.RequestFailed, 1)
57}
58
59func (s *Stats) IncrementResponseCode2xx() {
60	atomic.AddUint32(&s.ResponseCode2xx, 1)
61}
62
63func (s *Stats) IncrementResponseCode3xx() {
64	atomic.AddUint32(&s.ResponseCode3xx, 1)
65}
66
67func (s *Stats) IncrementResponseCode4xx() {
68	atomic.AddUint32(&s.ResponseCode4xx, 1)
69}
70
71func (s *Stats) IncrementResponseCode5xx() {
72	atomic.AddUint32(&s.ResponseCode5xx, 1)
73}
74
75func (s *Stats) IncrementScreenshotSuccessful() {
76	atomic.AddUint32(&s.ScreenshotSuccessful, 1)
77}
78
79func (s *Stats) IncrementScreenshotFailed() {
80	atomic.AddUint32(&s.ScreenshotFailed, 1)
81}
82
83type Session struct {
84	sync.Mutex
85	Version                string                        `json:"version"`
86	Options                Options                       `json:"-"`
87	Out                    *Logger                       `json:"-"`
88	Stats                  *Stats                        `json:"stats"`
89	Pages                  map[string]*Page              `json:"pages"`
90	PageSimilarityClusters map[string][]string           `json:"pageSimilarityClusters"`
91	Ports                  []int                         `json:"-"`
92	EventBus               EventBus.Bus                  `json:"-"`
93	WaitGroup              sizedwaitgroup.SizedWaitGroup `json:"-"`
94}
95
96func (s *Session) Start() {
97	s.Pages = make(map[string]*Page)
98	s.PageSimilarityClusters = make(map[string][]string)
99	s.initStats()
100	s.initLogger()
101	s.initPorts()
102	s.initThreads()
103	s.initEventBus()
104	s.initWaitGroup()
105	s.initDirectories()
106}
107
108func (s *Session) End() {
109	s.Stats.FinishedAt = time.Now()
110}
111
112func (s *Session) AddPage(url string) (*Page, error) {
113	s.Lock()
114	defer s.Unlock()
115	if page, ok := s.Pages[url]; ok {
116		return page, nil
117	}
118
119	page, err := NewPage(url)
120	if err != nil {
121		return nil, err
122	}
123
124	s.Pages[url] = page
125	return page, nil
126}
127
128func (s *Session) GetPage(url string) *Page {
129	if page, ok := s.Pages[url]; ok {
130		return page
131	}
132	return nil
133}
134
135func (s *Session) GetPageByUUID(id string) *Page {
136	for _, page := range s.Pages {
137		if page.UUID == id {
138			return page
139		}
140	}
141	return nil
142}
143
144func (s *Session) initStats() {
145	if s.Stats != nil {
146		return
147	}
148	s.Stats = &Stats{
149		StartedAt: time.Now(),
150	}
151}
152
153func (s *Session) initPorts() {
154	var ports []int
155	switch *s.Options.Ports {
156	case "small":
157		ports = SmallPortList
158	case "", "medium", "default":
159		ports = MediumPortList
160	case "large":
161		ports = LargePortList
162	case "xlarge", "huge":
163		ports = XLargePortList
164	default:
165		for _, p := range strings.Split(*s.Options.Ports, ",") {
166			port, err := strconv.Atoi(strings.TrimSpace(p))
167			if err != nil {
168				s.Out.Fatal("Invalid port range given\n")
169				os.Exit(1)
170			}
171			if port < 1 || port > 65535 {
172				s.Out.Fatal("Invalid port given: %v\n", port)
173				os.Exit(1)
174			}
175			ports = append(ports, port)
176		}
177	}
178	s.Ports = ports
179}
180
181func (s *Session) initLogger() {
182	s.Out = &Logger{}
183	s.Out.SetDebug(*s.Options.Debug)
184	s.Out.SetSilent(*s.Options.Silent)
185}
186
187func (s *Session) initThreads() {
188	if *s.Options.Threads == 0 {
189		numCPUs := runtime.NumCPU()
190		s.Options.Threads = &numCPUs
191	}
192}
193
194func (s *Session) initEventBus() {
195	s.EventBus = EventBus.New()
196}
197
198func (s *Session) initWaitGroup() {
199	s.WaitGroup = sizedwaitgroup.New(*s.Options.Threads)
200}
201
202func (s *Session) initDirectories() {
203	for _, d := range []string{"headers", "html", "screenshots"} {
204		d = s.GetFilePath(d)
205		if _, err := os.Stat(d); os.IsNotExist(err) {
206			err = os.MkdirAll(d, 0755)
207			if err != nil {
208				s.Out.Fatal("Failed to create required directory %s\n", d)
209				os.Exit(1)
210			}
211		}
212	}
213}
214
215func (s *Session) BaseFilenameFromURL(stru string) string {
216	u, err := url.Parse(stru)
217	if err != nil {
218		return ""
219	}
220
221	h := sha1.New()
222	io.WriteString(h, u.Path)
223	io.WriteString(h, u.Fragment)
224
225	pathHash := fmt.Sprintf("%x", h.Sum(nil))[0:16]
226	host := strings.Replace(u.Host, ":", "__", 1)
227	filename := fmt.Sprintf("%s__%s__%s", u.Scheme, strings.Replace(host, ".", "_", -1), pathHash)
228	return strings.ToLower(filename)
229}
230
231func (s *Session) GetFilePath(p string) string {
232	return path.Join(*s.Options.OutDir, p)
233}
234
235func (s *Session) ReadFile(p string) ([]byte, error) {
236	content, err := ioutil.ReadFile(s.GetFilePath(p))
237	if err != nil {
238		return content, err
239	}
240	return content, nil
241}
242
243func (s *Session) ToJSON() string {
244	sessionJSON, _ := json.Marshal(s)
245	return string(sessionJSON)
246}
247
248func (s *Session) SaveToFile(filename string) error {
249	path := s.GetFilePath(filename)
250	err := ioutil.WriteFile(path, []byte(s.ToJSON()), 0644)
251	if err != nil {
252		return err
253	}
254
255	return nil
256}
257
258func (s *Session) Asset(name string) ([]byte, error) {
259	return Asset(name)
260}
261
262func NewSession() (*Session, error) {
263	var err error
264	var session Session
265
266	session.Version = Version
267
268	if session.Options, err = ParseOptions(); err != nil {
269		return nil, err
270	}
271
272	if *session.Options.ChromePath != "" {
273		if _, err := os.Stat(*session.Options.ChromePath); os.IsNotExist(err) {
274			return nil, fmt.Errorf("Chrome path %s does not exist", *session.Options.ChromePath)
275		}
276	}
277
278	if *session.Options.SessionPath != "" {
279		if _, err := os.Stat(*session.Options.SessionPath); os.IsNotExist(err) {
280			return nil, fmt.Errorf("Session path %s does not exist", *session.Options.SessionPath)
281		}
282	}
283
284	if *session.Options.TemplatePath != "" {
285		if _, err := os.Stat(*session.Options.TemplatePath); os.IsNotExist(err) {
286			return nil, fmt.Errorf("Template path %s does not exist", *session.Options.TemplatePath)
287		}
288	}
289
290	envOutPath := os.Getenv("AQUATONE_OUT_PATH")
291	if *session.Options.OutDir == "." && envOutPath != "" {
292		session.Options.OutDir = &envOutPath
293	}
294
295	outdir := filepath.Clean(*session.Options.OutDir)
296	session.Options.OutDir = &outdir
297
298	session.Version = Version
299	session.Start()
300
301	return &session, nil
302}
303