diff --git a/app/database.py b/app/database.py new file mode 100644 index 0000000..abf2fb6 --- /dev/null +++ b/app/database.py @@ -0,0 +1,60 @@ +import sqlite3 + +from app.types import Order + + +def connect(connection_string): + return sqlite3.connect( + connection_string, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES + ) + + +def apply_database_schema(connection): + connection.execute( + """ + CREATE TABLE IF NOT EXISTS orders ( + id integer primary key AUTOINCREMENT, + side varchar, + type varchar, + instrument char(12), + limit_price varchar, + quantity varchar, + created_at timestamp + ); + """ + ) + + +def insert_order(connection, order: Order): + result = connection.execute( + """ + INSERT INTO orders ( + side, + type, + instrument, + limit_price, + quantity, + created_at + ) VALUES ( + :side, + :type, + :instrument, + :limit_price, + :quantity, + datetime() + ) returning id, created_at + """, + { + "side": order.side.value, + "type": order.type_.value, + "instrument": order.instrument, + "limit_price": str(order.limit_price), + "quantity": str(order.quantity), + }, + ) + + returned_row = result.fetchone() + order.id_ = returned_row[0] + order.created_at = returned_row[1] + + return order diff --git a/tests/conftest.py b/tests/conftest.py index 2bd96c5..d8bba78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,8 +13,12 @@ import pytest as pytest from starlette.testclient import TestClient from app.api import app - +import app.database as database @pytest.fixture def client() -> TestClient: return TestClient(app) + +@pytest.fixture +def db_connection(): + return database.connect("test.sqlite3") diff --git a/tests/database_test.py b/tests/database_test.py new file mode 100644 index 0000000..f9bf2c7 --- /dev/null +++ b/tests/database_test.py @@ -0,0 +1,32 @@ +from datetime import datetime + +import app.database as database +from app.types import Order, OrderSide, OrderType + + +def test_store_order(db_connection): + database.apply_database_schema(db_connection) + + order = Order( + id="", + created_at=datetime.now(), + side=OrderSide.BUY, + type=OrderType.LIMIT, + instrument="123456789123", + limit_price=123.00, + quantity=1, + ) + + database.insert_order(db_connection, order) + + record = db_connection.execute( + "SELECT * FROM orders where id = :id", {"id": order.id_} + ).fetchone() + + assert record[0] == order.id_ + assert record[1] == order.side.value + assert record[2] == order.type_.value + assert record[3] == order.instrument + assert record[4] == str(order.limit_price) + assert record[5] == str(order.quantity) + assert record[6] == order.created_at