Added fp32 to fp16 conversion function in Python to make haxpy example work

pull/348/head
Cedric Nugteren 2019-01-23 19:52:01 +01:00
parent 347f0df32f
commit e0541c41a1
3 changed files with 40 additions and 3 deletions

View File

@ -49,7 +49,7 @@ FILES = [
"/src/clblast_cuda.cpp",
"/src/pyclblast/src/pyclblast.pyx"
]
HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 291]
HEADER_LINES = [124, 21, 128, 24, 29, 45, 29, 66, 40, 96, 21, 327]
FOOTER_LINES = [98, 57, 112, 275, 6, 6, 6, 9, 2, 41, 56, 37]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 232

View File

@ -13,7 +13,8 @@ import pyclblast
# Settings for this sample
dtype = 'float16'
alpha = np.array(1.0).astype(dtype=dtype).item()
alpha = 1.5
alpha_fp16 = pyclblast.float32_to_float16(alpha)
n = 4
print("# Setting up OpenCL")
@ -31,7 +32,7 @@ clx.set(x)
cly.set(y)
print("# Example level-1 operation: AXPY")
pyclblast.axpy(queue, n, clx, cly, alpha=alpha)
pyclblast.axpy(queue, n, clx, cly, alpha=alpha_fp16)
queue.finish()
print("# Result for vector y: %s" % cly.get())
print("# Expected result: %s" % (alpha * x + y))

View File

@ -11,6 +11,8 @@
#
####################################################################################################
import binascii
import struct
import numpy as np
import pyopencl as cl
from pyopencl.array import Array
@ -288,6 +290,40 @@ def check_matrix(a, name):
def check_vector(a, name):
check_array(a, 1, name)
####################################################################################################
# Half-precision utility functions
####################################################################################################
def float32_to_float16(float32):
# Taken from https://gamedev.stackexchange.com/a/28756
F16_EXPONENT_BITS = 0x1F
F16_EXPONENT_SHIFT = 10
F16_EXPONENT_BIAS = 15
F16_MANTISSA_BITS = 0x3ff
F16_MANTISSA_SHIFT = (23 - F16_EXPONENT_SHIFT)
F16_MAX_EXPONENT = (F16_EXPONENT_BITS << F16_EXPONENT_SHIFT)
a = struct.pack('>f', float32)
b = binascii.hexlify(a)
f32 = int(b, 16)
sign = (f32 >> 16) & 0x8000
exponent = ((f32 >> 23) & 0xff) - 127
mantissa = f32 & 0x007fffff
if exponent == 128:
f16 = sign | F16_MAX_EXPONENT
if mantissa:
f16 |= (mantissa & F16_MANTISSA_BITS)
elif exponent > 15:
f16 = sign | F16_MAX_EXPONENT
elif exponent > -15:
exponent += F16_EXPONENT_BIAS
mantissa >>= F16_MANTISSA_SHIFT
f16 = sign | exponent << F16_EXPONENT_SHIFT | mantissa
else:
f16 = sign
return f16
####################################################################################################
# Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP