import json

## Uncomment the one you'd like to use.
# fname = "calibration_scalar.json"
# fname = "calibration_linear.json"
fname = "calibration_quadratic.json"

## Defining the basis functions in a lookup table, so we can use the
## "basis" field in the JSON files to figure out which one to use,
## rather than have to hard-code anything.
basis_scalar = [
    lambda m1, m2, a1z, a2z: 1.0,
]
basis_linear = basis_scalar + [
    lambda m1, m2, a1z, a2z: m1,
    lambda m1, m2, a1z, a2z: m2,
]
basis_quadratic = basis_linear + [
    lambda m1, m2, a1z, a2z: m1*m2,
    lambda m1, m2, a1z, a2z: m1**2,
    lambda m1, m2, a1z, a2z: m2**2,
]
bases = {
    "scalar" : basis_scalar,
    "linear" : basis_linear,
    "quadratic" : basis_quadratic,
}

with open(fname, "r") as calibration_file:
    calibration_info = json.load(calibration_file)
    coeffs = calibration_info["coeffs"]
    basis_fns = bases[calibration_info["basis"]]

def f(m1, m2, a1z, a2z):
    """
    This is the correction factor function.
    Now anytime you evaluate VT(m1, m2, a1z, a2z), be sure to multiply
    by f(m1, m2, a1z, a2z), so
    vt_corrected = VT(m1, m2, a1z, a2z) * f(m1, m2, a1z, a2z)
    """
    return sum(
        c * g(m1, m2, a1z, a2z)
        for c, g in zip(coeffs, basis_fns)
    )
