改进曼德隆常数代码

计算科学 Python 表现
2021-12-24 12:49:12

我正在学习和提高我的 Python 技能。

我在 Python 中做了一个关于Mandelung constant的程序。但是,我有一个问题。使用以下总和计算 Mangelung 常数:

Vtotal=i,j,k=Li,j,k0LV(i,j,k)=e4πϵ0M

或者

M=i,j,k=Li,j,k0L(1)i+j+ki2+j2+k2

当我运行它时,它需要很长时间,当我输入一个巨大的数字时。所以,我需要改进我的代码,以更快地运行它。有人可以向我解释一种方法吗?(导入其他库,使用其他东西)

我做的代码:

import time

start_time = time.time()
L = int(input("Put the number of L:")) # size of the lattice
L = L+1 # this is for the vector (0,0,0)
# n = 0 # number of atoms
M = 0 # Madelung constant
for i in range(-L,L+1):
    for j in range(-L,L+1):
        for k in range(-L,L+1):
            # n += 1 #counter for number of atoms
            if i == j == k == 0: # doesn't count the origin (0,0,0)
                continue
            r = (i**2 + j**2 + k**2)**(-0.5)
            if (i + j + k) % 2 == 1: # odd number
                r *= -1
            M += r
print ("Mandelung Constant is::", M)
print("It takes %s seconds" % (time.time() - start_time))

当我放一个L=300, 需要超过 7 分钟。这就是我试图改进它的原因。

1个回答

正如@Richard 所提到的,Python 中的循环很慢。我想到了两个解决方案:

  • 使用 NumPy 并对操作进行矢量化。这将以将中间数组存储在内存中为代价来加快计算速度。

  • 将您的计算包装在一个函数中,并使用 Numba 的 jit 装饰器(神奇地)加速您的计算。这意味着有时会重构代码。


使用 NumPy 功能,您可以重写为具有以下内容:

import numpy as np
import time


start_time = time.time()
L = 300
i = np.array(range(-L, L + 1), dtype=np.float)
I, J, K = np.meshgrid(i, i, i)

M = (-1)**(I + J + K)/np.sqrt(I**2 + J**2 + K**2)
M[(I == 0)*(J == 0)*(K == 0)] = 0
M = np.sum(M)
print ("Madelung Constant is::", M)
print("It takes %s seconds" % (time.time() - start_time))

这给出了一个结果

Madelung Constant is:: -1.7456432959005515
It takes 17.750624418258667 seconds

和....相比

Madelung Constant is:: -1.745643295911936
It takes 310.843811750412 seconds

从你的代码。

你也可以使用Numba.

import time
from numba import jit


@jit
def sum_madelung(L):
    M = 0 # Madelung constant
    for i in range(-L , L + 1):
        for j in range(-L, L + 1):
            for k in range(-L, L + 1):
                if i == j == k == 0: # doesn't count the origin (0,0,0)
                    continue
                r = (i**2 + j**2 + k**2)**(-0.5)
                if (i + j + k) % 2 == 1: # odd number
                    r *= -1
                M += r
    return M

start_time = time.time()
M = sum_madelung(300)
print ("Madelung Constant is::", M)
print("It takes %s seconds" % (time.time() - start_time))

结果

Madelung Constant is:: -1.7456432959126562
It takes 2.8461902141571045 seconds