diff --git a/setup.py b/setup.py index 7e9cb34a3..0279f02df 100755 --- a/setup.py +++ b/setup.py @@ -26,36 +26,36 @@ Environmental variables: If building with system MuPDF (PYMUPDF_SETUP_MUPDF_BUILD is empty string): - + CFLAGS CXXFLAGS LDFLAGS Added to c, c++, and link commands. - + PYMUPDF_INCLUDES Colon-separated extra include paths. - + PYMUPDF_MUPDF_LIB Directory containing MuPDF libraries, (libmupdf.so, libmupdfcpp.so). - + PIPCL_SHOW_ENV If '0', we do not show environment variables on startup. - + PYMUPDF_SETUP_DEVENV Location of devenv.com on Windows. If unset we search for it - see wdev.py. if that fails we use just 'devenv.com'. PYMUPDF_SETUP_DUMMY If 1, we build dummy sdist and wheel with no files. - + PYMUPDF_SETUP_FLAVOUR Control building of separate wheels for PyMuPDF. - + Must be unset or a combination of 'p', 'b' and 'd'. - + Default is 'pbd'. - + 'p': Generated wheel contains PyMuPDF code. 'b': @@ -63,28 +63,28 @@ the Python version. 'd': Generated wheel contains includes and libraries for MuPDF. - + If 'p' is included, the generated wheel is called PyMuPDF. Otherwise if 'b' is included the generated wheel is called PyMuPDFb. Otherwise if 'd' is included the generated wheel is called PyMuPDFd. - + For example: - + 'pb': a `PyMuPDF` wheel with PyMuPDF runtime files and MuPDF runtime shared libraries. - + 'b': a `PyMuPDFb` wheel containing MuPDF runtime shared libraries. - + 'pbd' a `PyMuPDF` wheel with PyMuPDF runtime files and MuPDF runtime shared libraries, plus MuPDF build-time files (includes, *.lib files on Windows). - + 'd': a `PyMuPDFd` wheel containing MuPDF build-time files (includes, *.lib files on Windows). - + PYMUPDF_SETUP_LIBCLANG For internal testing. - + PYMUPDF_SETUP_MUPDF_BUILD If unset or '-', use internal hard-coded default MuPDF location. Otherwise overrides location of MuPDF when building PyMuPDF: @@ -99,19 +99,19 @@ Passed as arg to pipcl.git_get(). Otherwise: Location of mupdf directory. - + PYMUPDF_SETUP_MUPDF_BSYMBOLIC If '0' we do not link libmupdf.so with -Bsymbolic. - + PYMUPDF_SETUP_MUPDF_TESSERACT If '0' we build MuPDF without Tesseract. - + PYMUPDF_SETUP_MUPDF_BUILD_TYPE Unix only. Controls build type of MuPDF. Supported values are: debug memento release (default) - + PYMUPDF_SETUP_FAKE_NOGIL If '1' we (incorrectly) claim we are thread-safe. @@ -121,20 +121,20 @@ PYMUPDF_SETUP_MUPDF_REFCHECK_IF Should be preprocessor statement to enable MuPDF reference count checking. - + As of 2024-09-27, MuPDF default is `#ifndef NDEBUG`. PYMUPDF_SETUP_MUPDF_TRACE_IF Should be preprocessor statement to enable MuPDF runtime diagnostics in response to environment variables such as MUPDF_trace. - + As of 2024-09-27, MuPDF default is `#ifndef NDEBUG`. PYMUPDF_SETUP_MUPDF_THIRD If '0' and we are building on Linux with the system MuPDF (i.e. PYMUPDF_SETUP_MUPDF_BUILD=''), then don't link with `-lmupdf-third`. - + PYMUPDF_SETUP_MUPDF_VS_UPGRADE If '1' we run mupdf `scripts/mupdfwrap.py` with `--vs-upgrade 1` to help Windows builds work with Visual Studio versions newer than 2019. @@ -156,31 +156,31 @@ PYMUPDF_SETUP_MUPDF_OVERWRITE_CONFIG If '0' we do not overwrite MuPDF's include/mupdf/fitz/config.h with PyMuPDF's own configuration file, before building MuPDF. - + PYMUPDF_SETUP_MUPDF_REBUILD If 0 we do not (re)build mupdf. - + PYMUPDF_SETUP_PY_LIMITED_API If not '0', we build for current Python's stable ABI. - + PYMUPDF_SETUP_URL_WHEEL If set, we use an existing wheel instead of building a new wheel. - + If starts with `http://` or `https://`: If ends with '/', we append our wheel name and download. Otherwise we download directly. - + If starts with `file://`: If ends with '/' we look for a matching wheel name, `using pipcl.wheel_name_match()` to cope with differing platform tags, for example our `manylinux2014_x86_64` will match with an existing wheel with `manylinux2014_x86_64.manylinux_2_17_x86_64`. - + Any other prefix is an error. PYMUPDF_SETUP_SWIG If set, we use this instead of `swig`. - + WDEV_VS_YEAR If set, we use as Visual Studio year, for example '2019' or '2022'. @@ -257,7 +257,7 @@ def _fs_remove(path): os.remove( path) except Exception as e: pass - + if os.path.exists(path): # Try deleting `path` as a directory. Need to use # shutil.rmtree() callback to handle permission problems; see: @@ -268,7 +268,7 @@ def error_fn(fn, path, excinfo): os.chmod(path, stat.S_IWRITE) fn(path) shutil.rmtree( path, onerror=error_fn) - + assert not os.path.exists( path) @@ -301,7 +301,7 @@ def tar_check(path, mode='r:gz', prefix=None, remove=False): As tarfile.open(). prefix: If not None, we fail if tar file's is not . - + Returns the directory name (which will be if not None). ''' with tarfile.open( path, mode) as t: @@ -324,7 +324,7 @@ def tar_check(path, mode='r:gz', prefix=None, remove=False): def tar_extract(path, mode='r:gz', prefix=None, exists='raise'): ''' Extracts tar file into single local directory. - + We fail if items in tar file have different . path: @@ -338,7 +338,7 @@ def tar_extract(path, mode='r:gz', prefix=None, exists='raise'): 'raise': raise exception. 'remove': remove existing file/directory before extracting. 'return': return without extracting. - + Returns the directory name (which will be if not None, with '/' appended if not already present). ''' @@ -393,7 +393,7 @@ def git_info( directory): def git_patch(directory, patch, hard=False): ''' Applies string with `git patch` in . - + If is true we clean the tree with `git checkout .` and then apply the patch. @@ -427,7 +427,7 @@ def git_patch(directory, patch, hard=False): def get_mupdf_internal(out, location=None, local_tgz=None): ''' Gets MuPDF as either a .tgz or a local directory. - + Args: out: Either 'dir' (we return name of local directory containing mupdf) or 'tgz' (we return @@ -446,25 +446,25 @@ def get_mupdf_internal(out, location=None, local_tgz=None): `location_out` is `location` if not None, else the hard-coded default location. - + ''' log(f'get_mupdf_internal(): {out=} {location=}') assert out in ('dir', 'tgz') if location is None: location = f'https://mupdf.com/downloads/archive/mupdf-{version_mupdf}-source.tar.gz' #location = 'git:--branch master https://github.com/ArtifexSoftware/mupdf.git' - + if location == '': # Use system mupdf. return None, location - + local_dir = None if local_tgz: assert os.path.isfile(local_tgz) elif location.startswith( 'git:'): local_dir = 'mupdf-git' pipcl.git_get(local_dir, text=location, remote='https://github.com/ArtifexSoftware/mupdf.git') - + # Show sha of checkout. run( f'cd {local_dir} && git show --pretty=oneline|head -n 1', @@ -495,7 +495,7 @@ def get_mupdf_internal(out, location=None, local_tgz=None): else: assert os.path.isdir(location), f'Local MuPDF does not exist: {location=}' local_dir = location - + assert bool(local_dir) != bool(local_tgz) if out == 'dir': if not local_dir: @@ -522,14 +522,14 @@ def get_mupdf_internal(out, location=None, local_tgz=None): return os.path.abspath( local_tgz), location else: assert 0, f'Unrecognised {out=}' - - + + def get_mupdf_tgz(): ''' Creates .tgz file called containing MuPDF source, for inclusion in an sdist. - + What we do depends on environmental variable PYMUPDF_SETUP_MUPDF_TGZ; see docs at start of this file for details. @@ -577,11 +577,11 @@ def build(): pipcl.py `build_fn()` callback. ''' #pipcl.show_sysconfig() - + if PYMUPDF_SETUP_DUMMY == '1': log(f'{PYMUPDF_SETUP_DUMMY=} Building dummy wheel with no files.') return list() - + # Download MuPDF. # mupdf_local, mupdf_location = get_mupdf() @@ -592,12 +592,12 @@ def build(): build_type = os.environ.get( 'PYMUPDF_SETUP_MUPDF_BUILD_TYPE', 'release') assert build_type in ('debug', 'memento', 'release'), \ f'Unrecognised build_type={build_type!r}' - + overwrite_config = os.environ.get('PYMUPDF_SETUP_MUPDF_OVERWRITE_CONFIG', '1') == '1' - + PYMUPDF_SETUP_MUPDF_REFCHECK_IF = os.environ.get('PYMUPDF_SETUP_MUPDF_REFCHECK_IF') PYMUPDF_SETUP_MUPDF_TRACE_IF = os.environ.get('PYMUPDF_SETUP_MUPDF_TRACE_IF') - + # Build MuPDF shared libraries. # if windows: @@ -627,7 +627,7 @@ def build(): PYMUPDF_SETUP_FAKE_NOGIL, ) log( f'build(): mupdf_build_dir={mupdf_build_dir!r}') - + # Build `extra` module. # if 'p' in PYMUPDF_SETUP_FLAVOUR: @@ -640,24 +640,26 @@ def build(): else: log(f'Not building extension.') path_so_leaf = None - + # Generate list of (from, to) items to return to pipcl. What we add depends # on PYMUPDF_SETUP_FLAVOUR. # - ret = list() + ret = list() def add(flavour, from_, to_): assert flavour in 'pbd' if flavour in PYMUPDF_SETUP_FLAVOUR: ret.append((from_, to_)) - + to_dir = 'pymupdf/' to_dir_d = f'{to_dir}/mupdf-devel' - + # Add implementation files. add('p', f'{g_root}/src/__init__.py', to_dir) add('p', f'{g_root}/src/__main__.py', to_dir) add('p', f'{g_root}/src/pymupdf.py', to_dir) add('p', f'{g_root}/src/table.py', to_dir) + add('p', f'{g_root}/src/TableGridExtractorV4.py', to_dir) + add('p', f'{g_root}/src/table_grid_model_v4.onnx', to_dir) add('p', f'{g_root}/src/utils.py', to_dir) add('p', f'{g_root}/src/_wxcolors.py', to_dir) add('p', f'{g_root}/src/_apply_pages.py', to_dir) @@ -718,7 +720,7 @@ def add(flavour, from_, to_): assert header_abs.startswith(root) header_rel = header_abs[len(root)+1:] add('d', f'{header_abs}', f'{to_dir_d}/include/{header_rel}') - + # Add a .py file containing location of MuPDF. try: sha, comment, diff, branch = git_info(g_root) @@ -751,7 +753,7 @@ def int_or_0(text): text += f'fake_no_gil = {PYMUPDF_SETUP_FAKE_NOGIL=="1"!r}\n' log(f'_build.py is:\n{textwrap.indent(text, " ")}') add('p', text.encode(), f'{to_dir}/_build.py') - + # Add single README file. if 'p' in PYMUPDF_SETUP_FLAVOUR: add('p', f'{g_root}/README.md', '$dist-info/README.md') @@ -759,14 +761,14 @@ def int_or_0(text): add('b', f'{g_root}/READMEb.md', '$dist-info/README.md') elif 'd' in PYMUPDF_SETUP_FLAVOUR: add('d', f'{g_root}/READMEd.md', '$dist-info/README.md') - + return ret def env_add(env, name, value, sep=' ', prepend=False, verbose=False): ''' Appends/prepends `` to `env[name]`. - + If `name` is not in `env`, we use os.environ[name] if it exists. ''' v = env.get(name) @@ -794,7 +796,7 @@ def build_mupdf_windows( PYMUPDF_SETUP_MUPDF_TRACE_IF, PYMUPDF_SETUP_FAKE_NOGIL, ): - + assert mupdf_local mupdf_version_tuple = get_mupdf_version(mupdf_local) log(f'{overwrite_config=}') @@ -802,7 +804,7 @@ def build_mupdf_windows( wp = pipcl.wdev.WindowsPython() tesseract = '' if os.environ.get('PYMUPDF_SETUP_MUPDF_TESSERACT') == '0' else 'tesseract-' windows_build_tail = f'build\\shared-{tesseract}{build_type}' - + if overwrite_config: if mupdf_version_tuple >= (1, 28): # Tell mupdf build to use, for example, `/Build "ReleaseTofuCjkExt|x64"`. @@ -824,7 +826,7 @@ def build_mupdf_windows( with open(mupdf_config_h, 'w') as f: f.write(text) os.utime(mupdf_config_h, (st.st_atime, st.st_mtime)) - + if g_py_limited_api: windows_build_tail += f'-Py_LIMITED_API_{pipcl.current_py_limited_api()}' if PYMUPDF_SETUP_FAKE_NOGIL == '1': @@ -852,7 +854,7 @@ def build_mupdf_windows( command = f'cd "{mupdf_local}" && "{sys.executable}" ./scripts/mupdfwrap.py' if os.environ.get('PYMUPDF_SETUP_MUPDF_VS_UPGRADE') == '1': command += ' --vs-upgrade 1' - + # Would like to simply do f'... --devenv {shutil.quote(devenv)}', but # it looks like if `devenv` has spaces then `shutil.quote()` puts it # inside single quotes, which then appear to be ignored when run by @@ -879,7 +881,7 @@ def build_mupdf_windows( log( f'Building MuPDF by running: {command}') subprocess.run( command, shell=True, check=True) log( f'Finished building mupdf.') - + return windows_build_dir @@ -918,11 +920,11 @@ def build_mupdf_unix( Args: mupdf_local: Path of MuPDF directory or None if we are using system MuPDF. - + Returns the absolute path of build directory within MuPDF, e.g. `.../mupdf/build/pymupdf-shared-release`, or `None` if we are using the system MuPDF. - ''' + ''' if not mupdf_local: log( f'Using system mupdf.') return None @@ -937,7 +939,7 @@ def build_mupdf_unix( if openbsd or freebsd: env_add(env, 'CXX', 'c++', ' ') - + if darwin and os.environ.get('GITHUB_ACTIONS') == 'true': if os.environ.get('ImageOS') == 'macos13': # On Github macos13 we need to use Clang/LLVM (Homebrew) 15.0.7, @@ -959,7 +961,7 @@ def build_mupdf_unix( cxx = f'{cl15}/bin/clang++' env['CC'] = cc env['CXX'] = cxx - + # Show compiler versions. cc = env.get('CC', 'cc') cxx = env.get('CXX', 'c++') @@ -975,7 +977,7 @@ def build_mupdf_unix( env_add(env, 'XLIBS', archflags) mupdf_version_tuple = get_mupdf_version(mupdf_local) - + # We specify a build directory path containing 'pymupdf' so that we # coexist with non-PyMuPDF builds (because PyMuPDF builds have a # different config.h). @@ -1061,7 +1063,7 @@ def build_mupdf_unix( log( f'Building MuPDF by running: {command}') subprocess.run( command, shell=True, check=True) log( f'Finished building mupdf.') - + return unix_build_dir @@ -1088,7 +1090,7 @@ def _fs_update(text, path): if text != text0: with open( path, 'w') as f: f.write( text) - + def _build_extension( mupdf_local, mupdf_build_dir, build_type, g_py_limited_api): ''' @@ -1104,7 +1106,7 @@ def _build_extension( mupdf_local, mupdf_build_dir, build_type, g_py_limited_api f'{mupdf_local}/platform/c++/include', f'{mupdf_local}/include', ) - + log('Building PyMuPDF extension.') compile_extra_cpp = '' if darwin: @@ -1124,7 +1126,7 @@ def _build_extension( mupdf_local, mupdf_build_dir, build_type, g_py_limited_api f'{mupdf_build_dir}/libmupdf.so' f'{mupdf_build_dir}/libmupdfcpp.so' ] - + path_so_leaf = pipcl.build_extension( name = 'extra', path_i = f'{g_root}/src/extra.i', @@ -1144,7 +1146,7 @@ def _build_extension( mupdf_local, mupdf_build_dir, build_type, g_py_limited_api swig = PYMUPDF_SETUP_SWIG, nogil = (PYMUPDF_SETUP_FAKE_NOGIL=='1') ) - + return path_so_leaf @@ -1197,7 +1199,7 @@ def _extension_flags( mupdf_local, mupdf_build_dir, build_type): libraries = None if libpaths: libpaths = libpaths.split(':') - + if mupdf_local: includes = ( f'{mupdf_local}/include', @@ -1223,29 +1225,29 @@ def _extension_flags( mupdf_local, mupdf_build_dir, build_type): if cxxflags: compiler_extra += f' {cxxflags}' - return compiler_extra, linker_extra, includes, defines, optimise, debug, libpaths, libs, libraries, + return compiler_extra, linker_extra, includes, defines, optimise, debug, libpaths, libs, libraries, def clean(all_): pipcl.log(f'{all_=}') ret = list() ret.append(f'{g_root}/src/build') - + path_mupdf, _ = get_mupdf() - + # We remove mupdf directories directly with shutil.rmtree() instead of # returning them to pipcl, because pipcl will deliberately fail if asked to # remove things that are outside our checkout. shutil.rmtree(f'{path_mupdf}/platform/c++', ignore_errors=True) shutil.rmtree(f'{path_mupdf}/platform/python', ignore_errors=True) - + if all_: # Clean mupdf C library. shutil.rmtree(f'{path_mupdf}/build', ignore_errors=True) shutil.rmtree(f'{path_mupdf}/platform/win32', ignore_errors=True) shutil.rmtree(f'{path_mupdf}/platform/win32/Release', ignore_errors=True) shutil.rmtree(f'{path_mupdf}/platform/win32/x64', ignore_errors=True) - + pipcl.log(f'Returning: {ret=}') return ret @@ -1254,7 +1256,7 @@ def sdist(): ret = list() if PYMUPDF_SETUP_DUMMY == '1': return ret - + if PYMUPDF_SETUP_FLAVOUR == 'b': # Create a minimal sdist that will build/install a dummy PyMuPDFb. for p in ( @@ -1271,7 +1273,7 @@ def sdist(): ) ) return ret - + for p in pipcl.git_items( g_root): if p.startswith( ( @@ -1323,17 +1325,17 @@ def sdist(): version_b = '1.26.3' if os.path.exists(f'{g_root}/{g_pymupdfb_sdist_marker}'): - + # We are in a PyMuPDFb sdist. We specify a dummy package so that pip builds # from sdists work - pip's build using PyMuPDF's sdist will already create # the required binaries, but pip will still see `requires_dist` set to # 'PyMuPDFb', so will also download and build PyMuPDFb's sdist. # log(f'Specifying dummy PyMuPDFb wheel.') - + def get_requires_for_build_wheel(config_settings=None): return list() - + p = pipcl.Package( 'PyMuPDFb', version_b, @@ -1347,7 +1349,7 @@ def get_requires_for_build_wheel(config_settings=None): else: # A normal PyMuPDF package. - + with open( f'{g_root}/README.md', encoding='utf-8') as f: readme_p = f.read() @@ -1360,7 +1362,7 @@ def get_requires_for_build_wheel(config_settings=None): tag_python = None requires_dist = list() entry_points = None - + if 'p' in PYMUPDF_SETUP_FLAVOUR: version = version_p name = 'PyMuPDF' @@ -1387,7 +1389,7 @@ def get_requires_for_build_wheel(config_settings=None): tag_python = 'py3' else: assert 0, f'Unrecognised {PYMUPDF_SETUP_FLAVOUR=}.' - + if os.environ.get('PYODIDE_ROOT'): # We can't pip install pytest on pyodide, so specify it here. requires_dist.append('pytest') @@ -1410,13 +1412,13 @@ def get_requires_for_build_wheel(config_settings=None): ('Tracker, https://github.com/pymupdf/PyMuPDF/issues'), ('Changelog, https://pymupdf.readthedocs.io/en/latest/changes.html'), ], - + entry_points = entry_points, - + fn_build=build, fn_clean=clean, fn_sdist=sdist, - + tag_python=tag_python, py_limited_api=g_py_limited_api, @@ -1438,7 +1440,7 @@ def platform_release_tuple(): r = tuple(int(i) for i in r) log(f'platform_release_tuple() returning {r=}.') return r - + ret = list() libclang = os.environ.get('PYMUPDF_SETUP_LIBCLANG') if libclang: @@ -1514,7 +1516,7 @@ def build_wheel( shutil.copy2(in_path, out_path_temp) else: assert 0, f'Unrecognised prefix in {PYMUPDF_SETUP_URL_WHEEL=}.' - + log(f'Renaming from:\n {out_path_temp}\nto:\n {out_path}.') os.rename(out_path_temp, out_path) return os.path.basename(out_path) diff --git a/src/TableGridExtractorV4.py b/src/TableGridExtractorV4.py new file mode 100644 index 000000000..e70999c5d --- /dev/null +++ b/src/TableGridExtractorV4.py @@ -0,0 +1,562 @@ +""" +TableGridExtractorV4.py + +Loads exported GridModelV4 and ConnClassifier ONNX models and predicts +table grid structure (line positions + cell connectivity) from a cropped +table image. + +Post-processing pipeline +------------------------ +1. Resize input image to GridModelV4 input size. +2. Run GridModelV4 ONNX inference. + Outputs: h_on_logit, h_offset, v_on_logit, v_offset, feature_map +3. Decode anchor outputs -> normalized line positions using 1D NMS. + score = sigmoid(on_logit); line_pos = anchor + offset * anchor_step +4. Convert normalized line positions to input image pixel coordinates. +5. Optionally filter empty lines and snap to bbox gaps. +6. Run ConnClassifier ONNX inference on cell features. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional +from dataclasses import dataclass + +import numpy as np # pylint: disable=import-error +import onnxruntime as ort # pylint: disable=import-error + +import pymupdf + +# --------------------------------------------------------------------------- +# Inline GridPrediction / CellInfo (standalone mode) +# --------------------------------------------------------------------------- + + +@dataclass +class GridPrediction: # type: ignore[no-redef] + h_lines: list + v_lines: list + h_heatmap: object = None + v_heatmap: object = None + h_confidences: object = None + v_confidences: object = None + db_prob_map: object = None + h_on_prob: object = None + v_on_prob: object = None + h_lines_norm: object = None + v_lines_norm: object = None + h_cls: object = None + connectivity: object = None + + +@dataclass +class CellInfo: # type: ignore[no-redef] + bbox_idx: int + row_start: int + row_end: int + col_start: int + col_end: int + row: int = 0 + col: int = 0 + text: str = "" + + @property + def row_span(self) -> int: + return self.row_end - self.row_start + + @property + def col_span(self) -> int: + return self.col_end - self.col_start + + +# --------------------------------------------------------------------------- +# Anchor decode helper +# --------------------------------------------------------------------------- + + +def _decode_anchors( + on_logits: np.ndarray, + offsets: np.ndarray, + threshold: float, + nms_min_dist: float = 0.0, +) -> tuple: + """ + Decode anchor-based line predictions to normalized positions. + + offset is anchor_step-normalized: anchor + offset * anchor_step to decode. + score = sigmoid(on_logit). + """ + max_n = len(on_logits) + on_prob = (1.0 / (1.0 + np.exp(-on_logits.astype(np.float64)))).astype(np.float32) + anchors = np.linspace(0.0, 1.0, max_n, dtype=np.float32) + mask = on_prob >= threshold + + anchor_step = 1.0 / (max_n - 1) if max_n > 1 else 1.0 + candidates = (anchors + offsets * anchor_step)[mask] + scores = on_prob[mask] + + if nms_min_dist > 0.0 and len(candidates) > 0: + order = np.argsort(-scores) + suppressed = np.zeros(len(candidates), dtype=bool) + keep = [] + for i in order: + if suppressed[i]: + continue + keep.append(i) + for j in range(len(candidates)): + if not suppressed[j] and j != i: + if abs(float(candidates[j]) - float(candidates[i])) < nms_min_dist: + suppressed[j] = True + positions = np.sort(candidates[np.array(keep, dtype=np.int32)]) + else: + positions = np.sort(candidates) + + return positions, on_prob, mask + + +# --------------------------------------------------------------------------- +# Cell feature extractor +# --------------------------------------------------------------------------- + + +def _extract_cell_features( + feature_map: np.ndarray, + h_lines_norm: np.ndarray, + v_lines_norm: np.ndarray, +) -> np.ndarray: + """Pool feature_map regions defined by detected line positions.""" + C, H, W = feature_map.shape + N = len(h_lines_norm) + M = len(v_lines_norm) + + if N < 2 or M < 2: + return np.zeros((max(N - 1, 1), max(M - 1, 1), C), dtype=np.float32) + + ys = np.clip((h_lines_norm * H).astype(np.int32), 0, H) + xs = np.clip((v_lines_norm * W).astype(np.int32), 0, W) + + cell_feat = np.zeros((N - 1, M - 1, C), dtype=np.float32) + for i in range(N - 1): + y1 = ys[i] + y2 = max(y1 + 1, ys[i + 1]) + y2 = min(y2, H) + if y1 >= H: + continue + for j in range(M - 1): + x1 = xs[j] + x2 = max(x1 + 1, xs[j + 1]) + x2 = min(x2, W) + if x1 >= W: + continue + region = feature_map[:, y1:y2, x1:x2] + if region.size == 0: + continue + cell_feat[i, j] = region.mean(axis=(1, 2)) + + return cell_feat + + +# --------------------------------------------------------------------------- +# Main extractor class +# --------------------------------------------------------------------------- + + +class TableGridExtractorV4: + """ONNX-based table grid extractor using the V4 anchor-Transformer model.""" + + def __init__( + self, + grid_onnx_path, + conn_onnx_path=None, + h_on_threshold: float = 0.25, + v_on_threshold: float = 0.2, + conn_threshold: float = 0.2, + nms_min_dist: float = 0.02, + filter_empty_lines: bool = True, + snap_to_bbox_gaps: bool = False, + header_type: str = "1-Row", + merge_type: str = "BBox", + providers: Optional[list] = None, + ) -> None: + if providers is None: + providers = ["CPUExecutionProvider"] + + self.grid_onnx_path = Path(grid_onnx_path) + self.conn_onnx_path = Path(conn_onnx_path) if conn_onnx_path else None + self.h_on_threshold = h_on_threshold + self.v_on_threshold = v_on_threshold + self.conn_threshold = conn_threshold + self.nms_min_dist = nms_min_dist + self.filter_empty_lines = filter_empty_lines + self.snap_to_bbox_gaps = snap_to_bbox_gaps + self.header_type = header_type + self.merge_type = merge_type + + self._grid_sess = ort.InferenceSession( + str(self.grid_onnx_path), providers=providers + ) + self._conn_sess = ( + ort.InferenceSession(str(self.conn_onnx_path), providers=providers) + if self.conn_onnx_path and self.conn_onnx_path.exists() + else None + ) + + inp = self._grid_sess.get_inputs()[0] + self._input_h = int(inp.shape[2]) + self._input_w = int(inp.shape[3]) + + dummy = np.zeros((1, 3, self._input_h, self._input_w), dtype=np.float32) + outputs = self._grid_sess.run(None, {"image": dummy}) + self._max_h = int(outputs[0].shape[1]) + self._max_v = int(outputs[2].shape[1]) + + # ------------------------------------------------------------------ + # Post-processing helpers + # ------------------------------------------------------------------ + + @staticmethod + def _filter_empty_lines( + lines: list, + centers: list, + image_size: float, + ) -> list: + """ + Merge adjacent line pairs that have no bbox center between them. + CCL-style convergence. Leading/trailing empty gaps are removed. + """ + if not lines: + return lines + + current = sorted(lines) + changed = True + + while changed: + changed = False + if len(current) < 2: + break + result = [current[0]] + for i in range(1, len(current)): + lo_y, lo_score, lo_cls = result[-1] + hi_y, hi_score, hi_cls = current[i] + has_center = any(lo_y <= c <= hi_y for c in centers) + if not has_center: + if hi_score >= lo_score: + result[-1] = (hi_y, hi_score, hi_cls) + changed = True + else: + result.append(current[i]) + current = result + + return current + + @staticmethod + def _snap_lines_to_bbox_gaps( + h_lines: list, + bboxes_crop: np.ndarray, + snap_threshold: float = 0.0, + ) -> list: + """Snap h_lines crossing a bbox to nearest inter-row gap center.""" + if len(bboxes_crop) == 0 or len(h_lines) == 0: + return h_lines + + bottoms = np.sort(bboxes_crop[:, 3]) + tops = np.sort(bboxes_crop[:, 1]) + + gap_lines = [] + for bot in bottoms: + candidates = tops[tops > bot] + if len(candidates) > 0: + gap_lines.append((bot + float(candidates[0])) / 2.0) + + if not gap_lines: + return h_lines + + gap_arr = np.array( + sorted(set(round(g, 4) for g in gap_lines)), dtype=np.float32 + ) + used_gaps: set = set() + result = [] + for y in h_lines: + crosses = np.any((bboxes_crop[:, 1] < y) & (y < bboxes_crop[:, 3])) + if crosses: + dists = np.abs(gap_arr - y) + order = np.argsort(dists) + snapped = False + for idx in order: + nearest = float(gap_arr[idx]) + dist = float(dists[idx]) + if nearest in used_gaps: + continue + if snap_threshold <= 0.0 or dist <= snap_threshold: + result.append(nearest) + used_gaps.add(nearest) + snapped = True + break + if not snapped: + result.append(y) + else: + result.append(y) + return result + + # ------------------------------------------------------------------ + # Inference + # ------------------------------------------------------------------ + + def predict_grid(self, pixmap) -> GridPrediction: + orig_h, orig_w = pixmap.h, pixmap.w + pix_resized = pymupdf.Pixmap(pixmap, self._input_w, self._input_h) + if pix_resized.alpha: # If alpha channel was created, + pix_resized = pymupdf.Pixmap(pix_resized, 0) # remove it + img_rgb = np.frombuffer(pix_resized.samples, dtype=np.uint8).reshape( + pix_resized.height, pix_resized.width, pix_resized.n + ) + img = img_rgb.astype(np.float32) / 255.0 + inp = img.transpose(2, 0, 1)[np.newaxis] + + outputs = self._grid_sess.run(None, {"image": inp}) + h_on_logit = outputs[0][0] # (max_h,) + h_offset = outputs[1][0] # (max_h,) + v_on_logit = outputs[2][0] # (max_v,) + v_offset = outputs[3][0] # (max_v,) + feature_map = outputs[4][0] # (C, H', W') + + h_lines_norm, h_on_prob, h_mask = _decode_anchors( + h_on_logit, h_offset, self.h_on_threshold, self.nms_min_dist + ) + v_lines_norm, v_on_prob, _ = _decode_anchors( + v_on_logit, v_offset, self.v_on_threshold, self.nms_min_dist + ) + + h_cls = np.ones(len(h_lines_norm), dtype=np.int32) + h_lines = [float(y) * orig_h for y in h_lines_norm] + v_lines = [float(x) * orig_w for x in v_lines_norm] + + connectivity = None + if ( + self._conn_sess is not None + and len(h_lines_norm) >= 2 + and len(v_lines_norm) >= 2 + ): + cell_feat = _extract_cell_features(feature_map, h_lines_norm, v_lines_norm) + if np.isfinite(cell_feat).all(): + cell_inp = cell_feat.transpose(2, 0, 1)[np.newaxis].astype(np.float32) + conn_out = self._conn_sess.run(None, {"cell_features": cell_inp}) + conn_logits = conn_out[0][0] + conn_prob = 1.0 / (1.0 + np.exp(-conn_logits.astype(np.float64))) + connectivity = conn_prob.transpose(1, 2, 0).astype(np.float32) + + return GridPrediction( + h_lines=sorted(h_lines), + v_lines=sorted(v_lines), + h_on_prob=h_on_prob, + v_on_prob=v_on_prob, + h_lines_norm=h_lines_norm, + v_lines_norm=v_lines_norm, + h_cls=h_cls, + connectivity=connectivity, + ) + + def predict( + self, + pixmap, + bboxes=None, + texts=None, + span_threshold: float = 0.1, + ) -> tuple: + """ + Predict grid boundaries and optionally assign bboxes to grid cells. + + Parameters + ---------- + image_bgr : cropped table BGR image (any size) + bboxes : (N, 4) array-like of [x0, y0, x1, y1] in crop space. + If None or empty, only grid prediction is performed. + texts : list of text strings aligned with bboxes. + span_threshold : fractional overlap to trigger span expansion (default 0.1) + + Returns + ------- + (GridPrediction, list[CellInfo]) + CellInfo list is empty when bboxes is None or empty. + """ + grid = self.predict_grid(pixmap) + + if not bboxes: + return grid, [] + + img_rgb = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape( + pixmap.height, pixmap.width, pixmap.n + ) + + bboxes_arr = np.asarray(bboxes, dtype=np.float32) + crop_h = float(pixmap.h) + crop_w = float(pixmap.w) + + bboxes_crop = bboxes_arr + cx_list = sorted((float(b[0]) + float(b[2])) / 2.0 for b in bboxes_arr) + cy_list = sorted((float(b[1]) + float(b[3])) / 2.0 for b in bboxes_arr) + + max_h = len(grid.h_on_prob) + anchors = np.linspace(0.0, 1.0, max_h, dtype=np.float32) + orig_h = float(pixmap.h) + + h_tuples = [] + h_cls_list = ( + grid.h_cls.tolist() if grid.h_cls is not None else [1] * len(grid.h_lines) # pylint: disable=no-member + ) + for y, c in zip(grid.h_lines, h_cls_list): + y_norm = y / orig_h + idx = int(np.argmin(np.abs(anchors - y_norm))) + score = float(grid.h_on_prob[idx]) + h_tuples.append((y, score, c)) + + max_v = len(grid.v_on_prob) + anchors_v = np.linspace(0.0, 1.0, max_v, dtype=np.float32) + orig_w = float(pixmap.w) + + v_tuples = [] + for x in grid.v_lines: + x_norm = x / orig_w + idx = int(np.argmin(np.abs(anchors_v - x_norm))) + score = float(grid.v_on_prob[idx]) + v_tuples.append((x, score, 0)) + + if self.filter_empty_lines: + h_tuples = self._filter_empty_lines(h_tuples, cy_list, crop_h) + v_tuples = self._filter_empty_lines(v_tuples, cx_list, crop_w) + while h_tuples and not any(0.0 <= c <= h_tuples[0][0] for c in cy_list): + h_tuples = h_tuples[1:] + while v_tuples and not any(0.0 <= c <= v_tuples[0][0] for c in cx_list): + v_tuples = v_tuples[1:] + while h_tuples and not any(h_tuples[-1][0] <= c <= crop_h for c in cy_list): + h_tuples = h_tuples[:-1] + while v_tuples and not any(v_tuples[-1][0] <= c <= crop_w for c in cx_list): + v_tuples = v_tuples[:-1] + + filtered_h = [y for y, _, _ in h_tuples] + filtered_v = [x for x, _, _ in v_tuples] + filtered_h_cls = np.array([c for _, _, c in h_tuples], dtype=np.int32) + + if self.snap_to_bbox_gaps: + snapped = self._snap_lines_to_bbox_gaps(filtered_h, bboxes_crop) + h_tuples = [(sy, sc, c) for sy, (_, sc, c) in zip(snapped, h_tuples)] + filtered_h = [y for y, _, _ in h_tuples] + + filtered_h_norm = np.array([y / orig_h for y in filtered_h], dtype=np.float32) + filtered_v_norm = np.array([x / orig_w for x in filtered_v], dtype=np.float32) + + grid = GridPrediction( + h_lines=sorted(filtered_h), + v_lines=sorted(filtered_v), + h_on_prob=grid.h_on_prob, + v_on_prob=grid.v_on_prob, + h_lines_norm=filtered_h_norm, + v_lines_norm=filtered_v_norm, + h_cls=filtered_h_cls, + connectivity=grid.connectivity, + ) + + cells = self._post_process_grid( + bboxes_page=bboxes_arr, + grid=grid, + span_threshold=span_threshold, + ) + + if texts is not None: + for cell in cells: + if 0 <= cell.bbox_idx < len(texts): + cell.text = texts[cell.bbox_idx] + + return grid, cells + + # ------------------------------------------------------------------ + # Grid-based bbox assignment + # ------------------------------------------------------------------ + + def _post_process_grid( + self, + bboxes_page: np.ndarray, + grid: GridPrediction, + span_threshold: float, + ) -> list: + if grid.h_lines: + last_h = sorted(grid.h_lines)[-1] + row_edges = [0.0] + sorted(grid.h_lines) + [max(last_h * 2, last_h + 1)] + else: + row_edges = [0.0, 1.0] + + if grid.v_lines: + last_v = sorted(grid.v_lines)[-1] + col_edges = [0.0] + sorted(grid.v_lines) + [max(last_v * 2, last_v + 1)] + else: + col_edges = [0.0, 1.0] + + def find_cell_idx(pos, edges): + for i in range(len(edges) - 1): + if edges[i] <= pos < edges[i + 1]: + return i + return max(0, len(edges) - 2) + + results = [] + for i, bbox in enumerate(bboxes_page): + x0 = float(bbox[0]) + y0 = float(bbox[1]) + x1 = float(bbox[2]) + y1 = float(bbox[3]) + cx = (x0 + x1) / 2.0 + cy = (y0 + y1) / 2.0 + + base_row = find_cell_idx(cy, row_edges) + base_col = find_cell_idx(cx, col_edges) + row_start = base_row + row_end = base_row + 1 + col_start = base_col + col_end = base_col + 1 + + r = base_row + while r > 0: + h = row_edges[r] - row_edges[r - 1] + if h > 0 and (row_edges[r] - y0) / h > span_threshold: + row_start = r - 1 + r -= 1 + else: + break + r = base_row + while r < len(row_edges) - 2: + h = row_edges[r + 1] - row_edges[r] + if h > 0 and (y1 - row_edges[r + 1]) / h > span_threshold: + row_end = r + 2 + r += 1 + else: + break + c = base_col + while c > 0: + w = col_edges[c] - col_edges[c - 1] + if w > 0 and (col_edges[c] - x0) / w > span_threshold: + col_start = c - 1 + c -= 1 + else: + break + c = base_col + while c < len(col_edges) - 2: + w = col_edges[c + 1] - col_edges[c] + if w > 0 and (x1 - col_edges[c + 1]) / w > span_threshold: + col_end = c + 2 + c += 1 + else: + break + + results.append( + CellInfo( + bbox_idx=i, + row_start=row_start, + row_end=row_end, + col_start=col_start, + col_end=col_end, + row=base_row, + col=base_col, + ) + ) + + return results diff --git a/src/table.py b/src/table.py index c0f8f8981..39ce02f95 100644 --- a/src/table.py +++ b/src/table.py @@ -72,6 +72,7 @@ """ +import os import inspect import itertools import string @@ -80,6 +81,8 @@ from dataclasses import dataclass from operator import itemgetter import weakref +import pathlib + import pymupdf from pymupdf import mupdf @@ -89,6 +92,32 @@ # pylint: disable=no-name-in-module +# Optionally use the TGIF table grid finder. +# This replace fz_find_table_within_bounds. +USE_TGIF = os.getenv("USE_TGIF", "0") +EXTRACTOR_V4 = None # Keep pylint happy. +if USE_TGIF == "0": + if os.environ.get('PYMUPDF_LEGACY_TABLE_DIAGNOSTIC') != '0': + print("Using legacy table grid extraction.") +elif USE_TGIF == "1": + print("Using TGIFVx for table grid extraction.") + import pymupdf.tgif # pylint: disable=import-error +elif USE_TGIF == "4": + print("Using TGEV4 for table grid extraction.") + from pymupdf.TableGridExtractorV4 import TableGridExtractorV4 + + EXTRACTOR_V4 = TableGridExtractorV4( + grid_onnx_path=str(pathlib.Path(__file__).parent / "table_grid_model_v4.onnx"), + # conn_onnx_path=None, + # h_on_threshold=args.h_on_threshold, + # v_on_threshold=args.v_on_threshold, + # conn_threshold=args.conn_threshold, + # nms_min_dist=args.nms_min_dist, + # filter_empty_lines=not args.no_filter_empty, + ) +else: + raise Exception(f"Unrecognised {USE_TGIF=}, should be unset, '0', '1' or '4'.") + EDGES = [] # vector graphics from PyMuPDF CHARS = [] # text characters from PyMuPDF TEXTPAGE = None # textpage for cell text extraction @@ -161,9 +190,63 @@ def get_table_dict_from_rect(textpage, rect): return table_dict -def make_table_from_bbox(textpage, word_rects, rect): +def get_table_cells_from_rect_tgif1(page, word_rects, rect): + cells = [] + bound = mupdf.FzRect(*rect) + + try: + r, xpos, ypos = pymupdf.tgif.fz_visual_table_grid_finder(page, bound) # pylint: disable=no-member + x_count = int(xpos.m_internal.len) + x_values = [xpos.list(i).pos for i in range(x_count)] + y_count = int(ypos.m_internal.len) + y_values = [ypos.list(i).pos for i in range(y_count)] + if xpos.m_internal.max_uncertainty > 0 or ypos.m_internal.max_uncertainty > 0: + print(f"{page.number=}: grid with uncertainty for {bound=}") + except Exception: + return cells + for i in range(y_count - 1): + for j in range(x_count - 1): + cell = (x_values[j], y_values[i], x_values[j + 1], y_values[i + 1]) + cells.append(cell) + return cells + + +def get_table_cells_from_rect_tgif4(page, word_rects, rect): + """Use TableGridExtractorV4 to detect table structure.""" + pix = page.get_pixmap(clip=rect) # make Pixmap from passed-in rect + + # make transformation matrix from pixmap to rect coordinates + pclip = pymupdf.IRect(pix.irect) + matrix = pclip.torect(rect) # in case we want to change resolution + + pred = EXTRACTOR_V4.predict_grid(pix) # call GRID extractor + + h_lines = [pclip.y0, pclip.y1] # include top and bottom of the rect + # add predicted h lines + h_lines.extend(y + pclip.y0 for y in pred.h_lines) + h_lines = sorted(h_lines) + + v_lines = [pclip.x0, pclip.x1] # include left and right of the rect + # add predicted v lines + v_lines.extend(x + pclip.x0 for x in pred.v_lines) + v_lines = sorted(v_lines) + + # we now have the horizontal and vertical lines and make the cells + cells = [] + for i in range(len(h_lines) - 1): + for j in range(len(v_lines) - 1): + cell = pymupdf.Rect(v_lines[j], h_lines[i], v_lines[j + 1], h_lines[i + 1]) + cells.append(cell * matrix) + return cells + + +def make_table_from_bbox(page, textpage, word_rects, rect): """Detect table structure within a given rectangle.""" cells = [] # table cells as (x0,y0,x1,y1) tuples + if USE_TGIF == "1": + return get_table_cells_from_rect_tgif1(page, word_rects, rect) + elif USE_TGIF == "4": + return get_table_cells_from_rect_tgif4(page, word_rects, rect) # calls fz_find_table_within_bounds block = get_table_dict_from_rect(textpage, rect) @@ -2714,9 +2797,11 @@ def find_tables( if my_boxes: word_rects = [pymupdf.Rect(w[:4]) for w in TEXTPAGE.extractWORDS()] tp2 = page.get_textpage(flags=TABLE_DETECTOR_FLAGS) - for rect in my_boxes: - cells = make_table_from_bbox(tp2, word_rects, rect) # pylint: disable=E0606 - tbf.tables.append(Table(page, cells)) + for rect in my_boxes: + cells = make_table_from_bbox( + page, tp2, word_rects, rect + ) # pylint: disable=E0606 + tbf.tables.append(Table(page, cells)) except Exception as e: pymupdf.message("find_tables: exception occurred: %s" % str(e)) return None diff --git a/src/table_grid_model_v4.onnx b/src/table_grid_model_v4.onnx new file mode 100644 index 000000000..a40fa65ad Binary files /dev/null and b/src/table_grid_model_v4.onnx differ diff --git a/tests/test_4767.py b/tests/test_4767.py index 212becf27..2dc6967fb 100644 --- a/tests/test_4767.py +++ b/tests/test_4767.py @@ -15,6 +15,8 @@ def test_4767(): print('test_4767(): not running on Pyodide - cannot run child processes.') return + os.environ['PYMUPDF_LEGACY_TABLE_DIAGNOSTIC'] = '0' + if (1 and platform.system() == 'Windows' and os.environ.get('GITHUB_ACTIONS') == 'true' diff --git a/tests/test_general.py b/tests/test_general.py index 9a531faee..23d6bdfe5 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -1079,6 +1079,8 @@ def test_cli_out(): print('test_cli_out(): not running on Pyodide - cannot run child processes.') return + os.environ['PYMUPDF_LEGACY_TABLE_DIAGNOSTIC'] = '0' + import platform import re import subprocess @@ -1174,6 +1176,7 @@ def test_use_python_logging(): print('test_cli(): not running on Pyodide - cannot run child processes.') return + os.environ['PYMUPDF_LEGACY_TABLE_DIAGNOSTIC'] = '0' log_prefix = None if os.environ.get('PYMUPDF_USE_EXTRA') == '0': log_prefix = f'.+Using non-default setting from PYMUPDF_USE_EXTRA: \'0\'' diff --git a/tests/test_pylint.py b/tests/test_pylint.py index e460bbe4d..8c58f5880 100644 --- a/tests/test_pylint.py +++ b/tests/test_pylint.py @@ -115,6 +115,7 @@ def test_pylint(): directory = f'{root}/src' directory = directory.replace('/', os.sep) leafs = [ + 'TableGridExtractorV4.py', '__init__.py', '__main__.py', '_apply_pages.py',