diff --git a/sortinghat/cmd/init.py b/sortinghat/cmd/init.py index 5e77d776e..ba25e633e 100644 --- a/sortinghat/cmd/init.py +++ b/sortinghat/cmd/init.py @@ -25,7 +25,7 @@ import argparse -from ..command import Command, CMD_SUCCESS, CMD_FAILURE +from ..command import Command, CMD_SUCCESS from ..exceptions import DatabaseError, LoadError from ..db.database import Database from ..db.model import Country @@ -92,11 +92,11 @@ def initialize(self, name): self.__load_countries(db) except DatabaseError as e: self.error(str(e)) - return CMD_FAILURE + return e.code except LoadError as e: Database.drop(user, password, name, host, port) self.error(str(e)) - return CMD_FAILURE + return e.code return CMD_SUCCESS diff --git a/tests/test_cmd_init.py b/tests/test_cmd_init.py index 863b3e364..4a4448fec 100644 --- a/tests/test_cmd_init.py +++ b/tests/test_cmd_init.py @@ -33,9 +33,10 @@ sys.path.insert(0, '..') from sortinghat import api -from sortinghat.command import CMD_SUCCESS, CMD_FAILURE +from sortinghat.command import CMD_SUCCESS from sortinghat.cmd.init import Init from sortinghat.db.database import Database +from sortinghat.exceptions import CODE_DATABASE_ERROR from tests.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT @@ -96,7 +97,7 @@ def test_connection_error(self): cmd = Init(**kwargs) code = cmd.run(self.name) - self.assertEqual(code, CMD_FAILURE) + self.assertEqual(code, CODE_DATABASE_ERROR) with warnings.catch_warnings(record=True): output = sys.stderr.getvalue().strip() @@ -110,7 +111,7 @@ def test_existing_db_error(self): self.assertEqual(code1, CMD_SUCCESS) code2 = self.cmd.run(self.name) - self.assertEqual(code2, CMD_FAILURE) + self.assertEqual(code2, CODE_DATABASE_ERROR) # Context added to catch deprecation warnings raised on Python 3 with warnings.catch_warnings(record=True): @@ -147,7 +148,7 @@ def test_connection_error(self): cmd = Init(**kwargs) code = cmd.initialize(self.name) - self.assertEqual(code, CMD_FAILURE) + self.assertEqual(code, CODE_DATABASE_ERROR) # Context added to catch deprecation warnings raised on Python 3 with warnings.catch_warnings(record=True): @@ -162,7 +163,7 @@ def test_existing_db_error(self): self.assertEqual(code1, CMD_SUCCESS) code2 = self.cmd.initialize(self.name) - self.assertEqual(code2, CMD_FAILURE) + self.assertEqual(code2, CODE_DATABASE_ERROR) # Context added to catch deprecation warnings raised on Python 3 with warnings.catch_warnings(record=True):