使用Numba加速python科学计算代码

win10下安装Numba

安装参考Numba官网

windows10平台下的安装(NVIDIA显卡),其中cuda是NVIDIA显卡工具:

conda install numba
conda install cudatoolkit 

检查安装:

C:\>python
Python 3.8.5 (default, Sep  3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import numba
>>> numba.__version__
'0.51.2'


 C:\>numba
numba: error: the following arguments are required: filename

(base) C:\>numba --sysinfo
System info:
--------------------------------------------------------------------------------
__CUDA Information__
CUDA Device Initialized                       : True
CUDA Driver Version                           : 11000
CUDA Detect Output:
Found 1 CUDA devices
id 0     b'GeForce GTX 1650'                              [SUPPORTED]
                      compute capability: 7.5
                           pci device id: 0
                              pci bus id: 1
Summary:
        1/1 devices are supported

速度测试

参考: numba,让python速度提升百倍

# -*- coding: utf-8 -*-

from numba import njit,float64
import datetime

# 不使用numba的情况
def f_1(n):
    x = 123.45
    y = 678.9
    for i in range(1,n):
        temp = x * y / i

    return temp

# 使用 numba
@njit
def f_2(n):
    x = 123.45
    y = 678.9
    for i in range(1,n):
        temp = x * y / i

    return temp


# 使用 numba
@njit (float64(float64))
def f_3(n):
    x = 123.45
    y = 678.9
    for i in range(1,n):
        temp = x * y / i

    return temp


start=datetime.datetime.now()
r1 = f_1(100000000)
elapsed_1 = (datetime.datetime.now() - start)  # 秒数
elapsed_1 = elapsed_1.microseconds

start=datetime.datetime.now()
r2=f_2(100000000)
elapsed_2 = (datetime.datetime.now() - start)  # 秒数
elapsed_2 = elapsed_2.microseconds

start=datetime.datetime.now()
r3=f_3(100000000)
elapsed_3 = (datetime.datetime.now() - start)  # 秒数
elapsed_3 = elapsed_3.microseconds

print("函数1消耗微秒数:",elapsed_1)
print("函数2消耗微秒数:",elapsed_2)
print("函数1消耗微秒数:",elapsed_3)
print("函数1消耗微秒数/函数2消耗微秒数:",elapsed_1/elapsed_2)

print('r1,r2,r3:',r1,r2,r3)
'''输出:
函数1消耗微秒数: 250370
函数2消耗微秒数: 34907
函数1消耗微秒数: 0
函数1消耗微秒数/函数2消耗微秒数: 7.17248689374624
r1,r2,r3: 0.0008381020583810206 0.0008381020583810206 0.0008381020583810206
'''

数值计算中, Numba提升明显,指定参数类型时提速更夸张。一定用njit注解强制转换,并指定数据类型

python中含有None等情形或其他逻辑处理的写法:

def func_without_numba():
    # 调用含有 Numba 代码的函数
    try:
        func_with_numba()
    except:
        # 这里写异常处理代码
        pass

@njit(参数)
def func_with_numba()
    # 不需要处理异常的数值计算代码
    pass
© Licensed under CC BY-NC-SA 4.0

你自己的代码如果超过6个月不看,再看的时候也一样像是别人写——伊格尔森定律

发表我的评论
取消评论
表情

Hi,您需要填写昵称和邮箱!