Added fp32 to fp16 conversion function in Python to make haxpy example work
parent
347f0df32f
commit
e0541c41a1
|
@ -49,7 +49,7 @@ FILES = [
|
|||
"/src/clblast_cuda.cpp",
|
||||
"/src/pyclblast/src/pyclblast.pyx"
|
||||
]
|
||||
HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 291]
|
||||
HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 327]
|
||||
FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
|
||||
HEADER_LINES_DOC = 0
|
||||
FOOTER_LINES_DOC = 232
|
||||
|
|
|
@ -13,7 +13,8 @@ import pyclblast
|
|||
|
||||
# Settings for this sample
|
||||
dtype = 'float16'
|
||||
alpha = np.array(1.0).astype(dtype=dtype).item()
|
||||
alpha = 1.5
|
||||
alpha_fp16 = pyclblast.float32_to_float16(alpha)
|
||||
n = 4
|
||||
|
||||
print("# Setting up OpenCL")
|
||||
|
@ -31,7 +32,7 @@ clx.set(x)
|
|||
cly.set(y)
|
||||
|
||||
print("# Example level-1 operation: AXPY")
|
||||
pyclblast.axpy(queue, n, clx, cly, alpha=alpha)
|
||||
pyclblast.axpy(queue, n, clx, cly, alpha=alpha_fp16)
|
||||
queue.finish()
|
||||
print("# Result for vector y: %s" % cly.get())
|
||||
print("# Expected result: %s" % (alpha * x + y))
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
#
|
||||
####################################################################################################
|
||||
|
||||
import binascii
|
||||
import struct
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
from pyopencl.array import Array
|
||||
|
@ -288,6 +290,40 @@ def check_matrix(a, name):
|
|||
def check_vector(a, name):
|
||||
check_array(a, 1, name)
|
||||
|
||||
####################################################################################################
|
||||
# Half-precision utility functions
|
||||
####################################################################################################
|
||||
|
||||
def float32_to_float16(float32):
|
||||
# Taken from https://gamedev.stackexchange.com/a/28756
|
||||
F16_EXPONENT_BITS = 0x1F
|
||||
F16_EXPONENT_SHIFT = 10
|
||||
F16_EXPONENT_BIAS = 15
|
||||
F16_MANTISSA_BITS = 0x3ff
|
||||
F16_MANTISSA_SHIFT = (23 - F16_EXPONENT_SHIFT)
|
||||
F16_MAX_EXPONENT = (F16_EXPONENT_BITS << F16_EXPONENT_SHIFT)
|
||||
|
||||
a = struct.pack('>f', float32)
|
||||
b = binascii.hexlify(a)
|
||||
|
||||
f32 = int(b, 16)
|
||||
sign = (f32 >> 16) & 0x8000
|
||||
exponent = ((f32 >> 23) & 0xff) - 127
|
||||
mantissa = f32 & 0x007fffff
|
||||
|
||||
if exponent == 128:
|
||||
f16 = sign | F16_MAX_EXPONENT
|
||||
if mantissa:
|
||||
f16 |= (mantissa & F16_MANTISSA_BITS)
|
||||
elif exponent > 15:
|
||||
f16 = sign | F16_MAX_EXPONENT
|
||||
elif exponent > -15:
|
||||
exponent += F16_EXPONENT_BIAS
|
||||
mantissa >>= F16_MANTISSA_SHIFT
|
||||
f16 = sign | exponent << F16_EXPONENT_SHIFT | mantissa
|
||||
else:
|
||||
f16 = sign
|
||||
return f16
|
||||
|
||||
####################################################################################################
|
||||
# Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP
|
||||
|
|
Loading…
Reference in New Issue