Crafter.Math/interfaces/Crafter.Math-VectorF32.cppm
2026-03-31 11:52:13 +02:00

1331 lines
No EOL
75 KiB
C++
Executable file
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
Crafter®.Math
Copyright (C) 2026 Catcrafts®
catcrafts.net
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License version 3.0 as published by the Free Software Foundation;
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
module;
#ifdef __x86_64
#include <immintrin.h>
#endif
export module Crafter.Math:VectorF32;
import std;
import :Common;
#ifdef __AVX512FP16__
namespace Crafter {
export template <std::uint8_t Len, std::uint8_t Packing>
struct VectorF32 : public VectorBase<Len, Packing, float> {
template <std::uint8_t Len2, std::uint8_t Packing2>
friend struct VectorF32;
constexpr VectorF32() = default;
constexpr VectorF32(VectorBase<Len, Packing, float>::VectorType v) {
this->v = v;
}
constexpr VectorF32(const float* vB) {
Load(vB);
};
constexpr VectorF32(float val) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_set1_ps(val);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_set1_ps(val);
} else {
this->v = _mm512_set1_ps(val);
}
};
constexpr void Load(const float* vB) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_loadu_ps(vB);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_loadu_ps(vB);
} else {
this->v = _mm512_loadu_ps(vB);
}
}
constexpr void Store(float* vB) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
_mm_storeu_ps(vB, this->v);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
_mm256_storeu_ps(vB, this->v);
} else {
_mm512_storeu_ps(vB, this->v);
}
}
constexpr std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> Store() const {
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> returnArray;
Store(returnArray.data());
return returnArray;
}
template <std::uint8_t BLen, std::uint8_t BPacking>
constexpr operator VectorF32<BLen, BPacking>() const {
if constexpr (Len == BLen) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<BLen, BPacking>(_mm256_castps256_ps128(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<BLen, BPacking>(_mm512_castps512_ps128(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<BLen, BPacking>(_mm512_castps512_ps256(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<BLen, BPacking>(_mm256_castps128_ps256(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512>) {
return VectorF32<BLen, BPacking>(_mm512_castps128_ps512(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512>) {
return VectorF32<BLen, BPacking>(_mm512_castps256_ps512(this->v));
} else {
return VectorF32<BLen, BPacking>(this->v);
}
} else if constexpr (BLen <= Len) {
return this->template ExtractLo<BLen>();
} else {
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi8<BLen>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask =VectorBase<Len, Packing, float>::template GetExtractLoMaskepi32<BLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm_castps_si256(this->v));
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm256_castsi256_si128(result)));
} else {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
return VectorF32<BLen, BPacking>(_mm_castsi128_ps(_mm512_castsi512_si128(result)));
}
} else if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m256>) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castsi128_si256(_mm_castps_si128(this->v)));
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
} else {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m256i permIdx = _mm512_loadu_epi32(permMask.data());
__m256i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi512_si256(_mm512_castps_si512(this->v)));
return VectorF32<BLen, BPacking>(_mm256_castsi256_ps(result));
}
} else {
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, float>::VectorType, __m128>) {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi128_si512(_mm_castps_si128(this->v)));
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castsi256_si512(_mm256_castps_si256(this->v)));
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
} else {
constexpr std::array<std::uint32_t, VectorBase<BLen, Packing, float>::AlignmentElement> permMask = VectorBase<BLen, Packing, float>::template GetExtractLoMaskEpi32<BLen>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
return VectorF32<BLen, BPacking>(_mm512_castsi512_ps(result));
}
}
}
}
constexpr VectorF32<Len, Packing> operator+(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_add_ps(this->v, b.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_add_ps(this->v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_add_ps(this->v, b.v));
}
}
constexpr VectorF32<Len, Packing> operator-(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_sub_ps(this->v, b.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_sub_ps(this->v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_sub_ps(this->v, b.v));
}
}
constexpr VectorF32<Len, Packing> operator*(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_mul_ps(this->v, b.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_mul_ps(this->v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_mul_ps(this->v, b.v));
}
}
constexpr VectorF32<Len, Packing> operator/(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_div_ps(this->v, b.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_div_ps(this->v, b.v));
} else {
return VectorF32<Len, Packing>(_mm512_div_ps(this->v, b.v));
}
}
constexpr void operator+=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_add_ps(this->v, b.v);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_add_ps(this->v, b.v);
} else {
this->v = _mm512_add_ps(this->v, b.v);
}
}
constexpr void operator-=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_sub_ps(this->v, b.v);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_sub_ps(this->v, b.v);
} else {
this->v = _mm512_sub_ps(this->v, b.v);
}
}
constexpr void operator*=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_mul_ps(this->v, b.v);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_mul_ps(this->v, b.v);
} else {
this->v = _mm512_mul_ps(this->v, b.v);
}
}
constexpr void operator/=(VectorF32<Len, Packing> b) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
this->v = _mm_div_ps(this->v, b.v);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
this->v = _mm256_div_ps(this->v, b.v);
} else {
this->v = _mm512_div_ps(this->v, b.v);
}
}
constexpr VectorF32<Len, Packing> operator+(float b) {
VectorF32<Len, Packing> vB(b);
return *this + vB;
}
constexpr VectorF32<Len, Packing> operator-(float b) {
VectorF32<Len, Packing> vB(b);
return *this - vB;
}
constexpr VectorF32<Len, Packing> operator*(float b) {
VectorF32<Len, Packing> vB(b);
return *this * vB;
}
constexpr VectorF32<Len, Packing> operator/(float b) {
VectorF32<Len, Packing> vB(b);
return *this / vB;
}
constexpr void operator+=(float b) {
VectorF32<Len, Packing> vB(b);
*this += vB;
}
constexpr void operator-=(float b) {
VectorF32<Len, Packing> vB(b);
*this -= vB;
}
constexpr void operator*=(float b) {
VectorF32<Len, Packing> vB(b);
*this *= vB;
}
constexpr void operator/=(float b) {
VectorF32<Len, Packing> vB(b);
*this /= vB;
}
constexpr VectorF32<Len, Packing> operator-(){
return Negate<VectorBase<Len, Packing, float>::GetAllTrue()>();
}
constexpr bool operator==(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 15;
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 255;
} else {
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) == 65535;
}
}
constexpr bool operator!=(VectorF32<Len, Packing> b) const {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return _mm_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 15;
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return _mm256_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 255;
} else {
return _mm512_cmp_ps_mask(this->v, b.v, _CMP_EQ_OQ) != 65535;
}
}
template<std::uint32_t ExtractLen>
constexpr VectorF32<ExtractLen, Packing> ExtractLo() const {
if constexpr(Packing > 1) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint8_t,VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi8<ExtractLen>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF32<ExtractLen, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<ExtractLen>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
__m256i result = _mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v));
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(_mm256_castsi256_ps(result)));
} else {
return VectorF32<ExtractLen, Packing>(_mm256_castsi256_ps(result));
}
} else {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetExtractLoMaskEpi32<ExtractLen>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
__m512i result = _mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v));
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps128(_mm512_castsi512_ps(result)));
} else if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m256>) {
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps256(_mm512_castsi512_ps(result)));
} else {
return VectorF32<ExtractLen, Packing>(_mm512_castsi512_ps(result));
}
}
} else {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
return VectorF32<ExtractLen, Packing>(_mm256_castps256_ps128(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m128>) {
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps128(this->v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m512> && std::is_same_v<typename VectorBase<ExtractLen, Packing, float>::VectorType, __m256>) {
return VectorF32<ExtractLen, Packing>(_mm512_castps512_ps256(this->v));
} else {
return VectorF32<ExtractLen, Packing>(this->v);
}
}
}
constexpr VectorF32<Len, Packing> Cos() {
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x4(this->v));
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x8(this->v));
} else {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::cos_f32x16(this->v));
}
}
constexpr VectorF32<Len, Packing> Sin() {
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x4(this->v));
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x8(this->v));
} else {
return VectorF32<Len, Packing>(VectorBase<Len, Packing, float>::sin_f32x16(this->v));
}
}
std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>> SinCos() {
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
__m128 s, c;
VectorBase<Len, Packing, float>::sincos_f32x4(this->v, s, c);
return {
VectorF32<Len, Packing>(s),
VectorF32<Len, Packing>(c)
};
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
__m256 s, c;
VectorBase<Len, Packing, float>::sincos_f32x8(this->v, s, c);
return {
VectorF32<Len, Packing>(s),
VectorF32<Len, Packing>(c)
};
} else {
__m512 s, c;
VectorBase<Len, Packing, float>::sincos_f32x16(this->v, s, c);
return {
VectorF32<Len, Packing>(s),
VectorF32<Len, Packing>(c)
};
}
}
template <std::array<bool, Len> values>
constexpr VectorF32<Len, Packing> Negate() {
std::array<float, VectorBase<Len, Packing, float>::AlignmentElement> mask = VectorBase<Len, Packing, float>::template GetNegateMask<values>();
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_xor_si128(_mm_castps_si128(this->v), _mm_loadu_epi32(mask.data()))));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(this->v), _mm256_loadu_epi32(mask.data()))));
} else {
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(this->v), _mm512_loadu_epi32(mask.data()))));
}
}
static constexpr VectorF32<Len, Packing> MulitplyAdd(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b, VectorF32<Len, Packing> add) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_fmadd_ps(a.v, b.v, add.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_fmadd_ps(a.v, b.v, add.v));
} else {
return VectorF32<Len, Packing>(_mm512_fmadd_ps(a.v, b.v, add.v));
}
}
static constexpr VectorF32<Len, Packing> MulitplySub(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b, VectorF32<Len, Packing> sub) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_fmsub_ps(a.v, b.v, sub.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_fmsub_ps(a.v, b.v, sub.v));
} else {
return VectorF32<Len, Packing>(_mm512_fmsub_ps(a.v, b.v, sub.v));
}
}
constexpr static VectorF32<Len, Packing> Cross(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) requires(Len == 3) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
__m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data());
__m128 row1 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec1));
__m128 row4 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec1));
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
__m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data());
__m128 row3 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(a.v), shuffleVec3));
__m128 row2 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(b.v), shuffleVec3));
__m128 result = _mm_mul_ps(row3, row4);
return _mm_fmsub_ps(row1,row2,result);
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
__m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data()));
__m256 row1 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec1)));
__m256 row4 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec1)));
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
__m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data()));
__m256 row3 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(a.v)), shuffleVec3)));
__m256 row2 = _mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(b.v)), shuffleVec3)));
__m256 result = _mm256_mul_ps(row3, row4);
return _mm256_fmsub_ps(row1,row2,result);
} else {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask1 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{1,2,0}}>();
__m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data());
__m512 row1 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec1));
__m512 row4 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec1));
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask3 = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<{{2,0,1}}>();
__m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data());
__m512 row3 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(a.v), shuffleVec3));
__m512 row2 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b.v), shuffleVec3));
__m512 result = _mm512_mul_ps(row3, row4);
return _mm512_fmsub_ps(row1,row2,result);
}
}
template <const std::array<std::uint8_t, Len> ShuffleValues>
constexpr VectorF32<Len, Packing> Shuffle() {
if constexpr(VectorBase<Len, Packing, float>::template CheckEpi32Shuffle<ShuffleValues>()) {
constexpr std::uint8_t imm = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi32<ShuffleValues>();
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(this->v), imm)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_shuffle_epi32(_mm256_castps_si256(this->v), imm)));
} else {
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi32(_mm512_castps_si512(this->v), imm)));
}
} else if constexpr(VectorBase<Len, Packing, float>::template CheckEpi8Shuffle<ShuffleValues>()){
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<ShuffleValues>();
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castps_si256(this->v)), _mm512_castsi256_si512(shuffleVec)))));
} else {
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(this->v), shuffleVec)));
}
} else {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, float>::Alignment> shuffleMask = VectorBase<Len, Packing, float>::template GetShuffleMaskEpi8<ShuffleValues>();
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
return VectorF32<Len, Packing>(_mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(this->v), shuffleVec)));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
__m256i permIdx = _mm256_loadu_epi32(permMask.data());
return VectorF32<Len, Packing>(_mm256_castsi256_ps(_mm256_permutexvar_epi32(permIdx, _mm256_castps_si256(this->v))));
} else {
constexpr std::array<std::uint32_t, VectorBase<Len, Packing, float>::AlignmentElement> permMask = VectorBase<Len, Packing, float>::template GetPermuteMaskEpi32<ShuffleValues>();
__m512i permIdx = _mm512_loadu_epi32(permMask.data());
return VectorF32<Len, Packing>(_mm512_castsi512_ps(_mm512_permutexvar_epi32(permIdx, _mm512_castps_si512(this->v))));
}
}
}
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
VectorF32<1, 4> lenght = LengthNoShuffle(A, C, B, D);
constexpr float oneArr[] {1, 1, 1, 1};
__m128 one = _mm_loadu_ps(oneArr);
VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v));
VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0}}>();
VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1}}>();
VectorF32<4, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2}}>();
VectorF32<4, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3}}>();
return {
_mm_mul_ps(A.v, fLenghtA.v),
_mm_mul_ps(B.v, fLenghtB.v),
_mm_mul_ps(C.v, fLenghtC.v),
_mm_mul_ps(D.v, fLenghtD.v)
};
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
VectorF32<1, 8> lenght = LengthNoShuffle(A, C, B, D);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m256 one = _mm256_loadu_ps(oneArr);
VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v));
VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,4,4,4,4}}>();
VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1,5,5,5,5}}>();
VectorF32<8, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,6,6,6,6}}>();
VectorF32<8, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3,7,7,7,7}}>();
return {
_mm256_mul_ps(A.v, fLenghtA.v),
_mm256_mul_ps(B.v, fLenghtB.v),
_mm256_mul_ps(C.v, fLenghtC.v),
_mm256_mul_ps(D.v, fLenghtD.v)
};
} else {
VectorF32<1, 16> lenght = LengthNoShuffle(A, C, B, D);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512 one = _mm512_loadu_ps(oneArr);
VectorF32<16, 1> fLenght(_mm512_div_ps(one, lenght.v));
VectorF32<16, 1> fLenght2(lenght.v);
VectorF32<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,4,4,4,4,8,8,8,8,12,12,12,12}}>();
VectorF32<16, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1,5,5,5,5,9,9,9,9,13,13,13,13}}>();
VectorF32<16, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14}}>();
VectorF32<16, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3,7,7,7,7,11,11,11,11,15,15,15,15}}>();
return {
VectorF32<Len, Packing>(_mm512_mul_ps(A.v, fLenghtA.v)),
VectorF32<Len, Packing>(_mm512_mul_ps(B.v, fLenghtB.v)),
VectorF32<Len, Packing>(_mm512_mul_ps(C.v, fLenghtC.v)),
VectorF32<Len, Packing>(_mm512_mul_ps(D.v, fLenghtD.v)),
};
}
}
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 3 && Packing == 1) {
VectorF32<1, 4> lenght = Length(A, B, C, D);
constexpr float oneArr[] {1, 1, 1, 1};
__m128 one = _mm_loadu_ps(oneArr);
VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v));
VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0}}>();
VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{1,1,1,1}}>();
VectorF32<4, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2}}>();
VectorF32<4, 1> fLenghtD = fLenght.template Shuffle<{{3,3,3,3}}>();
return {
_mm_mul_ps(A.v, fLenghtA.v),
_mm_mul_ps(B.v, fLenghtB.v),
_mm_mul_ps(C.v, fLenghtC.v),
_mm_mul_ps(D.v, fLenghtD.v)
};
}
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 3 && Packing == 2) {
VectorF32<1, 8> lenght = Length(A, B, C, D);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m256 one = _mm256_loadu_ps(oneArr);
VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v));
VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0, 1,1,1}}>();
VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{2,2,2, 3,3,3}}>();
VectorF32<8, 1> fLenghtC = fLenght.template Shuffle<{{4,4,4, 5,5,5}}>();
VectorF32<8, 1> fLenghtD = fLenght.template Shuffle<{{6,6,6, 7,7,7}}>();
return {
_mm256_mul_ps(A.v, fLenghtA.v),
_mm256_mul_ps(B.v, fLenghtB.v),
_mm256_mul_ps(C.v, fLenghtC.v),
_mm256_mul_ps(D.v, fLenghtD.v)
};
}
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C
) requires(Len == 3 && Packing == 5) {
VectorF32<1, 15> lenght = Length(A, B, C);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512 one = _mm512_loadu_ps(oneArr);
VectorF32<15, 1> fLenght(_mm512_div_ps(one, lenght.v));
VectorF32<15, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0, 1,1,1, 2,2,2, 3,3,3, 4,4,4}}>();
VectorF32<15, 1> fLenghtB = fLenght.template Shuffle<{{5,5,5, 6,6,6, 7,7,7, 8,8,8, 9,9,9}}>();
VectorF32<15, 1> fLenghtC = fLenght.template Shuffle<{{10,10,10, 11,11,11, 12,12,12, 13,13,13, 14,14,14}}>();
return {
_mm512_mul_ps(A.v, fLenghtA.v),
_mm512_mul_ps(B.v, fLenghtB.v),
_mm512_mul_ps(C.v, fLenghtC.v),
};
}
constexpr static std::tuple<VectorF32<Len, Packing>, VectorF32<Len, Packing>> Normalize(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
VectorF32<1, 4> lenght = LengthNoShuffle(A, B);
constexpr float oneArr[] {1, 1, 1, 1};
__m128 one = _mm_loadu_ps(oneArr);
VectorF32<4, 1> fLenght(_mm_div_ps(one, lenght.v));
VectorF32<4, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1}}>();
VectorF32<4, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3}}>();
return {
_mm_mul_ps(A.v, fLenghtA.v),
_mm_mul_ps(B.v, fLenghtB.v),
};
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
VectorF32<1, 8> lenght = LengthNoShuffle(A, B);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
__m256 one = _mm256_loadu_ps(oneArr);
VectorF32<8, 1> fLenght(_mm256_div_ps(one, lenght.v));
VectorF32<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,4,4,5,5}}>();
VectorF32<8, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3,6,6,7,7}}>();
return {
_mm256_mul_ps(A.v, fLenghtA.v),
_mm256_mul_ps(B.v, fLenghtB.v),
};
} else {
VectorF32<1, 16> lenght = LengthNoShuffle(A, B);
constexpr float oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
__m512 one = _mm512_loadu_ps(oneArr);
VectorF32<16, 1> fLenght(_mm512_div_ps(one, lenght.v));
VectorF32<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,4,4,5,5,8,8,9,9,12,12,13,13}}>();
VectorF32<16, 1> fLenghtB = fLenght.template Shuffle<{{2,2,3,3,6,6,7,7,10,10,11,11,14,14,15,15}}>();
return {
_mm512_mul_ps(A.v, fLenghtA.v),
_mm512_mul_ps(B.v, fLenghtB.v),
};
}
}
constexpr static VectorF32<1, Packing*4> Length(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
VectorF32<1, Packing*4> lenghtSq = LengthSq(A, B, C, D);
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
} else {
return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v));
}
}
constexpr static VectorF32<1, 4> Length(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 3 && Packing == 1) {
VectorF32<1, 4> lenghtSq = LengthSq(A, B, C, D);
return VectorF32<1, 4>(_mm_sqrt_ps(lenghtSq.v));
}
constexpr static VectorF32<1, 8> Length(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 3 && Packing == 2) {
VectorF32<1, 8> lenghtSq = LengthSq(A, B, C, D);
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
}
constexpr static VectorF32<1, 15> Length(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C
) requires(Len == 3 && Packing == 5) {
VectorF32<1, 15> lenghtSq = LengthSq(A, B, C);
return VectorF32<1, 15>(_mm512_sqrt_ps(lenghtSq.v));
}
constexpr static VectorF32<1, Packing*2> Length(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> C
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
VectorF32<1, Packing*2> lenghtSq = LengthSq(A, C);
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
} else {
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
}
}
constexpr static VectorF32<1, Packing*4> LengthSq(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
return Dot(A, A, B, B, C, C, D, D);
}
constexpr static VectorF32<1, 4> LengthSq(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 3 && Packing == 1) {
return Dot(A, A, B, B, C, C, D, D);
}
constexpr static VectorF32<1, 8> LengthSq(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 3 && Packing == 2) {
return Dot(A, A, B, B, C, C, D, D);
}
constexpr static VectorF32<1, 15> LengthSq(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C
) requires(Len == 3 && Packing == 5) {
return Dot(A, A, B, B, C, C);
}
constexpr static VectorF32<1, Packing*2> LengthSq(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> C
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
return Dot(A, A, C, C);
}
constexpr static VectorF32<1, Packing*4> Dot(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return DotNoShuffle(A0, A1, C0, C1, B0, B1, D0, D1);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
VectorF32<8, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v);
vec = vec.template Shuffle<{{
0,4,2,6,
1,5,3,7,
}}>();
return vec.v;
} else {
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, B0, B1, C0, C1, D0, D1).v);
vec = vec.template Shuffle<{{
0,4,8,12,
2,6,10,14,
1,5,9,13,
3,7,11,15
}}>();
return vec.v;
}
}
constexpr static VectorF32<1, 4> Dot(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
) requires(Len == 3 && Packing == 1) {
// Each register: [X1 X2 X3 _]
// 4 pairs (A,B,C,D) → 4 dot products → 1 x __m128
//
// After element-wise multiply:
// mulA = [a1 a2 a3 _] (where ai = A0[i]*A1[i])
// mulB = [b1 b2 b3 _]
// mulC = [c1 c2 c3 _]
// mulD = [d1 d2 d3 _]
//
// We need: result = [a1+a2+a3, b1+b2+b3, c1+c2+c3, d1+d2+d3]
//
// Transpose to get:
// row1 = [a1 b1 c1 d1]
// row2 = [a2 b2 c2 d2]
// row3 = [a3 b3 c3 d3]
// Then sum rows.
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
__m128 mulB = _mm_mul_ps(B0.v, B1.v);
__m128 mulC = _mm_mul_ps(C0.v, C1.v);
__m128 mulD = _mm_mul_ps(D0.v, D1.v);
// Standard 4x4 transpose (only first 3 rows matter, 4th is garbage)
// unpacklo/hi interleave pairs of 32-bit elements
__m128 tmp0 = _mm_unpacklo_ps(mulA, mulB); // a1 b1 a2 b2
__m128 tmp1 = _mm_unpackhi_ps(mulA, mulB); // a3 b3 _ _
__m128 tmp2 = _mm_unpacklo_ps(mulC, mulD); // c1 d1 c2 d2
__m128 tmp3 = _mm_unpackhi_ps(mulC, mulD); // c3 d3 _ _
__m128 row1 = _mm_movelh_ps(tmp0, tmp2); // a1 b1 c1 d1
__m128 row2 = _mm_movehl_ps(tmp2, tmp0); // a2 b2 c2 d2
__m128 row3 = _mm_movelh_ps(tmp1, tmp3); // a3 b3 c3 d3
row1 = _mm_add_ps(row1, row2);
row1 = _mm_add_ps(row1, row3);
return row1;
}
constexpr static VectorF32<1, 8> Dot(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
) requires(Len == 3 && Packing == 2) {
// Each register: [X1 X2 X3 Y1 Y2 Y3 _ _]
// 4 pairs × 2 vectors each = 8 dot products → 1 x __m256
//
// After multiply:
// mulA = [a1 a2 a3 b1 b2 b3 _ _]
// mulB = [c1 c2 c3 d1 d2 d3 _ _]
// mulC = [e1 e2 e3 f1 f2 f3 _ _]
// mulD = [g1 g2 g3 h1 h2 h3 _ _]
//
// We need result = [a·, b·, c·, d·, e·, f·, g·, h·]
// where x· = x1+x2+x3
//
// Strategy: use permute to gather element 1s, 2s, 3s across all 8 vectors,
// then add.
//
// Gather indices (from the concatenated view of mulA|mulB|mulC|mulD):
// vec: a a a b b b _ _ c c c d d d _ _ e e e f f f _ _ g g g h h h _ _
// idx: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
//
// elem1 = [a1, b1, c1, d1, e1, f1, g1, h1] → indices [0, 3, 8, 11, 16, 19, 24, 27]
// elem2 = [a2, b2, c2, d2, e2, f2, g2, h2] → indices [1, 4, 9, 12, 17, 20, 25, 28]
// elem3 = [a3, b3, c3, d3, e3, f3, g3, h3] → indices [2, 5, 10, 13, 18, 21, 26, 29]
//
// Unfortunately AVX2 doesn't have cross-register permutes for 8x32 easily.
// Use vpermd (_mm256_permutevar8x32) within pairs, then blend/combine.
//
// Within each 256-bit register [X1 X2 X3 Y1 Y2 Y3 _ _]:
// elem1_local = [X1 Y1 ...] → gather from indices 0,3
// elem2_local = [X2 Y2 ...] → gather from indices 1,4
// elem3_local = [X3 Y3 ...] → gather from indices 2,5
//
// After permutevar8x32 on each mul register:
// From mulA: row1_part = [a1 b1 _ _ _ _ _ _]
// From mulB: row1_part = [c1 d1 _ _ _ _ _ _]
// From mulC: row1_part = [e1 f1 _ _ _ _ _ _]
// From mulD: row1_part = [g1 h1 _ _ _ _ _ _]
//
// Then combine with unpack/shuffle to get full rows.
__m256 mulA = _mm256_mul_ps(A0.v, A1.v); // a1 a2 a3 b1 b2 b3 _ _
__m256 mulB = _mm256_mul_ps(B0.v, B1.v); // c1 c2 c3 d1 d2 d3 _ _
__m256 mulC = _mm256_mul_ps(C0.v, C1.v); // e1 e2 e3 f1 f2 f3 _ _
__m256 mulD = _mm256_mul_ps(D0.v, D1.v); // g1 g2 g3 h1 h2 h3 _ _
// Permute each register to gather elements by position.
// For each register [X1 X2 X3 Y1 Y2 Y3 U U]:
// perm1: [X1 Y1 X2 Y2 X3 Y3 _ _] → indices {0,3,1,4,2,5,6,7}
__m256i permIdx = _mm256_setr_epi32(0, 3, 1, 4, 2, 5, 6, 7);
// After permute: [X1 Y1 X2 Y2 X3 Y3 _ _]
__m256 pA = _mm256_permutevar8x32_ps(mulA, permIdx); // a1 b1 a2 b2 a3 b3 _ _
__m256 pB = _mm256_permutevar8x32_ps(mulB, permIdx); // c1 d1 c2 d2 c3 d3 _ _
__m256 pC = _mm256_permutevar8x32_ps(mulC, permIdx); // e1 f1 e2 f2 e3 f3 _ _
__m256 pD = _mm256_permutevar8x32_ps(mulD, permIdx); // g1 h1 g2 h2 g3 h3 _ _
// Now combine pairs. Each pair contributes 4 consecutive results.
// pA has [a1 b1 a2 b2 a3 b3 _ _], pB has [c1 d1 c2 d2 c3 d3 _ _]
// We want:
// row1 = [a1 b1 c1 d1 | e1 f1 g1 h1]
// row2 = [a2 b2 c2 d2 | e2 f2 g2 h2]
// row3 = [a3 b3 c3 d3 | e3 f3 g3 h3]
//
// From pA: elements at [0,1] are elem1, [2,3] are elem2, [4,5] are elem3
// From pB: elements at [0,1] are elem1, [2,3] are elem2, [4,5] are elem3
//
// Use unpacklo_epi64 to interleave 64-bit chunks:
// unpacklo64(pA, pB) within 128-bit lanes:
// lo lane: pA[0:1]=a1,b1 | pB[0:1]=c1,d1 → [a1 b1 c1 d1]
// hi lane: pA[4:5]=a3,b3 | pB[4:5]=c3,d3 → [a3 b3 c3 d3]
// → [a1 b1 c1 d1 | a3 b3 c3 d3]
//
// unpackhi64(pA, pB) within 128-bit lanes:
// lo lane: pA[2:3]=a2,b2 | pB[2:3]=c2,d2 → [a2 b2 c2 d2]
// hi lane: pA[6:7]=_,_ | pB[6:7]=_,_ → garbage
// → [a2 b2 c2 d2 | _ _ _ _]
__m256i AB_lo = _mm256_unpacklo_epi64(
_mm256_castps_si256(pA), _mm256_castps_si256(pB)); // [a1 b1 c1 d1 | a3 b3 c3 d3]
__m256i AB_hi = _mm256_unpackhi_epi64(
_mm256_castps_si256(pA), _mm256_castps_si256(pB)); // [a2 b2 c2 d2 | _ _ _ _]
__m256i CD_lo = _mm256_unpacklo_epi64(
_mm256_castps_si256(pC), _mm256_castps_si256(pD)); // [e1 f1 g1 h1 | e3 f3 g3 h3]
__m256i CD_hi = _mm256_unpackhi_epi64(
_mm256_castps_si256(pC), _mm256_castps_si256(pD)); // [e2 f2 g2 h2 | _ _ _ _]
// row1 = [a1 b1 c1 d1 | e1 f1 g1 h1] → lo 128 of AB_lo, lo 128 of CD_lo
// row2 = [a2 b2 c2 d2 | e2 f2 g2 h2] → lo 128 of AB_hi, lo 128 of CD_hi
// row3 = [a3 b3 c3 d3 | e3 f3 g3 h3] → hi 128 of AB_lo, hi 128 of CD_lo
__m256 row1 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_lo, CD_lo, 0x20)); // lo,lo
__m256 row2 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_hi, CD_hi, 0x20)); // lo,lo
__m256 row3 = _mm256_castsi256_ps(_mm256_permute2x128_si256(AB_lo, CD_lo, 0x31)); // hi,hi
row1 = _mm256_add_ps(row1, row2);
row1 = _mm256_add_ps(row1, row3);
return row1;
}
constexpr static VectorF32<1, 15> Dot(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1
) requires(Len == 3 && Packing == 5) {
// __m512: Each register: [A1 A2 A3 B1 B2 B3 C1 C2 C3 D1 D2 D3 E1 E2 E3 _]
// 3 pairs × 5 vectors each = 15 dot products → fits in 1 x __m512 (slot 16 unused)
//
// After multiply of 3 pairs:
// mul0 = [a1 a2 a3 b1 b2 b3 c1 c2 c3 d1 d2 d3 e1 e2 e3 _]
// mul1 = [f1 f2 f3 g1 g2 g3 h1 h2 h3 i1 i2 i3 j1 j2 j3 _]
// mul2 = [k1 k2 k3 l1 l2 l3 m1 m2 m3 n1 n2 n3 o1 o2 o3 _]
//
// Result = [a· b· c· d· e· f· g· h· i· j· k· l· m· n· o· _]
//
// Strategy: for each mul register, gather element 1s, 2s, 3s with vpermps,
// then combine across registers.
//
// From mul0: 5 vectors at positions {0,1,2}, {3,4,5}, {6,7,8}, {9,10,11}, {12,13,14}
// elem1 = indices {0, 3, 6, 9, 12} → positions 0..4 of result
// elem2 = indices {1, 4, 7, 10, 13}
// elem3 = indices {2, 5, 8, 11, 14}
__m512 mul0 = _mm512_mul_ps(A0.v, A1.v);
__m512 mul1 = _mm512_mul_ps(B0.v, B1.v);
__m512 mul2 = _mm512_mul_ps(C0.v, C1.v);
// Gather elem1, elem2, elem3 from each mul register
// Each register has 5 vec3s: extract element 1,2,3 of each into consecutive positions
__m512i idx1 = _mm512_setr_epi32(0, 3, 6, 9, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512i idx2 = _mm512_setr_epi32(1, 4, 7, 10, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
__m512i idx3 = _mm512_setr_epi32(2, 5, 8, 11, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
// From mul0 → results 0..4, from mul1 → results 5..9, from mul2 → results 10..14
// Gather from each, then combine.
__m512 e1_0 = _mm512_permutexvar_ps(idx1, mul0); // [a1 b1 c1 d1 e1 ...]
__m512 e2_0 = _mm512_permutexvar_ps(idx2, mul0); // [a2 b2 c2 d2 e2 ...]
__m512 e3_0 = _mm512_permutexvar_ps(idx3, mul0); // [a3 b3 c3 d3 e3 ...]
__m512 e1_1 = _mm512_permutexvar_ps(idx1, mul1); // [f1 g1 h1 i1 j1 ...]
__m512 e2_1 = _mm512_permutexvar_ps(idx2, mul1); // [f2 g2 h2 i2 j2 ...]
__m512 e3_1 = _mm512_permutexvar_ps(idx3, mul1); // [f3 g3 h3 i3 j3 ...]
__m512 e1_2 = _mm512_permutexvar_ps(idx1, mul2); // [k1 l1 m1 n1 o1 ...]
__m512 e2_2 = _mm512_permutexvar_ps(idx2, mul2); // [k2 l2 m2 n2 o2 ...]
__m512 e3_2 = _mm512_permutexvar_ps(idx3, mul2); // [k3 l3 m3 n3 o3 ...]
// Now combine: we need positions 0..4 from reg0, 5..9 from reg1, 10..14 from reg2
// Use masked moves to assemble the final row vectors.
// mask for positions 0-4: 0b0000000000011111 = 0x001F
// mask for positions 5-9: 0b0000001111100000 = 0x03E0
// mask for positions 10-14: 0b0111110000000000 = 0x7C00
// For reg1, its results are in positions 0..4 but need to go to 5..9.
// For reg2, its results are in positions 0..4 but need to go to 10..14.
// Use a different approach: permute reg1/reg2 results to their target positions.
// Shift reg1 results from slots 0..4 to slots 5..9
__m512i shiftIdx1 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0);
// Shift reg2 results from slots 0..4 to slots 10..14
__m512i shiftIdx2 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0);
__m512 e1_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e1_1);
__m512 e2_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e2_1);
__m512 e3_1_shifted = _mm512_permutexvar_ps(shiftIdx1, e3_1);
__m512 e1_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e1_2);
__m512 e2_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e2_2);
__m512 e3_2_shifted = _mm512_permutexvar_ps(shiftIdx2, e3_2);
// Blend: take positions 0..4 from reg0, 5..9 from reg1, 10..14 from reg2
__mmask16 mask_5_9 = 0x03E0u; // bits 5-9
__mmask16 mask_10_14 = 0x7C00u; // bits 10-14
__m512 row1 = _mm512_mask_mov_ps(e1_0, mask_5_9, e1_1_shifted);
row1 = _mm512_mask_mov_ps(row1, mask_10_14, e1_2_shifted);
__m512 row2 = _mm512_mask_mov_ps(e2_0, mask_5_9, e2_1_shifted);
row2 = _mm512_mask_mov_ps(row2, mask_10_14, e2_2_shifted);
__m512 row3 = _mm512_mask_mov_ps(e3_0, mask_5_9, e3_1_shifted);
row3 = _mm512_mask_mov_ps(row3, mask_10_14, e3_2_shifted);
row1 = _mm512_add_ps(row1, row2);
row1 = _mm512_add_ps(row1, row3);
return row1;
}
constexpr static VectorF32<1, Packing*2> Dot(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return DotNoShuffle(A0, A1, C0, C1);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
VectorF32<8, 1> vec(DotNoShuffle(A0, A1, C0, C1).v);
vec = vec.template Shuffle<{{
0,1, 4,5,
2,3, 6,7,
}}>();
return vec.v;
} else {
VectorF32<16, 1> vec(DotNoShuffle(A0, A1, C0, C1).v);
vec = vec.template Shuffle<{{
0,1, 4,5,
8,9, 12,13,
2,3, 6,7,
10,11, 14,15
}}>();
return vec.v;
}
}
private:
constexpr static VectorF32<1, Packing*4> LengthNoShuffle(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
VectorF32<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, B, C, D);
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<1, Packing*4>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<1, Packing*4>(_mm256_sqrt_ps(lenghtSq.v));
} else {
return VectorF32<1, Packing*4>(_mm512_sqrt_ps(lenghtSq.v));
}
}
constexpr static VectorF32<1, Packing*2> LengthNoShuffle(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> C
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
VectorF32<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, C);
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return VectorF32<1, Packing*2>(_mm_sqrt_ps(lenghtSq.v));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
return VectorF32<1, Packing*2>(_mm256_sqrt_ps(lenghtSq.v));
} else {
return VectorF32<1, Packing*2>(_mm512_sqrt_ps(lenghtSq.v));
}
}
constexpr static VectorF32<1, Packing*4> LengthSqNoShuffle(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> B,
VectorF32<Len, Packing> C,
VectorF32<Len, Packing> D
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
return DotNoShuffle(A, A, B, B, C, C, D, D);
}
constexpr static VectorF32<1, Packing*2> LengthSqNoShuffle(
VectorF32<Len, Packing> A,
VectorF32<Len, Packing> C
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
return DotNoShuffle(A, A, C, C);
}
constexpr static VectorF32<1, Packing*4> DotNoShuffle(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> B0, VectorF32<Len, Packing> B1,
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1,
VectorF32<Len, Packing> D0, VectorF32<Len, Packing> D1
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
__m128 mulB = _mm_mul_ps(B0.v, B1.v);
__m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A1 B1 A2 B2
__m128i row34Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulB)); // A3 B3 A4 B4
__m128 mulC = _mm_mul_ps(C0.v, C1.v);
__m128 mulD = _mm_mul_ps(D0.v, D1.v);
__m128i row12Temp2 = _mm_unpacklo_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C1 D1 C2 D2
__m128i row34Temp2 = _mm_unpackhi_epi32(_mm_castps_si128(mulC), _mm_castps_si128(mulD)); // C3 D3 C4 D4
__m128 row1 = _mm_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1
__m128 row2 = _mm_unpackhi_epi32(row12Temp1, row12Temp2); // A2 C2 B2 D2
__m128 row3 = _mm_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3
__m128 row4 = _mm_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4
row1 = _mm_add_ps(row1, row2);
row1 = _mm_add_ps(row1, row3);
row1 = _mm_add_ps(row1, row4);
return row1;
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
__m256 mulA = _mm256_mul_ps(A0.v, A1.v);
__m256 mulB = _mm256_mul_ps(B0.v, B1.v);
__m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A1 B1 A2 B2
__m256i row34Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulB)); // A3 B3 A4 B4
__m256 mulC = _mm256_mul_ps(C0.v, C1.v);
__m256 mulD = _mm256_mul_ps(D0.v, D1.v);
__m256i row12Temp2 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C1 D1 C2 D2
__m256i row34Temp2 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulC), _mm256_castps_si256(mulD)); // C3 D3 C4 D4
__m256 row1 = _mm256_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1
__m256 row2 = _mm256_unpackhi_epi32(row12Temp1, row12Temp2); //A2 C2 B2 D2
__m256 row3 = _mm256_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3
__m256 row4 = _mm256_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4
row1 = _mm256_add_ps(row1, row2);
row1 = _mm256_add_ps(row1, row3);
row1 = _mm256_add_ps(row1, row4);
return row1;
} else {
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
__m512 mulB = _mm512_mul_ps(B0.v, B1.v);
__m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A1 B1 A2 B2
__m512i row34Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulB)); // A3 B3 A4 B4
__m512 mulC = _mm512_mul_ps(C0.v, C1.v);
__m512 mulD = _mm512_mul_ps(D0.v, D1.v);
__m512i row12Temp2 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C1 D1 C2 D2
__m512i row34Temp2 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulC), _mm512_castps_si512(mulD)); // C3 D3 C4 D4
__m512 row1 = _mm512_unpacklo_epi32(row12Temp1, row12Temp2); // A1 C1 B1 D1
__m512 row2 = _mm512_unpackhi_epi32(row12Temp1, row12Temp2); //A2 C2 B2 D2
__m512 row3 = _mm512_unpacklo_epi32(row34Temp1, row34Temp2); // A3 C3 B3 D3
__m512 row4 = _mm512_unpackhi_epi32(row34Temp1, row34Temp2); // A4 C4 B4 D4
row1 = _mm512_add_ps(row1, row2);
row1 = _mm512_add_ps(row1, row3);
row1 = _mm512_add_ps(row1, row4);
return row1;
}
}
constexpr static VectorF32<1, Packing*2> DotNoShuffle(
VectorF32<Len, Packing> A0, VectorF32<Len, Packing> A1,
VectorF32<Len, Packing> C0, VectorF32<Len, Packing> C1
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, float>::AlignmentElement) {
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
__m128 mulA = _mm_mul_ps(A0.v, A1.v);
__m128 mulC = _mm_mul_ps(C0.v, C1.v);
__m128i row12Temp1 = _mm_unpacklo_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulC)); // A1 C1 A2 C2
__m128i row56Temp1 = _mm_unpackhi_epi32(_mm_castps_si128(mulA), _mm_castps_si128(mulC)); // B1 D1 B2 D2
__m128i row1TempTemp1 = row12Temp1;
__m128i row5TempTemp1 = row56Temp1;
row12Temp1 = _mm_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
row56Temp1 = _mm_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm_add_ps(row12Temp1, row56Temp1);
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
__m256 mulA = _mm256_mul_ps(A0.v, A1.v);
__m256 mulC = _mm256_mul_ps(C0.v, C1.v);
__m256i row12Temp1 = _mm256_unpacklo_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulC)); // A1 C1 A2 C2
__m256i row56Temp1 = _mm256_unpackhi_epi32(_mm256_castps_si256(mulA), _mm256_castps_si256(mulC)); // B1 D1 B2 D2
__m256i row1TempTemp1 = row12Temp1;
__m256i row5TempTemp1 = row56Temp1;
row12Temp1 = _mm256_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
row56Temp1 = _mm256_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm256_add_ps(row12Temp1, row56Temp1);
} else {
__m512 mulA = _mm512_mul_ps(A0.v, A1.v);
__m512 mulC = _mm512_mul_ps(C0.v, C1.v);
__m512i row12Temp1 = _mm512_unpacklo_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulC)); // A1 C1 A2 C2
__m512i row56Temp1 = _mm512_unpackhi_epi32(_mm512_castps_si512(mulA), _mm512_castps_si512(mulC)); // B1 D1 B2 D2
__m512i row1TempTemp1 = row12Temp1;
__m512i row5TempTemp1 = row56Temp1;
row12Temp1 = _mm512_unpacklo_epi32(row12Temp1, row56Temp1); // A1 B1 C1 D1
row56Temp1 = _mm512_unpackhi_epi32(row1TempTemp1, row56Temp1); // A2 B2 C2 D2
return _mm512_add_ps(row12Temp1, row56Temp1);
}
}
public:
template <std::array<bool, Len> ShuffleValues>
constexpr static VectorF32<Len, Packing> Blend(VectorF32<Len, Packing> a, VectorF32<Len, Packing> b) {
constexpr auto mask = VectorBase<Len, Packing, float>::template GetBlendMaskEpi32<ShuffleValues>();
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m128>) {
return _mm_castsi128_ps(_mm_blend_epi32(_mm_castps_si128(a.v), _mm_castps_si128(b.v), mask));
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, float>::VectorType, __m256>) {
#ifndef __AVX512BW__
#ifndef __AVX512VL__
static_assert(false, "No __AVX512BW__ and __AVX512VL__ support");
#endif
#endif
return _mm256_castsi256_ps(_mm256_mask_blend_epi32(mask, _mm256_castps_si256(a.v), _mm256_castps_si256(b.v)));
} else {
return _mm512_castsi512_ps(_mm512_mask_blend_epi32(mask, _mm512_castps_si512(a.v), _mm512_castps_si512(b.v)));
}
}
constexpr static VectorF32<Len, Packing> Rotate(VectorF32<3, Packing> v, VectorF32<4, Packing> q) requires(Len == 3) {
VectorF32<3, Packing> qv(q);
VectorF32<Len, Packing> t = Cross(qv, v) * float(2);
return v + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t);
}
constexpr static VectorF32<4, 2> RotatePivot(VectorF32<3, Packing> v, VectorF32<4, Packing> q, VectorF32<3, Packing> pivot) requires(Len == 3) {
VectorF32<Len, Packing> translated = v - pivot;
VectorF32<3, Packing> qv(q.v);
VectorF32<Len, Packing> t = Cross(qv, translated) * float(2);
VectorF32<Len, Packing> rotated = translated + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t);
return rotated + pivot;
}
constexpr static VectorF32<4, Packing> QuanternionFromEuler(VectorF32<3, Packing> EulerHalf) requires(Len == 4) {
std::tuple<VectorF32<3, Packing>, VectorF32<3, Packing>> sinCos = EulerHalf.SinCos();
VectorF32<4, Packing> sin = std::get<0>(sinCos);
VectorF32<4, Packing> cos = std::get<1>(sinCos);
VectorF32<4, Packing> row1 = cos.template Shuffle<{{0,0,0,0}}>();
row1 = Blend<{{0,1,1,1}}>(sin, row1);
VectorF32<4, Packing> row2 = cos.template Shuffle<{{1,1,1,1}}>();
row2 = Blend<{{1,0,1,1}}>(sin, row2);
row1 *= row2;
VectorF32<4, Packing> row3 = cos.template Shuffle<{{2,2,2,2}}>();
row3 = Blend<{{1,1,0,1}}>(sin, row3);
row1 *= row3;
VectorF32<4, Packing> row4 = sin.template Shuffle<{{0,0,0,0}}>();
row4 = Blend<{{0,1,1,1}}>(cos, row4);
VectorF32<4, Packing> row5 = sin.template Shuffle<{{1,1,1,1}}>();
row5 = Blend<{{1,0,1,1}}>(cos, row5);
row4 *= row5;
VectorF32<4, Packing> row6 = sin.template Shuffle<{{2,2,2,2}}>();
row6 = Blend<{{1,1,0,1}}>(cos, row6);
row6 = row6.template Negate<{{true,false,true,false}}>();
row1 = MulitplyAdd(row4, row6, row1);
return row1;
}
};
}
export template <std::uint32_t Len, std::uint32_t Packing>
struct std::formatter<Crafter::VectorF32<Len, Packing>> : std::formatter<std::string> {
constexpr auto format(const Crafter::VectorF32<Len, Packing>& obj, format_context& ctx) const {
std::array<float, Crafter::VectorF32<Len, Packing>::AlignmentElement> vec = obj.Store();
std::string out = "{";
for(std::uint32_t i = 0; i < Packing; i++) {
out += "{";
for(std::uint32_t i2 = 0; i2 < Len; i2++) {
out += std::format("{}", static_cast<float>(vec[i * Len + i2]));
if (i2 + 1 < Len) out += ",";
}
out += "}";
}
out += "}";
return std::formatter<std::string>::format(out, ctx);
}
};
#endif