diff --git a/infrastructure/modules/lambda/lambda.tf b/infrastructure/modules/lambda/lambda.tf index 889cd6d2a..6485be675 100644 --- a/infrastructure/modules/lambda/lambda.tf +++ b/infrastructure/modules/lambda/lambda.tf @@ -17,14 +17,15 @@ resource "aws_lambda_function" "eligibility_signposting_lambda" { environment { variables = { - PERSON_TABLE_NAME = var.eligibility_status_table_name, - RULES_BUCKET_NAME = var.eligibility_rules_bucket_name, - KINESIS_AUDIT_STREAM_TO_S3 = var.kinesis_audit_stream_to_s3_name - ENV = var.environment - LOG_LEVEL = var.log_level - ENABLE_XRAY_PATCHING = var.enable_xray_patching - API_DOMAIN_NAME = var.api_domain_name - HASHING_SECRET_NAME = var.hashing_secret_name + PERSON_TABLE_NAME = var.eligibility_status_table_name, + RULES_BUCKET_NAME = var.eligibility_rules_bucket_name, + CONSUMER_MAPPING_BUCKET_NAME = var.eligibility_consumer_mappings_bucket_name, + KINESIS_AUDIT_STREAM_TO_S3 = var.kinesis_audit_stream_to_s3_name + ENV = var.environment + LOG_LEVEL = var.log_level + ENABLE_XRAY_PATCHING = var.enable_xray_patching + API_DOMAIN_NAME = var.api_domain_name + HASHING_SECRET_NAME = var.hashing_secret_name } } diff --git a/infrastructure/modules/lambda/variables.tf b/infrastructure/modules/lambda/variables.tf index 85b639862..6f238e149 100644 --- a/infrastructure/modules/lambda/variables.tf +++ b/infrastructure/modules/lambda/variables.tf @@ -44,6 +44,11 @@ variable "eligibility_rules_bucket_name" { type = string } +variable "eligibility_consumer_mappings_bucket_name" { + description = "consumer mappings bucket name" + type = string +} + variable "eligibility_status_table_name" { description = "eligibility datastore table name" type = string diff --git a/infrastructure/stacks/api-layer/iam_policies.tf b/infrastructure/stacks/api-layer/iam_policies.tf index 0fd67e453..28eb0ac2c 100644 --- a/infrastructure/stacks/api-layer/iam_policies.tf +++ b/infrastructure/stacks/api-layer/iam_policies.tf @@ -104,6 +104,60 @@ data "aws_iam_policy_document" "rules_s3_bucket_policy" { } } +# Policy doc for S3 Consumer Mappings bucket +data "aws_iam_policy_document" "s3_consumer_mapping_bucket_policy" { + statement { + sid = "AllowSSLRequestsOnly" + actions = [ + "s3:GetObject", + "s3:ListBucket", + ] + resources = [ + module.s3_consumer_mappings_bucket.storage_bucket_arn, + "${module.s3_consumer_mappings_bucket.storage_bucket_arn}/*", + ] + condition { + test = "Bool" + values = ["true"] + variable = "aws:SecureTransport" + } + } +} + +# ensure only secure transport is allowed + +resource "aws_s3_bucket_policy" "consumer_mapping_s3_bucket" { + bucket = module.s3_consumer_mappings_bucket.storage_bucket_id + policy = data.aws_iam_policy_document.consumer_mapping_s3_bucket_policy.json +} + +data "aws_iam_policy_document" "consumer_mapping_s3_bucket_policy" { + statement { + sid = "AllowSslRequestsOnly" + actions = [ + "s3:*", + ] + effect = "Deny" + resources = [ + module.s3_consumer_mappings_bucket.storage_bucket_arn, + "${module.s3_consumer_mappings_bucket.storage_bucket_arn}/*", + ] + principals { + type = "*" + identifiers = ["*"] + } + condition { + test = "Bool" + values = [ + "false", + ] + + variable = "aws:SecureTransport" + } + } +} + +# audit bucket resource "aws_s3_bucket_policy" "audit_s3_bucket" { bucket = module.s3_audit_bucket.storage_bucket_id policy = data.aws_iam_policy_document.audit_s3_bucket_policy.json @@ -136,12 +190,18 @@ data "aws_iam_policy_document" "audit_s3_bucket_policy" { } # Attach s3 read policy to Lambda role -resource "aws_iam_role_policy" "lambda_s3_read_policy" { +resource "aws_iam_role_policy" "lambda_s3_rules_read_policy" { name = "S3ReadAccess" role = aws_iam_role.eligibility_lambda_role.id policy = data.aws_iam_policy_document.s3_rules_bucket_policy.json } +resource "aws_iam_role_policy" "lambda_s3_mapping_read_policy" { + name = "S3ConsumerMappingReadAccess" + role = aws_iam_role.eligibility_lambda_role.id + policy = data.aws_iam_policy_document.s3_consumer_mapping_bucket_policy.json +} + # Attach s3 write policy to kinesis firehose role resource "aws_iam_role_policy" "kinesis_firehose_s3_write_policy" { name = "S3WriteAccess" @@ -290,6 +350,38 @@ resource "aws_kms_key_policy" "s3_rules_kms_key" { policy = data.aws_iam_policy_document.s3_rules_kms_key_policy.json } +data "aws_iam_policy_document" "s3_consumer_mapping_kms_key_policy" { + #checkov:skip=CKV_AWS_111: Root user needs full KMS key management + #checkov:skip=CKV_AWS_356: Root user needs full KMS key management + #checkov:skip=CKV_AWS_109: Root user needs full KMS key management + statement { + sid = "EnableIamUserPermissions" + effect = "Allow" + principals { + type = "AWS" + identifiers = ["arn:aws:iam::${data.aws_caller_identity.current.account_id}:root"] + } + actions = ["kms:*"] + resources = ["*"] + } + + statement { + sid = "AllowLambdaDecrypt" + effect = "Allow" + principals { + type = "AWS" + identifiers = [aws_iam_role.eligibility_lambda_role.arn] + } + actions = ["kms:Decrypt"] + resources = ["*"] + } +} + +resource "aws_kms_key_policy" "s3_consumer_mapping_kms_key" { + key_id = module.s3_consumer_mappings_bucket.storage_bucket_kms_key_id + policy = data.aws_iam_policy_document.s3_consumer_mapping_kms_key_policy.json +} + resource "aws_iam_role_policy" "splunk_firehose_policy" { #checkov:skip=CKV_AWS_290: Firehose requires write access to dynamic log streams without static constraints #checkov:skip=CKV_AWS_355: Firehose logging requires wildcard resource for CloudWatch log groups/streams diff --git a/infrastructure/stacks/api-layer/lambda.tf b/infrastructure/stacks/api-layer/lambda.tf index 9b31fee49..f87c36588 100644 --- a/infrastructure/stacks/api-layer/lambda.tf +++ b/infrastructure/stacks/api-layer/lambda.tf @@ -11,27 +11,28 @@ data "aws_subnet" "private_subnets" { } module "eligibility_signposting_lambda_function" { - source = "../../modules/lambda" - eligibility_lambda_role_arn = aws_iam_role.eligibility_lambda_role.arn - eligibility_lambda_role_name = aws_iam_role.eligibility_lambda_role.name - workspace = local.workspace - environment = var.environment - runtime = "python3.13" - lambda_func_name = "${terraform.workspace == "default" ? "" : "${terraform.workspace}-"}eligibility_signposting_api" + source = "../../modules/lambda" + eligibility_lambda_role_arn = aws_iam_role.eligibility_lambda_role.arn + eligibility_lambda_role_name = aws_iam_role.eligibility_lambda_role.name + workspace = local.workspace + environment = var.environment + runtime = "python3.13" + lambda_func_name = "${terraform.workspace == "default" ? "" : "${terraform.workspace}-"}eligibility_signposting_api" security_group_ids = [data.aws_security_group.main_sg.id] - vpc_intra_subnets = [for v in data.aws_subnet.private_subnets : v.id] - file_name = "../../../dist/lambda.zip" - handler = "eligibility_signposting_api.app.lambda_handler" - eligibility_rules_bucket_name = module.s3_rules_bucket.storage_bucket_name - eligibility_status_table_name = module.eligibility_status_table.table_name - kinesis_audit_stream_to_s3_name = module.eligibility_audit_firehose_delivery_stream.firehose_stream_name - hashing_secret_name = module.secrets_manager.aws_hashing_secret_name - lambda_insights_extension_version = 38 - log_level = "INFO" - enable_xray_patching = "true" - stack_name = local.stack_name - provisioned_concurrency_count = 5 - api_domain_name = local.api_domain_name + vpc_intra_subnets = [for v in data.aws_subnet.private_subnets : v.id] + file_name = "../../../dist/lambda.zip" + handler = "eligibility_signposting_api.app.lambda_handler" + eligibility_rules_bucket_name = module.s3_rules_bucket.storage_bucket_name + eligibility_consumer_mappings_bucket_name = module.s3_consumer_mappings_bucket.storage_bucket_name + eligibility_status_table_name = module.eligibility_status_table.table_name + kinesis_audit_stream_to_s3_name = module.eligibility_audit_firehose_delivery_stream.firehose_stream_name + hashing_secret_name = module.secrets_manager.aws_hashing_secret_name + lambda_insights_extension_version = 38 + log_level = "INFO" + enable_xray_patching = "true" + stack_name = local.stack_name + provisioned_concurrency_count = 5 + api_domain_name = local.api_domain_name } # ----------------------------------------------------------------------------- diff --git a/infrastructure/stacks/api-layer/s3_buckets.tf b/infrastructure/stacks/api-layer/s3_buckets.tf index 1a94f7284..276e71354 100644 --- a/infrastructure/stacks/api-layer/s3_buckets.tf +++ b/infrastructure/stacks/api-layer/s3_buckets.tf @@ -7,6 +7,15 @@ module "s3_rules_bucket" { workspace = terraform.workspace } +module "s3_consumer_mappings_bucket" { + source = "../../modules/s3" + bucket_name = "eli-consumer-map" + environment = var.environment + project_name = var.project_name + stack_name = local.stack_name + workspace = terraform.workspace +} + module "s3_audit_bucket" { source = "../../modules/s3" bucket_name = "eli-audit" diff --git a/src/eligibility_signposting_api/common/api_error_response.py b/src/eligibility_signposting_api/common/api_error_response.py index 40c1ddcdd..cb1006584 100644 --- a/src/eligibility_signposting_api/common/api_error_response.py +++ b/src/eligibility_signposting_api/common/api_error_response.py @@ -135,3 +135,11 @@ def log_and_generate_response( fhir_error_code=FHIRSpineErrorCode.ACCESS_DENIED, fhir_display_message="Access has been denied to process this request.", ) + +CONSUMER_ID_NOT_PROVIDED_ERROR = APIErrorResponse( + status_code=HTTPStatus.FORBIDDEN, + fhir_issue_code=FHIRIssueCode.FORBIDDEN, + fhir_issue_severity=FHIRIssueSeverity.ERROR, + fhir_error_code=FHIRSpineErrorCode.ACCESS_DENIED, + fhir_display_message="Access has been denied to process this request.", +) diff --git a/src/eligibility_signposting_api/common/request_validator.py b/src/eligibility_signposting_api/common/request_validator.py index cd213287a..796b4239a 100644 --- a/src/eligibility_signposting_api/common/request_validator.py +++ b/src/eligibility_signposting_api/common/request_validator.py @@ -7,12 +7,13 @@ from flask.typing import ResponseReturnValue from eligibility_signposting_api.common.api_error_response import ( + CONSUMER_ID_NOT_PROVIDED_ERROR, INVALID_CATEGORY_ERROR, INVALID_CONDITION_FORMAT_ERROR, INVALID_INCLUDE_ACTIONS_ERROR, NHS_NUMBER_MISMATCH_ERROR, ) -from eligibility_signposting_api.config.constants import NHS_NUMBER_HEADER +from eligibility_signposting_api.config.constants import CONSUMER_ID, NHS_NUMBER_HEADER logger = logging.getLogger(__name__) @@ -56,6 +57,13 @@ def validate_request_params() -> Callable: def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> ResponseReturnValue: # noqa:ANN002,ANN003 + consumer_id = request.headers.get(CONSUMER_ID) + if not consumer_id: + message = "You are not authorised to request" + return CONSUMER_ID_NOT_PROVIDED_ERROR.log_and_generate_response( + log_message=message, diagnostics=message + ) + path_nhs_number = str(kwargs.get("nhs_number")) header_nhs_no = str(request.headers.get(NHS_NUMBER_HEADER)) diff --git a/src/eligibility_signposting_api/config/config.py b/src/eligibility_signposting_api/config/config.py index 6be1840aa..52f3111cc 100644 --- a/src/eligibility_signposting_api/config/config.py +++ b/src/eligibility_signposting_api/config/config.py @@ -22,6 +22,7 @@ def config() -> dict[str, Any]: person_table_name = TableName(os.getenv("PERSON_TABLE_NAME", "test_eligibility_datastore")) rules_bucket_name = BucketName(os.getenv("RULES_BUCKET_NAME", "test-rules-bucket")) + consumer_mapping_bucket_name = BucketName(os.getenv("CONSUMER_MAPPING_BUCKET_NAME", "test-consumer-mapping-bucket")) audit_bucket_name = BucketName(os.getenv("AUDIT_BUCKET_NAME", "test-audit-bucket")) hashing_secret_name = HashSecretName(os.getenv("HASHING_SECRET_NAME", "test_secret")) aws_default_region = AwsRegion(os.getenv("AWS_DEFAULT_REGION", "eu-west-1")) @@ -41,6 +42,7 @@ def config() -> dict[str, Any]: "s3_endpoint": None, "rules_bucket_name": rules_bucket_name, "audit_bucket_name": audit_bucket_name, + "consumer_mapping_bucket_name": consumer_mapping_bucket_name, "firehose_endpoint": None, "kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3, "enable_xray_patching": enable_xray_patching, @@ -59,6 +61,7 @@ def config() -> dict[str, Any]: "s3_endpoint": URL(os.getenv("S3_ENDPOINT", local_stack_endpoint)), "rules_bucket_name": rules_bucket_name, "audit_bucket_name": audit_bucket_name, + "consumer_mapping_bucket_name": consumer_mapping_bucket_name, "firehose_endpoint": URL(os.getenv("FIREHOSE_ENDPOINT", local_stack_endpoint)), "kinesis_audit_stream_to_s3": kinesis_audit_stream_to_s3, "enable_xray_patching": enable_xray_patching, diff --git a/src/eligibility_signposting_api/config/constants.py b/src/eligibility_signposting_api/config/constants.py index 3aa45fd35..bdc49e307 100644 --- a/src/eligibility_signposting_api/config/constants.py +++ b/src/eligibility_signposting_api/config/constants.py @@ -3,4 +3,5 @@ URL_PREFIX = "patient-check" RULE_STOP_DEFAULT = False NHS_NUMBER_HEADER = "nhs-login-nhs-number" +CONSUMER_ID = "consumer-id" ALLOWED_CONDITIONS = Literal["COVID", "FLU", "MMR", "RSV"] diff --git a/src/eligibility_signposting_api/model/consumer_mapping.py b/src/eligibility_signposting_api/model/consumer_mapping.py new file mode 100644 index 000000000..046aa9fee --- /dev/null +++ b/src/eligibility_signposting_api/model/consumer_mapping.py @@ -0,0 +1,17 @@ +from typing import NewType + +from pydantic import BaseModel, Field, RootModel + +from eligibility_signposting_api.model.campaign_config import CampaignID + +ConsumerId = NewType("ConsumerId", str) + + +class ConsumerCampaign(BaseModel): + campaign: CampaignID = Field(alias="Campaign") + description: str | None = Field(default=None, alias="Description") + + +class ConsumerMapping(RootModel[dict[ConsumerId, list[ConsumerCampaign]]]): + def get(self, key: ConsumerId, default: list[ConsumerCampaign] | None = None) -> list[ConsumerCampaign] | None: + return self.root.get(key, default) diff --git a/src/eligibility_signposting_api/repos/consumer_mapping_repo.py b/src/eligibility_signposting_api/repos/consumer_mapping_repo.py new file mode 100644 index 000000000..583acb4d5 --- /dev/null +++ b/src/eligibility_signposting_api/repos/consumer_mapping_repo.py @@ -0,0 +1,43 @@ +import json +from typing import Annotated, NewType + +from botocore.client import BaseClient +from wireup import Inject, service + +from eligibility_signposting_api.model.campaign_config import CampaignID +from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping + +BucketName = NewType("BucketName", str) + + +@service +class ConsumerMappingRepo: + """Repository class for Campaign Rules, which we can use to calculate a person's eligibility for vaccination. + + These rules are stored as JSON files in AWS S3.""" + + def __init__( + self, + s3_client: Annotated[BaseClient, Inject(qualifier="s3")], + bucket_name: Annotated[BucketName, Inject(param="consumer_mapping_bucket_name")], + ) -> None: + super().__init__() + self.s3_client = s3_client + self.bucket_name = bucket_name + + def get_permitted_campaign_ids(self, consumer_id: ConsumerId) -> list[CampaignID] | None: + objects = self.s3_client.list_objects(Bucket=self.bucket_name).get("Contents") + + if not objects: + return None + + consumer_mappings_obj = objects[0] + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=consumer_mappings_obj["Key"]) + body = response["Body"].read() + + mapping_result = ConsumerMapping.model_validate(json.loads(body)).get(consumer_id) + + if mapping_result is None: + return None + + return [item.campaign for item in mapping_result] diff --git a/src/eligibility_signposting_api/services/eligibility_services.py b/src/eligibility_signposting_api/services/eligibility_services.py index 79934e174..13b701d61 100644 --- a/src/eligibility_signposting_api/services/eligibility_services.py +++ b/src/eligibility_signposting_api/services/eligibility_services.py @@ -3,7 +3,10 @@ from wireup import service from eligibility_signposting_api.model import eligibility_status +from eligibility_signposting_api.model.campaign_config import CampaignConfig +from eligibility_signposting_api.model.consumer_mapping import ConsumerId from eligibility_signposting_api.repos import CampaignRepo, NotFoundError, PersonRepo +from eligibility_signposting_api.repos.consumer_mapping_repo import ConsumerMappingRepo from eligibility_signposting_api.services.calculators import eligibility_calculator as calculator logger = logging.getLogger(__name__) @@ -23,12 +26,14 @@ def __init__( self, person_repo: PersonRepo, campaign_repo: CampaignRepo, + consumer_mapping_repo: ConsumerMappingRepo, calculator_factory: calculator.EligibilityCalculatorFactory, ) -> None: super().__init__() self.person_repo = person_repo self.campaign_repo = campaign_repo self.calculator_factory = calculator_factory + self.consumer_mapping = consumer_mapping_repo def get_eligibility_status( self, @@ -36,16 +41,33 @@ def get_eligibility_status( include_actions: str, conditions: list[str], category: str, + consumer_id: str, ) -> eligibility_status.EligibilityStatus: """Calculate a person's eligibility for vaccination given an NHS number.""" if nhs_number: try: person_data = self.person_repo.get_eligibility_data(nhs_number) - campaign_configs = list(self.campaign_repo.get_campaign_configs()) except NotFoundError as e: raise UnknownPersonError from e else: - calc: calculator.EligibilityCalculator = self.calculator_factory.get(person_data, campaign_configs) + campaign_configs: list[CampaignConfig] = list(self.campaign_repo.get_campaign_configs()) + permitted_campaign_configs = self.__collect_permitted_campaign_configs( + campaign_configs, ConsumerId(consumer_id) + ) + calc: calculator.EligibilityCalculator = self.calculator_factory.get( + person_data, permitted_campaign_configs + ) return calc.get_eligibility_status(include_actions, conditions, category) raise UnknownPersonError # pragma: no cover + + def __collect_permitted_campaign_configs( + self, campaign_configs: list[CampaignConfig], consumer_id: ConsumerId + ) -> list[CampaignConfig]: + permitted_campaign_ids = self.consumer_mapping.get_permitted_campaign_ids(ConsumerId(consumer_id)) + if permitted_campaign_ids: + permitted_campaign_configs: list[CampaignConfig] = [ + campaign for campaign in campaign_configs if campaign.id in permitted_campaign_ids + ] + return permitted_campaign_configs + return [] diff --git a/src/eligibility_signposting_api/views/eligibility.py b/src/eligibility_signposting_api/views/eligibility.py index eb2b706ea..b935678f6 100644 --- a/src/eligibility_signposting_api/views/eligibility.py +++ b/src/eligibility_signposting_api/views/eligibility.py @@ -11,9 +11,12 @@ from eligibility_signposting_api.audit.audit_context import AuditContext from eligibility_signposting_api.audit.audit_service import AuditService -from eligibility_signposting_api.common.api_error_response import NHS_NUMBER_NOT_FOUND_ERROR +from eligibility_signposting_api.common.api_error_response import ( + NHS_NUMBER_NOT_FOUND_ERROR, +) from eligibility_signposting_api.common.request_validator import validate_request_params -from eligibility_signposting_api.config.constants import URL_PREFIX +from eligibility_signposting_api.config.constants import CONSUMER_ID, URL_PREFIX +from eligibility_signposting_api.model.consumer_mapping import ConsumerId from eligibility_signposting_api.model.eligibility_status import Condition, EligibilityStatus, NHSNumber, Status from eligibility_signposting_api.services import EligibilityService, UnknownPersonError from eligibility_signposting_api.views.response_model import eligibility_response @@ -47,13 +50,17 @@ def check_eligibility( nhs_number: NHSNumber, eligibility_service: Injected[EligibilityService], audit_service: Injected[AuditService] ) -> ResponseReturnValue: logger.info("checking nhs_number %r in %r", nhs_number, eligibility_service, extra={"nhs_number": nhs_number}) + + query_params = _get_or_default_query_params() + consumer_id = _get_consumer_id_from_headers() + try: - query_params = get_or_default_query_params() eligibility_status = eligibility_service.get_eligibility_status( nhs_number, query_params["includeActions"], query_params["conditions"], query_params["category"], + consumer_id, ) except UnknownPersonError: return handle_unknown_person_error(nhs_number) @@ -63,7 +70,14 @@ def check_eligibility( return make_response(response.model_dump(by_alias=True, mode="json", exclude_none=True), HTTPStatus.OK) -def get_or_default_query_params() -> dict[str, Any]: +def _get_consumer_id_from_headers() -> ConsumerId: + """ + @validate_request_params() ensures the consumer ID is never null at this stage. + """ + return ConsumerId(request.headers.get(CONSUMER_ID, "")) + + +def _get_or_default_query_params() -> dict[str, Any]: default_query_params = {"category": "ALL", "conditions": ["ALL"], "includeActions": "Y"} if not request.args: diff --git a/tests/fixtures/builders/model/rule.py b/tests/fixtures/builders/model/rule.py index bf62de900..2793ea032 100644 --- a/tests/fixtures/builders/model/rule.py +++ b/tests/fixtures/builders/model/rule.py @@ -93,7 +93,7 @@ class IterationFactory(ModelFactory[Iteration]): class RawCampaignConfigFactory(ModelFactory[CampaignConfig]): iterations = Use(IterationFactory.batch, size=2) - + id = "42-hi5tch-hi5kers-gu5ide-t2o-t3he-gal6axy" start_date = Use(past_date) end_date = Use(future_date) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4af1223a2..840af8b04 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -21,6 +21,7 @@ from eligibility_signposting_api.model.campaign_config import ( AvailableAction, CampaignConfig, + CampaignID, EndDate, RuleCode, RuleEntry, @@ -30,6 +31,7 @@ StartDate, StatusText, ) +from eligibility_signposting_api.model.consumer_mapping import ConsumerCampaign, ConsumerId, ConsumerMapping from eligibility_signposting_api.processors.hashing_service import HashingService, HashSecretName from eligibility_signposting_api.repos import SecretRepo from eligibility_signposting_api.repos.campaign_repo import BucketName @@ -661,6 +663,14 @@ def rules_bucket(s3_client: BaseClient) -> Generator[BucketName]: s3_client.delete_bucket(Bucket=bucket_name) +@pytest.fixture(scope="session") +def consumer_mapping_bucket(s3_client: BaseClient) -> Generator[BucketName]: + bucket_name = BucketName(os.getenv("CONSUMER_MAPPING_BUCKET_NAME", "test-consumer-mapping-bucket")) + s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + yield bucket_name + s3_client.delete_bucket(Bucket=bucket_name) + + @pytest.fixture(scope="session") def audit_bucket(s3_client: BaseClient) -> Generator[BucketName]: bucket_name = BucketName(os.getenv("AUDIT_BUCKET_NAME", "test-audit-bucket")) @@ -690,7 +700,7 @@ def firehose_delivery_stream(firehose_client: BaseClient, audit_bucket: BucketNa @pytest.fixture(scope="class") -def campaign_config(s3_client: BaseClient, rules_bucket: BucketName) -> Generator[CampaignConfig]: +def rsv_campaign_config(s3_client: BaseClient, rules_bucket: BucketName) -> Generator[CampaignConfig]: campaign: CampaignConfig = rule.CampaignConfigFactory.build( target="RSV", iterations=[ @@ -1101,6 +1111,285 @@ def campaign_config_with_missing_descriptions_missing_rule_text( s3_client.delete_object(Bucket=rules_bucket, Key=f"{campaign.name}.json") +@pytest.fixture +def campaign_configs(request, s3_client: BaseClient, rules_bucket: BucketName) -> Generator[list[CampaignConfig]]: + """Create and upload multiple campaign configs to S3, then clean up after tests.""" + campaigns, campaign_data_keys = [], [] # noqa: F841 + + raw = getattr( + request, "param", [("RSV", "RSV_campaign_id"), ("COVID", "COVID_campaign_id"), ("FLU", "FLU_campaign_id")] + ) + + targets = [] + campaign_id = [] + status = [] + + for t, _id, *rest in raw: + targets.append(t) + campaign_id.append(_id) + status.append(rest[0] if rest else None) + + for i in range(len(targets)): + campaign: CampaignConfig = rule.CampaignConfigFactory.build( + name=f"campaign_{i}", + id=campaign_id[i], + target=targets[i], + type="V", + iterations=[ + rule.IterationFactory.build( + iteration_rules=[ + rule.PostcodeSuppressionRuleFactory.build(type=RuleType.filter), + rule.PersonAgeSuppressionRuleFactory.build(), + rule.PersonAgeSuppressionRuleFactory.build(name="Exclude 76 rolling", description=""), + ], + iteration_cohorts=[ + rule.IterationCohortFactory.build( + cohort_label="cohort1", + cohort_group="cohort_group1", + positive_description="", + negative_description="", + ) + ], + status_text=None, + ) + ], + ) + + if status[i] == "inactive": + campaign.iterations[0].iteration_date = datetime.datetime.now(tz=datetime.UTC) + datetime.timedelta(days=7) + + campaign_data = {"CampaignConfig": campaign.model_dump(by_alias=True)} + key = f"{campaign.name}.json" + s3_client.put_object( + Bucket=rules_bucket, Key=key, Body=json.dumps(campaign_data), ContentType="application/json" + ) + campaign_id.append(campaign) + campaign_data_keys.append(key) + + yield campaign_id + + for key in campaign_data_keys: + s3_client.delete_object(Bucket=rules_bucket, Key=key) + + +@pytest.fixture(scope="class") +def consumer_id() -> ConsumerId: + return ConsumerId("23-mic7heal-jor6don") + + +def create_and_put_consumer_mapping_in_s3( + campaign_config: CampaignConfig, consumer_id: str, consumer_mapping_bucket, s3_client +) -> ConsumerMapping: + consumer_mapping = ConsumerMapping.model_validate({}) + campaign_entry = ConsumerCampaign(Campaign=campaign_config.id, Description="Test description for campaign mapping") + + consumer_mapping.root[ConsumerId(consumer_id)] = [campaign_entry] + consumer_mapping_data = consumer_mapping.model_dump(by_alias=True) + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(consumer_mapping_data), + ContentType="application/json", + ) + return consumer_mapping + + +@pytest.fixture(scope="class") +def consumer_mapped_to_campaign_having_invalid_tokens( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + campaign_config_with_invalid_tokens: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_invalid_tokens, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_mapped_to_campaign_having_tokens( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + campaign_config_with_tokens: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_tokens, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_mapped_to_rsv_campaign( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + rsv_campaign_config: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + rsv_campaign_config, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_mapped_to_campaign_having_and_rule( + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, + campaign_config_with_and_rule: CampaignConfig, + consumer_id: ConsumerId, +) -> Generator[ConsumerMapping]: + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_and_rule, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_mapped_to_campaign_missing_descriptions_and_rule_text( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_missing_descriptions_missing_rule_text, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_mapped_to_campaign_having_rules_with_rule_code( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_rules_having_rule_code: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_rules_having_rule_code, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_mapped_to_campaign_having_rules_with_rule_mapper( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_rules_having_rule_mapper: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_rules_having_rule_mapper, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_mapped_to_campaign_having_only_virtual_cohort( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + campaign_config_with_virtual_cohort: CampaignConfig, + consumer_id: ConsumerId, +): + consumer_mapping = create_and_put_consumer_mapping_in_s3( + campaign_config_with_virtual_cohort, consumer_id, consumer_mapping_bucket, s3_client + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_mapped_to_campaign_having_inactive_iteration_config( + s3_client: BaseClient, + consumer_mapping_bucket: ConsumerMapping, + inactive_iteration_config: list[CampaignConfig], + consumer_id: ConsumerId, +): + mapping = ConsumerMapping.model_validate({}) + mapping.root[consumer_id] = [ + ConsumerCampaign(Campaign=cc.id, Description=f"Description for {cc.id}") for cc in inactive_iteration_config + ] + + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(mapping.model_dump(by_alias=True)), + ContentType="application/json", + ) + yield mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_mapped_to_multiple_campaign_configs( + multiple_campaign_configs: list[CampaignConfig], + consumer_id: ConsumerId, + s3_client: BaseClient, + consumer_mapping_bucket: BucketName, +) -> Generator[ConsumerMapping]: + mapping = ConsumerMapping.model_validate({}) + mapping.root[consumer_id] = [ + ConsumerCampaign(Campaign=cc.id, Description=f"Description for {cc.id}") for cc in multiple_campaign_configs + ] + + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(mapping.model_dump(by_alias=True)), + ContentType="application/json", + ) + yield mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture +def consumer_mappings( + request, s3_client: BaseClient, consumer_mapping_bucket: BucketName +) -> Generator[ConsumerMapping]: + consumer_mapping = ConsumerMapping.model_validate(getattr(request, "param", {})) + consumer_mapping_data = consumer_mapping.model_dump(by_alias=True) + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(consumer_mapping_data), + ContentType="application/json", + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + +@pytest.fixture(scope="class") +def consumer_mapped_to_with_various_targets( + s3_client: BaseClient, consumer_mapping_bucket: BucketName +) -> Generator[ConsumerMapping]: + consumer_mapping = ConsumerMapping.model_validate({}) + + consumer_mapping.root[ConsumerId("23-mic7heal-jor6don")] = [ + ConsumerCampaign(Campaign=CampaignID("campaign_start_date")), + ConsumerCampaign(Campaign=CampaignID("campaign_start_date_plus_one_day")), + ConsumerCampaign(Campaign=CampaignID("campaign_today")), + ConsumerCampaign(Campaign=CampaignID("campaign_tomorrow")), + ] + + consumer_mapping_data = consumer_mapping.model_dump(by_alias=True) + s3_client.put_object( + Bucket=consumer_mapping_bucket, + Key="consumer_mapping.json", + Body=json.dumps(consumer_mapping_data), + ContentType="application/json", + ) + yield consumer_mapping + s3_client.delete_object(Bucket=consumer_mapping_bucket, Key="consumer_mapping.json") + + # If you put StubSecretRepo in a separate module, import it instead class StubSecretRepo(SecretRepo): # def __init__(self, current: str = AWS_CURRENT_SECRET, previous: str = AWS_PREVIOUS_SECRET): diff --git a/tests/integration/in_process/test_eligibility_endpoint.py b/tests/integration/in_process/test_eligibility_endpoint.py index 4e5cdbfb8..3f6310534 100644 --- a/tests/integration/in_process/test_eligibility_endpoint.py +++ b/tests/integration/in_process/test_eligibility_endpoint.py @@ -1,5 +1,7 @@ +import json from http import HTTPStatus +import pytest from botocore.client import BaseClient from brunns.matchers.data import json_matching as is_json_that from brunns.matchers.werkzeug import is_werkzeug_response as is_response @@ -7,16 +9,20 @@ from hamcrest import ( assert_that, contains_exactly, + contains_inanyorder, equal_to, has_entries, has_entry, has_key, ) +from eligibility_signposting_api.config.constants import CONSUMER_ID from eligibility_signposting_api.model.campaign_config import CampaignConfig +from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping from eligibility_signposting_api.model.eligibility_status import ( NHSNumber, ) +from eligibility_signposting_api.repos.campaign_repo import BucketName class TestBaseLine: @@ -24,11 +30,12 @@ def test_nhs_number_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG002 secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}", headers=headers) @@ -57,7 +64,6 @@ def test_no_nhs_number_given_but_header_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 ): # Given headers = {"nhs-login-nhs-number": str(persisted_person)} @@ -79,10 +85,12 @@ def test_not_base_eligible( self, client: FlaskClient, persisted_person_no_cohorts: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts)} + headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_no_cohorts}?includeActions=Y", headers=headers) @@ -123,10 +131,12 @@ def test_not_eligible_by_rule( self, client: FlaskClient, persisted_person_pc_sw19: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19)} + headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_pc_sw19}?includeActions=Y", headers=headers) @@ -167,10 +177,12 @@ def test_not_actionable_and_check_response_when_no_rule_code_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -217,9 +229,11 @@ def test_actionable( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=Y", headers=headers) @@ -268,10 +282,12 @@ def test_actionable_with_and_rule( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_and_rule: CampaignConfig, # noqa: ARG002 + consumer_id: ConsumerId, + consumer_mapped_to_campaign_having_and_rule: ConsumerMapping, # noqa: ARG002 + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -322,10 +338,12 @@ def test_not_eligible_by_rule_when_only_virtual_cohort_is_present( self, client: FlaskClient, persisted_person_pc_sw19: NHSNumber, - campaign_config_with_virtual_cohort: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_having_only_virtual_cohort: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19)} + headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_pc_sw19}?includeActions=Y", headers=headers) @@ -366,10 +384,12 @@ def test_not_actionable_when_only_virtual_cohort_is_present( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_virtual_cohort: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_having_only_virtual_cohort: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -416,10 +436,12 @@ def test_actionable_when_only_virtual_cohort_is_present( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config_with_virtual_cohort: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_having_only_virtual_cohort: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=Y", headers=headers) @@ -470,10 +492,12 @@ def test_not_base_eligible( self, client: FlaskClient, persisted_person_no_cohorts: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_missing_descriptions_and_rule_text: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts)} + headers = {"nhs-login-nhs-number": str(persisted_person_no_cohorts), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_no_cohorts}?includeActions=Y", headers=headers) @@ -508,10 +532,12 @@ def test_not_eligible_by_rule( self, client: FlaskClient, persisted_person_pc_sw19: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_missing_descriptions_and_rule_text: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19)} + headers = {"nhs-login-nhs-number": str(persisted_person_pc_sw19), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person_pc_sw19}?includeActions=Y", headers=headers) @@ -546,10 +572,12 @@ def test_not_actionable( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_missing_descriptions_and_rule_text: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -590,10 +618,12 @@ def test_actionable( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_missing_descriptions_and_rule_text: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=Y", headers=headers) @@ -636,10 +666,12 @@ def test_actionable_no_actions( self, client: FlaskClient, persisted_77yo_person: NHSNumber, - campaign_config_with_missing_descriptions_missing_rule_text: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_missing_descriptions_and_rule_text: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_77yo_person)} + headers = {"nhs-login-nhs-number": str(persisted_77yo_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_77yo_person}?includeActions=N", headers=headers) @@ -710,10 +742,12 @@ def test_not_actionable_and_check_response_when_rule_mapper_is_absent_but_rule_c self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_rules_having_rule_code: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_having_rules_with_rule_code: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -760,10 +794,12 @@ def test_not_actionable_and_check_response_when_rule_mapper_is_given( self, client: FlaskClient, persisted_person: NHSNumber, - campaign_config_with_rules_having_rule_mapper: CampaignConfig, # noqa: ARG002 + consumer_mapped_to_campaign_having_rules_with_rule_mapper: ConsumerMapping, # noqa: ARG002 + consumer_id: ConsumerId, + secretsmanager_client: BaseClient, # noqa: ARG002 ): # Given - headers = {"nhs-login-nhs-number": str(persisted_person)} + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} # When response = client.get(f"/patient-check/{persisted_person}?includeActions=Y", headers=headers) @@ -805,3 +841,326 @@ def test_not_actionable_and_check_response_when_rule_mapper_is_given( ) ), ) + + @pytest.mark.parametrize( + ( + "campaign_configs", + "consumer_mappings", + "consumer_id", + "requested_conditions", + "requested_category", + "expected_targets", + ), + [ + # ============================================================ + # Group 1: Consumer is mapped, campaign exists in S3, requesting + # ============================================================ + # 1.1 Consumer is mapped; multiple active campaigns exist; requesting ALL + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "ALL", + "VACCINATIONS", + ["RSV", "COVID"], + ), + # 1.2 Consumer is mapped; requested single campaign exists and is mapped + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "RSV", + "VACCINATIONS", + ["RSV"], + ), + # 1.3 Consumer is mapped; requested multiple campaigns exist and are mapped + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "RSV,COVID", + "VACCINATIONS", + ["RSV", "COVID"], + ), + # ============================================================ + # Group 2: Consumer is mapped, campaign does NOT exist in S3 + # ============================================================ + # 2.1 Consumer is mapped; requested campaign exists in S3 but not mapped + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "FLU", + "VACCINATIONS", + [], + ), + # 2.2 Consumer is mapped, but none of the mapped campaigns exist in S3 + ( + [ + ("MMR", "MMR_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "ALL", + "VACCINATIONS", + [], + ), + # 2.3 Consumer is mapped; requested mapped campaign does not exist in S3 + ( + [ + ("MMR", "MMR_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "RSV", + "VACCINATIONS", + [], + ), + # ============================================================ + # Group 3: Consumer is NOT mapped, campaign exists in S3 + # ============================================================ + # 3.1 Consumer not mapped; requesting ALL + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "another-consumer-id", + "ALL", + "VACCINATIONS", + [], + ), + # 3.2 Consumer not mapped; requesting specific campaign + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "another-consumer-id", + "RSV", + "VACCINATIONS", + [], + ), + # ============================================================ + # Group 4: Consumer NOT mapped, campaign does NOT exist in S3 + # ============================================================ + # 4.1 Consumer mapped; requested campaign does not exist + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + { + "consumer-id": [ + {"Campaign": "RSV_campaign_id"}, + {"Campaign": "COVID_campaign_id"}, + ] + }, + "consumer-id", + "HPV", + "VACCINATIONS", + [], + ), + # 4.2 No consumer mappings; requesting ALL + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + {}, + "consumer-id", + "ALL", + "VACCINATIONS", + [], + ), + # 4.3 No consumer mappings; requesting specific campaign + ( + [ + ("RSV", "RSV_campaign_id"), + ("COVID", "COVID_campaign_id"), + ("FLU", "FLU_campaign_id"), + ], + {}, + "consumer-id", + "RSV", + "VACCINATIONS", + [], + ), + ], + indirect=["campaign_configs", "consumer_mappings"], + ) + def test_valid_response_when_consumer_has_a_valid_campaign_config_mapping( # noqa: PLR0913 + self, + client: FlaskClient, + persisted_person: NHSNumber, + secretsmanager_client: BaseClient, # noqa: ARG002 + campaign_configs: CampaignConfig, # noqa: ARG002 + consumer_mappings: ConsumerMapping, # noqa: ARG002 + consumer_id: str, + requested_conditions: str, + requested_category: str, + expected_targets: list[str], + ): + # Given + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} + + # When + response = client.get( + f"/patient-check/{persisted_person}?includeActions=Y&category={requested_category}&conditions={requested_conditions}", + headers=headers, + ) + + assert_that( + response, + is_response() + .with_status_code(HTTPStatus.OK) + .and_text( + is_json_that( + has_entry( + "processedSuggestions", + # This ensures ONLY these items exist, no extras like FLU + contains_inanyorder(*[has_entry("condition", i) for i in expected_targets]), + ) + ) + ), + ) + + @pytest.mark.parametrize( + ("consumer_id", "expected_campaign_id"), + [ + # Consumer is mapped only to RSV_campaign_id_1 + ("consumer-id-1", "RSV_campaign_id_1"), + # Consumer is mapped only to RSV_campaign_id_2 + ("consumer-id-2", "RSV_campaign_id_2"), + # Edge-case : Consumer-id-3a is mapped to multiple active campaigns, so only one taken. + ("consumer-id-3a", "RSV_campaign_id_3"), + # Edge-case : Consumer-id-3b is mapped to multiple active campaigns, so only one taken. + ("consumer-id-3b", "RSV_campaign_id_3"), + # Edge-case : Consumer is mapped to inactive inactive_RSV_campaign_id_5 and active RSV_campaign_id_6 + ("consumer-id-4", "RSV_campaign_id_6"), + # Edge-case : Consumer is mapped only to inactive RSV_campaign_id_5 + ("consumer-id-5", None), + ], + ) + @pytest.mark.parametrize( + ("campaign_configs", "consumer_mappings", "requested_conditions", "requested_category"), + [ + ( + [ + ("RSV", "RSV_campaign_id_1"), + ("RSV", "RSV_campaign_id_2"), + ("RSV", "RSV_campaign_id_3"), + ("RSV", "RSV_campaign_id_4"), + ("RSV", "inactive_RSV_campaign_id_5", "inactive"), # inactive iteration + ("RSV", "RSV_campaign_id_6"), + ], + { + "consumer-id-1": [{"Campaign": "RSV_campaign_id_1"}], + "consumer-id-2": [{"Campaign": "RSV_campaign_id_2"}], + "consumer-id-3a": [{"Campaign": "RSV_campaign_id_3"}, {"Campaign": "RSV_campaign_id_4"}], + "consumer-id-3b": [{"Campaign": "RSV_campaign_id_4"}, {"Campaign": "RSV_campaign_id_3"}], + "consumer-id-4": [{"Campaign": "inactive_RSV_campaign_id_5"}, {"Campaign": "RSV_campaign_id_6"}], + "consumer-id-5": [{"Campaign": "inactive_RSV_campaign_id_5"}], + }, + "RSV", + "VACCINATIONS", + ) + ], + indirect=["campaign_configs", "consumer_mappings"], + ) + def test_if_correct_campaign_is_chosen_for_the_consumer_if_there_exists_multiple_campaign_per_target( # noqa : PLR0913 + self, + client: FlaskClient, + persisted_person: NHSNumber, + secretsmanager_client: BaseClient, # noqa: ARG002 + audit_bucket: BucketName, + s3_client: BaseClient, + campaign_configs: CampaignConfig, # noqa: ARG002 + consumer_mappings: ConsumerMapping, # noqa: ARG002 + consumer_id: str, + requested_conditions: str, + requested_category: str, + expected_campaign_id: list[str], + ): + # Given + headers = {"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id} + + # When + client.get( + f"/patient-check/{persisted_person}?includeActions=Y&category={requested_category}&conditions={requested_conditions}", + headers=headers, + ) + + objects = s3_client.list_objects_v2(Bucket=audit_bucket).get("Contents", []) + object_keys = [obj["Key"] for obj in objects] + latest_key = sorted(object_keys)[-1] + audit_data = json.loads(s3_client.get_object(Bucket=audit_bucket, Key=latest_key)["Body"].read()) + + # Then + if expected_campaign_id is not None: + assert_that(len(audit_data["response"]["condition"]), equal_to(1)) + assert_that(audit_data["response"]["condition"][0].get("campaignId"), equal_to(expected_campaign_id)) + else: + assert_that(len(audit_data["response"]["condition"]), equal_to(0)) diff --git a/tests/integration/lambda/test_app_running_as_lambda.py b/tests/integration/lambda/test_app_running_as_lambda.py index 9c473696a..5ee7c2fdb 100644 --- a/tests/integration/lambda/test_app_running_as_lambda.py +++ b/tests/integration/lambda/test_app_running_as_lambda.py @@ -23,7 +23,9 @@ ) from yarl import URL +from eligibility_signposting_api.config.constants import CONSUMER_ID from eligibility_signposting_api.model.campaign_config import CampaignConfig +from eligibility_signposting_api.model.consumer_mapping import ConsumerId, ConsumerMapping from eligibility_signposting_api.model.eligibility_status import NHSNumber from eligibility_signposting_api.repos.campaign_repo import BucketName @@ -34,7 +36,8 @@ def test_install_and_call_lambda_flask( lambda_client: BaseClient, flask_function: str, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG001 + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, ): """Given lambda installed into localstack, run it via boto3 lambda client""" # Given @@ -49,6 +52,7 @@ def test_install_and_call_lambda_flask( "accept": "application/json", "content-type": "application/json", "nhs-login-nhs-number": str(persisted_person), + CONSUMER_ID: consumer_id, }, "pathParameters": {"id": str(persisted_person)}, "requestContext": { @@ -85,7 +89,8 @@ def test_install_and_call_lambda_flask( def test_install_and_call_flask_lambda_over_http( persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG001 + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, api_gateway_endpoint: URL, ): """Given api-gateway and lambda installed into localstack, run it via http""" @@ -94,7 +99,7 @@ def test_install_and_call_flask_lambda_over_http( invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(persisted_person)}, + headers={"nhs-login-nhs-number": str(persisted_person), CONSUMER_ID: consumer_id}, timeout=10, ) @@ -105,12 +110,14 @@ def test_install_and_call_flask_lambda_over_http( ) -def test_install_and_call_flask_lambda_with_unknown_nhs_number( +def test_install_and_call_flask_lambda_with_unknown_nhs_number( # noqa: PLR0913 flask_function: str, persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa: ARG001 + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, logs_client: BaseClient, api_gateway_endpoint: URL, + secretsmanager_client: BaseClient, # noqa: ARG001 ): """Given lambda installed into localstack, run it via http, with a nonexistent NHS number specified""" # Given @@ -120,7 +127,7 @@ def test_install_and_call_flask_lambda_with_unknown_nhs_number( invoke_url = f"{api_gateway_endpoint}/patient-check/{nhs_number}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(nhs_number)}, + headers={"nhs-login-nhs-number": str(nhs_number), CONSUMER_ID: consumer_id}, timeout=10, ) @@ -181,7 +188,9 @@ def get_log_messages(flask_function: str, logs_client: BaseClient) -> list[str]: def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_if_audited( # noqa: PLR0913 lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, + rsv_campaign_config: CampaignConfig, + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, @@ -195,6 +204,7 @@ def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_i invoke_url, headers={ "nhs-login-nhs-number": str(persisted_person), + CONSUMER_ID: consumer_id, "x_request_id": "x_request_id", "x_correlation_id": "x_correlation_id", "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods", @@ -226,13 +236,13 @@ def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_i expected_conditions = [ { - "campaignId": campaign_config.id, - "campaignVersion": campaign_config.version, - "iterationId": campaign_config.iterations[0].id, - "iterationVersion": campaign_config.iterations[0].version, - "conditionName": campaign_config.target, + "campaignId": rsv_campaign_config.id, + "campaignVersion": rsv_campaign_config.version, + "iterationId": rsv_campaign_config.iterations[0].id, + "iterationVersion": rsv_campaign_config.iterations[0].version, + "conditionName": rsv_campaign_config.target, "status": "not_actionable", - "statusText": f"You should have the {campaign_config.target} vaccine", + "statusText": f"You should have the {rsv_campaign_config.target} vaccine", "eligibilityCohorts": [{"cohortCode": "cohort1", "cohortStatus": "not_actionable"}], "eligibilityCohortGroups": [ { @@ -277,7 +287,6 @@ def test_given_nhs_number_in_path_matches_with_nhs_number_in_headers_and_check_i def test_given_nhs_number_in_path_does_not_match_with_nhs_number_in_headers_results_in_error_response( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 api_gateway_endpoint: URL, ): # Given @@ -285,7 +294,7 @@ def test_given_nhs_number_in_path_does_not_match_with_nhs_number_in_headers_resu invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": f"123{persisted_person!s}"}, + headers={"nhs-login-nhs-number": f"123{persisted_person!s}", "consumer-id": "test_consumer_id"}, timeout=10, ) @@ -324,7 +333,6 @@ def test_given_nhs_number_in_path_does_not_match_with_nhs_number_in_headers_resu def test_given_nhs_number_not_present_in_headers_results_in_error_response( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 api_gateway_endpoint: URL, ): # Given @@ -332,6 +340,7 @@ def test_given_nhs_number_not_present_in_headers_results_in_error_response( invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, + headers={"consumer-id": "test_consumer_id"}, timeout=10, ) @@ -370,7 +379,8 @@ def test_given_nhs_number_not_present_in_headers_results_in_error_response( def test_validation_of_query_params_when_all_are_valid( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 + consumer_mapped_to_rsv_campaign: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, api_gateway_endpoint: URL, ): # Given @@ -378,7 +388,7 @@ def test_validation_of_query_params_when_all_are_valid( invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": persisted_person}, + headers={"nhs-login-nhs-number": persisted_person, CONSUMER_ID: consumer_id}, params={"category": "VACCINATIONS", "conditions": "COVID19", "includeActions": "N"}, timeout=10, ) @@ -390,7 +400,6 @@ def test_validation_of_query_params_when_all_are_valid( def test_validation_of_query_params_when_invalid_conditions_is_specified( lambda_client: BaseClient, # noqa:ARG001 persisted_person: NHSNumber, - campaign_config: CampaignConfig, # noqa:ARG001 api_gateway_endpoint: URL, ): # Given @@ -398,7 +407,7 @@ def test_validation_of_query_params_when_invalid_conditions_is_specified( invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": persisted_person}, + headers={"nhs-login-nhs-number": persisted_person, "consumer-id": "test_consumer_id"}, params={"category": "ALL", "conditions": "23-097"}, timeout=10, ) @@ -411,15 +420,19 @@ def test_given_person_has_unique_status_for_different_conditions_with_audit( # lambda_client: BaseClient, # noqa:ARG001 persisted_person_all_cohorts: NHSNumber, multiple_campaign_configs: list[CampaignConfig], + consumer_mapped_to_multiple_campaign_configs: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, + secretsmanager_client: BaseClient, # noqa: ARG001 ): invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person_all_cohorts}" response = httpx.get( invoke_url, headers={ "nhs-login-nhs-number": str(persisted_person_all_cohorts), + CONSUMER_ID: consumer_id, "x_request_id": "x_request_id", "x_correlation_id": "x_correlation_id", "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods", @@ -553,7 +566,8 @@ def test_given_person_has_unique_status_for_different_conditions_with_audit( # def test_no_active_iteration_returns_empty_processed_suggestions( lambda_client: BaseClient, # noqa:ARG001 persisted_person_all_cohorts: NHSNumber, - inactive_iteration_config: list[CampaignConfig], # noqa:ARG001 + consumer_mapped_to_campaign_having_inactive_iteration_config: ConsumerMapping, # noqa:ARG001 + consumer_id: ConsumerId, api_gateway_endpoint: URL, ): invoke_url = f"{api_gateway_endpoint}/patient-check/{persisted_person_all_cohorts}" @@ -561,6 +575,7 @@ def test_no_active_iteration_returns_empty_processed_suggestions( invoke_url, headers={ "nhs-login-nhs-number": str(persisted_person_all_cohorts), + CONSUMER_ID: consumer_id, "x_request_id": "x_request_id", "x_correlation_id": "x_correlation_id", "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods", @@ -589,7 +604,8 @@ def test_no_active_iteration_returns_empty_processed_suggestions( def test_token_formatting_in_eligibility_response_and_audit( # noqa: PLR0913 lambda_client: BaseClient, # noqa:ARG001 person_with_all_data: NHSNumber, - campaign_config_with_tokens: CampaignConfig, # noqa:ARG001 + consumer_mapped_to_campaign_having_tokens: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, @@ -599,7 +615,7 @@ def test_token_formatting_in_eligibility_response_and_audit( # noqa: PLR0913 invoke_url = f"{api_gateway_endpoint}/patient-check/{person_with_all_data}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(person_with_all_data)}, + headers={"nhs-login-nhs-number": str(person_with_all_data), CONSUMER_ID: consumer_id}, timeout=10, ) @@ -639,7 +655,8 @@ def test_token_formatting_in_eligibility_response_and_audit( # noqa: PLR0913 def test_incorrect_token_causes_internal_server_error( # noqa: PLR0913 lambda_client: BaseClient, # noqa:ARG001 person_with_all_data: NHSNumber, - campaign_config_with_invalid_tokens: CampaignConfig, # noqa:ARG001 + consumer_mapped_to_campaign_having_invalid_tokens: ConsumerMapping, # noqa: ARG001 + consumer_id: ConsumerId, s3_client: BaseClient, audit_bucket: BucketName, api_gateway_endpoint: URL, @@ -649,7 +666,7 @@ def test_incorrect_token_causes_internal_server_error( # noqa: PLR0913 invoke_url = f"{api_gateway_endpoint}/patient-check/{person_with_all_data}" response = httpx.get( invoke_url, - headers={"nhs-login-nhs-number": str(person_with_all_data)}, + headers={"nhs-login-nhs-number": str(person_with_all_data), CONSUMER_ID: consumer_id}, timeout=10, ) diff --git a/tests/test_data/test_consumer_mapping/test_consumer_mapping_config.json b/tests/test_data/test_consumer_mapping/test_consumer_mapping_config.json new file mode 100644 index 000000000..29127b19c --- /dev/null +++ b/tests/test_data/test_consumer_mapping/test_consumer_mapping_config.json @@ -0,0 +1,23 @@ + +{ + "consumer-id-123": [ + { + "Campaign": "RSV_campaign_id", + "Description": "RSV Ongoing for My Vaccines" + }, + { + "Campaign": "COVID_campaign_id", + "Description": "COVID Ongoing for My Vaccines" + } + ], + "consumer-id-456": [ + { + "Campaign": "RSV_campaign_id_NBS", + "Description": "RSV Ongoing for NBS" + }, + { + "Campaign": "COVID_campaign_id_NBS", + "Description": "RSV Ongoing for NBS" + } + ] +} diff --git a/tests/unit/common/test_request_validator.py b/tests/unit/common/test_request_validator.py index 7de1c776a..c19977dc9 100644 --- a/tests/unit/common/test_request_validator.py +++ b/tests/unit/common/test_request_validator.py @@ -48,7 +48,7 @@ def test_validate_request_params_success(self, app, caplog): with app.test_request_context( "/dummy?id=1234567890", - headers={"nhs-login-nhs-number": "1234567890"}, + headers={"nhs-login-nhs-number": "1234567890", "consumer-id": "test_consumer_id"}, method="GET", ): with caplog.at_level(logging.INFO): @@ -66,7 +66,7 @@ def test_validate_request_params_nhs_mismatch(self, app, caplog): with app.test_request_context( "/dummy?id=1234567890", - headers={"nhs-login-nhs-number": "0987654321"}, + headers={"nhs-login-nhs-number": "0987654321", "consumer-id": "test_consumer_id"}, method="GET", ): with caplog.at_level(logging.INFO): @@ -84,6 +84,58 @@ def test_validate_request_params_nhs_mismatch(self, app, caplog): assert issue["diagnostics"] == "You are not authorised to request information for the supplied NHS Number" assert response.headers["Content-Type"] == "application/fhir+json" + def test_validate_request_params_consumer_id_present(self, app, caplog): + mock_api = MagicMock(return_value="ok") + + decorator = request_validator.validate_request_params() + dummy_route = decorator(mock_api) + + with ( + app.test_request_context( + "/dummy?id=1234567890", + headers={ + "consumer-id": "some-consumer", + "nhs-login-nhs-number": "1234567890", + }, + method="GET", + ), + caplog.at_level(logging.INFO), + ): + response = dummy_route(nhs_number=request.args.get("id")) + + mock_api.assert_called_once() + assert response == "ok" + assert not any(record.levelname == "ERROR" for record in caplog.records) + + def test_validate_request_params_missing_consumer_id(self, app, caplog): + mock_api = MagicMock() + + decorator = request_validator.validate_request_params() + dummy_route = decorator(mock_api) + + with ( + app.test_request_context( + "/dummy?id=1234567890", + headers={"nhs-login-nhs-number": "1234567890"}, # no consumer ID + method="GET", + ), + caplog.at_level(logging.ERROR), + ): + response = dummy_route(nhs_number=request.args.get("id")) + + mock_api.assert_not_called() + + assert response is not None + assert response.status_code == HTTPStatus.FORBIDDEN + response_json = response.json + + issue = response_json["issue"][0] + assert issue["code"] == "forbidden" + assert issue["details"]["coding"][0]["code"] == "ACCESS_DENIED" + assert issue["details"]["coding"][0]["display"] == "Access has been denied to process this request." + assert issue["diagnostics"] == "You are not authorised to request" + assert response.headers["Content-Type"] == "application/fhir+json" + class TestValidateQueryParameters: @pytest.mark.parametrize( diff --git a/tests/unit/repos/test_consumer_mapping_repo.py b/tests/unit/repos/test_consumer_mapping_repo.py new file mode 100644 index 000000000..057042e27 --- /dev/null +++ b/tests/unit/repos/test_consumer_mapping_repo.py @@ -0,0 +1,62 @@ +import json +from unittest.mock import MagicMock + +import pytest + +from eligibility_signposting_api.model.consumer_mapping import ConsumerId +from eligibility_signposting_api.repos.consumer_mapping_repo import BucketName, ConsumerMappingRepo + + +class TestConsumerMappingRepo: + @pytest.fixture + def mock_s3_client(self): + return MagicMock() + + @pytest.fixture + def repo(self, mock_s3_client): + return ConsumerMappingRepo(s3_client=mock_s3_client, bucket_name=BucketName("test-bucket")) + + def test_get_permitted_campaign_ids_success(self, repo, mock_s3_client): + # Given + consumer_id = "user-123" + + # The expected output is just the IDs + expected_campaign_ids = ["flu-2024", "covid-2024"] + + # The mocked S3 data must match the new schema (objects with description) + mapping_data = { + consumer_id: [ + {"Campaign": "flu-2024", "Description": "Flu Shot Description"}, + {"Campaign": "covid-2024", "Description": "Covid Shot Description"}, + ] + } + + mock_s3_client.list_objects.return_value = {"Contents": [{"Key": "mappings.json"}]} + + body_json = json.dumps(mapping_data).encode("utf-8") + mock_s3_client.get_object.return_value = {"Body": MagicMock(read=lambda: body_json)} + + # When + result = repo.get_permitted_campaign_ids(ConsumerId(consumer_id)) + + # Then + assert result == expected_campaign_ids + mock_s3_client.list_objects.assert_called_once_with(Bucket="test-bucket") + mock_s3_client.get_object.assert_called_once_with(Bucket="test-bucket", Key="mappings.json") + + def test_get_permitted_campaign_ids_returns_none_when_missing(self, repo, mock_s3_client): + """ + Setup data where the consumer_id doesn't exist + We must still use the valid schema (dicts inside the list) to pass Pydantic validation + """ + valid_schema_data = {"other-user": [{"Campaign": "camp-1", "Description": "Some description"}]} + + mock_s3_client.list_objects.return_value = {"Contents": [{"Key": "mappings.json"}]} + body_json = json.dumps(valid_schema_data).encode("utf-8") + mock_s3_client.get_object.return_value = {"Body": MagicMock(read=lambda: body_json)} + + # When + result = repo.get_permitted_campaign_ids(ConsumerId("missing-user")) + + # Then + assert result is None diff --git a/tests/unit/services/test_eligibility_services.py b/tests/unit/services/test_eligibility_services.py index 504888f12..3d3b787cd 100644 --- a/tests/unit/services/test_eligibility_services.py +++ b/tests/unit/services/test_eligibility_services.py @@ -3,23 +3,43 @@ import pytest from hamcrest import assert_that, empty +from eligibility_signposting_api.model.campaign_config import CampaignConfig, CampaignID from eligibility_signposting_api.model.eligibility_status import NHSNumber from eligibility_signposting_api.repos import CampaignRepo, NotFoundError, PersonRepo +from eligibility_signposting_api.repos.consumer_mapping_repo import ConsumerMappingRepo from eligibility_signposting_api.services import EligibilityService, UnknownPersonError from eligibility_signposting_api.services.calculators.eligibility_calculator import EligibilityCalculatorFactory from tests.fixtures.matchers.eligibility import is_eligibility_status +@pytest.fixture +def mock_repos(): + return { + "person": MagicMock(spec=PersonRepo), + "campaign": MagicMock(spec=CampaignRepo), + "consumer": MagicMock(spec=ConsumerMappingRepo), + "factory": MagicMock(spec=EligibilityCalculatorFactory), + } + + +@pytest.fixture +def service(mock_repos): + return EligibilityService( + mock_repos["person"], mock_repos["campaign"], mock_repos["consumer"], mock_repos["factory"] + ) + + def test_eligibility_service_returns_from_repo(): # Given person_repo = MagicMock(spec=PersonRepo) campaign_repo = MagicMock(spec=CampaignRepo) + consumer_mapping_repo = MagicMock(spec=ConsumerMappingRepo) person_repo.get_eligibility = MagicMock(return_value=[]) - service = EligibilityService(person_repo, campaign_repo, EligibilityCalculatorFactory()) + service = EligibilityService(person_repo, campaign_repo, consumer_mapping_repo, EligibilityCalculatorFactory()) # When actual = service.get_eligibility_status( - NHSNumber("1234567890"), include_actions="Y", conditions=["ALL"], category="ALL" + NHSNumber("1234567890"), include_actions="Y", conditions=["ALL"], category="ALL", consumer_id="test_consumer_id" ) # Then @@ -30,9 +50,53 @@ def test_eligibility_service_for_nonexistent_nhs_number(): # Given person_repo = MagicMock(spec=PersonRepo) campaign_repo = MagicMock(spec=CampaignRepo) + consumer_mapping_repo = MagicMock(spec=ConsumerMappingRepo) person_repo.get_eligibility_data = MagicMock(side_effect=NotFoundError) - service = EligibilityService(person_repo, campaign_repo, EligibilityCalculatorFactory()) + service = EligibilityService(person_repo, campaign_repo, consumer_mapping_repo, EligibilityCalculatorFactory()) # When with pytest.raises(UnknownPersonError): - service.get_eligibility_status(NHSNumber("1234567890"), include_actions="Y", conditions=["ALL"], category="ALL") + service.get_eligibility_status( + NHSNumber("1234567890"), + include_actions="Y", + conditions=["ALL"], + category="ALL", + consumer_id="test_consumer_id", + ) + + +def test_get_eligibility_status_filters_permitted_campaigns(service, mock_repos): + """Tests that ONLY permitted campaigns reach the calculator factory.""" + # Given + nhs_number = NHSNumber("1234567890") + person_data = {"age": 65, "vulnerable": True} + mock_repos["person"].get_eligibility_data.return_value = person_data + + # Available campaigns in system + camp_a = MagicMock(spec=CampaignConfig, id=CampaignID("CAMP_A")) + camp_b = MagicMock(spec=CampaignConfig, id=CampaignID("CAMP_B")) + mock_repos["campaign"].get_campaign_configs.return_value = [camp_a, camp_b] + + # Consumer is only permitted to see CAMP_B + mock_repos["consumer"].get_permitted_campaign_ids.return_value = [CampaignID("CAMP_B")] + + # Mock calculator behavior + mock_calc = MagicMock() + mock_repos["factory"].get.return_value = mock_calc + mock_calc.get_eligibility_status.return_value = "eligible_result" + + # When + result = service.get_eligibility_status(nhs_number, "Y", ["FLU"], "G1", "consumer_xyz") + + # Then + # Verify the factory was called ONLY with camp_b + mock_repos["factory"].get.assert_called_once_with(person_data, [camp_b]) + assert result == "eligible_result" + + +def test_raises_unknown_person_error_on_repo_not_found(service, mock_repos): + """Tests that NotFoundError from repo is translated to UnknownPersonError.""" + mock_repos["person"].get_eligibility_data.side_effect = NotFoundError + + with pytest.raises(UnknownPersonError): + service.get_eligibility_status(NHSNumber("999"), "Y", [], "", "any") diff --git a/tests/unit/views/test_eligibility.py b/tests/unit/views/test_eligibility.py index 5c323a7b2..cecd26c38 100644 --- a/tests/unit/views/test_eligibility.py +++ b/tests/unit/views/test_eligibility.py @@ -29,10 +29,10 @@ ) from eligibility_signposting_api.services import EligibilityService, UnknownPersonError from eligibility_signposting_api.views.eligibility import ( + _get_or_default_query_params, build_actions, build_eligibility_cohorts, build_suitability_results, - get_or_default_query_params, ) from eligibility_signposting_api.views.response_model import eligibility_response from tests.fixtures.builders.model.eligibility import ( @@ -60,6 +60,7 @@ def get_eligibility_status( _include_actions: str, _conditions: list[str], _category: str, + _consumer_id: str, ) -> EligibilityStatus: return EligibilityStatusFactory.build() @@ -74,6 +75,7 @@ def get_eligibility_status( _include_actions: str, _conditions: list[str], _category: str, + _consumer_id: str, ) -> EligibilityStatus: raise UnknownPersonError @@ -100,7 +102,7 @@ def test_security_headers_present_on_successful_response(app: Flask, client: Fla get_app_container(app).override.service(AuditService, new=FakeAuditService()), ): # When - headers = {"nhs-login-nhs-number": "9876543210"} + headers = {"nhs-login-nhs-number": "9876543210", "Consumer-Id": "test_consumer_id"} response = client.get("/patient-check/9876543210", headers=headers) # Then @@ -128,7 +130,7 @@ def test_security_headers_present_on_error_response(app: Flask, client: FlaskCli get_app_container(app).override.service(AuditService, new=FakeAuditService()), ): # When - headers = {"nhs-login-nhs-number": "9876543210"} + headers = {"nhs-login-nhs-number": "9876543210", "consumer-id": "test_customer_id"} response = client.get("/patient-check/9876543210", headers=headers) # Then @@ -177,7 +179,7 @@ def test_nhs_number_given(app: Flask, client: FlaskClient): get_app_container(app).override.service(AuditService, new=FakeAuditService()), ): # Given - headers = {"nhs-login-nhs-number": str(12345)} + headers = {"nhs-login-nhs-number": str(12345), "consumer-id": "test_customer_id"} # When response = client.get("/patient-check/12345", headers=headers) @@ -190,7 +192,7 @@ def test_no_nhs_number_given(app: Flask, client: FlaskClient): # Given with get_app_container(app).override.service(EligibilityService, new=FakeUnknownPersonEligibilityService()): # Given - headers = {"nhs-login-nhs-number": str(12345)} + headers = {"nhs-login-nhs-number": str(12345), "consumer-id": "test_customer_id"} # When response = client.get("/patient-check/", headers=headers) @@ -229,7 +231,7 @@ def test_no_nhs_number_given(app: Flask, client: FlaskClient): def test_unexpected_error(app: Flask, client: FlaskClient): # Given - headers = {"nhs-login-nhs-number": str(12345)} + headers = {"nhs-login-nhs-number": str(12345), "consumer-id": "test_customer_id"} with get_app_container(app).override.service(EligibilityService, new=FakeUnexpectedErrorEligibilityService()): response = client.get("/patient-check/12345", headers=headers) @@ -439,7 +441,9 @@ def test_excludes_nulls_via_build_response(client: FlaskClient): return_value=mocked_response, ), ): - response = client.get("/patient-check/12345", headers={"nhs-login-nhs-number": str(12345)}) + response = client.get( + "/patient-check/12345", headers={"nhs-login-nhs-number": str(12345), "consumer-id": "test_customer_id"} + ) assert response.status_code == HTTPStatus.OK payload = json.loads(response.data) @@ -491,7 +495,9 @@ def test_build_response_include_values_that_are_not_null(client: FlaskClient): return_value=mocked_response, ), ): - response = client.get("/patient-check/12345", headers={"nhs-login-nhs-number": str(12345)}) + response = client.get( + "/patient-check/12345", headers={"nhs-login-nhs-number": str(12345), "consumer-id": "test_customer_id"} + ) assert response.status_code == HTTPStatus.OK payload = json.loads(response.data) @@ -507,7 +513,7 @@ def test_build_response_include_values_that_are_not_null(client: FlaskClient): def test_get_or_default_query_params_with_no_args(app: Flask): with app.test_request_context("/patient-check"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"category": "ALL", "conditions": ["ALL"], "includeActions": "Y"} @@ -516,7 +522,7 @@ def test_get_or_default_query_params_with_no_args(app: Flask): def test_get_or_default_query_params_with_all_args(app: Flask): with app.test_request_context("/patient-check?includeActions=Y&category=VACCINATIONS&conditions=FLU"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"includeActions": "Y", "category": "VACCINATIONS", "conditions": ["FLU"]} @@ -525,7 +531,7 @@ def test_get_or_default_query_params_with_all_args(app: Flask): def test_get_or_default_query_params_with_partial_args(app: Flask): with app.test_request_context("/patient-check?includeActions=N"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"includeActions": "N", "category": "ALL", "conditions": ["ALL"]} @@ -534,13 +540,13 @@ def test_get_or_default_query_params_with_partial_args(app: Flask): def test_get_or_default_query_params_with_lowercase_y(app: Flask): with app.test_request_context("/patient-check?includeActions=y"): - result = get_or_default_query_params() + result = _get_or_default_query_params() assert_that(result["includeActions"], is_("Y")) def test_get_or_default_query_params_missing_include_actions(app: Flask): with app.test_request_context("/patient-check?category=SCREENING&conditions=COVID19,FLU"): - result = get_or_default_query_params() + result = _get_or_default_query_params() expected = {"includeActions": "Y", "category": "SCREENING", "conditions": ["COVID19", "FLU"]} @@ -581,3 +587,30 @@ def test_status_endpoint(app: Flask, client: FlaskClient): ) ), ) + + +def test_consumer_id_is_passed_to_service(app: Flask, client: FlaskClient): + """ + Verifies that the consumer ID from the header is actually passed + to the eligibility service call. + """ + # Given + mock_service = MagicMock(spec=EligibilityService) + mock_service.get_eligibility_status.return_value = EligibilityStatusFactory.build() + + with ( + get_app_container(app).override.service(EligibilityService, new=mock_service), + get_app_container(app).override.service(AuditService, new=FakeAuditService()), + ): + headers = {"nhs-login-nhs-number": "1234567890", "Consumer-Id": "specific_consumer_123"} + + # When + client.get("/patient-check/1234567890", headers=headers) + + # Then + # Verify the 5th positional argument or the keyword argument 'consumer_id' + mock_service.get_eligibility_status.assert_called_once() + args, _kwargs = mock_service.get_eligibility_status.call_args + + # Check that 'specific_consumer_123' was the consumer_id passed + assert args[4] == "specific_consumer_123"