#!/usr/bin/env python """ Script checking certificate validity. If valid for less than a week, regenerate and send MQTT message on a dedicated channel containing new certificate. """ import os import sys import subprocess import logging from datetime import datetime from configparser import ConfigParser import ssl from paho.mqtt.client import Client as MQTTClient LOG_LEVELS = { 'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING': logging.WARNING, 'ERROR': logging.ERROR, 'CRITICAL': logging.CRITICAL, } CONFIG_FILE = 'settings.cfg' TIMEFORMAT = "%b %d %H:%M:%S %Y %Z" CERT_MESSAGE_FMT = 'certificate: %s' KEY_MESSAGE_FMT = 'key: %s' # User certificate and key naming convention. %s is replaced by username: # This only works for DER certificate format that Pico uses CERT_FMT = '%s_crt.pem' KEY_FMT = '%s_key.pem' SSL_SET = None SSL_CA_CERT = None SSL_CLIENTS_FOLDER = None SSL_SERVER_CERT = None SSL_SERVER_KEY = None SSL_CLIENT_CERT = None SSL_CLIENT_KEY = None MQTT_SERVER = None MQTT_PORT = None MQTT_KEEPALIVE = None MQTT_TOPIC_CERT_RENEWAL = None MQTT_LOG_FILE = None MQTT_LOG_LEVEL = None MQTT_CLIENT_USERNAME = None GRACE_PERIOD = None SSL_SCRIPTS_DIR = None def check_cert_valid(username, grace_period, server=False): """ Check certificate valid for specified time. Return: 0 - valid for more than a 1 - needs renewal 2 - expired """ retval = -1 expiry_date, delta, stdout = None, None, None cert_pathname, _ = get_cert_key_paths(username, server) if not cert_pathname: logger.error("Did not find certificate files for user %s. Wrong certificate format?", username) else: logger.debug("Checking certificate for user %s", username) logger.debug("Certificate path: %s ", cert_pathname) openssl_cmd_check = ['openssl', 'x509', '-enddate', '-noout', '-in', cert_pathname] with subprocess.Popen( openssl_cmd_check, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) as proc: stdout, stderr = proc.communicate() if stderr: logger.warning('Encountered error checking certificate expiation date for user %s: %r', username, stderr) if proc.returncode: logger.error('openssl return code non-zero: %d', proc.returncode, exc_info=True) if stdout: logger.debug("Stdout from certificate check: %r", stdout) try: expiry_date = str(stdout).split('=')[1].strip("\\n'") logger.debug('Certificate expires on : %r', expiry_date) except IndexError: logger.error('Invalid data from openssl: %r', stdout) if expiry_date: delta = check_time(expiry_date) logger.debug("Time delta for the validity of the certificate: %r", delta) if delta: if delta.days >= 0: logger.debug("SSL certificate for user '%s' expired.", username) # Certificate expired. retval = 2 elif delta.days > grace_period: logger.debug("SSL certificate for user '%s' does not need renewal", username) # Ceritficate valid for longer than grace_period retval = 0 else: logger.debug("SSL certificate for user '%s' needs renewing", username) # Certificate needs renewal retval = 1 return retval def renew_certificate(username): """ Use shell script to recreate certificate. Backs up old certificate. """ ssl_script_path = os.path.join(SSL_SCRIPTS_DIR, 'client_maker_erg.sh') openssl_cmd_create = [ssl_script_path, 'der', username] with subprocess.Popen( openssl_cmd_create, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) as proc: stdout, stderr = proc.communicate() if stderr: logger.warning("Error creating certificate for %s: %r", username, stderr) return proc.returncode def check_time(date_as_string): """Calculate time delta and compare to the minimum valid period.""" validity_time = datetime.strptime(date_as_string, TIMEFORMAT) current_date = datetime.now() return current_date - validity_time def on_message_callback(): pass def handle_renewal(username): # subscrite to renewal topic. # publish certificate # listen to oncoming OK/NOK message # If NOK - log, else log and restart mqtt renewal_topic_set = f"{MQTT_TOPIC_CERT_RENEWAL}/#" cert_pathname, key_pathname = get_cert_key_paths(username) client_topic = f"MQTT_TOPIC_CERT_RENEWAL/{username}" client = MQTTClient() client.tls_set( ca_certs=SSL_CA_CERT, certfile=SSL_SERVER_CERT, keyfile=SSL_SERVER_KEY, cert_reqs=ssl.CERT_REQUIRED, ) client.connect( host=MQTT_SERVER, port=MQTT_PORT, keepalive=MQTT_KEEPALIVE ) client.subscribe(renewal_topic_set) client.on_message = on_message_callback if not cert_pathname: logger.warning("SSL certificate for user %s does not exist, will not renew!", username) if not key_pathname: logger.warning("SSL key for user %s does not exist, will not renew!", username) if cert_pathname and key_pathname: try: with open(cert_pathname, 'r', encoding='utf-8') as f: cert_data = f.read() cert_message = CERT_MESSAGE_FMT % cert_data with open(key_pathname, 'r', encoding='utf-8') as f: key_data = f.read() key_message = KEY_MESSAGE_FMT % key_data logger.debug("Sending user '%s' SSL certificate on channel: %s", username, client_topic) client.publish(client_topic, cert_message) logger.debug("Sending user '%s' SSL key on channel: %s", username, client_topic) client.publish(client_topic, key_message) except Exception as exc: logger.error("Exception sending cert/key pair for user %s: %r", username, exc, exc_info=True) def get_settings(config_file): """ Get settings from configparser. """ cfg = ConfigParser(interpolation=None) cfg.read(config_file) global SSL_SET global SSL_CA_CERT global SSL_CLIENTS_FOLDER global SSL_SERVER_CERT global SSL_SERVER_KEY global SSL_CLIENT_CERT global SSL_CLIENT_KEY global MQTT_SERVER global MQTT_PORT global MQTT_KEEPALIVE global MQTT_TOPIC_CERT_RENEWAL global MQTT_LOG_FILE global MQTT_LOG_LEVEL global MQTT_CLIENT_USERNAME global GRACE_PERIOD global SSL_SCRIPTS_DIR # SSL certificates: SSL_SET = cfg.getboolean('Certificates', 'ssl') if SSL_SET: SSL_CA_CERT = cfg.get('Certificates', 'ca_crt') SSL_SERVER_CERT = cfg.get('Certificates', 'server_cert_file') SSL_SERVER_KEY = cfg.get('Certificates', 'server_key_file') SSL_CLIENTS_FOLDER = cfg.get('Certificates', 'clients_folder') SSL_CLIENT_CERT = cfg.get('Certificates', 'client_cert_file') SSL_CLIENT_KEY = cfg.get('Certificates', 'client_key_file') MQTT_CLIENT_USERNAME = cfg.get('Certificates', 'client_username') SSL_SCRIPTS_DIR = cfg.get('Certificates', 'scripts_dir') GRACE_PERIOD = cfg.getint('Certificates', 'grace_period') else: logger.error( "SSL_SET either not set or set to false in config file, exiting" ) sys.exit(1) MQTT_SERVER = cfg.get('Mqtt', 'hostname') MQTT_PORT = cfg.getint('Mqtt', 'port') MQTT_KEEPALIVE = cfg.getint('Mqtt', 'keepalive') MQTT_TOPIC_CERT_RENEWAL = cfg.get('Mqtt', 'topic') MQTT_LOG_FILE = cfg.get('Mqtt', 'log_file') _mqtt_log_level = cfg.get('Mqtt', 'log_level') MQTT_LOG_LEVEL = LOG_LEVELS[_mqtt_log_level.upper()] def get_usernames(): """Return a list of usernames.""" return next(os.walk(SSL_CLIENTS_FOLDER))[1] def get_cert_key_paths(username, server): """ Return a tuple of user certificate and key pathnames. Below only works if we follow the certificate folder convention from the bash script generating certificates. Checks if files exist and returns None if not. """ if server: cert_pathname, key_pathname = SSL_SERVER_CERT, SSL_SERVER_KEY else: client_path = os.path.join(SSL_CLIENTS_FOLDER, username) cert_filename = CERT_FMT % username key_filename = KEY_FMT % username cert_pathname = os.path.join(client_path, cert_filename) key_pathname = os.path.join(client_path, key_filename) if not os.path.isfile(cert_pathname): cert_pathname = None if not os.path.isfile(key_pathname): key_pathname = None return cert_pathname, key_pathname def check_server_cert_valid(): pass def main(): """Check all user certificates valid and if not, renew.""" usernames = get_usernames() for username in usernames: check_server_cert_valid() validity = check_cert_valid(username, GRACE_PERIOD) if validity == 2: logger.error( "SSL Certificate for user %s expired, manual installation of a new certificate required!", username) elif validity == 1: if not renew_certificate(username): handle_renewal(username) else: logger.error("Renewing certificate for %s failed", username) if __name__ == '__main__': logger = logging.getLogger() get_settings(CONFIG_FILE) logging.basicConfig(filename=MQTT_LOG_FILE, level=MQTT_LOG_LEVEL) main()