diff --git a/backend/flaskr/__init__.py b/backend/flaskr/__init__.py index 531034738..eef4a19d8 100644 --- a/backend/flaskr/__init__.py +++ b/backend/flaskr/__init__.py @@ -11,7 +11,12 @@ def create_app(test_config=None): # create and configure the app app = Flask(__name__) - setup_db(app) + + if test_config is None: + setup_db(app) + else: + database_path = test_config.get('SQLALCHEMY_DATABASE_URI') + setup_db(app, database_path=database_path) """ @TODO: Set up CORS. Allow '*' for origins. Delete the sample route after completing the TODOs diff --git a/backend/test_flaskr.py b/backend/test_flaskr.py index 16f9c5dd6..7a9b6b783 100644 --- a/backend/test_flaskr.py +++ b/backend/test_flaskr.py @@ -12,18 +12,15 @@ class TriviaTestCase(unittest.TestCase): def setUp(self): """Define test variables and initialize app.""" - self.app = create_app() - self.client = self.app.test_client self.database_name = "trivia_test" self.database_path = "postgres://{}/{}".format('localhost:5432', self.database_name) - setup_db(self.app, self.database_path) - - # binds the app to the current context - with self.app.app_context(): - self.db = SQLAlchemy() - self.db.init_app(self.app) - # create all tables - self.db.create_all() + + self.app = create_app({ + "SQLALCHEMY_DATABASE_URI": self.database_path + }) + + self.client = self.app.test_client + def tearDown(self): """Executed after reach test"""