#!/usr/bin/python3

#	cve-manager : CVE management tool
#	Copyright (C) 2017-2025 Alexey Appolonov
#
#	This program is free software: you can redistribute it and/or modify
#	it under the terms of the GNU General Public License as published by
#	the Free Software Foundation, either version 3 of the License, or
#	(at your option) any later version.
#
#	This program is distributed in the hope that it will be useful,
#	but WITHOUT ANY WARRANTY; without even the implied warranty of
#	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#	GNU General Public License for more details.
#
#	You should have received a copy of the GNU General Public License
#	along with this program.  If not, see <http://www.gnu.org/licenses/>.

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import argparse
import requests
from collections     import defaultdict
from os              import path
from cve_manager.url import ParseURL, URL_API_PATH, URL_SPEC, FMT_GITHUB, \
	FMT_SOURCEFG, FMT_PYTHON

DESCRIPTION = 'Collect Homepage-URLs of specified projects'
HOSTS = {
	FMT_GITHUB: ('api.github.com/repos',
		('homepage',)),
	FMT_SOURCEFG: ('sourceforge.net/rest/p',
		('external_homepage', 'moved_to_url',)),
	FMT_PYTHON: ('pypi.python.org/pypi',
		('info/home_page',)),
	}

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Parsing the arguments

argparser = argparse.ArgumentParser(description=DESCRIPTION)
argparser.add_argument(
	'-u', '--urls',
	metavar='URLs', type=str, nargs='+', required=True,
	help='URLs that specifiy the projects'
	)
argparser.add_argument(
	'-l', '--lim',
	metavar='CONSECUTIVE_FAILS_LIMIT', type=int, default=-1,
	help='Number of failed requests to a host before data from that host '
	'is no longer requested'
	)
args = argparser.parse_args()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# If requests for some supported host are consecutivly failed specified number
# of times then this host is excluded from subsequent processing

count_consecutive_failed_requests = {k: 0 for k in HOSTS.keys()}
blocked_hosts = {k: False for k in HOSTS.keys()}


def BlockHostIfLimitReached(host_id):

	count_consecutive_failed_requests[host_id] += 1

	if args.lim > 0 and count_consecutive_failed_requests[host_id] > args.lim:
		blocked_hosts[host_id] = True


def ResetCounterIfHostNotBlocked(host_id):

	if not blocked_hosts[host_id]:
		count_consecutive_failed_requests[host_id] = 0

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

def GetProjectApiPath(url):

	parsed_url = ParseURL(url)
	if not parsed_url:
		return -1, '', '', ''

	host_id = parsed_url.get(URL_SPEC)
	if host_id == None:
		return -1, '', '', ''

	host_api_path, params = HOSTS.get(host_id, (None, None))
	if not host_api_path:
		return -1, '', '', ''

	project_path = parsed_url.get(URL_API_PATH)
	if not project_path:
		return -1, '', '', f'Can\'t get the path to form the query, URL: "{url}"'

	project_api_path = f'https://{host_api_path}' + \
		('/' if project_path[0] != '/' else '') + project_path

	return host_id, project_api_path, params, ''

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

def ParseUrls(urls):

	res = defaultdict(set)
	res_err = set()
	warn = set()

	for url in urls:

		host_id, project_api_path, params, w = GetProjectApiPath(url)
		if not project_api_path:
			res_err.add(url)
			if w:
				warn.add(w)
			continue

		res[(host_id, project_api_path, params)].add(url)

	return res, res_err, warn

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

def GetHomePage(host_id, api_path, param, extracted_urls):

	try:
		resp = requests.get(api_path, timeout=(3.05, 15))
		if resp.status_code != 200:
			err = f'Response {resp.status_code} from {api_path}'
			if resp.status_code == 404:
				return set('!'), err
			BlockHostIfLimitReached(host_id)
			return set('?'), err
		json = resp.json()
	except Exception as exc:
		err = f'Can\'t get JSON doc, {exc}'
		return set(), err

	res = json
	for k in param.split('/'):
		res = res.get(k) if type(res) == dict else None
		if not res:
			err = f'Can\'t get value of the "{k}" key of the JSON document' \
				if res == None else None
			return set(), err
	home_page = res

	# If there was a redirect then save the URL to which the redirect occurred
	if all([r.is_redirect for r in resp.history]):
		redirect_url = json.get('html_url')
		if redirect_url:
			extracted_urls.add(redirect_url)

	# Prevent an endless loop
	if home_page in extracted_urls:
		return set(), None

	# Don't need URLs of the same platform as the source URL
	new_host_id, new_api_path, new_params, err = GetProjectApiPath(home_page)
	if err or api_path == new_api_path:
		return set(), err if err else None

	extracted_urls.add(home_page)
	ResetCounterIfHostNotBlocked(host_id)

	if not new_api_path:
		return extracted_urls, None

	extracted_urls_next = set()

	for new_param in new_params:
		extracted_urls_next, err = \
			GetHomePage(new_host_id, new_api_path, new_param, extracted_urls)
		if err:
			return extracted_urls, err

	return extracted_urls | extracted_urls_next, None

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

if __name__ == '__main__':

	parsed_urls, parsed_with_err_urls, warn = ParseUrls(args.urls)

	res_urls = defaultdict(set)
	skipped = set()

	for coordinates, initial_urls in parsed_urls.items():
		host_id, api_path, params = coordinates
		for param in params:
			if blocked_hosts[host_id]:
				skipped |= initial_urls
				continue
			related_urls, err = GetHomePage(host_id, api_path, param, set())
			if err:
				warn.add(err + ", URL(s): " + ", ".join(initial_urls))
			if related_urls:
				for url in initial_urls:
					res_urls[url] |= related_urls

	for initial_urls in parsed_urls.values():
		for url in initial_urls:
			if not res_urls.get(url):
				res_urls[url].add('?' if url in skipped else '-')

	for url in parsed_with_err_urls:
		res_urls[url] = {'!'}

	for url in args.urls:
		related_urls = res_urls.get(url)
		if related_urls:
			print(f'{url} >> {" ".join(sorted(related_urls))}')

	for w in warn:
		print(f'[WARNING: {w}]')

	exit(0)
