From f1fbbe0fafae9695826802a1750be3b200d91fb1 Mon Sep 17 00:00:00 2001 From: Jorijn van der Graaf Date: Thu, 19 Mar 2026 05:53:17 +0100 Subject: [PATCH] more F16 math --- interfaces/Crafter.Math-Vector.cppm | 4 +- interfaces/Crafter.Math-VectorF16.cppm | 58 ++++++++++++++++++++--- interfaces/main.cpp | 64 ++++++++++++-------------- 3 files changed, 83 insertions(+), 43 deletions(-) diff --git a/interfaces/Crafter.Math-Vector.cppm b/interfaces/Crafter.Math-Vector.cppm index 4934e46..d8618cf 100755 --- a/interfaces/Crafter.Math-Vector.cppm +++ b/interfaces/Crafter.Math-Vector.cppm @@ -318,7 +318,7 @@ namespace Crafter { T fLength = Length(); fLength = 1.0f / fLength; - + for(std::uint32_t i = 0; i < Len; i++) { this->v[i] *= fLength; } @@ -480,7 +480,7 @@ template struct std::formatter> : std::formatter { auto format(const Crafter::Vector& obj, format_context& ctx) const { return std::formatter::format(std::format("{{{}, {}, {}, {}}}", - obj.x, obj.y, obj.z, obj.w + (float)obj.x, (float)obj.y, (float)obj.z, (float)obj.w ), ctx); } }; diff --git a/interfaces/Crafter.Math-VectorF16.cppm b/interfaces/Crafter.Math-VectorF16.cppm index b1f9e9a..a3f1a96 100755 --- a/interfaces/Crafter.Math-VectorF16.cppm +++ b/interfaces/Crafter.Math-VectorF16.cppm @@ -580,7 +580,7 @@ namespace Crafter { } } - constexpr static std::tuple, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16> Normalize( + constexpr static std::tuple, VectorF16, VectorF16, VectorF16> Normalize( VectorF16 A, VectorF16 C, VectorF16 E, @@ -705,15 +705,15 @@ namespace Crafter { __m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG)); return { - _mm512_mul_ph(A.v, fLenghtA), - _mm512_mul_ph(C.v, fLenghtC), - _mm512_mul_ph(E.v, fLenghtE), - _mm512_mul_ph(G.v, fLenghtG), + VectorF16(_mm512_mul_ph(A.v, fLenghtA)), + VectorF16(_mm512_mul_ph(C.v, fLenghtC)), + VectorF16(_mm512_mul_ph(E.v, fLenghtE)), + VectorF16(_mm512_mul_ph(G.v, fLenghtG)), }; } } - constexpr static std::tuple, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16, VectorF16> Normalize( + constexpr static std::tuple, VectorF16> Normalize( VectorF16 A, VectorF16 E ) requires(Packing == 4) { @@ -932,6 +932,36 @@ namespace Crafter { } } + constexpr static VectorF16 Length( + VectorF16 A, + VectorF16 C, + VectorF16 E, + VectorF16 G + ) requires(Packing == 2) { + VectorF16 lenghtSq = LengthSq(A, C, E, G); + if constexpr(std::is_same_v) { + return VectorF16(_mm_sqrt_ph(lenghtSq.v)); + } else if constexpr(std::is_same_v) { + return VectorF16(_mm256_sqrt_ph(lenghtSq.v)); + } else { + return VectorF16(_mm512_sqrt_ph(lenghtSq.v)); + } + } + + constexpr static VectorF16 Length( + VectorF16 A, + VectorF16 E + ) requires(Packing == 2) { + VectorF16 lenghtSq = LengthSq(A, E); + if constexpr(std::is_same_v) { + return VectorF16(_mm_sqrt_ph(lenghtSq.v)); + } else if constexpr(std::is_same_v) { + return VectorF16(_mm256_sqrt_ph(lenghtSq.v)); + } else { + return VectorF16(_mm512_sqrt_ph(lenghtSq.v)); + } + } + constexpr static VectorF16 LengthSq( VectorF16 A, VectorF16 B, @@ -945,6 +975,22 @@ namespace Crafter { return Dot(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H); } + constexpr static VectorF16 LengthSq( + VectorF16 A, + VectorF16 C, + VectorF16 E, + VectorF16 G + ) requires(Packing == 2) { + return Dot(A, A, C, C, E, E, G, G); + } + + constexpr static VectorF16 LengthSq( + VectorF16 A, + VectorF16 E + ) requires(Packing == 4) { + return Dot(A, A, E, E); + } + constexpr static VectorF16 Dot( VectorF16 A0, VectorF16 A1, VectorF16 B0, VectorF16 B1, diff --git a/interfaces/main.cpp b/interfaces/main.cpp index c9ae69b..0a36443 100644 --- a/interfaces/main.cpp +++ b/interfaces/main.cpp @@ -30,44 +30,38 @@ int main() { // std::cout << std::chrono::duration_cast(end-start) << std::endl; // std::println("{}", vfC); - + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(0, 100); - Vector<_Float16, 1326, 32> vA; - // for(std::uint32_t i = 0; i < 2; i++) { - // vA.v[i] = i; - // } - // for(std::uint32_t i = 2; i < 4; i++) { - // vA.v[i] = i-2; - // } - // for(std::uint32_t i = 4; i < 6; i++) { - // vA.v[i] = i-4; - // } - // for(std::uint32_t i = 6; i < 8; i++) { - // vA.v[i] = i-6; - // } - for(std::uint32_t i = 0; i < 8; i++) { - vA.v[i] = i; + Vector<_Float16, 32, 32> vA; + for(std::uint32_t i = 0; i < 32; i++) { + vA.v[i] = dist(gen); } - for(std::uint32_t i = 8; i < 16; i++) { - vA.v[i] = i-8; - } - for(std::uint32_t i = 16; i < 24; i++) { - vA.v[i] = i-16; - } - for(std::uint32_t i = 24; i < 32; i++) { - } - VectorF16<8, 1, 4> vfA(&vA); - std::tuple, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>> dot = VectorF16<8, 1, 4>::Normalize(vfA, vfA, vfA, vfA, vfA, vfA, vfA, vfA); - std::println("{}", std::get<0>(dot)); - Vector vB; - for(std::uint32_t i = 0; i < 8; i++) { - vB.v[i] = i; - } - vB.Normalize(); std::string log; - for(std::uint32_t i = 0; i < 8; i++) { - log += std::format("{} ", (float)vB.v[i]); + std::chrono::duration totalVector(0); + std::tuple, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>> vfA {VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA)}; + for(std::uint32_t i = 0; i < 1000000; i++) { + auto start = std::chrono::high_resolution_clock::now(); + vfA = VectorF16<4, 2, 4>::Normalize(std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA)); + auto end = std::chrono::high_resolution_clock::now(); + totalVector += end-start; } - std::println("{{{}}}", log); + + std::chrono::duration totalScalar(0); + Vector<_Float16, 4, 4> vB; + for(std::uint32_t i = 0; i < 4; i++) { + vB.v[i] = dist(gen); + } + for(std::uint32_t i = 0; i < 1000000; i++) { + auto start2 = std::chrono::high_resolution_clock::now(); + vB.Normalize(); + auto end2 = std::chrono::high_resolution_clock::now(); + totalScalar += end2-start2; + } + + std::println("{} {} {} {}", std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA)); + std::println("{}", vB); + std::println("Vector: {}, Scalar: {}", std::chrono::duration_cast(totalVector), std::chrono::duration_cast(totalScalar*8)); } \ No newline at end of file