#!/usr/bin/env python

# external modules
import argparse
import yaml
import cv2
import numpy as np

# internal modules
import mapGenerator
import constants as c
from LandmarkMap import LandmarkMap
from Localizer import Localizer
from Visualizer import Visualizer
from geometryUtils import *

class SimConfig(object):

    def __init__(self, simConfFilename):
        with open(simConfFilename, 'r') as yamlConfFile:
            yamlConfig = yaml.load(yamlConfFile)
            self.is2d = True
            self.width = yamlConfig['width']
            self.height = yamlConfig['height']
            self.minDepth = yamlConfig['minDepth']
            self.maxDepth = yamlConfig['maxDepth']
            self.fx = yamlConfig['fx']
            self.fy = yamlConfig['fy']

            
class Simulator(object):

    def __init__(self, mapFilename, simConfFilename):
        self.landmarkMap = LandmarkMap(mapFilename)
        self.simConfig = SimConfig(simConfFilename)
        self.localizer = Localizer(mapFilename)

    def generatePoseWithObjects(self):
        # generate a random pose: position + orientation
        pos, ornt = mapGenerator.generatePose(self.landmarkMap.size,
                                              self.landmarkMap.getNDims())
        
        # find all of the objects in that pose
        """
TODO There's now the question of the probability of an object appearing...it shouldn't just be flat.  It should be inversely proportional to the size of the object and proportional to the scale.
"""
        # TODO generate a depth - for now this has a default value
        depth = 1
        if len(pos) == 2:
            pos = np.array([pos[0], pos[1], depth])
        seenLandmarks = []
        for lm in self.landmarkMap.landmarks:
            camLm = lm.getLandmarkInCamCoords(pos, ornt)
            xRad = self.simConfig.width / 2
            yRad = self.simConfig.height / 2
            inXRange = camLm.position[0] > -xRad and camLm.position[0] < xRad
            inYRange = camLm.position[1] > -yRad and camLm.position[1] < yRad
            if inXRange and inYRange:
                seenLandmarks.append(camLm)

        # incorporate the depth into pos
        if len(pos) == 2:
            tmpPos = pos
            pos = np.zeros(3)
            pos[:2] = tmpPos
            pos[2] = depth
        return pos, ornt, seenLandmarks

    def runSimulation(self):
        """
        Generate some samples from the simulator and send them to the \
        localizer.  Publish the results.
        """
        pos, ornt, seenLandmarks = self.generatePoseWithObjects()
        # estimatedPoses = self.localizer.localize(seenLandmarks)
        estimatedPoses = self.localizer.planeWorldLocalize(seenLandmarks)
        # estimatedPoses = self.localizer.exact2dLocalize(seenLandmarks)
        return pos, ornt, estimatedPoses, seenLandmarks


def getSimulatorArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('mapFilename', help='\
    The yaml file containing the description of the map.')
    parser.add_argument('--simConfig', default=c.SIM_CONFIG_FILE, help='\
    The name of the config file that defines the parameters of the \
    simulator.')
    parser.add_argument('-n', type=int, default=1, help='\
    The number of experiments to run.')
    return parser.parse_args()


def main():
    args = getSimulatorArgs()
    # create the simulator
    simulator = Simulator(args.mapFilename, args.simConfig)
    visualizer = Visualizer(args.mapFilename, simulator.simConfig.width,
                            simulator.simConfig.height)
    # TODO create the localizer
    # TODO handle 3D case
    for _ in range(args.n):
        truePos, trueYaw, estPoses, seenLandmarks = simulator.runSimulation()
        if estPoses is not None:
            print "True pose:", truePos, trueYaw
            estProbs, estPoss, estYaws = zip(*estPoses)
            print "Estimated poses:", estPoses
            vis = visualizer.renderSimulation(truePos, trueYaw, estProbs, estPoss,
                                              estYaws, seenLandmarks)
            cv2.imshow('visualization', vis)
            cv2.waitKey(0)
    
if __name__ == "__main__":
    main()
