#include <float.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>

/* returns the correctly rounded to float result of 
   the mathematical expression a*b+c. 
   Assumes round-to-nearest, a≥0, b≥0, c≥0, all finite. */
float myfma(float a, float b, float c) 
{
  double p1 = a * (double) b; 
  double p2 = c;
  double s1 = p1 + p2;
  if (p1 < p2)
    {
      double tmp = p1;
      p1 = p2;
      p2 = tmp;
    }
  double r1 = s1 - p1 - p2; /* fma = s1 + r1 */

  float f1 = s1;            

  if (r1 == 0) return f1;

  double t = s1 - f1;

  if (t == 0) return f1; /* definitely not a halfway point */

  double dir = copysign(1.0 / 0.0, t);
  float f2 = nextafterf(f1, dir);
  double ulp = f2 - (double) f1;
  
#if 0
  printf("f1:%a\nf2:%a\np1:%a\np2:%a\ns1:%a\nt :%a\nr1:%a\nul:%a\n", 
          f1,    f2,    p1,    p2,    s1,    t,     r1,    ulp);
#endif
  
  if (fabs(t) > fabs(0.75 * ulp))
    return f2;

  if (fabs(t) < fabs(0.25 * ulp))
    return f1;

  double r = t - 0.5 * ulp + r1;

  if (r > 0 && t > 0 || r < 0 && t < 0)
    return f2;

  if (r > 0 && t < 0 || r < 0 && t > 0)
    return f1;

  /* Round to even. */
  return f1 + 0.5 * ulp;
}
  
int main()
{
  while (1)
    {
      int a = (rand() & 0xFFFFF) << (rand() & 7);
      int b = (rand() & 0xFFFFF) << (rand() & 7);
      long long c = (long long)(rand() & 0xFFFFF) << (rand() & 31);

      float truefma = (long long) a * b + c;
      float r = myfma(a, b, c);
      if (r != truefma)
        {
          printf("bad: %d %d %lld\ntr:%a\nmy:%a\n\n", a, b, c, truefma, r);
        }
    }
}
