diff --git a/interfaces/Crafter.Math-Basic.cppm b/interfaces/Crafter.Math-Basic.cppm index 135d7ba..5f392a5 100755 --- a/interfaces/Crafter.Math-Basic.cppm +++ b/interfaces/Crafter.Math-Basic.cppm @@ -21,7 +21,6 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA export module Crafter.Math:Basic; import std; import :VectorF16; -import :VectorF32; namespace Crafter { template diff --git a/interfaces/Crafter.Math-Common.cppm b/interfaces/Crafter.Math-Common.cppm new file mode 100644 index 0000000..16041f4 --- /dev/null +++ b/interfaces/Crafter.Math-Common.cppm @@ -0,0 +1,409 @@ +module; +#ifdef __x86_64 +#include +#endif +export module Crafter.Math:Common; +import std; + +namespace Crafter { + export template + struct VectorF16; + + template + struct VectorBase { + template + friend struct VectorF16; + protected: + static consteval std::uint8_t GetAlingment() { + if(Len * Packing * sizeof(T) <= 16) { + return 16; + } else if(Len * Packing * sizeof(T) <= 32) { + return 32; + } else if(Len * Packing * sizeof(T) <= 64) { + return 64; + } + } + using VectorType = std::conditional_t< + (Len * Packing > 16), __m512h, + std::conditional_t<(Len * Packing > 8), __m256h, __m128h> + >; + + VectorType v; + + public: + + template + friend struct VectorBase; + + #ifdef __AVX512F__ + static constexpr std::uint8_t Max = 64; + #else + static constexpr std::uint8_t Max = 32; + #endif + static constexpr std::uint8_t MaxElement = Max/sizeof(T); + + static constexpr std::uint8_t AlignmentElement = GetAlingment()/sizeof(T); + static constexpr std::uint8_t Alignment = GetAlingment(); + static_assert(Len * Packing <= MaxElement, "Len * Packing exceeds MaxElement"); + + protected: + static constexpr std::uint8_t PerLane = 16/sizeof(T); + static consteval std::array GetAllTrue() { + std::array arr{}; + arr.fill(true); + return arr; + } + + template ShuffleValues> + static consteval bool CheckEpi32Shuffle() { + if constexpr (PerLane == 8) { + for(std::uint8_t i = 1; i < Len; i+=2) { + if(ShuffleValues[i-1] != ShuffleValues[i] - 1) { + return false; + } + } + } + for(std::uint8_t i = 0; i < Len; i++) { + for(std::uint8_t i2 = PerLane; i2 < Len; i2 += PerLane) { + if(ShuffleValues[i] != ShuffleValues[i2]) { + return false; + } + } + } + return true; + } + + template ShuffleValues> + static consteval bool CheckEpi8Shuffle() { + for(std::uint8_t i = 0; i < Len; i++) { + std::uint8_t lane = i / PerLane; + if(ShuffleValues[i] < lane * PerLane || ShuffleValues[i] > lane * PerLane + PerLane-1) { + return false; + } + } + return true; + } + + template ShuffleValues> + static consteval std::array GetShuffleMaskEpi8() { + std::array shuffleMask {{0}}; + for(std::uint8_t i2 = 0; i2 < Packing; i2++) { + for(std::uint8_t i = 0; i < Len; i++) { + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T)); + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+1)] = ShuffleValues[i]*sizeof(T)+1+(i2*Len*sizeof(T)); + } + } + return shuffleMask; + } + + + template values> + static consteval std::array GetNegateMask() { + std::array mask{}; + + T high_bit = 0; + + if constexpr(sizeof(T) == 2) { + high_bit = std::bit_cast( + static_cast(1u << (std::numeric_limits::digits - 1)) + ); + } + + + + for (std::uint8_t i2 = 0; i2 < Packing; ++i2) { + for (std::uint8_t i = 0; i < Len; ++i) { + mask[i2 * Len + i] = values[i] ? high_bit : T(0); + } + } + + return mask; + } + + template + static constexpr std::array GetExtractLoMaskEpi8() { + std::array mask {{0}}; + for(std::uint8_t i2 = 0; i2 < Packing; i2++) { + for(std::uint8_t i = 0; i < ExtractLen; i++) { + mask[(i2*ExtractLen*sizeof(T))+(i*sizeof(T))] = i*sizeof(T)+(i2*Len*sizeof(T)); + mask[(i2*ExtractLen*sizeof(T))+(i*sizeof(T)+1)] = i*sizeof(T)+1+(i2*Len*sizeof(T)); + } + } + return mask; + } + + template + static consteval std::array GetExtractLoMaskEpi16() { + std::array mask{}; + for (std::uint16_t i2 = 0; i2 < Packing; i2++) { + for (std::uint16_t i = 0; i < ExtractLen; i++) { + mask[i2 * ExtractLen + i] = i + (i2 * Len); + } + } + return mask; + } + + template ShuffleValues> + static consteval std::uint8_t GetShuffleMaskEpi32() { + std::uint8_t mask = 0; + for(std::uint8_t i = 0; i < std::min(Len, std::uint8_t(8)); i+=2) { + mask = mask | (ShuffleValues[i] & 0b11) << i; + } + return mask; + } + + template ShuffleValues> + static consteval std::array GetPermuteMaskEpi16() { + std::array shuffleMask {{0}}; + for(std::uint8_t i2 = 0; i2 < Packing; i2++) { + for(std::uint8_t i = 0; i < Len; i++) { + shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len; + } + } + return shuffleMask; + } + + template ShuffleValues> + static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v){ + std::uint8_t mask = 0; + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < Len; i++) { + if (ShuffleValues[i]) { + mask |= (1u << (i2 * Len + i)); + } + } + } + return mask; + } + + template ShuffleValues> + static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v){ + std::uint16_t mask = 0; + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < Len; i++) { + if (ShuffleValues[i]) { + mask |= (1u << (i2 * Len + i)); + } + } + } + return mask; + } + + template ShuffleValues> + static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v){ + std::uint32_t mask = 0; + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < Len; i++) { + if (ShuffleValues[i]) { + mask |= (1u << (i2 * Len + i)); + } + } + } + return mask; + } + + static constexpr float two_over_pi = 0.6366197723675814f; + static constexpr float pi_over_2_hi = 1.5707963267341256f; + static constexpr float pi_over_2_lo = 6.077100506506192e-11f; + + // Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ... + static constexpr float c0 = 1.0f; + static constexpr float c2 = -0.4999999642372f; + static constexpr float c4 = 0.0416666418707f; + static constexpr float c6 = -0.0013888397720f; + static constexpr float c8 = 0.0000248015873f; + static constexpr float c10 = -0.0000002752258f; + + // Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...) + static constexpr float s1 = -0.1666666641831f; + static constexpr float s3 = 0.0083333293858f; + static constexpr float s5 = -0.0001984090955f; + static constexpr float s7 = 0.0000027526372f; + static constexpr float s9 = -0.0000000239013f; + + // Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant + static constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) { + __m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q = _mm256_cvtps_epi32(fq); + r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi))); + r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo))); + r2 = _mm256_mul_ps(r, r); + } + + + static constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) { + cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2)); + cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0)); + + sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1)); + sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f)); + sin_r = _mm256_mul_ps(sin_r, r); + } + + // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 + static constexpr __m256 cos_f32x8(__m256 x) { + const __m256 sign_mask = _mm256_set1_ps(-0.0f); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + + __m256 r, r2; __m256i q; + range_reduce_f32x8(ax, r, r2, q); + + __m256 cos_r, sin_r; + sincos_poly_f32x8(r, r2, cos_r, sin_r); + + __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); + __m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); + __m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin); + + __m256i need_neg = _mm256_and_si256( + _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); + __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); + return _mm256_xor_ps(result, neg_mask); + } + + // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign + static constexpr __m256 sin_f32x8(__m256 x) { + const __m256 sign_mask = _mm256_set1_ps(-0.0f); + __m256 x_sign = _mm256_and_ps(x, sign_mask); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + + __m256 r, r2; __m256i q; + range_reduce_f32x8(ax, r, r2, q); + + __m256 cos_r, sin_r; + sincos_poly_f32x8(r, r2, cos_r, sin_r); + + __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); + __m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); + __m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos); + + __m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); + __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); + result = _mm256_xor_ps(result, neg_mask); + + // Apply original sign of x + return _mm256_xor_ps(result, x_sign); + } + + // // --- 512-bit helpers --- + + static constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) { + __m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q = _mm512_cvtps_epi32(fq); + r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi))); + r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo))); + r2 = _mm512_mul_ps(r, r); + } + + static constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) { + cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2)); + cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0)); + + sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1)); + sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f)); + sin_r = _mm512_mul_ps(sin_r, r); + } + + static constexpr __m512 cos_f32x16(__m512 x) { + __m512 ax = _mm512_abs_ps(x); + + __m512 r, r2; __m512i q; + range_reduce_f32x16(ax, r, r2, q); + + __m512 cos_r, sin_r; + sincos_poly_f32x16(r, r2, cos_r, sin_r); + + __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); + __m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r); + + __m512i need_neg = _mm512_and_si512( + _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); + __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); + return _mm512_xor_ps(result, neg_mask); + } + + static constexpr __m512 sin_f32x16(__m512 x) { + __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); + __m512 ax = _mm512_abs_ps(x); + + __m512 r, r2; __m512i q; + range_reduce_f32x16(ax, r, r2, q); + + __m512 cos_r, sin_r; + sincos_poly_f32x16(r, r2, cos_r, sin_r); + + __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); + __m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r); + + __m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); + __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); + result = _mm512_xor_ps(result, neg_mask); + + return _mm512_xor_ps(result, x_sign); + } + + // --- 256-bit sincos --- + static constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) { + const __m256 sign_mask = _mm256_set1_ps(-0.0f); + __m256 x_sign = _mm256_and_ps(x, sign_mask); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + + __m256 r, r2; __m256i q; + range_reduce_f32x8(ax, r, r2, q); + + __m256 cos_r, sin_r; + sincos_poly_f32x8(r, r2, cos_r, sin_r); + + __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); + __m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); + + // cos: swap on odd, negate if (q+1)&2 + out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd); + __m256i cos_neg = _mm256_and_si256(_mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); + out_cos = _mm256_xor_ps(out_cos, _mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30))); + + // sin: swap on odd, negate if q&2, apply input sign + out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd); + __m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); + out_sin = _mm256_xor_ps(out_sin,_mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30))); + out_sin = _mm256_xor_ps(out_sin, x_sign); + } + + // --- 512-bit sincos --- + static constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) { + __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); + __m512 ax = _mm512_abs_ps(x); + + __m512 r, r2; __m512i q; + range_reduce_f32x16(ax, r, r2, q); + + __m512 cos_r, sin_r; + sincos_poly_f32x16(r, r2, cos_r, sin_r); + + __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); + + // cos + out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r); + __m512i cos_neg = _mm512_and_si512(_mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); + out_cos = _mm512_xor_ps(out_cos, _mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30))); + + // sin + out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r); + __m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); + out_sin = _mm512_xor_ps(out_sin, _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30))); + out_sin = _mm512_xor_ps(out_sin, x_sign); + } + }; +} \ No newline at end of file diff --git a/interfaces/Crafter.Math-VectorF16.cppm b/interfaces/Crafter.Math-VectorF16.cppm index ee3a207..fa43358 100755 --- a/interfaces/Crafter.Math-VectorF16.cppm +++ b/interfaces/Crafter.Math-VectorF16.cppm @@ -22,273 +22,125 @@ module; #endif export module Crafter.Math:VectorF16; import std; -import :Vector; +import :Common; #ifdef __AVX512FP16__ namespace Crafter { - export template - struct VectorF16 { - private: - static consteval std::uint8_t GetAlingment() { - if(Len * Packing <= 8) { - return 8; - } else if(Len * Packing <= 16) { - return 16; - } else if(Len * Packing <= 32) { - return 32; - } - } - using VectorType = std::conditional_t< - (Len * Packing > 16), __m512h, - std::conditional_t<(Len * Packing > 8), __m256h, __m128h> - >; - - VectorType v; - public: - static constexpr std::uint32_t MaxSize = 32; - static constexpr std::uint8_t Alignment = GetAlingment(); - static_assert(Len * Packing <= MaxSize, "Len * Packing exceeds MaxSize"); - private: - - template values> - static consteval std::array GetNegateMask() { - std::array mask{0}; - for(std::uint8_t i2 = 0; i2 < Packing; i2++) { - for(std::uint8_t i = 0; i < Len; i++) { - if(values[i]) { - mask[i2*Len+i] = 0b1000000000000000; - } else { - mask[i2*Len+i] = 0; - } - } - } - return mask; - } - - static consteval std::array GetNegateMaskAll() { - std::array mask{0}; - for(std::uint8_t i = 0; i < Packing*Len; i++) { - mask[i] = 0b1000000000000000; - } - return mask; - } - - template ShuffleValues> - static consteval bool GetShuffleMaskEpi32() { - std::uint8_t mask = 0; - for(std::uint8_t i = 0; i < std::min(Len, std::uint32_t(8)); i+=2) { - mask = mask | (ShuffleValues[i] & 0b11) << i; - } - return mask; - } - - template ShuffleValues> - static consteval std::array::Alignment> GetPermuteMaskEpi16() { - std::array::Alignment> shuffleMask {{0}}; - for(std::uint8_t i2 = 0; i2 < Packing; i2++) { - for(std::uint8_t i = 0; i < Len; i++) { - shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len; - } - } - return shuffleMask; - } - - static consteval std::array GetAllTrue() { - std::array arr{}; - arr.fill(true); - return arr; - } - - template ShuffleValues> - static consteval bool CheckEpi32Shuffle() { - for(std::uint8_t i = 1; i < Len; i+=2) { - if(ShuffleValues[i-1] != ShuffleValues[i] - 1) { - return false; - } - } - for(std::uint8_t i = 0; i < Len; i++) { - for(std::uint8_t i2 = 0; i2 < Len; i2 += 8) { - if(ShuffleValues[i] != ShuffleValues[i2]) { - return false; - } - } - } - return true; - } - - template ShuffleValues> - static consteval bool CheckEpi8Shuffle() { - for(std::uint8_t i = 0; i < Len; i++) { - std::uint8_t lane = i / 8; - if(ShuffleValues[i] < lane * 8 || ShuffleValues[i] > lane * 8 + 7) { - return false; - } - } - return true; - } - - template ShuffleValues> - static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v){ - std::uint8_t mask = 0; - for (std::uint8_t i2 = 0; i2 < Packing; i2++) { - for (std::uint8_t i = 0; i < Len; i++) { - if (ShuffleValues[i]) { - mask |= (1u << (i2 * Len + i)); - } - } - } - return mask; - } - - template ShuffleValues> - static consteval std::uint16_t GetBlendMaskEpi16() requires (std::is_same_v){ - std::uint16_t mask = 0; - for (std::uint8_t i2 = 0; i2 < Packing; i2++) { - for (std::uint8_t i = 0; i < Len; i++) { - if (ShuffleValues[i]) { - mask |= (1u << (i2 * Len + i)); - } - } - } - return mask; - } - - template ShuffleValues> - static consteval std::uint32_t GetBlendMaskEpi16() requires (std::is_same_v){ - std::uint32_t mask = 0; - for (std::uint8_t i2 = 0; i2 < Packing; i2++) { - for (std::uint8_t i = 0; i < Len; i++) { - if (ShuffleValues[i]) { - mask |= (1u << (i2 * Len + i)); - } - } - } - return mask; - } - - template ShuffleValues> - static consteval std::array::Alignment*2> GetShuffleMaskEpi8() { - std::array::Alignment*2> shuffleMask {{0}}; - for(std::uint8_t i2 = 0; i2 < Packing; i2++) { - for(std::uint8_t i = 0; i < Len; i++) { - shuffleMask[(i2*Len*2)+(i*2)] = ShuffleValues[i]*2+(i2*Len*2); - shuffleMask[(i2*Len*2)+(i*2+1)] = ShuffleValues[i]*2+1+(i2*Len*2); - } - } - return shuffleMask; - } - public: - template - friend class VectorF16; + export template + struct VectorF16 : public VectorBase { + template + friend struct VectorF16; constexpr VectorF16() = default; - constexpr VectorF16(VectorType v) : v(v) {} + constexpr VectorF16(VectorBase::VectorType v) { + this->v = v; + } constexpr VectorF16(const _Float16* vB) { Load(vB); }; constexpr VectorF16(_Float16 val) { - if constexpr(std::is_same_v) { - v = _mm_set1_ph(val); - } else if constexpr(std::is_same_v) { - v = _mm256_set1_ph(val); + if constexpr(std::is_same_v::VectorType, __m128h>) { + this->v = _mm_set1_ph(val); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + this->v = _mm256_set1_ph(val); } else { - v = _mm512_set1_ph(val); + this->v = _mm512_set1_ph(val); } }; constexpr void Load(const _Float16* vB) { - if constexpr(std::is_same_v) { - v = _mm_loadu_ph(vB); - } else if constexpr(std::is_same_v) { - v = _mm256_loadu_ph(vB); + if constexpr(std::is_same_v::VectorType, __m128h>) { + this->v = _mm_loadu_ph(vB); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + this->v = _mm256_loadu_ph(vB); } else { - v = _mm512_loadu_ph(vB); + this->v = _mm512_loadu_ph(vB); } } constexpr void Store(_Float16* vB) const { - if constexpr(std::is_same_v) { - _mm_storeu_ph(vB, v); - } else if constexpr(std::is_same_v) { - _mm256_storeu_ph(vB, v); + if constexpr(std::is_same_v::VectorType, __m128h>) { + _mm_storeu_ph(vB, this->v); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + _mm256_storeu_ph(vB, this->v); } else { - _mm512_storeu_ph(vB, v); + _mm512_storeu_ph(vB, this->v); } } - constexpr Vector<_Float16, Len*Packing, Alignment> Store() const { - Vector<_Float16, Len*Packing, Alignment> returnVec; - Store(returnVec.v); - return returnVec; + constexpr std::array<_Float16, VectorBase::AlignmentElement> Store() const { + std::array<_Float16, VectorBase::AlignmentElement> returnArray; + Store(returnArray.data()); + return returnArray; } - template + template constexpr operator VectorF16() const { if constexpr (Len == BLen) { - if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { - return VectorF16(_mm256_castph256_ph128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { - return VectorF16(_mm512_castph512_ph128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { - return VectorF16(_mm512_castph512_ph256(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { - return VectorF16(_mm256_castph128_ph256(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512h>) { - return VectorF16(_mm512_castph128_ph512(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512h>) { - return VectorF16(_mm512_castph256_ph512(v)); + if constexpr(std::is_same_v::VectorType, __m256h> && std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm256_castph256_ph128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512h> && std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm512_castph512_ph128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512h> && std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm512_castph512_ph256(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m128h> && std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm256_castph128_ph256(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m128h> && std::is_same_v::VectorType, __m512h>) { + return VectorF16(_mm512_castph128_ph512(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m256h> && std::is_same_v::VectorType, __m512h>) { + return VectorF16(_mm512_castph256_ph512(this->v)); } else { - return VectorF16(v); + return VectorF16(this->v); } } else if constexpr (BLen <= Len) { return this->template ExtractLo(); } else { - if constexpr(std::is_same_v::VectorType, __m128h>) { - if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetExtractLoMaskEpi8(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + if constexpr(std::is_same_v::VectorType, __m128h>) { + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetExtractLoMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); - return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + constexpr std::array::AlignmentElement> permMask =VectorBase::template GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); - __m256i result = _mm256_permutexvar_epi16(permIdx, _mm_castph_si256(v)); + __m256i result = _mm256_permutexvar_epi16(permIdx, _mm_castph_si256(this->v)); return VectorF16(_mm_castsi128_ph(_mm256_castsi256_si128(result))); } else { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); - __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)); + __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v)); return VectorF16(_mm_castsi128_ph(_mm512_castsi512_si128(result))); } - } else if constexpr(std::is_same_v::VectorType, __m256h>) { - if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + if constexpr(std::is_same_v::VectorType, __m128h>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); - __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castsi128_si256(_mm_castph_si128(v))); + __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castsi128_si256(_mm_castph_si128(this->v))); return VectorF16(_mm256_castsi256_ph(result)); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); - __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v)); + __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(this->v)); return VectorF16(_mm256_castsi256_ph(result)); } else { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m256i permIdx = _mm512_loadu_epi16(permMask.data()); - __m256i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi512_si256(_mm512_castph_si512(v))); + __m256i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi512_si256(_mm512_castph_si512(this->v))); return VectorF16(_mm256_castsi256_ph(result)); } } else { - if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); - __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi128_si512(_mm_castph_si128(v))); + __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi128_si512(_mm_castph_si128(this->v))); return VectorF16(_mm512_castsi512_ph(result)); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); - __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi256_si512(_mm256_castph_si256(v))); + __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi256_si512(_mm256_castph_si256(this->v))); return VectorF16(_mm512_castsi512_ph(result)); } else { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); - __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)); + __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v)); return VectorF16(_mm512_castsi512_ph(result)); } } @@ -296,83 +148,83 @@ namespace Crafter { } constexpr VectorF16 operator+(VectorF16 b) const { - if constexpr(std::is_same_v) { - return VectorF16(_mm_add_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_add_ph(v, b.v)); + if constexpr(std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm_add_ph(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm256_add_ph(this->v, b.v)); } else { - return VectorF16(_mm512_add_ph(v, b.v)); + return VectorF16(_mm512_add_ph(this->v, b.v)); } } constexpr VectorF16 operator-(VectorF16 b) const { - if constexpr(std::is_same_v) { - return VectorF16(_mm_sub_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_sub_ph(v, b.v)); + if constexpr(std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm_sub_ph(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm256_sub_ph(this->v, b.v)); } else { - return VectorF16(_mm512_sub_ph(v, b.v)); + return VectorF16(_mm512_sub_ph(this->v, b.v)); } } constexpr VectorF16 operator*(VectorF16 b) const { - if constexpr(std::is_same_v) { - return VectorF16(_mm_mul_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_mul_ph(v, b.v)); + if constexpr(std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm_mul_ph(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm256_mul_ph(this->v, b.v)); } else { - return VectorF16(_mm512_mul_ph(v, b.v)); + return VectorF16(_mm512_mul_ph(this->v, b.v)); } } constexpr VectorF16 operator/(VectorF16 b) const { - if constexpr(std::is_same_v) { - return VectorF16(_mm_div_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_div_ph(v, b.v)); + if constexpr(std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm_div_ph(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm256_div_ph(this->v, b.v)); } else { - return VectorF16(_mm512_div_ph(v, b.v)); + return VectorF16(_mm512_div_ph(this->v, b.v)); } } constexpr void operator+=(VectorF16 b) { - if constexpr(std::is_same_v) { - v = _mm_add_ph(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_add_ph(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128h>) { + this->v = _mm_add_ph(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + this->v = _mm256_add_ph(this->v, b.v); } else { - v = _mm512_add_ph(v, b.v); + this->v = _mm512_add_ph(this->v, b.v); } } constexpr void operator-=(VectorF16 b) { - if constexpr(std::is_same_v) { - v = _mm_sub_ph(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_sub_ph(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128h>) { + this->v = _mm_sub_ph(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + this->v = _mm256_sub_ph(this->v, b.v); } else { - v = _mm512_sub_ph(v, b.v); + this->v = _mm512_sub_ph(this->v, b.v); } } constexpr void operator*=(VectorF16 b) { - if constexpr(std::is_same_v) { - v = _mm_mul_ph(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_mul_ph(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128h>) { + this->v = _mm_mul_ph(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + this->v = _mm256_mul_ph(this->v, b.v); } else { - v = _mm512_mul_ph(v, b.v); + this->v = _mm512_mul_ph(this->v, b.v); } } constexpr void operator/=(VectorF16 b) { - if constexpr(std::is_same_v) { - v = _mm_div_ph(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_div_ph(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128h>) { + this->v = _mm_div_ph(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + this->v = _mm256_div_ph(this->v, b.v); } else { - v = _mm512_div_ph(v, b.v); + this->v = _mm512_div_ph(this->v, b.v); } } @@ -417,109 +269,86 @@ namespace Crafter { } constexpr VectorF16 operator-(){ - return Negate(); + return Negate::GetAllTrue()>(); } constexpr bool operator==(VectorF16 b) const { - if constexpr(std::is_same_v) { - return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 255; - } else if constexpr(std::is_same_v) { - return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 65535; + if constexpr(std::is_same_v::VectorType, __m128h>) { + return _mm_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) == 255; + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return _mm256_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) == 65535; } else { - return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) == 4294967295; + return _mm512_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) == 4294967295; } } constexpr bool operator!=(VectorF16 b) const { - if constexpr(std::is_same_v) { - return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 255; - } else if constexpr(std::is_same_v) { - return _mm256_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 65535; + if constexpr(std::is_same_v::VectorType, __m128h>) { + return _mm_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) != 255; + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return _mm256_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) != 65535; } else { - return _mm512_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 4294967295; + return _mm512_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) != 4294967295; } } - template - static consteval std::array::Alignment*2> GetExtractLoMaskEpi8() { - std::array::Alignment*2> mask {{0}}; - for(std::uint8_t i2 = 0; i2 < Packing; i2++) { - for(std::uint8_t i = 0; i < ExtractLen; i++) { - mask[(i2*ExtractLen*2)+(i*2)] = i*2+(i2*Len*2); - mask[(i2*ExtractLen*2)+(i*2+1)] = i*2+1+(i2*Len*2); - } - } - return mask; - } - - template - static consteval std::array::Alignment> GetExtractLoMaskEpi16() { - std::array::Alignment> mask{}; - for (std::uint16_t i2 = 0; i2 < Packing; i2++) { - for (std::uint16_t i = 0; i < ExtractLen; i++) { - mask[i2 * ExtractLen + i] = i + (i2 * Len); - } - } - return mask; - } - template constexpr VectorF16 ExtractLo() const { if constexpr(Packing > 1) { - if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetExtractLoMaskEpi8(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetExtractLoMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); - return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); - __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v)); - if constexpr(std::is_same_v::VectorType, __m128h>) { + __m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(this->v)); + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm256_castph256_ph128(_mm256_castsi256_ph(result))); } else { return VectorF16(_mm256_castsi256_ph(result)); } } else { - constexpr std::array::Alignment> permMask = GetExtractLoMaskEpi16(); + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); - __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)); - if constexpr(std::is_same_v::VectorType, __m128h>) { + __m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v)); + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm512_castph512_ph128(_mm512_castsi512_ph(result))); - } else if constexpr(std::is_same_v::VectorType, __m256h>) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm512_castph512_ph256(_mm512_castsi512_ph(result))); } else { return VectorF16(_mm512_castsi512_ph(result)); } } } else { - if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { - return VectorF16(_mm256_castph256_ph128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128h>) { - return VectorF16(_mm512_castph512_ph128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256h>) { - return VectorF16(_mm512_castph512_ph256(v)); + if constexpr(std::is_same_v::VectorType, __m256h> && std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm256_castph256_ph128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512h> && std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm512_castph512_ph128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512h> && std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm512_castph512_ph256(this->v)); } else { - return VectorF16(v); + return VectorF16(this->v); } } } constexpr void Normalize() requires(Packing == 1) { - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { _Float16 dot = LengthSq(); __m128h vec = _mm_set1_ph(dot); __m128h sqrt = _mm_sqrt_ph(vec); - v = _mm_div_ph(v, sqrt); - } else if constexpr(std::is_same_v) { + this->v = _mm_div_ph(this->v, sqrt); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { _Float16 dot = LengthSq(); __m256h vec = _mm256_set1_ph(dot); __m256h sqrt = _mm256_sqrt_ph(vec); - v = _mm256_div_ph(v, sqrt); + this->v = _mm256_div_ph(this->v, sqrt); } else { _Float16 dot = LengthSq(); __m512h vec = _mm512_set1_ph(dot); __m512h sqrt = _mm512_sqrt_ph(vec); - v = _mm512_div_ph(v, sqrt); + this->v = _mm512_div_ph(this->v, sqrt); } } @@ -533,74 +362,77 @@ namespace Crafter { } constexpr VectorF16 Cos() { - if constexpr (std::is_same_v) { - __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); - wide = cos_f32x8(wide); + if constexpr (std::is_same_v::VectorType, __m128h>) { + __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(this->v)); + wide = VectorBase::cos_f32x8(wide); return VectorF16( _mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); - } else if constexpr (std::is_same_v) { - __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); - wide = cos_f32x16(wide); + } else if constexpr (std::is_same_v::VectorType, __m256h>) { + __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(this->v)); + wide = VectorBase::cos_f32x16(wide); return VectorF16( _mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else { - __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); - __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); - __m256i lo_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); - __m256i hi_ph = _mm512_cvtps_ph(cos_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); - return VectorF16( - _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); + __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(this->v)); + __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(this->v), 1); + __m512i cosLo =VectorBase::cos_f32x16(_mm512_cvtph_ps(lo)); + __m512i cosHi =VectorBase::cos_f32x16(_mm512_cvtph_ps(hi)); + __m256i lo_ph = _mm512_cvtps_ph(cosLo, _MM_FROUND_TO_NEAREST_INT); + __m256i hi_ph = _mm512_cvtps_ph(cosHi, _MM_FROUND_TO_NEAREST_INT); + return VectorF16(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } } constexpr VectorF16 Sin() { - if constexpr (std::is_same_v) { - __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); - wide = sin_f32x8(wide); + if constexpr (std::is_same_v::VectorType, __m128h>) { + __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(this->v)); + wide = VectorBase::sin_f32x8(wide); return VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); - } else if constexpr (std::is_same_v) { - __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); - wide = sin_f32x16(wide); + } else if constexpr (std::is_same_v::VectorType, __m256h>) { + __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(this->v)); + wide = VectorBase::sin_f32x16(wide); return VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT))); } else { - __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); - __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); - __m256i lo_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(lo)), _MM_FROUND_TO_NEAREST_INT); - __m256i hi_ph = _mm512_cvtps_ph(sin_f32x16(_mm512_cvtph_ps(hi)), _MM_FROUND_TO_NEAREST_INT); + __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(this->v)); + __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(this->v), 1); + __m512i loSin = VectorBase::sin_f32x16(_mm512_cvtph_ps(lo)); + __m512i hiSin = VectorBase::sin_f32x16(_mm512_cvtph_ps(hi)); + __m256i lo_ph = _mm512_cvtps_ph(loSin, _MM_FROUND_TO_NEAREST_INT); + __m256i hi_ph = _mm512_cvtps_ph(hiSin, _MM_FROUND_TO_NEAREST_INT); return VectorF16(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1))); } } std::tuple, VectorF16> SinCos() { - if constexpr (std::is_same_v) { - __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(v)); + if constexpr (std::is_same_v::VectorType, __m128h>) { + __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(this->v)); __m256 s, c; - sincos_f32x8(wide, s, c); + VectorBase::sincos_f32x8(wide, s, c); return { VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), VectorF16(_mm_castsi128_ph(_mm256_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) }; - } else if constexpr (std::is_same_v) { - __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(v)); + } else if constexpr (std::is_same_v::VectorType, __m256h>) { + __m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(this->v)); __m512 s, c; - sincos_f32x16(wide, s, c); + VectorBase::sincos_f32x16(wide, s, c); return { VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))), VectorF16(_mm256_castsi256_ph(_mm512_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT))) }; } else { - __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(v)); - __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(v), 1); + __m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(this->v)); + __m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(this->v), 1); __m512 s_lo, c_lo, s_hi, c_hi; - sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo); - sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi); + VectorBase::sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo); + VectorBase::sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi); auto pack = [](__m256i lo_ph, __m256i hi_ph) { return _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)); @@ -614,20 +446,20 @@ namespace Crafter { template values> constexpr VectorF16 Negate() { - std::array mask = GetNegateMask(); - if constexpr(std::is_same_v) { - return VectorF16(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(v), _mm_loadu_epi16(mask.data())))); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(v), _mm256_loadu_epi16(mask.data())))); + std::array<_Float16, VectorBase::AlignmentElement> mask = VectorBase::template GetNegateMask(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(this->v), _mm_loadu_epi16(mask.data())))); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(this->v), _mm256_loadu_epi16(mask.data())))); } else { - return VectorF16(_mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(v), _mm512_loadu_epi16(mask.data())))); + return VectorF16(_mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(this->v), _mm512_loadu_epi16(mask.data())))); } } static constexpr VectorF16 MulitplyAdd(VectorF16 a, VectorF16 b, VectorF16 add) { - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm_fmadd_ph(a.v, b.v, add.v)); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm256_fmadd_ph(a.v, b.v, add.v)); } else { return VectorF16(_mm512_fmadd_ph(a.v, b.v, add.v)); @@ -635,9 +467,9 @@ namespace Crafter { } static constexpr VectorF16 MulitplySub(VectorF16 a, VectorF16 b, VectorF16 sub) { - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16(_mm_fmsub_ph(a.v, b.v, sub.v)); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16(_mm256_fmsub_ph(a.v, b.v, sub.v)); } else { return VectorF16(_mm512_fmsub_ph(a.v, b.v, sub.v)); @@ -645,26 +477,26 @@ namespace Crafter { } constexpr static VectorF16 Cross(VectorF16 a, VectorF16 b) requires(Len == 3) { - if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + constexpr std::array::AlignmentElement*2> shuffleMask1 = VectorBase::template GetShuffleMaskEpi8<{{1,2,0}}>(); __m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data()); __m128h row1 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec1)); __m128h row4 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec1)); - constexpr std::array::Alignment*2> shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); + constexpr std::array::AlignmentElement*2> shuffleMask3 = VectorBase::template GetShuffleMaskEpi8<{{2,0,1}}>(); __m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data()); __m128h row3 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec3)); __m128h row2 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec3)); __m128h result = _mm_mul_ph(row3, row4); return _mm_fmsub_ph(row1,row2,result); - } else if constexpr (std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); + } else if constexpr (std::is_same_v::VectorType, __m256h>) { + constexpr std::array::AlignmentElement*2> shuffleMask1 = VectorBase::template GetShuffleMaskEpi8<{{1,2,0}}>(); __m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data())); __m256h row1 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec1))); __m256h row4 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(b.v)), shuffleVec1))); - constexpr std::array::Alignment*2> shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); + constexpr std::array::AlignmentElement*2> shuffleMask3 = VectorBase::template GetShuffleMaskEpi8<{{2,0,1}}>(); __m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data())); __m256h row3 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec3))); @@ -673,13 +505,13 @@ namespace Crafter { __m256h result = _mm256_mul_ph(row3, row4); return _mm256_fmsub_ph(row1,row2,result); } else { - constexpr std::array::Alignment*2> shuffleMask1 = GetShuffleMaskEpi8<{{1,2,0}}>(); + constexpr std::array::AlignmentElement*2> shuffleMask1 = VectorBase::template GetShuffleMaskEpi8<{{1,2,0}}>(); __m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data()); __m512h row1 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec1)); __m512h row4 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec1)); - constexpr std::array::Alignment*2> shuffleMask3 = GetShuffleMaskEpi8<{{2,0,1}}>(); + constexpr std::array::AlignmentElement*2> shuffleMask3 = VectorBase::template GetShuffleMaskEpi8<{{2,0,1}}>(); __m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data()); __m512h row3 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec3)); @@ -691,11 +523,11 @@ namespace Crafter { } constexpr static _Float16 Dot(VectorF16 a, VectorF16 b) requires(Packing == 1) { - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { __m128h mul = _mm_mul_ph(a.v, b.v); return _mm_reduce_add_ph(mul); - } else if constexpr(std::is_same_v) { - static_assert(std::is_same_v, "a.v is NOT VectorType"); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + static_assert(std::is_same_v::VectorType>, "a.v is NOT VectorBase::VectorType"); __m256h mul = _mm256_mul_ph(a.v, b.v); return _mm256_reduce_add_ph(mul); } else { @@ -706,41 +538,40 @@ namespace Crafter { template ShuffleValues> constexpr VectorF16 Shuffle() { - if constexpr(CheckEpi32Shuffle()) { - if constexpr(std::is_same_v) { - return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(v), GetShuffleMaskEpi32()))); - } else if constexpr(std::is_same_v) { - return VectorF16(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(v), GetShuffleMaskEpi32()))); + if constexpr(VectorBase::template CheckEpi32Shuffle()) { + constexpr std::uint8_t imm = VectorBase::template GetShuffleMaskEpi32(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(this->v), imm))); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + return VectorF16(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(this->v), imm))); } else { - return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm512_castph_si512(v), GetShuffleMaskEpi32()))); + return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm512_castph_si512(this->v), imm))); } - } else if constexpr(CheckEpi8Shuffle()){ - if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + } else if constexpr(VectorBase::template CheckEpi8Shuffle()){ + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetShuffleMaskEpi8(); + if constexpr(std::is_same_v::VectorType, __m128h>) { __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); - return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { __m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data()); - return VectorF16(_mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(v)), _mm512_castsi256_si512(shuffleVec))))); + return VectorF16(_mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(this->v)), _mm512_castsi256_si512(shuffleVec))))); } else { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); __m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data()); - return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec))); + return VectorF16(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(this->v), shuffleVec))); } } else { - if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetShuffleMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); - return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec))); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetPermuteMaskEpi16(); + return VectorF16(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetPermuteMaskEpi16(); __m256i permIdx = _mm256_loadu_epi16(permMask.data()); - return VectorF16(_mm256_castsi256_ph(_mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(v)))); + return VectorF16(_mm256_castsi256_ph(_mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(this->v)))); } else { - constexpr std::array::Alignment> permMask = GetPermuteMaskEpi16(); + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetPermuteMaskEpi16(); __m512i permIdx = _mm512_loadu_epi16(permMask.data()); - return VectorF16(_mm512_castsi512_ph(_mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(v)))); + return VectorF16(_mm512_castsi512_ph(_mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v)))); } } } @@ -750,8 +581,8 @@ namespace Crafter { VectorF16 C, VectorF16 E, VectorF16 G - ) requires(Len == 4 && Packing*Len == Alignment) { - if constexpr(std::is_same_v) { + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128h>) { VectorF16<1, 8> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); @@ -768,7 +599,7 @@ namespace Crafter { _mm_mul_ph(E.v, fLenghtE.v), _mm_mul_ph(G.v, fLenghtG.v) }; - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { VectorF16<1, 16> lenght = LengthNoShuffle(A, E, C, G); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); @@ -810,8 +641,8 @@ namespace Crafter { constexpr static std::tuple, VectorF16> Normalize( VectorF16 A, VectorF16 E - ) requires(Len == 2 && Packing*Len == Alignment) { - if constexpr(std::is_same_v) { + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128h>) { VectorF16<1, 8> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m128h one = _mm_loadu_ph(oneArr); @@ -824,7 +655,7 @@ namespace Crafter { _mm_mul_ph(A.v, fLenghtA.v), _mm_mul_ph(E.v, fLenghtE.v), }; - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { VectorF16<1, 16> lenght = LengthNoShuffle(A, E); constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m256h one = _mm256_loadu_ph(oneArr); @@ -858,11 +689,11 @@ namespace Crafter { VectorF16 C, VectorF16 E, VectorF16 G - ) requires(Len == 4 && Packing*Len == Alignment) { + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { VectorF16<1, Packing*4> lenghtSq = LengthSq(A, C, E, G); - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v)); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v)); @@ -872,11 +703,11 @@ namespace Crafter { constexpr static VectorF16<1, Packing*2> Length( VectorF16 A, VectorF16 E - ) requires(Len == 2 && Packing*Len == Alignment) { + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { VectorF16<1, Packing*2> lenghtSq = LengthSq(A, E); - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v)); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v)); @@ -888,14 +719,14 @@ namespace Crafter { VectorF16 C, VectorF16 E, VectorF16 G - ) requires(Len == 4 && Packing*Len == Alignment) { + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { return Dot(A, A, C, C, E, E, G, G); } constexpr static VectorF16<1, Packing*2> LengthSq( VectorF16 A, VectorF16 E - ) requires(Len == 2 && Packing*Len == Alignment) { + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { return Dot(A, A, E, E); } @@ -904,10 +735,10 @@ namespace Crafter { VectorF16 C0, VectorF16 C1, VectorF16 E0, VectorF16 E1, VectorF16 G0, VectorF16 G1 - ) requires(Len == 4 && Packing*Len == Alignment) { - if constexpr(std::is_same_v) { + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return DotNoShuffle(A0, A1, E0, E1, C0, C1, G0, G1); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { VectorF16<16, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v); vec = vec.template Shuffle<{{ 0,1,8,9, @@ -938,10 +769,10 @@ namespace Crafter { constexpr static VectorF16<1, Packing*2> Dot( VectorF16 A0, VectorF16 A1, VectorF16 E0, VectorF16 E1 - ) requires(Len == 2 && Packing*Len == Alignment) { - if constexpr(std::is_same_v) { + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return DotNoShuffle(A0, A1, E0, E1); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { VectorF16<16, 1> vec(DotNoShuffle(A0, A1, E0, E1).v); vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>(); return vec.v; @@ -959,11 +790,11 @@ namespace Crafter { VectorF16 C, VectorF16 E, VectorF16 G - ) requires(Len == 4 && Packing*Len == Alignment) { + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { VectorF16<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, C, E, G); - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v)); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v)); @@ -973,11 +804,11 @@ namespace Crafter { constexpr static VectorF16<1, Packing*2> LengthNoShuffle( VectorF16 A, VectorF16 E - ) requires(Len == 2 && Packing*Len == Alignment) { + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { VectorF16<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, E); - if constexpr(std::is_same_v) { + if constexpr(std::is_same_v::VectorType, __m128h>) { return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v)); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v)); } else { return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v)); @@ -989,14 +820,14 @@ namespace Crafter { VectorF16 C, VectorF16 E, VectorF16 G - ) requires(Len == 4 && Packing*Len == Alignment) { + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { return DotNoShuffle(A, A, C, C, E, E, G, G); } constexpr static VectorF16<1, Packing*2> LengthSqNoShuffle( VectorF16 A, VectorF16 E - ) requires(Len == 2 && Packing*Len == Alignment) { + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { return DotNoShuffle(A, A, E, E); } @@ -1006,8 +837,8 @@ namespace Crafter { VectorF16 C0, VectorF16 C1, VectorF16 E0, VectorF16 E1, VectorF16 G0, VectorF16 G1 - ) requires(Len == 4 && Packing*Len == Alignment) { - if constexpr(std::is_same_v) { + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128h>) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulC = _mm_mul_ph(C0.v, C1.v); @@ -1037,7 +868,7 @@ namespace Crafter { row1 = _mm_add_ph(row1, row4); return row1; - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulC = _mm256_mul_ph(C0.v, C1.v); @@ -1103,8 +934,8 @@ namespace Crafter { constexpr static VectorF16<1, Packing*2> DotNoShuffle( VectorF16 A0, VectorF16 A1, VectorF16 E0, VectorF16 E1 - ) requires(Len == 2 && Packing*Len == Alignment) { - if constexpr(std::is_same_v) { + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128h>) { __m128h mulA = _mm_mul_ph(A0.v, A1.v); __m128h mulE = _mm_mul_ph(E0.v, E1.v); __m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2 @@ -1118,7 +949,7 @@ namespace Crafter { __m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2 return _mm_add_ph(row1, row2); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256h>) { __m256h mulA = _mm256_mul_ph(A0.v, A1.v); __m256h mulE = _mm256_mul_ph(E0.v, E1.v); @@ -1156,17 +987,18 @@ namespace Crafter { template ShuffleValues> constexpr static VectorF16 Blend(VectorF16 a, VectorF16 b) { - if constexpr(std::is_same_v) { - return _mm_castsi128_ph(_mm_blend_epi16(_mm_castph_si128(a.v), _mm_castph_si128(b.v), GetBlendMaskEpi16())); - } else if constexpr(std::is_same_v) { + constexpr auto mask = VectorBase::template GetBlendMaskEpi16(); + if constexpr(std::is_same_v::VectorType, __m128h>) { + return _mm_castsi128_ph(_mm_blend_epi16(_mm_castph_si128(a.v), _mm_castph_si128(b.v), mask)); + } else if constexpr(std::is_same_v::VectorType, __m256h>) { #ifndef __AVX512BW__ #ifndef __AVX512VL__ static_assert(false, "No __AVX512BW__ and __AVX512VL__ support"); #endif #endif - return _mm256_castsi256_ph(_mm256_mask_blend_epi16(GetBlendMaskEpi16(), _mm256_castph_si256(a.v), _mm256_castph_si256(b.v))); + return _mm256_castsi256_ph(_mm256_mask_blend_epi16(mask, _mm256_castph_si256(a.v), _mm256_castph_si256(b.v))); } else { - return _mm512_castsi512_ph(_mm512_mask_blend_epi16(GetBlendMaskEpi16(), _mm512_castph_si512(a.v), _mm512_castph_si512(b.v))); + return _mm512_castsi512_ph(_mm512_mask_blend_epi16(mask, _mm512_castph_si512(a.v), _mm512_castph_si512(b.v))); } } @@ -1218,215 +1050,6 @@ namespace Crafter { return row1; } - static constexpr float two_over_pi = 0.6366197723675814f; - static constexpr float pi_over_2_hi = 1.5707963267341256f; - static constexpr float pi_over_2_lo = 6.077100506506192e-11f; - - // Cos polynomial on [-pi/4, pi/4]: c0 + c2*r^2 + c4*r^4 + ... - static constexpr float c0 = 1.0f; - static constexpr float c2 = -0.4999999642372f; - static constexpr float c4 = 0.0416666418707f; - static constexpr float c6 = -0.0013888397720f; - static constexpr float c8 = 0.0000248015873f; - static constexpr float c10 = -0.0000002752258f; - - // Sin polynomial on [-pi/4, pi/4]: r * (1 + s1*r^2 + s3*r^4 + ...) - static constexpr float s1 = -0.1666666641831f; - static constexpr float s3 = 0.0083333293858f; - static constexpr float s5 = -0.0001984090955f; - static constexpr float s7 = 0.0000027526372f; - static constexpr float s9 = -0.0000000239013f; - - // Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant - constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) { - __m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - q = _mm256_cvtps_epi32(fq); - r = _mm256_sub_ps(ax, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_hi))); - r = _mm256_sub_ps(r, _mm256_mul_ps(fq, _mm256_set1_ps(pi_over_2_lo))); - r2 = _mm256_mul_ps(r, r); - } - - - // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 - constexpr __m256 cos_f32x8(__m256 x) { - const __m256 sign_mask = _mm256_set1_ps(-0.0f); - __m256 ax = _mm256_andnot_ps(sign_mask, x); - - __m256 r, r2; __m256i q; - range_reduce_f32x8(ax, r, r2, q); - - __m256 cos_r, sin_r; - sincos_poly_f32x8(r, r2, cos_r, sin_r); - - __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); - __m256 use_sin = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); - __m256 result = _mm256_blendv_ps(cos_r, sin_r, use_sin); - - __m256i need_neg = _mm256_and_si256( - _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); - __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); - return _mm256_xor_ps(result, neg_mask); - } - - constexpr void sincos_poly_f32x8(__m256 r, __m256 r2, __m256& cos_r, __m256& sin_r) { - cos_r = _mm256_fmadd_ps(_mm256_set1_ps(c10), r2, _mm256_set1_ps(c8)); - cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c6)); - cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c4)); - cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c2)); - cos_r = _mm256_fmadd_ps(cos_r, r2, _mm256_set1_ps(c0)); - - sin_r = _mm256_fmadd_ps(_mm256_set1_ps(s9), r2, _mm256_set1_ps(s7)); - sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s5)); - sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s3)); - sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(s1)); - sin_r = _mm256_fmadd_ps(sin_r, r2, _mm256_set1_ps(1.0f)); - sin_r = _mm256_mul_ps(sin_r, r); - } - - // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign - constexpr __m256 sin_f32x8(__m256 x) { - const __m256 sign_mask = _mm256_set1_ps(-0.0f); - __m256 x_sign = _mm256_and_ps(x, sign_mask); - __m256 ax = _mm256_andnot_ps(sign_mask, x); - - __m256 r, r2; __m256i q; - range_reduce_f32x8(ax, r, r2, q); - - __m256 cos_r, sin_r; - sincos_poly_f32x8(r, r2, cos_r, sin_r); - - __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); - __m256 use_cos = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); - __m256 result = _mm256_blendv_ps(sin_r, cos_r, use_cos); - - __m256i need_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); - __m256 neg_mask = _mm256_castsi256_ps(_mm256_slli_epi32(need_neg, 30)); - result = _mm256_xor_ps(result, neg_mask); - - // Apply original sign of x - return _mm256_xor_ps(result, x_sign); - } - - // --- 512-bit helpers --- - - constexpr void range_reduce_f32x16(__m512 ax, __m512& r, __m512& r2, __m512i& q) { - __m512 fq = _mm512_roundscale_ps(_mm512_mul_ps(ax, _mm512_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - q = _mm512_cvtps_epi32(fq); - r = _mm512_sub_ps(ax, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_hi))); - r = _mm512_sub_ps(r, _mm512_mul_ps(fq, _mm512_set1_ps(pi_over_2_lo))); - r2 = _mm512_mul_ps(r, r); - } - - constexpr void sincos_poly_f32x16(__m512 r, __m512 r2, __m512& cos_r, __m512& sin_r) { - cos_r = _mm512_fmadd_ps(_mm512_set1_ps(c10), r2, _mm512_set1_ps(c8)); - cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c6)); - cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c4)); - cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c2)); - cos_r = _mm512_fmadd_ps(cos_r, r2, _mm512_set1_ps(c0)); - - sin_r = _mm512_fmadd_ps(_mm512_set1_ps(s9), r2, _mm512_set1_ps(s7)); - sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s5)); - sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s3)); - sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(s1)); - sin_r = _mm512_fmadd_ps(sin_r, r2, _mm512_set1_ps(1.0f)); - sin_r = _mm512_mul_ps(sin_r, r); - } - - constexpr __m512 cos_f32x16(__m512 x) { - __m512 ax = _mm512_abs_ps(x); - - __m512 r, r2; __m512i q; - range_reduce_f32x16(ax, r, r2, q); - - __m512 cos_r, sin_r; - sincos_poly_f32x16(r, r2, cos_r, sin_r); - - __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); - __m512 result = _mm512_mask_blend_ps(odd, cos_r, sin_r); - - __m512i need_neg = _mm512_and_si512( - _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); - __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); - return _mm512_xor_ps(result, neg_mask); - } - - constexpr __m512 sin_f32x16(__m512 x) { - __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); - __m512 ax = _mm512_abs_ps(x); - - __m512 r, r2; __m512i q; - range_reduce_f32x16(ax, r, r2, q); - - __m512 cos_r, sin_r; - sincos_poly_f32x16(r, r2, cos_r, sin_r); - - __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); - __m512 result = _mm512_mask_blend_ps(odd, sin_r, cos_r); - - __m512i need_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); - __m512 neg_mask = _mm512_castsi512_ps(_mm512_slli_epi32(need_neg, 30)); - result = _mm512_xor_ps(result, neg_mask); - - return _mm512_xor_ps(result, x_sign); - } - - // --- 256-bit sincos --- - constexpr void sincos_f32x8(__m256 x, __m256& out_sin, __m256& out_cos) { - const __m256 sign_mask = _mm256_set1_ps(-0.0f); - __m256 x_sign = _mm256_and_ps(x, sign_mask); - __m256 ax = _mm256_andnot_ps(sign_mask, x); - - __m256 r, r2; __m256i q; - range_reduce_f32x8(ax, r, r2, q); - - __m256 cos_r, sin_r; - sincos_poly_f32x8(r, r2, cos_r, sin_r); - - __m256i odd = _mm256_and_si256(q, _mm256_set1_epi32(1)); - __m256 is_odd = _mm256_castsi256_ps(_mm256_cmpeq_epi32(odd, _mm256_set1_epi32(1))); - - // cos: swap on odd, negate if (q+1)&2 - out_cos = _mm256_blendv_ps(cos_r, sin_r, is_odd); - __m256i cos_neg = _mm256_and_si256( - _mm256_add_epi32(q, _mm256_set1_epi32(1)), _mm256_set1_epi32(2)); - out_cos = _mm256_xor_ps(out_cos, - _mm256_castsi256_ps(_mm256_slli_epi32(cos_neg, 30))); - - // sin: swap on odd, negate if q&2, apply input sign - out_sin = _mm256_blendv_ps(sin_r, cos_r, is_odd); - __m256i sin_neg = _mm256_and_si256(q, _mm256_set1_epi32(2)); - out_sin = _mm256_xor_ps(out_sin, - _mm256_castsi256_ps(_mm256_slli_epi32(sin_neg, 30))); - out_sin = _mm256_xor_ps(out_sin, x_sign); - } - - // --- 512-bit sincos --- - constexpr void sincos_f32x16(__m512 x, __m512& out_sin, __m512& out_cos) { - __m512 x_sign = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); - __m512 ax = _mm512_abs_ps(x); - - __m512 r, r2; __m512i q; - range_reduce_f32x16(ax, r, r2, q); - - __m512 cos_r, sin_r; - sincos_poly_f32x16(r, r2, cos_r, sin_r); - - __mmask16 odd = _mm512_test_epi32_mask(q, _mm512_set1_epi32(1)); - - // cos - out_cos = _mm512_mask_blend_ps(odd, cos_r, sin_r); - __m512i cos_neg = _mm512_and_si512( - _mm512_add_epi32(q, _mm512_set1_epi32(1)), _mm512_set1_epi32(2)); - out_cos = _mm512_xor_ps(out_cos, - _mm512_castsi512_ps(_mm512_slli_epi32(cos_neg, 30))); - - // sin - out_sin = _mm512_mask_blend_ps(odd, sin_r, cos_r); - __m512i sin_neg = _mm512_and_si512(q, _mm512_set1_epi32(2)); - out_sin = _mm512_xor_ps(out_sin, - _mm512_castsi512_ps(_mm512_slli_epi32(sin_neg, 30))); - out_sin = _mm512_xor_ps(out_sin, x_sign); - } }; } @@ -1434,12 +1057,12 @@ namespace Crafter { export template struct std::formatter> : std::formatter { constexpr auto format(const Crafter::VectorF16& obj, format_context& ctx) const { - Crafter::Vector<_Float16, Len * Packing, 0> vec = obj.Store(); + std::array<_Float16, Crafter::VectorF16::AlignmentElement> vec = obj.Store(); std::string out = "{"; for(std::uint32_t i = 0; i < Packing; i++) { out += "{"; for(std::uint32_t i2 = 0; i2 < Len; i2++) { - out += std::format("{}", static_cast(vec.v[i * Len + i2])); + out += std::format("{}", static_cast(vec[i * Len + i2])); if (i2 + 1 < Len) out += ","; } out += "}"; diff --git a/interfaces/Crafter.Math-VectorF32.cppm b/interfaces/Crafter.Math-VectorF32.cppm index eba8d3d..2573524 100755 --- a/interfaces/Crafter.Math-VectorF32.cppm +++ b/interfaces/Crafter.Math-VectorF32.cppm @@ -18,14 +18,15 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ module; #ifdef __x86_64 -#include +#include #endif export module Crafter.Math:VectorF32; import std; import :Vector; +import :Common; namespace Crafter { - export template + export template struct VectorF32 { #ifdef __AVX512F__ static constexpr std::uint32_t MaxSize = 16; @@ -45,8 +46,6 @@ namespace Crafter { return 16; } static_assert(Len * Packing <= 16, "Len * Packing is larger than supported max size of 16"); - static_assert(Len * Packing <= 4, "Len * Packing is larger than supported packed size of 4"); - static_assert(Len * Packing * Repeats <= 16, "Len * Packing * Repeats is larger than supported max of 16"); #else if constexpr (Len * Packing <= 4) { return 4; @@ -55,17 +54,12 @@ namespace Crafter { return 8; } static_assert(Len * Packing <= 8, "Len * Packing is larger than supported max size of 8"); - static_assert(Len * Packing <= 4, "Len * Packing is larger than supported packed size of 4"); - static_assert(Len * Packing * Repeats <= 8, "Len * Packing * Repeats is larger than supported max of 8"); #endif } - static consteval std::uint32_t GetTotalSize() { - return GetAlignment() * Repeats; - } using VectorType = std::conditional_t< - (GetTotalSize() == 16), __m512, - std::conditional_t<(GetTotalSize() == 8), __m256, __m128> + (Len * Packing > 8), __m512h, + std::conditional_t<(Len * Packing > 4), __m256h, __m128> >; VectorType v; @@ -107,91 +101,96 @@ namespace Crafter { } constexpr void Load(const _Float16* vB) { if constexpr(std::is_same_v) { - v = _mm_cvtph_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB))); + v = _mm_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB))); } else if constexpr(std::is_same_v) { - v = _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB))); + v = _mm256_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB))); } else { - v = _mm512_cvtph_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB))); + v = _mm512_cvtps_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB))); } } constexpr void Store(_Float16* vB) const { if constexpr(std::is_same_v) { - _mm_storeu_si128(_mm_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v); + _mm_storeu_si128(_mm_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v); } else if constexpr(std::is_same_v) { - _mm_storeu_si128(_mm256_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v); + _mm_storeu_si128(_mm256_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v); } else { - _mm256_storeu_si256(_mm512_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT), v); + _mm256_storeu_si256(_mm512_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v); } } - template - constexpr Vector Store() const { - Vector returnVec; - Store(returnVec.v); - return returnVec; + constexpr std::array Store() const { + std::array returnArray; + Store(returnArray.data()); + return returnArray; } - template - constexpr operator VectorF32() const { - if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128>) { - return VectorF32(_mm256_castps256_ps128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128>) { - return VectorF32(_mm512_castps512_ps128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256>) { - return VectorF32(_mm512_castps512_ps256(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256>) { - return VectorF32(_mm256_castps128_ps256(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512>) { - return VectorF32(_mm512_castps128_ps512(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512>) { - return VectorF32(_mm512_castps256_ps512(v)); + template + constexpr operator VectorF32() const { + if constexpr (Len == BLen) { + if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm256_castps256_ps128(v)); + } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm512_castps512_ps128(v)); + } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm512_castps512_ps256(v)); + } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_castps128_ps256(v)); + } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512>) { + return VectorF32(_mm512_castps128_ps512(v)); + } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512>) { + return VectorF32(_mm512_castps256_ps512(v)); + } else { + return VectorF32(v); + } + } else if constexpr (BLen <= Len) { + return this->template ExtractLo(); } else { - return VectorF32(v); + return VectorF32(v); } } - - constexpr VectorF32 operator+(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_add_ps(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_add_ps(v, b.v)); + + constexpr VectorF32 operator+(VectorF32 b) const { + if constexpr(std::is_same_v) { + return VectorF32(_mm_add_ph(v, b.v)); + } else if constexpr(std::is_same_v) { + return VectorF32(_mm256_add_ph(v, b.v)); } else { - return VectorF32(_mm512_add_ps(v, b.v)); + return VectorF32(_mm512_add_ph(v, b.v)); } } - constexpr VectorF32 operator-(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_sub_ps(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_sub_ps(v, b.v)); + constexpr VectorF32 operator-(VectorF32 b) const { + if constexpr(std::is_same_v) { + return VectorF32(_mm_sub_ph(v, b.v)); + } else if constexpr(std::is_same_v) { + return VectorF32(_mm256_sub_ph(v, b.v)); } else { - return VectorF32(_mm512_sub_ps(v, b.v)); + return VectorF32(_mm512_sub_ph(v, b.v)); } } - constexpr VectorF32 operator*(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_mul_ps(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_mul_ps(v, b.v)); + constexpr VectorF32 operator*(VectorF32 b) const { + if constexpr(std::is_same_v) { + return VectorF32(_mm_mul_ph(v, b.v)); + } else if constexpr(std::is_same_v) { + return VectorF32(_mm256_mul_ph(v, b.v)); } else { - return VectorF32(_mm512_mul_ps(v, b.v)); + return VectorF32(_mm512_mul_ph(v, b.v)); } } - constexpr VectorF32 operator/(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_div_ps(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_div_ps(v, b.v)); + constexpr VectorF32 operator/(VectorF32 b) const { + if constexpr(std::is_same_v) { + return VectorF32(_mm_div_ph(v, b.v)); + } else if constexpr(std::is_same_v) { + return VectorF32(_mm256_div_ph(v, b.v)); } else { - return VectorF32(_mm512_div_ps(v, b.v)); + return VectorF32(_mm512_div_ph(v, b.v)); } } - constexpr void operator+=(VectorF32 b) const { + constexpr void operator+=(VectorF32 b) { if constexpr(std::is_same_v) { v = _mm_add_ps(v, b.v); } else if constexpr(std::is_same_v) { @@ -201,7 +200,7 @@ namespace Crafter { } } - constexpr void operator-=(VectorF32 b) const { + constexpr void operator-=(VectorF32 b) { if constexpr(std::is_same_v) { v = _mm_sub_ps(v, b.v); } else if constexpr(std::is_same_v) { @@ -211,7 +210,7 @@ namespace Crafter { } } - constexpr void operator*=(VectorF32 b) const { + constexpr void operator*=(VectorF32 b) { if constexpr(std::is_same_v) { v = _mm_mul_ps(v, b.v); } else if constexpr(std::is_same_v) { @@ -221,7 +220,7 @@ namespace Crafter { } } - constexpr void operator/=(VectorF32 b) const { + constexpr void operator/=(VectorF32 b) { if constexpr(std::is_same_v) { v = _mm_div_ps(v, b.v); } else if constexpr(std::is_same_v) { @@ -231,60 +230,48 @@ namespace Crafter { } } - constexpr VectorF32 operator+(float b) const { - VectorF32 vB(b); - return this + vB; + constexpr VectorF32 operator+(float b) { + VectorF32 vB(b); + return *this + vB; } - constexpr VectorF32 operator-(float b) const { - VectorF32 vB(b); - return this - vB; + constexpr VectorF32 operator-(float b) { + VectorF32 vB(b); + return *this - vB; } - constexpr VectorF32 operator*(float b) const { - VectorF32 vB(b); - return this * vB; + constexpr VectorF32 operator*(float b) { + VectorF32 vB(b); + return *this * vB; } - constexpr VectorF32 operator/(float b) const { - VectorF32 vB(b); - return this / vB; + constexpr VectorF32 operator/(float b) { + VectorF32 vB(b); + return *this / vB; } - constexpr void operator+=(float b) const { - VectorF32 vB(b); - this += vB; + constexpr void operator+=(float b) { + VectorF32 vB(b); + *this += vB; } - constexpr void operator-=(float b) const { - VectorF32 vB(b); - this -= vB; + constexpr void operator-=(float b) { + VectorF32 vB(b); + *this -= vB; } - constexpr void operator*=(float b) const { - VectorF32 vB(b); - this *= vB; + constexpr void operator*=(float b) { + VectorF32 vB(b); + *this *= vB; } - constexpr void operator/=(float b) const { - VectorF32 vB(b); - this /= vB; + constexpr void operator/=(float b) { + VectorF32 vB(b); + *this /= vB; } - - constexpr VectorF32 operator-(){ - if constexpr(std::is_same_v) { - constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000}; - __m128i sign_mask = _mm_loadu_si128(reinterpret_cast(mask)); - return VectorF32(_mm_castsi128_ps(_mm_xor_si128(sign_mask, _mm_castps_si128(v)))); - } else if constexpr(std::is_same_v) { - constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000}; - __m256i sign_mask = _mm256_loadu_si256(reinterpret_cast(mask)); - return VectorF32(_mm256_castsi256_ps(_mm256_xor_si256(sign_mask, _mm256_castps_si256(v)))); - } else { - constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000}; - __m512i sign_mask = _mm512_loadu_si512(reinterpret_cast(mask)); - return VectorF32(_mm512_castsi512_ps(_mm512_xor_si512(sign_mask, _mm512_castps_si512(v)))); - } + + constexpr VectorF32 operator-(){ + return Negate()>(); } constexpr bool operator==(VectorF32 b) const { @@ -335,71 +322,47 @@ namespace Crafter { return Dot(*this, *this); } - constexpr VectorF32 Cos() requires(Len == 3) { - if constexpr(std::is_same_v) { - return VectorF32(_mm_cos_ps(v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_cos_ps(v)); + template ShuffleValues> + constexpr VectorF32 Shuffle() { + if constexpr(CheckEpi32Shuffle()) { + if constexpr(std::is_same_v) { + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), GetShuffleMaskEpi32()))); + } else if constexpr(std::is_same_v) { + return VectorF32(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), GetShuffleMaskEpi32()))); + } else { + return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(v), GetShuffleMaskEpi32()))); + } + } else if constexpr(CheckEpi8Shuffle()){ + if constexpr(std::is_same_v) { + constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec))); + } else if constexpr(std::is_same_v) { + constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + __m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(v)), _mm512_castsi256_si512(shuffleVec))))); + } else { + constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + __m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec))); + } } else { - return VectorF32(_mm512_cos_ps(v)); + if constexpr(std::is_same_v) { + constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec))); + } else if constexpr(std::is_same_v) { + constexpr std::array::Alignment> permMask = GetPermuteMaskEpi32(); + __m256i permIdx = _mm256_loadu_epi16(permMask.data()); + return VectorF32(_mm256_castsi256_ps(_mm256_permutexvar_epi16(permIdx, _mm256_castps_si256(v)))); + } else { + constexpr std::array::Alignment> permMask = GetPermuteMaskEpi32(); + __m512i permIdx = _mm512_loadu_epi16(permMask.data()); + return VectorF32(_mm512_castsi512_ps(_mm512_permutexvar_epi16(permIdx, _mm512_castps_si512(v)))); + } } } - constexpr VectorF32 Sin() requires(Len == 3) { - if constexpr(std::is_same_v) { - return VectorF32(_mm_sin_ps(v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_sin_ps(v)); - } else { - return VectorF32(_mm512_sin_ps(v)); - } - } - - template - constexpr VectorF32 Shuffle() { - constexpr std::uint32_t val = - (A & 0x3) | - ((B & 0x3) << 2) | - ((C & 0x3) << 4) | - ((D & 0x3) << 6); - if constexpr(std::is_same_v) { - return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), val))); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), val))); - } else { - return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm_512castps_si512(v), val))); - } - } - - template < - std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, - std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1 - > - constexpr VectorF32 Shuffle() requires(Repeats == 2) { - constexpr std::uint8_t shuffleMask[] { - A0,A0,A0,A0,B0,B0,B0,B0,C0,C0,C0,C0,D0,D0,D0,D0, - A1,A1,A1,A1,B1,B1,B1,B1,C1,C1,C1,C1,D1,D1,D1,D1, - }; - __m256 shuffleVec = _mm256_loadu_epi8(shuffleMask); - return VectorF32(_mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(v), shuffleVec))); - } - - template < - std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, std::uint8_t E0, std::uint8_t F0, std::uint8_t G0, std::uint8_t H0, - std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1, std::uint8_t E1, std::uint8_t F1, std::uint8_t G1, std::uint8_t H1, - std::uint8_t A2, std::uint8_t B2, std::uint8_t C2, std::uint8_t D2, std::uint8_t E2, std::uint8_t F2, std::uint8_t G2, std::uint8_t H2, - std::uint8_t A3, std::uint8_t B3, std::uint8_t C3, std::uint8_t D3, std::uint8_t E3, std::uint8_t F3, std::uint8_t G3, std::uint8_t H3 - > - constexpr VectorF32 Shuffle() requires(Repeats == 4) { - constexpr std::uint8_t shuffleMask[] { - A0,A0,A0,A0,B0,B0,B0,B0,C0,C0,C0,C0,D0,D0,D0,D0, - A1,A1,A1,A1,B1,B1,B1,B1,C1,C1,C1,C1,D1,D1,D1,D1, - A2,A2,A2,A2,B2,B2,B2,B2,C2,C2,C2,C2,D2,D2,D2,D2, - A3,A3,A3,A3,B3,B3,B3,B3,C3,C3,C3,C3,D3,D3,D3,D3, - }; - __m512 shuffleVec = _mm512_loadu_epi8(shuffleMask); - return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec))); - } static constexpr VectorF32 MulitplyAdd(VectorF32 a, VectorF32 b, VectorF32 add) { if constexpr(std::is_same_v) { diff --git a/interfaces/Crafter.Math.cppm b/interfaces/Crafter.Math.cppm index bf48bbe..d599bf9 100644 --- a/interfaces/Crafter.Math.cppm +++ b/interfaces/Crafter.Math.cppm @@ -20,8 +20,6 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA export module Crafter.Math; export import :Basic; -export import :Vector; -export import :MatrixRowMajor; -export import :Intersection; +export import :Common; export import :VectorF16; -export import :VectorF32; \ No newline at end of file +// export import :VectorF32; \ No newline at end of file diff --git a/project.json b/project.json index b233dc6..cedefe2 100644 --- a/project.json +++ b/project.json @@ -4,13 +4,10 @@ { "name": "base", "interfaces": [ - "interfaces/Crafter.Math-Vector", "interfaces/Crafter.Math-Basic", - "interfaces/Crafter.Math-MatrixRowMajor", "interfaces/Crafter.Math", - "interfaces/Crafter.Math-Intersection", - "interfaces/Crafter.Math-VectorF16", - "interfaces/Crafter.Math-VectorF32" + "interfaces/Crafter.Math-Common", + "interfaces/Crafter.Math-VectorF16" ], "implementations": [] }, diff --git a/tests/Vector.cpp b/tests/Vector.cpp index d889bdb..a2a146b 100644 --- a/tests/Vector.cpp +++ b/tests/Vector.cpp @@ -45,17 +45,26 @@ consteval std::array GetCountReverse() { return result; } -template class VectorType, std::uint32_t MaxSize, std::uint32_t Len = 1, std::uint32_t Packing = 1> +template +constexpr std::array array_cast(const std::array& src) { + std::array dst{}; + for (std::size_t i = 0; i < N; ++i) { + dst[i] = static_cast(src[i]); + } + return dst; +} + +template class VectorType, std::uint32_t MaxSize, std::uint32_t Len = 1, std::uint32_t Packing = 1> std::string* TestAllCombinations() { if constexpr (Len > MaxSize) { return nullptr; } else if constexpr (Len * Packing > MaxSize) { return TestAllCombinations(); } else { - T floats[VectorType::Alignment]; - T floats1[VectorType::Alignment]; - T floats2[VectorType::Alignment]; - for (std::uint32_t i = 0; i < VectorType::Alignment; i++) { + T floats[VectorType::AlignmentElement]; + T floats1[VectorType::AlignmentElement]; + T floats2[VectorType::AlignmentElement]; + for (std::uint32_t i = 0; i < VectorType::AlignmentElement; i++) { floats[i] = static_cast(i+1); } for (std::uint32_t i = 0; i < Packing*Len; i++) { @@ -64,7 +73,7 @@ std::string* TestAllCombinations() { for (std::uint32_t i = 0; i < Packing*Len; i++) { floats2[i] = static_cast(i+1+Len); } - for (std::uint32_t i = Len*Packing; i < VectorType::Alignment; i++) { + for (std::uint32_t i = Len*Packing; i < VectorType::AlignmentElement; i++) { floats1[i] = 0; floats2[i] = 0; } @@ -81,10 +90,10 @@ std::string* TestAllCombinations() { if constexpr(total > 0 && (total & (total - 1)) == 0) { { VectorType vec(floats); - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], floats[i])) { - return new std::string(std::format("Load/Store mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i]), (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i])) { + return new std::string(std::format("Load/Store mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i]), (float)stored[i])); } } } @@ -92,10 +101,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec + vec; - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], floats[i] + floats[i])) { - return new std::string(std::format("Add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + floats[i]), (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i] + floats[i])) { + return new std::string(std::format("Add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + floats[i]), (float)stored[i])); } } } @@ -103,10 +112,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec - vec; - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], T(0))) { - return new std::string(std::format("Subtract mismatch at Len={} Packing={}, Expected: 0, Got: {}", Len, Packing, (float)stored.v[i])); + if (!FloatEquals(stored[i], T(0))) { + return new std::string(std::format("Subtract mismatch at Len={} Packing={}, Expected: 0, Got: {}", Len, Packing, (float)stored[i])); } } } @@ -114,10 +123,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec * vec; - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], floats[i] * floats[i])) { - return new std::string(std::format("Multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * floats[i]), (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i] * floats[i])) { + return new std::string(std::format("Multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * floats[i]), (float)stored[i])); } } } @@ -125,10 +134,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec / vec; - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], T(1))) { - return new std::string(std::format("Divide mismatch at Len={} Packing={}, Expected: 1, Got: {}", Len, Packing, (float)stored.v[i])); + if (!FloatEquals(stored[i], T(1))) { + return new std::string(std::format("Divide mismatch at Len={} Packing={}, Expected: 1, Got: {}", Len, Packing, (float)stored[i])); } } } @@ -136,10 +145,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec + T(2); - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], floats[i] + T(2))) { - return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i] + T(2))) { + return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored[i])); } } } @@ -147,10 +156,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec - T(2); - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], floats[i] - T(2))) { - return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i] - T(2))) { + return new std::string(std::format("Scalar add mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] + T(2)), (float)stored[i])); } } } @@ -158,10 +167,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec * T(2); - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], floats[i] * T(2))) { - return new std::string(std::format("Scalar multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i] * T(2))) { + return new std::string(std::format("Scalar multiply mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored[i])); } } } @@ -169,10 +178,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = vec / T(2); - Vector::Alignment> stored = vec.Store(); + std::array::AlignmentElement> stored = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(stored.v[i], floats[i] / T(2))) { - return new std::string(std::format("Scalar divide mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i] / T(2))) { + return new std::string(std::format("Scalar divide mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(floats[i] * T(2)), (float)stored[i])); } } } @@ -216,10 +225,10 @@ std::string* TestAllCombinations() { { VectorType vec(floats); vec = -vec; - Vector::Alignment> result = vec.Store(); + std::array::AlignmentElement> result = vec.Store(); for (std::uint32_t i = 0; i < Len * Packing; i++) { - if (!FloatEquals(result.v[i], -floats[i])) { - return new std::string(std::format("Negate mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(-floats[i]), (float)result.v[i])); + if (!FloatEquals(result[i], -floats[i])) { + return new std::string(std::format("Negate mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)(-floats[i]), (float)result[i])); } } } @@ -228,12 +237,12 @@ std::string* TestAllCombinations() { VectorType vecA(floats1); VectorType vecB(floats2); VectorType result = VectorType::template Blend()>(vecA, vecB); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for (std::uint32_t i = 0; i < Len; i++) { bool useB = (i % 2 == 0); T expected = useB ? floats2[i]: floats1[i]; - if (!FloatEquals(stored.v[i], expected)) { - return new std::string(std::format("Blend mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i])); + if (!FloatEquals(stored[i], expected)) { + return new std::string(std::format("Blend mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i])); } } } @@ -243,11 +252,11 @@ std::string* TestAllCombinations() { VectorType vecB(floats); VectorType vecAdd(floats); VectorType result = VectorType::MulitplyAdd(vecA, vecB, vecAdd); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for (std::uint32_t i = 0; i < Len; i++) { T expected = floats[i] * floats[i] + floats[i]; - if (!FloatEquals(stored.v[i], expected)) { - return new std::string(std::format("MulitplyAdd mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i])); + if (!FloatEquals(stored[i], expected)) { + return new std::string(std::format("MulitplyAdd mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i])); } } } @@ -257,11 +266,11 @@ std::string* TestAllCombinations() { VectorType vecB(floats); VectorType vecSub(floats); VectorType result = VectorType::MulitplySub(vecA, vecB, vecSub); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for (std::uint32_t i = 0; i < Len; i++) { T expected = floats[i] * floats[i] - floats[i]; - if (!FloatEquals(stored.v[i], expected)) { - return new std::string(std::format("MulitplySub mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i])); + if (!FloatEquals(stored[i], expected)) { + return new std::string(std::format("MulitplySub mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i])); } } } @@ -269,12 +278,12 @@ std::string* TestAllCombinations() { if constexpr(Len > 2){ VectorType vec(floats); VectorType result = vec.template ExtractLo(); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for(std::uint32_t i2 = 0; i2 < Packing; i2++){ for (std::uint32_t i = 0; i < Len-1; i++) { T expected = floats[i2*(Len)+i]; - if (!FloatEquals(stored.v[i2*(Len-1)+i], expected)) { - return new std::string(std::format("ExtractLo mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i2*(Len-1)+i])); + if (!FloatEquals(stored[i2*(Len-1)+i], expected)) { + return new std::string(std::format("ExtractLo mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i2*(Len-1)+i])); } } } @@ -283,12 +292,12 @@ std::string* TestAllCombinations() { { VectorType vec(floats); VectorType result = vec.Sin(); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for(std::uint32_t i2 = 0; i2 < Packing; i2++){ for (std::uint32_t i = 0; i < Len; i++) { T expected = (T)std::sin((float)floats[i2*Len+i]); - if (!FloatEquals(stored.v[i2*Len+i], expected)) { - return new std::string(std::format("Sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i2*(Len-1)+i])); + if (!FloatEquals(stored[i2*Len+i], expected)) { + return new std::string(std::format("Sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i2*(Len-1)+i])); } } } @@ -297,12 +306,12 @@ std::string* TestAllCombinations() { { VectorType vec(floats); VectorType result = vec.Cos(); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for(std::uint32_t i2 = 0; i2 < Packing; i2++){ for (std::uint32_t i = 0; i < Len; i++) { T expected = (T)std::cos((float)floats[i2*Len+i]); - if (!FloatEquals(stored.v[i2*Len+i], expected)) { - return new std::string(std::format("Cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i2*(Len-1)+i])); + if (!FloatEquals(stored[i2*Len+i], expected)) { + return new std::string(std::format("Cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i2*(Len-1)+i])); } } } @@ -311,18 +320,18 @@ std::string* TestAllCombinations() { { VectorType vec(floats); auto result = vec.SinCos(); - Vector::Alignment> storedSin = std::get<0>(result).Store(); - Vector::Alignment> storedCos = std::get<1>(result).Store(); + std::array::AlignmentElement> storedSin = std::get<0>(result).Store(); + std::array::AlignmentElement> storedCos = std::get<1>(result).Store(); for(std::uint32_t i2 = 0; i2 < Packing; i2++){ for (std::uint32_t i = 0; i < Len; i++) { T expected = (T)std::sin((float)floats[i2*Len+i]); - if (!FloatEquals(storedSin.v[i2*Len+i], expected)) { - return new std::string(std::format("SinCos sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedSin.v[i2*(Len-1)+i])); + if (!FloatEquals(storedSin[i2*Len+i], expected)) { + return new std::string(std::format("SinCos sin mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedSin[i2*(Len-1)+i])); } expected = (T)std::cos((float)floats[i2*Len+i]); - if (!FloatEquals(storedCos.v[i2*Len+i], expected)) { - return new std::string(std::format("SinCos cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedCos.v[i2*(Len-1)+i])); + if (!FloatEquals(storedCos[i2*Len+i], expected)) { + return new std::string(std::format("SinCos cos mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)storedCos[i2*(Len-1)+i])); } } } @@ -331,11 +340,11 @@ std::string* TestAllCombinations() { { VectorType vec(floats); VectorType result = vec.template Shuffle()>(); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for (std::uint32_t i = 0; i < Len; i++) { T expected = floats[Len - 1 - i]; - if (!FloatEquals(stored.v[i], expected)) { - return new std::string(std::format("Shuffle mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored.v[i])); + if (!FloatEquals(stored[i], expected)) { + return new std::string(std::format("Shuffle mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)expected, (float)stored[i])); } } } @@ -343,7 +352,7 @@ std::string* TestAllCombinations() { if constexpr(Packing == 1) { T expectedLengthSq = T(0); - for (std::uint32_t i = 0; i < VectorType::Alignment; i++) { + for (std::uint32_t i = 0; i < VectorType::AlignmentElement; i++) { expectedLengthSq += floats[i] * floats[i]; } @@ -387,13 +396,13 @@ std::string* TestAllCombinations() { VectorType vec1(floats1); VectorType vec2(floats2); VectorType result = VectorType::Cross(vec1, vec2); - Vector::Alignment> stored = result.Store(); - if (!FloatEquals(stored.v[0], T(-3)) || !FloatEquals(stored.v[1], T(6)) || !FloatEquals(stored.v[2], T(-3))) { - return new std::string(std::format("Cross mismatch at Len={} Packing={}, Expected: -3,6,-3, Got: {},{},{}", Len, Packing, (float)stored.v[0], (float)stored.v[1], (float)stored.v[2])); + std::array::AlignmentElement> stored = result.Store(); + if (!FloatEquals(stored[0], T(-3)) || !FloatEquals(stored[1], T(6)) || !FloatEquals(stored[2], T(-3))) { + return new std::string(std::format("Cross mismatch at Len={} Packing={}, Expected: -3,6,-3, Got: {},{},{}", Len, Packing, (float)stored[0], (float)stored[1], (float)stored[2])); } } - if constexpr(4 * Packing < VectorType<1, 1>::MaxSize) { - T qData[VectorType<4, Packing>::Alignment]; + if constexpr(4 * Packing < VectorType<1, 1>::MaxElement) { + T qData[VectorType<4, Packing>::AlignmentElement]; qData[0] = T(0); qData[1] = T(0); qData[2] = T(0); @@ -402,18 +411,18 @@ std::string* TestAllCombinations() { VectorType<3, Packing> vecV(floats); VectorType<4, Packing> vecQ(qData); VectorType<3, Packing> result = VectorType<3, Packing>::Rotate(vecV, vecQ); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); for (std::uint32_t i = 0; i < 3; i++) { - if (!FloatEquals(stored.v[i], floats[i])) { - return new std::string(std::format("Rotate mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)floats[i], (float)stored.v[i])); + if (!FloatEquals(stored[i], floats[i])) { + return new std::string(std::format("Rotate mismatch at Len={} Packing={}, Index={}, Expected: {}, Got: {}", Len, Packing, i, (float)floats[i], (float)stored[i])); } } } } if constexpr(Len == 4) { - T eulerData[VectorType<3, Packing>::Alignment]; + T eulerData[VectorType<3, Packing>::AlignmentElement]; for(std::uint8_t i = 0; i < Packing; i++) { eulerData[i*3] = T(0.7853981); eulerData[i*3+1] = T(0.1243412); @@ -421,27 +430,27 @@ std::string* TestAllCombinations() { } VectorType<3, Packing> eulerVec(eulerData); VectorType<4, Packing> result = VectorType<4, Packing>::QuanternionFromEuler(eulerVec); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); - if (!FloatEquals(stored.v[0], T(0.63720703)) || !FloatEquals(stored.v[1], T(0.30688477)) || - !FloatEquals(stored.v[2], T(0.14074707)) || !FloatEquals(stored.v[3], T(0.6933594))) { - return new std::string(std::format("QuanternionFromEuler mismatch at Len={} Packing={}, Expected: 0.63720703,0.30688477,0.14074707,0.6933594, Got: {},{},{},{}", Len, Packing, (float)stored.v[0], (float)stored.v[1], (float)stored.v[2], (float)stored.v[3])); + if (!FloatEquals(stored[0], T(0.63720703)) || !FloatEquals(stored[1], T(0.30688477)) || + !FloatEquals(stored[2], T(0.14074707)) || !FloatEquals(stored[3], T(0.6933594))) { + return new std::string(std::format("QuanternionFromEuler mismatch at Len={} Packing={}, Expected: 0.63720703,0.30688477,0.14074707,0.6933594, Got: {},{},{},{}", Len, Packing, (float)stored[0], (float)stored[1], (float)stored[2], (float)stored[3])); } } - if constexpr(Len == 2 && Packing*Len == VectorType::Alignment) { + if constexpr(Len == 2 && Packing*Len == VectorType::AlignmentElement) { { VectorType vecA(floats); VectorType vecE = vecA *2; VectorType<1, Packing*2> result = VectorType::Length(vecA, vecE); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); - if (!FloatEquals(stored.v[0], expectedLength[0])) { - return new std::string(std::format("Length 2 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored.v[0])); + if (!FloatEquals(stored[0], expectedLength[0])) { + return new std::string(std::format("Length 2 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored[0])); } - if (!FloatEquals(stored.v[(Len*Packing)/2], expectedLength[0] * 2)) { - return new std::string(std::format("Length 2 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored.v[(Len*Packing)/2])); + if (!FloatEquals(stored[(Len*Packing)/2], expectedLength[0] * 2)) { + return new std::string(std::format("Length 2 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored[(Len*Packing)/2])); } } @@ -450,39 +459,39 @@ std::string* TestAllCombinations() { VectorType vecE = vecA * 2; auto result = VectorType::Normalize(vecA, vecE); VectorType<1, Packing*2> result2 = VectorType::Length(std::get<0>(result), std::get<1>(result)); - Vector::Alignment> stored = result2.Store(); + std::array::AlignmentElement> stored = result2.Store(); for(std::uint8_t i = 0; i < Len*Packing; i++) { - if (!FloatEquals(stored.v[i], T(1))) { - return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored.v[i])); + if (!FloatEquals(stored[i], T(1))) { + return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored[i])); } } } } - if constexpr(Len == 4 && Packing*Len == VectorType::Alignment) { + if constexpr(Len == 4 && Packing*Len == VectorType::AlignmentElement) { { VectorType vecA(floats); VectorType vecC = vecA * 2; VectorType vecE = vecA * 3; VectorType vecG = vecA * 4; VectorType<1, Packing*4> result = VectorType::Length(vecA, vecC, vecE, vecG); - Vector::Alignment> stored = result.Store(); + std::array::AlignmentElement> stored = result.Store(); - if (!FloatEquals(stored.v[0], expectedLength[0])) { - return new std::string(std::format("Length 4 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored.v[0])); + if (!FloatEquals(stored[0], expectedLength[0])) { + return new std::string(std::format("Length 4 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored[0])); } - if (!FloatEquals(stored.v[Packing], expectedLength[0] * 2)) { - return new std::string(std::format("Length 4 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored.v[Packing])); + if (!FloatEquals(stored[Packing], expectedLength[0] * 2)) { + return new std::string(std::format("Length 4 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored[Packing])); } - if (!FloatEquals(stored.v[Packing*2], expectedLength[0] * 3)) { - return new std::string(std::format("Length 4 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored.v[Packing*2])); + if (!FloatEquals(stored[Packing*2], expectedLength[0] * 3)) { + return new std::string(std::format("Length 4 vecE test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored[Packing*2])); } - if (!FloatEquals(stored.v[Packing*3], expectedLength[0] * 4)) { - return new std::string(std::format("Length 4 vecG test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored.v[Packing*3])); + if (!FloatEquals(stored[Packing*3], expectedLength[0] * 4)) { + return new std::string(std::format("Length 4 vecG test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored[Packing*3])); } } @@ -493,11 +502,11 @@ std::string* TestAllCombinations() { VectorType vecG = vecA * 4; auto result = VectorType::Normalize(vecA, vecC, vecE, vecG); VectorType<1, Packing*4> result2 = VectorType::Length(std::get<0>(result), std::get<1>(result), std::get<2>(result), std::get<3>(result)); - Vector::Alignment> stored = result2.Store(); + std::array::AlignmentElement> stored = result2.Store(); for(std::uint8_t i = 0; i < Len*Packing; i++) { - if (!FloatEquals(stored.v[i], T(1))) { - return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored.v[i])); + if (!FloatEquals(stored[i], T(1))) { + return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored[i])); } } } @@ -509,7 +518,7 @@ std::string* TestAllCombinations() { extern "C" { std::string* RunTest() { - std::string* err = TestAllCombinations<_Float16, VectorF16, VectorF16<1, 1>::MaxSize>(); + std::string* err = TestAllCombinations<_Float16, VectorF16, VectorF16<1, 1>::MaxElement>(); if (err) { return err; }