F16 tests
This commit is contained in:
parent
7bd67a2cb9
commit
b2b4ca9c4d
3 changed files with 660 additions and 48 deletions
|
|
@ -148,7 +148,7 @@ namespace Crafter {
|
|||
}
|
||||
|
||||
|
||||
constexpr void operator+=(VectorF16<Len, Packing> b) const {
|
||||
constexpr void operator+=(VectorF16<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
v = _mm_add_ph(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
|
|
@ -158,7 +158,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator-=(VectorF16<Len, Packing> b) const {
|
||||
constexpr void operator-=(VectorF16<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
v = _mm_sub_ph(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
|
|
@ -168,7 +168,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator*=(VectorF16<Len, Packing> b) const {
|
||||
constexpr void operator*=(VectorF16<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
v = _mm_mul_ph(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
|
|
@ -178,7 +178,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr void operator/=(VectorF16<Len, Packing> b) const {
|
||||
constexpr void operator/=(VectorF16<Len, Packing> b) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
v = _mm_div_ph(v, b.v);
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
|
|
@ -188,48 +188,48 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr VectorF16<Len, Packing> operator+(_Float16 b) const {
|
||||
constexpr VectorF16<Len, Packing> operator+(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
return this + vB;
|
||||
return *this + vB;
|
||||
}
|
||||
|
||||
constexpr VectorF16<Len, Packing> operator-(_Float16 b) const {
|
||||
constexpr VectorF16<Len, Packing> operator-(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
return this - vB;
|
||||
return *this - vB;
|
||||
}
|
||||
|
||||
constexpr VectorF16<Len, Packing> operator*(_Float16 b) const {
|
||||
constexpr VectorF16<Len, Packing> operator*(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
return this * vB;
|
||||
return *this * vB;
|
||||
}
|
||||
|
||||
constexpr VectorF16<Len, Packing> operator/(_Float16 b) const {
|
||||
constexpr VectorF16<Len, Packing> operator/(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
return this / vB;
|
||||
return *this / vB;
|
||||
}
|
||||
|
||||
constexpr void operator+=(_Float16 b) const {
|
||||
constexpr void operator+=(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
this += vB;
|
||||
*this += vB;
|
||||
}
|
||||
|
||||
constexpr void operator-=(_Float16 b) const {
|
||||
constexpr void operator-=(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
this -= vB;
|
||||
*this -= vB;
|
||||
}
|
||||
|
||||
constexpr void operator*=(_Float16 b) const {
|
||||
constexpr void operator*=(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
this *= vB;
|
||||
*this *= vB;
|
||||
}
|
||||
|
||||
constexpr void operator/=(_Float16 b) const {
|
||||
constexpr void operator/=(_Float16 b) {
|
||||
VectorF16<Len, Packing> vB(b);
|
||||
this /= vB;
|
||||
*this /= vB;
|
||||
}
|
||||
|
||||
constexpr VectorF16<Len, Packing> operator-(){
|
||||
return Negate<GetAllTrue>();
|
||||
return Negate<GetAllTrue()>();
|
||||
}
|
||||
|
||||
constexpr bool operator==(VectorF16<Len, Packing> b) const {
|
||||
|
|
@ -281,28 +281,88 @@ namespace Crafter {
|
|||
}
|
||||
|
||||
constexpr VectorF16<Len, Packing> Cos() {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing>(_mm_cos_ph(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing>(_mm256_cos_ph(v));
|
||||
if constexpr (std::is_same_v<VectorType, __m128h>) {
|
||||
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v));
|
||||
wide = cos_f32x8(wide);
|
||||
return VectorF16<Len, Packing>(
|
||||
_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
||||
|
||||
} else if constexpr (std::is_same_v<VectorType, __m256h>) {
|
||||
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v));
|
||||
wide = cos_f32x16(wide);
|
||||
return VectorF16<Len, Packing>(
|
||||
_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
||||
|
||||
} else {
|
||||
return VectorF16<Len, Packing>(_mm512_cos_ph(v));
|
||||
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v));
|
||||
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1);
|
||||
__m256i lo_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT);
|
||||
__m256i hi_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT);
|
||||
return VectorF16<Len, Packing>(
|
||||
_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr VectorF16<Len, Packing> Sin() {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing>(_mm_sin_ph(v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing>(_mm256_sin_ph(v));
|
||||
if constexpr (std::is_same_v<VectorType, __m128h>) {
|
||||
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v));
|
||||
wide = sin_f32x8(wide);
|
||||
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
||||
|
||||
} else if constexpr (std::is_same_v<VectorType, __m256h>) {
|
||||
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v));
|
||||
wide = sin_f32x16(wide);
|
||||
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
||||
|
||||
} else {
|
||||
return VectorF16<Len, Packing>(_mm512_sin_ph(v));
|
||||
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v));
|
||||
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1);
|
||||
__m256i lo_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT);
|
||||
__m256i hi_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT);
|
||||
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>> SinCos() {
|
||||
if constexpr (std::is_same_v<VectorType, __m128h>) {
|
||||
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v));
|
||||
__m256 s, c;
|
||||
sincos_f32x8(wide, s, c);
|
||||
return {
|
||||
VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))),
|
||||
VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT)))
|
||||
};
|
||||
|
||||
} else if constexpr (std::is_same_v<VectorType, __m256h>) {
|
||||
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v));
|
||||
__m512 s, c;
|
||||
sincos_f32x16(wide, s, c);
|
||||
return {
|
||||
VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))),
|
||||
VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT)))
|
||||
};
|
||||
|
||||
} else {
|
||||
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v));
|
||||
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1);
|
||||
|
||||
__m512 s_lo, c_lo, s_hi, c_hi;
|
||||
sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo);
|
||||
sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi);
|
||||
|
||||
auto pack = [](__m256i lo_ph, __m256i hi_ph) {
|
||||
return _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1));
|
||||
};
|
||||
return {
|
||||
VectorF16<Len, Packing>(pack(_mm512_cvtps_ph(s_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(s_hi, _MM_FROUND_TO_NEAREST_INT))),
|
||||
VectorF16<Len, Packing>(pack( _mm512_cvtps_ph(c_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(c_hi, _MM_FROUND_TO_NEAREST_INT)))
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
template <std::array<bool, Len> values>
|
||||
constexpr std::array<std::uint16_t, Len> Negate() {
|
||||
std::array<std::uint16_t, Len> mask = GetShuffleMaskEpi32<values>();
|
||||
constexpr VectorF16<Len, Packing> Negate() {
|
||||
std::array<std::uint16_t, Len> mask = GetNegateMask<values>();
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(v), _mm_loadu_epi16(mask.data()))));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
|
|
@ -341,21 +401,21 @@ namespace Crafter {
|
|||
|
||||
static constexpr VectorF16<Len, Packing> MulitplyAdd(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b, VectorF16<Len, Packing> add) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing>(_mm_fmadd_ph(a.v, b.v, add));
|
||||
return VectorF16<Len, Packing>(_mm_fmadd_ph(a.v, b.v, add.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF16<Len, Packing>(_mm256_fmadd_ph(a.v, b.v, add));
|
||||
return VectorF16<Len, Packing>(_mm256_fmadd_ph(a.v, b.v, add.v));
|
||||
} else {
|
||||
return VectorF16<Len, Packing>(_mm512_fmadd_ph(a.v, b.v, add));
|
||||
return VectorF16<Len, Packing>(_mm512_fmadd_ph(a.v, b.v, add.v));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr VectorF16<Len, Packing> MulitplySub(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b, VectorF16<Len, Packing> sub) {
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing>(_mm_fmsub_ph(a.v, b.v, sub));
|
||||
return VectorF16<Len, Packing>(_mm_fmsub_ph(a.v, b.v, sub.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF16<Len, Packing>(_mm256_fmsub_ph(a.v, b.v, sub));
|
||||
return VectorF16<Len, Packing>(_mm256_fmsub_ph(a.v, b.v, sub.v));
|
||||
} else {
|
||||
return VectorF16<Len, Packing>(_mm512_fmsub_ph(a.v, b.v, sub));
|
||||
return VectorF16<Len, Packing>(_mm512_fmsub_ph(a.v, b.v, sub.v));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1227,7 +1287,7 @@ namespace Crafter {
|
|||
return shuffleMask;
|
||||
}
|
||||
|
||||
consteval std::array<bool, Len> GetAllTrue() {
|
||||
static consteval std::array<bool, Len> GetAllTrue() {
|
||||
std::array<bool, Len> arr{};
|
||||
arr.fill(true);
|
||||
return arr;
|
||||
|
|
@ -1288,6 +1348,216 @@ namespace Crafter {
|
|||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
static constexpr float two_over_pi = 0.6366197723675814f;
|
||||
static constexpr float pi_over_2_hi = 1.5707963267341256f;
|
||||
static constexpr float pi_over_2_lo = 6.077100506506192e-11f;
|
||||
|
||||
// Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ...
|
||||
static constexpr float c0 = 1.0f;
|
||||
static constexpr float c2 = -0.4999999642372f;
|
||||
static constexpr float c4 = 0.0416666418707f;
|
||||
static constexpr float c6 = -0.0013888397720f;
|
||||
static constexpr float c8 = 0.0000248015873f;
|
||||
static constexpr float c10 = -0.0000002752258f;
|
||||
|
||||
// Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...)
|
||||
static constexpr float s1 = -0.1666666641831f;
|
||||
static constexpr float s3 = 0.0083333293858f;
|
||||
static constexpr float s5 = -0.0001984090955f;
|
||||
static constexpr float s7 = 0.0000027526372f;
|
||||
static constexpr float s9 = -0.0000000239013f;
|
||||
|
||||
// Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant
|
||||
constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) {
|
||||
__m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
q = _mm256_cvtps_epi32(fq);
|
||||
r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi)));
|
||||
r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo)));
|
||||
r2 = _mm256_mul_ps(r, r);
|
||||
}
|
||||
|
||||
|
||||
// cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2
|
||||
constexpr __m256 cos_f32x8(__m256 x) {
|
||||
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
|
||||
__m256 ax = _mm256_andnot_ps(sign_mask, x);
|
||||
|
||||
__m256 r, r2; __m256i q;
|
||||
range_reduce_f32x8(ax, r, r2, q);
|
||||
|
||||
__m256 cos_r, sin_r;
|
||||
sincos_poly_f32x8(r, r2, cos_r, sin_r);
|
||||
|
||||
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
|
||||
__m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
|
||||
__m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin);
|
||||
|
||||
__m256i need_neg = _mm256_and_si256(
|
||||
_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2));
|
||||
__m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30));
|
||||
return _mm256_xor_ps(result, neg_mask);
|
||||
}
|
||||
|
||||
constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) {
|
||||
cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2));
|
||||
cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0));
|
||||
|
||||
sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1));
|
||||
sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f));
|
||||
sin_r = _mm256_mul_ps(sin_r, r);
|
||||
}
|
||||
|
||||
// sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign
|
||||
constexpr __m256 sin_f32x8(__m256 x) {
|
||||
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
|
||||
__m256 x_sign = _mm256_and_ps(x, sign_mask);
|
||||
__m256 ax = _mm256_andnot_ps(sign_mask, x);
|
||||
|
||||
__m256 r, r2; __m256i q;
|
||||
range_reduce_f32x8(ax, r, r2, q);
|
||||
|
||||
__m256 cos_r, sin_r;
|
||||
sincos_poly_f32x8(r, r2, cos_r, sin_r);
|
||||
|
||||
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
|
||||
__m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
|
||||
__m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos);
|
||||
|
||||
__m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2));
|
||||
__m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30));
|
||||
result = _mm256_xor_ps(result, neg_mask);
|
||||
|
||||
// Apply original sign of x
|
||||
return _mm256_xor_ps(result, x_sign);
|
||||
}
|
||||
|
||||
// --- 512-bit helpers ---
|
||||
|
||||
constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) {
|
||||
__m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
q = _mm512_cvtps_epi32(fq);
|
||||
r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi)));
|
||||
r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo)));
|
||||
r2 = _mm512_mul_ps(r, r);
|
||||
}
|
||||
|
||||
constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) {
|
||||
cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2));
|
||||
cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0));
|
||||
|
||||
sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1));
|
||||
sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f));
|
||||
sin_r = _mm512_mul_ps(sin_r, r);
|
||||
}
|
||||
|
||||
constexpr __m512 cos_f32x16(__m512 x) {
|
||||
__m512 ax = _mm512_abs_ps(x);
|
||||
|
||||
__m512 r, r2; __m512i q;
|
||||
range_reduce_f32x16(ax, r, r2, q);
|
||||
|
||||
__m512 cos_r, sin_r;
|
||||
sincos_poly_f32x16(r, r2, cos_r, sin_r);
|
||||
|
||||
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
|
||||
__m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r);
|
||||
|
||||
__m512i need_neg = _mm512_and_si512(
|
||||
_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2));
|
||||
__m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30));
|
||||
return _mm512_xor_ps(result, neg_mask);
|
||||
}
|
||||
|
||||
constexpr __m512 sin_f32x16(__m512 x) {
|
||||
__m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f));
|
||||
__m512 ax = _mm512_abs_ps(x);
|
||||
|
||||
__m512 r, r2; __m512i q;
|
||||
range_reduce_f32x16(ax, r, r2, q);
|
||||
|
||||
__m512 cos_r, sin_r;
|
||||
sincos_poly_f32x16(r, r2, cos_r, sin_r);
|
||||
|
||||
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
|
||||
__m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r);
|
||||
|
||||
__m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2));
|
||||
__m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30));
|
||||
result = _mm512_xor_ps(result, neg_mask);
|
||||
|
||||
return _mm512_xor_ps(result, x_sign);
|
||||
}
|
||||
|
||||
// --- 256-bit sincos ---
|
||||
constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) {
|
||||
const __m256 sign_mask = _mm256_set1_ps(-0.0f);
|
||||
__m256 x_sign = _mm256_and_ps(x, sign_mask);
|
||||
__m256 ax = _mm256_andnot_ps(sign_mask, x);
|
||||
|
||||
__m256 r, r2; __m256i q;
|
||||
range_reduce_f32x8(ax, r, r2, q);
|
||||
|
||||
__m256 cos_r, sin_r;
|
||||
sincos_poly_f32x8(r, r2, cos_r, sin_r);
|
||||
|
||||
__m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1));
|
||||
__m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1)));
|
||||
|
||||
// cos: swap on odd, negate if (q+1)&2
|
||||
out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd);
|
||||
__m256i cos_neg = _mm256_and_si256(
|
||||
_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2));
|
||||
out_cos = _mm256_xor_ps(out_cos,
|
||||
_mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30)));
|
||||
|
||||
// sin: swap on odd, negate if q&2, apply input sign
|
||||
out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd);
|
||||
__m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2));
|
||||
out_sin = _mm256_xor_ps(out_sin,
|
||||
_mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30)));
|
||||
out_sin = _mm256_xor_ps(out_sin, x_sign);
|
||||
}
|
||||
|
||||
// --- 512-bit sincos ---
|
||||
constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) {
|
||||
__m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f));
|
||||
__m512 ax = _mm512_abs_ps(x);
|
||||
|
||||
__m512 r, r2; __m512i q;
|
||||
range_reduce_f32x16(ax, r, r2, q);
|
||||
|
||||
__m512 cos_r, sin_r;
|
||||
sincos_poly_f32x16(r, r2, cos_r, sin_r);
|
||||
|
||||
__mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1));
|
||||
|
||||
// cos
|
||||
out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r);
|
||||
__m512i cos_neg = _mm512_and_si512(
|
||||
_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2));
|
||||
out_cos = _mm512_xor_ps(out_cos,
|
||||
_mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30)));
|
||||
|
||||
// sin
|
||||
out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r);
|
||||
__m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2));
|
||||
out_sin = _mm512_xor_ps(out_sin,
|
||||
_mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30)));
|
||||
out_sin = _mm512_xor_ps(out_sin, x_sign);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue