Add implementation to persist orders in the db
This commit is contained in:
parent
58a5013744
commit
d427d0a184
60
app/database.py
Normal file
60
app/database.py
Normal file
@ -0,0 +1,60 @@
|
||||
import sqlite3
|
||||
|
||||
from app.types import Order
|
||||
|
||||
|
||||
def connect(connection_string):
|
||||
return sqlite3.connect(
|
||||
connection_string, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
|
||||
)
|
||||
|
||||
|
||||
def apply_database_schema(connection):
|
||||
connection.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS orders (
|
||||
id integer primary key AUTOINCREMENT,
|
||||
side varchar,
|
||||
type varchar,
|
||||
instrument char(12),
|
||||
limit_price varchar,
|
||||
quantity varchar,
|
||||
created_at timestamp
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def insert_order(connection, order: Order):
|
||||
result = connection.execute(
|
||||
"""
|
||||
INSERT INTO orders (
|
||||
side,
|
||||
type,
|
||||
instrument,
|
||||
limit_price,
|
||||
quantity,
|
||||
created_at
|
||||
) VALUES (
|
||||
:side,
|
||||
:type,
|
||||
:instrument,
|
||||
:limit_price,
|
||||
:quantity,
|
||||
datetime()
|
||||
) returning id, created_at
|
||||
""",
|
||||
{
|
||||
"side": order.side.value,
|
||||
"type": order.type_.value,
|
||||
"instrument": order.instrument,
|
||||
"limit_price": str(order.limit_price),
|
||||
"quantity": str(order.quantity),
|
||||
},
|
||||
)
|
||||
|
||||
returned_row = result.fetchone()
|
||||
order.id_ = returned_row[0]
|
||||
order.created_at = returned_row[1]
|
||||
|
||||
return order
|
@ -13,8 +13,12 @@ import pytest as pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.api import app
|
||||
|
||||
import app.database as database
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def db_connection():
|
||||
return database.connect("test.sqlite3")
|
||||
|
32
tests/database_test.py
Normal file
32
tests/database_test.py
Normal file
@ -0,0 +1,32 @@
|
||||
from datetime import datetime
|
||||
|
||||
import app.database as database
|
||||
from app.types import Order, OrderSide, OrderType
|
||||
|
||||
|
||||
def test_store_order(db_connection):
|
||||
database.apply_database_schema(db_connection)
|
||||
|
||||
order = Order(
|
||||
id="",
|
||||
created_at=datetime.now(),
|
||||
side=OrderSide.BUY,
|
||||
type=OrderType.LIMIT,
|
||||
instrument="123456789123",
|
||||
limit_price=123.00,
|
||||
quantity=1,
|
||||
)
|
||||
|
||||
database.insert_order(db_connection, order)
|
||||
|
||||
record = db_connection.execute(
|
||||
"SELECT * FROM orders where id = :id", {"id": order.id_}
|
||||
).fetchone()
|
||||
|
||||
assert record[0] == order.id_
|
||||
assert record[1] == order.side.value
|
||||
assert record[2] == order.type_.value
|
||||
assert record[3] == order.instrument
|
||||
assert record[4] == str(order.limit_price)
|
||||
assert record[5] == str(order.quantity)
|
||||
assert record[6] == order.created_at
|
Loading…
x
Reference in New Issue
Block a user