import json ## Uncomment the one you'd like to use. # fname = "calibration_scalar.json" # fname = "calibration_linear.json" fname = "calibration_quadratic.json" ## Defining the basis functions in a lookup table, so we can use the ## "basis" field in the JSON files to figure out which one to use, ## rather than have to hard-code anything. basis_scalar = [ lambda m1, m2, a1z, a2z: 1.0, ] basis_linear = basis_scalar + [ lambda m1, m2, a1z, a2z: m1, lambda m1, m2, a1z, a2z: m2, ] basis_quadratic = basis_linear + [ lambda m1, m2, a1z, a2z: m1*m2, lambda m1, m2, a1z, a2z: m1**2, lambda m1, m2, a1z, a2z: m2**2, ] bases = { "scalar" : basis_scalar, "linear" : basis_linear, "quadratic" : basis_quadratic, } with open(fname, "r") as calibration_file: calibration_info = json.load(calibration_file) coeffs = calibration_info["coeffs"] basis_fns = bases[calibration_info["basis"]] def f(m1, m2, a1z, a2z): """ This is the correction factor function. Now anytime you evaluate VT(m1, m2, a1z, a2z), be sure to multiply by f(m1, m2, a1z, a2z), so vt_corrected = VT(m1, m2, a1z, a2z) * f(m1, m2, a1z, a2z) """ return sum( c * g(m1, m2, a1z, a2z) for c, g in zip(coeffs, basis_fns) )