diff --git a/python/matrix.py b/python/matrix.py index b968f0a0..7833f931 100644 --- a/python/matrix.py +++ b/python/matrix.py @@ -3,13 +3,43 @@ from . import number from .math_basics import is_Interval +def snappy_make_vector(entries, *, ring=None): + return SimpleVector(entries, ring) + +def snappy_make_matrix(entries, *, ring=None): + return SimpleMatrix(entries, ring) + +if _within_sage: + from sage.modules.free_module_element import vector as _sage_vector + from sage.matrix.constructor import matrix as _sage_matrix + + def sage_make_vector(entries, *, ring=None): + if ring is None: + return _sage_vector(entries) + else: + return _sage_vector(ring, entries) + + def sage_make_matrix(entries, *, ring=None): + if ring is None: + return _sage_matrix(entries) + else: + return _sage_matrix(ring, entries) + + make_vector = sage_make_vector + make_matrix = sage_make_matrix +else: + make_vector = snappy_make_vector + make_matrix = snappy_make_matrix class SimpleVector(number.SupportsMultiplicationByNumber): - def __init__(self, list_of_values): - self.data = list_of_values + def __init__(self, entries, ring=None): + if ring is None: + self.data = entries + else: + self.data = [ ring(e) for e in entries ] try: self.type = type(self.data[0]) - self.shape = (len(list_of_values),) + self.shape = (len(entries),) except IndexError: self.type = type(0) self.shape = (0,) @@ -484,4 +514,4 @@ def mat_solve(m, v, epsilon=0): # Return the last column # (11/7, -164/133, 46/133, 32/133) - return vector([ row[-1] for row in m1]) + return make_vector([ row[-1] for row in m1])