Skip to content

Add endpoint to enroll with provided CSR #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.6.4
current_version = 1.7.0
commit = False
tag = False

Expand Down
2,443 changes: 1,239 additions & 1,204 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rasenmaeher_api"
version = "1.6.4"
version = "1.7.0"
description = "python-rasenmaeher-api"
authors = [
"Aciid <703382+Aciid@users.noreply.github.com>",
Expand Down Expand Up @@ -82,7 +82,8 @@ multikeyjwt = "^1.0"
uvicorn = {version = "^0.20", extras = ["standard"]}
gunicorn = "^20.1"
pyopenssl = "^23.1"
libpvarki = { git="https://github.com/pvarki/python-libpvarki.git", tag="1.9.0"}
# Can't update to 2.0 before pydantic migration is done
libpvarki = { git="https://github.com/pvarki/python-libpvarki.git", tag="1.9.1"}
openapi-readme = "^0.2"
python-multipart = "^0.0.6"
aiohttp = ">=3.11.10,<4.0"
Expand Down
2 changes: 1 addition & 1 deletion src/rasenmaeher_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""python-rasenmaeher-api"""

__version__ = "1.6.4" # NOTE Use `bump2version --config-file patch` to bump versions correctly
__version__ = "1.7.0" # NOTE Use `bump2version --config-file patch` to bump versions correctly
17 changes: 13 additions & 4 deletions src/rasenmaeher_api/db/enrollments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .errors import ForbiddenOperation, CallsignReserved, NotFound, Deleted, PoolInactive
from ..rmsettings import RMSettings
from .engine import EngineWrapper
from ..web.api.utils.csr_utils import verify_csr

LOGGER = logging.getLogger(__name__)
CODE_ALPHABET = string.ascii_uppercase + string.digits
Expand Down Expand Up @@ -51,13 +52,13 @@ async def by_pk_or_invitecode(cls, inval: Union[str, uuid.UUID], allow_deleted:
except ValueError:
return await cls.by_invitecode(str(inval), allow_deleted)

async def create_enrollment(self, callsign: str) -> "Enrollment":
async def create_enrollment(self, callsign: str, csr: Optional[str] = None) -> "Enrollment":
"""Create enrollment from this pool"""
if not self.active:
raise PoolInactive()
if self.deleted:
raise Deleted("Can't create enrollments on deleted pools")
return await Enrollment.create_for_callsign(callsign, self, self.extra)
return await Enrollment.create_for_callsign(callsign, self, self.extra, csr)

async def set_active(self, state: bool) -> "EnrollmentPool":
"""Set active and return refreshed object"""
Expand Down Expand Up @@ -168,6 +169,7 @@ class Enrollment(ORMBaseModel, table=True): # type: ignore[call-arg,misc]
)
state: int = Field(nullable=False, index=False, unique=False, default=EnrollmentState.PENDING)
extra: Dict[str, Any] = Field(sa_type=JSONB, nullable=False, sa_column_kwargs={"server_default": "{}"})
csr: Optional[str] = Field(default=None, nullable=True)

@classmethod
async def by_pk_or_callsign(cls, inval: Union[str, uuid.UUID]) -> "Enrollment":
Expand All @@ -180,7 +182,7 @@ async def by_pk_or_callsign(cls, inval: Union[str, uuid.UUID]) -> "Enrollment":
async def approve(self, approver: Person) -> Person:
"""Creates the person record, their certs etc"""
with EngineWrapper.get_session() as session:
person = await Person.create_with_cert(self.callsign, extra=self.extra)
person = await Person.create_with_cert(self.callsign, extra=self.extra, csrpem=self.csr)
self.state = EnrollmentState.APPROVED
self.decided_by = approver.pk
self.decided_on = datetime.datetime.now(datetime.UTC)
Expand Down Expand Up @@ -273,11 +275,17 @@ async def _generate_unused_code(cls) -> str:

@classmethod
async def create_for_callsign(
cls, callsign: str, pool: Optional[EnrollmentPool] = None, extra: Optional[Dict[str, Any]] = None
cls,
callsign: str,
pool: Optional[EnrollmentPool] = None,
extra: Optional[Dict[str, Any]] = None,
csr: Optional[str] = None,
) -> "Enrollment":
"""Create a new one with random code for the callsign"""
if callsign in RMSettings.singleton().valid_product_cns:
raise CallsignReserved("Using product CNs as callsigns is forbidden")
if csr and not verify_csr(csr, callsign):
raise CallsignReserved("CSR CN must match callsign")
with EngineWrapper.get_session() as session:
try:
await Enrollment.by_callsign(callsign)
Expand All @@ -294,6 +302,7 @@ async def create_for_callsign(
state=EnrollmentState.PENDING,
extra=extra,
pool=poolpk,
csr=csr,
)
session.add(obj)
session.commit()
Expand Down
19 changes: 15 additions & 4 deletions src/rasenmaeher_api/db/people.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..rmsettings import RMSettings
from ..kchelpers import KCClient, KCUserData
from .engine import EngineWrapper
from ..web.api.utils.csr_utils import verify_csr

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,8 +91,12 @@ async def by_pk_or_callsign(cls, inval: Union[str, uuid.UUID], allow_deleted: bo
return await cls.by_callsign(str(inval), allow_deleted)

@classmethod
async def create_with_cert(cls, callsign: str, extra: Optional[Dict[str, Any]] = None) -> "Person":
async def create_with_cert(
cls, callsign: str, extra: Optional[Dict[str, Any]] = None, csrpem: Optional[str] = None
) -> "Person":
"""Create the cert etc and save the person"""
if csrpem and not verify_csr(csrpem, callsign):
raise CallsignReserved("CSR CN must match callsign")
cnf = RMSettings.singleton()
if callsign in cnf.valid_product_cns:
raise CallsignReserved("Using product CNs as callsigns is forbidden")
Expand All @@ -110,8 +115,11 @@ async def create_with_cert(cls, callsign: str, extra: Optional[Dict[str, Any]] =
newperson = Person(pk=puuid, callsign=callsign, certspath=str(certspath), extra=extra)
session.add(newperson)
session.commit()
ckp = await async_create_keypair(newperson.privkeyfile, newperson.pubkeyfile)
csrpem = await async_create_client_csr(ckp, newperson.csrfile, newperson.certsubject)
if csrpem:
newperson.csrfile.write_text(csrpem, encoding="utf-8")
else:
ckp = await async_create_keypair(newperson.privkeyfile, newperson.pubkeyfile)
csrpem = await async_create_client_csr(ckp, newperson.csrfile, newperson.certsubject)
certpem = (await sign_csr(csrpem)).replace("\\n", "\n")
newperson.certfile.write_text(certpem)
except Exception as exc:
Expand Down Expand Up @@ -150,7 +158,10 @@ async def create_pfx(self) -> Path:
def write_pfx() -> None:
"""Do the IO"""
nonlocal self
p12bytes = convert_pem_to_pkcs12(self.certfile, self.privkeyfile, self.callsign, None, self.callsign)
if self.privkeyfile.exists():
p12bytes = convert_pem_to_pkcs12(self.certfile, self.privkeyfile, self.callsign, None, self.callsign)
else:
p12bytes = convert_pem_to_pkcs12(self.certfile, None, self.callsign, None, self.callsign)
self.pfxfile.write_bytes(p12bytes)

await asyncio.get_event_loop().run_in_executor(None, write_pfx)
Expand Down
30 changes: 29 additions & 1 deletion src/rasenmaeher_api/web/api/enduserpfx/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,31 @@
LOGGER = logging.getLogger(__name__)


@router.get(f"/{{callsign}}_{RMSettings.singleton().deployment_name}.pem")
@router.get("/{callsign}.pem")
async def get_user_pem(
callsign: str,
person: Person = Depends(ValidUser(auto_error=True)),
) -> FileResponse:
"""Get the signed cert in PEM format (no keys)"""
deplosuffix = f"_{RMSettings.singleton().deployment_name}.pem"
if callsign.endswith(deplosuffix):
callsign = callsign[: -len(deplosuffix)]
if callsign.endswith(".pem"):
callsign = callsign[:-4]
LOGGER.debug("PEM: Called with callsign={}".format(callsign))
if person.callsign != callsign:
raise HTTPException(status_code=403, detail="Callsign must match authenticated user")
# Make sure the pfx exists, this is no-op if it does
await person.create_pfx()

return FileResponse(
path=person.certfile,
media_type="application/x-pem-file",
filename=f"{callsign}_{RMSettings.singleton().deployment_name}.pem",
)


@router.get(f"/{{callsign}}_{RMSettings.singleton().deployment_name}.pfx")
@router.get("/{callsign}.pfx")
@router.get("/{callsign}")
Expand All @@ -30,7 +55,10 @@ async def get_user_pfx(
callsign = callsign[: -len(deplosuffix)]
if callsign.endswith(".pfx"):
callsign = callsign[:-4]
LOGGER.debug("Called with callsign={}".format(callsign))
if callsign.endswith(".pem"):
LOGGER.debug("PFX: got .pem suffix, delegating")
return await get_user_pem(callsign, person)
LOGGER.debug("PFX: Called with callsign={}".format(callsign))
if person.callsign != callsign:
raise HTTPException(status_code=403, detail="Callsign must match authenticated user")
# Make sure the pfx exists, this is no-op if it does
Expand Down
4 changes: 3 additions & 1 deletion src/rasenmaeher_api/web/api/enrollment/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Schema for enrollment."""

from typing import List, Dict, Any
from typing import List, Dict, Any, Optional

from pydantic import BaseModel, Extra, Field

Expand Down Expand Up @@ -105,6 +105,7 @@ class EnrollmentInitIn(BaseModel): # pylint: disable=too-few-public-methods
"""Enrollment init in response schema"""

callsign: str = Field(description="Callsign to create enrollment for")
csr: Optional[str] = Field(description="CSR for mTLS key in PEM format", default=None)

class Config: # pylint: disable=too-few-public-methods
"""Example values for schema"""
Expand Down Expand Up @@ -421,6 +422,7 @@ class EnrollmentInviteCodeEnrollIn(BaseModel):

invite_code: str
callsign: str
csr: Optional[str] = Field(description="CSR for mTLS key in PEM format", default=None)

class Config: # pylint: disable=too-few-public-methods
"""Example values for schema"""
Expand Down
6 changes: 4 additions & 2 deletions src/rasenmaeher_api/web/api/enrollment/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ async def request_enrollment_init(

# TODO ADD POOL NAME CHECK

new_enrollment = await Enrollment.create_for_callsign(callsign=request_in.callsign, pool=None, extra={})
new_enrollment = await Enrollment.create_for_callsign(
callsign=request_in.callsign, pool=None, extra={}, csr=request_in.csr
)
# Create JWT token for user
claims = {"sub": request_in.callsign}
new_jwt = Issuer.singleton().issue(claims)
Expand Down Expand Up @@ -414,7 +416,7 @@ async def post_enroll_invite_code(
except NotFound:
pass

enrollment = await obj.create_enrollment(callsign=request_in.callsign)
enrollment = await obj.create_enrollment(callsign=request_in.callsign, csr=request_in.csr)

# Create JWT token for user
claims = {"sub": request_in.callsign}
Expand Down
5 changes: 4 additions & 1 deletion src/rasenmaeher_api/web/api/firstuser/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Schema for enrollment."""

from pydantic import BaseModel, Extra
from typing import Optional

from pydantic import BaseModel, Extra, Field


class FirstuserCheckCodeIn(BaseModel): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -51,6 +53,7 @@ class FirstuserAddAdminIn(BaseModel): # pylint: disable=too-few-public-methods

# temp_admin_code: str
callsign: str
csr: Optional[str] = Field(default=None, description="CSR for mTLS key in PEM format")

class Config: # pylint: disable=too-few-public-methods
"""Example values for schema"""
Expand Down
4 changes: 3 additions & 1 deletion src/rasenmaeher_api/web/api/firstuser/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ async def post_admin_add(
await anon_user.assign_role(role="anon_admin")

# Create new admin user enrollment
enrollment = await Enrollment.create_for_callsign(callsign=request_in.callsign, pool=None, extra={})
enrollment = await Enrollment.create_for_callsign(
callsign=request_in.callsign, pool=None, extra={}, csr=request_in.csr
)

# Get the anon_admin 'user' that will be used to approve the new admin user
# and approve the user
Expand Down
8 changes: 4 additions & 4 deletions src/rasenmaeher_api/web/api/instructions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import cast, Optional
import logging

from fastapi import Depends, APIRouter, Request
from fastapi import Depends, APIRouter, Request, HTTPException
from libpvarki.schemas.product import UserCRUDRequest, UserInstructionFragment


Expand Down Expand Up @@ -71,8 +71,8 @@ async def get_product_instructions(request: Request, product: str, language: str
endpoint_url = f"api/v1/instructions/{language}"
response = await post_to_product(product, endpoint_url, user.dict(), InstructionData)
if response is None:
LOGGER.error("post_to_product({}, {}): failed".format(product, endpoint_url))
# TODO: Raise a reasonable error instead
return None
_reason = f"Unable to get instructions for {product}"
LOGGER.error("{} : {}".format(request.url, _reason))
raise HTTPException(status_code=404, detail=_reason)
response = cast(InstructionData, response)
return response
2 changes: 1 addition & 1 deletion src/rasenmaeher_api/web/api/middleware/mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class MTLSorJWT(HTTPBase): # pylint: disable=too-few-public-methods
"""Auth either by JWT or mTLS header"""

def __init__( # pylint: disable=R0913
def __init__( # pylint: disable=too-many-arguments
self,
*,
scheme: str = "header",
Expand Down
25 changes: 25 additions & 0 deletions src/rasenmaeher_api/web/api/utils/csr_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Utils for checking CSR:s etc"""

import logging

from cryptography import x509
from libadvian.binpackers import ensure_utf8

LOGGER = logging.getLogger(__name__)


# FIXME: This should be part of libpvarki


def verify_csr(csrpem: str, callsign: str) -> bool:
"""Verify CSR matches our rules for CN/DN for the given callsign"""
csr = x509.load_pem_x509_csr(ensure_utf8(csrpem))
dn = csr.subject.rfc4514_string()
LOGGER.debug("DN={} callsign={}".format(dn, callsign))
if f"CN={callsign}" not in dn:
LOGGER.warning("Callsign does not match CSR subject. DN={} callsign={}".format(dn, callsign))
return False
# TODO: check that keyusages in the CSR are fine
# crypto.X509Extension(b"keyUsage", True, b"digitalSignature,nonRepudiation,keyEncipherment"),
# crypto.X509Extension(b"extendedKeyUsage", True, b"clientAuth"),
return True
10 changes: 5 additions & 5 deletions tests/ptfpapi/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aiohttp==3.8.6
cryptography==41.0.5
fastapi==0.104.1
libadvian==1.4.0
aiohttp==3.11.18
cryptography==41.0.7
fastapi==0.115.12
libadvian==1.7.0
pyOpenSSL==23.3.0
Brotli==1.1.0
libpvarki @ git+https://github.com/pvarki/python-libpvarki.git@1.7.0
libpvarki @ git+https://github.com/pvarki/python-libpvarki.git@1.9.1
Loading