From 2781a7ed48a4aa75f2dd730eacf9f0919cee89b0 Mon Sep 17 00:00:00 2001 From: Frederich Munch Date: Sun, 17 Dec 2017 22:08:20 -0500 Subject: [PATCH] Fix poorly defined behavior when choosing certain function overloads. Previously the order of overload declarations would affect which was chosen. --- CMakeLists.txt | 1 + src/liboslcomp/ast.h | 11 - src/liboslcomp/typecheck.cpp | 309 +++++++++++++++++------ testsuite/function-overloads/a_fcn.h | 11 + testsuite/function-overloads/a_ivp.h | 11 + testsuite/function-overloads/b_nci.h | 12 + testsuite/function-overloads/b_vpf.h | 11 + testsuite/function-overloads/c_cnf.h | 11 + testsuite/function-overloads/c_vpi.h | 11 + testsuite/function-overloads/ref/out.txt | 46 ++++ testsuite/function-overloads/run.py | 15 ++ testsuite/function-overloads/test.osl | 120 +++++++++ 12 files changed, 477 insertions(+), 92 deletions(-) create mode 100644 testsuite/function-overloads/a_fcn.h create mode 100644 testsuite/function-overloads/a_ivp.h create mode 100644 testsuite/function-overloads/b_nci.h create mode 100644 testsuite/function-overloads/b_vpf.h create mode 100644 testsuite/function-overloads/c_cnf.h create mode 100644 testsuite/function-overloads/c_vpi.h create mode 100644 testsuite/function-overloads/ref/out.txt create mode 100755 testsuite/function-overloads/run.py create mode 100644 testsuite/function-overloads/test.osl diff --git a/CMakeLists.txt b/CMakeLists.txt index cc0097dd3..a5eec2db0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -248,6 +248,7 @@ TESTSUITE ( and-or-not-synonyms aastep arithmetic array array-derivs array-range error-dupes exit exponential fprintf function-earlyreturn function-simple function-outputelem + function-overloads geomath getattribute-camera getattribute-shader getsymbol-nonheap gettextureinfo group-outputs groupstring diff --git a/src/liboslcomp/ast.h b/src/liboslcomp/ast.h index a2e8f7a04..daac69596 100644 --- a/src/liboslcomp/ast.h +++ b/src/liboslcomp/ast.h @@ -889,17 +889,6 @@ class ASTfunction_call : public ASTNode } private: - /// Typecheck all polymorphic versions, return UNKNOWN if no match was - /// found, or a real type if there was a match. Also, upon matching, - /// re-jigger m_sym to point to the specific polymorphic match. - /// Allow arguments to be coerced (e.g., substituting a vector where - /// a point was expected, or a float where a color was expected) only - /// if coerceargs is true. For return values, allow spatial triples to - /// mutually match if 'equivreturn' is true, and allow any coercive - /// return type if 'expected' is TypeSpec() (i.e., unknown). - TypeSpec typecheck_all_poly (TypeSpec expected, bool coerceargs, - bool equivreturn); - /// Handle all the special cases for built-ins. This includes /// irregular patterns of which args are read vs written, special /// checks for printf- and texture-like, etc. diff --git a/src/liboslcomp/typecheck.cpp b/src/liboslcomp/typecheck.cpp index 0beedfd60..fd61a18bc 100644 --- a/src/liboslcomp/typecheck.cpp +++ b/src/liboslcomp/typecheck.cpp @@ -903,28 +903,229 @@ ASTNode::check_arglist (const char *funcname, ASTNode::ref arg, } +class CandidateFunctions { + enum { + kExactMatch = 100, + kIntegralToFP = 80, + kArrayMatch = 40, + kCoercable = 20, + kMatchAnything = 1, + kNoMatch = 0, + + // Additional rules that don't match C++ behaviour + kFPToIntegral = 60, // = kIntegralToFP to match c++ + kMatchReturn = kExactMatch, // = 0 to match c++ + kCoercehReturn = kCoercable, // = 0 to match c++ + }; + struct Candidate { + FunctionSymbol* sym; + TypeSpec rtype; + int ascore; + int rscore; + + Candidate(FunctionSymbol *s, TypeSpec rt, int as, int rs) : + sym(s), rtype(rt), ascore(as), rscore(rs) {} + + string_view name() const { return sym->name(); } + }; + typedef std::vector Candidates; + + OSLCompilerImpl* m_compiler; + Candidates m_candidates; + TypeSpec m_rval; + ASTNode::ref m_args; + size_t m_nargs; + + const char* scoreWildcard(int& argscore, size_t& fargs, const char* args) const { + while (fargs < m_nargs) { + argscore += kMatchAnything; + ++fargs; + } + return args + 1; + } -TypeSpec -ASTfunction_call::typecheck_all_poly (TypeSpec expected, bool coerceargs, - bool equivreturn) -{ - for (FunctionSymbol *poly = func(); poly; poly = poly->nextpoly()) { - const char *code = poly->argcodes().c_str(); + int addCandidate(FunctionSymbol* func) { int advance; - TypeSpec returntype = m_compiler->type_from_code (code, &advance); - code += advance; - if (check_arglist (m_name.c_str(), args(), code, coerceargs)) { - // Return types also must match if not coercible - if (expected == returntype || - (equivreturn && equivalent(expected,returntype)) || - expected == TypeSpec()) { - m_sym = poly; - return returntype; + const char *formals = func->argcodes().c_str(); + TypeSpec rtype = m_compiler->type_from_code (formals, &advance); + formals += advance; + + int argscore = 0; + size_t fargs = 0; + for (ASTNode::ref arg = m_args; *formals && arg; ++fargs, arg = arg->next()) { + switch (*formals) { + case '*': // Will match anything left + formals = scoreWildcard(argscore, fargs, formals); + ASSERT (*formals == 0); + continue; + + case '.': // Token/value pairs + if (arg->typespec().is_string() && arg->next()) { + formals = scoreWildcard(argscore, fargs, formals); + ASSERT (*formals == 0); + continue; + } + return kNoMatch; + + case '?': + if (formals[1] == '[' && formals[2] == ']') { + // Any array + formals += 3; + if (!arg->typespec().is_array()) + return kNoMatch; // wanted an array, didn't get one + argscore += kMatchAnything; + } else if (!arg->typespec().is_array()) { + formals += 1; // match anything + argscore += kMatchAnything; + } else + return kNoMatch; // wanted any scalar, got an array + continue; + + default: + break; } + // To many arguments for the function, done without a match. + if (fargs >= m_nargs) + return kNoMatch; + + TypeSpec argtype = arg->typespec(); + TypeSpec formaltype = m_compiler->type_from_code (formals, &advance); + formals += advance; + + if (argtype == formaltype) + argscore += kExactMatch; // ok, move on to next arg + else if (!argtype.is_closure() && argtype.is_scalarnum() && + !formaltype.is_closure() && formaltype.is_scalarnum()) + argscore += formaltype.is_int() ? kFPToIntegral : kIntegralToFP; + else if (formaltype.is_unsized_array() && argtype.is_sized_array() && + formaltype.elementtype() == argtype.elementtype()) { + // Allow a fixed-length array match to a formal array with + // unspecified length, if the element types are the same. + argscore += kArrayMatch; + } else if (assignable (formaltype, argtype)) + argscore += kCoercable; + else + return kNoMatch; } + + // Check any remaining arguments + switch (*formals) { + case '*': + case '.': + // Skip over the unused optional args + ++formals; + ++fargs; + case '\0': + if (fargs < m_nargs) + return 0; + break; + + default: + // TODO: Scoring default function arguments would go here + // Curently an unused formal argument, so no match at all. + return 0; + } + ASSERT (*formals == 0); + + int highscore = m_candidates.empty() ? 0 : m_candidates.front().ascore; + if (argscore < highscore) + return 0; + + + if (argscore == highscore) { + // Check for duplicate declarations + for (auto& candidate : m_candidates) { + if (candidate.sym->argcodes() == func->argcodes()) + return 0; + } + } else // clear any prior ambiguous matches + m_candidates.clear(); + + // append the latest high scoring function + m_candidates.emplace_back(func, rtype, argscore, rtype == m_rval ? + kMatchReturn : (equivalent(rtype, m_rval) ? kCoercehReturn : kNoMatch)); + + return argscore; } - return TypeSpec(); -} + +public: + CandidateFunctions(OSLCompilerImpl* compiler, TypeSpec rval, ASTNode::ref args, FunctionSymbol* func) : + m_compiler(compiler), m_rval(rval), m_args(args), m_nargs(0) { + + //std::cerr << "Matching " << func->name() << " formals='" << (rval.simpletype().basetype != TypeDesc::UNKNOWN ? compiler->code_from_type (rval) : " "); + for (ASTNode::ref arg = m_args; arg; arg = arg->next()) { + //std::cerr << compiler->code_from_type (arg->typespec()); + ++m_nargs; + } + //std::cerr << "'\n"; + + while (func) { + //int score = + addCandidate(func); + //std::cerr << '\t' << func->name() << " formals='" << func->argcodes().c_str() << "' " << score << ", " << (score ? m_candidates.back().rscore : 0) << "\n"; + func = func->nextpoly(); + } + } + + void reportError(ASTfunction_call* caller, string_view name) { + std::string actualargs; + for (ASTNode::ref arg = m_args; arg; arg = arg->next()) { + if (actualargs.length()) + actualargs += ", "; + actualargs += arg->typespec().string(); + } + caller->error ("No matching function call to '%s (%s)'", + name.c_str(), actualargs.c_str()); + } + + void reportAmbiguity(FunctionSymbol* sym) const { + int advance; + const char *formals = sym->argcodes().c_str(); + TypeSpec returntype = m_compiler->type_from_code (formals, &advance); + formals += advance; + + auto& errh = m_compiler->errhandler(); + if (ASTNode* decl = sym->node()) + errh.message("%s:%d ", decl->sourcefile(), decl->sourceline()); + + errh.message("candidate function:\n"); + errh.message("\t%s %s (%s)\n", + m_compiler->type_c_str(returntype), sym->name(), + m_compiler->typelist_from_code(formals).c_str()); + } + + std::pair best(ASTNode* caller, bool strict = 0) { + switch (m_candidates.size()) { + case 0: return { nullptr, TypeSpec() }; + case 1: return { m_candidates[0].sym, m_candidates[0].rtype }; + default: break; + } + + int ambiguity = 0; + std::pair c = { nullptr, -1 }; + for (auto& candidate : m_candidates) { + // re-score based on matching return value + if (candidate.rscore > c.second) + c = std::make_pair(&candidate, candidate.rscore); + else if (candidate.rscore == c.second) + ambiguity = candidate.rscore; + } + + if (ambiguity || strict) { + ASSERT (caller); + caller->warning( "call to '%s' is ambiguous", m_candidates[0].name()); + for (auto& candidate : m_candidates) { + if (candidate.rscore >= ambiguity) + reportAmbiguity(candidate.sym); + } + } + + ASSERT (c.first); + return {c.first->sym, c.first->rtype}; + } + + bool empty() const { return m_candidates.empty(); } +}; @@ -1191,47 +1392,10 @@ ASTfunction_call::typecheck (TypeSpec expected) return typecheck_struct_constructor (); } - bool match = false; - - // Look for an exact match, including expected return type - m_typespec = typecheck_all_poly (expected, false, false); - if (m_typespec != TypeSpec()) - match = true; - - // Now look for an exact match for arguments, but equivalent return type - m_typespec = typecheck_all_poly (expected, false, true); - if (m_typespec != TypeSpec()) - match = true; + CandidateFunctions candidates(m_compiler, expected, args(), func()); + std::tie(m_sym, m_typespec) = candidates.best(this); - // Now look for an exact match on args, but any return type - if (! match && expected != TypeSpec()) { - m_typespec = typecheck_all_poly (TypeSpec(), false, false); - if (m_typespec != TypeSpec()) - match = true; - } - - // Now look for a coercible match of args, exact march on return type - if (! match) { - m_typespec = typecheck_all_poly (expected, true, false); - if (m_typespec != TypeSpec()) - match = true; - } - - // Now look for a coercible match of args, equivalent march on return type - if (! match) { - m_typespec = typecheck_all_poly (expected, true, true); - if (m_typespec != TypeSpec()) - match = true; - } - - // All that failed, try for a coercible match on everything - if (! match && expected != TypeSpec()) { - m_typespec = typecheck_all_poly (TypeSpec(), true, false); - if (m_typespec != TypeSpec()) - match = true; - } - - if (match) { + if (m_sym != nullptr) { if (is_user_function()) { if (func()->number_of_returns() == 0 && ! func()->typespec().is_void()) { @@ -1245,35 +1409,18 @@ ASTfunction_call::typecheck (TypeSpec expected) return m_typespec; } + // Ambiguity has already been reported. + if (!candidates.empty()) + return TypeSpec(); + // Couldn't find any way to match any polymorphic version of the // function that we know about. OK, at least try for helpful error // message. - std::string choices (""); - for (FunctionSymbol *poly = func(); poly; poly = poly->nextpoly()) { - const char *code = poly->argcodes().c_str(); - int advance; - TypeSpec returntype = m_compiler->type_from_code (code, &advance); - code += advance; - if (choices.length()) - choices += "\n"; - choices += Strutil::format ("\t%s %s (%s)", - type_c_str(returntype), m_name.c_str(), - m_compiler->typelist_from_code(code).c_str()); - } + candidates.reportError(this, m_name); - std::string actualargs; - for (ASTNode::ref arg = args(); arg; arg = arg->next()) { - if (actualargs.length()) - actualargs += ", "; - actualargs += arg->typespec().string(); - } + for (FunctionSymbol *poly = func(); poly; poly = poly->nextpoly()) + candidates.reportAmbiguity(poly); - if (choices.size()) - error ("No matching function call to '%s (%s)'\n Candidates are:\n%s", - m_name.c_str(), actualargs.c_str(), choices.c_str()); - else - error ("No matching function call to '%s (%s)'", - m_name.c_str(), actualargs.c_str()); return TypeSpec(); } diff --git a/testsuite/function-overloads/a_fcn.h b/testsuite/function-overloads/a_fcn.h new file mode 100644 index 000000000..3156d9bce --- /dev/null +++ b/testsuite/function-overloads/a_fcn.h @@ -0,0 +1,11 @@ + +void testA(float a, float b, float c) { + printf("testA float\n"); +} +void testA(color a, float b, float c) { + printf("testA color\n"); +} +void testA(normal a, float b, float c) { + printf("testA normal\n"); +} + diff --git a/testsuite/function-overloads/a_ivp.h b/testsuite/function-overloads/a_ivp.h new file mode 100644 index 000000000..f572e5ec2 --- /dev/null +++ b/testsuite/function-overloads/a_ivp.h @@ -0,0 +1,11 @@ + +void testA(int a, float b, float c) { + printf("testA int\n"); +} +void testA(vector a, float b, float c) { + printf("testA vector\n"); +} +void testA(point a, float b, float c) { + printf("testA point\n"); +} + diff --git a/testsuite/function-overloads/b_nci.h b/testsuite/function-overloads/b_nci.h new file mode 100644 index 000000000..27aaf125f --- /dev/null +++ b/testsuite/function-overloads/b_nci.h @@ -0,0 +1,12 @@ + + +void testB(normal a, float b, float c) { + printf("testB normal\n"); +} +void testB(color a, float b, float c) { + printf("testB color\n"); +} +void testB(int a, float b, float c) { + printf("testB int\n"); +} + diff --git a/testsuite/function-overloads/b_vpf.h b/testsuite/function-overloads/b_vpf.h new file mode 100644 index 000000000..1bd8c6a7f --- /dev/null +++ b/testsuite/function-overloads/b_vpf.h @@ -0,0 +1,11 @@ + +void testB(vector a, float b, float c) { + printf("testB vector\n"); +} +void testB(point a, float b, float c) { + printf("testB point\n"); +} +void testB(float a, float b, float c) { + printf("testB float\n"); +} + diff --git a/testsuite/function-overloads/c_cnf.h b/testsuite/function-overloads/c_cnf.h new file mode 100644 index 000000000..68a1aaceb --- /dev/null +++ b/testsuite/function-overloads/c_cnf.h @@ -0,0 +1,11 @@ + + +void testC(color a, float b, float c) { + printf("testC color\n"); +} +void testC(normal a, float b, float c) { + printf("testC normal\n"); +} +void testC(float a, float b, float c) { + printf("testC float\n"); +} diff --git a/testsuite/function-overloads/c_vpi.h b/testsuite/function-overloads/c_vpi.h new file mode 100644 index 000000000..b34381010 --- /dev/null +++ b/testsuite/function-overloads/c_vpi.h @@ -0,0 +1,11 @@ + +void testC(vector a, float b, float c) { + printf("testC vector\n"); +} +void testC(point a, float b, float c) { + printf("testC point\n"); +} +void testC(int a, float b, float c) { + printf("testC int\n"); +} + diff --git a/testsuite/function-overloads/ref/out.txt b/testsuite/function-overloads/ref/out.txt new file mode 100644 index 000000000..71a7f1a25 --- /dev/null +++ b/testsuite/function-overloads/ref/out.txt @@ -0,0 +1,46 @@ +Compiled test.osl -> test.oso +testA int +testB int +testC int +testD int +testA int +testB int +testC int +testD int2 + +testA float +testB float +testC float +testD float +testA float +testB float +testC float +testD float + +testD int +testD int2 +testD float +testD vector +testD int +testD int2 +testD float +testD point +testD int +testD int2 +testD float +testD color +testD int +testD int2 +testD float +testD normal +testD int +testD int2 +testD float +testD int +testD int2 +testD float + +testE color +testE vector + + diff --git a/testsuite/function-overloads/run.py b/testsuite/function-overloads/run.py new file mode 100755 index 000000000..75c92e399 --- /dev/null +++ b/testsuite/function-overloads/run.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python + +realruntest = runtest + +def runtest (command, *args, **kwargs) : + passed = True + for arg in ("-DORDER_1 ", ""): + command = oslc(arg + "test.osl") + command += testshade("-g 1 1 test") + if realruntest(command, *args, **kwargs): + passed = False + return not passed + +command = "" + diff --git a/testsuite/function-overloads/test.osl b/testsuite/function-overloads/test.osl new file mode 100644 index 000000000..bf117aa9b --- /dev/null +++ b/testsuite/function-overloads/test.osl @@ -0,0 +1,120 @@ + +#ifdef ORDER_1 + #include "a_fcn.h" + #include "b_nci.h" + #include "c_vpi.h" + + #include "a_ivp.h" + #include "b_vpf.h" + #include "c_cnf.h" +#else + #include "a_ivp.h" + #include "b_vpf.h" + #include "c_cnf.h" + + #include "a_fcn.h" + #include "b_nci.h" + #include "c_vpi.h" +#endif + +int intval() { return 0; } +float floatval() { return 0.0; } + + +normal testD(normal a, float b, float c) { + printf("testD normal\n"); + return normal(0); +} +color testD(color a, float b, float c) { + printf("testD color\n"); + return color(1); +} +vector testD(vector a, float b, float c) { + printf("testD vector\n"); + return vector(2); +} +point testD(point a, float b, float c) { + printf("testD point\n"); + return point(3); +} +int testD(int a, float b, float c) { + printf("testD int\n"); + return 4; +} +float testD(float a, float b, float c) { + printf("testD float\n"); + return 5; +} + +// This would break C++ +int testD(int a, int b, float c) { + printf("testD int2\n"); + return 4; +} + +int testE(int a, float b, color s) { + printf("testE color\n"); + return 4; +} +int testE(float a, int b, vector v) { + printf("testE vector\n"); + return 4; +} + +shader test () +{ + { + testA(intval(), 1.0, 1.0); + testB(intval(), 1.0, 1.0); + testC(intval(), 1.0, 1.0); + testD(intval(), 1.0, 1.0); + testA(intval(), 1, 1); + testB(intval(), 1, 1); + testC(intval(), 1, 1); + testD(intval(), 1, 1); + printf("\n"); + } + + { + testA(floatval(), 1.0, 1.0); + testB(floatval(), 1.0, 1.0); + testC(floatval(), 1.0, 1.0); + testD(floatval(), 1.0, 1.0); + testA(floatval(), 1, 1); + testB(floatval(), 1, 1); + testC(floatval(), 1, 1); + testD(floatval(), 1, 1); + printf("\n"); + } + + { + vector v0 = testD(intval(), 1.0, 1.0); + vector v1 = testD(intval(), 1, 1); + vector v2 = testD(floatval(), 1, 1); + vector v3 = testD(vector(0), 1, 1); + point p0 = testD(intval(), 1.0, 1.0); + point p1 = testD(intval(), 1, 1); + point p2 = testD(floatval(), 1, 1); + point p3 = testD(point(0), 1, 1); + color c0 = testD(intval(), 1.0, 1.0); + color c1 = testD(intval(), 1, 1); + color c2 = testD(floatval(), 1, 1); + color c3 = testD(color(0), 1, 1); + normal n0 = testD(intval(), 1.0, 1.0); + normal n1 = testD(intval(), 1, 1); + normal n2 = testD(floatval(), 1, 1); + normal n3 = testD(normal(0), 1, 1); + int i0 = testD(intval(), 1.0, 1.0); + int i1 = testD(intval(), 1, 1); + int i2 = (int) testD(floatval(), 1, 1); + float f0 = testD(intval(), 1.0, 1.0); + float f1 = testD(intval(), 1, 1); + float f2 = testD(floatval(), 1, 1); + printf("\n"); + } + { + testE(1.0, 1, color(0)); + testE(1, 1.0, vector(0)); + printf("\n"); + } +}