-
Notifications
You must be signed in to change notification settings - Fork 92
Pyrtl floating point library #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
94c4c68
Pyrtl floating point library
gaborszita 15ad996
Remove FloatWireVector
gaborszita e601040
Address comments
gaborszita 27e4871
Address comments
gaborszita 829f016
Merge remote-tracking branch 'origin/development' into pyrtlfloat
gaborszita 9b2772f
Fix test import error
gaborszita File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode | ||
| from .floatoperations import ( | ||
| BFloat16Operations, | ||
| Float16Operations, | ||
| Float32Operations, | ||
| Float64Operations, | ||
| FloatOperations, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "FloatingPointType", | ||
| "FPTypeProperties", | ||
| "PyrtlFloatConfig", | ||
| "RoundingMode", | ||
| "FloatOperations", | ||
| "BFloat16Operations", | ||
| "Float16Operations", | ||
| "Float32Operations", | ||
| "Float64Operations", | ||
| ] |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| import pyrtl | ||
|
|
||
| from ._types import FPTypeProperties | ||
|
|
||
|
|
||
| def get_sign(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns the sign bit of floating point number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the sign bit. | ||
| """ | ||
| return wire[fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits] | ||
|
|
||
|
|
||
| def get_exponent(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns the exponent bits of floating point number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the exponent bits. | ||
| """ | ||
| return wire[ | ||
| fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits | ||
| + fp_prop.num_exponent_bits | ||
| ] | ||
|
|
||
|
|
||
| def get_mantissa(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns the mantissa bits of floating point number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the mantissa bits. | ||
| """ | ||
| return wire[: fp_prop.num_mantissa_bits] | ||
|
|
||
|
|
||
| def is_zero(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| """ | ||
| Returns whether the floating point number is zero. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is zero. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) == 0) & (get_exponent(fp_prop, wire) == 0) | ||
|
|
||
|
|
||
| def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns whether the floating point number is infinity. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is infinity. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) == 0) & ( | ||
| get_exponent(fp_prop, wire) == (1 << fp_prop.num_exponent_bits) - 1 | ||
| ) | ||
|
|
||
|
|
||
| def is_denormalized( | ||
| fp_prop: FPTypeProperties, wire: pyrtl.WireVector | ||
| ) -> pyrtl.WireVector: | ||
| """ | ||
| Returns whether the floating point number is denormalized. | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is denormalized. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) != 0) & (get_exponent(fp_prop, wire) == 0) | ||
|
|
||
|
|
||
| def is_nan(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: | ||
| """ | ||
| Returns whether the floating point number is NaN. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: 1-bit WireVector indicating whether the number is NaN. | ||
| """ | ||
| return (get_mantissa(fp_prop, wire) != 0) & ( | ||
| get_exponent(fp_prop, wire) == (1 << fp_prop.num_exponent_bits) - 1 | ||
| ) | ||
|
|
||
|
|
||
| def make_denormals_zero( | ||
| fp_prop: FPTypeProperties, wire: pyrtl.WireVector | ||
| ) -> pyrtl.WireVector: | ||
| """ | ||
| Returns zero if denormalized, else original number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param wire: WireVector holding the floating point number. | ||
| :return: WireVector holding the resulting floating point number. | ||
| """ | ||
| out = pyrtl.WireVector( | ||
| bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 | ||
| ) | ||
| with pyrtl.conditional_assignment: | ||
| with get_exponent(fp_prop, wire) == 0: | ||
| out |= pyrtl.concat( | ||
| get_sign(fp_prop, wire), | ||
| get_exponent(fp_prop, wire), | ||
| pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), | ||
| ) | ||
| with pyrtl.otherwise: | ||
| out |= wire | ||
| return out | ||
|
|
||
|
|
||
| def make_inf( | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| fp_prop: FPTypeProperties, | ||
| exponent: pyrtl.WireVector, | ||
| mantissa: pyrtl.WireVector, | ||
| ) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent infinity. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= (1 << fp_prop.num_exponent_bits) - 1 | ||
| mantissa |= 0 | ||
|
|
||
|
|
||
| def make_nan( | ||
| fp_prop: FPTypeProperties, | ||
| exponent: pyrtl.WireVector, | ||
| mantissa: pyrtl.WireVector, | ||
| ) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent NaN. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= (1 << fp_prop.num_exponent_bits) - 1 | ||
| mantissa |= 1 << (fp_prop.num_mantissa_bits - 1) | ||
|
|
||
|
|
||
| def make_zero(exponent: pyrtl.WireVector, mantissa: pyrtl.WireVector) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent zero. | ||
|
|
||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= 0 | ||
| mantissa |= 0 | ||
|
|
||
|
|
||
| def make_largest_finite_number( | ||
| fp_prop: FPTypeProperties, | ||
| exponent: pyrtl.WireVector, | ||
| mantissa: pyrtl.WireVector, | ||
| ) -> None: | ||
| """ | ||
| Sets the exponent and mantissa to represent the largest finite number. | ||
|
|
||
| :param fp_prop: Floating point type properties. | ||
| :param exponent: WireVector to set the exponent bits. | ||
| :param mantissa: WireVector to set the mantissa bits. | ||
| """ | ||
| exponent |= (1 << fp_prop.num_exponent_bits) - 2 | ||
| mantissa |= (1 << fp_prop.num_mantissa_bits) - 1 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| import pyrtl | ||
|
|
||
| from ._float_utills import ( | ||
| get_exponent, | ||
| get_mantissa, | ||
| get_sign, | ||
| is_denormalized, | ||
| is_inf, | ||
| is_nan, | ||
| is_zero, | ||
| make_denormals_zero, | ||
| make_inf, | ||
| make_largest_finite_number, | ||
| make_nan, | ||
| make_zero, | ||
| ) | ||
| from ._types import PyrtlFloatConfig, RoundingMode | ||
|
|
||
|
|
||
| def mul( | ||
| config: PyrtlFloatConfig, | ||
| operand_a: pyrtl.WireVector, | ||
| operand_b: pyrtl.WireVector, | ||
| ) -> pyrtl.WireVector: | ||
| """ | ||
| Performs floating point multiplication of two WireVectors. | ||
|
|
||
| :param config: Configuration for the floating point type and rounding mode. | ||
| :param operand_a: The first floating point operand as a WireVector. | ||
| :param operand_b: The second floating point operand as a WireVector. | ||
| :return: The result of the multiplication as a WireVector. | ||
| """ | ||
| fp_type_props = config.fp_type_properties | ||
| rounding_mode = config.rounding_mode | ||
| num_exp_bits = fp_type_props.num_exponent_bits | ||
| num_mant_bits = fp_type_props.num_mantissa_bits | ||
|
|
||
| # Denormalized numbers are not supported, so we flush them to zero. | ||
| operands = (operand_a, operand_b) | ||
| operands_daz = tuple(make_denormals_zero(fp_type_props, op) for op in operands) | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Extract the sign and exponent of both operands. | ||
| signs = tuple(get_sign(fp_type_props, op) for op in operands_daz) | ||
| exponents = tuple(get_exponent(fp_type_props, op) for op in operands_daz) | ||
|
|
||
| result_sign = signs[0] ^ signs[1] | ||
|
|
||
| # IEEE-754 floating point numbers have a bias: | ||
| # https://en.wikipedia.org/wiki/Exponent_bias | ||
| # real_exponent = stored_exponent - bias, so stored_exponent = real + bias | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| # Therefore, stored_exponent_product = real_exponent_product + bias | ||
| # = (real_exponent_a + real_exponent_b) + bias | ||
| # = (stored_exponent_a - bias + stored_exponent_b - bias) + bias | ||
| # = stored_exponent_a + stored_exponent_b - bias | ||
| operand_exponent_sums = exponents[0] + exponents[1] | ||
| exponent_bias = 2 ** (fp_type_props.num_exponent_bits - 1) - 1 | ||
| product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) | ||
|
|
||
| # Extract the mantissa of both operands and add the implicit leading 1. | ||
| mantissas = tuple( | ||
| pyrtl.concat(pyrtl.Const(1), get_mantissa(fp_type_props, op)) | ||
| for op in operands_daz | ||
| ) | ||
| product_mantissa = mantissas[0] * mantissas[1] | ||
|
|
||
| normalized_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) | ||
| normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
|
|
||
| # We need to normalize (shift right) if the leading bit is 1. | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| # https://numeral-systems.com/ieee-754-multiply/ | ||
| need_to_normalize = product_mantissa[-1] | ||
|
|
||
| if rounding_mode == RoundingMode.RNE: | ||
| guard = pyrtl.WireVector(bitwidth=1) | ||
| sticky = pyrtl.WireVector(bitwidth=1) | ||
| last = pyrtl.WireVector(bitwidth=1) # Last bit of the mantissa before rounding. | ||
|
|
||
| # Assign the normalized mantissa, exponent, guard, sticky, and last bits | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| # based on whether normalization is needed. | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| with pyrtl.conditional_assignment: | ||
| with need_to_normalize: | ||
| normalized_product_mantissa |= product_mantissa[-num_mant_bits - 1 :] | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| normalized_product_exponent |= product_exponent + 1 | ||
| if rounding_mode == RoundingMode.RNE: | ||
| guard |= product_mantissa[-num_mant_bits - 2] | ||
| sticky |= product_mantissa[: -num_mant_bits - 2] != 0 | ||
| last |= product_mantissa[-num_mant_bits - 1] | ||
| with pyrtl.otherwise: | ||
| normalized_product_mantissa |= product_mantissa[-num_mant_bits - 2 : -1] | ||
| normalized_product_exponent |= product_exponent | ||
| if rounding_mode == RoundingMode.RNE: | ||
| guard |= product_mantissa[-num_mant_bits - 3] | ||
| sticky |= product_mantissa[: -num_mant_bits - 3] != 0 | ||
| last |= product_mantissa[-num_mant_bits - 2] | ||
|
|
||
| if rounding_mode == RoundingMode.RNE: | ||
| rounded_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
| rounded_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) | ||
| # Whether exponent was incremented due to rounding (for overflow check). | ||
| exponent_incremented = pyrtl.WireVector(bitwidth=1) | ||
| # If guard bit is not set, number is closer to smaller value: no round. | ||
| # If guard and sticky are set, round up. | ||
| # If guard is set but sticky is not, value is exactly halfway. | ||
| # Following round-to-nearest ties-to-even, round up if last bit is 1. | ||
| round_up = guard & (last | sticky) | ||
| with pyrtl.conditional_assignment: | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| with round_up: | ||
| with normalized_product_mantissa == (1 << num_mant_bits) - 1: | ||
| rounded_product_mantissa |= 0 | ||
| rounded_product_exponent |= normalized_product_exponent + 1 | ||
| exponent_incremented |= 1 | ||
| with pyrtl.otherwise: | ||
| rounded_product_mantissa |= normalized_product_mantissa + 1 | ||
| rounded_product_exponent |= normalized_product_exponent | ||
| exponent_incremented |= 0 | ||
| with pyrtl.otherwise: | ||
| rounded_product_mantissa |= normalized_product_mantissa | ||
| rounded_product_exponent |= normalized_product_exponent | ||
| exponent_incremented |= 0 | ||
|
|
||
| result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) | ||
| result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
|
|
||
| # Check whether operands are special: NaN, infinity, zero, or denormalized. | ||
| operand_nans = tuple(is_nan(fp_type_props, op) for op in operands_daz) | ||
| operand_infs = tuple(is_inf(fp_type_props, op) for op in operands_daz) | ||
| operand_zeros = tuple(is_zero(fp_type_props, op) for op in operands_daz) | ||
| operand_denorms = tuple(is_denormalized(fp_type_props, op) for op in operands_daz) | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
|
|
||
| # We check for overflow and underflow by computing max and min exponent | ||
| # values of the sum of operands before rounding and normalization. | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| # These values depend on the operands. If the result requires | ||
| # normalization, the exponent is incremented by 1. Additionally, rounding | ||
| # may further increase the exponent. Therefore, we subtract these | ||
| # potential increments from the absolute maximum exponent, which is one | ||
| # less than the all-1s exponent (reserved for inf/NaN) plus bias. | ||
| # Similarly, we subtract these increments from the absolute minimum | ||
| # exponent, which is 1 plus the exponent bias. | ||
| sum_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2 + exponent_bias) | ||
| sum_exponent_min_value = pyrtl.Const(1 + exponent_bias) | ||
| if rounding_mode == RoundingMode.RNE: | ||
| exponent_max_value = ( | ||
| sum_exponent_max_value - need_to_normalize - exponent_incremented | ||
|
gaborszita marked this conversation as resolved.
Outdated
|
||
| ) | ||
| exponent_min_value = ( | ||
| sum_exponent_min_value - need_to_normalize - exponent_incremented | ||
| ) | ||
| else: | ||
| exponent_max_value = sum_exponent_max_value - need_to_normalize | ||
| exponent_min_value = sum_exponent_min_value - need_to_normalize | ||
|
|
||
| # Assign the raw result's exponent and mantissa depending on whether RNE rounding | ||
| # is used. The calculated exponent WireVector has an extra bit due to the carry-out | ||
| # from addition, so we take only the lower num_exp_bits to remove this extra bit. | ||
| if rounding_mode == RoundingMode.RNE: | ||
| raw_result_exponent = rounded_product_exponent[:num_exp_bits] | ||
| raw_result_mantissa = rounded_product_mantissa | ||
| else: | ||
| raw_result_exponent = normalized_product_exponent[:num_exp_bits] | ||
| raw_result_mantissa = normalized_product_mantissa | ||
|
|
||
| with pyrtl.conditional_assignment: | ||
| # If either operand is NaN, or if one operand is infinity and the other is | ||
| # zero, the result is NaN. | ||
| with ( | ||
| operand_nans[0] | ||
| | operand_nans[1] | ||
| | (operand_infs[0] & operand_zeros[1]) | ||
| | (operand_zeros[0] & operand_infs[1]) | ||
| ): | ||
| make_nan(fp_type_props, result_exponent, result_mantissa) | ||
| # If either operand is infinity, the result is infinity. | ||
| with operand_infs[0] | operand_infs[1]: | ||
| make_inf(fp_type_props, result_exponent, result_mantissa) | ||
| # Detect overflow. | ||
| with operand_exponent_sums > exponent_max_value: | ||
| if rounding_mode == RoundingMode.RNE: | ||
| make_inf(fp_type_props, result_exponent, result_mantissa) | ||
| else: | ||
| make_largest_finite_number( | ||
| fp_type_props, result_exponent, result_mantissa | ||
| ) | ||
| # If either operand is zero, if underflow occurred, or if either operand is | ||
| # denormalized, the result is zero. | ||
| with ( | ||
| operand_zeros[0] | ||
| | operand_zeros[1] | ||
| | (operand_exponent_sums < exponent_min_value) | ||
| | operand_denorms[0] | ||
| | operand_denorms[1] | ||
| ): | ||
| make_zero(result_exponent, result_mantissa) | ||
| with pyrtl.otherwise: | ||
| result_exponent |= raw_result_exponent | ||
| result_mantissa |= raw_result_mantissa | ||
|
|
||
| return pyrtl.concat(result_sign, result_exponent, result_mantissa) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.