"""All kinds of extra math functionality."""

import math

import vizmat


##########################
###                    ###
###  helper functions  ###
###                    ###
##########################


def transform(t, func):
    """Transforms t according to a """

    return {'linear': t, # linear; has no effect
            'siso': (1.0 - math.cos(math.pi * t)) * 0.5, # slow in, slow out
            'fiso': math.sin(((1.0 - t) ** 2 + 1) * math.pi / 2.0) ** 2 # fast in, slow out
            }[func]

def lerp(t, val1=0, val2=1, func='linear'):
    """Linear interpolation from <val1> to <val2> at time <t>.
    
        __add__, __sub__ and __mul__ (+ - *)
        should be defined for <val1> and <val2>.

        Returns <val1> when it fails!
    """
    
#    if val1.__class__ is val2.__class__:
    try:
        return val1 + (val2 - val1) * transform(t, func)
    except:
#        raise Exception('Cannot interpolate between %s and %s.' % (val1.__class__, val2.__class__))
        raise Exception('Cannot interpolate between %s %s and %s %s.' % (str(val1), val1.__class__, str(val2), val2.__class__))

def clamp(val, l=0.0, h=1.0):
    """Make sure that <val> always lies in range [l, h]."""
    
    return min(max(l, val), h)

def mapScale(val, fromAB, toXY):
    """"Maps <val> from the [a, b] scale to the [x, y] scale."""
    
    a, b = fromAB
    x, y = toXY
    return x + (val - a) / (b - a) * (y - x)

def regularToVizAngle(a):
    return (90.0 - (math.degrees(a)) + 360.0) % 360.0

def vizToRegularAngle(a):
    return (math.radians(90.0 - a) + 2 * math.pi) % (2 * math.pi)

def absmax(*args):
    return max([(abs(i), i) for i in args])[1]

def lookToEuler(x, y, z):
    """Convert direction vector to Euler."""
    return vizmat.QuatToEuler(vizmat.LookToQuat(x, y, z))


#APERTURE = 32
ASPECT_RATIO = 4.0 / 3.0

def spaceToScreenCoords((x, y, z), fov):
    """Return coordinates of a point projected on the image plane of a camera.

    rx = horizontal distance of projection from image plane center [0,1]
    ry = vertical distance of projection from image plane center [0,1]
    rz = distance from image plane
    """
    rz = z
    if rz <= 0:
        f = 1000000.0 # 'infinite'
    else:
#        f = fov / (rz * 0.5 * APERTURE)
        f = 1.0 / math.tan(math.radians(fov)) * rz
        rx = x * f
        ry = y * f * ASPECT_RATIO
        return rx, ry, rz

def screenToSpaceCoords((x, y, z), fov):
    rz = z
#    f = (z * 0.5 * APERTURE) / fov
    f = math.tan(math.radians(fov)) * rz
    rx = x * f
    ry = y * f / ASPECT_RATIO
    return rx, ry, rz


#######################
###                 ###
###  Bezier curves  ###
###                 ###
#######################

def factorial(x):
    return x and x * factorial(x - 1) or 1

def comb(n, k):
    """n over i"""
    return factorial(n) / (factorial(k) * factorial(n - k))

def bincoeff(i, n, t):
    """Compute the binomial coefficient b(i,n)_t."""
    return int(comb(n, i)) * (1 - t) ** (n - i) * t ** i

def bezierPoint(cp, t):
    """Return a point at time <t> on the Bezier curve span by controlpoints <cp>."""
    n = len(cp) - 1
    l = [vecMul(p, bincoeff(i, n, t)) for i, p in enumerate(cp)]
    return reduce(lambda x, y: vecAdd(x, y), l)

def bezierPointCasteljau(cp, t):
    """Perform De Casteljau's algorithm to find the point on the curve at <t>."""
    if len(cp) == 1:
        return cp[0]
    else:
        newControlPoints = []
        for i in xrange(len(cp) - 1):
            cp1, cp2 = cp[i:i + 2]
            newControlPoints.append(cp1 + (cp2 - cp1) * t)
        return bezierPointCasteljau(newControlPoints, t)
    
def bezierLength(cp, n=20):
    """Approaches the length of a bezier curve.
    Adjust <n> for more precision."""
    p = [bezierPoint(cp, t / float(n)) for t in  range(n + 1)]
    l = 0.0
    for i in xrange(n - 1):
        p1, p2 = p[i:i + 2]
        l += vecDist(p1, p2)
    return l
        
def bezier3(start, end, dir):
    """Compute control points from start position, start direction and end position.
    
    Warning! Doesn't work for angles > 180 degrees between start and end."""
    v = vecSub(end, start)
    d = vecLength(v)
    nv = vecNorm(v)
    ndir = vecNorm(dir)
    s = d / (vecMul(nv, ndir) * 2.0)
    p = vecAdd(start, vecMul(ndir, s))
    return [start, p, end]    

def bezier4(start, end, startDir, endDir=None):
    """Compute control points from start position, start direction and end position."""
    v = vecSub(end, start)
    if endDir is None:
        endDir = vecMirror(startDir, v)
    ns = vecNorm(startDir)
    ne = vecNorm(endDir)
    l = vecLength(v) * 0.25
    v1 = vecMul(ns, l)
    v2 = vecMul(ne, l)
    p1 = vecAdd(start, v1)
    p2 = vecSub(end, v2)
    return [start, p1, p2, end]


#####################
###               ###
###  VECTOR MATH  ###
###               ###
#####################


# functions (faster)

def vecNeg(v):
    """Negation."""
    return [-i for i in v]

def vecAdd(v1, v2):
    """Add two vectors."""
    if len(v1) == len(v2):
        return [i + j for (i, j) in zip(v1, v2)]
    else:
        raise Exception('Cannot add %d-dimensional vector to %d-dimensional vector.' % (len(v1), len(v2)))

def vecSub(v1, v2):
    """Substract v2 from v1."""
    if len(v1) == len(v2):
        return [i - j for (i, j) in zip(v1, v2)]
    else:
        raise Exception('Cannot substract %d-dimensional vector from %d-dimensional vector' % (len(v2), len(v1)))
    
def vecMul(v1, v2):
    """Multiply two vectors (dot product) or vector with scalar."""
    
    if type(v2) in [int, float]:
        return [i * v2 for i in v1]
    
    elif isinstance(v2, list):
        if len(v1) == len(v2):
            v = [i * j for (i, j) in zip(v1, v2)]
            return sum(v)
        else:
            raise Exception('Cannot multiply %d-dimensional vector with %d-dimensional vector.' % (len(v1), len(v2)))
        
def vecDiv(v, scalar):
    """Divide <v> by <scalar>."""
    if type(scalar) in [int, float]:
        return vecMul(v, (1.0 / scalar))
    else:
        raise Exception('Cannot divide a vector by a non-scalar value.')

def vecLength(v):
    """Return the length of a vector."""
    return math.sqrt(sum([x ** 2 for x in v]))

def vecDist(v1, v2):
    """Distance between two vectors."""
    if len(v1) == len(v2):
        return vecLength(vecSub(v1, v2))
    else:
        raise Exception('Cannot compute distanceTo between %d-dimensional vector and %d-dimensional vector.' % (len(v1), len(v2)))
    
def vecNorm(v):
    """Return a normalized vector."""
    return vecDiv(v, vecLength(v))

def vecCenter(v, p1, p2):
    """Return center of triangle spanned by <v>, <p1> and <p2>."""
    
    if len(v) == len(p1) == len(p2):
        return vecDiv(vecAdd(v, vecAdd(p1, p2)), 3.0)
    else:
        raise Exception('All vectors should be of the same length.')
    
def vecProjectToLine(v, a, b):
    """Return the point on line segment between a and b that is closest to v."""
    
    if len(v) == len(a) == len(b):
        d = vecSub(b, a)
        dl = vecLength(d)
        if dl == 0.0:
            return a
        else:
            l = [(es - ea) * (eb - ea) for ea, eb, es in zip(a, b, v)]
            u = sum(l) / (dl ** len(v))
            return [ea + u * ed for ea, ed in zip(a, d)]
    else:
        raise Exception('Cannot intersect %d-dimensional vector with line trough vectors with dimension %d and %d.' % (len(v), len(a), len(b)))

def vecAngleBetween(v1, v2, units='radians'):
    """Absolute angle between <v1> and <v2>."""
    
    if len(v1) == len(v2):
        # angle is always positive
        v1n = vecNorm(v1)
        v2n = vecNorm(v2)
        a = math.acos(vecMul(v1n, v2n))
        if units == 'degrees':
            return math.degrees(a)
        return a
    
    else:
        raise Exception('Cannot compute angle between %d-dimensional vector and %d-dimensional vector.' % (len(v1), len(v2)))

def vec2FromAngle(angle=None, vizAngle=None):
    """Construct a Vec2 from an angle."""
    if angle is None:
        if vizAngle is None:
            raise Exception('')
        else:
            a = vizToRegularAngle(vizAngle)
    else:
        a = angle
    return [math.cos(a), math.sin(a)]

def vec2Angle(v, vizAngle=False):
    """Return absolute angle of direction."""
    
    a = vec2AngleBetween(v, [1.0, 0.0], relative=True)
    if vizAngle:
        a = regularToVizAngle(a)
    return a

def vec2AngleBetween(v1, v2, units='radians', relative=False):
    """Angle between <v1> and <v2>."""
    
    if len(v1) == len(v2):
        if relative:
            # direction matters, angle can either be positive or negative
            a = math.atan2(*v2) - math.atan2(*v1)
            # -pi <= a < pi
            a = (a + math.pi) % (math.pi * 2) - math.pi
            
        else:
            # angle is always positive
            a = vecAngleBetween(v1, v2)
            
        if units == 'degrees':
            return math.degrees(a)
        return a
    
    else:
        raise Exception('Cannot compute angle between %d-dimensional vector and %d-dimensional vector.' % (len(v1), len(v2)))

def vec3Cross(v1, v2):
    """Cross product of two vectors."""
    a1, a2, a3 = v1
    b1, b2, b3 = v2
    return [a2 * b3 - a3 * b2, a3 * b1 - a1 * b3, a1 * b2 - a2 * b1]

def vecMirror(v, other):
    """Mirror a vector against another vector. Returns a normalized vector."""
    if len(v) == len(other):
        # derive normal
        if len(v) == 2:
            [x, y] = other
            n = [-y, x]
        elif len(v) == 3:
            h = vec3Cross(v, other)
            n = vec3Cross(h, other)
        else:
            raise Exception('Cannot mirror %d-dimensional vectors.' % (len(v),))
        # compute mirror
        nv = vecNorm(v)
        nn = vecNorm(n)
        return vecAdd(nv, vecMul(nn, vecMul(vecNeg(nv), nn) * 2))
    else:
        raise Exception('Cannot mirror %d-dimensional vector to %d-dimensional vector.' % (len(v), len(other)))


# classes (syntax goodness)

class Vector(list):
    """Wraps vector functionality in a class for nice syntax with operators and such."""
    
    def __init__(self, data=None, size=0):
        if data is None:
            list.__init__(self, [0] * size)
        elif len(data) == size:
            list.__init__(self, data)
        else:
            raise Exception('%s cannot be used for %d-dimensional data.' % (self.__class__.__name__, len(data)))
        
    def set(self, other):
        if self.__class__ is other.__class__:
            list.__init__(self, other)
            
    def __neg__(self):
        """Negation."""
        v = vecNeg(self)
        return self.__class__(v)
        
    def __add__(self, other):
        """Add two vectors, overloads + operator."""
        v = vecAdd(self, other)
        return self.__class__(v, len(v))
        
    def __iadd__(self, other):
        return self + other
            
    def __sub__(self, other):
        """Substract other from self, overloads - operator."""
        v = vecSub(self, other)
        return self.__class__(v, len(v))
        
    def __isub__(self, other):
        return self - other
        
    def __mul__(self, other):
        """Multiply two vectors (dot product) or vector with scalar, overloads * operator."""
        v = vecMul(self, other)
        if type(other) in [int, float]:
            return self.__class__(v)
        elif isinstance(other, list):
            return v
        
    def __imul__(self, other):
        return self * other
        
    def __div__(self, scalar):
        """Divide <self> by <scalar>, overloads / operator."""
        v = vecDiv(self, scalar)
        return self.__class__(v)
            
    def __idiv__(self, scalar):
        return self / scalar
        
    def length(self):
        return vecLength(self)
    
    def normalized(self):
        v = vecNorm(self)
        return self.__class__(v)
    
    def normalize(self):
        self /= self.length()
        
    def distanceTo(self, other):
        return vecDist(self, other)
        
    def projectToLine(self, a, b):
        """Return the point on line segment between a and b that is closest to self."""
        v = vecProjectToLine(self, a, b)
        return self.__class__(v)
    
    def center(self, p1, p2):
        """Return center of triangle spanned by <self>, <p1> and <p2>."""
        v = vecCenter(self, p1, p2)
        return self.__class__(v)
        
    def angleWith(self, other, units='radians'):
        """Absolute angle between <v1> and <v2>."""
        return vecAngleBetween(self, other, units)


class Vec2(Vector):
    """Two-dimensional vector."""
    
    def __init__(self, x=None, y=None, angle=None, vizAngle=None):
        data = None
        if x is None:
            if angle is not None:
                data = vec2FromAngle(angle=angle)
            elif vizAngle is not None:
                data = vec2FromAngle(vizAngle=vizAngle)
        elif isinstance(x, list):
            data = x
        elif y is not None:
            data = [x, y]
        Vector.__init__(self, data, size=2)
        
    def polarCoords(self, vizAngle=False):
        """Return polar coordinates (angle and length)."""
        return self.angle(vizAngle), self.length()
    
    def angle(self, vizAngle=False):
        """Return absolute angle of direction."""
        return vec2Angle(self, vizAngle)

    def angleWith(self, other, units='radians', relative=False):
        """Angle between <v1> and <v2>."""
        return vec2AngleBetween(self, other, units, relative)

        
class Vec3(Vector):
    """Three-dimensional vector."""

    def __init__(self, x=None, y=None, z=None):
        data = None
        if x is not None:
            if isinstance(x, list):
                data = x
            elif y is not None and z is not None:
                data = [x, y, z]
        Vector.__init__(self, data, size=3)
        
    def __xor__(self, other):
        """Cross product of two vectors, overloads ^ operator."""
        return Vec3(vec3Cross(self, other))
        

######################
###                ###
###  Polygon math  ###
###                ###
######################


def pointInPoly(x, y, poly):

    n = len(poly)
    inside = False

    x1, y1 = poly[0]
    for i in range(n + 1):
        x2, y2 = poly[i % n]
        if min(y1, y2) < y <= max(y1, y2) and x <= max(x1, x2):
            if y1 != y2:
                xinters = (y - y1) * (x2 - x1) / (y2 - y1) + x1
            if x1 == x2 or x <= xinters:
                inside = not inside
        x1, y1 = x2, y2

    return inside



##############
###        ###
###  Main  ###
###        ###
##############

if __name__ == '__main__':
    print lerp(0.5, 1.5, 1)
    print Vec3(2, 5) + Vec3()
    print [(t / 10.0, transform(t / 10.0, 'fiso')) for t in xrange(10)]
