# external imports
import pdb
import random
import numpy as np
# from tf import transformations

# internal imports
from LandmarkMap import LandmarkMap
from geometryUtils import angle_between, get2dRotmatFromYaw


class Localizer(object):

    def __init__(self, mapFilename):
        self.landmarkMap = LandmarkMap(mapFilename)

    def localize(self, seenLandmarks):
        if len(self.landmarkMap.size) == 2:
            position = np.ones(3, dtype=float)
            position[:2] = self.landmarkMap.size / 2
            orientation = np.pi / 4
        else:
            position = np.array(self.landmarkMap.size / 2, dtype=float)
            # the null, no-rotation quaternion, for the 3D case
            orientation = np.array([0, 0, 0, 1])

        # probability, position, orientation
        return [(0.33, position, orientation),
                (0.67, position / 5, 0)]

    def planeWorldLocalize(self, seenLandmarks):
        n = 2
        m = 1
        k = 1
        d = 1
        print "saw", len(seenLandmarks), "landmarks"
        if len(seenLandmarks) < n:
            print 'not enough landmarks seen to localize'
            return None
        return self.ransacLocalize(seenLandmarks, getModelFromPlaneWorldMatches,
                                   isPlaneWorldInlier, computePlaneWorldError,
                                   n, m, k, d)

    def exact2dLocalize(self, seenLandmarks):
        usingLandmarks = seenLandmarks[:2]
        usingMatches = [(l, self.landmarkMap.getLandmarkMatches(l)[0])
                        for l in usingLandmarks]
        camPos, camYaw = getModelFromPlaneWorldMatches(usingMatches)
        return [(1, camPos, camYaw)]

    def ransacLocalize(self, seenLandmarks, getModelFromMatches, isInlier,
                       computeError, n, m, k, d):
        """
        The "data points" selected for ransac are matches between objects \
        in seenLandmarks and objects in the map.

        seenLandmarks: the collection of observed landmarks to fit to the map.
        getModelFromMatches: a function that computes a transform from a set \
        of point matches.
        isInlier: a function that determines whether a landmark should or \
        should not be considered an "inlier" given a model transform and a map.
        computeError: measures the error of a set of matches given a model.
        n: number of landmarks to assume are inliers from seenLandmarks. \
        n = 2 for the 2D and 2D+scale case, n = 4 in full 3D.
        m: number of best matches of seen landmarks to consider from the map.
        k: number of iterations before terminating.
        d: number of inliers that a model must have to be considered
        """

        i = 0
        consideredMatches = []
        bestError = None
        bestModel = None
        while i < k:
            i += 1
            if len(consideredMatches) == 0:
                # randomly select n points from seenLandmarks
                nLandmarks = random.sample(seenLandmarks, n)
                # for each point, retrieve m possible matches from the map
                matchSets = [(l, self.landmarkMap.getLandmarkMatches(l, m))
                             for l in nLandmarks]
                # generate sets of possible matches, sorted with best set last
                tuples = [(j,) for j in range(m)]
                for _ in range(n-1):
                    tuples = [t + (j,) for j in range(m) for t in tuples]
                sortedTuples = sorted(tuples, key=sum, reverse=True)

                for indexTuple in sortedTuples:
                    newMatchSet = [(l, ms[indexTuple[j]])
                                   for j, (l, ms) in enumerate(matchSets)]
                    consideredMatches.append(newMatchSet)

            # for each set of possible matches:
            testMatches = consideredMatches.pop()
            model = getModelFromMatches(testMatches)
            seenLandmarksToCheck = set(seenLandmarks) - set(testMatches[0])
            foundMatches = [(sl, isInlier(sl, model, self.landmarkMap))
                            for sl in seenLandmarksToCheck]
            foundMatches = filter(lambda m: m[1] is not False, foundMatches)
            if len(foundMatches) > d:
                inlierMatches = testMatches + foundMatches
                # TODO in the planar case, this won't change, so could remove
                # or replace with ICP
                # model = getModelFromMatches(inlierMatches)
                error = computeError(inlierMatches, *model)
                if bestError is None or error < bestError:
                    bestError = error
                    bestModel = model

        if bestModel is None:
            print "no good model found!"

        return [(1, bestModel[0], bestModel[1])]

    def qualityLocalize(self, seenLandmarks, getModelFromMatches):
        # get a quality score for the closest match of each seen landmark
        # do ransac, but ordered by the quality scores of the best matches,
        # 2nd-best matches, etc.
        pass


def isPlaneWorldInlier(seenLandmark, model, landmarkMap, nClosest=1, nBest=1):
    """ How do we decide what is and is not an inlier?
    it should be a function of the distances to the nearest point
    that is, both in space and in the distance between their feature vectors
    we could make the inlier status be:
    is the physically-closest landmark the same as the best match in the map?
    This is rigourous, maybe too rigourous.  Could be top-N closest and/or
    top-N matches?  So if the best match is in the top N closest, or if the

    closest is in the top N matches...or if at least one of the top N
    closest is in the top M matches...
    """
    # in this case, we assume a 2.5d camera model
    camPos, camYaw = model
    # transform seenLandmark into map space using the model
    mapSeenLandmark = seenLandmark.getLandmarkInWorldCoords(camPos, camYaw)
    mapSeenPos = mapSeenLandmark.position
    # check if the transformed landmark is an inlier
    closestLandmarks = landmarkMap.getClosestLandmarks(mapSeenPos, nClosest)
    bestMatches = landmarkMap.getLandmarkMatches(mapSeenLandmark, nBest)
    bestInClosest = [lm for lm in closestLandmarks if lm in bestMatches]
    if len(bestInClosest) == 0:
        # the seen landmark is not an inlier.
        return False
    else:
        # return the closest landmark in the top N matches
        return bestInClosest[0]
        # # TODO alternatively, we could return the best match in the closest n
        # return [lm for lm in bestMatches if lm in closestLandmarks][0]


def getModelFromPlaneWorldMatches(matches):
    """
    Given a pair of matches between landmarks in the image and in the map,
    finds the transform between the world coordinate frame and the camera's
    coordinate frame.  It assumes fixed roll and pitch, so it just returns the
    x and y coordinates, the height above the map, and the yaw.

    A few notes on notation:
    a and b represent the locations of the two landmarks
    c represents the position of the camera in world coordinates
    Ci represents the projection of the camera along z onto the world plane
      (that is, Ci.x = c.x, Ci.y = c.y, and Ci.z = 0)
    angle_xyz represents the angle between lines from points x to y and y to z
    angle_vVsu represents the angle between vectors v and u
    x2y represents a line from point x to point y
    """
    (aCam, aMap), (bCam, bMap) = matches[:2]
    
    # compute robot's height
    a2bVecMap = bMap.position - aMap.position
    a2bDistMap = np.linalg.norm(a2bVecMap)
    b2aVecCam = aCam.position - bCam.position
    angle_abCi = angle_between(b2aVecCam, -bCam.position)
    angle_aCib = angle_between(aCam.position, bCam.position)
    a2CiDistMap = a2bDistMap * np.sin(angle_abCi) / np.sin(angle_aCib)
    angle_Ciac = np.arctan2(1, np.linalg.norm(aCam.position))
    camAltitude = a2CiDistMap * np.tan(angle_Ciac)

    # compute robot's yaw
    angle_a2bVsMapX = np.arctan2(a2bVecMap[1], a2bVecMap[0])
    a2bVecCam = bCam.position - aCam.position
    angle_a2bVsCamX = np.arctan2(a2bVecCam[1], a2bVecCam[0])
    camYaw = angle_a2bVsCamX - angle_a2bVsMapX

    # compute robot's x, y coordinates
    # rotate aCam's position by -camYaw (assuming +ve angles are ccw)
    rotmat = get2dRotmatFromYaw(-camYaw)
    aCamRot = aCam.getRotatedLandmark(rotmat)
    # scale aCam's position by camAltitude
    a2CiVecMap = - aCamRot.position * camAltitude
    CiPos = aMap.position + a2CiVecMap
    camPos = np.array([CiPos[0], CiPos[1], camAltitude])

    return camPos, camYaw


def computePlaneWorldError(matches, camPos, camOrientation):
    # TODO implement this
    totalError = 0
    for seenLm, mapLm in matches:
        # transform seenLm into world coordinates
        worldSeenLm = seenLm.getLandmarkInWorldCoords(camPos, camOrientation)
        # compute the error between seenLmWorld and mapLm
        error = np.linalg.norm(worldSeenLm.position - mapLm.position)
        # TODO consider orientation error as well - how to scale this?
        # accumulate error
        totalError += error

    # return the mean positional error between matches
    return totalError / len(matches)
