#!/usr/bin/env python3
# encoding: utf-8
'''
./photosquare.py --demomode --on-nofaces=entropy --on-failsquare=grid pics-t/ out/
@author: rsu
'''

import sys, os, logging, math, re, shutil
from PIL import Image, ImageDraw, ImageOps
import dlib
import smartcrop

from pathlib import Path, PurePath
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter

__all__ = []
__version__ = '0.0.1'
__date__ = '2019-07-07'
__updated__ = '2019-07-20'

DEBUG = 0
TESTRUN = 0
PROFILE = 0

_LOGGER = logging.getLogger(__name__)

class CLIError(Exception):
	'''Generic exception to raise and log different fatal errors.'''
	def __init__(self, msg):
		super(CLIError).__init__(type(self))
		self.msg = "E: %s" % msg
	def __str__(self):
		return self.msg
	def __unicode__(self):
		return self.msg

class ArgFloatRange(object):
	def __init__(self, start, end):
		self.start = start
		self.end = end
	def __eq__(self, other):
		return self.start <= other <= self.end

def get_images_dir( fpath):
	fpatt = re.compile( r'.*.(?:jpg|png|webp|tiff|gif)', re.IGNORECASE)
	allfiles = []
	for root, dirs, filenames in os.walk( fpath):
		for f in filenames:
			if fpatt.match( f):
				allfiles.append( {"fullpath": os.path.join(root, f), "filename": f })
	_LOGGER.info( "get_images_dir: {} files from {}".format( len(allfiles), fpath))
	return allfiles

def detect_faces( detectors, image, mode="both"):
	nimg = dlib.load_rgb_image( image['fullpath'])
	_LOGGER.debug( "detect_faces: loaded {}".format( image))
	allfaces = []
	dupfaces = []
	for n, detector in enumerate( detectors):
		_LOGGER.debug("running detection #{}".format(n))
		dets = detector( nimg, 1)
		_LOGGER.debug("number of faces detected: {}".format(len(dets)))
		for i, d in enumerate(dets):
			if isinstance( d, dlib.mmod_rectangle):
				rect = d.rect
			else:
				rect = d
			_LOGGER.debug("face {}: Left: {} Top: {} Right: {} Bottom: {}".format(i, rect.left(), rect.top(), rect.right(), rect.bottom()))
			for r in allfaces:
				if get_iou( rect, r) > 0.5:
					dupfaces.append( rect)
			allfaces.append( rect )
	_LOGGER.debug("final number of faces detected: {} dups: {} cleaned: {}".format(len(allfaces), len(dupfaces), len( frozenset(allfaces).difference( dupfaces))))

	return frozenset(allfaces).difference( dupfaces)

def create_grid( img, detections, bgcolor="black", percadd = 0.2):
	maxdims = [0, 0]
	allfaces= []
	allareas = []
	cols = math.ceil( math.sqrt( len( detections['facesrects'])))
	rows = math.ceil( len( detections['facesrects']) / cols)
	h_sizes, v_sizes = [0] * cols, [0] * rows
	_LOGGER.debug("create_grid: {}x{} grid".format( cols, rows))
	
	for rect in detections['facesrects']:
		w = math.ceil( rect.width()*percadd)
		h = math.ceil( rect.height()*percadd)
		area = [ rect.left()-w, rect.top()-h, rect.right()+w, rect.bottom()+h] 
		maxdims[0], maxdims[1] = max( maxdims[0], rect.width()+2*w), max( maxdims[1], rect.height()+2*h)
		allareas.append( area)

	if cols != rows:
		maxdims[1] = math.ceil( maxdims[1] * cols/rows)
	_LOGGER.debug("create_grid maxdims:{} gridimg w:{} h:{}".format( maxdims, maxdims[0]*cols, maxdims[1]*rows))
	im_grid = Image.new('RGB', ( maxdims[0]*cols, maxdims[1]*rows), color=bgcolor)

	for i, area in enumerate( allareas):
		newdims = calc_rect_expand( [0, 0, img.size[0], img.size[1]], area, maxdims)
		_LOGGER.debug("create_grid newdims:{} c:{} r:{}".format( newdims, i % cols, i // cols))
		if cols != rows and i == len(allareas)-1 and (i % cols) == ( cols-2) and (i // cols) == ( rows-1):
			pos = ((cols-1)*maxdims[0], (i // cols)*maxdims[1] )
		else:
			pos = ((i % cols)*maxdims[0], (i // cols)*maxdims[1] )
		im_grid.paste( img.crop( newdims), pos)
	return im_grid

def calc_rect_expand( ri, rs, targetdims):
	sw = rs[2]-rs[0]
	sh = rs[3]-rs[1]
	scx = rs[0] + abs( sw/2)
	scy = rs[1] + abs( sh/2)
	cr = [ scx - targetdims[0]/2, scy - targetdims[1]/2, scx + targetdims[0]/2, scy + targetdims[1]/2]
	if cr[0] < 0:
		cr[2] -= cr[0]
		cr[0] = 0
	elif cr[2] > ri[2]:
		cr[0] -= cr[2] - ri[2]
		cr[2] = ri[2]
	if cr[1] < 0:
		cr[3] -= cr[1]
		cr[1] = 0
	elif cr[3] > ri[3]:
		cr[1] -= (cr[3] - ri[3])
		cr[3] = ri[3]
#	_LOGGER.debug("pil_grid calc_rect_expand: cr2:{}".format( cr))
	return cr

def get_iou(r1, r2):
	bb1 = { "x1": r1.left(), "y1": r1.top(), "x2": r1.right(), "y2": r1.bottom() }
	bb2 = { "x1": r2.left(), "y1": r2.top(), "x2": r2.right(), "y2": r2.bottom() }
	"""
	Credit: Martin Thoma
	Calculate the Intersection over Union (IoU) of two bounding boxes.

	Parameters
	----------
	bb1 : dict
		Keys: {'x1', 'x2', 'y1', 'y2'}
		The (x1, y1) position is at the top left corner,
		the (x2, y2) position is at the bottom right corner
	bb2 : dict
		Keys: {'x1', 'x2', 'y1', 'y2'}
		The (x, y) position is at the top left corner,
		the (x2, y2) position is at the bottom right corner

	Returns
	-------
	float
		in [0, 1]
	"""
	assert bb1['x1'] < bb1['x2']
	assert bb1['y1'] < bb1['y2']
	assert bb2['x1'] < bb2['x2']
	assert bb2['y1'] < bb2['y2']

	# determine the coordinates of the intersection rectangle
	x_left = max(bb1['x1'], bb2['x1'])
	y_top = max(bb1['y1'], bb2['y1'])
	x_right = min(bb1['x2'], bb2['x2'])
	y_bottom = min(bb1['y2'], bb2['y2'])

	if x_right < x_left or y_bottom < y_top:
		return 0.0

	# The intersection of two axis-aligned bounding boxes is always an
	# axis-aligned bounding box
	intersection_area = (x_right - x_left) * (y_bottom - y_top)

	# compute the area of both AABBs
	bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
	bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])

	# compute the intersection over union by taking the intersection
	# area and dividing it by the sum of prediction + ground-truth
	# areas - the interesection area
	iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
	assert iou >= 0.0
	assert iou <= 1.0
	return iou

def main(argv=None): # IGNORE:C0111
	'''Command line options.'''

	if argv is None:
		argv = sys.argv
	else:
		sys.argv.extend(argv)

	program_name = os.path.basename(sys.argv[0])
	program_version = "v%s" % __version__
	program_build_date = str(__updated__)
	program_version_message = '%%(prog)s %s (%s)' % (program_version, program_build_date)
	program_shortdesc = __import__('__main__').__doc__.split("\n")[1]
	program_license = '''%s

	photosquare v%s.
	requires dlib and pillow

	copyright 2019 R.S.U. GPL v3. All rights reserved.
	uses the Intersection over Union algorithm implemented by Martin Thoma
	uses smartcrop.py (https://github.com/smartcrop/smartcrop.py)

USAGE
''' % (program_shortdesc, str(__version__))

	try:
		# Setup argument parser
		parser = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter)
		parser.add_argument('-V', '--version', action='version', version=program_version_message)
		parser.add_argument('--log-level', action='store', choices=['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], default='INFO', help="set log level (default: %(default)s)")

		parser.add_argument('sourcedir', action='store', help='source images directory')
		parser.add_argument('destdir', action='store', help='destination directory')
		parser.add_argument('--clobber', action='store_true', help='overwrite existing destination files (default: %(default)s)')
		parser.add_argument('--facedetect-mode', action='store', required=False, choices=['both', 'hog', 'cnn'], default="both", help='mode of face detection (default: %(default)s)')
		parser.add_argument('--cnn-model', action='store', required=False, default="data/mmod_human_face_detector.dat", help='path to pre-trained cnn model (dl from http://dlib.net/files/mmod_human_face_detector.dat.bz2) (default: %(default)s)')
		parser.add_argument('--skip-squareness', action='store', type=float, metavar='0.0-0.2', choices=[ ArgFloatRange( 0.0, 0.2)], required=False, default=0.0, help='how square the original image has to be to be skipped (default: %(default)s)')
		parser.add_argument('--on-nofaces', action='store', required=False, choices=['skip', 'simple', 'smartcrop'], default="smartcrop", help='action if no face detected (default: %(default)s)')
		parser.add_argument('--on-failsquare', action='store', required=False, choices=['skip', 'ignore', 'grid'], default="grid", help='action if square crop impossible (default: %(default)s)')
		parser.add_argument('--grid-bgcolor', action='store', required=False, default="black", help='background color of image grid (default: %(default)s)')
		parser.add_argument('--copy-square', action='store_true', help='copy square images to destdir (default: %(default)s)')
		parser.add_argument('--demomode', action='store_true', help='don\'t crop, draw face / crop indicators instead (default: %(default)s)')

		# Process arguments
		args = parser.parse_args()
		logging.basicConfig(level=args.log_level)

		if os.path.isdir( args.sourcedir) and os.path.isdir( args.destdir):
			detectors = []
			if args.facedetect_mode == 'hog' or args.facedetect_mode == 'both':
				detectors.append( dlib.get_frontal_face_detector())
				_LOGGER.info("added frontal_face_detector")
			if args.facedetect_mode == 'cnn' or args.facedetect_mode == 'both':
				if os.path.isfile( args.cnn_model):
					detectors.append( dlib.cnn_face_detection_model_v1( args.cnn_model))
					_LOGGER.info("added cnn_face_detection_model_v1 (model: {})".format( args.cnn_model))
				else:
					_LOGGER.warn("cnn model {} not found, not using cnn_face_detection".format( args.cnn_model))
			if len( detectors) == 0:
				_LOGGER.warn("no face detectors available, using fallback mode ({})".format( args.on_nofaces))

			if args.on_nofaces == "smartcrop":
				scrop = smartcrop.SmartCrop()

			imgfiles = get_images_dir( args.sourcedir)
			alldets = []
			for imgfile in imgfiles:
				outfile = Path( args.destdir, imgfile['filename'])
				if args.clobber or not os.path.isfile( outfile):
					tmp = Image.open( imgfile['fullpath'])
					img_w, img_h = tmp.size
					if img_w == img_h or min(img_w,img_h) + max(img_h,img_w)*args.skip_squareness >= max(img_h,img_w):
						_LOGGER.info("{} is square, skipping (copy-square: {})".format( outfile, args.copy_square))
						if args.copy_square:
							shutil.copy( imgfile['fullpath'], args.destdir)
					else:
						_LOGGER.info("processing image {} ({}x{}) ratio: {:.5f} squareness: {:.2f}".format( imgfile['filename'], img_w, img_h, max(img_h,img_w)/min(img_w,img_h), min(img_w,img_h) + max(img_h,img_w)*args.skip_squareness-max(img_h,img_w)))
						alldets.append( {"image":imgfile, "facesrects": detect_faces( detectors, imgfile)})
				else:
					_LOGGER.info("{} exists, skipping".format( outfile))

			for det in alldets:
				outfile = Path( args.destdir, det['image']['filename'])
				img = Image.open( det['image']['fullpath'])
				draw = ImageDraw.Draw( img)
				img_w, img_h = img.size
				targetsize = min( img_w, img_h)
				_LOGGER.debug("image: {} faces: {} out: {} imgsize:{}x{} target:{}".format( det['image'], det['facesrects'], Path( args.destdir, det['image']['filename']), img_w, img_h, targetsize))

				if len( det['facesrects']) > 0:
					maxfaces =  [ img_w-1, img_h-1, 0, 0]
					for i, rect in enumerate( det['facesrects']):
						maxfaces = [ min( maxfaces[0], rect.left()),
										min( maxfaces[1], rect.top()),
										max( maxfaces[2], rect.right()),
										max( maxfaces[3], rect.bottom())]
						if args.demomode:
							cx = rect.left()+( rect.width()/2)
							cy = rect.top()+( rect.height()/2)
							draw.line( [ cx, rect.top(), cx, rect.bottom()], fill="#0000ff")
							draw.line( [ rect.left(), cy, rect.right(), cy], fill="#0000ff")
							draw.rectangle([ rect.left(), rect.top(), rect.right(), rect.bottom() ], width = 2, outline="#5555ff")

					mcx = maxfaces[0] + (maxfaces[2]-maxfaces[0])/2
					mcy = maxfaces[1] + (maxfaces[3]-maxfaces[1])/2
					#_LOGGER.debug("mcx: {} mcy:{}: maxfaces: {}".format( mcx, mcy, maxfaces))
					finalrect = []
					tl = max( 0, (mcx - (targetsize/2)))
					tr = min( img_w, (mcx + (targetsize/2)))
					tt = max( 0, (mcy - (targetsize/2)))
					tb = min( img_w, (mcy + (targetsize/2)))
					if img_w > img_h:
						if mcx+(targetsize/2) > img_w:
							tl = img_w - targetsize
							tr = img_w
						elif mcx-(targetsize/2) < 0:
							tl = 0
							tr = targetsize
						finalrect = [ tl,0,tr,img_h]
					else:
						if mcy+(targetsize/2) > img_h:
							tt = img_h - targetsize
							tb = img_h
						elif mcy-(targetsize/2) < 0:
							tt = 0
							tb = targetsize
						finalrect = [ 0,tt,img_w,tb]
					
					crop_ok = True
					for i, rect in enumerate( det['facesrects']):
						if finalrect[0] > rect.left() or finalrect[1] > rect.top() or finalrect[2] < rect.right() or finalrect[3] < rect.bottom():
							_LOGGER.info("face {} outside crop area".format( i))
							crop_ok = False
							break

					if not crop_ok:
						if args.on_failsquare == "ignore":
							_LOGGER.debug("cropping anyway")
							img.crop( finalrect).save( outfile)
						elif args.on_failsquare == "grid":
							grid_img = create_grid( img, det, bgcolor=args.grid_bgcolor)
							grid_img.save( outfile)
					else:
						if args.demomode:
							draw.rectangle( maxfaces, width = 1, outline="#ffff00")
							draw.line( [ mcx, maxfaces[1], mcx, maxfaces[3]], fill="#ffff00")
							draw.line( [ maxfaces[0], mcy, maxfaces[2], mcy], fill="#ffff00")
							draw.rectangle( finalrect, width = 2, outline="#ff0044")
						else:
							img = img.crop( finalrect)
						img.save( outfile)
				else:
					_LOGGER.info("no faces detected in {} (action: {})".format( outfile, args.on_nofaces))
					if args.on_nofaces == "simple":
						_LOGGER.debug("using simple")
						finalrect = []
						if img_w > img_h:
							finalrect = [ (img_w/2)-(targetsize/2), 0, (img_w/2)+(targetsize/2), img_h]
						else:
							finalrect = [ 0, 0, img_w, targetsize]
					elif args.on_nofaces == "smartcrop":
						_LOGGER.debug("using SmartCrop")
						result = scrop.crop( img, width=targetsize, height=targetsize)
						finalrect = [
							result['top_crop']['x'],
							result['top_crop']['y'],
							result['top_crop']['width'] + result['top_crop']['x'],
							result['top_crop']['height'] + result['top_crop']['y']
						]
						_LOGGER.debug("SmartCrop result: {}".format( finalrect))
					if not args.on_nofaces == "skip":
						if args.demomode:
							draw.rectangle( finalrect, width = 2, outline="#ff0044")
						else:
							img = img.crop( finalrect)
						img.save( outfile)

		else:
			_LOGGER.error( "path {} or {} not found".format( args.sourcedir, args.destdir))

	except KeyboardInterrupt:
		### handle keyboard interrupt ###
		return 0
	except Exception as e:
		if DEBUG or TESTRUN:
			raise(e)
		indent = len(program_name) * " "
		sys.stderr.write(program_name + ": " + repr(e) + "\n")
		sys.stderr.write(indent + "  for help use --help")
		return 2

if __name__ == "__main__":
	sys.exit(main())
