diff --git a/fri/server/main.py b/fri/server/main.py index f94bc66..cf5c7ac 100644 --- a/fri/server/main.py +++ b/fri/server/main.py @@ -7,8 +7,80 @@ from pathlib import Path import json import platform +import re from flask_cors import CORS, cross_origin +# Input validation pattern for safe names (alphanumeric, dash, underscore, slash, dot, space) +SAFE_INPUT_PATTERN = re.compile(r'^[a-zA-Z0-9_\-/. ]+$') +# Pattern for filenames - no path separators or .. allowed +SAFE_FILENAME_PATTERN = re.compile(r'^[a-zA-Z0-9_\-. ]+$') + +def validate_input(value, field_name, required=False): + """Validate that input contains only safe characters.""" + if value is None: + if required: + raise ValueError(f"Missing required field: {field_name}") + return True + if not isinstance(value, str): + raise ValueError(f"Invalid {field_name}: must be a string") + if required and len(value) == 0: + raise ValueError(f"Missing required field: {field_name}") + if len(value) > 0 and not SAFE_INPUT_PATTERN.match(value): + raise ValueError(f"Invalid {field_name}: contains unsafe characters") + return True + +def validate_filename(value, field_name, required=False): + """Validate filename - no path separators or .. segments allowed.""" + if value is None: + if required: + raise ValueError(f"Missing required field: {field_name}") + return True + if not isinstance(value, str): + raise ValueError(f"Invalid {field_name}: must be a string") + if required and len(value) == 0: + raise ValueError(f"Missing required field: {field_name}") + # Reject path traversal attempts + if '..' in value: + raise ValueError(f"Invalid {field_name}: path traversal not allowed") + # Use basename to strip any path components + basename = os.path.basename(value) + if basename != value: + raise ValueError(f"Invalid {field_name}: must be a filename, not a path") + if len(value) > 0 and not SAFE_FILENAME_PATTERN.match(value): + raise ValueError(f"Invalid {field_name}: contains unsafe characters") + return True + +def validate_text_field(value, field_name, max_length=None): + """Validate text fields like PR title/body - allow more characters but check type/length.""" + if value is None: + return True + if not isinstance(value, str): + raise ValueError(f"Invalid {field_name}: must be a string") + if max_length and len(value) > max_length: + raise ValueError(f"Invalid {field_name}: too long (max {max_length} characters)") + return True + +def get_error_output(e): + """Extract error output from CalledProcessError, preferring stderr then output.""" + raw_output = None + if hasattr(e, 'stderr') and e.stderr: + raw_output = e.stderr + elif hasattr(e, 'output') and e.output: + raw_output = e.output + + if raw_output is None: + return "Command execution failed" + + if isinstance(raw_output, bytes): + try: + return raw_output.decode('utf-8', errors='replace') + except Exception: + return str(raw_output) + elif isinstance(raw_output, str): + return raw_output + else: + return str(raw_output) + cur_path = os.path.dirname(os.path.abspath(__file__)) concore_path = os.path.abspath(os.path.join(cur_path, '../../')) @@ -298,20 +370,34 @@ def clear(dir): def contribute(): try: data = request.json - PR_TITLE = data.get('title') - PR_BODY = data.get('desc') - AUTHOR_NAME = data.get('auth') - STUDY_NAME = data.get('study') - STUDY_NAME_PATH = data.get('path') - BRANCH_NAME = data.get('branch') + PR_TITLE = data.get('title') or '' + PR_BODY = data.get('desc') or '' + AUTHOR_NAME = data.get('auth') or '' + STUDY_NAME = data.get('study') or '' + STUDY_NAME_PATH = data.get('path') or '' + BRANCH_NAME = data.get('branch') or '' + + # Validate all user inputs to prevent command injection + # Strict validation for names/paths that go into command arguments + validate_input(STUDY_NAME, 'study', required=True) + validate_input(STUDY_NAME_PATH, 'path', required=True) + validate_input(AUTHOR_NAME, 'auth', required=True) + validate_input(BRANCH_NAME, 'branch', required=False) + + # For PR title/body, allow more characters but enforce type/length + validate_text_field(PR_TITLE, 'title', max_length=512) + validate_text_field(PR_BODY, 'desc', max_length=8192) + if(platform.uname()[0]=='Windows'): - proc=check_output(["contribute",STUDY_NAME,STUDY_NAME_PATH,AUTHOR_NAME,BRANCH_NAME,PR_TITLE,PR_BODY],cwd=concore_path,shell=True) + # Use cmd.exe /c to invoke contribute.bat on Windows + proc = subprocess.run(["cmd.exe", "/c", "contribute.bat", STUDY_NAME, STUDY_NAME_PATH, AUTHOR_NAME, BRANCH_NAME, PR_TITLE, PR_BODY], cwd=concore_path, check=True, capture_output=True, text=True) + output_string = proc.stdout else: if len(BRANCH_NAME)==0: proc = check_output([r"./contribute",STUDY_NAME,STUDY_NAME_PATH,AUTHOR_NAME],cwd=concore_path) else: proc = check_output([r"./contribute",STUDY_NAME,STUDY_NAME_PATH,AUTHOR_NAME,BRANCH_NAME,PR_TITLE,PR_BODY],cwd=concore_path) - output_string = proc.decode() + output_string = proc.decode() status=200 if output_string.find("/pulls/")!=-1: status=200 @@ -320,6 +406,11 @@ def contribute(): else: status=400 return jsonify({'message': output_string}),status + except ValueError as e: + return jsonify({'message': str(e)}), 400 + except subprocess.CalledProcessError as e: + output_string = get_error_output(e) + return jsonify({'message': output_string}), 501 except Exception as e: output_string = "Some Error occured.Please try after some time" status=501 @@ -365,18 +456,36 @@ def library(dir): dir_path = os.path.abspath(os.path.join(concore_path, dir_name)) filename = request.args.get('filename') library_path = request.args.get('path') - proc = 0 + + # Validate user inputs to prevent command injection + try: + # Use strict filename validation - no path separators or .. allowed + validate_filename(filename, 'filename', required=True) + validate_input(library_path, 'path', required=False) + except ValueError as e: + resp = jsonify({'message': str(e)}) + resp.status_code = 400 + return resp + if (library_path == None or library_path == ''): library_path = r"../tools" - if(platform.uname()[0]=='Windows'): - proc = subprocess.check_output([r"..\library", library_path, filename],shell=True, cwd=dir_path) - else: - proc = subprocess.check_output([r"../library", library_path, filename], cwd=dir_path) - if(proc != 0): - resp = jsonify({'message': proc.decode("utf-8")}) + try: + if(platform.uname()[0]=='Windows'): + # Use cmd.exe /c to invoke library.bat on Windows + result = subprocess.run(["cmd.exe", "/c", r"..\library.bat", library_path, filename], cwd=dir_path, check=True, capture_output=True, text=True) + proc = result.stdout + else: + proc = subprocess.check_output([r"../library", library_path, filename], cwd=dir_path) + proc = proc.decode("utf-8") + resp = jsonify({'message': proc}) resp.status_code = 201 return resp - else: + except subprocess.CalledProcessError as e: + error_output = get_error_output(e) + resp = jsonify({'message': f'Command execution failed: {error_output}'}) + resp.status_code = 500 + return resp + except Exception as e: resp = jsonify({'message': 'There is an Error'}) resp.status_code = 500 return resp