From 8ea3608e3f15e70149318d0806ceff4586c56a43 Mon Sep 17 00:00:00 2001 From: robertsasu Date: Tue, 6 Feb 2024 14:58:11 +0200 Subject: [PATCH 01/10] start with refactoring some ugly if's. --- vmhost/contexts/asyncComposability.go | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vmhost/contexts/asyncComposability.go b/vmhost/contexts/asyncComposability.go index 495aa30ff..6ec53922f 100644 --- a/vmhost/contexts/asyncComposability.go +++ b/vmhost/contexts/asyncComposability.go @@ -63,31 +63,34 @@ func (context *asyncContext) complete() error { return nil } + gasToAccumulate := context.gasAccumulated + notifyChildComplete := true currentCallID := context.GetCallID() - if context.callType == vm.AsynchronousCall { + switch context.callType { + case vm.AsynchronousCall: vmOutput := context.childResults - isCallbackComplete, _, err := context.callCallback(currentCallID, vmOutput, nil) + notifyChildComplete, _, err = context.callCallback(currentCallID, vmOutput, nil) if err != nil { return err } - if isCallbackComplete { - return context.NotifyChildIsComplete(currentCallID, 0) - } - } else if context.callType == vm.AsynchronousCallBack { + gasToAccumulate = 0 + case vm.AsynchronousCallBack: err = context.LoadParentContext() if err != nil { return err } - currentCallID := context.GetCallerCallID() - return context.NotifyChildIsComplete(currentCallID, context.gasAccumulated) - } else if context.callType == vm.DirectCall { + currentCallID = context.GetCallerCallID() + case vm.DirectCall: err = context.LoadParentContext() if err != nil { return err } + currentCallID = nil + } - return context.NotifyChildIsComplete(nil, context.gasAccumulated) + if notifyChildComplete { + return context.NotifyChildIsComplete(currentCallID, gasToAccumulate) } return nil From 33f21980d2aa0056f12952c3dfbef2c08bb8322c Mon Sep 17 00:00:00 2001 From: robertsasu Date: Wed, 7 Feb 2024 15:56:19 +0200 Subject: [PATCH 02/10] resolving todo's --- scenario/gasSchedules/gasSchedules.go | 1 - test/contracts/erc20/erc20.c | 1 - vmhost/contexts/async.go | 4 +- vmhost/contexts/asyncLocal.go | 37 +++---- vmhost/contexts/asyncParams.go | 136 ++------------------------ vmhost/hostCore/execution.go | 2 - vmhost/interface.go | 2 +- 7 files changed, 24 insertions(+), 159 deletions(-) diff --git a/scenario/gasSchedules/gasSchedules.go b/scenario/gasSchedules/gasSchedules.go index 807248821..178de9023 100644 --- a/scenario/gasSchedules/gasSchedules.go +++ b/scenario/gasSchedules/gasSchedules.go @@ -1,6 +1,5 @@ package gasschedules -// TODO: go:embed can be used after we upgrade to go 1.16 // import _ "embed" // //go:embed gasScheduleV1.toml diff --git a/test/contracts/erc20/erc20.c b/test/contracts/erc20/erc20.c index 5e3aab34c..0f49c4de6 100644 --- a/test/contracts/erc20/erc20.c +++ b/test/contracts/erc20/erc20.c @@ -44,7 +44,6 @@ void computeAllowanceKey(byte *destination, byte *from, byte* to) { // Note: in smart contract addresses, the first 10 bytes are all 0 // therefore we read from byte 10 onwards to provide more significant bytes // and to minimize the chance for collisions - // TODO: switching to a hash instead of a concatenation of addresses might make it safer for (int i = 0; i < 15; i++) { destination[1+i] = from[10+i]; } diff --git a/vmhost/contexts/async.go b/vmhost/contexts/async.go index b90e87a0c..f66f86c1a 100644 --- a/vmhost/contexts/async.go +++ b/vmhost/contexts/async.go @@ -183,7 +183,7 @@ func (context *asyncContext) PushState() { callbackData: context.callbackData, gasAccumulated: context.gasAccumulated, returnData: context.returnData, - asyncCallGroups: context.asyncCallGroups, // TODO matei-p use cloneCallGroups()? + asyncCallGroups: context.cloneCallGroups(), callType: context.callType, callbackAsyncInitiatorCallID: context.callbackAsyncInitiatorCallID, @@ -864,7 +864,7 @@ func (context *asyncContext) callCallback(callID []byte, vmOutput *vmcommon.VMOu } context.host.Metering().DisableRestoreGas() - isComplete, callbackVMOutput := loadedContext.ExecuteSyncCallbackAndFinishOutput(asyncCall, vmOutput, nil, gasAccumulated, err) + isComplete, callbackVMOutput := loadedContext.ExecuteLocalCallbackAndFinishOutput(asyncCall, vmOutput, nil, gasAccumulated, err) context.host.Metering().EnableRestoreGas() return isComplete, callbackVMOutput, nil } diff --git a/vmhost/contexts/asyncLocal.go b/vmhost/contexts/asyncLocal.go index 9b59bf6fe..491577378 100644 --- a/vmhost/contexts/asyncLocal.go +++ b/vmhost/contexts/asyncLocal.go @@ -35,7 +35,6 @@ func (context *asyncContext) executeAsyncLocalCalls() error { return nil } -// TODO split this method into smaller ones func (context *asyncContext) executeAsyncLocalCall(asyncCall *vmhost.AsyncCall) error { destinationCallInput, err := context.createContractCallInput(asyncCall) if err != nil { @@ -79,10 +78,11 @@ func (context *asyncContext) executeAsyncLocalCall(asyncCall *vmhost.AsyncCall) asyncCall.UpdateStatus(vmOutput.ReturnCode) if isComplete { + callbackGasRemaining := uint64(0) if asyncCall.HasCallback() { // Restore gas locked while still on the caller instance; otherwise, the // locked gas will appear to have been used twice by the caller instance. - isCallbackComplete, callbackVMOutput := context.ExecuteSyncCallbackAndFinishOutput(asyncCall, vmOutput, destinationCallInput, 0, err) + isCallbackComplete, callbackVMOutput := context.ExecuteLocalCallbackAndFinishOutput(asyncCall, vmOutput, destinationCallInput, 0, err) if callbackVMOutput == nil { return vmhost.ErrAsyncNoOutputFromCallback } @@ -90,33 +90,30 @@ func (context *asyncContext) executeAsyncLocalCall(asyncCall *vmhost.AsyncCall) context.host.CompleteLogEntriesWithCallType(callbackVMOutput, vmhost.AsyncCallbackString) if isCallbackComplete { - callbackGasRemaining := callbackVMOutput.GasRemaining + callbackGasRemaining = callbackVMOutput.GasRemaining callbackVMOutput.GasRemaining = 0 - return context.completeChild(asyncCall.CallID, callbackGasRemaining) } - } else { - return context.completeChild(asyncCall.CallID, 0) } + + return context.completeChild(asyncCall.CallID, callbackGasRemaining) } return nil } -// ExecuteSyncCallbackAndFinishOutput executes the callback and finishes the output -// TODO rename to executeLocalCallbackAndFinishOutput -func (context *asyncContext) ExecuteSyncCallbackAndFinishOutput( +// ExecuteLocalCallbackAndFinishOutput executes the callback and finishes the output +func (context *asyncContext) ExecuteLocalCallbackAndFinishOutput( asyncCall *vmhost.AsyncCall, vmOutput *vmcommon.VMOutput, _ *vmcommon.ContractCallInput, gasAccumulated uint64, err error) (bool, *vmcommon.VMOutput) { - callbackVMOutput, isComplete, _ := context.executeSyncCallback(asyncCall, vmOutput, gasAccumulated, err) + callbackVMOutput, isComplete, _ := context.executeLocalCallback(asyncCall, vmOutput, gasAccumulated, err) context.finishAsyncLocalCallbackExecution() return isComplete, callbackVMOutput } -// TODO rename to executeLocalCallback -func (context *asyncContext) executeSyncCallback( +func (context *asyncContext) executeLocalCallback( asyncCall *vmhost.AsyncCall, destinationVMOutput *vmcommon.VMOutput, gasAccumulated uint64, @@ -124,11 +121,11 @@ func (context *asyncContext) executeSyncCallback( ) (*vmcommon.VMOutput, bool, error) { callbackInput, err := context.createCallbackInput(asyncCall, destinationVMOutput, gasAccumulated, destinationErr) if err != nil { - logAsync.Trace("executeSyncCallback", "error", err) + logAsync.Trace("executeLocalCallback", "error", err) return nil, true, err } - logAsync.Trace("executeSyncCallback", + logAsync.Trace("executeLocalCallback", "caller", callbackInput.CallerAddr, "dest", callbackInput.RecipientAddr, "func", callbackInput.Function, @@ -183,7 +180,7 @@ func (context *asyncContext) executeSyncHalfOfBuiltinFunction(asyncCall *vmhost. if vmOutput.ReturnCode != vmcommon.Ok { asyncCall.Reject() if asyncCall.HasCallback() { - _, _, _ = context.executeSyncCallback(asyncCall, vmOutput, 0, err) + _, _, _ = context.executeLocalCallback(asyncCall, vmOutput, 0, err) context.finishAsyncLocalCallbackExecution() } } @@ -240,7 +237,6 @@ func (context *asyncContext) createContractCallInput(asyncCall *vmhost.AsyncCall return contractCallInput, nil } -// TODO function too large; refactor needed func (context *asyncContext) createCallbackInput( asyncCall *vmhost.AsyncCall, vmOutput *vmcommon.VMOutput, @@ -255,14 +251,12 @@ func (context *asyncContext) createCallbackInput( } arguments := context.getArgumentsForCallback(vmOutput, destinationErr) - returnWithError := false if destinationErr != nil || vmOutput.ReturnCode != vmcommon.Ok { returnWithError = true } callbackFunction := asyncCall.GetCallbackName() - dataLength := computeDataLengthFromArguments(callbackFunction, arguments) gasLimit, err := context.computeGasLimitForCallback(asyncCall, vmOutput, dataLength) if err != nil { @@ -270,9 +264,8 @@ func (context *asyncContext) createCallbackInput( } originalCaller := runtime.GetOriginalCallerAddress() - caller := context.address - lastTransferInfo := context.extractLastTransferWithoutData(caller, vmOutput) + lastTransferData := context.extractLastTransferWithoutData(caller, vmOutput) // Return to the sender SC, calling its specified callback method. contractCallInput := &vmcommon.ContractCallInput{ @@ -280,7 +273,7 @@ func (context *asyncContext) createCallbackInput( OriginalCallerAddr: originalCaller, CallerAddr: actualCallbackInitiator, Arguments: arguments, - CallValue: lastTransferInfo.callValue, + CallValue: lastTransferData.callValue, CallType: vm.AsynchronousCallBack, GasPrice: runtime.GetVMInput().GasPrice, GasProvided: gasLimit, @@ -289,7 +282,7 @@ func (context *asyncContext) createCallbackInput( OriginalTxHash: runtime.GetOriginalTxHash(), PrevTxHash: runtime.GetPrevTxHash(), ReturnCallAfterError: returnWithError, - ESDTTransfers: lastTransferInfo.lastESDTTransfers, + ESDTTransfers: lastTransferData.lastESDTTransfers, }, RecipientAddr: caller, Function: callbackFunction, diff --git a/vmhost/contexts/asyncParams.go b/vmhost/contexts/asyncParams.go index 15ec142e1..349c8fb20 100644 --- a/vmhost/contexts/asyncParams.go +++ b/vmhost/contexts/asyncParams.go @@ -8,31 +8,23 @@ import ( vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/multiversx/mx-chain-vm-common-go/txDataBuilder" "github.com/multiversx/mx-chain-vm-go/crypto" - "github.com/multiversx/mx-chain-vm-go/vmhost" ) -/* - Called to process OutputTransfers created by a - direct call (on dest) builtin function call by the VM -*/ +// AddAsyncArgumentsToOutputTransfers +// Called to process OutputTransfers created by a +// direct call (on dest) builtin function call by the VM func AddAsyncArgumentsToOutputTransfers( - output vmhost.OutputContext, - address []byte, asyncParams *vmcommon.AsyncArguments, callType vm.CallType, - vmOutput *vmcommon.VMOutput) error { + vmOutput *vmcommon.VMOutput, +) error { if asyncParams == nil { return nil } + for _, outAcc := range vmOutput.OutputAccounts { - // if !bytes.Equal(address, outAcc.Address) { - // continue - // } for t, outTransfer := range outAcc.OutputTransfers { - // if !bytes.Equal(address, outTransfer.SenderAddress) { - // continue - // } if outTransfer.CallType != callType { continue } @@ -84,122 +76,6 @@ func createDataFromAsyncParams( return callData.ToBytes(), nil } -/* - Called when a SCR for a callback is created outside the VM - (by createAsyncCallBackSCRFromVMOutput()) - This is the case - A) after an async call executed following a builtin function call, - B) other cases where processing the output trasnfers of a VMOutput did - not produce a SCR of type AsynchronousCallBack - TODO(check): function not used? -*/ -func AppendAsyncArgumentsToCallbackCallData( - hasher crypto.Hasher, - data []byte, - asyncArguments *vmcommon.AsyncArguments, - parseArgumentsFunc func(data string) ([][]byte, error)) ([]byte, error) { - - return appendAsyncParamsToCallData( - CreateCallbackAsyncParams(hasher, asyncArguments), - data, - false, - parseArgumentsFunc) -} - -/* - Called when a SCR is created from VMOutput in order to recompose - async data and call data into a transfer data ready for the SCR - (by preprocessOutTransferToSCR()) - TODO(check): function not used? -*/ -func AppendTransferAsyncDataToCallData( - callData []byte, - asyncData []byte, - parseArgumentsFunc func(data string) ([][]byte, error)) ([]byte, error) { - - var asyncParams [][]byte - if asyncData != nil { - asyncParams, _ = parseArgumentsFunc(string(asyncData)) - // string start with a @ so first parsed argument will be empty always - asyncParams = asyncParams[1:] - } else { - return callData, nil - } - - return appendAsyncParamsToCallData( - asyncParams, - callData, - true, - parseArgumentsFunc) -} - -func appendAsyncParamsToCallData( - asyncParams [][]byte, - data []byte, - hasFunction bool, - parseArgumentsFunc func(data string) ([][]byte, error)) ([]byte, error) { - - if data == nil { - return nil, nil - } - - args, err := parseArgumentsFunc(string(data)) - if err != nil { - return nil, err - } - - var functionName string - if hasFunction { - functionName = string(args[0]) - } - - // check if there is only one argument and that is 0 - if len(args) != 0 { - args = args[1:] - } - - callData := txDataBuilder.NewBuilder() - - if functionName != "" { - callData.Func(functionName) - } - - if len(args) != 0 { - for _, arg := range args { - callData.Bytes(arg) - } - } else { - if !hasFunction { - callData.Bytes([]byte{}) - } - } - - for _, asyncParam := range asyncParams { - callData.Bytes(asyncParam) - } - - return callData.ToBytes(), nil -} - -/* - Used by when a callback SCR is created - 1) after a failure of an async call - Async data is extracted (by extractAsyncCallParamsFromTxData()) and then - reappended to the new SCR's callback data (by reapendAsyncParamsToTxData()) - 2) from the last transfer (see useLastTransferAsAsyncCallBackWhenNeeded()) -*/ -func CreateCallbackAsyncParams(hasher crypto.Hasher, asyncParams *vmcommon.AsyncArguments) [][]byte { - if asyncParams == nil { - return nil - } - newAsyncParams := make([][]byte, 4) - newAsyncParams[0] = GenerateNewCallID(hasher, asyncParams.CallID, []byte{0}) - newAsyncParams[1] = asyncParams.CallID - newAsyncParams[2] = asyncParams.CallerCallID - newAsyncParams[3] = []byte{0} - return newAsyncParams -} - // GenerateNewCallID will generate a new call ID as byte slice func GenerateNewCallID(hasher crypto.Hasher, parentCallID []byte, suffix []byte) []byte { newCallID := append(parentCallID, suffix...) diff --git a/vmhost/hostCore/execution.go b/vmhost/hostCore/execution.go index 1eb04379e..f5d07bd69 100644 --- a/vmhost/hostCore/execution.go +++ b/vmhost/hostCore/execution.go @@ -429,8 +429,6 @@ func (host *vmHost) handleBuiltinFunctionCall(input *vmcommon.ContractCallInput) } err = contexts.AddAsyncArgumentsToOutputTransfers( - host.Output(), - input.RecipientAddr, input.AsyncArguments, vm.AsynchronousCall, builtinOutput) diff --git a/vmhost/interface.go b/vmhost/interface.go index 4348f2c65..e210e4434 100644 --- a/vmhost/interface.go +++ b/vmhost/interface.go @@ -399,7 +399,7 @@ type AsyncContext interface { GetAsyncCallByCallID(callID []byte) AsyncCallLocation LoadParentContextFromStackOrStorage() (AsyncContext, error) - ExecuteSyncCallbackAndFinishOutput( + ExecuteLocalCallbackAndFinishOutput( asyncCall *AsyncCall, vmOutput *vmcommon.VMOutput, destinationCallInput *vmcommon.ContractCallInput, From 76ddb2c7b60fdb518775b9f4c57d6b628e3f8b28 Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Fri, 5 Apr 2024 11:29:47 +0300 Subject: [PATCH 03/10] refactor output in case of error for async callback and empty function name check --- vmhost/contexts/output.go | 38 +++++++++++++- vmhost/hostCore/execution.go | 96 +++++++++++++++++++----------------- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/vmhost/contexts/output.go b/vmhost/contexts/output.go index 1c7e93207..82d6ab400 100644 --- a/vmhost/contexts/output.go +++ b/vmhost/contexts/output.go @@ -562,14 +562,50 @@ func (context *outputContext) DeployCode(input vmhost.CodeDeployInput) { context.codeUpdates[string(input.ContractAddress)] = empty } +// createVMOutputInCaseOfErrorOfAsyncCallback appends the deletion of the async context to the output +func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(err error, returnCode vmcommon.ReturnCode, returnMessage string) *vmcommon.VMOutput { + async := context.host.Async() + metering := context.host.Metering() + + callId := async.GetCallbackAsyncInitiatorCallID() + + context.PushState() + + context.outputState = &vmcommon.VMOutput{ + GasRemaining: 0, + GasRefund: big.NewInt(0), + ReturnCode: returnCode, + ReturnMessage: returnMessage, + } + + err = async.DeleteFromCallID(callId) + logOutput.Trace("failed to delete Async Context", "callId", callId, "err", err) + + vmOutput := context.GetVMOutput() + context.PopSetActiveState() + + metering.UpdateGasStateOnFailure(vmOutput) + + return vmOutput +} + // CreateVMOutputInCaseOfError creates a new vmOutput with the given error set as return message. func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.VMOutput { runtime := context.host.Runtime() + async := context.host.Async() + metering := context.host.Metering() + + callType := runtime.GetVMInput().CallType + runtime.AddError(err, runtime.FunctionName()) returnCode := context.resolveReturnCodeFromError(err) returnMessage := context.resolveReturnMessageFromError(err) + if callType == vm.AsynchronousCallBack && async.HasCallback() { + return context.createVMOutputInCaseOfErrorOfAsyncCallback(err, returnCode, returnMessage) + } + vmOutput := &vmcommon.VMOutput{ GasRemaining: 0, GasRefund: big.NewInt(0), @@ -577,7 +613,7 @@ func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.V ReturnMessage: returnMessage, } - context.host.Metering().UpdateGasStateOnFailure(vmOutput) + metering.UpdateGasStateOnFailure(vmOutput) return vmOutput } diff --git a/vmhost/hostCore/execution.go b/vmhost/hostCore/execution.go index d2eff2b70..0b87a401c 100644 --- a/vmhost/hostCore/execution.go +++ b/vmhost/hostCore/execution.go @@ -1145,6 +1145,14 @@ func (host *vmHost) checkFinalGasAfterExit() error { return nil } +func (host *vmHost) checkValidFunctionName(name string) error { + if name == "" { + return executor.ErrInvalidFunction + } + + return nil +} + func (host *vmHost) callInitFunction() error { return host.callSCFunction(vmhost.InitFunctionName) } @@ -1154,12 +1162,18 @@ func (host *vmHost) callUpgradeFunction() error { } func (host *vmHost) callSCFunction(functionName string) error { + err := host.checkValidFunctionName(functionName) + if err != nil { + log.Trace("call SC method failed", "error", err, "src", "checkValidFunctionName") + return err + } + runtime := host.Runtime() if !runtime.HasFunction(functionName) { return executor.ErrFuncNotFound } - err := runtime.CallSCFunction(functionName) + err = runtime.CallSCFunction(functionName) if err != nil { err = host.handleBreakpointIfAny(err) } @@ -1236,12 +1250,6 @@ func (host *vmHost) callSCMethodAsynchronousCallBack() error { metering.UseGas(metering.GasLeft()) } - // TODO matei-p R2 Returning an error here will cause the VMOutput to be - // empty (due to CreateVMOutputInCaseOfError()). But in release 2 of - // Promises, CreateVMOutputInCaseOfError() should still contain storage - // deletions caused by AsyncContext cleanup, even if callbackErr != nil and - // was returned here. The storage deletions MUST be persisted in the data - // trie once R2 goes live. if !isCallComplete { return callbackErr } @@ -1263,47 +1271,47 @@ func (host *vmHost) callFunctionAndExecuteAsync() (bool, error) { runtime := host.Runtime() async := host.Async() - // TODO refactor this, and apply this condition in other places where a - // function is called - if runtime.FunctionName() != "" { - err := host.verifyAllowedFunctionCall() - if err != nil { - log.Trace("call SC method failed", "error", err, "src", "verifyAllowedFunctionCall") - return false, err - } + err := host.checkValidFunctionName(runtime.FunctionName()) + if err != nil { + log.Trace("call SC method failed", "error", err, "src", "checkValidFunctionName") + return false, err + } - functionName, err := runtime.FunctionNameChecked() - if err != nil { - log.Trace("call SC method failed", "error", err, "src", "FunctionNameChecked") - return false, err - } + err = host.verifyAllowedFunctionCall() + if err != nil { + log.Trace("call SC method failed", "error", err, "src", "verifyAllowedFunctionCall") + return false, err + } - err = runtime.CallSCFunction(functionName) - if err != nil { - err = host.handleBreakpointIfAny(err) - log.Trace("breakpoint detected and handled", "err", err) - } - if err == nil { - err = host.checkFinalGasAfterExit() - } - if err != nil { - log.Trace("call SC method failed", "error", err, "src", "sc function") - return true, err - } + functionName, err := runtime.FunctionNameChecked() + if err != nil { + log.Trace("call SC method failed", "error", err, "src", "FunctionNameChecked") + return false, err + } - err = async.Execute() - if err != nil { - log.Trace("call SC method failed", "error", err, "src", "async execution") - return false, err - } + err = runtime.CallSCFunction(functionName) + if err != nil { + err = host.handleBreakpointIfAny(err) + log.Trace("breakpoint detected and handled", "err", err) + } + if err == nil { + err = host.checkFinalGasAfterExit() + } + if err != nil { + log.Trace("call SC method failed", "error", err, "src", "sc function") + return true, err + } - if !async.IsComplete() || async.HasLegacyGroup() { - async.SetResults(host.Output().GetVMOutput()) - err = async.Save() - return false, err - } - } else { - return false, executor.ErrInvalidFunction + err = async.Execute() + if err != nil { + log.Trace("call SC method failed", "error", err, "src", "async execution") + return false, err + } + + if !async.IsComplete() || async.HasLegacyGroup() { + async.SetResults(host.Output().GetVMOutput()) + err = async.Save() + return false, err } return true, nil From c9ba5b7e8740b47da9483f9d2d1c7e97d2b9c724 Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Fri, 5 Apr 2024 12:39:22 +0300 Subject: [PATCH 04/10] refactor output in case of error for async callback --- vmhost/contexts/output.go | 18 ++++++++++-------- vmhost/hosttest/execution_test.go | 2 ++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vmhost/contexts/output.go b/vmhost/contexts/output.go index 82d6ab400..480d1a72a 100644 --- a/vmhost/contexts/output.go +++ b/vmhost/contexts/output.go @@ -572,16 +572,19 @@ func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(err err context.PushState() context.outputState = &vmcommon.VMOutput{ - GasRemaining: 0, - GasRefund: big.NewInt(0), - ReturnCode: returnCode, - ReturnMessage: returnMessage, + GasRemaining: 0, + GasRefund: big.NewInt(0), + ReturnCode: returnCode, + ReturnMessage: returnMessage, + OutputAccounts: make(map[string]*vmcommon.OutputAccount), } err = async.DeleteFromCallID(callId) - logOutput.Trace("failed to delete Async Context", "callId", callId, "err", err) + if err != nil { + logOutput.Trace("failed to delete Async Context", "callId", callId, "err", err) + } - vmOutput := context.GetVMOutput() + vmOutput := context.outputState // GetVMOutput updates metering context.PopSetActiveState() metering.UpdateGasStateOnFailure(vmOutput) @@ -592,7 +595,6 @@ func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(err err // CreateVMOutputInCaseOfError creates a new vmOutput with the given error set as return message. func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.VMOutput { runtime := context.host.Runtime() - async := context.host.Async() metering := context.host.Metering() callType := runtime.GetVMInput().CallType @@ -602,7 +604,7 @@ func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.V returnCode := context.resolveReturnCodeFromError(err) returnMessage := context.resolveReturnMessageFromError(err) - if callType == vm.AsynchronousCallBack && async.HasCallback() { + if callType == vm.AsynchronousCallBack { return context.createVMOutputInCaseOfErrorOfAsyncCallback(err, returnCode, returnMessage) } diff --git a/vmhost/hosttest/execution_test.go b/vmhost/hosttest/execution_test.go index f2edcfdd0..018baaf2e 100644 --- a/vmhost/hosttest/execution_test.go +++ b/vmhost/hosttest/execution_test.go @@ -2905,6 +2905,8 @@ func TestExecution_AsyncCall_CallBackFails(t *testing.T) { } func TestExecution_AsyncCall_Promises_CallBackFails(t *testing.T) { + _ = logger.SetLogLevel("*:TRACE") + // same scenario as in TestExecution_AsyncCall_CallBackFails txHash := []byte("txhash..........................") test.BuildInstanceCallTest(t). From f2a60049be6c8ad43261be42525cf29690c667bc Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Fri, 5 Apr 2024 12:39:52 +0300 Subject: [PATCH 05/10] refactor output in case of error for async callback --- vmhost/hosttest/execution_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/vmhost/hosttest/execution_test.go b/vmhost/hosttest/execution_test.go index 018baaf2e..f2edcfdd0 100644 --- a/vmhost/hosttest/execution_test.go +++ b/vmhost/hosttest/execution_test.go @@ -2905,8 +2905,6 @@ func TestExecution_AsyncCall_CallBackFails(t *testing.T) { } func TestExecution_AsyncCall_Promises_CallBackFails(t *testing.T) { - _ = logger.SetLogLevel("*:TRACE") - // same scenario as in TestExecution_AsyncCall_CallBackFails txHash := []byte("txhash..........................") test.BuildInstanceCallTest(t). From 3a2719cf41476cbb66e3747d902014849c0288de Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Thu, 11 Apr 2024 18:30:16 +0300 Subject: [PATCH 06/10] Add AsyncV3 flag --- vmhost/flags.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vmhost/flags.go b/vmhost/flags.go index ad1a4cbfe..de4d175b2 100644 --- a/vmhost/flags.go +++ b/vmhost/flags.go @@ -3,6 +3,8 @@ package vmhost import "github.com/multiversx/mx-chain-core-go/core" const ( + // AsyncV3Flag defines the flag that activates async v3 + AsyncV3Flag core.EnableEpochFlag = "AsyncV3Flag" // MultiESDTTransferFixOnCallBackFlag defines the flag that activates the multi esdt transfer fix on callback MultiESDTTransferFixOnCallBackFlag core.EnableEpochFlag = "MultiESDTTransferFixOnCallBackFlag" // RemoveNonUpdatedStorageFlag defines the flag that activates the remove non updated storage fix From 6adf73689e1265630b3f8cab0ee8062f211e7df5 Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Thu, 11 Apr 2024 18:30:39 +0300 Subject: [PATCH 07/10] Add test for output in case of error of async callback --- vmhost/contexts/async_test.go | 74 +++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/vmhost/contexts/async_test.go b/vmhost/contexts/async_test.go index 0093c7ca1..12144a237 100644 --- a/vmhost/contexts/async_test.go +++ b/vmhost/contexts/async_test.go @@ -2,6 +2,7 @@ package contexts import ( "errors" + "github.com/multiversx/mx-chain-core-go/core" "math/big" "testing" @@ -496,6 +497,79 @@ func TestAsyncContext_UpdateCurrentCallStatus(t *testing.T) { require.Equal(t, vmhost.AsyncCallRejected, asyncCall.Status) } +func TestAsyncContext_OutputInCaseOfErrorInCallback(t *testing.T) { + user := []byte("user") + contractA := []byte("contractA") + contractB := []byte("contractB") + + host, _ := initializeVMAndWasmerAsyncContext(t) + host.EnableEpochsHandlerField = &worldmock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == vmhost.AsyncV3Flag + }, + } + + async := makeAsyncContext(t, host, contractA) + host.Storage().SetAddress(contractA) + host.AsyncContext = async + + vmInput := &vmcommon.ContractCallInput{ + VMInput: vmcommon.VMInput{ + CallerAddr: user, + Arguments: [][]byte{{0}}, + CallType: vm.DirectCall, + }, + RecipientAddr: contractA, + } + host.Runtime().InitStateFromContractCallInput(vmInput) + + err := async.RegisterAsyncCall("", &vmhost.AsyncCall{ + Destination: contractB, + Data: []byte("function"), + }) + require.Nil(t, err) + + err = async.Save() + require.Nil(t, err) + + asyncCallId := async.GetCallID() + asyncStoragePrefix := host.Storage().GetVmProtectedPrefix(vmhost.AsyncDataPrefix) + asyncCallKey := vmhost.CustomStorageKey(string(asyncStoragePrefix), asyncCallId) + + data, _, _, _ := host.Storage().GetStorageUnmetered(asyncCallKey) + require.NotEqual(t, len(data), 0) + + vmInput = &vmcommon.ContractCallInput{ + VMInput: vmcommon.VMInput{ + CallerAddr: contractB, + Arguments: [][]byte{{0}}, + CallType: vm.AsynchronousCallBack, + }, + RecipientAddr: contractA, + } + host.Runtime().InitStateFromContractCallInput(vmInput) + + async.callbackAsyncInitiatorCallID = asyncCallId + async.callType = vmInput.CallType + err = async.LoadParentContext() + require.Nil(t, err) + + vmOutput := host.Output().CreateVMOutputInCaseOfError(vmhost.ErrNotEnoughGas) + outputAccount := vmOutput.OutputAccounts[string(contractA)] + + require.NotNil(t, outputAccount) + + storageUpdates := outputAccount.StorageUpdates + require.Equal(t, len(storageUpdates), 1) + + asyncContextDeletionUpdate := storageUpdates[string(asyncCallKey)] + require.NotNil(t, asyncContextDeletionUpdate) + require.Equal(t, len(asyncContextDeletionUpdate.Data), 0) + + data, _, _, _ = host.Storage().GetStorageUnmetered(asyncCallKey) + require.Equal(t, len(data), 0) +} + func TestAsyncContext_SendAsyncCallCrossShard(t *testing.T) { host, world := initializeVMAndWasmerAsyncContext(t) world.AcctMap.PutAccount(&worldmock.Account{ From a0e8e8472e699d572507f4b19533c69742940d57 Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Thu, 11 Apr 2024 18:31:49 +0300 Subject: [PATCH 08/10] Remove redundant output stack operations and guard deleteion by AsyncV3 flag --- vmhost/contexts/output.go | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vmhost/contexts/output.go b/vmhost/contexts/output.go index 480d1a72a..f84e5f7e3 100644 --- a/vmhost/contexts/output.go +++ b/vmhost/contexts/output.go @@ -569,8 +569,6 @@ func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(err err callId := async.GetCallbackAsyncInitiatorCallID() - context.PushState() - context.outputState = &vmcommon.VMOutput{ GasRemaining: 0, GasRefund: big.NewInt(0), @@ -584,12 +582,9 @@ func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(err err logOutput.Trace("failed to delete Async Context", "callId", callId, "err", err) } - vmOutput := context.outputState // GetVMOutput updates metering - context.PopSetActiveState() - - metering.UpdateGasStateOnFailure(vmOutput) + metering.UpdateGasStateOnFailure(context.outputState) - return vmOutput + return context.outputState } // CreateVMOutputInCaseOfError creates a new vmOutput with the given error set as return message. @@ -604,7 +599,7 @@ func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.V returnCode := context.resolveReturnCodeFromError(err) returnMessage := context.resolveReturnMessageFromError(err) - if callType == vm.AsynchronousCallBack { + if context.host.EnableEpochsHandler().IsFlagEnabled(vmhost.AsyncV3Flag) && callType == vm.AsynchronousCallBack { return context.createVMOutputInCaseOfErrorOfAsyncCallback(err, returnCode, returnMessage) } From bbee461aa02e2f6c03cb656e9e07801b59cc6771 Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Thu, 11 Apr 2024 19:16:11 +0300 Subject: [PATCH 09/10] A bit of cleanup --- vmhost/contexts/output.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vmhost/contexts/output.go b/vmhost/contexts/output.go index f84e5f7e3..79ef6999c 100644 --- a/vmhost/contexts/output.go +++ b/vmhost/contexts/output.go @@ -563,7 +563,7 @@ func (context *outputContext) DeployCode(input vmhost.CodeDeployInput) { } // createVMOutputInCaseOfErrorOfAsyncCallback appends the deletion of the async context to the output -func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(err error, returnCode vmcommon.ReturnCode, returnMessage string) *vmcommon.VMOutput { +func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(returnCode vmcommon.ReturnCode, returnMessage string) *vmcommon.VMOutput { async := context.host.Async() metering := context.host.Metering() @@ -577,7 +577,7 @@ func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(err err OutputAccounts: make(map[string]*vmcommon.OutputAccount), } - err = async.DeleteFromCallID(callId) + err := async.DeleteFromCallID(callId) if err != nil { logOutput.Trace("failed to delete Async Context", "callId", callId, "err", err) } @@ -600,7 +600,7 @@ func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.V returnMessage := context.resolveReturnMessageFromError(err) if context.host.EnableEpochsHandler().IsFlagEnabled(vmhost.AsyncV3Flag) && callType == vm.AsynchronousCallBack { - return context.createVMOutputInCaseOfErrorOfAsyncCallback(err, returnCode, returnMessage) + return context.createVMOutputInCaseOfErrorOfAsyncCallback(returnCode, returnMessage) } vmOutput := &vmcommon.VMOutput{ From c4854af3d79cd158ceba91e989d7f23c03f2b42b Mon Sep 17 00:00:00 2001 From: Laurentiu Ciobanu Date: Fri, 12 Apr 2024 12:22:48 +0300 Subject: [PATCH 10/10] AsyncV3Flag swap order --- vmhost/flags.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vmhost/flags.go b/vmhost/flags.go index de4d175b2..15bc2ab73 100644 --- a/vmhost/flags.go +++ b/vmhost/flags.go @@ -3,8 +3,6 @@ package vmhost import "github.com/multiversx/mx-chain-core-go/core" const ( - // AsyncV3Flag defines the flag that activates async v3 - AsyncV3Flag core.EnableEpochFlag = "AsyncV3Flag" // MultiESDTTransferFixOnCallBackFlag defines the flag that activates the multi esdt transfer fix on callback MultiESDTTransferFixOnCallBackFlag core.EnableEpochFlag = "MultiESDTTransferFixOnCallBackFlag" // RemoveNonUpdatedStorageFlag defines the flag that activates the remove non updated storage fix @@ -31,4 +29,6 @@ const ( FixOOGReturnCodeFlag core.EnableEpochFlag = "FixOOGReturnCodeFlag" // DynamicGasCostForDataTrieStorageLoadFlag defines the flag that activates the dynamic gas cost for data trie storage load DynamicGasCostForDataTrieStorageLoadFlag core.EnableEpochFlag = "DynamicGasCostForDataTrieStorageLoadFlag" + // AsyncV3Flag defines the flag that activates async v3 + AsyncV3Flag core.EnableEpochFlag = "AsyncV3Flag" )