f32, u32, and const

Some time ago, I wrote “floats, bits, and constant expressions” about converting floating point number into its representative ones and zeros as a C++ constant expression – constructing the IEEE 754 representation without being able to examine the bits directly.

I’ve been playing around with Rust recently, and rewrote that conversion code as a bit of a learning exercise for myself, with a thoroughly contrived set of constraints: using integer and single-precision floating point math, at compile time, without unsafe blocks, while using as few unstable features as possible.

I’ve included the listing below, for your bemusement and/or head-shaking, and you can play with the code in the Rust Playground and rust.godbolt.org

// Jonathan Adamczewski 2020-05-12
//
// Constructing the bit-representation of an IEEE 754 single precision floating 
// point number, using integer and single-precision floating point math, at 
// compile time, in rust, without unsafe blocks, while using as few unstable 
// features as I can.
//
// or "What if this silly C++ thing https://brnz.org/hbr/?p=1518 but in Rust?"


// Q. Why? What is this good for?
// A. To the best of my knowledge, this code serves no useful purpose. 
//    But I did learn a thing or two while writing it :)


// This is needed to be able to perform floating point operations in a const 
// function:
#![feature(const_fn)]


// bits_transmute(): Returns the bits representing a floating point value, by
//                   way of std::mem::transmute()
//
// For completeness (and validation), and to make it clear the fundamentally 
// unnecessary nature of the exercise :D - here's a short, straightforward, 
// library-based version. But it needs the const_transmute flag and an unsafe 
// block.
#![feature(const_transmute)]
const fn bits_transmute(f: f32) -> u32 {
  unsafe { std::mem::transmute::<f32, u32>(f) }
}



// get_if_u32(predicate:bool, if_true: u32, if_false: u32):
//   Returns if_true if predicate is true, else if_false
//
// If and match are not able to be used in const functions (at least, not 
// without #![feature(const_if_match)] - so here's a branch-free select function
// for u32s
const fn get_if_u32(predicate: bool, if_true: u32, if_false: u32) -> u32 {
  let pred_mask = (-1 * (predicate as i32)) as u32;
  let true_val = if_true & pred_mask;
  let false_val = if_false & !pred_mask;
  true_val | false_val
}

// get_if_f32(predicate, if_true, if_false):
//   Returns if_true if predicate is true, else if_false
//
// A branch-free select function for f32s.
// 
// If either is_true or is_false is NaN or an infinity, the result will be NaN,
// which is not ideal. I don't know of a better way to implement this function
// within the arbitrary limitations of this silly little side quest.
const fn get_if_f32(predicate: bool, if_true: f32, if_false: f32) -> f32 {
  // can't convert bool to f32 - but can convert bool to i32 to f32
  let pred_sel = (predicate as i32) as f32;
  let pred_not_sel = ((!predicate) as i32) as f32;
  let true_val = if_true * pred_sel;
  let false_val = if_false * pred_not_sel;
  true_val + false_val
}


// bits(): Returns the bits representing a floating point value.
const fn bits(f: f32) -> u32 {
  // the result value, initialized to a NaN value that will otherwise not be
  // produced by this function.
  let mut r = 0xffff_ffff;

  // These floation point operations (and others) cause the following error:
  //     only int, `bool` and `char` operations are stable in const fn
  // hence #![feature(const_fn)] at the top of the file
  
  // Identify special cases
  let is_zero    = f == 0_f32;
  let is_inf     = f == f32::INFINITY;
  let is_neg_inf = f == f32::NEG_INFINITY;
  let is_nan     = f != f;

  // Writing this as !(is_zero || is_inf || ...) cause the following error:
  //     Loops and conditional expressions are not stable in const fn
  // so instead write this as type coversions, and bitwise operations
  //
  // "normalish" here means that f is a normal or subnormal value
  let is_normalish = 0 == ((is_zero as u32) | (is_inf as u32) | 
                        (is_neg_inf as u32) | (is_nan as u32));

  // set the result value for each of the special cases
  r = get_if_u32(is_zero,    0,           r); // if (iz_zero)    { r = 0; }
  r = get_if_u32(is_inf,     0x7f80_0000, r); // if (is_inf)     { r = 0x7f80_0000; }
  r = get_if_u32(is_neg_inf, 0xff80_0000, r); // if (is_neg_inf) { r = 0xff80_0000; }
  r = get_if_u32(is_nan,     0x7fc0_0000, r); // if (is_nan)     { r = 0x7fc0_0000; }
 
  // It was tempting at this point to try setting f to a "normalish" placeholder 
  // value so that special cases do not have to be handled in the code that 
  // follows, like so:
  // f = get_if_f32(is_normal, f, 1_f32);
  //
  // Unfortunately, get_if_f32() returns NaN if either input is NaN or infinite.
  // Instead of switching the value, we work around the non-normalish cases 
  // later.
  //
  // (This whole function is branch-free, so all of it is executed regardless of 
  // the input value)

  // extract the sign bit
  let sign_bit  = get_if_u32(f < 0_f32,  1, 0);

  // compute the absolute value of f
  let mut abs_f = get_if_f32(f < 0_f32, -f, f);

  
  // This part is a little complicated. The algorithm is functionally the same 
  // as the C++ version linked from the top of the file.
  // 
  // Because of the various contrived constraints on thie problem, we compute 
  // the exponent and significand, rather than extract the bits directly.
  //
  // The idea is this:
  // Every finite single precision float point number can be represented as a
  // series of (at most) 24 significant digits as a 128.149 fixed point number 
  // (128: 126 exponent values >= 0, plus one for the implicit leading 1, plus 
  // one more so that the decimal point falls on a power-of-two boundary :)
  // 149: 126 negative exponent values, plus 23 for the bits of precision in the 
  // significand.)
  //
  // If we are able to scale the number such that all of the precision bits fall 
  // in the upper-most 64 bits of that fixed-point representation (while 
  // tracking our effective manipulation of the exponent), we can then 
  // predictably and simply scale that computed value back to a range than can 
  // be converted safely to a u64, count the leading zeros to determine the 
  // exact exponent, and then shift the result into position for the final u32 
  // representation.
  
  // Start with the largest possible exponent - subsequent steps will reduce 
  // this number as appropriate
  let mut exponent: u32 = 254;
  {
    // Hex float literals are really nice. I miss them.

    // The threshold is 2^87 (think: 64+23 bits) to ensure that the number will 
    // be large enough that, when scaled down by 2^64, all the precision will 
    // fit nicely in a u64
    const THRESHOLD: f32 = 154742504910672534362390528_f32; // 0x1p87f == 2^87

    // The scaling factor is 2^41 (think: 64-23 bits) to ensure that a number 
    // between 2^87 and 2^64 will not overflow in a single scaling step.
    const SCALE_UP: f32 = 2199023255552_f32; // 0x1p41f == 2^41

    // Because loops are not available (no #![feature(const_loops)], and 'if' is
    // not available (no #![feature(const_if_match)]), perform repeated branch-
    // free conditional multiplication of abs_f.

    // use a macro, because why not :D It's the most compact, simplest option I 
    // could find.
    macro_rules! maybe_scale {
      () => {{
        // care is needed: if abs_f is above the threshold, multiplying by 2^41 
        // will cause it to overflow (INFINITY) which will cause get_if_f32() to
        // return NaN, which will destroy the value in abs_f. So compute a safe 
        // scaling factor for each iteration.
        //
        // Roughly equivalent to :
        // if (abs_f < THRESHOLD) {
        //   exponent -= 41;
        //   abs_f += SCALE_UP;
        // }
        let scale = get_if_f32(abs_f < THRESHOLD, SCALE_UP,      1_f32);    
        exponent  = get_if_u32(abs_f < THRESHOLD, exponent - 41, exponent); 
        abs_f     = get_if_f32(abs_f < THRESHOLD, abs_f * scale, abs_f);
      }}
    }
    // 41 bits per iteration means up to 246 bits shifted.
    // Even the smallest subnormal value will end up in the desired range.
    maybe_scale!();  maybe_scale!();  maybe_scale!();
    maybe_scale!();  maybe_scale!();  maybe_scale!();
  }

  // Now that we know that abs_f is in the desired range (2^87 <= abs_f < 2^128)
  // scale it down to be in the range (2^23 <= _ < 2^64), and convert without 
  // loss of precision to u64.
  const INV_2_64: f32 = 5.42101086242752217003726400434970855712890625e-20_f32; // 0x1p-64f == 2^64
  let a = (abs_f * INV_2_64) as u64;

  // Count the leading zeros.
  // (C++ doesn't provide a compile-time constant function for this. It's nice 
  // that rust does :)
  let mut lz = a.leading_zeros();

  // if the number isn't normalish, lz is meaningless: we stomp it with 
  // something that will not cause problems in the computation that follows - 
  // the result of which is meaningless, and will be ignored in the end for 
  // non-normalish values.
  lz = get_if_u32(!is_normalish, 0, lz); // if (!is_normalish) { lz = 0; }

  {
    // This step accounts for subnormal numbers, where there are more leading 
    // zeros than can be accounted for in a valid exponent value, and leading 
    // zeros that must remain in the final significand.
    //
    // If lz < exponent, reduce exponent to its final correct value - lz will be
    // used to remove all of the leading zeros.
    //
    // Otherwise, clamp exponent to zero, and adjust lz to ensure that the 
    // correct number of bits will remain (after multiplying by 2^41 six times - 
    // 2^246 - there are 7 leading zeros ahead of the original subnormal's
    // computed significand of 0.sss...)
    // 
    // The following is roughly equivalent to:
    // if (lz < exponent) {
    //   exponent = exponent - lz;
    // } else {
    //   exponent = 0;
    //   lz = 7;
    // }

    // we're about to mess with lz and exponent - compute and store the relative 
    // value of the two
    let lz_is_less_than_exponent = lz < exponent;

    lz       = get_if_u32(!lz_is_less_than_exponent, 7,             lz);
    exponent = get_if_u32( lz_is_less_than_exponent, exponent - lz, 0);
  }

  // compute the final significand.
  // + 1 shifts away a leading 1-bit for normal, and 0-bit for subnormal values
  // Shifts are done in u64 (that leading bit is shifted into the void), then
  // the resulting bits are shifted back to their final resting place.
  let significand = ((a << (lz + 1)) >> (64 - 23)) as u32;

  // combine the bits
  let computed_bits = (sign_bit << 31) | (exponent << 23) | significand;

  // return the normalish result, or the non-normalish result, as appopriate
  get_if_u32(is_normalish, computed_bits, r)
}


// Compile-time validation - able to be examined in rust.godbolt.org output
pub static BITS_BIGNUM: u32 = bits(std::f32::MAX);
pub static TBITS_BIGNUM: u32 = bits_transmute(std::f32::MAX);
pub static BITS_LOWER_THAN_MIN: u32 = bits(7.0064923217e-46_f32);
pub static TBITS_LOWER_THAN_MIN: u32 = bits_transmute(7.0064923217e-46_f32);
pub static BITS_ZERO: u32 = bits(0.0f32);
pub static TBITS_ZERO: u32 = bits_transmute(0.0f32);
pub static BITS_ONE: u32 = bits(1.0f32);
pub static TBITS_ONE: u32 = bits_transmute(1.0f32);
pub static BITS_NEG_ONE: u32 = bits(-1.0f32);
pub static TBITS_NEG_ONE: u32 = bits_transmute(-1.0f32);
pub static BITS_INF: u32 = bits(std::f32::INFINITY);
pub static TBITS_INF: u32 = bits_transmute(std::f32::INFINITY);
pub static BITS_NEG_INF: u32 = bits(std::f32::NEG_INFINITY);
pub static TBITS_NEG_INF: u32 = bits_transmute(std::f32::NEG_INFINITY);
pub static BITS_NAN: u32 = bits(std::f32::NAN);
pub static TBITS_NAN: u32 = bits_transmute(std::f32::NAN);
pub static BITS_COMPUTED_NAN: u32 = bits(std::f32::INFINITY/std::f32::INFINITY);
pub static TBITS_COMPUTED_NAN: u32 = bits_transmute(std::f32::INFINITY/std::f32::INFINITY);


// Run-time validation of many more values
fn main() {
  let end: usize = 0xffff_ffff;
  let count = 9_876_543; // number of values to test
  let step = end / count;
  for u in (0..=end).step_by(step) {
      let v = u as u32;
      
      // reference
      let f = unsafe { std::mem::transmute::<u32, f32>(v) };
      
      // compute
      let c = bits(f);

      // validation
      if c != v && 
         !(f.is_nan() && c == 0x7fc0_0000) && // nans
         !(v == 0x8000_0000 && c == 0) { // negative 0
          println!("{:x?} {:x?}", v, c); 
      }
  }
}