Skip to content

Commit

Permalink
added cleaner method using re to escape underscores, added cleaner te…
Browse files Browse the repository at this point in the history
…st to assert underscores are escaped
  • Loading branch information
Dekermanjian committed Sep 14, 2024
1 parent 67383c5 commit 4156804
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 24 deletions.
22 changes: 4 additions & 18 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


import re

from functools import partial

from pytensor.compile import SharedVariable
Expand Down Expand Up @@ -301,22 +303,6 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):

def _format_underscore(variable: str) -> str:
"""
formats variables with underscores in its name by prefixing underscores by '\\'
---
Params:
variable: The string representation of the variable in the model
Escapes all unescaped underscores in the variable name for LaTeX representation.
"""
if "_" not in variable:
return variable

inds = [i for i, ltr in enumerate(variable) if ltr == "_"]
var_len_original = len(variable)
var_len = None
for ind in inds:
if var_len:
if var_len != var_len_original:
ind = ind + (var_len - var_len_original)
if variable[ind - 1 : ind] != "\\":
variable = variable[:ind] + "\\" + variable[ind:]
var_len = len(variable)
return variable
return re.sub(r"(?<!\\)_", r"\\_", variable)
8 changes: 2 additions & 6 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

import numpy as np

Expand Down Expand Up @@ -335,8 +334,5 @@ def test_latex_escaped_underscore(self):
"""
model = self.simple_model()
model_str = model.str_repr(formatting="latex")
underscores = re.finditer(r"_", model_str)
for match in underscores:
if match:
start = match.span(0)[0] - 1
assert model_str[start : start + 1] == "\\"
assert "\\_" in model_str
assert "_" not in model_str.replace("\\_", "")

0 comments on commit 4156804

Please sign in to comment.