Bridging numerical relativity and automatic differentiation using JAX

What is automatic differentiation?

Automatic differentation is a technique for computing the derivative of a function extremely efficiently and with exact numerical precision. Two claims that cannot be made for numeric and symbolic different…


This content originally appeared on DEV Community and was authored by Baalateja Kataru

What is automatic differentiation?

Automatic differentation is a technique for computing the derivative of a function extremely efficiently and with exact numerical precision. Two claims that cannot be made for numeric and symbolic differentation, which are as of now the predominant methods to compute derivatives in scientific computing.

Symbolic differentation involves automated manipulation of symbols while respecting the rules of algebra and mathematics to derive exact expressions for derivatives. This form of differentation usually requires a CAS (Computer Algebra System) that has knowledge of how to perform differentation by hand using the rules of calculus and algebra. While symbolic differentation does provide exact results by virtue of producing exact expressions of derivatives, it is the most computationally expensive form of differentation, leaving it lacking in the domains of speed and efficiency. Examples of symbolic differentiation software include Mathematica, SageMath and SymPy.

Numerical differentation calculates derivatives by using the method of finite differences to approximate the limit definition of differentation by first principles. While it is faster than symbolic differentiation, it is plagued by accuracy and stability issues stemming from floating point arithmetic. Common pain points include numerical instability due to dividing by extremely small numbers which causes expressions to explode, intermediate floating point round off errors which accumulate across expressions leading to divergence from the true derivative, and the limits of floating point precision and storage in modern day classical computing.

Automatic differentation intends to circumvent all these issues by tracing out all the operations defined and constructing a DAG (directed acyclic graph) for the target function, and computing gradients for each variable by tracing backwards in the graph via the chain rule. This is essentially how optimization of deep neural networks is done, via something called backpropagation - a form of automatic differentiation. Autodiff is at the heart of modern machine learning and deep learning libraries like tensorflow, pytorch and jax. PyTorch's torch.autograd module provides methods to run autodiff on PyTorch tensors, which is used to train neural networks in order to minimize predicion loss (error) and improve accuracy of outputs.

For a visual explanation of automatic differentation, check out: https://www.youtube.com/watch?v=wG_nF1awSSY&t=2s

To implement automatic differentation yourself, check out this tutorial on building micrograd, a tiny autodiff engine, by the brilliant Andrej Karpathy, The spelled-out intro to neural networks and backpropagation: building micrograd. This tutorial also doubles as an excellent hands-on introduction to the world of modern day machine/deep learning and the frameworks involved like tensorflow and pytorch because you're essentially building a smaller version of them.

What is JAX?

jax is a high performance machine learning and scientific computing library for multilinear algebra, automatic differentiation, and for writing performant numerical code that can run on CPUs as well as accelerators such as GPUs and TPUs for profitable speedups.

In a nutshell, jax is numpy with autodiff (automatic differentation) support. I say this because jax's API has a 1-1 correspondance to numpy, which makes jax a drop-in replacement for numpy.

Compared to other machine learning libraries like tensorflow and pytorch that come packed with practically everything you need to do machine learning of any sorts at scale, jax embraces the "simple is beautiful" philosophy by being leaner and simpler in its design, represented by its smaller size (jax is 9MB while other libraries like pytorch and tensorflow exceed 1GB when installing with CUDA support). jax can be futher augmented by libraries like flox and equinox to augment its capabilities to do things like construct and train neural networks, nominally putting it on par with bigger libraries like tensorflow and pytorch.

What does this have to do with General Relativity?

Firstly, GR is formulated in the language of tensors and multilinear algebra. A tensor is a multidimensional array of numbers that acts as a generalization of vectors and matrices to higher dimensions. Like vectors and matrices, we can do calculus on tensors, which is a large part of the calculations we do in GR (computing Christoffel symbols, etc).

Tensors are also the language of neural networks and deep learning. Input and output data, and parameters such as a neural network's weights and biases, are stored as tensors to allow flexibility when it comes to representing data of many different shapes and sizes. We also need to be able to compute derivatives of tensors and do calculus on them in order to train neural networks. As one might gather, this means that there lies a natural intersection between GR and deep learning when it comes to the language they use to represent and manipulate data. Of course, there are conceptual differences that abound and need to be kept in mind when trying to make sense of this synergy. For instance, the mathematics of general relativity (and tensor calculus in general) defines an indexed quantity as a tensor based on its transformation properties, whereas any indexed quantity (eg. the Christoffel symbols) qualifies as a tensor in deep learning. However, such details are technicalities that can be kept in mind and accounted for by exercising adequate oversight.

The rapid developments in the field of AI in the past few decades have led to creation of efficient and accurate methods and tools for tensor algebra and calculus. Automatic differentation and jax are just some of many such methods and tools, but prominent ones due to their potential applications to the field of physics. However, I have noticed that there is a gaping lack of utilization of these tools and methods that have been built and refined for AI/ML to solve computational problems and improve computations done in modern day physics. I partly attribute this to the novel nature of the tools themselves, and I'm sure that given enough time and efforts, there will be widespread adoption of them to the many different subfields of physics.

In fact, such adoption has already begun with automatic differentation finding applications in fluid dynamics and differential equation solving. From my search of the interweb and arxiv, however, the field of general relativity, and specifically numerical relativity, which deals with numerically computing quantities and solving equations of Einstein's general theory of relativity, has had little to no such adoption, which I believe it is time to change.

What are we doing?

In this exercise, I will use jax to demonstrate the power of autodiff by exploring its potential and applicability to the field of numerical relativity.

Given a metric and some coordinates, we will explore how to compute derivatives of the metric tensor (Christoffel symbols) and other relevant tensors and quantities, such as the Riemann and Ricci tensors, the Ricci scalar curvature and the Kretschmann invariant, in order to finally compute the Einstein tensor and the stress-energy-momentum tensor.

We will make use of modern Python features such as type hinting, which was introduced in recent versions of Python (3.11, 3.12, ...) and allows us to explicitly specify the data types of the variables we're using in our code to enable better readability and type safety, and decorators, which are higher-order functions that modify/augment the behavior of the functions given to them as inputs to transform said functions' inputs/outputs for a specific purpose, among other things.

Note: we will be working exclusively in SI units and not natural units of any sort for ease of interpreting calculations when deriving numerical values of quantities.

Setup

We begin by importing dependencies from the typing library for type hinting support and importing jax + its numpy variant jax.numpy.

We configure jax to force all floating point numbers to be in 64-bit precision for higher accuracy in our results.

We also define a utility decorator function close_to_zero to help round off values close to zero and below a certain arbitrarily chosen tolerance in our tensors in order to reduce floating point errors arising from numerical precision issues compounding in intermediary arithmetic calculation steps.

from typing import Callable

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

TOLERANCE = 1e-8

def close_to_zero(func):
    def wrapper(*args, **kwargs) -> jnp.ndarray:
        result: jnp.ndarray = func(*args, **kwargs)

        return jnp.where(jnp.abs(result) < TOLERANCE, 0.0, result)

    return wrapper

Defining common metrics

We define some common metrics that are used as trivial examples for the rest of the tutorial.

# while the minkowski metric is a constant tensor that is not coordinate dependent,
# we still need to take in some "dummy" coordinates in order to make this function play nice with JAX's autodiff mechanism
def minkowski_metric(coordinates: jnp.ndarray) -> jnp.ndarray:
    """Returns the Minkowski metric in float64 precision with the (-1, 1, 1, 1) metric signature"""
    return jnp.diag(jnp.array([-1, 1, 1, 1], dtype=jnp.float64))

# this is the standard metric for a 2-sphere.
@close_to_zero
def spherical_polar_metric(coordinates: jnp.ndarray) -> jnp.ndarray:
    r, theta, phi = coordinates
    return jnp.diag(jnp.array([1, r**2, r**2 * jnp.sin(theta)**2], dtype=jnp.float64))

The Christoffel symbols

Given a metric gijg_{ij} gij , the Christoffel symbols of the second kind (also called the affine connection, or the connection coefficients) are defined as derivatives of the metric contracted with the inverse metric tensor gijg^{ij} gij :

Γjkl=12gjm(∂gmk∂xl+∂glm∂xk−∂gkl∂xm)\Gamma^{j}{kl} = \frac{1}{2} g^{jm} \left( \frac{\partial g{mk} }{\partial x^l } + \frac{\partial g_{lm} }{\partial x^k } - \frac{\partial g_{kl} }{\partial x^m } \right) Γjkl=21gjm(xlgmk+xkglmxmgkl)

Note: This definition and subsequent ones are taken from Mathematical Methods for Students of Physics and Related Fields by Sadri Hassani and uses all Roman indices. We will be doing the same and using all Roman indices throughout this tutorial, even when dealing with spacetime quantities and not just Cartesian quantities, because of their convenience when it comes to specifying them in the code. Greek symbols are non-trivial when it comes to their underlying Unicode representation, and that's a complexity I want to avoid for now.

The Christoffel symbols represent the gravitational field in a given spacetime, being derivatives of the metric which are analogous to the gravitational potential in Newtonian gravitation.

Now, we define a Python function to compute the Christoffel symbols. The function takes in the coordinates to compute the Christoffel symbols at, and the metric function to compute the Christoffel symbols for, and its implementation is broadly as follows:

  1. We evaluate the metric function to get the metric at the given coordinates and calculate the inverse metric tensor along the way using jnp.linalg.inv.

  2. We use jax.jacfwd to compute the "Jacobian" of the metric, i.e., its partial derivatives with respect to the given coordinates, using forward-more automatic differentiation. This is ∂gkl∂xm\frac{\partial g_{kl} }{\partial x^m }xmgkl , also denoted as gkl;mg_{kl;m} gkl;m . I write "Jacobian" in quotes because rigorously speaking, the Jacobian is defined as the matrix of partial derivatives of a vector-valued function with respect to its inputs, however, what we have is a tensor-valued function in the form of the metric. I am not aware of any mathematical quantity that is used to describe the higher-rank tensor associated with the derivatives of a tensor-valued function, hence my loose usage of the word "Jacobian" here.

  3. We use jnp.einsum, an extremely convenient and performant subroutine that manipulates computational tensors and other indexed objects using standard index notation, to compute the Christoffel symbols according to the equation given above.

@close_to_zero # this ensures that any values of our Christoffel symbols are rounded off if they're close to 0
def christoffel_symbols(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    # evaluate the metric tensor at the coordinates
    g = metric(coordinates)
    # compute the inverse metric tensor
    g_inv = jnp.linalg.inv(g)
    # obtain and evaluate the "jacobian" of the metric tensor at the coordinates
    jacobian = jax.jacfwd(metric)(coordinates) # this is kl;m

    return 0.5 * jnp.einsum('jm, klm -> jkl', g_inv, jnp.einsum('klm -> mkl', jacobian) + jnp.einsum('klm -> lmk', jacobian) - jacobian)

The Torsion Tensor

The torsion tensor is defined as the antisymmetric part of a general affine connection Γhkl\Gamma^l_{hk} Γhkl :

Γhkl−Γkhl\Gamma^l_{hk} - \Gamma^l_{kh} ΓhklΓkhl

If the torsion tensor vanishes in one coordinate system, then it vanishes in all coordinate systems (the zero tensor is zero in all coordinate systems). Therefore, the torsion tensor of a general affine connection is zero if and only if the connection is symmetric, i.e.,

Γhkl=Γkhl\Gamma^l_{hk} = \Gamma^l_{kh} Γhkl=Γkhl

Since we are dealing only with Christoffel symbols of the second kind, which is a unique and symmetric affine connection derived from the metric tensor, all the torsion tensors that we will be computing in this notebook should be zero, and the following subroutine will be used to verify that.

@close_to_zero
def torsion_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    christoffels = christoffel_symbols(coordinates, metric)

    return christoffels - jnp.einsum('ijk -> ikj', christoffels)

The Riemann Curvature Tensor and the Kretschmann Invariant

The Riemann tensor RklmjR^j_{klm} Rklmj encodes the intrinsic curvature of the Riemannian (or semi-Riemannian) manifold produced by any given metric. It is defined as derivatives of the Christoffel symbols Γklj\Gamma^j_{kl} Γklj as:

Rklmj=∂mΓklj−∂lΓkmj+ΓrmjΓklr−ΓrljΓkmrR^j_{klm} = \partial_m \Gamma^j_{kl} - \partial_l \Gamma^j_{km} + \Gamma^j_{rm} \Gamma^r_{kl} - \Gamma^j_{rl} \Gamma^r_{km} Rklmj=mΓkljlΓkmj+ΓrmjΓklrΓrljΓkmr

We define a Python function that:

  1. Uses the previous christoffel_symbols function to obtain the Christoffel symbols for a given set of metric and coordinates.
  2. Computes the "jacobian" of the Christoffel symbols to obtain Γkl;mj\Gamma^j_{kl;m} Γkl;mj .
  3. Manipulates this "jacobian" tensor and products of the Christoffel symbols using jnp.einsum to obtain the Riemann tensor

We also define a Python function to compute the Kretschmann invariant from the Riemann tensor, a scalar that is used to look for true physical singularities (gravitational singularities) in certain manifolds independent of the choice of coordinates:

RjklmRjklmR^{jklm} R_{jklm} RjklmRjklm

We will use the Kretschmann invariant later on in this tutorial to verify whether the Riemann tensor implementation we have below is correct or not when using the Schwarzchild metric as a case study to verify the correctness of these subroutines.

@close_to_zero
def riemann_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    christoffels = christoffel_symbols(coordinates, metric)
    jacobian = jax.jacfwd(christoffel_symbols)(coordinates, metric) # computes jkl;m

    return jacobian - jnp.einsum('jklm -> jkml', jacobian) + jnp.einsum('jrm, rkl -> jklm', christoffels, christoffels) - jnp.einsum('jrl, rkm -> jklm', christoffels, christoffels)

@close_to_zero
def kretschmann_invariant(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    riemann = riemann_tensor(coordinates, metric)

    g = metric(coordinates)
    g_inv = jnp.linalg.inv(g)

    riemann_upper = jnp.einsum('pj, qk, rl, ijkl -> ipqr', g_inv, g_inv, g_inv, riemann) # computes R^{jklm} by contracting with three inverse metric tensors
    riemann_lower = jnp.einsum('pi, ijkl -> pjkl', g, riemann) # computes R_{jklm} by contracting with one metric tensor

    return jnp.einsum('ijkl, ijkl ->', riemann_upper, riemann_lower)

The Ricci tensor and the Ricci scalar curvature

The Ricci tensor RklR_{kl} Rkl is another curvature-related quantity which is defined as the trace component of the Riemann tensor RklmjR^j_{klm} Rklmj . The Ricci tensor is obtained by contracting the only contravariant index of the Riemann tensor with its last covariant index:

Rkl=RkljjR_{kl} = R^j_{klj} Rkl=Rkljj

Physically, the Ricci tensor encodes information about how volumes change in the presence of tidal forces.

By raising one of the Ricci tensor's indices and contracting, we obtain the Ricci scalar curvature

R=Rll=gklRklR = R^l_l = g^{kl}R_{kl} R=Rll=gklRkl

We implement Python subroutines to do this computationally using jnp.einsum again to manipulate indices and perform contractions.

@close_to_zero
def ricci_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    riemann = riemann_tensor(coordinates, metric)

    return jnp.einsum('jklj -> kl', riemann) # contracting the first and last indices

@close_to_zero
def ricci_scalar(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.float32:
    g = metric(coordinates)
    g_inv = jnp.linalg.inv(g)
    ricci = ricci_tensor(coordinates, metric)

    return jnp.einsum('kl, kl -> ', g_inv, ricci) # trace of the ricci tensor

The Einstein tensor and the stress-energy-momentum tensor

The Einstein tensor, the crown jewel of the general theory of relativity, encodes all information about the curvature of a spacetime manifold, and is defined in terms of the Ricci tensor, the metric tensor, and the Ricci scalar curvature as:

Gij≡Rij−12gijRG_{ij} \equiv R_{ij} - \frac{1}{2} g_{ij} R GijRij21gijR

It forms the left-hand-side of the Einstein Field Equations (EFEs), a set of 16 coupled partial differential equations that relate the curvature of a spacetime manifold to the mass-energy content in it:

Gij=8πGc4TijG_{ij} = \frac{8 \pi G}{c^4} T_{ij} Gij=c48πGTij

The right hand side TijT_{ij} Tij is the stress-energy-momentum tensor, that encodes information about all the mass-energy present in a spacetime manifold.

We write Python functions to call the subroutines implemented before to trivially compute the Einstein tensor and the stress-energy-momentum tensor using the equations we just described.

@close_to_zero
def einstein_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    g = metric(coordinates)
    rt = ricci_tensor(coordinates, metric)
    rs = ricci_scalar(coordinates, metric)

    return rt - 0.5 * g * rs

@close_to_zero
def stress_energy_momentum_tensor(coordinates: jnp.ndarray, metric: Callable[[jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
    G = einstein_tensor(coordinates, metric)

    kappa = (8 * jnp.pi * 6.67e-11) / ((299792458)**4)

    return G / kappa

Case study 1: the 2-sphere metric

Using all of the Python functions we have defined before, let's perform calculations in float64 precision for the 2-sphere metric given in spherical polar coordinates (r,θ,ϕ)(r, \theta, \phi) (r,θ,ϕ) as:

gij=diag(1,r2,r2sin⁡2(θ))g_{ij} = \text{diag}(1, r^2, r^2 \sin^2(\theta)) gij=diag(1,r2,r2sin2(θ))

We use the following coordinate values for the calculations, which were arbitrarily chosen

r=5r = 5 r=5
θ=π/3\theta = \pi/3 θ=π/3
ϕ=π/2\phi = \pi/2 ϕ=π/2

coordinates = jnp.array([5, jnp.pi/3, jnp.pi/2], dtype=jnp.float64)
metric = spherical_polar_metric

print(f"Christoffel symbols: {christoffel_symbols(coordinates, metric)}")
print(f"Torsion tensor: {torsion_tensor(coordinates, metric)}")
print(f"Riemann tensor: {riemann_tensor(coordinates, metric)}")
print(f"Ricci tensor: {ricci_tensor(coordinates, metric)}")
print(f"Ricci scalar: {ricci_scalar(coordinates, metric)}")
print(f"Einstein tensor: {einstein_tensor(coordinates, metric)}")
print(f"Stress-energy-momentum tensor: {stress_energy_momentum_tensor(coordinates, metric)}")
print(f"Kretschmann invariant: {kretschmann_invariant(coordinates, metric)}")

Running this, we get

Christoffel symbols: [[[ 0.          0.          0.        ]
  [ 0.         -5.          0.        ]
  [ 0.          0.         -3.75      ]]

 [[ 0.          0.2         0.        ]
  [ 0.2         0.          0.        ]
  [ 0.          0.         -0.4330127 ]]

 [[ 0.          0.          0.2       ]
  [ 0.          0.          0.57735027]
  [ 0.2         0.57735027  0.        ]]]
Torsion tensor: [[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
Riemann tensor: [[[[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]


 [[[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]


 [[[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]]
Ricci tensor: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Ricci scalar: 0.0
Einstein tensor: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Stress-energy-momentum tensor: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Kretschmann invariant: 0.0

We obtain all of the Christoffel symbols for this metric. Furthermore, as expected, the Riemann tensor, Ricci tensor, Ricci scalar, the Einstein tensor, and the stress-energy-momentum tensor all vanish since this metric describes a flat spacetime with no curvature and no mass-energy content. The torsion tensor is also zero as expected by the symmetric nature of the Christoffel symbols.

Case study 2: the Schwarzschild metric

Now we come to a more interesting case study, the Schwarzschild metric, which describes the spacetime of an uncharged, unrotating, spherically symmetric body in vacuum. The Schwarzschild metric in traditional spherical polar spacetime coordinates (t,r,θ,ϕ)(t, r, \theta, \phi) (t,r,θ,ϕ) is given as:

gij=diag(−(1−rsr)c2,1(1−rsr),r2,r2sin⁡2(θ))g_{ij} = \text{diag}\left(- \left( 1 - \frac{r_s}{r} \right) c^2, \frac{1}{\left(1 - \frac{r_s}{r} \right)}, r^2, r^2 \sin^2(\theta) \right) gij=diag((1rrs)c2,(1rrs)1,r2,r2sin2(θ))

Where rsr_s rs is the Schwarzschild radius of the massive body, a scale factor which is related to its mass MM M by:

rs=2GMc2r_s = \frac{2 G M}{c^2} rs=c22GM

We do exactly the same as the previous case study, computing all the quantities in float64 precision for a body with ~4.3 million solar masses using the following arbitrary chosen spacetime spherical polar coordinates

t=3600t = 3600 t=3600

r=3000r = 3000 r=3000

θ=π/3\theta = \pi/3 θ=π/3

ϕ=π/2\phi = \pi/2 ϕ=π/2

G = 6.67e-11
c = 299792458.0

M = 4.297e+6 * 1.989e+30 # 4.3 million solar masses, mass of Sgr A*

# schwarzschild radius
rs = (2 * G * M) / c**2

@close_to_zero
def schwarzschild_metric(coordinates: jnp.ndarray) -> jnp.ndarray:
    t, r, theta, phi = coordinates

    return jnp.diag(jnp.array([-(1 - (rs / r)) * c**2, 1/(1 - (rs/r)), r**2, r**2 * jnp.sin(theta)**2], dtype=jnp.float64))

coordinates = jnp.array([3600, 3000, jnp.pi/3, jnp.pi/2], dtype=jnp.float64)
metric = schwarzschild_metric

print(f"Christoffel symbols: {christoffel_symbols(coordinates, metric)}")
print(f"Torsion tensor: {torsion_tensor(coordinates, metric)}")
print(f"Riemann tensor: {riemann_tensor(coordinates, metric)}")
print(f"Ricci tensor: {ricci_tensor(coordinates, metric)}")
print(f"Ricci scalar: {ricci_scalar(coordinates, metric)}")
print(f"Einstein tensor: {einstein_tensor(coordinates, metric)}")
print(f"Stress-energy-momentum tensor: {stress_energy_momentum_tensor(coordinates, metric)}")
print(f"Kretschmann invariant: {kretschmann_invariant(coordinates, metric)}")

Running this outputs

Christoffel symbols: [[[ 0.00000000e+00 -1.66666706e-04  0.00000000e+00  0.00000000e+00]
  [-1.66666706e-04  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

 [[-2.67840757e+26  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  1.66666706e-04  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  1.26857006e+10  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  9.51427546e+09]]

 [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  3.33333333e-04  0.00000000e+00]
  [ 0.00000000e+00  3.33333333e-04  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -4.33012702e-01]]

 [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  3.33333333e-04]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  5.77350269e-01]
  [ 0.00000000e+00  3.33333333e-04  5.77350269e-01  0.00000000e+00]]]
Torsion tensor: [[[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
Riemann tensor: [[[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  1.11111137e-07  0.00000000e+00  0.00000000e+00]
   [-1.11111137e-07  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  2.11428394e+06  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [-2.11428394e+06  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.58571295e+06]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [-1.58571295e+06  0.00000000e+00  0.00000000e+00  0.00000000e+00]]]


 [[[ 0.00000000e+00  1.78560505e+23  0.00000000e+00  0.00000000e+00]
   [-1.78560505e+23  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  2.11428394e+06  0.00000000e+00]
   [ 0.00000000e+00 -2.11428394e+06  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.58571295e+06]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -1.90734863e-06]
   [ 0.00000000e+00 -1.58571295e+06  1.90734863e-06  0.00000000e+00]]]


 [[[ 0.00000000e+00  0.00000000e+00 -8.92802525e+22  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 8.92802525e+22  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  5.55555687e-08  0.00000000e+00]
   [ 0.00000000e+00 -5.55555687e-08  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -3.17142590e+06]
   [ 0.00000000e+00  0.00000000e+00  3.17142590e+06  0.00000000e+00]]]


 [[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00 -8.92802525e+22]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 8.92802525e+22  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  5.55555687e-08]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00 -5.55555687e-08  0.00000000e+00  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  4.22856787e+06]
   [ 0.00000000e+00  0.00000000e+00 -4.22856787e+06  0.00000000e+00]]

  [[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]]]
Ricci tensor: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Ricci scalar: 0.0
Einstein tensor: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Stress-energy-momentum tensor: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
Kretschmann invariant: 2.649005370647907

Like before, we've obtained our Christoffel symbols for the Schwarzschild metric and the torsion tensor is zero as expected.

Furthermore, since the Schwarzschild metric is a vacuum solution to the Einstein Field Equations, the Ricci tensor, Ricci scalar, Einstein tensor, and Stress-energy-momentum tensor are all zero.

However, observe that the Riemann tensor is not zero. In fact, some of its components are quite large in magnitude. We also obtain a value for the Kretschmann invariant from the Riemann tensor's contraction. To verify this and thereby verify whether the components of our calculated Riemann tensor are correct or not, we can compute the Krietschmann invariant directly using a relatively trivial scalar formula:

RjklmRjklm=48G2M2c4r6R^{jklm} R_{jklm} = \frac{48 G^2 M^2}{c^4 r^6} RjklmRjklm=c4r648G2M2
G = 6.67e-11
M = 4.297e+6 * 1.989e+30 # 4.3 million solar masses, mass of Sgr A*
c = 299792458.0

r = 3000

kr = (48 * G**2 * M**2) / (c**4 * r**6)

print("Kretschmann invariant is", kr)

Running this, we get the output

Kretschmann invariant is 2.649005370647906

Which matches the result we obtained by contracting all indices of the Riemann tensor up to 15 decimal places!

Conclusion

In this exploration, we've witnessed first hand the power of automatic differentiation in enabling derivative computation of all quantities to maximum machine precision in terms of floating point accuracy.

We saw the effectiveness and simplicity of JAX's jax.jacfwd to compute Jacobians/derivatives of tensors in forward-mode automatic differentiation.

We've also seen the benefits of JAX's jnp.einsum with its ability to do index manipulations and perform calculations at a high level, allowing one to do tensor calculus operations in a convenient and efficient manner. The language of tensors and tensor calculus are pervasive in other topics and fields of theoretical physics, such as covariant electromagnetism and quantum field theory. At the very least, it would be an interesting exercise to explore a similar approach leveraging modern AI/ML libraries and frameworks in order to perform tensor computations for calculations arising in these fields.

JAX's capabilities further allows us to parallelize all of this code and run it on accelerators such as GPUs and TPUs with no changes to the original code, in order to speed up calculations for heavy numerical computations and siulations requiring parallelism, speed, and efficiency. We have not explored that aspect specifically in this demonstration, and all computations have been done on the CPU since none of the computations or use cases required usage of parallelism and accelerators, but nonetheless the capabilities exist and can be called upon if a specific use case demands it.

The explanations, scripts, and results of this exploration are also conveniently documented in this Jupyter notebook hosted on Google Colab, which you can use to reproduce and build upon the ideas presented here.

Future work

Some ideas and extensions in this general direction that are worth exploring:

  • Perform the same calculations for other prominent metrics such as the Kerr and Kerr-Newmann metrics
  • Write functions to compute the Weyl tensor and the Weyl invariant, other important curvature-related quantities.
  • Provide alternative but equivalent PyTorch and Tensorflow implementations to the JAX implementation here.
  • Use the @jax.jit decorator on relevant subroutines for using JIT compilation in order to speed up computations.
  • Look into jax.jacrev for reverse-mode automatic differentiation and try to combine it with jax.jacfwd for optimal speed and accuracy.
  • Configure this code to run on GPUs and other accelerators.
  • Train a neural network to parameterize the metric tensor and solve an optimization problem in order to find the metric tensor components from data.
  • Compute the Christoffel symbols for a metric and solve the geodesic equation to find the EOM using automatic differentation powered differential equation solvers, such as diffrax
  • Extend this methodology to other aspects and calculations of numerical relativity, specifically those tackled by popular general relativity Python libraries such as EinsteinPy


This content originally appeared on DEV Community and was authored by Baalateja Kataru


Print Share Comment Cite Upload Translate Updates
APA

Baalateja Kataru | Sciencx (2024-10-19T00:06:33+00:00) Bridging numerical relativity and automatic differentiation using JAX. Retrieved from https://www.scien.cx/2024/10/19/bridging-numerical-relativity-and-automatic-differentiation-using-jax/

MLA
" » Bridging numerical relativity and automatic differentiation using JAX." Baalateja Kataru | Sciencx - Saturday October 19, 2024, https://www.scien.cx/2024/10/19/bridging-numerical-relativity-and-automatic-differentiation-using-jax/
HARVARD
Baalateja Kataru | Sciencx Saturday October 19, 2024 » Bridging numerical relativity and automatic differentiation using JAX., viewed ,<https://www.scien.cx/2024/10/19/bridging-numerical-relativity-and-automatic-differentiation-using-jax/>
VANCOUVER
Baalateja Kataru | Sciencx - » Bridging numerical relativity and automatic differentiation using JAX. [Internet]. [Accessed ]. Available from: https://www.scien.cx/2024/10/19/bridging-numerical-relativity-and-automatic-differentiation-using-jax/
CHICAGO
" » Bridging numerical relativity and automatic differentiation using JAX." Baalateja Kataru | Sciencx - Accessed . https://www.scien.cx/2024/10/19/bridging-numerical-relativity-and-automatic-differentiation-using-jax/
IEEE
" » Bridging numerical relativity and automatic differentiation using JAX." Baalateja Kataru | Sciencx [Online]. Available: https://www.scien.cx/2024/10/19/bridging-numerical-relativity-and-automatic-differentiation-using-jax/. [Accessed: ]
rf:citation
» Bridging numerical relativity and automatic differentiation using JAX | Baalateja Kataru | Sciencx | https://www.scien.cx/2024/10/19/bridging-numerical-relativity-and-automatic-differentiation-using-jax/ |

Please log in to upload a file.




There are no updates yet.
Click the Upload button above to add an update.

You must be logged in to translate posts. Please log in or register.