diff --git a/app/api.py b/app/api.py index f86c60e..c8fd9bd 100644 --- a/app/api.py +++ b/app/api.py @@ -1,13 +1,15 @@ -from typing import Optional +import os +import sqlite3 +from typing import Annotated, Optional -from fastapi import FastAPI +from fastapi import Depends, FastAPI +from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, condecimal, conint, constr, root_validator +import app.database as database +from app.order_processor import FailedToProcessOrderError, process_order from app.types import Order, OrderSide, OrderType -from datetime import datetime -from app.stock_exchange import place_order - app = FastAPI() @@ -35,21 +37,33 @@ class CreateOrderResponseModel(Order): pass -@app.post( - "/orders", - status_code=201, - response_model=CreateOrderResponseModel, - response_model_by_alias=True, -) -async def create_order(model: CreateOrderModel): - place_order(model) - return CreateOrderResponseModel( - id=1, - created_at=datetime.now(), - type=model.type_, - side=model.side, - instrument=model.instrument, - limit_price=model.limit_price, - quantity=model.quantity +async def get_db(): + # Replace with proper db connection pooling implementation + db = database.connect(os.getenv("DB_CONNECTION")) + database.apply_database_schema(db) + try: + yield db + finally: + db.close() - ) + +@app.post("/orders", status_code=201) +async def create_order( + model: CreateOrderModel, db: Annotated[sqlite3.Connection, Depends(get_db)] +): + try: + order = process_order(db, model) + return CreateOrderResponseModel( + id=order.id_, + created_at=order.created_at, + type=order.type_, + side=order.side, + instrument=order.instrument, + limit_price=order.limit_price, + quantity=order.quantity, + ) + except FailedToProcessOrderError: + return JSONResponse( + content={"message": "Internal server error while placing the order"}, + status_code=500, + ) diff --git a/tests/api_test.py b/tests/api_test.py index 5f73189..84be98a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1,11 +1,27 @@ -import pytest +def test_placing_an_order(client): + response = client.post( + "/orders", + json={ + "type": "limit", + "instrument": "US0378331005", + "limit_price": 100.00, + "quantity": 10, + "side": "buy", + }, + ) -def test_rudimentary_request(client): - response = client.post("/orders", json={ - "type": "limit", - "instrument": "US0378331005", - "limit_price": 100.00, - "quantity": 10, - "side": "buy" - }) - assert response.status_code == 201 + assert response.status_code == 201 or response.status_code == 500 + + +def test_create_order_with_invalid_request(client): + response = client.post( + "/orders", + json={ + "type": "unknown_type", + "instrument": "US0378331005", + "limit_price": 100.00, + "quantity": 10, + "side": "buy", + }, + ) + assert response.status_code == 422 diff --git a/tests/conftest.py b/tests/conftest.py index b625ad0..b7bb0ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ from app.api import app @pytest.fixture def client() -> TestClient: + os.environ["DB_CONNECTION"] = ":memory:" return TestClient(app)