Skip to content
Snippets Groups Projects
Commit 08c52fda authored by James Graham's avatar James Graham
Browse files

Tidied functional forms code

parent ab22ad42
Branches
Tags
No related merge requests found
...@@ -115,31 +115,29 @@ class BondSet: ...@@ -115,31 +115,29 @@ class BondSet:
except AttributeError: except AttributeError:
self._default_fc = False self._default_fc = False
# Setup default functional forms
functional_forms = FunctionalForms() functional_forms = FunctionalForms()
self._functional_forms = {} if self._default_fc:
default_forms = ["MartiniDefaultLength", "MartiniDefaultAngle", "MartiniDefaultDihedral"]
else:
default_forms = ["Harmonic", "CosHarmonic", "Harmonic"]
self._functional_forms = [None, None]
self._functional_forms.extend(map(lambda x: functional_forms[x], default_forms))
try: try:
self._functional_forms[2] = functional_forms[options.length_form] self._functional_forms[2] = functional_forms[options.length_form]
except AttributeError: except AttributeError:
if self._default_fc: pass
self._functional_forms[2] = functional_forms.MartiniDefaultLength
else:
self._functional_forms[2] = functional_forms.Harmonic
try: try:
self._functional_forms[3] = functional_forms[options.angle_form] self._functional_forms[3] = functional_forms[options.angle_form]
except AttributeError: except AttributeError:
if self._default_fc: pass
self._functional_forms[3] = functional_forms.MartiniDefaultAngle
else:
self._functional_forms[3] = functional_forms.CosHarmonic
try: try:
self._functional_forms[4] = functional_forms[options.dihedral_form] self._functional_forms[4] = functional_forms[options.dihedral_form]
except AttributeError: except AttributeError:
if self._default_fc: pass
self._functional_forms[4] = functional_forms.MartiniDefaultDihedral
else:
self._functional_forms[4] = functional_forms.Harmonic
with CFG(filename) as cfg: with CFG(filename) as cfg:
for mol in cfg: for mol in cfg:
...@@ -150,6 +148,7 @@ class BondSet: ...@@ -150,6 +148,7 @@ class BondSet:
for atomlist in mol: for atomlist in mol:
try: try:
# TODO consider best way to override default func form # TODO consider best way to override default func form
# On per bond, or per type basis
func_form = functional_forms[atomlist[-1]] func_form = functional_forms[atomlist[-1]]
except AttributeError: except AttributeError:
func_form = self._functional_forms[len(atomlist)] func_form = self._functional_forms[len(atomlist)]
......
...@@ -5,10 +5,20 @@ from pycgtool.util import SimpleEnum ...@@ -5,10 +5,20 @@ from pycgtool.util import SimpleEnum
class FunctionalForms(object): class FunctionalForms(object):
"""
Class holding list of all defined functional forms for Boltzmann Inversion.
Creating an instance causes the Enum of functional forms to be updated with
all new subclasses of FunctionalForm. These may then be accessed by name,
either as attributes or using square brackets.
"""
FormsEnum = SimpleEnum.enum("FormsEnum") FormsEnum = SimpleEnum.enum("FormsEnum")
@classmethod @classmethod
def refresh(cls): def _refresh(cls):
"""
Update the functional forms Enum to include all new subclasses of FunctionalForm.
"""
enum_dict = cls.FormsEnum.as_dict() enum_dict = cls.FormsEnum.as_dict()
for subclass in FunctionalForm.__subclasses__(): for subclass in FunctionalForm.__subclasses__():
name = subclass.__name__ name = subclass.__name__
...@@ -19,6 +29,7 @@ class FunctionalForms(object): ...@@ -19,6 +29,7 @@ class FunctionalForms(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._kwargs = kwargs self._kwargs = kwargs
type(self)._refresh()
def __getattr__(self, item): def __getattr__(self, item):
return type(self).FormsEnum[item].value return type(self).FormsEnum[item].value
...@@ -37,9 +48,23 @@ class FunctionalForms(object): ...@@ -37,9 +48,23 @@ class FunctionalForms(object):
class FunctionalForm(object, metaclass=abc.ABCMeta): class FunctionalForm(object, metaclass=abc.ABCMeta):
"""
Parent class of any functional form used in Boltzmann Inversion to convert variance to a force constant.
New functional forms must define a static __call__ method.
"""
@staticmethod @staticmethod
@abc.abstractstaticmethod @abc.abstractstaticmethod
def __call__(mean, var, temp): def __call__(mean, var, temp):
"""
Calculate force constant.
Abstract static method to be defined by all functional forms.
:param mean: Mean of internal coordinate distribution
:param var: Variance of internal coordinate distribution
:param temp: Temperature of simulation
:return: Calculated force constant
"""
pass pass
...@@ -73,7 +98,3 @@ class MartiniDefaultDihedral(FunctionalForm): ...@@ -73,7 +98,3 @@ class MartiniDefaultDihedral(FunctionalForm):
@staticmethod @staticmethod
def __call__(mean, var, temp): def __call__(mean, var, temp):
return 50. return 50.
FunctionalForms.refresh()
...@@ -16,7 +16,6 @@ class FunctionalFormTest(unittest.TestCase): ...@@ -16,7 +16,6 @@ class FunctionalFormTest(unittest.TestCase):
@staticmethod @staticmethod
def __call__(mean, var, temp): def __call__(mean, var, temp):
return "TestResult" return "TestResult"
FunctionalForms.refresh()
funcs = FunctionalForms() funcs = FunctionalForms()
self.assertIn("TestFunc", funcs) self.assertIn("TestFunc", funcs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment