-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OPEN: FP integration #3
Changes from all commits
ae9f793
65bdbac
b67df55
6640216
0ebe43f
33920fa
326c793
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
from __future__ import annotations | ||
|
||
import copy | ||
import math | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Dict, Generic, Iterable, List, Optional, Type, TypeVar, Union | ||
|
@@ -234,6 +235,139 @@ def checkValue(cls, value: Union[int, Iterable[int]], ctxt: Optional[_NetworkCon | |
return True | ||
|
||
|
||
class FloatImmediate(Immediate[Union[float, Iterable[float]], _ImmediateType]): | ||
typeFraction: int #: int: Represents the number of bits reserved for the fraction part | ||
typeExponent: int #: int: Represents the number of bits reserved for the exponent part | ||
signed: bool #: bool: Represents whether the underlying float is signed or unsigned (should be removed) | ||
|
||
@_classproperty | ||
def typeExponentMax(cls) -> int: | ||
# In floating point, all 1 in exponent is reserved for special numbers (i.e. NaN or Inf) | ||
return 2**(cls.typeExponent) - 2 | ||
|
||
@_classproperty | ||
def typeExponentOffset(cls) -> int: | ||
# The offset added to the exponent | ||
return 2**(cls.typeExponent - 1) - 1 | ||
|
||
# ADEQUINO: This is a ugly workaround for FP, works for bfloat16 and fp32 because bfloat16 is a truncated fp32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand this comment. What about the code is a workaround? |
||
@classmethod | ||
def partialOrderUpcast(cls, otherCls: Type[Immediate]) -> bool: | ||
if issubclass(otherCls, FloatImmediate): | ||
return cls.typeFraction >= otherCls.typeFraction and cls.typeExponent >= otherCls.typeExponent | ||
else: | ||
return False | ||
|
||
@classmethod | ||
def checkValue(cls, value: Union[float, Iterable[float]], ctxt: Optional[_NetworkContext] = None): | ||
""" | ||
This method tries to manually cast standard python's standard immediate float precision values | ||
(64 bits) to an arbitrary FP representation and check if the new representation is close enough | ||
to the original value. | ||
""" | ||
_val_list = [] | ||
|
||
if isinstance(value, float): | ||
_val_list.append(value) | ||
elif isinstance(value, np.ndarray): | ||
_val_list = value.tolist() | ||
elif isinstance(value, Iterable): | ||
for i in value: | ||
_val_list.append(i) | ||
else: | ||
raise Exception("Immediate type not recognized.") | ||
|
||
for val in _val_list: | ||
# Zero (and subnormals, not implemented) are special cases | ||
if (val == 0): | ||
continue | ||
# Make the value positive | ||
if (val < 0): | ||
val = val * -1 | ||
|
||
# Separate Integer and Fraction of immediate | ||
fraction, integer = math.modf(val) | ||
|
||
# Binarylist for the mantissa | ||
binarylist = [] | ||
f = fraction | ||
|
||
# Fraction binarization, fails if nbits required > n bits mantissa. | ||
# If integer part of immediate is 0, we start counting mantissa bits after we find the first 1 bit. | ||
if (int(integer) > 0): | ||
for i in range(cls.typeFraction): | ||
f = f * 2 | ||
f, fint = math.modf(f) | ||
binarylist.append(str(int(fint))) | ||
if f == 0: | ||
break | ||
elif i == (cls.typeFraction - 1): | ||
return False | ||
else: | ||
flag = 0 | ||
count = cls.typeFraction + 1 | ||
while (count): | ||
f = f * 2 | ||
f, fint = math.modf(f) | ||
binarylist.append(str(int(fint))) | ||
if int(fint) == 1 and flag == 0: | ||
flag = 1 | ||
if f == 0: | ||
break | ||
if flag == 1: | ||
count = count - 1 | ||
if (count == 0): | ||
return False | ||
Comment on lines
+295
to
+320
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of this float to string to list to int casting seems unnecessary to me. |
||
|
||
# Float exponent part | ||
# It's equal to the length of the integer part minus 1, if the integer part is not zero. | ||
# Otherwise, it's minus the number of 0 bits before the first 1 bit in the fraction representation + 1 | ||
exponent = 0 | ||
if (int(bin(int(integer))[2:]) == 0): | ||
for b in binarylist: | ||
exponent = exponent - 1 | ||
if b == '1': | ||
break | ||
else: | ||
exponent = len(str(bin(int(integer))[2:])) - 1 | ||
|
||
# Check if exponent is representable in n_exponent bits | ||
true_exponent = int(bin(cls.typeExponentOffset + exponent)[2:]) | ||
if (cls.typeExponentOffset + exponent) > cls.typeExponentMax or (cls.typeExponentOffset + exponent) < 0: | ||
return False | ||
|
||
# Append bits to head of mantissa, if integer part is not in scientific notion | ||
binarylist2 = [] | ||
if len(str(bin(int(integer))[2:])) > 1: | ||
for digit in str(bin(int(integer))[3:]): | ||
binarylist2.append((digit)) | ||
|
||
# If integer part is zero, trim the mantissa bits that have been used to calculate the exponent part | ||
if (int(integer) > 0): | ||
finalbinaryfraction = binarylist2 + binarylist | ||
else: | ||
finalbinaryfraction = binarylist | ||
while (finalbinaryfraction[0] == '0'): | ||
finalbinaryfraction.pop(0) | ||
finalbinaryfraction.pop(0) | ||
|
||
# Fix mantissa size | ||
if ((cls.typeFraction - len(finalbinaryfraction)) > 0): | ||
finalbinaryfraction += ['0'] * (cls.typeFraction - len(finalbinaryfraction)) | ||
if (len(finalbinaryfraction) > cls.typeFraction): | ||
finalbinaryfraction = finalbinaryfraction[:cls.typeFraction] | ||
|
||
# Check if the value in binary float represent the immediate value | ||
exponent_part = 2**exponent | ||
mantissa_part = 1 | ||
for (i, m) in enumerate(finalbinaryfraction): | ||
mantissa_part = mantissa_part + 2**(-(i + 1)) * int(m) | ||
if (exponent_part * mantissa_part != val): | ||
return False | ||
|
||
return True | ||
|
||
|
||
class Pointer(BaseType[Optional[str], _PointerType]): | ||
"""Represents a C Pointer type to an underlying BaseType data type | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
|
||
from typing import Tuple, Type | ||
|
||
from Deeploy.AbstractDataTypes import IntegerImmediate | ||
from Deeploy.AbstractDataTypes import FloatImmediate, IntegerImmediate | ||
|
||
|
||
class int8_t(IntegerImmediate): | ||
|
@@ -76,10 +76,27 @@ class uint64_t(IntegerImmediate): | |
signed = False | ||
|
||
|
||
class bfloat16(FloatImmediate): | ||
typeName = "float16alt" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest to keep the |
||
typeWidth = 16 | ||
typeFraction = 7 | ||
typeExponent = 8 | ||
signed = True | ||
|
||
|
||
class float32(FloatImmediate): | ||
typeName = "float" | ||
typeWidth = 32 | ||
typeFraction = 23 | ||
typeExponent = 8 | ||
signed = True | ||
|
||
|
||
SignedIntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (int8_t, int16_t, int32_t, int64_t) | ||
UnsignedIntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (uint8_t, uint16_t, uint32_t, uint64_t) | ||
IntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (sorted(( | ||
*SignedIntegerDataTypes, | ||
*UnsignedIntegerDataTypes, | ||
), | ||
key = lambda _type: _type.typeWidth)) | ||
FloatDataTypes: Tuple[Type[FloatImmediate], ...] = (bfloat16, float32) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# ---------------------------------------------------------------------- | ||
# | ||
# File: FloatAddTemplate.py | ||
# | ||
# Last edited: 15.12.2021 | ||
# | ||
# Copyright (C) 2021, ETH Zurich and University of Bologna. | ||
# | ||
# Author: Moritz Scherer, ETH Zurich | ||
# | ||
# ---------------------------------------------------------------------- | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the License); you may | ||
# not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an AS IS BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation | ||
|
||
|
||
class _FloatAddTemplate(NodeTemplate): | ||
|
||
def alignToContext(self, ctxt: NetworkContext, | ||
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: | ||
|
||
data_in_1 = ctxt.lookup(operatorRepresentation['data_in_1']) | ||
data_in_2 = ctxt.lookup(operatorRepresentation['data_in_2']) | ||
data_out = ctxt.lookup(operatorRepresentation['data_out']) | ||
|
||
input_1_offset = 0 | ||
if hasattr(data_in_1, "_signed") and hasattr(data_in_1, "nLevels"): | ||
input_1_offset = (data_in_1._signed == 0) * int(data_in_1.nLevels / 2) | ||
input_2_offset = 0 | ||
if hasattr(data_in_2, "_signed") and hasattr(data_in_2, "nLevels"): | ||
input_2_offset = (data_in_2._signed == 0) * int(data_in_2.nLevels / 2) | ||
output_offset = 0 | ||
if hasattr(data_out, "_signed") and hasattr(data_out, "nLevels"): | ||
output_offset = -(data_out._signed == 0) * int(data_out.nLevels // 2) | ||
|
||
operatorRepresentation['offset'] = input_1_offset + input_2_offset + output_offset | ||
|
||
return ctxt, operatorRepresentation, [] | ||
Comment on lines
+40
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite understand the use of |
||
|
||
|
||
referenceTemplate = _FloatAddTemplate(""" | ||
// Add (Name: ${nodeName}, Op: ${nodeOp}) | ||
BEGIN_SINGLE_CORE | ||
for (uint32_t i=0;i<${size};i++){ | ||
${data_out}[i] = ${data_in_1}[i] + ${data_in_2}[i] + ${offset}; | ||
} | ||
END_SINGLE_CORE | ||
""") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is not needed, please remove it :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, float numbers are all signed after all, I just put the bool here to make sure it wouldn't conflict with anything else in the framework