관리-도구
편집 파일: __init__.py
import asyncio import base64 import errno import functools import hashlib import itertools import logging import os import pwd import re import shutil import signal import stat import subprocess as _subprocess import time import urllib.request from asyncio import Future from collections import OrderedDict, deque from collections.abc import Generator, Iterable from contextlib import ExitStack, contextmanager, suppress from datetime import timedelta from enum import Enum from fcntl import LOCK_EX, LOCK_NB, LOCK_UN, flock from functools import wraps from itertools import islice from pathlib import Path from tempfile import NamedTemporaryFile from typing import ( Any, Awaitable, Callable, Dict, FrozenSet, List, Tuple, TypeVar, ) import async_lru import distro import psutil from ._shutil import rmtree F = TypeVar("F", bound=Callable) logger = logging.getLogger(__name__) USER_IDENTITY_FIELD = "user_id" USER_IDENTITY_HEADERS = ( "User-Agent", "Accept-Language", "Accept-Encoding", "Connection", "DNT", ) _MIN_UID = -1 BACKUP_EXTENSION = ".i360bak" _SYSTEMD_BOOTED_DIR = Path("/run/systemd/system") _CL_SOLO_EDITION_FILE = "/etc/cloudlinux-edition-solo" AV_PID_PATH = Path("/var/run/imunify-antivirus.pid") IM360_NON_RESIDENT_PID_PATH = Path("/var/run/imunify360-agent.pid") IM360_RESIDENT_PID_PATH = Path("/var/run/imunify360.pid") class Scope(Enum): AV = "AV only" AV_IM360 = "AV and IM360" IM360 = "IM360 only" @functools.lru_cache(maxsize=1) def is_systemd_boot(): """Return True if /run/systemd/system folder exists: [sd_booted] (https://www.freedesktop.org/software/systemd/man/sd_booted.html) """ return ( _SYSTEMD_BOOTED_DIR.exists() and _SYSTEMD_BOOTED_DIR.is_dir() and not _SYSTEMD_BOOTED_DIR.is_symlink() ) @contextmanager def timeit(action, logger_=None, log=None): """ :param str: action name to log :param logging.Logger: logger you want action name and timing to be logged with :param func: log function to use (`log` has preference over `logger_`) """ assert logger_ or log start = time.monotonic() yield stop = time.monotonic() (log or logger_.debug)("%s took %.2f second(s)", action, stop - start) def timefun(logger_=logger, action=None, log=None): def decorator(fun): @functools.wraps(fun) async def wrapper(*args, **kwargs): with timeit(action or fun.__name__, logger_=logger_, log=log): return await fun(*args, **kwargs) return wrapper return decorator class sync: """ the same timefun decorator variation but without async/await """ @staticmethod def timefun(logger_=logger, action=None, log=None): """ :param logging.Logger: logger you want action name and timing to be logged with :param str: action name to log """ def decorator(fun): @functools.wraps(fun) def wrapper(*args, **kwargs): with timeit(action or fun.__name__, logger_=logger_, log=log): return fun(*args, **kwargs) return wrapper return decorator async def run( command, stdin=None, stdout=_subprocess.PIPE, stderr=_subprocess.PIPE, shell=False, input=None, **kwargs, ) -> Tuple[int, bytes, bytes]: """Asynchronous command executor. Returns a tuple (exit_code, stdout_data, stderr_data).""" if input is not None: if stdin is not None: # pragma: no cover raise ValueError("stdin and input arguments may not both be used.") stdin = _subprocess.PIPE if shell: assert isinstance(command, str) command = [command] create_subprocess = asyncio.create_subprocess_shell else: assert isinstance(command, (list, tuple)) create_subprocess = asyncio.create_subprocess_exec # type: ignore proc = await retry_on( BlockingIOError, max_tries=2, on_error=await_for(seconds=1) )( create_subprocess )( # type: ignore *command, stdin=stdin, stdout=stdout, stderr=stderr, start_new_session=True, **kwargs, ) out, err = await proc.communicate(input) exit_code = await proc.wait() logger.debug( "run(%s, stdin=%s, shell=%s) = %s", command, stdin, shell, (exit_code, out, err), ) return exit_code, out, err def run_coro(coro, *, loop=None, timeout=None): """Run coroutine from a blocking code (outside the event loop). Coroutine will be wrapped in Task. """ if loop is None: for _ in range(2): try: loop = asyncio.get_event_loop() except RuntimeError: # no loop in the main thread pass else: if not loop.is_closed(): break asyncio.set_event_loop(asyncio.new_event_loop()) return loop.run_until_complete( asyncio.wait_for( coro if isinstance(coro, asyncio.Future) else asyncio.Task(coro), timeout=timeout, ) ) class CheckRunError(_subprocess.CalledProcessError): def __str__(self): _MESSAGE = ( "Command {cmd!r} returned non-zero code {returncode},\n" "\t\tStdout: {output},\n" "\t\tStderr: {error}\n" ) return _MESSAGE.format( cmd=self.cmd, returncode=self.returncode, output=self.output.decode() or None, error=self.stderr.decode() or None, ) async def check_run(command, raise_exc=CheckRunError, **kwargs) -> bytes: """ Asynchronous command executor. Returns output as bytestring. """ returncode, out, err = await run(command, **kwargs) if returncode != 0: raise raise_exc(returncode, command, out, err) return out async def check_exit_code(command, raise_exc=CheckRunError) -> None: """ Asynchronous command executor. Raises raise_exc if exit code is nonzero. Stdin, stdout and stderr of command are connected to /dev/null. """ code, _, _ = await run( command, stdin=_subprocess.DEVNULL, stdout=_subprocess.DEVNULL, stderr=_subprocess.DEVNULL, ) if code != 0: raise raise_exc(code, command) async def safe_run(command, check_returncode=True, **kwargs) -> str: """Safe run command. Returns stdout as string or empty string on error""" try: rc, out, err = await run(command, **kwargs) except OSError: logger.warning("Command %s failed with OSError", command) return "" if check_returncode and rc != 0: logger.warning( "Command %s failed with exit code %s: %s", command, rc, err ) return "" try: result = out.strip().decode() except UnicodeDecodeError: logger.warning("Command %s returned non-utf8 output", command) return "" return result def plainold_lazy_init(decorated_f): """non asyncio vesion of lazy init""" placeholder = None def wrapper(): nonlocal placeholder if placeholder is None: placeholder = decorated_f() return placeholder return wrapper class PeriodicCheck: """ Invoke a callback with a certain period and return cached result in between. Raising an exception from the callback does not affect the next check schedule. """ def __init__(self, cb_coro, check_every_n_seconds): self._cb_coro = cb_coro self._check_every_n_seconds = check_every_n_seconds self._last_check_timestamp = time.monotonic() - check_every_n_seconds self._last_check_result = None self._lock = plainold_lazy_init(asyncio.Lock) async def __call__(self, *args, **kwargs): async with self._lock(): delta = time.monotonic() - self._last_check_timestamp if delta >= self._check_every_n_seconds: logger.debug( "Timeout %d seconds has expired, doing the check: %s", self._check_every_n_seconds, self._cb_coro, ) self._last_check_timestamp = time.monotonic() self._last_check_result = await self._cb_coro(*args, **kwargs) return self._last_check_result def cache_result(nsec): def decorate(coro): return PeriodicCheck(coro, nsec) return decorate class RecurringCheckStop(Exception): """ raised by coroutine to stop recurring_check loop """ pass async def wait_for_period(period, **period_kwargs): try: if callable(period): await asyncio.sleep(period(**period_kwargs)) else: await asyncio.sleep(period) return False except asyncio.CancelledError: return True async def should_stop_after_period_passed(check, period, **period_kwargs): return ( await wait_for_period(period, **period_kwargs) if period and check else False ) def recurring_check( period, consecutive_err_limit=10, check_period_first=False, **period_kwargs ): """ run decorated corotine in a loop every :period: seconds. If more then consecutive_err_limit error occured, exit loop. :param period: :param consecutive_err_limit: :param check_period_first: default false :return: """ def decorator(fun): @wraps(fun) async def wrapped(*args, **kwargs): consecutive_err_cnt = 0 while True: if await should_stop_after_period_passed( check_period_first, period, **period_kwargs ): break try: await fun(*args, **kwargs) except (asyncio.CancelledError, RecurringCheckStop): break except Exception as exc: consecutive_err_cnt += 1 if consecutive_err_cnt > consecutive_err_limit: logger.exception( "Error count exceeded limit,exiting check loop" ) break if isinstance(exc, _subprocess.CalledProcessError): logger.exception( "Failed to run %s (%s). stdout=%s, stderr=%s", exc.cmd, exc.returncode, exc.output, exc.stderr, ) else: logger.exception("Error executing %s", fun) else: consecutive_err_cnt = 0 if await should_stop_after_period_passed( not check_period_first, period, **period_kwargs ): break return wrapped return decorator def atomic_rewrite( filename, data, /, # ^^ positional-only for backward compatibility *, backup=True, uid=None, gid=None, allow_empty_content=True, permissions=None, ) -> bool: """Atomically rewrites *filename* with given *data*. If *filename*'s content is *data* already, do nothing. If both *uid* and *gid* are given then resulting file is chowned to given user id and group id. Skip rewrite with empty content if *allow_empty_content* is False. Chmod to given access *permissions* else preserve *filename* 's permissions. Return True if *filename* file was updated, False otherwise """ if isinstance(data, str): data = data.encode() with suppress(FileNotFoundError): with open(filename, "rb") as file: old_content = file.read(len(data) + 1) if old_content == data: return False if not allow_empty_content and not data: logger.error("empty content: %r for file: %s", data, filename) return False if backup: if isinstance(backup, (str, os.PathLike)): backup_filename = backup else: backup_filename = os.fspath(filename) + BACKUP_EXTENSION shutil.copy(filename, backup_filename) if permissions is None: # get filename's access permissions try: permissions = stat.S_IMODE(os.stat(filename).st_mode) except FileNotFoundError: # input file doesn't exists # derive permissions from umask current_umask = os.umask(0) # can't get it without setting os.umask(current_umask) permissions = 0o666 & ~current_umask dirpath, basename = os.path.split(filename) if not Path(dirpath).exists(): raise FileNotFoundError(f"Parent dir is missing: {dirpath!r}") with ExitStack() as stack: with NamedTemporaryFile( mode="wb", dir=dirpath, suffix=".i360edit", prefix=basename + "_", buffering=0, delete=False, ) as tf: def cleanup(): with suppress(FileNotFoundError): os.remove(tf.name) stack.callback(cleanup) # clean it up in case of any error tf.write(data) tf.flush() if uid is not None and gid is not None: os.chown(tf.fileno(), uid, gid) # note: NamedTemporaryFile always sets 0b600 os.chmod(tf.fileno(), permissions) # avoid partial/empty data on crash os.fsync(tf.fileno()) os.rename(tf.name, filename) stack.pop_all() # success, don't call cleanup # no attempt to ensure that filename is written to disk # (dir is not fsync-ed) return True @functools.lru_cache(1) def os_release_and_version(): try: return Path("/etc/system-release").read_text().rstrip() except OSError: return None def os_version(release_and_version=None) -> str: """Return os version, if can't get it raise ValueError""" rv = release_and_version or os_release_and_version() if rv: match = re.search(r"\s*(\d+\.\d+\S*)(\s|$)", rv) if match: return match.group(1) else: os_release_and_version.cache_clear() raise ValueError("Can't discover os version from %r" % rv) class OsReleaseInfo: ETC_OS_RELEASE = "/etc/os-release" DEBIAN = frozenset(("debian",)) RHEL_FEDORA_CENTOS = frozenset(("rhel", "fedora", "centos")) UNKNOWN = frozenset(("unknown",)) dict_ = None @classmethod def dict_from_file(cls, dict_): with open(cls.ETC_OS_RELEASE) as f: for line in f: try: k, v = line.rstrip().split("=") dict_[k] = v.strip('"') except ValueError: pass if "ID_LIKE" in dict_: dict_["ID_LIKE"] = frozenset(dict_["ID_LIKE"].split()) else: # https://www.freedesktop.org/software/systemd/man/os-release.html#ID= dict_["ID_LIKE"] = frozenset((dict_.get("ID", "linux"),)) @classmethod def to_dict(cls) -> Dict[str, Any]: if cls.dict_ is None: dict_: Dict[str, Any] = dict() if os.path.exists(cls.ETC_OS_RELEASE): cls.dict_from_file(dict_) else: # centos and cl 6 does not have /etc/os-release file # this will need to move to distro package in python 3.8 d = distro.linux_distribution() if d and d[0]: osid = d[0].lower().split()[0] if osid == "red" and "Red Hat Enterprise Linux" in d[0]: osid = "rhel" dict_["ID"] = osid dict_["PRETTY_NAME"] = "{} {} ({})".format( d[0], d[1], d[2] ) if osid in ("cloudlinux", "centos", "rhel"): dict_["ID_LIKE"] = cls.RHEL_FEDORA_CENTOS elif osid in ("ubuntu", "debian"): dict_["ID_LIKE"] = cls.DEBIAN else: dict_["ID_LIKE"] = cls.UNKNOWN else: dict_["ID"] = "unknown" dict_["ID_LIKE"] = cls.UNKNOWN dict_["PRETTY_NAME"] = "unknown" cls.dict_ = dict_ return cls.dict_ @classmethod def id_like(cls) -> FrozenSet[str]: return cls.to_dict()["ID_LIKE"] @classmethod def pretty_name(cls) -> str: return cls.to_dict()["PRETTY_NAME"] @classmethod def get_os(cls) -> str: """ :return: OS name, like centos, ubuntu, debian, cloudlinux, redhat in lower case """ return cls.to_dict().get("ID", "unknown") @classmethod def is_rhel(cls): return cls.get_os() == "rhel" @classmethod def is_centos(cls): return cls.get_os() == "centos" @classmethod def is_ubuntu(cls): return cls.get_os() == "ubuntu" @classmethod def is_cloudlinux(cls): return cls.get_os() in ("cloudlinux", "cloudlinuxserver") @classmethod def is_cloudlinux_solo(cls): return os.path.exists(_CL_SOLO_EDITION_FILE) @classmethod def is_debian(cls): return cls.get_os() == "debian" @classmethod def is_oracle_linux(cls): return cls.get_os() == "ol" @classmethod def is_almalinux(cls): return cls.get_os() == "almalinux" @classmethod def is_rockylinux(cls): return cls.get_os() == "rocky" def file_hash( filename: str, hash_func=hashlib.md5, chunksize: int = 4096 ) -> str: """Return hash of the file `filename`, reading it in chunks. * filename is a path to a file; * hash_func is a function that returns hash object (one of hashlib.md5 etc); * chunksize is a size of chunks to read, in bytes. """ return file_hash_and_size(filename, hash_func, chunksize)[0] def file_hash_and_size( filename: str, hash_func, chunksize: int = 4096, ) -> Tuple[str, int]: """Calculate hash and size of the file `filename`, reading it in chunks. * filename is a path to a file; * hash_func is a function that returns hash object (one of hashlib.md5 etc); * chunksize is a size of chunks to read, in bytes. Return tuple(hash, file size).""" hash_ = hash_func() size = 0 with open(filename, "rb") as f: while True: chunk = f.read(chunksize) if not chunk: break hash_.update(chunk) size += len(chunk) return hash_.hexdigest(), size def _parse_name_value(varname, defs_line): """Given login.defs line, return *varname*'s value.""" name, value = defs_line.split() # no end of line comments if varname != name: raise ValueError("Expected {varname!r}, got {name!r}".format(**vars())) return value def get_min_uid(): global _MIN_UID if _MIN_UID == -1: _MIN_UID, _ = _get_max_min_uid() return _MIN_UID def _get_max_min_uid(path="/etc/login.defs"): """Get UID_MIN, UID_MAX from the login.defs file specified as *path*. On error, return default for the current OS values. """ uid_min, uid_max = 1000, 60000 # default for centos 7, ubunty 16.04 centos = OsReleaseInfo.id_like() & OsReleaseInfo.RHEL_FEDORA_CENTOS with suppress(ValueError): if centos and os_version().startswith("6"): uid_min, uid_max = 500, 60000 # default for centos 6 try: with open(path) as file: for line in file: if line.startswith("UID_MIN"): uid_min = int(_parse_name_value("UID_MIN", line)) if line.startswith("UID_MAX"): uid_max = int(_parse_name_value("UID_MAX", line)) except (OSError, ValueError): # use default pass return uid_min, uid_max def get_non_system_users( excludes=("imunify360-captcha", "imunify360-webshield") ): """ :param excludes: users to exclude in results :return: list: list of pwd.struct_passwd objects representing users """ uid_min, uid_max = _get_max_min_uid() return [ entry for entry in pwd.getpwall() if uid_min <= entry.pw_uid <= uid_max and entry.pw_name not in excludes ] def get_system_user_names(): """ :return: list: list of str with system user names """ uid_min, _ = _get_max_min_uid() return [ entry.pw_name for entry in pwd.getpwall() if uid_min >= entry.pw_uid ] @functools.lru_cache() def is_system_user(uid: int): uid_min, uid_max = _get_max_min_uid() return uid < uid_min async_lru_cache = functools.partial( async_lru.alru_cache, maxsize=100, # set tot true because of backward compatibility with previous # implementation of async_lru_cache typed=True, ) def append_with_newline(filename, data): with open(filename, "r+") as f: # ensure we have eol at the end of file # returns poiner position 0 if file is empty last_char_pos = f.seek(0, 2) if last_char_pos != 0: f.seek(last_char_pos - 1) if f.read(1) != "\n": f.write("\n") f.write(data) if not data.endswith("\n"): f.write("\n") def append_with_newline_bytes(filename: os.PathLike, data: bytes) -> None: """Append *data* to *filename* making sure there is \n at the end.""" with open(filename, "r+b") as f: # ensure we have eol at the end of file # returns poiner position 0 if file is empty last_char_pos = f.seek(0, 2) if last_char_pos != 0: f.seek(last_char_pos - 1) if f.read(1) != b"\n": f.write(b"\n") f.write(data) if not data.endswith(b"\n"): f.write(b"\n") def ensure_line_in_file(filename, line): """Add *line* to *filename* if it is not present in the file Returns: True if the file was changed, False otherwise. """ changed = False with open(filename, "r") as f: if not any(_line.strip() == line for _line in f): changed = True if changed: append_with_newline(filename, line) return changed def ensure_line_in_file_bytes(filename: os.PathLike, line: bytes) -> bool: """Add *line* to *filename* if it is not present in the file. Returns: True if the file was changed, False otherwise. """ changed = False with open(filename, "rb") as f: if not any(_line.strip() == line for _line in f): changed = True if changed: append_with_newline_bytes(filename, line) return changed def remove_line_from_file(filename, line): basedir = os.path.dirname(filename) with open(filename, "r") as sf, NamedTemporaryFile( mode="w", dir=basedir, delete=False ) as tf: for _line in sf: if _line.strip() != line: tf.write(_line) os.rename(tf.name, filename) class FileLock: """ Simple context manager to enable UNIX-specific file locking with flock system call """ _TIMEOUT = 10 # Default timeout to wait for lock def __init__(self, path, timeout=_TIMEOUT): self.path = path self.locked = False self.file = open(path, "w") self.timeout = timeout async def __aenter__(self): start = time.time() while True: try: # Trying to perform file lock flock(self.file, LOCK_EX | LOCK_NB) self.locked = True return self # Resource temporarily unavailable except (OSError, IOError) as ex: if ex.errno != errno.EAGAIN: raise # if did not succeed # to lock file within a given timeout # perform operation without it elif self.timeout < time.time() - start: logger.warning( "Failed to lock file %s. Timeout exceeded.", self.path ) break # Return control to event loop and wait await asyncio.sleep(1) async def __aexit__(self, exc_type, exc_val, exc_tb): # If successfully locked file at entering context # release it if self.locked: flock(self.file, LOCK_UN) self.locked = False self.file.close() def user_identity(attackers_ip, source, fields=USER_IDENTITY_HEADERS): try: # TODO: change after migtration to python3.8 # dicts in python3.5 do not keep order, # that's why we sort items to get the same hash for the same source uid_data = [attackers_ip] uid_data.extend( str(value) for field, value in sorted(source.items()) if field in fields ) # ModSecurity has no capability to create sha256 hashes # using sha1 instead hash_alg = hashlib.sha1() hash_alg.update("".join(uid_data).encode("utf8", "surrogateescape")) return hash_alg.hexdigest() except (ValueError, UnicodeEncodeError) as e: logger.error( "Generation of user identity hash failed, invalid data: %s", e ) return None def is_root_user(): return os.getuid() == 0 @contextmanager def run_with_umask(mask: int): current_mask = os.umask(mask) try: yield finally: os.umask(current_mask) def get_abspath_from_user_dir(username: str, relpath="") -> Path: """ Returns user's home dir if `relpath` is not specified. Otherwise, returns absolute path of `relpath` build from `username`'s home dir :raise ValueError: when user home dir is not exists """ if not isinstance(username, str): raise ValueError("Invalid type for %s, should be str!" % username) if os.sep in username: raise ValueError("Invalid username") try: pw = pwd.getpwnam(username) except KeyError: raise ValueError("User {!r} doesn't exist".format(username)) abs_path = os.path.join(pw.pw_dir, relpath) return Path(abs_path) def does_path_belong_to_user(path: str, username: str) -> bool: status = False try: user_home = get_abspath_from_user_dir(username) Path(path).relative_to(user_home) status = True except ValueError as e: logger.warning(str(e)) return status def get_path_owner(path): if not os.path.abspath(path): raise ValueError("Path %s should be absolute!" % path) while True: if os.path.exists(path): try: return pwd.getpwuid(os.stat(path).st_uid).pw_name except KeyError: return str(os.stat(path).st_uid) path = os.path.dirname(path) def split_for_chunk(iterable: Iterable, chunk_size: int = 500) -> Generator: """ Generator that splits iterable on N-parts by chunk_size items in each chunk >>> list(split_for_chunk([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], chunk_size=2)) [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] :param iterable: :param int chunk_size: :return: generator: """ i = iter(iterable) piece = list(islice(i, chunk_size)) while piece: yield piece piece = list(islice(i, chunk_size)) def freeze(d): if isinstance(d, dict): return frozenset((key, freeze(value)) for key, value in d.items()) elif isinstance(d, list): return tuple(freeze(value) for value in d) return d class Singleton(type): """ Metaclass for creating only one instance of class, when providing the same arguments. """ _instances = {} def __call__(cls, *args, **kwargs): key = (cls, freeze(args), freeze(kwargs)) if not cls._instances.get(key): cls._instances[key] = super(Singleton, cls).__call__( *args, **kwargs ) return cls._instances[key] @functools.lru_cache(maxsize=10) def get_external_ip(): """ :return str: server's external IP address """ with urllib.request.urlopen("https://api.ipify.org", timeout=2) as r: return r.read().decode() def get_kernel_module_parameter(module_name, parameter): """ Reads parameter of kernel module from /sys/module/{module_name}/parameters/{parameter} :return str: value of the parameter """ _MOD_PAR_PATH = "/sys/module/{mod}/parameters/{parameter}" param_file = _MOD_PAR_PATH.format(mod=module_name, parameter=parameter) if not os.path.exists(param_file): raise ValueError( "Cannot find parameter %s for module %s" % (parameter, module_name) ) with open(param_file, "r") as p: value = p.read().strip() return value def dict_deep_update(dst, src, allow_overwrite=True) -> bool: """Performs deep update of dict dst with values from src. Does not overwrite subdicts in dst blindly with new dicts in src, but does a deep update of (sub)dict content recursively""" updated = False for k, v in src.items(): if isinstance(v, dict): if k not in dst or not v: dst[k] = v updated = True else: updated = dict_deep_update(dst[k], v) else: assert ( k not in dst or allow_overwrite ), f"{k} already exists in {dst}" dst[k] = v updated = True return updated class TimedCache: def __init__(self, expiration, maxsize=100): assert isinstance(expiration, timedelta) self.expiration = expiration self.maxsize = maxsize self.cache = OrderedDict() self._locks = {} def _collect(self): """Clear cache from expired values""" tmp_cache = OrderedDict() for key in self.cache: value, added_at = self.cache[key] if (time.time() - added_at) < self.expiration.total_seconds(): tmp_cache[key] = value, added_at self.cache = tmp_cache def cache_clear(self): self.cache = OrderedDict() self._locks = {} def _make_key(self, args, kwargs): """ Generate key from call arguments :param args: call positional args :param kwargs: call keyword args :return: """ seed = args if kwargs: kw = sorted(kwargs.items()) seed += tuple(kw) return hash(seed) def __call__(self, func: F) -> F: """ Use it to cache calls to decorated function @TimedCache(expiration=timedelta(minutes=10)) async def func(*args, **kwargs): pass :param func: decorated function :return: NOTE: is not thread safe. """ @wraps(func) async def wrapper_async(*args, **kwargs): key = self._make_key(args, kwargs) lock = self._locks.get(key) if lock is None: lock = self._locks[key] = asyncio.Lock() while True: try: await asyncio.wait_for( lock.acquire(), self.expiration.total_seconds() ) break except asyncio.TimeoutError: # if TimeoutError occurred it means that we not able to # acquire lock, and if it the same lock which we try to # acquire just create a new one, otherwise it already # recreated and we should repeat the attempt to acquire # a lock if lock is self._locks[key]: lock = self._locks[key] = asyncio.Lock() else: lock = self._locks[key] try: self._collect() try: result, _ = self.cache[key] except KeyError: if len(self.cache) >= self.maxsize: self.cache.popitem(last=False) result = await func(*args, **kwargs) self.cache[key] = result, time.time() finally: lock.release() return result @wraps(func) def wrapper_sync(*args, **kwargs): self._collect() key = self._make_key(args, kwargs) try: result, _ = self.cache[key] except KeyError: if len(self.cache) >= self.maxsize: self.cache.popitem(last=False) result = func(*args, **kwargs) self.cache[key] = result, time.time() return result wrapper = ( wrapper_async if asyncio.iscoroutinefunction(func) else wrapper_sync ) wrapper.cache_clear = self.cache_clear # type: ignore return wrapper # type: ignore timed_cache = TimedCache def fail_agent_service(): """ Send SIGUSR2 to os.getpid() to shutdown agent process by signal (implies exit code -12). Agent will do failover restart then thanks to systemd (or chkservd) if it needs. """ os.kill(os.getpid(), signal.SIGUSR2) async def run_cmd_and_log(cmd, log_file_mask, **popen_kwargs): """ Runs command and log it's output to the log file :param cmd: :param log_file_mask: :return: str path of log file """ live_log = log_file_mask.replace("*", str(os.getpid())) with open(live_log, "w") as live_log_fp: popen_kwargs.update( dict( stdin=asyncio.subprocess.DEVNULL, stdout=live_log_fp, stderr=live_log_fp, start_new_session=True, ) ) logger.debug("Popen(%r, %r)", cmd, popen_kwargs) proc = await asyncio.subprocess.create_subprocess_shell( cmd, **popen_kwargs ) with open(live_log + ".pid", "w") as pf: pf.write( "{:d}\t{}\n".format( proc.pid, psutil.Process(proc.pid).create_time().hex() ) ) return live_log # fix AttributeError: 'NoneType' object has no attribute '_PENDING' on exit # https://github.com/python/asyncio/issues/423#issuecomment-268882753 class Task(asyncio.Task): def __del__(self): if self._state == "PENDING" and self._log_destroy_pending: context = { "task": self, "message": "Task was destroyed but it is pending!", } if self._source_traceback: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) try: Future.__del__(self) except AttributeError: name = getattr(self._coro, "__qualname__", None) or getattr( self._coro, "__name__", None ) code = getattr(self._coro, "gi_code", None) or getattr( self._coro, "cr_code", None ) frame = getattr(self._coro, "gi_frame", None) or getattr( self._coro, "cr_frame", None ) filename = code.co_filename lineno = (frame and frame.f_lineno) or code.co_firstlineno print( "!> Finalizer error in {}() {} at {} line {}".format( name, self._state, filename, lineno ) ) def await_for(seconds): """Return async callback which waits for *seconds*. Usage: @retry_on(Error, on_error=await_for(seconds=PAUSE_INTERVAL), timeout=T) async def coro(): 'here's something that may raise Error.' """ async def pause(*args): return await asyncio.sleep(seconds) return pause def retry_on( exception, on_error=None, max_tries=None, timeout=None, silent=False, log=None, ): """ Retry the function call on exception (or exceptions, if given in tuple) at most *max_tries*. Await *on_error* (if set) for each exception. If *timeout* is set, stop all attempts in *timeout* seconds. If *silent* is set to True - don't raise exceptions after max tries or timeout. """ if not any([max_tries, timeout]): raise ValueError("Set any of max_tries, timeout") def decorator(func): @functools.wraps(func) async def wrapper_async(*args, **kwargs): if timeout: end_time = time.monotonic() + timeout for i in ( itertools.count(1) if not max_tries else range(1, max_tries + 1) ): try: if timeout: remaining_time = end_time - time.monotonic() if remaining_time > 0: return await asyncio.wait_for( func(*args, **kwargs), timeout=remaining_time ) else: if not silent: raise asyncio.TimeoutError elif log: log.error( "Timeout exceeded when calling %s", func ) else: return await func(*args, **kwargs) except (asyncio.TimeoutError, asyncio.CancelledError): raise except exception as exc: if i == max_tries: if not silent: raise elif log: log.error( "Max tries exceeded when calling %s with" " error %s", func, exc, ) if on_error is not None: await on_error(exc, i) @functools.wraps(func) def wrapper_sync(*args, **kwargs): if timeout: end_time = time.monotonic() + timeout for i in ( itertools.count(1) if not max_tries else range(1, max_tries + 1) ): try: if timeout: remaining_time = end_time - time.monotonic() if remaining_time > 0: return func(*args, **kwargs) else: if not silent: raise TimeoutError elif log: log.error( "Timeout exceeded when calling %s", func ) else: return func(*args, **kwargs) except exception as exc: if i == max_tries: if not silent: raise elif log: log.error( "Max tries exceeded when calling %s with" " error %s", func, exc, ) if on_error is not None: on_error(exc, i) if asyncio.iscoroutinefunction(func): return wrapper_async else: return wrapper_sync return decorator def stub_unexpected_error(func): """If func throws an exception it is catched, converted to a string and returned as a result of a call.""" @functools.wraps(func) async def wrapper_async(*args, **kwargs): try: return await func(*args, **kwargs) except Exception as e: # noqa return repr(e) @functools.wraps(func) def wrapper_sync(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: # noqa return repr(e) return wrapper_async if asyncio.iscoroutinefunction(func) else wrapper_sync def log_error_and_ignore(exception=Exception, log_handler=None): """A decorator that logs uncaught exceptions ignoring them otherwise. CancelledError is not handled. """ if log_handler is None: log_handler = logger.exception def decorator(coro): @functools.wraps(coro) async def wrapper_async(*args, **kwargs): try: return await coro(*args, **kwargs) except asyncio.CancelledError: raise except exception as e: log_handler( "Ignoring exception from %s: %s", getattr(coro, "__qualname__", "coro"), e, ) @functools.wraps(coro) def wrapper_sync(*args, **kwargs): try: return coro(*args, **kwargs) except exception as e: log_handler( "Ignoring exception from %s: %s", getattr(coro, "__qualname__", "coro"), e, ) if asyncio.iscoroutinefunction(coro): return wrapper_async else: return wrapper_sync return decorator def abort_agent_on(exception, abort=fail_agent_service): """Abort the agent service on *exception*.""" def decorator(coro): @functools.wraps(coro) async def wrapper(*args, **kwargs): try: return await coro(*args, **kwargs) except exception as e: logger.exception(e) # do not silently stop the current task but abort() return wrapper return decorator def snake_case(string): """PascalCase to snake_case""" return re.sub("([a-z])([A-Z])", r"\1_\2", string).lower() CHUNK_SIZE_SQL_QUERY = 200 def get_results_iterable_expression( expr, iterable, *args, exec_expr_with_empty_iter=False ): """ Get iterator over results of sql expression expr. Given iterable will be split for chunks and we will return iterator containing results of all split queries. Useful for sql selects with in_() in order to avoid too many sql variables error. If exec_expr_with_empty_iter is True and iterable is None(empty) we will process expression once, passing here chunk=None expr(None, *args) :param expr: :param iterable: :param exec_expr_with_empty_iter: if iterable is None(empty) process given expression once, passing here chunk=None expr(None, *args) :return: """ if not iterable and exec_expr_with_empty_iter: chunks = [None] else: chunks = split_for_chunk(iterable, chunk_size=CHUNK_SIZE_SQL_QUERY) from defence360agent.model import instance with instance.db.transaction(): for chunk in chunks: yield from expr(chunk, *args) def execute_iterable_expression( expr, iterable, *args, chunk_size=CHUNK_SIZE_SQL_QUERY ): """ Get number of results of sql expression expr. Given iterable will be split for chunks and we will return number of results of all split queries. Useful for sql delete with in_() in order to avoid too many sql variables error """ changed = 0 from defence360agent.model import instance with instance.db.transaction(): for chunk in split_for_chunk(iterable, chunk_size=chunk_size): changed += expr(chunk, *args).execute() return changed def encode_filename(file): return os.fsencode(file.replace("\n", "\\n")) + b"\n" def decode_filename(file): return os.fsdecode(file)[:-1].replace("\\n", "\n") def base64_encode_filename(path: Path) -> bytes: return base64.b64encode(os.fsencode(path)) def base64_decode_filename(b64name: bytes) -> Path: return Path(os.fsdecode(base64.b64decode(b64name))) def getpwnam(username): """ Like pwd.getpwnam(username) but returns None instead of raising KeyError. """ try: result = pwd.getpwnam(username) except KeyError: result = None return result def clip(value, low, high): """ Put the specified `value` inside the [`low`, `high`] interval. """ return max(min(value, high), low) def create_task_and_log_exceptions( loop, coro: Callable[..., Awaitable], *args, **kwargs ): """ Use this function in plugin initialization instead of loop.create_task to be able to see the exceptions from the specified coroutine. """ def _log_exception(task): if not task.cancelled() and task.exception() is not None: loop.call_exception_handler( { "message": ( "Unhandled exception during plugin initialization!" ), "exception": task.exception(), "task": task, } ) new_task = loop.create_task(coro(*args, **kwargs)) new_task.add_done_callback(_log_exception) return new_task def make_coro(function): """ Create coroutine from regular function Useful to pass functions to APIs requiring coroutines Note: coroutine will still block event loop in main thread. For most blocking functions, run_in_executor should be considered instead :param function: :return: coroutine running function """ async def coro(*args, **kwargs): return function(*args, **kwargs) return coro COPY_TO_MODSEC_MAXTRIES = 5 _MODSEC_COPY_FAILURE_TIMEOUT = 5 async def log_failed_to_copy_to_modsec(exc, i): if i == COPY_TO_MODSEC_MAXTRIES: log = logger.error else: log = logger.warning log( "Failed to copy data%s to modsec ruleset dir %r, try: %s", f" ({fn})" if (fn := getattr(exc, "filename", None)) else "", exc, i, ) await asyncio.sleep(_MODSEC_COPY_FAILURE_TIMEOUT) @functools.lru_cache(maxsize=1) def is_centos6_or_cloudlinux6(): with suppress(ValueError): if ( OsReleaseInfo.is_centos() or OsReleaseInfo.is_cloudlinux() ) and os_version().startswith("6"): return True return False async def readlines_from_cmd_output( cmd: List[str], *, err_buf_size=100, **popen_kwargs ): """ Start *cmd*, yield its stdout line by line [b'\n'] If *cmd* return nonzero exit status, raise CheckRunError with the last *err_buf_size* lines from stderr. """ async def read_pipe_into(pipe, buf): async for line in pipe: buf.append(line) err_buf = deque(maxlen=err_buf_size) # keep a few last lines proc = await asyncio.create_subprocess_exec( *cmd, start_new_session=True, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, **popen_kwargs, ) try: # note: read data from stderr to avoid deadlock # if stderr pipe buffer is full asyncio.create_task(read_pipe_into(proc.stderr, err_buf)) async for line in proc.stdout: # type: ignore yield line finally: returncode = await proc.wait() if returncode != 0: raise CheckRunError(returncode, cmd, b"", b"".join(err_buf)) async def finally_happened(predicate_coro, *args, max_tries=2, delay=5): """ Retry *predicate_coro(*args)* until it becomes true, but no more than *max_tries* attempts. Sleep for *delay* seconds before the next *predicate_coro()* call. Return whether the predicate became true. """ for attempt in range(1, max_tries + 1): result = await predicate_coro(*args) if not result and attempt < max_tries: await asyncio.sleep(delay) continue return result async def nice_iterator(iterable, chunk_size=10_000): """Yield to the event loop every *chunk_size* iterations.""" # for chunks in zip(*[iter(iterable)]*chunk_size): # yield from chunks # -> SyntaxError: 'yield from' inside async function for i, item in enumerate(iterable, start=1): yield item if (i % chunk_size) == 0: await asyncio.sleep(0) class LazyLock: """ Descriptor object to share async Lock between client objects. Used in order to achieve lazy evaluation of the lock and share state between it's clients. Using asyncio.Lock in client code directly: >>> class Foo: >>> lock = asyncio.Lock() leads to an unclear error ([Errno 9] Bad file descriptor), when trying to move this Lock during demonization process. """ def __init__(self): self._lock = None def __get__(self, instance, owner): if not self._lock: self._lock = asyncio.Lock() return self._lock def _get_system_package_version(regex, output): if m := re.search(regex, output): return m.group(1).strip() @functools.lru_cache(maxsize=1) def _get_cmd_n_regex(): if OsReleaseInfo.is_ubuntu() or OsReleaseInfo.is_debian(): return (["dpkg-query", "-l"], r"(?m)^ii\s+{}\s+(\S+).*") else: return (["rpm", "-q"], r"{}-([\d\.]*-\d*)") class FirewallDisabledException(Exception): """Exception in case of using firewall api, when it's disabled""" def check_disabled_firewall(func): @wraps(func) async def wrapper(*args, **kwargs): if os.path.exists("/var/imunify360/firewall_disabled"): raise FirewallDisabledException( "Not available in the current build" ) return await func(*args, **kwargs) return wrapper async def system_packages_info(packages) -> dict: """ Retrieves the version of the specified system packages using a command and regex specific to the current system. Parameters: packages (Set[str]): A set of package names to retrieve version for. Returns: A dictionary mapping package names to their corresponding version strings, or None if the package is not installed or version information cannot be retrieved. """ cmd, version_regexp = _get_cmd_n_regex() output = await safe_run_with_timeout( cmd + list(packages), timeout=30, check_returncode=False ) return { package_name: _get_system_package_version( version_regexp.format(package_name), output ) for package_name in packages } async def safe_run_with_timeout(command, timeout, log=logger.error, **kwargs): try: return await asyncio.wait_for( safe_run(command, **kwargs), timeout=timeout ) except asyncio.TimeoutError: log("Command %s failed: Timeout occurred", command) return "" def batched(iterable, n: int): # backported from Python 3.12, except it yields a list instead of a tuple # https://docs.python.org/3.12/library/itertools.html#itertools.batched # # batched('ABCDEFG', 3) → ABC DEF G if n < 1: raise ValueError("n must be at least one") it = iter(iterable) while batch := list(islice(it, n)): yield batch def batched_dict(d: Dict[Any, Any], n: int): for batch in batched(d, n): yield {k: d[k] for k in batch} @functools.lru_cache(maxsize=1) def is_cloudways(): try: hostname = _subprocess.check_output( ["hostname", "-f"], text=True ).strip() _is_cloudways = hostname.endswith( (".cloudwaysapps.com", ".cloudwaysstagingapps.com") ) if not _is_cloudways and Path("/usr/local/sbin/apm").exists(): result = _subprocess.check_output( ["/usr/local/sbin/apm", "info"], text=True ) if "Cloudways" in result: _is_cloudways = True return _is_cloudways except Exception as e: logger.error("Error while checking environment: %s", e) return False