#!/usr/bin/env python3
# vim: noexpandtab shiftwidth=8 softtabstop=0


###############################################################################
#
# Helper script to try out our DEF CON CTF services
# Based on our dc2020q version (2020 online qualifier)
#
# Order of the Overflow
# https://oooverflow.io/
#
###############################################################################


IMAGE_FMT = "dc2020q:%s"

SHORTREAD_ALLOWED_DIFF = 2  # You can have this number of remaining processes (override: shortread_allowed_diff: -1)

import concurrent.futures
import urllib.request
import subprocess
import argparse
import logging
import socket
import json
import yaml
import time
import sys
import re
import os
import traceback

logging.basicConfig()
_LOG = logging.getLogger("OOO")
_LOG.setLevel("DEBUG")
try:
	import coloredlogs
	coloredlogs.install(logger=_LOG, level=_LOG.level)
except ImportError:
	pass

service_dir = os.path.dirname(__file__)

dsystem = os.system  # But see cmdline options
def system_without_stdout(cmd):
	p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
	out, _ = p.communicate()
	if p.returncode != 0:
		_LOG.warning("Command %s failed (%d). Stdout was: %s", cmd, p.returncode, out)
	return p.returncode


def get_healthcheck_info():
	hc = [ k for k in service_conf.keys() if k.startswith('healthcheck_') ]
	tcp_send = None
	if 'healthcheck_tcp_send' in hc:
		tcp_send = service_conf['healthcheck_tcp_send'].encode('ascii','strict')
		hc.remove('healthcheck_tcp_send')
	if not hc: return None
	assert len(hc) == 1, "More than one healthcheck_xxx line?!?"
	protocol = hc[0][len('healthcheck_'):]
	rgx = ""
	if service_conf[hc[0]]:
		rgx = service_conf[hc[0]].encode('ascii','strict')  # TODO: matches reality?
	return protocol, re.compile(rgx), tcp_send

def simulate_healthcheck(protocol, regex, tcp_send, host, port):
	_LOG.info("Simulating a %s healthcheck %s:%d -> regex %s", protocol, host, port, repr(regex))
	if protocol not in ("tcp","http"):
		_LOG.warning("TODO: missing %s healthcheck simulation", protocol)
		return None
	try:
		if protocol == 'http':
			assert tcp_send is None
			with urllib.request.urlopen('http://{}:{}/'.format(host,port), timeout=5) as u:
				if u.getcode() != 200:
					_LOG.critical('Got %d %s [!= 200] for %s (info: %s)',
							u.getcode(), u.reason, u.geturl(), u.info())
				else:
					_LOG.debug('Got %d %s for %s',
							u.getcode(), u.reason, u.geturl())
				rdata = u.read()
		else:
			with socket.create_connection((host,port), timeout=5) as c:
				c.settimeout(5)
				if tcp_send is not None:
					_LOG.debug("Sending %s ...", tcp_send.decode('ascii','backslashreplace'))
					c.sendall(tcp_send)
				if regex.pattern:  # Empty healthcheck_tcp => just try connecting
					rdata = c.recv(1024)  # TODO: loop over received lines instead
					_LOG.debug("TCP healthcheck received: %s", rdata.decode('ascii','backslashreplace'))
		if regex.pattern:  # Empty healthcheck_tcp => just try connecting
			rdata_msgstr = rdata.decode('ascii','backslashreplace')
			m = regex.search(rdata)
			if m:
				_LOG.debug("Matched: %s", str(m))
			else:
				_LOG.error("Simulated healthcheck failed -- received %s (didn't match %s)", rdata_msgstr, repr(regex))
				return False
		_LOG.info("Simulated healthcheck passed, good!")
		return True
	except Exception as e:
		_LOG.critical("Got an exception while simulating a healthcheck on (%s:%d) -> %s %s", host, port, type(e), str(e))



def build_service():
	if os.path.exists(os.path.join(service_dir, "service", "Dockerfile")):
		_LOG.info("Building service image...")
		build_args = ""
		if service_conf.get('copy_flag_using_build_arg'):
			build_args = "--build-arg THE_FLAG='%s'" % service_conf["flag"]
		assert dsystem("docker build %s -t %s %s/service" % (build_args, image_tag, service_dir)) == 0, "service docker image build failed"
	else:
		_LOG.warning("no dockerfile found for service...")

def build_interactions():
	if os.path.exists(os.path.join(service_dir, "interaction", "Dockerfile")):
		_LOG.info("Building interaction image...")
		assert dsystem("docker build -t %s %s/interaction" % (interaction_image_tag, service_dir)) == 0, "interaction docker image build failed"

def _start_container():
	_stop_container()
	assert dsystem("docker run --name %s --rm -d %s" % (container_tag, image_tag)) == 0, "service container failed to start"
def _stop_container():
	dsystem("docker kill %s 2>/dev/null >/dev/null | true" % container_tag)
	dsystem("docker rm %s 2>/dev/null >/dev/null | true" % container_tag)

def launch_service():
	_LOG.debug("starting container")
	_start_container()
	time.sleep(1)
	container_config = json.loads(subprocess.check_output("docker inspect %s" % container_tag, shell=True).decode('utf-8'))
	ip_address = container_config[0]['NetworkSettings']['Networks']['bridge']['IPAddress']
	port = list(container_config[0]['Config']['ExposedPorts'].keys())[0].split("/")[0]
	_LOG.debug("network endpoint: %s:%s", ip_address, port)
	return ip_address, port

def test_service(ip_address, port, exploit=True, host_net=False):
	_LOG.info("Testing container...")

	_LOG.info("launching interaction container")
	interaction_container = subprocess.check_output("docker run %s --rm -i -d %s" % ("--net=host" if host_net else "", interaction_image_tag), shell=True).strip().decode('utf-8')

	def _run_interaction(script):
		if os.path.basename(script).startswith("exploit") and not exploit:
			_LOG.info("skipping exploit checks")
			return
		_LOG.info("launching %s", script)

		try:
			output = subprocess.check_output("docker exec %s %s %s %s" % (interaction_container, script, ip_address, port), shell=True,
					stderr=None if _LOG.getEffectiveLevel() < logging.WARNING else subprocess.STDOUT)
			_LOG.debug("%s -out-> %s", script, output)
		except subprocess.CalledProcessError as e:
			_LOG.critical("ERROR: %s returned %d -- full cmd: %s output: %s", script, e.returncode, e.cmd, e.output)
			sys.exit(e.returncode)

		if os.path.basename(script).startswith("exploit"):
			_LOG.info("... checking for the flag")
			assert service_conf['flag'].encode('utf-8') in output, "exploit %s did not print the flag"%script #pylint:disable=unsupported-membership-test

	_LOG.info("launching interaction scripts")
	interaction_files = service_conf['interactions']
	for f in interaction_files:
		_run_interaction(f)

	_LOG.info("STRESS TEST TIME")
	n = 2
	old_level = _LOG.level
	while n <= service_conf['concurrent_connections']:
		_LOG.info("stress testing with %d concurrent connections!", n)
		_LOG.setLevel(max(logging.WARNING, old_level))
		with concurrent.futures.ThreadPoolExecutor(max_workers=n) as pool:
			results = pool.map(_run_interaction, (interaction_files*n)[:n])
		try:
			for result in results:
				pass
		except Exception as e:
			_LOG.error('One iteration returns an exception: %s' % str(e))
			_LOG.error(traceback.format_exc())
			sys.exit(1)

		_LOG.setLevel(old_level)

		n *= 2

	_LOG.info("SHORT-READ SANITY CHECK")
	allowed = service_conf.get('shortread_allowed_diff', SHORTREAD_ALLOWED_DIFF)
	if SHORTREAD_ALLOWED_DIFF >= 0 and allowed >= 0:
		start_num_procs = len(subprocess.check_output("docker exec %s ps aux" % container_tag, shell=True).splitlines())
		assert os.system('docker run --rm ubuntu bash -ec "for i in {1..128}; do echo > /dev/tcp/%s/%s; done"' % (ip_address, port)) == 0
		_LOG.info("waiting for service to clean up after short reads")
		time.sleep(15)
		final_num_procs = len(subprocess.check_output("docker exec %s ps aux" % container_tag, shell=True).splitlines())
		assert final_num_procs < (start_num_procs + allowed), "your service did not clean up after short reads -- starting procs = {sp} final={fp}".format(sp=start_num_procs, fp=final_num_procs)
	else:
		_LOG.info("The short-read test is disabled")

	_LOG.info("stopping interaction container")
	dsystem("docker kill %s" % interaction_container)

	hck = get_healthcheck_info()
	if hck is not None:
		protocol, regex, tcp_send = hck
		simulate_healthcheck(protocol, regex, tcp_send, ip_address, int(port))



def list_public_files():
	if not ('public_files' in service_conf and service_conf['public_files']):
		print("")
		print("")
		print("^^^ \033[36m No Public Files Found \033[0m")
		print("")
		print("")
		return ""

	_LOG.info("Looking at public files...")
	for f in service_conf['public_files']:
		bname = os.path.basename(f)
		print("Public file: %s <-- %s" % (bname, f))
		assert os.path.exists(f), "Public file not found: {} -- remember that all public files must be pre-built and checked into git".format(f)
		assert os.path.isfile(f), "Only regular files for the public: {}".format(f)
		assert not os.path.islink(f), "No symlinks for the public: {}".format(f)



if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument("--log-level", metavar='LVL', help="WARNING will also sink docker output. Default: DEBUG")
	parser.add_argument("--use-cwd", action="store_true", help="Use CWD instead of script location for service directory")
	parser.add_argument("--force-color", action="store_true", help="Force color even if not on a TTY")
	parser.add_argument("cmds", metavar='CMD...', nargs=argparse.REMAINDER, help="nothing (or single steps: build, test, launch)")

	args = parser.parse_args()
	if args.force_color:
		coloredlogs.install(logger=_LOG, level=_LOG.level, isatty=True)
	if args.log_level:
		_LOG.setLevel(args.log_level)
	if _LOG.getEffectiveLevel() >= logging.WARNING:
		dsystem = system_without_stdout

	if args.use_cwd:
		service_dir = os.getcwd()

	_LOG.info("USING YAML: %s/info.yml", service_dir)
	if not os.path.exists(os.path.join(service_dir, "info.yml")):
		_LOG.critical("info.yml not found -- either:")
		_LOG.critical("  - copy me to the service directory, or")
		_LOG.critical("  - cd service_directory; %s --use-cwd ...", os.path.abspath(__file__))
		sys.exit(2)
	with open(os.path.join(service_dir, "info.yml")) as yf:
		service_conf = yaml.safe_load(yf)
	service_name = service_conf['service_name']
	_LOG.info("SERVICE ID: %s", service_name)

	image_tag = IMAGE_FMT % service_name
	interaction_image_tag = IMAGE_FMT % service_name + '-interaction'
	container_tag = "running-%s" % service_name


	assert not any(('--' in c) for c in args.cmds)   # XXX: we should really rewrite this thing
	sys.argv = [sys.argv[0]] + args.cmds
	arg = sys.argv[1] if len(sys.argv) >= 2 else ""
	if arg == 'list_public_files':
		list_public_files()
	elif arg == 'build':
		build_service()
		build_interactions()
		list_public_files()
	elif arg == 'test':
		# test
		# test ip port [noexploit]
		if len(sys.argv) == 2:
			_ip_address, _port = launch_service()
			test_service(_ip_address, _port)
		elif len(sys.argv) in (4,5):
			port = int(sys.argv[3])
			test_exploits = not((len(sys.argv)>=5) and (sys.argv[4] == 'noexploit'))
			test_service(sys.argv[2], port, exploit=test_exploits)
		else:
			print("Usage: ... test", file=sys.stderr)
			print("Usage: ... test ip port [noexploit]", file=sys.stderr)
			sys.exit(1)
		_stop_container()
	elif arg == 'launch':
		build_service()
		try:
			_ip_address, _port = launch_service()
			print("")
			print("SERVICE RUNNING AT: %s %s" % (_ip_address, _port))
			print("nc %s %s" % (_ip_address, _port))
			print("%s test %s %s" % (sys.argv[0], _ip_address, _port))
			print("%s:%s" % (_ip_address, _port))
			print("")
			print("Press ENTER to stop the continainer")
			input()
		finally:
			_stop_container()
	else:
		assert len(sys.argv) == 1, "Unknown command '{}', try --help".format(sys.argv[1])
		try:
			build_service()
			build_interactions()
			_ip_address, _port = launch_service()
			test_service(_ip_address, _port)
			list_public_files()
		finally:
			cont = input('Keep the container running so you can try solving the challenge? [y/n] ')
			if cont and cont.lower() in ('yes','y'):
				print("")
				print("SERVICE RUNNING AT: %s %s" % (_ip_address, _port))
				print("nc %s %s" % (_ip_address, _port))
				print("%s test %s %s" % (sys.argv[0], _ip_address, _port))
				print("%s:%s" % (_ip_address, _port))
				print("")
				print("Press ENTER to stop the continainer")
				input()
			_stop_container()
