diff --git a/Project.toml b/Project.toml index 1cc7c4dfb..77d88e76c 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +DecoratedParticles = "023d0394-cb16-4d2d-a5c7-724bed42bbb6" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" EquivariantTensors = "5e107534-7145-4f8f-b06f-47a52840c895" @@ -59,15 +61,17 @@ BenchmarkTools = "1.6.3" Bumper = "0.7" ChainRulesCore = "1" ChunkSplitters = "3.0" +ConcreteStructs = "0.2.3" +DecoratedParticles = "0.1.3" DynamicPolynomials = "0.6" EmpiricalPotentials = "0.2" -EquivariantTensors = "0.3" +EquivariantTensors = "0.4.3" ExtXYZ = "0.2.0" Folds = "0.2" -ForwardDiff = "0.10" -Interpolations = "0.15" -JSON = "0.21" -Lux = "1.25" +ForwardDiff = "0.10, 1" +Interpolations = "0.16" +JSON = "0.21, 1" +Lux = "1.21" LuxCore = "1" NamedTupleTools = "0.13, 0.14" NeighbourLists = "0.5" @@ -76,10 +80,10 @@ Optim = "1" Optimisers = "0.3.4, 0.4" OrderedCollections = "1" Polynomials4ML = "0.5" -PrettyTables = "1.3, 2.0" +PrettyTables = "1.3, 2" Reexport = "1" Roots = "2" -SparseArrays = "1.10" +SparseArrays = "1" SpheriCart = "0.2" StaticArrays = "1" StaticPolynomials = "1" @@ -87,7 +91,7 @@ StrideArrays = "0.1" Unitful = "1" WithAlloc = "0.1" YAML = "0.4" -Zygote = "0.6, 0.7" +Zygote = "0.7" julia = "1.11, 1.12" [extras] diff --git a/benchmark/benchmark_full_model.jl b/benchmark/benchmark_full_model.jl new file mode 100644 index 000000000..1a4c859f0 --- /dev/null +++ b/benchmark/benchmark_full_model.jl @@ -0,0 +1,198 @@ +# Benchmark: Full model (1+2+many body) with StackedCalculator +# Compares ACE CPU vs ETACE CPU vs ETACE GPU for energy and forces + +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators +using StaticArrays, Lux, Random, LuxCore, LinearAlgebra +using AtomsBase, AtomsBuilder, Unitful +using BenchmarkTools +using Printf + +# GPU detection +dev = identity +has_cuda = false + +try + using CUDA + if CUDA.functional() + @info "Using CUDA" + CUDA.versioninfo() + global has_cuda = true + global dev = cu + else + @info "CUDA is not functional" + end +catch e + @info "Couldn't load CUDA: $e" +end + +if !has_cuda + @info "No GPU available. Using CPU only." +end + +rng = Random.MersenneTwister(1234) + +# Build models with E0s and pair potential enabled +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +# E0s for one-body +E0s = Dict(:Si => -158.54496821, :O => -2042.0330099956639) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true, # Keep learnable for ET conversion + E0s = E0s) + +ps, st = Lux.setup(rng, model) + +# Create old ACE calculator (full model with E0s and pair) +ace_calc = M.ACEPotential(model, ps, st) + +# Convert to full ETACE with StackedCalculator +et_calc = ETM.convert2et_full(model, ps, st) + +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# Function to create system of given size +function make_system(n_repeat) + sys = AtomsBuilder.bulk(:Si, cubic=true) * n_repeat + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +# Benchmark configurations +configs = [ + (2, 2, 2), # 64 atoms + (3, 3, 2), # 144 atoms + (4, 4, 2), # 256 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("=" ^ 90) +println("BENCHMARK: Full Model (1+2+many body) - ACE vs ETACE StackedCalculator") +println("=" ^ 90) +println() + +# --- ENERGY BENCHMARK --- +println("### ENERGY ###") +println() + +if has_cuda + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE GPU (ms) | CPU Speedup | GPU Speedup |") + println("|-------|---------|--------------|----------------|----------------|-------------|-------------|") +else + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") + println("|-------|---------|--------------|----------------|-------------|") +end + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.potential_energy(sys, ace_calc) + + # Warmup ETACE CPU + _ = AtomsCalculators.potential_energy(sys, et_calc) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.potential_energy($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU + t_etace_cpu = @belapsed AtomsCalculators.potential_energy($sys, $et_calc) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + if has_cuda + # For GPU we need to handle the StackedCalculator with GPU-capable models + # TODO: GPU version of StackedCalculator + t_etace_gpu_ms = NaN + gpu_speedup = NaN + + @printf("| %5d | %7d | %12.2f | %14.2f | %14s | %10.1fx | %10s |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, "N/A", cpu_speedup, "N/A") + else + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) + end +end + +println() + +# --- FORCES BENCHMARK --- +println("### FORCES ###") +println() + +if has_cuda + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE GPU (ms) | CPU Speedup | GPU Speedup |") + println("|-------|---------|--------------|----------------|----------------|-------------|-------------|") +else + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") + println("|-------|---------|--------------|----------------|-------------|") +end + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.forces(sys, ace_calc) + + # Warmup ETACE CPU + _ = AtomsCalculators.forces(sys, et_calc) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.forces($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU + t_etace_cpu = @belapsed AtomsCalculators.forces($sys, $et_calc) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + if has_cuda + t_etace_gpu_ms = NaN + gpu_speedup = NaN + + @printf("| %5d | %7d | %12.2f | %14.2f | %14s | %10.1fx | %10s |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, "N/A", cpu_speedup, "N/A") + else + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) + end +end + +println() +println("Notes:") +println("- ACE CPU: Original ACEpotentials model (full: E0 + pair + many-body)") +println("- ETACE CPU: StackedCalculator with ETOneBody + ETPairModel + ETACE") +println("- CPU Speedup = ACE CPU / ETACE CPU") +println("- Graph construction time included in ETACE timings") diff --git a/benchmark/gpu_benchmark.jl b/benchmark/gpu_benchmark.jl new file mode 100644 index 000000000..e861be532 --- /dev/null +++ b/benchmark/gpu_benchmark.jl @@ -0,0 +1,317 @@ +# GPU Benchmark for ETACE Models +# Run with: julia --project=test benchmark/gpu_benchmark.jl +# +# Tests: ETOneBody, ETPairModel, ETACE (many-body), and combined full model + +using CUDA +using LuxCUDA + +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +using Lux, LuxCore, Random +using AtomsBase, AtomsBuilder, Unitful +using Printf + +println("CUDA available: ", CUDA.functional()) + +# Build model +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +E0s = Dict(:Si => -158.54496821, :O => -2042.0330099956639) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true, + E0s = E0s) + +rng = Random.MersenneTwister(1234) +ps, st = Lux.setup(rng, model) + +rcut = 5.5 +NZ = 2 + +# ============================================================================ +# 1. ETACE (many-body only) +# ============================================================================ +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(rng, et_model) + +# Copy parameters +for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + et_ps.rembed.post.W[:, :, idx] .= ps.rbasis.Wnlq[:, :, i, j] +end +for iz in 1:NZ + et_ps.readout.W[1, :, iz] .= ps.WB[:, iz] +end + +# ============================================================================ +# 2. ETPairModel +# ============================================================================ +et_pair = ETM.convertpair(model) +pair_ps, pair_st = LuxCore.setup(rng, et_pair) + +# Copy pair parameters +for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + pair_ps.rembed.rbasis.post.W[:, :, idx] .= ps.pairbasis.Wnlq[:, :, i, j] +end +for iz in 1:NZ + pair_ps.readout.W[1, :, iz] .= ps.Wpair[:, iz] +end + +# ============================================================================ +# 3. ETOneBody +# ============================================================================ +zlist = ChemicalSpecies.((:Si, :O)) +E0_dict = Dict(z => E0s[Symbol(z)] for z in zlist) +et_onebody = ETM.one_body(E0_dict, x -> x.z) +onebody_ps, onebody_st = LuxCore.setup(rng, et_onebody) + +# GPU device +gdev = Lux.gpu_device() +println("GPU device: ", gdev) + +# Benchmark configurations +configs = [ + (2, 2, 2), # 64 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("="^80) +println("GPU BENCHMARK: ETACE Models (with P4ML v0.5.8)") +println("="^80) + +# ============================================================================ +# SECTION 1: Many-Body Only (ETACE) +# ============================================================================ +println() +println("### MANY-BODY ONLY (ETACE) - ENERGY ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup CPU + _ = et_model(G, et_ps, et_st) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:10 + et_model(G, et_ps, et_st) + end + t_cpu_ms = (t_cpu / 10) * 1000 + + # GPU setup + G_gpu = gdev(G) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + # Warmup GPU + CUDA.@sync et_model(G_gpu, et_ps_gpu, et_st_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:10 + CUDA.@sync et_model(G_gpu, et_ps_gpu, et_st_gpu) + end + t_gpu_ms = (t_gpu / 10) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +println() +println("### MANY-BODY ONLY (ETACE) - FORCES ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup CPU + _ = ETM.site_grads(et_model, G, et_ps, et_st) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:5 + ETM.site_grads(et_model, G, et_ps, et_st) + end + t_cpu_ms = (t_cpu / 5) * 1000 + + # GPU setup + G_gpu = gdev(G) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + # Warmup GPU + CUDA.@sync ETM.site_grads(et_model, G_gpu, et_ps_gpu, et_st_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:5 + CUDA.@sync ETM.site_grads(et_model, G_gpu, et_ps_gpu, et_st_gpu) + end + t_gpu_ms = (t_gpu / 5) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +# ============================================================================ +# SECTION 2: Full Model (E0 + Pair + Many-Body) +# ============================================================================ +println() +println("="^80) +println("### FULL MODEL (E0 + Pair + Many-Body) - ENERGY ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # CPU: evaluate all three models + function full_energy_cpu(G) + E_onebody, _ = et_onebody(G, onebody_ps, onebody_st) + E_pair, _ = et_pair(G, pair_ps, pair_st) + E_mb, _ = et_model(G, et_ps, et_st) + return sum(E_onebody) + sum(E_pair) + sum(E_mb) + end + + # Warmup CPU + _ = full_energy_cpu(G) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:10 + full_energy_cpu(G) + end + t_cpu_ms = (t_cpu / 10) * 1000 + + # GPU setup - all models + G_gpu = gdev(G) + onebody_ps_gpu = gdev(onebody_ps) + onebody_st_gpu = gdev(onebody_st) + pair_ps_gpu = gdev(pair_ps) + pair_st_gpu = gdev(pair_st) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + function full_energy_gpu(G_gpu) + E_onebody, _ = et_onebody(G_gpu, onebody_ps_gpu, onebody_st_gpu) + E_pair, _ = et_pair(G_gpu, pair_ps_gpu, pair_st_gpu) + E_mb, _ = et_model(G_gpu, et_ps_gpu, et_st_gpu) + return sum(E_onebody) + sum(E_pair) + sum(E_mb) + end + + # Warmup GPU + CUDA.@sync full_energy_gpu(G_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:10 + CUDA.@sync full_energy_gpu(G_gpu) + end + t_gpu_ms = (t_gpu / 10) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +println() +println("### FULL MODEL (E0 + Pair + Many-Body) - FORCES ###") +println("| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup |") +println("|-------|---------|----------|----------|-------------|") + +for cfg in configs + sys = AtomsBuilder.bulk(:Si, cubic=true) * cfg + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + natoms = length(sys) + + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # CPU: evaluate all three gradients + function full_grads_cpu(G) + ∂onebody = ETM.site_grads(et_onebody, G, onebody_ps, onebody_st) + ∂pair = ETM.site_grads(et_pair, G, pair_ps, pair_st) + ∂mb = ETM.site_grads(et_model, G, et_ps, et_st) + return (∂onebody, ∂pair, ∂mb) + end + + # Warmup CPU + _ = full_grads_cpu(G) + + # CPU benchmark + t_cpu = @elapsed for _ in 1:5 + full_grads_cpu(G) + end + t_cpu_ms = (t_cpu / 5) * 1000 + + # GPU setup - all models + G_gpu = gdev(G) + onebody_ps_gpu = gdev(onebody_ps) + onebody_st_gpu = gdev(onebody_st) + pair_ps_gpu = gdev(pair_ps) + pair_st_gpu = gdev(pair_st) + et_ps_gpu = gdev(et_ps) + et_st_gpu = gdev(et_st) + + function full_grads_gpu(G_gpu) + ∂onebody = ETM.site_grads(et_onebody, G_gpu, onebody_ps_gpu, onebody_st_gpu) + ∂pair = ETM.site_grads(et_pair, G_gpu, pair_ps_gpu, pair_st_gpu) + ∂mb = ETM.site_grads(et_model, G_gpu, et_ps_gpu, et_st_gpu) + return (∂onebody, ∂pair, ∂mb) + end + + # Warmup GPU + CUDA.@sync full_grads_gpu(G_gpu) + + # GPU benchmark + t_gpu = CUDA.@elapsed for _ in 1:5 + CUDA.@sync full_grads_gpu(G_gpu) + end + t_gpu_ms = (t_gpu / 5) * 1000 + + speedup = t_cpu_ms / t_gpu_ms + + @printf("| %5d | %7d | %8.2f | %8.2f | %10.1fx |\n", + natoms, nedges, t_cpu_ms, t_gpu_ms, speedup) +end + +println() diff --git a/docs/Project.toml b/docs/Project.toml index 1facd1b7c..8b19a1034 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,15 +5,19 @@ AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EquivariantTensors = "5e107534-7145-4f8f-b06f-47a52840c895" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/docs/make.jl b/docs/make.jl index ed112b1e3..285a3c7b2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -28,8 +28,14 @@ Literate.markdown(_tutorial_src * "/dataset_analysis.jl", Literate.markdown(_tutorial_src * "/descriptor.jl", _tutorial_out; documenter = true) -Literate.markdown(_tutorial_src * "/asp.jl", +Literate.markdown(_tutorial_src * "/asp.jl", _tutorial_out; documenter = true) + +# ETACE tutorial lives in examples/ to avoid duplication +_examples_src = joinpath(@__DIR__(), "..", "examples", "etmodels") +Literate.markdown(_examples_src * "/etace_tutorial.jl", + _tutorial_out; documenter = true) + # Literate.markdown(_tutorial_src * "/first_example_model.jl", # _tutorial_out; documenter = true) @@ -70,9 +76,10 @@ makedocs(; "literate_tutorials/basic_julia_workflow.md", "literate_tutorials/smoothness_priors.md", "literate_tutorials/dataset_analysis.md", - "tutorials/scripting.md", + "tutorials/scripting.md", "literate_tutorials/descriptor.md", "literate_tutorials/asp.md", + "literate_tutorials/etace_tutorial.md", ], "Additional Topics" => Any[ "gettingstarted/parallel-fitting.md", diff --git a/docs/plans/et_calculators_plan.md b/docs/plans/et_calculators_plan.md new file mode 100644 index 000000000..8522b0577 --- /dev/null +++ b/docs/plans/et_calculators_plan.md @@ -0,0 +1,231 @@ +# Plan: ETACE Calculator Interface and Training Support + +## Overview + +Create calculator wrappers and training assembly for the new ETACE backend, integrating with EquivariantTensors.jl. + +**Status**: ✅ Core implementation complete. GPU acceleration working. PR #313 under review. + +**Branch**: `jrk/etcalculators` (rebased on `acesuit/co/etback` including `co/etpair` merge) + +**PR**: https://github.com/ACEsuit/ACEpotentials.jl/pull/313 + +--- + +## Progress Summary + +| Phase | Description | Status | +|-------|-------------|--------| +| Phase 1 | ETACEPotential with AtomsCalculators interface | ✅ Complete | +| Phase 2 | WrappedSiteCalculator + StackedCalculator | ✅ Complete | +| Phase 3 | E0Model + PairModel | ✅ Complete (upstream ETOneBody, ETPairModel) | +| Phase 5 | Training assembly functions | ✅ Complete (many-body only) | +| Phase 6 | Full model integration | ✅ Complete | +| Benchmarks | CPU + GPU performance comparison | ✅ Complete | + +### Key Design Decision: Unified Architecture + +**All upstream ETACE-pattern models share the same interface:** + +| Method | ETACE | ETPairModel | ETOneBody | +|--------|-------|-------------|-----------| +| `model(G, ps, st)` | site energies | site energies | site energies | +| `site_grads(model, G, ps, st)` | edge gradients | edge gradients | zero gradients | +| `site_basis(model, G, ps, st)` | basis matrix | basis matrix | empty | +| `site_basis_jacobian(model, G, ps, st)` | (basis, jac) | (basis, jac) | (empty, empty) | + +This enables a **unified `WrappedSiteCalculator`** that works with all three model types directly. + +--- + +## Benchmark Results + +### GPU Benchmarks (Many-Body Only - ETACE) + +**Energy:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2146 | 3.38 | 0.54 | **6.3x** | +| 512 | 17176 | 27.77 | 0.66 | **41.9x** | +| 800 | 26868 | 37.12 | 0.78 | **47.6x** | + +**Forces:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2146 | 46.46 | 14.42 | **3.2x** | +| 512 | 17178 | 104.39 | 15.12 | **6.9x** | +| 800 | 26860 | 289.32 | 16.33 | **17.7x** | + +### GPU Benchmarks (Full Model - E0 + Pair + Many-Body) + +**Energy:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2140 | 3.40 | 0.94 | **3.6x** | +| 512 | 17166 | 31.18 | 0.95 | **32.9x** | +| 800 | 26858 | 45.16 | 1.24 | **36.4x** | + +**Forces:** +| Atoms | Edges | CPU (ms) | GPU (ms) | GPU Speedup | +|-------|-------|----------|----------|-------------| +| 64 | 2134 | 24.05 | 19.34 | **1.2x** | +| 512 | 17178 | ~110 | ~20 | **~5x** | +| 800 | 26860 | ~300 | ~22 | **~14x** | + +### CPU Benchmarks (ETACE vs Classic ACE) + +**Forces (Full Model):** +| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE Speedup | +|-------|-------|--------------|----------------|---------------| +| 64 | 2146 | 73.6 | 30.5 | **2.4x** | +| 256 | 8596 | 307.7 | 74.4 | **4.1x** | +| 800 | 26886 | 975.0 | 225.6 | **4.3x** | + +**Notes:** +- GPU forces require Polynomials4ML v0.5.8+ (bug fix Dec 29, 2024) +- GPU shows excellent scaling: larger systems see better speedups +- Full model GPU speedups are lower than many-body only due to graph construction overhead +- CPU forces are 2-4x faster with ETACE due to Zygote AD through ET graph + +--- + +## Architecture + +### Current Implementation (Complete) + +``` +StackedCalculator +├── WrappedSiteCalculator{ETOneBody} # One-body reference energies +├── WrappedSiteCalculator{ETPairModel} # Pair potential +└── WrappedSiteCalculator{ETACE} # Many-body ACE +``` + +### Core Components + +**WrappedSiteCalculator{M, PS, ST}** (`et_calculators.jl`) +- Unified wrapper for any ETACE-pattern model +- Provides AtomsCalculators interface (energy, forces, virial) +- Mutable to allow parameter updates during training + +**ETACEPotential** - Type alias for `WrappedSiteCalculator{ETACE, PS, ST}` + +**StackedCalculator{N, C}** (`stackedcalc.jl`) +- Combines multiple calculators by summing contributions +- Uses @generated functions for type-stable loop unrolling + +### Conversion Functions + +```julia +convert2et(model) # Many-body ACE → ETACE +convertpair(model) # Pair potential → ETPairModel +convert2et_full(model, ps, st) # Full model → StackedCalculator +``` + +### Training Assembly (Many-Body Only) + +```julia +length_basis(calc) # Total linear parameters +get_linear_parameters(calc) # Extract θ vector +set_linear_parameters!(calc, θ) # Set θ vector +potential_energy_basis(sys, calc) # Energy design row +energy_forces_virial_basis(sys, calc) # Full EFV design row +``` + +--- + +## Files + +### Source Files +- `src/et_models/et_ace.jl` - ETACE model implementation +- `src/et_models/et_pair.jl` - ETPairModel implementation +- `src/et_models/onebody.jl` - ETOneBody implementation +- `src/et_models/et_calculators.jl` - WrappedSiteCalculator, ETACEPotential, training assembly +- `src/et_models/stackedcalc.jl` - StackedCalculator with @generated +- `src/et_models/convert.jl` - Model conversion utilities +- `src/et_models/et_envbranch.jl` - EnvRBranchL for envelope × radial basis +- `src/et_models/et_models.jl` - Module includes and exports + +### Test Files +- `test/etmodels/test_etbackend.jl` - ETACE tests +- `test/etmodels/test_etpair.jl` - ETPairModel tests +- `test/etmodels/test_etonebody.jl` - ETOneBody tests + +### Benchmark Files +- `benchmark/gpu_benchmark.jl` - GPU energy/forces benchmarks +- `benchmark/benchmark_full_model.jl` - CPU comparison benchmarks + +--- + +## Outstanding Work + +### ~~1. Training Assembly for Pair Model~~ ✅ Complete +**Status**: Implemented in `et_calculators.jl` and `stackedcalc.jl` + +**What was done**: +- Added `ETPairPotential` type alias with full training assembly support +- Added `ETOneBodyPotential` type alias (returns empty arrays - no learnable params) +- Implemented `length_basis`, `energy_forces_virial_basis`, `potential_energy_basis`, `get_linear_parameters`, `set_linear_parameters!` for all calculator types +- Extended `StackedCalculator` to concatenate basis functions from all components +- Added `ACEfit.basis_size` dispatch for all calculator types + +### ~~2. ACEfit.assemble Dispatch Integration~~ ✅ Complete +**Status**: Works out-of-the-box after extending `length_basis` and `energy_forces_virial_basis` + +**What was done**: +- Added empty function declarations in `models/models.jl` for `length_basis`, `energy_forces_virial_basis`, `potential_energy_basis` +- ETModels now imports and extends these functions +- `ACEfit.feature_matrix(d::AtomsData, calc)` works with ETACE calculators +- `ACEfit.assemble(data, calc)` works with `StackedCalculator` + +### 3. Committee Support +**Priority**: Low +**Description**: Extend committee/uncertainty quantification to work with StackedCalculator. + +### 4. Basis Index Design Discussion +**Priority**: Needs Discussion +**Description**: Moderator raised concern about basis indices: + +> "I realized I made a mistake in the design of the basis interface. I'm returning the site energy basis but for each center-atom species, the basis occupies the same indices. We need to perform a transformation so that bases for different species occupy separate indices." + +**Current Implementation**: Species separation is handled at the **calculator level** in `energy_forces_virial_basis` using `p = (s-1) * nbasis + k`. Each species gets separate parameter indices. + +**Options**: +1. Keep current approach (calculator-level separation) +2. Move to site potential model level +3. Handle at WrappedSiteCalculator level + +Moderator wants discussion before making changes. + +--- + +## Dependencies + +- EquivariantTensors.jl >= 0.4.3 +- Polynomials4ML.jl >= 0.5.8 (for GPU forces) +- LuxCUDA (for GPU support, test dependency) + +--- + +## Test Status + +All tests pass: **946 passed, 1 broken** (known Julia 1.12 hash ordering issue) + +```bash +# Run ET model tests +julia --project=test -e 'using Pkg; Pkg.test("ACEpotentials"; test_args=["etmodels"])' + +# Run GPU benchmark +julia --project=test benchmark/gpu_benchmark.jl +``` + +--- + +## Notes + +- Virial formula: `V = -∑ ∂E/∂𝐫ij ⊗ 𝐫ij` +- GPU time scales sub-linearly with system size +- Forces speedup (CPU) larger than energy speedup due to Zygote AD efficiency +- StackedCalculator uses @generated functions for zero-overhead composition +- Upstream `ETOneBody` stores E0s in state (`st.E0s`) for float type flexibility +- All models use `VState` for edge gradients in `site_grads()` return +- `ETOneBody.site_grads()` returns `fill(VState(), length(edges))` for type stability (empty VState acts as additive identity) diff --git a/docs/src/all_exported.md b/docs/src/all_exported.md index 9e250d258..16dacd79e 100644 --- a/docs/src/all_exported.md +++ b/docs/src/all_exported.md @@ -3,13 +3,13 @@ ### Exported ```@autodocs -Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat] +Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat, ACEpotentials.ETModels] Private = false -``` +``` ### Not exported ```@autodocs -Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat] +Modules = [ACEpotentials, ACEpotentials.Models, ACEpotentials.ACE1compat, ACEpotentials.ETModels] Public = false ``` diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 98ee4560c..23da5738e 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,11 +1,11 @@ # Tutorials Overview -* [Basic Julia Workflow](../literate_tutorials/basic_julia_workflow.md) : minimal example to fit a potential to an existing dataset using a Julia script +* [Basic Julia Workflow](../literate_tutorials/basic_julia_workflow.md) : minimal example to fit a potential to an existing dataset using a Julia script * [Basic Shell Workflow](scripting.md) : basic workflow for fitting via the command line * [Smoothness Priors](../literate_tutorials/smoothness_priors.md) : brief introduction to smoothness priors * [Basic Dataset Analysis](../literate_tutorials/dataset_analysis.md) : basic techniques to visualize training datasets and correlate such observations to the choice of geometric priors -* [Descriptors](../literate_tutorials/descriptor.md) : `ACEpotentials` can be used as descriptors of atomic environments or structures, which is described here. +* [Descriptors](../literate_tutorials/descriptor.md) : `ACEpotentials` can be used as descriptors of atomic environments or structures, which is described here. * [Sparse Solvers](../literate_tutorials/asp.md) : basic tutorial on using the `ASP` and `OMP` solvers. - +* [ETACE Models](../literate_tutorials/etace_tutorial.md) : using the EquivariantTensors backend for ACE models, including conversion from standard ACE models and creating ETACE models from scratch. diff --git a/examples/etmodels/etace_tutorial.jl b/examples/etmodels/etace_tutorial.jl new file mode 100644 index 000000000..dd5f31054 --- /dev/null +++ b/examples/etmodels/etace_tutorial.jl @@ -0,0 +1,348 @@ +# # ETACE Models Tutorial +# +# This tutorial demonstrates how to use the EquivariantTensors (ET) backend +# for ACE models in ACEpotentials.jl. The ET backend provides: +# - Graph-based evaluation (edge-centric computation) +# - Automatic differentiation via Zygote +# - GPU-ready architecture via KernelAbstractions +# - Lux.jl layer integration +# +# We cover two approaches: +# 1. **Converting from an existing ACE model** - The recommended approach +# 2. **Creating an ETACE model from scratch** - For advanced users +# + +## Load required packages +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful +using AtomsCalculators, Random, LinearAlgebra + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +rng = Random.MersenneTwister(1234) + +# ============================================================================= +# Part 1: Converting from an Existing ACE Model (Recommended) +# ============================================================================= +# +# The simplest way to get an ETACE model is to convert from a standard ACE model. +# This approach ensures consistency with the familiar ACE model construction API. + +## Define model hyperparameters +elements = (:Si, :O) +order = 3 # correlation order (body-order = order + 1) +max_level = 10 # total polynomial degree +maxl = 6 # maximum angular momentum +rcut = 5.5 # cutoff radius in Angstrom + +## Create the standard ACE model +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) + +## Note: pair_learnable=true is required for ET conversion +## (default uses splines which aren't yet supported by convert2et) +model = M.ace_model(; + elements = elements, + order = order, + Ytype = :solid, + level = M.TotalDegree(), + max_level = max_level, + maxl = maxl, + pair_maxn = max_level, + rin0cuts = rin0cuts, + E0s = Dict(:Si => -0.846, :O => -1.023), # reference energies + pair_learnable = true # required for ET conversion +) + +## Initialize parameters with Lux +ps, st = Lux.setup(rng, model) + +@info "Standard ACE model created" +@info " Number of basis functions: $(M.length_basis(model))" + +# ----------------------------------------------------------------------------- +# Method A: Convert full model (E0 + Pair + Many-body) to StackedCalculator +# ----------------------------------------------------------------------------- + +## convert2et_full creates a StackedCalculator combining: +## - ETOneBody (reference energies per species) +## - ETPairModel (pair potential) +## - ETACE (many-body ACE potential) + +et_calc_full = ETM.convert2et_full(model, ps, st; rng=rng) + +@info "Full conversion to StackedCalculator" +@info " Contains: ETOneBody + ETPairPotential + ETACEPotential" +@info " Total linear parameters: $(ETM.length_basis(et_calc_full))" + +# ----------------------------------------------------------------------------- +# Method B: Convert only the many-body ACE component +# ----------------------------------------------------------------------------- + +## convert2et creates just the ETACE model (many-body only, no E0 or pair) +et_ace = ETM.convert2et(model) +et_ace_ps, et_ace_st = Lux.setup(rng, et_ace) + +## Copy parameters from the original model +ETM.copy_ace_params!(et_ace_ps, ps, model) + +## Wrap in calculator for AtomsCalculators interface +et_ace_calc = ETM.ETACEPotential(et_ace, et_ace_ps, et_ace_st, rcut) + +@info "Many-body only conversion" +@info " ETACE basis size: $(ETM.length_basis(et_ace_calc))" + +# ----------------------------------------------------------------------------- +# Method C: Convert only the pair potential +# ----------------------------------------------------------------------------- + +## convertpair creates an ETPairModel +et_pair = ETM.convertpair(model) +et_pair_ps, et_pair_st = Lux.setup(rng, et_pair) + +## Copy parameters from the original model +ETM.copy_pair_params!(et_pair_ps, ps, model) + +## Wrap in calculator +et_pair_calc = ETM.ETPairPotential(et_pair, et_pair_ps, et_pair_st, rcut) + +@info "Pair potential only conversion" +@info " ETPairModel basis size: $(ETM.length_basis(et_pair_calc))" + + +# ============================================================================= +# Part 2: Using ETACE Calculators +# ============================================================================= + +## Create a test system +sys = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys, 0.1u"Å") +AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + +@info "Test system: $(length(sys)) atoms" + +## Evaluate energy, forces, virial using AtomsCalculators interface +E = AtomsCalculators.potential_energy(sys, et_calc_full) +F = AtomsCalculators.forces(sys, et_calc_full) +V = AtomsCalculators.virial(sys, et_calc_full) + +@info "Energy evaluation with full ETACE calculator" +@info " Energy: $E" +@info " Max force magnitude: $(maximum(norm.(F)))" + +## Combined evaluation (more efficient) +efv = AtomsCalculators.energy_forces_virial(sys, et_calc_full) +@info " Combined EFV evaluation successful" + + +# ============================================================================= +# Part 3: Training Assembly (for Linear Fitting) +# ============================================================================= +# +# The ETACE calculators support training assembly functions for ACEfit integration. +# These compute the design matrix rows for linear least squares fitting. + +## Energy-only basis evaluation (fastest) +E_basis = ETM.potential_energy_basis(sys, et_ace_calc) +@info "Energy basis: $(length(E_basis)) components" + +## Full energy, forces, virial basis +efv_basis = ETM.energy_forces_virial_basis(sys, et_ace_calc) +@info "EFV basis shapes:" +@info " Energy: $(size(efv_basis.energy))" +@info " Forces: $(size(efv_basis.forces))" +@info " Virial: $(size(efv_basis.virial))" + +## Get/set linear parameters +params = ETM.get_linear_parameters(et_ace_calc) +@info "Linear parameters: $(length(params)) values" + +## Parameters can be updated for fitting: +## ETM.set_linear_parameters!(et_ace_calc, new_params) + + +# ============================================================================= +# Part 4: Creating an ETACE Model from Scratch (Advanced) +# ============================================================================= +# +# For advanced users who want direct control over the model architecture. +# This requires understanding the EquivariantTensors.jl API. + +## Define model parameters +scratch_elements = [:Si, :O] +scratch_maxn = 6 # number of radial basis functions +scratch_maxl = 4 # maximum angular momentum +scratch_order = 2 # correlation order +scratch_rcut = 5.5 # cutoff radius + +## Species information +zlist = ChemicalSpecies.(scratch_elements) +NZ = length(zlist) + +# ----------------------------------------------------------------------------- +# Build the radial embedding (Rnl) +# ----------------------------------------------------------------------------- + +## Radial specification (n, l pairs) +Rnl_spec = [(n=n, l=l) for n in 1:scratch_maxn for l in 0:scratch_maxl] + +## Distance transform: r -> transformed coordinate y +## Using standard Agnesi transform parameters +f_trans = let rcut = scratch_rcut + (x, st) -> begin + r = norm(x.𝐫) + ## Simple polynomial transform (normalized to [-1, 1]) + y = 1 - 2 * r / rcut + return y + end +end +trans = ET.NTtransformST(f_trans, NamedTuple()) + +## Envelope function: smooth cutoff +f_env = y -> (1 - y^2)^2 # quartic envelope + +## Polynomial basis (Chebyshev) +polys = P4ML.ChebBasis(scratch_maxn) +Penv = P4ML.wrapped_basis(Lux.BranchLayer( + polys, + Lux.WrappedFunction(y -> f_env.(y)), + fusion = Lux.WrappedFunction(Pe -> Pe[2] .* Pe[1]) +)) + +## Species-pair selector for radial weights +selector_ij = let zlist = tuple(zlist...) + xij -> ET.catcat2idx(zlist, xij.z0, xij.z1) +end + +## Linear layer: P(yij) -> W[(Zi, Zj)] * P(yij) +linl = ET.SelectLinL(scratch_maxn, length(Rnl_spec), NZ^2, selector_ij) + +## Complete radial embedding +rbasis = ET.EmbedDP(trans, Penv, linl) +rembed = ET.EdgeEmbed(rbasis) + +# ----------------------------------------------------------------------------- +# Build the angular embedding (Ylm) +# ----------------------------------------------------------------------------- + +## Spherical harmonics basis +ylm_basis = P4ML.real_sphericalharmonics(scratch_maxl) +Ylm_spec = P4ML.natural_indices(ylm_basis) + +## Angular embedding: edge direction -> spherical harmonics +ybasis = ET.EmbedDP( + ET.NTtransformST((x, st) -> x.𝐫, NamedTuple()), + ylm_basis +) +yembed = ET.EdgeEmbed(ybasis) + +# ----------------------------------------------------------------------------- +# Build the many-body basis (sparse ACE) +# ----------------------------------------------------------------------------- + +## Define the many-body specification +## This specifies which (n,l) combinations appear in each correlation +## For simplicity, use all 1-correlations up to given degree +mb_spec = [[(n=n, l=l)] for n in 1:scratch_maxn for l in 0:scratch_maxl] + +## Create sparse equivariant tensor (ACE basis) +mb_basis = ET.sparse_equivariant_tensor( + L = 0, # scalar (invariant) output + mb_spec = mb_spec, + Rnl_spec = Rnl_spec, + Ylm_spec = Ylm_spec, + basis = real # real-valued basis +) + +# ----------------------------------------------------------------------------- +# Build the readout layer +# ----------------------------------------------------------------------------- + +## Species selector for readout +selector_i = let zlist = zlist + x -> ET.cat2idx(zlist, x.z) +end + +## Readout: basis values -> site energies +readout = ET.SelectLinL( + mb_basis.lens[1], # input dimension (basis length) + 1, # output dimension (site energy) + NZ, # number of species categories + selector_i +) + +# ----------------------------------------------------------------------------- +# Assemble the ETACE model +# ----------------------------------------------------------------------------- + +scratch_etace = ETM.ETACE(rembed, yembed, mb_basis, readout) + +## Initialize with Lux +scratch_ps, scratch_st = Lux.setup(rng, scratch_etace) + +@info "ETACE model created from scratch" +@info " Radial basis size: $(length(Rnl_spec))" +@info " Angular basis size: $(length(Ylm_spec))" +@info " Many-body basis size: $(mb_basis.lens[1])" + +## Wrap in calculator +scratch_calc = ETM.ETACEPotential(scratch_etace, scratch_ps, scratch_st, scratch_rcut) + +## Test evaluation +E_scratch = AtomsCalculators.potential_energy(sys, scratch_calc) +@info "Scratch model energy: $E_scratch" + + +# ============================================================================= +# Part 5: Creating One-Body and Pair Models from Scratch +# ============================================================================= + +# ----------------------------------------------------------------------------- +# ETOneBody: Reference energies +# ----------------------------------------------------------------------------- + +## Define reference energies per species +E0_dict = Dict(ChemicalSpecies(:Si) => -0.846, + ChemicalSpecies(:O) => -1.023) + +## Category function extracts species from atom state +catfun = x -> x.z # x.z is the ChemicalSpecies + +## Create one-body model +et_onebody = ETM.one_body(E0_dict, catfun) +_, onebody_st = Lux.setup(rng, et_onebody) + +## Wrap in calculator (uses small cutoff since no neighbors needed) +onebody_calc = ETM.ETOneBodyPotential(et_onebody, nothing, onebody_st, 3.0) + +@info "ETOneBody model created" +@info " Reference energies: $E0_dict" + +E_onebody = AtomsCalculators.potential_energy(sys, onebody_calc) +@info " One-body energy for test system: $E_onebody" + + +# ============================================================================= +# Part 6: Combining Models with StackedCalculator +# ============================================================================= +# +# StackedCalculator combines multiple calculators by summing their contributions. + +## Stack our from-scratch models +combined_calc = ETM.StackedCalculator((onebody_calc, scratch_calc)) + +@info "StackedCalculator created" +@info " Components: ETOneBody + ETACE" +@info " Total basis size: $(ETM.length_basis(combined_calc))" + +## Evaluate combined model +E_combined = AtomsCalculators.potential_energy(sys, combined_calc) +@info " Combined energy: $E_combined" + +## Training assembly works on StackedCalculator too +efv_combined = ETM.energy_forces_virial_basis(sys, combined_calc) +@info " Combined EFV basis shapes: E=$(size(efv_combined.energy)), F=$(size(efv_combined.forces))" + +@info "Tutorial complete!" diff --git a/examples/modelbuilding/README.md b/examples/modelbuilding/README.md new file mode 100644 index 000000000..500ea672e --- /dev/null +++ b/examples/modelbuilding/README.md @@ -0,0 +1 @@ +lux_model.jl shows how to build a Lux model out of standard layers instead of the "canned" layers provided. This has disadvantages for both performance and flexibility (e.g. unclear how to get efficient jacobians) but can be useful for experimenting. The file is leftover from an extensive development and testing period, likely out of date, and needs to be cleaned up. \ No newline at end of file diff --git a/examples/modelbuilding/lux_model.jl b/examples/modelbuilding/lux_model.jl new file mode 100644 index 000000000..f4f68ee77 --- /dev/null +++ b/examples/modelbuilding/lux_model.jl @@ -0,0 +1,460 @@ +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..")) +using TestEnv; TestEnv.activate(); +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "EquivariantTensors.jl")) +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "Polynomials4ML.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) + +## + +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +# build a pure Lux Rnl basis compatible with LearnableRnlrzz +import EquivariantTensors as ET +import Polynomials4ML as P4ML +import DecoratedParticles as DP + +using StaticArrays, Lux +using AtomsBase, AtomsBuilder, Unitful, AtomsCalculators + +using Random, LuxCore, Test, LinearAlgebra, ACEbase +using Polynomials4ML.Testing: print_tf, println_slim +rng = Random.MersenneTwister(1234) + +Random.seed!(1234) + +## + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 + +# modify rin0cuts to have same cutoff for all elements +# TODO: there is currently a bug with variable cutoffs +# (?is there? The radials seem fine? check again) +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, model) + +# Missing issues: +# Vref = 0 => this will not be tested +# pair potential will also not be tested + +# kill the pair basis for now +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +## +# build the Rnl basis +# here we build it from the model.rbasis, so we can exactly match it +# but in the final implementation we will have to create it directly + +rbasis = model.rbasis +et_i2z = AtomsBase.ChemicalSpecies.(rbasis._i2z) +# et_rbasis = M._convert_Rnl_learnable(rbasis; zlist = et_i2z, +# rfun = x -> norm(x.𝐫) ) +et_rbasis = ETM._convert_Rnl_learnable(rbasis) + +# TODO: this is cheating, but this set can probably be generated quite +# easily as part of the construction of et_rbasis. +et_rspec = rbasis.spec + +## +# build the ybasis + +et_ybasis = ET.EmbedDP( ET.NTtransformST( (x, st) -> x.𝐫, NamedTuple()), + model.ybasis ) +et_yspec = P4ML.natural_indices(et_ybasis.basis) + +## +# combining the Rnl and Ylm basis we can build an embedding layer +et_embed = BranchLayer(; + Rnl = ET.EdgeEmbed( et_rbasis ), + Ylm = ET.EdgeEmbed( et_ybasis ) ) + +## +# now build the linear ACE layer + +# Convert AA_spec from (n,l,m) format to (n,l) format for mb_spec +AA_spec = model.tensor.meta["𝔸spec"] +et_mb_spec = unique([[(n=b.n, l=b.l) for b in bb] for bb in AA_spec]) + +et_mb_basis = ET.sparse_equivariant_tensor( + L = 0, # Invariant (scalar) output only + mb_spec = et_mb_spec, + Rnl_spec = et_rspec, + Ylm_spec = et_yspec, + basis = real + ) + +# et_acel = ET.SparseACElayer(et_mb_basis, (1,)) + +# ------------------------------------------------ +# readout layer : need to select which linear output to +# use based on the center atom species + +# CO: doing it this way is type unstable and causes problems in +# the GPU kernel generation. +# __zi = let zlist = (_i2z = et_i2z, ) +# x -> M._z2i(zlist, x.s) +# end +# +# et_readout = ET.SelectLinL( +# et_mb_basis.lens[1], # input dim +# 1, # output dim +# length(et_i2z), # num species +# __zi ) + + +et_readout_2 = let zlist = et_i2z + __zi = x -> ET.cat2idx(zlist, x.z) + ET.SelectLinL( + et_mb_basis.lens[1], # input dim + 1, # output dim + length(et_i2z), # num species + __zi ) +end + + +# finally build the full model from the two layers +# +# TODO: there is a huge problem here; the read-out layer needs to know +# about the center species; need to figure out how to pass that information +# through to the ace layer +# + +__sz(::Any) = nothing +__sz(A::AbstractArray) = size(A) +__sz(x::Tuple) = __sz.(x) +dbglayer(msg = ""; show=false) = WrappedFunction(x -> + begin + println("$msg : ", typeof(x), ", ", __sz(x)) + if show; display(x); end + return x + end ) + +et_basis = Lux.Chain(; + embed = et_embed, # embedding layer + ace = et_mb_basis, # ACE layer -> basis + unwrp = WrappedFunction(x -> x[1]), # unwrap the tuple + ) + +et_model = Lux.Chain( + L1 = Lux.BranchLayer(; + basis = et_basis, + nodes = WrappedFunction(G -> G.node_data), # pass node data through + ), + Ei = et_readout_2, + E = WrappedFunction(sum), # sum up to get a total energy + ) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) + +## +# generate a second ET model based on the implementation in ETM +et_model_2 = ETM.convert2et(model) +et_ps_2, et_st_2 = LuxCore.setup(MersenneTwister(1234), et_model_2) + +## +# fixup all the parameters to make sure they match +# the basis ordering appears to be identical, but it is not clear it really +# is because meta["mb_spec"] only gives the original ordering before basis +# construction ... +nnll = M.get_nnll_spec(model.tensor) +et_nnll = et_mb_basis.meta["mb_spec"] +et_nnll_2 = et_model_2.basis.meta["mb_spec"] +@show nnll == et_nnll == et_nnll_2 + +# but this is also identical ... +@show ( model.tensor.A2Bmaps[1] + == et_mb_basis.A2Bmaps[1] + == et_model_2.basis.A2Bmaps[1] ) + +# radial basis parameters for et_model +et_ps.L1.basis.embed.Rnl.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.L1.basis.embed.Rnl.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.L1.basis.embed.Rnl.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.L1.basis.embed.Rnl.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] + +# radial basis parameters for et_model_2 +et_ps_2.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps_2.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps_2.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps_2.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] + +# many-body basis parameters; because the readout layer doesn't know about +# species yet we take a single parameter set; this needs to be fixed asap. +# ps.WB[:, 2] .= ps.WB[:, 1] + +et_ps.Ei.W[1, :, 1] .= ps.WB[:, 1] +et_ps.Ei.W[1, :, 2] .= ps.WB[:, 2] + +et_ps_2.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps_2.readout.W[1, :, 2] .= ps.WB[:, 2] + + +## + +# wrap the old ACE model into a calculator +calc_model = ACEpotentials.ACEPotential(model, ps, st) + +# we will also need to get the cutoff radius which we didn't track +# (Another TODO!!!) +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +function rand_struct() + sys = AtomsBuilder.bulk(:Si) * (2,2,1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +function energy_new(sys, et_model) + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + return et_model(G, et_ps, et_st)[1] +end + +function energy_new_2(sys, et_model) + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + Ei, _ = et_model_2(G, et_ps_2, et_st_2) + return sum(Ei) +end + +## + +for ntest = 1:30 + sys = rand_struct() + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + E1 = AtomsCalculators.potential_energy(sys, calc_model) + E2 = energy_new(sys, et_model) + E3 = energy_new_2(sys, et_model_2) + print_tf( @test abs(ustrip(E1) - ustrip(E2)) < 1e-6 ) + print_tf( @test abs(ustrip(E1) - ustrip(E3)) < 1e-6 ) +end +println() + +## +# +# Zygote gradient +# +using Zygote, ForwardDiff + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +∂G1 = Zygote.gradient(G -> et_model(G, et_ps, et_st)[1], G)[1] +∂G2a = Zygote.gradient(G -> sum(et_model_2(G, et_ps_2, et_st_2)[1]), G)[1] +∂G2b = ETM.site_grads(et_model_2, G, et_ps_2, et_st_2) + +@info("confirm consistency of three gradients") +println(@test all(∂G1.edge_data .≈ ∂G2a.edge_data .≈ ∂G2b.edge_data)) + +## + +# Rnl, _ = et_model_2.rembed(G, et_ps_2.rembed, et_st_2.rembed) + +## +# test gradient against ForwardDiff + +function grad_fd(G, model, ps, st) + function _replace_edges(X, Rmat) + Rsvec = [ SVector{3}(Rmat[:, i]) for i in 1:size(Rmat, 2) ] + new_edgedata = [ DP.PState(𝐫 = 𝐫, z0 = x.z0, z1 = x.z1, 𝐒 = x.𝐒) + for (𝐫, x) in zip(Rsvec, G.edge_data) ] + return ET.ETGraph( X.ii, X.jj, X.first, + X.node_data, new_edgedata, X.graph_data, + X.maxneigs ) + end + + function _energy(Rmat) + G_new = _replace_edges(G, Rmat) + return sum(model(G_new, ps, st)[1]) + end + + Rsvec = [ x.𝐫 for x in G.edge_data ] + Rmat = reinterpret(reshape, eltype(Rsvec[1]), Rsvec) + ∇E_fd = ForwardDiff.gradient(_energy, Rmat) + ∇E_svec = [ SVector{3}(∇E_fd[:, i]) for i in 1:size(∇E_fd, 2) ] + ∇E_edges = [ DP.VState(; 𝐫 = 𝐫) for 𝐫 in ∇E_svec ] + return ET.ETGraph( G.ii, G.jj, G.first, + G.node_data, ∇E_edges, G.graph_data, + G.maxneigs ) +end + +@info("confirm consistency of gradients with ForwardDiff") + +∇E_fd = grad_fd(G, et_model_2, et_ps_2, et_st_2) +println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data)) + +## +# +# Jacobians cannot work yet - these need manual work or more infrastructure +# + +# sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +nnodes = length(G.node_data) +iZ = et_model_2.readout.selector.(G.node_data) +WW = et_ps_2.readout.W + +𝔹1 = ETM.site_basis(et_model_2, G, et_ps_2, et_st_2) +𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model_2, G, et_ps_2, et_st_2) + +## + +@info("confirm correctness of site basis") + +println_slim(@test 𝔹1 ≈ 𝔹2) +Ei_a = [ dot(𝔹2[i, :], WW[1, :, iZ[i]]) for (i, iz) in enumerate(iZ) ] +Ei_b = et_model_2(G, et_ps_2, et_st_2)[1][:] +println(@test Ei_a ≈ Ei_b) + +## + +@info("Confirm correctness of Jacobian against gradient") +# compute the gradient from the jacobian by hand +# size(𝔹2) = (num_nodes, basis_len) +# size(∂𝔹2) = (num_edges, num_nodes, basislen) + +∇Ei2 = reduce( hcat, ∂𝔹2[:, i, :] * WW[1, :, iZ[i]] + for (i, iz) in enumerate(iZ) ) +∇Ei3 = reshape(∇Ei2, size(∇Ei2)..., 1) +∇E_𝔹_edges = ET.rev_reshape_embedding(∇Ei3, G)[:] +println(@test all(∇E_𝔹_edges .≈ ∂G2b.edge_data)) + + +## +# +# demo GPU evaluation +# + +@info("Checking GPU evaluation with Metal.jl") + +# TODO: replace Metal with generic GPU test +using Metal +dev = Metal.mtl + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +G_32 = ET.float32(G) + +# move all data to the device +G_32_dev = dev(G_32) +ps_dev = dev(ET.float32(et_ps)) +st_dev = dev(ET.float32(et_st)) +ps_dev_2 = dev(ET.float32(et_ps_2)) +st_dev_2 = dev(ET.float32(et_st_2)) +ps_32_2 = ET.float32(et_ps_2) +st_32_2 = ET.float32(et_st_2) + +E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) +E2 = energy_new(sys, et_model) +E3 = et_model(G_32_dev, ps_dev, st_dev)[1] +E4 = sum(et_model_2(G_32_dev, ps_dev_2, st_dev_2)[1]) + +println_slim( @test abs(E1 - E2) < 1e-5 ) +println_slim( @test abs(E1 - E3) / (abs(E1) + abs(E3) + 1e-7) < 1e-5 ) +println_slim( @test abs(E1 - E4) / (abs(E1) + abs(E4) + 1e-7) < 1e-5 ) + +## +# gradients on GPU +# currently failing because somehow the transform is still +# accessing some Float64 values somewhere .... + +@info("Check Evaluation of gradient on GPU") +g1 = ETM.site_grads(et_model_2, G_32, ps_32_2, st_32_2) +g2_dev = ETM.site_grads(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +∇1 = g1.edge_data +∇2 = Array(g2_dev.edge_data) +println_slim( @test all(∇1 .≈ ∇2) ) + +## + +@info("Basis evaluation on GPU") + +𝔹1 = ETM.site_basis(et_model_2, G_32, ps_32_2, st_32_2) +𝔹2_dev = ETM.site_basis(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +𝔹2 = Array(𝔹2_dev) +println_slim( @test 𝔹1 ≈ 𝔹2 ) + + +@info("Basis jacobian evaluation on GPU") +𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model_2, G_32, ps_32_2, st_32_2) + +try + 𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model_2, G_32_dev, ps_dev_2, st_dev_2) +catch + @warn("Basis jacobian evaluation on GPU still failing") +end + + +## +# leftover debugging snippets +# +# This was a huge problem, namely to get the differentiation through +# the radial transform to behave nicely. At the moment this is solved +# via a workaround by implementing 2 x NTtransformST _pb_ed functions. +# this should be revisited. +# +# ET.evaluate_ed(et_model_2.rembed, G, +# et_ps_2.rembed, et_st_2.rembed) + +# (R, dR), _ = ET.evaluate_ed(et_model_2.rembed, G_32, +# ps_32_2.rembed, st_32_2.rembed) + + +# ET.evaluate_ed(et_model_2.rembed, G_32_dev, +# ps_dev_2.rembed, st_dev_2.rembed) + + +## + +# This one is also quote interesting and could be posted on +# the julia discourse forum. Using reinterpret works +# fine on CPU but fail on GPU. +# + +#= +using DecoratedParticles, ForwardDiff, StaticArrays +using Metal +import EquivariantTensors as ET + +function graddp(f, v) + # _dp2svec(v) = v.𝐫 + # _svec2dp(sv) = VState(𝐫 = sv) + _dp2svec(v) = reinterpret(SVector{3, Float32}, v) + _dp2svec(v) = ET.DiffNT._nt2svec(v) + _svec2dp(sv) = ET.DiffNT._svec2nt(sv, v) + _fvec = _v -> f(_svec2dp(_v)) + g = ForwardDiff.gradient(_fvec, _dp2svec(v)) + return _svec2dp(g) + # return g +end + +function gradX(X, P) + function _gradx(x, p) + v0 = zero(vstate_type(x)) + graddp(_v -> sum((x + _v).𝐫 .* p), v0) + end + return _gradx.(X, P) +end + +X = [ VState(𝐫 = randn(SVector{3, Float32})) for _=1:10 ] +P = randn(SVector{3, Float32}, 10) +vsP = [ VState(𝐫 = p) for p in P ] +gradX(X, P) + +X_dev = mtl(X) +P_dev = mtl(P) +Array(gradX(X_dev, P_dev)) +=# \ No newline at end of file diff --git a/src/ACEpotentials.jl b/src/ACEpotentials.jl index 9688995ca..3f7ddb4eb 100644 --- a/src/ACEpotentials.jl +++ b/src/ACEpotentials.jl @@ -12,6 +12,10 @@ include("defaults.jl") include("models/models.jl") include("ace1_compat.jl") +# New ET backend based models +include("et_models/et_models.jl") + + # Fitting include("atoms_data.jl") include("fit_model.jl") @@ -37,7 +41,6 @@ import ACEpotentials.Models: algebraic_smoothness_prior, exp_smoothness_prior, gaussian_smoothness_prior, set_parameters!, - fast_evaluator, @committee, set_committee! import JSON diff --git a/src/ace1_compat.jl b/src/ace1_compat.jl index b6ddb8394..e75a87fb8 100644 --- a/src/ace1_compat.jl +++ b/src/ace1_compat.jl @@ -155,7 +155,7 @@ function _get_all_rcut(kwargs; _rcut = kwargs[:rcut]) end -function _rin0cuts_rcut(zlist, cutoffs::Dict, kwargs = nothing) +function _rin0cuts_rcut(zlist, cutoffs::AbstractDict, kwargs = nothing) function _get_r0(zi, zj) if kwargs == nothing return DefaultHypers.bond_len(zi, zj) diff --git a/src/analysis/dataset_analysis.jl b/src/analysis/dataset_analysis.jl index 0b1e9fe9d..4c4183124 100644 --- a/src/analysis/dataset_analysis.jl +++ b/src/analysis/dataset_analysis.jl @@ -6,7 +6,7 @@ using StaticArrays using AtomsBuilder using LinearAlgebra: norm, dot -function copy_zz_sym!(D::Dict) +function copy_zz_sym!(D::AbstractDict) _zz = collect(keys(D)) for z12 in _zz sym12 = Symbol.( ChemicalSpecies.(z12) ) diff --git a/src/analysis/potential_analysis.jl b/src/analysis/potential_analysis.jl index 21cbd78f2..82bc026a5 100644 --- a/src/analysis/potential_analysis.jl +++ b/src/analysis/potential_analysis.jl @@ -90,7 +90,7 @@ function trimer_energy(IP, r1, r2, θ, z0, z1, z2) - atom_energy(IP, z2) ) end -# function copy_zz_sym!(D::Dict) +# function copy_zz_sym!(D::AbstractDict) # _zz = collect(keys(D)) # for z12 in _zz # sym12 = Symbol.( ChemicalSpecies.(z12) ) diff --git a/src/atoms_data.jl b/src/atoms_data.jl index 0ec19bbeb..653b9de4b 100644 --- a/src/atoms_data.jl +++ b/src/atoms_data.jl @@ -351,12 +351,12 @@ function compute_errors(data::AbstractArray{AtomsData}, model; end -function print_errors_tables(config_errors::Dict) +function print_errors_tables(config_errors::AbstractDict) print_rmse_table(config_errors) print_mae_table(config_errors) end -function _print_err_tbl(D::Dict) +function _print_err_tbl(D::AbstractDict) header = ["Type", "E [meV]", "F [eV/A]", "V [meV]"] config_types = setdiff(collect(keys(D)), ["set",]) push!(config_types, "set") @@ -374,12 +374,12 @@ function _print_err_tbl(D::Dict) end -function print_rmse_table(config_errors::Dict; header=true) +function print_rmse_table(config_errors::AbstractDict; header=true) if header; (@info "RMSE Table"); end _print_err_tbl(config_errors["rmse"]) end -function print_mae_table(config_errors::Dict; header=true) +function print_mae_table(config_errors::AbstractDict; header=true) if header; (@info "MAE Table"); end _print_err_tbl(config_errors["mae"]) end diff --git a/src/et_models/convert.jl b/src/et_models/convert.jl new file mode 100644 index 000000000..b53ffcda9 --- /dev/null +++ b/src/et_models/convert.jl @@ -0,0 +1,304 @@ + +using StaticArrays +using Lux + +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +import ACEpotentials.Models: LearnableRnlrzzBasis, PolyEnvelope2sX, + _i2z, GeneralizedAgnesiTransform, PolyEnvelope1sR + +using LinearAlgebra: norm, dot + + +function convert2et(model) + # TODO: add checks that the model we are importing is of the format + # that we can actually import and then raise errors if not. + # but since we might just drop this import functionality entirely it + # is not so clear we should waste our time on that. + + # extract species information from the ACE model + rbasis = model.rbasis + et_i2z = ChemicalSpecies.(rbasis._i2z) + + # ---------------------------- REMBED + # convert the radial basis + et_rbasis = _convert_Rnl_learnable(rbasis) + et_rspec = rbasis.spec + # convert the radial basis into an edge embedding layer which has some + # additional logic for handling the ETGraph input correctly + rembed = ET.EdgeEmbed( et_rbasis) + + # ---------------------------- YEMBED + # convert the angular basis + ybasis = model.ybasis + et_ybasis = ET.EmbedDP( ET.NTtransformST( (x, st) -> x.𝐫, NamedTuple()), + ybasis ) + et_yspec = P4ML.natural_indices(et_ybasis.basis) + yembed = ET.EdgeEmbed( et_ybasis) + + # ---------------------------- MANY-BODY BASIS + # Convert AA_spec from (n,l,m) format to (n,l) format for mb_spec + AA_spec = model.tensor.meta["𝔸spec"] + et_mb_spec = unique([[(n=b.n, l=b.l) for b in bb] for bb in AA_spec]) + + et_mb_basis = ET.sparse_equivariant_tensor( + L = 0, # Invariant (scalar) output only + mb_spec = et_mb_spec, + Rnl_spec = et_rspec, + Ylm_spec = et_yspec, + basis = real + ) + + # ---------------------------- READOUT LAYER + # readout layer : need to select which linear operator to apply + # based on the center atom species + selector = let zlist = et_i2z + x -> ET.cat2idx(zlist, x.z) + end + readout = ET.SelectLinL( + et_mb_basis.lens[1], # input dim (mb basis length) + 1, # output dim (only one site energy per atom) + length(et_i2z), # number of categories = num species + selector) + + # generate the model and return it + et_model = ETACE(rembed, yembed, et_mb_basis, readout) + return et_model +end + + + +# In ET we currently store an edge xij as a NamedTuple, e.g, +# xij = (𝐫ij = ..., zi = ..., zj = ...) +# The NTtransform is a wrapper for mapping xij -> y +# (in this case y = transformed distance) adding logic to enable +# differentiation through this operation. +# +# In ET.Atoms edges are of the form xij = (𝐫 = ..., s0 = ..., s1 = ...) + + +# build a pure Lux Rnl basis 100% compatible with LearnableRnlrzz + +function _convert_Rnl_learnable(basis; zlist = ChemicalSpecies.(basis._i2z), + rfun = x -> norm(x.𝐫) ) + + # number of species + NZ = length(zlist) + + selector = let zlist = tuple(zlist...) + xij -> ET.catcat2idx(zlist, xij.z0, xij.z1) + end + + # construct the transform to be a Lux layer that behaves a bit + # like a WrappedFunction, but with additional support for + # named-tuple or DP inputs + # + et_trans = _convert_agnesi(basis) + + # OLD VERSION - KEEP FOR DEBUGGING then remove + # et_trans = let transforms = basis.transforms + # ET.NTtransform( xij -> begin + # trans_ij = transforms[__z2i(xij.s0), __z2i(xij.s1)] + # return trans_ij(rfun(xij)) + # end ) + # end + + # the envelope is always a simple quartic y -> (1 - y^2)^2 + # otherwise make this transform fail. + # ( note the transforms is normalized to map to [-1, 1] + # y outside [-1, 1] maps to 1 or -1. ) + # this obviously needs to be relaxed if we want compatibility + # with older versions of the code + for env in basis.envelopes + @assert env isa PolyEnvelope2sX + @assert env.p1 == env.p2 == 2 + @assert env.x1 == -1 + @assert env.x2 == 1 + end + et_env = y -> (1 - y^2)^2 + # et_env = _convert_envelope(basis.envelopes) + + # the polynomial basis just stays the same + # but needs to be wrapped due to the envelope being applied + # + et_polys = basis.polys + Penv = P4ML.wrapped_basis( BranchLayer( + et_polys, # y -> P + WrappedFunction( y -> et_env.(y) ), # y -> fₑₙᵥ + fusion = WrappedFunction( Pe -> Pe[2] .* Pe[1] ) + ) ) + + # the linear layer transformation + # P(yij) -> W[(Zi, Zj)] * P(yij) + # with W[a] learnable weights + # + et_linl = ET.SelectLinL(length(et_polys), # indim + length(basis.spec), # outdim + NZ^2, # num (Zi,Zj) pairs + selector) + + et_rbasis = ET.EmbedDP(et_trans, Penv, et_linl) + return et_rbasis +end + + + +# important auxiliary function to convert the transforms + +function _agnesi_et_params(trans) + @assert trans.trans isa GeneralizedAgnesiTransform + a = trans.trans.a + pcut = trans.trans.p + pin = trans.trans.q + req = trans.trans.r0 + rin = trans.trans.rin + rcut = trans.rcut + + params = ET.agnesi_params(pcut, pin, rin, req, rcut) + @assert params.a ≈ a + + # ----- for debugging ----------- + # r = rin + rand() * (rcut - rin) + # y1 = trans(r) + # y2 = ET.eval_agnesi(r, params) + # @assert y1 ≈ y2 + # ------------------------------- + + # ----- for debugging ----------- + # DEBUG: convert to Float32, to see if that fixes the + # site_grads on GPU? + # @show params + # params_32 = ET.float32(params) + # @show params_32 + # ------------------------------- + + return params +end + + +function _convert_agnesi(rbasis::LearnableRnlrzzBasis) + transforms = rbasis.transforms + @assert transforms isa SMatrix + NZ = size(transforms, 1) + params = [] + for i = 1:NZ, j = i:NZ + push!(params, _agnesi_et_params(transforms[i,j])) + end + st = (zlist = ChemicalSpecies.(rbasis._i2z), + params = SVector{length(params)}(identity.(params)) ) + + f_agnesi = let + (x, st) -> begin + r = norm(x.𝐫) + idx = ET.catcat2idx_sym(st.zlist, x.z0, x.z1) + return ET.eval_agnesi(r, st.params[idx]) + end + end + + return ET.NTtransformST(f_agnesi, st) +end + + +function _convert_envelope(envelopes) + TENV = typeof(envelopes[1]) + for env in envelopes + @assert typeof(env) == TENV + end + + @show TENV + return _convert_env_TENV(TENV, envelopes) +end + +function _convert_env_TENV(::Type{<: PolyEnvelope2sX}, envelopes) + for env in envelopes + @assert env isa PolyEnvelope2sX + @assert env.p1 == env.p2 == 2 + @assert env.x1 == -1 + @assert env.x2 == 1 + end + return y -> (1 - y^2)^2 +end + +function _convert_env_TENV(::Type{<: PolyEnvelope1sR}, envelopes) + env1 = envelopes[1] + for env in envelopes + @assert env == env1 + end + f_env = (r, st) -> _eval_env_1sr(r, st.rcut, st.p) + refst = ( rcut = env1.rcut, p = env1.p ) + return ET.st_transform(f_env, refst) +end + +function _eval_env_1sr(r, rcut, p) + _1 = one(r) + s = r / rcut + return (s^(-p) - _1) * (_1 - s) * (s < _1) +end + +function _convert_pair_envelope(envelopes) + TENV = typeof(envelopes[1]) + for env in envelopes + @assert typeof(env) == TENV + end + env1 = envelopes[1] + @assert env1 isa PolyEnvelope1sR + for env in envelopes + @assert env == env1 + end + refst = ( rcut = env1.rcut, p = env1.p ) + f_env = ET.dp_transform( (x, st) -> _eval_env_1sr( norm(x.𝐫), st.rcut, st.p ), + refst ) + return f_env +end + + + +function convertpair(model) + + # extract radial basis information + basis = model.pairbasis + zlist = ChemicalSpecies.(basis._i2z) + NZ = length(zlist) + + # this construction is a little different from the Rnl basis for the + # many-body model because the envelope takes a different input + # and this makes life a little more complicated. + + # 1: polynomials without the envelope + # + dp_agnesi = _convert_agnesi(basis) + polys = basis.polys + selector2 = let zlist = zlist + xij -> ET.catcat2idx(zlist, xij.z0, xij.z1) + end + et_linl = ET.SelectLinL(length(polys), # indim + length(basis), # outdim + NZ^2, # num (Zi,Zj) pairs + selector2) + rbasis_1 = ET.EmbedDP(dp_agnesi, polys, et_linl) + + # 2: envelope + dp_envelope = _convert_pair_envelope(basis.envelopes) + # _env_r = _convert_envelope(basis.envelopes) + # dp_envelope = ET.dp_transform( (x, st) -> _env_r.f( norm(x.𝐫), st ), + # _env_r.refstate ) + + # 3. combine into the radial basis + rembed = ET.EdgeEmbed( EnvRBranchL(dp_envelope, rbasis_1) ) + + # 4. rembed provides the radial basis for the pair model, now we just + # need the readout layer which is similar to before. + selector1 = let zlist = zlist + x -> ET.cat2idx(zlist, x.z) + end + readout = ET.SelectLinL( + length(basis), + 1, # output dim (only one site energy per atom) + NZ, # number of categories = num species + selector1) + + et_pair = ETPairModel(rembed, readout) + + return et_pair +end diff --git a/src/et_models/et_ace.jl b/src/et_models/et_ace.jl new file mode 100644 index 000000000..df2446453 --- /dev/null +++ b/src/et_models/et_ace.jl @@ -0,0 +1,79 @@ + +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +import LuxCore: AbstractLuxContainerLayer +import AtomsBase: ChemicalSpecies +using ConcreteStructs: @concrete +using LinearAlgebra: norm, dot + + +@concrete struct ETACE <: AbstractLuxContainerLayer{(:rembed, :yembed, :basis, :readout)} + rembed # radial embedding layer + yembed # angular embedding layer + basis # many-body basis layer + readout # selectlinl readout layer +end + + +(l::ETACE)(X::ET.ETGraph, ps, st) = _apply_etace(l, X, ps, st), st + + +function _apply_etace(l::ETACE, X::ET.ETGraph, ps, st) + # embed edges + Rnl, _ = l.rembed(X, ps.rembed, st.rembed) + Ylm, _ = l.yembed(X, ps.yembed, st.yembed) + + # many-body basis + (𝔹,), _ = l.basis((Rnl, Ylm), ps.basis, st.basis) + + # readout layer + φ, _ = l.readout((𝔹, X.node_data), ps.readout, st.readout) + + # TODO: return site energies or total energy? + # for THIS layer probably site energies, then write all + # the summation and differentiation in the calculator layer. + + return φ +end + +# ----------------------------------------------------------- + +import Zygote + +# +# At first glance this looks like we are computing ∂E / ∂ri but this is not +# actually true. Because E = ∑ Ei and by interpreting G as a list of edges +# we are differentiating E w.r.t. 𝐫ij which is the same is Ei w.r.t. 𝐫ij. +# + +function site_grads(l::ETACE, X::ET.ETGraph, ps, st) + ∂X = Zygote.gradient( X -> sum(_apply_etace(l, X, ps, st)), X)[1] + return ∂X +end + + +# ----------------------------------------------------------- +# basis and jacobian evaluation + + +function site_basis(l::ETACE, X::ET.ETGraph, ps, st) + # embed edges + Rnl, _ = l.rembed(X, ps.rembed, st.rembed) + Ylm, _ = l.yembed(X, ps.yembed, st.yembed) + + # many-body basis + 𝔹, _ = l.basis((Rnl, Ylm), ps.basis, st.basis) + + return 𝔹[1] +end + + +function site_basis_jacobian(l::ETACE, X::ET.ETGraph, ps, st) + (R, ∂R), _ = ET.evaluate_ed(l.rembed, X, ps.rembed, st.rembed) + (Y, ∂Y), _ = ET.evaluate_ed(l.yembed, X, ps.yembed, st.yembed) + # _jacobian_X for SparseACEbasis takes (basis, Rnl, Ylm, dRnl, dYlm, ps, st) + # Requires EquivariantTensors >= 0.4.2 + (𝔹,), (∂𝔹,) = ET._jacobian_X(l.basis, R, Y, ∂R, ∂Y, ps.basis, st.basis) + return 𝔹, ∂𝔹 +end diff --git a/src/et_models/et_calculators.jl b/src/et_models/et_calculators.jl new file mode 100644 index 000000000..ac1c36749 --- /dev/null +++ b/src/et_models/et_calculators.jl @@ -0,0 +1,784 @@ + +# Calculator interfaces for ETACE models +# Provides AtomsCalculators-compatible energy/forces/virial evaluation +# +# Architecture: +# - WrappedSiteCalculator: Unified wrapper for ETACE-pattern models (ETACE, ETPairModel, ETOneBody) +# - ETACEPotential: Type alias for WrappedSiteCalculator with ETACE model +# - StackedCalculator: Combines multiple calculators (see stackedcalc.jl) +# +# All wrapped models must implement the ETACE interface: +# model(G, ps, st) -> (site_energies, st) +# site_grads(model, G, ps, st) -> edge gradients +# +# See also: stackedcalc.jl for StackedCalculator (combines multiple calculators) + +import AtomsCalculators +import AtomsBase: AbstractSystem, ChemicalSpecies +import EquivariantTensors as ET +using DecoratedParticles: PState +using StaticArrays +using Unitful +using LinearAlgebra: norm + +# Import from parent Models module to extend these functions +import ..Models: length_basis, energy_forces_virial_basis, potential_energy_basis + + +# ============================================================================ +# WrappedSiteCalculator - Unified wrapper for ETACE-pattern models +# ============================================================================ + +""" + WrappedSiteCalculator{M, PS, ST} + +Wraps any ETACE-pattern model (ETACE, ETPairModel, ETOneBody) and provides +the AtomsCalculators interface. + +All wrapped models must implement the ETACE interface: +- `model(G, ps, st)` → `(site_energies, st)` +- `site_grads(model, G, ps, st)` → edge gradients + +Mutable to allow parameter updates during training. + +# Example +```julia +# With ETACE model +calc = WrappedSiteCalculator(et_model, ps, st, 5.5) + +# With ETOneBody (upstream) +et_onebody = ETM.one_body(Dict(:Si => -0.846), x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) +calc = WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + +E = potential_energy(sys, calc) +F = forces(sys, calc) +``` + +# Fields +- `model` - ETACE-pattern model (ETACE, ETPairModel, or ETOneBody) +- `ps` - Model parameters (can be `nothing` for ETOneBody) +- `st` - Model state +- `rcut::Float64` - Cutoff radius for graph construction (Å) +- `co_ps` - Optional committee parameters for uncertainty quantification +""" +mutable struct WrappedSiteCalculator{M, PS, ST} + model::M + ps::PS + st::ST + rcut::Float64 + co_ps::Any +end + +# Constructor without committee parameters +function WrappedSiteCalculator(model, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut), nothing) +end + +cutoff_radius(calc::WrappedSiteCalculator) = calc.rcut * u"Å" + +function _wrapped_energy(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + Ei, _ = calc.model(G, calc.ps, calc.st) + return sum(Ei) +end + +function _wrapped_forces(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + ∂G = site_grads(calc.model, G, calc.ps, calc.st) + # Handle empty edge case (e.g., ETOneBody with small cutoff) + if isempty(∂G.edge_data) + return zeros(SVector{3, Float64}, length(sys)) + end + # forces_from_edge_grads returns +∇E, negate for forces + return -ET.Atoms.forces_from_edge_grads(sys, G, ∂G.edge_data) +end + +# Compute virial tensor from edge gradients +function _compute_virial(G::ET.ETGraph, ∂G) + # V = -∑ (∂E/∂𝐫ij) ⊗ 𝐫ij + V = zeros(SMatrix{3,3,Float64,9}) + for (edge, ∂edge) in zip(G.edge_data, ∂G.edge_data) + V -= ∂edge.𝐫 * edge.𝐫' + end + return V +end + +function _wrapped_virial(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + ∂G = site_grads(calc.model, G, calc.ps, calc.st) + # Handle empty edge case + if isempty(∂G.edge_data) + return zeros(SMatrix{3,3,Float64,9}) + end + return _compute_virial(G, ∂G) +end + +function _wrapped_energy_forces_virial(calc::WrappedSiteCalculator, sys::AbstractSystem) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + + # Energy from site energies (call model directly - ETACE interface) + Ei, _ = calc.model(G, calc.ps, calc.st) + E = sum(Ei) + + # Forces and virial from edge gradients + ∂G = site_grads(calc.model, G, calc.ps, calc.st) + + # Handle empty edge case (e.g., ETOneBody with small cutoff) + if isempty(∂G.edge_data) + F = zeros(SVector{3, Float64}, length(sys)) + V = zeros(SMatrix{3,3,Float64,9}) + else + F = -ET.Atoms.forces_from_edge_grads(sys, G, ∂G.edge_data) + V = _compute_virial(G, ∂G) + end + + return (energy=E, forces=F, virial=V) +end + +# AtomsCalculators interface for WrappedSiteCalculator +AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + return _wrapped_energy(calc, sys) * u"eV" +end + +AtomsCalculators.@generate_interface function AtomsCalculators.forces( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + return _wrapped_forces(calc, sys) .* u"eV/Å" +end + +AtomsCalculators.@generate_interface function AtomsCalculators.virial( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + return _wrapped_virial(calc, sys) * u"eV" +end + +function AtomsCalculators.energy_forces_virial( + sys::AbstractSystem, calc::WrappedSiteCalculator; kwargs...) + efv = _wrapped_energy_forces_virial(calc, sys) + return ( + energy = efv.energy * u"eV", + forces = efv.forces .* u"eV/Å", + virial = efv.virial * u"eV" + ) +end + + +# ============================================================================ +# ETACEPotential - Type alias for WrappedSiteCalculator{ETACE} +# ============================================================================ + +""" + ETACEPotential + +AtomsCalculators-compatible calculator wrapping an ETACE model. +This is a type alias for `WrappedSiteCalculator{<:ETACE, PS, ST}`. + +Access underlying components via: +- `calc.model` - The ETACE model +- `calc.ps` - Model parameters +- `calc.st` - Model state +- `calc.rcut` - Cutoff radius in Ångström +- `calc.co_ps` - Committee parameters (optional) + +# Example +```julia +calc = ETACEPotential(et_model, ps, st, 5.5) +E = potential_energy(sys, calc) +``` +""" +const ETACEPotential{MOD<:ETACE, PS, ST} = WrappedSiteCalculator{MOD, PS, ST} + +# Constructor: creates WrappedSiteCalculator with ETACE model directly +function ETACEPotential(model::ETACE, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut)) +end + +# ============================================================================ +# Training Assembly Interface +# ============================================================================ +# +# These functions compute the basis values for linear least squares fitting. +# The linear parameters are the readout weights W[1, k, s] where: +# k = basis function index (1:nbasis) +# s = species index (1:nspecies) +# +# Total parameters: nbasis * nspecies +# +# Energy basis: E = ∑_i ∑_k W[k, species[i]] * 𝔹[i, k] +# Force basis: F_atom = -∑ edges ∂E/∂r_edge, computed per basis function +# Virial basis: V = -∑ edges (∂E/∂r_edge) ⊗ r_edge, computed per basis function + +# Accessor helpers for ETACEPotential (which is WrappedSiteCalculator{ETACE}) +_etace(calc::ETACEPotential) = calc.model # Underlying ETACE model (direct) +_ps(calc::ETACEPotential) = calc.ps # Model parameters +_st(calc::ETACEPotential) = calc.st # Model state + +""" + length_basis(calc::ETACEPotential) + +Return the number of linear parameters in the model (nbasis * nspecies). +""" +function length_basis(calc::ETACEPotential) + etace = _etace(calc) + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + return nbasis * nspecies +end + +# ACEfit integration +import ACEfit +ACEfit.basis_size(calc::ETACEPotential) = length_basis(calc) + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::ETACEPotential) + +Compute the basis functions for energy, forces, and virial. +Returns a named tuple with: +- `energy::Vector{Float64}` - length = length_basis(calc) +- `forces::Matrix{SVector{3,Float64}}` - size = (natoms, length_basis) +- `virial::Vector{SMatrix{3,3,Float64}}` - length = length_basis(calc) + +The linear combination of basis values with parameters gives: + E = dot(energy, params) + F = forces * params + V = sum(params .* virial) +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::ETACEPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + etace = _etace(calc) + + # Get basis and jacobian + # 𝔹: (nnodes, nbasis) - basis values per site (Float64) + # ∂𝔹: (maxneigs, nnodes, nbasis) - directional derivatives (VState objects) + 𝔹, ∂𝔹 = site_basis_jacobian(etace, G, _ps(calc), _st(calc)) + + natoms = length(sys) + nnodes = size(𝔹, 1) + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + nparams = nbasis * nspecies + maxneigs = size(∂𝔹, 1) + + # Species indices for each node + iZ = etace.readout.selector.(G.node_data) + + # Initialize outputs + E_basis = zeros(nparams) + F_basis = zeros(SVector{3, Float64}, natoms, nparams) + V_basis = zeros(SMatrix{3, 3, Float64, 9}, nparams) + + # Pre-allocate work buffer for gradient (same element type as ∂𝔹) + # This avoids allocating a new matrix in each iteration + ∇Ei_buf = similar(∂𝔹, maxneigs, nnodes) + + # Pre-compute a zero element for masking (same type as ∂𝔹 elements) + zero_grad = zero(∂𝔹[1, 1, 1]) + + # Pre-compute edge vectors for virial (avoid repeated access) + edge_𝐫 = [edge.𝐫 for edge in G.edge_data] + + # Compute basis values for each parameter (k, s) pair + # Parameter index: p = (s-1) * nbasis + k + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + + # Energy basis: sum of 𝔹[i, k] for atoms of species s + for i in 1:nnodes + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + + # Fill gradient buffer: ∇Ei[:, i] = ∂𝔹[:, i, k] if iZ[i] == s, else zeros + # This avoids allocating W_unit and doing matrix-vector multiply + for i in 1:nnodes + if iZ[i] == s + @views ∇Ei_buf[:, i] .= ∂𝔹[:, i, k] + else + @views ∇Ei_buf[:, i] .= Ref(zero_grad) + end + end + + # Reshape for rev_reshape_embedding (needs 3D array) - this is a view, no allocation + ∇Ei_3d = reshape(∇Ei_buf, maxneigs, nnodes, 1) + + # Convert to edge-indexed format with 3D vectors + ∇E_edges = ET.rev_reshape_embedding(∇Ei_3d, G)[:] + + # Convert edge gradients to atomic forces (negate for forces) + F_basis[:, p] = -ET.Atoms.forces_from_edge_grads(sys, G, ∇E_edges) + + # Compute virial: V = -∑ (∂E/∂𝐫ij) ⊗ 𝐫ij + V = zero(SMatrix{3, 3, Float64, 9}) + for (e, ∂edge) in enumerate(∇E_edges) + V -= ∂edge.𝐫 * edge_𝐫[e]' + end + V_basis[p] = V + end + end + + return ( + energy = E_basis * u"eV", + forces = F_basis .* u"eV/Å", + virial = V_basis * u"eV" + ) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::ETACEPotential) + +Compute only the energy basis (faster when forces/virial not needed). +""" +function potential_energy_basis(sys::AbstractSystem, calc::ETACEPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + etace = _etace(calc) + + # Get basis values + 𝔹 = site_basis(etace, G, _ps(calc), _st(calc)) + + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + nparams = nbasis * nspecies + + # Species indices for each node + iZ = etace.readout.selector.(G.node_data) + + # Compute energy basis + E_basis = zeros(nparams) + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + for i in 1:length(G.node_data) + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + end + end + + return E_basis * u"eV" +end + +""" + get_linear_parameters(calc::ETACEPotential) + +Extract the linear parameters (readout weights) as a flat vector. +Parameters are ordered as: [W[1,:,1]; W[1,:,2]; ... ; W[1,:,nspecies]] +""" +function get_linear_parameters(calc::ETACEPotential) + return vec(_ps(calc).readout.W) +end + +""" + set_linear_parameters!(calc::ETACEPotential, θ::AbstractVector) + +Set the linear parameters (readout weights) from a flat vector. +""" +function set_linear_parameters!(calc::ETACEPotential, θ::AbstractVector) + etace = _etace(calc) + nbasis = etace.readout.in_dim + nspecies = etace.readout.ncat + @assert length(θ) == nbasis * nspecies + + # Reshape and copy into ps (WrappedSiteCalculator is mutable) + ps = _ps(calc) + new_W = reshape(θ, 1, nbasis, nspecies) + calc.ps = merge(ps, (readout = merge(ps.readout, (W = new_W,)),)) + return calc +end + + +# ============================================================================ +# ETPairPotential - Type alias for WrappedSiteCalculator{ETPairModel} +# ============================================================================ + +""" + ETPairPotential + +AtomsCalculators-compatible calculator wrapping an ETPairModel. +This is a type alias for `WrappedSiteCalculator{<:ETPairModel, PS, ST}`. + +Supports training assembly functions: +- `length_basis(calc)` - Total linear parameters +- `energy_forces_virial_basis(sys, calc)` - Full EFV design row +- `potential_energy_basis(sys, calc)` - Energy design row +- `get_linear_parameters(calc)` / `set_linear_parameters!(calc, θ)` + +# Example +```julia +et_pair = convertpair(model) +ps, st = Lux.setup(rng, et_pair) +calc = ETPairPotential(et_pair, ps, st, 5.5) +E = potential_energy(sys, calc) +``` +""" +const ETPairPotential{MOD<:ETPairModel, PS, ST} = WrappedSiteCalculator{MOD, PS, ST} + +function ETPairPotential(model::ETPairModel, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut)) +end + +# ============================================================================ +# ETPairPotential Training Assembly +# ============================================================================ + +# Accessor helpers +_pair(calc::ETPairPotential) = calc.model +_ps(calc::ETPairPotential) = calc.ps +_st(calc::ETPairPotential) = calc.st + +""" + length_basis(calc::ETPairPotential) + +Return the number of linear parameters in the pair model (nbasis * nspecies). +""" +function length_basis(calc::ETPairPotential) + pair = _pair(calc) + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + return nbasis * nspecies +end + +ACEfit.basis_size(calc::ETPairPotential) = length_basis(calc) + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::ETPairPotential) + +Compute the basis functions for energy, forces, and virial for pair potential. +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::ETPairPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + pair = _pair(calc) + + # Get basis and jacobian + 𝔹, ∂𝔹 = site_basis_jacobian(pair, G, _ps(calc), _st(calc)) + + natoms = length(sys) + nnodes = size(𝔹, 1) + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + nparams = nbasis * nspecies + maxneigs = size(∂𝔹, 1) + + # Species indices for each node + iZ = pair.readout.selector.(G.node_data) + + # Initialize outputs + E_basis = zeros(nparams) + F_basis = zeros(SVector{3, Float64}, natoms, nparams) + V_basis = zeros(SMatrix{3, 3, Float64, 9}, nparams) + + # Pre-allocate work buffer + ∇Ei_buf = similar(∂𝔹, maxneigs, nnodes) + zero_grad = zero(∂𝔹[1, 1, 1]) + edge_𝐫 = [edge.𝐫 for edge in G.edge_data] + + # Compute basis values for each parameter (k, s) pair + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + + # Energy basis + for i in 1:nnodes + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + + # Fill gradient buffer + for i in 1:nnodes + if iZ[i] == s + @views ∇Ei_buf[:, i] .= ∂𝔹[:, i, k] + else + @views ∇Ei_buf[:, i] .= Ref(zero_grad) + end + end + + # Convert to edge format and compute forces/virial + ∇Ei_3d = reshape(∇Ei_buf, maxneigs, nnodes, 1) + ∇E_edges = ET.rev_reshape_embedding(∇Ei_3d, G)[:] + F_basis[:, p] = -ET.Atoms.forces_from_edge_grads(sys, G, ∇E_edges) + + V = zero(SMatrix{3, 3, Float64, 9}) + for (e, ∂edge) in enumerate(∇E_edges) + V -= ∂edge.𝐫 * edge_𝐫[e]' + end + V_basis[p] = V + end + end + + return ( + energy = E_basis * u"eV", + forces = F_basis .* u"eV/Å", + virial = V_basis * u"eV" + ) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::ETPairPotential) + +Compute only the energy basis for pair potential. +""" +function potential_energy_basis(sys::AbstractSystem, calc::ETPairPotential) + G = ET.Atoms.interaction_graph(sys, calc.rcut * u"Å") + pair = _pair(calc) + + 𝔹 = site_basis(pair, G, _ps(calc), _st(calc)) + + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + nparams = nbasis * nspecies + + iZ = pair.readout.selector.(G.node_data) + + E_basis = zeros(nparams) + for s in 1:nspecies + for k in 1:nbasis + p = (s - 1) * nbasis + k + for i in 1:length(G.node_data) + if iZ[i] == s + E_basis[p] += 𝔹[i, k] + end + end + end + end + + return E_basis * u"eV" +end + +""" + get_linear_parameters(calc::ETPairPotential) + +Extract the linear parameters (readout weights) as a flat vector. +""" +function get_linear_parameters(calc::ETPairPotential) + return vec(_ps(calc).readout.W) +end + +""" + set_linear_parameters!(calc::ETPairPotential, θ::AbstractVector) + +Set the linear parameters (readout weights) from a flat vector. +""" +function set_linear_parameters!(calc::ETPairPotential, θ::AbstractVector) + pair = _pair(calc) + nbasis = pair.readout.in_dim + nspecies = pair.readout.ncat + @assert length(θ) == nbasis * nspecies + + ps = _ps(calc) + new_W = reshape(θ, 1, nbasis, nspecies) + calc.ps = merge(ps, (readout = merge(ps.readout, (W = new_W,)),)) + return calc +end + + +# ============================================================================ +# ETOneBodyPotential - Type alias for WrappedSiteCalculator{ETOneBody} +# ============================================================================ + +""" + ETOneBodyPotential + +AtomsCalculators-compatible calculator wrapping an ETOneBody model. +This is a type alias for `WrappedSiteCalculator{<:ETOneBody, PS, ST}`. + +ETOneBody has no learnable parameters, so training assembly returns empty results: +- `length_basis(calc)` returns 0 +- `energy_forces_virial_basis(sys, calc)` returns empty arrays +- Forces and virial are always zero (energy only depends on atom types) + +# Example +```julia +et_onebody = one_body(Dict(:Si => -0.846), x -> x.z) +_, st = Lux.setup(rng, et_onebody) +calc = ETOneBodyPotential(et_onebody, nothing, st, 3.0) +E = potential_energy(sys, calc) +``` +""" +const ETOneBodyPotential{MOD<:ETOneBody, PS, ST} = WrappedSiteCalculator{MOD, PS, ST} + +function ETOneBodyPotential(model::ETOneBody, ps, st, rcut::Real) + return WrappedSiteCalculator(model, ps, st, Float64(rcut)) +end + +# ============================================================================ +# ETOneBodyPotential Training Assembly (empty - no learnable parameters) +# ============================================================================ + +_onebody(calc::ETOneBodyPotential) = calc.model +_ps(calc::ETOneBodyPotential) = calc.ps +_st(calc::ETOneBodyPotential) = calc.st + +""" + length_basis(calc::ETOneBodyPotential) + +Return 0 - ETOneBody has no learnable linear parameters. +""" +length_basis(calc::ETOneBodyPotential) = 0 + +ACEfit.basis_size(calc::ETOneBodyPotential) = 0 + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::ETOneBodyPotential) + +Return empty arrays - ETOneBody has no learnable parameters. +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::ETOneBodyPotential) + natoms = length(sys) + return ( + energy = zeros(0) * u"eV", + forces = zeros(SVector{3, Float64}, natoms, 0) .* u"eV/Å", + virial = zeros(SMatrix{3, 3, Float64, 9}, 0) * u"eV" + ) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::ETOneBodyPotential) + +Return empty array - ETOneBody has no learnable parameters. +""" +potential_energy_basis(sys::AbstractSystem, calc::ETOneBodyPotential) = zeros(0) * u"eV" + +""" + get_linear_parameters(calc::ETOneBodyPotential) + +Return empty vector - ETOneBody has no learnable parameters. +""" +get_linear_parameters(calc::ETOneBodyPotential) = Float64[] + +""" + set_linear_parameters!(calc::ETOneBodyPotential, θ::AbstractVector) + +No-op for ETOneBody (no learnable parameters). +""" +function set_linear_parameters!(calc::ETOneBodyPotential, θ::AbstractVector) + @assert length(θ) == 0 "ETOneBody has no learnable parameters" + return calc +end + + +# ============================================================================ +# Full Model Conversion +# ============================================================================ + +using Random: AbstractRNG, default_rng +using Lux: setup + +""" + convert2et_full(model, ps, st; rng=default_rng()) -> StackedCalculator + +Convert a complete ACE model (E0 + Pair + Many-body) to an ETACE-based +StackedCalculator. This creates a calculator that combines: +1. ETOneBody - reference energies per species +2. ETPairModel - pair potential +3. ETACE - many-body ACE potential + +The returned StackedCalculator is fully compatible with AtomsCalculators +and can be used for energy, forces, and virial evaluation. + +# Arguments +- `model`: ACE model (from ACEpotentials.Models) +- `ps`: Model parameters (from Lux.setup) +- `st`: Model state (from Lux.setup) +- `rng`: Random number generator (default: `default_rng()`) + +# Returns +- `StackedCalculator` combining ETOneBody, ETPairModel, and ETACE + +# Example +```julia +model = ace_model(elements=[:Si], order=3, totaldegree=8) +ps, st = Lux.setup(rng, model) +# ... fit model ... +calc = convert2et_full(model, ps, st) +E = potential_energy(sys, calc) +``` +""" +function convert2et_full(model, ps, st; rng::AbstractRNG=default_rng()) + # Extract cutoff radius from pair basis + rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + + # 1. Convert E0/Vref to ETOneBody + E0s = model.Vref.E0 # Dict{Int, Float64} + zlist = ChemicalSpecies.(model.rbasis._i2z) + E0_dict = Dict(z => E0s[z.atomic_number] for z in zlist) + et_onebody = one_body(E0_dict, x -> x.z) + _, onebody_st = setup(rng, et_onebody) + # Use minimum cutoff for graph construction (ETOneBody needs no neighbors) + onebody_calc = WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + + # 2. Convert pair potential to ETPairModel + et_pair = convertpair(model) + et_pair_ps, et_pair_st = setup(rng, et_pair) + copy_pair_params!(et_pair_ps, ps, model) + pair_calc = WrappedSiteCalculator(et_pair, et_pair_ps, et_pair_st, rcut) + + # 3. Convert many-body to ETACE + et_ace = convert2et(model) + et_ace_ps, et_ace_st = setup(rng, et_ace) + copy_ace_params!(et_ace_ps, ps, model) + ace_calc = WrappedSiteCalculator(et_ace, et_ace_ps, et_ace_st, rcut) + + # 4. Stack all components + return StackedCalculator((onebody_calc, pair_calc, ace_calc)) +end + + +# ============================================================================ +# Parameter Copying Utilities +# ============================================================================ + +""" + copy_ace_params!(et_ps, ps, model) + +Copy many-body (ACE) parameters from ACE model format to ETACE format. +""" +function copy_ace_params!(et_ps, ps, model) + NZ = length(model.rbasis._i2z) + + # Copy radial basis parameters (Wnlq) + # ACE format: Wnlq[:, :, iz, jz] for species pair (iz, jz) + # ETACE format: rembed.post.W[:, :, idx] where idx = (i-1)*NZ + j + # (post is the SelectLinL layer in EmbedDP) + for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + et_ps.rembed.post.W[:, :, idx] .= ps.rbasis.Wnlq[:, :, i, j] + end + + # Copy readout (many-body) parameters + # ACE format: WB[:, s] for species s + # ETACE format: W[1, :, s] + for s in 1:NZ + et_ps.readout.W[1, :, s] .= ps.WB[:, s] + end +end + + +""" + copy_pair_params!(et_ps, ps, model) + +Copy pair potential parameters from ACE model format to ETPairModel format. +Based on parameter mapping from test/etmodels/test_etpair.jl. +""" +function copy_pair_params!(et_ps, ps, model) + NZ = length(model.pairbasis._i2z) + + # Copy pair radial basis parameters + # ACE format: pairbasis.Wnlq[:, :, i, j] for species pair (i, j) + # ETACE format: rembed.rbasis.post.W[:, :, idx] where idx = (i-1)*NZ + j + # (post is the SelectLinL layer in EmbedDP) + for i in 1:NZ, j in 1:NZ + idx = (i-1)*NZ + j + et_ps.rembed.rbasis.post.W[:, :, idx] .= ps.pairbasis.Wnlq[:, :, i, j] + end + + # Copy pair readout parameters + # ACE format: Wpair[:, s] for species s + # ETACE format: readout.W[1, :, s] + for s in 1:NZ + et_ps.readout.W[1, :, s] .= ps.Wpair[:, s] + end +end + diff --git a/src/et_models/et_envbranch.jl b/src/et_models/et_envbranch.jl new file mode 100644 index 000000000..31a5bb44d --- /dev/null +++ b/src/et_models/et_envbranch.jl @@ -0,0 +1,54 @@ + + +using ConcreteStructs +import Polynomials4ML: evaluate, evaluate_ed +import LuxCore: AbstractLuxContainerLayer +import ChainRulesCore: NoTangent, rrule, unthunk + +""" + struct EnvRBranchL + +An auxiliary layer that is basically a branch layer needed to build +radial bases, with additional evaluate_ed functionality, needed for +Jacobians. +""" +@concrete struct EnvRBranchL <: AbstractLuxContainerLayer{(:envelope, :rbasis)} + envelope + rbasis +end + +(l::EnvRBranchL)(X, ps, st) = _apply_envrbranchl(l, X, ps, st), st + +evaluate(l::EnvRBranchL, X, ps, st) = l(X, ps, st) + +function _apply_envrbranchl(l::EnvRBranchL, X, ps, st) + ee, _ = l.envelope(X, ps.envelope, st.envelope) + P, _ = l.rbasis(X, ps.rbasis, st.rbasis) + return ee .* P +end + +function evaluate_ed(l::EnvRBranchL, X, ps, st) + (ee, d_ee), _ = evaluate_ed(l.envelope, X, ps.envelope, st.envelope) + (P, d_P), _ = evaluate_ed(l.rbasis,X, ps.rbasis, st.rbasis) + + # product rule + pP = ee .* P + ∂_pP = d_ee .* P .+ ee .* d_P + + return (pP, ∂_pP), st +end + +function rrule(::typeof(_apply_envrbranchl), + l::EnvRBranchL, X, ps, st) + + (P, dP), st = evaluate_ed(l, X, ps, st) + + function _pb_embeddp(_∂P) + ∂P = unthunk(_∂P) + ∂X = dropdims( sum(∂P .* dP, dims = 2), dims = 2) + return NoTangent(), NoTangent(), ∂X, NoTangent(), NoTangent() + end + + return P, _pb_embeddp +end + diff --git a/src/et_models/et_models.jl b/src/et_models/et_models.jl new file mode 100644 index 000000000..9aedb182e --- /dev/null +++ b/src/et_models/et_models.jl @@ -0,0 +1,24 @@ + +module ETModels + +# utility layers : these should likely be moved into ET or be removed +# if more convenient implementations can be found. +# +include("et_envbranch.jl") + +# ET based ACE model components +include("et_ace.jl") +include("onebody.jl") +include("et_pair.jl") + +# converstion utilities: convert from 0.8 style ACE models to ET based models +include("convert.jl") + +# utilities to convert radial embeddings to splined versions +# for simplicity and performance and to freeze parameters +include("splinify.jl") + +include("et_calculators.jl") +include("stackedcalc.jl") + +end \ No newline at end of file diff --git a/src/et_models/et_pair.jl b/src/et_models/et_pair.jl new file mode 100644 index 000000000..419321d95 --- /dev/null +++ b/src/et_models/et_pair.jl @@ -0,0 +1,66 @@ +# +# This is a temporary model implementation needed due to the fact that +# ETACEModel has Rnl, Ylm hard-coded. In the future it could be tested +# whether the pair model could simply be taken as another ACE model +# with a single embedding rather than several, This would need generalization +# of a fair few methods in both ACEpotentials and EquivariantTensors. +# + + +import EquivariantTensors as ET +import Zygote +import LuxCore: AbstractLuxContainerLayer +using ConcreteStructs: @concrete + + +@concrete struct ETPairModel <: AbstractLuxContainerLayer{(:rembed, :readout)} + rembed # radial embedding layer = basis + readout # normally a selectlinl readout layer +end + + +(l::ETPairModel)(X::ET.ETGraph, ps, st) = _apply_etpairmodel(l, X, ps, st), st + + +function _apply_etpairmodel(l::ETPairModel, X::ET.ETGraph, ps, st) + # evaluate the basis + 𝔹 = site_basis(l, X, ps, st) + + # readout layer + φ, _ = l.readout((𝔹, X.node_data), ps.readout, st.readout) + + return φ +end + +# ----------------------------------------------------------- + + +function site_grads(l::ETPairModel, X::ET.ETGraph, ps, st) + ∂X = Zygote.gradient( X -> sum(_apply_etpairmodel(l, X, ps, st)), X)[1] + return ∂X +end + + +# ----------------------------------------------------------- +# basis and jacobian evaluation + + +function site_basis(l::ETPairModel, X::ET.ETGraph, ps, st) + # embed edges + Rnl, _ = l.rembed(X, ps.rembed, st.rembed) + + # the basis is obtain by summing over the neighbours of each node, + # which is just a sum over the first dimension of Rnl + 𝔹 = dropdims(sum(Rnl, dims=1), dims=1) + + return 𝔹 +end + + +function site_basis_jacobian(l::ETPairModel, X::ET.ETGraph, ps, st) + (R, ∂R), _ = ET.evaluate_ed(l.rembed, X, ps.rembed, st.rembed) + 𝔹 = dropdims(sum(R, dims=1), dims=1) + # ∂𝔹 == ∂R + return 𝔹, ∂R +end + diff --git a/src/et_models/onebody.jl b/src/et_models/onebody.jl new file mode 100644 index 000000000..986e4971d --- /dev/null +++ b/src/et_models/onebody.jl @@ -0,0 +1,76 @@ +# +# NOTES +# First draft of a one-body function that fits into the +# et_model interface. At the moment the E0s are forced to be +# stored as an SVector, to be accessed via the selector. +# But with a little bit of extra work we could add the option +# of making the E0s either constants or learnable parameters. +# + +using Random: AbstractRNG +import EquivariantTensors as ET +using DecoratedParticles: VState +using StaticArrays: SVector + + +""" + one_body(D::Dict, catfun) + +Create a one-body energy model that assigns to each atom an energy based on +a categorical variable that is extracted from the atom state via +`catfun`. The dictionary `D` contains category-value pairs. The one-body +energy assigned to an atom with state `x` is `D[catfun(x)]`. +""" +function one_body(D::Dict, catfun) + categories = SVector(collect(keys(D))...) + E0s = SVector(collect(values(D))...) + NZ = length(categories) + selector = x -> ET.cat2idx(categories, catfun(x)) + return ETOneBody{NZ, eltype(E0s), eltype(categories), typeof(selector)}( + E0s, categories, selector) +end + + +using StaticArrays: SVector +import LuxCore: AbstractLuxLayer, initialparameters, initialstates + + +struct ETOneBody{NZ, T, CAT, TSEL} <: AbstractLuxLayer + E0s::SVector{NZ, T} + categories::SVector{NZ, CAT} + selector::TSEL +end + +initialstates(rng::AbstractRNG, l::ETOneBody) = (; E0s = l.E0s) + + + +(l::ETOneBody)(x, ps, st) = _apply_onebody(l, x, st), st +(l::ETOneBody)(x) = _apply_onebody(l, x, (; E0s = l.E0s)) + +_apply_onebody(l::ETOneBody, X::ET.ETGraph, st) = + _apply_onebody(l, X.node_data, st) + +_apply_onebody(l::ETOneBody, X::AbstractVector, st) = + ___apply_onebody(l.selector, X, st.E0s) + +___apply_onebody(selector, X::AbstractVector, E0s) = + map(x -> E0s[selector(x)], X) + + +# ETOneBody energy only depends on atom types (categorical), not positions. +# Gradient w.r.t. positions is always zero. +# Return empty edge_data array since there are no position-dependent gradients. +function site_grads(l::ETOneBody, X::ET.ETGraph, ps, st) + return (; edge_data = VState[]) +end + +site_basis(l::ETOneBody, X::ET.ETGraph, ps, st) = + fill(zero(eltype(st.E0s)), (ET.nnodes(X), 0)) + +function site_basis_jacobian(l::ETOneBody, X::ET.ETGraph, ps, st) + 𝔹 = site_basis(l, X, ps, st) + ∂𝔹 = fill(VState(), (ET.maxneigs(X), ET.nnodes(X), 0)) + return 𝔹, ∂𝔹 +end + diff --git a/src/et_models/splinify.jl b/src/et_models/splinify.jl new file mode 100644 index 000000000..e36359cea --- /dev/null +++ b/src/et_models/splinify.jl @@ -0,0 +1,60 @@ + +import EquivariantTensors as ET +import Polynomials4ML as P4ML + +# These implementations of `splinify` expect a very specific structure of the +# pair potential basis. In principle it is possible to relax this +# considerably but it needs a little bit of thinking and planning/design +# work before just diving in. To be discussed when needed. + +function splinify(et_pair::ETPairModel, et_ps, et_st; + Nspl = 30) + + # polynomial basis taking y = y(r) as input + trans_y = et_pair.rembed.layer.rbasis.trans + polys_y = et_pair.rembed.layer.rbasis.basis + # weights for learnable radials + WW = et_ps.rembed.rbasis.post.W + # use P4ML to generate individual cubic splines + splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] + # selects the correct spline based on the (Zi, Zj) pair + selector2 = et_pair.rembed.layer.rbasis.post.selector + # envelope multiplying the spline + envelope = et_pair.rembed.layer.envelope + + spl_rbasis = ET.trans_splines(trans_y, splines, selector2, envelope) + + return ETPairModel( ET.EdgeEmbed(spl_rbasis), et_pair.readout ) +end + + +function splinify(et_model::ETACE, et_ps, et_st; Nspl = 50) + + rembed = et_model.rembed.layer # radial embedding, edgeembed stripped + trans = rembed.trans # x -> y dp_transform + rpolys_env = rembed.basis # polynomials * envelope + polys_y = rpolys_env.l.layers.layer_1 # polynomial basis + yenv_func = rpolys_env.l.layers.layer_2.func # envelope function + + # envelope multiplying the spline, apply the transformation a second + # time until we figure out how to reuse it conveniently + trans_yenv = ET.dp_transform( + (x, st) -> yenv_func(trans.f(x, st)), + trans.refstate ) + # selects the correct spline based on the (Zi, Zj) pair + selector2 = rembed.post.selector + # generate the splines using P4ML + WW = et_ps.rembed.post.W + splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] + + rembed_spl = ET.trans_splines(trans, splines, selector2, trans_yenv) + ace_spl = ETACE( ET.EdgeEmbed(rembed_spl), + et_model.yembed, + et_model.basis, + et_model.readout ) + return ace_spl +end \ No newline at end of file diff --git a/src/et_models/stackedcalc.jl b/src/et_models/stackedcalc.jl new file mode 100644 index 000000000..cdcb737b4 --- /dev/null +++ b/src/et_models/stackedcalc.jl @@ -0,0 +1,219 @@ + +# StackedCalculator - Combines multiple AtomsCalculators +# +# Generic utility for combining multiple calculators by summing their +# energy, forces, and virial contributions. +# +# Uses @generated functions with Base.Cartesian for efficient +# compile-time loop unrolling when the number of calculators is known. + +import AtomsCalculators +import AtomsBase: AbstractSystem +using StaticArrays +using Unitful +using Base.Cartesian: @nexprs, @ntuple, @ncall + +""" + StackedCalculator{N, C<:Tuple} + +Combines multiple AtomsCalculators by summing their energy, forces, and virial. +Each calculator in the tuple must implement the AtomsCalculators interface. + +This allows combining site-based calculators (via WrappedSiteCalculator) with +calculators that don't have site decompositions (e.g., Coulomb, dispersion). + +The implementation uses compile-time loop unrolling for efficiency when +the number of calculators is small and known at compile time. + +# Example +```julia +# Wrap site energy models +E0_calc = WrappedSiteCalculator(E0Model(Dict(:Si => -0.846))) +ace_calc = WrappedSiteCalculator(WrappedETACE(et_model, ps, st, 5.5)) + +# Stack them (could also add Coulomb, dispersion, etc.) +calc = StackedCalculator((E0_calc, ace_calc)) + +E = potential_energy(sys, calc) +F = forces(sys, calc) +``` + +# Fields +- `calcs::Tuple` - Tuple of N calculators implementing AtomsCalculators interface +""" +struct StackedCalculator{N, C<:Tuple} + calcs::C +end + +# Constructor that infers N from the tuple length +StackedCalculator(calcs::C) where {C<:Tuple} = StackedCalculator{length(C.parameters), C}(calcs) + +# Get maximum cutoff from all calculators (for informational purposes) +@generated function cutoff_radius(calc::StackedCalculator{N}) where {N} + quote + rcuts = @ntuple $N i -> ustrip(u"Å", cutoff_radius(calc.calcs[i])) + return maximum(rcuts) * u"Å" + end +end + +# ============================================================================ +# Efficient implementations using @generated for compile-time unrolling +# ============================================================================ + +@generated function _stacked_energy(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> E_i = AtomsCalculators.potential_energy(sys, calc.calcs[i]) + return sum(@ntuple $N E) + end +end + +@generated function _stacked_forces(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> F_i = AtomsCalculators.forces(sys, calc.calcs[i]) + return reduce(.+, @ntuple $N F) + end +end + +@generated function _stacked_virial(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> V_i = AtomsCalculators.virial(sys, calc.calcs[i]) + return sum(@ntuple $N V) + end +end + +@generated function _stacked_efv(sys::AbstractSystem, calc::StackedCalculator{N}) where {N} + quote + @nexprs $N i -> efv_i = AtomsCalculators.energy_forces_virial(sys, calc.calcs[i]) + return ( + energy = sum(@ntuple $N i -> efv_i.energy), + forces = reduce(.+, @ntuple $N i -> efv_i.forces), + virial = sum(@ntuple $N i -> efv_i.virial) + ) + end +end + +# ============================================================================ +# AtomsCalculators interface +# ============================================================================ + +AtomsCalculators.@generate_interface function AtomsCalculators.potential_energy( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_energy(sys, calc) +end + +AtomsCalculators.@generate_interface function AtomsCalculators.forces( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_forces(sys, calc) +end + +AtomsCalculators.@generate_interface function AtomsCalculators.virial( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_virial(sys, calc) +end + +function AtomsCalculators.energy_forces_virial( + sys::AbstractSystem, calc::StackedCalculator; kwargs...) + return _stacked_efv(sys, calc) +end + +# ============================================================================ +# Training Assembly Interface for StackedCalculator +# ============================================================================ + +import ACEfit + +""" + length_basis(calc::StackedCalculator) + +Return total number of linear parameters across all stacked calculators. +""" +function length_basis(calc::StackedCalculator) + return sum(length_basis(c) for c in calc.calcs) +end + +ACEfit.basis_size(calc::StackedCalculator) = length_basis(calc) + +""" + energy_forces_virial_basis(sys::AbstractSystem, calc::StackedCalculator) + +Compute concatenated basis for all stacked calculators. +""" +function energy_forces_virial_basis(sys::AbstractSystem, calc::StackedCalculator) + # Collect basis from each calculator + results = [energy_forces_virial_basis(sys, c) for c in calc.calcs] + + natoms = length(sys) + + # Concatenate results - energy is Vector of Quantity{Float64} + E_basis = vcat([_strip_energy_units(r.energy) for r in results]...) + + # For forces, need to hcat the matrices + # Strip units element by element for matrices of SVectors with units + F_parts = [_strip_force_units(r.forces) for r in results] + F_basis = isempty(F_parts) ? zeros(SVector{3, Float64}, natoms, 0) : hcat(F_parts...) + + # Virial is Vector of SMatrix with units + V_basis = vcat([_strip_virial_units(r.virial) for r in results]...) + + return ( + energy = E_basis * u"eV", + forces = F_basis .* u"eV/Å", + virial = V_basis * u"eV" + ) +end + +# Helper to strip units from energy (Vector of Quantity{Float64}) +function _strip_energy_units(E) + return map(e -> ustrip(e), E) +end + +# Helper to strip units from force matrices (Matrix of SVector with units) +function _strip_force_units(F) + # F is Matrix{SVector{3, Quantity}} + # We need to strip the units from the inner SVectors + return map(f -> SVector{3, Float64}(ustrip.(f)), F) +end + +# Helper to strip units from virial (Vector of SMatrix with units) +function _strip_virial_units(V) + # V is Vector{SMatrix{3,3, Quantity}} + return map(v -> SMatrix{3, 3, Float64, 9}(ustrip.(v)), V) +end + +""" + potential_energy_basis(sys::AbstractSystem, calc::StackedCalculator) + +Compute concatenated energy basis for all stacked calculators. +""" +function potential_energy_basis(sys::AbstractSystem, calc::StackedCalculator) + results = [potential_energy_basis(sys, c) for c in calc.calcs] + E_basis = vcat([ustrip.(u"eV", r) for r in results]...) + return E_basis * u"eV" +end + +""" + get_linear_parameters(calc::StackedCalculator) + +Get concatenated linear parameters from all stacked calculators. +""" +function get_linear_parameters(calc::StackedCalculator) + return vcat([get_linear_parameters(c) for c in calc.calcs]...) +end + +""" + set_linear_parameters!(calc::StackedCalculator, θ::AbstractVector) + +Set linear parameters for all stacked calculators from concatenated vector. +""" +function set_linear_parameters!(calc::StackedCalculator, θ::AbstractVector) + offset = 0 + for c in calc.calcs + n = length_basis(c) + if n > 0 + set_linear_parameters!(c, θ[offset+1:offset+n]) + end + offset += n + end + @assert offset == length(θ) "Parameter count mismatch" + return calc +end diff --git a/src/json_interface.jl b/src/json_interface.jl index 5459782df..57768cbb0 100644 --- a/src/json_interface.jl +++ b/src/json_interface.jl @@ -10,7 +10,7 @@ using Optimisers: destructure recursive_dict2nt(x) = x -recursive_dict2nt(D::Dict) = (; +recursive_dict2nt(D::AbstractDict) = (; [ Symbol(key) => recursive_dict2nt(D[key]) for key in keys(D)]... ) function _sanitize_arg(arg) @@ -40,12 +40,12 @@ end # === make fits === """ - make_model(model_dict::Dict) + make_model(model_dict::AbstractDict) User-facing script to generate a model from a dictionary. See documentation for details. """ -function make_model(model_dict::Dict) +function make_model(model_dict::AbstractDict) if model_dict["model_name"] == "ACE1" model_nt = _sanitize_dict(model_dict) return ACE1compat.ace1_model(; model_nt...) @@ -55,7 +55,7 @@ function make_model(model_dict::Dict) end # chho: make this support other solvers -function make_solver(model, solver_dict::Dict, prior_dict::Dict) +function make_solver(model, solver_dict::AbstractDict, prior_dict::AbstractDict) # if no prior is specified, then use I as default, which is dumb if isempty(prior_dict) @@ -81,7 +81,7 @@ function make_solver(model, solver_dict::Dict, prior_dict::Dict) end # calles into functions defined in ACEpotentials.Models -function make_prior(model, prior_dict::Dict) +function make_prior(model, prior_dict::AbstractDict) return ACEpotentials.Models.make_prior(model, namedtuple(prior_dict)) end diff --git a/src/models/ace.jl b/src/models/ace.jl index 826f0ebbc..b7ebe0c62 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -276,7 +276,6 @@ function evaluate(model::ACEModel, # contract with params val = dot(B, (@view ps.WB[:, i_z0])) - # ------------------- # pair potential diff --git a/src/models/ace_heuristics.jl b/src/models/ace_heuristics.jl index b1432497b..bc58afa6f 100644 --- a/src/models/ace_heuristics.jl +++ b/src/models/ace_heuristics.jl @@ -102,9 +102,13 @@ end -_convert_E0s(E0s::Union{Dict, NamedTuple}) = E0s +_convert_E0s(E0s::Union{AbstractDict, NamedTuple}) = E0s _convert_E0s(E0s::Union{AbstractVector, Tuple}) = Dict(E0s...) -_convert_E0s(E0s) = error("E0s must be nothing, a NamedTuple, Dict or list of pairs") +_convert_E0s(E0s) = error( +""" +E0s must be nothing, a NamedTuple, Dict or list of pairs. +Instead it is a $(typeof(E0s)). +""") # E0s can be anything with (key, value) pairs _make_Vref_E0s(elements, E0s) = OneBody(_convert_E0s(E0s)) diff --git a/src/models/models.jl b/src/models/models.jl index 303ab5d3c..5bbeb88e2 100644 --- a/src/models/models.jl +++ b/src/models/models.jl @@ -18,7 +18,9 @@ import LuxCore: AbstractLuxLayer, initialparameters, initialstates -function length_basis end +function length_basis end +function energy_forces_virial_basis end +function potential_energy_basis end include("elements.jl") @@ -45,7 +47,7 @@ include("smoothness_priors.jl") include("utils.jl") -include("fasteval.jl") +# include("fasteval.jl") diff --git a/src/models/radial_envelopes.jl b/src/models/radial_envelopes.jl index 3174e2bdb..00c51aec7 100644 --- a/src/models/radial_envelopes.jl +++ b/src/models/radial_envelopes.jl @@ -4,14 +4,9 @@ abstract type AbstractEnvelope end struct PolyEnvelope1sR{T} rcut::T p::Int - # ------- - meta::Dict{String, Any} end -PolyEnvelope1sR(rcut, p) = - PolyEnvelope1sR(rcut, p, Dict{String, Any}()) - function evaluate(env::PolyEnvelope1sR, r::T, x::T) where T if r >= env.rcut return zero(T) diff --git a/src/models/radial_transforms.jl b/src/models/radial_transforms.jl index 930f27b8b..4a6cc371f 100644 --- a/src/models/radial_transforms.jl +++ b/src/models/radial_transforms.jl @@ -15,11 +15,11 @@ write_dict(T::GeneralizedAgnesiTransform) = Dict("__id__" => "ACEpotentials_GeneralizedAgnesiTransform", "r0" => T.r0, "p" => T.p, "q" => T.q, "a" => T.a, "rin" => T.rin) -GeneralizedAgnesiTransform(D::Dict) = +GeneralizedAgnesiTransform(D::AbstractDict) = GeneralizedAgnesiTransform(D["r0"], D["p"], D["q"], D["a"], D["rin"]) -read_dict(::Val{:ACEpotentials_GeneralizedAgnesiTransform}, D::Dict) = +read_dict(::Val{:ACEpotentials_GeneralizedAgnesiTransform}, D::AbstractDict) = GeneralizedAgnesiTransform(D) function evaluate(t::GeneralizedAgnesiTransform{T}, r::Number) where {T} @@ -98,7 +98,7 @@ write_dict(T::NormalizedTransform) = "yin" => T.yin, "ycut" => T.ycut, "rin" => T.rin, "rcut" => T.rcut ) -read_dict(::Val{:ACEpotentials_NormalizedTransform}, D::Dict) = +read_dict(::Val{:ACEpotentials_NormalizedTransform}, D::AbstractDict) = NormalizedTransform(read_dict(D["trans"]), D["yin"], D["ycut"], D["rin"], D["rcut"]) diff --git a/test/Project.toml b/test/Project.toml index 37e4d10ba..5113a70ff 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,11 +1,17 @@ [deps] ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" +ACEfit = "ad31a8ef-59f5-4a01-b543-a85c2f73e95c" +ACEpotentials = "3b96b61c-0fcc-4693-95ed-1ef9f35fcc53" AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsBuilder = "f5cc8831-eeb7-4288-8d9f-d6c1ddb77004" AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1" AtomsCalculatorsUtilities = "9855a07e-8816-4d1b-ac92-859c17475477" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DecoratedParticles = "023d0394-cb16-4d2d-a5c7-724bed42bbb6" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EmpiricalPotentials = "38527215-9240-4c91-a638-d4250620c9e2" +EquivariantTensors = "5e107534-7145-4f8f-b06f-47a52840c895" ExtXYZ = "352459e4-ddd7-4360-8937-99dcb397b478" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" @@ -13,7 +19,9 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f" @@ -22,8 +30,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +ACEpotentials = {path = ".."} + [compat] +EquivariantTensors = "0.4.3" StaticArrays = "1" diff --git a/test/benchmark_comparison.jl b/test/benchmark_comparison.jl new file mode 100644 index 000000000..5ad9fea45 --- /dev/null +++ b/test/benchmark_comparison.jl @@ -0,0 +1,172 @@ +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators +using StaticArrays, Lux, Random, LuxCore, LinearAlgebra +using AtomsBase, AtomsBuilder, Unitful +using BenchmarkTools +using Printf + +# GPU detection (simplified from ET test utils) +dev = identity +has_cuda = false + +try + using CUDA + if CUDA.functional() + @info "Using CUDA" + CUDA.versioninfo() + global has_cuda = true + global dev = cu + else + @info "CUDA is not functional" + end +catch e + @info "Couldn't load CUDA: $e" +end + +if !has_cuda + @info "No GPU available. Using CPU only." +end + +rng = Random.MersenneTwister(1234) + +# Build models +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, model) + +# Kill the pair basis for fair comparison +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +# Create old ACE calculator +ace_calc = M.ACEPotential(model, ps, st) + +# Convert to ETACE +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(rng, et_model) + +# Copy parameters +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps.readout.W[1, :, 2] .= ps.WB[:, 2] + +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# GPU setup +has_gpu = has_cuda +if has_gpu + et_ps_32 = ET.float32(et_ps) + et_st_32 = ET.float32(et_st) + et_ps_gpu = dev(et_ps_32) + et_st_gpu = dev(et_st_32) +end + +# Function to create system of given size +function make_system(n_repeat) + sys = AtomsBuilder.bulk(:Si, cubic=true) * n_repeat + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +# Benchmark configurations (tuple for bulk multiplication) +configs = [ + (1, 1, 1), # 8 atoms + (2, 1, 1), # 16 atoms + (2, 2, 2), # 64 atoms + (3, 3, 2), # 144 atoms + (4, 4, 2), # 256 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("=" ^ 85) +println("BENCHMARK: ACE (CPU) vs ETACE (CPU) vs ETACE (GPU)") +println("=" ^ 85) +println() + +# Header +if has_gpu + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | ETACE GPU (ms) | CPU Speedup | GPU Speedup |") + println("|-------|---------|--------------|----------------|----------------|-------------|-------------|") +else + println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") + println("|-------|---------|--------------|----------------|-------------|") +end + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.potential_energy(sys, ace_calc) + + # Warmup ETACE CPU + _ = sum(et_model(G, et_ps, et_st)[1]) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.potential_energy($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU (graph construction NOT included for fair comparison with GPU) + t_etace_cpu = @belapsed sum($et_model($G, $et_ps, $et_st)[1]) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + if has_gpu + # Convert graph to GPU (Float32) + G_32 = ET.float32(G) + G_gpu = dev(G_32) + + # Warmup GPU + _ = sum(et_model(G_gpu, et_ps_gpu, et_st_gpu)[1]) + + # Benchmark ETACE GPU (graph already on GPU) + t_etace_gpu = @belapsed sum($et_model($G_gpu, $et_ps_gpu, $et_st_gpu)[1]) samples=5 evals=3 + t_etace_gpu_ms = t_etace_gpu * 1000 + + gpu_speedup = t_ace_ms / t_etace_gpu_ms + + @printf("| %5d | %7d | %12.2f | %14.2f | %14.2f | %10.1fx | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, t_etace_gpu_ms, cpu_speedup, gpu_speedup) + else + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) + end +end + +println() +println("Notes:") +println("- ACE CPU: Original ACEpotentials model (pair basis zeroed for fair comparison)") +println("- ETACE CPU: EquivariantTensors backend on CPU (Float64)") +println("- ETACE GPU: EquivariantTensors backend on GPU (Float32)") +println("- CPU Speedup = ACE CPU / ETACE CPU") +println("- GPU Speedup = ACE CPU / ETACE GPU") +println("- Graph construction time NOT included (currently CPU-only)") diff --git a/test/benchmark_forces.jl b/test/benchmark_forces.jl new file mode 100644 index 000000000..afef3a864 --- /dev/null +++ b/test/benchmark_forces.jl @@ -0,0 +1,122 @@ +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators +using StaticArrays, Lux, Random, LuxCore, LinearAlgebra +using AtomsBase, AtomsBuilder, Unitful +using BenchmarkTools +using Printf + +rng = Random.MersenneTwister(1234) + +# Build models +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 8 +order = 2 +maxl = 4 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, model) + +# Kill the pair basis for fair comparison +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +# Create old ACE calculator +ace_calc = M.ACEPotential(model, ps, st) + +# Convert to ETACE +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(rng, et_model) + +# Copy parameters +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps.readout.W[1, :, 2] .= ps.WB[:, 2] + +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# Function to create system of given size +function make_system(n_repeat) + sys = AtomsBuilder.bulk(:Si, cubic=true) * n_repeat + rattle!(sys, 0.1u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +# ETACE forces function (CPU only for now) +function etace_forces(et_model, G, sys, et_ps, et_st) + ∂G = ETM.site_grads(et_model, G, et_ps, et_st) + return -ET.Atoms.forces_from_edge_grads(sys, G, ∂G.edge_data) +end + +# Benchmark configurations (tuple for bulk multiplication) +configs = [ + (1, 1, 1), # 8 atoms + (2, 1, 1), # 16 atoms + (2, 2, 2), # 64 atoms + (3, 3, 2), # 144 atoms + (4, 4, 2), # 256 atoms + (4, 4, 4), # 512 atoms + (5, 5, 4), # 800 atoms +] + +println() +println("=" ^ 70) +println("BENCHMARK: Forces - ACE (CPU) vs ETACE (CPU)") +println("=" ^ 70) +println() + +# Header +println("| Atoms | Edges | ACE CPU (ms) | ETACE CPU (ms) | CPU Speedup |") +println("|-------|---------|--------------|----------------|-------------|") + +for cfg in configs + sys = make_system(cfg) + natoms = length(sys) + + # Count edges + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + nedges = length(G.edge_data) + + # Warmup ACE + _ = AtomsCalculators.forces(sys, ace_calc) + + # Warmup ETACE CPU + _ = etace_forces(et_model, G, sys, et_ps, et_st) + + # Benchmark ACE CPU + t_ace = @belapsed AtomsCalculators.forces($sys, $ace_calc) samples=5 evals=3 + t_ace_ms = t_ace * 1000 + + # Benchmark ETACE CPU (graph construction NOT included for fair comparison) + t_etace_cpu = @belapsed etace_forces($et_model, $G, $sys, $et_ps, $et_st) samples=5 evals=3 + t_etace_cpu_ms = t_etace_cpu * 1000 + + cpu_speedup = t_ace_ms / t_etace_cpu_ms + + @printf("| %5d | %7d | %12.2f | %14.2f | %10.1fx |\n", + natoms, nedges, t_ace_ms, t_etace_cpu_ms, cpu_speedup) +end + +println() +println("Notes:") +println("- ACE CPU: Original ACEpotentials model (pair basis zeroed for fair comparison)") +println("- ETACE CPU: EquivariantTensors backend on CPU (Float64)") +println("- CPU Speedup = ACE CPU / ETACE CPU") +println("- Graph construction time NOT included") diff --git a/test/et_models/test_et_calculators.jl b/test/et_models/test_et_calculators.jl new file mode 100644 index 000000000..5b14a5b77 --- /dev/null +++ b/test/et_models/test_et_calculators.jl @@ -0,0 +1,776 @@ +# Tests for ETACEPotential calculator interface +# +# These tests verify: +# 1. Energy consistency between ETACE model and ETACEPotential +# 2. Force consistency against original ACE model +# 3. Virial consistency against original ACE model +# 4. AtomsCalculators interface compliance + +using Test, ACEbase, BenchmarkTools +using Polynomials4ML.Testing: print_tf, println_slim + +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +import EquivariantTensors as ET +import AtomsCalculators + +using AtomsBase, AtomsBuilder, Unitful +using Random, LuxCore, StaticArrays, LinearAlgebra + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## +# Build an ETACE model for testing + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 + +# Use same cutoff for all elements +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = LuxCore.setup(rng, model) + +# Kill pair basis for clarity (test only ACE part) +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +# Convert to ETACE model +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) + +# Match parameters using utility function +ETM.copy_ace_params!(et_ps, ps, model) + +# Get cutoff radius +rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts) + +# Helper to generate random structures +function rand_struct() + sys = AtomsBuilder.bulk(:Si) * (2, 2, 1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +## + +@info("Testing ETACEPotential construction") + +# Create calculator from ETACE model +# ETACEPotential is now WrappedSiteCalculator{ETACE} (direct, no WrappedETACE) +et_calc = ETM.ETACEPotential(et_model, et_ps, et_st, rcut) + +# Access underlying ETACE directly via calc.model +@test et_calc.model === et_model +@test et_calc.rcut == rcut +@test et_calc.co_ps === nothing + +## + +@info("Testing energy consistency: ETACE model vs ETACEPotential") + +for ntest = 1:20 + local sys, G, E_model, E_calc + + sys = rand_struct() + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + + # Energy from direct model evaluation + Ei_model, _ = et_model(G, et_ps, et_st) + E_model = sum(Ei_model) + + # Energy from calculator + E_calc = AtomsCalculators.potential_energy(sys, et_calc) + + print_tf(@test abs(E_model - ustrip(E_calc)) < 1e-10) +end +println() + +## + +@info("Testing energy consistency: ETACE vs original ACE model") + +# Wrap original ACE model into calculator +calc_model = M.ACEPotential(model, ps, st) + +for ntest = 1:20 + local sys, E_old, E_new + + sys = rand_struct() + E_old = AtomsCalculators.potential_energy(sys, calc_model) + E_new = AtomsCalculators.potential_energy(sys, et_calc) + + print_tf(@test abs(ustrip(E_old) - ustrip(E_new)) < 1e-6) +end +println() + +## + +@info("Testing forces consistency: ETACE vs original ACE model") + +for ntest = 1:20 + local sys, F_old, F_new + + sys = rand_struct() + F_old = AtomsCalculators.forces(sys, calc_model) + F_new = AtomsCalculators.forces(sys, et_calc) + + # Compare force magnitudes (allow small numerical differences) + max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_old, F_new)) + print_tf(@test max_diff < 1e-6) +end +println() + +## + +@info("Testing virial consistency: ETACE vs original ACE model") + +for ntest = 1:20 + local sys, V_old, V_new + + sys = rand_struct() + efv_old = AtomsCalculators.energy_forces_virial(sys, calc_model) + efv_new = AtomsCalculators.energy_forces_virial(sys, et_calc) + + V_old = ustrip.(efv_old.virial) + V_new = ustrip.(efv_new.virial) + + # Compare virial tensors + print_tf(@test norm(V_old - V_new) / (norm(V_old) + 1e-10) < 1e-6) +end +println() + +## + +@info("Testing AtomsCalculators interface compliance") + +sys = rand_struct() + +# Test individual methods +E = AtomsCalculators.potential_energy(sys, et_calc) +F = AtomsCalculators.forces(sys, et_calc) +V = AtomsCalculators.virial(sys, et_calc) + +@test E isa typeof(1.0u"eV") +@test eltype(F) <: StaticArrays.SVector +@test V isa StaticArrays.SMatrix + +## + +@info("Testing combined energy_forces_virial efficiency") + +sys = rand_struct() + +# Combined evaluation +efv1 = AtomsCalculators.energy_forces_virial(sys, et_calc) + +# Separate evaluations +E = AtomsCalculators.potential_energy(sys, et_calc) +F = AtomsCalculators.forces(sys, et_calc) +V = AtomsCalculators.virial(sys, et_calc) + +@test ustrip(efv1.energy) ≈ ustrip(E) +@test all(ustrip.(efv1.forces) .≈ ustrip.(F)) +@test ustrip.(efv1.virial) ≈ ustrip.(V) + +## + +@info("Testing cutoff_radius function") + +@test ETM.cutoff_radius(et_calc) == rcut * u"Å" + +## + +@info("All Phase 1 tests passed!") + +# ============================================================================ +# Phase 2 Tests: WrappedSiteCalculator and StackedCalculator +# ============================================================================ + +@info("Testing Phase 2: WrappedSiteCalculator and StackedCalculator") + +## + +@info("Testing ETOneBody (upstream one-body model)") + +using Lux + +# Create ETOneBody model with reference energies (using upstream interface) +E0_Si = -0.846 +E0_O = -2.15 +et_onebody = ETM.one_body(Dict(:Si => E0_Si, :O => E0_O), x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) + +# Test site energies via direct model call +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +Ei_E0, _ = et_onebody(G, nothing, onebody_st) + +# Count Si and O atoms +n_Si = count(node -> node.z == AtomsBase.ChemicalSpecies(:Si), G.node_data) +n_O = count(node -> node.z == AtomsBase.ChemicalSpecies(:O), G.node_data) +expected_E0 = n_Si * E0_Si + n_O * E0_O + +@test length(Ei_E0) == length(sys) +@test sum(Ei_E0) ≈ expected_E0 + +# Test site gradients (should be empty for constant energies) +# Returns NamedTuple with empty edge_data, matching ETACE/ETPairModel interface +∂G_E0 = ETM.site_grads(et_onebody, G, nothing, onebody_st) +@test isempty(∂G_E0.edge_data) + +## + +@info("Testing WrappedSiteCalculator with ETOneBody") + +# Wrap ETOneBody in a calculator (using new unified interface) +E0_calc = ETM.WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) +@test ustrip(u"Å", ETM.cutoff_radius(E0_calc)) == 3.0 # minimum cutoff + +# Test ETOneBody calculator energy +sys = rand_struct() +E_E0_calc = AtomsCalculators.potential_energy(sys, E0_calc) +G = ET.Atoms.interaction_graph(sys, 3.0 * u"Å") +n_Si = count(node -> node.z == AtomsBase.ChemicalSpecies(:Si), G.node_data) +n_O = count(node -> node.z == AtomsBase.ChemicalSpecies(:O), G.node_data) +expected_E = (n_Si * E0_Si + n_O * E0_O) * u"eV" +@test ustrip(E_E0_calc) ≈ ustrip(expected_E) + +# Test ETOneBody calculator forces (should be zero) +F_E0_calc = AtomsCalculators.forces(sys, E0_calc) +@test all(norm(ustrip.(f)) < 1e-14 for f in F_E0_calc) + +## + +@info("Testing WrappedSiteCalculator with ETACE") + +# Wrap ETACE model in a calculator (unified interface) +ace_site_calc = ETM.WrappedSiteCalculator(et_model, et_ps, et_st, rcut) +@test ustrip(u"Å", ETM.cutoff_radius(ace_site_calc)) == rcut + +# Test ETACE calculator matches ETACEPotential +sys = rand_struct() +E_ace_site = AtomsCalculators.potential_energy(sys, ace_site_calc) +E_ace_pot = AtomsCalculators.potential_energy(sys, et_calc) +@test ustrip(E_ace_site) ≈ ustrip(E_ace_pot) + +F_ace_site = AtomsCalculators.forces(sys, ace_site_calc) +F_ace_pot = AtomsCalculators.forces(sys, et_calc) +max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_ace_site, F_ace_pot)) +@test max_diff < 1e-10 + +## + +@info("Testing StackedCalculator construction") + +# Create stacked calculator with E0 + ACE (both wrapped) +stacked = ETM.StackedCalculator((E0_calc, ace_site_calc)) + +@test ustrip(u"Å", ETM.cutoff_radius(stacked)) == rcut +@test length(stacked.calcs) == 2 + +## + +@info("Testing StackedCalculator energy consistency") + +for ntest = 1:10 + local sys, E_stacked, E_separate + + sys = rand_struct() + + # Energy from stacked calculator + E_stacked = AtomsCalculators.potential_energy(sys, stacked) + + # Energy from separate evaluations + E_E0 = AtomsCalculators.potential_energy(sys, E0_calc) + E_ace = AtomsCalculators.potential_energy(sys, ace_site_calc) + E_separate = E_E0 + E_ace + + print_tf(@test ustrip(E_stacked) ≈ ustrip(E_separate)) +end +println() + +## + +@info("Testing StackedCalculator forces consistency") + +for ntest = 1:10 + local sys, F_stacked, F_ace_only, max_diff + + sys = rand_struct() + + # Forces from stacked calculator + F_stacked = AtomsCalculators.forces(sys, stacked) + + # Forces from ACE-only (E0 has zero forces) + F_ace_only = AtomsCalculators.forces(sys, et_calc) + + # Should be identical since E0 contributes zero forces + max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_stacked, F_ace_only)) + print_tf(@test max_diff < 1e-10) +end +println() + +## + +@info("Testing StackedCalculator virial consistency") + +for ntest = 1:10 + local sys, efv_stacked, efv_ace_only + + sys = rand_struct() + + efv_stacked = AtomsCalculators.energy_forces_virial(sys, stacked) + efv_ace_only = AtomsCalculators.energy_forces_virial(sys, et_calc) + + # Virial should match (E0 has zero virial) + V_stacked = ustrip.(efv_stacked.virial) + V_ace_only = ustrip.(efv_ace_only.virial) + + print_tf(@test norm(V_stacked - V_ace_only) / (norm(V_ace_only) + 1e-10) < 1e-10) +end +println() + +## + +@info("Testing StackedCalculator with ETOneBody only") + +# Create stacked calculator with just ETOneBody (E0_calc is WrappedSiteCalculator{ETOneBody}) +E0_only_stacked = ETM.StackedCalculator((E0_calc,)) + +sys = rand_struct() +E = AtomsCalculators.potential_energy(sys, E0_only_stacked) +F = AtomsCalculators.forces(sys, E0_only_stacked) + +# Energy should match E0_calc +E_direct = AtomsCalculators.potential_energy(sys, E0_calc) +@test ustrip(E) ≈ ustrip(E_direct) + +# Forces should be zero +@test all(norm(ustrip.(f)) < 1e-14 for f in F) + +## + +@info("All Phase 2 tests passed!") + +## ============================================================================ +## Phase 5: Training Assembly Tests +## ============================================================================ + +@info("Testing Phase 5: Training assembly functions") + +## + +@info("Testing length_basis") +nparams = ETM.length_basis(et_calc) +nbasis = et_model.readout.in_dim +nspecies = et_model.readout.ncat +@test nparams == nbasis * nspecies + +## + +@info("Testing get/set_linear_parameters round-trip") +θ_orig = ETM.get_linear_parameters(et_calc) +@test length(θ_orig) == nparams + +# Modify and restore +θ_test = randn(nparams) +ETM.set_linear_parameters!(et_calc, θ_test) +θ_check = ETM.get_linear_parameters(et_calc) +@test θ_check ≈ θ_test + +# Restore original +ETM.set_linear_parameters!(et_calc, θ_orig) +@test ETM.get_linear_parameters(et_calc) ≈ θ_orig + +## + +@info("Testing potential_energy_basis") +sys = rand_struct() +E_basis = ETM.potential_energy_basis(sys, et_calc) +@test length(E_basis) == nparams +@test eltype(ustrip.(E_basis)) <: Real + +## + +@info("Testing energy_forces_virial_basis") +efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) +natoms = length(sys) + +@test length(efv_basis.energy) == nparams +@test size(efv_basis.forces) == (natoms, nparams) +@test length(efv_basis.virial) == nparams + +## + +@info("Testing linear combination gives correct energy") + +# E = dot(E_basis, θ) should match potential_energy +θ = ETM.get_linear_parameters(et_calc) +E_from_basis = dot(ustrip.(efv_basis.energy), θ) +E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + +print_tf(@test E_from_basis ≈ E_direct rtol=1e-10) +println() + +## + +@info("Testing linear combination gives correct forces") + +# F = efv_basis.forces * θ should match forces +F_from_basis = efv_basis.forces * θ +F_direct = AtomsCalculators.forces(sys, et_calc) + +max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) +print_tf(@test max_diff < 1e-10) +println() + +## + +@info("Testing linear combination gives correct virial") + +# V = sum(θ .* virial) should match virial +V_from_basis = sum(θ[i] * ustrip.(efv_basis.virial[i]) for i in 1:nparams) +V_direct = ustrip.(AtomsCalculators.virial(sys, et_calc)) + +virial_diff = maximum(abs.(V_from_basis - V_direct)) +print_tf(@test virial_diff < 1e-10) +println() + +## + +@info("Testing potential_energy_basis matches energy from efv_basis") +@test ustrip.(E_basis) ≈ ustrip.(efv_basis.energy) + +## + +@info("All Phase 5 basic tests passed!") + +## ============================================================================ +## Phase 5b: Extended Training Assembly Tests +## ============================================================================ + +@info("Testing Phase 5b: Extended training assembly tests") + +## + +@info("Testing ACEfit.basis_size integration") +import ACEfit +@test ACEfit.basis_size(et_calc) == ETM.length_basis(et_calc) + +## + +@info("Testing training assembly on multiple structures") + +# Generate multiple random structures +nstructs = 5 +test_systems = [rand_struct() for _ in 1:nstructs] + +all_ok = true +for (i, sys) in enumerate(test_systems) + local θ = ETM.get_linear_parameters(et_calc) + local efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) + + # Check energy + local E_from_basis = dot(ustrip.(efv_basis.energy), θ) + local E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + if !isapprox(E_from_basis, E_direct, rtol=1e-10) + @warn "Energy mismatch on structure $i" + all_ok = false + end + + # Check forces + local F_from_basis = efv_basis.forces * θ + local F_direct = AtomsCalculators.forces(sys, et_calc) + local max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) + if max_diff >= 1e-10 + @warn "Force mismatch on structure $i: max_diff = $max_diff" + all_ok = false + end + + # Check virial + local V_from_basis = sum(θ[k] * ustrip.(efv_basis.virial[k]) for k in 1:length(θ)) + local V_direct = ustrip.(AtomsCalculators.virial(sys, et_calc)) + local virial_diff = maximum(abs.(V_from_basis - V_direct)) + if virial_diff >= 1e-10 + @warn "Virial mismatch on structure $i: max_diff = $virial_diff" + all_ok = false + end +end +print_tf(@test all_ok) +println() + +## + +@info("Testing multi-species parameter ordering") + +# Create structures with varying species compositions +# Pure Si +sys_si = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_si, 0.1u"Å") + +# Pure O (use Si lattice but with O atoms) +sys_o = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_o, 0.1u"Å") +AtomsBuilder.randz!(sys_o, [:O => 1.0]) + +# Mixed 50/50 +sys_mixed = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_mixed, 0.1u"Å") +AtomsBuilder.randz!(sys_mixed, [:Si => 0.5, :O => 0.5]) + +# Mixed 25/75 +sys_mixed2 = AtomsBuilder.bulk(:Si) * (2, 2, 1) +rattle!(sys_mixed2, 0.1u"Å") +AtomsBuilder.randz!(sys_mixed2, [:Si => 0.25, :O => 0.75]) + +species_test_systems = [sys_si, sys_o, sys_mixed, sys_mixed2] +species_labels = ["Pure Si", "Pure O", "50/50 Si:O", "25/75 Si:O"] + +all_species_ok = true +for (label, sys) in zip(species_labels, species_test_systems) + local θ = ETM.get_linear_parameters(et_calc) + local efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) + + # Check energy consistency + local E_from_basis = dot(ustrip.(efv_basis.energy), θ) + local E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + + if !isapprox(E_from_basis, E_direct, rtol=1e-10) + @warn "Energy mismatch for $label: basis=$E_from_basis, direct=$E_direct" + all_species_ok = false + end + + # Check forces + local F_from_basis = efv_basis.forces * θ + local F_direct = AtomsCalculators.forces(sys, et_calc) + local max_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) + + if max_diff >= 1e-10 + @warn "Force mismatch for $label: max_diff=$max_diff" + all_species_ok = false + end +end +print_tf(@test all_species_ok) +println() + +## + +@info("Testing species-specific basis contributions") + +# Verify that different species contribute to different parts of the basis +nbasis = et_model.readout.in_dim +nspecies = et_model.readout.ncat + +# For pure Si system, only Si basis should contribute +efv_si = ETM.energy_forces_virial_basis(sys_si, et_calc) +E_basis_si = ustrip.(efv_si.energy) + +# For pure O system, only O basis should contribute +efv_o = ETM.energy_forces_virial_basis(sys_o, et_calc) +E_basis_o = ustrip.(efv_o.energy) + +# Check that the patterns differ (different species activate different parameters) +# Si uses parameters 1:nbasis, O uses parameters (nbasis+1):(2*nbasis) +si_params = E_basis_si[1:nbasis] +o_params_for_si = E_basis_si[(nbasis+1):end] +o_params = E_basis_o[(nbasis+1):end] +si_params_for_o = E_basis_o[1:nbasis] + +# Pure Si should have zero contribution from O parameters +@test all(abs.(o_params_for_si) .< 1e-12) +# Pure O should have zero contribution from Si parameters +@test all(abs.(si_params_for_o) .< 1e-12) +# Pure Si should have nonzero Si parameters +@test any(abs.(si_params) .> 1e-12) +# Pure O should have nonzero O parameters +@test any(abs.(o_params) .> 1e-12) + +## + +@info("All Phase 5b extended tests passed!") + +## ============================================================================ +## Phase 5c: Training Assembly for ETPairPotential, ETOneBodyPotential, StackedCalculator +## ============================================================================ + +@info("Testing Phase 5c: Training assembly for pair, onebody, and stacked calculators") + +## + +@info("Testing ETOneBodyPotential training assembly (empty - no learnable params)") + +# Create ETOneBody calculator +E0s = model.Vref.E0 +zlist = ChemicalSpecies.(model.rbasis._i2z) +E0_dict = Dict(z => E0s[z.atomic_number] for z in zlist) +et_onebody = ETM.one_body(E0_dict, x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) +onebody_calc = ETM.WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + +# Test length_basis returns 0 +@test ETM.length_basis(onebody_calc) == 0 + +# Test energy_forces_virial_basis returns empty arrays +sys = rand_struct() +efv_onebody = ETM.energy_forces_virial_basis(sys, onebody_calc) +@test length(efv_onebody.energy) == 0 +@test size(efv_onebody.forces, 2) == 0 +@test length(efv_onebody.virial) == 0 + +# Test get/set_linear_parameters +@test length(ETM.get_linear_parameters(onebody_calc)) == 0 +ETM.set_linear_parameters!(onebody_calc, Float64[]) # Should not error + +# Test ACEfit.basis_size +@test ACEfit.basis_size(onebody_calc) == 0 + +## + +@info("Testing ETPairPotential training assembly") + +# Need a model with learnable pair basis for this test +# Create a new model with pair_learnable=true +elements_pair = (:Si, :O) +level_pair = M.TotalDegree() +max_level_pair = 10 +order_pair = 3 +maxl_pair = 4 + +rin0cuts_pair = M._default_rin0cuts(elements_pair) +rin0cuts_pair = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts_pair) + +model_pair = M.ace_model(; elements = elements_pair, order = order_pair, + Ytype = :solid, level = level_pair, max_level = max_level_pair, + maxl = maxl_pair, pair_maxn = max_level_pair, + rin0cuts = rin0cuts_pair, + pair_learnable = true, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) +ps_pair, st_pair = Lux.setup(rng, model_pair) + +# Convert pair potential +et_pair = ETM.convertpair(model_pair) +et_pair_ps, et_pair_st = Lux.setup(rng, et_pair) + +# Copy pair parameters using utility function +ETM.copy_pair_params!(et_pair_ps, ps_pair, model_pair) + +rcut_pair = maximum(a.rcut for a in model_pair.pairbasis.rin0cuts) +pair_calc = ETM.ETPairPotential(et_pair, et_pair_ps, et_pair_st, rcut_pair) + +# Test length_basis +pair_nbasis = et_pair.readout.in_dim +pair_nspecies = et_pair.readout.ncat +@test ETM.length_basis(pair_calc) == pair_nbasis * pair_nspecies + +# Test energy_forces_virial_basis +sys_pair = rand_struct() # Uses Si/O system from earlier +efv_pair = ETM.energy_forces_virial_basis(sys_pair, pair_calc) +natoms_pair = length(sys_pair) +nparams_pair = ETM.length_basis(pair_calc) + +@test length(efv_pair.energy) == nparams_pair +@test size(efv_pair.forces) == (natoms_pair, nparams_pair) +@test length(efv_pair.virial) == nparams_pair + +# Test linear combination gives correct energy +θ_pair = ETM.get_linear_parameters(pair_calc) +E_from_pair_basis = dot(ustrip.(efv_pair.energy), θ_pair) +E_pair_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys_pair, pair_calc)) +print_tf(@test E_from_pair_basis ≈ E_pair_direct rtol=1e-10) +println() + +# Test get/set round-trip +θ_pair_test = randn(nparams_pair) +ETM.set_linear_parameters!(pair_calc, θ_pair_test) +@test ETM.get_linear_parameters(pair_calc) ≈ θ_pair_test +ETM.set_linear_parameters!(pair_calc, θ_pair) # Restore + +# Test ACEfit.basis_size +@test ACEfit.basis_size(pair_calc) == nparams_pair + +## + +@info("Testing StackedCalculator training assembly") + +# Create a StackedCalculator with E0 + Pair + ManyBody (using convert2et_full) +stacked_calc = ETM.convert2et_full(model_pair, ps_pair, st_pair) + +# Verify structure: 3 components (ETOneBody, ETPairModel, ETACE) +@test length(stacked_calc.calcs) == 3 + +# Test length_basis is sum of components +n_onebody = ETM.length_basis(stacked_calc.calcs[1]) +n_pair = ETM.length_basis(stacked_calc.calcs[2]) +n_ace = ETM.length_basis(stacked_calc.calcs[3]) +n_total = ETM.length_basis(stacked_calc) + +@test n_onebody == 0 # ETOneBody has no learnable params +@test n_pair > 0 +@test n_ace > 0 +@test n_total == n_onebody + n_pair + n_ace + +# Test energy_forces_virial_basis +sys_stacked = rand_struct() +efv_stacked = ETM.energy_forces_virial_basis(sys_stacked, stacked_calc) +natoms_stacked = length(sys_stacked) + +@test length(efv_stacked.energy) == n_total +@test size(efv_stacked.forces) == (natoms_stacked, n_total) +@test length(efv_stacked.virial) == n_total + +# Test linear combination gives correct energy +θ_stacked = ETM.get_linear_parameters(stacked_calc) +@test length(θ_stacked) == n_total +E_from_stacked_basis = dot(ustrip.(efv_stacked.energy), θ_stacked) +E_stacked_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys_stacked, stacked_calc)) +print_tf(@test E_from_stacked_basis ≈ E_stacked_direct rtol=1e-10) +println() + +# Test linear combination gives correct forces +F_from_stacked_basis = efv_stacked.forces * θ_stacked +F_stacked_direct = AtomsCalculators.forces(sys_stacked, stacked_calc) +max_diff_stacked_F = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_stacked_basis, F_stacked_direct)) +print_tf(@test max_diff_stacked_F < 1e-10) +println() + +# Test linear combination gives correct virial +V_from_stacked_basis = sum(θ_stacked[k] * ustrip.(efv_stacked.virial[k]) for k in 1:n_total) +V_stacked_direct = ustrip.(AtomsCalculators.virial(sys_stacked, stacked_calc)) +virial_diff_stacked = maximum(abs.(V_from_stacked_basis - V_stacked_direct)) +print_tf(@test virial_diff_stacked < 1e-10) +println() + +# Test get/set_linear_parameters round-trip +θ_stacked_orig = copy(θ_stacked) +θ_stacked_test = randn(n_total) +ETM.set_linear_parameters!(stacked_calc, θ_stacked_test) +θ_stacked_check = ETM.get_linear_parameters(stacked_calc) +@test θ_stacked_check ≈ θ_stacked_test +ETM.set_linear_parameters!(stacked_calc, θ_stacked_orig) # Restore + +# Test potential_energy_basis consistency +E_basis_stacked = ETM.potential_energy_basis(sys_stacked, stacked_calc) +@test length(E_basis_stacked) == n_total +@test ustrip.(E_basis_stacked) ≈ ustrip.(efv_stacked.energy) rtol=1e-10 + +# Test ACEfit.basis_size +@test ACEfit.basis_size(stacked_calc) == n_total + +## + +@info("All Phase 5c tests passed!") diff --git a/test/et_models/test_et_silicon.jl b/test/et_models/test_et_silicon.jl new file mode 100644 index 000000000..204a52f99 --- /dev/null +++ b/test/et_models/test_et_silicon.jl @@ -0,0 +1,231 @@ +# Integration test for ETACE calculators +# +# This test verifies that ETACE calculators produce identical results +# to the many-body part of ACE models (excluding pair potential). +# +# Note: convert2et only supports LearnableRnlrzzBasis (not SplineRnlrzzBasis), +# so we use ace_model() directly instead of ace1_model(). +# ETACE implements only the many-body basis, not the pair potential. + +using Test +using ACEpotentials +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels + +using ExtXYZ, AtomsBase, Unitful, StaticArrays +using AtomsCalculators +using LazyArtifacts +using LuxCore, Lux, Random, LinearAlgebra + +@info("ETACE Integration Test: Silicon dataset") + +## ----- setup ----- + +# Build model using ace_model (LearnableRnlrzzBasis, compatible with convert2et) +elements = (:Si,) +level = M.TotalDegree() +max_level = 12 +order = 3 +maxl = 6 + +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +rng = Random.MersenneTwister(1234) + +ace_model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, ace_model) + +# Create ACE calculator +model = M.ACEPotential(ace_model, ps, st) + +# Load dataset +data = ExtXYZ.load(artifact"Si_tiny_dataset" * "/Si_tiny.xyz") + +data_keys = [:energy_key => "dft_energy", + :force_key => "dft_force", + :virial_key => "dft_virial"] + +weights = Dict("default" => Dict("E"=>30.0, "F"=>1.0, "V"=>1.0), + "liq" => Dict("E"=>10.0, "F"=>0.66, "V"=>0.25)) + +## ----- Fit original ACE model ----- + +@info("Fitting original ACE model with QR solver") +acefit!(data, model; + data_keys..., + weights = weights, + solver = ACEfit.QR()) + +ace_err = ACEpotentials.compute_errors(data, model; data_keys..., weights=weights) +@info("Original ACE RMSE (set):", + E=ace_err["rmse"]["set"]["E"], + F=ace_err["rmse"]["set"]["F"], + V=ace_err["rmse"]["set"]["V"]) + +# Store for comparison +ace_rmse_E = ace_err["rmse"]["set"]["E"] +ace_rmse_F = ace_err["rmse"]["set"]["F"] +ace_rmse_V = ace_err["rmse"]["set"]["V"] + +## ----- Convert to ETACE and compare ----- + +@info("Converting to ETACE model") + +# Update ps from model after fitting +ps = model.ps +st = model.st + +# Convert to ETACE +et_model = ETM.convert2et(ace_model) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) + +# Copy radial basis parameters (single species case) +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] + +# Copy readout parameters +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] + +# Get cutoff +rcut = maximum(a.rcut for a in ace_model.pairbasis.rin0cuts) + +# Create ETACEPotential +et_calc = ETM.ETACEPotential(et_model, et_ps, et_st, rcut) + +# Create ACE model WITHOUT pair potential for fair comparison +# (ETACE only implements the many-body basis, not the pair potential) +ps_nopair = merge(ps, (Wpair = zeros(size(ps.Wpair)),)) +model_nopair = M.ACEPotential(ace_model, ps_nopair, st) + +## ----- Test energy consistency ----- + +@info("Testing energy consistency between ACE (no pair) and ETACE") + +# Skip isolated atom (index 1) - ETACE requires at least 2 atoms for graph construction +max_energy_diff = 0.0 +for (i, sys) in enumerate(data[2:min(11, length(data))]) + local E_ace = ustrip(u"eV", AtomsCalculators.potential_energy(sys, model_nopair)) + local E_etace = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) + local diff = abs(E_ace - E_etace) + global max_energy_diff = max(max_energy_diff, diff) +end + +@info("Max energy difference: $max_energy_diff eV") +@test max_energy_diff < 1e-10 +println("Energy consistency: OK (max_diff = $max_energy_diff eV)") + +## ----- Test forces consistency ----- + +@info("Testing forces consistency between ACE (no pair) and ETACE") + +max_force_diff = 0.0 +for (i, sys) in enumerate(data[2:min(10, length(data))]) + F_ace = AtomsCalculators.forces(sys, model_nopair) + F_etace = AtomsCalculators.forces(sys, et_calc) + for (f1, f2) in zip(F_ace, F_etace) + diff = norm(ustrip.(f1) - ustrip.(f2)) + global max_force_diff = max(max_force_diff, diff) + end +end + +@info("Max force difference: $max_force_diff eV/Å") +@test max_force_diff < 1e-10 +println("Forces consistency: OK (max_diff = $max_force_diff eV/Å)") + +## ----- Test virial consistency ----- + +@info("Testing virial consistency between ACE (no pair) and ETACE") + +max_virial_diff = 0.0 +for (i, sys) in enumerate(data[2:min(10, length(data))]) + V_ace = AtomsCalculators.virial(sys, model_nopair) + V_etace = AtomsCalculators.virial(sys, et_calc) + diff = maximum(abs.(ustrip.(V_ace) - ustrip.(V_etace))) + global max_virial_diff = max(max_virial_diff, diff) +end + +@info("Max virial difference: $max_virial_diff eV") +@test max_virial_diff < 1e-9 +println("Virial consistency: OK (max_diff = $max_virial_diff eV)") + +## ----- Test training basis assembly ----- + +@info("Testing training basis assembly") + +# Pick a test structure +sys = data[5] +natoms = length(sys) +nparams = ETM.length_basis(et_calc) + +# Get basis +efv_basis = ETM.energy_forces_virial_basis(sys, et_calc) + +# Verify shapes +@test length(efv_basis.energy) == nparams +@test size(efv_basis.forces) == (natoms, nparams) +@test length(efv_basis.virial) == nparams + +# Verify linear combination matches direct evaluation +θ = ETM.get_linear_parameters(et_calc) + +E_from_basis = dot(ustrip.(efv_basis.energy), θ) +E_direct = ustrip(u"eV", AtomsCalculators.potential_energy(sys, et_calc)) +@test isapprox(E_from_basis, E_direct, rtol=1e-10) + +F_from_basis = efv_basis.forces * θ +F_direct = AtomsCalculators.forces(sys, et_calc) +max_F_diff = maximum(norm(ustrip.(f1) - ustrip.(f2)) for (f1, f2) in zip(F_from_basis, F_direct)) +@test max_F_diff < 1e-10 + +V_from_basis = sum(θ[k] * ustrip.(efv_basis.virial[k]) for k in 1:nparams) +V_direct = ustrip.(AtomsCalculators.virial(sys, et_calc)) +max_V_diff = maximum(abs.(V_from_basis - V_direct)) +@test max_V_diff < 1e-9 + +println("Training basis assembly: OK") + +## ----- Test StackedCalculator with ETOneBody ----- + +@info("Testing StackedCalculator with ETOneBody") + +# Create ETOneBody model with arbitrary E0 value for testing (upstream interface) +E0s = Dict(:Si => -158.54496821) # Si symbol => E0 +et_onebody = ETM.one_body(E0s, x -> x.z) +_, onebody_st = Lux.setup(rng, et_onebody) +E0_calc = ETM.WrappedSiteCalculator(et_onebody, nothing, onebody_st, 3.0) + +# Create wrapped ETACE (unified interface) +ace_calc = ETM.WrappedSiteCalculator(et_model, et_ps, et_st, rcut) + +# Stack them +stacked = ETM.StackedCalculator((E0_calc, ace_calc)) + +# Test on a few structures (skip isolated atom) +for (i, sys) in enumerate(data[2:5]) + E_E0 = AtomsCalculators.potential_energy(sys, E0_calc) + E_ace = AtomsCalculators.potential_energy(sys, ace_calc) + E_stacked = AtomsCalculators.potential_energy(sys, stacked) + + expected = ustrip(E_E0) + ustrip(E_ace) + actual = ustrip(E_stacked) + + @test isapprox(expected, actual, rtol=1e-10) +end + +println("StackedCalculator: OK") + +## ----- Summary ----- + +@info("All ETACE integration tests passed!") +@info("Summary:") +@info(" - Energy matches ACE (many-body only) to < 1e-10 eV") +@info(" - Forces match ACE (many-body only) to < 1e-10 eV/Å") +@info(" - Virial matches ACE (many-body only) to < 1e-9 eV") +@info(" - Training basis assembly verified") +@info(" - StackedCalculator composition verified") +@info("Note: ETACE implements only the many-body basis, not the pair potential.") diff --git a/test/etmodels/test_etace.jl b/test/etmodels/test_etace.jl new file mode 100644 index 000000000..96dc3b87b --- /dev/null +++ b/test/etmodels/test_etace.jl @@ -0,0 +1,311 @@ +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "Polynomials4ML.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) + +## + +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful, + AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML +import DecoratedParticles as DP + +using Polynomials4ML.Testing: print_tf, println_slim + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## + +# Generate an ACE model in the v0.8 style but +# - with fixed rcut. (relaxes this requirement later!!) +# - remove E0s +# - remove pair potential + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 +rcut = 5.5 + +# modify rin0cuts to have same cutoff for all elements +# TODO: there is currently a bug with variable cutoffs +# (?is there? The radials seem fine? check again) +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) + + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal) + +ps, st = Lux.setup(rng, model) + +# wrap the old ACE model into a calculator +calc_model = ACEpotentials.ACEPotential(model, ps, st) + + +# Missing issues: +# Vref = 0 => this will not be tested +# pair potential will also not be tested + +# kill the pair basis for now +for s in model.pairbasis.splines + s.itp.itp.coefs[:] *= 0 +end + +## +# +# Convert the v0.8 model to an ET backend based model based on the +# implementation in ETM +# +et_model = ETM.convert2et(model) +et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model) + +## +# fixup all the parameters to make sure they match +# the basis ordering appears to be identical, but it is not clear it really +# is because meta["mb_spec"] only gives the original ordering before basis +# construction ... something to look into. +nnll = M.get_nnll_spec(model.tensor) +et_nnll = et_model.basis.meta["mb_spec"] +@info("Check basis ordering") +println_slim(@test nnll == et_nnll) + +# but this is also identical ... +@info("Check symmetrization operator") +@show ( model.tensor.A2Bmaps[1] == et_model.basis.A2Bmaps[1] ) + +# radial basis parameters for et_model +et_ps.rembed.post.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.post.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.post.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.post.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2] + +# many-body basis parameters for et_model +et_ps.readout.W[1, :, 1] .= ps.WB[:, 1] +et_ps.readout.W[1, :, 2] .= ps.WB[:, 2] + +## + +# setup two splined ACE models + +spl_50 = ETM.splinify(et_model, et_ps, et_st; Nspl = 50) +ps_50, st_50 = Lux.setup(rng, spl_50) +ps_50.readout.W[:] .= et_ps.readout.W[:] + +spl_200 = ETM.splinify(et_model, et_ps, et_st; Nspl = 200) +ps_200, st_200 = Lux.setup(rng, spl_200) +ps_200.readout.W[:] .= et_ps.readout.W[:] + +## + + +function rand_struct() + sys = AtomsBuilder.bulk(:Si) * (2,2,1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +function energy_new(sys, _model, _ps, _st) + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + Ei, _ = _model(G, _ps, _st) + return sum(Ei) +end + +## + +Random.seed!(1234) +@info("Check total energies match") +for ntest = 1:30 + sys = rand_struct() + E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) + E2 = energy_new(sys, et_model, et_ps, et_st) + E3 = energy_new(sys, spl_50, ps_50, st_50) + E4 = energy_new(sys, spl_200, ps_200, st_200) + print_tf( @test abs(E1 - E2) < 1e-6 ) + print_tf( @test abs(E2 - E3) / (1+abs(E2)+abs(E3)) < 1e-2 ) + print_tf( @test abs(E2 - E4) / (1+abs(E2)+abs(E4)) < 1e-4 ) +end +println() + +## +# +# Zygote gradient +# +using Zygote, ForwardDiff + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +∂G2a = Zygote.gradient(G -> sum(et_model(G, et_ps, et_st)[1]), G)[1] +∂G2b = ETM.site_grads(et_model, G, et_ps, et_st) + +∂G_50 = ETM.site_grads(spl_50, G, ps_50, st_50) +∂G_200 = ETM.site_grads(spl_200, G, ps_200, st_200) + +@info("confirm consistency of Zygote and site_grads") +println(@test all(∂G2a.edge_data .≈ ∂G2b.edge_data)) + +err_50 = maximum(norm.(∂G2b.edge_data - ∂G_50.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_50.edge_data))) +err_200 = maximum(norm.(∂G2b.edge_data - ∂G_200.edge_data) ./ (1 .+ norm.(∂G2b.edge_data) .+ norm.(∂G_200.edge_data))) +println_slim(@test err_50 < 1) +println_slim(@test err_200 < 0.01) + +## +# test gradient against ForwardDiff + +function grad_fd(G, model, ps, st) + function _replace_edges(X, Rmat) + Rsvec = [ SVector{3}(Rmat[:, i]) for i in 1:size(Rmat, 2) ] + new_edgedata = [ DP.PState(𝐫 = 𝐫, z0 = x.z0, z1 = x.z1, 𝐒 = x.𝐒) + for (𝐫, x) in zip(Rsvec, G.edge_data) ] + return ET.ETGraph( X.ii, X.jj, X.first, + X.node_data, new_edgedata, X.graph_data, + X.maxneigs ) + end + + function _energy(Rmat) + G_new = _replace_edges(G, Rmat) + return sum(model(G_new, ps, st)[1]) + end + + Rsvec = [ x.𝐫 for x in G.edge_data ] + Rmat = reinterpret(reshape, eltype(Rsvec[1]), Rsvec) + ∇E_fd = ForwardDiff.gradient(_energy, Rmat) + ∇E_svec = [ SVector{3}(∇E_fd[:, i]) for i in 1:size(∇E_fd, 2) ] + ∇E_edges = [ DP.VState(; 𝐫 = 𝐫) for 𝐫 in ∇E_svec ] + return ET.ETGraph( G.ii, G.jj, G.first, + G.node_data, ∇E_edges, G.graph_data, + G.maxneigs ) +end + +@info("confirm consistency of gradients with ForwardDiff") + +∇E_fd = grad_fd(G, et_model, et_ps, et_st) +println(@test all(∇E_fd.edge_data .≈ ∂G2b.edge_data)) + +## +# +# sys = rand_struct() +@info("Testing basis and jacobian") + +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +nnodes = length(G.node_data) +iZ = et_model.readout.selector.(G.node_data) +WW = et_ps.readout.W + +𝔹1 = ETM.site_basis(et_model, G, et_ps, et_st) +𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_model, G, et_ps, et_st) + +## + +@info("confirm correctness of site basis") + +println_slim(@test 𝔹1 ≈ 𝔹2) +Ei_a = [ dot(𝔹2[i, :], WW[1, :, iZ[i]]) for (i, iz) in enumerate(iZ) ] +Ei_b = et_model(G, et_ps, et_st)[1][:] +println_slim(@test Ei_a ≈ Ei_b) + +## + +@info("splined site basis") +𝔹_200 = ETM.site_basis(spl_200, G, ps_200, st_200) +𝔹2_200, ∂𝔹2_200 = ETM.site_basis_jacobian(spl_200, G, ps_200, st_200) + +println_slim(@test 𝔹_200 ≈ 𝔹2_200 ) +println_slim(@test norm(𝔹1 - 𝔹_200, Inf) < 3e-3) +println_slim(@test maximum(norm.(∂𝔹2 - ∂𝔹2_200)) < 0.1) + +## + +@info("Confirm correctness of Jacobian against gradient") +# compute the gradient from the jacobian by hand +# size(𝔹2) = (num_nodes, basis_len) +# size(∂𝔹2) = (num_edges, num_nodes, basislen) + +∇Ei2 = reduce( hcat, ∂𝔹2[:, i, :] * WW[1, :, iZ[i]] + for (i, iz) in enumerate(iZ) ) +∇Ei3 = reshape(∇Ei2, size(∇Ei2)..., 1) +∇E_𝔹_edges = ET.rev_reshape_embedding(∇Ei3, G)[:] +println(@test all(∇E_𝔹_edges .≈ ∂G2b.edge_data)) + + +## +# +# demo GPU evaluation +# + +# +# turning off this test until we figure out how to do proper CI on GPUs? +# until then this just needs to be done manually and locally? + +#= + +@info("Checking GPU evaluation with Metal.jl") + +# TODO: replace Metal with generic GPU test +using Metal +dev = Metal.mtl + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +G_32 = ET.float32(G) + +# move all data to the device +G_32_dev = dev(G_32) +ps_dev = dev(ET.float32(et_ps)) +st_dev = dev(ET.float32(et_st)) +ps_32 = ET.float32(et_ps) +st_32 = ET.float32(et_st) + +E1 = ustrip(AtomsCalculators.potential_energy(sys, calc_model)) +E4 = sum(et_model(G_32_dev, ps_dev, st_dev)[1]) +println_slim( @test abs(E1 - E4) / (abs(E1) + abs(E4) + 1e-7) < 1e-5 ) + +# Something still wrong evaluating the splines on GPU +# ps_50_dev = dev(ET.float32(ps_50)) +# st_50_dev = dev(ET.float32(st_50)) +# E5 = sum(spl_50(G_32_dev, ps_50_dev, st_50_dev)[1]) + +## +# gradients on GPU + +@info("Check Evaluation of gradient on GPU") +g1 = ETM.site_grads(et_model, G_32, ps_32, st_32) +g2_dev = ETM.site_grads(et_model, G_32_dev, ps_dev, st_dev) +∇1 = g1.edge_data +∇2 = Array(g2_dev.edge_data) +println_slim( @test all(∇1 .≈ ∇2) ) + +## + +@info("Basis evaluation on GPU") + +𝔹1 = ETM.site_basis(et_model, G_32, ps_32, st_32) +𝔹2_dev = ETM.site_basis(et_model, G_32_dev, ps_dev, st_dev) +𝔹2 = Array(𝔹2_dev) +println_slim( @test 𝔹1 ≈ 𝔹2 ) + +@info("Basis jacobian evaluation on GPU") +𝔹1, ∂𝔹1 = ETM.site_basis_jacobian(et_model, G_32, ps_32, st_32) +𝔹2_dev, ∂𝔹2_dev = ETM.site_basis_jacobian(et_model, G_32_dev, ps_dev, st_dev) + +𝔹2 = Array(𝔹2_dev) +∂𝔹2 = Array(∂𝔹2_dev) + +println_slim( @test 𝔹1 ≈ 𝔹2 ) +err_jac = norm.(∂𝔹1 - ∂𝔹2) ./ (norm.(∂𝔹1) + norm.(∂𝔹2) .+ 0.1) +println_slim( @test maximum(err_jac) < 1e-4 ) +@show maximum(err_jac) +@info("The jacobian error feels a bit large. This may need further investigation.") + +=# \ No newline at end of file diff --git a/test/etmodels/test_etonebody.jl b/test/etmodels/test_etonebody.jl new file mode 100644 index 000000000..8e7b18f5d --- /dev/null +++ b/test/etmodels/test_etonebody.jl @@ -0,0 +1,163 @@ +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "EquivariantTensors.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "Polynomials4ML.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "DecoratedParticles")) + +## + +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful, + AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML +import DecoratedParticles as DP + +using Polynomials4ML.Testing: print_tf, println_slim + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## + +# Generate an ACE model in the v0.8 style but +# - with fixed rcut. (relaxes this requirement later!!) +# - remove E0s +# - remove pair potential + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 + +# modify rin0cuts to have same cutoff for all elements +# TODO: there is currently a bug with variable cutoffs +# (?is there? The radials seem fine? check again) +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts) + +E0s_ref = Dict(:Si => randn(), :O => randn()) + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + E0s = E0s_ref + ) + +ps, st = Lux.setup(rng, model) + +## + +V0 = model.Vref +E0s_z = V0.E0 # Dict{Int, Float64} +et_E0s = Dict( ChemicalSpecies(key) => val for (key, val) in E0s_z ) + +# let block is only needed to avoid type instability +catfun = let + x -> x.z +end +et_V0 = ETM.one_body(et_E0s, catfun) +ps, st = Lux.setup(rng, et_V0) + +## + + + +function rand_struct() + sys = AtomsBuilder.bulk(:Si) * (2,2,2) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +function site_Es_old(V0, sys) + return [ M.eval_site(V0, [], [], atomic_number(sys, i)) + for i in 1:length(sys) ] +end + +function site_Es_et(et_V0, sys, args...) + G = ET.Atoms.interaction_graph(sys, 5.0 * u"Å") + return et_V0(G, args...) +end + +## + +@info("Confirm correctness of site energies") + +for ntest = 1:30 + sys = rand_struct() + Es = site_Es_old(V0, sys) + et_Es_a = site_Es_et(et_V0, sys) + et_Es_b, _ = site_Es_et(et_V0, sys, ps, st) + print_tf( @test Es ≈ et_Es_a ≈ et_Es_b ) +end +println() + +## + +@info("Confirm correctness of gradient") + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, 5.0 * u"Å") +∂G1 = ETM.site_grads(et_V0, G, ps, st) + +# ETOneBody returns NamedTuple with empty edge_data array since there are no +# position-dependent gradients (energy only depends on atom types, not positions) +println_slim(@test ∂G1 isa NamedTuple) +println_slim(@test haskey(∂G1, :edge_data)) +println_slim(@test isempty(∂G1.edge_data)) + +## + +@info("Confirm correctness of basis and basis jacobian") + +𝔹1 = ETM.site_basis(et_V0, G, ps, st) +𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_V0, G, ps, st) + +println_slim(@test size(𝔹1) == size(𝔹2) == (length(sys), 0)) +println_slim(@test size(∂𝔹2) == (ET.maxneigs(G), length(sys), 0)) + +## + +# turn off during CI -- need to sort out CI for GPU tests +#= +@info("Check GPU evaluation") +using Metal +dev = Metal.mtl +ps_32 = ET.float32(ps) +st_32 = ET.float32(st) +ps_dev = dev(ps_32) +st_dev = dev(st_32) + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, 5.0 * u"Å") +G_32 = ET.float32(G) +G_dev = dev(G_32) + +E1, st = et_V0(G_32, ps_32, st_32) +E2_dev, st_dev = et_V0(G_dev, ps_dev, st_dev) +E2 = Array(E2_dev) +# TODO: add E1 ≈ E2 test?? + +g1 = ETM.site_grads(et_V0, G_32, ps_32, st_32) +g2_dev = ETM.site_grads(et_V0, G_dev, ps_dev, st_dev) +g2 = Array(g2_dev) +println_slim(@test g1 == g2) + +b1 = ETM.site_basis(et_V0, G_32, ps_32, st_32) +b2_dev = ETM.site_basis(et_V0, G_dev, ps_dev, st_dev) +b2 = Array(b2_dev) +println_slim(@test b1 == b2) + +b1, ∂db1 = ETM.site_basis_jacobian(et_V0, G_32, ps_32, st_32) +b2_dev, ∂db2_dev = ETM.site_basis_jacobian(et_V0, G_dev, ps_dev, st_dev) +b2 = Array(b2_dev) +∂db2 = Array(∂db2_dev) +println_slim(@test b1 == b2) +println_slim(@test ∂db1 == ∂db2) +=# \ No newline at end of file diff --git a/test/etmodels/test_etpair.jl b/test/etmodels/test_etpair.jl new file mode 100644 index 000000000..3084dd90e --- /dev/null +++ b/test/etmodels/test_etpair.jl @@ -0,0 +1,227 @@ +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) + +## + +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful, + AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML +import DecoratedParticles as DP + +using Polynomials4ML.Testing: print_tf, println_slim + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## + +# Generate an ACE model in the v0.8 style but +# - with fixed rcut. (relaxe this requirement later!!) +# get the pair potential component, compare with ETPairModel +# make pair_learnable = true to prevent splinification. + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 +rcut = 5.5 + +# modify rin0cuts to have same cutoff for all elements +# TODO: there is currently a bug with variable cutoffs +# (?is there? The radials seem fine? check again) +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) + + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true ) + +ps, st = Lux.setup(rng, model) + +# confirm that the E0s are all zero +@assert all( collect(values(model.Vref.E0)) .== 0 ) + +# set the many-body parameters to zero to isolate the pair potential +ps.WB[:] .= 0 + +## +# +# construct an ETPairModel that is consistent with `model` +# fixup the parameters to match the ACE model + +et_pair = ETM.convertpair(model) +et_ps, et_st = Lux.setup(rng, et_pair) + +# radial basis parameters for et_model_2 +et_ps.rembed.rbasis.post.W[:, :, 1] = ps.pairbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.rbasis.post.W[:, :, 2] = ps.pairbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.rbasis.post.W[:, :, 3] = ps.pairbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.rbasis.post.W[:, :, 4] = ps.pairbasis.Wnlq[:, :, 2, 2] + +# many-body basis parameters for et_model_2 +et_ps.readout.W[1, :, 1] .= ps.Wpair[:, 1] +et_ps.readout.W[1, :, 2] .= ps.Wpair[:, 2] + +## +# +# make a splined version of the et_pair model +# + +spl_50 = ETM.splinify(et_pair, et_ps, et_st; Nspl = 50) +ps_50, st_50 = Lux.setup(rng, spl_50) + +spl_200 = ETM.splinify(et_pair, et_ps, et_st; Nspl = 200) +ps_200, st_200 = Lux.setup(rng, spl_200) + +# many-body basis parameters for et_model_2 +ps_50.readout.W[:] = et_ps.readout.W +ps_200.readout.W[:] = et_ps.readout.W + + +## +# +# test energy evaluations +# + +calc_model = ACEpotentials.ACEPotential(model, ps, st) + +function rand_struct() + sys = AtomsBuilder.bulk(:Si) * (2,2,1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + return sys +end + +function energy_new(sys, et_model, ps, st) + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + Ei, _ = et_model(G, ps, st) + return sum(Ei) +end + +## + +Random.seed!(1234) +@info("Check total energies match") +for ntest = 1:30 + sys = rand_struct() + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + E1 = AtomsCalculators.potential_energy(sys, calc_model) + E2 = energy_new(sys, et_pair, et_ps, et_st) + E_50 = energy_new(sys, spl_50, ps_50, st_50) + E_200 = energy_new(sys, spl_200, ps_200, st_200) + print_tf( @test abs(ustrip(E1) - ustrip(E2)) < 1e-6 ) + print_tf( @test abs(ustrip(E2) - ustrip(E_50)) < 1e-2 ) + print_tf( @test abs(ustrip(E2) - ustrip(E_200)) < 1e-4 ) +end + +## + +@info("Check gradients and jacobians") + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, rcut * u"Å") +nnodes = length(G.node_data) +iZ = et_pair.readout.selector.(G.node_data) +WW = et_ps.readout.W + +# gradient of model w.r.t. positions +∂G = ETM.site_grads(et_pair, G, et_ps, et_st) +∂G_200 = ETM.site_grads(spl_200, G, ps_200, st_200) + +# basis +𝔹1 = ETM.site_basis(et_pair, G, et_ps, et_st) +𝔹1_200 = ETM.site_basis(spl_200, G, ps_200, st_200) + +# basis jacobian +𝔹2, ∂𝔹2 = ETM.site_basis_jacobian(et_pair, G, et_ps, et_st) +𝔹2_200, ∂𝔹2_200 = ETM.site_basis_jacobian(spl_200, G, ps_200, st_200) + +println_slim(@test 𝔹1 ≈ 𝔹2) +println_slim(@test 𝔹1_200 ≈ 𝔹2_200) +println_slim(@test norm(𝔹1 - 𝔹1_200, Inf) < 1e-4) + +∇Ei2 = reduce( hcat, ∂𝔹2[:, i, :] * WW[1, :, iZ[i]] + for (i, iz) in enumerate(iZ) ) +∇Ei3 = reshape(∇Ei2, size(∇Ei2)..., 1) +∇E_𝔹_edges = ET.rev_reshape_embedding(∇Ei3, G)[:] +println_slim(@test all(∇E_𝔹_edges .≈ ∂G.edge_data)) + +# check error in site energy gradients for splines +println_slim(@test maximum(norm.(∂G.edge_data - ∂G_200.edge_data)) < 1e-3) +# check error in basis jacobian for splines +println_slim(@test maximum(norm.(∂𝔹2 - ∂𝔹2_200)) < 1e-3) + +## + + +# turn off during CI -- need to sort out CI for GPU tests + +#= + +@info("Check GPU evaluation") +using Metal +dev = Metal.mtl +ps_32 = ET.float32(et_ps) +st_32 = ET.float32(et_st) +ps_dev = dev(ps_32) +st_dev = dev(st_32) + +sys = rand_struct() +G = ET.Atoms.interaction_graph(sys, 5.0 * u"Å") +G_32 = ET.float32(G) +G_dev = dev(G_32) + +E1, st = et_pair(G_32, ps_32, st_32) +E2_dev, st_dev = et_pair(G_dev, ps_dev, st_dev) +E2 = Array(E2_dev) +println_slim(@test E1 ≈ E2) + +## + +@info(" .... with splines") +ps_50_32 = ET.float32(ps_50) +st_50_32 = ET.float32(st_50) +ps_50_dev = dev(ET.float32(ps_50)) +st_50_dev = dev(ET.float32(st_50)) +E3a, _ = spl_50(G_32, ps_50_32, st_50_32) +E3b_dev, _ = spl_50(G_dev, ps_50_dev, st_50_dev) +E3b = Array(E3b_dev) +println_slim(@test E3a ≈ E3b) + +## + +@info(" .... gradients on GPU") +g1 = ETM.site_grads(et_pair, G_32, ps_32, st_32) +g2_dev = ETM.site_grads(et_pair, G_dev, ps_dev, st_dev) +g2_edge = Array(g2_dev.edge_data) +println_slim(@test all(g1.edge_data .≈ g2_edge)) + +@info(" .... basis on GPU") +b1 = ETM.site_basis(et_pair, G_32, ps_32, st_32) +b2_dev = ETM.site_basis(et_pair, G_dev, ps_dev, st_dev) +b2 = Array(b2_dev) +println_slim(@test b1 ≈ b2) + +b1, ∂db1 = ETM.site_basis_jacobian(et_pair, G_32, ps_32, st_32) +b2_dev, ∂db2_dev = ETM.site_basis_jacobian(et_pair, G_dev, ps_dev, st_dev) +b2 = Array(b2_dev) +∂db2 = Array(∂db2_dev) +println_slim(@test b1 ≈ b2) +jacerr = norm.(∂db1 .- ∂db2) ./ (1 .+ norm.(∂db1) + norm.(∂db2)) +@show maximum(jacerr) +println_slim( @test maximum(jacerr) < 1e-4 ) + +## +=# \ No newline at end of file diff --git a/test/etmodels/test_splines.jl b/test/etmodels/test_splines.jl new file mode 100644 index 000000000..f5cff75b4 --- /dev/null +++ b/test/etmodels/test_splines.jl @@ -0,0 +1,152 @@ +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using TestEnv; TestEnv.activate(); +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) +# Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "DecoratedParticles")) + +## + +using ACEpotentials, StaticArrays, Lux, AtomsBase, AtomsBuilder, Unitful, + AtomsCalculators, Random, LuxCore, Test, LinearAlgebra, ACEbase, + ForwardDiff + +M = ACEpotentials.Models +ETM = ACEpotentials.ETModels +import EquivariantTensors as ET +import Polynomials4ML as P4ML +import DecoratedParticles as DP + +using Polynomials4ML.Testing: print_tf, println_slim + +rng = Random.MersenneTwister(1234) +Random.seed!(1234) + +## + +# Generate an ACE model in the v0.8 style but +# - with fixed rcut. (relaxe this requirement later!!) +# get the pair potential component, compare with ETPairModel +# make pair_learnable = true to prevent splinification. + +elements = (:Si, :O) +level = M.TotalDegree() +max_level = 10 +order = 3 +maxl = 6 +rcut = 5.5 + +# modify rin0cuts to have same cutoff for all elements +# TODO: there is currently a bug with variable cutoffs +# (?is there? The radials seem fine? check again) +rin0cuts = M._default_rin0cuts(elements) +rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = rcut)).(rin0cuts) + + +model = M.ace_model(; elements = elements, order = order, + Ytype = :solid, level = level, max_level = max_level, + maxl = maxl, pair_maxn = max_level, + rin0cuts = rin0cuts, + init_WB = :glorot_normal, init_Wpair = :glorot_normal, + pair_learnable = true ) + +Random.seed!(1234) # new seed to make sure the tests are consistent +ps, st = Lux.setup(rng, model) + +# confirm that the E0s are all zero +@assert all( collect(values(model.Vref.E0)) .== 0 ) + +# set the many-body parameters to zero to isolate the pair potential +ps.WB[:] .= 0 + +## +# +# construct an ETPairModel that is consistent with `model` +# fixup the parameters to match the ACE model + +et_pair = ETM.convertpair(model) +et_ps, et_st = Lux.setup(rng, et_pair) + +# radial basis parameters for et_model_2 +et_ps.rembed.rbasis.post.W[:, :, 1] = ps.pairbasis.Wnlq[:, :, 1, 1] +et_ps.rembed.rbasis.post.W[:, :, 2] = ps.pairbasis.Wnlq[:, :, 1, 2] +et_ps.rembed.rbasis.post.W[:, :, 3] = ps.pairbasis.Wnlq[:, :, 2, 1] +et_ps.rembed.rbasis.post.W[:, :, 4] = ps.pairbasis.Wnlq[:, :, 2, 2] + +# many-body basis parameters for et_model_2 +et_ps.readout.W[1, :, 1] .= ps.Wpair[:, 1] +et_ps.readout.W[1, :, 2] .= ps.Wpair[:, 2] + +## +# convert the pair basis to a splined version + +# overkill spline accuracy to check errors +Nspl = 200 + +# polynomial basis taking y = y(r) as input +polys_y = et_pair.rembed.layer.rbasis.basis +# weights for cat-1 +WW = et_ps.rembed.rbasis.post.W +splines = [ + P4ML.splinify( y -> WW[:, :, i] * polys_y(y), -1.0, 1.0, Nspl ) + for i in 1:size(WW, 3) ] +states = [ P4ML._init_luxstate(spl) for spl in splines ] +selector2 = et_pair.rembed.layer.rbasis.post.selector +trans_y = et_pair.rembed.layer.rbasis.trans +envelope = et_pair.rembed.layer.envelope + +spl_rbasis = ET.trans_splines(trans_y, splines, selector2, envelope) +ps_spl, st_spl = LuxCore.setup(rng, spl_rbasis) + +poly_rbasis = et_pair.rembed.layer +ps_poly = et_ps.rembed +st_poly = et_st.rembed + +## + +function rand_X() + sys = AtomsBuilder.bulk(:Si) * (2,2,1) + rattle!(sys, 0.2u"Å") + AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5]) + G = ET.Atoms.interaction_graph(sys, rcut * u"Å") + return G.edge_data +end + + +## + +@info("Checking spline accuracy against polynomial basis") +Random.seed!(1234) # new seed to make sure the tests are ok. +for ntest = 1:30 + X = rand_X() + P1, _ = poly_rbasis(X, ps_poly, st_poly) + P2, _ = spl_rbasis(X, ps_spl, st_spl) + spl_err = abs.(P1 - P2) ./ (abs.(P1) .+ abs.(P2) .+ 1) + # @show maximum(spl_err) + print_tf(@test maximum(spl_err) < 1e-5) + + (P1a, dP1), _ = ET.evaluate_ed(poly_rbasis, X, ps_poly, st_poly) + (P2a, dP2), _ = ET.evaluate_ed(spl_rbasis, X, ps_spl, st_spl) + print_tf(@test P2 ≈ P2a) + dspl_err = norm.(dP1 - dP2) ./ (1 .+ abs.(P1) + abs.(P2)) + # @show maximum(dspl_err) + print_tf(@test maximum(dspl_err) < 1e-3) +end + +## + +@info("Checking machine precision derivative accuracy ") +# NOTE: This test should really be in ET and not here ... + +X = rand_X() +rand_u() = ( u = (@SVector randn(3)); DP.VState(𝐫 = u/norm(u)) ) +U = [ rand_u() for _ = 1:length(X) ] + +f(t) = spl_rbasis(X + t * U, ps_spl, st_spl)[1] +df0 = ForwardDiff.derivative(f, 0.0) + +(P2a, dP2), _ = ET.evaluate_ed(spl_rbasis, X, ps_spl, st_spl) +dp = [ dot(U[i], dP2[i, j]) for i in 1:length(U), j = 1:size(dP2, 2) ] +println_slim(@test df0 ≈ dp) + + +## diff --git a/test/models/test_committee.jl b/test/models/test_committee.jl index e8291e884..c45d1ea11 100644 --- a/test/models/test_committee.jl +++ b/test/models/test_committee.jl @@ -1,5 +1,5 @@ -using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) ## using Test, ACEbase, LinearAlgebra diff --git a/test/models/test_learnable_Rnl.jl b/test/models/test_learnable_Rnl.jl index 9a1e64c52..3ec6b5a54 100644 --- a/test/models/test_learnable_Rnl.jl +++ b/test/models/test_learnable_Rnl.jl @@ -1,14 +1,18 @@ +using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) +using TestEnv; TestEnv.activate(); +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "EquivariantTensors.jl")) +Pkg.develop(url = joinpath(@__DIR__(), "..", "..", "..", "Polynomials4ML.jl")) - -# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", "..")) -# using TestEnv; TestEnv.activate(); +## using ACEpotentials M = ACEpotentials.Models +import EquivariantTensors as ET -using Random, LuxCore, Test, LinearAlgebra, ACEbase -using Polynomials4ML.Testing: print_tf +using Random, LuxCore, Test, LinearAlgebra, ACEbase +using AtomsBase, StaticArrays +using Polynomials4ML.Testing: print_tf, println_slim rng = Random.MersenneTwister(1234) Random.seed!(1234) @@ -83,3 +87,60 @@ Rnl_spl, ∇Rnl_spl = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl) println_slim(@test norm(Rnl - Rnl_spl, Inf) < 1e-4 ) println_slim(@test norm(∇Rnl - ∇Rnl_spl, Inf) < 1e-2 ) +## +# +# test the conversion to a Lux style Rnl basis +# +et_rbasis = M._convert_Rnl_learnable(basis) +et_ps, et_st = LuxCore.setup(Random.default_rng(), et_rbasis) + +et_ps.connection.W[:, :, 1] = ps.Wnlq[:, :, 1, 1] +et_ps.connection.W[:, :, 2] = ps.Wnlq[:, :, 1, 2] +et_ps.connection.W[:, :, 3] = ps.Wnlq[:, :, 2, 1] +et_ps.connection.W[:, :, 4] = ps.Wnlq[:, :, 2, 2] + +for ntest = 1:50 + global ps, st, et_ps, et_st + r = 2.0 + 5 * rand() + Zi = rand(basis._i2z) + Zj = rand(basis._i2z) + xij = ( 𝐫 = SA[r,0.0,0.0], s0 = ChemicalSpecies(Zi), s1 = ChemicalSpecies(Zj) ) + R1 = basis(r, Zi, Zj, ps, st) + R2 = et_rbasis( xij, et_ps, et_st)[1] + print_tf(@test R1 ≈ R2) +end + +# batched test +for ntest = 1:10 + z0 = rand(basis._i2z) + xx = [ (𝐫 = SA[2.0 + 2 * rand(), 0.0, 0.0], + s0 = ChemicalSpecies(z0), + s1 = ChemicalSpecies(rand(basis._i2z))) for _ in 1:30 ] + rr = [ x.𝐫[1] for x in xx ] + Zjs = [ atomic_number(x.s1) for x in xx ] + R1 = M.evaluate_batched(basis, rr, z0, Zjs, ps, st) + R2 = et_rbasis( xx, et_ps, et_st)[1] + print_tf(@test R1 ≈ R2) +end + +## +# run on GPU +using Metal +dev = Metal.mtl + +et_rbasis = M._convert_Rnl_learnable(basis) +et_ps, et_st = LuxCore.setup(Random.default_rng(), et_rbasis) + +z0 = rand(basis._i2z) +xx = [ (𝐫 = SA[2.0 + 2 * rand(), 0.0, 0.0], + s0 = ChemicalSpecies(z0), + s1 = ChemicalSpecies(rand(basis._i2z))) for _ in 1:1000 ] + +xx_dev = dev(ET.float32.(xx)) +ps_dev = dev(ET.float32(et_ps)) +st_dev = dev(ET.float32(et_st)) + +R1 = et_rbasis(xx, et_ps, et_st)[1] +R2 = et_rbasis(xx_dev, ps_dev, st_dev)[1] +println_slim(@test Matrix(R2) ≈ R1) + diff --git a/test/models/test_radial_transforms.jl b/test/models/test_radial_transforms.jl index 02e1480a8..4f3757604 100644 --- a/test/models/test_radial_transforms.jl +++ b/test/models/test_radial_transforms.jl @@ -57,3 +57,6 @@ for trans in [trans_2_2, trans_2_4, trans_1_3] println_slim( @test ACEpotentials.Models.test_normalized_transform(trans_2_2) ) end +## + + diff --git a/test/runtests.jl b/test/runtests.jl index afe91e8b8..af4b1d070 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,12 @@ using ACEpotentials, Test, LazyArtifacts # make sure miscellaneous and weird bugs @testset "Weird bugs" begin include("test_bugs.jl") end + # new ET backend tests + @testset "ET ACE" begin include("etmodels/test_etace.jl") end + @testset "ET OneBody" begin include("etmodels/test_etonebody.jl") end + @testset "ET Pair" begin include("etmodels/test_etpair.jl") end + @testset "ET Calculators" begin include("et_models/test_et_calculators.jl") end + # ACE1 compatibility tests # TODO: these tests need to be revived either by creating a JSON # of test data, or by updating ACE1/ACE1x/JuLIP to be compatible. diff --git a/test/test_bugs.jl b/test/test_bugs.jl index daaa26529..719964561 100644 --- a/test/test_bugs.jl +++ b/test/test_bugs.jl @@ -5,6 +5,7 @@ using ACEpotentials.ACE1compat: ace1_model using ACEpotentials.Models: ACEPotential, potential_energy using AtomsBuilder using Unitful +using ACEbase.Testing: println_slim @info(" ============== Testing for ACEpotentials #208 ================") @info(" On Julia 1.9 some energy computations were inconsistent. ") @@ -32,18 +33,23 @@ E_per_at = [ energy_per_at(model, i) for i = 1:10 ] maxdiff = maximum(abs(E_per_at[i] - E_per_at[j]) for i = 1:10, j = 1:10 ) @show maxdiff -# NOTE: Known failure on Julia 1.12 due to hash algorithm changes +# NOTE: +# Known failure on Julia 1.12 due to hash algorithm changes # Julia 1.12 introduces a new hash algorithm that affects Dict iteration order. # This causes basis function ordering issues during model construction, resulting # in catastrophic energy errors (~28 eV instead of <1e-9 eV). # TODO: Requires investigation of upstream EquivariantTensors package and/or -# comprehensive refactoring to use OrderedDict throughout the basis construction pipeline. +# comprehensive refactoring to use OrderedDict throughout the basis construction +# pipeline. +# The test does not fail on Julia 1.12, locally on Macbook, so make sure it +# passes CI on Linux before removing the exception here. if VERSION >= v"1.12" @test_broken ustrip(u"eV", maxdiff) < 1e-9 else @test ustrip(u"eV", maxdiff) < 1e-9 end + @info(" ============================================================") @info(" ============== Testing for no Eref bug ====================")