Skip to content

Commit

Permalink
make from_str staticmethods into classmethods (#3429)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Oct 27, 2023
1 parent 368f51c commit 9ae6121
Show file tree
Hide file tree
Showing 24 changed files with 206 additions and 220 deletions.
6 changes: 3 additions & 3 deletions dev_scripts/potcar_scrambler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def to_file(self, filename: str):
with zopen(filename, "wt") as f:
f.write(self.scrambled_potcars_str)

@staticmethod
def from_file(input_filename: str, output_filename: str | None = None):
@classmethod
def from_file(cls, input_filename: str, output_filename: str | None = None):
psp = Potcar.from_file(input_filename)
psp_scrambled = PotcarScrambler(psp)
psp_scrambled = cls(psp)
if output_filename:
psp_scrambled.to_file(output_filename)
return psp_scrambled
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/analysis/reaction_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def from_dict(cls, d):
def from_string(cls, *args, **kwargs):
return cls.from_str(*args, **kwargs)

@staticmethod
def from_str(rxn_str):
@classmethod
def from_str(cls, rxn_str):
"""
Generates a balanced reaction from a string. The reaction must
already be balanced.
Expand Down
16 changes: 8 additions & 8 deletions pymatgen/core/periodic_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,8 +1011,8 @@ def from_string(cls, *args, **kwargs):
"""Use from_str instead."""
return cls.from_str(*args, **kwargs)

@staticmethod
def from_str(species_string: str) -> Species:
@classmethod
def from_str(cls, species_string: str) -> Species:
"""Returns a Species from a string representation.
Args:
Expand Down Expand Up @@ -1051,10 +1051,10 @@ def from_str(species_string: str) -> Species:

# but we need either an oxidation state or a property
if oxi is None and properties == {}:
raise ValueError("Invalid Species String")
raise ValueError("Invalid species string")

return Species(sym, 0 if oxi is None else oxi, **properties)
raise ValueError("Invalid Species String")
return cls(sym, 0 if oxi is None else oxi, **properties)
raise ValueError("Invalid species string")

def __repr__(self):
return f"Species {self}"
Expand Down Expand Up @@ -1294,8 +1294,8 @@ def symbol(self) -> str:
def __deepcopy__(self, memo):
return DummySpecies(self.symbol, self._oxi_state)

@staticmethod
def from_str(species_string: str) -> DummySpecies:
@classmethod
def from_str(cls, species_string: str) -> DummySpecies:
"""Returns a Dummy from a string representation.
Args:
Expand All @@ -1320,7 +1320,7 @@ def from_str(species_string: str) -> DummySpecies:
if m.group(4): # has Spin property
tokens = m.group(4).split("=")
properties = {tokens[0]: float(tokens[1])}
return DummySpecies(sym, oxi, **properties)
return cls(sym, oxi, **properties)
raise ValueError("Invalid DummySpecies String")

def as_dict(self) -> dict:
Expand Down
24 changes: 12 additions & 12 deletions pymatgen/electronic_structure/boltztrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,8 +1924,8 @@ def parse_cond_and_hall(path_dir, doping_levels=None):
carrier_conc,
)

@staticmethod
def from_files(path_dir, dos_spin=1):
@classmethod
def from_files(cls, path_dir, dos_spin=1):
"""Get a BoltztrapAnalyzer object from a set of files.
Args:
Expand All @@ -1935,29 +1935,29 @@ def from_files(path_dir, dos_spin=1):
Returns:
a BoltztrapAnalyzer object
"""
run_type, warning, efermi, gap, doping_levels = BoltztrapAnalyzer.parse_outputtrans(path_dir)
run_type, warning, efermi, gap, doping_levels = cls.parse_outputtrans(path_dir)

vol = BoltztrapAnalyzer.parse_struct(path_dir)
vol = cls.parse_struct(path_dir)

intrans = BoltztrapAnalyzer.parse_intrans(path_dir)
intrans = cls.parse_intrans(path_dir)

if run_type == "BOLTZ":
dos, pdos = BoltztrapAnalyzer.parse_transdos(path_dir, efermi, dos_spin=dos_spin, trim_dos=False)
dos, pdos = cls.parse_transdos(path_dir, efermi, dos_spin=dos_spin, trim_dos=False)

*cond_and_hall, carrier_conc = BoltztrapAnalyzer.parse_cond_and_hall(path_dir, doping_levels)
*cond_and_hall, carrier_conc = cls.parse_cond_and_hall(path_dir, doping_levels)

return BoltztrapAnalyzer(gap, *cond_and_hall, intrans, dos, pdos, carrier_conc, vol, warning)
return cls(gap, *cond_and_hall, intrans, dos, pdos, carrier_conc, vol, warning)

if run_type == "DOS":
trim = intrans["dos_type"] == "HISTO"
dos, pdos = BoltztrapAnalyzer.parse_transdos(path_dir, efermi, dos_spin=dos_spin, trim_dos=trim)
dos, pdos = cls.parse_transdos(path_dir, efermi, dos_spin=dos_spin, trim_dos=trim)

return BoltztrapAnalyzer(gap=gap, dos=dos, dos_partial=pdos, warning=warning, vol=vol)
return cls(gap=gap, dos=dos, dos_partial=pdos, warning=warning, vol=vol)

if run_type == "BANDS":
bz_kpoints = np.loadtxt(f"{path_dir}/boltztrap_band.dat")[:, -3:]
bz_bands = np.loadtxt(f"{path_dir}/boltztrap_band.dat")[:, 1:-6]
return BoltztrapAnalyzer(bz_bands=bz_bands, bz_kpoints=bz_kpoints, warning=warning, vol=vol)
return cls(bz_bands=bz_bands, bz_kpoints=bz_kpoints, warning=warning, vol=vol)

if run_type == "FERMI":
if os.path.exists(f"{path_dir}/boltztrap_BZ.cube"):
Expand All @@ -1966,7 +1966,7 @@ def from_files(path_dir, dos_spin=1):
fs_data = read_cube_file(f"{path_dir}/fort.30")
else:
raise BoltztrapError("No data file found for fermi surface")
return BoltztrapAnalyzer(fermi_surface_data=fs_data)
return cls(fermi_surface_data=fs_data)

raise ValueError(f"{run_type=} not recognized!")

Expand Down
55 changes: 27 additions & 28 deletions pymatgen/io/adf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,6 @@ def is_numeric(s) -> bool:
return True


def iterlines(s: str) -> Generator[str, None, None]:
r"""A generator form of s.split('\n') for reducing memory overhead.
Args:
s (str): A multi-line string.
Yields:
str: line
"""
prevnl = -1
while True:
nextnl = s.find("\n", prevnl + 1)
if nextnl < 0:
yield s[(prevnl + 1) :]
break
yield s[(prevnl + 1) : nextnl]
prevnl = nextnl


class AdfInputError(Exception):
"""The default error class for ADF."""

Expand Down Expand Up @@ -362,8 +343,8 @@ def from_dict(cls, d):
def from_string(cls, *args, **kwargs):
return cls.from_str(*args, **kwargs)

@staticmethod
def from_str(string):
@classmethod
def from_str(cls, string: str) -> AdfKey:
"""
Construct an AdfKey object from the string.
Expand Down Expand Up @@ -395,34 +376,52 @@ def is_float(s) -> bool:
el = string.split()
if len(el) > 1:
options = [s.split("=") for s in el[1:]] if string.find("=") != -1 else el[1:]
for i, op in enumerate(options):
for idx, op in enumerate(options): # type: ignore[var-annotated, arg-type]
if isinstance(op, list) and is_numeric(op[1]):
op[1] = float(op[1]) if is_float(op[1]) else int(op[1])
elif is_numeric(op):
options[i] = float(op) if is_float(op) else int(op)
options[idx] = float(op) if is_float(op) else int(op) # type: ignore[index]
else:
options = None
return AdfKey(el[0], options)
return cls(el[0], options)

if string.find("subend") != -1:
raise ValueError("Nested subkeys are not supported!")

def iterlines(s: str) -> Generator[str, None, None]:
r"""A generator form of s.split('\n') for reducing memory overhead.
Args:
s (str): A multi-line string.
Yields:
str: line
"""
prev_nl = -1
while True:
next_nl = s.find("\n", prev_nl + 1)
if next_nl < 0:
yield s[(prev_nl + 1) :]
break
yield s[(prev_nl + 1) : next_nl]
prev_nl = next_nl

key = None
for line in iterlines(string):
if line == "":
continue
el = line.strip().split()
if len(el) == 0:
continue
if el[0].upper() in AdfKey.block_keys:
if el[0].upper() in cls.block_keys:
if key is None:
key = AdfKey.from_str(line)
key = cls.from_str(line)
else:
return key
elif el[0].upper() == "END":
return key
return key # type: ignore[return-value]
elif key is not None:
key.add_subkey(AdfKey.from_str(line))
key.add_subkey(cls.from_str(line))

raise Exception("IncompleteKey: 'END' is missing!")

Expand Down
14 changes: 7 additions & 7 deletions pymatgen/io/babel.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ def write_file(self, filename, file_format="xyz"):
mol = pybel.Molecule(self._ob_mol)
return mol.write(file_format, filename, overwrite=True)

@staticmethod
def from_file(filename, file_format="xyz", return_all_molecules=False):
@classmethod
def from_file(cls, filename, file_format="xyz", return_all_molecules=False):
"""
Uses OpenBabel to read a molecule from a file in all supported formats.
Expand All @@ -322,9 +322,9 @@ def from_file(filename, file_format="xyz", return_all_molecules=False):
"""
mols = pybel.readfile(str(file_format), str(filename))
if return_all_molecules:
return [BabelMolAdaptor(mol.OBMol) for mol in mols]
return [cls(mol.OBMol) for mol in mols]

return BabelMolAdaptor(next(mols).OBMol)
return cls(next(mols).OBMol)

@staticmethod
def from_molecule_graph(mol):
Expand All @@ -345,8 +345,8 @@ def from_string(cls, *args, **kwargs):
return cls.from_str(*args, **kwargs)

@needs_openbabel
@staticmethod
def from_str(string_data, file_format="xyz"):
@classmethod
def from_str(cls, string_data, file_format="xyz"):
"""
Uses OpenBabel to read a molecule from a string in all supported
formats.
Expand All @@ -359,4 +359,4 @@ def from_str(string_data, file_format="xyz"):
BabelMolAdaptor object
"""
mols = pybel.readstring(str(file_format), str(string_data))
return BabelMolAdaptor(mols.OBMol)
return cls(mols.OBMol)
6 changes: 3 additions & 3 deletions pymatgen/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,8 @@ def is_magcif_incommensurate() -> bool:
def from_string(cls, *args, **kwargs):
return cls.from_str(*args, **kwargs)

@staticmethod
def from_str(cif_string: str, **kwargs) -> CifParser:
@classmethod
def from_str(cls, cif_string: str, **kwargs) -> CifParser:
"""
Creates a CifParser from a string.
Expand All @@ -379,7 +379,7 @@ def from_str(cif_string: str, **kwargs) -> CifParser:
CifParser
"""
stream = StringIO(cif_string)
return CifParser(stream, **kwargs)
return cls(stream, **kwargs)

def _sanitize_data(self, data):
"""
Expand Down
20 changes: 10 additions & 10 deletions pymatgen/io/cp2k/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,12 @@ def from_dict(cls, d):
def from_string(cls, *args, **kwargs):
return cls.from_str(*args, **kwargs)

@staticmethod
def from_str(s):
@classmethod
def from_str(cls, s):
"""
Initialize from a string.
Keywords must be labeled with strings. If the postprocessor finds
Keywords must be labeled with strings. If the post-processor finds
that the keywords is a number, then None is return (used by
the file reader).
Expand All @@ -183,7 +183,7 @@ def from_str(s):
args = s.split()
args = list(map(postprocessor if args[0].upper() != "ELEMENT" else str, args))
args[0] = str(args[0])
return Keyword(*args, units=units[0], description=description)
return cls(*args, units=units[0], description=description)

def verbosity(self, v):
"""Change the printing of this keyword's description."""
Expand Down Expand Up @@ -730,26 +730,26 @@ def _from_dict(cls, d):
.subsections,
)

@staticmethod
def from_file(file: str):
@classmethod
def from_file(cls, file: str):
"""Initialize from a file."""
with zopen(file, "rt") as f:
txt = preprocessor(f.read(), os.path.dirname(f.name))
return Cp2kInput.from_str(txt)
return cls.from_str(txt)

@classmethod
@np.deprecate(message="Use from_str instead")
def from_string(cls, *args, **kwargs):
return cls.from_str(*args, **kwargs)

@staticmethod
def from_str(s: str):
@classmethod
def from_str(cls, s: str):
"""Initialize from a string."""
lines = s.splitlines()
lines = [line.replace("\t", "") for line in lines]
lines = [line.strip() for line in lines]
lines = [line for line in lines if line]
return Cp2kInput.from_lines(lines)
return cls.from_lines(lines)

@classmethod
def from_lines(cls, lines: list | tuple):
Expand Down
12 changes: 6 additions & 6 deletions pymatgen/io/cssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def write_file(self, filename):
with zopen(filename, "wt") as f:
f.write(str(self) + "\n")

@staticmethod
def from_str(string):
@classmethod
def from_str(cls, string):
"""
Reads a string representation to a Cssr object.
Expand All @@ -79,10 +79,10 @@ def from_str(string):
if m:
sp.append(m.group(1))
coords.append([float(m.group(i)) for i in range(2, 5)])
return Cssr(Structure(latt, sp, coords))
return cls(Structure(latt, sp, coords))

@staticmethod
def from_file(filename):
@classmethod
def from_file(cls, filename):
"""
Reads a CSSR file to a Cssr object.
Expand All @@ -93,4 +93,4 @@ def from_file(filename):
Cssr object.
"""
with zopen(filename, "rt") as f:
return Cssr.from_str(f.read())
return cls.from_str(f.read())
Loading

0 comments on commit 9ae6121

Please sign in to comment.