module; #ifdef __x86_64 #include #endif export module Crafter.Math:Common; import std; namespace Crafter { #ifdef __AVX512FP16__ export template struct VectorF16; #endif export template struct VectorF32; template struct VectorBase { #ifdef __AVX512FP16__ template friend struct VectorF16; #endif template friend struct VectorF32; protected: static consteval std::uint8_t GetAlingment() { if(Len * Packing * sizeof(T) <= 16) { return 16; } else if(Len * Packing * sizeof(T) <= 32) { return 32; } else if(Len * Packing * sizeof(T) <= 64) { return 64; } } using VectorType = std::conditional_t, std::conditional_t<(Len * Packing > 16), __m512h, std::conditional_t<(Len * Packing > 8), __m256h, __m128h>>, std::conditional_t<(Len * Packing > 8), __m512, std::conditional_t<(Len * Packing > 4), __m256, __m128>> >; VectorType v; public: template friend struct VectorBase; #ifdef __AVX512F__ static constexpr std::uint8_t Max = 64; #else static constexpr std::uint8_t Max = 32; #endif static constexpr std::uint8_t MaxElement = Max/sizeof(T); static constexpr std::uint8_t AlignmentElement = GetAlingment()/sizeof(T); static constexpr std::uint8_t Alignment = GetAlingment(); static_assert(Len * Packing <= MaxElement, "Len * Packing exceeds MaxElement"); protected: static constexpr std::uint8_t PerLane = 16/sizeof(T); static consteval std::array GetAllTrue() { std::array arr{}; arr.fill(true); return arr; } template ShuffleValues> static consteval bool CheckEpi32Shuffle() { if constexpr (PerLane == 8) { for(std::uint8_t i = 1; i < Len; i+=2) { if(ShuffleValues[i-1] != ShuffleValues[i] - 1) { return false; } } } for(std::uint8_t i = 0; i < Len; i++) { for(std::uint8_t i2 = PerLane; i2 < Len; i2 += PerLane) { if(ShuffleValues[i] != ShuffleValues[i2]) { return false; } } } return true; } template ShuffleValues> static consteval bool CheckEpi8Shuffle() { for(std::uint8_t i = 0; i < Len; i++) { std::uint8_t lane = i / PerLane; if(ShuffleValues[i] < lane * PerLane || ShuffleValues[i] > lane * PerLane + PerLane-1) { return false; } } return true; } template ShuffleValues> static consteval std::array GetShuffleMaskEpi8() { std::array shuffleMask {{0}}; if constexpr(std::same_as) { for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T)); shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+1)] = ShuffleValues[i]*sizeof(T)+1+(i2*Len*sizeof(T)); } } } else if constexpr(std::same_as) { for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T)); shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+1)] = ShuffleValues[i]*sizeof(T)+1+(i2*Len*sizeof(T)); shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+2)] = ShuffleValues[i]*sizeof(T)+2+(i2*Len*sizeof(T)); shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+3)] = ShuffleValues[i]*sizeof(T)+3+(i2*Len*sizeof(T)); } } } return shuffleMask; } template values> static consteval std::array GetNegateMask() { std::array mask{}; T high_bit = 0; if constexpr(sizeof(T) == 2) { high_bit = std::bit_cast( static_cast(1u << (std::numeric_limits::digits - 1)) ); } else if constexpr(sizeof(T) == 4) { high_bit = std::bit_cast( static_cast(1u << (std::numeric_limits::digits - 1)) ); } for (std::uint8_t i2 = 0; i2 < Packing; ++i2) { for (std::uint8_t i = 0; i < Len; ++i) { mask[i2 * Len + i] = values[i] ? high_bit : T(0); } } return mask; } template static constexpr std::array GetExtractLoMaskEpi8() { std::array mask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < ExtractLen; i++) { mask[(i2*ExtractLen*sizeof(T))+(i*sizeof(T))] = i*sizeof(T)+(i2*Len*sizeof(T)); mask[(i2*ExtractLen*sizeof(T))+(i*sizeof(T)+1)] = i*sizeof(T)+1+(i2*Len*sizeof(T)); } } return mask; } template static consteval std::array GetExtractLoMaskEpi16() { std::array mask{}; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < ExtractLen; i++) { mask[i2 * ExtractLen + i] = i + (i2 * Len); } } return mask; } template static consteval std::array GetExtractLoMaskEpi32() { std::array mask{}; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < ExtractLen; i++) { mask[i2 * ExtractLen + i] = i + (i2 * Len); } } return mask; } template ShuffleValues> static consteval std::uint8_t GetShuffleMaskEpi32() { std::uint8_t mask = 0; for(std::uint8_t i = 0; i < std::min(Len, std::uint8_t(8)); i+=4/sizeof(T)) { mask = mask | (ShuffleValues[i] & 0b11) << (8 / sizeof(T) * i); } return mask; } template ShuffleValues> static consteval std::array GetPermuteMaskEpi16() { std::array shuffleMask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len; } } return shuffleMask; } template ShuffleValues> static consteval std::array GetPermuteMaskEpi32() { std::array shuffleMask {{0}}; for(std::uint8_t i2 = 0; i2 < Packing; i2++) { for(std::uint8_t i = 0; i < Len; i++) { shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len; } } return shuffleMask; } template ShuffleValues> static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint8_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint16_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint32_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint8_t GetBlendMaskEpi32() requires (std::is_same_v){ std::uint8_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint16_t GetBlendMaskEpi32() requires (std::is_same_v){ std::uint16_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } template ShuffleValues> static consteval std::uint32_t GetBlendMaskEpi32() requires (std::is_same_v){ std::uint32_t mask = 0; for (std::uint8_t i2 = 0; i2 < Packing; i2++) { for (std::uint8_t i = 0; i < Len; i++) { if (ShuffleValues[i]) { mask |= (1u << (i2 * Len + i)); } } } return mask; } static constexpr float two_over_pi = 0.6366197723675814f; static constexpr float pi_over_2_hi = 1.5707963267341256f; static constexpr float pi_over_2_lo = 6.077100506506192e-11f; // Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ... static constexpr float c0 = 1.0f; static constexpr float c2 = -0.4999999642372f; static constexpr float c4 = 0.0416666418707f; static constexpr float c6 = -0.0013888397720f; static constexpr float c8 = 0.0000248015873f; static constexpr float c10 = -0.0000002752258f; // Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...) static constexpr float s1 = -0.1666666641831f; static constexpr float s3 = 0.0083333293858f; static constexpr float s5 = -0.0001984090955f; static constexpr float s7 = 0.0000027526372f; static constexpr float s9 = -0.0000000239013f; // --- 128-bit (SSE) helpers --- static constexpr void range_reduce_f32x4(__m128 ax, __m128& r, __m128& r2, __m128i& q) { __m128 fq = _mm_round_ps(_mm_mul_ps(ax, _mm_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); q = _mm_cvtps_epi32(fq); r = _mm_sub_ps(ax, _mm_mul_ps(fq, _mm_set1_ps(pi_over_2_hi))); r = _mm_sub_ps(r, _mm_mul_ps(fq, _mm_set1_ps(pi_over_2_lo))); r2 = _mm_mul_ps(r, r); } static constexpr void sincos_poly_f32x4(__m128 r, __m128 r2, __m128& cos_r, __m128& sin_r) { cos_r = _mm_fmadd_ps(_mm_set1_ps(c10), r2, _mm_set1_ps(c8)); cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c6)); cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c4)); cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c2)); cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c0)); sin_r = _mm_fmadd_ps(_mm_set1_ps(s9), r2, _mm_set1_ps(s7)); sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(s5)); sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(s3)); sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(s1)); sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(1.0f)); sin_r = _mm_mul_ps(sin_r, r); } // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 static constexpr __m128 cos_f32x4(__m128 x) { const __m128 sign_mask = _mm_set1_ps(-0.0f); __m128 ax = _mm_andnot_ps(sign_mask, x); __m128 r, r2; __m128i q; range_reduce_f32x4(ax, r, r2, q); __m128 cos_r, sin_r; sincos_poly_f32x4(r, r2, cos_r, sin_r); __m128i odd = _mm_and_si128(q, _mm_set1_epi32(1)); __m128 use_sin = _mm_castsi128_ps(_mm_cmpeq_epi32(odd, _mm_set1_epi32(1))); __m128 result = _mm_blendv_ps(cos_r, sin_r, use_sin); __m128i need_neg = _mm_and_si128( _mm_add_epi32(q, _mm_set1_epi32(1)), _mm_set1_epi32(2)); __m128 neg_mask = _mm_castsi128_ps(_mm_slli_epi32(need_neg, 30)); return _mm_xor_ps(result, neg_mask); } // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign static constexpr __m128 sin_f32x4(__m128 x) { const __m128 sign_mask = _mm_set1_ps(-0.0f); __m128 x_sign = _mm_and_ps(x, sign_mask); __m128 ax = _mm_andnot_ps(sign_mask, x); __m128 r, r2; __m128i q; range_reduce_f32x4(ax, r, r2, q); __m128 cos_r, sin_r; sincos_poly_f32x4(r, r2, cos_r, sin_r); __m128i odd = _mm_and_si128(q, _mm_set1_epi32(1)); __m128 use_cos = _mm_castsi128_ps(_mm_cmpeq_epi32(odd, _mm_set1_epi32(1))); __m128 result = _mm_blendv_ps(sin_r, cos_r, use_cos); __m128i need_neg = _mm_and_si128(q, _mm_set1_epi32(2)); __m128 neg_mask = _mm_castsi128_ps(_mm_slli_epi32(need_neg, 30)); result = _mm_xor_ps(result, neg_mask); // Apply original sign of x return _mm_xor_ps(result, x_sign); } // --- 128-bit sincos --- static constexpr void sincos_f32x4(__m128 x, __m128& out_sin, __m128& out_cos) { const __m128 sign_mask = _mm_set1_ps(-0.0f); __m128 x_sign = _mm_and_ps(x, sign_mask); __m128 ax = _mm_andnot_ps(sign_mask, x); __m128 r, r2; __m128i q; range_reduce_f32x4(ax, r, r2, q); __m128 cos_r, sin_r; sincos_poly_f32x4(r, r2, cos_r, sin_r); __m128i odd = _mm_and_si128(q, _mm_set1_epi32(1)); __m128 is_odd = _mm_castsi128_ps(_mm_cmpeq_epi32(odd, _mm_set1_epi32(1))); // cos: swap on odd, negate if (q+1)&2 out_cos = _mm_blendv_ps(cos_r, sin_r, is_odd); __m128i cos_neg = _mm_and_si128(_mm_add_epi32(q, _mm_set1_epi32(1)), _mm_set1_epi32(2)); out_cos = _mm_xor_ps(out_cos, _mm_castsi128_ps(_mm_slli_epi32(cos_neg, 30))); // sin: swap on odd, negate if q&2, apply input sign out_sin = _mm_blendv_ps(sin_r, cos_r, is_odd); __m128i sin_neg = _mm_and_si128(q, _mm_set1_epi32(2)); out_sin = _mm_xor_ps(out_sin, _mm_castsi128_ps(_mm_slli_epi32(sin_neg, 30))); out_sin = _mm_xor_ps(out_sin, x_sign); } // Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant static constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) { __m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); q = _mm256_cvtps_epi32(fq); r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi))); r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo))); r2 = _mm256_mul_ps(r, r); } static constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) { cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2)); cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0)); sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1)); sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f)); sin_r = _mm256_mul_ps(sin_r, r); } // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 static constexpr __m256 cos_f32x8(__m256 x) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); __m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin); __m256i need_neg = _mm256_and_si256( _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); return _mm256_xor_ps(result, neg_mask); } // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign static constexpr __m256 sin_f32x8(__m256 x) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 x_sign = _mm256_and_ps(x, sign_mask); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); __m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos); __m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); result = _mm256_xor_ps(result, neg_mask); // Apply original sign of x return _mm256_xor_ps(result, x_sign); } // // --- 512-bit helpers --- static constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) { __m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); q = _mm512_cvtps_epi32(fq); r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi))); r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo))); r2 = _mm512_mul_ps(r, r); } static constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) { cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2)); cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0)); sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1)); sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f)); sin_r = _mm512_mul_ps(sin_r, r); } static constexpr __m512 cos_f32x16(__m512 x) { __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); __m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r); __m512i need_neg = _mm512_and_si512( _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); return _mm512_xor_ps(result, neg_mask); } static constexpr __m512 sin_f32x16(__m512 x) { __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); __m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r); __m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); result = _mm512_xor_ps(result, neg_mask); return _mm512_xor_ps(result, x_sign); } // --- 256-bit sincos --- static constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) { const __m256 sign_mask = _mm256_set1_ps(-0.0f); __m256 x_sign = _mm256_and_ps(x, sign_mask); __m256 ax = _mm256_andnot_ps(sign_mask, x); __m256 r, r2; __m256i q; range_reduce_f32x8(ax, r, r2, q); __m256 cos_r, sin_r; sincos_poly_f32x8(r, r2, cos_r, sin_r); __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); __m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); // cos: swap on odd, negate if (q+1)&2 out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd); __m256i cos_neg = _mm256_and_si256(_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); out_cos = _mm256_xor_ps(out_cos, _mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30))); // sin: swap on odd, negate if q&2, apply input sign out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd); __m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); out_sin = _mm256_xor_ps(out_sin,_mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30))); out_sin = _mm256_xor_ps(out_sin, x_sign); } // --- 512-bit sincos --- static constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) { __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 ax = _mm512_abs_ps(x); __m512 r, r2; __m512i q; range_reduce_f32x16(ax, r, r2, q); __m512 cos_r, sin_r; sincos_poly_f32x16(r, r2, cos_r, sin_r); __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); // cos out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r); __m512i cos_neg = _mm512_and_si512(_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); out_cos = _mm512_xor_ps(out_cos, _mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30))); // sin out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r); __m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); out_sin = _mm512_xor_ps(out_sin, _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30))); out_sin = _mm512_xor_ps(out_sin, x_sign); } }; }