137 lines
5.6 KiB
Python
137 lines
5.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
import datetime
|
|
import logging
|
|
from multiprocessing.context import TimeoutError
|
|
from multiprocessing.pool import ThreadPool
|
|
from threading import Lock
|
|
from typing import Union, Callable, TypeVar, Generic, Set, List
|
|
|
|
from backend.models import Recorder
|
|
from backend.tools.simple_state_checker import check_capture_agent_state, ping_capture_agent
|
|
|
|
logger = logging.getLogger("mal.cron.recorder_state")
|
|
|
|
recorder_jobs_lock = Lock()
|
|
recorder_jobs = set()
|
|
|
|
NUM_THREADS = 8
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class StateChecker(Generic[T]):
|
|
def __init__(self, state_checker_func: Union[Callable, List[Callable]], type_to_check: T, type_name=None,
|
|
threads=NUM_THREADS):
|
|
self.num_threads = threads
|
|
self.lock = Lock()
|
|
self.jobs: Set[T] = set()
|
|
self.checker_func = state_checker_func
|
|
self.checker_type = type_to_check
|
|
self.update_state_lock = Lock()
|
|
self.state_results = {}
|
|
self.type_name = type_name if type_name is not None else self.checker_type.__name__
|
|
|
|
def add_object_to_state_check(self, object_to_check: Union[int, T]):
|
|
if isinstance(object_to_check, int):
|
|
if not hasattr(self.checker_type, 'get_by_identifier'):
|
|
logger.error(
|
|
'Can\'t add object to state check, as >get_by_identifier< not defined on checker_type ({})!'.format(
|
|
str(self.checker_type)))
|
|
return
|
|
object_to_check = self.checker_type.get_by_identifier(object_to_check)
|
|
if object_to_check is None:
|
|
logger.warning(
|
|
"Could not add object ({}) to state check, as specified >id ({})< could not be found / object is None".format(
|
|
self.type_name, object_to_check))
|
|
return
|
|
self.lock.acquire()
|
|
if hasattr(object_to_check, 'name'):
|
|
name = object_to_check.name
|
|
else:
|
|
name = str(object_to_check)
|
|
logger.debug("Adding {} to object ({}) to state check".format(self.type_name, name))
|
|
self.jobs.add(object_to_check)
|
|
self.lock.release()
|
|
|
|
def remove_recorder_from_state_check(self, object_to_check: Union[int, T]):
|
|
if isinstance(object_to_check, int):
|
|
object_to_check = self.checker_type.get_by_identifier(object_to_check)
|
|
if object_to_check is None:
|
|
logger.warning(
|
|
"Could not remove object ({}) from state check, as specified id could not be found / object is None".format(
|
|
self.type_name))
|
|
return
|
|
self.lock.acquire()
|
|
if hasattr(object_to_check, 'name'):
|
|
name = object_to_check.name
|
|
else:
|
|
name = str(object_to_check)
|
|
logger.debug("Removing {} from object ({}) to state check".format(self.type_name, name))
|
|
self.jobs.remove(object_to_check)
|
|
self.lock.release()
|
|
|
|
def execute_checker_func(self, func, jobs: List[T], object_states: dict) -> dict:
|
|
with ThreadPool(self.num_threads) as pool:
|
|
results = [pool.apply_async(func, (job,)) for job in jobs]
|
|
try:
|
|
state_results = [res.get(timeout=12) for res in results]
|
|
for r in state_results:
|
|
if r[0]: # ok :)
|
|
if object_states[r[2]].get('msg', "") == "unknown state!":
|
|
del object_states[r[2]]['msg']
|
|
object_states[r[2]] = {
|
|
'msg': ", ".join([s for s in [object_states[r[2]].get('msg', None), r[1]] if s]),
|
|
'state_ok': True}
|
|
else:
|
|
object_states[r[2]]['msg'] = r[1]
|
|
except TimeoutError as e:
|
|
logger.error("Timeout while performing state check func! {}".format(e))
|
|
|
|
return object_states
|
|
|
|
def check_object_state(self) -> dict:
|
|
logger.info("checking object ({}) state...".format(self.type_name))
|
|
self.lock.acquire()
|
|
jobs = list(self.jobs)
|
|
self.lock.release()
|
|
|
|
if len(jobs) <= 0:
|
|
logger.info("No objects ({}) to check... returning".format(self.type_name))
|
|
return {}
|
|
logger.info("checking state of {} recorders".format(len(jobs)))
|
|
|
|
object_states = {j.name: {'state_ok': False, 'msg': 'unknown state!'} for j in jobs}
|
|
|
|
if isinstance(self.checker_func, list):
|
|
for c_f in self.checker_func:
|
|
self.execute_checker_func(c_f, jobs, object_states)
|
|
else:
|
|
self.execute_checker_func(self.checker_func, jobs, object_states)
|
|
|
|
self.update_state_dict(object_states)
|
|
|
|
return object_states
|
|
|
|
def update_state_dict(self, object_states: dict):
|
|
self.update_state_lock.acquire()
|
|
for o_s in object_states.keys():
|
|
if o_s in self.state_results:
|
|
# update existing state
|
|
self.state_results[o_s] = {**object_states[o_s],
|
|
'time_stamp': datetime.datetime.now(datetime.timezone.utc).strftime(
|
|
"%d.%m.%Y - %H:%M:%S %Z"),
|
|
'previous': {'state_ok': self.state_results[o_s]['state_ok'],
|
|
'msg': self.state_results[o_s].get('msg', None),
|
|
'time_stamp': self.state_results[o_s].get('time_stamp', None)}}
|
|
pass
|
|
else:
|
|
self.state_results[o_s] = object_states[o_s]
|
|
|
|
self.update_state_lock.release()
|
|
|
|
def get_current_state(self):
|
|
return self.check_object_state()
|
|
|
|
|
|
recorder_checker = StateChecker([check_capture_agent_state, ping_capture_agent], Recorder)
|