1# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: t -*-
2# vi: set ft=python sts=4 ts=4 sw=4 noet :
3
4# This file is part of Fail2Ban.
5#
6# Fail2Ban is free software; you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation; either version 2 of the License, or
9# (at your option) any later version.
10#
11# Fail2Ban is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with Fail2Ban; if not, write to the Free Software
18# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
19
20__author__ = "Cyril Jaquier, Arturo 'Buanzo' Busleiman, Yaroslav Halchenko"
21__license__ = "GPL"
22
23import gc
24import locale
25import logging
26import os
27import re
28import sys
29import traceback
30
31from threading import Lock
32
33from .server.mytime import MyTime
34import importlib
35
36try:
37	import ctypes
38	_libcap = ctypes.CDLL('libcap.so.2')
39except:
40	_libcap = None
41
42
43PREFER_ENC = locale.getpreferredencoding()
44# correct preferred encoding if lang not set in environment:
45if PREFER_ENC.startswith('ANSI_'): # pragma: no cover
46	if sys.stdout and sys.stdout.encoding is not None and not sys.stdout.encoding.startswith('ANSI_'):
47		PREFER_ENC = sys.stdout.encoding
48	elif all((os.getenv(v) in (None, "") for v in ('LANGUAGE', 'LC_ALL', 'LC_CTYPE', 'LANG'))):
49		PREFER_ENC = 'UTF-8';
50
51# py-2.x: try to minimize influence of sporadic conversion errors on python 2.x,
52# caused by implicit converting of string/unicode (e. g. `str(u"\uFFFD")` produces an error
53# if default encoding is 'ascii');
54if sys.version_info < (3,): # pragma: 3.x no cover
55	# correct default (global system) encoding (mostly UTF-8):
56	def __resetDefaultEncoding(encoding):
57		global PREFER_ENC
58		ode = sys.getdefaultencoding().upper()
59		if ode == 'ASCII' and ode != PREFER_ENC.upper():
60			# setdefaultencoding is normally deleted after site initialized, so hack-in using load of sys-module:
61			_sys = sys
62			if not hasattr(_sys, "setdefaultencoding"):
63				try:
64					from imp import load_dynamic as __ldm
65					_sys = __ldm('_sys', 'sys')
66				except ImportError: # pragma: no cover - only if load_dynamic fails
67					importlib.reload(sys)
68					_sys = sys
69			if hasattr(_sys, "setdefaultencoding"):
70				_sys.setdefaultencoding(encoding)
71	# override to PREFER_ENC:
72	__resetDefaultEncoding(PREFER_ENC)
73	del __resetDefaultEncoding
74
75# todo: rewrite explicit (and implicit) str-conversions via encode/decode with IO-encoding (sys.stdout.encoding),
76# e. g. inside tags-replacement by command-actions, etc.
77
78#
79# Following "uni_decode", "uni_string" functions unified python independent any
80# to string converting.
81#
82# Typical example resp. work-case for understanding the coding/decoding issues:
83#
84#   [isinstance('', str), isinstance(b'', str), isinstance(u'', str)]
85#   [True, True, False]; # -- python2
86#	  [True, False, True]; # -- python3
87#
88if sys.version_info >= (3,): # pragma: 2.x no cover
89	def uni_decode(x, enc=PREFER_ENC, errors='strict'):
90		try:
91			if isinstance(x, bytes):
92				return x.decode(enc, errors)
93			return x
94		except (UnicodeDecodeError, UnicodeEncodeError): # pragma: no cover - unsure if reachable
95			if errors != 'strict':
96				raise
97			return x.decode(enc, 'replace')
98	def uni_string(x):
99		if not isinstance(x, bytes):
100			return str(x)
101		return x.decode(PREFER_ENC, 'replace')
102else: # pragma: 3.x no cover
103	def uni_decode(x, enc=PREFER_ENC, errors='strict'):
104		try:
105			if isinstance(x, str):
106				return x.encode(enc, errors)
107			return x
108		except (UnicodeDecodeError, UnicodeEncodeError): # pragma: no cover - unsure if reachable
109			if errors != 'strict':
110				raise
111			return x.encode(enc, 'replace')
112	if sys.getdefaultencoding().upper() != 'UTF-8': # pragma: no cover - utf-8 is default encoding now
113		def uni_string(x):
114			if not isinstance(x, str):
115				return str(x)
116			return x.encode(PREFER_ENC, 'replace')
117	else:
118		uni_string = str
119
120
121def _as_bool(val):
122	return bool(val) if not isinstance(val, str) \
123		else val.lower() in ('1', 'on', 'true', 'yes')
124
125
126def formatExceptionInfo():
127	""" Consistently format exception information """
128	cla, exc = sys.exc_info()[:2]
129	return (cla.__name__, uni_string(exc))
130
131
132#
133# Following "traceback" functions are adopted from PyMVPA distributed
134# under MIT/Expat and copyright by PyMVPA developers (i.e. me and
135# Michael).  Hereby I re-license derivative work on these pieces under GPL
136# to stay in line with the main Fail2Ban license
137#
138def mbasename(s):
139	"""Custom function to include directory name if filename is too common
140
141	Also strip .py at the end
142	"""
143	base = os.path.basename(s)
144	if base.endswith('.py'):
145		base = base[:-3]
146	if base in set(['base', '__init__']):
147		base = os.path.basename(os.path.dirname(s)) + '.' + base
148	return base
149
150
151class TraceBack(object):
152	"""Customized traceback to be included in debug messages
153	"""
154
155	def __init__(self, compress=False):
156		"""Initialize TrackBack metric
157
158		Parameters
159		----------
160		compress : bool
161		  if True then prefix common with previous invocation gets
162		  replaced with ...
163		"""
164		self.__prev = ""
165		self.__compress = compress
166
167	def __call__(self):
168		ftb = traceback.extract_stack(limit=100)[:-2]
169		entries = [
170			[mbasename(x[0]), os.path.dirname(x[0]), str(x[1])] for x in ftb]
171		entries = [ [e[0], e[2]] for e in entries
172					if not (e[0] in ['unittest', 'logging.__init__']
173							or e[1].endswith('/unittest'))]
174
175		# lets make it more concise
176		entries_out = [entries[0]]
177		for entry in entries[1:]:
178			if entry[0] == entries_out[-1][0]:
179				entries_out[-1][1] += ',%s' % entry[1]
180			else:
181				entries_out.append(entry)
182		sftb = '>'.join(['%s:%s' % (mbasename(x[0]),
183									x[1]) for x in entries_out])
184		if self.__compress:
185			# lets remove part which is common with previous invocation
186			prev_next = sftb
187			common_prefix = os.path.commonprefix((self.__prev, sftb))
188			common_prefix2 = re.sub('>[^>]*$', '', common_prefix)
189
190			if common_prefix2 != "":
191				sftb = '...' + sftb[len(common_prefix2):]
192			self.__prev = prev_next
193
194		return sftb
195
196
197class FormatterWithTraceBack(logging.Formatter):
198	"""Custom formatter which expands %(tb) and %(tbc) with tracebacks
199
200	TODO: might need locking in case of compressed tracebacks
201	"""
202	def __init__(self, fmt, *args, **kwargs):
203		logging.Formatter.__init__(self, fmt=fmt, *args, **kwargs)
204		compress = '%(tbc)s' in fmt
205		self._tb = TraceBack(compress=compress)
206
207	def format(self, record):
208		record.tbc = record.tb = self._tb()
209		return logging.Formatter.format(self, record)
210
211
212logging.exitOnIOError = False
213def __stopOnIOError(logSys=None, logHndlr=None): # pragma: no cover
214	if logSys and len(logSys.handlers):
215		logSys.removeHandler(logSys.handlers[0])
216	if logHndlr:
217		logHndlr.close = lambda: None
218	logging.StreamHandler.flush = lambda self: None
219	#sys.excepthook = lambda *args: None
220	if logging.exitOnIOError:
221		try:
222			sys.stderr.close()
223		except:
224			pass
225		sys.exit(0)
226
227try:
228	BrokenPipeError = BrokenPipeError
229except NameError: # pragma: 3.x no cover
230	BrokenPipeError = IOError
231
232__origLog = logging.Logger._log
233def __safeLog(self, level, msg, args, **kwargs):
234	"""Safe log inject to avoid possible errors by unsafe log-handlers,
235	concat, str. conversion, representation fails, etc.
236
237	Used to intrude exception-safe _log-method instead of _log-method
238	of Logger class to be always safe by logging and to get more-info about.
239
240	See testSafeLogging test-case for more information. At least the errors
241	covered in phase 3 seems to affected in all known pypy/python versions
242	until now.
243	"""
244	try:
245		# if isEnabledFor(level) already called...
246		__origLog(self, level, msg, args, **kwargs)
247	except (BrokenPipeError, IOError) as e: # pragma: no cover
248		if e.errno == 32: # closed / broken pipe
249			__stopOnIOError(self)
250		raise
251	except Exception as e: # pragma: no cover - unreachable if log-handler safe in this python-version
252		try:
253			for args in (
254				("logging failed: %r on %s", (e, uni_string(msg))),
255				("  args: %r", ([uni_string(a) for a in args],))
256			):
257				try:
258					__origLog(self, level, *args)
259				except: # pragma: no cover
260					pass
261		except: # pragma: no cover
262			pass
263logging.Logger._log = __safeLog
264
265__origLogFlush = logging.StreamHandler.flush
266def __safeLogFlush(self):
267	"""Safe flush inject stopping endless logging on closed streams (redirected pipe).
268	"""
269	try:
270		__origLogFlush(self)
271	except (BrokenPipeError, IOError) as e: # pragma: no cover
272		if e.errno == 32: # closed / broken pipe
273			__stopOnIOError(None, self)
274		raise
275logging.StreamHandler.flush = __safeLogFlush
276
277def getLogger(name):
278	"""Get logging.Logger instance with Fail2Ban logger name convention
279	"""
280	if "." in name:
281		name = "fail2ban.%s" % name.rpartition(".")[-1]
282	return logging.getLogger(name)
283
284def str2LogLevel(value):
285	try:
286		if isinstance(value, int) or value.isdigit():
287			ll = int(value)
288		else:
289			ll = getattr(logging, value.upper())
290	except AttributeError:
291		raise ValueError("Invalid log level %r" % value)
292	return ll
293
294def getVerbosityFormat(verbosity, fmt=' %(message)s', addtime=True, padding=True):
295	"""Custom log format for the verbose runs
296	"""
297	if verbosity > 1: # pragma: no cover
298		if verbosity > 3:
299			fmt = ' | %(module)15.15s-%(levelno)-2d: %(funcName)-20.20s |' + fmt
300		if verbosity > 2:
301			fmt = ' +%(relativeCreated)5d %(thread)X %(name)-25.25s %(levelname)-5.5s' + fmt
302		else:
303			fmt = ' %(thread)X %(levelname)-5.5s' + fmt
304			if addtime:
305				fmt = ' %(asctime)-15s' + fmt
306	else: # default (not verbose):
307		fmt = "%(name)-24s[%(process)d]: %(levelname)-7s" + fmt
308		if addtime:
309			fmt = "%(asctime)s " + fmt
310	# remove padding if not needed:
311	if not padding:
312		fmt = re.sub(r'(?<=\))-?\d+(?:\.\d+)?s', lambda m: 's', fmt)
313	return fmt
314
315
316def excepthook(exctype, value, traceback):
317	"""Except hook used to log unhandled exceptions to Fail2Ban log
318	"""
319	getLogger("fail2ban").critical(
320		"Unhandled exception in Fail2Ban:", exc_info=True)
321	return sys.__excepthook__(exctype, value, traceback)
322
323def splitwords(s):
324	"""Helper to split words on any comma, space, or a new line
325
326	Returns empty list if input is empty (or None) and filters
327	out empty entries
328	"""
329	if not s:
330		return []
331	return list(filter(bool, [v.strip() for v in re.split('[ ,\n]+', s)]))
332
333if sys.version_info >= (3,5):
334	eval(compile(r'''if 1:
335	def _merge_dicts(x, y):
336		"""Helper to merge dicts.
337		"""
338		if y:
339			return {**x, **y}
340		return x
341
342	def _merge_copy_dicts(x, y):
343		"""Helper to merge dicts to guarantee a copy result (r is never x).
344		"""
345		return {**x, **y}
346	''', __file__, 'exec'))
347else:
348	def _merge_dicts(x, y):
349		"""Helper to merge dicts.
350		"""
351		r = x
352		if y:
353			r = x.copy()
354			r.update(y)
355		return r
356	def _merge_copy_dicts(x, y):
357		"""Helper to merge dicts to guarantee a copy result (r is never x).
358		"""
359		r = x.copy()
360		if y:
361			r.update(y)
362		return r
363
364#
365# Following function used for parse options from parameter (e.g. `name[p1=0, p2="..."][p3='...']`).
366#
367
368# regex, to extract list of options:
369OPTION_CRE = re.compile(r"^([^\[]+)(?:\[(.*)\])?\s*$", re.DOTALL)
370# regex, to iterate over single option in option list, syntax:
371# `action = act[p1="...", p2='...', p3=...]`, where the p3=... not contains `,` or ']'
372# since v0.10 separator extended with `]\s*[` for support of multiple option groups, syntax
373# `action = act[p1=...][p2=...]`
374OPTION_EXTRACT_CRE = re.compile(
375	r'([\w\-_\.]+)=(?:"([^"]*)"|\'([^\']*)\'|([^,\]]*))(?:,|\]\s*\[|$)', re.DOTALL)
376# split by new-line considering possible new-lines within options [...]:
377OPTION_SPLIT_CRE = re.compile(
378	r'(?:[^\[\s]+(?:\s*\[\s*(?:[\w\-_\.]+=(?:"[^"]*"|\'[^\']*\'|[^,\]]*)\s*(?:,|\]\s*\[)?\s*)*\])?\s*|\S+)(?=\n\s*|\s+|$)', re.DOTALL)
379
380def extractOptions(option):
381	match = OPTION_CRE.match(option)
382	if not match:
383		# TODO proper error handling
384		return None, None
385	option_name, optstr = match.groups()
386	option_opts = dict()
387	if optstr:
388		for optmatch in OPTION_EXTRACT_CRE.finditer(optstr):
389			opt = optmatch.group(1)
390			value = [
391				val for val in optmatch.group(2,3,4) if val is not None][0]
392			option_opts[opt.strip()] = value.strip()
393	return option_name, option_opts
394
395def splitWithOptions(option):
396	return OPTION_SPLIT_CRE.findall(option)
397
398#
399# Following facilities used for safe recursive interpolation of
400# tags (<tag>) in tagged options.
401#
402
403# max tag replacement count (considering tag X in tag Y repeat):
404MAX_TAG_REPLACE_COUNT = 25
405
406# compiled RE for tag name (replacement name)
407TAG_CRE = re.compile(r'<([^ <>]+)>')
408
409def substituteRecursiveTags(inptags, conditional='',
410	ignore=(), addrepl=None
411):
412	"""Sort out tag definitions within other tags.
413	Since v.0.9.2 supports embedded interpolation (see test cases for examples).
414
415	so:		becomes:
416	a = 3		a = 3
417	b = <a>_3	b = 3_3
418
419	Parameters
420	----------
421	inptags : dict
422		Dictionary of tags(keys) and their values.
423
424	Returns
425	-------
426	dict
427		Dictionary of tags(keys) and their values, with tags
428		within the values recursively replaced.
429	"""
430	#logSys = getLogger("fail2ban")
431	tre_search = TAG_CRE.search
432	tags = inptags
433	# init:
434	ignore = set(ignore)
435	done = set()
436	noRecRepl = hasattr(tags, "getRawItem")
437	# repeat substitution while embedded-recursive (repFlag is True)
438	repCounts = {}
439	while True:
440		repFlag = False
441		# substitute each value:
442		for tag in tags.keys():
443			# ignore escaped or already done (or in ignore list):
444			if tag in ignore or tag in done: continue
445			# ignore replacing callable items from calling map - should be converted on demand only (by get):
446			if noRecRepl and callable(tags.getRawItem(tag)): continue
447			value = orgval = uni_string(tags[tag])
448			# search and replace all tags within value, that can be interpolated using other tags:
449			m = tre_search(value)
450			rplc = repCounts.get(tag, {})
451			#logSys.log(5, 'TAG: %s, value: %s' % (tag, value))
452			while m:
453				# found replacement tag:
454				rtag = m.group(1)
455				# don't replace tags that should be currently ignored (pre-replacement):
456				if rtag in ignore:
457					m = tre_search(value, m.end())
458					continue
459				#logSys.log(5, 'found: %s' % rtag)
460				if rtag == tag or rplc.get(rtag, 1) > MAX_TAG_REPLACE_COUNT:
461					# recursive definitions are bad
462					#logSys.log(5, 'recursion fail tag: %s value: %s' % (tag, value) )
463					raise ValueError(
464						"properties contain self referencing definitions "
465						"and cannot be resolved, fail tag: %s, found: %s in %s, value: %s" %
466						(tag, rtag, rplc, value))
467				repl = None
468				if conditional:
469					repl = tags.get(rtag + '?' + conditional)
470				if repl is None:
471					repl = tags.get(rtag)
472					# try to find tag using additional replacement (callable):
473					if repl is None and addrepl is not None:
474						repl = addrepl(rtag)
475				if repl is None:
476					# Missing tags - just continue on searching after end of match
477					# Missing tags are ok - cInfo can contain aInfo elements like <HOST> and valid shell
478					# constructs like <STDIN>.
479					m = tre_search(value, m.end())
480					continue
481				# if calling map - be sure we've string:
482				if not isinstance(repl, str): repl = uni_string(repl)
483				value = value.replace('<%s>' % rtag, repl)
484				#logSys.log(5, 'value now: %s' % value)
485				# increment reference count:
486				rplc[rtag] = rplc.get(rtag, 0) + 1
487				# the next match for replace:
488				m = tre_search(value, m.start())
489			#logSys.log(5, 'TAG: %s, newvalue: %s' % (tag, value))
490			# was substituted?
491			if orgval != value:
492				# check still contains any tag - should be repeated (possible embedded-recursive substitution):
493				if tre_search(value):
494					repCounts[tag] = rplc
495					repFlag = True
496				# copy return tags dict to prevent modifying of inptags:
497				if id(tags) == id(inptags):
498					tags = inptags.copy()
499				tags[tag] = value
500			# no more sub tags (and no possible composite), add this tag to done set (just to be faster):
501			if '<' not in value: done.add(tag)
502		# stop interpolation, if no replacements anymore:
503		if not repFlag:
504			break
505	return tags
506
507
508if _libcap:
509	def prctl_set_th_name(name):
510		"""Helper to set real thread name (used for identification and diagnostic purposes).
511
512		Side effect: name can be silently truncated to 15 bytes (16 bytes with NTS zero)
513		"""
514		try:
515			if sys.version_info >= (3,): # pragma: 2.x no cover
516				name = name.encode()
517			else: # pragma: 3.x no cover
518				name = bytes(name)
519			_libcap.prctl(15, name) # PR_SET_NAME = 15
520		except: # pragma: no cover
521			pass
522else: # pragma: no cover
523	def prctl_set_th_name(name):
524		pass
525
526
527class BgService(object):
528	"""Background servicing
529
530	Prevents memory leak on some platforms/python versions,
531	using forced GC in periodical intervals.
532	"""
533
534	_mutex = Lock()
535	_instance = None
536	def __new__(cls):
537		if not cls._instance:
538			cls._instance = \
539				super(BgService, cls).__new__(cls)
540		return cls._instance
541
542	def __init__(self):
543		self.__serviceTime = -0x7fffffff
544		self.__periodTime = 30
545		self.__threshold = 100;
546		self.__count = self.__threshold;
547		if hasattr(gc, 'set_threshold'):
548			gc.set_threshold(0)
549		# don't disable auto garbage, because of non-reference-counting python's (like pypy),
550		# otherwise it may leak there on objects like unix-socket, etc.
551		#gc.disable()
552
553	def service(self, force=False, wait=False):
554		self.__count -= 1
555		# avoid locking if next service time don't reached
556		if not force and (self.__count > 0 or MyTime.time() < self.__serviceTime):
557			return False
558		# return immediately if mutex already locked (other thread in servicing):
559		if not BgService._mutex.acquire(wait):
560			return False
561		try:
562			# check again in lock:
563			if MyTime.time() < self.__serviceTime:
564				return False
565			gc.collect()
566			self.__serviceTime = MyTime.time() + self.__periodTime
567			self.__count = self.__threshold
568			return True
569		finally:
570			BgService._mutex.release()
571		return False
572