more F16 math
This commit is contained in:
parent
c54ff6228c
commit
f1fbbe0faf
3 changed files with 82 additions and 42 deletions
|
|
@ -318,7 +318,7 @@ namespace Crafter {
|
|||
T fLength = Length();
|
||||
|
||||
fLength = 1.0f / fLength;
|
||||
|
||||
|
||||
for(std::uint32_t i = 0; i < Len; i++) {
|
||||
this->v[i] *= fLength;
|
||||
}
|
||||
|
|
@ -480,7 +480,7 @@ template <typename T, std::uint32_t Aligment>
|
|||
struct std::formatter<Crafter::Vector<T, 4, Aligment>> : std::formatter<std::string> {
|
||||
auto format(const Crafter::Vector<T, 4, Aligment>& obj, format_context& ctx) const {
|
||||
return std::formatter<std::string>::format(std::format("{{{}, {}, {}, {}}}",
|
||||
obj.x, obj.y, obj.z, obj.w
|
||||
(float)obj.x, (float)obj.y, (float)obj.z, (float)obj.w
|
||||
), ctx);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -580,7 +580,7 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
|
||||
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
|
||||
VectorF16<Len, Packing, Repeats> A,
|
||||
VectorF16<Len, Packing, Repeats> C,
|
||||
VectorF16<Len, Packing, Repeats> E,
|
||||
|
|
@ -705,15 +705,15 @@ namespace Crafter {
|
|||
__m512h fLenghtG = _mm512_castsi512_ph(_mm512_shuffle_epi8(_mm512_castph_si512(fLenght), shuffleVecG));
|
||||
|
||||
return {
|
||||
_mm512_mul_ph(A.v, fLenghtA),
|
||||
_mm512_mul_ph(C.v, fLenghtC),
|
||||
_mm512_mul_ph(E.v, fLenghtE),
|
||||
_mm512_mul_ph(G.v, fLenghtG),
|
||||
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(A.v, fLenghtA)),
|
||||
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(C.v, fLenghtC)),
|
||||
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(E.v, fLenghtE)),
|
||||
VectorF16<Len, Packing, Repeats>(_mm512_mul_ph(G.v, fLenghtG)),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
|
||||
constexpr static std::tuple<VectorF16<Len, Packing, Repeats>, VectorF16<Len, Packing, Repeats>> Normalize(
|
||||
VectorF16<Len, Packing, Repeats> A,
|
||||
VectorF16<Len, Packing, Repeats> E
|
||||
) requires(Packing == 4) {
|
||||
|
|
@ -932,6 +932,36 @@ namespace Crafter {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<Len, Packing, Repeats> Length(
|
||||
VectorF16<Len, Packing, Repeats> A,
|
||||
VectorF16<Len, Packing, Repeats> C,
|
||||
VectorF16<Len, Packing, Repeats> E,
|
||||
VectorF16<Len, Packing, Repeats> G
|
||||
) requires(Packing == 2) {
|
||||
VectorF16<Len, Packing, Repeats> lenghtSq = LengthSq(A, C, E, G);
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing, Repeats>(_mm_sqrt_ph(lenghtSq.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF16<Len, Packing, Repeats>(_mm256_sqrt_ph(lenghtSq.v));
|
||||
} else {
|
||||
return VectorF16<Len, Packing, Repeats>(_mm512_sqrt_ph(lenghtSq.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<Len, Packing, Repeats> Length(
|
||||
VectorF16<Len, Packing, Repeats> A,
|
||||
VectorF16<Len, Packing, Repeats> E
|
||||
) requires(Packing == 2) {
|
||||
VectorF16<Len, Packing, Repeats> lenghtSq = LengthSq(A, E);
|
||||
if constexpr(std::is_same_v<VectorType, __m128h>) {
|
||||
return VectorF16<Len, Packing, Repeats>(_mm_sqrt_ph(lenghtSq.v));
|
||||
} else if constexpr(std::is_same_v<VectorType, __m256h>) {
|
||||
return VectorF16<Len, Packing, Repeats>(_mm256_sqrt_ph(lenghtSq.v));
|
||||
} else {
|
||||
return VectorF16<Len, Packing, Repeats>(_mm512_sqrt_ph(lenghtSq.v));
|
||||
}
|
||||
}
|
||||
|
||||
constexpr static VectorF16<Len, Packing, Repeats> LengthSq(
|
||||
VectorF16<Len, Packing, Repeats> A,
|
||||
VectorF16<Len, Packing, Repeats> B,
|
||||
|
|
@ -945,6 +975,22 @@ namespace Crafter {
|
|||
return Dot(A, A, B, B, C, C, D, D, E, E, F, F, G, G, H, H);
|
||||
}
|
||||
|
||||
constexpr static VectorF16<Len, Packing, Repeats> LengthSq(
|
||||
VectorF16<Len, Packing, Repeats> A,
|
||||
VectorF16<Len, Packing, Repeats> C,
|
||||
VectorF16<Len, Packing, Repeats> E,
|
||||
VectorF16<Len, Packing, Repeats> G
|
||||
) requires(Packing == 2) {
|
||||
return Dot(A, A, C, C, E, E, G, G);
|
||||
}
|
||||
|
||||
constexpr static VectorF16<Len, Packing, Repeats> LengthSq(
|
||||
VectorF16<Len, Packing, Repeats> A,
|
||||
VectorF16<Len, Packing, Repeats> E
|
||||
) requires(Packing == 4) {
|
||||
return Dot(A, A, E, E);
|
||||
}
|
||||
|
||||
constexpr static VectorF16<Len, Packing, Repeats> Dot(
|
||||
VectorF16<Len, Packing, Repeats> A0, VectorF16<Len, Packing, Repeats> A1,
|
||||
VectorF16<Len, Packing, Repeats> B0, VectorF16<Len, Packing, Repeats> B1,
|
||||
|
|
|
|||
|
|
@ -30,44 +30,38 @@ int main() {
|
|||
// std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end-start) << std::endl;
|
||||
// std::println("{}", vfC);
|
||||
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dist(0, 100);
|
||||
|
||||
Vector<_Float16, 1326, 32> vA;
|
||||
// for(std::uint32_t i = 0; i < 2; i++) {
|
||||
// vA.v[i] = i;
|
||||
// }
|
||||
// for(std::uint32_t i = 2; i < 4; i++) {
|
||||
// vA.v[i] = i-2;
|
||||
// }
|
||||
// for(std::uint32_t i = 4; i < 6; i++) {
|
||||
// vA.v[i] = i-4;
|
||||
// }
|
||||
// for(std::uint32_t i = 6; i < 8; i++) {
|
||||
// vA.v[i] = i-6;
|
||||
// }
|
||||
for(std::uint32_t i = 0; i < 8; i++) {
|
||||
vA.v[i] = i;
|
||||
Vector<_Float16, 32, 32> vA;
|
||||
for(std::uint32_t i = 0; i < 32; i++) {
|
||||
vA.v[i] = dist(gen);
|
||||
}
|
||||
for(std::uint32_t i = 8; i < 16; i++) {
|
||||
vA.v[i] = i-8;
|
||||
}
|
||||
for(std::uint32_t i = 16; i < 24; i++) {
|
||||
vA.v[i] = i-16;
|
||||
}
|
||||
for(std::uint32_t i = 24; i < 32; i++) {
|
||||
}
|
||||
VectorF16<8, 1, 4> vfA(&vA);
|
||||
std::tuple<VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>, VectorF16<8, 1, 4>> dot = VectorF16<8, 1, 4>::Normalize(vfA, vfA, vfA, vfA, vfA, vfA, vfA, vfA);
|
||||
std::println("{}", std::get<0>(dot));
|
||||
|
||||
Vector<float, 8, 8> vB;
|
||||
for(std::uint32_t i = 0; i < 8; i++) {
|
||||
vB.v[i] = i;
|
||||
}
|
||||
vB.Normalize();
|
||||
std::string log;
|
||||
for(std::uint32_t i = 0; i < 8; i++) {
|
||||
log += std::format("{} ", (float)vB.v[i]);
|
||||
std::chrono::duration<double> totalVector(0);
|
||||
std::tuple<VectorF16<4, 2, 4>, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>, VectorF16<4, 2, 4>> vfA {VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA), VectorF16<4, 2, 4>(&vA)};
|
||||
for(std::uint32_t i = 0; i < 1000000; i++) {
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
vfA = VectorF16<4, 2, 4>::Normalize(std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA));
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
totalVector += end-start;
|
||||
}
|
||||
std::println("{{{}}}", log);
|
||||
|
||||
std::chrono::duration<double> totalScalar(0);
|
||||
Vector<_Float16, 4, 4> vB;
|
||||
for(std::uint32_t i = 0; i < 4; i++) {
|
||||
vB.v[i] = dist(gen);
|
||||
}
|
||||
for(std::uint32_t i = 0; i < 1000000; i++) {
|
||||
auto start2 = std::chrono::high_resolution_clock::now();
|
||||
vB.Normalize();
|
||||
auto end2 = std::chrono::high_resolution_clock::now();
|
||||
totalScalar += end2-start2;
|
||||
}
|
||||
|
||||
std::println("{} {} {} {}", std::get<0>(vfA), std::get<1>(vfA), std::get<2>(vfA), std::get<3>(vfA));
|
||||
std::println("{}", vB);
|
||||
std::println("Vector: {}, Scalar: {}", std::chrono::duration_cast<std::chrono::milliseconds>(totalVector), std::chrono::duration_cast<std::chrono::milliseconds>(totalScalar*8));
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue