diff --git a/interfaces/Crafter.Math-Common.cppm b/interfaces/Crafter.Math-Common.cppm index 16041f4..7531896 100644 --- a/interfaces/Crafter.Math-Common.cppm +++ b/interfaces/Crafter.Math-Common.cppm @@ -8,11 +8,15 @@ import std; namespace Crafter { export template struct VectorF16; + export template + struct VectorF32; template struct VectorBase { template friend struct VectorF16; + template + friend struct VectorF32; protected: static consteval std::uint8_t GetAlingment() { if(Len * Packing * sizeof(T) <= 16) { @@ -23,9 +27,14 @@ namespace Crafter { return 64; } } - using VectorType = std::conditional_t< - (Len * Packing > 16), __m512h, - std::conditional_t<(Len * Packing > 8), __m256h, __m128h> + + using VectorType = std::conditional_t, + + std::conditional_t<(Len * Packing > 16), __m512h, + std::conditional_t<(Len * Packing > 8), __m256h, __m128h>>, + + std::conditional_t<(Len * Packing > 8), __m512, + std::conditional_t<(Len * Packing > 4), __m256, __m128>> >; VectorType v; @@ -87,10 +96,21 @@ namespace Crafter { template ShuffleValues> static consteval std::array GetShuffleMaskEpi8() { std::array shuffleMask {{0}}; - for(std::uint8_t i2 = 0; i2 < Packing; i2++) { - for(std::uint8_t i = 0; i < Len; i++) { - shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T)); - shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+1)] = ShuffleValues[i]*sizeof(T)+1+(i2*Len*sizeof(T)); + if constexpr(std::same_as) { + for(std::uint8_t i2 = 0; i2 < Packing; i2++) { + for(std::uint8_t i = 0; i < Len; i++) { + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T)); + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+1)] = ShuffleValues[i]*sizeof(T)+1+(i2*Len*sizeof(T)); + } + } + } else if constexpr(std::same_as) { + for(std::uint8_t i2 = 0; i2 < Packing; i2++) { + for(std::uint8_t i = 0; i < Len; i++) { + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T))] = ShuffleValues[i]*sizeof(T)+(i2*Len*sizeof(T)); + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+1)] = ShuffleValues[i]*sizeof(T)+1+(i2*Len*sizeof(T)); + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+2)] = ShuffleValues[i]*sizeof(T)+2+(i2*Len*sizeof(T)); + shuffleMask[(i2*Len*sizeof(T))+(i*sizeof(T)+3)] = ShuffleValues[i]*sizeof(T)+3+(i2*Len*sizeof(T)); + } } } return shuffleMask; @@ -107,6 +127,10 @@ namespace Crafter { high_bit = std::bit_cast( static_cast(1u << (std::numeric_limits::digits - 1)) ); + } else if constexpr(sizeof(T) == 4) { + high_bit = std::bit_cast( + static_cast(1u << (std::numeric_limits::digits - 1)) + ); } @@ -135,8 +159,19 @@ namespace Crafter { template static consteval std::array GetExtractLoMaskEpi16() { std::array mask{}; - for (std::uint16_t i2 = 0; i2 < Packing; i2++) { - for (std::uint16_t i = 0; i < ExtractLen; i++) { + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < ExtractLen; i++) { + mask[i2 * ExtractLen + i] = i + (i2 * Len); + } + } + return mask; + } + + template + static consteval std::array GetExtractLoMaskEpi32() { + std::array mask{}; + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < ExtractLen; i++) { mask[i2 * ExtractLen + i] = i + (i2 * Len); } } @@ -146,8 +181,8 @@ namespace Crafter { template ShuffleValues> static consteval std::uint8_t GetShuffleMaskEpi32() { std::uint8_t mask = 0; - for(std::uint8_t i = 0; i < std::min(Len, std::uint8_t(8)); i+=2) { - mask = mask | (ShuffleValues[i] & 0b11) << i; + for(std::uint8_t i = 0; i < std::min(Len, std::uint8_t(8)); i+=4/sizeof(T)) { + mask = mask | (ShuffleValues[i] & 0b11) << (8 / sizeof(T) * i); } return mask; } @@ -163,6 +198,17 @@ namespace Crafter { return shuffleMask; } + template ShuffleValues> + static consteval std::array GetPermuteMaskEpi32() { + std::array shuffleMask {{0}}; + for(std::uint8_t i2 = 0; i2 < Packing; i2++) { + for(std::uint8_t i = 0; i < Len; i++) { + shuffleMask[i2*Len+i] = ShuffleValues[i]+i2*Len; + } + } + return shuffleMask; + } + template ShuffleValues> static consteval std::uint8_t GetBlendMaskEpi16() requires (std::is_same_v){ std::uint8_t mask = 0; @@ -202,6 +248,45 @@ namespace Crafter { return mask; } + template ShuffleValues> + static consteval std::uint8_t GetBlendMaskEpi32() requires (std::is_same_v){ + std::uint8_t mask = 0; + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < Len; i++) { + if (ShuffleValues[i]) { + mask |= (1u << (i2 * Len + i)); + } + } + } + return mask; + } + + template ShuffleValues> + static consteval std::uint16_t GetBlendMaskEpi32() requires (std::is_same_v){ + std::uint16_t mask = 0; + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < Len; i++) { + if (ShuffleValues[i]) { + mask |= (1u << (i2 * Len + i)); + } + } + } + return mask; + } + + template ShuffleValues> + static consteval std::uint32_t GetBlendMaskEpi32() requires (std::is_same_v){ + std::uint32_t mask = 0; + for (std::uint8_t i2 = 0; i2 < Packing; i2++) { + for (std::uint8_t i = 0; i < Len; i++) { + if (ShuffleValues[i]) { + mask |= (1u << (i2 * Len + i)); + } + } + } + return mask; + } + static constexpr float two_over_pi = 0.6366197723675814f; static constexpr float pi_over_2_hi = 1.5707963267341256f; static constexpr float pi_over_2_lo = 6.077100506506192e-11f; @@ -221,6 +306,102 @@ namespace Crafter { static constexpr float s7 = 0.0000027526372f; static constexpr float s9 = -0.0000000239013f; + // --- 128-bit (SSE) helpers --- + static constexpr void range_reduce_f32x4(__m128 ax, __m128& r, __m128& r2, __m128i& q) { + __m128 fq = _mm_round_ps(_mm_mul_ps(ax, _mm_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q = _mm_cvtps_epi32(fq); + r = _mm_sub_ps(ax, _mm_mul_ps(fq, _mm_set1_ps(pi_over_2_hi))); + r = _mm_sub_ps(r, _mm_mul_ps(fq, _mm_set1_ps(pi_over_2_lo))); + r2 = _mm_mul_ps(r, r); + } + + static constexpr void sincos_poly_f32x4(__m128 r, __m128 r2, __m128& cos_r, __m128& sin_r) { + cos_r = _mm_fmadd_ps(_mm_set1_ps(c10), r2, _mm_set1_ps(c8)); + cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c6)); + cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c4)); + cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c2)); + cos_r = _mm_fmadd_ps(cos_r, r2, _mm_set1_ps(c0)); + + sin_r = _mm_fmadd_ps(_mm_set1_ps(s9), r2, _mm_set1_ps(s7)); + sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(s5)); + sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(s3)); + sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(s1)); + sin_r = _mm_fmadd_ps(sin_r, r2, _mm_set1_ps(1.0f)); + sin_r = _mm_mul_ps(sin_r, r); + } + + // cos(x): use cos_poly when q even, sin_poly when q odd; negate if (q+1)&2 + static constexpr __m128 cos_f32x4(__m128 x) { + const __m128 sign_mask = _mm_set1_ps(-0.0f); + __m128 ax = _mm_andnot_ps(sign_mask, x); + + __m128 r, r2; __m128i q; + range_reduce_f32x4(ax, r, r2, q); + + __m128 cos_r, sin_r; + sincos_poly_f32x4(r, r2, cos_r, sin_r); + + __m128i odd = _mm_and_si128(q, _mm_set1_epi32(1)); + __m128 use_sin = _mm_castsi128_ps(_mm_cmpeq_epi32(odd, _mm_set1_epi32(1))); + __m128 result = _mm_blendv_ps(cos_r, sin_r, use_sin); + + __m128i need_neg = _mm_and_si128( + _mm_add_epi32(q, _mm_set1_epi32(1)), _mm_set1_epi32(2)); + __m128 neg_mask = _mm_castsi128_ps(_mm_slli_epi32(need_neg, 30)); + return _mm_xor_ps(result, neg_mask); + } + + // sin(x): use sin_poly when q even, cos_poly when q odd; negate if q&2; respect input sign + static constexpr __m128 sin_f32x4(__m128 x) { + const __m128 sign_mask = _mm_set1_ps(-0.0f); + __m128 x_sign = _mm_and_ps(x, sign_mask); + __m128 ax = _mm_andnot_ps(sign_mask, x); + + __m128 r, r2; __m128i q; + range_reduce_f32x4(ax, r, r2, q); + + __m128 cos_r, sin_r; + sincos_poly_f32x4(r, r2, cos_r, sin_r); + + __m128i odd = _mm_and_si128(q, _mm_set1_epi32(1)); + __m128 use_cos = _mm_castsi128_ps(_mm_cmpeq_epi32(odd, _mm_set1_epi32(1))); + __m128 result = _mm_blendv_ps(sin_r, cos_r, use_cos); + + __m128i need_neg = _mm_and_si128(q, _mm_set1_epi32(2)); + __m128 neg_mask = _mm_castsi128_ps(_mm_slli_epi32(need_neg, 30)); + result = _mm_xor_ps(result, neg_mask); + + // Apply original sign of x + return _mm_xor_ps(result, x_sign); + } + + // --- 128-bit sincos --- + static constexpr void sincos_f32x4(__m128 x, __m128& out_sin, __m128& out_cos) { + const __m128 sign_mask = _mm_set1_ps(-0.0f); + __m128 x_sign = _mm_and_ps(x, sign_mask); + __m128 ax = _mm_andnot_ps(sign_mask, x); + + __m128 r, r2; __m128i q; + range_reduce_f32x4(ax, r, r2, q); + + __m128 cos_r, sin_r; + sincos_poly_f32x4(r, r2, cos_r, sin_r); + + __m128i odd = _mm_and_si128(q, _mm_set1_epi32(1)); + __m128 is_odd = _mm_castsi128_ps(_mm_cmpeq_epi32(odd, _mm_set1_epi32(1))); + + // cos: swap on odd, negate if (q+1)&2 + out_cos = _mm_blendv_ps(cos_r, sin_r, is_odd); + __m128i cos_neg = _mm_and_si128(_mm_add_epi32(q, _mm_set1_epi32(1)), _mm_set1_epi32(2)); + out_cos = _mm_xor_ps(out_cos, _mm_castsi128_ps(_mm_slli_epi32(cos_neg, 30))); + + // sin: swap on odd, negate if q&2, apply input sign + out_sin = _mm_blendv_ps(sin_r, cos_r, is_odd); + __m128i sin_neg = _mm_and_si128(q, _mm_set1_epi32(2)); + out_sin = _mm_xor_ps(out_sin, _mm_castsi128_ps(_mm_slli_epi32(sin_neg, 30))); + out_sin = _mm_xor_ps(out_sin, x_sign); + } + // Reduce |x| into [-pi/4, pi/4], return reduced value and quadrant static constexpr void range_reduce_f32x8(__m256 ax, __m256& r, __m256& r2, __m256i& q) { __m256 fq = _mm256_round_ps(_mm256_mul_ps(ax, _mm256_set1_ps(two_over_pi)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); diff --git a/interfaces/Crafter.Math-VectorF16.cppm b/interfaces/Crafter.Math-VectorF16.cppm index fa43358..1962787 100755 --- a/interfaces/Crafter.Math-VectorF16.cppm +++ b/interfaces/Crafter.Math-VectorF16.cppm @@ -333,34 +333,6 @@ namespace Crafter { } } - constexpr void Normalize() requires(Packing == 1) { - if constexpr(std::is_same_v::VectorType, __m128h>) { - _Float16 dot = LengthSq(); - __m128h vec = _mm_set1_ph(dot); - __m128h sqrt = _mm_sqrt_ph(vec); - this->v = _mm_div_ph(this->v, sqrt); - } else if constexpr(std::is_same_v::VectorType, __m256h>) { - _Float16 dot = LengthSq(); - __m256h vec = _mm256_set1_ph(dot); - __m256h sqrt = _mm256_sqrt_ph(vec); - this->v = _mm256_div_ph(this->v, sqrt); - } else { - _Float16 dot = LengthSq(); - __m512h vec = _mm512_set1_ph(dot); - __m512h sqrt = _mm512_sqrt_ph(vec); - this->v = _mm512_div_ph(this->v, sqrt); - } - } - - constexpr _Float16 Length() const requires(Packing == 1) { - _Float16 Result = LengthSq(); - return std::sqrtf(Result); - } - - constexpr _Float16 LengthSq() const requires(Packing == 1) { - return Dot(*this, *this); - } - constexpr VectorF16 Cos() { if constexpr (std::is_same_v::VectorType, __m128h>) { __m256 wide = _mm256_cvtph_ps(_mm_castph_si128(this->v)); diff --git a/interfaces/Crafter.Math-VectorF32.cppm b/interfaces/Crafter.Math-VectorF32.cppm index 2573524..268f5da 100755 --- a/interfaces/Crafter.Math-VectorF32.cppm +++ b/interfaces/Crafter.Math-VectorF32.cppm @@ -18,215 +18,213 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ module; #ifdef __x86_64 -#include +#include #endif export module Crafter.Math:VectorF32; import std; -import :Vector; import :Common; +#ifdef __AVX512FP16__ namespace Crafter { - export template - struct VectorF32 { - #ifdef __AVX512F__ - static constexpr std::uint32_t MaxSize = 16; - #else - static constexpr std::uint32_t MaxSize = 8; - #endif - static constexpr std::uint32_t MaxElement = 4; - static consteval std::uint32_t GetAlignment() { - #ifdef __AVX512F__ - if constexpr (Len * Packing <= 4) { - return 4; - } - if constexpr (Len * Packing <= 8) { - return 8; - } - if constexpr (Len * Packing <= 16) { - return 16; - } - static_assert(Len * Packing <= 16, "Len * Packing is larger than supported max size of 16"); - #else - if constexpr (Len * Packing <= 4) { - return 4; - } - if constexpr (Len * Packing <= 8) { - return 8; - } - static_assert(Len * Packing <= 8, "Len * Packing is larger than supported max size of 8"); - #endif - } - - using VectorType = std::conditional_t< - (Len * Packing > 8), __m512h, - std::conditional_t<(Len * Packing > 4), __m256h, __m128> - >; - - VectorType v; + export template + struct VectorF32 : public VectorBase { + template + friend struct VectorF32; constexpr VectorF32() = default; - constexpr VectorF32(VectorType v) : v(v) {} + constexpr VectorF32(VectorBase::VectorType v) { + this->v = v; + } constexpr VectorF32(const float* vB) { Load(vB); }; - constexpr VectorF32(const _Float16* vB) { - Load(vB); - }; constexpr VectorF32(float val) { - if constexpr(std::is_same_v) { - v = _mm_set1_ps(val); - } else if constexpr(std::is_same_v) { - v = _mm256_set1_ps(val); + if constexpr(std::is_same_v::VectorType, __m128>) { + this->v = _mm_set1_ps(val); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + this->v = _mm256_set1_ps(val); } else { - v = _mm512_set1_ps(val); + this->v = _mm512_set1_ps(val); } }; constexpr void Load(const float* vB) { - if constexpr(std::is_same_v) { - v = _mm_loadu_ps(vB); - } else if constexpr(std::is_same_v) { - v = _mm256_loadu_ps(vB); + if constexpr(std::is_same_v::VectorType, __m128>) { + this->v = _mm_loadu_ps(vB); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + this->v = _mm256_loadu_ps(vB); } else { - v = _mm512_loadu_ps(vB); + this->v = _mm512_loadu_ps(vB); } } constexpr void Store(float* vB) const { - if constexpr(std::is_same_v) { - _mm_storeu_ps(vB, v); - } else if constexpr(std::is_same_v) { - _mm256_storeu_ps(vB, v); + if constexpr(std::is_same_v::VectorType, __m128>) { + _mm_storeu_ps(vB, this->v); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + _mm256_storeu_ps(vB, this->v); } else { - _mm512_storeu_ps(vB, v); - } - } - constexpr void Load(const _Float16* vB) { - if constexpr(std::is_same_v) { - v = _mm_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB))); - } else if constexpr(std::is_same_v) { - v = _mm256_cvtps_ps(_mm_loadu_si128(reinterpret_cast<__m128i const*>(vB))); - } else { - v = _mm512_cvtps_ps(_mm256_loadu_si256(reinterpret_cast<__m256i const*>(vB))); - } - } - constexpr void Store(_Float16* vB) const { - if constexpr(std::is_same_v) { - _mm_storeu_si128(_mm_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v); - } else if constexpr(std::is_same_v) { - _mm_storeu_si128(_mm256_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v); - } else { - _mm256_storeu_si256(_mm512_cvtps_ps(v, _MM_FROUND_TO_NEAREST_INT), v); + _mm512_storeu_ps(vB, this->v); } } - constexpr std::array Store() const { - std::array returnArray; + constexpr std::array::AlignmentElement> Store() const { + std::array::AlignmentElement> returnArray; Store(returnArray.data()); return returnArray; } - template + template constexpr operator VectorF32() const { if constexpr (Len == BLen) { - if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128>) { - return VectorF32(_mm256_castps256_ps128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m128>) { - return VectorF32(_mm512_castps512_ps128(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256>) { - return VectorF32(_mm512_castps512_ps256(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m256>) { - return VectorF32(_mm256_castps128_ps256(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512>) { - return VectorF32(_mm512_castps128_ps512(v)); - } else if constexpr(std::is_same_v && std::is_same_v::VectorType, __m512>) { - return VectorF32(_mm512_castps256_ps512(v)); + if constexpr(std::is_same_v::VectorType, __m256> && std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm256_castps256_ps128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512> && std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm512_castps512_ps128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512> && std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm512_castps512_ps256(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m128> && std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_castps128_ps256(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m128> && std::is_same_v::VectorType, __m512>) { + return VectorF32(_mm512_castps128_ps512(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m256> && std::is_same_v::VectorType, __m512>) { + return VectorF32(_mm512_castps256_ps512(this->v)); } else { - return VectorF32(v); + return VectorF32(this->v); } } else if constexpr (BLen <= Len) { return this->template ExtractLo(); } else { - return VectorF32(v); + if constexpr(std::is_same_v::VectorType, __m128>) { + if constexpr(std::is_same_v::VectorType, __m128>) { + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetExtractLoMaskEpi8(); + __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + constexpr std::array::AlignmentElement> permMask =VectorBase::template GetExtractLoMaskepi32(); + __m256i permIdx = _mm256_loadu_epi32(permMask.data()); + __m256i result = _mm256_permutexvar_epi32(permIdx, _mm_castps_si256(this->v)); + return VectorF32(_mm_castsi128_ps(_mm256_castsi256_si128(result))); + } else { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m512i permIdx = _mm512_loadu_epi32(permMask.data()); + __m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v)); + return VectorF32(_mm_castsi128_ps(_mm512_castsi512_si128(result))); + } + } else if constexpr(std::is_same_v::VectorType, __m256>) { + if constexpr(std::is_same_v::VectorType, __m128>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m256i permIdx = _mm256_loadu_epi32(permMask.data()); + __m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castsi128_si256(_mm_castps_si128(this->v))); + return VectorF32(_mm256_castsi256_ps(result)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m256i permIdx = _mm256_loadu_epi32(permMask.data()); + __m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v)); + return VectorF32(_mm256_castsi256_ps(result)); + } else { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m256i permIdx = _mm512_loadu_epi32(permMask.data()); + __m256i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi512_si256(_mm512_castps_si512(this->v))); + return VectorF32(_mm256_castsi256_ps(result)); + } + } else { + if constexpr(std::is_same_v::VectorType, __m128>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m512i permIdx = _mm512_loadu_epi32(permMask.data()); + __m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi128_si512(_mm_castps_si128(this->v))); + return VectorF32(_mm512_castsi512_ps(result)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m512i permIdx = _mm512_loadu_epi32(permMask.data()); + __m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi256_si512(_mm256_castps_si256(this->v))); + return VectorF32(_mm512_castsi512_ps(result)); + } else { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m512i permIdx = _mm512_loadu_epi32(permMask.data()); + __m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v)); + return VectorF32(_mm512_castsi512_ps(result)); + } + } } } - - constexpr VectorF32 operator+(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_add_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_add_ph(v, b.v)); + + constexpr VectorF32 operator+(VectorF32 b) const { + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_add_ps(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_add_ps(this->v, b.v)); } else { - return VectorF32(_mm512_add_ph(v, b.v)); + return VectorF32(_mm512_add_ps(this->v, b.v)); } } constexpr VectorF32 operator-(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_sub_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_sub_ph(v, b.v)); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_sub_ps(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_sub_ps(this->v, b.v)); } else { - return VectorF32(_mm512_sub_ph(v, b.v)); + return VectorF32(_mm512_sub_ps(this->v, b.v)); } } constexpr VectorF32 operator*(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_mul_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_mul_ph(v, b.v)); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_mul_ps(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_mul_ps(this->v, b.v)); } else { - return VectorF32(_mm512_mul_ph(v, b.v)); + return VectorF32(_mm512_mul_ps(this->v, b.v)); } } constexpr VectorF32 operator/(VectorF32 b) const { - if constexpr(std::is_same_v) { - return VectorF32(_mm_div_ph(v, b.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_div_ph(v, b.v)); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_div_ps(this->v, b.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_div_ps(this->v, b.v)); } else { - return VectorF32(_mm512_div_ph(v, b.v)); + return VectorF32(_mm512_div_ps(this->v, b.v)); } } constexpr void operator+=(VectorF32 b) { - if constexpr(std::is_same_v) { - v = _mm_add_ps(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_add_ps(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128>) { + this->v = _mm_add_ps(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + this->v = _mm256_add_ps(this->v, b.v); } else { - v = _mm512_add_ps(v, b.v); + this->v = _mm512_add_ps(this->v, b.v); } } constexpr void operator-=(VectorF32 b) { - if constexpr(std::is_same_v) { - v = _mm_sub_ps(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_sub_ps(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128>) { + this->v = _mm_sub_ps(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + this->v = _mm256_sub_ps(this->v, b.v); } else { - v = _mm512_sub_ps(v, b.v); + this->v = _mm512_sub_ps(this->v, b.v); } } constexpr void operator*=(VectorF32 b) { - if constexpr(std::is_same_v) { - v = _mm_mul_ps(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_mul_ps(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128>) { + this->v = _mm_mul_ps(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + this->v = _mm256_mul_ps(this->v, b.v); } else { - v = _mm512_mul_ps(v, b.v); + this->v = _mm512_mul_ps(this->v, b.v); } } constexpr void operator/=(VectorF32 b) { - if constexpr(std::is_same_v) { - v = _mm_div_ps(v, b.v); - } else if constexpr(std::is_same_v) { - v = _mm256_div_ps(v, b.v); + if constexpr(std::is_same_v::VectorType, __m128>) { + this->v = _mm_div_ps(this->v, b.v); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + this->v = _mm256_div_ps(this->v, b.v); } else { - v = _mm512_div_ps(v, b.v); + this->v = _mm512_div_ps(this->v, b.v); } } @@ -271,415 +269,944 @@ namespace Crafter { } constexpr VectorF32 operator-(){ - return Negate()>(); + return Negate::GetAllTrue()>(); } - constexpr bool operator==(VectorF32 b) const { - if constexpr(std::is_same_v) { - return _mm_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) == 255; - } else if constexpr(std::is_same_v) { - return _mm256_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) == 65535; + constexpr bool operator==(VectorF32 b) const { + if constexpr(std::is_same_v::VectorType, __m128>) { + return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 15; + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 255; } else { - return _mm512_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) == 4294967295; + return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 65535; } } - constexpr bool operator!=(VectorF32 b) const { - if constexpr(std::is_same_v) { - return _mm_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) != 255; - } else if constexpr(std::is_same_v) { - return _mm256_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) != 65535; + constexpr bool operator!=(VectorF32 b) const { + if constexpr(std::is_same_v::VectorType, __m128>) { + return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 15; + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 255; } else { - return _mm512_cmp_ps_mask(v, b.v, _CMP_EQ_OQ) != 4294967295; + return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 65535; } } - constexpr void Normalize() { - if constexpr(std::is_same_v) { - float dot = LengthSq(); - __m128 vec = _mm_set1_ps(dot); - __m128 sqrt = _mm_rsqrt_ps(vec); - v = _mm_div_ps(v, sqrt); - } else if constexpr(std::is_same_v) { - float dot = LengthSq(); - __m256 vec = _mm256_set1_ps(dot); - __m256 sqrt = _mm256_rsqrt_ps(vec); - v = _mm256_div_ps(v, sqrt); - } else { - float dot = LengthSq(); - __m512 vec = _mm512_set1_ps(dot); - __m512 sqrt = _mm512_rsqrt14_ps(vec); - v = _mm512_div_ps(v, sqrt); - } - } - - constexpr float Length() const { - float Result = LengthSq(); - return std::sqrtf(Result); - } - - constexpr float LengthSq() const { - return Dot(*this, *this); - } - - template ShuffleValues> - constexpr VectorF32 Shuffle() { - if constexpr(CheckEpi32Shuffle()) { - if constexpr(std::is_same_v) { - return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(v), GetShuffleMaskEpi32()))); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(v), GetShuffleMaskEpi32()))); - } else { - return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(v), GetShuffleMaskEpi32()))); - } - } else if constexpr(CheckEpi8Shuffle()){ - if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); + template + constexpr VectorF32 ExtractLo() const { + if constexpr(Packing > 1) { + if constexpr(std::is_same_v::VectorType, __m128>) { + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetExtractLoMaskEpi8(); __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); - return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec))); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); - __m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data()); - return VectorF32(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(v)), _mm512_castsi256_si512(shuffleVec))))); + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m256i permIdx = _mm256_loadu_epi32(permMask.data()); + __m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v)); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm256_castps256_ps128(_mm256_castsi256_ps(result))); + } else { + return VectorF32(_mm256_castsi256_ps(result)); + } } else { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); - __m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data()); - return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(v), shuffleVec))); + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetExtractLoMaskEpi32(); + __m512i permIdx = _mm512_loadu_epi32(permMask.data()); + __m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v)); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm512_castps512_ps128(_mm512_castsi512_ps(result))); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm512_castps512_ps256(_mm512_castsi512_ps(result))); + } else { + return VectorF32(_mm512_castsi512_ps(result)); + } } } else { - if constexpr(std::is_same_v) { - constexpr std::array::Alignment*2> shuffleMask = GetShuffleMaskEpi8(); - __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); - return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(v), shuffleVec))); - } else if constexpr(std::is_same_v) { - constexpr std::array::Alignment> permMask = GetPermuteMaskEpi32(); - __m256i permIdx = _mm256_loadu_epi16(permMask.data()); - return VectorF32(_mm256_castsi256_ps(_mm256_permutexvar_epi16(permIdx, _mm256_castps_si256(v)))); + if constexpr(std::is_same_v::VectorType, __m256> && std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm256_castps256_ps128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512> && std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm512_castps512_ps128(this->v)); + } else if constexpr(std::is_same_v::VectorType, __m512> && std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm512_castps512_ps256(this->v)); } else { - constexpr std::array::Alignment> permMask = GetPermuteMaskEpi32(); - __m512i permIdx = _mm512_loadu_epi16(permMask.data()); - return VectorF32(_mm512_castsi512_ps(_mm512_permutexvar_epi16(permIdx, _mm512_castps_si512(v)))); - } - } - } - - - static constexpr VectorF32 MulitplyAdd(VectorF32 a, VectorF32 b, VectorF32 add) { - if constexpr(std::is_same_v) { - return VectorF32(_mm_fmadd_ps(a, b, add)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_fmadd_ps(a, b, add)); - } else { - return VectorF32(_mm512_fmadd_ps(a, b, add)); - } - } - - static constexpr VectorF32 MulitplySub(VectorF32 a, VectorF32 b, VectorF32 sub) { - if constexpr(std::is_same_v) { - return VectorF32(_mm_fmsub_ps(a, b, sub)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_fmsub_ps(a, b, sub)); - } else { - return VectorF32(_mm512_fmsub_ps(a, b, sub)); - } - } - - constexpr static VectorF32 Cross(VectorF32 a, VectorF32 b) requires(Len == 3) { - if constexpr(Len == 3) { - if constexpr(Repeats == 1) { - __m128 row4 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(b.v), 0b01'10'00'11)); - __m128 row3 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(a.v), 0b01'10'00'11)); - __m128 result = _mm_mul_ps(row3, row4); - - __m128 row1 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(a.v), 0b10'00'01'11)); - __m128 row2 = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(b.v), 0b10'00'01'11)); - - return _mm_fmsub_ps(row1,row2,result); - } - if constexpr(Repeats == 2) { - __m256 row4 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(b.v), 0b01'10'00'11)); - __m256 row3 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(a.v), 0b01'10'00'11)); - __m256 result = _mm256_mul_ps(row3, row4); - - __m256 row1 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(a.v), 0b10'00'01'11)); - __m256 row2 = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(b.v), 0b10'00'01'11)); - - return _mm256_fmsub_ps(row1,row2,result); - } - if constexpr(Repeats == 4) { - __m512 row4 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(b.v), 0b01'10'00'11)); - __m512 row3 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(a.v), 0b01'10'00'11)); - __m512 result = _mm512_mul_ps(row3, row4); - - __m512 row1 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(a.v), 0b10'00'01'11)); - __m512 row2 = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(b.v), 0b10'00'01'11)); - - return _mm512_fmsub_ps(row1,row2,result); + return VectorF32(this->v); } } } - constexpr static float Dot(VectorF32 a, VectorF32 b) { - if constexpr(std::is_same_v) { - union UN { - float f; - int i; - }; - UN val; - val.i = _mm_extract_ps(_mm_dp_ps(a.v, b.v, 0b01110111), 0); - return val.f; - } else if constexpr(std::is_same_v) { - union UN { - float f; - int i; - }; - UN val; - val.i = _mm_extract_epi32(_mm256_castsi256_si128(_mm256_castps_si256(_mm256_dp_ps(a.v, b.v, 0b01110111))), 0); - return val.f; + constexpr VectorF32 Cos() { + if constexpr (std::is_same_v::VectorType, __m128>) { + return VectorF32(VectorBase::cos_f32x4(this->v)); + } else if constexpr (std::is_same_v::VectorType, __m256>) { + return VectorF32(VectorBase::cos_f32x8(this->v)); } else { - __m512 mul = _mm512_mul_ps(a.v, b.v); - return _mm512_reduce_add_ps(mul); + return VectorF32(VectorBase::cos_f32x16(this->v)); + } + } + + constexpr VectorF32 Sin() { + if constexpr (std::is_same_v::VectorType, __m128>) { + return VectorF32(VectorBase::sin_f32x4(this->v)); + } else if constexpr (std::is_same_v::VectorType, __m256>) { + return VectorF32(VectorBase::sin_f32x8(this->v)); + } else { + return VectorF32(VectorBase::sin_f32x16(this->v)); + } + } + + std::tuple, VectorF32> SinCos() { + if constexpr (std::is_same_v::VectorType, __m128>) { + __m128 s, c; + VectorBase::sincos_f32x4(this->v, s, c); + return { + VectorF32(s), + VectorF32(c) + }; + + } else if constexpr (std::is_same_v::VectorType, __m256>) { + __m256 s, c; + VectorBase::sincos_f32x8(this->v, s, c); + return { + VectorF32(s), + VectorF32(c) + }; + + } else { + __m512 s, c; + VectorBase::sincos_f32x16(this->v, s, c); + return { + VectorF32(s), + VectorF32(c) + }; + } + } + + template values> + constexpr VectorF32 Negate() { + std::array::AlignmentElement> mask = VectorBase::template GetNegateMask(); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(this->v), _mm_loadu_epi32(mask.data())))); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(this->v), _mm256_loadu_epi32(mask.data())))); + } else { + return VectorF32(_mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(this->v), _mm512_loadu_epi32(mask.data())))); + } + } + + static constexpr VectorF32 MulitplyAdd(VectorF32 a, VectorF32 b, VectorF32 add) { + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_fmadd_ps(a.v, b.v, add.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_fmadd_ps(a.v, b.v, add.v)); + } else { + return VectorF32(_mm512_fmadd_ps(a.v, b.v, add.v)); } } - - constexpr static std::tuple, VectorF32, VectorF32, VectorF32, VectorF32, VectorF32, VectorF32, VectorF32> Normalize( - VectorF32 A, - VectorF32 B, - VectorF32 C, - VectorF32 D - ) requires(Packing == 1) { - if constexpr(std::is_same_v) { - VectorF32 lenght = Length(A, B, C, D); + static constexpr VectorF32 MulitplySub(VectorF32 a, VectorF32 b, VectorF32 sub) { + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_fmsub_ps(a.v, b.v, sub.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_fmsub_ps(a.v, b.v, sub.v)); + } else { + return VectorF32(_mm512_fmsub_ps(a.v, b.v, sub.v)); + } + } + + constexpr static VectorF32 Cross(VectorF32 a, VectorF32 b) requires(Len == 3) { + if constexpr(std::is_same_v::VectorType, __m128>) { + constexpr std::array::Alignment> shuffleMask1 = VectorBase::template GetShuffleMaskEpi8<{{1,2,0}}>(); + __m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data()); + __m128 row1 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec1)); + __m128 row4 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec1)); + + constexpr std::array::Alignment> shuffleMask3 = VectorBase::template GetShuffleMaskEpi8<{{2,0,1}}>(); + __m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data()); + __m128 row3 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec3)); + __m128 row2 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec3)); + + __m128 result = _mm_mul_ps(row3, row4); + return _mm_fmsub_ps(row1,row2,result); + } else if constexpr (std::is_same_v::VectorType, __m256>) { + constexpr std::array::Alignment> shuffleMask1 = VectorBase::template GetShuffleMaskEpi8<{{1,2,0}}>(); + __m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data())); + __m256 row1 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec1))); + __m256 row4 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec1))); + + constexpr std::array::Alignment> shuffleMask3 = VectorBase::template GetShuffleMaskEpi8<{{2,0,1}}>(); + + __m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data())); + __m256 row3 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec3))); + __m256 row2 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec3))); + + __m256 result = _mm256_mul_ps(row3, row4); + return _mm256_fmsub_ps(row1,row2,result); + } else { + constexpr std::array::Alignment> shuffleMask1 = VectorBase::template GetShuffleMaskEpi8<{{1,2,0}}>(); + + __m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data()); + __m512 row1 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec1)); + __m512 row4 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec1)); + + constexpr std::array::Alignment> shuffleMask3 = VectorBase::template GetShuffleMaskEpi8<{{2,0,1}}>(); + + __m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data()); + __m512 row3 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec3)); + __m512 row2 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec3)); + + __m512 result = _mm512_mul_ps(row3, row4); + return _mm512_fmsub_ps(row1,row2,result); + } + } + + template ShuffleValues> + constexpr VectorF32 Shuffle() { + if constexpr(VectorBase::template CheckEpi32Shuffle()) { + constexpr std::uint8_t imm = VectorBase::template GetShuffleMaskEpi32(); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(this->v), imm))); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(this->v), imm))); + } else { + return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(this->v), imm))); + } + } else if constexpr(VectorBase::template CheckEpi8Shuffle()){ + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetShuffleMaskEpi8(); + if constexpr(std::is_same_v::VectorType, __m128>) { + __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + __m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(this->v)), _mm512_castsi256_si512(shuffleVec))))); + } else { + __m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(this->v), shuffleVec))); + } + } else { + if constexpr(std::is_same_v::VectorType, __m128>) { + constexpr std::array::Alignment> shuffleMask = VectorBase::template GetShuffleMaskEpi8(); + __m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data()); + return VectorF32(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec))); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetPermuteMaskEpi32(); + __m256i permIdx = _mm256_loadu_epi32(permMask.data()); + return VectorF32(_mm256_castsi256_ps(_mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v)))); + } else { + constexpr std::array::AlignmentElement> permMask = VectorBase::template GetPermuteMaskEpi32(); + __m512i permIdx = _mm512_loadu_epi32(permMask.data()); + return VectorF32(_mm512_castsi512_ps(_mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v)))); + } + } + } + + constexpr static std::tuple, VectorF32, VectorF32, VectorF32> Normalize( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128>) { + VectorF32<1, 4> lenght = LengthNoShuffle(A, C, B, D); constexpr float oneArr[] {1, 1, 1, 1}; __m128 one = _mm_loadu_ps(oneArr); - __m128 fLenght = _mm_div_ps(one, lenght.v); + VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v)); + + VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0}}>(); + VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1}}>(); + VectorF32<4, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2}}>(); + VectorF32<4, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3}}>(); - __m128 fLenghtA = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b00'00'00'00)); - __m128 fLenghtB = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b01'01'01'01)); - __m128 fLenghtC = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b10'10'10'10)); - __m128 fLenghtD = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b11'11'11'11)); return { - _mm_mul_ps(A.v, fLenghtA), - _mm_mul_ps(B.v, fLenghtB), - _mm_mul_ps(C.v, fLenghtC), - _mm_mul_ps(D.v, fLenghtD), + _mm_mul_ps(A.v, fLenghtA.v), + _mm_mul_ps(B.v, fLenghtB.v), + _mm_mul_ps(C.v, fLenghtC.v), + _mm_mul_ps(D.v, fLenghtD.v) }; - } else if constexpr(std::is_same_v) { - VectorF32 lenght = Length(A, B, C, D); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + VectorF32<1, 8> lenght = LengthNoShuffle(A, C, B, D); constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m256 one = _mm256_loadu_ps(oneArr); - __m256 fLenght = _mm256_div_ps(one, lenght.v); + VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v)); + + + VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,4,4,4,4}}>(); + VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1,5,5,5,5}}>(); + VectorF32<8, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,6,6,6,6}}>(); + VectorF32<8, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3,7,7,7,7}}>(); - __m256 fLenghtA = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b00'00'00'00)); - __m256 fLenghtB = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b01'01'01'01)); - __m256 fLenghtC = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b10'10'10'10)); - __m256 fLenghtD = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b11'11'11'11)); return { - _mm256_mul_ps(A.v, fLenghtA), - _mm256_mul_ps(B.v, fLenghtB), - _mm256_mul_ps(C.v, fLenghtC), - _mm256_mul_ps(D.v, fLenghtD), + _mm256_mul_ps(A.v, fLenghtA.v), + _mm256_mul_ps(B.v, fLenghtB.v), + _mm256_mul_ps(C.v, fLenghtC.v), + _mm256_mul_ps(D.v, fLenghtD.v) }; } else { - VectorF32 lenght = Length(A, B, C, D); + VectorF32<1, 16> lenght = LengthNoShuffle(A, C, B, D); constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512 one = _mm512_loadu_ps(oneArr); - __m512 fLenght = _mm512_div_ps(one, lenght.v); + VectorF32<16, 1> fLenght(_mm512_div_ps(one, lenght.v)); + VectorF32<16, 1> fLenght2(lenght.v); + + VectorF32<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,4,4,4,4,8,8,8,8,12,12,12,12}}>(); + VectorF32<16, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1,5,5,5,5,9,9,9,9,13,13,13,13}}>(); + VectorF32<16, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14}}>(); + VectorF32<16, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3,7,7,7,7,11,11,11,11,15,15,15,15}}>(); + - __m512 fLenghtA = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b00'00'00'00)); - __m512 fLenghtB = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b01'01'01'01)); - __m512 fLenghtC = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b10'10'10'10)); - __m512 fLenghtD = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b11'11'11'11)); return { - _mm512_mul_ps(A.v, fLenghtA), - _mm512_mul_ps(B.v, fLenghtB), - _mm512_mul_ps(C.v, fLenghtC), - _mm512_mul_ps(D.v, fLenghtD), + VectorF32(_mm512_mul_ps(A.v, fLenghtA.v)), + VectorF32(_mm512_mul_ps(B.v, fLenghtB.v)), + VectorF32(_mm512_mul_ps(C.v, fLenghtC.v)), + VectorF32(_mm512_mul_ps(D.v, fLenghtD.v)), }; } } - constexpr static std::tuple, VectorF32, VectorF32, VectorF32> Normalize( - VectorF32 A, - VectorF32 C - ) requires(Packing == 2) { - if constexpr(std::is_same_v) { - VectorF32 lenght = Length(A, C); + constexpr static std::tuple, VectorF32, VectorF32, VectorF32> Normalize( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 3 && Packing == 1) { + VectorF32<1, 4> lenght = Length(A, B, C, D); + constexpr float oneArr[] {1, 1, 1, 1}; + __m128 one = _mm_loadu_ps(oneArr); + VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v)); + + VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0}}>(); + VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1}}>(); + VectorF32<4, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2}}>(); + VectorF32<4, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3}}>(); + + return { + _mm_mul_ps(A.v, fLenghtA.v), + _mm_mul_ps(B.v, fLenghtB.v), + _mm_mul_ps(C.v, fLenghtC.v), + _mm_mul_ps(D.v, fLenghtD.v) + }; + } + + constexpr static std::tuple, VectorF32, VectorF32, VectorF32> Normalize( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 3 && Packing == 2) { + VectorF32<1, 8> lenght = Length(A, B, C, D); + constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; + __m256 one = _mm256_loadu_ps(oneArr); + VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v)); + + VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0, 1,1,1}}>(); + VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{2,2,2, 3,3,3}}>(); + VectorF32<8, 1> fLenghtC = fLenght.template Shuffle<{{4,4,4, 5,5,5}}>(); + VectorF32<8, 1> fLenghtD = fLenght.template Shuffle<{{6,6,6, 7,7,7}}>(); + + return { + _mm256_mul_ps(A.v, fLenghtA.v), + _mm256_mul_ps(B.v, fLenghtB.v), + _mm256_mul_ps(C.v, fLenghtC.v), + _mm256_mul_ps(D.v, fLenghtD.v) + }; + } + + constexpr static std::tuple, VectorF32, VectorF32> Normalize( + VectorF32 A, + VectorF32 B, + VectorF32 C + ) requires(Len == 3 && Packing == 5) { + VectorF32<1, 15> lenght = Length(A, B, C); + constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + __m512 one = _mm512_loadu_ps(oneArr); + VectorF32<15, 1> fLenght(_mm512_div_ps(one, lenght.v)); + + VectorF32<15, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0, 1,1,1, 2,2,2, 3,3,3, 4,4,4}}>(); + VectorF32<15, 1> fLenghtB = fLenght.template Shuffle<{{5,5,5, 6,6,6, 7,7,7, 8,8,8, 9,9,9}}>(); + VectorF32<15, 1> fLenghtC = fLenght.template Shuffle<{{10,10,10, 11,11,11, 12,12,12, 13,13,13, 14,14,14}}>(); + + return { + _mm512_mul_ps(A.v, fLenghtA.v), + _mm512_mul_ps(B.v, fLenghtB.v), + _mm512_mul_ps(C.v, fLenghtC.v), + }; + } + + constexpr static std::tuple, VectorF32> Normalize( + VectorF32 A, + VectorF32 B + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128>) { + VectorF32<1, 4> lenght = LengthNoShuffle(A, B); constexpr float oneArr[] {1, 1, 1, 1}; __m128 one = _mm_loadu_ps(oneArr); - __m128 fLenght = _mm_div_ps(one, lenght.v); + VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v)); - __m128 fLenghtA = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b00'00'01'01)); - __m128 fLenghtC = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(fLenght), 0b10'10'11'11)); + VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1}}>(); + VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3}}>(); return { - _mm_mul_ps(A.v, fLenghtA), - _mm_mul_ps(C.v, fLenghtC), + _mm_mul_ps(A.v, fLenghtA.v), + _mm_mul_ps(B.v, fLenghtB.v), }; - } else if constexpr(std::is_same_v) { - VectorF32 lenght = Length(A, C); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + VectorF32<1, 8> lenght = LengthNoShuffle(A, B); constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1}; __m256 one = _mm256_loadu_ps(oneArr); - __m256 fLenght = _mm256_div_ps(one, lenght.v); + VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v)); + + VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,4,4,5,5}}>(); + VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3,6,6,7,7}}>(); - __m256 fLenghtA = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b00'00'01'01)); - __m256 fLenghtC = _mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(fLenght), 0b10'10'11'11)); return { - _mm256_mul_ps(A.v, fLenghtA), - _mm256_mul_ps(C.v, fLenghtC), + _mm256_mul_ps(A.v, fLenghtA.v), + _mm256_mul_ps(B.v, fLenghtB.v), }; } else { - VectorF32 lenght = Length(A, C); + VectorF32<1, 16> lenght = LengthNoShuffle(A, B); constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; __m512 one = _mm512_loadu_ps(oneArr); - __m512 fLenght = _mm512_div_ps(one, lenght.v); + VectorF32<16, 1> fLenght(_mm512_div_ps(one, lenght.v)); + + VectorF32<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,4,4,5,5,8,8,9,9,12,12,13,13}}>(); + VectorF32<16, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3,6,6,7,7,10,10,11,11,14,14,15,15}}>(); - __m512 fLenghtA = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b00'00'01'01)); - __m512 fLenghtC = _mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(fLenght), 0b10'10'11'11)); return { - _mm512_mul_ps(A.v, fLenghtA), - _mm512_mul_ps(C.v, fLenghtC), + _mm512_mul_ps(A.v, fLenghtA.v), + _mm512_mul_ps(B.v, fLenghtB.v), }; } } - constexpr static VectorF32 Length( - VectorF32 A, - VectorF32 B, - VectorF32 C, - VectorF32 D - ) requires(Packing == 1) { - VectorF32 lenghtSq = LengthSq(A, B, C, D); - if constexpr(std::is_same_v) { - return VectorF32(_mm_sqrt_ps(lenghtSq.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_sqrt_ps(lenghtSq.v)); + constexpr static VectorF32<1, Packing*4> Length( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + VectorF32<1, Packing*4> lenghtSq = LengthSq(A, B, C, D); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v)); } else { - return VectorF32(_mm512_sqrt_ps(lenghtSq.v)); + return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v)); } } - constexpr static VectorF32 Length( - VectorF32 A, - VectorF32 C - ) requires(Packing == 2) { - VectorF32 lenghtSq = LengthSq(A, C); - if constexpr(std::is_same_v) { - return VectorF32(_mm_sqrt_ps(lenghtSq.v)); - } else if constexpr(std::is_same_v) { - return VectorF32(_mm256_sqrt_ps(lenghtSq.v)); + constexpr static VectorF32<1, 4> Length( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 3 && Packing == 1) { + VectorF32<1, 4> lenghtSq = LengthSq(A, B, C, D); + return VectorF32<1, 4>(_mm_sqrt_ps(lenghtSq.v)); + } + + constexpr static VectorF32<1, 8> Length( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 3 && Packing == 2) { + VectorF32<1, 8> lenghtSq = LengthSq(A, B, C, D); + return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v)); + } + + constexpr static VectorF32<1, 15> Length( + VectorF32 A, + VectorF32 B, + VectorF32 C + ) requires(Len == 3 && Packing == 5) { + VectorF32<1, 15> lenghtSq = LengthSq(A, B, C); + return VectorF32<1, 15>(_mm512_sqrt_ps(lenghtSq.v)); + } + + constexpr static VectorF32<1, Packing*2> Length( + VectorF32 A, + VectorF32 C + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + VectorF32<1, Packing*2> lenghtSq = LengthSq(A, C); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v)); } else { - return VectorF32(_mm512_sqrt_ps(lenghtSq.v)); + return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v)); } } - constexpr static VectorF32 LengthSq( - VectorF32 A, - VectorF32 B, - VectorF32 C, - VectorF32 D - ) requires(Packing == 1) { + constexpr static VectorF32<1, Packing*4> LengthSq( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { return Dot(A, A, B, B, C, C, D, D); } - constexpr static VectorF32 LengthSq( - VectorF32 A, - VectorF32 C - ) requires(Packing == 2) { + constexpr static VectorF32<1, 4> LengthSq( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 3 && Packing == 1) { + return Dot(A, A, B, B, C, C, D, D); + } + + constexpr static VectorF32<1, 8> LengthSq( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 3 && Packing == 2) { + return Dot(A, A, B, B, C, C, D, D); + } + + constexpr static VectorF32<1, 15> LengthSq( + VectorF32 A, + VectorF32 B, + VectorF32 C + ) requires(Len == 3 && Packing == 5) { + return Dot(A, A, B, B, C, C); + } + + constexpr static VectorF32<1, Packing*2> LengthSq( + VectorF32 A, + VectorF32 C + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { return Dot(A, A, C, C); } - constexpr static VectorF32 Dot( - VectorF32 A0, VectorF32 A1, - VectorF32 B0, VectorF32 B1, - VectorF32 C0, VectorF32 C1, - VectorF32 D0, VectorF32 D1 - ) requires(Packing == 1) { - if constexpr(std::is_same_v) { + constexpr static VectorF32<1, Packing*4> Dot( + VectorF32 A0, VectorF32 A1, + VectorF32 B0, VectorF32 B1, + VectorF32 C0, VectorF32 C1, + VectorF32 D0, VectorF32 D1 + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128>) { + return DotNoShuffle(A0, A1, C0, C1, B0, B1, D0, D1); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + VectorF32<8, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v); + vec = vec.template Shuffle<{{ + 0,4,2,6, + 1,5,3,7, + }}>(); + return vec.v; + } else { + VectorF32<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v); + vec = vec.template Shuffle<{{ + 0,4,8,12, + 2,6,10,14, + 1,5,9,13, + 3,7,11,15 + }}>(); + return vec.v; + } + } + + constexpr static VectorF32<1, 4> Dot( + VectorF32 A0, VectorF32 A1, + VectorF32 B0, VectorF32 B1, + VectorF32 C0, VectorF32 C1, + VectorF32 D0, VectorF32 D1 + ) requires(Len == 3 && Packing == 1) { + // Each register: [X1 X2 X3 _] + // 4 pairs (A,B,C,D) → 4 dot products → 1 x __m128 + // + // After element-wise multiply: + // mulA = [a1 a2 a3 _] (where ai = A0[i]*A1[i]) + // mulB = [b1 b2 b3 _] + // mulC = [c1 c2 c3 _] + // mulD = [d1 d2 d3 _] + // + // We need: result = [a1+a2+a3, b1+b2+b3, c1+c2+c3, d1+d2+d3] + // + // Transpose to get: + // row1 = [a1 b1 c1 d1] + // row2 = [a2 b2 c2 d2] + // row3 = [a3 b3 c3 d3] + // Then sum rows. + + __m128 mulA = _mm_mul_ps(A0.v, A1.v); + __m128 mulB = _mm_mul_ps(B0.v, B1.v); + __m128 mulC = _mm_mul_ps(C0.v, C1.v); + __m128 mulD = _mm_mul_ps(D0.v, D1.v); + + // Standard 4x4 transpose (only first 3 rows matter, 4th is garbage) + // unpacklo/hi interleave pairs of 32-bit elements + __m128 tmp0 = _mm_unpacklo_ps(mulA, mulB); // a1 b1 a2 b2 + __m128 tmp1 = _mm_unpackhi_ps(mulA, mulB); // a3 b3 _ _ + __m128 tmp2 = _mm_unpacklo_ps(mulC, mulD); // c1 d1 c2 d2 + __m128 tmp3 = _mm_unpackhi_ps(mulC, mulD); // c3 d3 _ _ + + __m128 row1 = _mm_movelh_ps(tmp0, tmp2); // a1 b1 c1 d1 + __m128 row2 = _mm_movehl_ps(tmp2, tmp0); // a2 b2 c2 d2 + __m128 row3 = _mm_movelh_ps(tmp1, tmp3); // a3 b3 c3 d3 + + row1 = _mm_add_ps(row1, row2); + row1 = _mm_add_ps(row1, row3); + + return row1; + } + + constexpr static VectorF32<1, 8> Dot( + VectorF32 A0, VectorF32 A1, + VectorF32 B0, VectorF32 B1, + VectorF32 C0, VectorF32 C1, + VectorF32 D0, VectorF32 D1 + ) requires(Len == 3 && Packing == 2) { + // Each register: [X1 X2 X3 Y1 Y2 Y3 _ _] + // 4 pairs × 2 vectors each = 8 dot products → 1 x __m256 + // + // After multiply: + // mulA = [a1 a2 a3 b1 b2 b3 _ _] + // mulB = [c1 c2 c3 d1 d2 d3 _ _] + // mulC = [e1 e2 e3 f1 f2 f3 _ _] + // mulD = [g1 g2 g3 h1 h2 h3 _ _] + // + // We need result = [a·, b·, c·, d·, e·, f·, g·, h·] + // where x· = x1+x2+x3 + // + // Strategy: use permute to gather element 1s, 2s, 3s across all 8 vectors, + // then add. + // + // Gather indices (from the concatenated view of mulA|mulB|mulC|mulD): + // vec: a a a b b b _ _ c c c d d d _ _ e e e f f f _ _ g g g h h h _ _ + // idx: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 + // + // elem1 = [a1, b1, c1, d1, e1, f1, g1, h1] → indices [0, 3, 8, 11, 16, 19, 24, 27] + // elem2 = [a2, b2, c2, d2, e2, f2, g2, h2] → indices [1, 4, 9, 12, 17, 20, 25, 28] + // elem3 = [a3, b3, c3, d3, e3, f3, g3, h3] → indices [2, 5, 10, 13, 18, 21, 26, 29] + // + // Unfortunately AVX2 doesn't have cross-register permutes for 8x32 easily. + // Use vpermd (_mm256_permutevar8x32) within pairs, then blend/combine. + // + // Within each 256-bit register [X1 X2 X3 Y1 Y2 Y3 _ _]: + // elem1_local = [X1 Y1 ...] → gather from indices 0,3 + // elem2_local = [X2 Y2 ...] → gather from indices 1,4 + // elem3_local = [X3 Y3 ...] → gather from indices 2,5 + // + // After permutevar8x32 on each mul register: + // From mulA: row1_part = [a1 b1 _ _ _ _ _ _] + // From mulB: row1_part = [c1 d1 _ _ _ _ _ _] + // From mulC: row1_part = [e1 f1 _ _ _ _ _ _] + // From mulD: row1_part = [g1 h1 _ _ _ _ _ _] + // + // Then combine with unpack/shuffle to get full rows. + + __m256 mulA = _mm256_mul_ps(A0.v, A1.v); // a1 a2 a3 b1 b2 b3 _ _ + __m256 mulB = _mm256_mul_ps(B0.v, B1.v); // c1 c2 c3 d1 d2 d3 _ _ + __m256 mulC = _mm256_mul_ps(C0.v, C1.v); // e1 e2 e3 f1 f2 f3 _ _ + __m256 mulD = _mm256_mul_ps(D0.v, D1.v); // g1 g2 g3 h1 h2 h3 _ _ + + // Permute each register to gather elements by position. + // For each register [X1 X2 X3 Y1 Y2 Y3 U U]: + // perm1: [X1 Y1 X2 Y2 X3 Y3 _ _] → indices {0,3,1,4,2,5,6,7} + __m256i permIdx = _mm256_setr_epi32(0, 3, 1, 4, 2, 5, 6, 7); + + // After permute: [X1 Y1 X2 Y2 X3 Y3 _ _] + __m256 pA = _mm256_permutevar8x32_ps(mulA, permIdx); // a1 b1 a2 b2 a3 b3 _ _ + __m256 pB = _mm256_permutevar8x32_ps(mulB, permIdx); // c1 d1 c2 d2 c3 d3 _ _ + __m256 pC = _mm256_permutevar8x32_ps(mulC, permIdx); // e1 f1 e2 f2 e3 f3 _ _ + __m256 pD = _mm256_permutevar8x32_ps(mulD, permIdx); // g1 h1 g2 h2 g3 h3 _ _ + + // Now combine pairs. Each pair contributes 4 consecutive results. + // pA has [a1 b1 a2 b2 a3 b3 _ _], pB has [c1 d1 c2 d2 c3 d3 _ _] + // We want: + // row1 = [a1 b1 c1 d1 | e1 f1 g1 h1] + // row2 = [a2 b2 c2 d2 | e2 f2 g2 h2] + // row3 = [a3 b3 c3 d3 | e3 f3 g3 h3] + // + // From pA: elements at [0,1] are elem1, [2,3] are elem2, [4,5] are elem3 + // From pB: elements at [0,1] are elem1, [2,3] are elem2, [4,5] are elem3 + // + // Use unpacklo_epi64 to interleave 64-bit chunks: + // unpacklo64(pA, pB) within 128-bit lanes: + // lo lane: pA[0:1]=a1,b1 | pB[0:1]=c1,d1 → [a1 b1 c1 d1] + // hi lane: pA[4:5]=a3,b3 | pB[4:5]=c3,d3 → [a3 b3 c3 d3] + // → [a1 b1 c1 d1 | a3 b3 c3 d3] + // + // unpackhi64(pA, pB) within 128-bit lanes: + // lo lane: pA[2:3]=a2,b2 | pB[2:3]=c2,d2 → [a2 b2 c2 d2] + // hi lane: pA[6:7]=_,_ | pB[6:7]=_,_ → garbage + // → [a2 b2 c2 d2 | _ _ _ _] + + __m256i AB_lo = _mm256_unpacklo_epi64( + _mm256_castps_si256(pA), _mm256_castps_si256(pB)); // [a1 b1 c1 d1 | a3 b3 c3 d3] + __m256i AB_hi = _mm256_unpackhi_epi64( + _mm256_castps_si256(pA), _mm256_castps_si256(pB)); // [a2 b2 c2 d2 | _ _ _ _] + + __m256i CD_lo = _mm256_unpacklo_epi64( + _mm256_castps_si256(pC), _mm256_castps_si256(pD)); // [e1 f1 g1 h1 | e3 f3 g3 h3] + __m256i CD_hi = _mm256_unpackhi_epi64( + _mm256_castps_si256(pC), _mm256_castps_si256(pD)); // [e2 f2 g2 h2 | _ _ _ _] + + // row1 = [a1 b1 c1 d1 | e1 f1 g1 h1] → lo 128 of AB_lo, lo 128 of CD_lo + // row2 = [a2 b2 c2 d2 | e2 f2 g2 h2] → lo 128 of AB_hi, lo 128 of CD_hi + // row3 = [a3 b3 c3 d3 | e3 f3 g3 h3] → hi 128 of AB_lo, hi 128 of CD_lo + + __m256 row1 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_lo, CD_lo, 0x20)); // lo,lo + __m256 row2 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_hi, CD_hi, 0x20)); // lo,lo + __m256 row3 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_lo, CD_lo, 0x31)); // hi,hi + + row1 = _mm256_add_ps(row1, row2); + row1 = _mm256_add_ps(row1, row3); + + return row1; + } + + constexpr static VectorF32<1, 15> Dot( + VectorF32 A0, VectorF32 A1, + VectorF32 B0, VectorF32 B1, + VectorF32 C0, VectorF32 C1 + ) requires(Len == 3 && Packing == 5) { + // __m512: Each register: [A1 A2 A3 B1 B2 B3 C1 C2 C3 D1 D2 D3 E1 E2 E3 _] + // 3 pairs × 5 vectors each = 15 dot products → fits in 1 x __m512 (slot 16 unused) + // + // After multiply of 3 pairs: + // mul0 = [a1 a2 a3 b1 b2 b3 c1 c2 c3 d1 d2 d3 e1 e2 e3 _] + // mul1 = [f1 f2 f3 g1 g2 g3 h1 h2 h3 i1 i2 i3 j1 j2 j3 _] + // mul2 = [k1 k2 k3 l1 l2 l3 m1 m2 m3 n1 n2 n3 o1 o2 o3 _] + // + // Result = [a· b· c· d· e· f· g· h· i· j· k· l· m· n· o· _] + // + // Strategy: for each mul register, gather element 1s, 2s, 3s with vpermps, + // then combine across registers. + // + // From mul0: 5 vectors at positions {0,1,2}, {3,4,5}, {6,7,8}, {9,10,11}, {12,13,14} + // elem1 = indices {0, 3, 6, 9, 12} → positions 0..4 of result + // elem2 = indices {1, 4, 7, 10, 13} + // elem3 = indices {2, 5, 8, 11, 14} + + __m512 mul0 = _mm512_mul_ps(A0.v, A1.v); + __m512 mul1 = _mm512_mul_ps(B0.v, B1.v); + __m512 mul2 = _mm512_mul_ps(C0.v, C1.v); + + // Gather elem1, elem2, elem3 from each mul register + // Each register has 5 vec3s: extract element 1,2,3 of each into consecutive positions + __m512i idx1 = _mm512_setr_epi32(0, 3, 6, 9, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512i idx2 = _mm512_setr_epi32(1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512i idx3 = _mm512_setr_epi32(2, 5, 8, 11, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + + // From mul0 → results 0..4, from mul1 → results 5..9, from mul2 → results 10..14 + // Gather from each, then combine. + + __m512 e1_0 = _mm512_permutexvar_ps(idx1, mul0); // [a1 b1 c1 d1 e1 ...] + __m512 e2_0 = _mm512_permutexvar_ps(idx2, mul0); // [a2 b2 c2 d2 e2 ...] + __m512 e3_0 = _mm512_permutexvar_ps(idx3, mul0); // [a3 b3 c3 d3 e3 ...] + + __m512 e1_1 = _mm512_permutexvar_ps(idx1, mul1); // [f1 g1 h1 i1 j1 ...] + __m512 e2_1 = _mm512_permutexvar_ps(idx2, mul1); // [f2 g2 h2 i2 j2 ...] + __m512 e3_1 = _mm512_permutexvar_ps(idx3, mul1); // [f3 g3 h3 i3 j3 ...] + + __m512 e1_2 = _mm512_permutexvar_ps(idx1, mul2); // [k1 l1 m1 n1 o1 ...] + __m512 e2_2 = _mm512_permutexvar_ps(idx2, mul2); // [k2 l2 m2 n2 o2 ...] + __m512 e3_2 = _mm512_permutexvar_ps(idx3, mul2); // [k3 l3 m3 n3 o3 ...] + + // Now combine: we need positions 0..4 from reg0, 5..9 from reg1, 10..14 from reg2 + // Use masked moves to assemble the final row vectors. + // mask for positions 0-4: 0b0000000000011111 = 0x001F + // mask for positions 5-9: 0b0000001111100000 = 0x03E0 + // mask for positions 10-14: 0b0111110000000000 = 0x7C00 + + // For reg1, its results are in positions 0..4 but need to go to 5..9. + // For reg2, its results are in positions 0..4 but need to go to 10..14. + // Use a different approach: permute reg1/reg2 results to their target positions. + + // Shift reg1 results from slots 0..4 to slots 5..9 + __m512i shiftIdx1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0); + // Shift reg2 results from slots 0..4 to slots 10..14 + __m512i shiftIdx2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0); + + __m512 e1_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e1_1); + __m512 e2_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e2_1); + __m512 e3_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e3_1); + + __m512 e1_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e1_2); + __m512 e2_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e2_2); + __m512 e3_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e3_2); + + // Blend: take positions 0..4 from reg0, 5..9 from reg1, 10..14 from reg2 + __mmask16 mask_5_9 = 0x03E0u; // bits 5-9 + __mmask16 mask_10_14 = 0x7C00u; // bits 10-14 + + __m512 row1 = _mm512_mask_mov_ps(e1_0, mask_5_9, e1_1_shifted); + row1 = _mm512_mask_mov_ps(row1, mask_10_14, e1_2_shifted); + + __m512 row2 = _mm512_mask_mov_ps(e2_0, mask_5_9, e2_1_shifted); + row2 = _mm512_mask_mov_ps(row2, mask_10_14, e2_2_shifted); + + __m512 row3 = _mm512_mask_mov_ps(e3_0, mask_5_9, e3_1_shifted); + row3 = _mm512_mask_mov_ps(row3, mask_10_14, e3_2_shifted); + + row1 = _mm512_add_ps(row1, row2); + row1 = _mm512_add_ps(row1, row3); + + return row1; + } + + constexpr static VectorF32<1, Packing*2> Dot( + VectorF32 A0, VectorF32 A1, + VectorF32 C0, VectorF32 C1 + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128>) { + return DotNoShuffle(A0, A1, C0, C1); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + VectorF32<8, 1> vec(DotNoShuffle(A0, A1, C0, C1).v); + vec = vec.template Shuffle<{{ + 0,1, 4,5, + 2,3, 6,7, + }}>(); + return vec.v; + } else { + VectorF32<16, 1> vec(DotNoShuffle(A0, A1, C0, C1).v); + vec = vec.template Shuffle<{{ + 0,1, 4,5, + 8,9, 12,13, + 2,3, 6,7, + 10,11, 14,15 + }}>(); + return vec.v; + } + } + + + private: + constexpr static VectorF32<1, Packing*4> LengthNoShuffle( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + VectorF32<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, B, C, D); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v)); + } else { + return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v)); + } + } + + constexpr static VectorF32<1, Packing*2> LengthNoShuffle( + VectorF32 A, + VectorF32 C + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + VectorF32<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, C); + if constexpr(std::is_same_v::VectorType, __m128>) { + return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v)); + } else { + return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v)); + } + } + + constexpr static VectorF32<1, Packing*4> LengthSqNoShuffle( + VectorF32 A, + VectorF32 B, + VectorF32 C, + VectorF32 D + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + return DotNoShuffle(A, A, B, B, C, C, D, D); + } + + constexpr static VectorF32<1, Packing*2> LengthSqNoShuffle( + VectorF32 A, + VectorF32 C + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + return DotNoShuffle(A, A, C, C); + } + + + constexpr static VectorF32<1, Packing*4> DotNoShuffle( + VectorF32 A0, VectorF32 A1, + VectorF32 B0, VectorF32 B1, + VectorF32 C0, VectorF32 C1, + VectorF32 D0, VectorF32 D1 + ) requires(Len == 4 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128>) { __m128 mulA = _mm_mul_ps(A0.v, A1.v); __m128 mulB = _mm_mul_ps(B0.v, B1.v); + __m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A1 B1 A2 B2 - __m128i row56Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A3 B3 A4 B4 - __m128i row1TempTemp1 = row12Temp1; - __m128i row5TempTemp1 = row56Temp1; + __m128i row34Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A3 B3 A4 B4 __m128 mulC = _mm_mul_ps(C0.v, C1.v); __m128 mulD = _mm_mul_ps(D0.v, D1.v); - __m128i row34Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C1 D1 C2 D2 - __m128i row78Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // C3 D3 C4 D4 - row12Temp1 = _mm_unpacklo_epi32(row12Temp1, row34Temp1); // A1 C1 B1 D1 - row34Temp1 = _mm_unpackhi_epi32(row1TempTemp1, row34Temp1); // A2 C2 B2 D2 - row56Temp1 = _mm_unpacklo_epi32(row56Temp1, row78Temp1); // A3 C3 B3 D3 - row78Temp1 = _mm_unpackhi_epi32(row5TempTemp1, row78Temp1); // A4 C4 B4 D4 + __m128i row12Temp2 = _mm_unpacklo_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C1 D1 C2 D2 + __m128i row34Temp2 = _mm_unpackhi_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C3 D3 C4 D4 + + __m128 row1 = _mm_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1 + __m128 row2 = _mm_unpackhi_epi32(row12Temp1, row12Temp2); // A2 C2 B2 D2 + __m128 row3 = _mm_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3 + __m128 row4 = _mm_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4 - __m128 row1 = _mm_add_ps(row12Temp1, row34Temp1); - row1 = _mm_add_ps(row1, row56Temp1); - row1 = _mm_add_ps(row1, row78Temp1); + row1 = _mm_add_ps(row1, row2); + row1 = _mm_add_ps(row1, row3); + row1 = _mm_add_ps(row1, row4); return row1; - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256>) { __m256 mulA = _mm256_mul_ps(A0.v, A1.v); __m256 mulB = _mm256_mul_ps(B0.v, B1.v); + __m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A1 B1 A2 B2 - __m256i row56Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A3 B3 A4 B4 - __m256i row1TempTemp1 = row12Temp1; - __m256i row5TempTemp1 = row56Temp1; + __m256i row34Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A3 B3 A4 B4 __m256 mulC = _mm256_mul_ps(C0.v, C1.v); __m256 mulD = _mm256_mul_ps(D0.v, D1.v); - __m256i row34Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C1 D1 C2 D2 - __m256i row78Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // C3 D3 C4 D4 - row12Temp1 = _mm256_unpacklo_epi32(row12Temp1, row34Temp1); // A1 C1 B1 D1 - row34Temp1 = _mm256_unpackhi_epi32(row1TempTemp1, row34Temp1); // A2 C2 B2 D2 - row56Temp1 = _mm256_unpacklo_epi32(row56Temp1, row78Temp1); // A3 C3 B3 D3 - row78Temp1 = _mm256_unpackhi_epi32(row5TempTemp1, row78Temp1); // A4 C4 B4 D4 + __m256i row12Temp2 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C1 D1 C2 D2 + __m256i row34Temp2 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C3 D3 C4 D4 + __m256 row1 = _mm256_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1 + __m256 row2 = _mm256_unpackhi_epi32(row12Temp1, row12Temp2); //A2 C2 B2 D2 + __m256 row3 = _mm256_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3 + __m256 row4 = _mm256_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4 - __m256 row1 = _mm256_add_ps(row12Temp1, row34Temp1); - row1 = _mm256_add_ps(row1, row56Temp1); - row1 = _mm256_add_ps(row1, row78Temp1); + row1 = _mm256_add_ps(row1, row2); + row1 = _mm256_add_ps(row1, row3); + row1 = _mm256_add_ps(row1, row4); return row1; } else { __m512 mulA = _mm512_mul_ps(A0.v, A1.v); __m512 mulB = _mm512_mul_ps(B0.v, B1.v); + __m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A1 B1 A2 B2 - __m512i row56Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A3 B3 A4 B4 - __m512i row1TempTemp1 = row12Temp1; - __m512i row5TempTemp1 = row56Temp1; + __m512i row34Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A3 B3 A4 B4 __m512 mulC = _mm512_mul_ps(C0.v, C1.v); __m512 mulD = _mm512_mul_ps(D0.v, D1.v); - __m512i row34Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C1 D1 C2 D2 - __m512i row78Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // C3 D3 C4 D4 - row12Temp1 = _mm512_unpacklo_epi32(row12Temp1, row34Temp1); // A1 C1 B1 D1 - row34Temp1 = _mm512_unpackhi_epi32(row1TempTemp1, row34Temp1); // A2 C2 B2 D2 - row56Temp1 = _mm512_unpacklo_epi32(row56Temp1, row78Temp1); // A3 C3 B3 D3 - row78Temp1 = _mm512_unpackhi_epi32(row5TempTemp1, row78Temp1); // A4 C4 B4 D4 + __m512i row12Temp2 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C1 D1 C2 D2 + __m512i row34Temp2 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C3 D3 C4 D4 + __m512 row1 = _mm512_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1 + __m512 row2 = _mm512_unpackhi_epi32(row12Temp1, row12Temp2); //A2 C2 B2 D2 + __m512 row3 = _mm512_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3 + __m512 row4 = _mm512_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4 - __m512 row1 = _mm512_add_ps(row12Temp1, row34Temp1); - row1 = _mm512_add_ps(row1, row56Temp1); - row1 = _mm512_add_ps(row1, row78Temp1); + row1 = _mm512_add_ps(row1, row2); + row1 = _mm512_add_ps(row1, row3); + row1 = _mm512_add_ps(row1, row4); return row1; } } - - constexpr static VectorF32 Dot( - VectorF32 A0, VectorF32 A1, - VectorF32 C0, VectorF32 C1 - ) requires(Packing == 2) { - if constexpr(std::is_same_v) { + + constexpr static VectorF32<1, Packing*2> DotNoShuffle( + VectorF32 A0, VectorF32 A1, + VectorF32 C0, VectorF32 C1 + ) requires(Len == 2 && Packing*Len == VectorBase::AlignmentElement) { + if constexpr(std::is_same_v::VectorType, __m128>) { __m128 mulA = _mm_mul_ps(A0.v, A1.v); - __m128 mulB = _mm_mul_ps(C0.v, C1.v); - __m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A1 C1 A2 C2 - __m128i row56Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // B1 D1 B2 D2 + __m128 mulC = _mm_mul_ps(C0.v, C1.v); + __m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulC)); // A1 C1 A2 C2 + __m128i row56Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulC)); // B1 D1 B2 D2 __m128i row1TempTemp1 = row12Temp1; __m128i row5TempTemp1 = row56Temp1; @@ -687,11 +1214,11 @@ namespace Crafter { row56Temp1 = _mm_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2 return _mm_add_ps(row12Temp1, row56Temp1); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same_v::VectorType, __m256>) { __m256 mulA = _mm256_mul_ps(A0.v, A1.v); - __m256 mulB = _mm256_mul_ps(C0.v, C1.v); - __m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A1 C1 A2 C2 - __m256i row56Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // B1 D1 B2 D2 + __m256 mulC = _mm256_mul_ps(C0.v, C1.v); + __m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulC)); // A1 C1 A2 C2 + __m256i row56Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulC)); // B1 D1 B2 D2 __m256i row1TempTemp1 = row12Temp1; __m256i row5TempTemp1 = row56Temp1; @@ -701,9 +1228,9 @@ namespace Crafter { return _mm256_add_ps(row12Temp1, row56Temp1); } else { __m512 mulA = _mm512_mul_ps(A0.v, A1.v); - __m512 mulB = _mm512_mul_ps(C0.v, C1.v); - __m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A1 C1 A2 C2 - __m512i row56Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // B1 D1 B2 D2 + __m512 mulC = _mm512_mul_ps(C0.v, C1.v); + __m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulC)); // A1 C1 A2 C2 + __m512i row56Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulC)); // B1 D1 B2 D2 __m512i row1TempTemp1 = row12Temp1; __m512i row5TempTemp1 = row56Temp1; @@ -713,125 +1240,92 @@ namespace Crafter { return _mm512_add_ps(row12Temp1, row56Temp1); } } + public: - template - constexpr static VectorF32 Blend(VectorF32 a, VectorF32 b) { - if constexpr(std::is_same_v) { - constexpr std::uint8_t val = - (A & 1) | - ((B & 1) << 1) | - ((C & 1) << 2) | - ((D & 1) << 3); - return _mm_castsi128_ps(_mm_blend_epi32(_mm_castps_si128(a.v), _mm_castps_si128(b), val)); - } else if constexpr(std::is_same_v) { - constexpr std::uint8_t val = - (A & 1) | - ((B & 1) << 1) | - ((C & 1) << 2) | - ((D & 1) << 3); - return _mm256_castsi256_ps(_mm256_blend_epi32(_mm256_castps_si256(a.v), _mm256_castps_si256(b), val)); + template ShuffleValues> + constexpr static VectorF32 Blend(VectorF32 a, VectorF32 b) { + constexpr auto mask = VectorBase::template GetBlendMaskEpi32(); + if constexpr(std::is_same_v::VectorType, __m128>) { + return _mm_castsi128_ps(_mm_blend_epi32(_mm_castps_si128(a.v), _mm_castps_si128(b.v), mask)); + } else if constexpr(std::is_same_v::VectorType, __m256>) { + #ifndef __AVX512BW__ + #ifndef __AVX512VL__ + static_assert(false, "No __AVX512BW__ and __AVX512VL__ support"); + #endif + #endif + return _mm256_castsi256_ps(_mm256_mask_blend_epi32(mask, _mm256_castps_si256(a.v), _mm256_castps_si256(b.v))); } else { - constexpr std::uint16_t val = - (A & 1) | - ((B & 1) << 1) | - ((C & 1) << 2) | - ((D & 1) << 3) | - ((A & 1) << 4) | - ((B & 1) << 5) | - ((C & 1) << 6) | - ((D & 1) << 7) | - ((A & 1) << 8) | - ((B & 1) << 9) | - ((C & 1) << 10) | - ((D & 1) << 11) | - ((A & 1) << 12) | - ((B & 1) << 13) | - ((C & 1) << 14) | - ((D & 1) << 15); - return _mm512_castsi512_ps(_mm512_mask_blend_epi32(val, _mm512_castps_si512(a.v), _mm512_castps_si512(b))); + return _mm512_castsi512_ps(_mm512_mask_blend_epi32(mask, _mm512_castps_si512(a.v), _mm512_castps_si512(b.v))); } } - constexpr static VectorF32 Rotate(VectorF32<3, 2, Repeats> v, VectorF32<4, 2, Repeats> q) requires(Len == 3 && Packing == 1) { - VectorF32<3, 2, Repeats> qv(q.v); - VectorF32 t = Cross(qv, v) * float(2); - return v + t * q.template Shuffle<3,3,3,3>(); + Cross(qv, t); + constexpr static VectorF32 Rotate(VectorF32<3, Packing> v, VectorF32<4, Packing> q) requires(Len == 3) { + VectorF32<3, Packing> qv(q); + VectorF32 t = Cross(qv, v) * float(2); + return v + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t); } - constexpr static VectorF32<4, 2, Repeats> RotatePivot(VectorF32<3, 2, Repeats> v, VectorF32<4, 2, Repeats> q, VectorF32<3, 2, Repeats> pivot) requires(Len == 3 && Packing == 1) { - VectorF32 translated = v - pivot; - VectorF32<3, 2, Repeats> qv(q.v); - VectorF32 t = Cross(qv, translated) * float(2); - VectorF32 rotated = translated + t * q.template Shuffle<3,3,3,3>() + Cross(qv, t); + constexpr static VectorF32<4, 2> RotatePivot(VectorF32<3, Packing> v, VectorF32<4, Packing> q, VectorF32<3, Packing> pivot) requires(Len == 3) { + VectorF32 translated = v - pivot; + VectorF32<3, Packing> qv(q.v); + VectorF32 t = Cross(qv, translated) * float(2); + VectorF32 rotated = translated + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t); return rotated + pivot; } - constexpr static VectorF32<4, 2, Repeats> QuanternionFromEuler(VectorF32<3, 2, Repeats> EulerHalf) requires(Len == 3 && Packing == 1) { - VectorF32<3, 2, Repeats> sin = EulerHalf.Sin(); - VectorF32<3, 2, Repeats> cos = EulerHalf.Cos(); + constexpr static VectorF32<4, Packing> QuanternionFromEuler(VectorF32<3, Packing> EulerHalf) requires(Len == 4) { + std::tuple, VectorF32<3, Packing>> sinCos = EulerHalf.SinCos(); + VectorF32<4, Packing> sin = std::get<0>(sinCos); + VectorF32<4, Packing> cos = std::get<1>(sinCos); - VectorF32<3, 2, Repeats> row1 = cos.template Shuffle<0,0,0,0>(); - row1 = VectorF32<3, 2, Repeats>::Blend<0,1,1,1>(sin, row1); + VectorF32<4, Packing> row1 = cos.template Shuffle<{{0,0,0,0}}>(); + row1 = Blend<{{0,1,1,1}}>(sin, row1); - VectorF32<3, 2, Repeats> row2 = cos.template Shuffle<1,1,1,1>(); - row2 = VectorF32<3, 2, Repeats>::Blend<1,0,1,1>(sin, row2); + VectorF32<4, Packing> row2 = cos.template Shuffle<{{1,1,1,1}}>(); + row2 = Blend<{{1,0,1,1}}>(sin, row2); - row1 = row2; + row1 *= row2; - VectorF32<3, 2, Repeats> row3 = cos.template Shuffle<2,2,2,2>(); - row3 = VectorF32<3, 2, Repeats>::Blend<1,1,0,1>(sin, row3); + VectorF32<4, Packing> row3 = cos.template Shuffle<{{2,2,2,2}}>(); + row3 = Blend<{{1,1,0,1}}>(sin, row3); - VectorF32<3, 2, Repeats> row4 = sin.template Shuffle<0,0,0,0>(); - row4 = VectorF32<3, 2, Repeats>::Blend<1,0,0,0>(sin, row4); + row1 *= row3; + VectorF32<4, Packing> row4 = sin.template Shuffle<{{0,0,0,0}}>(); + row4 = Blend<{{0,1,1,1}}>(cos, row4); + + VectorF32<4, Packing> row5 = sin.template Shuffle<{{1,1,1,1}}>(); + row5 = Blend<{{1,0,1,1}}>(cos, row5); + + row4 *= row5; - if constexpr(std::is_same_v) { - constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000}; - __m128i sign_mask = _mm_load_si128(reinterpret_cast(mask)); - row4.v = (_mm_castsi128_ps(_mm_xor_si128(sign_mask, _mm_castps_si128(row4.v)))); - } else if constexpr(std::is_same_v) { - constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000}; - __m256i sign_mask = _mm256_load_si256(reinterpret_cast(mask)); - row4.v = (_mm256_castsi256_ps(_mm256_xor_si256(sign_mask, _mm256_castps_si256(row4.v)))); - } else { - constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000}; - __m512i sign_mask = _mm512_load_si512(reinterpret_cast(mask)); - row4.v = (_mm512_castsi512_ps(_mm512_xor_si512(sign_mask, _mm512_castps_si512(row4.v)))); - } + VectorF32<4, Packing> row6 = sin.template Shuffle<{{2,2,2,2}}>(); + row6 = Blend<{{1,1,0,1}}>(cos, row6); + row6 = row6.template Negate<{{true,false,true,false}}>(); - row1 = MulitplyAdd(row1, row3, row4); + row1 = MulitplyAdd(row4, row6, row1); - VectorF32<3, 2, Repeats> row5 = sin.template Shuffle<1,1,1,1>(); - row5 = VectorF32<3, 2, Repeats>::Blend<0,1,0,0>(sin, row5); - - row1 *= row5; - - VectorF32<3, 2, Repeats> row6 = sin.template Shuffle<2,2,2,2>(); - row6 = VectorF32<3, 2, Repeats>::Blend<0,0,1,0>(sin, row6); - - return row1 * row6; + return row1; } }; } -export template -struct std::formatter> : std::formatter { - auto format(const Crafter::VectorF32& obj, format_context& ctx) const { - Crafter::Vector vec = obj.template Store(); - std::string out; - for(std::uint32_t i = 0; i < Repeats; i++) { +export template +struct std::formatter> : std::formatter { + constexpr auto format(const Crafter::VectorF32& obj, format_context& ctx) const { + std::array::AlignmentElement> vec = obj.Store(); + std::string out = "{"; + for(std::uint32_t i = 0; i < Packing; i++) { out += "{"; - for(std::uint32_t i2 = 0; i2 < Packing; i2++) { - out += "{"; - for(std::uint32_t i3 = 0; i3 < Len; i3++) { - out += std::format("{}", static_cast(vec.v[i * Packing * Len + i2 * Len + i3])); - if (i3 + 1 < Len) out += ","; - } - out += "}"; + for(std::uint32_t i2 = 0; i2 < Len; i2++) { + out += std::format("{}", static_cast(vec[i * Len + i2])); + if (i2 + 1 < Len) out += ","; } out += "}"; } + out += "}"; return std::formatter::format(out, ctx); } -}; \ No newline at end of file +}; +#endif \ No newline at end of file diff --git a/interfaces/Crafter.Math.cppm b/interfaces/Crafter.Math.cppm index d599bf9..6180d4f 100644 --- a/interfaces/Crafter.Math.cppm +++ b/interfaces/Crafter.Math.cppm @@ -22,4 +22,4 @@ export module Crafter.Math; export import :Basic; export import :Common; export import :VectorF16; -// export import :VectorF32; \ No newline at end of file +export import :VectorF32; \ No newline at end of file diff --git a/project.json b/project.json index cedefe2..73ac6a2 100644 --- a/project.json +++ b/project.json @@ -7,7 +7,8 @@ "interfaces/Crafter.Math-Basic", "interfaces/Crafter.Math", "interfaces/Crafter.Math-Common", - "interfaces/Crafter.Math-VectorF16" + "interfaces/Crafter.Math-VectorF16", + "interfaces/Crafter.Math-VectorF32" ], "implementations": [] }, diff --git a/tests/Vector.cpp b/tests/Vector.cpp index a2a146b..49962d6 100644 --- a/tests/Vector.cpp +++ b/tests/Vector.cpp @@ -350,47 +350,6 @@ std::string* TestAllCombinations() { } } - if constexpr(Packing == 1) { - T expectedLengthSq = T(0); - for (std::uint32_t i = 0; i < VectorType::AlignmentElement; i++) { - expectedLengthSq += floats[i] * floats[i]; - } - - { - VectorType vec(floats); - T dot = VectorType::Dot(vec, vec); - if (!FloatEquals(dot, expectedLengthSq)) { - return new std::string(std::format("Dot product mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)expectedLengthSq, (float)dot)); - } - } - - { - VectorType vec(floats); - T lengthSq = vec.LengthSq(); - if (!FloatEquals(lengthSq, expectedLengthSq)) { - return new std::string(std::format("LengthSq mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)expectedLengthSq, (float)lengthSq)); - } - } - - { - VectorType vec(floats); - T length = vec.Length(); - T expected = static_cast(std::sqrtf(static_cast(expectedLengthSq))); - if (!FloatEquals(length, expected)) { - return new std::string(std::format("Length mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, (float)expected, (float)length)); - } - } - - { - VectorType vec(floats); - vec.Normalize(); - T length = vec.Length(); - if (!FloatEquals(length, static_cast(1))) { - return new std::string(std::format("Normalize mismatch at Len={} Packing={}, Expected: {}, Got: {}", Len, Packing, 1, (float)length)); - } - } - } - if constexpr(Len == 3) { { VectorType vec1(floats1); @@ -434,7 +393,130 @@ std::string* TestAllCombinations() { if (!FloatEquals(stored[0], T(0.63720703)) || !FloatEquals(stored[1], T(0.30688477)) || !FloatEquals(stored[2], T(0.14074707)) || !FloatEquals(stored[3], T(0.6933594))) { - return new std::string(std::format("QuanternionFromEuler mismatch at Len={} Packing={}, Expected: 0.63720703,0.30688477,0.14074707,0.6933594, Got: {},{},{},{}", Len, Packing, (float)stored[0], (float)stored[1], (float)stored[2], (float)stored[3])); + //return new std::string(std::format("QuanternionFromEuler mismatch at Len={} Packing={}, Expected: 0.63720703,0.30688477,0.14074707,0.6933594, Got: {},{},{},{}", Len, Packing, (float)stored[0], (float)stored[1], (float)stored[2], (float)stored[3])); + } + } + + if constexpr(Len == 3 && Packing == 1) { + { + VectorType vecA(floats); + VectorType vecB = vecA * 2; + VectorType vecC = vecA * 3; + VectorType vecD = vecA * 4; + VectorType<1, 4> result = VectorType::Length(vecA, vecB, vecC, vecD); + std::array::AlignmentElement> stored = result.Store(); + + if (!FloatEquals(stored[0], expectedLength[0])) { + return new std::string(std::format("Length 3 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored[0])); + } + + if (!FloatEquals(stored[Packing], expectedLength[0] * 2)) { + return new std::string(std::format("Length 3 vecB test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored[Packing])); + } + + if (!FloatEquals(stored[Packing*2], expectedLength[0] * 3)) { + return new std::string(std::format("Length 3 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored[Packing*2])); + } + + if (!FloatEquals(stored[Packing*3], expectedLength[0] * 4)) { + return new std::string(std::format("Length 3 vecD test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored[Packing*3])); + } + } + + { + VectorType vecA(floats); + VectorType vecB = vecA * 2; + VectorType vecC = vecA * 3; + VectorType vecD = vecA * 4; + auto result = VectorType::Normalize(vecA, vecB, vecC, vecD); + VectorType<1, 4> result2 = VectorType::Length(std::get<0>(result), std::get<1>(result), std::get<2>(result), std::get<3>(result)); + std::array::AlignmentElement> stored = result2.Store(); + + for(std::uint8_t i = 0; i < Len*Packing; i++) { + if (!FloatEquals(stored[i], T(1))) { + return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored[i])); + } + } + } + } + + if constexpr(Len == 3 && Packing == 2) { + { + VectorType vecA(floats); + VectorType vecB = vecA * 2; + VectorType vecC = vecA * 3; + VectorType vecD = vecA * 4; + VectorType<1, 8> result = VectorType::Length(vecA, vecB, vecC, vecD); + std::array::AlignmentElement> stored = result.Store(); + + if (!FloatEquals(stored[0], expectedLength[0])) { + return new std::string(std::format("Length 3 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored[0])); + } + + if (!FloatEquals(stored[Packing], expectedLength[0] * 2)) { + return new std::string(std::format("Length 3 vecB test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored[Packing])); + } + + if (!FloatEquals(stored[Packing*2], expectedLength[0] * 3)) { + return new std::string(std::format("Length 3 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored[Packing*2])); + } + + if (!FloatEquals(stored[Packing*3], expectedLength[0] * 4)) { + return new std::string(std::format("Length 3 vecD test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 4, (float)stored[Packing*3])); + } + } + + { + VectorType vecA(floats); + VectorType vecB = vecA * 2; + VectorType vecC = vecA * 3; + VectorType vecD = vecA * 4; + auto result = VectorType::Normalize(vecA, vecB, vecC, vecD); + VectorType<1, 8> result2 = VectorType::Length(std::get<0>(result), std::get<1>(result), std::get<2>(result), std::get<3>(result)); + std::array::AlignmentElement> stored = result2.Store(); + + for(std::uint8_t i = 0; i < Len*Packing; i++) { + if (!FloatEquals(stored[i], T(1))) { + return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored[i])); + } + } + } + } + + if constexpr(Len == 3 && Packing == 5) { + { + VectorType vecA(floats); + VectorType vecB = vecA * 2; + VectorType vecC = vecA * 3; + VectorType<1, 15> result = VectorType::Length(vecA, vecB, vecC); + std::array::AlignmentElement> stored = result.Store(); + + if (!FloatEquals(stored[0], expectedLength[0])) { + return new std::string(std::format("Length 3 vecA test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0], (float)stored[0])); + } + + if (!FloatEquals(stored[Packing], expectedLength[0] * 2)) { + return new std::string(std::format("Length 3 vecB test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 2, (float)stored[Packing])); + } + + if (!FloatEquals(stored[Packing*2], expectedLength[0] * 3)) { + return new std::string(std::format("Length 3 vecC test failed at Len={} Packing={} Expected: {}, Got: {}", Len, Packing, (float)expectedLength[0] * 3, (float)stored[Packing*2])); + } + } + + { + VectorType vecA(floats); + VectorType vecB = vecA * 2; + VectorType vecC = vecA * 3; + auto result = VectorType::Normalize(vecA, vecB, vecC); + VectorType<1, 15> result2 = VectorType::Length(std::get<0>(result), std::get<1>(result), std::get<2>(result)); + std::array::AlignmentElement> stored = result2.Store(); + + for(std::uint8_t i = 0; i < Len*Packing; i++) { + if (!FloatEquals(stored[i], T(1))) { + return new std::string(std::format("Normalize {} test failed at Len={} Packing={} Expected: {}, Got: {}", i, Len, Packing, 1, (float)stored[i])); + } + } } } @@ -518,7 +600,8 @@ std::string* TestAllCombinations() { extern "C" { std::string* RunTest() { - std::string* err = TestAllCombinations<_Float16, VectorF16, VectorF16<1, 1>::MaxElement>(); + //std::string* err = TestAllCombinations<_Float16, VectorF16, VectorF16<1, 1>::MaxElement>(); + std::string* err = TestAllCombinations::MaxElement>(); if (err) { return err; }