Skip to content

Commit

Permalink
generic getters and setters
Browse files Browse the repository at this point in the history
  • Loading branch information
ACEsuit committed May 28, 2024
1 parent 939ebcc commit 99c6308
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 55 deletions.
123 changes: 70 additions & 53 deletions src/atomsbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,35 @@ export get_cell, AosSystem, SoaSystem, ChemicalElement
# an `Atom` is now just a `PState`, so we define
# accessors for the PState fields with canonical names.

symbol(::typeof(position)) = :𝐫
symbol(::typeof(velocity)) = :𝐯
symbol(::typeof(atomic_mass)) = :π‘š
symbol(::typeof(atomic_symbol)) = :𝑍

const _atom_syms = (𝐫 = position,
𝐯 = velocity,
π‘š = atomic_mass,
𝑍 = atomic_symbol)

position(atom::PState) = atom.𝐫
velocity(atom::PState) = atom.𝐯
atomic_mass(atom::PState) = atom.π‘š
atomic_symbol(atom::PState) = atom.𝑍 # this one I'm not sure about
# list of properties that DP knows about:
_list_of_properties = [
(:𝐫, :position),
(:𝐯, :velocity),
(:π‘š, :atomic_mass),
(:𝑍, :atomic_symbol),
(:π‘ž, :charge),
(:ΞΌ, :dipole),
(:𝐩, :momentum),
(:𝐸, :energy),
(:𝑀, :mass),
]

for (sym, name) in _list_of_properties
@eval $name(x::XState) = x.$sym
@eval symbol(::typeof($name)) = $(Meta.quot(sym))
end

atomic_number(atom::PState) = atomic_number(atomic_symbol(atom))

_post(a) = a
_post(sym::Symbol) = ChemicalElement(sym)
_post(p, a) = a
_post(::typeof(atomic_symbol), sym::Symbol) = ChemicalElement(sym)
_post(::typeof(atomic_symbol), z::Integer) = ChemicalElement(z)

"""
Generate an atom with the given properties.
"""
atom(at; properties = (position, atomic_mass, atomic_symbol)) =
PState((; [symbol(p) => _post(p(at)) for p in properties]...))
PState((; [symbol(p) => _post(p, p(at)) for p in properties]...))

# ---------------------------------------------------
# Array of Structs System
Expand Down Expand Up @@ -68,10 +72,10 @@ Base.length(at::AosSystem) = length(at.particles)
Base.getindex(at::AosSystem, i::Int) = at.particles[i]
Base.getindex(at::AosSystem, inds::AbstractVector) = at.particles[inds]

for f in (:position, :velocity, :atomic_mass, :atomic_symbol)
@eval $f(sys::AosSystem) = [ $f(x) for x in sys.particles ]
@eval $f(sys::AosSystem, i::Integer) = $f(sys.particles[i])
@eval $f(sys::AosSystem, inds::AbstractVector) = [$f(sys.particles[i]) for i in inds]
for (sym, name) in [ _list_of_properties; [(:ignore, :atomic_number) ] ]
@eval $name(sys::AosSystem) = [ $name(x) for x in sys.particles ]
@eval $name(sys::AosSystem, i::Integer) = $name(sys.particles[i])
@eval $name(sys::AosSystem, inds::AbstractVector) = [$name(sys.particles[i]) for i in inds]
end

# AtomsBase.
Expand All @@ -95,7 +99,7 @@ end
function SoaSystem(sys::AbstractSystem;
properties = (position, atomic_mass, atomic_symbol), )

arrays = (; [symbol(p) => _post.(p(sys)) for p in properties]... )
arrays = (; [symbol(p) => _post.(p, p(sys)) for p in properties]... )
cell = get_cell(sys)
D = AtomsBase.n_dimensions(cell)
return SoaSystem{D, typeof(cell), typeof(arrays)}(
Expand Down Expand Up @@ -130,13 +134,14 @@ end
Base.getindex(sys::SoaSystem, inds::AbstractVector{<: Integer}) =
[ sys[i] for i in inds ]

for f in (:position, :velocity, :atomic_mass, :atomic_symbol)
@eval $f(sys::SoaSystem) = getfield(sys.arrays, symbol($f))
@eval $f(sys::SoaSystem, i::Integer) = getfield(sys.arrays, symbol($f))[i]
@eval $f(sys::SoaSystem, inds::AbstractVector) = getfield(sys.arrays, symbol($f))[inds]

for (sym, name) in [ _list_of_properties; [(:ignore, :atomic_number) ] ]
@eval $name(sys::SoaSystem) = copy(sys.arrays.$sym)
@eval $name(sys::SoaSystem, i::Integer) = sys.arrays.$sym[i]
@eval $name(sys::SoaSystem, inds::AbstractVector) = sys.arrays.$sym[inds]
end

# AtomsBase.

get_cell(at::SoaSystem) = at.cell

for f in (:n_dimensions, :bounding_box, :boundary_conditions, :periodicity)
Expand All @@ -149,37 +154,49 @@ end
# ---------------------------------------------------------------
# Extension of the AtomsBase interface with setter functions

export set_position,
set_position!,
set_positions!,
set_bounding_box!

set_position(x::PState, 𝐫::SVector) = setproperty(x, :𝐫, 𝐫)

function set_position!(sys::AosSystem, i::Integer, 𝐫::SVector)
xi = sys.particles[i]
sys.particles[i] = set_position(xi, 𝐫)
return nothing
end

function set_positions!(sys::AosSystem, R::AbstractVector{<: SVector})
for i = 1:length(sys)
set_position!(sys, i, R[i])
for (sym, name) in _list_of_properties
set_name = Symbol("set_$name")
set_name_ip = Symbol("set_$(name)!")
set_names_ip = Symbol("set_$(name)s!")
@eval export $set_name, $set_name_ip, $set_names_ip
# e.g. set_position(x::PState, 𝐫) = setproperty(x, :𝐫, 𝐫)
@eval $set_name(x::PState, t) = setproperty(x, $(Meta.quot(sym)), t)
# e.g. set_position!(sys, i, 𝐫)
@eval begin
function $set_name_ip(sys::AosSystem, i::Integer, t)
xi = sys.particles[i]
sys.particles[i] = $set_name(xi, t)
return nothing
end
end
@eval begin
function $set_name_ip(sys::SoaSystem, i::Integer, t)
sys.arrays.$sym[i] = t
return nothing
end
end
# e.g. set_positions!(sys, X)
@eval begin
function $set_names_ip(sys::AosSystem, Rs)
for i = 1:length(sys)
$set_name_ip(sys, i, Rs[i])
end
end
return nothing
end
@eval begin
function $set_names_ip(sys::SoaSystem, R::AbstractVector)
copy!(sys.arrays.$sym, R)
return nothing
end
end
return nothing
end

# todo: set_atomic_number & co

function set_position!(sys::SoaSystem, i::Integer, 𝐫::SVector)
sys.arrays.𝐫[i] = 𝐫
return nothing
end

function set_positions!(sys::SoaSystem, R::AbstractVector{<: SVector})
copy!(sys.arrays.𝐫, R)
return nothing
end
# cell-level setter

export set_bounding_box!

function set_bounding_box!(sys::Union{SoaSystem, AosSystem}, bb)
cell = PCell(bb, periodicity(sys))
Expand Down
2 changes: 1 addition & 1 deletion src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Base: *, +, -, zero, rand, randn, show, promote_rule, rtoldefault,

import LinearAlgebra: norm, promote_leaf_eltypes

export PState, VState, vstate_type
export PState, VState, vstate_type, setproperty

abstract type XState{NT <: NamedTuple} end

Expand Down
5 changes: 4 additions & 1 deletion test/test_atomsbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ display(x)
@test atomic_symbol(x) == x.𝑍
@test atomic_number(x) == 6


##
#convert an Atom

Expand All @@ -27,7 +28,9 @@ display(x)
@test x.𝐫 == position(x) == position(at)
@test x.π‘š == atomic_mass(x) == atomic_mass(at)
@test x.𝑍 == atomic_symbol(x) == atomic_symbol(at)

@test DP.symbol(position) == :𝐫
@test DP.symbol(atomic_mass) == :π‘š
@test DP.symbol(atomic_symbol) == :𝑍

##
# convert an entire system
Expand Down

0 comments on commit 99c6308

Please sign in to comment.