more F16 math

This commit is contained in:
Jorijn van der Graaf 2026-03-22 03:51:09 +01:00
commit 1544e92391
2 changed files with 306 additions and 135 deletions

View file

@ -57,36 +57,34 @@ namespace Crafter {
constexpr VectorF16() = default;
constexpr VectorF16(VectorType v) : v(v) {}
template <std::uint32_t VLen, std::uint32_t VAlign>
constexpr VectorF16(const Vector<_Float16, VLen, VAlign>* vA) requires(VAlign != 0 || VLen >= GetTotalSize()) {
constexpr VectorF16(const _Float16* vB) {
Load(vB);
};
constexpr VectorF16(_Float16 val) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_loadu_ph(vA->v);
v = _mm_set1_ph(val);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
v = _mm256_loadu_ph(vA->v);
v = _mm256_set1_ph(val);
} else {
v = _mm512_loadu_ph(vA->v);
v = _mm512_set1_ph(val);
}
};
template <std::uint32_t VLen, std::uint32_t VAlign>
constexpr void Load(const Vector<_Float16, VLen, VAlign>* vA) {
constexpr void Load(const _Float16* vB) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_loadu_ph(vA->v);
v = _mm_loadu_ph(vB);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
v = _mm256_loadu_ph(vA->v);
v = _mm256_loadu_ph(vB);
} else {
v = _mm512_loadu_ph(vA->v);
v = _mm512_loadu_ph(vB);
}
}
template <std::uint32_t VLen, std::uint32_t VAlign>
constexpr void Store(Vector<_Float16, VLen, VAlign>* vA) const {
constexpr void Store(const _Float16* vB) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
_mm_storeu_ph(vA->v, v);
_mm_storeu_ph(vB, v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
_mm256_storeu_ph(vA->v, v);
_mm256_storeu_ph(vB, v);
} else {
_mm512_storeu_ph(vA->v, v);
_mm512_storeu_ph(vB, v);
}
}
@ -157,7 +155,7 @@ namespace Crafter {
}
constexpr VectorF16<Len, Packing, Repeats> operator+=(VectorF16<Len, Packing, Repeats> b) const {
constexpr void operator+=(VectorF16<Len, Packing, Repeats> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_add_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -167,7 +165,7 @@ namespace Crafter {
}
}
constexpr VectorF16<Len, Packing, Repeats> operator-=(VectorF16<Len, Packing, Repeats> b) const {
constexpr void operator-=(VectorF16<Len, Packing, Repeats> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_sub_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -177,7 +175,7 @@ namespace Crafter {
}
}
constexpr VectorF16<Len, Packing, Repeats> operator*=(VectorF16<Len, Packing, Repeats> b) const {
constexpr void operator*=(VectorF16<Len, Packing, Repeats> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_mul_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -187,7 +185,7 @@ namespace Crafter {
}
}
constexpr VectorF16<Len, Packing, Repeats> operator/=(VectorF16<Len, Packing, Repeats> b) const {
constexpr void operator/=(VectorF16<Len, Packing, Repeats> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
v = _mm_div_ph(v, b.v);
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -197,6 +195,46 @@ namespace Crafter {
}
}
constexpr VectorF16<Len, Packing, Repeats> operator+(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
return this + vB;
}
constexpr VectorF16<Len, Packing, Repeats> operator-(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
return this - vB;
}
constexpr VectorF16<Len, Packing, Repeats> operator*(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
return this * vB;
}
constexpr VectorF16<Len, Packing, Repeats> operator/(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
return this / vB;
}
constexpr void operator+=(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
this += vB;
}
constexpr void operator-=(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
this -= vB;
}
constexpr void operator*=(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
this *= vB;
}
constexpr void operator/=(_Float16 b) const {
VectorF16<Len, Packing, Repeats> vB(b);
this /= vB;
}
constexpr VectorF16<Len, Packing, Repeats> operator-(){
if constexpr(std::is_same_v<VectorType, __m128h>) {
alignas(16) constexpr std::uint64_t mask[] {0b1000000000000000100000000000000010000000000000001000000000000000, 0b1000000000000000100000000000000010000000000000001000000000000000};
@ -223,8 +261,7 @@ namespace Crafter {
}
}
template <typename BT, std::uint32_t Blen, std::uint32_t BAlignment>
constexpr bool operator!=(Vector<BT, Blen, BAlignment> b) const {
constexpr bool operator!=(VectorF16<Len, Packing, Repeats> b) const {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return _mm_cmp_ph_mask(v, b.v, _CMP_EQ_OQ) != 255;
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
@ -253,7 +290,6 @@ namespace Crafter {
}
}
constexpr _Float16 Length() const {
_Float16 Result = LengthSq();
return std::sqrtf(Result);
@ -263,6 +299,122 @@ namespace Crafter {
return Dot(*this, *this);
}
constexpr VectorF16<Len, Packing, Repeats> Cos() requires(Len == 3) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm_cos_ph(v));
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm256_cos_ph(v));
} else {
return VectorF16<Len, Packing, Repeats>(_mm512_cos_ph(v));
}
}
constexpr VectorF16<Len, Packing, Repeats> Sin() requires(Len == 3) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm_sin_ph(v));
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm256_sin_ph(v));
} else {
return VectorF16<Len, Packing, Repeats>(_mm512_sin_ph(v));
}
}
template <std::uint8_t A, std::uint8_t B, std::uint8_t C, std::uint8_t D, std::uint8_t E, std::uint8_t F, std::uint8_t G, std::uint8_t H>
constexpr VectorF16<Len, Packing, Repeats> Shuffle() {
if constexpr(A == B-1 && C == D-1 && E == F-1 && G == H-1) {
constexpr std::uint32_t val =
(A & 0x3) |
((B & 0x3) << 2) |
((C & 0x3) << 4) |
((D & 0x3) << 6) |
((E & 0x3) << 8) |
((F & 0x3) << 10) |
((G & 0x3) << 12) |
((H & 0x3) << 14);
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm_castsi128_ph(_mm_shuffle_epi32(_mm_castph_si128(v), val)));
} else if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm256_castsi256_ph(_mm256_shuffle_epi32(_mm256_castph_si256(v), val)));
} else {
return VectorF16<Len, Packing, Repeats>(_mm512_castsi512_ph(_mm512_shuffle_epi32(_mm_512castph_si512(v), val)));
}
} else {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::uint8_t shuffleMask[] {
A,A,B,B,C,C,D,D,E,E,F,F,G,G,H,H
};
__m128h shuffleVec = _mm_loadu_epi8(shuffleMask);
return VectorF16<Len, Packing, Repeats>(_mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(v), shuffleVec)));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::uint8_t shuffleMask[] {
A,A,B,B,C,C,D,D,E,E,F,F,G,G,H,H,
A+16,A+16,B+16,B+16,C+16,C+16,D+16,D+16,E+16,E+16,F+16,F+16,G+16,G+16,H+16,H+16,
};
__m256h shuffleVec = _mm256_loadu_epi8(shuffleMask);
return VectorF16<Len, Packing, Repeats>(_mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(v), shuffleVec)));
} else {
constexpr std::uint8_t shuffleMask[] {
A,A,B,B,C,C,D,D,E,E,F,F,G,G,H,H,
A+16,A+16,B+16,B+16,C+16,C+16,D+16,D+16,E+16,E+16,F+16,F+16,G+16,G+16,H+16,H+16,
A+32,A+32,B+32,B+32,C+32,C+32,D+32,D+32,E+32,E+32,F+32,F+32,G+32,G+32,H+32,H+32,
A+48,A+48,B+48,B+48,C+48,C+48,D+48,D+48,E+48,E+48,F+48,F+48,G+48,G+48,H+48,H+48,
};
__m512h shuffleVec = _mm512_loadu_epi8(shuffleMask);
return VectorF16<Len, Packing, Repeats>(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec)));
}
}
}
template <
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, std::uint8_t E0, std::uint8_t F0, std::uint8_t G0, std::uint8_t H0,
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1, std::uint8_t E1, std::uint8_t F1, std::uint8_t G1, std::uint8_t H1
>
constexpr VectorF16<Len, Packing, Repeats> Shuffle() requires(Repeats == 2) {
constexpr std::uint8_t shuffleMask[] {
A0,A0,B0,B0,C0,C0,D0,D0,E0,E0,F0,F0,G0,G0,H0,H0,
A1,A1,B1,B1,C1,C1,D1,D1,E1,E1,F1,F1,G1,G1,H1,H1,
};
__m256h shuffleVec = _mm256_loadu_epi8(shuffleMask);
return VectorF16<Len, Packing, Repeats>(_mm256_castsi256_ph(_mm256_shuffle_epi8(_mm256_castph_si256(v), shuffleVec)));
}
template <
std::uint8_t A0, std::uint8_t B0, std::uint8_t C0, std::uint8_t D0, std::uint8_t E0, std::uint8_t F0, std::uint8_t G0, std::uint8_t H0,
std::uint8_t A1, std::uint8_t B1, std::uint8_t C1, std::uint8_t D1, std::uint8_t E1, std::uint8_t F1, std::uint8_t G1, std::uint8_t H1,
std::uint8_t A2, std::uint8_t B2, std::uint8_t C2, std::uint8_t D2, std::uint8_t E2, std::uint8_t F2, std::uint8_t G2, std::uint8_t H2,
std::uint8_t A3, std::uint8_t B3, std::uint8_t C3, std::uint8_t D3, std::uint8_t E3, std::uint8_t F3, std::uint8_t G3, std::uint8_t H3
>
constexpr VectorF16<Len, Packing, Repeats> Shuffle() requires(Repeats == 4) {
constexpr std::uint8_t shuffleMask[] {
A0,A0,B0,B0,C0,C0,D0,D0,E0,E0,F0,F0,G0,G0,H0,H0,
A1,A1,B1,B1,C1,C1,D1,D1,E1,E1,F1,F1,G1,G1,H1,H1,
A2,A2,B2,B2,C2,C2,D2,D2,E2,E2,F2,F2,G2,G2,H2,H2,
A3,A3,B3,B3,C3,C3,D3,D3,E3,E3,F3,F3,G3,G3,H3,H3,
};
__m512h shuffleVec = _mm512_loadu_epi8(shuffleMask);
return VectorF16<Len, Packing, Repeats>(_mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(v), shuffleVec)));
}
static constexpr VectorF16<Len, Packing, Repeats> MulitplyAdd(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b, VectorF16<Len, Packing, Repeats> add) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm_fmadd_ph(a, b, add));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing, Repeats>(_mm256_fmadd_ph(a, b, add));
} else {
return VectorF16<Len, Packing, Repeats>(_mm512_fmadd_ph(a, b, add));
}
}
static constexpr VectorF16<Len, Packing, Repeats> MulitplySub(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b, VectorF16<Len, Packing, Repeats> sub) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
return VectorF16<Len, Packing, Repeats>(_mm_fmsub_ph(a, b, sub));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
return VectorF16<Len, Packing, Repeats>(_mm256_fmsub_ph(a, b, sub));
} else {
return VectorF16<Len, Packing, Repeats>(_mm512_fmsub_ph(a, b, sub));
}
}
constexpr static VectorF16<Len, Packing, Repeats> Cross(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b) requires(Len == 3 && Packing == 2) {
if constexpr(Len == 3) {
if constexpr(Repeats == 1) {
@ -342,6 +494,7 @@ namespace Crafter {
return _mm512_reduce_add_ph(mul);
}
}
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
VectorF16<Len, Packing, Repeats> A,
@ -1304,88 +1457,106 @@ namespace Crafter {
return _mm512_add_ph(row1, row2);
}
}
template <std::uint8_t A, std::uint8_t B, std::uint8_t C, std::uint8_t D, std::uint8_t E, std::uint8_t F, std::uint8_t G, std::uint8_t H>
constexpr static VectorF16<Len, Packing, Repeats> Blend(VectorF16<Len, Packing, Repeats> a, VectorF16<Len, Packing, Repeats> b) {
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::uint8_t val =
(A & 1) |
((B & 1) << 1) |
((C & 1) << 2) |
((D & 1) << 3) |
((E & 1) << 4) |
((F & 1) << 5) |
((G & 1) << 6) |
((H & 1) << 7);
return _mm_castsi128_ph(_mm_blend_epi16(_mm_castph_si128(a.v), _mm_castph_si128(b), val));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::uint8_t val =
(A & 1) |
((B & 1) << 1) |
((C & 1) << 2) |
((D & 1) << 3) |
((E & 1) << 4) |
((F & 1) << 5) |
((G & 1) << 6) |
((H & 1) << 7);
return _mm256_castsi256_ph(_mm256_blend_epi16(_mm256_castph_si256(a.v), _mm256_castph_si256(b), val));
} else {
constexpr std::uint8_t byte =
(A & 1) |
((B & 1) << 1) |
((C & 1) << 2) |
((D & 1) << 3) |
((E & 1) << 4) |
((F & 1) << 5) |
((G & 1) << 6) |
((H & 1) << 7);
// constexpr static VectorF16<Len, Packing, Repeats> Rotate(VectorF16<3, Packing, Repeats> v, VectorF16<4, Packing, Repeats> q) requires(Len == 3) {
// Vector<T, 3, 0> qv(q.x, q.y, q.z);
// Vector<T, 3, 0> t = Vector<T, 3, Aligment>::Cross(qv, v) * T(2);
// return v + t * q.w + Vector<T, 3, Aligment>::Cross(qv, t);
// }
constexpr std::uint32_t val = byte * 0x01010101u;
return _mm512_castsi512_ph(_mm512_mask_blend_epi16(val, _mm512_castph_si512(a.v), _mm512_castph_si512(b)));
}
}
// template <typename AT, std::uint32_t AAlignment, typename BT, std::uint32_t BAlignment, typename PT, std::uint32_t PAlignment>
// constexpr static Vector<T, 3, Aligment> RotatePivot(Vector<AT, 3, AAlignment> v, Vector<BT, 4, BAlignment> q, Vector<PT, 3, PAlignment> pivot) requires(Len == 3) {
// Vector<T, 3, 0> translated = v - pivot;
// Vector<T, 3, 0> qv(q.x, q.y, q.z);
// Vector<T, 3, 0> t = Cross(qv, translated) * T(2);
// Vector<T, 3, 0> rotated = translated + t * q.w +Cross(qv, t);
// return rotated + pivot;
// }
constexpr static VectorF16<Len, Packing, Repeats> Rotate(VectorF16<3, 2, Repeats> v, VectorF16<4, 2, Repeats> q) requires(Len == 3 && Packing == 2) {
VectorF16<3, 2, Repeats> qv(q.v);
VectorF16<Len, Packing, Repeats> t = Cross(qv, v) * _Float16(2);
return v + t * q.template Shuffle<3,3,3,3,3,3,3,3>(); + Cross(qv, t);
}
// template <typename AT, std::uint32_t AAlignment, typename BT, std::uint32_t BAlignment, typename CT, std::uint32_t CAlignment>
// constexpr static Vector<T, 4, Aligment> QuanternionFromBasis(Vector<AT, 3, AAlignment> right, Vector<BT, 3, BAlignment> up, Vector<CT, 3, CAlignment> forward) requires(Len == 4) {
// T m00 = right.x;
// T m01 = up.x;
// T m02 = forward.x;
constexpr static VectorF16<4, 2, Repeats> RotatePivot(VectorF16<3, 2, Repeats> v, VectorF16<4, 2, Repeats> q, VectorF16<3, 2, Repeats> pivot) requires(Len == 3 && Packing == 2) {
VectorF16<Len, Packing, Repeats> translated = v - pivot;
VectorF16<3, 2, Repeats> qv(q.v);
VectorF16<Len, Packing, Repeats> t = Cross(qv, translated) * _Float16(2);
VectorF16<Len, Packing, Repeats> rotated = translated + t * q.template Shuffle<3,3,3,3,3,3,3,3>() + Cross(qv, t);
return rotated + pivot;
}
// T m10 = right.y;
// T m11 = up.y;
// T m12 = forward.y;
constexpr static VectorF16<4, 2, Repeats> QuanternionFromEuler(VectorF16<3, 2, Repeats> EulerHalf) requires(Len == 3 && Packing == 2) {
VectorF16<3, 2, Repeats> sin = EulerHalf.Sin();
VectorF16<3, 2, Repeats> cos = EulerHalf.Cos();
// T m20 = right.z;
// T m21 = up.z;
// T m22 = forward.z;
VectorF16<3, 2, Repeats> row1 = cos.template Shuffle<0,0,0,0,4,4,4,4>();
row1 = VectorF16<3, 2, Repeats>::Blend<0,1,1,1, 0,1,1,1>(sin, row1);
// T trace = m00 + m11 + m22;
VectorF16<3, 2, Repeats> row2 = cos.template Shuffle<1,1,1,1,5,5,5,5>();
row2 = VectorF16<3, 2, Repeats>::Blend<1,0,1,1, 1,0,1,1>(sin, row2);
// Vector<T, 4, Aligment> q;
row1 = row2;
// if (trace > std::numeric_limits<T>::epsilon()) {
// T s = std::sqrt(trace + T(1)) * T(2);
// q.w = T(0.25) * s;
// q.x = (m21 - m12) / s;
// q.y = (m02 - m20) / s;
// q.z = (m10 - m01) / s;
// }
// else if ((m00 > m11) && (m00 > m22)) {
// T s = std::sqrt(T(1) + m00 - m11 - m22) * T(2);
// q.w = (m21 - m12) / s;
// q.x = T(0.25) * s;
// q.y = (m01 + m10) / s;
// q.z = (m02 + m20) / s;
// }
// else if (m11 > m22) {
// T s = std::sqrt(T(1) + m11 - m00 - m22) * T(2);
// q.w = (m02 - m20) / s;
// q.x = (m01 + m10) / s;
// q.y = T(0.25) * s;
// q.z = (m12 + m21) / s;
// }
// else {
// T s = std::sqrt(T(1) + m22 - m00 - m11) * T(2);
// q.w = (m10 - m01) / s;
// q.x = (m02 + m20) / s;
// q.y = (m12 + m21) / s;
// q.z = T(0.25) * s;
// }
VectorF16<3, 2, Repeats> row3 = cos.template Shuffle<2,2,2,2,6,6,6,6>();
row3 = VectorF16<3, 2, Repeats>::Blend<1,1,0,1, 1,1,0,1>(sin, row3);
// q.Normalize();
// return q;
// }
VectorF16<3, 2, Repeats> row4 = sin.template Shuffle<0,0,0,0,4,4,4,4>();
row4 = VectorF16<3, 2, Repeats>::Blend<1,0,0,0, 1,0,0,0>(sin, row4);
// constexpr static Vector<T, 4, Aligment> QuanternionFromEuler(T roll, T pitch, T yaw) {
// T cr = std::cos(roll * 0.5);
// T sr = std::sin(roll * 0.5);
// T cp = std::cos(pitch * 0.5);
// T sp = std::sin(pitch * 0.5);
// T cy = std::cos(yaw * 0.5);
// T sy = std::sin(yaw * 0.5);
// return Vector<T, 4, Aligment>(
// sr * cp * cy - cr * sp * sy,
// cr * sp * cy + sr * cp * sy,
// cr * cp * sy - sr * sp * cy,
// cr * cp * cy + sr * sp * sy
// );
// }
if constexpr(std::is_same_v<VectorType, __m128h>) {
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
__m128i sign_mask = _mm_load_si128(reinterpret_cast<const __m128i*>(mask));
row4.v = (_mm_castsi128_ph(_mm_xor_si128(sign_mask, _mm_castph_si128(row4.v))));
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
__m256i sign_mask = _mm256_load_si256(reinterpret_cast<const __m256i*>(mask));
row4.v = (_mm256_castsi256_ph(_mm256_xor_si256(sign_mask, _mm256_castph_si256(row4.v))));
} else {
constexpr std::uint64_t mask[] {0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000, 0b1000000000000000000000000000000010000000000000000000000000000000};
__m512i sign_mask = _mm512_load_si512(reinterpret_cast<const __m256i*>(mask));
row4.v = (_mm512_castsi512_ph(_mm512_xor_si512(sign_mask, _mm512_castph_si512(row4.v))));
}
row1 = MulitplyAdd(row1, row3, row4);
VectorF16<3, 2, Repeats> row5 = sin.template Shuffle<1,1,1,1,5,5,5,5>();
row5 = VectorF16<3, 2, Repeats>::Blend<0,1,0,0, 0,1,0,0>(sin, row5);
row1 *= row5;
VectorF16<3, 2, Repeats> row6 = sin.template Shuffle<2,2,2,2,6,6,6,6>();
row6 = VectorF16<3, 2, Repeats>::Blend<0,0,1,0, 0,0,1,0>(sin, row6);
return row1 * row6;
}
};
}

View file

@ -30,38 +30,38 @@ int main() {
// std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end-start) << std::endl;
// std::println("{}", vfC);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dist(0, 100);
// std::random_device rd;
// std::mt19937 gen(rd());
// std::uniform_real_distribution<float> dist(0, 100);
Vector<_Float16, 32, 32> vA;
for(std::uint32_t i = 0; i < 32; i++) {
vA.v[i] = dist(gen);
}
// Vector<_Float16, 32, 32> vA;
// for(std::uint32_t i = 0; i < 32; i++) {
// vA.v[i] = dist(gen);
// }
std::string log;
std::chrono::duration<double> totalVector(0);
std::tuple<VectorF16<4, 2, 4>, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>> vfA {VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA)};
for(std::uint32_t i = 0; i < 1000000; i++) {
auto start = std::chrono::high_resolution_clock::now();
vfA = VectorF16<4, 2, 4>::Normalize(std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA));
auto end = std::chrono::high_resolution_clock::now();
totalVector += end-start;
}
// std::string log;
// std::chrono::duration<double> totalVector(0);
// std::tuple<VectorF16<4, 2, 4>, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>> vfA {VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA)};
// for(std::uint32_t i = 0; i < 1000000; i++) {
// auto start = std::chrono::high_resolution_clock::now();
// vfA = VectorF16<4, 2, 4>::Normalize(std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA));
// auto end = std::chrono::high_resolution_clock::now();
// totalVector += end-start;
// }
std::chrono::duration<double> totalScalar(0);
Vector<_Float16, 4, 4> vB;
for(std::uint32_t i = 0; i < 4; i++) {
vB.v[i] = dist(gen);
}
for(std::uint32_t i = 0; i < 1000000; i++) {
auto start2 = std::chrono::high_resolution_clock::now();
vB.Normalize();
auto end2 = std::chrono::high_resolution_clock::now();
totalScalar += end2-start2;
}
// std::chrono::duration<double> totalScalar(0);
// Vector<_Float16, 4, 4> vB;
// for(std::uint32_t i = 0; i < 4; i++) {
// vB.v[i] = dist(gen);
// }
// for(std::uint32_t i = 0; i < 1000000; i++) {
// auto start2 = std::chrono::high_resolution_clock::now();
// vB.Normalize();
// auto end2 = std::chrono::high_resolution_clock::now();
// totalScalar += end2-start2;
// }
std::println("{} {} {} {}", std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA));
std::println("{}", vB);
std::println("Vector: {}, Scalar: {}", std::chrono::duration_cast<std::chrono::milliseconds>(totalVector), std::chrono::duration_cast<std::chrono::milliseconds>(totalScalar*8));
// std::println("{} {} {} {}", std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA));
// std::println("{}", vB);
// std::println("Vector: {}, Scalar: {}", std::chrono::duration_cast<std::chrono::milliseconds>(totalVector), std::chrono::duration_cast<std::chrono::milliseconds>(totalScalar*8));
}