Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 18 additions & 25 deletions GridKit/CommonMath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@ namespace GridKit
{
namespace Math
{
/**
* @brief Smoothing scale shared by CommonMath primitives
*
* Used by @ref sigmoid, @ref ramp, and functions composed from them to set
* the width of smooth transitions.
*
* @tparam RealT - real data type
*/
template <typename RealT>
Comment thread
lukelowry marked this conversation as resolved.
inline constexpr RealT MU = 240.0;

/**
* @brief Scaled sigmoid activation function
*
* @note The sigmoid constant (mu) value is chosen to balance accuracy
* and finite derivatives. Large values more closely approximate a step
* function, but lead to inf or NaN derivatives.
* function, but can make the transition numerically stiff.
*
* @tparam ScalarT - scalar data type
*
Expand All @@ -25,9 +36,8 @@ namespace GridKit
template <class ScalarT>
__attribute__((always_inline)) inline ScalarT sigmoid(const ScalarT x)
{
using RealT = typename GridKit::ScalarTraits<ScalarT>::RealT;
static constexpr RealT MU = 240.0;
return ONE<RealT> / (ONE<RealT> + std::exp(-MU * x));
using RealT = typename GridKit::ScalarTraits<ScalarT>::RealT;
return HALF<RealT> * (ONE<RealT> + std::tanh(HALF<RealT> * MU<RealT> * x));
}

/**
Expand All @@ -44,12 +54,11 @@ namespace GridKit
template <class ScalarT>
__attribute__((always_inline)) inline ScalarT ramp(const ScalarT x)
{
using RealT = typename GridKit::ScalarTraits<ScalarT>::RealT;
static constexpr RealT MU = 240.0;
using RealT = typename GridKit::ScalarTraits<ScalarT>::RealT;

ScalarT z = MU * x;
ScalarT a = std::abs(z);
return (HALF<RealT> * (z + a) + std::log1p(std::exp(-a))) / MU;
RealT mu = MU<RealT>;
ScalarT a = std::abs(mu * x);
return HALF<RealT> * (x + a / mu) + std::log1p(std::exp(-a)) / mu;
}

/**
Expand Down Expand Up @@ -224,22 +233,6 @@ namespace GridKit
return height / (upper - lower) * (ramp(x - lower) - ramp(x - upper));
}

/**
* @brief Derivative of the scaled sigmoid activation function
* (i.e., approximation to the delta dirac function)
*
* @tparam ScalarT - scalar data type
*
* @param[in] x - expected to be of order 1
* @return value of the sigmoid function
*/
template <class ScalarT>
__attribute__((always_inline)) inline ScalarT dsigmoid(const ScalarT x)
{
using RealT = typename GridKit::ScalarTraits<ScalarT>::RealT;
return FOUR<RealT> * sigmoid(x) * (ONE<RealT> - sigmoid(x));
}

/**
* @brief Smooth above-limit indicator
*
Expand Down
Loading
Loading