diff --git a/vmhost/vmhooks/baseOps.go b/vmhost/vmhooks/baseOps.go index c2e1f2aa8..190ed575b 100644 --- a/vmhost/vmhooks/baseOps.go +++ b/vmhost/vmhooks/baseOps.go @@ -319,40 +319,6 @@ func (context *VMHooksImpl) GetBlockHash(nonce int64, resultOffset executor.MemP return 0 } -func getESDTDataFromBlockchainHook( - context *VMHooksImpl, - addressOffset executor.MemPtr, - tokenIDOffset executor.MemPtr, - tokenIDLen executor.MemLength, - nonce int64, -) (*esdt.ESDigitalToken, error) { - metering := context.GetMeteringContext() - blockchain := context.GetBlockchainContext() - - gasToUse := metering.GasSchedule().BaseOpsAPICost.GetExternalBalance - err := metering.UseGasBounded(gasToUse) - if err != nil { - return nil, err - } - - address, err := context.MemLoad(addressOffset, vmhost.AddressLen) - if err != nil { - return nil, err - } - - tokenID, err := context.MemLoad(tokenIDOffset, tokenIDLen) - if err != nil { - return nil, err - } - - esdtToken, err := blockchain.GetESDTToken(address, tokenID, uint64(nonce)) - if err != nil { - return nil, err - } - - return esdtToken, nil -} - // GetESDTBalance VMHooks implementation. // @autogenerate(VMHooks) func (context *VMHooksImpl) GetESDTBalance( @@ -362,22 +328,16 @@ func (context *VMHooksImpl) GetESDTBalance( nonce int64, resultOffset executor.MemPtr, ) int32 { - metering := context.GetMeteringContext() - metering.StartGasTracing(getESDTBalanceName) - - esdtData, err := getESDTDataFromBlockchainHook(context, addressOffset, tokenIDOffset, tokenIDLen, nonce) - - if err != nil { - context.FailExecution(err) - return -1 - } - err = context.MemStore(resultOffset, esdtData.Value.Bytes()) - if err != nil { - context.FailExecution(err) - return -1 - } + return context.withESDTData(addressOffset, tokenIDOffset, tokenIDLen, nonce, getESDTBalanceName, + func(goContext *VMHooksImpl, esdtData *esdt.ESDigitalToken) int32 { + err := goContext.MemStore(resultOffset, esdtData.Value.Bytes()) + if err != nil { + goContext.FailExecution(err) + return -1 + } - return int32(len(esdtData.Value.Bytes())) + return int32(len(esdtData.Value.Bytes())) + }) } // GetESDTNFTNameLength VMHooks implementation. @@ -388,21 +348,15 @@ func (context *VMHooksImpl) GetESDTNFTNameLength( tokenIDLen executor.MemLength, nonce int64, ) int32 { - metering := context.GetMeteringContext() - metering.StartGasTracing(getESDTNFTNameLengthName) - - esdtData, err := getESDTDataFromBlockchainHook(context, addressOffset, tokenIDOffset, tokenIDLen, nonce) - - if err != nil { - context.FailExecution(err) - return -1 - } - if esdtData == nil || esdtData.TokenMetaData == nil { - FailExecution(context.GetVMHost(), vmhost.ErrNilESDTData) - return 0 - } + return context.withESDTData(addressOffset, tokenIDOffset, tokenIDLen, nonce, getESDTNFTNameLengthName, + func(goContext *VMHooksImpl, esdtData *esdt.ESDigitalToken) int32 { + if esdtData == nil || esdtData.TokenMetaData == nil { + FailExecution(goContext.GetVMHost(), vmhost.ErrNilESDTData) + return 0 + } - return int32(len(esdtData.TokenMetaData.Name)) + return int32(len(esdtData.TokenMetaData.Name)) + }) } // GetESDTNFTAttributeLength VMHooks implementation. @@ -413,21 +367,15 @@ func (context *VMHooksImpl) GetESDTNFTAttributeLength( tokenIDLen executor.MemLength, nonce int64, ) int32 { - metering := context.GetMeteringContext() - metering.StartGasTracing(getESDTNFTAttributeLengthName) - - esdtData, err := getESDTDataFromBlockchainHook(context, addressOffset, tokenIDOffset, tokenIDLen, nonce) - - if err != nil { - context.FailExecution(err) - return -1 - } - if esdtData == nil || esdtData.TokenMetaData == nil { - FailExecution(context.GetVMHost(), vmhost.ErrNilESDTData) - return 0 - } + return context.withESDTData(addressOffset, tokenIDOffset, tokenIDLen, nonce, getESDTNFTAttributeLengthName, + func(goContext *VMHooksImpl, esdtData *esdt.ESDigitalToken) int32 { + if esdtData == nil || esdtData.TokenMetaData == nil { + FailExecution(goContext.GetVMHost(), vmhost.ErrNilESDTData) + return 0 + } - return int32(len(esdtData.TokenMetaData.Attributes)) + return int32(len(esdtData.TokenMetaData.Attributes)) + }) } // GetESDTNFTURILength VMHooks implementation. @@ -438,24 +386,18 @@ func (context *VMHooksImpl) GetESDTNFTURILength( tokenIDLen executor.MemLength, nonce int64, ) int32 { - metering := context.GetMeteringContext() - metering.StartGasTracing(getESDTNFTURILengthName) - - esdtData, err := getESDTDataFromBlockchainHook(context, addressOffset, tokenIDOffset, tokenIDLen, nonce) - - if err != nil { - context.FailExecution(err) - return -1 - } - if esdtData == nil || esdtData.TokenMetaData == nil { - FailExecution(context.GetVMHost(), vmhost.ErrNilESDTData) - return 0 - } - if len(esdtData.TokenMetaData.URIs) == 0 { - return 0 - } + return context.withESDTData(addressOffset, tokenIDOffset, tokenIDLen, nonce, getESDTNFTURILengthName, + func(goContext *VMHooksImpl, esdtData *esdt.ESDigitalToken) int32 { + if esdtData == nil || esdtData.TokenMetaData == nil { + FailExecution(goContext.GetVMHost(), vmhost.ErrNilESDTData) + return 0 + } + if len(esdtData.TokenMetaData.URIs) == 0 { + return 0 + } - return int32(len(esdtData.TokenMetaData.URIs[0])) + return int32(len(esdtData.TokenMetaData.URIs[0])) + }) } // GetESDTTokenData VMHooks implementation. @@ -474,60 +416,54 @@ func (context *VMHooksImpl) GetESDTTokenData( royaltiesHandle int32, urisOffset executor.MemPtr, ) int32 { - managedType := context.GetManagedTypesContext() - metering := context.GetMeteringContext() - metering.StartGasTracing(getESDTTokenDataName) + return context.withESDTData(addressOffset, tokenIDOffset, tokenIDLen, nonce, getESDTTokenDataName, + func(goContext *VMHooksImpl, esdtData *esdt.ESDigitalToken) int32 { + managedType := goContext.GetManagedTypesContext() - esdtData, err := getESDTDataFromBlockchainHook(context, addressOffset, tokenIDOffset, tokenIDLen, nonce) + value := managedType.GetBigIntOrCreate(valueHandle) + value.Set(esdtData.Value) - if err != nil { - context.FailExecution(err) - return -1 - } - - value := managedType.GetBigIntOrCreate(valueHandle) - value.Set(esdtData.Value) - - err = context.MemStore(propertiesOffset, esdtData.Properties) - if err != nil { - context.FailExecution(err) - return -1 - } - - if esdtData.TokenMetaData != nil { - err = context.MemStore(hashOffset, esdtData.TokenMetaData.Hash) - if err != nil { - context.FailExecution(err) - return -1 - } - err = context.MemStore(nameOffset, esdtData.TokenMetaData.Name) - if err != nil { - context.FailExecution(err) - return -1 - } - err = context.MemStore(attributesOffset, esdtData.TokenMetaData.Attributes) - if err != nil { - context.FailExecution(err) - return -1 - } - err = context.MemStore(creatorOffset, esdtData.TokenMetaData.Creator) - if err != nil { - context.FailExecution(err) - return -1 - } - - royalties := managedType.GetBigIntOrCreate(royaltiesHandle) - royalties.SetUint64(uint64(esdtData.TokenMetaData.Royalties)) - - if len(esdtData.TokenMetaData.URIs) > 0 { - err = context.MemStore(urisOffset, esdtData.TokenMetaData.URIs[0]) + err := goContext.MemStore(propertiesOffset, esdtData.Properties) if err != nil { - context.FailExecution(err) + goContext.FailExecution(err) return -1 } - } - } - return int32(len(esdtData.Value.Bytes())) + + if esdtData.TokenMetaData != nil { + err = goContext.MemStore(hashOffset, esdtData.TokenMetaData.Hash) + if err != nil { + goContext.FailExecution(err) + return -1 + } + err = goContext.MemStore(nameOffset, esdtData.TokenMetaData.Name) + if err != nil { + goContext.FailExecution(err) + return -1 + } + err = goContext.MemStore(attributesOffset, esdtData.TokenMetaData.Attributes) + if err != nil { + goContext.FailExecution(err) + return -1 + } + err = goContext.MemStore(creatorOffset, esdtData.TokenMetaData.Creator) + if err != nil { + goContext.FailExecution(err) + return -1 + } + + royalties := managedType.GetBigIntOrCreate(royaltiesHandle) + royalties.SetUint64(uint64(esdtData.TokenMetaData.Royalties)) + + if len(esdtData.TokenMetaData.URIs) > 0 { + err = goContext.MemStore(urisOffset, esdtData.TokenMetaData.URIs[0]) + if err != nil { + goContext.FailExecution(err) + return -1 + } + } + } + return int32(len(esdtData.Value.Bytes())) + }) } // GetESDTLocalRoles VMHooks implementation. diff --git a/vmhost/vmhooks/cryptoei.go b/vmhost/vmhooks/cryptoei.go index f325375c1..58a55ca2f 100644 --- a/vmhost/vmhooks/cryptoei.go +++ b/vmhost/vmhooks/cryptoei.go @@ -92,42 +92,17 @@ func (context *VMHooksImpl) Sha256( // ManagedSha256 VMHooks implementation. // @autogenerate(VMHooks) func (context *VMHooksImpl) ManagedSha256(inputHandle, outputHandle int32) int32 { - managedType := context.GetManagedTypesContext() crypto := context.GetCryptoContext() - enableEpochsHandler := context.host.EnableEpochsHandler() metering := context.GetMeteringContext() - err := metering.UseGasBoundedAndAddTracedGas(sha256Name, metering.GasSchedule().CryptoAPICost.SHA256) - if err != nil { - context.FailExecution(err) - return 1 - } - - inputBytes, err := managedType.GetBytes(inputHandle) - if err != nil { - context.FailExecution(err) - return 1 - } - - err = managedType.ConsumeGasForBytes(inputBytes) - if err != nil { - context.FailExecution(err) - return 1 - } - - resultBytes, err := crypto.Sha256(inputBytes) - if err != nil { - if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { - err = vmhost.ErrSha256Hash - } - - context.FailExecution(err) - return 1 - } - - managedType.SetBytes(outputHandle, resultBytes) - - return 0 + return context.managedHash( + inputHandle, + outputHandle, + sha256Name, + metering.GasSchedule().CryptoAPICost.SHA256, + crypto.Sha256, + vmhost.ErrSha256Hash, + ) } // Keccak256 VMHooks implementation. @@ -173,42 +148,17 @@ func (context *VMHooksImpl) Keccak256(dataOffset executor.MemPtr, length executo // ManagedKeccak256 VMHooks implementation. // @autogenerate(VMHooks) func (context *VMHooksImpl) ManagedKeccak256(inputHandle, outputHandle int32) int32 { - managedType := context.GetManagedTypesContext() crypto := context.GetCryptoContext() - enableEpochsHandler := context.host.EnableEpochsHandler() metering := context.GetMeteringContext() - err := metering.UseGasBoundedAndAddTracedGas(keccak256Name, metering.GasSchedule().CryptoAPICost.Keccak256) - if err != nil { - context.FailExecution(err) - return 1 - } - - inputBytes, err := managedType.GetBytes(inputHandle) - if err != nil { - context.FailExecution(err) - return 1 - } - - err = managedType.ConsumeGasForBytes(inputBytes) - if err != nil { - context.FailExecution(err) - return 1 - } - - resultBytes, err := crypto.Keccak256(inputBytes) - if err != nil { - if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { - err = vmhost.ErrKeccak256Hash - } - - context.FailExecution(err) - return 1 - } - - managedType.SetBytes(outputHandle, resultBytes) - - return 0 + return context.managedHash( + inputHandle, + outputHandle, + keccak256Name, + metering.GasSchedule().CryptoAPICost.Keccak256, + crypto.Keccak256, + vmhost.ErrKeccak256Hash, + ) } // Ripemd160 VMHooks implementation. @@ -254,48 +204,17 @@ func (context *VMHooksImpl) Ripemd160(dataOffset executor.MemPtr, length executo // ManagedRipemd160 VMHooks implementation. // @autogenerate(VMHooks) func (context *VMHooksImpl) ManagedRipemd160(inputHandle int32, outputHandle int32) int32 { - host := context.GetVMHost() - return ManagedRipemd160WithHost(host, inputHandle, outputHandle) -} - -// ManagedRipemd160WithHost VMHooks implementation. -func ManagedRipemd160WithHost(host vmhost.VMHost, inputHandle int32, outputHandle int32) int32 { - metering := host.Metering() - managedType := host.ManagedTypes() - crypto := host.Crypto() - enableEpochsHandler := host.EnableEpochsHandler() - - err := metering.UseGasBoundedAndAddTracedGas(ripemd160Name, metering.GasSchedule().CryptoAPICost.Ripemd160) - if err != nil { - FailExecution(host, err) - return 1 - } - - inputBytes, err := managedType.GetBytes(inputHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(inputBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - result, err := crypto.Ripemd160(inputBytes) - if err != nil { - if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { - err = vmhost.ErrRipemd160Hash - } - - FailExecution(host, err) - return 1 - } - - managedType.SetBytes(outputHandle, result) + crypto := context.GetCryptoContext() + metering := context.GetMeteringContext() - return 0 + return context.managedHash( + inputHandle, + outputHandle, + ripemd160Name, + metering.GasSchedule().CryptoAPICost.Ripemd160, + crypto.Ripemd160, + vmhost.ErrRipemd160Hash, + ) } // VerifyBLS VMHooks implementation. @@ -363,8 +282,14 @@ func (context *VMHooksImpl) ManagedVerifyBLS( messageHandle int32, sigHandle int32, ) int32 { - host := context.GetVMHost() - return ManagedVerifyBLSWithHost(host, keyHandle, messageHandle, sigHandle, verifyBLSName) + crypto := context.GetCryptoContext() + return context.managedVerifyWithOperands(verifyBLSName, vmhost.ErrBlsVerify, func() error { + keyBytes, msgBytes, sigBytes, err := context.getSignatureOperands(keyHandle, messageHandle, sigHandle) + if err != nil { + return err + } + return crypto.VerifyBLS(keyBytes, msgBytes, sigBytes) + }) } func useGasForCryptoVerify( @@ -390,89 +315,6 @@ func useGasForCryptoVerify( return metering.UseGasBounded(gasToUse) } -// ManagedVerifyBLSWithHost VMHooks implementation. -func ManagedVerifyBLSWithHost( - host vmhost.VMHost, - keyHandle int32, - messageHandle int32, - sigHandle int32, - sigVerificationType string, -) int32 { - runtime := host.Runtime() - metering := host.Metering() - managedType := host.ManagedTypes() - crypto := host.Crypto() - enableEpochsHandler := host.EnableEpochsHandler() - err := useGasForCryptoVerify(metering, sigVerificationType) - if err != nil && runtime.UseGasBoundedShouldFailExecution() { - FailExecution(host, err) - return 1 - } - - keyBytes, err := managedType.GetBytes(keyHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(keyBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - msgBytes, err := managedType.GetBytes(messageHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(msgBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - sigBytes, err := managedType.GetBytes(sigHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(sigBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - invalidSigErr := vmhost.ErrInvalidArgument - switch sigVerificationType { - case verifyBLSName: - invalidSigErr = crypto.VerifyBLS(keyBytes, msgBytes, sigBytes) - case verifyBLSSignatureShare: - invalidSigErr = crypto.VerifySignatureShare(keyBytes, msgBytes, sigBytes) - case verifyBLSAggregatedSignature: - var pubKeyBytes [][]byte - pubKeyBytes, _, invalidSigErr = managedType.ReadManagedVecOfManagedBuffers(keyHandle) - if invalidSigErr != nil { - break - } - - invalidSigErr = crypto.VerifyAggregatedSig(pubKeyBytes, msgBytes, sigBytes) - } - - if invalidSigErr != nil { - if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { - invalidSigErr = vmhost.ErrBlsVerify - } - - FailExecution(host, invalidSigErr) - return -1 - } - - return 0 -} - // VerifyEd25519 VMHooks implementation. // @autogenerate(VMHooks) func (context *VMHooksImpl) VerifyEd25519( @@ -536,75 +378,14 @@ func (context *VMHooksImpl) VerifyEd25519( func (context *VMHooksImpl) ManagedVerifyEd25519( keyHandle, messageHandle, sigHandle int32, ) int32 { - host := context.GetVMHost() - return ManagedVerifyEd25519WithHost(host, keyHandle, messageHandle, sigHandle) -} - -// ManagedVerifyEd25519WithHost VMHooks implementation. -func ManagedVerifyEd25519WithHost( - host vmhost.VMHost, - keyHandle, messageHandle, sigHandle int32, -) int32 { - metering := host.Metering() - managedType := host.ManagedTypes() - enableEpochsHandler := host.EnableEpochsHandler() - crypto := host.Crypto() - metering.StartGasTracing(verifyEd25519Name) - - gasToUse := metering.GasSchedule().CryptoAPICost.VerifyEd25519 - err := metering.UseGasBounded(gasToUse) - if err != nil { - FailExecution(host, err) - return 1 - } - - keyBytes, err := managedType.GetBytes(keyHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(keyBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - msgBytes, err := managedType.GetBytes(messageHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(msgBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - sigBytes, err := managedType.GetBytes(sigHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(sigBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - invalidSigErr := crypto.VerifyEd25519(keyBytes, msgBytes, sigBytes) - if invalidSigErr != nil { - if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { - invalidSigErr = vmhost.ErrEd25519Verify + crypto := context.GetCryptoContext() + return context.managedVerifyWithOperands(verifyEd25519Name, vmhost.ErrEd25519Verify, func() error { + keyBytes, msgBytes, sigBytes, err := context.getSignatureOperands(keyHandle, messageHandle, sigHandle) + if err != nil { + return err } - - FailExecution(host, invalidSigErr) - return -1 - } - - return 0 + return crypto.VerifyEd25519(keyBytes, msgBytes, sigBytes) + }) } // VerifyCustomSecp256k1 VMHooks implementation. @@ -688,89 +469,14 @@ func (context *VMHooksImpl) ManagedVerifyCustomSecp256k1( keyHandle, messageHandle, sigHandle int32, hashType int32, ) int32 { - host := context.GetVMHost() - return ManagedVerifyCustomSecp256k1WithHost( - host, - keyHandle, - messageHandle, - sigHandle, - hashType, - verifyCustomSecp256k1Name) -} - -// ManagedVerifyCustomSecp256k1WithHost VMHooks implementation. -func ManagedVerifyCustomSecp256k1WithHost( - host vmhost.VMHost, - keyHandle, messageHandle, sigHandle int32, - hashType int32, - verifyCryptoFunc string, -) int32 { - runtime := host.Runtime() - enableEpochsHandler := host.EnableEpochsHandler() - metering := host.Metering() - managedType := host.ManagedTypes() - crypto := host.Crypto() - - err := useGasForCryptoVerify(metering, verifyCryptoFunc) - if err != nil && runtime.UseGasBoundedShouldFailExecution() { - FailExecution(host, err) - return 1 - } - - keyBytes, err := managedType.GetBytes(keyHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(keyBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - msgBytes, err := managedType.GetBytes(messageHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(msgBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - sigBytes, err := managedType.GetBytes(sigHandle) - if err != nil { - FailExecution(host, err) - return 1 - } - - err = managedType.ConsumeGasForBytes(sigBytes) - if err != nil { - FailExecution(host, err) - return 1 - } - - invalidSigErr := vmhost.ErrInvalidArgument - switch verifyCryptoFunc { - case verifyCustomSecp256k1Name: - invalidSigErr = crypto.VerifySecp256k1(keyBytes, msgBytes, sigBytes, uint8(hashType)) - case verifySecp256R1Signature: - invalidSigErr = crypto.VerifySecp256r1(keyBytes, msgBytes, sigBytes) - } - - if invalidSigErr != nil { - if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { - invalidSigErr = vmhost.ErrSecp256k1Verify + crypto := context.GetCryptoContext() + return context.managedVerifyWithOperands(verifyCustomSecp256k1Name, vmhost.ErrSecp256k1Verify, func() error { + keyBytes, msgBytes, sigBytes, err := context.getSignatureOperands(keyHandle, messageHandle, sigHandle) + if err != nil { + return err } - - FailExecution(host, invalidSigErr) - return -1 - } - - return 0 + return crypto.VerifySecp256k1(keyBytes, msgBytes, sigBytes, uint8(hashType)) + }) } // VerifySecp256k1 VMHooks implementation. @@ -797,22 +503,11 @@ func (context *VMHooksImpl) VerifySecp256k1( func (context *VMHooksImpl) ManagedVerifySecp256k1( keyHandle, messageHandle, sigHandle int32, ) int32 { - host := context.GetVMHost() - return ManagedVerifySecp256k1WithHost(host, keyHandle, messageHandle, sigHandle) -} - -// ManagedVerifySecp256k1WithHost VMHooks implementation. -func ManagedVerifySecp256k1WithHost( - host vmhost.VMHost, - keyHandle, messageHandle, sigHandle int32, -) int32 { - return ManagedVerifyCustomSecp256k1WithHost( - host, + return context.ManagedVerifyCustomSecp256k1( keyHandle, messageHandle, sigHandle, int32(secp256.ECDSADoubleSha256), - verifyCustomSecp256k1Name, ) } @@ -2093,14 +1788,14 @@ func (context *VMHooksImpl) EllipticCurveGetValues(ecHandle int32, fieldOrderHan func (context *VMHooksImpl) ManagedVerifySecp256r1( keyHandle, messageHandle, sigHandle int32, ) int32 { - host := context.GetVMHost() - return ManagedVerifyCustomSecp256k1WithHost( - host, - keyHandle, - messageHandle, - sigHandle, - 0, - verifySecp256R1Signature) + crypto := context.GetCryptoContext() + return context.managedVerifyWithOperands(verifySecp256R1Signature, vmhost.ErrSecp256k1Verify, func() error { + keyBytes, msgBytes, sigBytes, err := context.getSignatureOperands(keyHandle, messageHandle, sigHandle) + if err != nil { + return err + } + return crypto.VerifySecp256r1(keyBytes, msgBytes, sigBytes) + }) } // ManagedVerifyBLSSignatureShare VMHooks implementation. @@ -2110,8 +1805,14 @@ func (context *VMHooksImpl) ManagedVerifyBLSSignatureShare( messageHandle int32, sigHandle int32, ) int32 { - host := context.GetVMHost() - return ManagedVerifyBLSWithHost(host, keyHandle, messageHandle, sigHandle, verifyBLSSignatureShare) + crypto := context.GetCryptoContext() + return context.managedVerifyWithOperands(verifyBLSSignatureShare, vmhost.ErrBlsVerify, func() error { + keyBytes, msgBytes, sigBytes, err := context.getSignatureOperands(keyHandle, messageHandle, sigHandle) + if err != nil { + return err + } + return crypto.VerifySignatureShare(keyBytes, msgBytes, sigBytes) + }) } // ManagedVerifyBLSAggregatedSignature VMHooks implementation. @@ -2121,6 +1822,20 @@ func (context *VMHooksImpl) ManagedVerifyBLSAggregatedSignature( messageHandle int32, sigHandle int32, ) int32 { - host := context.GetVMHost() - return ManagedVerifyBLSWithHost(host, keyHandle, messageHandle, sigHandle, verifyBLSAggregatedSignature) + crypto := context.GetCryptoContext() + managedType := context.GetManagedTypesContext() + + return context.managedVerifyWithOperands(verifyBLSAggregatedSignature, vmhost.ErrBlsVerify, func() error { + pubKeyBytes, _, err := managedType.ReadManagedVecOfManagedBuffers(keyHandle) + if err != nil { + return err + } + + _, msgBytes, sigBytes, err := context.getSignatureOperands(0, messageHandle, sigHandle) + if err != nil { + return err + } + + return crypto.VerifyAggregatedSig(pubKeyBytes, msgBytes, sigBytes) + }) } diff --git a/vmhost/vmhooks/helpers.go b/vmhost/vmhooks/helpers.go new file mode 100644 index 000000000..c5026d0c3 --- /dev/null +++ b/vmhost/vmhooks/helpers.go @@ -0,0 +1,176 @@ +package vmhooks + +import ( + "github.com/multiversx/mx-chain-vm-go/vmhost" +) + +import ( + "github.com/multiversx/mx-chain-core-go/data/esdt" + "github.com/multiversx/mx-chain-vm-go/executor" + "github.com/multiversx/mx-chain-vm-go/vmhost" +) + +// This file will contain helper functions to reduce boilerplate and duplication in the other files. + +type hashFunc func(data []byte) ([]byte, error) + +func (context *VMHooksImpl) managedHash( + inputHandle int32, + outputHandle int32, + traceName string, + gasCost uint64, + hf hashFunc, + failError error, +) int32 { + host := context.GetVMHost() + metering := host.Metering() + managedType := host.ManagedTypes() + enableEpochsHandler := host.EnableEpochsHandler() + + err := metering.UseGasBoundedAndAddTracedGas(traceName, gasCost) + if err != nil { + FailExecution(host, err) + return 1 + } + + inputBytes, err := managedType.GetBytes(inputHandle) + if err != nil { + FailExecution(host, err) + return 1 + } + + err = managedType.ConsumeGasForBytes(inputBytes) + if err != nil { + FailExecution(host, err) + return 1 + } + + resultBytes, err := hf(inputBytes) + if err != nil { + if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { + err = failError + } + + FailExecution(host, err) + return 1 + } + + managedType.SetBytes(outputHandle, resultBytes) + + return 0 +} + +type esdtDataHandler func(context *VMHooksImpl, esdtData *esdt.ESDigitalToken) int32 + +func (context *VMHooksImpl) withESDTData( + addressOffset executor.MemPtr, + tokenIDOffset executor.MemPtr, + tokenIDLen executor.MemLength, + nonce int64, + traceName string, + handler esdtDataHandler, +) int32 { + metering := context.GetMeteringContext() + metering.StartGasTracing(traceName) + + esdtData, err := context.GetESDTDataFromBlockchainHook(addressOffset, tokenIDOffset, tokenIDLen, nonce) + + if err != nil { + context.FailExecution(err) + return -1 + } + + return handler(context, esdtData) +} + +func (context *VMHooksImpl) GetESDTDataFromBlockchainHook( + addressOffset executor.MemPtr, + tokenIDOffset executor.MemPtr, + tokenIDLen executor.MemLength, + nonce int64, +) (*esdt.ESDigitalToken, error) { + metering := context.GetMeteringContext() + blockchain := context.GetBlockchainContext() + + gasToUse := metering.GasSchedule().BaseOpsAPICost.GetExternalBalance + err := metering.UseGasBounded(gasToUse) + if err != nil { + return nil, err + } + + address, err := context.MemLoad(addressOffset, vmhost.AddressLen) + if err != nil { + return nil, err + } + + tokenID, err := context.MemLoad(tokenIDOffset, tokenIDLen) + if err != nil { + return nil, err + } + + esdtToken, err := blockchain.GetESDTToken(address, tokenID, uint64(nonce)) + if err != nil { + return nil, err + } + + return esdtToken, nil +} + +func (context *VMHooksImpl) getSignatureOperands(keyHandle, messageHandle, sigHandle int32) ([]byte, []byte, []byte, error) { + managedType := context.GetManagedTypesContext() + + keyBytes, err := managedType.GetBytes(keyHandle) + if err != nil { + return nil, nil, nil, err + } + if err = managedType.ConsumeGasForBytes(keyBytes); err != nil { + return nil, nil, nil, err + } + + msgBytes, err := managedType.GetBytes(messageHandle) + if err != nil { + return nil, nil, nil, err + } + if err = managedType.ConsumeGasForBytes(msgBytes); err != nil { + return nil, nil, nil, err + } + + sigBytes, err := managedType.GetBytes(sigHandle) + if err != nil { + return nil, nil, nil, err + } + if err = managedType.ConsumeGasForBytes(sigBytes); err != nil { + return nil, nil, nil, err + } + + return keyBytes, msgBytes, sigBytes, nil +} + +func (context *VMHooksImpl) managedVerifyWithOperands( + sigVerificationType string, + failError error, + verify func() error, +) int32 { + host := context.GetVMHost() + runtime := host.Runtime() + metering := host.Metering() + enableEpochsHandler := host.EnableEpochsHandler() + + err := useGasForCryptoVerify(metering, sigVerificationType) + if err != nil && runtime.UseGasBoundedShouldFailExecution() { + FailExecution(host, err) + return 1 + } + + invalidSigErr := verify() + if invalidSigErr != nil { + if enableEpochsHandler.IsFlagEnabled(vmhost.MaskInternalDependenciesErrorsFlag) { + invalidSigErr = failError + } + + FailExecution(host, invalidSigErr) + return -1 + } + + return 0 +} diff --git a/vmhost/vmhooks/helpers_test.go b/vmhost/vmhooks/helpers_test.go new file mode 100644 index 000000000..ab54e5958 --- /dev/null +++ b/vmhost/vmhooks/helpers_test.go @@ -0,0 +1,208 @@ +package vmhooks + +import ( + "errors" + "testing" + + "github.com/multiversx/mx-chain-core-go/data/esdt" + "github.com/multiversx/mx-chain-vm-go/vmhost" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// mockVMHost is a mock for the VMHost interface +type mockVMHost struct { + mock.Mock + crypto vmhost.VMCrypto + managedType vmhost.ManagedTypesContext + metering vmhost.MeteringContext + blockchain vmhost.BlockchainContext + enableEpochsH vmhost.EnableEpochsHandler + runtime vmhost.RuntimeContext +} + +func (m *mockVMHost) Crypto() vmhost.VMCrypto { return m.crypto } +func (m *mockVMHost) ManagedTypes() vmhost.ManagedTypesContext { return m.managedType } +func (m *mockVMHost) Metering() vmhost.MeteringContext { return m.metering } +func (m *mockVMHost) Blockchain() vmhost.BlockchainContext { return m.blockchain } +func (m *mockVMHost) EnableEpochsHandler() vmhost.EnableEpochsHandler { return m.enableEpochsH } +func (m *mockVMHost) Runtime() vmhost.RuntimeContext { return m.runtime } +func (m *mockVMHost) FailExecution(err error) { m.Called(err) } +func (m *mockVMHost) GetGasSchedule() vmhost.GasSchedule { return nil } +func (m *mockVMHost) AreInSameShard(address1, address2 []byte) bool { return true } +func (m *mockVMHost) IsBuiltinFunctionName(name string) bool { return false } +func (m *mockVMHost) GetTxContext() vmhost.TxContext { return nil } +func (m *mockVMHost) GetLogEntries() []*vmhost.LogEntry { return nil } +func (m *mockVMHost) CompleteLogEntriesWithCallType(output *vmcommon.VMOutput, callType string) {} + + +// mockCryptoHook is a mock for the VMCrypto interface +type mockCryptoHook struct { + mock.Mock +} + +func (m *mockCryptoHook) Sha256(data []byte) ([]byte, error) { + args := m.Called(data) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]byte), args.Error(1) +} +func (m *mockCryptoHook) Keccak256(p []byte) ([]byte, error) { + args := m.Called(p) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]byte), args.Error(1) +} + +func (m *mockCryptoHook) Ripemd160(p []byte) ([]byte, error) { + args := m.Called(p) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]byte), args.Error(1) +} +// ... other crypto functions + +// mockManagedTypesContext is a mock for the ManagedTypesContext interface +type mockManagedTypesContext struct { + mock.Mock +} + +func (m *mockManagedTypesContext) GetBytes(handle int32) ([]byte, error) { + args := m.Called(handle) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]byte), args.Error(1) +} + +func (m *mockManagedTypesContext) ConsumeGasForBytes(data []byte) error { + args := m.Called(data) + return args.Error(0) +} + +func (m *mockManagedTypesContext) SetBytes(handle int32, data []byte) { + m.Called(handle, data) +} + +// mockMetering is a mock for the MeteringContext interface +type mockMetering struct { + mock.Mock +} + +func (m *mockMetering) UseGasBoundedAndAddTracedGas(name string, gas uint64) error { + args := m.Called(name, gas) + return args.Error(0) +} +func (m *mockMetering) GasLeft() uint64 { + return 0 +} +func (m *mockMetering) UseGas(gas uint64) error { + return nil +} + +// mockEnableEpochsHandler is a mock for the EnableEpochsHandler interface +type mockEnableEpochsHandler struct { + mock.Mock +} + +func (m *mockEnableEpochsHandler) IsFlagEnabled(flag vmhost.EpochFlag) bool { + return true +} + +func TestVMHooksImpl_managedHash(t *testing.T) { + t.Parallel() + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + cryptoHook := &mockCryptoHook{} + managedType := &mockManagedTypesContext{} + metering := &mockMetering{} + enableEpochsH := &mockEnableEpochsHandler{} + host := &mockVMHost{ + crypto: cryptoHook, + managedType: managedType, + metering: metering, + enableEpochsH: enableEpochsH, + } + context := &VMHooksImpl{host: host} + + inputBytes := []byte("input") + outputBytes := []byte("output") + + managedType.On("GetBytes", int32(1)).Return(inputBytes, nil) + managedType.On("ConsumeGasForBytes", inputBytes).Return(nil) + managedType.On("SetBytes", int32(2), outputBytes).Return() + cryptoHook.On("Sha256", inputBytes).Return(outputBytes, nil) + metering.On("UseGasBoundedAndAddTracedGas", "sha256", uint64(100)).Return(nil) + + result := context.managedHash(1, 2, "sha256", 100, cryptoHook.Sha256, vmhost.ErrSha256Hash) + + assert.Equal(t, int32(0), result) + managedType.AssertExpectations(t) + cryptoHook.AssertExpectations(t) + metering.AssertExpectations(t) + }) + + t.Run("should fail on get bytes", func(t *testing.T) { + t.Parallel() + + cryptoHook := &mockCryptoHook{} + managedType := &mockManagedTypesContext{} + metering := &mockMetering{} + enableEpochsH := &mockEnableEpochsHandler{} + host := &mockVMHost{ + crypto: cryptoHook, + managedType: managedType, + metering: metering, + enableEpochsH: enableEpochsH, + } + context := &VMHooksImpl{host: host} + + err := errors.New("err") + managedType.On("GetBytes", int32(1)).Return(nil, err) + metering.On("UseGasBoundedAndAddTracedGas", "sha256", uint64(100)).Return(nil) + host.On("FailExecution", err).Return() + + result := context.managedHash(1, 2, "sha256", 100, cryptoHook.Sha256, vmhost.ErrSha256Hash) + + assert.Equal(t, int32(1), result) + host.AssertExpectations(t) + }) +} + +func TestVMHooksImpl_getSignatureOperands(t *testing.T) { + t.Parallel() + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + managedType := &mockManagedTypesContext{} + host := &mockVMHost{ + managedType: managedType, + } + context := &VMHooksImpl{host: host} + + keyBytes := []byte("key") + msgBytes := []byte("msg") + sigBytes := []byte("sig") + + managedType.On("GetBytes", int32(1)).Return(keyBytes, nil) + managedType.On("ConsumeGasForBytes", keyBytes).Return(nil) + managedType.On("GetBytes", int32(2)).Return(msgBytes, nil) + managedType.On("ConsumeGasForBytes", msgBytes).Return(nil) + managedType.On("GetBytes", int32(3)).Return(sigBytes, nil) + managedType.On("ConsumeGasForBytes", sigBytes).Return(nil) + + k, m, s, err := context.getSignatureOperands(1, 2, 3) + + assert.Nil(t, err) + assert.Equal(t, keyBytes, k) + assert.Equal(t, msgBytes, m) + assert.Equal(t, sigBytes, s) + managedType.AssertExpectations(t) + }) +}