diff --git a/pyrtl/rtllib/pyrtlfloat/__init__.py b/pyrtl/rtllib/pyrtlfloat/__init__.py new file mode 100644 index 00000000..d9b64710 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/__init__.py @@ -0,0 +1,20 @@ +from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode +from .floatoperations import ( + BFloat16Operations, + Float16Operations, + Float32Operations, + Float64Operations, + FloatOperations, +) + +__all__ = [ + "FloatingPointType", + "FPTypeProperties", + "PyrtlFloatConfig", + "RoundingMode", + "FloatOperations", + "BFloat16Operations", + "Float16Operations", + "Float32Operations", + "Float64Operations", +] diff --git a/pyrtl/rtllib/pyrtlfloat/_add_sub.py b/pyrtl/rtllib/pyrtlfloat/_add_sub.py new file mode 100644 index 00000000..74ff2d6c --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_add_sub.py @@ -0,0 +1,479 @@ +import pyrtl + +from ._float_utils import ( + _fp_wire_struct, + _RawResult, + _RawResultGRS, + _round_rne, + check_kinds, + make_denormals_zero, + make_inf, + make_largest_finite_number, + make_nan, + make_zero, +) +from ._types import FPTypeProperties, PyrtlFloatConfig, RoundingMode + + +def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> pyrtl.WireVector: + """ + Performs floating point addition of two WireVectors. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the addition as a WireVector. + """ + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + FP = _fp_wire_struct(num_exp_bits, num_mant_bits) + + # Denormalized numbers are not supported, so we flush them to zero. + operands = tuple( + make_denormals_zero(fp_type_props, op) for op in (operand_a, operand_b) + ) + fps = _sort_operands(FP, operands) + del operands + + # Align mantissas and compute both the addition and subtraction results. + smaller_mantissa_shifted_grs, larger_mantissa_extended = _align_mantissa(fps) + sum_result, sum_carry = _add_operands( + fps[1].exponent, smaller_mantissa_shifted_grs, larger_mantissa_extended + ) + difference_result, num_leading_zeros = _sub_operands( + num_mant_bits, + fps[1].exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + del smaller_mantissa_shifted_grs, larger_mantissa_extended + + # Select the correct result based on operand signs, then round if needed. + raw_result, rounding_exponent_incremented = _select_and_round( + fp_type_props, + fps, + sum_result, + difference_result, + rounding_mode, + ) + del sum_result, difference_result + + return _handle_special_cases( + FP, + fp_type_props, + fps, + raw_result, + sum_carry, + num_leading_zeros, + rounding_mode, + rounding_exponent_incremented, + ) + + +def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> pyrtl.WireVector: + """ + Performs floating point subtraction of two WireVectors. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the subtraction as a WireVector. + """ + num_exp_bits = config.fp_type_properties.num_exponent_bits + num_mant_bits = config.fp_type_properties.num_mantissa_bits + operand_b_negated = operand_b ^ pyrtl.concat( + pyrtl.Const(1, bitwidth=1), + pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), + ) + return add(config, operand_a, operand_b_negated) + + +def _sort_operands( + FP, + operands: tuple, +) -> tuple: + """ + Sorts operands by absolute value. + + :param FP: The FP wire_struct class for the current floating point type. + :param operands: Tuple of two operand WireVectors with denormals flushed to zero. + :return: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. + """ + total_bits = operands[0].bitwidth + sorted_operands = [pyrtl.WireVector(bitwidth=total_bits) for _ in range(2)] + with pyrtl.conditional_assignment: + # Compare the lower (total_bits - 1) bits, which excludes the sign bit, + # to determine the operand with the smaller absolute value. + with operands[0][: total_bits - 1] < operands[1][: total_bits - 1]: + sorted_operands[0] |= operands[0] + sorted_operands[1] |= operands[1] + with pyrtl.otherwise: + sorted_operands[0] |= operands[1] + sorted_operands[1] |= operands[0] + return tuple(FP(FP=op) for op in sorted_operands) + + +def _align_mantissa( + fps: tuple, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + """ + Aligns the smaller mantissa to the larger operand's exponent and computes + the guard, round, and sticky (GRS) bits for RNE rounding. + + :param fps: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. + :return: Tuple of (smaller_mantissa_shifted_grs, larger_mantissa_extended). + """ + num_mant_bits = fps[0].mantissa.bitwidth + mantissas_with_leading_1 = tuple( + pyrtl.concat(pyrtl.Const(1), fp.mantissa) for fp in fps + ) + + # Align mantissas by shifting the smaller one to match the larger's exponent. + # Shifting the mantissa right by one divides the value by two, while adding + # one to the exponent multiplies the value by two. Doing both simultaneously + # preserves the value while matching the operands' exponents for addition. + shift_amount = fps[1].exponent - fps[0].exponent + smaller_mantissa_shifted = pyrtl.shift_right_logical( + mantissas_with_leading_1[0], shift_amount + ) + + # RNE rounding uses the guard, round, and sticky bits. + # When shifting the smaller mantissa to the right, some bits are shifted out. + # The most significant bit to the right of the mantissa after shifting becomes + # the guard bit, the next bit becomes the round bit, and any remaining bits + # are ORed together to form the sticky bit. + # https://drilian.com/posts/2023.01.10-floating-point-numbers-and-rounding/ + grs = pyrtl.WireVector(bitwidth=3) + with pyrtl.conditional_assignment: + # If the smaller mantissa is shifted by 2 or more, the two most + # significant bits shifted out are the guard and round bits, and the + # sticky bit is the OR of all remaining bits. + with shift_amount >= 2: + guard_and_round = pyrtl.shift_right_logical( + mantissas_with_leading_1[0], shift_amount - 2 + )[:2] + # Mask with the least significant (shift_amount - 2) bits set to 1 + mask = ( + pyrtl.shift_left_logical( + pyrtl.Const(1, bitwidth=num_mant_bits), + shift_amount - 2, + ) + - 1 + ) + sticky = (mantissas_with_leading_1[0] & mask) != 0 + grs |= pyrtl.concat(guard_and_round, sticky) + # If the smaller mantissa is shifted by 1, the single bit shifted out + # is the guard bit; the round bit and sticky bit are both 0. + with shift_amount == 1: + grs |= pyrtl.concat( + mantissas_with_leading_1[0][0], pyrtl.Const(0, bitwidth=2) + ) + # If not shifted, guard, round, and sticky bits are all 0. + with pyrtl.otherwise: + grs |= 0 + + # Concatenate the shifted smaller mantissa with the GRS bits, and extend + # the larger mantissa with three zeros to align it with the smaller mantissa. + smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs) + larger_mantissa_extended = pyrtl.concat( + mantissas_with_leading_1[1], pyrtl.Const(0, bitwidth=3) + ) + return smaller_mantissa_shifted_grs, larger_mantissa_extended + + +def _select_and_round( + fp_type_props: FPTypeProperties, + fps: tuple, + sum_result: _RawResultGRS, + diff_result: _RawResultGRS, + rounding_mode: RoundingMode, +) -> tuple: + """ + Selects the addition or subtraction result based on operand signs, then + applies RNE rounding if configured. + + :param fp_type_props: Floating point type properties. + :param fps: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. + :param sum_result: _RawResultGRS from the addition operation. + :param diff_result: _RawResultGRS from the subtraction operation. + :param rounding_mode: The rounding mode to apply. + :return: Tuple of (_RawResult, rounding_exponent_incremented). The second + element is None for RTZ rounding mode. + """ + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + + raw_result = _RawResult( + exponent=pyrtl.WireVector(bitwidth=num_exp_bits), + mantissa=pyrtl.WireVector(bitwidth=num_mant_bits), + ) + if rounding_mode == RoundingMode.RNE: + raw_grs = pyrtl.WireVector(bitwidth=3) + + # Determine whether we need to add or subtract the operands. + with pyrtl.conditional_assignment: + # If the operands have the same sign, we perform addition. + # For example, (+a) + (+b) or (-a) + (-b). + with fps[0].sign == fps[1].sign: + raw_result.exponent |= sum_result.exponent + raw_result.mantissa |= sum_result.mantissa + if rounding_mode == RoundingMode.RNE: + raw_grs |= sum_result.grs + # If the operands have different signs, we perform subtraction. + # For example, (+a) + (-b) or (-a) + (+b). + with pyrtl.otherwise: + raw_result.exponent |= diff_result.exponent + raw_result.mantissa |= diff_result.mantissa + if rounding_mode == RoundingMode.RNE: + raw_grs |= diff_result.grs + + if rounding_mode == RoundingMode.RNE: + return _round_rne(raw_result, raw_grs) + # No additional rounding logic needed for RTZ rounding mode + return raw_result, None + + +def _handle_special_cases( + FP, + fp_type_props, + fps: tuple, + raw_result: _RawResult, + sum_carry: pyrtl.WireVector, + num_leading_zeros: pyrtl.WireVector, + rounding_mode: RoundingMode, + rounding_exponent_incremented, +): + """ + Handles special cases: NaN, infinity, zero, overflow, and underflow. + + :param FP: The FP wire_struct class for the current floating point type. + :param fp_type_props: Floating point type properties. + :param fps: Tuple of (smaller_fp, larger_fp) as FP wire_struct instances. + :param raw_result: Pre-rounding result as a _RawResult. + :param sum_carry: Carry bit from the addition operation. + :param num_leading_zeros: Leading zero count from the subtraction normalization. + :param rounding_mode: The rounding mode being used. + :param rounding_exponent_incremented: Whether rounding incremented the exponent + (None for RTZ mode). + :return: The final FP wire_struct result. + """ + num_exp_bits = fp_type_props.num_exponent_bits + + operand_kinds = tuple(check_kinds(fp) for fp in fps) + + # Pre-compute special value constants for use inside conditional_assignment. + final_result = FP(sign=None, exponent=None, mantissa=None) + nan_exp, nan_mant = make_nan(fp_type_props) + inf_exp, inf_mant = make_inf(fp_type_props) + zero_exp, zero_mant = make_zero(fp_type_props) + largest_exp, largest_mant = make_largest_finite_number(fp_type_props) + + # Check for overflow on addition. + # We check for overflow by calculating the max value of the larger + # operand's exponent. This value can vary depending on the operands. + # If there was a carry out from the addition, the result exponent is + # incremented by 1. Additionally, if rounding causes the exponent to + # increment, we need to account for that as well. Therefore, we + # subtract these increments from the absolute maximum exponent, which + # is one less than the all-1s exponent (reserved for infinity/NaN). + # However, instead of performing subtractions in hardware, we use + # conditional assignments to determine the appropriate maximum exponent + # value. Since we are merely selecting among three possible values, + # instantiating a subtractor would be overkill. + base_exponent_max_value = 2**num_exp_bits - 2 + exponent_max_value = pyrtl.WireVector(bitwidth=num_exp_bits) + if rounding_mode == RoundingMode.RNE: + with pyrtl.conditional_assignment: + with rounding_exponent_incremented & sum_carry: + exponent_max_value |= base_exponent_max_value - 2 + with rounding_exponent_incremented | sum_carry: + exponent_max_value |= base_exponent_max_value - 1 + with pyrtl.otherwise: + exponent_max_value |= base_exponent_max_value + else: + with pyrtl.conditional_assignment: + with sum_carry: + exponent_max_value |= base_exponent_max_value - 1 + with pyrtl.otherwise: + exponent_max_value |= base_exponent_max_value + + # Check for underflow on subtraction. + # We check for underflow by computing the min value of the larger + # operand's exponent. As with overflow, this value can vary depending + # on the operands. We subtract the number of leading zeros from the + # larger exponent to obtain the subtraction exponent. Additionally, + # if rounding causes the exponent to increment, we need to account + # for that. Therefore, we add the number of leading zeros and + # subtract the rounding increment from the absolute minimum exponent, + # which is one greater than the all-0s exponent (reserved for + # zero and denormals). Similarly to overflow checking, we use a + # combinational assignment to account for the rounding increment, + # avoiding the need to instantiate an adder. + base_exponent_min_value = pyrtl.WireVector(bitwidth=num_exp_bits) + if rounding_mode == RoundingMode.RNE: + with pyrtl.conditional_assignment: + with rounding_exponent_incremented: + base_exponent_min_value |= 0 + with pyrtl.otherwise: + base_exponent_min_value |= 1 + else: + base_exponent_min_value <<= 1 + exponent_min_value = num_leading_zeros + base_exponent_min_value + + with pyrtl.conditional_assignment: + # If either operand is NaN, or if both operands are infinities with + # opposite signs, the result is NaN. + with ( + operand_kinds[0].is_nan + | operand_kinds[1].is_nan + | ( + operand_kinds[0].is_inf + & operand_kinds[1].is_inf + & (fps[1].sign != fps[0].sign) + ) + ): + final_result.sign |= fps[1].sign + final_result.exponent |= nan_exp + final_result.mantissa |= nan_mant + + # If either operand is infinity, result is infinity with that sign. + with operand_kinds[0].is_inf: + final_result.sign |= fps[1].sign + final_result.exponent |= inf_exp + final_result.mantissa |= inf_mant + with operand_kinds[1].is_inf: + final_result.sign |= fps[1].sign + final_result.exponent |= inf_exp + final_result.mantissa |= inf_mant + + # If operands are equal in magnitude but opposite in sign, the result is +0. + with ( + (fps[0].mantissa == fps[1].mantissa) + & (fps[0].exponent == fps[1].exponent) + & (fps[1].sign != fps[0].sign) + ): + final_result.sign |= 0 + final_result.exponent |= zero_exp + final_result.mantissa |= zero_mant + + # If either operand is zero, the result is the other operand. + with operand_kinds[0].is_zero: + final_result.sign |= fps[1].sign + final_result.mantissa |= fps[1].mantissa + final_result.exponent |= fps[1].exponent + with operand_kinds[1].is_zero: + final_result.sign |= fps[0].sign + final_result.mantissa |= fps[0].mantissa + final_result.exponent |= fps[0].exponent + + # Checks if an addition was performed and the result overflowed. + with (fps[0].sign == fps[1].sign) & (fps[1].exponent > exponent_max_value): + final_result.sign |= fps[1].sign + # IEEE 754 Section 7.4: On overflow, RNE rounds to infinity, + # while truncation rounds to the largest finite number. + if rounding_mode == RoundingMode.RNE: + final_result.exponent |= inf_exp + final_result.mantissa |= inf_mant + else: + final_result.exponent |= largest_exp + final_result.mantissa |= largest_mant + + # Checks if a subtraction was performed and the result underflowed. + with (fps[0].sign != fps[1].sign) & (fps[1].exponent < exponent_min_value): + final_result.sign |= fps[1].sign + final_result.exponent |= zero_exp + final_result.mantissa |= zero_mant + # Otherwise no special cases apply: this is the common case. + with pyrtl.otherwise: + final_result.sign |= fps[1].sign + final_result.exponent |= raw_result.exponent + final_result.mantissa |= raw_result.mantissa + + return final_result + + +def _add_operands( + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, +) -> tuple[_RawResultGRS, pyrtl.WireVector]: + """ + Helper function for performing addition of two floating point mantissas. + + :param larger_operand_exponent: Exponent of the larger operand. + :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand + shifted to align with the larger operand and concatenated with GRS. + :param larger_mantissa_extended: Larger mantissa with three zeros. + :return: Tuple of (_RawResultGRS, carry bit). + """ + sum_mantissa_grs = pyrtl.WireVector() + sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs + sum_carry = sum_mantissa_grs[-1] + # Pick the correct bits for the mantissa and GRS based on carry out. + sum_mantissa = pyrtl.select(sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1]) + sum_grs = pyrtl.select( + sum_carry, + pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0), + sum_mantissa_grs[:3], + ) + # Increment the exponent if there was a carry out. + sum_exponent = pyrtl.select( + sum_carry, larger_operand_exponent + 1, larger_operand_exponent + ) + return _RawResultGRS(sum_exponent, sum_mantissa, sum_grs), sum_carry + + +def _sub_operands( + num_mant_bits: int, + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, +) -> tuple[_RawResultGRS, pyrtl.WireVector]: + """ + Helper function for performing subtraction of two floating point mantissas. + + :param num_mant_bits: Number of mantissa bits. + :param larger_operand_exponent: Exponent of the larger operand. + :param smaller_mantissa_shifted_grs: Mantissa of the smaller operand + shifted to align with the larger operand and concatenated with GRS. + :param larger_mantissa_extended: Larger mantissa with three zeros. + :return: Tuple of (_RawResultGRS, num leading zeros). + """ + + # Priority encoder that counts the number of leading zeros in a WireVector. + def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): + out = pyrtl.WireVector( + bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth + ) + with pyrtl.conditional_assignment: + for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1): + with wire[i]: + out |= wire.bitwidth - i - 1 + return out + + difference_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4) + difference_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs + # Normalize result by shifting left until leading 1 is in position. + num_leading_zeros = leading_zero_priority_encoder( + difference_mantissa_grs, num_mant_bits + 1 + ) + difference_mantissa_grs_shifted = pyrtl.shift_left_logical( + difference_mantissa_grs, num_leading_zeros + ) + difference_mantissa = difference_mantissa_grs_shifted[3:] + difference_grs = difference_mantissa_grs_shifted[:3] + # Adjust the exponent by subtracting the number of leading zeros. + difference_exponent = larger_operand_exponent - num_leading_zeros + return _RawResultGRS( + difference_exponent, difference_mantissa, difference_grs + ), num_leading_zeros diff --git a/pyrtl/rtllib/pyrtlfloat/_float_utils.py b/pyrtl/rtllib/pyrtlfloat/_float_utils.py new file mode 100644 index 00000000..b9026653 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_float_utils.py @@ -0,0 +1,222 @@ +import pyrtl + +from ._types import FPTypeProperties + + +def _fp_wire_struct(num_exp_bits, num_mant_bits): + """Creates a wire_struct class for an IEEE 754 floating point number. + + The returned class has three fields: sign (1 bit), exponent, and mantissa. + + :param num_exp_bits: Number of exponent bits. + :param num_mant_bits: Number of mantissa bits. + :return: A wire_struct class with sign, exponent, and mantissa fields. + """ + + @pyrtl.wire_struct + class FP: + sign: 1 + exponent: num_exp_bits + mantissa: num_mant_bits + + return FP + + +@pyrtl.wire_struct +class _GRS: + """Guard, round, and sticky bits used for RNE rounding.""" + + guard: 1 + round: 1 + sticky: 1 + + +@pyrtl.wire_struct +class _FPKinds: + """Bits indicating the kind of a floating-point number.""" + + is_nan: 1 + is_inf: 1 + is_zero: 1 + is_denormalized: 1 + + +class _RawResult: + """Groups the exponent and mantissa WireVectors of a result.""" + + def __init__( + self, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ): + self.exponent = exponent + self.mantissa = mantissa + + +class _RawResultGRS(_RawResult): + """Groups the exponent, mantissa, and GRS WireVectors of a result.""" + + def __init__( + self, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + grs: pyrtl.WireVector, + ): + super().__init__(exponent, mantissa) + self.grs = grs + + +def check_kinds(fp) -> _FPKinds: + """ + Returns an _FPKinds wire struct indicating the kind of the given floating point + number. + + :param fp: FP wire_struct instance. + :return: _FPKinds instance. + """ + kinds = _FPKinds(is_nan=None, is_inf=None, is_zero=None, is_denormalized=None) + max_exp = (1 << fp.exponent.bitwidth) - 1 + all_ones_exp = fp.exponent == max_exp + zero_exp = fp.exponent == 0 + zero_mant = fp.mantissa == 0 + kinds.is_nan <<= all_ones_exp & ~zero_mant + kinds.is_inf <<= all_ones_exp & zero_mant + kinds.is_zero <<= zero_exp & zero_mant + kinds.is_denormalized <<= zero_exp & ~zero_mant + return kinds + + +def make_denormals_zero( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector +) -> pyrtl.WireVector: + """ + Returns zero if denormalized, else original number. + https://en.wikipedia.org/wiki/Subnormal_number + + :param fp_prop: Floating point type properties. + :param wire: WireVector holding the floating point number. + :return: WireVector holding the resulting floating point number. + """ + FP = _fp_wire_struct(fp_prop.num_exponent_bits, fp_prop.num_mantissa_bits) + fp = FP(FP=wire) + out = pyrtl.WireVector( + bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 + ) + with pyrtl.conditional_assignment: + with fp.exponent == 0: + out |= pyrtl.concat( + fp.sign, + fp.exponent, + pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), + ) + with pyrtl.otherwise: + out |= wire + return out + + +def make_inf(fp_props: FPTypeProperties) -> tuple: + """ + Returns (exponent, mantissa) WireVectors representing infinity. + + :param fp_props: Floating point type properties. + :return: Tuple of (exponent, mantissa) WireVectors. + """ + num_exp_bits = fp_props.num_exponent_bits + num_mant_bits = fp_props.num_mantissa_bits + return ( + pyrtl.Const((1 << num_exp_bits) - 1, bitwidth=num_exp_bits), + pyrtl.Const(0, bitwidth=num_mant_bits), + ) + + +def make_nan(fp_props: FPTypeProperties) -> tuple: + """ + Returns (exponent, mantissa) WireVectors representing NaN. + + :param fp_props: Floating point type properties. + :return: Tuple of (exponent, mantissa) WireVectors. + """ + num_exp_bits = fp_props.num_exponent_bits + num_mant_bits = fp_props.num_mantissa_bits + return ( + pyrtl.Const((1 << num_exp_bits) - 1, bitwidth=num_exp_bits), + pyrtl.Const(1 << (num_mant_bits - 1), bitwidth=num_mant_bits), + ) + + +def make_zero(fp_props: FPTypeProperties) -> tuple: + """ + Returns (exponent, mantissa) WireVectors representing zero. + + :param fp_props: Floating point type properties. + :return: Tuple of (exponent, mantissa) WireVectors. + """ + num_exp_bits = fp_props.num_exponent_bits + num_mant_bits = fp_props.num_mantissa_bits + return ( + pyrtl.Const(0, bitwidth=num_exp_bits), + pyrtl.Const(0, bitwidth=num_mant_bits), + ) + + +def make_largest_finite_number(fp_props: FPTypeProperties) -> tuple: + """ + Returns (exponent, mantissa) WireVectors representing the largest finite number. + + :param fp_props: Floating point type properties. + :return: Tuple of (exponent, mantissa) WireVectors. + """ + num_exp_bits = fp_props.num_exponent_bits + num_mant_bits = fp_props.num_mantissa_bits + return ( + pyrtl.Const((1 << num_exp_bits) - 2, bitwidth=num_exp_bits), + pyrtl.Const((1 << num_mant_bits) - 1, bitwidth=num_mant_bits), + ) + + +def _round_rne( + raw_result: _RawResult, + raw_grs: pyrtl.WireVector, +) -> tuple: + """ + Round the floating point result using round to nearest, ties to even (RNE). + + Uses the GRS bits to determine if the result needs to be rounded up. + + :param raw_result: Pre-rounding result as a _RawResult. + :param raw_grs: GRS bits of the raw result before rounding (guard=MSB, sticky=LSB). + :return: Tuple of (rounded _RawResult, rounding_exponent_incremented). + """ + num_mant_bits = raw_result.mantissa.bitwidth + num_exp_bits = raw_result.exponent.bitwidth + grs = _GRS(_GRS=raw_grs) + last = raw_result.mantissa[0] + # If guard bit is not set, number is closer to smaller value: no round up. + # If guard bit is set and round or sticky is set, round up. + # If guard bit is set but round and sticky are not set, value is exactly + # halfway. Following round-to-nearest ties-to-even, round up if last bit + # of mantissa is 1 (to make it even); otherwise do not round up. + # https://drilian.com/posts/2023.01.10-floating-point-numbers-and-rounding/ + round_up = grs.guard & (last | grs.round | grs.sticky) + rounded = _RawResult( + exponent=pyrtl.WireVector(bitwidth=num_exp_bits), + mantissa=pyrtl.WireVector(bitwidth=num_mant_bits), + ) + # Whether exponent was incremented due to rounding (for overflow check). + rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with round_up: + # If rounding causes a mantissa overflow, we need to increment the exponent. + with raw_result.mantissa == (1 << num_mant_bits) - 1: + rounded.mantissa |= 0 + rounded.exponent |= raw_result.exponent + 1 + rounding_exponent_incremented |= 1 + with pyrtl.otherwise: + rounded.mantissa |= raw_result.mantissa + 1 + rounded.exponent |= raw_result.exponent + rounding_exponent_incremented |= 0 + with pyrtl.otherwise: + rounded.mantissa |= raw_result.mantissa + rounded.exponent |= raw_result.exponent + rounding_exponent_incremented |= 0 + return rounded, rounding_exponent_incremented diff --git a/pyrtl/rtllib/pyrtlfloat/_multiplication.py b/pyrtl/rtllib/pyrtlfloat/_multiplication.py new file mode 100644 index 00000000..7b3925b8 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_multiplication.py @@ -0,0 +1,303 @@ +import pyrtl + +from ._float_utils import ( + _fp_wire_struct, + _RawResult, + _round_rne, + check_kinds, + make_denormals_zero, + make_inf, + make_largest_finite_number, + make_nan, + make_zero, +) +from ._types import FPTypeProperties, PyrtlFloatConfig, RoundingMode + + +def mul( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> pyrtl.WireVector: + """ + Performs floating point multiplication of two WireVectors. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the multiplication as a WireVector. + """ + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + + # Denormalized numbers are not supported, so we flush them to zero. + operands = tuple( + make_denormals_zero(fp_type_props, op) for op in (operand_a, operand_b) + ) + + # Extract the sign, exponent, and mantissa of both operands. + FP = _fp_wire_struct(num_exp_bits, num_mant_bits) + fps = tuple(FP(FP=op) for op in operands) + del operands + + result_sign = fps[0].sign ^ fps[1].sign + + # Compute the product exponent and mantissa. + operand_exponent_sums, product_exponent, product_mantissa = _multiply( + fps, + num_exp_bits, + ) + + # Normalize the product and perform rounding. + raw_result, need_to_normalize, exponent_incremented = _normalize_and_round( + product_exponent, + product_mantissa, + fp_type_props, + rounding_mode, + ) + del product_mantissa, product_exponent + + return _handle_special_cases( + FP, + fp_type_props, + fps, + result_sign, + raw_result, + operand_exponent_sums, + need_to_normalize, + exponent_incremented, + rounding_mode, + ) + + +def _multiply( + fps: tuple, + num_exp_bits: int, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + """ + Computes the sum of operand exponents, the product exponent, and the raw + product mantissa. + + :param fps: Tuple of FP wire_struct instances for the two operands. + :param num_exp_bits: Number of exponent bits. + :return: Tuple of (operand_exponent_sums, product_exponent, + product_mantissa). + """ + # IEEE-754 floating point numbers have a bias: + # https://en.wikipedia.org/wiki/Exponent_bias + # stored_exponent = real_exponent + bias + # The sum of the stored exponents of the operands is (real0 + bias) + (real1 + bias) + # = real0 + real1 + 2*bias. + # Subtracting bias gives the stored exponent of the product: real0 + real1 + bias. + operand_exponent_sums = fps[0].exponent + fps[1].exponent + exponent_bias = 2 ** (num_exp_bits - 1) - 1 + product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) + + # Extract the mantissa of both operands and add the implicit leading 1. + mantissas = tuple(pyrtl.concat(pyrtl.Const(1), fp.mantissa) for fp in fps) + product_mantissa = mantissas[0] * mantissas[1] + + return operand_exponent_sums, product_exponent, product_mantissa + + +def _normalize_and_round( + product_exponent: pyrtl.WireVector, + product_mantissa: pyrtl.WireVector, + fp_type_props: FPTypeProperties, + rounding_mode: RoundingMode, +) -> tuple: + """ + Normalizes the product mantissa and applies rounding if configured. + + :param product_exponent: The product exponent (sum of operand exponents + minus bias). + :param product_mantissa: Raw product of the two mantissas (with implicit 1s). + :param fp_type_props: Floating point type properties. + :param rounding_mode: The rounding mode to apply. + :return: Tuple of (_RawResult, need_to_normalize, exponent_incremented). + exponent_incremented is None for RTZ rounding mode. + """ + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + # We're multiplying two numbers that both have the form 1. in + # binary. The product's binary point sits just after its second-most + # significant bit, giving the form ab.cdef... where each letter is one bit. + # + # Either a (the MSB) or b (the bit after the MSB) must be 1 because the + # product of two 1. numbers is always in [1, 4): the smallest + # is 1.0... * 1.0... = 01.0... and the largest is 1.111... * 1.111... + # = 11.111... Therefore, there must be one or two bits to the left of the + # binary point. + # + # A properly normalized result must have the form 1., so: + # - If a (the MSB) is 1, the product is 1.bcdef... so we increment the + # exponent to reinterpret the result as 1.bcdef... and the mantissa is + # correct as is. + # - If a is 0, the product is 0b.cdef... = 01.cdef... (b must be 1 since a + # is 0), so we shift the mantissa left by 1 to reinterpret the result + # as 1.cdef... and the exponent is correct as is. + pyrtl.rtl_assert( + product_mantissa[-1] | product_mantissa[-2], + AssertionError("product mantissa MSB or the bit after the MSB must be 1"), + ) + need_to_normalize = product_mantissa[-1] + aligned_mantissa = pyrtl.WireVector(bitwidth=product_mantissa.bitwidth) + normalized_product_exponent = pyrtl.WireVector(bitwidth=product_exponent.bitwidth) + with pyrtl.conditional_assignment: + with need_to_normalize: + aligned_mantissa |= product_mantissa + normalized_product_exponent |= product_exponent + 1 + with pyrtl.otherwise: + aligned_mantissa |= pyrtl.concat( + product_mantissa[:-1], pyrtl.Const(0, bitwidth=1) + ) + normalized_product_exponent |= product_exponent + + # Strip the implicit leading 1 to get the stored mantissa bits. + normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + normalized_product_mantissa <<= aligned_mantissa[-num_mant_bits - 1 :] + + if rounding_mode == RoundingMode.RNE: + # Extract guard, round, and sticky bits for rounding. + # https://drilian.com/posts/2023.01.10-floating-point-numbers-and-rounding/ + guard = aligned_mantissa[-num_mant_bits - 2] + round_bit = aligned_mantissa[-num_mant_bits - 3] + sticky = aligned_mantissa[: -num_mant_bits - 3] != 0 + + raw_product = _RawResult( + exponent=normalized_product_exponent, + mantissa=normalized_product_mantissa, + ) + raw_grs = pyrtl.concat(guard, round_bit, sticky) + rounded_product, exponent_incremented = _round_rne(raw_product, raw_grs) + raw_result = _RawResult( + exponent=rounded_product.exponent[:num_exp_bits], + mantissa=rounded_product.mantissa, + ) + return raw_result, need_to_normalize, exponent_incremented + raw_result = _RawResult( + exponent=normalized_product_exponent[:num_exp_bits], + mantissa=normalized_product_mantissa, + ) + return raw_result, need_to_normalize, None + + +def _handle_special_cases( + FP, + fp_type_props: FPTypeProperties, + fps: tuple, + result_sign: pyrtl.WireVector, + raw_result: _RawResult, + operand_exponent_sums: pyrtl.WireVector, + need_to_normalize: pyrtl.WireVector, + exponent_incremented, + rounding_mode: RoundingMode, +): + """ + Handles special cases: NaN, infinity, zero, overflow, and underflow. + + :param FP: The FP wire_struct class for the current floating point type. + :param fp_type_props: Floating point type properties. + :param fps: Tuple of FP wire_struct instances for the two operands. + :param result_sign: Sign bit of the result. + :param raw_result: Normalized (and possibly rounded) result as a _RawResult. + :param operand_exponent_sums: Sum of the two operand exponents. + :param need_to_normalize: Whether the product mantissa required normalization. + :param exponent_incremented: Whether rounding incremented the exponent + (None for RTZ mode). + :param rounding_mode: The rounding mode being used. + :return: The final FP wire_struct result. + """ + num_exp_bits = fp_type_props.num_exponent_bits + exponent_bias = 2 ** (num_exp_bits - 1) - 1 + + # Check whether operands are special: NaN, infinity, zero, or denormalized. + operand_kinds = tuple(check_kinds(fp) for fp in fps) + + # Pre-compute special value constants for use inside conditional_assignment. + result = FP(sign=None, exponent=None, mantissa=None) + result.sign <<= result_sign + nan_exp, nan_mant = make_nan(fp_type_props) + inf_exp, inf_mant = make_inf(fp_type_props) + zero_exp, zero_mant = make_zero(fp_type_props) + largest_exp, largest_mant = make_largest_finite_number(fp_type_props) + + # We check for overflow and underflow by computing max and min exponent + # values of the sum of operands' exponent before rounding and normalization. + # These values depend on the operands. If the result requires + # normalization, the exponent is incremented by 1. Additionally, rounding + # may further increase the exponent. Therefore, we subtract these + # potential increments from the base maximum exponent, which is one + # less than the all-1s exponent (reserved for inf/NaN) plus bias. + # Similarly, we subtract these increments from the base minimum + # exponent, which is 1 plus the exponent bias. However, instead of performing + # subtractions in hardware, we use conditional assignments to determine the + # appropriate maximum and minimum exponent values. Since we are merely + # selecting among three possible values, instantiating a subtractor + # would be overkill. + base_exponent_max_value = 2**num_exp_bits - 2 + exponent_bias + base_exponent_min_value = 1 + exponent_bias + exponent_max_value = pyrtl.WireVector(bitwidth=operand_exponent_sums.bitwidth) + exponent_min_value = pyrtl.WireVector(bitwidth=operand_exponent_sums.bitwidth) + if rounding_mode == RoundingMode.RNE: + with pyrtl.conditional_assignment: + with exponent_incremented & need_to_normalize: + exponent_max_value |= base_exponent_max_value - 2 + exponent_min_value |= base_exponent_min_value - 2 + with exponent_incremented | need_to_normalize: + exponent_max_value |= base_exponent_max_value - 1 + exponent_min_value |= base_exponent_min_value - 1 + with pyrtl.otherwise: + exponent_max_value |= base_exponent_max_value + exponent_min_value |= base_exponent_min_value + else: + with pyrtl.conditional_assignment: + with need_to_normalize: + exponent_max_value |= base_exponent_max_value - 1 + exponent_min_value |= base_exponent_min_value - 1 + with pyrtl.otherwise: + exponent_max_value |= base_exponent_max_value + exponent_min_value |= base_exponent_min_value + + with pyrtl.conditional_assignment: + # If either operand is NaN, or if one operand is infinity and the other is + # zero, the result is NaN. + with ( + operand_kinds[0].is_nan + | operand_kinds[1].is_nan + | (operand_kinds[0].is_inf & operand_kinds[1].is_zero) + | (operand_kinds[0].is_zero & operand_kinds[1].is_inf) + ): + result.exponent |= nan_exp + result.mantissa |= nan_mant + # If either operand is infinity, the result is infinity. + with operand_kinds[0].is_inf | operand_kinds[1].is_inf: + result.exponent |= inf_exp + result.mantissa |= inf_mant + # Detect overflow. + with operand_exponent_sums > exponent_max_value: + if rounding_mode == RoundingMode.RNE: + result.exponent |= inf_exp + result.mantissa |= inf_mant + else: + result.exponent |= largest_exp + result.mantissa |= largest_mant + # If either operand is zero, if underflow occurred, or if either operand is + # denormalized, the result is zero. + with ( + operand_kinds[0].is_zero + | operand_kinds[1].is_zero + | (operand_exponent_sums < exponent_min_value) + | operand_kinds[0].is_denormalized + | operand_kinds[1].is_denormalized + ): + result.exponent |= zero_exp + result.mantissa |= zero_mant + # Otherwise no special cases apply: this is the common case. + with pyrtl.otherwise: + result.exponent |= raw_result.exponent + result.mantissa |= raw_result.mantissa + + return result diff --git a/pyrtl/rtllib/pyrtlfloat/_types.py b/pyrtl/rtllib/pyrtlfloat/_types.py new file mode 100644 index 00000000..a18df565 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_types.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from enum import Enum + + +class RoundingMode(Enum): + """ + Enum representing different rounding modes. + + Attributes: + RTZ (int): Round towards zero (truncate). + RNE (int): Round to nearest, ties to even (default mode). + """ + + RTZ = 1 + RNE = 2 + + +@dataclass(frozen=True) +class FPTypeProperties: + """ + Data class representing properties of a floating-point type. + + Attributes: + num_exponent_bits (int): Number of bits used for the exponent. + num_mantissa_bits (int): Number of bits used for the mantissa. + """ + + num_exponent_bits: int + num_mantissa_bits: int + + +class FloatingPointType(Enum): + """ + Enum representing different floating-point types. + + Attributes: + BFLOAT16 (FPTypeProperties): BFloat16 type properties. + FLOAT16 (FPTypeProperties): Float16 type properties. + FLOAT32 (FPTypeProperties): Float32 type properties. + FLOAT64 (FPTypeProperties): Float64 type properties. + """ + + BFLOAT16 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=7) + FLOAT16 = FPTypeProperties(num_exponent_bits=5, num_mantissa_bits=10) + FLOAT32 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=23) + FLOAT64 = FPTypeProperties(num_exponent_bits=11, num_mantissa_bits=52) + + +@dataclass(frozen=True) +class PyrtlFloatConfig: + """ + Data class representing the configuration for PyrtlFloat operations (floating point + type properties and rounding mode). + + Attributes: + fp_type_properties (FPTypeProperties): Properties of the floating-point type. + rounding_mode (RoundingMode): Rounding mode to be used. + """ + + fp_type_properties: FPTypeProperties + rounding_mode: RoundingMode diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py new file mode 100644 index 00000000..c3bdf98b --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/floatoperations.py @@ -0,0 +1,180 @@ +import pyrtl + +from ._add_sub import add, sub +from ._multiplication import mul +from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode + + +def _validate_operand_bitwidths( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, +) -> None: + """Validate that operand bitwidths match the floating point config.""" + fp_props = config.fp_type_properties + expected_bitwidth = fp_props.num_exponent_bits + fp_props.num_mantissa_bits + 1 + if operand_a.bitwidth != expected_bitwidth: + msg = ( + f"operand_a bitwidth {operand_a.bitwidth} does not match expected " + f"bitwidth {expected_bitwidth} for floating point type" + ) + raise pyrtl.PyrtlError(msg) + if operand_b.bitwidth != expected_bitwidth: + msg = ( + f"operand_b bitwidth {operand_b.bitwidth} does not match expected " + f"bitwidth {expected_bitwidth} for floating point type" + ) + raise pyrtl.PyrtlError(msg) + + +class FloatOperations: + """ + The rounding mode used for typed floating-point operations. + To change it, set this variable to the desired RoundingMode value. + """ + + default_rounding_mode = RoundingMode.RNE + + @staticmethod + def mul( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + """ + Performs floating point multiplication of two WireVectors. The bitwidth of + the operands must be num_exponent_bits + num_mantissa_bits + 1, where + num_exponent_bits and num_mantissa_bits are defined in the config. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the multiplication as a WireVector. + :raises PyrtlError: If operand bitwidths don't match config. + """ + _validate_operand_bitwidths(config, operand_a, operand_b) + return mul(config, operand_a, operand_b) + + @staticmethod + def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + """ + Performs floating point addition of two WireVectors. The bitwidth of + the operands must be num_exponent_bits + num_mantissa_bits + 1, where + num_exponent_bits and num_mantissa_bits are defined in the config. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the addition as a WireVector. + :raises PyrtlError: If operand bitwidths don't match config. + """ + _validate_operand_bitwidths(config, operand_a, operand_b) + return add(config, operand_a, operand_b) + + @staticmethod + def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + """ + Performs floating point subtraction of two WireVectors. The bitwidth of + the operands must be num_exponent_bits + num_mantissa_bits + 1, where + num_exponent_bits and num_mantissa_bits are defined in the config. + + :param config: Configuration for the floating point type and rounding mode. + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the subtraction as a WireVector. + :raises PyrtlError: If operand bitwidths don't match config. + """ + _validate_operand_bitwidths(config, operand_a, operand_b) + return sub(config, operand_a, operand_b) + + +class _BaseTypedFloatOperations: + _fp_type: FloatingPointType = None + + @classmethod + def mul( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + """ + Performs floating point multiplication of two WireVectors. The bitwidth of + the operands must match the bitwidth of the floating point type of this class. + + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the multiplication as a WireVector. + """ + return FloatOperations.mul(cls._get_config(), operand_a, operand_b) + + @classmethod + def add( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + """ + Performs floating point addition of two WireVectors. The bitwidth of + the operands must match the bitwidth of the floating point type of this class. + + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the addition as a WireVector. + """ + return FloatOperations.add(cls._get_config(), operand_a, operand_b) + + @classmethod + def sub( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + """ + Performs floating point subtraction of two WireVectors. The bitwidth of + the operands must match the bitwidth of the floating point type of this class. + + :param operand_a: The first floating point operand as a WireVector. + :param operand_b: The second floating point operand as a WireVector. + :return: The result of the subtraction as a WireVector. + """ + return FloatOperations.sub(cls._get_config(), operand_a, operand_b) + + @classmethod + def _get_config(cls) -> PyrtlFloatConfig: + return PyrtlFloatConfig( + cls._fp_type.value, FloatOperations.default_rounding_mode + ) + + +class BFloat16Operations(_BaseTypedFloatOperations): + """ + Operations for BFloat16 floating point type. + """ + + _fp_type = FloatingPointType.BFLOAT16 + + +class Float16Operations(_BaseTypedFloatOperations): + """ + Operations for Float16 floating point type. + """ + + _fp_type = FloatingPointType.FLOAT16 + + +class Float32Operations(_BaseTypedFloatOperations): + """ + Operations for Float32 floating point type. + """ + + _fp_type = FloatingPointType.FLOAT32 + + +class Float64Operations(_BaseTypedFloatOperations): + """ + Operations for Float64 floating point type. + """ + + _fp_type = FloatingPointType.FLOAT64 diff --git a/tests/rtllib/pyrtlfloat/float16_test_utils.py b/tests/rtllib/pyrtlfloat/float16_test_utils.py new file mode 100644 index 00000000..43542a52 --- /dev/null +++ b/tests/rtllib/pyrtlfloat/float16_test_utils.py @@ -0,0 +1,57 @@ +"""Shared Float16 test utilities: constants, encoding helpers, and assertions.""" + +import unittest + +# IEEE 754 Float16 special values +FLOAT16_POS_ZERO = 0x0000 +FLOAT16_NEG_ZERO = 0x8000 +FLOAT16_POS_INF = 0x7C00 +FLOAT16_NEG_INF = 0xFC00 +# Quiet NaN: https://en.wikipedia.org/wiki/NaN#Quiet_NaN +# Encoding: https://en.wikipedia.org/wiki/NaN#Encoding +FLOAT16_NAN = 0x7E00 +FLOAT16_ONE = 0x3C00 # 1.0 +FLOAT16_NEG_ONE = 0xBC00 # -1.0 +FLOAT16_TWO = 0x4000 # 2.0 +FLOAT16_NEG_TWO = 0xC000 # -2.0 +FLOAT16_THREE = 0x4200 # 3.0 +FLOAT16_HALF = 0x3800 # 0.5 +FLOAT16_ONE_POINT_FIVE = 0x3E00 # 1.5 +FLOAT16_LARGEST_NORMAL = 0x7BFF # Largest normal number (~65504) +FLOAT16_DENORMALIZED = 0x0001 # Smallest denormalized number + + +def float16_parts(sign, exp, mant): + """Construct Float16 from sign, exponent, and mantissa.""" + assert sign in (0, 1), f"sign must be 0 or 1, got {sign}" + assert 0 <= exp <= 31, f"exponent must be in [0, 31], got {exp}" + assert 0 <= mant <= 1023, f"mantissa must be in [0, 1023], got {mant}" + return (sign << 15) | (exp << 10) | mant + + +def decode_float16(bits): + """Decode Float16 bits to (sign, exponent, mantissa).""" + assert 0 <= bits <= 0xFFFF, f"bits must be a 16-bit value, got {bits:#06x}" + return (bits >> 15) & 1, (bits >> 10) & 0x1F, bits & 0x3FF + + +def is_nan(bits): + """Check if Float16 bits represent NaN.""" + _, exp, mant = decode_float16(bits) + return exp == 0x1F and mant != 0 + + +def assertFloat16Equal(test_case: unittest.TestCase, sim, output_name, expected): + """Assert that a simulated Float16 output matches expected, with decoded info.""" + actual = sim.inspect(output_name) + if actual != expected: + actual_sign, actual_exponent, actual_mantissa = decode_float16(actual) + expected_sign, expected_exponent, expected_mantissa = decode_float16( + expected + ) + test_case.fail( + f"{output_name}: expected {expected:#06x} sign: {expected_sign}, " + f"exponent: {expected_exponent}, mantissa: {expected_mantissa};\n" + f"got {actual:#06x} sign: {actual_sign}, exponent: {actual_exponent}, " + f"mantissa: {actual_mantissa}" + ) diff --git a/tests/rtllib/pyrtlfloat/test_add_sub.py b/tests/rtllib/pyrtlfloat/test_add_sub.py new file mode 100644 index 00000000..3c8a961e --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_add_sub.py @@ -0,0 +1,428 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode + +from .float16_test_utils import ( + FLOAT16_DENORMALIZED, + FLOAT16_LARGEST_NORMAL, + FLOAT16_NAN, + FLOAT16_NEG_INF, + FLOAT16_NEG_ONE, + FLOAT16_NEG_ZERO, + FLOAT16_ONE, + FLOAT16_POS_INF, + FLOAT16_POS_ZERO, + FLOAT16_THREE, + FLOAT16_TWO, + assertFloat16Equal, + float16_parts, + is_nan, +) + +# Additional Float16 constants used only in add/sub tests +FLOAT16_QUARTER = 0x3400 # 0.25 +FLOAT16_HALF = 0x3800 # 0.5 +FLOAT16_ONE_POINT_FIVE = 0x3E00 # 1.5 +FLOAT16_ONE_POINT_TWOFIVE = 0x3D00 # 1.25 +FLOAT16_SMALLEST_NORMAL = 0x0400 # Smallest normal number (2^-14) + + +class TestAddition(unittest.TestCase): + """Tests for Float16 addition operations.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.add(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.add(self.a, self.b) + self.sim = pyrtl.Simulation() + + def assertFloat16Equal(self, output_name, expected): + assertFloat16Equal(self, self.sim, output_name, expected) + + ############################ + # Normal cases. + + def test_add_one_plus_two(self): + """Test 1.0 + 2.0 = 3.0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_THREE) + self.assertFloat16Equal("result_rtz", FLOAT16_THREE) + + def test_add_one_plus_half(self): + """Test 1.0 + 0.5 = 1.5 (no rounding, GRS=000)""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_HALF}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE_POINT_FIVE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE_POINT_FIVE) + + def test_add_one_plus_quarter(self): + """Test 1.0 + 0.25 = 1.25 (no rounding, shift by 2)""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_QUARTER}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE_POINT_TWOFIVE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE_POINT_TWOFIVE) + + def test_add_half_plus_half(self): + """Test 0.5 + 0.5 = 1.0""" + self.sim.step({"a": FLOAT16_HALF, "b": FLOAT16_HALF}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE) + + def test_add_with_carry(self): + """Test 1.5 + 1.5 = 3.0 (carry propagates to exponent)""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_ONE_POINT_FIVE}) + self.assertFloat16Equal("result_rne", FLOAT16_THREE) + self.assertFloat16Equal("result_rtz", FLOAT16_THREE) + + def test_add_opposite_signs_equal_magnitude(self): + """Test 1.0 + (-1.0) = 0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ONE}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + ############################ + # Rounding tests. + + def test_rounding_g1_r0_s0_lsb0_tie_truncates(self): + """Test G=1, R=0, S=0, LSB=0: tie, RNE truncates to even. + + a = 1.0 (exp=15, mant=0) + b = 0.5 * (1 + 1/1024) = exp=14, mant=1 + Shift b by 1: G=1 (bit 0 of original), R=0, S=0 + Sum mantissa LSB = 0, so RNE truncates. + Both RNE and RTZ produce same result. + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 14, 1) # 0.5 * (1 + 1/1024) + expected = float16_parts(0, 15, 512) # 1.5 + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected) + self.assertFloat16Equal("result_rtz", expected) + + def test_rounding_g1_r0_s0_lsb1_tie_rounds_up(self): + """Test G=1, R=0, S=0, LSB=1: tie, RNE rounds up to even. + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 0.5 * (1 + 1/1024) = exp=14, mant=1 + Shift b by 1: G=1, R=0, S=0 + Sum mantissa = 1.1000000001, LSB = 1 + RNE: round up to make LSB even -> mant = 514 + RTZ: truncate -> mant = 513 + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 14, 1) # 0.5 * (1 + 1/1024) + expected_rne = float16_parts(0, 15, 514) + expected_rtz = float16_parts(0, 15, 513) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected_rne) + self.assertFloat16Equal("result_rtz", expected_rtz) + + def test_rounding_g1_r1_s0_rounds_up(self): + """Test G=1, R=1, S=0: greater than half ULP, RNE rounds up. + + a = 1.0 (exp=15, mant=0) + b = 0.25 * (1 + 3/1024) = exp=13, mant=3 + Shift b by 2: G=1 (bit 1), R=1 (bit 0), S=0 + RNE: round up + RTZ: truncate + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 13, 3) # 0.25 * (1 + 3/1024) + expected_rne = float16_parts(0, 15, 257) + expected_rtz = float16_parts(0, 15, 256) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected_rne) + self.assertFloat16Equal("result_rtz", expected_rtz) + + def test_rounding_g1_r0_s1_rounds_up(self): + """Test G=1, R=0, S=1: greater than half ULP, RNE rounds up. + + a = 1.0 (exp=15, mant=0) + b = 0.125 * (1 + 5/1024) = exp=12, mant=5 (binary: 101) + Shift b by 3: G=1 (bit 2), R=0 (bit 1), S=1 (bit 0) + RNE: round up + RTZ: truncate + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 12, 5) # 0.125 * (1 + 5/1024) + expected_rne = float16_parts(0, 15, 129) + expected_rtz = float16_parts(0, 15, 128) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected_rne) + self.assertFloat16Equal("result_rtz", expected_rtz) + + def test_rounding_g0_r1_s1_truncates(self): + """Test G=0, R=1, S=1: less than half ULP, RNE truncates. + + a = 1.0 (exp=15, mant=0) + b = 0.125 * (1 + 3/1024) = exp=12, mant=3 (binary: 011) + Shift b by 3: G=0 (bit 2), R=1 (bit 1), S=1 (bit 0) + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 12, 3) # 0.125 * (1 + 3/1024) + expected = float16_parts(0, 15, 128) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected) + self.assertFloat16Equal("result_rtz", expected) + + ############################ + # Rounding with carry tests. + + def test_carry_g1_r0_s0_lsb1_tie_rounds_up(self): + """Test carry with G=1, R=0, S=0, LSB=1: tie rounds up. + + a = 1.1111111110 (exp=15, mant=1022) + b = 1.0000000001 (exp=15, mant=1) + Sum = 10.1111111111 -> normalize to 1.01111111111 + After normalization: G=1 (shifted out bit), R=0, S=0 + Result mantissa = 0111111111 = 511, LSB = 1 + RNE: tie, LSB=1 -> round up to 512 + RTZ: truncate to 511 + """ + a = float16_parts(0, 15, 1022) + b = float16_parts(0, 15, 1) + expected_rne = float16_parts(0, 16, 512) # Rounded up + expected_rtz = float16_parts(0, 16, 511) # Truncated + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected_rne) + self.assertFloat16Equal("result_rtz", expected_rtz) + + def test_carry_g1_r0_s0_lsb0_tie_truncates(self): + """Test carry with G=1, R=0, S=0, LSB=0: tie truncates. + + a = 1.1111111100 (exp=15, mant=1020) + b = 1.0000000001 (exp=15, mant=1) + After normalization and carry handling: + Result mantissa = 510, LSB = 0 + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 1020) + b = float16_parts(0, 15, 1) + expected = float16_parts(0, 16, 510) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected) + self.assertFloat16Equal("result_rtz", expected) + + ############################ + # Edge cases. + + def test_add_zero_to_number(self): + """Test x + 0 = x""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE) + + def test_add_negative_zero_to_number(self): + """Test x + (-0) = x""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE) + + def test_add_infinity_to_number(self): + """Test x + inf = inf""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_INF}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_INF) + + def test_add_negative_infinity_to_number(self): + """Test x + (-inf) = -inf""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_INF}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_NEG_INF) + + def test_add_infinity_minus_infinity_is_nan(self): + """Test inf + (-inf) = NaN""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_NEG_INF}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_add_nan_propagates(self): + """Test x + NaN = NaN""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NAN}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_add_denormalized_flushed_to_zero(self): + """Test that denormalized inputs are flushed to zero.""" + self.sim.step({"a": FLOAT16_POS_ZERO, "b": FLOAT16_DENORMALIZED}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + ############################ + # Overflow tests. + + def test_overflow_rne_produces_infinity(self): + """Test that overflow produces infinity with RNE.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_LARGEST_NORMAL}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_INF) + + def test_overflow_rtz_produces_largest_finite(self): + """Test that overflow produces largest finite with RTZ.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_LARGEST_NORMAL}) + self.assertFloat16Equal("result_rtz", FLOAT16_LARGEST_NORMAL) + + +class TestSubtraction(unittest.TestCase): + """Tests for Float16 subtraction operations.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.sub(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.sub(self.a, self.b) + self.sim = pyrtl.Simulation() + + def assertFloat16Equal(self, output_name, expected): + assertFloat16Equal(self, self.sim, output_name, expected) + + ############################ + # Normal cases. + + def test_sub_three_minus_one(self): + """Test 3.0 - 1.0 = 2.0""" + self.sim.step({"a": FLOAT16_THREE, "b": FLOAT16_ONE}) + self.assertFloat16Equal("result_rne", FLOAT16_TWO) + self.assertFloat16Equal("result_rtz", FLOAT16_TWO) + + def test_sub_one_point_five_minus_half(self): + """Test 1.5 - 0.5 = 1.0""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_HALF}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE) + + def test_sub_equal_numbers(self): + """Test 1.0 - 1.0 = 0.0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_ONE}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + def test_sub_from_zero(self): + """Test 0 - 1.0 = -1.0""" + self.sim.step({"a": FLOAT16_POS_ZERO, "b": FLOAT16_ONE}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_ONE) + self.assertFloat16Equal("result_rtz", FLOAT16_NEG_ONE) + + def test_sub_two_minus_half(self): + """Test 2.0 - 0.5 = 1.5""" + self.sim.step({"a": FLOAT16_TWO, "b": FLOAT16_HALF}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE_POINT_FIVE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE_POINT_FIVE) + + def test_sub_double_negative(self): + """Test x - (-y) = x + y: 1.0 - (-1.0) = 2.0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ONE}) + self.assertFloat16Equal("result_rne", FLOAT16_TWO) + self.assertFloat16Equal("result_rtz", FLOAT16_TWO) + + ############################ + # Rounding tests. + + def test_sub_exact_no_rounding(self): + """Test 1.5 - 0.25 = 1.25 (no rounding needed, exact result). + + a = 1.5 (exp=15, mant=512) + b = 0.25 (exp=13, mant=0) + Result fits exactly in mantissa, GRS=000. + Both RNE and RTZ produce same result. + """ + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_QUARTER}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE_POINT_TWOFIVE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE_POINT_TWOFIVE) + + def test_sub_rne_rounds_up(self): + """Test subtraction where RNE rounds up and RTZ truncates. + + a = 1.0 * (1 + 512/1024) = exp=15, mant=512 (1.5) + b = 0.25 * (1 + 2/1024) = exp=13, mant=2 + Shift b by 2 -> LSB=1, G=1, R=0, S=0 + RNE: rounds up to 256 + RTZ: truncates to 255 + """ + a = float16_parts(0, 15, 512) + b = float16_parts(0, 13, 2) + expected_rne = float16_parts(0, 15, 256) + expected_rtz = float16_parts(0, 15, 255) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected_rne) + self.assertFloat16Equal("result_rtz", expected_rtz) + + ############################ + # Edge cases. + + def test_sub_zero_from_number(self): + """Test x - 0 = x""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE) + + def test_sub_pos_zero_minus_pos_zero(self): + """Test +0 - +0 = +0""" + self.sim.step({"a": FLOAT16_POS_ZERO, "b": FLOAT16_POS_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + def test_sub_neg_zero_minus_neg_zero(self): + """Test -0 - -0 = +0""" + self.sim.step({"a": FLOAT16_NEG_ZERO, "b": FLOAT16_NEG_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + def test_sub_infinity_from_number(self): + """Test x - inf = -inf""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_INF}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_NEG_INF) + + def test_sub_infinity_from_infinity_is_nan(self): + """Test inf - inf = NaN""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_POS_INF}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_sub_neg_infinity_from_pos_infinity(self): + """Test inf - (-inf) = inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_NEG_INF}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_INF) + + def test_sub_nan_propagates(self): + """Test x - NaN = NaN""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NAN}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_sub_denormalized_flushed_to_zero(self): + """Test that denormalized operands are flushed to zero.""" + self.sim.step({"a": FLOAT16_DENORMALIZED, "b": FLOAT16_POS_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + ############################ + # Overflow tests. + + def test_overflow_by_subtracting_negative(self): + """Test overflow when subtracting large negative: large - (-large). + + RNE: overflow produces infinity + RTZ: overflow produces largest finite + """ + a = FLOAT16_LARGEST_NORMAL # Large positive + b = FLOAT16_LARGEST_NORMAL | 0x8000 # Same magnitude, negative + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_LARGEST_NORMAL) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rtllib/pyrtlfloat/test_multiplication.py b/tests/rtllib/pyrtlfloat/test_multiplication.py new file mode 100644 index 00000000..bd433514 --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_multiplication.py @@ -0,0 +1,248 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode + +from .float16_test_utils import ( + FLOAT16_DENORMALIZED, + FLOAT16_HALF, + FLOAT16_LARGEST_NORMAL, + FLOAT16_NAN, + FLOAT16_NEG_INF, + FLOAT16_NEG_ONE, + FLOAT16_NEG_TWO, + FLOAT16_NEG_ZERO, + FLOAT16_ONE, + FLOAT16_ONE_POINT_FIVE, + FLOAT16_POS_INF, + FLOAT16_POS_ZERO, + FLOAT16_THREE, + FLOAT16_TWO, + assertFloat16Equal, + float16_parts, + is_nan, +) + + +class TestMultiplication(unittest.TestCase): + """Tests for Float16 multiplication operations.""" + + def setUp(self): + pyrtl.reset_working_block() + self.a = pyrtl.Input(bitwidth=16, name="a") + self.b = pyrtl.Input(bitwidth=16, name="b") + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= Float16Operations.mul(self.a, self.b) + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= Float16Operations.mul(self.a, self.b) + self.sim = pyrtl.Simulation() + + def assertFloat16Equal(self, output_name, expected): + assertFloat16Equal(self, self.sim, output_name, expected) + + ############################ + # Normal cases. + + def test_mul_half_times_two(self): + """Test 0.5 * 2.0 = 1.0""" + self.sim.step({"a": FLOAT16_HALF, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_ONE) + self.assertFloat16Equal("result_rtz", FLOAT16_ONE) + + def test_mul_one_point_five_times_two(self): + """Test 1.5 * 2.0 = 3.0""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_THREE) + self.assertFloat16Equal("result_rtz", FLOAT16_THREE) + + def test_mul_opposite_signs(self): + """Test -1.0 * 2.0 = -2.0""" + self.sim.step({"a": FLOAT16_NEG_ONE, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_TWO) + self.assertFloat16Equal("result_rtz", FLOAT16_NEG_TWO) + + def test_mul_both_negative(self): + """Test -1.0 * -2.0 = 2.0""" + self.sim.step({"a": FLOAT16_NEG_ONE, "b": FLOAT16_NEG_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_TWO) + self.assertFloat16Equal("result_rtz", FLOAT16_TWO) + + def test_mul_one_point_five_times_one_point_five(self): + """Test 1.5 * 1.5 = 2.25""" + self.sim.step({"a": FLOAT16_ONE_POINT_FIVE, "b": FLOAT16_ONE_POINT_FIVE}) + expected = 0x4080 # 2.25 in float16 + self.assertFloat16Equal("result_rne", expected) + self.assertFloat16Equal("result_rtz", expected) + + ############################ + # Rounding tests. + + def test_rounding_g0_s0_truncates(self): + """Test Guard=0, Sticky=0: RNE truncates (exact result, no rounding needed). + + a = 1.0 (exp=15, mant=0) + b = 1.0 (exp=15, mant=0) + Product mantissa has Guard=0, Sticky=0. + Both RNE and RTZ produce the same result. + """ + a = float16_parts(0, 15, 0) # 1.0 + b = float16_parts(0, 15, 0) # 1.0 + expected = float16_parts(0, 15, 0) # 1.0 + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected) + self.assertFloat16Equal("result_rtz", expected) + + def test_rounding_g0_s1_truncates(self): + """Test Guard=0, Sticky=1: RNE truncates (less than half ULP). + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 1.0 * (1 + 1/1024) = exp=15, mant=1 + Product has Guard=0, Sticky=1, Last=0. + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + expected = float16_parts(0, 15, 2) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected) + self.assertFloat16Equal("result_rtz", expected) + + def test_rounding_g1_l0_s0_tie_truncates(self): + """Test Guard=1, Last=0, Sticky=0: tie, RNE truncates (LSB already even). + + a = 1.0 * (1 + 2/1024) = exp=15, mant=2 + b = 1.0 * (1 + 256/1024) = exp=15, mant=256 + Product has Guard=1, Sticky=0, Last=0. + RNE: LSB is 0 (even), so truncate. + Both RNE and RTZ truncate. + """ + a = float16_parts(0, 15, 2) # 1.0 * (1 + 2/1024) + b = float16_parts(0, 15, 256) # 1.0 * (1 + 256/1024) + expected = float16_parts(0, 15, 258) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected) + self.assertFloat16Equal("result_rtz", expected) + + def test_rounding_g1_l1_s0_tie_rounds_up(self): + """Test Guard=1, Last=1, Sticky=0: tie, RNE rounds up (make LSB even). + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 1.0 * (1 + 512/1024) = exp=15, mant=512 (1.5) + Product has Guard=1, Sticky=0, Last=1. + RNE: LSB is 1 (odd), so round up to make it even. + RTZ: truncates + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 15, 512) # 1.0 * (1 + 512/1024) + expected_rne = float16_parts(0, 15, 514) + expected_rtz = float16_parts(0, 15, 513) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected_rne) + self.assertFloat16Equal("result_rtz", expected_rtz) + + def test_rounding_g1_l0_s1_rounds_up(self): + """Test Guard=1, Last=0, Sticky=1: greater than half ULP, RNE rounds up. + + a = 1.0 * (1 + 1/1024) = exp=15, mant=1 + b = 1.0 * (1 + 513/1024) = exp=15, mant=513 + Product has Guard=1, Sticky=1, Last=0. + Greater than half ULP, so round up. + RNE: rounds up + RTZ: truncates + """ + a = float16_parts(0, 15, 1) # 1.0 * (1 + 1/1024) + b = float16_parts(0, 15, 513) # 1.0 * (1 + 513/1024) + expected_rne = float16_parts(0, 15, 515) + expected_rtz = float16_parts(0, 15, 514) + self.sim.step({"a": a, "b": b}) + self.assertFloat16Equal("result_rne", expected_rne) + self.assertFloat16Equal("result_rtz", expected_rtz) + + ############################ + # Edge cases. + + def test_mul_by_zero(self): + """Test x * 0 = 0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_POS_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + def test_mul_by_negative_zero(self): + """Test x * (-0) = -0""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NEG_ZERO}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_NEG_ZERO) + + def test_mul_infinity_by_number(self): + """Test inf * x = inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_INF) + + def test_mul_neg_infinity_by_number(self): + """Test -inf * x = -inf""" + self.sim.step({"a": FLOAT16_NEG_INF, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_NEG_INF) + + def test_mul_infinity_by_zero_is_nan(self): + """Test inf * 0 = NaN""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_POS_ZERO}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_mul_infinity_by_infinity(self): + """Test inf * inf = inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_POS_INF}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_INF) + + def test_mul_pos_infinity_by_neg_infinity(self): + """Test inf * (-inf) = -inf""" + self.sim.step({"a": FLOAT16_POS_INF, "b": FLOAT16_NEG_INF}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_INF) + self.assertFloat16Equal("result_rtz", FLOAT16_NEG_INF) + + def test_mul_nan_propagates(self): + """Test x * NaN = NaN""" + self.sim.step({"a": FLOAT16_ONE, "b": FLOAT16_NAN}) + self.assertTrue(is_nan(self.sim.inspect("result_rne"))) + self.assertTrue(is_nan(self.sim.inspect("result_rtz"))) + + def test_mul_denormalized_flushed_to_zero(self): + """Test that denormalized operands are flushed to zero.""" + self.sim.step({"a": FLOAT16_DENORMALIZED, "b": FLOAT16_ONE}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_ZERO) + self.assertFloat16Equal("result_rtz", FLOAT16_POS_ZERO) + + ############################ + # Overflow tests. + + def test_overflow_rne_produces_infinity(self): + """Test that overflow produces infinity with RNE.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_POS_INF) + + def test_overflow_rtz_produces_largest_finite(self): + """Test that overflow produces largest finite with RTZ.""" + self.sim.step({"a": FLOAT16_LARGEST_NORMAL, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rtz", FLOAT16_LARGEST_NORMAL) + + def test_negative_overflow_rne_produces_neg_infinity(self): + """Test that negative overflow produces -infinity with RNE.""" + neg_largest = FLOAT16_LARGEST_NORMAL | 0x8000 + self.sim.step({"a": neg_largest, "b": FLOAT16_TWO}) + self.assertFloat16Equal("result_rne", FLOAT16_NEG_INF) + + def test_negative_overflow_rtz_produces_neg_largest_finite(self): + """Test that negative overflow produces -largest finite with RTZ.""" + neg_largest = FLOAT16_LARGEST_NORMAL | 0x8000 + self.sim.step({"a": neg_largest, "b": FLOAT16_TWO}) + expected = FLOAT16_LARGEST_NORMAL | 0x8000 + self.assertFloat16Equal("result_rtz", expected) + + +if __name__ == "__main__": + unittest.main()