#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2025 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import os
import argparse
from datetime           import datetime
from re                 import search as re_search
from ax.datetime        import DateTimeNow
from ax.filesystem      import PrepareDir
from cve_backup.helpers import MySQLDump, MySQLRestore
from cve_manager.desc   import BACKUP
from cve_manager.common import NewArgParser, Init
from cve_manager.conf   import DB_CON_SEC, COMMON_SEC
from cve_manager.db     import DB

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

argparser = argparse.ArgumentParser(description=BACKUP)
argparser.add_argument(
	'-b', '--backup',
	metavar='FILE_NAME', type=str, nargs='?', const='',
	help='Make a new backup of a VUL DB and save it into a file with a given '
	'name or save it in cve-manager home dir with an automatically generated '
	'name'
	)
argparser.add_argument(
	'-s', '--store',
	metavar='N_BACKUPS', type=int,
	help='Max number of stored backup files (older backup files will be '
	'removed, this parameter is used only with the "--backup" flag and must be '
	'greater than 0)'
	)
argparser.add_argument(
	'-r', '--restore',
	metavar='FILE_NAME', type=str, nargs='?', const='',
	help='Restore VUL DB from a given *.sql backup file or use a latest one '
	'from cve-manager home dir'
	)
argparser = NewArgParser(base=argparser, ptype='m')
args = argparser.parse_args()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Generate an absolute path for a backup file

def GenFilePath(cve_manager_home):

	file_path = ''

	while True:

		file_name = f'bak{DateTimeNow()}.sql'
		file_path = os.path.join(cve_manager_home, file_name)

		if not os.path.exists(file_path):
			break

	return file_path

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Form an absolute path for a file that has specified name and that exists
# either in a current dir or in a specified dir

def CustomFilePath(file_name, cve_manager_home, check_existence=False,
		check_sub_dir_wperm=False):

	for sub_dir in (os.getcwd(), cve_manager_home):
		abs_file_path = os.path.join(sub_dir, file_name)
		if check_existence and not os.path.isfile(abs_file_path):
			continue
		if check_sub_dir_wperm and not os.access(sub_dir, os.W_OK):
			continue
		return abs_file_path

	return file_name

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Form a list of (<file_path>, <time_stamp>) pairs for backup files stored in a
# specified dir; Formed list is sorted by file paths in ascending order

def BackupFiles(cve_manager_home):

	names = []

	for name in os.listdir(cve_manager_home):
		if re_search(r'^bak\d+-\d+-\d+_\d+:\d+:\d+\.sql$', name):
			names.append(name)

	# Get timestamp from a given file name
	def TS(name):
		return datetime.strptime(name, 'bak%Y-%m-%d_%H:%M:%S.sql')
	# - nested func

	# Get absolute path of a backup file
	def AbsPath(name):
		return os.path.join(cve_manager_home, name)
	# - nested func

	backups = [(AbsPath(name), TS(name)) for name in names]

	return sorted(backups, key=lambda pair: pair[1])

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Backuping the DB to SQL-file

def BackupDB(printer, cve_manager_home):

	printer.LineBegin(f'Backuping "{mysql_config["database"]}" database')

	# Checking the DB connection
	db = DB(printer)
	if not db.Connect(mysql_config) or not db.Disconnect():
		return False

	# Removing excessive backups
	if args.store:
		backups = BackupFiles(cve_manager_home)
		for path, _ in backups[:(len(backups) - args.store + 1)]:
			os.remove(path)
			printer.LineAddExtra(f'"{path}" file was removed')

	# If a file name for a backup file is given
	if args.backup:
		file_path = args.backup if args.backup[0] == '/' else \
			CustomFilePath(args.backup, cve_manager_home, check_sub_dir_wperm=True)
	else:
		file_path = GenFilePath(cve_manager_home)

	printer.LineCat(f' to "{file_path}"')

	# Checking a write permission
	file_subdir = os.path.dirname(file_path)
	if not os.access(file_subdir, os.W_OK):
		printer.Err(f'Permission denied')
		return False

	err = MySQLDump(file_path, mysql_config)

	return printer.BiStatus(not err, err)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Restoring the DB from SQL-file

def RestoreDB(printer, cve_manager_home):

	printer.LineBegin(f'Restoring "{mysql_config["database"]}" database')

	# If source file is given
	if args.restore:
		file_path = args.restore if args.restore[0] == '/' else \
			CustomFilePath(args.restore, cve_manager_home, check_existence=True)
	else:
		backups = BackupFiles(cve_manager_home)
		if not backups:
			printer.Err('Can\'t find any backup files')
			return False
		file_path, _ = backups[-1]

	printer.LineCat(f' from "{file_path}"')

	# Checking an existence of a file
	if not os.path.isfile(file_path):
		printer.Err(f'Can\'t find "{file_path}" file')
		return False

	# Recreating the DB
	db = DB(printer)
	if not db.Connect(mysql_config, recreate_db=True) or not db.Disconnect():
		return False

	err = MySQLRestore(file_path, mysql_config)

	return printer.BiStatus(not err, err)

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	if (args.backup == None and args.restore == None) or \
			(args.store != None and (args.backup == None or args.store < 1)):
		argparser.print_help()
		exit(1)

	# Initialising the printer and reading the configuration file
	printer, conf = Init(args)
	mysql_config, common_params = conf.Get([DB_CON_SEC, COMMON_SEC])

	# Preparing the home dir
	cve_manager_home, err = PrepareDir(common_params['download'])
	if not cve_manager_home:
		printer.Err(err)
		exit(1)

	# Backuping or restoring the DB
	if args.backup != None:
		if not BackupDB(printer, cve_manager_home):
			exit(1)
	else:
		if not RestoreDB(printer, cve_manager_home):
			exit(1)

	exit(0)
