Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This project provides all the code necessary to deploy the APE classification mo

## Prerequisites

- Python 3.12
- Python 3.13

## Setup

Expand Down
26 changes: 13 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@ description = "API for deployment of the classification model for APE nomenclatu
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"fastapi>=0.115.12",
"mlflow>=2.21.2",
"nltk>=3.9.1",
"numpy>=2.2.4",
"fastapi>=0.138.1",
"mlflow>=3.14.0",
"nltk>=3.9.4",
"numpy>=2.5.0",
"omegaconf>=2.3.0",
"pandas>=2.2.3",
"pendulum>=3.0.0",
"pyarrow>=19.0.1",
"pydantic>=2.11.1",
"requests>=2.32.3",
"s3fs>=2025.3.2",
"torch>=2.6.0",
"pandas>=2.3.3",
"pendulum>=3.2.0",
"pyarrow>=24.0.0",
"pydantic>=2.13.4",
"requests>=2.34.2",
"s3fs>=2026.4.0",
"torch>=2.12.1",
"torchfasttext",
"tqdm>=4.67.1",
"unidecode>=1.3.8",
"tqdm>=4.68.3",
"unidecode>=1.4.0",
"uvicorn>=0.34.0",
]
authors = [
Expand Down
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ uv run pre-commit install
export MLFLOW_S3_ENDPOINT_URL="https://$AWS_S3_ENDPOINT"
export MLFLOW_TRACKING_URI=https://projet-ape-mlflow.user.lab.sspcloud.fr
export MLFLOW_MODEL_NAME=FastText-pytorch-2025
export MLFLOW_MODEL_VERSION="6"
export MLFLOW_MODEL_VERSION="12"
export API_USERNAME=username
export API_PASSWORD=password
export AUTH_API=False
2 changes: 1 addition & 1 deletion src/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def lifespan(app: FastAPI):
logger.info("🚀 Starting API lifespan")

app.state.model = load_model()
app.state.run_id = app.state.model.metadata.run_id
app.state.model_id = app.state.model.metadata.model_id

yield
logger.info("🛑 Shutting down API lifespan")
Expand Down
4 changes: 2 additions & 2 deletions src/api/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class OutputResponse(RootModel[Dict[str, Union[Prediction, float, str]]]):
"""
Contract for the output response of the API including:
- KV of PredictionResponse: normalized prediction responses generated by the model artifact
- MLversion: run_id as version of the ML model
- MLversion: model_id as version of the ML model

Expected flat structure after normalization:

Expand All @@ -22,7 +22,7 @@ class OutputResponse(RootModel[Dict[str, Union[Prediction, float, str]]]):
"2": Prediction,
...,
"IC": float, # required confidence score
"MLversion": str # required run_id as model version
"MLversion": str # required model_id as model version
}

Notes:
Expand Down
2 changes: 1 addition & 1 deletion src/api/routes/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ async def predict(

output = request.app.state.model.predict(input_data, params=params_dict)
return [
OutputResponse({**out.model_dump(), "MLversion": request.app.state.run_id})
OutputResponse({**out.model_dump(), "MLversion": request.app.state.model_id})
for out in output
]
3 changes: 2 additions & 1 deletion src/utils/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@


def load_model():
# mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])
model_uri = f"models:/{os.environ['MLFLOW_MODEL_NAME']}/{os.environ['MLFLOW_MODEL_VERSION']}"

# Step 1: Set the destination path for the model artifacts
dst_path = "/tmp/my_model"
dst_path = "/tmp/my_model/artifacts/pyfunc_model"

# Step 2: Download/extract the model here *without loading it yet*
mlflow.artifacts.download_artifacts(artifact_uri=model_uri, dst_path=dst_path)
Expand Down
1 change: 1 addition & 0 deletions src/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def configure_logging():
logging.StreamHandler(),
],
)
logging.getLogger("mlflow").setLevel(logging.ERROR)


def log_prediction(query: dict, response: OutputResponse, index: int = 0):
Expand Down
684 changes: 497 additions & 187 deletions uv.lock

Large diffs are not rendered by default.

Loading