Skip to content

Commit baa67dc

Browse files
Op creation logic updated;now ops can be passed as parameters
1 parent 2925e97 commit baa67dc

File tree

1 file changed

+46
-25
lines changed

1 file changed

+46
-25
lines changed

ravop/core/__init__.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import numpy as np
44
from ravcom import dump_data, RavQueue, QUEUE_LOW_PRIORITY, QUEUE_HIGH_PRIORITY
5+
from ravcom import globals as g
56
from ravcom import ravdb
67

7-
from ravcom import globals as g
88
from ravop.enums import *
99

1010

@@ -23,6 +23,7 @@ def minus_one():
2323
def inf():
2424
return Scalar(np.inf)
2525

26+
2627
def pi():
2728
return Scalar(np.pi)
2829

@@ -75,14 +76,14 @@ def create(self, operator, inputs=None, outputs=None, **kwargs):
7576
outputs = json.dumps(outputs)
7677

7778
op = ravdb.create_op(name=kwargs.get("name", None),
78-
graph_id=g.graph_id,
79-
node_type=node_type,
80-
inputs=inputs,
81-
outputs=outputs,
82-
op_type=op_type,
83-
operator=operator,
84-
status=status,
85-
params=json.dumps(kwargs))
79+
graph_id=g.graph_id,
80+
node_type=node_type,
81+
inputs=inputs,
82+
outputs=outputs,
83+
op_type=op_type,
84+
operator=operator,
85+
status=status,
86+
params=json.dumps(kwargs))
8687
# Add op to queue
8788
if op.status != OpStatus.COMPUTED.value and op.status != OpStatus.FAILED.value:
8889
if g.graph_id is None:
@@ -332,6 +333,7 @@ def __str__(self):
332333
self._op_db.operator,
333334
self.output,
334335
self.status)
336+
335337
def __call__(self, *args, **kwargs):
336338
return self.output
337339

@@ -818,6 +820,7 @@ def sign(op1, **kwargs):
818820
def foreach(op, **kwargs):
819821
return __create_math_op2(op, Operators.FOREACH.value, **kwargs)
820822

823+
821824
# Data Preprocessing
822825

823826

@@ -829,15 +832,24 @@ def __create_math_op(op1, op2, operator, **kwargs):
829832
if op1 is None or op2 is None:
830833
raise Exception("Null Op")
831834

835+
params = dict()
836+
for key, value in kwargs.items():
837+
if isinstance(value, Op) or isinstance(value, Data) or isinstance(value, Scalar) or isinstance(value, Tensor):
838+
params[key] = value.id
839+
elif type(value).__name__ in ['int', 'float']:
840+
params[key] = Scalar(value)
841+
elif type(value).__name__ == 'str':
842+
params[key] = value
843+
832844
op = ravdb.create_op(name=kwargs.get('name', None),
833-
graph_id=g.graph_id,
834-
node_type=NodeTypes.MIDDLE.value,
835-
inputs=json.dumps([op1.id, op2.id]),
836-
outputs=json.dumps(None),
837-
op_type=OpTypes.BINARY.value,
838-
operator=operator,
839-
status=OpStatus.PENDING.value,
840-
params=json.dumps(kwargs))
845+
graph_id=g.graph_id,
846+
node_type=NodeTypes.MIDDLE.value,
847+
inputs=json.dumps([op1.id, op2.id]),
848+
outputs=json.dumps(None),
849+
op_type=OpTypes.BINARY.value,
850+
operator=operator,
851+
status=OpStatus.PENDING.value,
852+
params=json.dumps(params))
841853

842854
# Add op to queue
843855
if op.status != OpStatus.COMPUTED.value and op.status != OpStatus.FAILED.value:
@@ -855,15 +867,24 @@ def __create_math_op2(op1, operator, **kwargs):
855867
if op1 is None:
856868
raise Exception("Null Op")
857869

870+
params = dict()
871+
for key, value in kwargs.items():
872+
if isinstance(value, Op) or isinstance(value, Data) or isinstance(value, Scalar) or isinstance(value, Tensor):
873+
params[key] = value.id
874+
elif type(value).__name__ in ['int', 'float']:
875+
params[key] = Scalar(value)
876+
elif type(value).__name__ == 'str':
877+
params[key] = value
878+
858879
op = ravdb.create_op(name=kwargs.get('name', None),
859-
graph_id=g.graph_id,
860-
node_type=NodeTypes.MIDDLE.value,
861-
inputs=json.dumps([op1.id]),
862-
outputs=json.dumps(None),
863-
op_type=OpTypes.UNARY.value,
864-
operator=operator,
865-
status=OpStatus.PENDING.value,
866-
params=json.dumps(kwargs))
880+
graph_id=g.graph_id,
881+
node_type=NodeTypes.MIDDLE.value,
882+
inputs=json.dumps([op1.id]),
883+
outputs=json.dumps(None),
884+
op_type=OpTypes.UNARY.value,
885+
operator=operator,
886+
status=OpStatus.PENDING.value,
887+
params=json.dumps(params))
867888

868889
# Add op to queue
869890
if op.status != OpStatus.COMPUTED.value and op.status != OpStatus.FAILED.value:

0 commit comments

Comments
 (0)