我一直在玩用
lambdify
能够在Sympy方程中使用天体测量单位来实现。以下是我为这个问题提出的解决方案:
from sympy import pi, Symbol, lambdify, sympify, symbols, solve
import astropy.units as u
from astropy.constants import G as astroG
Ms = Symbol('Ms') # mass of sun
Mp = Symbol('Mp') # mass of planet
G = Symbol('G') # Gravitational consatant
a = Symbol('a') # semi major axes
P = Symbol('P') # period
solutions = solve(
(G * (Ms + Mp) / (4 * pi**2)) - (a**3/P**2),
a
)
# set values at which to evaluate the solutions
values = {
"P": 2.1 * u.yr,
"Ms": 1.4 * u.M_sun,
"Mp": 0.5 * u.M_jup, # (0.5 * u.M_jup.to("M_sun")),
"G": astroG,
}
evalsolutions = []
reqsymbols = [Ms, Mp, P, G]
for sol in solutions:
# lambdify solutions
f = lambdify(
reqsymbols,
sol,
modules=["numpy"],
)
res = f(**values)
if res.dtype != "complex":
evalsolutions.append(res.si.decompose())
print(evalsolutions)
这给出了:
[<Quantity 2.74467861e+11 m>]
旧解决方案
这是我最初发布的另一个解决方案。它更复杂,而且被证明是矫枉过正的,但确实展示了如何利用
modules
lambdify-ed函数的参数。
from sympy import pi, Symbol, lambdify, sympify, symbols, solve
import astropy.units as u
from astropy.constants import G as astroG
Ms = Symbol('Ms') # mass of sun
Mp = Symbol('Mp') # mass of planet
Munit = sympify('unit(msol)')
G = sympify('const(G)') # Gravitational consatant
a = Symbol('a') # semi major axes
P = Symbol('P') # period
Punit = Symbol('unit(yr)')
solutions = solve(
(G * ((Ms + Mp) * Munit) / (4 * pi**2)) - (a**3/(P * Punit)**2),
a
)
UNITS = {
"msol": u.M_sun,
"au": u.au,
"yr": u.yr,
}
CONSTANTS = {
"G": astroG,
}
def unitfunc(key):
return 1.0 * UNITS.get(key, 1)
def constfunc(key):
return CONSTANTS.get(key, 1)
# set values at which to evaluate the solutions
values = {
"P": 2.1,
"Ms": 1.4,
"Mp": (0.5 * u.M_jup.to("M_sun")),
}
values.update({key: key for key in UNITS})
values.update({key: key for key in CONSTANTS})
evalsolutions = []
reqsymbols = [Ms, Mp, P]
reqsymbols.extend([symbols(key) for key in list(UNITS.keys()) + list(CONSTANTS.keys())])
for sol in solutions:
# lambdify solutions
f = lambdify(
reqsymbols,
sol,
modules=[{"unit": unitfunc, "const": constfunc}, "numpy"],
)
res = f(**values)
if res.dtype != "complex":
evalsolutions.append(res.si.decompose())
print(evalsolutions)