Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
#ifndef _TRT_FP16_H_
#define _TRT_FP16_H_

#include <cublas_v2.h>

namespace fp16
{
// Code added before equivalent code was available via cuda.
// This code needs to be removed when we ship for cuda-9.2.
template<typename T, typename U> T bitwise_cast(U u)
{
    return *reinterpret_cast<T*>(&u);
}

__half __float2half(float f)
{
    uint32_t x = bitwise_cast<uint32_t, float>(f);
    uint32_t u = (x & 0x7fffffff);

    // Get rid of +NaN/-NaN case first.
    if (u > 0x7f800000)
        return bitwise_cast<__half, uint16_t>(uint16_t(0x7fff));
  
    uint16_t sign = ((x >> 16) & 0x8000);
  
    // Get rid of +Inf/-Inf, +0/-0.
    if (u > 0x477fefff)
        return bitwise_cast<__half, uint16_t>(sign | uint16_t(0x7c00));

    if (u < 0x33000001)
        return bitwise_cast<__half, uint16_t>(sign | uint16_t(0x0000));

    uint32_t exponent = ((u >> 23) & 0xff);
    uint32_t mantissa = (u & 0x7fffff);

    uint32_t shift;
    if (exponent > 0x70)
    {
        shift = 13;
        exponent -= 0x70;
    }
    else
    {
        shift = 0x7e - exponent;
        exponent = 0;
        mantissa |= 0x800000;
    }

    uint32_t lsb    = (1 << shift);
    uint32_t lsb_s1 = (lsb >> 1);
    uint32_t lsb_m1 = (lsb - 1);
  
    // Round to nearest even.
    uint32_t remainder = (mantissa & lsb_m1);
    mantissa >>= shift;
    if ( (remainder > lsb_s1) || ((remainder == lsb_s1) && (mantissa & 0x1)) )
    {
        ++mantissa;
        if (!(mantissa & 0x3ff))
        {
            ++exponent;
            mantissa = 0;
        }
    }
    
    return bitwise_cast<__half, uint16_t>(sign | uint16_t(exponent<<10) | uint16_t(mantissa));
}

float __half2float(__half h)
{
    uint16_t x        = bitwise_cast<uint16_t,__half>(h);
    uint32_t sign     = ((x >> 15) & 1);
    uint32_t exponent = ((x >> 10) & 0x1f);
    uint32_t mantissa = (static_cast<uint32_t>(x & 0x3ff) << 13);

    if (exponent == 0x1f)
    {  /* NaN or Inf */
        if (mantissa != 0)
        {   // NaN
            sign     = 0;
            mantissa = 0x7fffff;
        }
        else // Inf
            mantissa = 0;
        exponent = 0xff;
    }
    else if (!exponent)
    {  /* Denorm or Zero */
        if (mantissa) {
            unsigned int msb;
            exponent = 0x71;
            do
            {
                msb = (mantissa & 0x400000);
                mantissa <<= 1; /* normalize */
                --exponent;
            }
            while (!msb);
            mantissa &= 0x7fffff; /* 1.mantissa is implicit */
        }
    }
    else
        exponent += 0x70;
    return bitwise_cast<float, uint32_t>( (sign<<31) | (exponent<<23) | mantissa );
}

};

#endif // _TRT_FP16_H_