-
Notifications
You must be signed in to change notification settings - Fork 1
/
arith.h
44 lines (36 loc) · 1.13 KB
/
arith.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Basic math functions
#pragma once
#include "cutil.h"
#include <algorithm>
#include <cmath>
namespace mandelbrot {
using std::abs;
using std::max;
// Floating point helpers
#define ARITH(S) \
__host__ __device__ static inline S sqr(S x) { return x * x; } \
__host__ __device__ static inline S half(S x) { return S(0.5) * x; } \
__host__ __device__ static inline S twice(S x) { return x + x; } \
__host__ __device__ static inline S inv(S x) { return 1 / x; } \
__host__ __device__ static inline S bound(S x) { return abs(x); } \
__host__ __device__ static inline S fma(const S x, const S y, const S s) { return __builtin_fma(x, y, s); }
ARITH(float)
ARITH(double)
#undef ARITH
// Integer overload to make generated code happy
__host__ __device__ static inline int twice(int x) { return x << 1; }
// relu(x) = max(0, x)
template<class T> static inline T relu(const T& x) {
return max(x, T(0));
}
static inline int exponent(const double x) {
int e;
frexp(x, &e);
return e;
}
template<class I> static inline I exact_div(const I a, const I b) {
const I r = a / b;
slow_assert(r * b == a);
return r;
}
} // namespace mandelbrot