"""
Algorithm from Subbarao 1988
'Parallel Depth Recovery by Changing Camera Parameters'
Basis method mentioned in DFD book
"""
import numpy as np
import scipy.fftpack as fft
import sys


def gaussian_window(img, sigma=0.1):
    x = np.arange(img.shape[1])
    y = np.arange(img.shape[0])
    xvals = np.e**(-((x-0.5*len(x))**2) / ((sigma*len(x))**2))
    yvals = np.e**(-((y-0.5*len(y))**2) / ((sigma*len(y))**2))
    window = np.outer(xvals, yvals)
    return window

def shifted_psd(img):
    """ 
    returns power spectral density for image
    """
    # applying gaussian window
    w_img = img *gaussian_window(img) 
    # pad with zeros
    pad_w_img = np.pad(w_img, int(w_img.shape[0]/2), 'constant')
        
    f_img = fft.fftshift(fft.fft2(pad_w_img))

    magnitudes = np.absolute(f_img)
    # account for zero padding by only returning every second magnitude (for this to make sense image
    # has to be padded to twice its original size)
    return (magnitudes[::2,::2])

def get_weights(P):
    omega = get_omega(P)
    omegas = np.ones(P.shape) * omega
    nu = omegas.T
    weights = -1./(omega**2 + nu**2)
    return weights

def get_omega(P):
    omega = np.fft.fftshift(np.fft.fftfreq(P.shape[1]))
    return omega

def get_fin_wd(im, im2, prms):
    P1 = shifted_psd(im)
    P2 = shifted_psd(im2)
    P1 = process_P(P1, prms)
    P2 = process_P(P2, prms)
    weights = get_weights(P1)
    diff = np.log(P1) - np.log(P2)
    wd = (weights * diff)
    fin_wd = wd[np.isfinite(wd)] 
    return wd, fin_wd


def process_P(P, prms):
    P[np.where(P<prms['p_cutoff'])] = 0 
    return P

def C(img1, img2, prms=None):    
    '''
    uses shifted FFT
    '''
    wd, fin_diff = get_fin_wd(img1, img2, prms)
        
    percs = np.percentile(fin_diff, [25,75])
    fin_diff = fin_diff[np.where((fin_diff>percs[0])*(fin_diff<percs[1]))]
    if len(fin_diff) > 0:
        c = np.mean(fin_diff) / (2*np.pi**2) # scipy version of Fourier transform requires this normalization
    else: 
        c= percs[0]
    # scaling C with parameters found by least square fitting
    c+=prms['addC']
    c*=prms['multC']
    return  c

def solve_for_sigma2(a, b, c):
    d = b*b - 4.*a*c
    if d < 0:
        return []
    else:
        x1 = -b + np.sqrt(d) 
        x2 = -b - np.sqrt(d) 
        return [x1/(2.*a), x2/(2.*a)]

def distance(img1, img2, v1, v2, prms):
    f = prms['f']
    A = prms['aperture']
    rho = prms['k']
    kapC = C(img1, img2, prms=prms)
    alpha = v1 / float(v2)
    beta = rho * A * v1 * (1./v2 - 1./v1) 
    a = alpha**2  - 1
    b = 2 * alpha * beta
    c = beta**2 - kapC
    sigma2s = solve_for_sigma2(a, b, c)
    if not sigma2s:
        sigma2s = (np.nan, np.nan)
    us = [1./ (1./f -1./v2 - sigma2/(rho*A*v2) ) for sigma2 in sigma2s]
    return us

