1074 lines
No EOL
66 KiB
C++
Executable file
1074 lines
No EOL
66 KiB
C++
Executable file
/*
|
|
Crafter®.Math
|
|
Copyright (C) 2026 Catcrafts®
|
|
catcrafts.net
|
|
|
|
This library is free software; you can redistribute it and/or
|
|
modify it under the terms of the GNU Lesser General Public
|
|
License version 3.0 as published by the Free Software Foundation;
|
|
|
|
This library is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
Lesser General Public License for more details.
|
|
|
|
You should have received a copy of the GNU Lesser General Public
|
|
License along with this library; if not, write to the Free Software
|
|
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|
*/
|
|
module;
|
|
#ifdef __x86_64
|
|
#include <immintrin.h>
|
|
#endif
|
|
export module Crafter.Math:VectorF16;
|
|
import std;
|
|
import :Common;
|
|
|
|
#ifdef __AVX512FP16__
|
|
namespace Crafter {
|
|
export template <std::uint8_t Len, std::uint8_t Packing>
|
|
struct VectorF16 : public VectorBase<Len, Packing, _Float16> {
|
|
template <std::uint8_t Len2, std::uint8_t Packing2>
|
|
friend struct VectorF16;
|
|
|
|
constexpr VectorF16() = default;
|
|
constexpr VectorF16(VectorBase<Len, Packing, _Float16>::VectorType v) {
|
|
this->v = v;
|
|
}
|
|
constexpr VectorF16(const _Float16* vB) {
|
|
Load(vB);
|
|
};
|
|
constexpr VectorF16(_Float16 val) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
this->v = _mm_set1_ph(val);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
this->v = _mm256_set1_ph(val);
|
|
} else {
|
|
this->v = _mm512_set1_ph(val);
|
|
}
|
|
};
|
|
constexpr void Load(const _Float16* vB) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
this->v = _mm_loadu_ph(vB);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
this->v = _mm256_loadu_ph(vB);
|
|
} else {
|
|
this->v = _mm512_loadu_ph(vB);
|
|
}
|
|
}
|
|
constexpr void Store(_Float16* vB) const {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
_mm_storeu_ph(vB, this->v);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
_mm256_storeu_ph(vB, this->v);
|
|
} else {
|
|
_mm512_storeu_ph(vB, this->v);
|
|
}
|
|
}
|
|
|
|
constexpr std::array<_Float16, VectorBase<Len, Packing, _Float16>::AlignmentElement> Store() const {
|
|
std::array<_Float16, VectorBase<Len, Packing, _Float16>::AlignmentElement> returnArray;
|
|
Store(returnArray.data());
|
|
return returnArray;
|
|
}
|
|
|
|
template <std::uint8_t BLen, std::uint8_t BPacking>
|
|
constexpr operator VectorF16<BLen, BPacking>() const {
|
|
if constexpr (Len == BLen) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h> && std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<BLen, BPacking>(_mm256_castph256_ph128(this->v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m512h> && std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<BLen, BPacking>(_mm512_castph512_ph128(this->v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m512h> && std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<BLen, BPacking>(_mm512_castph512_ph256(this->v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h> && std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<BLen, BPacking>(_mm256_castph128_ph256(this->v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h> && std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m512h>) {
|
|
return VectorF16<BLen, BPacking>(_mm512_castph128_ph512(this->v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h> && std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m512h>) {
|
|
return VectorF16<BLen, BPacking>(_mm512_castph256_ph512(this->v));
|
|
} else {
|
|
return VectorF16<BLen, BPacking>(this->v);
|
|
}
|
|
} else if constexpr (BLen <= Len) {
|
|
return this->template ExtractLo<BLen>();
|
|
} else {
|
|
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, _Float16>::VectorType, __m128h>) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::Alignment> shuffleMask = VectorBase<Len, Packing, _Float16>::template GetExtractLoMaskEpi8<BLen>();
|
|
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
|
return VectorF16<BLen, BPacking>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec)));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
constexpr std::array<std::uint16_t, VectorBase<Len, Packing, _Float16>::AlignmentElement> permMask =VectorBase<Len, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
|
|
__m256i result = _mm256_permutexvar_epi16(permIdx, _mm_castph_si256(this->v));
|
|
return VectorF16<BLen, BPacking>(_mm_castsi128_ph(_mm256_castsi256_si128(result)));
|
|
} else {
|
|
constexpr std::array<std::uint16_t, VectorBase<Len, Packing, _Float16>::AlignmentElement> permMask = VectorBase<Len, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
|
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v));
|
|
return VectorF16<BLen, BPacking>(_mm_castsi128_ph(_mm512_castsi512_si128(result)));
|
|
}
|
|
} else if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, _Float16>::VectorType, __m256h>) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
constexpr std::array<std::uint16_t,VectorBase<BLen, Packing, _Float16>::AlignmentElement> permMask = VectorBase<BLen, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
|
|
__m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castsi128_si256(_mm_castph_si128(this->v)));
|
|
return VectorF16<BLen, BPacking>(_mm256_castsi256_ph(result));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
constexpr std::array<std::uint16_t,VectorBase<BLen, Packing, _Float16>::AlignmentElement> permMask = VectorBase<BLen, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
|
|
__m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(this->v));
|
|
return VectorF16<BLen, BPacking>(_mm256_castsi256_ph(result));
|
|
} else {
|
|
constexpr std::array<std::uint16_t,VectorBase<BLen, Packing, _Float16>::AlignmentElement> permMask = VectorBase<BLen, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m256i permIdx = _mm512_loadu_epi16(permMask.data());
|
|
__m256i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi512_si256(_mm512_castph_si512(this->v)));
|
|
return VectorF16<BLen, BPacking>(_mm256_castsi256_ph(result));
|
|
}
|
|
} else {
|
|
if constexpr(std::is_same_v<typename VectorBase<BLen, BPacking, _Float16>::VectorType, __m128h>) {
|
|
constexpr std::array<std::uint16_t,VectorBase<BLen, Packing, _Float16>::AlignmentElement> permMask = VectorBase<BLen, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
|
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi128_si512(_mm_castph_si128(this->v)));
|
|
return VectorF16<BLen, BPacking>(_mm512_castsi512_ph(result));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
constexpr std::array<std::uint16_t,VectorBase<BLen, Packing, _Float16>::AlignmentElement> permMask = VectorBase<BLen, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
|
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castsi256_si512(_mm256_castph_si256(this->v)));
|
|
return VectorF16<BLen, BPacking>(_mm512_castsi512_ph(result));
|
|
} else {
|
|
constexpr std::array<std::uint16_t,VectorBase<BLen, Packing, _Float16>::AlignmentElement> permMask = VectorBase<BLen, Packing, _Float16>::template GetExtractLoMaskEpi16<BLen>();
|
|
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
|
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v));
|
|
return VectorF16<BLen, BPacking>(_mm512_castsi512_ph(result));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator+(VectorF16<Len, Packing> b) const {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_add_ph(this->v, b.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_add_ph(this->v, b.v));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_add_ph(this->v, b.v));
|
|
}
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator-(VectorF16<Len, Packing> b) const {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_sub_ph(this->v, b.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_sub_ph(this->v, b.v));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_sub_ph(this->v, b.v));
|
|
}
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator*(VectorF16<Len, Packing> b) const {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_mul_ph(this->v, b.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_mul_ph(this->v, b.v));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_mul_ph(this->v, b.v));
|
|
}
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator/(VectorF16<Len, Packing> b) const {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_div_ph(this->v, b.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_div_ph(this->v, b.v));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_div_ph(this->v, b.v));
|
|
}
|
|
}
|
|
|
|
|
|
constexpr void operator+=(VectorF16<Len, Packing> b) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
this->v = _mm_add_ph(this->v, b.v);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
this->v = _mm256_add_ph(this->v, b.v);
|
|
} else {
|
|
this->v = _mm512_add_ph(this->v, b.v);
|
|
}
|
|
}
|
|
|
|
constexpr void operator-=(VectorF16<Len, Packing> b) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
this->v = _mm_sub_ph(this->v, b.v);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
this->v = _mm256_sub_ph(this->v, b.v);
|
|
} else {
|
|
this->v = _mm512_sub_ph(this->v, b.v);
|
|
}
|
|
}
|
|
|
|
constexpr void operator*=(VectorF16<Len, Packing> b) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
this->v = _mm_mul_ph(this->v, b.v);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
this->v = _mm256_mul_ph(this->v, b.v);
|
|
} else {
|
|
this->v = _mm512_mul_ph(this->v, b.v);
|
|
}
|
|
}
|
|
|
|
constexpr void operator/=(VectorF16<Len, Packing> b) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
this->v = _mm_div_ph(this->v, b.v);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
this->v = _mm256_div_ph(this->v, b.v);
|
|
} else {
|
|
this->v = _mm512_div_ph(this->v, b.v);
|
|
}
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator+(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
return *this + vB;
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator-(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
return *this - vB;
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator*(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
return *this * vB;
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator/(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
return *this / vB;
|
|
}
|
|
|
|
constexpr void operator+=(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
*this += vB;
|
|
}
|
|
|
|
constexpr void operator-=(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
*this -= vB;
|
|
}
|
|
|
|
constexpr void operator*=(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
*this *= vB;
|
|
}
|
|
|
|
constexpr void operator/=(_Float16 b) {
|
|
VectorF16<Len, Packing> vB(b);
|
|
*this /= vB;
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> operator-(){
|
|
return Negate<VectorBase<Len, Packing, _Float16>::GetAllTrue()>();
|
|
}
|
|
|
|
constexpr bool operator==(VectorF16<Len, Packing> b) const {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return _mm_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) == 255;
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return _mm256_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) == 65535;
|
|
} else {
|
|
return _mm512_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) == 4294967295;
|
|
}
|
|
}
|
|
|
|
constexpr bool operator!=(VectorF16<Len, Packing> b) const {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return _mm_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) != 255;
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return _mm256_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) != 65535;
|
|
} else {
|
|
return _mm512_cmp_ph_mask(this->v, b.v, _CMP_EQ_OQ) != 4294967295;
|
|
}
|
|
}
|
|
|
|
template<std::uint32_t ExtractLen>
|
|
constexpr VectorF16<ExtractLen, Packing> ExtractLo() const {
|
|
if constexpr(Packing > 1) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
constexpr std::array<std::uint8_t,VectorBase<Len, Packing, _Float16>::Alignment> shuffleMask = VectorBase<Len, Packing, _Float16>::template GetExtractLoMaskEpi8<ExtractLen>();
|
|
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
|
return VectorF16<ExtractLen, Packing>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec)));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
constexpr std::array<std::uint16_t,VectorBase<Len, Packing, _Float16>::AlignmentElement> permMask = VectorBase<Len, Packing, _Float16>::template GetExtractLoMaskEpi16<ExtractLen>();
|
|
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
|
|
__m256i result = _mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(this->v));
|
|
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<ExtractLen, Packing>(_mm256_castph256_ph128(_mm256_castsi256_ph(result)));
|
|
} else {
|
|
return VectorF16<ExtractLen, Packing>(_mm256_castsi256_ph(result));
|
|
}
|
|
} else {
|
|
constexpr std::array<std::uint16_t, VectorBase<Len, Packing, _Float16>::AlignmentElement> permMask = VectorBase<Len, Packing, _Float16>::template GetExtractLoMaskEpi16<ExtractLen>();
|
|
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
|
__m512i result = _mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v));
|
|
if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<ExtractLen, Packing>(_mm512_castph512_ph128(_mm512_castsi512_ph(result)));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<ExtractLen, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<ExtractLen, Packing>(_mm512_castph512_ph256(_mm512_castsi512_ph(result)));
|
|
} else {
|
|
return VectorF16<ExtractLen, Packing>(_mm512_castsi512_ph(result));
|
|
}
|
|
}
|
|
} else {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h> && std::is_same_v<typename VectorBase<ExtractLen, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<ExtractLen, Packing>(_mm256_castph256_ph128(this->v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m512h> && std::is_same_v<typename VectorBase<ExtractLen, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<ExtractLen, Packing>(_mm512_castph512_ph128(this->v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m512h> && std::is_same_v<typename VectorBase<ExtractLen, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<ExtractLen, Packing>(_mm512_castph512_ph256(this->v));
|
|
} else {
|
|
return VectorF16<ExtractLen, Packing>(this->v);
|
|
}
|
|
}
|
|
}
|
|
|
|
constexpr void Normalize() requires(Packing == 1) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
_Float16 dot = LengthSq();
|
|
__m128h vec = _mm_set1_ph(dot);
|
|
__m128h sqrt = _mm_sqrt_ph(vec);
|
|
this->v = _mm_div_ph(this->v, sqrt);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
_Float16 dot = LengthSq();
|
|
__m256h vec = _mm256_set1_ph(dot);
|
|
__m256h sqrt = _mm256_sqrt_ph(vec);
|
|
this->v = _mm256_div_ph(this->v, sqrt);
|
|
} else {
|
|
_Float16 dot = LengthSq();
|
|
__m512h vec = _mm512_set1_ph(dot);
|
|
__m512h sqrt = _mm512_sqrt_ph(vec);
|
|
this->v = _mm512_div_ph(this->v, sqrt);
|
|
}
|
|
}
|
|
|
|
constexpr _Float16 Length() const requires(Packing == 1) {
|
|
_Float16 Result = LengthSq();
|
|
return std::sqrtf(Result);
|
|
}
|
|
|
|
constexpr _Float16 LengthSq() const requires(Packing == 1) {
|
|
return Dot(*this, *this);
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> Cos() {
|
|
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(this->v));
|
|
wide = VectorBase<Len, Packing, _Float16>::cos_f32x8(wide);
|
|
return VectorF16<Len, Packing>(
|
|
_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
|
|
|
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(this->v));
|
|
wide = VectorBase<Len, Packing, _Float16>::cos_f32x16(wide);
|
|
return VectorF16<Len, Packing>(
|
|
_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
|
|
|
} else {
|
|
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(this->v));
|
|
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(this->v), 1);
|
|
__m512i cosLo =VectorBase<Len, Packing, _Float16>::cos_f32x16(_mm512_cvtph_ps(lo));
|
|
__m512i cosHi =VectorBase<Len, Packing, _Float16>::cos_f32x16(_mm512_cvtph_ps(hi));
|
|
__m256i lo_ph = _mm512_cvtps_ph(cosLo, _MM_FROUND_TO_NEAREST_INT);
|
|
__m256i hi_ph = _mm512_cvtps_ph(cosHi, _MM_FROUND_TO_NEAREST_INT);
|
|
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)));
|
|
}
|
|
}
|
|
|
|
constexpr VectorF16<Len, Packing> Sin() {
|
|
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(this->v));
|
|
wide = VectorBase<Len, Packing, _Float16>::sin_f32x8(wide);
|
|
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
|
|
|
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(this->v));
|
|
wide = VectorBase<Len, Packing, _Float16>::sin_f32x16(wide);
|
|
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(wide, _MM_FROUND_TO_NEAREST_INT)));
|
|
|
|
} else {
|
|
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(this->v));
|
|
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(this->v), 1);
|
|
__m512i loSin = VectorBase<Len, Packing, _Float16>::sin_f32x16(_mm512_cvtph_ps(lo));
|
|
__m512i hiSin = VectorBase<Len, Packing, _Float16>::sin_f32x16(_mm512_cvtph_ps(hi));
|
|
__m256i lo_ph = _mm512_cvtps_ph(loSin, _MM_FROUND_TO_NEAREST_INT);
|
|
__m256i hi_ph = _mm512_cvtps_ph(hiSin, _MM_FROUND_TO_NEAREST_INT);
|
|
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1)));
|
|
}
|
|
}
|
|
|
|
std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>> SinCos() {
|
|
if constexpr (std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
__m256 wide = _mm256_cvtph_ps(_mm_castph_si128(this->v));
|
|
__m256 s, c;
|
|
VectorBase<Len, Packing, _Float16>::sincos_f32x8(wide, s, c);
|
|
return {
|
|
VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))),
|
|
VectorF16<Len, Packing>(_mm_castsi128_ph(_mm256_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT)))
|
|
};
|
|
|
|
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
__m512 wide = _mm512_cvtph_ps(_mm256_castph_si256(this->v));
|
|
__m512 s, c;
|
|
VectorBase<Len, Packing, _Float16>::sincos_f32x16(wide, s, c);
|
|
return {
|
|
VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(s, _MM_FROUND_TO_NEAREST_INT))),
|
|
VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_cvtps_ph(c, _MM_FROUND_TO_NEAREST_INT)))
|
|
};
|
|
|
|
} else {
|
|
__m256i lo = _mm512_castsi512_si256(_mm512_castph_si512(this->v));
|
|
__m256i hi = _mm512_extracti64x4_epi64(_mm512_castph_si512(this->v), 1);
|
|
|
|
__m512 s_lo, c_lo, s_hi, c_hi;
|
|
VectorBase<Len, Packing, _Float16>::sincos_f32x16(_mm512_cvtph_ps(lo), s_lo, c_lo);
|
|
VectorBase<Len, Packing, _Float16>::sincos_f32x16(_mm512_cvtph_ps(hi), s_hi, c_hi);
|
|
|
|
auto pack = [](__m256i lo_ph, __m256i hi_ph) {
|
|
return _mm512_castsi512_ph(_mm512_inserti64x4(_mm512_castsi256_si512(lo_ph), hi_ph, 1));
|
|
};
|
|
return {
|
|
VectorF16<Len, Packing>(pack(_mm512_cvtps_ph(s_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(s_hi, _MM_FROUND_TO_NEAREST_INT))),
|
|
VectorF16<Len, Packing>(pack( _mm512_cvtps_ph(c_lo, _MM_FROUND_TO_NEAREST_INT), _mm512_cvtps_ph(c_hi, _MM_FROUND_TO_NEAREST_INT)))
|
|
};
|
|
}
|
|
}
|
|
|
|
template <std::array<bool, Len> values>
|
|
constexpr VectorF16<Len, Packing> Negate() {
|
|
std::array<_Float16, VectorBase<Len, Packing, _Float16>::AlignmentElement> mask = VectorBase<Len, Packing, _Float16>::template GetNegateMask<values>();
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(this->v), _mm_loadu_epi16(mask.data()))));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(this->v), _mm256_loadu_epi16(mask.data()))));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(this->v), _mm512_loadu_epi16(mask.data()))));
|
|
}
|
|
}
|
|
|
|
static constexpr VectorF16<Len, Packing> MulitplyAdd(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b, VectorF16<Len, Packing> add) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_fmadd_ph(a.v, b.v, add.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_fmadd_ph(a.v, b.v, add.v));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_fmadd_ph(a.v, b.v, add.v));
|
|
}
|
|
}
|
|
|
|
static constexpr VectorF16<Len, Packing> MulitplySub(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b, VectorF16<Len, Packing> sub) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_fmsub_ph(a.v, b.v, sub.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_fmsub_ph(a.v, b.v, sub.v));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_fmsub_ph(a.v, b.v, sub.v));
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<Len, Packing> Cross(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b) requires(Len == 3) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::AlignmentElement*2> shuffleMask1 = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
|
__m128i shuffleVec1 = _mm_loadu_epi8(shuffleMask1.data());
|
|
__m128h row1 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec1));
|
|
__m128h row4 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec1));
|
|
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::AlignmentElement*2> shuffleMask3 = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
|
__m128i shuffleVec3 = _mm_loadu_epi8(shuffleMask3.data());
|
|
__m128h row3 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a.v), shuffleVec3));
|
|
__m128h row2 = _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(b.v), shuffleVec3));
|
|
|
|
__m128h result = _mm_mul_ph(row3, row4);
|
|
return _mm_fmsub_ph(row1,row2,result);
|
|
} else if constexpr (std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::AlignmentElement*2> shuffleMask1 = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
|
__m512i shuffleVec1 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask1.data()));
|
|
__m256h row1 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec1)));
|
|
__m256h row4 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(b.v)), shuffleVec1)));
|
|
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::AlignmentElement*2> shuffleMask3 = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
|
|
|
__m512i shuffleVec3 = _mm512_castsi256_si512(_mm256_loadu_epi8(shuffleMask3.data()));
|
|
__m256h row3 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(a.v)), shuffleVec3)));
|
|
__m256h row2 = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(b.v)), shuffleVec3)));
|
|
|
|
__m256h result = _mm256_mul_ph(row3, row4);
|
|
return _mm256_fmsub_ph(row1,row2,result);
|
|
} else {
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::AlignmentElement*2> shuffleMask1 = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<{{1,2,0}}>();
|
|
|
|
__m512i shuffleVec1 = _mm512_loadu_epi8(shuffleMask1.data());
|
|
__m512h row1 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec1));
|
|
__m512h row4 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec1));
|
|
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::AlignmentElement*2> shuffleMask3 = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<{{2,0,1}}>();
|
|
|
|
__m512i shuffleVec3 = _mm512_loadu_epi8(shuffleMask3.data());
|
|
__m512h row3 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(a.v), shuffleVec3));
|
|
__m512h row2 = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(b.v), shuffleVec3));
|
|
|
|
__m512h result = _mm512_mul_ph(row3, row4);
|
|
return _mm512_fmsub_ph(row1,row2,result);
|
|
}
|
|
}
|
|
|
|
constexpr static _Float16 Dot(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b) requires(Packing == 1) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
__m128h mul = _mm_mul_ph(a.v, b.v);
|
|
return _mm_reduce_add_ph(mul);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
static_assert(std::is_same_v<decltype(a.v), typename VectorBase<Len, Packing, _Float16>::VectorType>, "a.v is NOT VectorBase<Len, Packing, _Float16>::VectorType");
|
|
__m256h mul = _mm256_mul_ph(a.v, b.v);
|
|
return _mm256_reduce_add_ph(mul);
|
|
} else {
|
|
__m512h mul = _mm512_mul_ph(a.v, b.v);
|
|
return _mm512_reduce_add_ph(mul);
|
|
}
|
|
}
|
|
|
|
template <const std::array<std::uint8_t, Len> ShuffleValues>
|
|
constexpr VectorF16<Len, Packing> Shuffle() {
|
|
if constexpr(VectorBase<Len, Packing, _Float16>::template CheckEpi32Shuffle<ShuffleValues>()) {
|
|
constexpr std::uint8_t imm = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi32<ShuffleValues>();
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(this->v), imm)));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(this->v), imm)));
|
|
} else {
|
|
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm512_castph_si512(this->v), imm)));
|
|
}
|
|
} else if constexpr(VectorBase<Len, Packing, _Float16>::template CheckEpi8Shuffle<ShuffleValues>()){
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::Alignment> shuffleMask = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<ShuffleValues>();
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
|
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec)));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
__m256i shuffleVec = _mm256_loadu_epi8(shuffleMask.data());
|
|
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_shuffle_epi8(_mm512_castsi256_si512(_mm256_castph_si256(this->v)), _mm512_castsi256_si512(shuffleVec)))));
|
|
} else {
|
|
__m512i shuffleVec = _mm512_loadu_epi8(shuffleMask.data());
|
|
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(this->v), shuffleVec)));
|
|
}
|
|
} else {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
constexpr std::array<std::uint8_t, VectorBase<Len, Packing, _Float16>::Alignment> shuffleMask = VectorBase<Len, Packing, _Float16>::template GetShuffleMaskEpi8<ShuffleValues>();
|
|
__m128i shuffleVec = _mm_loadu_epi8(shuffleMask.data());
|
|
return VectorF16<Len, Packing>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(this->v), shuffleVec)));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
constexpr std::array<std::uint16_t, VectorBase<Len, Packing, _Float16>::AlignmentElement> permMask = VectorBase<Len, Packing, _Float16>::template GetPermuteMaskEpi16<ShuffleValues>();
|
|
__m256i permIdx = _mm256_loadu_epi16(permMask.data());
|
|
return VectorF16<Len, Packing>(_mm256_castsi256_ph(_mm256_permutexvar_epi16(permIdx, _mm256_castph_si256(this->v))));
|
|
} else {
|
|
constexpr std::array<std::uint16_t, VectorBase<Len, Packing, _Float16>::AlignmentElement> permMask = VectorBase<Len, Packing, _Float16>::template GetPermuteMaskEpi16<ShuffleValues>();
|
|
__m512i permIdx = _mm512_loadu_epi16(permMask.data());
|
|
return VectorF16<Len, Packing>(_mm512_castsi512_ph(_mm512_permutexvar_epi16(permIdx, _mm512_castph_si512(this->v))));
|
|
}
|
|
}
|
|
}
|
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>, VectorF16<Len, Packing>> Normalize(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> C,
|
|
VectorF16<Len, Packing> E,
|
|
VectorF16<Len, Packing> G
|
|
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
VectorF16<1, 8> lenght = LengthNoShuffle(A, E, C, G);
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v));
|
|
|
|
VectorF16<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1}}>();
|
|
VectorF16<8, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3}}>();
|
|
VectorF16<8, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5}}>();
|
|
VectorF16<8, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7}}>();
|
|
|
|
return {
|
|
_mm_mul_ph(A.v, fLenghtA.v),
|
|
_mm_mul_ph(C.v, fLenghtC.v),
|
|
_mm_mul_ph(E.v, fLenghtE.v),
|
|
_mm_mul_ph(G.v, fLenghtG.v)
|
|
};
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
VectorF16<1, 16> lenght = LengthNoShuffle(A, E, C, G);
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
__m256h one = _mm256_loadu_ph(oneArr);
|
|
VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v));
|
|
|
|
|
|
VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9}}>();
|
|
VectorF16<16, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11}}>();
|
|
VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13}}>();
|
|
VectorF16<16, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15}}>();
|
|
|
|
return {
|
|
_mm256_mul_ph(A.v, fLenghtA.v),
|
|
_mm256_mul_ph(C.v, fLenghtC.v),
|
|
_mm256_mul_ph(E.v, fLenghtE.v),
|
|
_mm256_mul_ph(G.v, fLenghtG.v)
|
|
};
|
|
} else {
|
|
VectorF16<1, 32> lenght = LengthNoShuffle(A, E, C, G);
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
__m512h one = _mm512_loadu_ph(oneArr);
|
|
VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v));
|
|
VectorF16<32, 1> fLenght2(lenght.v);
|
|
|
|
VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,0,0,1,1,1,1,8,8,8,8,9,9,9,9,16,16,16,16,17,17,17,17,24,24,24,24,25,25,25,25}}>();
|
|
VectorF16<32, 1> fLenghtC = fLenght.template Shuffle<{{2,2,2,2,3,3,3,3,10,10,10,10,11,11,11,11,18,18,18,18,19,19,19,19,26,26,26,26,27,27,27,27}}>();
|
|
VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,4,4,5,5,5,5,12,12,12,12,13,13,13,13,20,20,20,20,21,21,21,21,28,28,28,28,29,29,29,29}}>();
|
|
VectorF16<32, 1> fLenghtG = fLenght.template Shuffle<{{6,6,6,6,7,7,7,7,14,14,14,14,15,15,15,15,22,22,22,22,23,23,23,23,30,30,30,30,31,31,31,31}}>();
|
|
|
|
return {
|
|
VectorF16<Len, Packing>(_mm512_mul_ph(A.v, fLenghtA.v)),
|
|
VectorF16<Len, Packing>(_mm512_mul_ph(C.v, fLenghtC.v)),
|
|
VectorF16<Len, Packing>(_mm512_mul_ph(E.v, fLenghtE.v)),
|
|
VectorF16<Len, Packing>(_mm512_mul_ph(G.v, fLenghtG.v)),
|
|
};
|
|
}
|
|
}
|
|
|
|
constexpr static std::tuple<VectorF16<Len, Packing>, VectorF16<Len, Packing>> Normalize(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> E
|
|
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
VectorF16<1, 8> lenght = LengthNoShuffle(A, E);
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1};
|
|
__m128h one = _mm_loadu_ph(oneArr);
|
|
VectorF16<8, 1> fLenght(_mm_div_ph(one, lenght.v));
|
|
|
|
VectorF16<8, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3}}>();
|
|
VectorF16<8, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7}}>();
|
|
|
|
return {
|
|
_mm_mul_ph(A.v, fLenghtA.v),
|
|
_mm_mul_ph(E.v, fLenghtE.v),
|
|
};
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
VectorF16<1, 16> lenght = LengthNoShuffle(A, E);
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
__m256h one = _mm256_loadu_ph(oneArr);
|
|
VectorF16<16, 1> fLenght(_mm256_div_ph(one, lenght.v));
|
|
|
|
VectorF16<16, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11}}>();
|
|
VectorF16<16, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15}}>();
|
|
|
|
return {
|
|
_mm256_mul_ph(A.v, fLenghtA.v),
|
|
_mm256_mul_ph(E.v, fLenghtE.v),
|
|
};
|
|
} else {
|
|
VectorF16<1, 32> lenght = LengthNoShuffle(A, E);
|
|
constexpr _Float16 oneArr[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
__m512h one = _mm512_loadu_ph(oneArr);
|
|
VectorF16<32, 1> fLenght(_mm512_div_ph(one, lenght.v));
|
|
|
|
VectorF16<32, 1> fLenghtA = fLenght.template Shuffle<{{0,0,1,1,2,2,3,3,8,8,9,9,10,10,11,11,16,16,17,17,18,18,19,19,24,24,25,25,26,26,27,27}}>();
|
|
VectorF16<32, 1> fLenghtE = fLenght.template Shuffle<{{4,4,5,5,6,6,7,7,12,12,13,13,14,14,15,15,20,20,21,21,22,22,23,23,28,28,29,29,30,30,31,31}}>();
|
|
|
|
return {
|
|
_mm512_mul_ph(A.v, fLenghtA.v),
|
|
_mm512_mul_ph(E.v, fLenghtE.v),
|
|
};
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*4> Length(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> C,
|
|
VectorF16<Len, Packing> E,
|
|
VectorF16<Len, Packing> G
|
|
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
VectorF16<1, Packing*4> lenghtSq = LengthSq(A, C, E, G);
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v));
|
|
} else {
|
|
return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v));
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*2> Length(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> E
|
|
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
VectorF16<1, Packing*2> lenghtSq = LengthSq(A, E);
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v));
|
|
} else {
|
|
return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v));
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*4> LengthSq(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> C,
|
|
VectorF16<Len, Packing> E,
|
|
VectorF16<Len, Packing> G
|
|
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
return Dot(A, A, C, C, E, E, G, G);
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*2> LengthSq(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> E
|
|
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
return Dot(A, A, E, E);
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*4> Dot(
|
|
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
|
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
|
|
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
|
|
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1
|
|
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return DotNoShuffle(A0, A1, E0, E1, C0, C1, G0, G1);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
VectorF16<16, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v);
|
|
vec = vec.template Shuffle<{{
|
|
0,1,8,9,
|
|
4,5,12,13,
|
|
2,3,10,11,
|
|
6,7,14,15
|
|
}}>();
|
|
return vec.v;
|
|
} else {
|
|
VectorF16<32, 1> vec(DotNoShuffle(A0, A1, C0, C1, E0, E1, G0, G1).v);
|
|
vec = vec.template Shuffle<{{
|
|
0,1,8,9,
|
|
16,17,24,25,
|
|
|
|
4,5,12,13,
|
|
20,21,28,29,
|
|
|
|
2,3,10,11,
|
|
18,19,24,25,
|
|
|
|
6,7,14,15,
|
|
22,23,30,31
|
|
}}>();
|
|
return vec.v;
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*2> Dot(
|
|
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
|
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1
|
|
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return DotNoShuffle(A0, A1, E0, E1);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
VectorF16<16, 1> vec(DotNoShuffle(A0, A1, E0, E1).v);
|
|
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,4,5,6,7,12,13,14,15}}>();
|
|
return vec.v;
|
|
} else {
|
|
VectorF16<32, 1> vec(DotNoShuffle(A0, A1, E0, E1).v);
|
|
vec = vec.template Shuffle<{{0,1,2,3,8,9,10,11,16,17,18,19,24,25,26,27,4,5,6,7,12,13,14,15,20,21,22,23,28,29,30,31}}>();
|
|
return vec.v;
|
|
}
|
|
}
|
|
|
|
|
|
private:
|
|
constexpr static VectorF16<1, Packing*4> LengthNoShuffle(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> C,
|
|
VectorF16<Len, Packing> E,
|
|
VectorF16<Len, Packing> G
|
|
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
VectorF16<1, Packing*4> lenghtSq = LengthSqNoShuffle(A, C, E, G);
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<1, Packing*4>(_mm_sqrt_ph(lenghtSq.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<1, Packing*4>(_mm256_sqrt_ph(lenghtSq.v));
|
|
} else {
|
|
return VectorF16<1, Packing*4>(_mm512_sqrt_ph(lenghtSq.v));
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*2> LengthNoShuffle(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> E
|
|
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
VectorF16<1, Packing*2> lenghtSq = LengthSqNoShuffle(A, E);
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return VectorF16<1, Packing*2>(_mm_sqrt_ph(lenghtSq.v));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
return VectorF16<1, Packing*2>(_mm256_sqrt_ph(lenghtSq.v));
|
|
} else {
|
|
return VectorF16<1, Packing*2>(_mm512_sqrt_ph(lenghtSq.v));
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*4> LengthSqNoShuffle(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> C,
|
|
VectorF16<Len, Packing> E,
|
|
VectorF16<Len, Packing> G
|
|
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
return DotNoShuffle(A, A, C, C, E, E, G, G);
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*2> LengthSqNoShuffle(
|
|
VectorF16<Len, Packing> A,
|
|
VectorF16<Len, Packing> E
|
|
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
return DotNoShuffle(A, A, E, E);
|
|
}
|
|
|
|
|
|
constexpr static VectorF16<1, Packing*4> DotNoShuffle(
|
|
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
|
VectorF16<Len, Packing> C0, VectorF16<Len, Packing> C1,
|
|
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1,
|
|
VectorF16<Len, Packing> G0, VectorF16<Len, Packing> G1
|
|
) requires(Len == 4 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
|
|
__m128h mulC = _mm_mul_ph(C0.v, C1.v);
|
|
|
|
__m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
|
|
__m128i row34Temp1 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
|
|
__m128i row1TempTemp1 = row12Temp1;
|
|
|
|
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
|
|
__m128h mulG = _mm_mul_ph(G0.v, G1.v);
|
|
|
|
__m128i row12Temp2 = _mm_unpacklo_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
|
|
__m128i row12Temp2Temp = row12Temp2;
|
|
__m128i row34Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulE), _mm_castph_si128(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
|
|
|
|
row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
|
|
row12Temp2 = _mm_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
|
|
row34Temp2 = _mm_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
|
|
row34Temp1 = _mm_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4
|
|
|
|
__m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
|
|
__m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
|
|
__m128h row3 = _mm_castsi128_ph(_mm_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3
|
|
__m128h row4 = _mm_castsi128_ph(_mm_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4
|
|
|
|
row1 = _mm_add_ph(row1, row2);
|
|
row1 = _mm_add_ph(row1, row3);
|
|
row1 = _mm_add_ph(row1, row4);
|
|
|
|
return row1;
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
|
|
__m256h mulC = _mm256_mul_ph(C0.v, C1.v);
|
|
|
|
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
|
|
__m256i row34Temp1 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
|
|
__m256i row1TempTemp1 = row12Temp1;
|
|
|
|
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
|
|
__m256h mulG = _mm256_mul_ph(G0.v, G1.v);
|
|
|
|
__m256i row12Temp2 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
|
|
__m256i row12Temp2Temp = row12Temp2;
|
|
__m256i row34Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulE), _mm256_castph_si256(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
|
|
|
|
row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
|
|
row12Temp2 = _mm256_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
|
|
row34Temp2 = _mm256_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
|
|
row34Temp1 = _mm256_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4
|
|
|
|
__m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
|
|
__m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
|
|
__m256h row3 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3
|
|
__m256h row4 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4
|
|
|
|
row1 = _mm256_add_ph(row1, row2);
|
|
row1 = _mm256_add_ph(row1, row3);
|
|
row1 = _mm256_add_ph(row1, row4);
|
|
|
|
return row1;
|
|
} else {
|
|
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
|
|
__m512h mulC = _mm512_mul_ph(C0.v, C1.v);
|
|
|
|
__m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // A1 C1 A2 C2 A3 C3 A4 C4
|
|
__m512i row34Temp1 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulC)); // B1 D1 B2 D2 B3 D3 B4 D4
|
|
__m512i row1TempTemp1 = row12Temp1;
|
|
|
|
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
|
|
__m512h mulG = _mm512_mul_ph(G0.v, G1.v);
|
|
|
|
__m512i row12Temp2 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // E1 G1 E2 G2 E3 G3 E4 G4
|
|
__m512i row12Temp2Temp = row12Temp2;
|
|
__m512i row34Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulE), _mm512_castph_si512(mulG)); // F1 H1 F2 H2 F3 H3 F4 H4
|
|
|
|
row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 E1 C1 G1 A2 E2 C2 G2
|
|
row12Temp2 = _mm512_unpacklo_epi16(row34Temp1, row34Temp2); // B1 F1 D1 H1 B2 F2 D2 H2
|
|
row34Temp2 = _mm512_unpackhi_epi16(row34Temp1, row34Temp2); // B3 F3 D3 H3 B4 F4 D4 H4
|
|
row34Temp1 = _mm512_unpackhi_epi16(row1TempTemp1, row12Temp2Temp); // A3 E3 C3 G3 A4 E4 C4 G4
|
|
|
|
__m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 E1 F1 C1 D1 G1 H1
|
|
__m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 E2 F2 C2 D2 G2 H2
|
|
__m512h row3 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row34Temp1, row34Temp2));// A3 B3 E3 F3 C3 D3 G3 H3
|
|
__m512h row4 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row34Temp1, row34Temp2));// A4 B4 E4 F4 C4 D4 G4 H4
|
|
|
|
row1 = _mm512_add_ph(row1, row2);
|
|
row1 = _mm512_add_ph(row1, row3);
|
|
row1 = _mm512_add_ph(row1, row4);
|
|
|
|
return row1;
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<1, Packing*2> DotNoShuffle(
|
|
VectorF16<Len, Packing> A0, VectorF16<Len, Packing> A1,
|
|
VectorF16<Len, Packing> E0, VectorF16<Len, Packing> E1
|
|
) requires(Len == 2 && Packing*Len == VectorBase<Len, Packing, _Float16>::AlignmentElement) {
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
__m128h mulA = _mm_mul_ph(A0.v, A1.v);
|
|
__m128h mulE = _mm_mul_ph(E0.v, E1.v);
|
|
__m128i row12Temp1 = _mm_unpacklo_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2
|
|
__m128i row12Temp2 = _mm_unpackhi_epi16(_mm_castph_si128(mulA), _mm_castph_si128(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2
|
|
__m128i row12Temp1Temp = row12Temp1;
|
|
|
|
row12Temp1 = _mm_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2
|
|
row12Temp2 = _mm_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2
|
|
|
|
__m128h row1 = _mm_castsi128_ph(_mm_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
|
|
__m128h row2 = _mm_castsi128_ph(_mm_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
|
|
|
|
return _mm_add_ph(row1, row2);
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
__m256h mulA = _mm256_mul_ph(A0.v, A1.v);
|
|
__m256h mulE = _mm256_mul_ph(E0.v, E1.v);
|
|
|
|
__m256i row12Temp1 = _mm256_unpacklo_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2
|
|
|
|
__m256i row12Temp2 = _mm256_unpackhi_epi16(_mm256_castph_si256(mulA), _mm256_castph_si256(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2
|
|
__m256i row12Temp1Temp = row12Temp1;
|
|
|
|
row12Temp1 = _mm256_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2
|
|
row12Temp2 = _mm256_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2
|
|
|
|
__m256h row1 = _mm256_castsi256_ph(_mm256_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
|
|
__m256h row2 = _mm256_castsi256_ph(_mm256_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
|
|
__m256h result = _mm256_add_ph(row1, row2);
|
|
|
|
return result;
|
|
} else {
|
|
__m512h mulA = _mm512_mul_ph(A0.v, A1.v);
|
|
__m512h mulE = _mm512_mul_ph(E0.v, E1.v);
|
|
__m512i row12Temp1 = _mm512_unpacklo_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // A1 E1 A2 E2 B1 F1 B2 F2
|
|
__m512i row12Temp2 = _mm512_unpackhi_epi16(_mm512_castph_si512(mulA), _mm512_castph_si512(mulE)); // C1 G1 C2 G2 D1 H1 D2 H2
|
|
__m512i row12Temp1Temp = row12Temp1;
|
|
|
|
row12Temp1 = _mm512_unpacklo_epi16(row12Temp1, row12Temp2); // A1 C1 E1 G1 A2 C2 E2 G2
|
|
row12Temp2 = _mm512_unpackhi_epi16(row12Temp1Temp, row12Temp2); // B1 D1 F1 H1 B2 D2 F2 H2
|
|
|
|
__m512h row1 = _mm512_castsi512_ph(_mm512_unpacklo_epi16(row12Temp1, row12Temp2));// A1 B1 C1 D1 E1 F1 G1 H1
|
|
__m512h row2 = _mm512_castsi512_ph(_mm512_unpackhi_epi16(row12Temp1, row12Temp2));// A2 B2 C2 D2 E2 F2 G2 H2
|
|
__m512h result = _mm512_add_ph(row1, row2);
|
|
|
|
return result;
|
|
}
|
|
}
|
|
public:
|
|
|
|
template <std::array<bool, Len> ShuffleValues>
|
|
constexpr static VectorF16<Len, Packing> Blend(VectorF16<Len, Packing> a, VectorF16<Len, Packing> b) {
|
|
constexpr auto mask = VectorBase<Len, Packing, _Float16>::template GetBlendMaskEpi16<ShuffleValues>();
|
|
if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m128h>) {
|
|
return _mm_castsi128_ph(_mm_blend_epi16(_mm_castph_si128(a.v), _mm_castph_si128(b.v), mask));
|
|
} else if constexpr(std::is_same_v<typename VectorBase<Len, Packing, _Float16>::VectorType, __m256h>) {
|
|
#ifndef __AVX512BW__
|
|
#ifndef __AVX512VL__
|
|
static_assert(false, "No __AVX512BW__ and __AVX512VL__ support");
|
|
#endif
|
|
#endif
|
|
return _mm256_castsi256_ph(_mm256_mask_blend_epi16(mask, _mm256_castph_si256(a.v), _mm256_castph_si256(b.v)));
|
|
} else {
|
|
return _mm512_castsi512_ph(_mm512_mask_blend_epi16(mask, _mm512_castph_si512(a.v), _mm512_castph_si512(b.v)));
|
|
}
|
|
}
|
|
|
|
constexpr static VectorF16<Len, Packing> Rotate(VectorF16<3, Packing> v, VectorF16<4, Packing> q) requires(Len == 3) {
|
|
VectorF16<3, Packing> qv(q);
|
|
VectorF16<Len, Packing> t = Cross(qv, v) * _Float16(2);
|
|
return v + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t);
|
|
}
|
|
|
|
constexpr static VectorF16<4, 2> RotatePivot(VectorF16<3, Packing> v, VectorF16<4, Packing> q, VectorF16<3, Packing> pivot) requires(Len == 3) {
|
|
VectorF16<Len, Packing> translated = v - pivot;
|
|
VectorF16<3, Packing> qv(q.v);
|
|
VectorF16<Len, Packing> t = Cross(qv, translated) * _Float16(2);
|
|
VectorF16<Len, Packing> rotated = translated + t * q.template Shuffle<{{3,3,3,3}}>() + Cross(qv, t);
|
|
return rotated + pivot;
|
|
}
|
|
|
|
constexpr static VectorF16<4, Packing> QuanternionFromEuler(VectorF16<3, Packing> EulerHalf) requires(Len == 4) {
|
|
std::tuple<VectorF16<3, Packing>, VectorF16<3, Packing>> sinCos = EulerHalf.SinCos();
|
|
VectorF16<4, Packing> sin = std::get<0>(sinCos);
|
|
VectorF16<4, Packing> cos = std::get<1>(sinCos);
|
|
|
|
VectorF16<4, Packing> row1 = cos.template Shuffle<{{0,0,0,0}}>();
|
|
row1 = Blend<{{0,1,1,1}}>(sin, row1);
|
|
|
|
VectorF16<4, Packing> row2 = cos.template Shuffle<{{1,1,1,1}}>();
|
|
row2 = Blend<{{1,0,1,1}}>(sin, row2);
|
|
|
|
row1 *= row2;
|
|
|
|
VectorF16<4, Packing> row3 = cos.template Shuffle<{{2,2,2,2}}>();
|
|
row3 = Blend<{{1,1,0,1}}>(sin, row3);
|
|
|
|
row1 *= row3;
|
|
|
|
VectorF16<4, Packing> row4 = sin.template Shuffle<{{0,0,0,0}}>();
|
|
row4 = Blend<{{0,1,1,1}}>(cos, row4);
|
|
|
|
VectorF16<4, Packing> row5 = sin.template Shuffle<{{1,1,1,1}}>();
|
|
row5 = Blend<{{1,0,1,1}}>(cos, row5);
|
|
|
|
row4 *= row5;
|
|
|
|
VectorF16<4, Packing> row6 = sin.template Shuffle<{{2,2,2,2}}>();
|
|
row6 = Blend<{{1,1,0,1}}>(cos, row6);
|
|
row6 = row6.template Negate<{{true,false,true,false}}>();
|
|
|
|
row1 = MulitplyAdd(row4, row6, row1);
|
|
|
|
return row1;
|
|
}
|
|
};
|
|
}
|
|
|
|
|
|
export template <std::uint32_t Len, std::uint32_t Packing>
|
|
struct std::formatter<Crafter::VectorF16<Len, Packing>> : std::formatter<std::string> {
|
|
constexpr auto format(const Crafter::VectorF16<Len, Packing>& obj, format_context& ctx) const {
|
|
std::array<_Float16, Crafter::VectorF16<Len, Packing>::AlignmentElement> vec = obj.Store();
|
|
std::string out = "{";
|
|
for(std::uint32_t i = 0; i < Packing; i++) {
|
|
out += "{";
|
|
for(std::uint32_t i2 = 0; i2 < Len; i2++) {
|
|
out += std::format("{}", static_cast<float>(vec[i * Len + i2]));
|
|
if (i2 + 1 < Len) out += ",";
|
|
}
|
|
out += "}";
|
|
}
|
|
out += "}";
|
|
return std::formatter<std::string>::format(out, ctx);
|
|
}
|
|
};
|
|
#endif |