Speeding Up Your Python Codes 1000x!!!
Introduction¶
Welcome to “Code Optimization: Speeding Up Your Python Codes 1000x”!
In the next hours, we will learn tricks to unlock python performance and increase a sample code’s speed by a factor of 1000x! This lab is based on a Hack Arizona tutorial CK gave in 2024. The original lab can be found here.
Here is our plan:
Understand the problem: we will start by outlining the -body problem and how to use the leapfrog algorithm to solve it numerically.
Benchmark and profile: we will learn how to measure performance and identify bottlenecks using tools like timeit and cProfile.
Optimize the code: step through a series of optimizations:
Use list comprehensions and reduce operation counts.
Replace slow operations with faster ones.
Use high-performance libraries such as NumPy.
Explore lower precision where applicable.
Use Google JAX for just-in-time compilation and GPU acceleration.
Apply the integrator: finally, we will use our optimized code to simulate the mesmerizing figure-8 self-gravitating orbit.
By the end of this lab, you will see that Python isn’t just great for rapid prototyping. It can also deliver serious performance when optimized right. Let’s dive in!
The -body Problem and the Leapfrog Algorithm¶
The -body problem involves simulating the motion of many bodies interacting with each other through (usually) gravitational forces.
This problem is physically interesting, as it models the solar system, galaxy dynamics, or even the whole cosmic structure formation!
This problem is also computational interesting. Each body experiences a force from every other body, making the simulation computationally intensive. It is ideal for exploring optimization techniques.
The leapfrog algorithm is a simple, robust numerical integrator that alternates between updating velocities (“kicks”) and positions (“drifts”). It is popular because it conserves energy and momentum well over long simulation periods.
We need to solve Newton’s second law of motion for a collection of bodies. The body forces are solved by using Newton’s Law of Universal Gravitation.
Gravitational Force/Acceleration Equation¶
For any two bodies labeled by and , the direct gravitational force exerted on body by body is given by:
where:
is the gravitational constant.
and are the masses of the bodies.
and are their position vectors.
Summing the contributions from all other bodies gives the net force (and thus the acceleration) on body :
Given Newton’s law , the “acceleration” applied on body caused by all other bodies is:
Choosing the right units so .
Here is a pure Python implementation of the gravitational acceleration:
def acc1(m, r):
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
axij = - m[j] * (xi - xj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
ayij = - m[j] * (yi - yj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
azij = - m[j] * (zi - zj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
axi += axij
ayi += ayij
azi += azij
a.append((axi, ayi, azi))
return aWe may try using this function to evaluate the gravitational force between two particles, with with mass , at a distance of 2. We expect the “acceleration” be in the direction of their seperation.
m = [1.0, 1.0]
r0 = [
(+1.0, 0.0, 0.0),
(-1.0, 0.0, 0.0),
]
a0ref = [
(-0.25, 0.0, 0.0),
(+0.25, 0.0, 0.0),
]a0 = acc1(m, r0)a0assert a0 == a0refLeapfrog Integrator¶
Our simulation solves Newton’s second law of motion, , numerically. The equation tells us that the acceleration of a body is determined by the net force acting on it. In a system where forces are conservative, this law guarantees the conservation of energy and momentum. These are critical properties when simulating long-term dynamics.
However, many standard integration methods tend to drift over time, gradually losing these conservation properties. The leapfrog algorithm is a symplectic integrator designed to address this issue. Its staggered (or “leapfrogging”) updates of velocity and position help preserve energy and momentum much more effectively over extended simulations.
Half-step velocity update (Kick): start by updating the velocity by half a time step using the current acceleration :
Full-step position update (Drift): update the position for a full time step using the half-stepped velocity:
Half-step velocity update (Kick): finally, update the velocity by another half time step using the new acceleration computed from the updated positions:
Here is a pure Python implementation of the gravitational acceleration:
def leapfrog1(m, r0, v0, dt, acc=acc1):
n = len(m)
# vh = v0 + 0.5 * dt * a0
a0 = acc(m, r0)
vh = []
for i in range(n):
a0xi, a0yi, a0zi = a0[i]
v0xi, v0yi, v0zi = v0[i]
vhxi = v0xi + 0.5 * dt * a0xi
vhyi = v0yi + 0.5 * dt * a0yi
vhzi = v0zi + 0.5 * dt * a0zi
vh.append((vhxi, vhyi, vhzi))
# r1 = r0 + dt * vh
r1 = []
for i in range(n):
vhxi, vhyi, vhzi = vh[i]
x0i, y0i, z0i = r0[i]
x1i = x0i + dt * vhxi
y1i = y0i + dt * vhyi
z1i = z0i + dt * vhzi
r1.append((x1i, y1i, z1i))
# v1 = vh + 0.5 * dt * a1
a1 = acc(m, r1)
v1 = []
for i in range(n):
a1xi, a1yi, a1zi = a1[i]
vhxi, vhyi, vhzi = vh[i]
v1xi = vhxi + 0.5 * dt * a1xi
v1yi = vhyi + 0.5 * dt * a1yi
v1zi = vhzi + 0.5 * dt * a1zi
v1.append((v1xi, v1yi, v1zi))
return r1, v1We may try using this function to evolve the two particles under gravitational force, with some initial velocities.
v0 = [
(0.0, +0.5, 0.0),
(0.0, -0.5, 0.0),
]
v1ref = [
(-0.024984345739803245, +0.4993750014648409, 0.0),
(+0.024984345739803245, -0.4993750014648409, 0.0),
]r1, v1 = leapfrog1(m, r0, v0, 0.1)v1assert v1 == v1refTest the Integrator¶
We test the integrator by evolving the two particle by a time using steps.
from math import pi
N = 64
T = 2 * pidt = T / N
R = [r0]
V = [v0]
for _ in range(N):
r, v = leapfrog1(m, R[-1], V[-1], dt)
R.append(r)
V.append(v)Although our numerical scheme is implemented in pure python, it is
still handy to use numpy to slice through the data ...
import numpy as np
R = np.array(R)
X = R[:,:,0]
Y = R[:,:,1]and plot the result...
from matplotlib import pyplot as plt
plt.plot(X, Y, '.-')
plt.gca().set_aspect('equal')In addition, we may create animation.
from matplotlib.animation import ArtistAnimation
from IPython.display import HTML
from tqdm import tqdm
def animate(X, Y, ntail=10):
fig, ax = plt.subplots(1, 1, figsize=(5,5))
ax.set_xlabel('x')
ax.set_ylabel('y')
frames = []
for i in tqdm(range(len(X))):
b,e = max(0, i-ntail), i+1
ax.set_prop_cycle(None)
f = ax.plot(X[b:e,:], Y[b:e,:], '.-', animated=True)
frames.append(f)
anim = ArtistAnimation(fig, frames, interval=50)
plt.close()
return animanim = animate(X, Y)
HTML(anim.to_html5_video()) # display animation
# anim.save('orbitss.mp4') # save animationBenchmarking and Profiling Your Code¶
Before diving into optimizations, it is essential to understand where our code spends most of its time. By benchmarking and profiling, we can pinpoint performance bottlenecks and measure improvements after optimization.
Benchmarking vs. Profiling¶
Benchmarking: Measures the overall runtime of your code. Tools like Python’s
timeitmodule run your code multiple times to provide an accurate average runtime. This helps in comparing the performance before and after optimizations.Profiling: Provides detailed insights into which parts of your code are consuming the most time. For example,
cProfilegenerates reports showing function call times and frequencies. Note thatcProfiletypically runs the code once, so its focus is on identifying hotspots rather than providing averaged timings.
Quick Benchmark Example¶
timeit is a module designed for benchmarking by executing code
multiple times.
It is excellent for obtaining reliable runtime measurements.
n = 1000
m = np.random.lognormal(size=n).tolist()
r0 = np.random.normal(size=(n, 3)).tolist()
v0 = np.random.normal(size=(n, 3)).tolist()%timeit r1, v1 = leapfrog1(m, r0, v0, dt)It takes about 1.31 second on my laptop to run a single step of leapfrog for bodies. Your mileage may vary. But this is pretty slow in today’s standard.
Quick Profiling Example¶
Here is a snippet using cProfile to profile our leapfrog stepper:
import cProfile
cProfile.run("r1, v1 = leapfrog1(m, r0, v0, dt)")From cProfile’s result, the most used function is method 'append' of 'list' objects.
This is not surprising given we’ve been using for-loop and append,
e.g.,
v1 = []
for i in range(n):
...
v1.append(...)in both acc1() and leapfrog1().
This is neither pythonic nor efficient.
Optimization 1: Use List Comprehension Over For Loop¶
Python’s list comprehensions is a concise way to create lists by
iterating over an iterable and applying an expression---all in a
single, compact line.
Internally, they are optimized in C, making them generally faster
than a standard python for-loop.
iterable = range(10)
# Standard python for-loop
l1 = []
for x in iterable:
l1.append(x * 2)
# Using a list comprehension
l2 = [x * 2 for x in iterable]
# Compare results
print(l1)
print(l2)We may use list comprehension to rewrite our leapfrog algorithm:
# O1. Use List Comprehension Over For Loop
def leapfrog2(m, r0, v0, dt, acc=acc1):
n = len(m)
# vh = v0 + 0.5 * dt * a0
a0 = acc(m, r0)
vh = [[v0[i][k] + 0.5 * dt * a0[i][k]
for k in range(3)]
for i in range(n)]
# r1 = r0 + dt * vh
# TODO: use list comprehension to rewrite the remaining of our leapfrog algorithm
# v1 = vh + 0.5 * dt * a1
# TODO: use list comprehension to rewrite the remaining of our leapfrog algorithm
# TODO: return ...Begin of reference code¶
# O1. Use List Comprehension Over For Loop
def leapfrog2(m, r0, v0, dt, acc=acc1):
n = len(m)
# vh = v0 + 0.5 * dt * a0
a0 = acc(m, r0)
vh = [[v0xi + 0.5 * dt * a0xi
for v0xi, a0xi in zip(v0i, a0i)]
for v0i, a0i in zip(v0, a0 )]
# r1 = r0 + dt * vh
r1 = [[x0i + dt * vhxi
for x0i, vhxi in zip(r0i, vhi)]
for r0i, vhi in zip(r0, vh )]
# v1 = vh + 0.5 * dt * a1
a1 = acc(m, r1)
v1 = [[vhxi + 0.5 * dt * a1xi
for vhxi, a1xi in zip(vhi, a1i)]
for vhi, a1i, in zip(vh, a1 )]
return r1, v1End of reference code¶
%timeit r1, v1 = leapfrog2(m, r0, v0, dt)cProfile.run("r1, v1 = leapfrog2(m, r0, v0, dt)")Depending on your python version, you may see slight performance
increase from timeit and number of call decreased for append.
It takes about 1.29 second on my laptop to run a single step of
leapfrog for bodies.
Optimization 2: Reduce Operation Count¶
Reducing the operation count means cutting down on unnecessary calculations and redundant function calls. When operations are executed millions of times---as in inner loops or simulations---even small optimizations can yield significant speed improvements.
For example, precomputing constant values or combining multiple arithmetic steps into one reduces repetitive work. In the context of the -body problem, calculate invariant quantities once outside of critical loops instead of recalculating them every time. This streamlined approach not only speeds up your code but also help further optimizations like vectorization and JIT compilation.
Recall the benchmark using leapfrog2() with acc1().
%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc1)We noticed that the computation of $\mathbf{r}_{ij}^3 = |\mathbf{r}_i
\mathbf{r}j|^3\mathbf{x}{ij}$. Instead of recomputing it, we can simply cache it in a variable
rrr.
# O2a. "Cache" r^3
def acc2(m, r):
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
# TODO: "Cache" r^3 below
axij = - m[j] * (xi - xj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
ayij = - m[j] * (yi - yj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
azij = - m[j] * (zi - zj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
axi += axij
ayi += ayij
azi += azij
a.append((axi, ayi, azi))
return aBegin of reference code¶
# O2a. "Cache" r^3
def acc2(m, r):
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
# "Cache" r^3
rrr = ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
axij = - m[j] * (xi - xj) / rrr
ayij = - m[j] * (yi - yj) / rrr
azij = - m[j] * (zi - zj) / rrr
axi += axij
ayi += ayij
azi += azij
a.append((axi, ayi, azi))
return aEnd of reference code¶
%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc2)This reduce the benchmark time by about 45% already!
But we don’t have to stop from this.
The different components of can be cached as dx, dy, dz, too.
# O2b. "Cache" the components of dr = r_ij = ri - rj
def acc3(m, r):
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
# TODO: "Cache" the components of dr = r_ij = ri - rj
rrr = ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)
axij = - m[j] * (xi - xj) / rrr
ayij = - m[j] * (yi - yj) / rrr
azij = - m[j] * (zi - zj) / rrr
axi += axij
ayi += ayij
azi += azij
a.append((axi, ayi, azi))
return aBegin of reference code¶
# O2b. "Cache" the components of dr = r_ij = ri - rj
def acc3(m, r):
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
# "Cache" the components of dr = r_ij = ri - rj
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2)
axij = - m[j] * dx / rrr
ayij = - m[j] * dy / rrr
azij = - m[j] * dz / rrr
axi += axij
ayi += ayij
azi += azij
a.append((axi, ayi, azi))
return aEnd of reference code¶
%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc3)Similarly, we can cache .
# O2c. "cache" -m_j / r_{ij}^3
def acc4(m, r):
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2)
# TODO: "cache" -m_j / r_{ij}^3
axij = - m[j] * dx / rrr
ayij = - m[j] * dy / rrr
azij = - m[j] * dz / rrr
axi += axij
ayi += ayij
azi += azij
a.append((axi, ayi, azi))
return aBegin of reference code¶
# O2c. "cache" -m_j / r_{ij}^3
def acc4(m, r):
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2)
f = - m[j] / rrr # "cache" -m_j / r_{ij}^3
axi += f * dx
ayi += f * dy
azi += f * dz
a.append((axi, ayi, azi))
return aEnd of reference code¶
%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc4)Finally, we notice the symmetry that . In principle, we only need to compute for . However, this requires we pre-allocate a list-of-list. Just creating list-of-list and use them to keep track of the acceleration actually increase benchmark time.
# O2d. Use list-of-list to keep track of the acceleration
def acc5(m, r):
n = len(m)
a = [] # TODO: create a list-of-list
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2)
f = - m[j] / rrr
# TODO: Use the list-of-list to keep track of the acceleration
axi += f * dx
ayi += f * dy
azi += f * dz
a.append((axi, ayi, azi))
return aBegin of reference code¶
# O2d. Use list-of-list to keep track of the acceleration
def acc5(m, r):
n = len(m)
a = [[0]*3]*n # create a list-of-list
for i in range(n):
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2)
f = - m[j] / rrr
# Use the list-of-list to keep track of the acceleration
a[i][0] += f * dx
a[i][1] += f * dy
a[i][2] += f * dz
return aEnd of reference code¶
%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc5)But once we have the list-of-list, we may change the upper bound of the inner loop to cut the computation to almost half. This is the fastest code so far.
# O2e. Take advantage of the symmetry of a_ij
def acc6(m, r):
n = len(m)
a = [[0]*3]*n
for i in range(n):
for j in range(n): # TODO: adjust the upper bound of the inner loop
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2)
f = - m[j] / rrr
a[i][0] += f * dx
a[i][1] += f * dy
a[i][2] += f * dz
# TODO: Account the acceleration to the j-th body.
return aBegin of reference code¶
# O2e. Take advantage of the symmetry of a_ij
def acc6(m, r):
n = len(m)
a = [[0]*3]*n
for i in range(n):
for j in range(i): # adjust the upper bound of the inner loop
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2)
fi = m[i] / rrr
fj = - m[j] / rrr
a[i][0] += fj * dx
a[i][1] += fj * dy
a[i][2] += fj * dz
# Account the acceleration to the j-th body.
a[j][0] += fi * dx
a[j][1] += fi * dy
a[j][2] += fi * dz
return aEnd of reference code¶
%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc6)However, acc?() is not the only function we can optimize to reduce
operation count.
If we study leapfrog?() carefully, the acceleration a in the
second “kick” calculation can actually be reused by the first “kick”
of the next step.
This requires modifying the function prototype a bit.
# O2f. Reuse computation
def leapfrog3(m, r0, v0, a0, dt, acc=acc6):
n = len(m)
# vh = v0 + 0.5 * dt * a0
a0 = acc(m, r0) # TODO: we comment this out, and reuse the acceleration computed in the second "kick" of the previous step
vh = [[v0xi + 0.5 * dt * a0xi
for v0xi, a0xi in zip(v0i, a0i)]
for v0i, a0i in zip(v0, a0 )]
# r1 = r0 + dt * vh
r1 = [[x0i + dt * vhxi
for x0i, vhxi in zip(r0i, vhi)]
for r0i, vhi in zip(r0, vh )]
# v1 = vh + 0.5 * dt * a1
a1 = acc(m, r1)
v1 = [[vhxi + 0.5 * dt * a1xi
for vhxi, a1xi in zip(vhi, a1i)]
for vhi, a1i, in zip(vh, a1 )]
return r1, v1 # TODO: to reuse the acceleration computed in the second "kick", let's return it.Begin of reference code¶
# O2f. Reuse computation
def leapfrog3(m, r0, v0, a0, dt, acc=acc6):
n = len(m)
# vh = v0 + 0.5 * dt * a0
# a0 = acc(m, r0) <--- we comment this out, and reuse the acceleration computed in the second "kick" of the previous step
vh = [[v0xi + 0.5 * dt * a0xi
for v0xi, a0xi in zip(v0i, a0i)]
for v0i, a0i in zip(v0, a0 )]
# r1 = r0 + dt * vh
r1 = [[x0i + dt * vhxi
for x0i, vhxi in zip(r0i, vhi)]
for r0i, vhi in zip(r0, vh )]
# v1 = vh + 0.5 * dt * a1
a1 = acc(m, r1)
v1 = [[vhxi + 0.5 * dt * a1xi
for vhxi, a1xi in zip(vhi, a1i)]
for vhi, a1i, in zip(vh, a1 )]
return r1, v1, a1 # <--- to reuse the acceleration computed in the second "kick", let's return it.End of reference code¶
# Precompute `a0` so we can pass it to `leapfrog3()`
a0 = acc4(m, r0)%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc4)%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc6)Remarkable, it takes about 241ms on my laptop to run a single step of leapfrog for bodies. This is almost a 92% reduction in benchmark time!!!
Optimizing 3: Use Fast Operations¶
Certain operations in Python can be surprisingly slow.
For example, the power operator (**) often calls C’s pow() function,
which for non-integer or variable exponents is typically implemented
as:
This method involves calculating a logarithm, a multiplication, and an exponential---operations that are much slower than simple multiplication.
For instance, if you need to square a number, it’s much faster to
write x*x rather than x**2.
When the exponent is a known small integer, manually multiplying the
base is usually the quickest route.
By choosing fast operations over more generic ones, you can shave off microseconds in code that runs millions of times, ultimately contributing to some performance boost.
# O3. Use fast operations
def acc7(m, r): # this is the same as acc4()
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2) # TODO: replace dx**2 by dx*dx etc
f = - m[j] / rrr
axi += f * dx
ayi += f * dy
azi += f * dz
a.append((axi, ayi, azi))
return aBegin of reference code¶
# O3. Use fast operations
def acc7(m, r): # this is a modification of acc4()
n = len(m)
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx*dx + dy*dy + dz*dz)**(3/2) # replace dx**2 by dx*dx etc
f = - m[j] / rrr
axi += f * dx
ayi += f * dy
azi += f * dz
a.append((axi, ayi, azi))
return aEnd of reference code¶
%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc7)# O3. Use fast operations
def acc8(m, r): # this is the same as acc6()
n = len(m)
a = [[0]*3]*n
for i in range(n):
for j in range(i):
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx**2 + dy**2 + dz**2)**(3/2) # TODO: replace dx**2 by dx*dx etc
fi = m[i] / rrr
fj = - m[j] / rrr
a[i][0] += fj * dx
a[i][1] += fj * dy
a[i][2] += fj * dz
a[j][0] += fi * dx
a[j][1] += fi * dy
a[j][2] += fi * dz
return aBegin of reference code¶
# O3. Use fast operations
def acc8(m, r): # this is a modification of acc6()
n = len(m)
a = [[0]*3]*n
for i in range(n):
for j in range(i):
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx*dx + dy*dy + dz*dz)**(3/2) # replace dx**2 by dx*dx etc
fi = m[i] / rrr
fj = - m[j] / rrr
a[i][0] += fj * dx
a[i][1] += fj * dy
a[i][2] += fj * dz
a[j][0] += fi * dx
a[j][1] += fi * dy
a[j][2] += fi * dz
return aEnd of reference code¶
%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc8)Surprisingly, although acc7() does not take advantage of the
symmetry of , it is now faster than acc6().
On the other hand, acc8() is still slightly faster than acc7().
Without external libraries, we’ve cut benchmark time by nearly 93%---a 14x speedup over our original -body implementation. While impressive even compared to compiled languages, this is just the beginning. By leveraging high-performance libraries, we can overcome Python’s inherent slowness and move closer to our goal of a 1000x speedup.
Optimization 4: Use High Performance Libraries¶
Leveraging high-performance libraries like NumPy is key to accelerating Python code. NumPy’s vectorized operations, implemented in C, allow you to perform complex computations on large arrays far more efficiently than native Python loops. This means you can offload heavy calculations to optimized, low-level routines and achieve significant speedups.
# O4. Use NumPy
def acc9(m, r): # this is the same as acc7()
n = len(m)
# TODO: rewrite the following code using numpy
a = []
for i in range(n):
axi, ayi, azi = 0, 0, 0
for j in range(n):
if j != i:
xi, yi, zi = r[i]
xj, yj, zj = r[j]
dx = xi - xj
dy = yi - yj
dz = zi - zj
rrr = (dx*dx + dy*dy + dz*dz)**(3/2)
f = - m[j] / rrr
axi += f * dx
ayi += f * dy
azi += f * dz
a.append((axi, ayi, azi))
return aBegin of reference code¶
# O4. Use NumPy
def acc9(m, r): # this is a modification of acc7(); acc8() is difficult to take advantage of numpy
# idx: i j i j
# v v v v
dr = r[:,None,:] - r[None,:,:]
rr = np.sum(dr * dr, axis=-1) # sum over vector components
# Ensure rr is non-zero
rr = np.maximum(rr, 1e-24)
# idx: i j
# v v
f = -m[None,:] / rr**(3/2)
a = np.sum(f[:,:,None] * dr, axis=1) # sum over j
return aEnd of reference code¶
We may do the same for leapfrog:
# O4. Use NumPy
def leapfrog4(m, r0, v0, a0, dt, acc=acc6):
# TODO: rewrite the following code using numpy
n = len(m)
# vh = v0 + 0.5 * dt * a0
vh = [[v0xi + 0.5 * dt * a0xi
for v0xi, a0xi in zip(v0i, a0i)]
for v0i, a0i in zip(v0, a0 )]
# r1 = r0 + dt * vh
r1 = [[x0i + dt * vhxi
for x0i, vhxi in zip(r0i, vhi)]
for r0i, vhi in zip(r0, vh )]
# v1 = vh + 0.5 * dt * a1
a1 = acc(m, r1)
v1 = [[vhxi + 0.5 * dt * a1xi
for vhxi, a1xi in zip(vhi, a1i)]
for vhi, a1i, in zip(vh, a1 )]
return r1, v1, a1Begin of reference code¶
# O4. Use NumPy
def leapfrog4(m, r0, v0, a0, dt, acc=acc9):
ht = 0.5 * dt
vh = v0 + ht * a0
r1 = r0 + dt * vh
a1 = acc(m, r1)
v1 = vh + ht * a1
return r1, v1, a1End of reference code¶
m = np.array(m)
r0 = np.array(r0)
v0 = np.array(v0)
a0 = np.array(a0)%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)cProfile.run("r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)")This runs at 38.3ms on my laptop. This cuts the benchmark time by 99% (which is no longer good indicator) and reaches a 82x speedup!
Optimization 5: Use Lower Precision¶
Using lower precision arithmetic can reduce memory usage and potentially speed up calculations. However, the benefits depend heavily on your hardware. For instance, on Intel x86 platforms, single precision may actually be slower than double precision due to conversion overhead and how the hardware is optimized. Always test and profile on your target system before switching to lower precision.
# 04. Use Lower Precision
# TODO: case the arrays `m`, `r0`, `v0`, and `a0` to single precision.Begin of reference code¶
m = np.array(m, dtype=np.single)
r0 = np.array(r0, dtype=np.single)
v0 = np.array(v0, dtype=np.single)
a0 = np.array(a0, dtype=np.single)End of reference code¶
%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)Optimization 6: Google JAX¶
Google JAX is a high-performance library that extends NumPy with
automatic differentiation and just-in-time (JIT) compilation.
It transforms Python functions into highly optimized machine code
using XLA, which can dramatically accelerate numerical computations.
One of JAX’s standout features is its seamless support for GPUs.
Depending on your hardware, running your code on a GPU with JAX can
lead to even greater speedups compared to CPU execution.
Additionally, JAX’s vectorization tools, like vmap, let you apply
functions over arrays efficiently without explicit Python loops.
By replace NumPy by JAX, we can harness the power of hardware
acceleration, and achieve the ambitious 1000x speedup!
from jax import numpy as jnp# O5. Use JAX
m = np.array(m)
r0 = np.array(r0)
v0 = np.array(v0)
a0 = np.array(a0)
def acc9(m, r):
# TODO: replace `np` by `jnp`.
dr = r[:,None,:] - r[None,:,:]
rr = np.sum(dr * dr, axis=-1)
rr = np.maximum(rr, 1e-24)
f = -m[None,:] / rr**(3/2)
a = np.sum(f[:,:,None] * dr, axis=1)
return aBegin of reference code¶
# O5. Use JAX
m = jnp.array(m)
r0 = jnp.array(r0)
v0 = jnp.array(v0)
a0 = jnp.array(a0)
def acc10(m, r):
dr = r[:,None,:] - r[None,:,:]
rr = jnp.sum(dr * dr, axis=-1)
rr = jnp.maximum(rr, 1e-24)
f = -m[None,:] / rr**(3/2)
a = jnp.sum(f[:,:,None] * dr, axis=1)
return aEnd of reference code¶
%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc10)cProfile.run("r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc10)")The timeit results on my laptop show that the slowest run took ~ 11 times longer than the fastest, with an average execution time of 3ms. This variability is likely due to the initialization of the JAX library and potential data transfer overhead. Nevertheless, these results are astonishing---achieving a 1023x speedup over our original implementation!
Optimization 7: Just-in-Time Compilation¶
Just-in-Time (JIT) compilation converts Python code into optimized machine code at runtime. This means that hot functions can be compiled on the fly, eliminating Python’s overhead and allowing them to run at speeds much closer to those of native C code.
JAX offers powerful JIT capabilities through its jax.jit
decorator.
When applied, it compiles your numerical functions into efficient,
low-level code, which can be further accelerated on GPUs.
This results in dramatic performance improvements, exeeding our quest
for that 1000x speedup.
from jax import jit# O6. JIT the leapfrog stepper
# TODO: Use the `@jit` decorator to compiler your codeBegin of reference code¶
# O6. JIT the leapfrog stepper
@jit
def leapfrog5(m, r0, v0, a0, dt):
return leapfrog4(m, r0, v0, a0, dt, acc=acc10)End of reference code¶
%timeit r1, v1, a1 = leapfrog5(m, r0, v0, a0, dt)cProfile.run("r1, v1, a1 = leapfrog5(m, r0, v0, a0, dt)")Our final results are truly remarkable.
By applying JAX’s JIT compilation, we compiled away all overhead
(see the output of cProfile) and achieved a 1710x speedup over the
original implementation.
This demostrate how powerful the optimization techniques we have
introduced to improve Python’s performance.
Applying the -Body Integrator¶
With our optimized -body integrator in hand, we can now explore its applications. This efficient python code lets us simulate various dynamical systems---from celestial mechanics to particle interactions---with different initial conditions. By reading in data files that specify masses, positions, velocities, we can easily tailor simulations to real-world scenarios.
def integrate(fname, T, N):
# Load data
ic = np.genfromtxt(fname)
# Setup initial conditions
m = jnp.array(ic[:,0])
R = [jnp.array(ic[:,1:4])]
V = [jnp.array(ic[:,4:7])]
A = [acc10(m, R[-1])]
# Main loop
T, dt = jnp.linspace(0, T, N+1, retstep=True)
for _ in tqdm(range(N)):
r, v, a = leapfrog5(m, R[-1], V[-1], A[-1], dt)
R.append(r)
V.append(v)
A.append(a)
# Return results
return T, jnp.array(R), jnp.array(V), jnp.array(A)The figure-8 orbit is a fascinating solution to the three-body problem where three equal-mass bodies follow a single, intertwined figure-8 path. Each body moves along the same curve, perfectly choreographed so that they never collide, yet their gravitational interactions keep them in a stable, periodic dance.
We first download the initial condition from GitHub if the file doesn’t exist:
! [ -f ic/figure-8.tsv ] || wget -P ic https://raw.githubusercontent.com/ua-2025q3-astr501-513/ua-2025q3-astr501-513.github.io/refs/heads/main/501/05/ic/figure-8.tsv# We then run the integrate() code
T, R, V, A = integrate("ic/figure-8.tsv", 2.5, 250)X = R[:,:,0]
Y = R[:,:,1]
plt.plot(X, Y, '-')anim = animate(X, Y)
HTML(anim.to_html5_video()) # display animation
# anim.save('orbitss.mp4') # save animation# HANDSON: Try out different initial conditions!
Conclusion and Discussion¶
In this workshop, we’ve journeyed from a straightforward, pure Python implementation of the -body simulation to a highly optimized version that achieves a 1710x speedup. We started by cleaning up our code with list comprehensions and cutting out unnecessary operations, then moved on to swapping slow functions for faster ones and leveraging high-performance libraries like NumPy.
The real game-changer came with Google JAX.
Its just-in-time compilation and GPU support pushed our code another
order of magnitude in performance.
These optimizations show that Python, when used right, isn’t just for
quick-and-dirty prototypes.
It can also deliver serious speed and efficiency.
This is especially important in hackathons and rapid prototyping environments. You can quickly iterate on your ideas and still end up with code that’s robust enough for real-world applications. Python truly bridges the gap between ease of development and high performance, letting you have your cake and eat it too.