#!/usr/bin/env python

import cv2
import os
import argparse
import pdb
import numpy as np


C_INF = float('Inf')
ONE_PIXEL_COLOUR = (255, 255, 255)
NO_PATCH_COLOUR = (0, 0, 0)
INIT_COLOUR = (255 * np.random.random(3)).astype(int)


def make3dGrayscaleCopy(image):
    copy = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
    for i in range(3):
        copy[:, :, i] = image
    return copy


def incrementColour(colour):
    primeNumbers = np.array([541, 293, 433])
    newColour = (np.array(colour) + primeNumbers) % 256
    return newColour


def runExperiment(imageFilename, processAndRender, numScales=4, splitColours=False):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    if splitColours:
        image = cv2.imread(imageFilename)
        images = [clahe.apply(image[:, :, i]) for i in range(image.shape[2])]
    else:
        image = cv2.imread(imageFilename, 0)
        image = clahe.apply(image)
        images = [image]

    for image in images:
        for i in range(numScales):
            print "current scale: 1/(2^" + str(i) + ")"
            resizedImage = cv2.resize(image, (image.shape[1] / 2**i,
                                              image.shape[0] / 2**i))
            resultsImage = processAndRender(resizedImage)
            resizedResultsImage = cv2.resize(resultsImage, (image.shape[1],
                                                            image.shape[0]))
            cv2.imshow('raw image', image)
            cv2.imshow('detected objects', resizedResultsImage)
            key = cv2.waitKey(0)
            if key == 27:
                raise ValueError("user pressed escape")


def findAndShowMSER(img):
    vis = img.copy()
    mser = cv2.MSER()
    regions = mser.detect(img, None)
    hulls = [cv2.convexHull(p.reshape(-1, 1, 2)) for p in regions]
    # outerHulls = []
    # print hulls[0], type(hulls[0])
    # for hull1 in hulls:
    #     include = True
    #     for hull2 in hulls:
    #         if hull2 is hull1:
    #             continue
    #         contains = True
    #         for point in hull1:
    #             if cv2.pointPolygonTest(contour=tuple(hull2), pt=tuple(point),
    #                                     measureDist=False) < 0:
    #                 contains = False
    #                 break
    #         if contains:
    #             include = False
    #             break
    #     if include:
    #         outerHulls.append(hull1)

    cv2.polylines(vis, hulls, 1, (0, 255, 0))
    return vis

            
# salient patches
    

def findAndShowSalientPatches(image, beta, k):
    componentImage = salientPatchProposal(image, beta, k)
    dispImage = renderSalientPatches(componentImage)
    return dispImage


def salientPatchProposal(image, beta, k, numScales=3):
    # scale the image to 0 to 1
    image = image.astype(float) / 255

    # I think this will not handle regions of texture very well.
    E = []
    for xi in range(image.shape[1]):
        for yi in range(image.shape[0]):
            iVal = image[yi, xi]
            for xj, yj in [(xi-1, yi-1), (xi-1, yi), (xi, yi-1), (xi+1, yi-1)]:
                if xj >= 0 and yj >= 0 and xj < image.shape[1]:
                    jVal = image[yj, xj]
                    # could experiment with other weight functions
                    weight = 1 - abs(iVal - jVal)
                    E.append(((yj, xj), (yi, xi), weight))

    E = sorted(E, key=lambda x: x[2])

    componentImage = np.zeros(image.shape, dtype=int)
    maxComponentWeights = [C_INF]
    componentSizes = [1]
    if not type(k) is float:
        k = float(k)

    # This is to scale appropriately with the different gaussian scales
    k = image.size * k / 241920
    for iPos, jPos, weight in E:
        # if xi, yi and xj, yj are not in the same component:
        Ci = componentImage[iPos]
        Cj = componentImage[jPos]
        if (Ci == 0 and Cj == 0) or (Ci != Cj):
            intCi = maxComponentWeights[Ci] + (k / componentSizes[Ci])
            intCj = maxComponentWeights[Cj] + (k / componentSizes[Cj])

            # compute MInt of the two components
            MInt = min(intCj, intCi)
            if weight < MInt:
                # merge the two components
                if Ci == 0 and Cj == 0:
                    Ci = len(maxComponentWeights)
                    componentImage[iPos] = Ci
                    componentImage[jPos] = Ci
                    maxComponentWeights.append(weight)
                    componentSizes.append(2)
                else:
                    if Cj == 0:
                        componentImage[jPos] = Ci
                        maxComponentWeights[Ci] = max(maxComponentWeights[Ci],
                                                      weight)
                        componentSizes[Ci] += 1
                    elif Ci == 0:
                        componentImage[iPos] = Cj
                        maxComponentWeights[Cj] = max(maxComponentWeights[Cj],
                                                      weight)
                        componentSizes[Cj] += 1
                    else:
                        componentImage[componentImage == Cj] = Ci
                        maxComponentWeights[Ci] = max(maxComponentWeights[Ci],
                                                      maxComponentWeights[Cj],
                                                      weight)
                        # wipe out the old component's weight
                        maxComponentWeights[Cj] = C_INF
                        componentSizes[Ci] += componentSizes[Cj]

    for Ci in range(len(maxComponentWeights)):
        if maxComponentWeights[Ci] < C_INF:
            saliency = 0
            yMin, yMax, xMin, xMax = findBoundingBox(componentImage, Ci)
            bboxWidth = xMax - xMin
            bboxHeight = yMax - yMin
            bboxCenter = (yMin + bboxHeight / 2, xMin + bboxWidth / 2)
            CiPixels = image[componentImage == Ci]
            CiSum = CiPixels.sum()
            CiMean = CiSum / CiPixels.size
            for j in range(numScales):
                # compute the scaled bounding box
                scale = 1.3**j
                jyMin = max(bboxCenter[0] - (bboxHeight * scale), 0)
                jyMax = min(bboxCenter[0] + (bboxHeight * scale),
                            image.shape[0])
                jxMin = max(bboxCenter[1] - (bboxWidth * scale), 0)
                jxMax = min(bboxCenter[1] + (bboxWidth * scale),
                            image.shape[1])
                Bj = image[jyMin:jyMax, jxMin:jxMax]
                BjOverCiMean = (Bj.sum() - CiSum) / (Bj.size - CiPixels.size)

                # saliency = sum over bounding boxes at different scales
                saliency += abs(CiMean - BjOverCiMean)
            if saliency < beta:
                # remove component from components
                componentImage[componentImage == Ci] = -1

    # pdb.set_trace()
    return componentImage


def renderSalientPatches(componentImage):
    print 'rendering...'
    dispImage = np.zeros((componentImage.shape[0], componentImage.shape[1], 3),
                         dtype=np.uint8)
    colourDict = {-1: NO_PATCH_COLOUR, 0: NO_PATCH_COLOUR}
    nextColour = INIT_COLOUR.copy()
    for i in range(componentImage.shape[0]):
        for j in range(componentImage.shape[1]):
            C = componentImage[i, j]
            if C not in colourDict:
                colourDict[C] = nextColour
                nextColour = incrementColour(nextColour)
            dispImage[i,j] = colourDict[C]
    print 'done rendering'
    return dispImage


def findBoundingBox(array, value):
    a = np.where(array == value)
    bbox = np.min(a[0]), np.max(a[1]), np.min(a[1]), np.max(a[1])
    return bbox


# canny edge functions.  probably not used anymore.
def findAndShowContours(image, cannyThreshold1, cannyThreshold2):
    edgeImage = cv2.Canny(image, cannyThreshold1, cannyThreshold2)
    # input to this should be product of an edge-finding algorithm.
    contours, hierarchy = cv2.findContours(edgeImage, cv2.RETR_TREE,
                                           cv2.CHAIN_APPROX_SIMPLE)
    dispImage = make3dGrayscaleCopy(image)
    for contour in contours:
        # check if contour is convex
        # if cv2.isContourConvex(contour):
        hull = cv2.convexHull(contour)
        colour = (np.random.random(3) * 255).astype(int)
        cv2.drawContours(dispImage, [hull], 0, colour, 1)
        # cv2.polylines(dispImage, hull, 1, colour)
    return dispImage


def runContoursAcrossThresholds(imageFilename, stepSize):
    for i in range(0, 255, stepSize):
        for j in range(0, 255, stepSize):
            print 'threshold 1:', i, 'threshold 2:', j
            runExperiment(ifn, lambda x: findAndShowContours(x, i, j),
                          numScales=2, splitColours=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    group = parser.add_mutually_exclusive_group()
    group.add_argument('-i',  '--image')
    group.add_argument('-d', '--dir')
    args = parser.parse_args()

    if not (args.image or args.dir):
        print "Error: Either an image or a dir must be specified."
        exit

    imageFilenames = []

    if args.image:
        imageFilenames = [args.image]

    if args.dir:
        imageFilenames = [os.path.join(args.dir, fn)
                          for fn in os.listdir(args.dir)]

    for ifn in imageFilenames:
        # # Canny contours
        runContoursAcrossThresholds(ifn, stepSize=32)

        # Salient patch proposal
        # beta = 0.1
        # k = 100
        # runExperiment(ifn, lambda x: findAndShowSalientPatches(x, beta, k),
        #               numScales=4)

        runExperiment(ifn, findAndShowMSER, numScales=1, splitColours=True)
