Skip to content
Merged
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions lib/Support/Check.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we want to add a test case of the correct error message output?

We could use FileCheck to ensure the hex values or whatnot are being formatted as expected

Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
//===----------------------------------------------------------------------===//

#include "Support/Check.h"
#include "Support/Pipeline.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <sstream>

constexpr uint16_t Float16BitSign = 0x8000;
constexpr uint16_t Float16BitExp = 0x7c00;
Expand Down Expand Up @@ -267,30 +269,100 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
return false;
}

template <typename T>
static std::string bitPatternAsHex64(const T &Val,
offloadtest::Rule ComparisonRule) {
static_assert(sizeof(T) <= sizeof(uint64_t), "Type too large for Hex64");

std::ostringstream Oss;
if (ComparisonRule == offloadtest::Rule::BufferExact)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use llvm::raw_svector_ostream to avoid the include of ? Or is not able to handle the conversion of std::hex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The std library is the only thing aware of std::hex, unfortunately none of the other llvm ostreams are compatible with std::hex / hexfloat.

Oss << std::hex << Val;
else
Oss << std::hexfloat << Val;
return Oss.str();
}

static const std::string getBufferStr(offloadtest::Buffer *B,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, but I think it would be cleaner to use templates for this instead of a macro.
I asked co-pilot to re-write using templates.

Also includes a simplification of the logic of the for loop formatting the output string.

Edit: After writing this comment I see the same pattern used elsewhere in Pipeline.cpp. The templated version still seems cleaner IMO. And I do see the same pattern with ENUM_CASE in pipeline.h as well, but all of those cases have very simple logic in the macro. I'll still leave the decision up to you.

Something like this:

template <typename T>
std::string formatBuffer(offloadtest::Buffer* B, offloadtest::Rule rule) {
  llvm::MutableArrayRef<T> arr(reinterpret_cast<T*>(B->Data.get()), B->Size / sizeof(T));
  if (arr.empty()) return "";

  std::string result = "[ " + bitPatternAsHex64(arr[0], rule);
  for (size_t i = 1; i < arr.size(); ++i)
    result += ", " + bitPatternAsHex64(arr[i], rule);
  result += " ]";
  return result;
}

static const std::string getBufferStr(offloadtest::Buffer* B, offloadtest::Rule rule) {
  using DF = offloadtest::DataFormat;
  switch (B->Format) {
    case DF::Hex8:    return formatBuffer<llvm::yaml::Hex8>(B, rule);
    case DF::Hex16:   return formatBuffer<llvm::yaml::Hex16>(B, rule);
    case DF::Hex32:   return formatBuffer<llvm::yaml::Hex32>(B, rule);
    case DF::Hex64:   return formatBuffer<llvm::yaml::Hex64>(B, rule);
    case DF::UInt16:  return formatBuffer<uint16_t>(B, rule);
    case DF::UInt32:  return formatBuffer<uint32_t>(B, rule);
    case DF::UInt64:  return formatBuffer<uint64_t>(B, rule);
    case DF::Int16:   return formatBuffer<int16_t>(B, rule);
    case DF::Int32:   return formatBuffer<int32_t>(B, rule);
    case DF::Int64:   return formatBuffer<int64_t>(B, rule);
    case DF::Float16: return formatBuffer<llvm::yaml::Hex16>(B, rule); // assuming no native float16
    case DF::Float32: return formatBuffer<float>(B, rule);
    case DF::Float64: return formatBuffer<double>(B, rule);
    case DF::Bool:    return formatBuffer<uint32_t>(B, rule); // Because sizeof(bool) is 1 but HLSL represents a bool using 4 bytes.
    default:          return "UHO SCOOBY";
  }
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would +1 on the templated version fwiw

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed Pipeline.cpp and Check.cpp, I think I should leave pipeline.h as is.

offloadtest::Rule ComparisonRule) {
std::string ret = "";
switch (B->Format) {
#define DATA_CASE(Enum, Type) \
case offloadtest::DataFormat::Enum: { \
const llvm::MutableArrayRef<Type> Arr( \
reinterpret_cast<Type *>(B->Data.get()), B->Size / sizeof(Type)); \
if (Arr.size() == 0) \
return ""; \
if (Arr.size() == 1) \
return "[ " + bitPatternAsHex64(Arr[0], ComparisonRule) + " ]"; \
ret += " [ " + bitPatternAsHex64(Arr[0], ComparisonRule); \
for (unsigned int i = 1; i < Arr.size(); i++) \
ret += ", " + bitPatternAsHex64(Arr[i], ComparisonRule); \
ret += " ]"; \
break; \
}
DATA_CASE(Hex8, llvm::yaml::Hex8)
DATA_CASE(Hex16, llvm::yaml::Hex16)
DATA_CASE(Hex32, llvm::yaml::Hex32)
DATA_CASE(Hex64, llvm::yaml::Hex64)
DATA_CASE(UInt16, uint16_t)
DATA_CASE(UInt32, uint32_t)
DATA_CASE(UInt64, uint64_t)
DATA_CASE(Int16, int16_t)
DATA_CASE(Int32, int32_t)
DATA_CASE(Int64, int64_t)
DATA_CASE(Float16, llvm::yaml::Hex16)
DATA_CASE(Float32, float)
DATA_CASE(Float64, double)
DATA_CASE(Bool, uint32_t) // Because sizeof(bool) is 1 but HLSL represents a
// bool using 4 bytes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be helpful to have a default string with something like "getBufferStr: Unrecognized DataFormat" to make it quicker to debug in the future if a format type was added and we forgot to handle it?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewing together with @inbelic he reminded me that llvm defaults to ensure all cases are covered. So I think NOT having a default is actually preferred?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it without a default for now.

}
return ret;
}

llvm::Error verifyResult(offloadtest::Result R) {
llvm::SmallString<256> Str;
llvm::raw_svector_ostream OS(Str);
OS << "Test failed: " << R.Name << "\n";

switch (R.ComparisonRule) {
case offloadtest::Rule::BufferExact: {
if (testBufferExact(R.ActualPtr, R.ExpectedPtr))
return llvm::Error::success();
OS << "Comparison Rule: BufferExact\n";
break;
}
case offloadtest::Rule::BufferFloatULP: {
if (testBufferFloatULP(R.ActualPtr, R.ExpectedPtr, R.ULPT, R.DM))
return llvm::Error::success();
OS << "Comparison Rule: BufferFloatULP\nULP: " << R.ULPT << "\n";
break;
}
case offloadtest::Rule::BufferFloatEpsilon: {
if (testBufferFloatEpsilon(R.ActualPtr, R.ExpectedPtr, R.Epsilon, R.DM))
return llvm::Error::success();
OS << "Comparison Rule: BufferFloatEpsilon\nEpsilon: " << R.Epsilon << "\n";
break;
}
}
llvm::SmallString<256> Str;
llvm::raw_svector_ostream OS(Str);
OS << "Test failed: " << R.Name << "\nExpected:\n";

OS << "Expected:\n";
llvm::yaml::Output YAMLOS(OS);
YAMLOS << *R.ExpectedPtr;
OS << "Got:\n";
YAMLOS << *R.ActualPtr;

// Now print exact hex64 representations of each element of the
// actual and expected buffers.

const std::string ExpectedBufferStr =
getBufferStr(R.ExpectedPtr, R.ComparisonRule);
const std::string ActualBufferStr =
getBufferStr(R.ActualPtr, R.ComparisonRule);

OS << "Full Hex 64bit representation of Expected Buffer Values:\n"
<< ExpectedBufferStr << "\n";
OS << "Full Hex 64bit representation of Actual Buffer Values:\n"
<< ActualBufferStr << "\n";

return llvm::createStringError(Str.c_str());
}