Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from app.database_models.user_model import UserModel
from app.database_models.refresh_token_model import RefreshTokenModel
from app.database_models.confirm_token_model import ConfirmTokenModel
from app.database_models.password_reset_token_model import PasswordResetTokenModel

# Open API file path
open_api_file_name = "makeascene.openapi.yaml"
Expand Down Expand Up @@ -79,7 +80,7 @@ def setup_database(app: Flask):
table_obj.create(db.engine)
app.logger.info(f"Created table: {table_name}")
except Exception as e:
app.logger.error(f"Error creating table {table_name}")
app.logger.error(f"Error creating table {table_name} with error {e}")
else:
app.logger.info(f"Table {table_name} already exists")

Expand All @@ -102,7 +103,8 @@ def setup_services(app: Flask):
app.user_service = UserService(storage_unit_of_work.user_repo, storage_unit_of_work.email_repo,
storage_unit_of_work.confirm_token_repo)
app.auth_service = AuthService(storage_unit_of_work.user_repo, storage_unit_of_work.refresh_token_repo,
storage_unit_of_work.confirm_token_repo)
storage_unit_of_work.confirm_token_repo, storage_unit_of_work.email_repo,
storage_unit_of_work.password_reset_token_repo)
app.image_service = ImageService(storage_unit_of_work.image_storage,
base_url=os.environ.get("BASE_URL", "http://127.0.0.1:5000"))
app.google_oauth_service = GoogleOauthService(storage_unit_of_work.user_repo,
Expand Down
31 changes: 31 additions & 0 deletions app/database_models/password_reset_token_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from datetime import datetime, timezone, timedelta

from sqlalchemy import Integer, String, DateTime, Boolean, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.extensions import db


class PasswordResetTokenModel(db.Model):
__tablename__ = "password_reset_tokens"

id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token_hash: Mapped[str] = mapped_column(String(64), nullable=False, unique=True, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)
)
expires_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc) + timedelta(minutes=15)
)
revoked: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)

user = relationship(
"UserModel", back_populates="password_reset_tokens"
)
3 changes: 3 additions & 0 deletions app/database_models/user_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ class UserModel(db.Model):
confirm_tokens = relationship(
"ConfirmTokenModel", back_populates="user", cascade="all, delete-orphan"
)
password_reset_tokens = relationship(
"PasswordResetTokenModel", back_populates="user", cascade="all, delete-orphan"
)
12 changes: 12 additions & 0 deletions app/domain_models/password_reset_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass
from datetime import datetime


@dataclass
class PasswordResetToken:
id: int
user_id: int
hashed_token: str
created_at: datetime
expires_at: datetime
revoked: bool
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Protocol

from app.domain_models.password_reset_token import PasswordResetToken
from app.domain_models.user import User


class PasswordResetTokenRepoProtocol(Protocol):
def __init__(self, session): ...

def create(self, hashed_token: str, user: User) -> bool: ...

def get_by_token_hash(self, hashed_token: str) -> PasswordResetToken | None: ...

def get_by_user(self, user: User) -> list[PasswordResetToken]: ...

def revoke(self, password_reset_token: PasswordResetToken) -> bool: ...
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ def create(self, hashed_token: str, user: User) -> bool: ...

def get_by_token_hash(self, hashed_token: str) -> RefreshToken | None: ...

def get_by_user(self, user: User) -> list[RefreshToken]: ...

def revoke(self, refresh_token: RefreshToken) -> bool: ...
44 changes: 44 additions & 0 deletions app/repositories/storage/sql_password_reset_token_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

from app.database_models.password_reset_token_model import PasswordResetTokenModel
from app.domain_models.password_reset_token import PasswordResetToken
from app.domain_models.user import User


class SqlPasswordResetTokenRepo:
def __init__(self, session):
self.session = session

def create(self, hashed_token: str, user: User) -> bool:
if self.get_by_token_hash(hashed_token):
return False
db_object = PasswordResetTokenModel(user_id=user.id, token_hash=hashed_token)
self.session.add(db_object)
self.session.commit()
return True

def get_by_token_hash(self, hashed_token: str) -> PasswordResetToken | None:
db_object = self.session.query(PasswordResetTokenModel).filter_by(token_hash=hashed_token).first()
if not db_object:
return None
return PasswordResetToken(
id=db_object.id,
user_id=db_object.user_id,
hashed_token=db_object.token_hash,
created_at=db_object.created_at,
expires_at=db_object.expires_at,
revoked=db_object.revoked
)

def get_by_user(self, user: User) -> list[PasswordResetToken]:
db_list = self.session.query(PasswordResetTokenModel).filter_by(user_id=user.id).all()
return [PasswordResetToken(id=db_object.id, user_id=db_object.user_id, hashed_token=db_object.token_hash,
created_at=db_object.created_at, expires_at=db_object.expires_at,
revoked=db_object.revoked) for db_object in db_list]

def revoke(self, password_reset_token: PasswordResetToken) -> bool:
db_object = self.session.get(PasswordResetTokenModel, password_reset_token.id)
if not db_object:
return False
db_object.revoked = True
self.session.commit()
return True
6 changes: 6 additions & 0 deletions app/repositories/storage/sql_refresh_token_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def get_by_token_hash(self, hashed_token: str) -> RefreshToken | None:
revoked=db_object.revoked
)

def get_by_user(self, user: User) -> list[RefreshToken]:
db_list = self.session.query(RefreshTokenModel).filter_by(user_id=user.id).all()
return [RefreshToken(id=db_object.id, user_id=db_object.user_id, hashed_token=db_object.token_hash,
created_at=db_object.created_at, expires_at=db_object.expires_at,
revoked=db_object.revoked) for db_object in db_list]

def revoke(self, refresh_token: RefreshToken) -> bool:
db_object = self.session.get(RefreshTokenModel, refresh_token.id)
if not db_object:
Expand Down
6 changes: 4 additions & 2 deletions app/repositories/storage/sql_user_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def get_user(self, user_id: int) -> User | None:
db_user = self.session.get(UserModel, user_id)
if not db_user:
return None
return User(id=db_user.id, hashed_password=db_user.password, oauth=db_user.oauth_method, email=db_user.email)
return User(id=db_user.id, hashed_password=db_user.password, oauth=db_user.oauth_method, email=db_user.email,
confirmed=db_user.confirmed)

def get_user_by_email(self, email: str) -> User | None:
db_user = self.session.query(UserModel).filter_by(email=email).first()
if not db_user:
return None
return User(id=db_user.id, hashed_password=db_user.password, oauth=db_user.oauth_method, email=db_user.email, confirmed=db_user.confirmed)
return User(id=db_user.id, hashed_password=db_user.password, oauth=db_user.oauth_method, email=db_user.email,
confirmed=db_user.confirmed)

def create_user(self, email: str, password: str, oauth="local") -> bool:
if self.get_user_by_email(email):
Expand Down
3 changes: 3 additions & 0 deletions app/repositories/units_of_work/deploy_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from app.repositories.interfaces.external.email_protocol import EmailProtocol
from app.repositories.interfaces.storage.confirm_token_repo_protocol import ConfirmTokenRepoProtocol
from app.repositories.interfaces.storage.image_storage_protocol import ImageStorageProtocol
from app.repositories.interfaces.storage.password_reset_token_repo_protocol import PasswordResetTokenRepoProtocol
from app.repositories.interfaces.storage.refresh_token_repo_protocol import RefreshTokenRepoProtocol
from app.repositories.interfaces.storage.user_repo_protocol import UserRepoProtocol
from app.repositories.storage.minio_image_storage import MinioImageStorage
from app.repositories.storage.sql_confirm_token_repo import SqlConfirmTokenRepo
from app.repositories.storage.sql_password_reset_token_repo import SqlPasswordResetTokenRepo
from app.repositories.storage.sql_refresh_token_repo import SqlRefreshTokenRepo
from app.repositories.storage.sql_user_repo import SqlUserRepo

Expand All @@ -21,3 +23,4 @@ def __init__(self):
self.refresh_token_repo: RefreshTokenRepoProtocol = SqlRefreshTokenRepo(db.session)
self.email_repo: EmailProtocol = ResendEmailRepo()
self.confirm_token_repo: ConfirmTokenRepoProtocol = SqlConfirmTokenRepo(db.session)
self.password_reset_token_repo: PasswordResetTokenRepoProtocol = SqlPasswordResetTokenRepo(db.session)
3 changes: 3 additions & 0 deletions app/repositories/units_of_work/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from app.repositories.interfaces.external.email_protocol import EmailProtocol
from app.repositories.interfaces.storage.confirm_token_repo_protocol import ConfirmTokenRepoProtocol
from app.repositories.interfaces.storage.image_storage_protocol import ImageStorageProtocol
from app.repositories.interfaces.storage.password_reset_token_repo_protocol import PasswordResetTokenRepoProtocol
from app.repositories.interfaces.storage.refresh_token_repo_protocol import RefreshTokenRepoProtocol
from app.repositories.interfaces.storage.user_repo_protocol import UserRepoProtocol
from app.repositories.storage.mem_image_storage import InMemoryImageStorage
from app.repositories.storage.sql_confirm_token_repo import SqlConfirmTokenRepo
from app.repositories.storage.sql_password_reset_token_repo import SqlPasswordResetTokenRepo
from app.repositories.storage.sql_refresh_token_repo import SqlRefreshTokenRepo
from app.repositories.storage.sql_user_repo import SqlUserRepo

Expand All @@ -22,3 +24,4 @@ def __init__(self):
self.refresh_token_repo: RefreshTokenRepoProtocol = SqlRefreshTokenRepo(db.session)
self.email_repo: EmailProtocol = ResendEmailRepo()
self.confirm_token_repo: ConfirmTokenRepoProtocol = SqlConfirmTokenRepo(db.session)
self.password_reset_token_repo: PasswordResetTokenRepoProtocol = SqlPasswordResetTokenRepo(db.session)
71 changes: 61 additions & 10 deletions app/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from authlib.integrations.base_client import MismatchingStateError
from flask import Blueprint, request, current_app, url_for

from app import validate
from app import validate, login_required
from app.domain_models.user import User
from app.services.password_service import PasswordService

auth = Blueprint('auth', __name__)

Expand Down Expand Up @@ -34,6 +36,15 @@ def login():
}, 201


@auth.route("/logout", methods=["POST"])
@validate
@login_required
def logout(user: User):
if current_app.auth_service.logout(user):
return {"message": "Logout successful."}, 200
return {}, 500


@auth.route("/oauth/redirect", methods=["GET"])
@validate
def redirect_to():
Expand Down Expand Up @@ -62,32 +73,72 @@ def google_callback():
if not user_info:
return {"error": "Failed to fetch user info from Google"}, 400

res = current_app.google_oauth_service.authenticate_user(user_info)
if not res:
exchange = current_app.google_oauth_service.authenticate_user(user_info)
if not exchange:
return {"error": "Failed to authenticate user"}, 400
token, refresh_token = res

return {
"message": "Login successful. Use token for authentication.",
"exchange": exchange
}, 201


@auth.route("/oauth/exchange", methods=["POST"])
@validate
def exchange_token():
data = request.get_json()
exchange = data.get("exchange")
res = current_app.google_oauth_service.exchange(exchange)
if not res:
return {
"error": "unauthorized",
"message": "Invalid exchange token."
}, 401
token, refresh_token = res
return {
"message": "Login successful.",
"access_token": token,
"refresh_token": refresh_token
}, 201
}, 200


@auth.route("/email/confirm", methods=["GET"])
@validate
def confirm_password():
def confirm_email():
params = request.args
token = params.get("token")
if current_app.auth_service.confirm_email(token):
return {"message": "Email confirmed."}, 200
return {"error": "unauthorized", "message": "Invalid credentials."}, 401


@auth.route("/email/resend", methods=["GET"])
@auth.route("/email/resend", methods=["POST"])
@validate
def resend_mail():
params = request.args
email = params.get("email")
data = request.get_json()
email = (data.get("email") or "").strip().replace(" ", "")
current_app.user_service.resend_confirmation_email(email)
return {"message": "Email sent if an unconfirmed user with that email exists."}, 200


@auth.route("/password/request-reset", methods=["POST"])
@validate
def request_reset_password():
data = request.get_json()
email = (data.get("email") or "").strip().replace(" ", "")
if current_app.auth_service.request_password_reset(email):
return {"message": "Password reset email sent if an account with that email exists."}, 200
return {}, 500


@auth.route("/password/reset", methods=["POST"])
@validate
def reset_password():
data = request.get_json()
token = data.get("token")
new_password = data.get("new_password")
if current_app.auth_service.reset_password(token, new_password):
return {"message": "Password reset successful."}, 200
return {
"error": "unauthorized",
"message": "Invalid token."
}, 401
Loading
Loading