11package patcher
22
33import (
4+ "fmt"
5+ "reflect"
6+
47 "github.com/expr-lang/expr/ast"
8+ "github.com/expr-lang/expr/builtin"
59 "github.com/expr-lang/expr/conf"
610)
711
8- type Operator struct {
9- Operators conf.OperatorsTable
10- Types conf.TypesTable
11- Functions conf.FunctionTable
12+ type OperatorOverride struct {
13+ Operator string // Operator token to override.
14+ Overrides []string // List of function names to override operator with.
15+ Types conf.TypesTable // Env types.
16+ Functions conf.FunctionsTable // Env functions.
1217}
1318
14- func (p * Operator ) Visit (node * ast.Node ) {
19+ func (p * OperatorOverride ) Visit (node * ast.Node ) {
1520 binaryNode , ok := (* node ).(* ast.BinaryNode )
1621 if ! ok {
1722 return
1823 }
1924
20- fns , ok := p .Operators [binaryNode .Operator ]
21- if ! ok {
25+ if binaryNode .Operator != p .Operator {
2226 return
2327 }
2428
2529 leftType := binaryNode .Left .Type ()
2630 rightType := binaryNode .Right .Type ()
2731
28- ret , fn , ok := conf .FindSuitableOperatorOverload (fns , p . Types , p . Functions , leftType , rightType )
32+ ret , fn , ok := p .FindSuitableOperatorOverload (leftType , rightType )
2933 if ok {
3034 newNode := & ast.CallNode {
3135 Callee : & ast.IdentifierNode {Value : fn },
@@ -35,3 +39,97 @@ func (p *Operator) Visit(node *ast.Node) {
3539 ast .Patch (node , newNode )
3640 }
3741}
42+
43+ func (p * OperatorOverride ) FindSuitableOperatorOverload (l , r reflect.Type ) (reflect.Type , string , bool ) {
44+ t , fn , ok := p .findSuitableOperatorOverloadInFunctions (l , r )
45+ if ! ok {
46+ t , fn , ok = p .findSuitableOperatorOverloadInTypes (l , r )
47+ }
48+ return t , fn , ok
49+ }
50+
51+ func (p * OperatorOverride ) findSuitableOperatorOverloadInTypes (l , r reflect.Type ) (reflect.Type , string , bool ) {
52+ for _ , fn := range p .Overrides {
53+ fnType , ok := p .Types [fn ]
54+ if ! ok {
55+ continue
56+ }
57+ firstInIndex := 0
58+ if fnType .Method {
59+ firstInIndex = 1 // As first argument to method is receiver.
60+ }
61+ ret , done := checkTypeSuits (fnType .Type , l , r , firstInIndex )
62+ if done {
63+ return ret , fn , true
64+ }
65+ }
66+ return nil , "" , false
67+ }
68+
69+ func (p * OperatorOverride ) findSuitableOperatorOverloadInFunctions (l , r reflect.Type ) (reflect.Type , string , bool ) {
70+ for _ , fn := range p .Overrides {
71+ fnType , ok := p .Functions [fn ]
72+ if ! ok {
73+ continue
74+ }
75+ firstInIndex := 0
76+ for _ , overload := range fnType .Types {
77+ ret , done := checkTypeSuits (overload , l , r , firstInIndex )
78+ if done {
79+ return ret , fn , true
80+ }
81+ }
82+ }
83+ return nil , "" , false
84+ }
85+
86+ func checkTypeSuits (t reflect.Type , l reflect.Type , r reflect.Type , firstInIndex int ) (reflect.Type , bool ) {
87+ firstArgType := t .In (firstInIndex )
88+ secondArgType := t .In (firstInIndex + 1 )
89+
90+ firstArgumentFit := l == firstArgType || (firstArgType .Kind () == reflect .Interface && (l == nil || l .Implements (firstArgType )))
91+ secondArgumentFit := r == secondArgType || (secondArgType .Kind () == reflect .Interface && (r == nil || r .Implements (secondArgType )))
92+ if firstArgumentFit && secondArgumentFit {
93+ return t .Out (0 ), true
94+ }
95+ return nil , false
96+ }
97+
98+ func (p * OperatorOverride ) Check () {
99+ for _ , fn := range p .Overrides {
100+ fnType , foundType := p .Types [fn ]
101+ fnFunc , foundFunc := p .Functions [fn ]
102+ if ! foundFunc && (! foundType || fnType .Type .Kind () != reflect .Func ) {
103+ panic (fmt .Errorf ("function %s for %s operator does not exist in the environment" , fn , p .Operator ))
104+ }
105+
106+ if foundType {
107+ checkType (fnType , fn , p .Operator )
108+ }
109+
110+ if foundFunc {
111+ checkFunc (fnFunc , fn , p .Operator )
112+ }
113+ }
114+ }
115+
116+ func checkType (fnType conf.Tag , fn string , operator string ) {
117+ requiredNumIn := 2
118+ if fnType .Method {
119+ requiredNumIn = 3 // As first argument of method is receiver.
120+ }
121+ if fnType .Type .NumIn () != requiredNumIn || fnType .Type .NumOut () != 1 {
122+ panic (fmt .Errorf ("function %s for %s operator does not have a correct signature" , fn , operator ))
123+ }
124+ }
125+
126+ func checkFunc (fn * builtin.Function , name string , operator string ) {
127+ if len (fn .Types ) == 0 {
128+ panic (fmt .Errorf ("function %s for %s operator misses types" , name , operator ))
129+ }
130+ for _ , t := range fn .Types {
131+ if t .NumIn () != 2 || t .NumOut () != 1 {
132+ panic (fmt .Errorf ("function %s for %s operator does not have a correct signature" , name , operator ))
133+ }
134+ }
135+ }
0 commit comments