From d4f0f2f9bf47fa5ee9b93d76d0e6b010d7a31e7a Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 10 Feb 2026 08:33:26 -0800 Subject: [PATCH 1/5] added _utils to examples --- examples/_utils.py | 49 +++++++++++++++++++++++++++++++++++++++++ examples/capetown.py | 44 ++++-------------------------------- examples/guanajuato.py | 46 ++++---------------------------------- examples/los_angeles.py | 46 ++++---------------------------------- examples/playground.py | 36 +++--------------------------- examples/rio.py | 44 ++++-------------------------------- examples/trinidad.py | 46 ++++---------------------------------- 7 files changed, 72 insertions(+), 239 deletions(-) create mode 100644 examples/_utils.py diff --git a/examples/_utils.py b/examples/_utils.py new file mode 100644 index 0000000..191c415 --- /dev/null +++ b/examples/_utils.py @@ -0,0 +1,49 @@ +_MAJOR_WATER = {'river', 'canal'} +_MINOR_WATER = {'stream', 'drain', 'ditch'} + + +def print_controls(): + print("\nControls:") + print(" W/S/A/D or Arrow keys: Move camera") + print(" Q/E or Page Up/Down: Move up/down") + print(" I/J/K/L: Look around") + print(" +/-: Adjust movement speed") + print(" G: Cycle overlay layers") + print(" O: Place observer (for viewshed)") + print(" V: Toggle viewshed (teal glow)") + print(" [/]: Adjust observer height") + print(" T: Toggle shadows") + print(" C: Cycle colormap") + print(" U: Toggle tile overlay") + print(" F: Screenshot") + print(" H: Toggle help overlay") + print(" X: Exit\n") + + +def classify_water_features(water_data): + """Split water GeoJSON features into (major, minor, body) lists.""" + major = [] + minor = [] + body = [] + for f in water_data.get('features', []): + ww = (f.get('properties') or {}).get('waterway', '') + nat = (f.get('properties') or {}).get('natural', '') + if ww in _MAJOR_WATER: + major.append(f) + elif ww in _MINOR_WATER: + minor.append(f) + elif nat == 'water': + body.append(f) + else: + minor.append(f) + return major, minor, body + + +def scale_building_heights(bldg_data, elev_scale=0.025, default_height_m=8.0): + """Scale MS building heights in-place to match terrain elevation scale.""" + for feat in bldg_data.get("features", []): + props = feat.get("properties", {}) + h = props.get("height", -1) + if not isinstance(h, (int, float)) or h <= 0: + h = default_height_m + props["height"] = h * elev_scale diff --git a/examples/capetown.py b/examples/capetown.py index f226eb3..9d01e6e 100644 --- a/examples/capetown.py +++ b/examples/capetown.py @@ -21,10 +21,7 @@ from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy - -# Water feature classification -_MAJOR_WATER = {'river', 'canal'} -_MINOR_WATER = {'stream', 'drain', 'ditch'} +from _utils import print_controls, classify_water_features, scale_building_heights # Cape Town bounding box (lon_min, lat_min, lon_max, lat_max) BOUNDS = (18.3, -34.2, 18.7, -33.8) @@ -64,21 +61,7 @@ def load_terrain(): # Load terrain data (downloads if needed) terrain = load_terrain() - print("\nControls:") - print(" W/S/A/D or Arrow keys: Move camera") - print(" Q/E or Page Up/Down: Move up/down") - print(" I/J/K/L: Look around") - print(" +/-: Adjust movement speed") - print(" G: Cycle overlay layers") - print(" O: Place observer (for viewshed)") - print(" V: Toggle viewshed (teal glow)") - print(" [/]: Adjust observer height") - print(" T: Toggle shadows") - print(" C: Cycle colormap") - print(" U: Toggle tile overlay") - print(" F: Screenshot") - print(" H: Toggle help overlay") - print(" X: Exit\n") + print_controls() # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") @@ -102,15 +85,9 @@ def load_terrain(): cache_path=bldg_cache, ) - # Scale building heights to match the 0.025× terrain elevation. elev_scale = 0.025 default_height_m = 8.0 - for feat in bldg_data.get("features", []): - props = feat.get("properties", {}) - h = props.get("height", -1) - if not isinstance(h, (int, float)) or h <= 0: - h = default_height_m - props["height"] = h * elev_scale + scale_building_heights(bldg_data, elev_scale, default_height_m) mesh_cache_path = Path(__file__).parent / "capetown_buildings_mesh.npz" with warnings.catch_warnings(): @@ -184,20 +161,7 @@ def load_terrain(): cache_path=water_cache, ) - major_features = [] - minor_features = [] - body_features = [] - for f in water_data.get('features', []): - ww = (f.get('properties') or {}).get('waterway', '') - nat = (f.get('properties') or {}).get('natural', '') - if ww in _MAJOR_WATER: - major_features.append(f) - elif ww in _MINOR_WATER: - minor_features.append(f) - elif nat == 'water': - body_features.append(f) - else: - minor_features.append(f) + major_features, minor_features, body_features = classify_water_features(water_data) if major_features: major_fc = {"type": "FeatureCollection", "features": major_features} diff --git a/examples/guanajuato.py b/examples/guanajuato.py index 5ace5c9..32c5241 100644 --- a/examples/guanajuato.py +++ b/examples/guanajuato.py @@ -26,10 +26,7 @@ # Import rtxpy to register the .rtx accessor from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy - -# Water feature classification -_MAJOR_WATER = {'river', 'canal'} -_MINOR_WATER = {'stream', 'drain', 'ditch'} +from _utils import print_controls, classify_water_features, scale_building_heights def load_terrain(): @@ -69,21 +66,7 @@ def load_terrain(): # Load terrain data (downloads if needed) terrain = load_terrain() - print("\nControls:") - print(" W/S/A/D or Arrow keys: Move camera") - print(" Q/E or Page Up/Down: Move up/down") - print(" I/J/K/L: Look around") - print(" +/-: Adjust movement speed") - print(" G: Cycle overlay layers") - print(" O: Place observer (for viewshed)") - print(" V: Toggle viewshed (teal glow)") - print(" [/]: Adjust observer height") - print(" T: Toggle shadows") - print(" C: Cycle colormap") - print(" U: Toggle tile overlay") - print(" F: Screenshot") - print(" H: Toggle help overlay") - print(" X: Exit\n") + print_controls() # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") @@ -107,17 +90,9 @@ def load_terrain(): cache_path=bldg_cache, ) - # Scale building heights to match the 0.025× terrain elevation. - # MS data has height in metres (-1 = unknown); replace unknowns - # with a reasonable default and apply the same scale factor. elev_scale = 0.025 default_height_m = 8.0 - for feat in bldg_data.get("features", []): - props = feat.get("properties", {}) - h = props.get("height", -1) - if not isinstance(h, (int, float)) or h <= 0: - h = default_height_m - props["height"] = h * elev_scale + scale_building_heights(bldg_data, elev_scale, default_height_m) mesh_cache_path = Path(__file__).parent / "guanajuato_buildings_mesh.npz" with warnings.catch_warnings(): @@ -191,20 +166,7 @@ def load_terrain(): cache_path=water_cache, ) - major_features = [] - minor_features = [] - body_features = [] - for f in water_data.get('features', []): - ww = (f.get('properties') or {}).get('waterway', '') - nat = (f.get('properties') or {}).get('natural', '') - if ww in _MAJOR_WATER: - major_features.append(f) - elif ww in _MINOR_WATER: - minor_features.append(f) - elif nat == 'water': - body_features.append(f) - else: - minor_features.append(f) + major_features, minor_features, body_features = classify_water_features(water_data) if major_features: major_fc = {"type": "FeatureCollection", "features": major_features} diff --git a/examples/los_angeles.py b/examples/los_angeles.py index 179ad0e..97410e8 100644 --- a/examples/los_angeles.py +++ b/examples/los_angeles.py @@ -24,10 +24,7 @@ from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy - -# Water feature classification -_MAJOR_WATER = {'river', 'canal'} -_MINOR_WATER = {'stream', 'drain', 'ditch'} +from _utils import print_controls, classify_water_features, scale_building_heights # Los Angeles bounding box (WGS84) # Focused area covering DTLA, Echo Park, Silver Lake, Griffith Park, @@ -71,21 +68,7 @@ def load_terrain(): # Load terrain data (downloads if needed) terrain = load_terrain() - print("\nControls:") - print(" W/S/A/D or Arrow keys: Move camera") - print(" Q/E or Page Up/Down: Move up/down") - print(" I/J/K/L: Look around") - print(" +/-: Adjust movement speed") - print(" G: Cycle overlay layers") - print(" O: Place observer (for viewshed)") - print(" V: Toggle viewshed (teal glow)") - print(" [/]: Adjust observer height") - print(" T: Toggle shadows") - print(" C: Cycle colormap") - print(" U: Toggle tile overlay") - print(" F: Screenshot") - print(" H: Toggle help overlay") - print(" X: Exit\n") + print_controls() # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") @@ -109,17 +92,9 @@ def load_terrain(): cache_path=bldg_cache, ) - # Scale building heights to match the 0.5× terrain elevation. - # MS data has height in metres (-1 = unknown); replace unknowns - # with a reasonable default and apply the same scale factor. elev_scale = 0.5 default_height_m = 8.0 - for feat in bldg_data.get("features", []): - props = feat.get("properties", {}) - h = props.get("height", -1) - if not isinstance(h, (int, float)) or h <= 0: - h = default_height_m - props["height"] = h * elev_scale + scale_building_heights(bldg_data, elev_scale, default_height_m) mesh_cache_path = Path(__file__).parent / "los_angeles_buildings_mesh.npz" with warnings.catch_warnings(): @@ -193,20 +168,7 @@ def load_terrain(): cache_path=water_cache, ) - major_features = [] - minor_features = [] - body_features = [] - for f in water_data.get('features', []): - ww = (f.get('properties') or {}).get('waterway', '') - nat = (f.get('properties') or {}).get('natural', '') - if ww in _MAJOR_WATER: - major_features.append(f) - elif ww in _MINOR_WATER: - minor_features.append(f) - elif nat == 'water': - body_features.append(f) - else: - minor_features.append(f) + major_features, minor_features, body_features = classify_water_features(water_data) if major_features: major_fc = {"type": "FeatureCollection", "features": major_features} diff --git a/examples/playground.py b/examples/playground.py index 8708bc3..7d452e6 100644 --- a/examples/playground.py +++ b/examples/playground.py @@ -26,10 +26,7 @@ # Import rtxpy to register the .rtx accessor from rtxpy import fetch_dem, fetch_roads, fetch_water import rtxpy - -# Water feature classification -_MAJOR_WATER = {'river', 'canal'} -_MINOR_WATER = {'stream', 'drain', 'ditch'} +from _utils import print_controls, classify_water_features def load_terrain(): @@ -69,21 +66,7 @@ def load_terrain(): # Load terrain data (downloads if needed) terrain = load_terrain() - print("\nControls:") - print(" W/S/A/D or Arrow keys: Move camera") - print(" Q/E or Page Up/Down: Move up/down") - print(" I/J/K/L: Look around") - print(" +/-: Adjust movement speed") - print(" G: Cycle overlay layers") - print(" O: Place observer (for viewshed)") - print(" V: Toggle viewshed (teal glow)") - print(" [/]: Adjust observer height") - print(" T: Toggle shadows") - print(" C: Cycle colormap") - print(" U: Toggle tile overlay") - print(" F: Screenshot") - print(" H: Toggle help overlay") - print(" X: Exit\n") + print_controls() # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") @@ -236,20 +219,7 @@ def load_terrain(): cache_path=water_cache, ) - major_features = [] - minor_features = [] - body_features = [] - for f in water_data.get('features', []): - ww = (f.get('properties') or {}).get('waterway', '') - nat = (f.get('properties') or {}).get('natural', '') - if ww in _MAJOR_WATER: - major_features.append(f) - elif ww in _MINOR_WATER: - minor_features.append(f) - elif nat == 'water': - body_features.append(f) - else: - minor_features.append(f) + major_features, minor_features, body_features = classify_water_features(water_data) if major_features: major_fc = {"type": "FeatureCollection", "features": major_features} diff --git a/examples/rio.py b/examples/rio.py index 630bdd6..de52d0e 100644 --- a/examples/rio.py +++ b/examples/rio.py @@ -21,10 +21,7 @@ from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water import rtxpy - -# Water feature classification -_MAJOR_WATER = {'river', 'canal'} -_MINOR_WATER = {'stream', 'drain', 'ditch'} +from _utils import print_controls, classify_water_features, scale_building_heights # Rio de Janeiro bounding box (WGS84) # Covers the city from Barra da Tijuca in the west to Ilha do Governador @@ -67,21 +64,7 @@ def load_terrain(): # Load terrain data (downloads if needed) terrain = load_terrain() - print("\nControls:") - print(" W/S/A/D or Arrow keys: Move camera") - print(" Q/E or Page Up/Down: Move up/down") - print(" I/J/K/L: Look around") - print(" +/-: Adjust movement speed") - print(" G: Cycle overlay layers") - print(" O: Place observer (for viewshed)") - print(" V: Toggle viewshed (teal glow)") - print(" [/]: Adjust observer height") - print(" T: Toggle shadows") - print(" C: Cycle colormap") - print(" U: Toggle tile overlay") - print(" F: Screenshot") - print(" H: Toggle help overlay") - print(" X: Exit\n") + print_controls() # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") @@ -105,15 +88,9 @@ def load_terrain(): cache_path=bldg_cache, ) - # Scale building heights to match the 0.025× terrain elevation. elev_scale = 0.025 default_height_m = 8.0 - for feat in bldg_data.get("features", []): - props = feat.get("properties", {}) - h = props.get("height", -1) - if not isinstance(h, (int, float)) or h <= 0: - h = default_height_m - props["height"] = h * elev_scale + scale_building_heights(bldg_data, elev_scale, default_height_m) mesh_cache_path = Path(__file__).parent / "rio_buildings_mesh.npz" with warnings.catch_warnings(): @@ -187,20 +164,7 @@ def load_terrain(): cache_path=water_cache, ) - major_features = [] - minor_features = [] - body_features = [] - for f in water_data.get('features', []): - ww = (f.get('properties') or {}).get('waterway', '') - nat = (f.get('properties') or {}).get('natural', '') - if ww in _MAJOR_WATER: - major_features.append(f) - elif ww in _MINOR_WATER: - minor_features.append(f) - elif nat == 'water': - body_features.append(f) - else: - minor_features.append(f) + major_features, minor_features, body_features = classify_water_features(water_data) if major_features: major_fc = {"type": "FeatureCollection", "features": major_features} diff --git a/examples/trinidad.py b/examples/trinidad.py index 8209573..839ada1 100644 --- a/examples/trinidad.py +++ b/examples/trinidad.py @@ -22,10 +22,7 @@ from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy - -# Water feature classification -_MAJOR_WATER = {'river', 'canal'} -_MINOR_WATER = {'stream', 'drain', 'ditch'} +from _utils import print_controls, classify_water_features, scale_building_heights def load_terrain(): @@ -62,21 +59,7 @@ def load_terrain(): # Load terrain data (downloads if needed) terrain = load_terrain() - print("\nControls:") - print(" W/S/A/D or Arrow keys: Move camera") - print(" Q/E or Page Up/Down: Move up/down") - print(" I/J/K/L: Look around") - print(" +/-: Adjust movement speed") - print(" G: Cycle overlay layers") - print(" O: Place observer (for viewshed)") - print(" V: Toggle viewshed (teal glow)") - print(" [/]: Adjust observer height") - print(" T: Toggle shadows") - print(" C: Cycle colormap") - print(" U: Toggle tile overlay") - print(" F: Screenshot") - print(" H: Toggle help overlay") - print(" X: Exit\n") + print_controls() # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") @@ -100,17 +83,9 @@ def load_terrain(): cache_path=bldg_cache, ) - # Scale building heights to match the 0.025× terrain elevation. - # MS data has height in metres (-1 = unknown); replace unknowns - # with a reasonable default and apply the same scale factor. elev_scale = 0.025 default_height_m = 8.0 - for feat in bldg_data.get("features", []): - props = feat.get("properties", {}) - h = props.get("height", -1) - if not isinstance(h, (int, float)) or h <= 0: - h = default_height_m - props["height"] = h * elev_scale + scale_building_heights(bldg_data, elev_scale, default_height_m) mesh_cache_path = Path(__file__).parent / "trinidad_buildings_mesh.npz" with warnings.catch_warnings(): @@ -184,20 +159,7 @@ def load_terrain(): cache_path=water_cache, ) - major_features = [] - minor_features = [] - body_features = [] - for f in water_data.get('features', []): - ww = (f.get('properties') or {}).get('waterway', '') - nat = (f.get('properties') or {}).get('natural', '') - if ww in _MAJOR_WATER: - major_features.append(f) - elif ww in _MINOR_WATER: - minor_features.append(f) - elif nat == 'water': - body_features.append(f) - else: - minor_features.append(f) + major_features, minor_features, body_features = classify_water_features(water_data) if major_features: major_fc = {"type": "FeatureCollection", "features": major_features} From 49e70ce4316158a69adf2acd45ad0e51ed538c20 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 10 Feb 2026 10:56:18 -0800 Subject: [PATCH 2/5] refactoring a bit --- examples/_utils.py | 49 ----------- examples/capetown.py | 175 +++++++----------------------------- examples/guanajuato.py | 184 ++++++++------------------------------ examples/los_angeles.py | 173 +++++++----------------------------- examples/playground.py | 108 +++++------------------ examples/rio.py | 154 ++++++-------------------------- examples/trinidad.py | 185 ++++++++------------------------------ rtxpy/accessor.py | 190 +++++++++++++++++++++++++++++++++++++++- rtxpy/engine.py | 162 +++++++++++++++++++++++++++++++++- 9 files changed, 536 insertions(+), 844 deletions(-) delete mode 100644 examples/_utils.py diff --git a/examples/_utils.py b/examples/_utils.py deleted file mode 100644 index 191c415..0000000 --- a/examples/_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -_MAJOR_WATER = {'river', 'canal'} -_MINOR_WATER = {'stream', 'drain', 'ditch'} - - -def print_controls(): - print("\nControls:") - print(" W/S/A/D or Arrow keys: Move camera") - print(" Q/E or Page Up/Down: Move up/down") - print(" I/J/K/L: Look around") - print(" +/-: Adjust movement speed") - print(" G: Cycle overlay layers") - print(" O: Place observer (for viewshed)") - print(" V: Toggle viewshed (teal glow)") - print(" [/]: Adjust observer height") - print(" T: Toggle shadows") - print(" C: Cycle colormap") - print(" U: Toggle tile overlay") - print(" F: Screenshot") - print(" H: Toggle help overlay") - print(" X: Exit\n") - - -def classify_water_features(water_data): - """Split water GeoJSON features into (major, minor, body) lists.""" - major = [] - minor = [] - body = [] - for f in water_data.get('features', []): - ww = (f.get('properties') or {}).get('waterway', '') - nat = (f.get('properties') or {}).get('natural', '') - if ww in _MAJOR_WATER: - major.append(f) - elif ww in _MINOR_WATER: - minor.append(f) - elif nat == 'water': - body.append(f) - else: - minor.append(f) - return major, minor, body - - -def scale_building_heights(bldg_data, elev_scale=0.025, default_height_m=8.0): - """Scale MS building heights in-place to match terrain elevation scale.""" - for feat in bldg_data.get("features", []): - props = feat.get("properties", {}) - h = props.get("height", -1) - if not isinstance(h, (int, float)) or h <= 0: - h = default_height_m - props["height"] = h * elev_scale diff --git a/examples/capetown.py b/examples/capetown.py index 9d01e6e..df1afcf 100644 --- a/examples/capetown.py +++ b/examples/capetown.py @@ -11,29 +11,27 @@ pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow """ +import warnings + import numpy as np import xarray as xr from xrspatial import slope, aspect, quantile from pathlib import Path -import warnings - from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy -from _utils import print_controls, classify_water_features, scale_building_heights # Cape Town bounding box (lon_min, lat_min, lon_max, lat_max) BOUNDS = (18.3, -34.2, 18.7, -33.8) +CACHE = Path(__file__).parent def load_terrain(): """Load Cape Town terrain data, downloading if necessary.""" - dem_path = Path(__file__).parent / "capetown_dem.tif" - terrain = fetch_dem( bounds=BOUNDS, - output_path=dem_path, + output_path=CACHE / "capetown_dem.tif", source='copernicus', crs='EPSG:32734', # UTM zone 34S ) @@ -58,11 +56,8 @@ def load_terrain(): if __name__ == "__main__": - # Load terrain data (downloads if needed) terrain = load_terrain() - print_controls() - # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") ds = xr.Dataset({ @@ -79,154 +74,48 @@ def load_terrain(): # --- Microsoft Global Building Footprints -------------------------------- try: - bldg_cache = Path(__file__).parent / "capetown_buildings.geojson" - bldg_data = fetch_buildings( - bounds=BOUNDS, - cache_path=bldg_cache, - ) - - elev_scale = 0.025 - default_height_m = 8.0 - scale_building_heights(bldg_data, elev_scale, default_height_m) - - mesh_cache_path = Path(__file__).parent / "capetown_buildings_mesh.npz" - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - bldg_info = ds.rtx.place_geojson( - bldg_data, - z='elevation', - height=default_height_m * elev_scale, - height_field='height', - geometry_id='building', - densify=False, - merge=True, - extrude=True, - mesh_cache=mesh_cache_path, - ) - print(f"Placed {bldg_info['geometries']} building footprint geometries") - except ImportError as e: + bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "capetown_buildings.geojson") + info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, + mesh_cache=CACHE / "capetown_buildings_mesh.npz") + print(f"Placed {info['geometries']} building geometries") + except Exception as e: print(f"Skipping buildings: {e}") # --- OpenStreetMap roads ------------------------------------------------ try: - # Major roads: motorways, trunk, primary, secondary - major_cache = Path(__file__).parent / "capetown_roads_major.geojson" - major_roads = fetch_roads( - bounds=BOUNDS, - road_type='major', - cache_path=major_cache, - ) - if major_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - major_roads, z='elevation', height=1, - label_field='name', geometry_id='road_major', - color=(0.10, 0.10, 0.10), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "capetown_roads_major_mesh.npz", - ) - print(f"Placed {info['geometries']} major road geometries") - - # Minor roads: tertiary, residential, service - minor_cache = Path(__file__).parent / "capetown_roads_minor.geojson" - minor_roads = fetch_roads( - bounds=BOUNDS, - road_type='minor', - cache_path=minor_cache, - ) - if minor_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - minor_roads, z='elevation', height=1, - label_field='name', geometry_id='road_minor', - color=(0.55, 0.55, 0.55), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "capetown_roads_minor_mesh.npz", - ) - print(f"Placed {info['geometries']} minor road geometries") - - except ImportError as e: + for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), + ('minor', 'road_minor', (0.55, 0.55, 0.55))]: + data = fetch_roads(bounds=BOUNDS, road_type=rt, + cache_path=CACHE / f"capetown_roads_{rt}.geojson") + info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, + mesh_cache=CACHE / f"capetown_roads_{rt}_mesh.npz") + print(f"Placed {info['geometries']} {rt} road geometries") + except Exception as e: print(f"Skipping roads: {e}") # --- OpenStreetMap water features --------------------------------------- try: - water_cache = Path(__file__).parent / "capetown_water.geojson" - water_data = fetch_water( - bounds=BOUNDS, - water_type='all', - cache_path=water_cache, - ) - - major_features, minor_features, body_features = classify_water_features(water_data) - - if major_features: - major_fc = {"type": "FeatureCollection", "features": major_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - major_info = ds.rtx.place_geojson( - major_fc, z='elevation', height=0, - label_field='name', geometry_id='water_major', - color=(0.40, 0.70, 0.95, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "capetown_water_major_mesh.npz", - ) - print(f"Placed {major_info['geometries']} major water features (rivers, canals)") - - if minor_features: - minor_fc = {"type": "FeatureCollection", "features": minor_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - minor_info = ds.rtx.place_geojson( - minor_fc, z='elevation', height=0, - label_field='name', geometry_id='water_minor', - color=(0.50, 0.75, 0.98, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "capetown_water_minor_mesh.npz", - ) - print(f"Placed {minor_info['geometries']} minor water features (streams, drains)") - - if body_features: - body_fc = {"type": "FeatureCollection", "features": body_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - body_info = ds.rtx.place_geojson( - body_fc, z='elevation', height=0.5, - label_field='name', geometry_id='water_body', - color=(0.35, 0.55, 0.88, 2.25), - extrude=True, - merge=True, - mesh_cache=Path(__file__).parent / "capetown_water_body_mesh.npz", - ) - print(f"Placed {body_info['geometries']} water bodies (lakes, ponds)") - - except ImportError as e: - print(f"Skipping water features: {e}") + water_data = fetch_water(bounds=BOUNDS, water_type='all', + cache_path=CACHE / "capetown_water.geojson") + results = ds.rtx.place_water(water_data, z='elevation', + mesh_cache_prefix=CACHE / "capetown_water") + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + except Exception as e: + print(f"Skipping water: {e}") - # --- NASA FIRMS fire detections (LANDSAT 30 m, last 7 days) ----------- + # --- NASA FIRMS fire detections (last 7 days) --------------------------- try: - fire_cache = Path(__file__).parent / "capetown_fires.geojson" - fire_data = fetch_firms( - bounds=BOUNDS, - date_span='7d', - cache_path=fire_cache, - crs='EPSG:32734', - ) + fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', + cache_path=CACHE / "capetown_fires.geojson", + crs='EPSG:32734') if fire_data.get('features'): - elev_scale = 0.025 with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="place_geojson called before") fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * elev_scale, - geometry_id='fire', - color=(1.0, 0.25, 0.0, 3.0), - extrude=True, - merge=True, + fire_data, z='elevation', height=20 * 0.025, + geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), + extrude=True, merge=True, ) print(f"Placed {fire_info['geometries']} fire detection footprints") else: diff --git a/examples/guanajuato.py b/examples/guanajuato.py index 32c5241..be06db2 100644 --- a/examples/guanajuato.py +++ b/examples/guanajuato.py @@ -1,6 +1,6 @@ -"""Interactive playground for the Guanajuato–San Miguel de Allende highlands. +"""Interactive playground for the Guanajuato-San Miguel de Allende highlands. -Explore the terrain of central Mexico's Bajío region using GPU-accelerated +Explore the terrain of central Mexico's Bajio region using GPU-accelerated ray tracing. The area covers the Sierra de Santa Rosa northwest of Guanajuato city, the colonial town of San Miguel de Allende to the east, and the rugged canyon country in between. @@ -26,18 +26,19 @@ # Import rtxpy to register the .rtx accessor from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy -from _utils import print_controls, classify_water_features, scale_building_heights + +BOUNDS = (-101.50, 20.70, -100.50, 21.30) +CRS = 'EPSG:32614' +CACHE = Path(__file__).parent def load_terrain(): """Load Guanajuato terrain data, downloading if necessary.""" - dem_path = Path(__file__).parent / "guanajuato_dem.tif" - terrain = fetch_dem( - bounds=(-101.50, 20.70, -100.50, 21.30), - output_path=dem_path, + bounds=BOUNDS, + output_path=CACHE / "guanajuato_dem.tif", source='copernicus', - crs='EPSG:32614', + crs=CRS, ) # Mask nodata / water pixels @@ -63,11 +64,8 @@ def load_terrain(): if __name__ == "__main__": - # Load terrain data (downloads if needed) terrain = load_terrain() - print_controls() - # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") ds = xr.Dataset({ @@ -84,154 +82,48 @@ def load_terrain(): # --- Microsoft Global Building Footprints -------------------------------- try: - bldg_cache = Path(__file__).parent / "guanajuato_buildings.geojson" - bldg_data = fetch_buildings( - bounds=(-101.50, 20.70, -100.50, 21.30), - cache_path=bldg_cache, - ) - - elev_scale = 0.025 - default_height_m = 8.0 - scale_building_heights(bldg_data, elev_scale, default_height_m) - - mesh_cache_path = Path(__file__).parent / "guanajuato_buildings_mesh.npz" - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - bldg_info = ds.rtx.place_geojson( - bldg_data, - z='elevation', - height=default_height_m * elev_scale, - height_field='height', - geometry_id='building', - densify=False, - merge=True, - extrude=True, - mesh_cache=mesh_cache_path, - ) - print(f"Placed {bldg_info['geometries']} building footprint geometries") - except ImportError as e: + bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "guanajuato_buildings.geojson") + info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, + mesh_cache=CACHE / "guanajuato_buildings_mesh.npz") + print(f"Placed {info['geometries']} building geometries") + except Exception as e: print(f"Skipping buildings: {e}") # --- OpenStreetMap roads ------------------------------------------------ try: - # Major roads: motorways, trunk, primary, secondary - major_cache = Path(__file__).parent / "guanajuato_roads_major.geojson" - major_roads = fetch_roads( - bounds=(-101.50, 20.70, -100.50, 21.30), - road_type='major', - cache_path=major_cache, - ) - if major_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - major_roads, z='elevation', height=1, - label_field='name', geometry_id='road_major', - color=(0.10, 0.10, 0.10), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "guanajuato_roads_major_mesh.npz", - ) - print(f"Placed {info['geometries']} major road geometries") - - # Minor roads: tertiary, residential, service - minor_cache = Path(__file__).parent / "guanajuato_roads_minor.geojson" - minor_roads = fetch_roads( - bounds=(-101.50, 20.70, -100.50, 21.30), - road_type='minor', - cache_path=minor_cache, - ) - if minor_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - minor_roads, z='elevation', height=1, - label_field='name', geometry_id='road_minor', - color=(0.55, 0.55, 0.55), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "guanajuato_roads_minor_mesh.npz", - ) - print(f"Placed {info['geometries']} minor road geometries") - - except ImportError as e: + for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), + ('minor', 'road_minor', (0.55, 0.55, 0.55))]: + data = fetch_roads(bounds=BOUNDS, road_type=rt, + cache_path=CACHE / f"guanajuato_roads_{rt}.geojson") + info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, + mesh_cache=CACHE / f"guanajuato_roads_{rt}_mesh.npz") + print(f"Placed {info['geometries']} {rt} road geometries") + except Exception as e: print(f"Skipping roads: {e}") # --- OpenStreetMap water features --------------------------------------- try: - water_cache = Path(__file__).parent / "guanajuato_water.geojson" - water_data = fetch_water( - bounds=(-101.50, 20.70, -100.50, 21.30), - water_type='all', - cache_path=water_cache, - ) - - major_features, minor_features, body_features = classify_water_features(water_data) - - if major_features: - major_fc = {"type": "FeatureCollection", "features": major_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - major_info = ds.rtx.place_geojson( - major_fc, z='elevation', height=0, - label_field='name', geometry_id='water_major', - color=(0.40, 0.70, 0.95, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "guanajuato_water_major_mesh.npz", - ) - print(f"Placed {major_info['geometries']} major water features (rivers, canals)") - - if minor_features: - minor_fc = {"type": "FeatureCollection", "features": minor_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - minor_info = ds.rtx.place_geojson( - minor_fc, z='elevation', height=0, - label_field='name', geometry_id='water_minor', - color=(0.50, 0.75, 0.98, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "guanajuato_water_minor_mesh.npz", - ) - print(f"Placed {minor_info['geometries']} minor water features (streams, drains)") - - if body_features: - body_fc = {"type": "FeatureCollection", "features": body_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - body_info = ds.rtx.place_geojson( - body_fc, z='elevation', height=0.5, - label_field='name', geometry_id='water_body', - color=(0.35, 0.55, 0.88, 2.25), - extrude=True, - merge=True, - mesh_cache=Path(__file__).parent / "guanajuato_water_body_mesh.npz", - ) - print(f"Placed {body_info['geometries']} water bodies (lakes, ponds)") - - except ImportError as e: - print(f"Skipping water features: {e}") + water_data = fetch_water(bounds=BOUNDS, water_type='all', + cache_path=CACHE / "guanajuato_water.geojson") + results = ds.rtx.place_water(water_data, z='elevation', + mesh_cache_prefix=CACHE / "guanajuato_water") + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + except Exception as e: + print(f"Skipping water: {e}") - # --- NASA FIRMS fire detections (LANDSAT 30 m, last 7 days) ----------- + # --- NASA FIRMS fire detections (last 7 days) --------------------------- try: - fire_cache = Path(__file__).parent / "guanajuato_fires.geojson" - fire_data = fetch_firms( - bounds=(-101.50, 20.70, -100.50, 21.30), - date_span='7d', - cache_path=fire_cache, - crs='EPSG:32614', - ) + fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', + cache_path=CACHE / "guanajuato_fires.geojson", + crs=CRS) if fire_data.get('features'): - elev_scale = 0.025 with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="place_geojson called before") fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * elev_scale, - geometry_id='fire', - color=(1.0, 0.25, 0.0, 3.0), - extrude=True, - merge=True, + fire_data, z='elevation', height=20 * 0.025, + geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), + extrude=True, merge=True, ) print(f"Placed {fire_info['geometries']} fire detection footprints") else: @@ -243,7 +135,7 @@ def load_terrain(): wind = None try: from rtxpy import fetch_wind - wind = fetch_wind((-101.50, 20.70, -100.50, 21.30), grid_size=15) + wind = fetch_wind(BOUNDS, grid_size=15) except Exception as e: print(f"Skipping wind: {e}") diff --git a/examples/los_angeles.py b/examples/los_angeles.py index 97410e8..ac18ff0 100644 --- a/examples/los_angeles.py +++ b/examples/los_angeles.py @@ -24,22 +24,20 @@ from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy -from _utils import print_controls, classify_water_features, scale_building_heights # Los Angeles bounding box (WGS84) # Focused area covering DTLA, Echo Park, Silver Lake, Griffith Park, -# Hollywood Hills, and the Hollywood Sign (~8 km × 8 km at 1 m resolution). +# Hollywood Hills, and the Hollywood Sign (~8 km x 8 km at 1 m resolution). BOUNDS = (-118.32, 34.04, -118.23, 34.12) CRS = 'EPSG:32611' # UTM zone 11N +CACHE = Path(__file__).parent def load_terrain(): """Load Los Angeles terrain data, downloading if necessary.""" - dem_path = Path(__file__).parent / "los_angeles_dem.tif" - terrain = fetch_dem( bounds=BOUNDS, - output_path=dem_path, + output_path=CACHE / "los_angeles_dem.tif", source='usgs_1m', crs=CRS, ) @@ -65,11 +63,8 @@ def load_terrain(): if __name__ == "__main__": - # Load terrain data (downloads if needed) terrain = load_terrain() - print_controls() - # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") ds = xr.Dataset({ @@ -86,154 +81,48 @@ def load_terrain(): # --- Microsoft Global Building Footprints -------------------------------- try: - bldg_cache = Path(__file__).parent / "los_angeles_buildings.geojson" - bldg_data = fetch_buildings( - bounds=BOUNDS, - cache_path=bldg_cache, - ) - - elev_scale = 0.5 - default_height_m = 8.0 - scale_building_heights(bldg_data, elev_scale, default_height_m) - - mesh_cache_path = Path(__file__).parent / "los_angeles_buildings_mesh.npz" - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - bldg_info = ds.rtx.place_geojson( - bldg_data, - z='elevation', - height=default_height_m * elev_scale, - height_field='height', - geometry_id='building', - densify=False, - merge=True, - extrude=True, - mesh_cache=mesh_cache_path, - ) - print(f"Placed {bldg_info['geometries']} building footprint geometries") - except ImportError as e: + bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "los_angeles_buildings.geojson") + info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.5, + mesh_cache=CACHE / "los_angeles_buildings_mesh.npz") + print(f"Placed {info['geometries']} building geometries") + except Exception as e: print(f"Skipping buildings: {e}") # --- OpenStreetMap roads ------------------------------------------------ try: - # Major roads: motorways, trunk, primary, secondary - major_cache = Path(__file__).parent / "los_angeles_roads_major.geojson" - major_roads = fetch_roads( - bounds=BOUNDS, - road_type='major', - cache_path=major_cache, - ) - if major_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - major_roads, z='elevation', height=1, - label_field='name', geometry_id='road_major', - color=(0.10, 0.10, 0.10), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "los_angeles_roads_major_mesh.npz", - ) - print(f"Placed {info['geometries']} major road geometries") - - # Minor roads: tertiary, residential, service - minor_cache = Path(__file__).parent / "los_angeles_roads_minor.geojson" - minor_roads = fetch_roads( - bounds=BOUNDS, - road_type='minor', - cache_path=minor_cache, - ) - if minor_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - minor_roads, z='elevation', height=1, - label_field='name', geometry_id='road_minor', - color=(0.55, 0.55, 0.55), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "los_angeles_roads_minor_mesh.npz", - ) - print(f"Placed {info['geometries']} minor road geometries") - - except ImportError as e: + for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), + ('minor', 'road_minor', (0.55, 0.55, 0.55))]: + data = fetch_roads(bounds=BOUNDS, road_type=rt, + cache_path=CACHE / f"los_angeles_roads_{rt}.geojson") + info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, + mesh_cache=CACHE / f"los_angeles_roads_{rt}_mesh.npz") + print(f"Placed {info['geometries']} {rt} road geometries") + except Exception as e: print(f"Skipping roads: {e}") # --- OpenStreetMap water features --------------------------------------- try: - water_cache = Path(__file__).parent / "los_angeles_water.geojson" - water_data = fetch_water( - bounds=BOUNDS, - water_type='all', - cache_path=water_cache, - ) - - major_features, minor_features, body_features = classify_water_features(water_data) - - if major_features: - major_fc = {"type": "FeatureCollection", "features": major_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - major_info = ds.rtx.place_geojson( - major_fc, z='elevation', height=0, - label_field='name', geometry_id='water_major', - color=(0.40, 0.70, 0.95, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "los_angeles_water_major_mesh.npz", - ) - print(f"Placed {major_info['geometries']} major water features (rivers, canals)") - - if minor_features: - minor_fc = {"type": "FeatureCollection", "features": minor_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - minor_info = ds.rtx.place_geojson( - minor_fc, z='elevation', height=0, - label_field='name', geometry_id='water_minor', - color=(0.50, 0.75, 0.98, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "los_angeles_water_minor_mesh.npz", - ) - print(f"Placed {minor_info['geometries']} minor water features (streams, drains)") - - if body_features: - body_fc = {"type": "FeatureCollection", "features": body_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - body_info = ds.rtx.place_geojson( - body_fc, z='elevation', height=0.5, - label_field='name', geometry_id='water_body', - color=(0.35, 0.55, 0.88, 2.25), - extrude=True, - merge=True, - mesh_cache=Path(__file__).parent / "los_angeles_water_body_mesh.npz", - ) - print(f"Placed {body_info['geometries']} water bodies (lakes, reservoirs)") - - except ImportError as e: - print(f"Skipping water features: {e}") + water_data = fetch_water(bounds=BOUNDS, water_type='all', + cache_path=CACHE / "los_angeles_water.geojson") + results = ds.rtx.place_water(water_data, z='elevation', + mesh_cache_prefix=CACHE / "los_angeles_water") + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + except Exception as e: + print(f"Skipping water: {e}") - # --- NASA FIRMS fire detections (LANDSAT 30 m, last 7 days) ----------- + # --- NASA FIRMS fire detections (last 7 days) --------------------------- try: - fire_cache = Path(__file__).parent / "los_angeles_fires.geojson" - fire_data = fetch_firms( - bounds=BOUNDS, - date_span='7d', - cache_path=fire_cache, - crs=CRS, - ) + fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', + cache_path=CACHE / "los_angeles_fires.geojson", + crs=CRS) if fire_data.get('features'): - elev_scale = 0.5 with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="place_geojson called before") fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * elev_scale, - geometry_id='fire', - color=(1.0, 0.25, 0.0, 3.0), - extrude=True, - merge=True, + fire_data, z='elevation', height=20 * 0.5, + geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), + extrude=True, merge=True, ) print(f"Placed {fire_info['geometries']} fire detection footprints") else: diff --git a/examples/playground.py b/examples/playground.py index 7d452e6..87e2c5a 100644 --- a/examples/playground.py +++ b/examples/playground.py @@ -26,18 +26,19 @@ # Import rtxpy to register the .rtx accessor from rtxpy import fetch_dem, fetch_roads, fetch_water import rtxpy -from _utils import print_controls, classify_water_features + +BOUNDS = (-122.3, 42.8, -121.9, 43.0) +CRS = 'EPSG:5070' +CACHE = Path(__file__).parent def load_terrain(): """Load Crater Lake terrain data, downloading if necessary.""" - dem_path = Path(__file__).parent / "crater_lake_national_park.tif" - terrain = fetch_dem( - bounds=(-122.3, 42.8, -121.9, 43.0), - output_path=dem_path, + bounds=BOUNDS, + output_path=CACHE / "crater_lake_national_park.tif", source='srtm', - crs='EPSG:5070', + crs=CRS, ) # Subsample for faster interactive performance @@ -66,8 +67,6 @@ def load_terrain(): # Load terrain data (downloads if needed) terrain = load_terrain() - print_controls() - # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") ds = xr.Dataset({ @@ -167,9 +166,6 @@ def load_terrain(): ], } - # Place GeoJSON on the elevation layer's RTX scene. - # Pixel spacing is 1.0 (pixel-coord mode) which matches explore()'s - # default, so the warning is expected — suppress it. with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="place_geojson called before") info = ds.rtx.place_geojson( @@ -184,90 +180,30 @@ def load_terrain(): # --- OpenStreetMap roads ------------------------------------------------ try: - roads_cache = Path(__file__).parent / "crater_lake_roads.geojson" - roads_data = fetch_roads( - bounds=(-122.3, 42.8, -121.9, 43.0), - road_type='all', - crs='EPSG:5070', - cache_path=roads_cache, - ) - roads_mesh = Path(__file__).parent / "crater_lake_roads_mesh.npz" - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - roads_info = ds.rtx.place_geojson( - roads_data, - z='elevation', - height=5.0, - label_field='name', - geometry_id='roads', - merge=True, - mesh_cache=roads_mesh, - ) - print(f"Placed {roads_info['geometries']} road geometries") - except ImportError as e: + roads_data = fetch_roads(bounds=BOUNDS, road_type='all', crs=CRS, + cache_path=CACHE / "crater_lake_roads.geojson") + info = ds.rtx.place_roads(roads_data, z='elevation', geometry_id='roads', + height=5, mesh_cache=CACHE / "crater_lake_roads_mesh.npz") + print(f"Placed {info['geometries']} road geometries") + except Exception as e: print(f"Skipping roads: {e}") # --- OpenStreetMap water features --------------------------------------- - # Split into major (rivers, canals → wider tubes) and minor (streams, - # drains, ditches → thinner tubes), each in a distinct blue tone. try: - water_cache = Path(__file__).parent / "crater_lake_water.geojson" - water_data = fetch_water( - bounds=(-122.3, 42.8, -121.9, 43.0), - water_type='all', - crs='EPSG:5070', - cache_path=water_cache, - ) - - major_features, minor_features, body_features = classify_water_features(water_data) - - if major_features: - major_fc = {"type": "FeatureCollection", "features": major_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - major_info = ds.rtx.place_geojson( - major_fc, z='elevation', height=10.0, - label_field='name', geometry_id='water_major', - color=(0.40, 0.70, 0.95, 2.25), - merge=True, - mesh_cache=Path(__file__).parent / "crater_lake_water_major_mesh.npz", - ) - print(f"Placed {major_info['geometries']} major water features (rivers, canals)") - - if minor_features: - minor_fc = {"type": "FeatureCollection", "features": minor_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - minor_info = ds.rtx.place_geojson( - minor_fc, z='elevation', height=4.0, - label_field='name', geometry_id='water_minor', - color=(0.50, 0.75, 0.98, 2.25), - merge=True, - mesh_cache=Path(__file__).parent / "crater_lake_water_minor_mesh.npz", - ) - print(f"Placed {minor_info['geometries']} minor water features (streams, drains)") - - if body_features: - body_fc = {"type": "FeatureCollection", "features": body_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - body_info = ds.rtx.place_geojson( - body_fc, z='elevation', height=6.0, - label_field='name', geometry_id='water_body', - color=(0.35, 0.55, 0.88, 2.25), - merge=True, - mesh_cache=Path(__file__).parent / "crater_lake_water_body_mesh.npz", - ) - print(f"Placed {body_info['geometries']} water bodies (lakes, ponds)") - - except ImportError as e: - print(f"Skipping water features: {e}") + water_data = fetch_water(bounds=BOUNDS, water_type='all', crs=CRS, + cache_path=CACHE / "crater_lake_water.geojson") + results = ds.rtx.place_water(water_data, z='elevation', + mesh_cache_prefix=CACHE / "crater_lake_water") + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + except Exception as e: + print(f"Skipping water: {e}") # --- Wind data -------------------------------------------------------- wind = None try: from rtxpy import fetch_wind - wind = fetch_wind((-122.3, 42.8, -121.9, 43.0), grid_size=15) + wind = fetch_wind(BOUNDS, grid_size=15) # Crater Lake is smaller — more particles, faster, shorter lives # so they cover the field instead of clumping wind['n_particles'] = 15000 diff --git a/examples/rio.py b/examples/rio.py index de52d0e..a301c1c 100644 --- a/examples/rio.py +++ b/examples/rio.py @@ -11,32 +11,30 @@ pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow """ +import warnings + import numpy as np import xarray as xr from xrspatial import slope, aspect, quantile from pathlib import Path -import warnings - from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water import rtxpy -from _utils import print_controls, classify_water_features, scale_building_heights # Rio de Janeiro bounding box (WGS84) # Covers the city from Barra da Tijuca in the west to Ilha do Governador # in the east, including Sugarloaf, Corcovado, and Tijuca Forest. BOUNDS = (-43.42, -23.08, -43.10, -22.84) CRS = 'EPSG:32723' # UTM zone 23S +CACHE = Path(__file__).parent def load_terrain(): """Load Rio de Janeiro terrain data, downloading if necessary.""" - dem_path = Path(__file__).parent / "rio_dem.tif" - terrain = fetch_dem( bounds=BOUNDS, - output_path=dem_path, + output_path=CACHE / "rio_dem.tif", source='copernicus', crs=CRS, ) @@ -61,11 +59,8 @@ def load_terrain(): if __name__ == "__main__": - # Load terrain data (downloads if needed) terrain = load_terrain() - print_controls() - # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") ds = xr.Dataset({ @@ -82,134 +77,35 @@ def load_terrain(): # --- Microsoft Global Building Footprints -------------------------------- try: - bldg_cache = Path(__file__).parent / "rio_buildings.geojson" - bldg_data = fetch_buildings( - bounds=BOUNDS, - cache_path=bldg_cache, - ) - - elev_scale = 0.025 - default_height_m = 8.0 - scale_building_heights(bldg_data, elev_scale, default_height_m) - - mesh_cache_path = Path(__file__).parent / "rio_buildings_mesh.npz" - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - bldg_info = ds.rtx.place_geojson( - bldg_data, - z='elevation', - height=default_height_m * elev_scale, - height_field='height', - geometry_id='building', - densify=False, - merge=True, - extrude=True, - mesh_cache=mesh_cache_path, - ) - print(f"Placed {bldg_info['geometries']} building footprint geometries") - except ImportError as e: + bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "rio_buildings.geojson") + info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, + mesh_cache=CACHE / "rio_buildings_mesh.npz") + print(f"Placed {info['geometries']} building geometries") + except Exception as e: print(f"Skipping buildings: {e}") # --- OpenStreetMap roads ------------------------------------------------ try: - # Major roads: motorways, trunk, primary, secondary - major_cache = Path(__file__).parent / "rio_roads_major.geojson" - major_roads = fetch_roads( - bounds=BOUNDS, - road_type='major', - cache_path=major_cache, - ) - if major_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - major_roads, z='elevation', height=1, - label_field='name', geometry_id='road_major', - color=(0.10, 0.10, 0.10), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "rio_roads_major_mesh.npz", - ) - print(f"Placed {info['geometries']} major road geometries") - - # Minor roads: tertiary, residential, service - minor_cache = Path(__file__).parent / "rio_roads_minor.geojson" - minor_roads = fetch_roads( - bounds=BOUNDS, - road_type='minor', - cache_path=minor_cache, - ) - if minor_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - minor_roads, z='elevation', height=1, - label_field='name', geometry_id='road_minor', - color=(0.55, 0.55, 0.55), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "rio_roads_minor_mesh.npz", - ) - print(f"Placed {info['geometries']} minor road geometries") - - except ImportError as e: + for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), + ('minor', 'road_minor', (0.55, 0.55, 0.55))]: + data = fetch_roads(bounds=BOUNDS, road_type=rt, + cache_path=CACHE / f"rio_roads_{rt}.geojson") + info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, + mesh_cache=CACHE / f"rio_roads_{rt}_mesh.npz") + print(f"Placed {info['geometries']} {rt} road geometries") + except Exception as e: print(f"Skipping roads: {e}") # --- OpenStreetMap water features --------------------------------------- try: - water_cache = Path(__file__).parent / "rio_water.geojson" - water_data = fetch_water( - bounds=BOUNDS, - water_type='all', - cache_path=water_cache, - ) - - major_features, minor_features, body_features = classify_water_features(water_data) - - if major_features: - major_fc = {"type": "FeatureCollection", "features": major_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - major_info = ds.rtx.place_geojson( - major_fc, z='elevation', height=0, - label_field='name', geometry_id='water_major', - color=(0.40, 0.70, 0.95, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "rio_water_major_mesh.npz", - ) - print(f"Placed {major_info['geometries']} major water features (rivers, canals)") - - if minor_features: - minor_fc = {"type": "FeatureCollection", "features": minor_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - minor_info = ds.rtx.place_geojson( - minor_fc, z='elevation', height=0, - label_field='name', geometry_id='water_minor', - color=(0.50, 0.75, 0.98, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "rio_water_minor_mesh.npz", - ) - print(f"Placed {minor_info['geometries']} minor water features (streams, drains)") - - if body_features: - body_fc = {"type": "FeatureCollection", "features": body_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - body_info = ds.rtx.place_geojson( - body_fc, z='elevation', height=0.5, - label_field='name', geometry_id='water_body', - color=(0.35, 0.55, 0.88, 2.25), - extrude=True, - merge=True, - mesh_cache=Path(__file__).parent / "rio_water_body_mesh.npz", - ) - print(f"Placed {body_info['geometries']} water bodies (lakes, ponds)") - - except ImportError as e: - print(f"Skipping water features: {e}") + water_data = fetch_water(bounds=BOUNDS, water_type='all', + cache_path=CACHE / "rio_water.geojson") + results = ds.rtx.place_water(water_data, z='elevation', + mesh_cache_prefix=CACHE / "rio_water") + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + except Exception as e: + print(f"Skipping water: {e}") # --- Wind data -------------------------------------------------------- wind = None diff --git a/examples/trinidad.py b/examples/trinidad.py index 839ada1..0441cfd 100644 --- a/examples/trinidad.py +++ b/examples/trinidad.py @@ -11,29 +11,29 @@ pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow """ +import warnings + import numpy as np import xarray as xr from xrspatial import slope, aspect, quantile from pathlib import Path -# Import rtxpy to register the .rtx accessor -import warnings - from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms import rtxpy -from _utils import print_controls, classify_water_features, scale_building_heights + +BOUNDS = (-61.95, 10.04, -60.44, 11.40) +CRS = 'EPSG:32620' +CACHE = Path(__file__).parent def load_terrain(): """Load Trinidad & Tobago terrain data, downloading if necessary.""" - dem_path = Path(__file__).parent / "trinidad_tobago_dem.tif" - terrain = fetch_dem( - bounds=(-61.95, 10.04, -60.44, 11.40), - output_path=dem_path, + bounds=BOUNDS, + output_path=CACHE / "trinidad_tobago_dem.tif", source='copernicus', - crs='EPSG:32620', + crs=CRS, ) # Scale down elevation for visualization (optional) @@ -56,11 +56,8 @@ def load_terrain(): if __name__ == "__main__": - # Load terrain data (downloads if needed) terrain = load_terrain() - print_controls() - # Build Dataset with derived layers print("Building Dataset with terrain analysis layers...") ds = xr.Dataset({ @@ -77,154 +74,48 @@ def load_terrain(): # --- Microsoft Global Building Footprints -------------------------------- try: - bldg_cache = Path(__file__).parent / "trinidad_buildings.geojson" - bldg_data = fetch_buildings( - bounds=(-61.95, 10.04, -60.44, 11.40), - cache_path=bldg_cache, - ) - - elev_scale = 0.025 - default_height_m = 8.0 - scale_building_heights(bldg_data, elev_scale, default_height_m) - - mesh_cache_path = Path(__file__).parent / "trinidad_buildings_mesh.npz" - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - bldg_info = ds.rtx.place_geojson( - bldg_data, - z='elevation', - height=default_height_m * elev_scale, - height_field='height', - geometry_id='building', - densify=False, - merge=True, - extrude=True, - mesh_cache=mesh_cache_path, - ) - print(f"Placed {bldg_info['geometries']} building footprint geometries") - except ImportError as e: + bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "trinidad_buildings.geojson") + info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, + mesh_cache=CACHE / "trinidad_buildings_mesh.npz") + print(f"Placed {info['geometries']} building geometries") + except Exception as e: print(f"Skipping buildings: {e}") # --- OpenStreetMap roads ------------------------------------------------ try: - # Major roads: motorways, trunk, primary, secondary - major_cache = Path(__file__).parent / "trinidad_roads_major.geojson" - major_roads = fetch_roads( - bounds=(-61.95, 10.04, -60.44, 11.40), - road_type='major', - cache_path=major_cache, - ) - if major_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - major_roads, z='elevation', height=1, - label_field='name', geometry_id='road_major', - color=(0.10, 0.10, 0.10), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "trinidad_roads_major_mesh.npz", - ) - print(f"Placed {info['geometries']} major road geometries") - - # Minor roads: tertiary, residential, service - minor_cache = Path(__file__).parent / "trinidad_roads_minor.geojson" - minor_roads = fetch_roads( - bounds=(-61.95, 10.04, -60.44, 11.40), - road_type='minor', - cache_path=minor_cache, - ) - if minor_roads.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - minor_roads, z='elevation', height=1, - label_field='name', geometry_id='road_minor', - color=(0.55, 0.55, 0.55), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "trinidad_roads_minor_mesh.npz", - ) - print(f"Placed {info['geometries']} minor road geometries") - - except ImportError as e: + for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), + ('minor', 'road_minor', (0.55, 0.55, 0.55))]: + data = fetch_roads(bounds=BOUNDS, road_type=rt, + cache_path=CACHE / f"trinidad_roads_{rt}.geojson") + info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, + mesh_cache=CACHE / f"trinidad_roads_{rt}_mesh.npz") + print(f"Placed {info['geometries']} {rt} road geometries") + except Exception as e: print(f"Skipping roads: {e}") # --- OpenStreetMap water features --------------------------------------- try: - water_cache = Path(__file__).parent / "trinidad_water.geojson" - water_data = fetch_water( - bounds=(-61.95, 10.04, -60.44, 11.40), - water_type='all', - cache_path=water_cache, - ) - - major_features, minor_features, body_features = classify_water_features(water_data) - - if major_features: - major_fc = {"type": "FeatureCollection", "features": major_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - major_info = ds.rtx.place_geojson( - major_fc, z='elevation', height=0, - label_field='name', geometry_id='water_major', - color=(0.40, 0.70, 0.95, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "trinidad_water_major_mesh.npz", - ) - print(f"Placed {major_info['geometries']} major water features (rivers, canals)") - - if minor_features: - minor_fc = {"type": "FeatureCollection", "features": minor_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - minor_info = ds.rtx.place_geojson( - minor_fc, z='elevation', height=0, - label_field='name', geometry_id='water_minor', - color=(0.50, 0.75, 0.98, 2.25), - densify=False, - merge=True, - mesh_cache=Path(__file__).parent / "trinidad_water_minor_mesh.npz", - ) - print(f"Placed {minor_info['geometries']} minor water features (streams, drains)") - - if body_features: - body_fc = {"type": "FeatureCollection", "features": body_features} - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - body_info = ds.rtx.place_geojson( - body_fc, z='elevation', height=0.5, - label_field='name', geometry_id='water_body', - color=(0.35, 0.55, 0.88, 2.25), - extrude=True, - merge=True, - mesh_cache=Path(__file__).parent / "trinidad_water_body_mesh.npz", - ) - print(f"Placed {body_info['geometries']} water bodies (lakes, ponds)") - - except ImportError as e: - print(f"Skipping water features: {e}") + water_data = fetch_water(bounds=BOUNDS, water_type='all', + cache_path=CACHE / "trinidad_water.geojson") + results = ds.rtx.place_water(water_data, z='elevation', + mesh_cache_prefix=CACHE / "trinidad_water") + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + except Exception as e: + print(f"Skipping water: {e}") - # --- NASA FIRMS fire detections (LANDSAT 30 m, last 7 days) ----------- + # --- NASA FIRMS fire detections (last 7 days) --------------------------- try: - fire_cache = Path(__file__).parent / "trinidad_fires.geojson" - fire_data = fetch_firms( - bounds=(-61.95, 10.04, -60.44, 11.40), - date_span='7d', - cache_path=fire_cache, - crs='EPSG:32620', - ) + fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', + cache_path=CACHE / "trinidad_fires.geojson", + crs=CRS) if fire_data.get('features'): - elev_scale = 0.025 with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="place_geojson called before") fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * elev_scale, - geometry_id='fire', - color=(1.0, 0.25, 0.0, 3.0), - extrude=True, - merge=True, + fire_data, z='elevation', height=20 * 0.025, + geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), + extrude=True, merge=True, ) print(f"Placed {fire_info['geometries']} fire detection footprints") else: @@ -236,7 +127,7 @@ def load_terrain(): wind = None try: from rtxpy import fetch_wind - wind = fetch_wind((-61.95, 10.04, -60.44, 11.40), grid_size=15) + wind = fetch_wind(BOUNDS, grid_size=15) except Exception as e: print(f"Skipping wind: {e}") diff --git a/rtxpy/accessor.py b/rtxpy/accessor.py index bb514ec..f24594e 100644 --- a/rtxpy/accessor.py +++ b/rtxpy/accessor.py @@ -1124,6 +1124,166 @@ def place_geojson(self, geojson, height=10.0, 'geometry_ids': geometry_ids, } + def place_buildings(self, geojson, elev_scale=0.025, default_height_m=8.0, + mesh_cache=None): + """Place building footprints as extruded 3D geometry on terrain. + + Scales building heights from the GeoJSON ``height`` property (metres) + by *elev_scale* to match terrain visualisation scaling. Features + without a valid height get *default_height_m*. + + Parameters + ---------- + geojson : dict + GeoJSON FeatureCollection of building footprint polygons + (e.g. from :func:`rtxpy.fetch_buildings`). + elev_scale : float, optional + Factor applied to real-world heights so they match the scaled + terrain. Default is 0.025. + default_height_m : float, optional + Height in metres used when a feature has no ``height`` property. + Default is 8.0. + mesh_cache : str or Path, optional + Path to an ``.npz`` file for caching the merged mesh. + + Returns + ------- + dict + ``{'features': int, 'geometries': int, 'geometry_ids': list}`` + """ + import warnings + for feat in geojson.get("features", []): + props = feat.get("properties", {}) + h = props.get("height", -1) + if not isinstance(h, (int, float)) or h <= 0: + h = default_height_m + props["height"] = h * elev_scale + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="place_geojson called before") + return self.place_geojson( + geojson, + height=default_height_m * elev_scale, + height_field='height', + geometry_id='building', + densify=False, + merge=True, + extrude=True, + mesh_cache=mesh_cache, + ) + + def place_roads(self, geojson, geometry_id='road', color=None, + height=1, mesh_cache=None): + """Place road LineStrings as flat merged ribbon geometry on terrain. + + Parameters + ---------- + geojson : dict + GeoJSON FeatureCollection of road LineStrings (e.g. from + :func:`rtxpy.fetch_roads`). + geometry_id : str, optional + Geometry layer name. Use ``'road_major'`` / ``'road_minor'`` + when placing separate road classes. Default is ``'road'``. + color : tuple, optional + RGB or RGBA colour. Default is dark grey ``(0.30, 0.30, 0.30)``. + height : float, optional + Ribbon height above the terrain surface. Default is 1. + mesh_cache : str or Path, optional + Path to an ``.npz`` file for caching the merged mesh. + + Returns + ------- + dict + ``{'features': int, 'geometries': int, 'geometry_ids': list}`` + """ + import warnings + if not geojson.get('features'): + return {'features': 0, 'geometries': 0, 'geometry_ids': []} + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="place_geojson called before") + return self.place_geojson( + geojson, + height=height, + label_field='name', + geometry_id=geometry_id, + color=color or (0.30, 0.30, 0.30), + densify=False, + merge=True, + mesh_cache=mesh_cache, + ) + + def place_water(self, geojson, body_height=0.5, mesh_cache_prefix=None): + """Classify and place water features as coloured geometry on terrain. + + Splits the GeoJSON into three categories based on the ``waterway`` + and ``natural`` properties: + + * **major** — rivers, canals (bright blue ribbons) + * **minor** — streams, drains, ditches (pale blue ribbons) + * **body** — natural water polygons (extruded blue-grey) + + Parameters + ---------- + geojson : dict + GeoJSON FeatureCollection of water features (e.g. from + :func:`rtxpy.fetch_water` with ``water_type='all'``). + body_height : float, optional + Extrusion height for water body polygons. Default is 0.5. + mesh_cache_prefix : str or Path, optional + Base path for mesh cache files. Three files are created: + ``{prefix}_major_mesh.npz``, ``{prefix}_minor_mesh.npz``, + ``{prefix}_body_mesh.npz``. + + Returns + ------- + dict + ``{'major': info, 'minor': info, 'body': info}`` where each + *info* is the dict returned by :meth:`place_geojson`, keyed + only for categories that had features. + """ + import warnings + _MAJOR = {'river', 'canal'} + _MINOR = {'stream', 'drain', 'ditch'} + major, minor, body = [], [], [] + for f in geojson.get('features', []): + ww = (f.get('properties') or {}).get('waterway', '') + nat = (f.get('properties') or {}).get('natural', '') + if ww in _MAJOR: + major.append(f) + elif ww in _MINOR: + minor.append(f) + elif nat == 'water': + body.append(f) + else: + minor.append(f) + results = {} + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="place_geojson called before") + if major: + mc = f"{mesh_cache_prefix}_major_mesh.npz" if mesh_cache_prefix else None + results['major'] = self.place_geojson( + {"type": "FeatureCollection", "features": major}, + height=0, label_field='name', geometry_id='water_major', + color=(0.40, 0.70, 0.95, 2.25), + densify=False, merge=True, mesh_cache=mc, + ) + if minor: + mc = f"{mesh_cache_prefix}_minor_mesh.npz" if mesh_cache_prefix else None + results['minor'] = self.place_geojson( + {"type": "FeatureCollection", "features": minor}, + height=0, label_field='name', geometry_id='water_minor', + color=(0.50, 0.75, 0.98, 2.25), + densify=False, merge=True, mesh_cache=mc, + ) + if body: + mc = f"{mesh_cache_prefix}_body_mesh.npz" if mesh_cache_prefix else None + results['body'] = self.place_geojson( + {"type": "FeatureCollection", "features": body}, + height=body_height, label_field='name', geometry_id='water_body', + color=(0.35, 0.55, 0.88, 2.25), + extrude=True, merge=True, mesh_cache=mc, + ) + return results + def triangulate(self, geometry_id='terrain', scale=1.0, pixel_spacing_x=1.0, pixel_spacing_y=1.0): """Triangulate the terrain and add it to the scene. @@ -1455,7 +1615,7 @@ def explore(self, width=800, height=600, render_scale=0.5, start_position=None, look_at=None, key_repeat_interval=0.05, pixel_spacing_x=None, pixel_spacing_y=None, mesh_type='tin', color_stretch='linear', title=None, - subsample=1, wind_data=None): + subsample=1, wind_data=None, terrain_loader=None): """Launch an interactive terrain viewer with keyboard controls. Opens a matplotlib window for terrain exploration with keyboard @@ -1579,6 +1739,7 @@ def explore(self, width=800, height=600, render_scale=0.5, subsample=subsample, wind_data=wind_data, accessor=self, + terrain_loader=terrain_loader, ) def memory_usage(self): @@ -1832,6 +1993,33 @@ def place_geojson(self, geojson, z=None, **kwargs): terrain_da = self._get_terrain_da(z) return terrain_da.rtx.place_geojson(geojson, **kwargs) + def place_buildings(self, geojson, z=None, **kwargs): + """Place building footprints. Delegates to DataArray accessor.""" + if z is None: + z = self._z_var + if z is None: + raise ValueError("z must be specified (no prior terrain variable set)") + terrain_da = self._get_terrain_da(z) + return terrain_da.rtx.place_buildings(geojson, **kwargs) + + def place_roads(self, geojson, z=None, **kwargs): + """Place road LineStrings. Delegates to DataArray accessor.""" + if z is None: + z = self._z_var + if z is None: + raise ValueError("z must be specified (no prior terrain variable set)") + terrain_da = self._get_terrain_da(z) + return terrain_da.rtx.place_roads(geojson, **kwargs) + + def place_water(self, geojson, z=None, **kwargs): + """Classify and place water features. Delegates to DataArray accessor.""" + if z is None: + z = self._z_var + if z is None: + raise ValueError("z must be specified (no prior terrain variable set)") + terrain_da = self._get_terrain_da(z) + return terrain_da.rtx.place_water(geojson, **kwargs) + def explore(self, z, width=800, height=600, render_scale=0.5, start_position=None, look_at=None, key_repeat_interval=0.05, pixel_spacing_x=None, pixel_spacing_y=None, diff --git a/rtxpy/engine.py b/rtxpy/engine.py index a9d5685..86ba633 100644 --- a/rtxpy/engine.py +++ b/rtxpy/engine.py @@ -307,6 +307,22 @@ def __init__(self, raster, width: int = 800, height: int = 600, self._mouse_last_x = None self._mouse_last_y = None + # Dynamic terrain loading (zarr streaming) + self._terrain_loader = None # callback: (lon, lat) → xr.DataArray + self._coord_origin_x = 0.0 # lon of pixel (0,0) in current window + self._coord_origin_y = 0.0 # lat of pixel (0,0) + self._coord_step_x = 1.0 # lon step per pixel + self._coord_step_y = -1.0 # lat step per pixel (negative = southward) + self._reload_cooldown = 2.0 # min seconds between reloads + self._last_reload_time = 0.0 + + # Derive coordinate metadata from raster coords if available + if hasattr(raster, 'x') and hasattr(raster, 'y') and len(raster.x) > 1: + self._coord_origin_x = float(raster.x.values[0]) + self._coord_origin_y = float(raster.y.values[0]) + self._coord_step_x = float(raster.x.values[1] - raster.x.values[0]) + self._coord_step_y = float(raster.y.values[1] - raster.y.values[0]) + # Get terrain info H, W = raster.shape terrain_data = raster.data @@ -1608,6 +1624,147 @@ def _sync_drone_from_pos(self, pos): self._viewshed_cache = None self._calculate_viewshed(quiet=True) + def _check_terrain_reload(self): + """Check if camera is near terrain edge and reload a new window if needed.""" + if self._terrain_loader is None: + return + + now = time.time() + if now - self._last_reload_time < self._reload_cooldown: + return + + if self.position is None: + return + + H, W = self.terrain_shape + cam_col = self.position[0] / self.pixel_spacing_x + cam_row = self.position[1] / self.pixel_spacing_y + + # Check if camera is within 20% of any edge + margin_x = W * 0.2 + margin_y = H * 0.2 + near_edge = (cam_col < margin_x or cam_col > W - margin_x or + cam_row < margin_y or cam_row > H - margin_y) + if not near_edge: + return + + # Compute camera lon/lat from world position + cam_lon = self._coord_origin_x + cam_col * self._coord_step_x + cam_lat = self._coord_origin_y + cam_row * self._coord_step_y + + # Call the terrain loader + new_raster = self._terrain_loader(cam_lon, cam_lat) + if new_raster is None: + self._last_reload_time = now + return + + cam_z = self.position[2] + + # Extract coordinate metadata from new raster + new_origin_x = float(new_raster.x.values[0]) + new_origin_y = float(new_raster.y.values[0]) + new_step_x = float(new_raster.x.values[1] - new_raster.x.values[0]) + new_step_y = float(new_raster.y.values[1] - new_raster.y.values[0]) + + # Compute camera position in new window's pixel space + new_col = (cam_lon - new_origin_x) / new_step_x + new_row = (cam_lat - new_origin_y) / new_step_y + + # Replace rasters + self._base_raster = new_raster + self.raster = new_raster + + # Update coordinate tracking + self._coord_origin_x = new_origin_x + self._coord_origin_y = new_origin_y + self._coord_step_x = new_step_x + self._coord_step_y = new_step_y + + # Recompute terrain stats + new_H, new_W = new_raster.shape + self.terrain_shape = (new_H, new_W) + + terrain_data = new_raster.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() + else: + terrain_np = np.asarray(terrain_data) + + ve = self.vertical_exaggeration + self.elev_min = float(np.nanmin(terrain_np)) * ve + self.elev_max = float(np.nanmax(terrain_np)) * ve + self.elev_mean = float(np.nanmean(terrain_np)) * ve + + # Rebuild water mask + floor_val = float(np.nanmin(terrain_np)) + floor_max = float(np.nanmax(terrain_np)) + eps = (floor_max - floor_val) * 1e-4 if floor_max > floor_val else 1e-6 + self._water_mask = (terrain_np <= floor_val + eps) | np.isnan(terrain_np) + + land_pixels = terrain_np[~self._water_mask] + if land_pixels.size > 0: + self._land_color_range = (float(np.nanmin(land_pixels)) * ve, + float(np.nanmax(land_pixels)) * ve) + + # Clear terrain mesh cache (old window geometry is stale) + self._terrain_mesh_cache.clear() + self._baked_mesh_cache.clear() + + # Rebuild terrain mesh + from . import mesh as mesh_mod + + H, W = new_H, new_W + if self.mesh_type == 'voxel': + num_verts = H * W * 8 + num_tris = H * W * 12 + vertices = np.zeros(num_verts * 3, dtype=np.float32) + indices = np.zeros(num_tris * 3, dtype=np.int32) + base_elev = float(np.nanmin(terrain_np)) + mesh_mod.voxelate_terrain(vertices, indices, new_raster, scale=1.0, + base_elevation=base_elev) + else: + num_verts = H * W + num_tris = (H - 1) * (W - 1) * 2 + vertices = np.zeros(num_verts * 3, dtype=np.float32) + indices = np.zeros(num_tris * 3, dtype=np.int32) + mesh_mod.triangulate_terrain(vertices, indices, new_raster, scale=1.0) + + # Scale x,y to world units + if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: + vertices[0::3] *= self.pixel_spacing_x + vertices[1::3] *= self.pixel_spacing_y + + # Apply vertical exaggeration + if ve != 1.0: + vertices[2::3] *= ve + + # Cache the new mesh + cache_key = (self.subsample_factor, self.mesh_type) + base_verts = vertices.copy() + if ve != 1.0: + base_verts[2::3] /= ve + self._terrain_mesh_cache[cache_key] = (base_verts, indices.copy(), terrain_np.copy()) + + # Replace terrain geometry + if self.rtx is not None: + self.rtx.add_geometry('terrain', vertices, indices) + + # Reposition camera in new window + self.position = np.array([ + new_col * self.pixel_spacing_x, + new_row * self.pixel_spacing_y, + cam_z + ], dtype=float) + + # Refresh minimap + self._compute_minimap_background() + if self._minimap_im is not None: + self._minimap_im.set_data(self._minimap_background) + + self._last_reload_time = time.time() + print(f"Terrain reloaded: center ({cam_lon:.4f}, {cam_lat:.4f}), " + f"window {new_W}x{new_H}") + def _tick(self): """Continuous render loop — process held keys and redraw (called by timer).""" if not self.running: @@ -1686,6 +1843,7 @@ def _tick(self): if self._drone_mode == 'fpv' and self._observer_drone_placed: self._sync_drone_from_pos(self.position) + self._check_terrain_reload() self._update_frame() def _cycle_terrain_layer(self): @@ -2880,7 +3038,8 @@ def explore(raster, width: int = 800, height: int = 600, baked_meshes=None, subsample: int = 1, wind_data=None, - accessor=None): + accessor=None, + terrain_loader=None): """ Launch an interactive terrain viewer. @@ -2987,6 +3146,7 @@ def explore(raster, width: int = 800, height: int = 600, viewer._geometry_colors_builder = geometry_colors_builder viewer._baked_meshes = baked_meshes or {} viewer._accessor = accessor + viewer._terrain_loader = terrain_loader viewer.color_stretch = color_stretch if color_stretch in viewer._color_stretches: viewer._color_stretch_idx = viewer._color_stretches.index(color_stretch) From 531d130c5eafe86ef2821bac480cda691b058511 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 16 Feb 2026 06:17:21 -0800 Subject: [PATCH 3/5] add GTFS transit route and realtime vehicle support - fetch_gtfs() in remote_data.py: auto-discovers feeds via Mobility Database catalog (ranked by IoU), parses GTFS ZIP (shapes, stops, routes), supports direct URL and local file inputs - place_gtfs() in accessor.py: groups routes by route_color so each transit line renders in its official color (e.g. MTA subway trunk lines), with per-mode width multipliers and stop markers - GTFS-RT realtime overlay in engine.py: background poll thread for protobuf vehicle positions, colored dots projected to screen space, toggled with Shift+B - NaN guard in geojson.py _densify_on_terrain for edge-of-raster coords - Added gtfs to NYC example features list --- examples/new_york_city.py | 17 + rtxpy/__init__.py | 20 +- rtxpy/accessor.py | 802 ++++++- rtxpy/engine.py | 4394 +++++++++++++++++++++++++++++-------- rtxpy/geojson.py | 137 +- rtxpy/quickstart.py | 351 +++ rtxpy/remote_data.py | 1594 +++++++++++++- 7 files changed, 6279 insertions(+), 1036 deletions(-) create mode 100644 examples/new_york_city.py create mode 100644 rtxpy/quickstart.py diff --git a/examples/new_york_city.py b/examples/new_york_city.py new file mode 100644 index 0000000..8c0da7f --- /dev/null +++ b/examples/new_york_city.py @@ -0,0 +1,17 @@ +"""New York City — GPU-accelerated terrain exploration.""" +import sys +from rtxpy import quickstart + +tour = None +if '--tour' in sys.argv: + idx = sys.argv.index('--tour') + tour = sys.argv[idx + 1] if idx + 1 < len(sys.argv) else 'nyc_tour.py' + +quickstart( + name='nyc', + bounds=(-74.26, 40.49, -73.70, 40.92), + crs='EPSG:32618', + features=['buildings', 'roads', 'water', 'fire', 'restaurant_grades', 'gtfs'], + ao_samples=1, + tour=tour, +) diff --git a/rtxpy/__init__.py b/rtxpy/__init__.py index f1560cb..7f90c95 100644 --- a/rtxpy/__init__.py +++ b/rtxpy/__init__.py @@ -5,6 +5,7 @@ get_device_properties, list_devices, get_current_device, + get_capabilities, ) from .mesh import ( triangulate_terrain, @@ -15,6 +16,11 @@ make_transform, make_transforms_on_terrain, ) +from .mesh_store import ( + save_meshes_to_zarr, + load_meshes_from_zarr, + chunks_for_pixel_window, +) from .analysis import viewshed, hillshade, render, flyover, view from .engine import explore @@ -22,7 +28,19 @@ # Optional convenience — network helpers with lazy dependency checks try: - from .remote_data import fetch_dem, fetch_osm, fetch_buildings, fetch_roads, fetch_water, fetch_wind, fetch_firms + from .remote_data import fetch_dem, fetch_osm, fetch_buildings, fetch_roads, fetch_water, fetch_wind, fetch_firms, fetch_places, fetch_infrastructure, fetch_land_use, fetch_restaurant_grades, fetch_gtfs +except ImportError: + pass + +# Jupyter notebook viewer +try: + from .notebook import JupyterViewer +except ImportError: + pass + +# One-call launcher +try: + from .quickstart import quickstart except ImportError: pass diff --git a/rtxpy/accessor.py b/rtxpy/accessor.py index f24594e..b711d91 100644 --- a/rtxpy/accessor.py +++ b/rtxpy/accessor.py @@ -27,9 +27,9 @@ class RTXAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj self._rtx_instance = None - # Track pixel spacing for coordinate conversion (set by triangulate/place_mesh) - self._pixel_spacing_x = 1.0 - self._pixel_spacing_y = 1.0 + # Auto-compute pixel spacing from DataArray coordinates + from .analysis._common import _compute_pixel_spacing + self._pixel_spacing_x, self._pixel_spacing_y = _compute_pixel_spacing(xarray_obj) # Per-geometry solid color overrides: {geometry_id: (r, g, b)} self._geometry_colors = {} self._geometry_colors_dirty = True @@ -73,7 +73,13 @@ def to_cupy(self): import cupy if isinstance(self._obj.data, cupy.ndarray): return self._obj - return self._obj.copy(data=cupy.asarray(self._obj.data)) + from numba import cuda as numba_cuda + data = np.ascontiguousarray(self._obj.data) + pinned = numba_cuda.pinned_array(data.shape, dtype=data.dtype) + pinned[:] = data + gpu_data = cupy.asarray(pinned) + del pinned + return self._obj.copy(data=gpu_data) def viewshed(self, x, y, observer_elev=0, target_elev=0, rtx=None): """Compute viewshed from observer point. @@ -117,7 +123,7 @@ def viewshed(self, x, y, observer_elev=0, target_elev=0, rtx=None): self._obj, x=x, y=y, observer_elev=observer_elev, target_elev=target_elev, - rtx=rtx + rtx=rtx, ) def hillshade(self, shadows=False, azimuth=225, angle_altitude=25, @@ -371,6 +377,139 @@ def _build_geometry_colors_gpu(self): self._geometry_colors_dirty = False return self._geometry_colors_gpu + def save_meshes(self, zarr_path): + """Save all baked mesh geometries to a zarr store. + + Writes mesh data into a ``meshes/`` group inside the zarr store, + spatially partitioned to match the DEM elevation chunks. + Requires that geometries have been placed via ``place_geojson()``, + ``place_buildings()``, ``place_roads()``, or ``place_water()`` + with ``merge=True`` first. + + Parameters + ---------- + zarr_path : str or Path + Path to an existing zarr store (the same one used for DEM data). + + See Also + -------- + load_meshes : Load meshes back from a zarr store. + """ + from .mesh_store import save_meshes_to_zarr + import zarr as _zarr + + if not self._baked_meshes: + raise ValueError("No baked meshes to save. Place features first " + "(place_buildings, place_roads, etc.).") + + # Determine elevation chunks from the zarr store + store = _zarr.open(str(zarr_path), mode='r', use_consolidated=False) + elev = store['elevation'] + elev_shape = elev.shape + # zarr 3: chunk shape from metadata + if hasattr(elev, 'chunks'): + elev_chunks = elev.chunks + else: + elev_chunks = elev.metadata.chunk_grid.chunk_shape + store = None + + # Separate triangle meshes (3-tuple) from curves (4-tuple) + meshes = {} + curves = {} + for gid, baked in self._baked_meshes.items(): + if len(baked) == 4: + # curve geometry: (verts, widths, indices, base_z) + curves[gid] = (baked[0], baked[1], baked[2]) + else: + meshes[gid] = (baked[0], baked[1]) # (vertices, indices) + + save_meshes_to_zarr( + zarr_path, + meshes=meshes, + colors=self._geometry_colors, + pixel_spacing=(self._pixel_spacing_x, self._pixel_spacing_y), + elevation_shape=elev_shape, + elevation_chunks=elev_chunks, + curves=curves, + ) + + def load_meshes(self, zarr_path, chunks=None): + """Load mesh geometries from a zarr store and add them to the scene. + + Parameters + ---------- + zarr_path : str or Path + Path to a zarr store containing a ``meshes/`` group. + chunks : list of (int, int) or None + Specific chunk indices ``[(row, col), ...]`` to load. + ``None`` loads all chunks (full scene). + + See Also + -------- + save_meshes : Save meshes to a zarr store. + """ + from .mesh_store import load_meshes_from_zarr + + meshes, colors, meta, curves = load_meshes_from_zarr( + zarr_path, chunks=chunks) + + psx, psy = meta['pixel_spacing'] + self._pixel_spacing_x = psx + self._pixel_spacing_y = psy + + # Get terrain data for base_z computation + terrain_data_np = self._obj.data + if hasattr(terrain_data_np, 'get'): + terrain_data_np = terrain_data_np.get() + else: + terrain_data_np = np.asarray(terrain_data_np) + H, W = terrain_data_np.shape + + loaded = [] + for gid, (verts, indices) in meshes.items(): + if len(indices) == 0: + continue + self._rtx.add_geometry(gid, verts, indices) + self._geometry_colors[gid] = colors.get(gid, (0.6, 0.6, 0.6)) + + # Compute base_z for VE rescaling using bilinear interpolation + # to match the terrain triangle mesh surface + vx = verts[0::3] + vy = verts[1::3] + base_z = self._bilinear_terrain_z( + terrain_data_np, vx, vy, psx, psy) + self._baked_meshes[gid] = (verts.copy(), indices.copy(), base_z) + loaded.append(gid) + + # Load curve geometries (roads, water) + loaded_curves = [] + for gid, (verts, widths, indices) in curves.items(): + if len(indices) == 0: + continue + self._rtx.add_curve_geometry(gid, verts, widths, indices) + self._geometry_colors[gid] = colors.get(gid, (0.6, 0.6, 0.6)) + + vx = verts[0::3] + vy = verts[1::3] + base_z = self._bilinear_terrain_z( + terrain_data_np, vx, vy, psx, psy) + # 4-tuple signals curve geometry to the re-snap logic + self._baked_meshes[gid] = (verts.copy(), widths.copy(), + indices.copy(), base_z) + loaded_curves.append(gid) + + self._geometry_colors_dirty = True + n_tris = sum(len(meshes[g][1]) // 3 for g in loaded) + n_segs = sum(len(curves[g][2]) for g in loaded_curves) + parts = [] + if loaded: + parts.append(f"{n_tris:,} triangles") + if loaded_curves: + parts.append(f"{n_segs:,} curve segments") + total = len(loaded) + len(loaded_curves) + print(f"Loaded {total} mesh geometries ({', '.join(parts)}) " + f"from {zarr_path}") + def trace(self, rays, hits, num_rays, primitive_ids=None, instance_ids=None): """Trace rays against the current acceleration structure. @@ -413,7 +552,9 @@ def render(self, camera_position, look_at, fov=60.0, up=(0, 0, 1), shadows=True, ambient=0.15, fog_density=0.0, fog_color=(0.7, 0.8, 0.9), colormap='terrain', color_range=None, output_path=None, alpha=False, - vertical_exaggeration=None, rtx=None): + vertical_exaggeration=None, rtx=None, + ao_samples=0, ao_radius=None, gi_bounces=1, sun_angle=0.0, + aperture=0.0, focal_distance=0.0): """Render terrain with a perspective camera for movie-quality visualization. Uses OptiX ray tracing to render terrain with realistic lighting, shadows, @@ -466,6 +607,25 @@ def render(self, camera_position, look_at, fov=60.0, up=(0, 0, 1), rtx : RTX, optional Existing RTX instance to reuse. If None, uses the accessor's cached RTX instance. + ao_samples : int, optional + Number of ambient occlusion rays per pixel. 0 disables AO. + Higher values produce smoother results. Default is 0. + ao_radius : float, optional + Maximum AO ray distance. If None, auto-computes from scene extent + (~5% of diagonal). Default is None. + sun_angle : float, optional + Sun cone half-angle in degrees for soft shadows. 0 gives hard + shadows. Typical values: 0.25 (realistic), 1.5 (artistic). + Only effective when used with progressive accumulation + (frame_seed > 0). Default is 0.0. + aperture : float, optional + Lens aperture radius for depth of field. 0 disables DOF. + Requires frame_seed > 0 for progressive accumulation. + Default is 0.0. + focal_distance : float, optional + Distance to the focal plane. Objects at this distance are sharp. + If 0, auto-computes from camera-to-lookat distance. + Default is 0.0. Returns ------- @@ -503,6 +663,14 @@ def render(self, camera_position, look_at, fov=60.0, up=(0, 0, 1), alpha=alpha, vertical_exaggeration=vertical_exaggeration, rtx=rtx, # User can override, but default None creates fresh instance + pixel_spacing_x=self._pixel_spacing_x, + pixel_spacing_y=self._pixel_spacing_y, + ao_samples=ao_samples, + ao_radius=ao_radius, + gi_bounces=gi_bounces, + sun_angle=sun_angle, + aperture=aperture, + focal_distance=focal_distance, ) def place_mesh(self, mesh_source, positions, geometry_id=None, scale=1.0, @@ -655,6 +823,44 @@ def place_mesh(self, mesh_source, positions, geometry_id=None, scale=1.0, (1.00, 0.40, 0.60), # pink ] + @staticmethod + def _bilinear_terrain_z(terrain, vx, vy, psx, psy): + """Sample terrain Z at world positions using bilinear interpolation. + + Supports both numpy and cupy arrays — the array module is chosen + automatically based on the type of ``terrain``. + """ + if has_cupy: + import cupy as _cp + if isinstance(terrain, _cp.ndarray): + xp = _cp + else: + xp = np + else: + xp = np + H, W = terrain.shape + cols = vx / psx + rows = vy / psy + cols = xp.clip(cols, 0, W - 1) + rows = xp.clip(rows, 0, H - 1) + x0 = xp.clip(xp.floor(cols).astype(xp.int32), 0, max(W - 2, 0)) + y0 = xp.clip(xp.floor(rows).astype(xp.int32), 0, max(H - 2, 0)) + fx = cols - x0 + fy = rows - y0 + z00 = terrain[y0, x0].astype(xp.float32) + z10 = terrain[y0, xp.minimum(x0 + 1, W - 1)].astype(xp.float32) + z01 = terrain[xp.minimum(y0 + 1, H - 1), x0].astype(xp.float32) + z11 = terrain[xp.minimum(y0 + 1, H - 1), + xp.minimum(x0 + 1, W - 1)].astype(xp.float32) + z00 = xp.where(xp.isnan(z00), 0.0, z00) + z10 = xp.where(xp.isnan(z10), 0.0, z10) + z01 = xp.where(xp.isnan(z01), 0.0, z01) + z11 = xp.where(xp.isnan(z11), 0.0, z11) + return (z00 * (1 - fx) * (1 - fy) + + z10 * fx * (1 - fy) + + z01 * (1 - fx) * fy + + z11 * fx * fy) + def place_geojson(self, geojson, height=10.0, label_field=None, height_field=None, fill_mesh=None, fill_spacing=None, fill_scale=1.0, @@ -751,6 +957,7 @@ def place_geojson(self, geojson, height=10.0, _geojson_to_world_coords, _build_transformer, _make_marker_cube, _make_marker_orb, _densify_on_terrain, _linestring_to_tube_mesh, _linestring_to_ribbon_mesh, + _linestring_to_curve_data, _polygon_to_curve_data, _polygon_to_tube_mesh, _polygon_to_ribbon_mesh, _extrude_polygon, _scatter_in_polygon, ) @@ -785,15 +992,12 @@ def place_geojson(self, geojson, height=10.0, terrain_data_np = terrain_data_np.get() else: terrain_data_np = np.asarray(terrain_data_np) - _H, _W = terrain_data_np.shape _psx = self._pixel_spacing_x _psy = self._pixel_spacing_y _vx = merged_v[0::3] _vy = merged_v[1::3] - _ix = np.clip(np.round(_vx / _psx).astype(int), 0, _W - 1) - _iy = np.clip(np.round(_vy / _psy).astype(int), 0, _H - 1) - _base_z = terrain_data_np[_iy, _ix].astype(np.float32) - _base_z = np.where(np.isnan(_base_z), 0.0, _base_z) + _base_z = self._bilinear_terrain_z( + terrain_data_np, _vx, _vy, _psx, _psy) self._baked_meshes[geometry_id] = (merged_v.copy(), merged_idx.copy(), _base_z) @@ -871,6 +1075,12 @@ def place_geojson(self, geojson, height=10.0, _merge_vert_offset = 0 _merge_color = None + # Curve merge accumulators (for round curve tubes) + _curve_verts = [] + _curve_widths = [] + _curve_indices = [] + _curve_vert_offset = 0 + for feat_i, (geom, props) in enumerate(features): if geom is None or not geom.get("coordinates"): continue @@ -964,27 +1174,43 @@ def place_geojson(self, geojson, height=10.0, wc = _densify_on_terrain( wc, terrain_data, psx, psy, step=step) if use_ribbon: - v, idx = _linestring_to_ribbon_mesh( + curve_result = _linestring_to_curve_data( wc, width=width, hover=ribbon_hover, ) + if curve_result is None: + continue + cv, cw, ci = curve_result + use_color = color if color is not None else feat_color + if merge: + _curve_indices.append(ci + _curve_vert_offset) + _curve_verts.append(cv) + _curve_widths.append(cw) + _curve_vert_offset += len(cv) // 3 + if _merge_color is None: + _merge_color = use_color + else: + self._rtx.add_curve_geometry( + gid, cv, cw, ci) + geometry_ids.append(gid) + self._geometry_colors[gid] = use_color else: v, idx = _linestring_to_tube_mesh( wc, radius=feat_height * 0.1, hover=feat_height * 0.15, ) - if len(v) == 0: - continue - use_color = color if color is not None else feat_color - if merge: - _merge_indices.append(idx + _merge_vert_offset) - _merge_verts.append(v) - _merge_vert_offset += len(v) // 3 - if _merge_color is None: - _merge_color = use_color - else: - self._rtx.add_geometry(gid, v, idx) - geometry_ids.append(gid) - self._geometry_colors[gid] = use_color + if len(v) == 0: + continue + use_color = color if color is not None else feat_color + if merge: + _merge_indices.append(idx + _merge_vert_offset) + _merge_verts.append(v) + _merge_vert_offset += len(v) // 3 + if _merge_color is None: + _merge_color = use_color + else: + self._rtx.add_geometry(gid, v, idx) + geometry_ids.append(gid) + self._geometry_colors[gid] = use_color geom_counter += 1 elif ptype == "Polygon": @@ -1017,6 +1243,19 @@ def place_geojson(self, geojson, height=10.0, if extrude: v, idx = _extrude_polygon( rings_world, feat_height) + if len(v) > 0: + use_color = color if color is not None else feat_color + if merge: + _merge_indices.append(idx + _merge_vert_offset) + _merge_verts.append(v) + _merge_vert_offset += len(v) // 3 + if _merge_color is None: + _merge_color = use_color + else: + self._rtx.add_geometry(gid, v, idx) + geometry_ids.append(gid) + self._geometry_colors[gid] = use_color + geom_counter += 1 else: if densify is not False: step = 1.0 if densify is True else float(densify) @@ -1028,28 +1267,44 @@ def place_geojson(self, geojson, height=10.0, else: dense_rings = rings_world if use_ribbon: - v, idx = _polygon_to_ribbon_mesh( + curve_result = _polygon_to_curve_data( dense_rings, width=width, hover=ribbon_hover, ) + if curve_result is not None: + cv, cw, ci = curve_result + use_color = color if color is not None else feat_color + if merge: + _curve_indices.append(ci + _curve_vert_offset) + _curve_verts.append(cv) + _curve_widths.append(cw) + _curve_vert_offset += len(cv) // 3 + if _merge_color is None: + _merge_color = use_color + else: + self._rtx.add_curve_geometry( + gid, cv, cw, ci) + geometry_ids.append(gid) + self._geometry_colors[gid] = use_color + geom_counter += 1 else: v, idx = _polygon_to_tube_mesh( dense_rings, radius=feat_height * 0.1, hover=feat_height * 0.15, ) - if len(v) > 0: - use_color = color if color is not None else feat_color - if merge: - _merge_indices.append(idx + _merge_vert_offset) - _merge_verts.append(v) - _merge_vert_offset += len(v) // 3 - if _merge_color is None: - _merge_color = use_color - else: - self._rtx.add_geometry(gid, v, idx) - geometry_ids.append(gid) - self._geometry_colors[gid] = use_color - geom_counter += 1 + if len(v) > 0: + use_color = color if color is not None else feat_color + if merge: + _merge_indices.append(idx + _merge_vert_offset) + _merge_verts.append(v) + _merge_vert_offset += len(v) // 3 + if _merge_color is None: + _merge_color = use_color + else: + self._rtx.add_geometry(gid, v, idx) + geometry_ids.append(gid) + self._geometry_colors[gid] = use_color + geom_counter += 1 # Scatter mesh inside polygon if scatter_verts is not None and ring_pixel_coords is not None: @@ -1075,7 +1330,7 @@ def place_geojson(self, geojson, height=10.0, self._geometry_colors[fill_gid] = feat_color geom_counter += 1 - # Flush merged geometry as a single GAS + # Flush merged triangle geometry as a single GAS if merge and _merge_verts: merged_v = np.concatenate(_merge_verts) merged_idx = np.concatenate(_merge_indices).astype(np.int32) @@ -1108,6 +1363,32 @@ def place_geojson(self, geojson, height=10.0, f"({len(merged_v)//3} verts, " f"{len(merged_idx)//3} tris)") + # Flush merged curve geometry as a single curve GAS + if merge and _curve_verts: + merged_cv = np.concatenate(_curve_verts) + merged_cw = np.concatenate(_curve_widths) + merged_ci = np.concatenate(_curve_indices).astype(np.int32) + curve_gid = geometry_id if not _merge_verts else f"{geometry_id}_curves" + self._rtx.add_curve_geometry( + curve_gid, merged_cv, merged_cw, merged_ci) + geometry_ids.append(curve_gid) + if color is not None: + self._geometry_colors[curve_gid] = color + elif _merge_color is not None: + self._geometry_colors[curve_gid] = _merge_color + # Store for VE rescaling / resolution re-snapping + _vx = merged_cv[0::3] + _vy = merged_cv[1::3] + _ix = np.clip(np.round(_vx / psx).astype(int), 0, W - 1) + _iy = np.clip(np.round(_vy / psy).astype(int), 0, H - 1) + _base_z = terrain_data[_iy, _ix].astype(np.float32) + _base_z = np.where(np.isnan(_base_z), 0.0, _base_z) + # 4-tuple signals curve geometry to the re-snap logic + self._baked_meshes[curve_gid] = (merged_cv.copy(), + merged_cw.copy(), + merged_ci.copy(), + _base_z) + if oob_counter[0] > 0: warnings.warn( f"{oob_counter[0]} GeoJSON coordinate(s) outside raster " @@ -1124,7 +1405,7 @@ def place_geojson(self, geojson, height=10.0, 'geometry_ids': geometry_ids, } - def place_buildings(self, geojson, elev_scale=0.025, default_height_m=8.0, + def place_buildings(self, geojson, elev_scale=None, default_height_m=8.0, mesh_cache=None): """Place building footprints as extruded 3D geometry on terrain. @@ -1139,12 +1420,17 @@ def place_buildings(self, geojson, elev_scale=0.025, default_height_m=8.0, (e.g. from :func:`rtxpy.fetch_buildings`). elev_scale : float, optional Factor applied to real-world heights so they match the scaled - terrain. Default is 0.025. + terrain. When *None* (default), auto-computed so that + ``default_height_m`` buildings are roughly 2× the pixel + spacing — ensuring they are always visible at the terrain's + native resolution. default_height_m : float, optional Height in metres used when a feature has no ``height`` property. Default is 8.0. mesh_cache : str or Path, optional Path to an ``.npz`` file for caching the merged mesh. + The cache is automatically invalidated when ``elev_scale`` + or ``default_height_m`` change. Returns ------- @@ -1152,6 +1438,39 @@ def place_buildings(self, geojson, elev_scale=0.025, default_height_m=8.0, ``{'features': int, 'geometries': int, 'geometry_ids': list}`` """ import warnings + import json + from pathlib import Path as _CachePath + + if elev_scale is None: + avg_spacing = (self._pixel_spacing_x + self._pixel_spacing_y) / 2 + # Buildings should be ~2× pixel spacing for visibility + target_height = avg_spacing * 1 + elev_scale = max(1.0, target_height / default_height_m) + + # Invalidate mesh cache if height parameters changed + if mesh_cache is not None: + cache_p = _CachePath(mesh_cache) + meta_p = cache_p.with_suffix('.meta.json') + current_meta = { + 'elev_scale': elev_scale, + 'default_height_m': default_height_m, + } + if cache_p.exists(): + stale = True + if meta_p.exists(): + try: + old_meta = json.loads(meta_p.read_text()) + if old_meta == current_meta: + stale = False + except (json.JSONDecodeError, KeyError): + pass + if stale: + cache_p.unlink() + print(f"Invalidated stale mesh cache: {cache_p.name}") + # Write/update meta alongside cache + meta_p.parent.mkdir(parents=True, exist_ok=True) + meta_p.write_text(json.dumps(current_meta)) + for feat in geojson.get("features", []): props = feat.get("properties", {}) h = props.get("height", -1) @@ -1172,7 +1491,7 @@ def place_buildings(self, geojson, elev_scale=0.025, default_height_m=8.0, ) def place_roads(self, geojson, geometry_id='road', color=None, - height=1, mesh_cache=None): + height=3, mesh_cache=None): """Place road LineStrings as flat merged ribbon geometry on terrain. Parameters @@ -1206,7 +1525,7 @@ def place_roads(self, geojson, geometry_id='road', color=None, label_field='name', geometry_id=geometry_id, color=color or (0.30, 0.30, 0.30), - densify=False, + densify=True, merge=True, mesh_cache=mesh_cache, ) @@ -1217,9 +1536,13 @@ def place_water(self, geojson, body_height=0.5, mesh_cache_prefix=None): Splits the GeoJSON into three categories based on the ``waterway`` and ``natural`` properties: - * **major** — rivers, canals (bright blue ribbons) - * **minor** — streams, drains, ditches (pale blue ribbons) - * **body** — natural water polygons (extruded blue-grey) + * **major** — rivers, canals (bright blue curves) + * **minor** — streams, drains, ditches (pale blue curves) + * **body** — natural water polygons (densified extruded fill) + + All categories are densified at pixel resolution so geometry + follows the terrain surface smoothly. Water body polygons use + extruded geometry (walls + cap) for visible fill. Parameters ---------- @@ -1227,7 +1550,7 @@ def place_water(self, geojson, body_height=0.5, mesh_cache_prefix=None): GeoJSON FeatureCollection of water features (e.g. from :func:`rtxpy.fetch_water` with ``water_type='all'``). body_height : float, optional - Extrusion height for water body polygons. Default is 0.5. + Height offset for water body outline curves. Default is 0.5. mesh_cache_prefix : str or Path, optional Base path for mesh cache files. Three files are created: ``{prefix}_major_mesh.npz``, ``{prefix}_minor_mesh.npz``, @@ -1264,7 +1587,7 @@ def place_water(self, geojson, body_height=0.5, mesh_cache_prefix=None): {"type": "FeatureCollection", "features": major}, height=0, label_field='name', geometry_id='water_major', color=(0.40, 0.70, 0.95, 2.25), - densify=False, merge=True, mesh_cache=mc, + densify=True, merge=True, mesh_cache=mc, ) if minor: mc = f"{mesh_cache_prefix}_minor_mesh.npz" if mesh_cache_prefix else None @@ -1272,20 +1595,201 @@ def place_water(self, geojson, body_height=0.5, mesh_cache_prefix=None): {"type": "FeatureCollection", "features": minor}, height=0, label_field='name', geometry_id='water_minor', color=(0.50, 0.75, 0.98, 2.25), - densify=False, merge=True, mesh_cache=mc, + densify=True, merge=True, mesh_cache=mc, ) if body: mc = f"{mesh_cache_prefix}_body_mesh.npz" if mesh_cache_prefix else None results['body'] = self.place_geojson( {"type": "FeatureCollection", "features": body}, - height=body_height, label_field='name', geometry_id='water_body', + height=body_height, label_field='name', + geometry_id='water_body', color=(0.35, 0.55, 0.88, 2.25), - extrude=True, merge=True, mesh_cache=mc, + densify=True, extrude=True, merge=True, mesh_cache=mc, ) return results + def place_gtfs(self, gtfs_data, stop_height=8.0, route_width=None, + show_routes=True, show_stops=True, + mesh_cache_prefix=None): + """Classify and place GTFS transit features as coloured geometry. + + Routes are sub-grouped by their GTFS ``route_color`` so that + each trunk line renders in its official colour (e.g. the NYC + subway 1/2/3 in red, A/C/E in blue). When ``route_color`` is + absent, a default colour per mode category is used. + + Mode categories (from ``route_type``): + + * **tram** (0, 5) + * **subway** (1) + * **rail** (2) + * **bus** (3) + * **ferry** (4) + * **other** (6, 7, …) + + Parameters + ---------- + gtfs_data : dict + Result from :func:`rtxpy.fetch_gtfs` with ``'routes'``, + ``'stops'``, and ``'metadata'`` keys. + stop_height : float, optional + Height of stop marker geometry. Default 8.0. + route_width : float, optional + Base width for route curves. Default is ``pixel_spacing * 0.15``. + show_routes : bool + Whether to place route curves. Default ``True``. + show_stops : bool + Whether to place stop points. Default ``True``. + mesh_cache_prefix : str or Path, optional + Base path for mesh cache files. + + Returns + ------- + dict + Nested dict keyed by category, then by colour group label. + """ + import re as _re + import warnings + from .remote_data import _gtfs_route_type_name + + _CATEGORIES = { + 'tram': {'route_types': {0, 5}, 'color': (0.85, 0.20, 0.20), 'width_mult': 1.0}, + 'subway': {'route_types': {1}, 'color': (0.20, 0.45, 0.90), 'width_mult': 1.2}, + 'rail': {'route_types': {2}, 'color': (0.55, 0.25, 0.70), 'width_mult': 1.5}, + 'bus': {'route_types': {3}, 'color': (0.95, 0.65, 0.10), 'width_mult': 0.7}, + 'ferry': {'route_types': {4}, 'color': (0.20, 0.75, 0.85), 'width_mult': 1.3}, + 'other': {'route_types': {6, 7}, 'color': (0.70, 0.70, 0.70), 'width_mult': 0.8}, + } + + def _hex_to_rgb(hex_str): + """Convert '0088FF' or '#0088FF' to (r, g, b) floats.""" + h = hex_str.lstrip('#').strip() + if len(h) != 6: + return None + try: + return (int(h[0:2], 16) / 255.0, + int(h[2:4], 16) / 255.0, + int(h[4:6], 16) / 255.0) + except ValueError: + return None + + def _safe_id(s): + """Sanitise a string for use as a geometry_id component.""" + return _re.sub(r'[^a-zA-Z0-9]', '_', s).strip('_').lower() + + if route_width is None: + route_width = abs(self._pixel_spacing_x) * 0.15 + + # ----- classify routes by (category, route_color) -------------------- + # Key: (cat, hex_color_upper) -> list of features + color_groups = {} # {(cat, hex): [features]} + color_names = {} # {(cat, hex): set of route_short_name} + + if show_routes: + for f in gtfs_data.get('routes', {}).get('features', []): + props = f.get('properties') or {} + rt = props.get('route_type', 3) + try: + rt = int(rt) + except (ValueError, TypeError): + rt = 3 + cat = _gtfs_route_type_name(rt) + if cat not in _CATEGORIES: + cat = 'other' + + # Use route_color from feed if available + rc = (props.get('route_color') or '').strip().lstrip('#') + if len(rc) != 6: + rc = '' + key = (cat, rc.upper() if rc else '') + color_groups.setdefault(key, []).append(f) + sn = (props.get('route_short_name') or '').strip() + if sn: + color_names.setdefault(key, set()).add(sn) + + # ----- classify stops by (category, route_color) --------------------- + stop_groups = {} # {(cat, hex): [stop features]} + if show_stops: + for f in gtfs_data.get('stops', {}).get('features', []): + props = f.get('properties') or {} + rts = props.get('route_types', []) + if rts: + try: + rt = int(rts[0]) + except (ValueError, TypeError): + rt = 3 + cat = _gtfs_route_type_name(rt) + else: + cat = 'bus' + if cat not in _CATEGORIES: + cat = 'other' + # Use first route_color from stop's served routes + rcs = props.get('route_colors', []) + hex_c = rcs[0].upper() if rcs else '' + stop_groups.setdefault((cat, hex_c), []).append(f) + + # ----- place geometry per colour group ------------------------------- + results = {} + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="place_geojson called before") + + # Collect all (cat, hex) keys across routes and stops + all_keys = set(color_groups.keys()) | set(stop_groups.keys()) + + for cat, hex_c in sorted(all_keys): + info = _CATEGORIES.get(cat) + if info is None: + info = _CATEGORIES['other'] + + routes = color_groups.get((cat, hex_c), []) + stops = stop_groups.get((cat, hex_c), []) + if not routes and not stops: + continue + + # Determine colour: prefer feed's route_color, else default + rgb = _hex_to_rgb(hex_c) if hex_c else None + if rgb is None: + rgb = info['color'] + + # Build a human-readable label for this colour group + names = color_names.get((cat, hex_c), set()) + if names: + label = '/'.join(sorted(names)) + elif hex_c: + label = hex_c.lower() + else: + label = 'default' + + gid_suffix = _safe_id(label) if label != 'default' else cat + cat_result = {} + + if routes: + gid = f'gtfs_{cat}_{gid_suffix}_route' + mc = f"{mesh_cache_prefix}_{gid}_mesh.npz" if mesh_cache_prefix else None + w = route_width * info['width_mult'] + cat_result['routes'] = self.place_geojson( + {"type": "FeatureCollection", "features": routes}, + height=0, geometry_id=gid, + width=w, color=(rgb[0], rgb[1], rgb[2], 1.4), + densify=True, merge=True, mesh_cache=mc, + ) + + if stops: + gid = f'gtfs_{cat}_{gid_suffix}_stop' + mc = f"{mesh_cache_prefix}_{gid}_mesh.npz" if mesh_cache_prefix else None + cat_result['stops'] = self.place_geojson( + {"type": "FeatureCollection", "features": stops}, + height=stop_height, geometry_id=gid, + color=(rgb[0], rgb[1], rgb[2], 1.6), + merge=True, mesh_cache=mc, + ) + + results.setdefault(cat, {})[label] = cat_result + + return results + def triangulate(self, geometry_id='terrain', scale=1.0, - pixel_spacing_x=1.0, pixel_spacing_y=1.0): + pixel_spacing_x=None, pixel_spacing_y=None): """Triangulate the terrain and add it to the scene. Creates a triangle mesh from the raster elevation data and adds it @@ -1298,9 +1802,11 @@ def triangulate(self, geometry_id='terrain', scale=1.0, scale : float, optional Scale factor for elevation values. Default is 1.0. pixel_spacing_x : float, optional - X spacing between pixels in world units. Default is 1.0 (pixel coords). + X spacing between pixels in world units. If None, uses the + auto-computed value from the DataArray coordinates. pixel_spacing_y : float, optional - Y spacing between pixels in world units. Default is 1.0 (pixel coords). + Y spacing between pixels in world units. If None, uses the + auto-computed value from the DataArray coordinates. Returns ------- @@ -1319,6 +1825,12 @@ def triangulate(self, geometry_id='terrain', scale=1.0, from .mesh import triangulate_terrain import numpy as np + # Fall back to auto-computed spacing from __init__ + if pixel_spacing_x is None: + pixel_spacing_x = self._pixel_spacing_x + if pixel_spacing_y is None: + pixel_spacing_y = self._pixel_spacing_y + H, W = self._obj.shape # Allocate buffers @@ -1347,7 +1859,7 @@ def triangulate(self, geometry_id='terrain', scale=1.0, return vertices, indices def voxelate(self, geometry_id='terrain', scale=1.0, base_elevation=None, - pixel_spacing_x=1.0, pixel_spacing_y=1.0): + pixel_spacing_x=None, pixel_spacing_y=None): """Voxelate the terrain into box-columns and add to the scene. Creates a voxelized mesh where each raster cell becomes a rectangular @@ -1363,9 +1875,11 @@ def voxelate(self, geometry_id='terrain', scale=1.0, base_elevation=None, Z coordinate for the bottom of all columns. If None, uses min(terrain) * scale so all columns have visible height. pixel_spacing_x : float, optional - X spacing between pixels in world units. Default is 1.0. + X spacing between pixels in world units. If None, uses the + auto-computed value from the DataArray coordinates. pixel_spacing_y : float, optional - Y spacing between pixels in world units. Default is 1.0. + Y spacing between pixels in world units. If None, uses the + auto-computed value from the DataArray coordinates. Returns ------- @@ -1375,6 +1889,12 @@ def voxelate(self, geometry_id='terrain', scale=1.0, base_elevation=None, from .mesh import voxelate_terrain import numpy as np + # Fall back to auto-computed spacing from __init__ + if pixel_spacing_x is None: + pixel_spacing_x = self._pixel_spacing_x + if pixel_spacing_y is None: + pixel_spacing_y = self._pixel_spacing_y + H, W = self._obj.shape # Auto-compute base elevation from terrain minimum @@ -1414,6 +1934,75 @@ def voxelate(self, geometry_id='terrain', scale=1.0, base_elevation=None, return vertices, indices + def heightfield(self, geometry_id='terrain', scale=1.0, + pixel_spacing_x=None, pixel_spacing_y=None, + tile_size=32): + """Add the terrain as a heightfield using custom intersection. + + Instead of materializing an explicit triangle mesh, this uploads + the raw elevation grid and uses an OptiX custom intersection + program that ray-marches through the heightfield at trace time. + This dramatically reduces GPU memory and provides smooth bilinear + normals (no triangle faceting). + + Parameters + ---------- + geometry_id : str, optional + ID for the terrain geometry. Default is 'terrain'. + scale : float, optional + Scale factor for elevation values. Default is 1.0. + pixel_spacing_x : float, optional + X spacing between pixels in world units. If None, uses the + auto-computed value from the DataArray coordinates. + pixel_spacing_y : float, optional + Y spacing between pixels in world units. If None, uses the + auto-computed value from the DataArray coordinates. + tile_size : int, optional + Number of cells per tile dimension for AABB grouping. + Smaller tiles = tighter BVH but more AABBs. Default is 32. + + Returns + ------- + numpy.ndarray + The 2-D elevation array that was uploaded to the GPU. + """ + import numpy as np + + if pixel_spacing_x is None: + pixel_spacing_x = self._pixel_spacing_x + if pixel_spacing_y is None: + pixel_spacing_y = self._pixel_spacing_y + + H, W = self._obj.shape + + # Get elevation data + terrain_data = self._obj.data if hasattr(self._obj, 'data') else self._obj + if has_cupy: + import cupy + if isinstance(terrain_data, cupy.ndarray): + elev = terrain_data.get().astype(np.float32) + else: + elev = np.asarray(terrain_data, dtype=np.float32) + else: + elev = np.asarray(terrain_data, dtype=np.float32) + + if scale != 1.0: + elev = elev * scale + + self._rtx.add_heightfield_geometry( + geometry_id, elev, H, W, + spacing_x=pixel_spacing_x, + spacing_y=pixel_spacing_y, + ve=1.0, + tile_size=tile_size, + ) + + self._pixel_spacing_x = pixel_spacing_x + self._pixel_spacing_y = pixel_spacing_y + self._terrain_mesh_type = 'heightfield' + + return elev + def flyover(self, output_path, duration=30.0, fps=10.0, orbit_scale=0.6, altitude_offset=500.0, fov=60.0, fov_range=None, width=1280, height=720, sun_azimuth=225, sun_altitude=35, @@ -1594,7 +2183,7 @@ def place_tiles(self, url='osm', zoom=None): ---------- url : str, optional Provider name or custom URL template with {z}/{x}/{y} placeholders. - Built-in providers: 'osm', 'satellite', 'topo'. + Built-in providers: 'osm', 'satellite'. Default is 'osm' (OpenStreetMap). zoom : int, optional Tile zoom level (0–19). If ``None``, defaults to 13. @@ -1614,12 +2203,15 @@ def place_tiles(self, url='osm', zoom=None): def explore(self, width=800, height=600, render_scale=0.5, start_position=None, look_at=None, key_repeat_interval=0.05, pixel_spacing_x=None, pixel_spacing_y=None, - mesh_type='tin', color_stretch='linear', title=None, - subsample=1, wind_data=None, terrain_loader=None): + mesh_type='heightfield', color_stretch='linear', title=None, + subsample=1, wind_data=None, gtfs_data=None, + terrain_loader=None, + scene_zarr=None, ao_samples=0, gi_bounces=1, denoise=False, + repl=False, tour=None): """Launch an interactive terrain viewer with keyboard controls. - Opens a matplotlib window for terrain exploration with keyboard - controls. Uses matplotlib's event system - no additional dependencies. + Opens a GLFW window for terrain exploration with keyboard + controls. Uses GLFW for input and ModernGL for display. Any meshes placed with place_mesh() will be visible in the scene. Use the G key to cycle through geometry layer information. @@ -1654,6 +2246,15 @@ def explore(self, width=800, height=600, render_scale=0.5, Initial terrain subsample factor (1, 2, 4, 8). Full-resolution data is preserved; press Shift+R / R to change at runtime. Default is 1 (full resolution). + ao_samples : int, optional + If > 0, enable ambient occlusion on launch with progressive + accumulation (1 sample per frame). Press 0 to toggle at runtime. + Default is 0 (disabled). + repl : bool, optional + If True, start an interactive Python REPL in a background + thread. The REPL exposes a ``v`` (ViewerProxy) object for + running analysis and updating the display while the viewer + is running. Default is False. Controls -------- @@ -1702,10 +2303,14 @@ def explore(self, width=800, height=600, render_scale=0.5, spacing_y = pixel_spacing_y if pixel_spacing_y is not None else self._pixel_spacing_y # Rebuild terrain geometry if mesh_type doesn't match current state - current_mesh_type = getattr(self, '_terrain_mesh_type', 'tin') + current_mesh_type = getattr(self, '_terrain_mesh_type', 'heightfield') if mesh_type != current_mesh_type and 'terrain' in (self._rtx.list_geometries() or []): self._rtx.remove_geometry('terrain') - if mesh_type == 'voxel': + if mesh_type == 'heightfield': + self.heightfield(geometry_id='terrain', + pixel_spacing_x=spacing_x, + pixel_spacing_y=spacing_y) + elif mesh_type == 'voxel': self.voxelate(geometry_id='terrain', pixel_spacing_x=spacing_x, pixel_spacing_y=spacing_y) @@ -1719,7 +2324,7 @@ def explore(self, width=800, height=600, render_scale=0.5, if self._geometry_colors: geometry_colors_builder = self._build_geometry_colors_gpu - _explore( + return _explore( self._obj, width=width, height=height, @@ -1738,8 +2343,15 @@ def explore(self, width=800, height=600, render_scale=0.5, baked_meshes=self._baked_meshes if self._baked_meshes else None, subsample=subsample, wind_data=wind_data, + gtfs_data=gtfs_data, accessor=self, terrain_loader=terrain_loader, + scene_zarr=scene_zarr, + ao_samples=ao_samples, + gi_bounces=gi_bounces, + denoise=denoise, + repl=repl, + tour=tour, ) def memory_usage(self): @@ -2020,11 +2632,40 @@ def place_water(self, geojson, z=None, **kwargs): terrain_da = self._get_terrain_da(z) return terrain_da.rtx.place_water(geojson, **kwargs) + def place_gtfs(self, gtfs_data, z=None, **kwargs): + """Classify and place GTFS transit features. Delegates to DataArray accessor.""" + if z is None: + z = self._z_var + if z is None: + raise ValueError("z must be specified (no prior terrain variable set)") + terrain_da = self._get_terrain_da(z) + return terrain_da.rtx.place_gtfs(gtfs_data, **kwargs) + + def save_meshes(self, zarr_path, z=None): + """Save baked meshes to zarr. Delegates to DataArray accessor.""" + if z is None: + z = self._z_var + if z is None: + z = list(self._obj.data_vars)[0] + terrain_da = self._get_terrain_da(z) + return terrain_da.rtx.save_meshes(zarr_path) + + def load_meshes(self, zarr_path, chunks=None, z=None): + """Load meshes from zarr. Delegates to DataArray accessor.""" + if z is None: + z = self._z_var + if z is None: + z = list(self._obj.data_vars)[0] + terrain_da = self._get_terrain_da(z) + return terrain_da.rtx.load_meshes(zarr_path, chunks=chunks) + def explore(self, z, width=800, height=600, render_scale=0.5, start_position=None, look_at=None, key_repeat_interval=0.05, pixel_spacing_x=None, pixel_spacing_y=None, - mesh_type='tin', color_stretch='linear', title=None, - subsample=1, wind_data=None): + mesh_type='heightfield', color_stretch='linear', title=None, + subsample=1, wind_data=None, gtfs_data=None, + scene_zarr=None, + ao_samples=0, gi_bounces=1, denoise=False, repl=False, tour=None): """Launch an interactive terrain viewer with Dataset variables as overlay layers cycled with the G key. @@ -2057,17 +2698,27 @@ def explore(self, z, width=800, height=600, render_scale=0.5, Y spacing between pixels in world units. Default is 1.0. mesh_type : str, optional Mesh generation method: 'tin' or 'voxel'. Default is 'tin'. + repl : bool, optional + If True, start an interactive Python REPL alongside the + viewer. Default is False. + tour : list of dict or str, optional + If provided, automatically play a camera tour after the + viewer launches. Can be a list of keyframe dicts or a + path to a ``.py`` file defining a ``tour`` variable. + Implies ``repl=True``. Examples -------- >>> ds.rtx.explore(z='elevation') + >>> ds.rtx.explore(z='elevation', repl=True) """ from .engine import explore as _explore terrain_da = self._get_terrain_da(z) - spacing_x = pixel_spacing_x if pixel_spacing_x is not None else self._pixel_spacing_x - spacing_y = pixel_spacing_y if pixel_spacing_y is not None else self._pixel_spacing_y + # Use explicit values, else fall back to the terrain DataArray's auto-computed spacing + spacing_x = pixel_spacing_x if pixel_spacing_x is not None else terrain_da.rtx._pixel_spacing_x + spacing_y = pixel_spacing_y if pixel_spacing_y is not None else terrain_da.rtx._pixel_spacing_y # NOTE: terrain mesh is built by the engine at the correct # (possibly subsampled) resolution — not here at full res. @@ -2093,7 +2744,7 @@ def explore(self, z, width=800, height=600, render_scale=0.5, if terrain_da.rtx._geometry_colors: geometry_colors_builder = terrain_da.rtx._build_geometry_colors_gpu - _explore( + return _explore( terrain_da, width=width, height=height, @@ -2113,7 +2764,14 @@ def explore(self, z, width=800, height=600, render_scale=0.5, baked_meshes=terrain_da.rtx._baked_meshes if terrain_da.rtx._baked_meshes else None, subsample=subsample, wind_data=wind_data, + gtfs_data=gtfs_data, accessor=terrain_da.rtx, + scene_zarr=scene_zarr, + ao_samples=ao_samples, + gi_bounces=gi_bounces, + denoise=denoise, + repl=repl, + tour=tour, ) def place_tiles(self, url='osm', z=None, zoom=None): @@ -2128,7 +2786,7 @@ def place_tiles(self, url='osm', z=None, zoom=None): ---------- url : str, optional Provider name or custom URL template with {z}/{x}/{y} placeholders. - Built-in providers: 'osm', 'satellite', 'topo'. + Built-in providers: 'osm', 'satellite'. Default is 'osm' (OpenStreetMap). z : str, optional Name of the Dataset variable to use as the spatial reference diff --git a/rtxpy/engine.py b/rtxpy/engine.py index 86ba633..9efb197 100644 --- a/rtxpy/engine.py +++ b/rtxpy/engine.py @@ -1,12 +1,20 @@ -"""Interactive terrain viewer using matplotlib for display. +"""Interactive terrain viewer using GLFW + ModernGL for display. This module provides a simple game-engine-like render loop for exploring terrain interactively with keyboard controls. -Uses matplotlib for display (no additional dependencies). +Uses GLFW for windowing/input and ModernGL for GPU texture display. """ +import os +import queue +import threading import time import numpy as np + +# On WSL2 the hardware GLX drivers often segfault. Force Mesa software +# rendering for the display path (CUDA still handles the ray tracing). +if 'microsoft' in os.uname().release.lower(): + os.environ.setdefault('LIBGL_ALWAYS_SOFTWARE', '1') from typing import Optional, Tuple from .rtx import RTX, has_cupy @@ -15,12 +23,944 @@ import cupy as cp +# --------------------------------------------------------------------------- +# OpenGL shaders for fullscreen textured quad +# --------------------------------------------------------------------------- +_QUAD_VERT = """ +#version 330 +in vec2 in_pos; +in vec2 in_uv; +out vec2 v_uv; +void main() { + gl_Position = vec4(in_pos, 0.0, 1.0); + v_uv = in_uv; +} +""" + +_QUAD_FRAG = """ +#version 330 +uniform sampler2D frame; +in vec2 v_uv; +out vec4 fragColor; +void main() { + fragColor = vec4(texture(frame, v_uv).rgb, 1.0); +} +""" + + +def _glfw_to_key(glfw_key, mods): + """Translate a GLFW key code + modifiers to the string format used by + _handle_key_press / _handle_key_release. + + Returns (raw_key, key_lower) matching the old matplotlib convention: + - raw_key preserves case (uppercase if SHIFT held for letters) + - key_lower is always lowercase + """ + import glfw + + _SPECIAL = { + glfw.KEY_UP: 'up', glfw.KEY_DOWN: 'down', + glfw.KEY_LEFT: 'left', glfw.KEY_RIGHT: 'right', + glfw.KEY_PAGE_UP: 'pageup', glfw.KEY_PAGE_DOWN: 'pagedown', + glfw.KEY_ESCAPE: 'escape', + glfw.KEY_EQUAL: '=', glfw.KEY_MINUS: '-', + glfw.KEY_COMMA: ',', glfw.KEY_PERIOD: '.', + glfw.KEY_LEFT_BRACKET: '[', glfw.KEY_RIGHT_BRACKET: ']', + glfw.KEY_SEMICOLON: ';', glfw.KEY_APOSTROPHE: "'", + } + + if glfw_key in _SPECIAL: + raw = _SPECIAL[glfw_key] + # SHIFT variants for special keys + if mods & glfw.MOD_SHIFT: + if raw == '=': + raw = '+' + elif raw == '-': + raw = '_' # unlikely to be used, keep '-' behaviour + elif raw == ';': + raw = ':' + elif raw == "'": + raw = '"' + return raw, raw.lower() + + # Letter keys A-Z + if glfw.KEY_A <= glfw_key <= glfw.KEY_Z: + lower = chr(glfw_key - glfw.KEY_A + ord('a')) + if mods & glfw.MOD_SHIFT: + return lower.upper(), lower + return lower, lower + + # Digit keys 0-9 + if glfw.KEY_0 <= glfw_key <= glfw.KEY_9: + digit = chr(glfw_key - glfw.KEY_0 + ord('0')) + return digit, digit + + return '', '' + + +# --------------------------------------------------------------------------- +# Multi-observer system — up to 8 independent observers with drone/tour +# --------------------------------------------------------------------------- + +OBSERVER_COLORS = [ + (1.0, 0.2, 0.2), # 1: red + (0.2, 0.6, 1.0), # 2: blue + (0.2, 1.0, 0.3), # 3: green + (1.0, 0.8, 0.1), # 4: yellow + (1.0, 0.4, 0.0), # 5: orange + (0.8, 0.2, 1.0), # 6: purple + (0.0, 1.0, 0.9), # 7: cyan + (1.0, 0.5, 0.7), # 8: pink +] + + +class Observer: + """State for a single observer slot (1-8).""" + + __slots__ = ( + 'slot', 'position', 'observer_elev', 'drone_mode', 'drone_placed', + 'yaw', 'pitch', 'saved_camera', 'tour_thread', 'tour_stop', + 'viewshed_enabled', 'viewshed_cache', + ) + + def __init__(self, slot, position, observer_elev=0.05): + self.slot = slot + self.position = position # (x, y) world coords + self.observer_elev = observer_elev + self.drone_mode = 'off' # 'off' | '3rd' | 'fpv' + self.drone_placed = False + self.yaw = 0.0 + self.pitch = 0.0 + self.saved_camera = None # (position, yaw, pitch) + self.tour_thread = None + self.tour_stop = threading.Event() + self.viewshed_enabled = False + self.viewshed_cache = None + + @property + def color(self): + return OBSERVER_COLORS[(self.slot - 1) % len(OBSERVER_COLORS)] + + def geometry_id(self, part_idx): + """Unique geometry ID for a drone sub-mesh, e.g. '_observer3_2'.""" + return f'_observer{self.slot}_{part_idx}' + + def is_touring(self): + return (self.tour_thread is not None and self.tour_thread.is_alive()) + + def stop_tour(self): + self.tour_stop.set() + if self.tour_thread is not None: + self.tour_thread.join(timeout=2.0) + self.tour_thread = None + self.tour_stop.clear() + + +def _bilinear_terrain_z(terrain, vx, vy, psx, psy): + """Sample terrain Z at world positions using bilinear interpolation. + + This matches the interpolation used by the triangle mesh surface, + preventing Z mismatches between placed meshes and the rendered terrain. + + Supports both numpy and cupy arrays — the array module is chosen + automatically based on the type of ``terrain``. + """ + if has_cupy and isinstance(terrain, cp.ndarray): + xp = cp + else: + xp = np + H, W = terrain.shape + cols = vx / psx + rows = vy / psy + cols = xp.clip(cols, 0, W - 1) + rows = xp.clip(rows, 0, H - 1) + x0 = xp.clip(xp.floor(cols).astype(xp.int32), 0, max(W - 2, 0)) + y0 = xp.clip(xp.floor(rows).astype(xp.int32), 0, max(H - 2, 0)) + fx = cols - x0 + fy = rows - y0 + z00 = terrain[y0, x0].astype(xp.float32) + z10 = terrain[y0, xp.minimum(x0 + 1, W - 1)].astype(xp.float32) + z01 = terrain[xp.minimum(y0 + 1, H - 1), x0].astype(xp.float32) + z11 = terrain[xp.minimum(y0 + 1, H - 1), + xp.minimum(x0 + 1, W - 1)].astype(xp.float32) + z00 = xp.where(xp.isnan(z00), 0.0, z00) + z10 = xp.where(xp.isnan(z10), 0.0, z10) + z01 = xp.where(xp.isnan(z01), 0.0, z01) + z11 = xp.where(xp.isnan(z11), 0.0, z11) + return (z00 * (1 - fx) * (1 - fy) + + z10 * fx * (1 - fy) + + z01 * (1 - fx) * fy + + z11 * fx * fy) + + +class _MeshChunkManager: + """Dynamically loads/unloads mesh chunks based on camera position. + + Manages chunk lifecycle: reads per-chunk mesh data from a zarr store, + caches it in memory, and merges visible chunks per geometry ID into + the RTX scene. Only nearby chunks (within ``radius`` of the camera) + are kept in the scene; the rest are removed. + """ + + def __init__(self, zarr_path, psx, psy): + import zarr as _zarr + store = _zarr.open(str(zarr_path), mode='r', use_consolidated=False) + mg = store['meshes'] + + self._elev_shape = tuple(mg.attrs['elevation_shape']) + self._elev_chunks = tuple(mg.attrs['elevation_chunks']) + self._chunk_h, self._chunk_w = self._elev_chunks + self._psx = psx + self._psy = psy + self._n_chunk_rows = (self._elev_shape[0] + self._chunk_h - 1) // self._chunk_h + self._n_chunk_cols = (self._elev_shape[1] + self._chunk_w - 1) // self._chunk_w + + # Per-gid colors from zarr attrs + self._colors = {} + self._gids = [] + for gid in mg: + gg = mg[gid] + if hasattr(gg, 'attrs'): + self._colors[gid] = tuple(gg.attrs.get('color', (0.6, 0.6, 0.6))) + self._gids.append(gid) + + # Cache: (cr, cc) -> {gid: (verts, indices)} or None if empty + self._cache = {} + self._visible = set() + self._active_gids = set() # gids currently in the RTX scene + self.radius = 2 + self._zarr_path = zarr_path + + def _load_chunk(self, cr, cc): + """Load a single chunk from zarr into cache.""" + if (cr, cc) in self._cache: + return + from .mesh_store import load_meshes_from_zarr + meshes, _, _, curves = load_meshes_from_zarr( + self._zarr_path, chunks=[(cr, cc)]) + # Merge curves into the same dict with a marker + combined = {} + for gid, data in meshes.items(): + combined[gid] = data # (verts, indices) + for gid, data in curves.items(): + combined[gid] = data # (verts, widths, indices) + self._cache[(cr, cc)] = combined + + def update(self, cam_x, cam_y, viewer): + """Called per tick. Returns True if meshes changed.""" + # Camera world pos -> chunk coord + cc_cam = int(cam_x / self._psx) // self._chunk_w + cr_cam = int(cam_y / self._psy) // self._chunk_h + + # Compute visible ring clamped to grid + cr0 = max(cr_cam - self.radius, 0) + cr1 = min(cr_cam + self.radius, self._n_chunk_rows - 1) + cc0 = max(cc_cam - self.radius, 0) + cc1 = min(cc_cam + self.radius, self._n_chunk_cols - 1) + + new_visible = set() + for cr in range(cr0, cr1 + 1): + for cc in range(cc0, cc1 + 1): + new_visible.add((cr, cc)) + + if new_visible == self._visible: + return False + + self._visible = new_visible + + # Load any uncached chunks + for cr, cc in new_visible: + self._load_chunk(cr, cc) + + # Merge visible chunks per gid + merged = {} + for gid in self._gids: + all_verts = [] + all_widths = [] + all_indices = [] + vert_offset = 0 + is_curve = False + for cr, cc in sorted(new_visible): + chunk_data = self._cache.get((cr, cc), {}) + if gid not in chunk_data: + continue + data = chunk_data[gid] + if len(data) == 3: + # Curve geometry: (verts, widths, indices) + verts, widths, indices = data + is_curve = True + if len(indices) == 0: + continue + all_widths.append(widths) + else: + verts, indices = data + if len(indices) == 0: + continue + all_indices.append(indices + vert_offset) + all_verts.append(verts) + vert_offset += len(verts) // 3 + if all_verts: + if is_curve: + merged[gid] = (np.concatenate(all_verts), + np.concatenate(all_widths), + np.concatenate(all_indices)) + else: + merged[gid] = (np.concatenate(all_verts), + np.concatenate(all_indices)) + + # Remove gids no longer present + rtx = viewer.rtx + accessor = viewer._accessor + for gid in list(self._active_gids): + if gid not in merged: + rtx.remove_geometry(gid) + if accessor is not None: + accessor._baked_meshes.pop(gid, None) + accessor._geometry_colors.pop(gid, None) + self._active_gids.discard(gid) + + # Get current (possibly subsampled) terrain data + terrain_np = viewer.raster.data + if hasattr(terrain_np, 'get'): + terrain_np = terrain_np.get() + else: + terrain_np = np.asarray(terrain_np) + H, W = terrain_np.shape + ve = viewer.vertical_exaggeration + + # Get full-res terrain for computing original base_z + base_terrain = viewer._base_raster.data + if hasattr(base_terrain, 'get'): + base_terrain_np = base_terrain.get() + else: + base_terrain_np = np.asarray(base_terrain) + base_psx = viewer._base_pixel_spacing_x + base_psy = viewer._base_pixel_spacing_y + + # Upload terrain to GPU once (use cached if available) + gpu_terrain = None + gpu_base_terrain = None + if has_cupy: + if viewer._gpu_terrain is None: + viewer._gpu_terrain = cp.asarray(terrain_np) + gpu_terrain = viewer._gpu_terrain + if viewer._gpu_base_terrain is None: + viewer._gpu_base_terrain = cp.asarray(base_terrain_np) + gpu_base_terrain = viewer._gpu_base_terrain + + # Add/update merged gids + for gid, data in merged.items(): + is_curve = len(data) == 3 + if is_curve: + verts, widths, indices = data + else: + verts, indices = data + + # Re-snap Z coordinates to current terrain surface + VE. + # Meshes from zarr have Z computed from the full-res terrain. + # When terrain is subsampled, the rendered surface differs from + # the full-res values, so we re-anchor each vertex's height + # offset onto the current terrain using bilinear interpolation. + n_verts = len(verts) // 3 + use_gpu = (gpu_terrain is not None + and gpu_base_terrain is not None + and n_verts > 1000) + + if use_gpu: + vx = cp.asarray(verts[0::3]) + vy = cp.asarray(verts[1::3]) + vz_stored = cp.asarray(verts[2::3]) + + orig_base_z_gpu = _bilinear_terrain_z( + gpu_base_terrain, vx, vy, base_psx, base_psy) + z_offset = vz_stored - orig_base_z_gpu + + new_base_z = _bilinear_terrain_z( + gpu_terrain, vx, vy, + viewer.pixel_spacing_x, viewer.pixel_spacing_y) + + updated_verts_gpu = cp.asarray(verts.copy()) + updated_verts_gpu[2::3] = (new_base_z + z_offset) * ve + + if is_curve: + rtx.add_curve_geometry( + gid, updated_verts_gpu, + cp.asarray(widths), cp.asarray(indices)) + else: + rtx.add_geometry(gid, updated_verts_gpu, cp.asarray(indices)) + self._active_gids.add(gid) + + if accessor is not None: + accessor._geometry_colors[gid] = self._colors.get(gid, (0.6, 0.6, 0.6)) + orig_base_z_np = orig_base_z_gpu.get() + if is_curve: + accessor._baked_meshes[gid] = ( + verts.copy(), widths.copy(), indices.copy(), orig_base_z_np) + else: + accessor._baked_meshes[gid] = (verts.copy(), indices.copy(), orig_base_z_np) + else: + vx = verts[0::3] + vy = verts[1::3] + vz_stored = verts[2::3].copy() + + orig_base_z = _bilinear_terrain_z( + base_terrain_np, vx, vy, base_psx, base_psy) + z_offset = vz_stored - orig_base_z + + new_base_z = _bilinear_terrain_z( + terrain_np, vx, vy, + viewer.pixel_spacing_x, viewer.pixel_spacing_y) + + updated_verts = verts.copy() + updated_verts[2::3] = (new_base_z + z_offset) * ve + + if is_curve: + rtx.add_curve_geometry(gid, updated_verts, widths, indices) + else: + rtx.add_geometry(gid, updated_verts, indices) + self._active_gids.add(gid) + + if accessor is not None: + accessor._geometry_colors[gid] = self._colors.get(gid, (0.6, 0.6, 0.6)) + if is_curve: + accessor._baked_meshes[gid] = ( + verts.copy(), widths.copy(), indices.copy(), orig_base_z) + else: + accessor._baked_meshes[gid] = (verts.copy(), indices.copy(), orig_base_z) + + if accessor is not None: + accessor._geometry_colors_dirty = True + + # Refresh viewer geometry tracking (same pattern as FIRMS toggle) + viewer._all_geometries = rtx.list_geometries() + groups = set() + for g in viewer._all_geometries: + parts = g.rsplit('_', 1) + if len(parts) == 2 and parts[1].isdigit(): + base = parts[0] + else: + base = g + if base != 'terrain': + groups.add(base) + viewer._geometry_layer_order = ['none', 'all'] + sorted(groups) + + # Apply current visibility mode + layer_idx = viewer._geometry_layer_idx + if layer_idx < len(viewer._geometry_layer_order): + layer_name = viewer._geometry_layer_order[layer_idx] + else: + layer_name = 'none' + viewer._geometry_layer_idx = 0 + + for geom_id in viewer._all_geometries: + if geom_id == 'terrain': + continue + if layer_name == 'none': + rtx.set_geometry_visible(geom_id, False) + elif layer_name == 'all': + rtx.set_geometry_visible(geom_id, True) + else: + parts = geom_id.rsplit('_', 1) + base_name = parts[0] if len(parts) == 2 and parts[1].isdigit() else geom_id + visible = (base_name == layer_name or geom_id == layer_name) + rtx.set_geometry_visible(geom_id, visible) + + n_tris = 0 + n_segs = 0 + for g in merged: + if len(merged[g]) == 3: + n_segs += len(merged[g][2]) + else: + n_tris += len(merged[g][1]) // 3 + parts = [] + if n_tris > 0: + parts.append(f"{n_tris:,} triangles") + if n_segs > 0: + parts.append(f"{n_segs:,} curve segments") + print(f"Mesh chunks: loaded {len(new_visible)} chunks, " + f"{len(merged)} geometries ({', '.join(parts)})") + return True + + +class ViewerProxy: + """Thread-safe handle to the running InteractiveViewer. + + Exposed as ``v`` (and ``viewer``) in the REPL started by + ``explore(repl=True)``. Methods push callables onto a queue that + the main GLFW thread drains each tick, so OptiX calls always + happen on the correct thread. + """ + + def __init__(self, viewer: 'InteractiveViewer'): + self._viewer = viewer + self._q = viewer._command_queue + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _submit(self, fn): + """Push *fn* to the render thread and block until it completes. + + ``fn(viewer)`` is called on the main thread. Returns whatever + ``fn`` returns. + """ + result = [None] + error = [None] + event = threading.Event() + + def _wrapped(viewer): + try: + result[0] = fn(viewer) + except Exception as exc: + error[0] = exc + finally: + event.set() + + self._q.put(_wrapped) + event.wait(timeout=30) + if error[0] is not None: + raise error[0] + return result[0] + + def _submit_fire_and_forget(self, fn): + """Push *fn* without blocking the caller.""" + self._q.put(fn) + + # ------------------------------------------------------------------ + # Read-only state (safe to access from any thread) + # ------------------------------------------------------------------ + + @property + def raster(self): + """The current terrain DataArray (may be subsampled).""" + return self._viewer.raster + + @property + def base_raster(self): + """The full-resolution terrain DataArray.""" + return self._viewer._base_raster + + @property + def position(self): + """Current camera position ``(x, y, z)`` as a numpy array copy.""" + return self._viewer.position.copy() + + @property + def yaw(self): + return self._viewer.yaw + + @property + def pitch(self): + return self._viewer.pitch + + @property + def shadows(self): + return self._viewer.shadows + + @shadows.setter + def shadows(self, value): + def fn(v): + v.shadows = bool(value) + v._update_frame() + self._submit(fn) + + @property + def colormap(self): + return self._viewer.colormap + + @colormap.setter + def colormap(self, value): + self.set_colormap(value) + + @property + def vertical_exaggeration(self): + return self._viewer.vertical_exaggeration + + @property + def overlay_names(self): + """List of available overlay layer names.""" + return list(self._viewer._overlay_layers.keys()) + + # ------------------------------------------------------------------ + # Terrain analysis (run on main thread → display result) + # ------------------------------------------------------------------ + + def hillshade(self, **kwargs): + """Compute hillshade on the current terrain and show it.""" + def fn(v): + acc = v._accessor + if acc is None: + print("No accessor — cannot compute hillshade") + return + data = acc.hillshade(**kwargs) + _add_overlay(v, 'hillshade', data.data) + self._submit(fn) + + def viewshed(self, x, y, observer_elev=2, **kwargs): + """Compute viewshed and show it as an overlay.""" + def fn(v): + acc = v._accessor + if acc is None: + print("No accessor — cannot compute viewshed") + return + data = acc.viewshed(x=x, y=y, observer_elev=observer_elev, + **kwargs) + _add_overlay(v, 'viewshed', data.data) + self._submit(fn) + + def slope(self, **kwargs): + """Compute slope and show it as an overlay.""" + def fn(v): + acc = v._accessor + if acc is None: + print("No accessor — cannot compute slope") + return + data = acc.slope(**kwargs) + _add_overlay(v, 'slope', data.data) + self._submit(fn) + + def aspect(self, **kwargs): + """Compute aspect and show it as an overlay.""" + def fn(v): + acc = v._accessor + if acc is None: + print("No accessor — cannot compute aspect") + return + data = acc.aspect(**kwargs) + _add_overlay(v, 'aspect', data.data) + self._submit(fn) + + # ------------------------------------------------------------------ + # Layer management + # ------------------------------------------------------------------ + + def add_layer(self, name, data): + """Add (or replace) a named overlay layer and switch to it. + + Parameters + ---------- + name : str + Layer name shown when cycling with G. + data : array-like + 2-D array (numpy or cupy) matching the terrain shape. + """ + def fn(v): + _add_overlay(v, name, data) + self._submit(fn) + + def remove_layer(self, name): + """Remove an overlay layer by name.""" + def fn(v): + if name in v._overlay_layers: + del v._overlay_layers[name] + if name in v._base_overlay_layers: + del v._base_overlay_layers[name] + v._overlay_names = list(v._overlay_layers.keys()) + v._terrain_layer_order = ( + ['elevation'] + list(v._overlay_names)) + # Reset to elevation if we removed the active layer + if v._terrain_layer_idx >= len(v._terrain_layer_order): + v._terrain_layer_idx = 0 + v._active_overlay_data = None + v._update_frame() + print(f"Removed layer: {name}") + self._submit(fn) + + def show_layer(self, name): + """Switch the terrain coloring to a named layer (or 'elevation').""" + def fn(v): + if name == 'elevation': + v._active_color_data = None + v._active_overlay_data = None + v._terrain_layer_idx = 0 + v._update_frame() + print("Terrain: elevation") + return + if name not in v._overlay_layers: + print(f"Unknown layer: {name}. " + f"Available: {list(v._overlay_layers.keys())}") + return + idx = v._terrain_layer_order.index(name) + v._terrain_layer_idx = idx + v._active_color_data = None + v._active_overlay_data = v._overlay_layers[name] + v._update_frame() + print(f"Terrain: {name}") + self._submit(fn) + + # ------------------------------------------------------------------ + # Display settings + # ------------------------------------------------------------------ + + def set_colormap(self, cmap): + """Change the active colormap by name (e.g. 'terrain', 'viridis').""" + def fn(v): + v.colormap = cmap + v._update_frame() + print(f"Colormap: {cmap}") + self._submit(fn) + + def set_color_stretch(self, stretch): + """Set the color stretch ('linear', 'sqrt', 'cbrt', 'log').""" + def fn(v): + if stretch in v._color_stretches: + v.color_stretch = stretch + v._color_stretch_idx = v._color_stretches.index(stretch) + v._update_frame() + print(f"Color stretch: {stretch}") + else: + print(f"Unknown stretch: {stretch}. " + f"Options: {v._color_stretches}") + self._submit(fn) + + # ------------------------------------------------------------------ + # Misc + # ------------------------------------------------------------------ + + def screenshot(self): + """Save a screenshot. Uses the viewer's built-in logic.""" + def fn(v): + v._save_screenshot() + self._submit(fn) + + def run(self, fn): + """Execute an arbitrary callable on the render thread. + + ``fn(viewer)`` receives the ``InteractiveViewer`` instance. + Use this for anything not covered by a convenience method. + """ + return self._submit(fn) + + # ------------------------------------------------------------------ + # Tour / scripting + # ------------------------------------------------------------------ + + def mark(self): + """Capture current camera state as a keyframe dict. + + Fly to desired positions manually, call ``v.mark()`` to record + them, then assemble the results into a tour keyframe list. + """ + from .tour import mark_camera + kf = mark_camera(self) + import pprint + pprint.pprint(kf) + return kf + + def tour(self, keyframes, fps=30, record=False, output_dir='.', + loop=False): + """Play a scripted camera tour. + + Parameters + ---------- + keyframes : list of dict or str + List of keyframe dicts, or path to a ``.py`` file that + defines a ``tour`` variable containing the keyframe list. + fps : int + Target playback framerate. + record : bool + Save each frame as a PNG for video assembly. + output_dir : str or Path + Directory for recorded frames. + loop : bool + Repeat the tour indefinitely until the viewer closes. + """ + from .tour import play_tour + if isinstance(keyframes, str): + ns = {} + with open(keyframes) as f: + exec(f.read(), ns) + keyframes = ns['tour'] + if 'loop' in ns: + loop = ns['loop'] + play_tour(self, keyframes, fps=fps, record=record, + output_dir=output_dir, loop=loop) + + def show_geometry(self, name): + """Show only a specific geometry group (or ``'all'`` / ``'none'``). + + Parameters + ---------- + name : str + Geometry group name, ``'all'`` to show everything, or + ``'none'`` to hide all non-terrain geometries. + """ + def fn(v): + if name == 'all': + for gid in v._all_geometries: + v.rtx.set_geometry_visible(gid, True) + print("Geometry: all") + elif name == 'none': + for gid in v._all_geometries: + if gid != 'terrain': + v.rtx.set_geometry_visible(gid, False) + print("Geometry: none") + else: + visible_count = 0 + for gid in v._all_geometries: + parts = gid.rsplit('_', 1) + base = parts[0] if len(parts) == 2 and parts[1].isdigit() else gid + visible = (base == name or gid == name or gid == 'terrain') + v.rtx.set_geometry_visible(gid, visible) + if visible: + visible_count += 1 + print(f"Geometry: {name} ({visible_count} visible)") + v._update_frame() + self._submit(fn) + + # ------------------------------------------------------------------ + # Multi-observer API + # ------------------------------------------------------------------ + + def place_observer(self, slot, x=None, y=None): + """Create or move an observer in *slot* (1-8). + + Defaults to the current camera position if *x*/*y* are omitted. + """ + def fn(v): + # Observer is defined at module level in this file + if slot not in v._observers: + obs = Observer(slot, position=None, + observer_elev=v.viewshed_observer_elev) + v._observers[slot] = obs + obs = v._observers[slot] + v._place_observer_at(obs, x=x, y=y) + v._active_observer = slot + self._submit(fn) + + def remove_observer(self, slot): + """Remove an observer from *slot*.""" + def fn(v): + v._clear_observer_slot(slot) + self._submit(fn) + + def remove_all_observers(self): + """Kill all observers — stop tours, exit drone modes, remove all.""" + def fn(v): + v._clear_all_observers() + self._submit(fn) + + def select_observer(self, slot): + """Select an observer slot for keyboard control.""" + def fn(v): + if slot in v._observers: + v._active_observer = slot + print(f"Observer {slot}: selected") + else: + print(f"Observer {slot} does not exist") + self._submit(fn) + + def observer_tour(self, slot, keyframes, fps=30, loop=False): + """Run a tour on an observer's drone. + + Parameters + ---------- + slot : int + Observer slot (1-8). Auto-created at first keyframe if needed. + keyframes : list of dict or str + Keyframe list, or path to a ``.py`` file containing a ``tour`` + variable. + fps : int + Target playback framerate. + loop : bool + Repeat indefinitely until stopped. + """ + import threading as _threading + from .tour import play_observer_tour + + if isinstance(keyframes, str): + ns = {} + with open(keyframes) as f: + exec(f.read(), ns) + keyframes = ns['tour'] + if 'loop' in ns: + loop = ns['loop'] + + # Auto-create observer at first keyframe position if needed + first_pos = None + for kf in keyframes: + if 'position' in kf: + first_pos = kf['position'] + break + + def _setup(v): + # Observer is defined at module level in this file + if slot in v._observers: + obs = v._observers[slot] + obs.stop_tour() + else: + obs = Observer(slot, position=None, + observer_elev=v.viewshed_observer_elev) + v._observers[slot] = obs + if obs.position is None and first_pos is not None: + v._place_observer_at(obs, x=first_pos[0], y=first_pos[1]) + elif obs.position is None: + v._place_observer_at(obs) + + self._submit(_setup) + + def _tour_thread(): + play_observer_tour(self, slot, keyframes, fps=fps, loop=loop) + + obs = self._viewer._observers.get(slot) + if obs is not None: + obs.tour_stop.clear() + t = _threading.Thread(target=_tour_thread, daemon=True) + obs.tour_thread = t + t.start() + + def stop_observer_tour(self, slot): + """Stop a running tour on observer *slot*.""" + obs = self._viewer._observers.get(slot) + if obs is not None: + obs.stop_tour() + print(f"Observer {slot} tour stopped") + + def observer_position(self, slot): + """Get an observer's current (x, y, z) position.""" + obs = self._viewer._observers.get(slot) + if obs is None: + return None + ox, oy = obs.position + tz = self._viewer._get_terrain_z(ox, oy) + return (ox, oy, tz + obs.observer_elev) + + def __repr__(self): + v = self._viewer + layers = ', '.join(v._overlay_layers.keys()) or '(none)' + obs_info = '' + if v._observers: + obs_info = f", observers={list(v._observers.keys())}" + return (f"ViewerProxy(layers=[{layers}], " + f"colormap={v.colormap!r}, " + f"shadows={v.shadows}{obs_info})") + + +def _add_overlay(viewer, name, data): + """Add or replace an overlay layer on *viewer* and switch to it. + + Must be called on the main (render) thread. + """ + viewer._overlay_layers[name] = data + viewer._base_overlay_layers[name] = data + viewer._overlay_names = list(viewer._overlay_layers.keys()) + viewer._terrain_layer_order = ( + ['elevation'] + list(viewer._overlay_names)) + idx = viewer._terrain_layer_order.index(name) + viewer._terrain_layer_idx = idx + viewer._active_color_data = None + viewer._active_overlay_data = data + viewer._update_frame() + print(f"Terrain: {name}") + + class InteractiveViewer: """ - Interactive terrain viewer using matplotlib. + Interactive terrain viewer using GLFW + ModernGL. Provides keyboard-controlled camera for exploring ray-traced terrain. - Uses matplotlib's event system for input handling. + Uses GLFW for windowing/input and ModernGL for GPU texture display. Controls -------- @@ -36,7 +976,7 @@ class InteractiveViewer: - +/=: Increase speed - -: Decrease speed - G: Cycle terrain color (elevation → overlays) - - U: Cycle basemap (none → satellite → osm → topo) + - U: Cycle basemap (none → satellite → osm) - N: Cycle geometry layer (none → all → groups) - P: Jump to previous geometry in current group - ,/.: Decrease/increase overlay alpha (transparency) @@ -51,9 +991,13 @@ class InteractiveViewer: - B: Toggle mesh type (TIN / voxel) - Y: Cycle color stretch (linear, sqrt, cbrt, log) - T: Toggle shadows + - 0: Toggle ambient occlusion (progressive) + - Shift+G: Cycle GI bounces (1→2→3→1) + - Shift+D: Toggle OptiX AI Denoiser - C: Cycle colormap - Shift+F: Fetch/toggle FIRMS fire layer (7d LANDSAT 30m) - Shift+W: Toggle wind particle animation + - Shift+B: Toggle GTFS-RT realtime vehicle overlay - F: Save screenshot - M: Toggle minimap overlay - H: Toggle help overlay @@ -69,7 +1013,7 @@ def __init__(self, raster, width: int = 800, height: int = 600, render_scale: float = 0.5, key_repeat_interval: float = 0.05, rtx: 'RTX' = None, pixel_spacing_x: float = 1.0, pixel_spacing_y: float = 1.0, - mesh_type: str = 'tin', + mesh_type: str = 'heightfield', overlay_layers: dict = None, title: str = None, subsample: int = 1): @@ -128,6 +1072,16 @@ def __init__(self, raster, width: int = 800, height: int = 600, self.subsample_factor = max(1, int(subsample)) self._terrain_mesh_cache = {} # (factor, mesh_type) -> (verts_base, indices, terrain_np) self._baked_mesh_cache = {} # (factor, geom_id) -> (scaled_v, orig_idx) + self._chunk_manager = None # set by explore() when scene_zarr provided + + # GPU terrain cache for accelerated mesh Z re-snapping + self._gpu_terrain = None # CuPy array of current (subsampled) terrain + self._gpu_base_terrain = None # CuPy array of full-res terrain (stable) + + # Async readback: non-blocking stream + pinned host buffer + self._readback_stream = cp.cuda.Stream(non_blocking=True) + self._pinned_mem = None + self._pinned_frame = None # Apply initial subsample to the working raster if self.subsample_factor > 1: @@ -165,7 +1119,7 @@ def __init__(self, raster, width: int = 800, height: int = 600, self._terrain_layer_idx = 0 # Independent basemap cycling (U key) - self._basemap_options = ['none', 'satellite', 'osm', 'topo'] + self._basemap_options = ['none', 'satellite', 'osm'] self._basemap_idx = 0 # Title / name for display @@ -226,6 +1180,14 @@ def __init__(self, raster, width: int = 800, height: int = 600, # Rendering settings self.fov = 60.0 + self._time_presets = [ + ('Morning', 135.0, 25.0), + ('Midday', 180.0, 65.0), + ('Afternoon', 225.0, 35.0), + ('Golden Hour', 270.0, 12.0), + ('Sunset', 280.0, 3.0), + ] + self._time_preset_idx = 2 # Afternoon (default) self.sun_azimuth = 225.0 self.sun_altitude = 35.0 self.shadows = True @@ -235,6 +1197,27 @@ def __init__(self, raster, width: int = 800, height: int = 600, self.colormap_idx = 0 self.color_stretch = 'linear' + # Ambient occlusion state + self.ao_enabled = False + self.ao_radius = None # auto-computed from scene extent + self.gi_intensity = 2.0 # GI bounce intensity multiplier + self.gi_bounces = 1 # Number of GI bounces (1=single, 2-3=multi) + self._ao_samples_per_frame = 4 # AO rays per pixel per frame + self._ao_max_frames = 32 # stop accumulating after this many frames + self._ao_frame_count = 0 + self._d_ao_accum = None # GPU accumulation buffer (H, W, 3) float32 + self._prev_cam_state = None # (position_tuple, yaw, pitch, fov) for dirty detection + + # Denoiser state + self.denoise_enabled = False + self._prev_cam_for_flow = None # (pos, forward, right, up, aspect, fov_scale) from prev frame + self._d_flow = None # (H, W, 2) float32 motion vectors + + # Depth of field state + self.dof_enabled = False + self._dof_aperture = 20.0 # lens radius in scene units + self._dof_focal_distance = 1000.0 # focal plane distance (= look_at distance) + # Tile overlay settings self._tile_service = None self._tiles_enabled = False @@ -249,13 +1232,10 @@ def __init__(self, raster, width: int = 800, height: int = 600, self._viewshed_coverage = 0.0 # Percentage of terrain visible self._viewshed_recalc_interval = 0.4 # Seconds between dynamic recalcs self._last_viewshed_time = 0.0 # Timestamp of last viewshed calc - self._observer_position = None # Fixed observer position (x, y) in terrain coords - self._observer_drone_parts = None # List of (verts, idxs, (r,g,b)) per sub-mesh - self._observer_drone_placed = False # Whether drone geometry is in the scene - self._drone_mode = 'off' # 'off' | '3rd' | 'fpv' - self._saved_camera = None # (position, yaw, pitch) before entering drone mode - self._drone_yaw = 0.0 # Drone heading (for 3rd-person flight) - self._drone_pitch = 0.0 # Drone pitch (for 3rd-person flight) + # Multi-observer system (up to 8 independent observers) + self._observers = {} # dict[int, Observer] — slot 1-8 + self._active_observer = None # int (slot 1-8) or None + self._shared_drone_parts = None # loaded once from drone.glb, shared by all # State self.running = False @@ -265,16 +1245,16 @@ def __init__(self, raster, width: int = 800, height: int = 600, self._last_title = None self._last_subtitle = None - # Minimap state (initialized in run() via _compute_minimap_background/_create_minimap) - self._minimap_ax = None - self._minimap_im = None - self._minimap_camera_dot = None - self._minimap_direction_line = None - self._minimap_fov_wedge = None - self._minimap_observer_dot = None + # Minimap state (initialized in run() via _compute_minimap_background) self._minimap_background = None self._minimap_scale_x = 1.0 self._minimap_scale_y = 1.0 + self._minimap_has_tiles = False + self._minimap_rect = None # (x0, y0, w, h) in frame coords + self._drone_glow = False + + # Help text cache (pre-rendered RGBA numpy array via PIL) + self._help_text_rgba = None # FIRMS fire layer state self._accessor = None # RTX accessor for place_geojson @@ -288,19 +1268,50 @@ def __init__(self, raster, width: int = 800, height: int = 600, self._wind_v_px = None # (H, W) V component in pixels/tick self._wind_particles = None # (N, 2) particle positions in pixel coords (row, col) self._wind_ages = None # (N,) age in ticks - self._wind_max_age = 120 # Max lifetime before respawn - self._wind_n_particles = 6000 + self._wind_max_age = 80 # Max lifetime before respawn + self._wind_n_particles = 10000 self._wind_trail_len = 20 # Number of trail positions to keep self._wind_trails = None # (N, trail_len, 2) ring buffer of past positions - self._wind_speed_mult = 50.0 # Velocity exaggeration for visibility + self._wind_speed_mult = 250.0 # Velocity exaggeration for visibility self._wind_min_depth = 0.0 # Min camera distance to render (set in _init_wind) - self._wind_dot_radius = 3 # Radius of each particle dot in screen pixels - self._wind_alpha = 0.035 # Per-pixel alpha for particle dots + self._wind_dot_radius = 2 # Radius of each particle dot in screen pixels + self._wind_alpha = 0.055 # Per-pixel alpha for particle dots + self._wind_min_visible_age = 6 # Ticks before particle becomes visible (builds trail first) + self._wind_terrain_np = None # Cached CPU terrain for wind Z lookup + + # GTFS-RT realtime vehicle overlay state + self._gtfs_rt_url = None + self._gtfs_rt_enabled = False + self._gtfs_rt_vehicles = None # (positions, bearings, colors) tuple + self._gtfs_rt_poll_interval = 15.0 + self._gtfs_rt_thread = None # daemon Thread + self._gtfs_rt_stop = threading.Event() + self._gtfs_rt_lock = threading.Lock() + self._gtfs_rt_route_colors = {} # {route_id: (r,g,b)} + self._gtfs_rt_dot_radius = 4 # Screen pixels per vehicle dot + self._gtfs_rt_alpha = 0.85 # Dot alpha # Held keys tracking for smooth simultaneous input self._held_keys = set() - self._tick_interval = int(key_repeat_interval * 1000) # Convert to ms for timer - self._timer = None + + # GLFW window handle (set in run()) + self._glfw_window = None + self._display_frame = None + self._render_needed = True # Flag: something changed, need to re-render + + # REPL command queue — background REPL thread pushes callables, + # main loop drains and executes them on the render thread. + self._command_queue = queue.Queue() + self._repl = False + + # FPS tracking + self._fps_counter = 0 + self._fps_last_time = 0.0 + self._fps_display = 0.0 + + # Delta-time for frame-rate-independent movement + self._last_tick_time = 0.0 # set in run() + self._dt_scale = 1.0 # multiplier: actual_dt / reference_dt(0.05) # Mouse drag state for slippy-map panning self._mouse_dragging = False @@ -323,19 +1334,6 @@ def __init__(self, raster, width: int = 800, height: int = 600, self._coord_step_x = float(raster.x.values[1] - raster.x.values[0]) self._coord_step_y = float(raster.y.values[1] - raster.y.values[0]) - # Get terrain info - H, W = raster.shape - terrain_data = raster.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) - - self.terrain_shape = (H, W) - self.elev_min = float(np.nanmin(terrain_np)) - self.elev_max = float(np.nanmax(terrain_np)) - self.elev_mean = float(np.nanmean(terrain_np)) - # Build water mask from *full-resolution* base raster (not subsampled) # so it can be applied to full-resolution overlay layers. base_data = self._base_raster.data @@ -343,11 +1341,50 @@ def __init__(self, raster, width: int = 800, height: int = 600, base_np = base_data.get() else: base_np = np.asarray(base_data) + + # Detect ocean-fill: global DEMs (Copernicus, SRTM) fill ocean with + # exactly 0.0 instead of NaN/nodata. Replace with NaN so the render + # kernel ocean water shader activates over true ocean areas. + ocean_fill = (base_np == 0.0) & ~np.isnan(base_np) + n_ocean_fill = int(ocean_fill.sum()) + if n_ocean_fill > base_np.size * 0.01: + base_np[ocean_fill] = np.nan # local copy for water_mask below + # Create a copy of the raster data with NaN-marked ocean + if hasattr(self._base_raster.data, 'get'): # cupy + new_data = self._base_raster.data.copy() + new_data[cp.asarray(ocean_fill)] = cp.nan + else: + new_data = self._base_raster.data.copy() + new_data[ocean_fill] = np.nan + self._base_raster = self._base_raster.copy(data=new_data) + # Re-derive working raster from updated base + if self.subsample_factor > 1: + f = self.subsample_factor + self.raster = self._base_raster.isel({ + self._base_raster.dims[0]: slice(None, None, f), + self._base_raster.dims[1]: slice(None, None, f) + }) + else: + self.raster = self._base_raster + floor_val = float(np.nanmin(base_np)) floor_max = float(np.nanmax(base_np)) eps = (floor_max - floor_val) * 1e-4 if floor_max > floor_val else 1e-6 self._water_mask = (base_np <= floor_val + eps) | np.isnan(base_np) + # Get terrain info (after ocean-fill → NaN replacement) + H, W = self.raster.shape + terrain_data = self.raster.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() + else: + terrain_np = np.asarray(terrain_data) + + self.terrain_shape = (H, W) + self.elev_min = float(np.nanmin(terrain_np)) + self.elev_max = float(np.nanmax(terrain_np)) + self.elev_mean = float(np.nanmean(terrain_np)) + # Compute land-only elevation range for coloring (excludes water) land_pixels = base_np[~self._water_mask] if land_pixels.size > 0: @@ -386,32 +1423,43 @@ def __init__(self, raster, width: int = 800, height: int = 600, # units), producing wrong results when pixel_spacing != 1. if rtx is not None and not rtx.has_geometry('terrain'): from . import mesh as mesh_mod - if mesh_type == 'voxel': - nv = H * W * 8 - nt = H * W * 12 - verts = np.zeros(nv * 3, dtype=np.float32) - idxs = np.zeros(nt * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(verts, idxs, raster, scale=1.0, - base_elevation=base_elev) + if mesh_type == 'heightfield': + rtx.add_heightfield_geometry( + 'terrain', terrain_np, H, W, + spacing_x=self.pixel_spacing_x, + spacing_y=self.pixel_spacing_y, + ve=1.0, + ) + cache_key = (self.subsample_factor, mesh_type) + self._terrain_mesh_cache[cache_key] = ( + None, None, terrain_np.copy(), + ) else: - nv = H * W - nt = (H - 1) * (W - 1) * 2 - verts = np.zeros(nv * 3, dtype=np.float32) - idxs = np.zeros(nt * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(verts, idxs, raster, scale=1.0) - - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - verts[0::3] *= self.pixel_spacing_x - verts[1::3] *= self.pixel_spacing_y - - # Cache the initial terrain mesh at scale=1.0 (before VE) - cache_key = (self.subsample_factor, mesh_type) - self._terrain_mesh_cache[cache_key] = ( - verts.copy(), idxs.copy(), terrain_np.copy(), - ) + if mesh_type == 'voxel': + nv = H * W * 8 + nt = H * W * 12 + verts = np.zeros(nv * 3, dtype=np.float32) + idxs = np.zeros(nt * 3, dtype=np.int32) + base_elev = float(np.nanmin(terrain_np)) + mesh_mod.voxelate_terrain(verts, idxs, raster, scale=1.0, + base_elevation=base_elev) + else: + nv = H * W + nt = (H - 1) * (W - 1) * 2 + verts = np.zeros(nv * 3, dtype=np.float32) + idxs = np.zeros(nt * 3, dtype=np.int32) + mesh_mod.triangulate_terrain(verts, idxs, raster, scale=1.0) + + if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: + verts[0::3] *= self.pixel_spacing_x + verts[1::3] *= self.pixel_spacing_y + + cache_key = (self.subsample_factor, mesh_type) + self._terrain_mesh_cache[cache_key] = ( + verts.copy(), idxs.copy(), terrain_np.copy(), + ) - rtx.add_geometry('terrain', verts, idxs) + rtx.add_geometry('terrain', verts, idxs) def _get_front(self): """Get the forward direction vector.""" @@ -480,15 +1528,28 @@ def _build_title(self): if not self.shadows: parts.append('no shadows') + # Ambient occlusion + if self.ao_enabled: + total = self._ao_frame_count * self._ao_samples_per_frame + cap = self._ao_max_frames * self._ao_samples_per_frame + parts.append(f'AO {total}/{cap}') + + # Denoiser + if self.denoise_enabled: + parts.append('DENOISE') + # Wind if self._wind_enabled: parts.append('wind') - # Drone mode - if self._drone_mode == 'fpv': - parts.append('DRONE FPV') - elif self._drone_mode == '3rd': - parts.append('DRONE 3RD') + # Active observer drone mode + active_obs = (self._observers.get(self._active_observer) + if self._active_observer else None) + if active_obs is not None: + if active_obs.drone_mode == 'fpv': + parts.append(f'OBS{active_obs.slot} FPV') + elif active_obs.drone_mode == '3rd': + parts.append(f'OBS{active_obs.slot} 3RD') return ' \u2502 '.join(parts) @@ -567,6 +1628,23 @@ def _compute_minimap_background(self): rgba[:, :, :3] = np.clip(rgba[:, :, :3], 0, 1) + # Blend satellite imagery if tile service has fetched tiles + if (self._tile_service is not None + and getattr(self._tile_service, '_fetched', None)): + cpu_tex = getattr(self._tile_service, '_rgb_texture', None) + if cpu_tex is not None and cpu_tex.shape[0] == H and cpu_tex.shape[1] == W: + y_idx_t = np.linspace(0, H - 1, new_h).astype(int) + x_idx_t = np.linspace(0, W - 1, new_w).astype(int) + sat_small = cpu_tex[np.ix_(y_idx_t, x_idx_t)] # (new_h, new_w, 3) + # Only blend where satellite has actual data (not all-black) + has_coverage = sat_small.max(axis=2) > 0.01 + blended = np.zeros_like(rgba[:, :, :3]) + for c in range(3): + blended[:, :, c] = sat_small[:, :, c] * 0.7 + rgba[:, :, c] * 0.3 + for c in range(3): + rgba[:, :, c] = np.where(has_coverage, blended[:, :, c], rgba[:, :, c]) + self._minimap_has_tiles = True + self._minimap_background = rgba self._minimap_scale_x = new_w / W self._minimap_scale_y = new_h / H @@ -598,6 +1676,7 @@ def _rebuild_at_resolution(self, factor): sub = base self.raster = sub + self._wind_terrain_np = None # invalidate cached terrain H, W = sub.shape self.terrain_shape = (H, W) @@ -609,48 +1688,75 @@ def _rebuild_at_resolution(self, factor): ve = self.vertical_exaggeration cache_key = (factor, self.mesh_type) - if cache_key in self._terrain_mesh_cache: - # Cache hit — reuse pre-built mesh (stored at scale=1.0) - verts_base, indices, terrain_np = self._terrain_mesh_cache[cache_key] - vertices = verts_base.copy() - if ve != 1.0: - vertices[2::3] *= ve - else: - # Cache miss — build mesh at scale=1.0 and cache it - terrain_data = sub.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() + if self.mesh_type == 'heightfield': + # Heightfield path: no triangle mesh needed + if cache_key in self._terrain_mesh_cache: + _, _, terrain_np = self._terrain_mesh_cache[cache_key] else: - terrain_np = np.asarray(terrain_data) + terrain_data = sub.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() + else: + terrain_np = np.asarray(terrain_data) + self._terrain_mesh_cache[cache_key] = ( + None, None, terrain_np.copy(), + ) - if self.mesh_type == 'voxel': - num_verts = H * W * 8 - num_tris = H * W * 12 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(vertices, indices, sub, scale=1.0, - base_elevation=base_elev) + if self.rtx is not None: + self.rtx.add_heightfield_geometry( + 'terrain', terrain_np, H, W, + spacing_x=self.pixel_spacing_x, + spacing_y=self.pixel_spacing_y, + ve=ve, + ) + else: + if cache_key in self._terrain_mesh_cache: + # Cache hit — reuse pre-built mesh (stored at scale=1.0) + verts_base, indices, terrain_np = self._terrain_mesh_cache[cache_key] + vertices = verts_base.copy() + if ve != 1.0: + vertices[2::3] *= ve else: - num_verts = H * W - num_tris = (H - 1) * (W - 1) * 2 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(vertices, indices, sub, scale=1.0) - - # Scale x,y to world units - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - vertices[0::3] *= self.pixel_spacing_x - vertices[1::3] *= self.pixel_spacing_y + # Cache miss — build mesh at scale=1.0 and cache it + terrain_data = sub.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() + else: + terrain_np = np.asarray(terrain_data) + + if self.mesh_type == 'voxel': + num_verts = H * W * 8 + num_tris = H * W * 12 + vertices = np.zeros(num_verts * 3, dtype=np.float32) + indices = np.zeros(num_tris * 3, dtype=np.int32) + base_elev = float(np.nanmin(terrain_np)) + mesh_mod.voxelate_terrain(vertices, indices, sub, scale=1.0, + base_elevation=base_elev) + else: + num_verts = H * W + num_tris = (H - 1) * (W - 1) * 2 + vertices = np.zeros(num_verts * 3, dtype=np.float32) + indices = np.zeros(num_tris * 3, dtype=np.int32) + mesh_mod.triangulate_terrain(vertices, indices, sub, scale=1.0) + + # Scale x,y to world units + if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: + vertices[0::3] *= self.pixel_spacing_x + vertices[1::3] *= self.pixel_spacing_y + + # Store in cache (scale=1.0, x/y already scaled) + self._terrain_mesh_cache[cache_key] = ( + vertices.copy(), indices.copy(), terrain_np.copy() + ) - # Store in cache (scale=1.0, x/y already scaled) - self._terrain_mesh_cache[cache_key] = ( - vertices.copy(), indices.copy(), terrain_np.copy() - ) + # Apply VE to this copy + if ve != 1.0: + vertices[2::3] *= ve - # Apply VE to this copy - if ve != 1.0: - vertices[2::3] *= ve + # 4. Replace terrain geometry (add_geometry overwrites existing key + # in-place, preserving dict insertion order and instance IDs) + if self.rtx is not None: + self.rtx.add_geometry('terrain', vertices, indices) self.elev_min = float(np.nanmin(terrain_np)) * ve self.elev_max = float(np.nanmax(terrain_np)) * ve @@ -664,11 +1770,6 @@ def _rebuild_at_resolution(self, factor): self._land_color_range = (float(np.nanmin(land_pixels)) * ve, float(np.nanmax(land_pixels)) * ve) - # 4. Replace terrain geometry (add_geometry overwrites existing key - # in-place, preserving dict insertion order and instance IDs) - if self.rtx is not None: - self.rtx.add_geometry('terrain', vertices, indices) - # 5. Subsample overlay layers if self._base_overlay_layers: self._overlay_layers = {} @@ -687,64 +1788,131 @@ def _rebuild_at_resolution(self, factor): if terrain_name != 'elevation' and terrain_name in self._overlay_layers: self._active_overlay_data = self._overlay_layers[terrain_name] - # 6. Re-snap placed meshes to new terrain surface + # 6. Invalidate chunk manager cache (meshes need new Z coords) + if self._chunk_manager is not None: + # Clear chunk cache and baked mesh entries for chunk-loaded geometries + for gid in list(self._chunk_manager._active_gids): + if hasattr(self, '_baked_meshes'): + self._baked_meshes.pop(gid, None) + if self._accessor is not None: + self._accessor._baked_meshes.pop(gid, None) + self._chunk_manager._cache.clear() + self._chunk_manager._visible.clear() + self._chunk_manager._active_gids.clear() + # Force immediate reload at new resolution + if hasattr(self, 'position'): + self._chunk_manager.update(self.position[0], self.position[1], self) + + # 7. Re-snap placed meshes to new terrain surface + # Invalidate GPU terrain cache (terrain changed) and upload once + self._gpu_terrain = None if self.rtx is not None: + gpu_terrain = None + if has_cupy: + gpu_terrain = cp.asarray(terrain_np) + self._gpu_terrain = gpu_terrain for geom_id in self.rtx.list_geometries(): if geom_id == 'terrain': continue # Baked meshes — re-snap Z to new terrain surface + VE if hasattr(self, '_baked_meshes') and geom_id in self._baked_meshes: + baked = self._baked_meshes[geom_id] + is_curve = (len(baked) == 4) baked_key = (factor, geom_id) if baked_key in self._baked_mesh_cache: - scaled_v, orig_idx = self._baked_mesh_cache[baked_key] + cached = self._baked_mesh_cache[baked_key] + if is_curve: + scaled_v, orig_w, orig_idx = cached + self.rtx.add_curve_geometry( + geom_id, scaled_v, orig_w, orig_idx) + else: + scaled_v, orig_idx = cached + self.rtx.add_geometry(geom_id, scaled_v, orig_idx) else: - baked = self._baked_meshes[geom_id] - if len(baked) == 3: + if is_curve: + orig_v, orig_w, orig_idx, orig_base_z = baked + elif len(baked) == 3: orig_v, orig_idx, orig_base_z = baked else: orig_v, orig_idx = baked orig_base_z = None - scaled_v = orig_v.copy() - if orig_base_z is not None: - # Sample new terrain Z at each vertex position - vx = orig_v[0::3] - vy = orig_v[1::3] - px = vx / self.pixel_spacing_x - py = vy / self.pixel_spacing_y - ix = np.clip(np.round(px).astype(int), 0, W - 1) - iy = np.clip(np.round(py).astype(int), 0, H - 1) - new_base_z = terrain_np[iy, ix].astype(np.float32) - new_base_z = np.where(np.isnan(new_base_z), 0.0, new_base_z) - z_offset = orig_v[2::3] - orig_base_z - scaled_v[2::3] = (new_base_z + z_offset) * ve + + n_verts = len(orig_v) // 3 + use_gpu = (gpu_terrain is not None + and orig_base_z is not None + and n_verts > 1000) + + if use_gpu: + vx = cp.asarray(orig_v[0::3]) + vy = cp.asarray(orig_v[1::3]) + new_base_z = _bilinear_terrain_z( + gpu_terrain, vx, vy, + self.pixel_spacing_x, self.pixel_spacing_y) + z_offset = cp.asarray(orig_v[2::3]) - cp.asarray(orig_base_z) + new_z = (new_base_z + z_offset) * ve + scaled_v_gpu = cp.asarray(orig_v.copy()) + scaled_v_gpu[2::3] = new_z + if is_curve: + self._baked_mesh_cache[baked_key] = ( + scaled_v_gpu.get().copy(), orig_w, orig_idx) + self.rtx.add_curve_geometry( + geom_id, scaled_v_gpu, + cp.asarray(orig_w), + cp.asarray(orig_idx)) + else: + self._baked_mesh_cache[baked_key] = ( + scaled_v_gpu.get().copy(), orig_idx) + self.rtx.add_geometry(geom_id, scaled_v_gpu, + cp.asarray(orig_idx)) else: - scaled_v[2::3] *= ve - self._baked_mesh_cache[baked_key] = (scaled_v.copy(), orig_idx) - self.rtx.add_geometry(geom_id, scaled_v, orig_idx) + scaled_v = orig_v.copy() + if orig_base_z is not None: + vx = orig_v[0::3] + vy = orig_v[1::3] + new_base_z = _bilinear_terrain_z( + terrain_np, vx, vy, + self.pixel_spacing_x, self.pixel_spacing_y) + z_offset = orig_v[2::3] - orig_base_z + scaled_v[2::3] = (new_base_z + z_offset) * ve + else: + scaled_v[2::3] *= ve + if is_curve: + self._baked_mesh_cache[baked_key] = ( + scaled_v.copy(), orig_w, orig_idx) + self.rtx.add_curve_geometry( + geom_id, scaled_v, orig_w, orig_idx) + else: + self._baked_mesh_cache[baked_key] = ( + scaled_v.copy(), orig_idx) + self.rtx.add_geometry(geom_id, scaled_v, orig_idx) continue # Instanced meshes — update transform Z from terrain transform = self.rtx.get_geometry_transform(geom_id) if transform is None: continue wx, wy = transform[3], transform[7] - px = wx / self.pixel_spacing_x - py = wy / self.pixel_spacing_y - ix = int(np.clip(px, 0, W - 1)) - iy = int(np.clip(py, 0, H - 1)) - z = float(terrain_np[iy, ix]) * ve + z = float(_bilinear_terrain_z( + terrain_np, + np.array([wx], dtype=np.float32), + np.array([wy], dtype=np.float32), + self.pixel_spacing_x, self.pixel_spacing_y)[0]) * ve transform[11] = z self.rtx.update_transform(geom_id, transform) - # 7. Recompute minimap + # 8. Re-snap all observer drones to new terrain + for obs in self._observers.values(): + if obs.drone_placed and obs.position is not None: + self._update_observer_drone_for(obs) + + # 9. Recompute minimap self._compute_minimap_background() - if self._minimap_im is not None: - self._minimap_im.set_data(self._minimap_background) - # Update minimap axes limits for new background size - mm_h, mm_w = self._minimap_background.shape[:2] - self._minimap_im.set_extent([-0.5, mm_w - 0.5, mm_h - 0.5, -0.5]) - # 8. Clear viewshed cache (no longer matches terrain) + # 10. Clear viewshed cache (no longer matches terrain) self._viewshed_cache = None + for obs in self._observers.values(): + obs.viewshed_cache = None + if obs.viewshed_enabled: + obs.viewshed_enabled = False if self.viewshed_enabled: self.viewshed_enabled = False print(" Viewshed disabled (terrain changed). Press V to recalculate.") @@ -768,48 +1936,70 @@ def _rebuild_vertical_exaggeration(self, ve): # Use cached mesh if available, otherwise build and cache cache_key = (self.subsample_factor, self.mesh_type) - if cache_key in self._terrain_mesh_cache: - verts_base, indices, terrain_np = self._terrain_mesh_cache[cache_key] - vertices = verts_base.copy() - if ve != 1.0: - vertices[2::3] *= ve - else: - terrain_data = self.raster.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() + if self.mesh_type == 'heightfield': + # Heightfield path: rebuild GAS with new VE + if cache_key in self._terrain_mesh_cache: + _, _, terrain_np = self._terrain_mesh_cache[cache_key] else: - terrain_np = np.asarray(terrain_data) + terrain_data = self.raster.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() + else: + terrain_np = np.asarray(terrain_data) + self._terrain_mesh_cache[cache_key] = ( + None, None, terrain_np.copy(), + ) - if self.mesh_type == 'voxel': - nv = H * W * 8 - nt = H * W * 12 - vertices = np.zeros(nv * 3, dtype=np.float32) - indices = np.zeros(nt * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(vertices, indices, self.raster, - scale=1.0, base_elevation=base_elev) + if self.rtx is not None: + self.rtx.add_heightfield_geometry( + 'terrain', terrain_np, H, W, + spacing_x=self.pixel_spacing_x, + spacing_y=self.pixel_spacing_y, + ve=ve, + ) + else: + if cache_key in self._terrain_mesh_cache: + verts_base, indices, terrain_np = self._terrain_mesh_cache[cache_key] + vertices = verts_base.copy() + if ve != 1.0: + vertices[2::3] *= ve else: - nv = H * W - nt = (H - 1) * (W - 1) * 2 - vertices = np.zeros(nv * 3, dtype=np.float32) - indices = np.zeros(nt * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(vertices, indices, self.raster, - scale=1.0) - - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - vertices[0::3] *= self.pixel_spacing_x - vertices[1::3] *= self.pixel_spacing_y - - self._terrain_mesh_cache[cache_key] = ( - vertices.copy(), indices.copy(), terrain_np.copy() - ) + terrain_data = self.raster.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() + else: + terrain_np = np.asarray(terrain_data) + + if self.mesh_type == 'voxel': + nv = H * W * 8 + nt = H * W * 12 + vertices = np.zeros(nv * 3, dtype=np.float32) + indices = np.zeros(nt * 3, dtype=np.int32) + base_elev = float(np.nanmin(terrain_np)) + mesh_mod.voxelate_terrain(vertices, indices, self.raster, + scale=1.0, base_elevation=base_elev) + else: + nv = H * W + nt = (H - 1) * (W - 1) * 2 + vertices = np.zeros(nv * 3, dtype=np.float32) + indices = np.zeros(nt * 3, dtype=np.int32) + mesh_mod.triangulate_terrain(vertices, indices, self.raster, + scale=1.0) + + if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: + vertices[0::3] *= self.pixel_spacing_x + vertices[1::3] *= self.pixel_spacing_y + + self._terrain_mesh_cache[cache_key] = ( + vertices.copy(), indices.copy(), terrain_np.copy() + ) - if ve != 1.0: - vertices[2::3] *= ve + if ve != 1.0: + vertices[2::3] *= ve - # Replace terrain geometry (preserves dict insertion order) - if self.rtx is not None: - self.rtx.add_geometry('terrain', vertices, indices) + # Replace terrain geometry (preserves dict insertion order) + if self.rtx is not None: + self.rtx.add_geometry('terrain', vertices, indices) # Update elevation stats (scaled) self.elev_min = float(np.nanmin(terrain_np)) * ve @@ -825,50 +2015,93 @@ def _rebuild_vertical_exaggeration(self, ve): float(np.nanmax(land_pixels)) * ve) # Re-snap placed meshes to scaled terrain + # Invalidate GPU terrain cache (VE changed terrain Z) and upload once + self._gpu_terrain = None if self.rtx is not None: + gpu_terrain = None + if has_cupy: + gpu_terrain = cp.asarray(terrain_np) + self._gpu_terrain = gpu_terrain for geom_id in self.rtx.list_geometries(): if geom_id == 'terrain': continue - # Baked meshes (merged buildings) — re-snap Z to terrain + VE + # Baked meshes (merged buildings/curves) — re-snap Z to terrain + VE if hasattr(self, '_baked_meshes') and geom_id in self._baked_meshes: baked = self._baked_meshes[geom_id] - if len(baked) == 3: + is_curve = (len(baked) == 4) + if is_curve: + orig_v, orig_w, orig_idx, orig_base_z = baked + elif len(baked) == 3: orig_v, orig_idx, orig_base_z = baked else: orig_v, orig_idx = baked orig_base_z = None - scaled_v = orig_v.copy() - if orig_base_z is not None: - # Sample current terrain Z at each vertex position - vx = orig_v[0::3] - vy = orig_v[1::3] - px = vx / self.pixel_spacing_x - py = vy / self.pixel_spacing_y - ix = np.clip(np.round(px).astype(int), 0, W - 1) - iy = np.clip(np.round(py).astype(int), 0, H - 1) - cur_base_z = terrain_np[iy, ix].astype(np.float32) - cur_base_z = np.where(np.isnan(cur_base_z), 0.0, cur_base_z) - z_offset = orig_v[2::3] - orig_base_z - scaled_v[2::3] = (cur_base_z + z_offset) * ve + + n_verts = len(orig_v) // 3 + use_gpu = (gpu_terrain is not None + and orig_base_z is not None + and n_verts > 1000) + + if use_gpu: + vx = cp.asarray(orig_v[0::3]) + vy = cp.asarray(orig_v[1::3]) + cur_base_z = _bilinear_terrain_z( + gpu_terrain, vx, vy, + self.pixel_spacing_x, self.pixel_spacing_y) + z_offset = cp.asarray(orig_v[2::3]) - cp.asarray(orig_base_z) + new_z = (cur_base_z + z_offset) * ve + scaled_v_gpu = cp.asarray(orig_v.copy()) + scaled_v_gpu[2::3] = new_z + if is_curve: + self.rtx.add_curve_geometry( + geom_id, scaled_v_gpu, + cp.asarray(orig_w), + cp.asarray(orig_idx)) + else: + self.rtx.add_geometry(geom_id, scaled_v_gpu, + cp.asarray(orig_idx)) else: - scaled_v[2::3] *= ve - self.rtx.add_geometry(geom_id, scaled_v, orig_idx) + scaled_v = orig_v.copy() + if orig_base_z is not None: + vx = orig_v[0::3] + vy = orig_v[1::3] + cur_base_z = _bilinear_terrain_z( + terrain_np, vx, vy, + self.pixel_spacing_x, self.pixel_spacing_y) + z_offset = orig_v[2::3] - orig_base_z + scaled_v[2::3] = (cur_base_z + z_offset) * ve + else: + scaled_v[2::3] *= ve + if is_curve: + self.rtx.add_curve_geometry( + geom_id, scaled_v, orig_w, orig_idx) + else: + self.rtx.add_geometry(geom_id, scaled_v, orig_idx) continue # Instanced meshes — update transform Z from terrain transform = self.rtx.get_geometry_transform(geom_id) if transform is None: continue wx, wy = transform[3], transform[7] - px = wx / self.pixel_spacing_x - py = wy / self.pixel_spacing_y - ix = int(np.clip(px, 0, W - 1)) - iy = int(np.clip(py, 0, H - 1)) - z = float(terrain_np[iy, ix]) * ve + z = float(_bilinear_terrain_z( + terrain_np, + np.array([wx], dtype=np.float32), + np.array([wy], dtype=np.float32), + self.pixel_spacing_x, self.pixel_spacing_y)[0]) * ve transform[11] = z self.rtx.update_transform(geom_id, transform) + # Re-snap all observer drones to updated terrain + for obs in self._observers.values(): + if obs.drone_placed and obs.position is not None: + self._update_observer_drone_for(obs) + # Clear viewshed cache self._viewshed_cache = None + for obs in self._observers.values(): + obs.viewshed_cache = None + if obs.viewshed_enabled: + obs.viewshed_enabled = False if self.viewshed_enabled: self.viewshed_enabled = False print(" Viewshed disabled (terrain changed). Press V to recalculate.") @@ -876,128 +2109,322 @@ def _rebuild_vertical_exaggeration(self, ve): print(f"Vertical exaggeration: {ve:.2f}x") self._update_frame() - def _create_minimap(self): - """Create the minimap inset axes and persistent artists.""" - if self._minimap_background is None: - return - - # Create inset axes in bottom-right corner (~20% of figure width) - mm_h, mm_w = self._minimap_background.shape[:2] - aspect = mm_h / mm_w - ax_width = 0.2 - ax_height = ax_width * aspect * (self.width / self.height) - # Clamp height so it doesn't get too tall - ax_height = min(ax_height, 0.35) - margin = 0.02 - self._minimap_ax = self.fig.add_axes( - [1 - ax_width - margin, margin, ax_width, ax_height] - ) - self._minimap_ax.set_xticks([]) - self._minimap_ax.set_yticks([]) - for spine in self._minimap_ax.spines.values(): - spine.set_edgecolor('#555555') - spine.set_linewidth(0.6) - - # Display RGBA background (origin='upper' so row 0 is top = +Y) - self._minimap_ax.set_facecolor('#0a0c14') - self._minimap_im = self._minimap_ax.imshow( - self._minimap_background, - aspect='auto', origin='upper' - ) + def _project_corner_to_terrain(self, nx, ny, cam_pos, forward, right, + up_cam, fov_scale, aspect, terrain_z): + """Project an NDC screen corner onto the terrain z-plane. - # FOV wedge (filled semi-transparent cone showing visible area) - from matplotlib.patches import Polygon - self._minimap_fov_wedge = Polygon( - [[0, 0]], closed=True, facecolor='red', alpha=0.25, - edgecolor='red', linewidth=0.8, zorder=3 - ) - self._minimap_ax.add_patch(self._minimap_fov_wedge) + Parameters + ---------- + nx, ny : float + Normalised device coords (-1..1), where (-1,-1) = bottom-left. + cam_pos : ndarray (3,) + Camera position in world space (already VE-scaled Z). + forward, right, up_cam : ndarray (3,) + Camera basis vectors. + fov_scale : float + tan(fov/2). + aspect : float + Width / height. + terrain_z : float + Z plane to intersect (mean_elev * VE). - # Direction line (bright, with arrowhead effect via thicker line) - self._minimap_direction_line, = self._minimap_ax.plot( - [], [], color='#ff4444', linewidth=2.0, solid_capstyle='round', zorder=4 - ) + Returns + ------- + (world_x, world_y) or None if ray doesn't hit ground. + """ + ray_dir = forward + nx * fov_scale * aspect * right + ny * fov_scale * up_cam + norm = np.linalg.norm(ray_dir) + if norm < 1e-8: + return None + ray_dir /= norm - # Camera position dot (white-outlined red) - self._minimap_camera_dot = self._minimap_ax.scatter( - [], [], c='red', s=30, zorder=5, edgecolors='white', linewidths=0.8 - ) + # Intersect with z = terrain_z plane + if abs(ray_dir[2]) < 1e-8: + # Ray parallel to ground — project far forward + t = 1e5 + else: + t = (terrain_z - cam_pos[2]) / ray_dir[2] + + if t < 0: + # Looking up past horizon — project far forward along horizontal + horiz = np.array([ray_dir[0], ray_dir[1], 0.0]) + hn = np.linalg.norm(horiz) + if hn < 1e-8: + return None + horiz /= hn + far_dist = max(self.terrain_shape[0] * self.pixel_spacing_y, + self.terrain_shape[1] * self.pixel_spacing_x) + return (cam_pos[0] + horiz[0] * far_dist, + cam_pos[1] + horiz[1] * far_dist) + + hit = cam_pos + ray_dir * t + return (float(hit[0]), float(hit[1])) + + def _blit_minimap_on_frame(self, img): + """Composite minimap overlay onto the rendered frame (numpy blit). + + Draws the minimap background with rounded corners and drop shadow, + terrain footprint quad, camera dot, direction line, and observer dots + directly onto the frame array in the bottom-right. - # Observer dot (magenta star) - self._minimap_observer_dot = self._minimap_ax.scatter( - [], [], c='magenta', s=50, marker='*', zorder=6, - edgecolors='white', linewidths=0.3 - ) + Parameters + ---------- + img : ndarray, shape (H, W, 3), float32 0-1 + Rendered frame to composite onto. Modified in-place. + """ + if self._minimap_background is None or not self.show_minimap: + return - self._minimap_ax.set_visible(self.show_minimap) + # Lazy re-check: pick up satellite tiles once they arrive + if (not self._minimap_has_tiles + and self._tile_service is not None + and getattr(self._tile_service, '_fetched', None)): + self._compute_minimap_background() + + mm_bg = self._minimap_background # (mm_h, mm_w, 4) RGBA float32 + mm_h, mm_w = mm_bg.shape[:2] + fh, fw = img.shape[:2] + + # Size the minimap to ~20% of frame width + target_w = max(40, int(fw * 0.2)) + scale = target_w / mm_w + target_h = max(20, int(mm_h * scale)) + target_w = min(target_w, fw - 8) + target_h = min(target_h, fh - 8) + + # Nearest-neighbour resize + y_idx = np.linspace(0, mm_h - 1, target_h).astype(int) + x_idx = np.linspace(0, mm_w - 1, target_w).astype(int) + bg_resized = mm_bg[np.ix_(y_idx, x_idx)].copy() # (th, tw, 4) + + # --- Rounded corner mask --- + corner_radius = min(8, target_h // 4, target_w // 4) + if corner_radius > 1: + mask = np.ones((target_h, target_w), dtype=np.float32) + yy = np.arange(target_h)[:, None] + xx = np.arange(target_w)[None, :] + # Four corners: (cy, cx) of the inscribed circle center + corners = [ + (corner_radius, corner_radius), # top-left + (corner_radius, target_w - 1 - corner_radius), # top-right + (target_h - 1 - corner_radius, corner_radius), # bottom-left + (target_h - 1 - corner_radius, target_w - 1 - corner_radius), # bottom-right + ] + for cy, cx in corners: + # Select the corner quadrant + if cy <= corner_radius: + row_sel = yy < corner_radius + else: + row_sel = yy > target_h - 1 - corner_radius + if cx <= corner_radius: + col_sel = xx < corner_radius + else: + col_sel = xx > target_w - 1 - corner_radius + in_corner = row_sel & col_sel + dist_sq = (yy - cy) ** 2 + (xx - cx) ** 2 + outside_circle = dist_sq > corner_radius ** 2 + mask = np.where(in_corner & outside_circle, 0.0, mask) + bg_resized[:, :, 3] *= mask + + # Placement: bottom-right with 6px margin + margin = 6 + y0 = fh - target_h - margin + x0 = fw - target_w - margin + + # --- Drop shadow (dark rounded rect offset by 2px) --- + shadow_off = 2 + sy0 = y0 + shadow_off + sx0 = x0 + shadow_off + sy1 = min(sy0 + target_h, fh) + sx1 = min(sx0 + target_w, fw) + sh = sy1 - sy0 + sw = sx1 - sx0 + if sh > 0 and sw > 0: + shadow_alpha = 0.35 + if corner_radius > 1: + shadow_mask = mask[:sh, :sw] * shadow_alpha + else: + shadow_mask = np.full((sh, sw), shadow_alpha, dtype=np.float32) + shadow_region = img[sy0:sy1, sx0:sx1] + shadow_region[:] = shadow_region * (1 - shadow_mask[:, :, None]) - def _update_minimap(self): - """Update minimap artists with current camera/observer state.""" - if self._minimap_ax is None: - return + # Alpha-composite background onto frame + alpha = bg_resized[:, :, 3:4] + rgb = bg_resized[:, :, :3] + region = img[y0:y0+target_h, x0:x0+target_w] + region[:] = region * (1 - alpha) + rgb * alpha - self._minimap_ax.set_visible(self.show_minimap) - if not self.show_minimap: - return + # Store minimap rect for click-to-teleport + self._minimap_rect = (x0, y0, target_w, target_h) + # --- Terrain footprint (visible area quad) --- H, W = self.terrain_shape - - # Convert camera world position to minimap pixel coords - # World coords: x = col * pixel_spacing_x, y = row * pixel_spacing_y - # Pixel indices: col = x / pixel_spacing_x, row = y / pixel_spacing_y - # Minimap coords: mx = col * scale_x, my = row * scale_y cam_col = self.position[0] / self.pixel_spacing_x cam_row = self.position[1] / self.pixel_spacing_y + # Minimap local coords + lx = cam_col / W * target_w + ly = cam_row / H * target_h - mx = cam_col * self._minimap_scale_x - # Flip Y: minimap origin='upper', so row 0 is displayed at top - # In world coords, +Y is increasing row. With origin='upper', - # imshow row 0 is at top, so minimap y = row * scale_y directly. - my = cam_row * self._minimap_scale_y - - # Update camera dot - self._minimap_camera_dot.set_offsets([[mx, my]]) - - # Direction line length in minimap pixels - mm_h, mm_w = self._minimap_background.shape[:2] - line_len = max(mm_h, mm_w) * 0.12 - - # Yaw: 0 = +X (right on minimap), 90 = +Y (down on minimap with origin='upper') + ve = self.vertical_exaggeration + terrain_z = self._get_terrain_z(self.position[0], self.position[1]) * ve + + # Camera basis in VE-scaled space + pos_ve = np.array([self.position[0], self.position[1], + self.position[2] * ve], dtype=np.float32) + look_ve = np.array([self.position[0] + self._get_front()[0] * 1000, + self.position[1] + self._get_front()[1] * 1000, + (self.position[2] + self._get_front()[2] * 1000) * ve], + dtype=np.float32) + # Simple basis from yaw/pitch yaw_rad = np.radians(self.yaw) - dx = np.cos(yaw_rad) * line_len - dy = np.sin(yaw_rad) * line_len # +Y in world = +row = down in minimap - - self._minimap_direction_line.set_data([mx, mx + dx], [my, my + dy]) - - # FOV wedge (filled triangle from camera through left/right edges) - half_fov = np.radians(self.fov / 2) - fov_len = line_len * 0.8 - - left_angle = yaw_rad - half_fov - right_angle = yaw_rad + half_fov - - lx = np.cos(left_angle) * fov_len - ly = np.sin(left_angle) * fov_len - rx = np.cos(right_angle) * fov_len - ry = np.sin(right_angle) * fov_len - - self._minimap_fov_wedge.set_xy([ - [mx, my], - [mx + lx, my + ly], - [mx + rx, my + ry], - ]) - - # Observer dot - if self._observer_position is not None: - obs_x, obs_y = self._observer_position - obs_col = obs_x / self.pixel_spacing_x - obs_row = obs_y / self.pixel_spacing_y - omx = obs_col * self._minimap_scale_x - omy = obs_row * self._minimap_scale_y - self._minimap_observer_dot.set_offsets([[omx, omy]]) - self._minimap_observer_dot.set_visible(True) + pitch_rad = np.radians(self.pitch) + forward = np.array([ + np.cos(yaw_rad) * np.cos(pitch_rad), + np.sin(yaw_rad) * np.cos(pitch_rad), + np.sin(pitch_rad), + ], dtype=np.float32) + world_up = np.array([0, 0, 1], dtype=np.float32) + right = np.cross(world_up, forward) + rn = np.linalg.norm(right) + if rn > 1e-8: + right /= rn else: - self._minimap_observer_dot.set_visible(False) + right = np.array([1, 0, 0], dtype=np.float32) + up_cam = np.cross(forward, right) + + fov_scale = np.tan(np.radians(self.fov) / 2.0) + aspect = self.render_width / max(1, self.render_height) + + # Project 4 screen corners onto terrain z-plane + ndc_corners = [(-1, -1), (1, -1), (1, 1), (-1, 1)] # BL, BR, TR, TL + mm_corners = [] + for nx, ny in ndc_corners: + hit = self._project_corner_to_terrain( + nx, ny, pos_ve, forward, right, up_cam, fov_scale, aspect, terrain_z) + if hit is None: + mm_corners = [] + break + # Convert world XY to minimap-local coords + mcol = hit[0] / self.pixel_spacing_x / W * target_w + mrow = hit[1] / self.pixel_spacing_y / H * target_h + mm_corners.append((mcol, mrow)) + + if len(mm_corners) == 4: + pts = np.array(mm_corners) # (4, 2) + # Fill as two triangles (BL-BR-TR and BL-TR-TL) + tri1 = pts[[0, 1, 2]] + tri2 = pts[[0, 2, 3]] + fill_color = np.array([0.9, 0.9, 0.9]) + self._fill_triangle(img, tri1, x0, y0, target_w, target_h, + color=fill_color, alpha_val=0.12) + self._fill_triangle(img, tri2, x0, y0, target_w, target_h, + color=fill_color, alpha_val=0.12) + # Outline edges + edge_color = np.array([0.8, 0.8, 0.8]) + for i in range(4): + j = (i + 1) % 4 + self._draw_line(img, pts[i, 0], pts[i, 1], + pts[j, 0], pts[j, 1], + x0, y0, target_w, target_h, + color=edge_color, thickness=1) + + # Direction line (2px wide red) + line_len = max(target_h, target_w) * 0.12 + ex = lx + np.cos(yaw_rad) * line_len + ey = ly + np.sin(yaw_rad) * line_len + self._draw_line(img, lx, ly, ex, ey, x0, y0, target_w, target_h, + color=np.array([1.0, 0.27, 0.27]), thickness=2) + + # Camera dot (red circle, r=3) + self._draw_dot(img, lx, ly, x0, y0, target_w, target_h, + color=np.array([1.0, 0.0, 0.0]), radius=3) + + # Observer dots — colored per-slot, active gets larger radius + for slot, obs in self._observers.items(): + if obs.position is None: + continue + obs_x, obs_y = obs.position + obs_lx = (obs_x / self.pixel_spacing_x) / W * target_w + obs_ly = (obs_y / self.pixel_spacing_y) / H * target_h + r = 4 if slot == self._active_observer else 2 + self._draw_dot(img, obs_lx, obs_ly, x0, y0, target_w, target_h, + color=np.array(obs.color), radius=r) + + @staticmethod + def _draw_dot(img, lx, ly, x0, y0, tw, th, color, radius=3): + """Draw a filled circle at minimap-local (lx, ly) onto frame.""" + fh, fw = img.shape[:2] + cx = int(round(lx)) + x0 + cy = int(round(ly)) + y0 + # Clip to minimap rect intersected with frame + clip_x0, clip_y0 = max(0, x0), max(0, y0) + clip_x1, clip_y1 = min(fw, x0 + tw), min(fh, y0 + th) + for dy in range(-radius, radius + 1): + for dx in range(-radius, radius + 1): + if dx*dx + dy*dy <= radius*radius: + px, py = cx + dx, cy + dy + if clip_x0 <= px < clip_x1 and clip_y0 <= py < clip_y1: + img[py, px, :] = color + + @staticmethod + def _draw_line(img, x1, y1, x2, y2, x0, y0, tw, th, color, thickness=1): + """Draw a line from (x1,y1) to (x2,y2) in minimap-local coords.""" + fh, fw = img.shape[:2] + # Clip to minimap rect intersected with frame + clip_x0, clip_y0 = max(0, x0), max(0, y0) + clip_x1, clip_y1 = min(fw, x0 + tw), min(fh, y0 + th) + steps = max(2, int(np.sqrt((x2-x1)**2 + (y2-y1)**2) * 2)) + for i in range(steps + 1): + t = i / steps + px = int(round(x1 + (x2-x1)*t)) + x0 + py = int(round(y1 + (y2-y1)*t)) + y0 + for d in range(-(thickness//2), thickness//2 + 1): + for e in range(-(thickness//2), thickness//2 + 1): + ppx, ppy = px + d, py + e + if clip_x0 <= ppx < clip_x1 and clip_y0 <= ppy < clip_y1: + img[ppy, ppx, :] = color + + @staticmethod + def _fill_triangle(img, tri, x0, y0, tw, th, color, alpha_val=0.25): + """Rasterize a filled triangle onto the frame with alpha blending.""" + fh, fw = img.shape[:2] + # Bounding box in frame coords, clipped to minimap rect + pts_x = tri[:, 0] + x0 + pts_y = tri[:, 1] + y0 + clip_x0, clip_y0 = max(0, x0), max(0, y0) + clip_x1, clip_y1 = min(fw - 1, x0 + tw - 1), min(fh - 1, y0 + th - 1) + min_x = max(clip_x0, int(np.floor(pts_x.min()))) + max_x = min(clip_x1, int(np.ceil(pts_x.max()))) + min_y = max(clip_y0, int(np.floor(pts_y.min()))) + max_y = min(clip_y1, int(np.ceil(pts_y.max()))) + + # Vectorised point-in-triangle using barycentric coords + v0 = tri[2] - tri[0] + v1 = tri[1] - tri[0] + d00 = v0[0]*v0[0] + v0[1]*v0[1] + d01 = v0[0]*v1[0] + v0[1]*v1[1] + d11 = v1[0]*v1[0] + v1[1]*v1[1] + denom = d00*d11 - d01*d01 + if abs(denom) < 1e-12: + return + + ys = np.arange(min_y, max_y + 1) + xs = np.arange(min_x, max_x + 1) + if len(ys) == 0 or len(xs) == 0: + return + gx, gy = np.meshgrid(xs, ys) + v2x = gx - (tri[0, 0] + x0) + v2y = gy - (tri[0, 1] + y0) + d20 = v2x*v0[0] + v2y*v0[1] + d21 = v2x*v1[0] + v2y*v1[1] + u = (d11*d20 - d01*d21) / denom + v = (d00*d21 - d01*d20) / denom + inside = (u >= 0) & (v >= 0) & (u + v <= 1) + + if inside.any(): + iy = gy[inside] + ix = gx[inside] + img[iy, ix, :] = img[iy, ix, :] * (1 - alpha_val) + color * alpha_val # ------------------------------------------------------------------ # Wind particle animation @@ -1127,6 +2554,8 @@ def _init_wind(self, wind_data): self._wind_dot_radius = int(wind_data['dot_radius']) if 'alpha' in wind_data: self._wind_alpha = float(wind_data['alpha']) + if 'min_visible_age' in wind_data: + self._wind_min_visible_age = int(wind_data['min_visible_age']) from .tiles import _build_latlon_grids raster = self._base_raster @@ -1175,7 +2604,9 @@ def _init_wind(self, wind_data): else: terrain_np = np.asarray(terrain_data) # Gradient in row/col directions (units: elevation per pixel) - grad_row, grad_col = np.gradient(terrain_np.astype(np.float32)) + # NaN-fill so ocean/water pixels have zero slope influence and + # particles flow purely by wind over water. + grad_row, grad_col = np.gradient(np.nan_to_num(terrain_np.astype(np.float32), nan=0.0)) # Downslope force = -gradient (pushes particles toward lower elevation) # Scale relative to wind speed so slope matters but doesn't dominate slope_scale = dt * sm * 0.15 @@ -1220,8 +2651,9 @@ def _update_wind_particles(self): # Bilinear sample wind at particle positions rows = pts[:, 0] cols = pts[:, 1] - r0 = np.clip(np.floor(rows).astype(int), 0, H - 2) - c0 = np.clip(np.floor(cols).astype(int), 0, W - 2) + # Replace NaN with 0 before int cast to avoid RuntimeWarning + r0 = np.clip(np.floor(np.nan_to_num(rows, nan=0.0)).astype(int), 0, H - 2) + c0 = np.clip(np.floor(np.nan_to_num(cols, nan=0.0)).astype(int), 0, W - 2) fr = rows - r0 fc = cols - c0 @@ -1270,15 +2702,17 @@ def _update_wind_particles(self): u_val += slope_u * dampen v_val += slope_v * dampen - # Advect - pts[:, 0] += v_val # row - pts[:, 1] += u_val # col + # Advect (scale by dt so wind speed is frame-rate independent) + s = self._dt_scale + pts[:, 0] += v_val * s # row + pts[:, 1] += u_val * s # col # Age particles self._wind_ages += 1 - # Respawn out-of-bounds or aged-out particles - oob = (pts[:, 0] < 0) | (pts[:, 0] >= H) | (pts[:, 1] < 0) | (pts[:, 1] >= W) + # Respawn out-of-bounds, NaN, or aged-out particles + nan_pos = np.isnan(pts[:, 0]) | np.isnan(pts[:, 1]) + oob = nan_pos | (pts[:, 0] < 0) | (pts[:, 0] >= H) | (pts[:, 1] < 0) | (pts[:, 1] >= W) old = self._wind_ages >= self._wind_lifetimes respawn = oob | old @@ -1297,9 +2731,8 @@ def _update_wind_particles(self): def _draw_wind_on_frame(self, img): """Project wind particles to screen space and draw on rendered frame. - Uses the same pinhole camera model as the ray tracer to project - 3D particle world positions onto 2D screen pixels, then draws - semi-transparent white dots with short trails. + Fully vectorised: projects all trail positions in one batch, then + splats them with a single ``np.add.at`` call per stamp offset. Parameters ---------- @@ -1313,7 +2746,8 @@ def _draw_wind_on_frame(self, img): import math sh, sw = img.shape[:2] - pts = self._wind_particles # (N, 2) — (row, col) in base-raster pixel coords + N = self._wind_particles.shape[0] + trail_len = self._wind_trail_len # Camera basis matching the ray tracer cam_pos = self.position @@ -1322,126 +2756,343 @@ def _draw_wind_on_frame(self, img): tuple(cam_pos), tuple(look_at), (0, 0, 1), ) fov_scale = math.tan(math.radians(self.fov) / 2.0) - aspect = sw / sh + aspect_ratio = sw / sh - # Get terrain elevation for Z coordinate - terrain_data = self.raster.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) + # Cached CPU terrain — avoid GPU→CPU copy every frame + if self._wind_terrain_np is None: + terrain_data = self.raster.data + if hasattr(terrain_data, 'get'): + self._wind_terrain_np = terrain_data.get() + else: + self._wind_terrain_np = np.asarray(terrain_data) + terrain_np = self._wind_terrain_np tH, tW = terrain_np.shape - # Convert particle pixel coords to world coords - # Account for subsample factor: particles are in base raster coords f = self.subsample_factor psx = self._base_pixel_spacing_x psy = self._base_pixel_spacing_y + ve = self.vertical_exaggeration + min_depth = self._wind_min_depth - def _project_points(rows, cols): - """Project (row, col) in base-raster space to screen (sx, sy).""" - # Sample terrain Z at subsampled resolution - sr = np.clip((rows / f).astype(int), 0, tH - 1) - sc = np.clip((cols / f).astype(int), 0, tW - 1) - z_vals = terrain_np[sr, sc] * self.vertical_exaggeration - # Hover particles well above terrain so they're clearly visible - z_vals = z_vals + 3.0 - - # World coordinates - wx = cols * psx - wy = rows * psy - wz = z_vals - - # Camera-relative position - dx = wx - cam_pos[0] - dy = wy - cam_pos[1] - dz = wz - cam_pos[2] - - # Project onto camera basis - depth = dx * forward[0] + dy * forward[1] + dz * forward[2] - u_cam = dx * right[0] + dy * right[1] + dz * right[2] - v_cam = dx * cam_up[0] + dy * cam_up[1] + dz * cam_up[2] - - # Skip particles behind camera or too close - valid = depth > self._wind_min_depth - - # NDC to screen - u_ndc = np.where(valid, u_cam / (depth * fov_scale * aspect + 1e-10), -2) - v_ndc = np.where(valid, v_cam / (depth * fov_scale + 1e-10), -2) - - sx = ((u_ndc + 1.0) * 0.5 * sw).astype(int) - sy = (((1.0 - v_ndc)) * 0.5 * sh).astype(int) - - # Clip to screen bounds - on_screen = valid & (sx >= 0) & (sx < sw) & (sy >= 0) & (sy < sh) - return sx, sy, on_screen, depth - - # Precompute circular stamp offsets for fat dots - r = self._wind_dot_radius - offsets = [] - for dy in range(-r, r + 1): - for dx in range(-r, r + 1): - dist_sq = dx * dx + dy * dy - if dist_sq <= r * r: - # Smooth circular falloff: 1 at centre, 0 at edge - falloff = 1.0 - (dist_sq / (r * r)) ** 0.5 - offsets.append((dx, dy, falloff)) - - # Draw trails — transparent lime, very subtle fat blobs + # --- Batch all trail positions into one flat array --- + # trails shape: (N, trail_len, 2) → (N * trail_len, 2) + all_pts = self._wind_trails.reshape(-1, 2) # (N*T, 2) + rows_all = all_pts[:, 0] + cols_all = all_pts[:, 1] + + # --- Single batched projection --- + sr = np.clip(np.nan_to_num(rows_all / f, nan=0.0).astype(np.int32), 0, tH - 1) + sc = np.clip(np.nan_to_num(cols_all / f, nan=0.0).astype(np.int32), 0, tW - 1) + z_vals = np.nan_to_num(terrain_np[sr, sc], nan=0.0) * ve + 3.0 + + wx = cols_all * psx + wy = rows_all * psy + + dx = wx - cam_pos[0] + dy = wy - cam_pos[1] + dz = z_vals - cam_pos[2] + + depth = dx * forward[0] + dy * forward[1] + dz * forward[2] + valid = depth > min_depth + + inv_depth = np.where(valid, 1.0 / (depth + 1e-10), 0.0) + u_cam = dx * right[0] + dy * right[1] + dz * right[2] + v_cam = dx * cam_up[0] + dy * cam_up[1] + dz * cam_up[2] + u_ndc = u_cam * inv_depth / (fov_scale * aspect_ratio) + v_ndc = v_cam * inv_depth / fov_scale + + sx_all = np.nan_to_num(((u_ndc + 1.0) * 0.5 * sw), nan=-1.0).astype(np.int32) + sy_all = np.nan_to_num(((1.0 - v_ndc) * 0.5 * sh), nan=-1.0).astype(np.int32) + + on_screen = valid & (sx_all >= 0) & (sx_all < sw) & (sy_all >= 0) & (sy_all < sh) + + # --- Build per-point alpha (fade-in, fade-out, trail decay) --- + # tile ages/lifetimes to match (N*T,) layout + ages = self._wind_ages # (N,) + lifetimes = self._wind_lifetimes # (N,) + + # Trail index for each point: 0=head, 1=prev, ... + trail_idx = np.tile(np.arange(trail_len, dtype=np.float32), N) # (N*T,) + # Particle must be at least trail_idx ticks old + ages_rep = np.repeat(ages, trail_len) # (N*T,) + lifetimes_rep = np.repeat(lifetimes, trail_len) # (N*T,) + age_ok = ages_rep > trail_idx + + # Fade in/out over particle lifetime + # Dead zone: invisible for first _wind_min_visible_age ticks while + # the particle silently builds a trail, then fade in over 10 ticks. + # This eliminates the "twinkle" of a single dot appearing. + mva = self._wind_min_visible_age + fade_in = np.clip((ages_rep - mva) / 10.0, 0, 1) + fade_out = np.clip((lifetimes_rep - ages_rep) / 20.0, 0, 1) + # Trail decay: head=1.0, tail→0.0 + trail_fade = 1.0 - (trail_idx / trail_len) + + alpha = self._wind_alpha * fade_in * fade_out * trail_fade + + # Final mask: on screen, old enough, positive alpha + mask = on_screen & age_ok & (alpha > 1e-6) + if not mask.any(): + return img + + sx_m = sx_all[mask] + sy_m = sy_all[mask] + alpha_m = alpha[mask].astype(np.float32) + + # --- Splat with stamp offsets using np.add.at --- color = np.array([0.3, 0.9, 0.8], dtype=np.float32) + r = self._wind_dot_radius + for offy in range(-r, r + 1): + for offx in range(-r, r + 1): + dist_sq = offx * offx + offy * offy + if dist_sq > r * r: + continue + falloff = 1.0 - (dist_sq / (r * r)) ** 0.5 + + px = sx_m + offx + py = sy_m + offy + ok = (px >= 0) & (px < sw) & (py >= 0) & (py < sh) + if not ok.any(): + continue + + contribution = alpha_m[ok] * falloff + for c in range(3): + np.add.at(img[:, :, c], (py[ok], px[ok]), contribution * color[c]) + + np.clip(img, 0, 1, out=img) + return img + + # ------------------------------------------------------------------ + # GTFS-RT realtime vehicle overlay + # ------------------------------------------------------------------ + + def _init_gtfs_rt(self, realtime_url, route_colors=None): + """Initialize GTFS-RT realtime vehicle polling. + + Parameters + ---------- + realtime_url : str + URL to a GTFS-Realtime VehiclePositions protobuf feed. + route_colors : dict, optional + ``{route_id: (r, g, b)}`` mapping. If not provided, all + vehicles render in white. + """ + self._gtfs_rt_url = realtime_url + if route_colors: + self._gtfs_rt_route_colors = route_colors + print(f"GTFS-RT feed configured: {realtime_url}") + print(" Press Shift+B to toggle realtime vehicle overlay.") + + def _toggle_gtfs_rt(self): + """Toggle GTFS-RT realtime vehicle overlay on/off.""" + if self._gtfs_rt_url is None: + print("No GTFS-RT feed configured. Pass realtime_url in gtfs_data metadata.") + return + self._gtfs_rt_enabled = not self._gtfs_rt_enabled + if self._gtfs_rt_enabled: + if self._gtfs_rt_thread is None or not self._gtfs_rt_thread.is_alive(): + self._gtfs_rt_stop.clear() + self._gtfs_rt_thread = threading.Thread( + target=self._gtfs_rt_poll_loop, daemon=True) + self._gtfs_rt_thread.start() + print("GTFS-RT vehicles: ON") + else: + self._gtfs_rt_stop.set() + print("GTFS-RT vehicles: OFF") + self._update_frame() + + def _gtfs_rt_poll_loop(self): + """Background thread: poll GTFS-RT feed at regular intervals.""" + import requests + + while not self._gtfs_rt_stop.is_set(): + try: + resp = requests.get(self._gtfs_rt_url, timeout=30) + resp.raise_for_status() + self._parse_gtfs_rt_response(resp.content) + self._render_needed = True + except Exception as e: + print(f"GTFS-RT poll error: {e}") + + self._gtfs_rt_stop.wait(self._gtfs_rt_poll_interval) + + def _parse_gtfs_rt_response(self, data): + """Parse GTFS-RT protobuf VehiclePositions into numpy arrays.""" + try: + from google.transit import gtfs_realtime_pb2 + except ImportError: + print("gtfs-realtime-bindings required for GTFS-RT. " + "Install with: pip install gtfs-realtime-bindings") + self._gtfs_rt_stop.set() + self._gtfs_rt_enabled = False + return + + feed = gtfs_realtime_pb2.FeedMessage() + feed.ParseFromString(data) + + positions = [] + bearings = [] + colors = [] + + for entity in feed.entity: + if not entity.HasField('vehicle'): + continue + vp = entity.vehicle + if not vp.HasField('position'): + continue + pos = vp.position + lat = pos.latitude + lon = pos.longitude + bearing = pos.bearing if pos.bearing else 0.0 + + # Determine color from route + route_id = vp.trip.route_id if vp.HasField('trip') else '' + color = self._gtfs_rt_route_colors.get(route_id, (1.0, 1.0, 1.0)) + + positions.append((lon, lat)) + bearings.append(bearing) + colors.append(color) + + if positions: + with self._gtfs_rt_lock: + self._gtfs_rt_vehicles = ( + np.array(positions, dtype=np.float64), + np.array(bearings, dtype=np.float32), + np.array(colors, dtype=np.float32), + ) + + def _draw_gtfs_rt_on_frame(self, img): + """Draw GTFS-RT vehicle positions as colored dots on the frame.""" + with self._gtfs_rt_lock: + if self._gtfs_rt_vehicles is None: + return + positions, bearings, colors = self._gtfs_rt_vehicles + + if len(positions) == 0: + return + + # Convert lon/lat to world coordinates (pixel space) + da = self.raster + y_coords = da.coords[da.dims[-2]].values + x_coords = da.coords[da.dims[-1]].values + + # lon/lat → pixel coords + px_x = (positions[:, 0] - x_coords[0]) / (x_coords[-1] - x_coords[0]) * (len(x_coords) - 1) + px_y = (positions[:, 1] - y_coords[0]) / (y_coords[-1] - y_coords[0]) * (len(y_coords) - 1) + + # World coords (match terrain mesh coordinate system) + wx = px_x * abs(self.pixel_spacing_x) + wy = px_y * abs(self.pixel_spacing_y) + + # Sample terrain Z for each vehicle (nearest neighbor) + H, W = da.shape[-2:] + ix = np.clip(np.round(px_x).astype(int), 0, W - 1) + iy = np.clip(np.round(px_y).astype(int), 0, H - 1) - ages = self._wind_ages # (N,) + terrain_np = self._wind_terrain_np + if terrain_np is None: + try: + import cupy + terrain_np = cupy.asnumpy(da.values) + except Exception: + terrain_np = np.asarray(da.values) + self._wind_terrain_np = terrain_np + + wz = terrain_np[iy, ix].astype(np.float64) * self.vertical_exaggeration + # Replace NaN with 0 + wz = np.where(np.isfinite(wz), wz, 0.0) + + # Project to screen space + world = np.stack([wx, wy, wz], axis=-1) # (N, 3) + cam_pos = np.array(self.position, dtype=np.float64) + cam_fwd = np.array(self._camera_forward(), dtype=np.float64) + cam_right = np.array(self._camera_right(), dtype=np.float64) + cam_up = np.array(self._camera_up(), dtype=np.float64) + + rel = world - cam_pos # (N, 3) + depth = rel @ cam_fwd + behind = depth <= 0.1 + depth[behind] = 1.0 # avoid division by zero + + fov_rad = np.radians(self.fov) + sh, sw = img.shape[:2] + f = sw / (2.0 * np.tan(fov_rad / 2.0)) - for t in range(self._wind_trail_len - 1, -1, -1): - # Don't draw trail points that haven't had time to separate - # (particle must be at least t ticks old for trail slot t) - trail_pts = self._wind_trails[:, t, :] - age_ok = ages > t - sx, sy, on_screen, depth = _project_points(trail_pts[:, 0], trail_pts[:, 1]) + sx = (rel @ cam_right) * f / depth + sw / 2.0 + sy = (rel @ cam_up) * f / depth + sh / 2.0 + # Flip Y (screen Y is top-down) + sy = sh - 1 - sy - mask = on_screen & age_ok - if not mask.any(): - continue + # Filter to on-screen, not behind camera + valid = (~behind) & (sx >= -10) & (sx < sw + 10) & (sy >= -10) & (sy < sh + 10) + if not valid.any(): + return - # Smooth fade-in over the first 15 ticks, fade-out over last 30 - masked_ages = ages[mask] - masked_lifetimes = self._wind_lifetimes[mask] - fade_in = np.clip(masked_ages / 15.0, 0, 1) - fade_out = np.clip((masked_lifetimes - masked_ages) / 30.0, 0, 1) - fade_in = fade_in * fade_out - - sx_m = sx[mask] - sy_m = sy[mask] - - for dx, dy, falloff in offsets: - px = sx_m + dx - py = sy_m + dy - valid = (px >= 0) & (px < sw) & (py >= 0) & (py < sh) - if not valid.any(): - continue - pxv = px[valid] - pyv = py[valid] - pixel_alpha = self._wind_alpha * falloff * fade_in[valid] - img[pyv, pxv, :] = np.clip( - img[pyv, pxv, :] + np.expand_dims(pixel_alpha, 1) * color, - 0, 1, - ) + sx = sx[valid].astype(np.int32) + sy = sy[valid].astype(np.int32) + vc = colors[valid] + r = self._gtfs_rt_dot_radius + alpha = self._gtfs_rt_alpha + + # Splat colored dots + for i in range(len(sx)): + x0 = max(0, sx[i] - r) + x1 = min(sw, sx[i] + r + 1) + y0 = max(0, sy[i] - r) + y1 = min(sh, sy[i] + r + 1) + if x0 >= x1 or y0 >= y1: + continue + # Circular mask + yy, xx = np.mgrid[y0:y1, x0:x1] + dist_sq = (xx - sx[i]) ** 2 + (yy - sy[i]) ** 2 + mask = dist_sq <= r * r + falloff = np.where(mask, 1.0 - np.sqrt(dist_sq[mask].astype(float)) / r, 0.0) + c = vc[i] + for ch in range(3): + patch = img[y0:y1, x0:x1, ch] + patch[mask] = patch[mask] * (1.0 - alpha * falloff) + c[ch] * alpha * falloff return img - def _handle_key_press(self, event): - """Handle key press - add to held keys or handle instant actions.""" - raw_key = event.key if event.key else '' - key = raw_key.lower() + def _cleanup_gtfs_rt(self): + """Stop the GTFS-RT poll thread.""" + if self._gtfs_rt_thread is not None: + self._gtfs_rt_stop.set() + self._gtfs_rt_thread.join(timeout=2.0) + self._gtfs_rt_thread = None + + def _handle_key_press(self, raw_key, key): + """Handle key press - add to held keys or handle instant actions. + + Parameters + ---------- + raw_key : str + Key with original case (uppercase if SHIFT held). + key : str + Lowercase version of the key. + """ # Drone mode cycle: Shift+O (before other keys) if raw_key == 'O': - self._cycle_drone_mode() + obs = self._observers.get(self._active_observer) if self._active_observer else None + if obs is None: + print("No observer selected. Press 1-8 first.") + else: + self._cycle_drone_mode_for(obs) return - # Snap camera to drone: Shift+V + # Snap camera to active observer: Shift+V if raw_key == 'V': - self._snap_to_drone() + obs = self._observers.get(self._active_observer) if self._active_observer else None + if obs is None: + print("No observer selected. Press 1-8 first.") + else: + self._snap_to_observer(obs) + return + + # Kill all observers: Shift+K + if raw_key == 'K': + self._clear_all_observers() return # FIRMS fire layer: Shift+F (before 'f' screenshot) @@ -1454,6 +3105,52 @@ def _handle_key_press(self, event): self._toggle_wind() return + # GTFS-RT realtime vehicle toggle: Shift+B + if raw_key == 'B': + self._toggle_gtfs_rt() + return + + # Denoiser toggle: Shift+D (before movement keys capture 'd') + if raw_key == 'D': + self.denoise_enabled = not self.denoise_enabled + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + self._prev_cam_for_flow = None + print(f"Denoiser: {'ON' if self.denoise_enabled else 'OFF'}") + self._update_frame() + return + + # GI bounces cycle: Shift+G (1 → 2 → 3 → 1) + if raw_key == 'G': + self.gi_bounces = self.gi_bounces % 3 + 1 + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"GI bounces: {self.gi_bounces}") + self._update_frame() + return + + # Drone glow toggle: Shift+L + if raw_key == 'L': + self._drone_glow = not self._drone_glow + self._apply_drone_glow() + print(f"Drone glow: {'ON' if self._drone_glow else 'OFF'}") + return + + # Time-of-day cycle: Shift+T (before 't' shadows toggle) + if raw_key == 'T': + self._time_preset_idx = (self._time_preset_idx + 1) % len(self._time_presets) + name, az, alt = self._time_presets[self._time_preset_idx] + self.sun_azimuth = az + self.sun_altitude = alt + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"Time of day: {name} (az={az:.0f}, alt={alt:.0f})") + self._update_frame() + return + # Movement/look keys are tracked as held movement_keys = {'w', 's', 'a', 'd', 'up', 'down', 'left', 'right', 'q', 'e', 'pageup', 'pagedown', 'i', 'j', 'k', 'l'} @@ -1500,10 +3197,17 @@ def _handle_key_press(self, event): self.show_minimap = not self.show_minimap self._update_frame() - # Viewshed controls + # Observer slot selection: 1-8 + elif key in ('1', '2', '3', '4', '5', '6', '7', '8'): + self._select_or_create_observer(int(key)) + + # Move active observer to camera position elif key == 'o': - if self._drone_mode == 'off': - self._place_observer() + obs = self._observers.get(self._active_observer) if self._active_observer else None + if obs is None: + print("No observer selected. Press 1-8 to create one.") + elif obs.drone_mode == 'off': + self._place_observer_at(obs) elif key == 'v': self._toggle_viewshed() elif key == '[': @@ -1533,13 +3237,14 @@ def _handle_key_press(self, event): print(f"Color stretch: {self.color_stretch}") self._update_frame() - # Toggle mesh type (tin ↔ voxel) + # Cycle mesh type (tin → voxel → heightfield → tin) elif key == 'b': - self.mesh_type = 'voxel' if self.mesh_type == 'tin' else 'tin' + cycle = {'tin': 'voxel', 'voxel': 'heightfield', 'heightfield': 'tin'} + self.mesh_type = cycle.get(self.mesh_type, 'tin') self._rebuild_vertical_exaggeration(self.vertical_exaggeration) print(f"Mesh type: {self.mesh_type}") - # Basemap cycling: U = cycle none → satellite → osm → topo → none + # Basemap cycling: U = cycle none → satellite → osm → none elif key == 'u': self._cycle_basemap() @@ -1564,32 +3269,92 @@ def _handle_key_press(self, event): if new_ve != self.vertical_exaggeration: self._rebuild_vertical_exaggeration(new_ve) + # Ambient occlusion toggle + elif key == '0': + self.ao_enabled = not self.ao_enabled + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"Ambient Occlusion: {'ON' if self.ao_enabled else 'OFF'}") + self._update_frame() + + # Depth of field toggle + elif key == '9': + self.dof_enabled = not self.dof_enabled + # Reset accumulation so DOF takes effect immediately + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"Depth of Field: {'ON' if self.dof_enabled else 'OFF'}") + self._update_frame() + + # DOF aperture: ; = decrease, ' = increase + elif key == ';': + self._dof_aperture = max(1.0, self._dof_aperture * 0.7) + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"DOF aperture: {self._dof_aperture:.1f}") + self._update_frame() + elif key == "'": + self._dof_aperture = min(200.0, self._dof_aperture * 1.4) + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"DOF aperture: {self._dof_aperture:.1f}") + self._update_frame() + + # DOF focal distance: : = decrease, " = increase + elif key == ':': + self._dof_focal_distance = max(10.0, self._dof_focal_distance * 0.7) + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"DOF focal distance: {self._dof_focal_distance:.0f}") + self._update_frame() + elif key == '"': + self._dof_focal_distance = min(10000.0, self._dof_focal_distance * 1.4) + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = None + print(f"DOF focal distance: {self._dof_focal_distance:.0f}") + self._update_frame() + # Exit elif key in ('escape', 'x'): self.running = False - if self._timer is not None: - self._timer.stop() - import matplotlib.pyplot as plt - plt.close(self.fig) - - def _handle_key_release(self, event): - """Handle key release - remove from held keys.""" - key = event.key.lower() if event.key else '' + + def _handle_key_release(self, key): + """Handle key release - remove from held keys. + + Parameters + ---------- + key : str + Lowercase key name. + """ self._held_keys.discard(key) - def _get_drone_front(self): - """Get forward direction for drone flight (uses drone yaw/pitch).""" - yaw_rad = np.radians(self._drone_yaw) - pitch_rad = np.radians(self._drone_pitch) + @staticmethod + def _get_drone_front_for(obs): + """Get forward direction for drone flight (uses observer yaw/pitch).""" + yaw_rad = np.radians(obs.yaw) + pitch_rad = np.radians(obs.pitch) return np.array([ np.cos(yaw_rad) * np.cos(pitch_rad), np.sin(yaw_rad) * np.cos(pitch_rad), np.sin(pitch_rad) ], dtype=np.float32) - def _get_drone_right(self): + @staticmethod + def _get_drone_right_for(obs): """Get right direction for drone flight.""" - front = self._get_drone_front() + yaw_rad = np.radians(obs.yaw) + pitch_rad = np.radians(obs.pitch) + front = np.array([ + np.cos(yaw_rad) * np.cos(pitch_rad), + np.sin(yaw_rad) * np.cos(pitch_rad), + np.sin(pitch_rad) + ], dtype=np.float32) world_up = np.array([0, 0, 1], dtype=np.float32) right = np.cross(world_up, front) return right / (np.linalg.norm(right) + 1e-8) @@ -1606,22 +3371,22 @@ def _clamp_drone_pos(self, pos): pos[2] = terrain_z return pos - def _sync_drone_from_pos(self, pos): - """Update observer position and drone mesh from a 3D position.""" + def _sync_drone_from_pos_for(self, obs, pos): + """Update an observer's position and drone mesh from a 3D position.""" pos = self._clamp_drone_pos(pos) - self._observer_position = (float(pos[0]), float(pos[1])) - self.viewshed_observer_elev = float(pos[2]) - self._get_terrain_z( + obs.position = (float(pos[0]), float(pos[1])) + obs.observer_elev = float(pos[2]) - self._get_terrain_z( pos[0], pos[1]) - if self.viewshed_observer_elev < 0: - self.viewshed_observer_elev = 0.0 - self._update_observer_drone() + if obs.observer_elev < 0: + obs.observer_elev = 0.0 + self._update_observer_drone_for(obs) # Dynamically recalculate viewshed as the drone moves (throttled) - if self.viewshed_enabled: + if obs.viewshed_enabled: now = time.monotonic() if now - self._last_viewshed_time >= self._viewshed_recalc_interval: self._last_viewshed_time = now - self._viewshed_cache = None + obs.viewshed_cache = None self._calculate_viewshed(quiet=True) def _check_terrain_reload(self): @@ -1673,6 +3438,7 @@ def _check_terrain_reload(self): # Replace rasters self._base_raster = new_raster self.raster = new_raster + self._wind_terrain_np = None # invalidate cached terrain # Update coordinate tracking self._coord_origin_x = new_origin_x @@ -1690,6 +3456,19 @@ def _check_terrain_reload(self): else: terrain_np = np.asarray(terrain_data) + # Detect ocean-fill (0-valued pixels) and replace with NaN + ocean_fill = (terrain_np == 0.0) & ~np.isnan(terrain_np) + if ocean_fill.sum() > terrain_np.size * 0.01: + terrain_np[ocean_fill] = np.nan + if hasattr(new_raster.data, 'get'): + new_data = new_raster.data.copy() + new_data[cp.asarray(ocean_fill)] = cp.nan + else: + new_data = new_raster.data.copy() + new_data[ocean_fill] = np.nan + self._base_raster = new_raster.copy(data=new_data) + self.raster = self._base_raster + ve = self.vertical_exaggeration self.elev_min = float(np.nanmin(terrain_np)) * ve self.elev_max = float(np.nanmax(terrain_np)) * ve @@ -1714,40 +3493,51 @@ def _check_terrain_reload(self): from . import mesh as mesh_mod H, W = new_H, new_W - if self.mesh_type == 'voxel': - num_verts = H * W * 8 - num_tris = H * W * 12 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(vertices, indices, new_raster, scale=1.0, - base_elevation=base_elev) - else: - num_verts = H * W - num_tris = (H - 1) * (W - 1) * 2 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(vertices, indices, new_raster, scale=1.0) - - # Scale x,y to world units - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - vertices[0::3] *= self.pixel_spacing_x - vertices[1::3] *= self.pixel_spacing_y - - # Apply vertical exaggeration - if ve != 1.0: - vertices[2::3] *= ve - - # Cache the new mesh cache_key = (self.subsample_factor, self.mesh_type) - base_verts = vertices.copy() - if ve != 1.0: - base_verts[2::3] /= ve - self._terrain_mesh_cache[cache_key] = (base_verts, indices.copy(), terrain_np.copy()) - # Replace terrain geometry - if self.rtx is not None: - self.rtx.add_geometry('terrain', vertices, indices) + if self.mesh_type == 'heightfield': + if self.rtx is not None: + self.rtx.add_heightfield_geometry( + 'terrain', terrain_np, H, W, + spacing_x=self.pixel_spacing_x, + spacing_y=self.pixel_spacing_y, + ve=ve, + ) + self._terrain_mesh_cache[cache_key] = (None, None, terrain_np.copy()) + else: + if self.mesh_type == 'voxel': + num_verts = H * W * 8 + num_tris = H * W * 12 + vertices = np.zeros(num_verts * 3, dtype=np.float32) + indices = np.zeros(num_tris * 3, dtype=np.int32) + base_elev = float(np.nanmin(terrain_np)) + mesh_mod.voxelate_terrain(vertices, indices, new_raster, scale=1.0, + base_elevation=base_elev) + else: + num_verts = H * W + num_tris = (H - 1) * (W - 1) * 2 + vertices = np.zeros(num_verts * 3, dtype=np.float32) + indices = np.zeros(num_tris * 3, dtype=np.int32) + mesh_mod.triangulate_terrain(vertices, indices, new_raster, scale=1.0) + + # Scale x,y to world units + if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: + vertices[0::3] *= self.pixel_spacing_x + vertices[1::3] *= self.pixel_spacing_y + + # Apply vertical exaggeration + if ve != 1.0: + vertices[2::3] *= ve + + # Cache the new mesh + base_verts = vertices.copy() + if ve != 1.0: + base_verts[2::3] /= ve + self._terrain_mesh_cache[cache_key] = (base_verts, indices.copy(), terrain_np.copy()) + + # Replace terrain geometry + if self.rtx is not None: + self.rtx.add_geometry('terrain', vertices, indices) # Reposition camera in new window self.position = np.array([ @@ -1758,10 +3548,9 @@ def _check_terrain_reload(self): # Refresh minimap self._compute_minimap_background() - if self._minimap_im is not None: - self._minimap_im.set_data(self._minimap_background) self._last_reload_time = time.time() + self._render_needed = True print(f"Terrain reloaded: center ({cam_lon:.4f}, {cam_lat:.4f}), " f"window {new_W}x{new_H}") @@ -1770,43 +3559,58 @@ def _tick(self): if not self.running: return + # Delta-time: scale movement relative to the old 20 Hz reference rate + now = time.monotonic() + dt = now - self._last_tick_time + self._last_tick_time = now + # Clamp to avoid huge jumps (e.g. after a stall or first frame) + dt = min(dt, 0.1) + dt_scale = dt / 0.05 # 0.05 = 1/20 Hz reference + # Process held movement / look keys if self._held_keys: - if self._drone_mode == '3rd' and self._observer_drone_placed: + speed = self.move_speed * dt_scale + look = self.look_speed * dt_scale + + # Get active observer (if any) + active_obs = (self._observers.get(self._active_observer) + if self._active_observer else None) + + if (active_obs is not None and active_obs.drone_mode == '3rd' + and active_obs.drone_placed): # --- 3rd-person: WASD/IJKL fly the drone, camera stays --- - front = self._get_drone_front() - right = self._get_drone_right() + front = self._get_drone_front_for(active_obs) + right = self._get_drone_right_for(active_obs) - # Current drone 3D position - obs_x, obs_y = self._observer_position + obs_x, obs_y = active_obs.position terrain_z = self._get_terrain_z(obs_x, obs_y) drone_pos = np.array([obs_x, obs_y, - terrain_z + self.viewshed_observer_elev], + terrain_z + active_obs.observer_elev], dtype=float) if 'w' in self._held_keys or 'up' in self._held_keys: - drone_pos += front * self.move_speed + drone_pos += front * speed if 's' in self._held_keys or 'down' in self._held_keys: - drone_pos -= front * self.move_speed + drone_pos -= front * speed if 'a' in self._held_keys or 'left' in self._held_keys: - drone_pos -= right * self.move_speed + drone_pos -= right * speed if 'd' in self._held_keys or 'right' in self._held_keys: - drone_pos += right * self.move_speed + drone_pos += right * speed if 'q' in self._held_keys or 'pageup' in self._held_keys: - drone_pos[2] += self.move_speed + drone_pos[2] += speed if 'e' in self._held_keys or 'pagedown' in self._held_keys: - drone_pos[2] -= self.move_speed + drone_pos[2] -= speed if 'i' in self._held_keys: - self._drone_pitch = min(89, self._drone_pitch + self.look_speed) + active_obs.pitch = min(89, active_obs.pitch + look) if 'k' in self._held_keys: - self._drone_pitch = max(-89, self._drone_pitch - self.look_speed) + active_obs.pitch = max(-89, active_obs.pitch - look) if 'j' in self._held_keys: - self._drone_yaw -= self.look_speed + active_obs.yaw -= look if 'l' in self._held_keys: - self._drone_yaw += self.look_speed + active_obs.yaw += look - self._sync_drone_from_pos(drone_pos) + self._sync_drone_from_pos_for(active_obs, drone_pos) else: # --- Normal / FPV: WASD moves camera --- @@ -1814,37 +3618,58 @@ def _tick(self): right = self._get_right() if 'w' in self._held_keys or 'up' in self._held_keys: - self.position += front * self.move_speed + self.position += front * speed if 's' in self._held_keys or 'down' in self._held_keys: - self.position -= front * self.move_speed + self.position -= front * speed if 'a' in self._held_keys or 'left' in self._held_keys: - self.position -= right * self.move_speed + self.position -= right * speed if 'd' in self._held_keys or 'right' in self._held_keys: - self.position += right * self.move_speed + self.position += right * speed if 'q' in self._held_keys or 'pageup' in self._held_keys: cam_up = np.cross(front, right) cam_up /= (np.linalg.norm(cam_up) + 1e-8) - self.position += cam_up * self.move_speed + self.position += cam_up * speed if 'e' in self._held_keys or 'pagedown' in self._held_keys: cam_up = np.cross(front, right) cam_up /= (np.linalg.norm(cam_up) + 1e-8) - self.position -= cam_up * self.move_speed + self.position -= cam_up * speed if 'i' in self._held_keys: - self.pitch = min(89, self.pitch + self.look_speed) + self.pitch = min(89, self.pitch + look) if 'k' in self._held_keys: - self.pitch = max(-89, self.pitch - self.look_speed) + self.pitch = max(-89, self.pitch - look) if 'j' in self._held_keys: - self.yaw -= self.look_speed + self.yaw -= look if 'l' in self._held_keys: - self.yaw += self.look_speed + self.yaw += look # In FPV, sync drone to camera - if self._drone_mode == 'fpv' and self._observer_drone_placed: - self._sync_drone_from_pos(self.position) + if (active_obs is not None and active_obs.drone_mode == 'fpv' + and active_obs.drone_placed): + self._sync_drone_from_pos_for(active_obs, self.position) + + self._render_needed = True + self._dt_scale = dt_scale self._check_terrain_reload() - self._update_frame() + if self._chunk_manager is not None: + if self._chunk_manager.update(self.position[0], self.position[1], self): + self._geometry_colors_builder = self._accessor._build_geometry_colors_gpu + self._render_needed = True + # AO: keep accumulating samples when camera is stationary + if (self.ao_enabled and not self._held_keys + and not self._mouse_dragging + and self._ao_frame_count < self._ao_max_frames): + self._render_needed = True + + if self._render_needed: + self._update_frame() + self._render_needed = False + elif self._wind_enabled and self._wind_particles is not None and self._pinned_frame is not None: + # Wind is on but camera didn't move — skip the expensive ray + # trace and just re-advect particles + re-composite overlays. + self._update_wind_particles() + self._composite_overlays() def _cycle_terrain_layer(self): """Cycle terrain color: elevation → overlay1 → overlay2 → ... → elevation. @@ -1871,7 +3696,7 @@ def _cycle_terrain_layer(self): self._update_frame() def _cycle_basemap(self): - """Cycle basemap: none → satellite → osm → topo → none. + """Cycle basemap: none → satellite → osm → none. Auto-creates XYZTileService on-the-fly if needed. """ @@ -2094,129 +3919,150 @@ def _load_drone_parts(self): parts.append((verts.flatten(), faces.flatten(), color)) return parts - def _update_observer_drone(self): - """Place or update the drone mesh at the observer position.""" - if self._observer_position is None or self.rtx is None: + def _update_observer_drone_for(self, obs): + """Place or update the drone mesh at an observer's position.""" + if obs.position is None or self.rtx is None: return from .mesh import make_transform - # Lazy-load drone parts once - if self._observer_drone_parts is None: - self._observer_drone_parts = self._load_drone_parts() - if not self._observer_drone_parts: + # Lazy-load drone parts once (shared across all observers) + if self._shared_drone_parts is None: + self._shared_drone_parts = self._load_drone_parts() + if not self._shared_drone_parts: return - obs_x, obs_y = self._observer_position + obs_x, obs_y = obs.position terrain_z = self._get_terrain_z(obs_x, obs_y) - obs_z = terrain_z + self.viewshed_observer_elev + obs_z = terrain_z + obs.observer_elev # Scale drone to ~0.05× pixel_spacing so it's visible but not huge drone_scale = 0.0125 * max(self.pixel_spacing_x, self.pixel_spacing_y) transform = make_transform(x=obs_x, y=obs_y, z=obs_z, scale=drone_scale) - for i, (verts, idxs, color) in enumerate(self._observer_drone_parts): - gid = f'_observer_{i}' - if self._observer_drone_placed: + # Tint base colors toward observer slot color + slot_color = obs.color + for i, (verts, idxs, base_color) in enumerate(self._shared_drone_parts): + gid = obs.geometry_id(i) + # Mix: 50% base + 50% slot tint + tinted = tuple(0.5 * base_color[c] + 0.5 * slot_color[c] + for c in range(3)) + if obs.drone_placed: self.rtx.update_transform(gid, transform) else: self.rtx.add_geometry(gid, verts, idxs, transform=transform) # Set geometry colors (needs the accessor's color dict) - if not self._observer_drone_placed: + if not obs.drone_placed: builder = getattr(self, '_geometry_colors_builder', None) if builder is not None: - # Access the accessor's _geometry_colors dict via the builder's __self__ acc = getattr(builder, '__self__', None) if acc is not None and hasattr(acc, '_geometry_colors'): - for i, (_, _, color) in enumerate(self._observer_drone_parts): - acc._geometry_colors[f'_observer_{i}'] = color + for i, (_, _, base_color) in enumerate(self._shared_drone_parts): + if self._drone_glow: + color = (*slot_color, 1.8) + else: + color = tuple(0.5 * base_color[c] + 0.5 * slot_color[c] + for c in range(3)) + acc._geometry_colors[obs.geometry_id(i)] = color acc._geometry_colors_dirty = True - self._observer_drone_placed = True + obs.drone_placed = True - def _set_drone_visibility(self, visible): - """Show or hide all drone sub-mesh geometries.""" - if self._observer_drone_placed and self.rtx is not None: - for i in range(len(self._observer_drone_parts or [])): - self.rtx.set_geometry_visible(f'_observer_{i}', visible) + def _apply_drone_glow(self): + """Toggle emissive glow on/off for all placed drone geometries.""" + builder = getattr(self, '_geometry_colors_builder', None) + if builder is None: + return + acc = getattr(builder, '__self__', None) + if acc is None or not hasattr(acc, '_geometry_colors'): + return + parts = self._shared_drone_parts + if not parts: + return + changed = False + for obs in self._observers.values(): + if not obs.drone_placed: + continue + slot_color = obs.color + for i, (_, _, base_color) in enumerate(parts): + gid = obs.geometry_id(i) + if self._drone_glow: + acc._geometry_colors[gid] = (*slot_color, 1.8) + else: + acc._geometry_colors[gid] = tuple( + 0.5 * base_color[c] + 0.5 * slot_color[c] + for c in range(3)) + changed = True + if changed: + acc._geometry_colors_dirty = True + self._update_frame() - def _cycle_drone_mode(self): - """Cycle drone control mode: off → 3rd person → FPV → off (Shift+O). + def _set_drone_visibility_for(self, obs, visible): + """Show or hide all drone sub-mesh geometries for an observer.""" + if obs.drone_placed and self.rtx is not None: + for i in range(len(self._shared_drone_parts or [])): + self.rtx.set_geometry_visible(obs.geometry_id(i), visible) - off: - Normal camera control. Drone is visible at observer position. - 3rd person: - Camera stays fixed. WASD/IJKL fly the drone. Watch it move. - FPV: - Camera = drone. WASD flies both. Drone mesh hidden. - """ - if self._observer_position is None: - print("No observer placed. Press O first.") + def _cycle_drone_mode_for(self, obs): + """Cycle drone mode for observer: off → 3rd person → FPV → off.""" + if obs.position is None: + print(f"Observer {obs.slot} has no position.") return - if self._drone_mode == 'off': + if obs.drone_mode == 'off': # --- Enter 3rd person --- - # Save camera so we can restore on full exit - self._saved_camera = ( + obs.saved_camera = ( self.position.copy(), float(self.yaw), float(self.pitch), ) - # Initialise drone heading from camera yaw - self._drone_yaw = float(self.yaw) - self._drone_pitch = 0.0 - self._drone_mode = '3rd' - print("Drone 3RD PERSON: ON (WASD flies drone, Shift+O → FPV)") + obs.yaw = float(self.yaw) + obs.pitch = 0.0 + obs.drone_mode = '3rd' + print(f"Observer {obs.slot} DRONE 3RD PERSON: ON") - elif self._drone_mode == '3rd': + elif obs.drone_mode == '3rd': # --- 3rd person → FPV --- - obs_x, obs_y = self._observer_position + obs_x, obs_y = obs.position terrain_z = self._get_terrain_z(obs_x, obs_y) - obs_z = terrain_z + self.viewshed_observer_elev + obs_z = terrain_z + obs.observer_elev self.position = np.array([obs_x, obs_y, obs_z], dtype=float) - self.yaw = self._drone_yaw - self.pitch = self._drone_pitch - # Hide drone mesh (you are the drone) - self._set_drone_visibility(False) - self._drone_mode = 'fpv' - print("Drone FPV: ON (WASD flies camera+drone, Shift+O → exit)") + self.yaw = obs.yaw + self.pitch = obs.pitch + self._set_drone_visibility_for(obs, False) + obs.drone_mode = 'fpv' + print(f"Observer {obs.slot} DRONE FPV: ON") else: # --- FPV → off --- - # Sync final drone position from camera - self._sync_drone_from_pos(self.position) - # Show drone mesh again - self._set_drone_visibility(True) - # Restore saved external camera - if self._saved_camera is not None: - self.position = self._saved_camera[0] - self.yaw = self._saved_camera[1] - self.pitch = self._saved_camera[2] - self._saved_camera = None - self._drone_mode = 'off' - print("Drone mode: OFF") + self._sync_drone_from_pos_for(obs, self.position) + self._set_drone_visibility_for(obs, True) + if obs.saved_camera is not None: + self.position = obs.saved_camera[0] + self.yaw = obs.saved_camera[1] + self.pitch = obs.saved_camera[2] + obs.saved_camera = None + obs.drone_mode = 'off' + print(f"Observer {obs.slot} DRONE: OFF") self._update_frame() - def _snap_to_drone(self): - """Snap external camera to look at the drone from nearby (Shift+V).""" - if self._observer_position is None: - print("No observer placed. Press O first.") + def _snap_to_observer(self, obs): + """Snap external camera to look at an observer's drone from nearby.""" + if obs.position is None: + print(f"Observer {obs.slot} has no position.") return - if self._drone_mode == 'fpv': - # Already in FPV — nothing to snap to + if obs.drone_mode == 'fpv': return - obs_x, obs_y = self._observer_position + obs_x, obs_y = obs.position terrain_z = self._get_terrain_z(obs_x, obs_y) - obs_z = terrain_z + self.viewshed_observer_elev + obs_z = terrain_z + obs.observer_elev - # Place camera a short distance behind and above the drone spacing = max(self.pixel_spacing_x, self.pixel_spacing_y) - offset = spacing * 8.0 # 8 pixels back - # Direction from drone to current camera (keep the viewing angle) + offset = spacing * 8.0 dx = self.position[0] - obs_x dy = self.position[1] - obs_y dist_xy = np.sqrt(dx * dx + dy * dy) @@ -2224,16 +4070,14 @@ def _snap_to_drone(self): dx /= dist_xy dy /= dist_xy else: - # Camera is right on top of drone — pick an arbitrary direction dx, dy = 1.0, 0.0 self.position = np.array([ obs_x + dx * offset, obs_y + dy * offset, - obs_z + spacing * 3.0, # a bit above + obs_z + spacing * 3.0, ], dtype=float) - # Point camera at the drone to_drone = np.array([obs_x - self.position[0], obs_y - self.position[1], obs_z - self.position[2]]) @@ -2241,48 +4085,195 @@ def _snap_to_drone(self): self.yaw = float(np.degrees(np.arctan2(to_drone[1], to_drone[0]))) self.pitch = float(np.degrees(np.arcsin(np.clip(to_drone[2], -1, 1)))) - print(f"Snapped to drone at ({obs_x:.0f}, {obs_y:.0f})") + print(f"Snapped to observer {obs.slot} at ({obs_x:.0f}, {obs_y:.0f})") self._update_frame() - def _place_observer(self): - """Place a viewshed observer at the current camera position on terrain. + def _place_observer_at(self, obs, x=None, y=None): + """Move an observer to a position (defaults to camera XY). - The observer is placed at the camera's x,y location, projected onto - the terrain surface. This becomes the fixed point for viewshed analysis. - Observer position is stored in world coordinates (same as camera). + Parameters + ---------- + obs : Observer + The observer to position. + x, y : float, optional + World coordinates. If None, use current camera position. """ H, W = self.terrain_shape + cam_x = x if x is not None else self.position[0] + cam_y = y if y is not None else self.position[1] - # Use camera position (in world coordinates if pixel_spacing != 1.0) - cam_x = self.position[0] - cam_y = self.position[1] - - # Compute terrain bounds in world coordinates max_x = (W - 1) * self.pixel_spacing_x max_y = (H - 1) * self.pixel_spacing_y - # Clamp to terrain bounds (in world coordinates) obs_x = float(np.clip(cam_x, 0, max_x)) obs_y = float(np.clip(cam_y, 0, max_y)) - # Store in world coordinates - self._observer_position = (obs_x, obs_y) + obs.position = (obs_x, obs_y) + self._update_observer_drone_for(obs) - # Place/update drone mesh at observer position - self._update_observer_drone() + print(f"Observer {obs.slot} placed at ({obs_x:.0f}, {obs_y:.0f})") - # Also compute pixel indices for display - px_x = int(obs_x / self.pixel_spacing_x) - px_y = int(obs_y / self.pixel_spacing_y) + if obs.viewshed_enabled: + self._calculate_viewshed(quiet=True) - print(f"Observer placed at world ({obs_x:.0f}, {obs_y:.0f}), pixel ({px_x}, {px_y})") - print(f" Height: {self.viewshed_observer_elev:.3f} above terrain") - print(f" Press V to toggle viewshed, [/] to adjust height") + self._update_frame() - # If viewshed is already enabled, recalculate - if self.viewshed_enabled: - self._calculate_viewshed() + def _select_or_create_observer(self, slot): + """Handle number key 1-8: select/create/deselect observer slot.""" + if self._active_observer == slot: + # Deselect — exit FPV first if active + obs = self._observers.get(slot) + if obs is not None and obs.drone_mode == 'fpv': + self._exit_fpv_for(obs) + self._active_observer = None + self.viewshed_enabled = False + self._viewshed_cache = None + print(f"Observer {slot}: deselected") + self._update_frame() + return + + # If switching away from an FPV observer, exit FPV first + if self._active_observer is not None: + prev_obs = self._observers.get(self._active_observer) + if prev_obs is not None and prev_obs.drone_mode == 'fpv': + self._exit_fpv_for(prev_obs) + + if slot in self._observers: + # Select existing — auto-enter FPV + self._active_observer = slot + obs = self._observers[slot] + # Sync viewer-level viewshed from this observer + self.viewshed_enabled = obs.viewshed_enabled + self._viewshed_cache = obs.viewshed_cache + # Enter FPV: save camera, snap to observer, hide drone + obs.saved_camera = ( + self.position.copy(), + float(self.yaw), + float(self.pitch), + ) + obs_x, obs_y = obs.position + terrain_z = self._get_terrain_z(obs_x, obs_y) + obs_z = terrain_z + obs.observer_elev + self.position = np.array([obs_x, obs_y, obs_z], dtype=float) + self.yaw = obs.yaw + self.pitch = obs.pitch + self._set_drone_visibility_for(obs, False) + obs.drone_mode = 'fpv' + print(f"Observer {slot}: FPV at ({obs.position[0]:.0f}, {obs.position[1]:.0f})") + else: + # Create new just in front of camera, matching altitude and angle + front = self._get_front() + spacing = max(self.pixel_spacing_x, self.pixel_spacing_y) + offset = spacing * 3 # A few pixels in front + obs_x = self.position[0] + front[0] * offset + obs_y = self.position[1] + front[1] * offset + # Clamp to terrain bounds + H, W = self.terrain_shape + obs_x = float(np.clip(obs_x, 0, (W - 1) * self.pixel_spacing_x)) + obs_y = float(np.clip(obs_y, 0, (H - 1) * self.pixel_spacing_y)) + terrain_z = self._get_terrain_z(obs_x, obs_y) + cam_elev = max(0.0, self.position[2] - terrain_z) + obs = Observer(slot, position=(obs_x, obs_y), + observer_elev=cam_elev) + obs.yaw = self.yaw + obs.pitch = self.pitch + self._observers[slot] = obs + self._active_observer = slot + self._update_observer_drone_for(obs) + print(f"Observer {slot} placed at ({obs_x:.0f}, {obs_y:.0f}), " + f"h={cam_elev:.3f}, yaw={self.yaw:.0f}, pitch={self.pitch:.0f}") + if obs.viewshed_enabled: + self._calculate_viewshed(quiet=True) + self._update_frame() + return + + self._update_frame() + + def _exit_fpv_for(self, obs): + """Exit FPV mode for an observer, restoring camera.""" + if obs.drone_mode != 'fpv': + return + self._sync_drone_from_pos_for(obs, self.position) + self._set_drone_visibility_for(obs, True) + if obs.saved_camera is not None: + self.position = obs.saved_camera[0] + self.yaw = obs.saved_camera[1] + self.pitch = obs.saved_camera[2] + obs.saved_camera = None + obs.drone_mode = 'off' + + def _clear_observer_slot(self, slot): + """Remove a single observer and its geometry.""" + obs = self._observers.get(slot) + if obs is None: + return + + # Stop tour if running + obs.stop_tour() + + # Exit drone mode (restore camera if FPV) + if obs.drone_mode != 'off': + if obs.drone_mode == 'fpv': + self._set_drone_visibility_for(obs, True) + if obs.saved_camera is not None: + self.position = obs.saved_camera[0] + self.yaw = obs.saved_camera[1] + self.pitch = obs.saved_camera[2] + obs.saved_camera = None + obs.drone_mode = 'off' + + # Remove drone geometry + if obs.drone_placed and self.rtx is not None: + n = len(self._shared_drone_parts) if self._shared_drone_parts else 0 + builder = getattr(self, '_geometry_colors_builder', None) + acc = getattr(builder, '__self__', None) if builder else None + for i in range(n): + gid = obs.geometry_id(i) + self.rtx.remove_geometry(gid) + if acc is not None and hasattr(acc, '_geometry_colors'): + acc._geometry_colors.pop(gid, None) + if acc is not None and hasattr(acc, '_geometry_colors_dirty'): + acc._geometry_colors_dirty = True + obs.drone_placed = False + + del self._observers[slot] + if self._active_observer == slot: + self._active_observer = None + + print(f"Observer {slot} removed") + self._update_frame() + + def _clear_all_observers(self): + """Kill all observers — stop tours, exit drone modes, remove geometry.""" + # Find if any observer is in FPV and restore camera + for obs in self._observers.values(): + if obs.drone_mode == 'fpv' and obs.saved_camera is not None: + self.position = obs.saved_camera[0] + self.yaw = obs.saved_camera[1] + self.pitch = obs.saved_camera[2] + break # Only one can be in FPV at a time + + for slot in list(self._observers.keys()): + obs = self._observers[slot] + obs.stop_tour() + # Remove drone geometry + if obs.drone_placed and self.rtx is not None: + n = len(self._shared_drone_parts) if self._shared_drone_parts else 0 + builder = getattr(self, '_geometry_colors_builder', None) + acc = getattr(builder, '__self__', None) if builder else None + for i in range(n): + gid = obs.geometry_id(i) + self.rtx.remove_geometry(gid) + if acc is not None and hasattr(acc, '_geometry_colors'): + acc._geometry_colors.pop(gid, None) + if acc is not None and hasattr(acc, '_geometry_colors_dirty'): + acc._geometry_colors_dirty = True + self._observers.clear() + self._active_observer = None + self.viewshed_enabled = False + self._viewshed_cache = None + print("All observers removed") self._update_frame() def _calculate_viewshed(self, quiet=False): @@ -2299,13 +4290,20 @@ def _calculate_viewshed(self, quiet=False): """ from .analysis.viewshed import _viewshed_rt - if self._observer_position is None: + # Get observer position: from _calculate_viewshed_for compat bridge, + # or from the active observer + obs_pos = getattr(self, '_observer_position_compat', None) + if obs_pos is None: + # Try active observer + obs = self._observers.get(self._active_observer) if self._active_observer else None + if obs is not None: + obs_pos = obs.position + if obs_pos is None: if not quiet: - print("No observer placed. Press O to place an observer first.") + print("No observer placed. Press 1-8 to create one.") return None - # Observer position is in world coordinates - world_x, world_y = self._observer_position + world_x, world_y = obs_pos H, W = self.terrain_shape # Convert world coords to pixel indices @@ -2349,12 +4347,13 @@ def _calculate_viewshed(self, quiet=False): elif not entry.visible: rtx.set_geometry_visible(geom_id, True) - # Always hide the drone so it doesn't block its own viewshed - if self._observer_drone_placed and self._observer_drone_parts: - for i in range(len(self._observer_drone_parts)): - gid = f'_observer_{i}' - saved_visibility[gid] = True - rtx.set_geometry_visible(gid, False) + # Hide all observer drones so they don't block viewshed + for obs in self._observers.values(): + if obs.drone_placed and self._shared_drone_parts: + for i in range(len(self._shared_drone_parts)): + gid = obs.geometry_id(i) + saved_visibility[gid] = True + rtx.set_geometry_visible(gid, False) def _enable_structures(): """Callback: make structures visible for occlusion trace.""" @@ -2488,81 +4487,69 @@ def _apply_viewshed_overlay(self, img): return result.astype(np.uint8) def _toggle_viewshed(self): - """Toggle viewshed overlay on/off.""" - if self._observer_position is None: - print("No observer placed. Press O to place an observer first.") + """Toggle viewshed overlay on/off for the active observer.""" + obs = self._observers.get(self._active_observer) if self._active_observer else None + if obs is None: + print("No observer selected. Press 1-8 to select/create one.") return - self.viewshed_enabled = not self.viewshed_enabled + obs.viewshed_enabled = not obs.viewshed_enabled - if self.viewshed_enabled: + if obs.viewshed_enabled: print("Calculating viewshed...") - viewshed = self._calculate_viewshed() + # Temporarily set position/elev for _calculate_viewshed + viewshed = self._calculate_viewshed_for(obs) if viewshed is None: - self.viewshed_enabled = False + obs.viewshed_enabled = False print("Viewshed: OFF (calculation failed)") else: + self.viewshed_enabled = True + self._viewshed_cache = obs.viewshed_cache print(f"Viewshed: ON ({self._viewshed_coverage:.1f}% coverage)") - # Debug: verify viewshed cache - if self._viewshed_cache is not None: - print(f" Viewshed cache shape: {self._viewshed_cache.shape}") - else: - print(" WARNING: Viewshed cache is None!") else: print("Viewshed: OFF") + self.viewshed_enabled = False + self._viewshed_cache = None self._update_frame() - def _clear_observer(self): - """Clear the placed observer and viewshed.""" - # Exit drone mode if active (restore external camera) - if self._drone_mode != 'off': - if self._drone_mode == 'fpv': - self._set_drone_visibility(True) - if self._saved_camera is not None: - self.position = self._saved_camera[0] - self.yaw = self._saved_camera[1] - self.pitch = self._saved_camera[2] - self._saved_camera = None - self._drone_mode = 'off' - - self._observer_position = None - self._viewshed_cache = None - self.viewshed_enabled = False - - # Remove all drone sub-mesh geometries from scene - if self._observer_drone_placed and self.rtx is not None: - n = len(self._observer_drone_parts) if self._observer_drone_parts else 0 - builder = getattr(self, '_geometry_colors_builder', None) - acc = getattr(builder, '__self__', None) if builder else None - for i in range(n): - gid = f'_observer_{i}' - self.rtx.remove_geometry(gid) - if acc is not None and hasattr(acc, '_geometry_colors'): - acc._geometry_colors.pop(gid, None) - if acc is not None and hasattr(acc, '_geometry_colors_dirty'): - acc._geometry_colors_dirty = True - self._observer_drone_placed = False - - print("Observer cleared") - self._update_frame() + def _calculate_viewshed_for(self, obs, quiet=False): + """Calculate viewshed using an observer's position/elevation.""" + # Temporarily bridge to existing _calculate_viewshed by setting compat state + old_pos = getattr(self, '_observer_position_compat', None) + old_elev = self.viewshed_observer_elev + self._observer_position_compat = obs.position + self.viewshed_observer_elev = obs.observer_elev + result = self._calculate_viewshed(quiet=quiet) + obs.viewshed_cache = self._viewshed_cache + self._observer_position_compat = old_pos + self.viewshed_observer_elev = old_elev + return result def _adjust_observer_elevation(self, delta): - """Adjust observer elevation for viewshed calculation.""" - self.viewshed_observer_elev = max(0, self.viewshed_observer_elev + delta) - print(f"Observer height: {self.viewshed_observer_elev:.3f}") + """Adjust active observer's elevation.""" + obs = self._observers.get(self._active_observer) if self._active_observer else None + if obs is None: + print("No observer selected. Press 1-8 first.") + return + + obs.observer_elev = max(0, obs.observer_elev + delta) + print(f"Observer {obs.slot} height: {obs.observer_elev:.3f}") - # Update drone position to match new elevation - self._update_observer_drone() + self._update_observer_drone_for(obs) - # Clear cache and recalculate viewshed if enabled - if self.viewshed_enabled and self._observer_position is not None: - self._viewshed_cache = None # Clear cache to force recalculation - self._calculate_viewshed() + if obs.viewshed_enabled: + obs.viewshed_cache = None + self._calculate_viewshed_for(obs) + self._viewshed_cache = obs.viewshed_cache self._update_frame() def _save_screenshot(self): - """Save current view as PNG image.""" + """Save current view as PNG image. + + When AO is enabled, renders multiple accumulated frames for + high-quality output with smooth AA, soft shadows, AO, and DOF. + """ import datetime timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"rtxpy_screenshot_{timestamp}.png" @@ -2570,8 +4557,10 @@ def _save_screenshot(self): # Pass viewshed data directly to render if enabled viewshed_data = None observer_pos = None - if self._observer_position is not None: - observer_pos = self._observer_position + active_obs = (self._observers.get(self._active_observer) + if self._active_observer else None) + if active_obs is not None and active_obs.position is not None: + observer_pos = active_obs.position if self.viewshed_enabled and self._viewshed_cache is not None: viewshed_data = self._viewshed_cache @@ -2589,10 +4578,10 @@ def _save_screenshot(self): if builder is not None: geometry_colors = builder() - # Render at full resolution for screenshot from .analysis import render as render_func - img = render_func( - self.raster, + + # Common render kwargs + render_kwargs = dict( camera_position=tuple(self.position), look_at=tuple(self._get_look_at()), fov=self.fov, @@ -2617,16 +4606,63 @@ def _save_screenshot(self): geometry_colors=geometry_colors, ) + # Accumulated multi-frame screenshot when AO is enabled + num_frames = 64 if self.ao_enabled else 1 + + if num_frames > 1: + import cupy + from .analysis.render import _bloom, _tone_map_aces, _render_buffers + print(f"Rendering {num_frames} accumulated frames...", end='', flush=True) + + # DOF params + if self.dof_enabled: + dof_aperture = self._dof_aperture + dof_focal = self._dof_focal_distance + else: + dof_aperture = 0.0 + dof_focal = 0.0 + + d_accum = None + for i in range(num_frames): + frame_seed = i + 1 + d_frame = render_func( + self.raster, + **render_kwargs, + ao_samples=self._ao_samples_per_frame, + ao_radius=self.ao_radius, + ao_seed=i, + gi_intensity=self.gi_intensity, + gi_bounces=self.gi_bounces, + frame_seed=frame_seed, + sun_angle=1.5, + aperture=dof_aperture, + focal_distance=dof_focal, + bloom=False, + tone_map=False, + _return_gpu=True, + ) + if d_accum is None: + d_accum = d_frame.astype(cupy.float32) + else: + d_accum += d_frame + d_accum /= num_frames + + # Apply bloom and tone mapping once to the averaged result + bufs = _render_buffers + if bufs.bloom_temp is not None: + _bloom(d_accum, bufs.bloom_temp, bufs.bloom_scratch) + _tone_map_aces(d_accum) + + img = cupy.asnumpy(d_accum) + print(" done") + else: + img = render_func(self.raster, **render_kwargs) + # Convert from float [0-1] to uint8 [0-255] img_uint8 = (np.clip(img, 0, 1) * 255).astype(np.uint8) - # Save using PIL or matplotlib - try: - from PIL import Image - Image.fromarray(img_uint8).save(filename) - except ImportError: - import matplotlib.pyplot as plt - plt.imsave(filename, img) + from PIL import Image + Image.fromarray(img_uint8).save(filename) print(f"Screenshot saved: {filename}") @@ -2637,8 +4673,10 @@ def _render_frame(self): # Always show observer orb when placed; viewshed overlay only when enabled viewshed_data = None observer_pos = None - if self._observer_position is not None: - observer_pos = self._observer_position + active_obs = (self._observers.get(self._active_observer) + if self._active_observer else None) + if active_obs is not None and active_obs.position is not None: + observer_pos = active_obs.position if self.viewshed_enabled: if self._viewshed_cache is not None: viewshed_data = self._viewshed_cache @@ -2663,7 +4701,28 @@ def _render_frame(self): if builder is not None: geometry_colors = builder() - img = render( + # AO parameters: multiple samples per frame for smooth early results, + # with progressive accumulation across frames for further refinement + ao_samples = self._ao_samples_per_frame if self.ao_enabled else 0 + ao_seed = self._ao_frame_count if self.ao_enabled else 0 + + # When progressive accumulation is active, pass frame seed for AA + soft shadows + DOF + frame_seed = self._ao_frame_count + 1 if self.ao_enabled else 0 + + # Depth of field (requires progressive accumulation via AO) + if self.dof_enabled and self.ao_enabled: + dof_aperture = self._dof_aperture + dof_focal = self._dof_focal_distance + else: + dof_aperture = 0.0 + dof_focal = 0.0 + + # When progressive AO accumulation or denoising is active, defer + # bloom and tone mapping until after averaging / denoising. Both + # are non-linear operations that must act on the clean signal. + defer_post = self.ao_enabled or self.denoise_enabled + + d_output = render( self.raster, camera_position=tuple(self.position), look_at=tuple(self._get_look_at()), @@ -2689,95 +4748,278 @@ def _render_frame(self): overlay_data=self._active_overlay_data, overlay_alpha=self._overlay_alpha, geometry_colors=geometry_colors, + ao_samples=ao_samples, + ao_radius=self.ao_radius, + ao_seed=ao_seed, + gi_intensity=self.gi_intensity, + gi_bounces=self.gi_bounces, + frame_seed=frame_seed, + sun_angle=1.5, + aperture=dof_aperture, + focal_distance=dof_focal, + bloom=not defer_post, + tone_map=not defer_post, + _return_gpu=True, ) - return img + return d_output def _update_frame(self): - """Render and display a new frame.""" - img = self._render_frame() + """Full render: GPU ray trace → D2H copy → overlays → display.""" + # Sync previous frame's async D2H copy (no-op on first frame) + self._readback_stream.synchronize() + + # GPU render — returns cupy array (no D2H copy) + d_output = self._render_frame() self.frame_count += 1 - # Wind particle overlay + # Progressive AO accumulation + if self.ao_enabled: + from .analysis.render import _bloom, _tone_map_aces, _render_buffers + + # Check if camera moved — compare current state to previous + cam_state = (tuple(self.position), self.yaw, self.pitch, self.fov) + if self._prev_cam_state != cam_state: + # Camera moved: reset accumulation + self._d_ao_accum = None + self._ao_frame_count = 0 + self._prev_cam_state = cam_state + + # Accumulate + if self._d_ao_accum is None or self._d_ao_accum.shape != d_output.shape: + self._d_ao_accum = d_output.copy() + else: + self._d_ao_accum += d_output + self._ao_frame_count += 1 + + # Average the accumulated frames + d_display = self._d_ao_accum / self._ao_frame_count + else: + d_display = d_output + + # Deferred post-processing: denoise → bloom → tone map. + # These are deferred when AO or denoiser is active so they + # operate on the clean / averaged signal. + defer_post = self.ao_enabled or self.denoise_enabled + if defer_post: + if not self.ao_enabled: + from .analysis.render import _bloom, _tone_map_aces, _render_buffers + + if self.denoise_enabled: + from .rtx import denoise as _denoise + from .analysis.render import ( + _compute_camera_basis, _render_buffers as _bufs, + compute_flow, + ) + h, w = self.render_height, self.render_width + d_normals = _bufs.primary_hits.reshape(h, w, 4)[:, :, 1:4].copy() + ve = self.vertical_exaggeration + pos = self.position + look = self._get_look_at() + scaled_pos = (pos[0], pos[1], pos[2] * ve) + scaled_look = (look[0], look[1], look[2] * ve) + forward, right, cam_up = _compute_camera_basis( + scaled_pos, scaled_look, (0, 0, 1)) + + # Compute flow vectors for temporal denoising + d_flow = None + aspect = w / h + fov_scale = np.tan(np.radians(self.fov) / 2.0) + if self._prev_cam_for_flow is not None: + prev_pos, prev_fwd, prev_right, prev_up, prev_aspect, prev_fov_scale = self._prev_cam_for_flow + # Allocate / resize flow buffer + if self._d_flow is None or self._d_flow.shape != (h, w, 2): + self._d_flow = cp.zeros((h, w, 2), dtype=cp.float32) + d_prev_pos = cp.asarray(np.array(prev_pos, dtype=np.float32)) + d_prev_fwd = cp.asarray(np.array(prev_fwd, dtype=np.float32)) + d_prev_right = cp.asarray(np.array(prev_right, dtype=np.float32)) + d_prev_up = cp.asarray(np.array(prev_up, dtype=np.float32)) + compute_flow( + self._d_flow, _bufs.primary_rays, _bufs.primary_hits, + w, h, + d_prev_pos, d_prev_fwd, d_prev_right, d_prev_up, + prev_aspect, prev_fov_scale, + ) + d_flow = self._d_flow + + self._prev_cam_for_flow = ( + scaled_pos, tuple(forward), tuple(right), tuple(cam_up), + aspect, fov_scale, + ) + + _denoise(d_display, d_normals, w, h, right, cam_up, forward, + albedo=_bufs.albedo, flow=d_flow) + + bufs = _render_buffers + if bufs.bloom_temp is not None: + _bloom(d_display, bufs.bloom_temp, bufs.bloom_scratch) + _tone_map_aces(d_display) + + # Allocate pinned host buffer lazily (or on shape change) + if self._pinned_frame is None or self._pinned_frame.shape != d_display.shape: + self._pinned_mem = cp.cuda.alloc_pinned_memory(d_display.nbytes) + self._pinned_frame = np.frombuffer( + self._pinned_mem, dtype=np.float32, count=d_display.size + ).reshape(d_display.shape) + + # Start async D2H copy on non-blocking stream + d_display.get(out=self._pinned_frame, stream=self._readback_stream) + + # CPU work while DMA runs if self._wind_enabled and self._wind_particles is not None: self._update_wind_particles() - img = img.copy() # Don't modify cached render - self._draw_wind_on_frame(img) - # Update the image - self.im.set_data(img) + # Wait for DMA to complete + self._readback_stream.synchronize() - # Update suptitle with map info (only when changed) - title = self._build_title() - if title != self._last_title: - self._suptitle.set_text(title) - self._last_title = title + # Composite overlays on top of the ray-traced base frame + self._composite_overlays() - # Update subtitle with camera / observer info (only when changed) + def _composite_overlays(self): + """Composite CPU overlays (wind, minimap, help) onto the base frame. + + Can be called without re-ray-tracing to animate wind cheaply. + """ + # FPS tracking + self._fps_counter += 1 + now = time.monotonic() + elapsed = now - self._fps_last_time + if elapsed >= 1.0: + self._fps_display = self._fps_counter / elapsed + self._fps_counter = 0 + self._fps_last_time = now + + # Build window title + title = self._build_title() pos = self.position - sub = f"Pos: ({pos[0]:.0f}, {pos[1]:.0f}, {pos[2]:.0f}) Speed: {self.move_speed:.0f}" - if self._observer_position is not None: - obs_x, obs_y = self._observer_position - sub += f" \u2502 Observer: ({obs_x:.0f}, {obs_y:.0f}) h={self.viewshed_observer_elev:.3f}" + fps = self._fps_display + sub = f"{fps:.0f} FPS Pos: ({pos[0]:.0f}, {pos[1]:.0f}, {pos[2]:.0f}) Speed: {self.move_speed:.0f}" + if self._observers: + obs_parts = [] + for slot in sorted(self._observers): + obs = self._observers[slot] + marker = '*' if slot == self._active_observer else '' + mode = '' + if obs.drone_mode != 'off': + mode = f' {obs.drone_mode.upper()}' + if obs.is_touring(): + mode += ' TOUR' + obs_parts.append(f"{slot}{marker}{mode}") + sub += f" \u2502 Obs: [{' '.join(obs_parts)}]" + active_obs = (self._observers.get(self._active_observer) + if self._active_observer else None) + if active_obs is not None: + sub += f" h={active_obs.observer_elev:.3f}" if self.viewshed_enabled: sub += f" Coverage: {self._viewshed_coverage:.1f}%" - if sub != self._last_subtitle: - self.ax.set_title(sub, fontsize=8, color='#aaaaaa', pad=2) - self._last_subtitle = sub - # Update help text - if self.show_help: - self.help_text.set_visible(True) + combined = f"{title} | {sub}" + if combined != self._last_title: + self._last_title = combined + if self._glfw_window is not None: + import glfw + glfw.set_window_title(self._glfw_window, combined) + + # Build display frame (copy if we need overlays, else use pinned directly) + needs_overlay = ( + (self._wind_enabled and self._wind_particles is not None) + or (self._gtfs_rt_enabled and self._gtfs_rt_vehicles is not None) + or self.show_minimap + or self.show_help + ) + if needs_overlay: + img = self._pinned_frame.copy() else: - self.help_text.set_visible(False) + img = self._pinned_frame + + # Wind overlay + if self._wind_enabled and self._wind_particles is not None: + self._draw_wind_on_frame(img) + + # GTFS-RT vehicle overlay + if self._gtfs_rt_enabled and self._gtfs_rt_vehicles is not None: + self._draw_gtfs_rt_on_frame(img) - self._update_minimap() + # Minimap overlay + self._blit_minimap_on_frame(img) - self.fig.canvas.draw_idle() - self.fig.canvas.flush_events() + # Help text overlay + if self.show_help and self._help_text_rgba is not None: + self._blit_help_on_frame(img) - def _handle_scroll(self, event): - """Handle mouse scroll wheel for zoom.""" - if event.step > 0: + self._display_frame = img + self._frame_dirty = True + + def _handle_scroll(self, yoffset): + """Handle mouse scroll wheel for zoom. + + Parameters + ---------- + yoffset : float + Scroll amount (positive = scroll up = zoom in). + """ + if yoffset > 0: self.fov = max(20, self.fov - 3) else: self.fov = min(120, self.fov + 3) print(f"FOV: {self.fov:.0f}") self._update_frame() - def _handle_mouse_press(self, event): - """Start drag on left-click inside the main axes.""" - if event.inaxes != self.ax: - return - if event.button == 1: + def _handle_mouse_press(self, button, xpos, ypos): + """Start drag on left-click, or teleport if click is on minimap. + + Parameters + ---------- + button : int + Mouse button (0 = left, 1 = right, 2 = middle). + xpos, ypos : float + Cursor position in window pixels. + """ + if button == 0: # left click + # Check for minimap click-to-teleport + if self._minimap_rect is not None and self.show_minimap: + mx0, my0, mw, mh = self._minimap_rect + # Convert window coords to frame (render) coords + frame_x = xpos * self.render_width / max(1, self.width) + frame_y = ypos * self.render_height / max(1, self.height) + if (mx0 <= frame_x < mx0 + mw and my0 <= frame_y < my0 + mh): + # Convert minimap-local → terrain pixel → world XY + local_x = frame_x - mx0 + local_y = frame_y - my0 + H, W = self.terrain_shape + terrain_col = local_x / mw * W + terrain_row = local_y / mh * H + world_x = terrain_col * self.pixel_spacing_x + world_y = terrain_row * self.pixel_spacing_y + self.position[0] = world_x + self.position[1] = world_y + self._update_frame() + return + self._mouse_dragging = True - self._mouse_last_x = event.x - self._mouse_last_y = event.y + self._mouse_last_x = xpos + self._mouse_last_y = ypos - def _handle_mouse_release(self, event): - """End drag on any button release.""" + def _handle_mouse_release(self, button): + """End drag on button release.""" self._mouse_dragging = False - def _handle_mouse_motion(self, event): - """Pan camera on mouse drag (slippy-map style).""" + def _handle_mouse_motion(self, xpos, ypos): + """Pan camera on mouse drag (slippy-map style). + + Parameters + ---------- + xpos, ypos : float + Cursor position in screen pixels. + """ if not self._mouse_dragging or self._mouse_last_x is None: return - if event.x is None or event.y is None: - self._mouse_dragging = False - return - # Cancel drag if mouse left the axes (missed release event) - if event.inaxes != self.ax: - self._mouse_dragging = False - return - # Cancel drag if no button is held (missed release event) - if hasattr(event, 'button') and event.button is None: - self._mouse_dragging = False - return - dx = event.x - self._mouse_last_x - dy = event.y - self._mouse_last_y - self._mouse_last_x = event.x - self._mouse_last_y = event.y + dx = xpos - self._mouse_last_x + # GLFW Y is top-down; invert so dragging up → positive dy + dy = -(ypos - self._mouse_last_y) + self._mouse_last_x = xpos + self._mouse_last_y = ypos if dx == 0 and dy == 0: return @@ -2799,12 +5041,215 @@ def _handle_mouse_motion(self, event): front_horiz = np.array([0, 1, 0], dtype=np.float32) # Scene follows cursor: drag right → camera left - # Drag forward (mouse moves up, dy > 0) → camera moves forward self.position -= right * dx * sensitivity self.position -= front_horiz * dy * sensitivity self._update_frame() + def _render_help_text(self): + """Pre-render help text to an RGBA numpy array using PIL. + + Called once at startup; the result is cached in self._help_text_rgba. + Two-column layout with styled section headers and key highlighting. + """ + # Two columns of (section_title, [(key, description), ...]) + col_left = [ + ("MOVEMENT", [ + ("W/S/A/D", "Move fwd / back / left / right"), + ("Arrows", "Move fwd / back / left / right"), + ("Q / E", "Move up / down"), + ("I/J/K/L", "Look up / left / down / right"), + ("Drag", "Pan (slippy-map)"), + ("Scroll", "Zoom (FOV)"), + ("+ / -", "Speed up / down"), + ]), + ("TERRAIN", [ + ("G", "Cycle terrain layer"), + ("U", "Cycle basemap"), + ("C", "Cycle colormap"), + ("Y", "Cycle color stretch"), + (", / .", "Overlay alpha"), + ("R / Shift+R", "Resolution down / up"), + ("Z / Shift+Z", "Vert. exag. down / up"), + ("B", "Toggle TIN / Voxel"), + ("T", "Toggle shadows"), + ("Shift+T", "Cycle time of day"), + ]), + ("DATA LAYERS", [ + ("Shift+F", "FIRMS fire (7d)"), + ("Shift+W", "Toggle wind"), + ]), + ] + col_right = [ + ("RENDERING", [ + ("0", "Toggle ambient occlusion"), + ("Shift+G", "Cycle GI bounces (1-3)"), + ("Shift+D", "Toggle AI denoiser"), + ("9", "Toggle depth of field"), + ("; / '", "DOF aperture down / up"), + (": / \"", "DOF focal dist. down / up"), + ]), + ("GEOMETRY", [ + ("N", "Cycle geometry layer"), + ("P", "Prev geometry in group"), + ]), + ("OBSERVERS", [ + ("1-8", "Select / create observer"), + ("O", "Move observer to camera"), + ("Shift+O", "Drone mode (3rd / FPV)"), + ("Shift+V", "Snap camera to observer"), + ("Shift+K", "Kill all observers"), + ("V", "Toggle viewshed"), + ("[ / ]", "Observer height down / up"), + ]), + ("OTHER", [ + ("F", "Screenshot"), + ("M", "Toggle minimap"), + ("H", "Toggle this help"), + ("X / Esc", "Exit"), + ]), + ] + + try: + from PIL import Image, ImageDraw, ImageFont + + font_size = 12 + header_size = 13 + mono_path = "/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf" + bold_path = "/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf" + sans_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" + sans_bold_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" + try: + font = ImageFont.truetype(sans_path, font_size) + font_key = ImageFont.truetype(mono_path, font_size) + font_header = ImageFont.truetype(sans_bold_path, header_size) + except (OSError, IOError): + font = ImageFont.load_default() + font_key = font + font_header = font + + line_h = font_size + 5 + header_h = header_size + 8 + section_gap = 6 + key_col_w = 105 # width reserved for keys + desc_col_w = 195 # width for descriptions + col_w = key_col_w + desc_col_w + col_gap = 20 + pad_x = 14 + pad_y = 12 + corner_r = 10 + + # Colors + bg_color = (15, 18, 24, 210) # dark blue-black, 82% opaque + header_color = (180, 210, 255, 255) # light blue + key_color = (255, 200, 100, 245) # warm amber + desc_color = (210, 215, 225, 220) # soft white + separator_color = (80, 90, 110, 120) # subtle line + accent_color = (90, 140, 220, 180) # blue accent for header underline + + def _column_height(sections): + h = 0 + for i, (title, items) in enumerate(sections): + if i > 0: + h += section_gap + h += header_h + 3 # header + underline space + h += len(items) * line_h + return h + + left_h = _column_height(col_left) + right_h = _column_height(col_right) + content_h = max(left_h, right_h) + footer_h = header_h + section_gap # space for "Press H to close" + img_w = pad_x * 2 + col_w * 2 + col_gap + img_h = pad_y * 2 + content_h + footer_h + + # Create with transparent background, then draw rounded rect + img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + + # Rounded rectangle background + draw.rounded_rectangle( + [0, 0, img_w - 1, img_h - 1], + radius=corner_r, fill=bg_color, + ) + + # Subtle border + draw.rounded_rectangle( + [0, 0, img_w - 1, img_h - 1], + radius=corner_r, outline=(60, 70, 90, 140), width=1, + ) + + def _draw_column(sections, x_start, y_start): + y = y_start + for si, (title, items) in enumerate(sections): + if si > 0: + y += section_gap + # Section header + draw.text((x_start, y), title, fill=header_color, + font=font_header) + # Accent underline + underline_y = y + header_size + 2 + draw.line( + [(x_start, underline_y), + (x_start + col_w - 10, underline_y)], + fill=accent_color, width=1) + y = underline_y + 3 + + # Key-description rows + for key_text, desc_text in items: + draw.text((x_start, y), key_text, + fill=key_color, font=font_key) + draw.text((x_start + key_col_w, y), desc_text, + fill=desc_color, font=font) + y += line_h + + _draw_column(col_left, pad_x, pad_y) + _draw_column(col_right, pad_x + col_w + col_gap, pad_y) + + # Vertical separator between columns + sep_x = pad_x + col_w + col_gap // 2 + draw.line( + [(sep_x, pad_y + 4), (sep_x, pad_y + content_h - 4)], + fill=separator_color, width=1) + + # Bold "Press H to close" footer, centered + footer_text = "Press H to close" + bbox = font_header.getbbox(footer_text) + fw = bbox[2] - bbox[0] + footer_x = (img_w - fw) // 2 + footer_y = pad_y + content_h + section_gap + draw.text((footer_x, footer_y), footer_text, + fill=header_color, font=font_header) + + self._help_text_rgba = np.array(img, dtype=np.float32) / 255.0 + except ImportError: + self._help_text_rgba = None + + def _blit_help_on_frame(self, img): + """Alpha-composite cached help text onto the rendered frame. + + Parameters + ---------- + img : ndarray, shape (H, W, 3), float32 0-1 + Rendered frame. Modified in-place. + """ + if self._help_text_rgba is None: + return + ht = self._help_text_rgba + hh, hw = ht.shape[:2] + fh, fw = img.shape[:2] + # Top-left with small margin + margin = 8 + # Clamp to frame size + bh = min(hh, fh - margin) + bw = min(hw, fw - margin) + if bh <= 0 or bw <= 0: + return + alpha = ht[:bh, :bw, 3:4] + rgb = ht[:bh, :bw, :3] + region = img[margin:margin+bh, margin:margin+bw] + region[:] = region * (1 - alpha) + rgb * alpha + def run(self, start_position: Optional[Tuple[float, float, float]] = None, look_at: Optional[Tuple[float, float, float]] = None): """ @@ -2818,53 +5263,8 @@ def run(self, start_position: Optional[Tuple[float, float, float]] = None, look_at : tuple, optional Initial look-at point. If None, looks toward terrain center. """ - import matplotlib - import matplotlib.pyplot as plt - - # Check if we're in a Jupyter notebook and need to switch backends - current_backend = matplotlib.get_backend().lower() - in_notebook = False - try: - from IPython import get_ipython - ipy = get_ipython() - if ipy is not None and 'IPKernelApp' in ipy.config: - in_notebook = True - except (ImportError, AttributeError): - pass - - # Warn if using a non-interactive backend - non_interactive_backends = ['agg', 'module://matplotlib_inline.backend_inline', 'inline'] - if any(nb in current_backend for nb in non_interactive_backends): - if in_notebook: - print("\n" + "="*70) - print("WARNING: Matplotlib is using a non-interactive backend.") - print("Keyboard controls will NOT work with the inline backend.") - print("\nTo fix, run this BEFORE calling explore():") - print(" %matplotlib qt") - print(" OR") - print(" %matplotlib tk") - print(" OR (if ipympl is installed):") - print(" %matplotlib widget") - print("\nThen restart the kernel and run the notebook again.") - print("="*70 + "\n") - else: - print("WARNING: Non-interactive matplotlib backend detected.") - print("Keyboard controls may not work.") - - # Disable default matplotlib keybindings that conflict with our controls - for key in ['s', 'q', 'l', 'k', 'a', 'w', 'e', 'c', 'h', 't']: - if key in plt.rcParams.get('keymap.save', []): - plt.rcParams['keymap.save'].remove(key) - if key in plt.rcParams.get('keymap.quit', []): - plt.rcParams['keymap.quit'].remove(key) - if key in plt.rcParams.get('keymap.xscale', []): - plt.rcParams['keymap.xscale'].remove(key) - if key in plt.rcParams.get('keymap.yscale', []): - plt.rcParams['keymap.yscale'].remove(key) - # Clear all default keymaps to avoid conflicts - for param in list(plt.rcParams.keys()): - if param.startswith('keymap.'): - plt.rcParams[param] = [] + import glfw + import moderngl H, W = self.terrain_shape @@ -2901,102 +5301,100 @@ def run(self, start_position: Optional[Tuple[float, float, float]] = None, self.yaw = np.degrees(np.arctan2(direction[1], direction[0])) self.pitch = np.degrees(np.arcsin(np.clip(direction[2], -1, 1))) - # Create figure (suppress the matplotlib navigation toolbar) - old_toolbar = plt.rcParams.get('toolbar', 'toolbar2') - plt.rcParams['toolbar'] = 'None' - plt.ion() # Interactive mode - self.fig, self.ax = plt.subplots(1, 1, figsize=(self.width/100, self.height/100), dpi=100) - plt.rcParams['toolbar'] = old_toolbar - self.fig.patch.set_facecolor('black') - self.ax.set_facecolor('black') - self.ax.axis('off') - self.fig.subplots_adjust(left=0, right=1, top=0.92, bottom=0) - - # Main title (map info, updated each frame) - self._suptitle = self.fig.suptitle( - self._build_title(), fontsize=11, color='white', - fontweight='bold', y=0.98, - ) + # --- GLFW window creation --- + if not glfw.init(): + raise RuntimeError("Failed to initialise GLFW") + + glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 3) + glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 3) + glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) + glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, True) - # Render initial frame - img = self._render_frame() - self.im = self.ax.imshow(img, aspect='auto') - - # Help text — vertical list down the left side - help_str = ( - "MOVEMENT\n" - " W/S/A/D Move forward/back/left/right\n" - " Arrows Move forward/back/left/right\n" - " Q / E Move up / down\n" - " I/J/K/L Look up/left/down/right\n" - " Drag Pan (slippy-map)\n" - " Scroll Zoom (FOV)\n" - " + / - Speed up / down\n" - "\n" - "TERRAIN\n" - " G Cycle terrain layer\n" - " U Cycle basemap\n" - " C Cycle colormap\n" - " Y Cycle color stretch\n" - " , / . Overlay alpha down / up\n" - " R Decrease resolution\n" - " Shift+R Increase resolution\n" - " Z Decrease vert. exag.\n" - " Shift+Z Increase vert. exag.\n" - " B Toggle TIN / Voxel\n" - " T Toggle shadows\n" - "\n" - "GEOMETRY\n" - " N Cycle geometry layer\n" - " P Prev geometry in group\n" - "\n" - "DRONE / VIEWSHED\n" - " O Place observer\n" - " Shift+O Drone mode (3rd/FPV)\n" - " Shift+V Snap camera to drone\n" - " V Toggle viewshed\n" - " [ / ] Observer height down / up\n" - "\n" - "DATA LAYERS\n" - " Shift+F FIRMS fire (7d)\n" - " Shift+W Toggle wind\n" - "\n" - "OTHER\n" - " F Screenshot\n" - " M Toggle minimap\n" - " H Toggle this help\n" - " X / Esc Exit" + window = glfw.create_window( + self.width, self.height, + f'rtxpy \u2014 {self._title}', None, None, ) - self.help_text = self.ax.text( - 0.01, 0.98, help_str, - transform=self.ax.transAxes, - fontsize=11, - color='white', - alpha=0.9, - verticalalignment='top', - fontfamily='monospace', - bbox=dict(boxstyle='round,pad=0.6', facecolor='black', alpha=0.6) + if not window: + glfw.terminate() + raise RuntimeError("Failed to create GLFW window") + + glfw.make_context_current(window) + glfw.swap_interval(0) # No VSync — render as fast as GPU allows + + self._glfw_window = window + + # --- ModernGL context + fullscreen quad --- + ctx = moderngl.create_context() + prog = ctx.program(vertex_shader=_QUAD_VERT, fragment_shader=_QUAD_FRAG) + + # Fullscreen quad: position (x, y) + UV (u, v) + # V is flipped: v=1 at bottom of screen maps to row 0 (top of image), + # because OpenGL textures start at the bottom but numpy row 0 is top. + quad_data = np.array([ + # x, y, u, v + -1.0, -1.0, 0.0, 1.0, # bottom-left → top of image + 1.0, -1.0, 1.0, 1.0, # bottom-right → top of image + -1.0, 1.0, 0.0, 0.0, # top-left → bottom of image + 1.0, 1.0, 1.0, 0.0, # top-right → bottom of image + ], dtype='f4') + vbo = ctx.buffer(quad_data.tobytes()) + vao = ctx.simple_vertex_array(prog, vbo, 'in_pos', 'in_uv') + + # Frame texture — sized to render resolution, updated every frame + frame_tex = ctx.texture( + (self.render_width, self.render_height), 3, dtype='f4', ) + frame_tex.filter = (moderngl.LINEAR, moderngl.LINEAR) - # Initialize minimap - self._compute_minimap_background() - self._create_minimap() + # --- Pre-render help text overlay --- + self._render_help_text() - # Connect event handlers - self.fig.canvas.mpl_connect('key_press_event', self._handle_key_press) - self.fig.canvas.mpl_connect('key_release_event', self._handle_key_release) - self.fig.canvas.mpl_connect('scroll_event', self._handle_scroll) - self.fig.canvas.mpl_connect('button_press_event', self._handle_mouse_press) - self.fig.canvas.mpl_connect('button_release_event', self._handle_mouse_release) - self.fig.canvas.mpl_connect('motion_notify_event', self._handle_mouse_motion) + # --- Initialize minimap --- + self._compute_minimap_background() - # Set up timer for smooth key repeat - self._timer = self.fig.canvas.new_timer(interval=self._tick_interval) - self._timer.add_callback(self._tick) - self._timer.start() + # --- GLFW callbacks --- + viewer = self # closure reference - # Window title bar - self.fig.canvas.manager.set_window_title(f'rtxpy \u2014 {self._title}') + def _key_cb(_win, glfw_key, _scancode, action, mods): + raw_key, key_lower = _glfw_to_key(glfw_key, mods) + if not raw_key: + return + if action == glfw.PRESS: + viewer._handle_key_press(raw_key, key_lower) + elif action == glfw.RELEASE: + viewer._handle_key_release(key_lower) + + def _scroll_cb(_win, _xoffset, yoffset): + viewer._handle_scroll(yoffset) + + def _mouse_btn_cb(_win, button, action, _mods): + xpos, ypos = glfw.get_cursor_pos(_win) + if action == glfw.PRESS: + viewer._handle_mouse_press(button, xpos, ypos) + elif action == glfw.RELEASE: + viewer._handle_mouse_release(button) + + def _cursor_cb(_win, xpos, ypos): + viewer._handle_mouse_motion(xpos, ypos) + + def _framebuffer_size_cb(_win, fb_width, fb_height): + if fb_width <= 0 or fb_height <= 0: + return # minimised + viewer.width = fb_width + viewer.height = fb_height + viewer.render_width = int(fb_width * viewer.render_scale) + viewer.render_height = int(fb_height * viewer.render_scale) + ctx.viewport = (0, 0, fb_width, fb_height) + # Invalidate pinned buffer so it's re-allocated at new size + viewer._pinned_frame = None + viewer._pinned_mem = None + viewer._render_needed = True + + glfw.set_key_callback(window, _key_cb) + glfw.set_scroll_callback(window, _scroll_cb) + glfw.set_mouse_button_callback(window, _mouse_btn_cb) + glfw.set_cursor_pos_callback(window, _cursor_cb) + glfw.set_framebuffer_size_callback(window, _framebuffer_size_cb) print(f"\nInteractive Viewer Started") print(f" Window: {self.width}x{self.height}") @@ -3005,20 +5403,128 @@ def run(self, start_position: Optional[Tuple[float, float, float]] = None, print(f"\nPress H for controls, X or Esc to exit\n") self.running = True - self._update_frame() + self._display_frame = None + self._frame_dirty = False + self._render_needed = True # Ensure first frame renders + self._fps_counter = 0 + self._fps_last_time = time.monotonic() + self._last_tick_time = time.monotonic() + + # Render the initial frame so the window isn't blank + self._tick() + + # --- REPL thread --- + if self._repl: + proxy = ViewerProxy(self) + + def _run_repl(): + # Auto-play tour if one was provided + if getattr(self, '_tour', None) is not None: + import time as _time + _time.sleep(0.5) # let first frames render + try: + proxy.tour(self._tour) + except Exception as exc: + print(f"Tour error: {exc}") + + banner = ( + "\nrtxpy interactive REPL\n" + "Use `v` (the viewer proxy) to interact with the scene.\n" + "Examples:\n" + " v.hillshade(shadows=True)\n" + " v.viewshed(x=500, y=300)\n" + " v.add_layer('slope', slope(v.raster).data)\n" + " v.set_colormap('viridis')\n" + " v.shadows = False\n" + "Type exit() or close the window to quit.\n" + ) + ns = { + 'v': proxy, + 'viewer': proxy, + 'np': np, + } + try: + import xarray + ns['xr'] = xarray + except ImportError: + pass + try: + from IPython.terminal.embed import InteractiveShellEmbed + shell = InteractiveShellEmbed( + banner1=banner, user_ns=ns, exit_msg='') + shell() + except ImportError: + import code + code.interact(banner=banner, local=ns) + # When REPL exits, close the viewer window + self.running = False - # Keep window open until closed - plt.show(block=True) + repl_thread = threading.Thread( + target=_run_repl, daemon=True, name='rtxpy-repl') + repl_thread.start() - # Clean up timer - if self._timer is not None: - self._timer.stop() - self._timer = None + # --- Main loop --- + try: + while not glfw.window_should_close(window) and self.running: + self._tick() + + # Drain REPL command queue (thread-safe) + while True: + try: + cmd = self._command_queue.get_nowait() + except queue.Empty: + break + try: + cmd(self) + except Exception as exc: + import traceback + traceback.print_exc() + self._render_needed = True + + # Upload frame to texture and render only when dirty + if self._frame_dirty and self._display_frame is not None: + tex_w, tex_h = frame_tex.size + fh, fw = self._display_frame.shape[:2] + if fw != tex_w or fh != tex_h: + frame_tex.release() + frame_tex = ctx.texture((fw, fh), 3, dtype='f4') + frame_tex.filter = (moderngl.LINEAR, moderngl.LINEAR) + frame_tex.write(self._display_frame) + frame_tex.use() + ctx.clear() + vao.render(moderngl.TRIANGLE_STRIP) + glfw.swap_buffers(window) + self._frame_dirty = False + + glfw.poll_events() + + # Idle: yield CPU when nothing is happening (no movement, + # no pending render). Keeps input polling responsive at + # ~120 Hz while avoiding a busy-wait spin. + if not self._held_keys and not self._mouse_dragging: + time.sleep(0.008) + finally: + # --- Cleanup --- + frame_tex.release() + vbo.release() + vao.release() + prog.release() + ctx.release() + glfw.destroy_window(window) + glfw.terminate() + self._glfw_window = None + # Reset terminal state (GLFW can hide cursor / alter termios) + import sys + sys.stdout.write('\033[?25h') # show cursor + sys.stdout.flush() # Clean up tile service if self._tile_service is not None: self._tile_service.shutdown() + # Clean up GTFS-RT thread + self._cleanup_gtfs_rt() + print(f"Viewer closed after {self.frame_count} frames") @@ -3029,7 +5535,7 @@ def explore(raster, width: int = 800, height: int = 600, key_repeat_interval: float = 0.05, rtx: 'RTX' = None, pixel_spacing_x: float = 1.0, pixel_spacing_y: float = 1.0, - mesh_type: str = 'tin', + mesh_type: str = 'heightfield', overlay_layers: dict = None, color_stretch: str = 'linear', title: str = None, @@ -3038,12 +5544,19 @@ def explore(raster, width: int = 800, height: int = 600, baked_meshes=None, subsample: int = 1, wind_data=None, + gtfs_data=None, accessor=None, - terrain_loader=None): + terrain_loader=None, + scene_zarr=None, + ao_samples: int = 0, + gi_bounces: int = 1, + denoise: bool = False, + repl: bool = False, + tour=None): """ Launch an interactive terrain viewer. - Uses matplotlib for display - no additional dependencies required. + Uses GLFW + ModernGL for display. Keyboard controls allow flying through the terrain. Parameters @@ -3075,12 +5588,31 @@ def explore(raster, width: int = 800, height: int = 600, Y spacing between pixels in world units. Default 1.0. mesh_type : str, optional Mesh generation method: 'tin' or 'voxel'. Default is 'tin'. + scene_zarr : str or Path, optional + Path to a zarr store with a ``meshes/`` group. When provided, + mesh chunks are loaded dynamically based on camera position + instead of loading the full scene upfront. accessor : RTXAccessor, optional RTX accessor instance for on-demand data fetching (e.g. FIRMS fire layer via Shift+F). wind_data : dict, optional Wind data from ``fetch_wind()``. If provided, Shift+W toggles wind particle animation. + gtfs_data : dict, optional + GTFS data from ``fetch_gtfs()``. If the metadata contains a + ``realtime_url``, Shift+B toggles realtime vehicle positions. + ao_samples : int, optional + If > 0, enable ambient occlusion on launch with progressive + accumulation (1 sample per frame). Press 0 to toggle at runtime. + Default is 0 (disabled). + denoise : bool, optional + If True, enable the OptiX AI Denoiser on launch. Press Shift+D + to toggle at runtime. Default is False. + tour : list of dict or str, optional + If provided, automatically play a camera tour after the viewer + launches. Can be a list of keyframe dicts or a path to a + ``.py`` file that defines a ``tour`` variable. Implies + ``repl=True`` — the REPL starts after the tour finishes. Controls -------- @@ -3096,7 +5628,7 @@ def explore(raster, width: int = 800, height: int = 600, - +/=: Increase speed - -: Decrease speed - G: Cycle terrain color (elevation → overlays) - - U: Cycle basemap (none → satellite → osm → topo) + - U: Cycle basemap (none → satellite → osm) - N: Cycle geometry layer (none → all → groups) - P: Jump to previous geometry in current group - ,/.: Decrease/increase overlay alpha (transparency) @@ -3111,9 +5643,13 @@ def explore(raster, width: int = 800, height: int = 600, - B: Toggle mesh type (TIN / voxel) - Y: Cycle color stretch (linear, sqrt, cbrt, log) - T: Toggle shadows + - 0: Toggle ambient occlusion (progressive) + - Shift+G: Cycle GI bounces (1→2→3→1) + - Shift+D: Toggle OptiX AI Denoiser - C: Cycle colormap - Shift+F: Fetch/toggle FIRMS fire layer (7d LANDSAT 30m) - Shift+W: Toggle wind particle animation + - Shift+B: Toggle GTFS-RT realtime vehicle overlay - F: Save screenshot - M: Toggle minimap overlay - H: Toggle help overlay @@ -3129,7 +5665,15 @@ def explore(raster, width: int = 800, height: int = 600, >>> # Or via accessor >>> dem.rtx.explore() """ - viewer = InteractiveViewer( + # Auto-detect Jupyter and use widget-based viewer + from .notebook import _detect_jupyter + if _detect_jupyter(): + from .notebook import JupyterViewer + ViewerClass = JupyterViewer + else: + ViewerClass = InteractiveViewer + + viewer = ViewerClass( raster, width=width, height=height, @@ -3147,6 +5691,9 @@ def explore(raster, width: int = 800, height: int = 600, viewer._baked_meshes = baked_meshes or {} viewer._accessor = accessor viewer._terrain_loader = terrain_loader + if scene_zarr is not None: + viewer._chunk_manager = _MeshChunkManager( + scene_zarr, pixel_spacing_x, pixel_spacing_y) viewer.color_stretch = color_stretch if color_stretch in viewer._color_stretches: viewer._color_stretch_idx = viewer._color_stretches.index(color_stretch) @@ -3163,6 +5710,41 @@ def explore(raster, width: int = 800, height: int = 600, if wind_data is not None: viewer._init_wind(wind_data) + # GTFS-RT initialization + if gtfs_data is not None: + rt_url = (gtfs_data.get('metadata') or {}).get('realtime_url') + if rt_url: + # Build route_id -> (r,g,b) colour map from route features + rc_map = {} + for f in gtfs_data.get('routes', {}).get('features', []): + props = f.get('properties') or {} + rc = (props.get('route_color') or '').strip().lstrip('#') + rname = props.get('route_short_name', '') + if len(rc) == 6: + try: + rgb = (int(rc[0:2], 16) / 255.0, + int(rc[2:4], 16) / 255.0, + int(rc[4:6], 16) / 255.0) + # Key by route_short_name since GTFS-RT uses route_id + # which may differ; store both for best chance of matching + if rname: + rc_map[rname] = rgb + rid = props.get('route_id', '') + if rid: + rc_map[rid] = rgb + except ValueError: + pass + viewer._init_gtfs_rt(rt_url, route_colors=rc_map) + + # Ambient occlusion initialization + if ao_samples > 0: + viewer.ao_enabled = True + viewer.gi_bounces = gi_bounces + + # Denoiser initialization + if denoise: + viewer.denoise_enabled = True + # Initial state: everything off except elevation viewer._tiles_enabled = False viewer._basemap_idx = 0 # 'none' @@ -3171,4 +5753,8 @@ def explore(raster, width: int = 800, height: int = 600, for geom_id in viewer._all_geometries: if geom_id != 'terrain': rtx.set_geometry_visible(geom_id, False) - viewer.run(start_position=start_position, look_at=look_at) + if tour is not None: + repl = True + viewer._repl = repl + viewer._tour = tour + return viewer.run(start_position=start_position, look_at=look_at) diff --git a/rtxpy/geojson.py b/rtxpy/geojson.py index d63d516..0dde892 100644 --- a/rtxpy/geojson.py +++ b/rtxpy/geojson.py @@ -204,7 +204,10 @@ def _geojson_to_world_coords(coords_lonlat, raster, terrain_data, psx, psy, cols = (x_crs - x_coords[0]) / dx rows = (y_crs - y_coords[0]) / dy - # Clip to raster extent + # Track out-of-bounds coords (for warning) but keep true positions. + # Only clamp the sampling indices used for terrain-Z lookup so roads + # and other line features extend naturally past the DEM edge instead + # of being warped toward the boundary. out_of_bounds = ( (cols < -0.5) | (cols > W - 0.5) | (rows < -0.5) | (rows > H - 0.5) @@ -212,14 +215,16 @@ def _geojson_to_world_coords(coords_lonlat, raster, terrain_data, psx, psy, n_oob = int(np.sum(out_of_bounds)) if np.any(out_of_bounds) else 0 if n_oob > 0 and oob_counter is not None: oob_counter[0] += n_oob - cols = np.clip(cols, 0, W - 1) - rows = np.clip(rows, 0, H - 1) + + # Clamp only for Z sampling — keep original cols/rows for world XY + cols_z = np.clip(cols, 0, W - 1) + rows_z = np.clip(rows, 0, H - 1) # --- Sample terrain Z (bilinear to match triangle mesh surface) ------ - x0 = np.clip(np.floor(cols).astype(int), 0, W - 2) - y0 = np.clip(np.floor(rows).astype(int), 0, H - 2) - fx = cols - x0 - fy = rows - y0 + x0 = np.clip(np.floor(cols_z).astype(int), 0, W - 2) + y0 = np.clip(np.floor(rows_z).astype(int), 0, H - 2) + fx = cols_z - x0 + fy = rows_z - y0 z00 = terrain_data[y0, x0].astype(np.float64) z10 = terrain_data[y0, x0 + 1].astype(np.float64) z01 = terrain_data[y0 + 1, x0].astype(np.float64) @@ -467,6 +472,9 @@ def _densify_on_terrain(world_coords, terrain_data, psx, psy, step=1.0): dy_px = (p1[1] - p0[1]) / psy dist_px = np.sqrt(dx_px ** 2 + dy_px ** 2) + if not np.isfinite(dist_px): + result.append(p1) + continue n_sub = max(1, int(np.ceil(dist_px / step))) for j in range(1, n_sub + 1): @@ -656,6 +664,114 @@ def _linestring_to_ribbon_mesh(world_coords, width=1.0, hover=0.2, return verts.ravel(), tris.ravel() +def _linestring_to_curve_data(world_coords, width=1.0, hover=0.2, + closed=False): + """Convert a polyline to round quadratic B-spline curve tube data. + + Produces control points, per-vertex widths (radii), and segment + indices for OptiX ``ROUND_QUADRATIC_BSPLINE`` curve primitives. + + Parameters + ---------- + world_coords : np.ndarray + (N, 3) array of world-space positions along the line. + width : float + Curve tube radius. + hover : float + Small Z offset above terrain to avoid z-fighting. + closed : bool + If True, connect last point back to first. + + Returns + ------- + tuple or None + ``(vertices, widths, indices)`` where vertices is flat float32, + widths is flat float32, and indices is flat int32. + Returns None if the input has fewer than 2 points. + """ + pts = np.asarray(world_coords, dtype=np.float32).copy() + N = len(pts) + if N < 2: + return None + + if closed and N > 2: + if not np.allclose(pts[0], pts[-1], atol=1e-3): + pts = np.vstack([pts, pts[0:1]]) + N = len(pts) + + # Pad 2-point lines to 3 points (minimum for 1 segment) + if N == 2: + mid = (pts[0] + pts[1]) / 2.0 + pts = np.array([pts[0], mid, pts[1]], dtype=np.float32) + N = 3 + + pts[:, 2] += hover + + # Phantom endpoints: quadratic B-splines don't interpolate their + # first/last control points — the curve starts at midpoint(cp0, cp1) + # and ends at midpoint(cpN-2, cpN-1). Duplicating the endpoints + # makes the curve reach the actual road positions, closing gaps at + # intersections where multiple roads meet. + if closed and N > 3: + # Wrap around: add the penultimate point before start and the + # second point after end so the B-spline loops seamlessly + pts = np.vstack([pts[-2:-1], pts, pts[1:2]]) + else: + # Open curve: duplicate first and last control points + pts = np.vstack([pts[0:1], pts, pts[-1:]]) + N = len(pts) + + num_segments = N - 2 + + vertices = pts.ravel() # N*3 floats + widths = np.full(N, width, dtype=np.float32) # constant radius + indices = np.arange(num_segments, dtype=np.int32) + + return vertices, widths, indices + + +def _polygon_to_curve_data(rings_world_coords, width=1.0, hover=0.2): + """Build curve tube data for a polygon's exterior and interior rings. + + Parameters + ---------- + rings_world_coords : list of np.ndarray + Each element is an (N, 3) array for one ring. + width : float + Curve tube radius. + hover : float + Small Z offset above terrain. + + Returns + ------- + tuple or None + ``(vertices, widths, indices)`` concatenated across all rings, + or None if no valid rings. + """ + all_verts = [] + all_widths = [] + all_indices = [] + vert_offset = 0 + + for ring_coords in rings_world_coords: + result = _linestring_to_curve_data( + ring_coords, width=width, hover=hover, closed=True, + ) + if result is None: + continue + v, w, idx = result + all_verts.append(v) + all_widths.append(w) + all_indices.append(idx + vert_offset) + vert_offset += len(v) // 3 + + if not all_verts: + return None + + return (np.concatenate(all_verts), np.concatenate(all_widths), + np.concatenate(all_indices)) + + def _polygon_to_ribbon_mesh(rings_world_coords, width=1.0, hover=0.2): """Build ribbon meshes for a polygon's exterior and interior rings. @@ -890,9 +1006,14 @@ def _extrude_polygon(rings_world_coords, height): if N < 3: return np.empty(0, dtype=np.float32), np.empty(0, dtype=np.int32) + # Flatten base to a single elevation so rooftops aren't jagged. + # Use the minimum Z so no part of the building floats above terrain. + base_z = ring[:, 2].min() + ring[:, 2] = base_z + # Vertices: bottom ring (0..N-1), top ring (N..2N-1) top = ring.copy() - top[:, 2] += height + top[:, 2] = base_z + height verts = np.empty(2 * N * 3, dtype=np.float32) verts[: N * 3] = ring.ravel() diff --git a/rtxpy/quickstart.py b/rtxpy/quickstart.py new file mode 100644 index 0000000..fca71a2 --- /dev/null +++ b/rtxpy/quickstart.py @@ -0,0 +1,351 @@ +"""One-call launcher: DEM fetch -> analysis layers -> feature placement -> explore().""" + +import warnings +from pathlib import Path + + +def quickstart( + name, + bounds, + crs, + source='copernicus', + features=None, + tiles='satellite', + tile_zoom=None, + wind=True, + cache_dir=None, + **explore_kwargs, +): + """Fetch terrain, place features, and launch the interactive viewer. + + Parameters + ---------- + name : str + Location name used to derive the zarr filename + (``{name}_dem.zarr``) and GeoJSON cache filenames. + bounds : tuple of float + (west, south, east, north) in WGS84 degrees. + crs : str + EPSG code for the target projection (e.g. ``'EPSG:32620'``). + source : str + DEM source: ``'copernicus'``, ``'usgs_10m'``, ``'srtm'``. + Default ``'copernicus'``. + features : list or dict, optional + Features to place on the terrain. List form uses defaults:: + + features=['buildings', 'roads', 'water', 'fire'] + + Dict form allows per-feature overrides:: + + features={'buildings': {'elev_scale': 0.33}, + 'fire': {'region': 'southeast_asia'}} + + Supported keys: ``'buildings'``, ``'roads'``, ``'water'``, + ``'fire'``, ``'places'``, ``'infrastructure'``, ``'land_use'``, + ``'restaurant_grades'``, ``'gtfs'``. + tiles : str or None + Tile provider: ``'satellite'``, ``'osm'``, or ``None`` to skip. + Default ``'satellite'``. + tile_zoom : int, optional + Tile zoom level override. ``None`` uses the provider default. + wind : bool + Fetch live wind data from Open-Meteo. Default ``True``. + cache_dir : str or Path, optional + Directory for the zarr store and GeoJSON caches. Defaults to + the current working directory. + **explore_kwargs + Forwarded to ``ds.rtx.explore()``. Defaults:: + + width=2048, height=1600, render_scale=0.5, + color_stretch='cbrt', subsample=1, repl=True + """ + import numpy as np + import xarray as xr + from xrspatial import slope, aspect, quantile + + # -- paths ---------------------------------------------------------------- + if cache_dir is None: + cache_dir = Path.cwd() + else: + cache_dir = Path(cache_dir) + zarr_path = cache_dir / f"{name}_dem.zarr" + + # -- DEM ------------------------------------------------------------------ + from .remote_data import fetch_dem as _fetch_dem + + terrain = _fetch_dem(bounds=bounds, output_path=zarr_path, + source=source, crs=crs) + terrain.data = np.ascontiguousarray(terrain.data) + terrain = terrain.rtx.to_cupy() + + # -- Dataset with analysis layers ----------------------------------------- + print("Building Dataset with terrain analysis layers...") + ds = xr.Dataset({ + 'elevation': terrain.rename(None), + 'slope': slope(terrain), + 'aspect': aspect(terrain), + 'quantile': quantile(terrain), + }) + + # -- tiles ---------------------------------------------------------------- + if tiles: + print(f"Loading {tiles} tiles...") + ds.rtx.place_tiles(tiles, z='elevation', zoom=tile_zoom) + + # -- features ------------------------------------------------------------- + feat = _parse_features(features) + _TEMPORAL = {'fire'} + cacheable = {k: v for k, v in feat.items() if k not in _TEMPORAL} + temporal = {k: v for k, v in feat.items() if k in _TEMPORAL} + + # Check for mesh cache in zarr + has_cache = False + if cacheable: + try: + import zarr as _zarr + store = _zarr.open(str(zarr_path), mode='r', + use_consolidated=False) + has_cache = ('meshes' in store + and len(list(store['meshes'])) > 0) + del store + except Exception: + pass + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + message="place_geojson called before") + if has_cache: + ds.rtx.load_meshes(zarr_path) + elif cacheable: + _place_features(ds, cacheable, name, bounds, crs, cache_dir) + try: + ds.rtx.save_meshes(zarr_path) + except Exception as e: + print(f"Could not save mesh cache: {e}") + + # Temporal features (always fresh, not cached in zarr) + if temporal: + _place_features(ds, temporal, name, bounds, crs, cache_dir) + + # -- wind ----------------------------------------------------------------- + wind_data = None + if wind: + try: + from .remote_data import fetch_wind as _fetch_wind + wind_data = _fetch_wind(bounds, grid_size=15) + except Exception as e: + print(f"Skipping wind: {e}") + + # -- explore -------------------------------------------------------------- + defaults = dict( + width=2048, height=1600, render_scale=0.5, + color_stretch='cbrt', subsample=1, repl=True, + ) + defaults.update(explore_kwargs) + + # -- gtfs realtime ---------------------------------------------------------- + gtfs_data = ds.attrs.pop('_gtfs_data', None) + + print("\nLaunching explore...\n") + ds.rtx.explore(z='elevation', scene_zarr=zarr_path, + wind_data=wind_data, gtfs_data=gtfs_data, **defaults) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _parse_features(features): + """Normalize *features* to ``{key: {opts}}`` dict.""" + if features is None: + return {} + if isinstance(features, (list, tuple)): + return {f: {} for f in features} + out = {} + for key, val in features.items(): + if val is True: + out[key] = {} + elif isinstance(val, dict): + out[key] = val + elif val is False or val is None: + continue + else: + out[key] = {} + return out + + +def _place_features(ds, features, name, bounds, crs, cache_dir): + """Place all requested features, catching errors per feature.""" + cache_dir = Path(cache_dir) + for key, opts in features.items(): + handler = _FEATURE_HANDLERS.get(key) + if handler is None: + print(f"Unknown feature: {key!r}") + continue + try: + handler(ds, opts, name, bounds, crs, cache_dir) + except Exception as e: + print(f"Skipping {key}: {e}") + + +# -- individual feature handlers ---------------------------------------------- + +def _place_buildings(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_buildings + src = opts.get('source', 'overture') + data = fetch_buildings(bounds=bounds, source=src, + cache_path=cache_dir / f"{name}_buildings.geojson") + place_kw = {} + if 'elev_scale' in opts: + place_kw['elev_scale'] = opts['elev_scale'] + info = ds.rtx.place_buildings(data, z='elevation', **place_kw) + print(f"Placed {info['geometries']} building geometries") + + +def _place_roads(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_roads + for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), + ('minor', 'road_minor', (0.55, 0.55, 0.55))]: + data = fetch_roads(bounds=bounds, road_type=rt, source='overture', + cache_path=cache_dir / f"{name}_roads_{rt}.geojson") + info = ds.rtx.place_roads(data, z='elevation', + geometry_id=gid, color=clr) + print(f"Placed {info['geometries']} {rt} road geometries") + + +def _place_water(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_water + src = opts.get('source', 'overture') + wt = opts.get('water_type', 'all') + data = fetch_water(bounds=bounds, water_type=wt, source=src, + cache_path=cache_dir / f"{name}_water.geojson") + results = ds.rtx.place_water(data, z='elevation') + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + + +def _place_fire(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_firms + span = opts.get('date_span', '7d') + region = opts.get('region', None) + data = fetch_firms(bounds=bounds, date_span=span, region=region, + cache_path=cache_dir / f"{name}_fires.geojson", + crs=crs) + if data.get('features'): + info = ds.rtx.place_geojson( + data, z='elevation', height=20, + geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), + extrude=True, merge=True, + ) + print(f"Placed {info['geometries']} fire detection footprints") + else: + print("No fire detections in the last 7 days") + + +def _place_places(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_places + category = opts.get('category', 'eat_and_drink') + data = fetch_places(bounds=bounds, category=category, + cache_path=cache_dir / f"{name}_places.geojson", + crs=crs) + if data.get('features'): + info = ds.rtx.place_geojson( + data, z='elevation', height=8, + geometry_id='places', color=(1.0, 0.8, 0.0), + merge=True, + ) + print(f"Placed {info['geometries']} place markers") + + +def _place_infrastructure(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_infrastructure + itype = opts.get('infra_type', 'communication') + data = fetch_infrastructure(bounds=bounds, infra_type=itype, + cache_path=cache_dir / f"{name}_infra.geojson", + crs=crs) + if data.get('features'): + info = ds.rtx.place_geojson( + data, z='elevation', height=30, + geometry_id='infrastructure', color=(0.8, 0.2, 0.2), + merge=True, + ) + print(f"Placed {info['geometries']} infrastructure features") + + +def _place_land_use(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_land_use + lt = opts.get('land_type', 'park') + data = fetch_land_use(bounds=bounds, land_type=lt, + cache_path=cache_dir / f"{name}_land_use.geojson", + crs=crs) + if data.get('features'): + info = ds.rtx.place_geojson( + data, z='elevation', height=2, + geometry_id='land_use', color=(0.2, 0.7, 0.3, 0.5), + extrude=True, merge=True, + ) + print(f"Placed {info['geometries']} park polygons") + + +def _place_restaurant_grades(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_restaurant_grades + data = fetch_restaurant_grades( + bounds=bounds, + cache_path=cache_dir / f"{name}_restaurants.geojson", + ) + for grade, gid, clr in [ + ('A', 'grade_a', (0.20, 0.78, 0.40)), + ('B', 'grade_b', (0.95, 0.75, 0.10)), + ('C', 'grade_c', (0.90, 0.22, 0.20)), + ]: + subset = { + "type": "FeatureCollection", + "features": [f for f in data['features'] + if f['properties'].get('grade') == grade], + } + if subset['features']: + info = ds.rtx.place_geojson( + subset, z='elevation', height=15, + geometry_id=gid, color=clr, merge=True, + ) + print(f"Placed {info['geometries']} grade {grade} restaurants") + + +def _place_gtfs(ds, opts, name, bounds, crs, cache_dir): + from .remote_data import fetch_gtfs + data = fetch_gtfs(bounds=bounds, + feed_url=opts.get('feed_url'), + gtfs_path=opts.get('gtfs_path'), + route_types=opts.get('route_types'), + cache_path=cache_dir / f"{name}_gtfs.json", + crs=crs, + realtime_url=opts.get('realtime_url')) + if data['metadata']['n_routes'] > 0 or data['metadata']['n_stops'] > 0: + results = ds.rtx.place_gtfs(data, z='elevation', + stop_height=opts.get('stop_height', 8.0), + route_width=opts.get('route_width')) + for cat, groups in results.items(): + for label, info in groups.items(): + parts = [] + if 'routes' in info: + parts.append(f"{info['routes']['geometries']} routes") + if 'stops' in info: + parts.append(f"{info['stops']['geometries']} stops") + print(f"Placed {cat} [{label}]: {', '.join(parts)}") + # Stash gtfs_data for realtime overlay in explore() + ds.attrs['_gtfs_data'] = data + else: + print("No GTFS routes/stops found in bounds") + + +_FEATURE_HANDLERS = { + 'buildings': _place_buildings, + 'roads': _place_roads, + 'water': _place_water, + 'fire': _place_fire, + 'places': _place_places, + 'infrastructure': _place_infrastructure, + 'land_use': _place_land_use, + 'restaurant_grades': _place_restaurant_grades, + 'gtfs': _place_gtfs, +} diff --git a/rtxpy/remote_data.py b/rtxpy/remote_data.py index 97473cd..1581ec5 100644 --- a/rtxpy/remote_data.py +++ b/rtxpy/remote_data.py @@ -1,7 +1,8 @@ """Download remote geospatial data: DEM tiles, OSM features, buildings, roads, water, and fire. -Supports Copernicus GLO-30, USGS SRTM 1-arc-second, and USGS 3DEP -1-meter DEM sources, OpenStreetMap vector features via osmnx, +Supports Copernicus GLO-30, USGS SRTM 1-arc-second, USGS 3DEP +1/3-arc-second (10 m), and USGS 3DEP 1-meter DEM sources, +OpenStreetMap vector features via osmnx, Microsoft Global ML Building Footprints, convenience wrappers for roads and water features, and NASA FIRMS fire detection footprints. @@ -11,10 +12,75 @@ package is missing. """ +import json import math import re from pathlib import Path +# --------------------------------------------------------------------------- +# Overture Maps Foundation (optional, requires duckdb) +# --------------------------------------------------------------------------- + +_OVERTURE_RELEASE = '2026-01-21.0' +_OVERTURE_S3 = f's3://overturemaps-us-west-2/release/{_OVERTURE_RELEASE}' + +_OVERTURE_MAJOR_CLASSES = {'motorway', 'trunk', 'primary', 'secondary'} +_OVERTURE_MINOR_CLASSES = {'tertiary', 'residential', 'living_street', + 'unclassified', 'service'} + + +def _query_overture(bounds, theme, type_name, columns, release=None): + """Query Overture Maps GeoParquet on S3 via DuckDB. + + Parameters + ---------- + bounds : tuple of float + (west, south, east, north) in WGS84 degrees. + theme : str + Overture theme (e.g. ``'buildings'``, ``'transportation'``). + type_name : str + Overture type within theme (e.g. ``'building'``, ``'segment'``). + columns : list of str + Columns to select from the parquet dataset. + release : str, optional + Overture release version. Defaults to ``_OVERTURE_RELEASE``. + + Returns + ------- + pandas.DataFrame + One row per feature with requested columns plus ``geometry_json``. + """ + try: + import duckdb + except ImportError: + raise ImportError( + "duckdb is required for Overture Maps data. " + "Install with: pip install duckdb" + ) + + conn = duckdb.connect() + conn.execute("INSTALL spatial; LOAD spatial;") + conn.execute("INSTALL httpfs; LOAD httpfs;") + conn.execute("SET s3_region='us-west-2';") + + release = release or _OVERTURE_RELEASE + s3_path = (f's3://overturemaps-us-west-2/release/{release}' + f'/theme={theme}/type={type_name}/*') + + west, south, east, north = bounds + col_str = ', '.join(columns) + + query = f""" + SELECT {col_str}, ST_AsGeoJSON(geometry) AS geometry_json + FROM read_parquet('{s3_path}', filename=true, hive_partitioning=1) + WHERE bbox.xmin > {west} AND bbox.xmax < {east} + AND bbox.ymin > {south} AND bbox.ymax < {north} + """ + + result = conn.execute(query).fetchdf() + conn.close() + return result + def _compute_srtm_tiles(bounds): """Return list of (tile_name, url) for USGS SRTM 1-arc-second tiles. @@ -44,6 +110,35 @@ def _compute_srtm_tiles(bounds): return tiles +def _compute_usgs_10m_tiles(bounds): + """Return list of (tile_name, url) for USGS 3DEP 1/3-arc-second tiles. + + Same grid naming as SRTM (``n43w122`` covers lat [42, 43], + lon [-122, -121]) but hosted under the ``/13/`` path prefix + with filename prefix ``USGS_13_``. + """ + west, south, east, north = bounds + base_url = ( + "https://prd-tnm.s3.amazonaws.com" + "/StagedProducts/Elevation/13/TIFF/current" + ) + + lat_min = math.ceil(south) + lat_max = math.ceil(north) + lon_min = math.floor(west) + lon_max = math.floor(east) + + tiles = [] + for lat in range(lat_min, lat_max + 1): + for lon in range(lon_min, lon_max + 1): + ns = "n" if lat >= 0 else "s" + ew = "w" if lon < 0 else "e" + tile_name = f"{ns}{abs(lat):02d}{ew}{abs(lon):03d}" + url = f"{base_url}/{tile_name}/USGS_13_{tile_name}.tif" + tiles.append((tile_name, url)) + return tiles + + def _compute_copernicus_tiles(bounds): """Return list of (tile_name, url) for Copernicus GLO-30 tiles. @@ -167,6 +262,72 @@ def _download_tile(url, tile_path): f.write(chunk) +def _save_zarr(da, output_path): + """Save a DataArray to a CF-encoded zarr store with int16 compression.""" + import numpy as np + import xarray as xr + from zarr.codecs import BloscCodec, BloscShuffle + + H, W = da.shape[-2], da.shape[-1] + + # Build dataset with elevation + spatial_ref + ds = da.to_dataset(name='elevation') + + # Drop scalar coord variables left by rioxarray (e.g. 'band') + for coord in list(ds.coords): + if coord not in ('x', 'y') and coord not in ds.dims: + ds = ds.drop_vars(coord) + + # Add spatial_ref variable with CRS metadata + crs_wkt = da.rio.crs.to_wkt() if da.rio.crs else "" + transform = da.rio.transform() + geo_transform = (f"{transform.c} {transform.a} {transform.b} " + f"{transform.f} {transform.d} {transform.e}") + ds['spatial_ref'] = xr.DataArray( + np.int32(0), + attrs={ + 'crs_wkt': crs_wkt, + 'GeoTransform': geo_transform, + }, + ) + + # Drop attrs that conflict with our CF encoding or spatial_ref variable + for attr in ('grid_mapping', 'spatial_ref', + 'scale_factor', 'add_offset', '_FillValue'): + ds['elevation'].attrs.pop(attr, None) + + encoding = { + 'elevation': { + 'dtype': 'int16', + 'scale_factor': np.float64(0.1), + 'add_offset': np.float64(0.0), + '_FillValue': np.int16(-9999), + 'compressors': BloscCodec(cname='zstd', clevel=6, + shuffle=BloscShuffle.bitshuffle), + 'chunks': (min(2048, H), min(2048, W)), + }, + } + + ds.to_zarr(str(output_path), mode='w', encoding=encoding) + + +def _load_zarr(output_path): + """Load a CF-encoded zarr store and return a float DataArray with CRS.""" + import xarray as xr + import rioxarray # noqa: F401 — needed for .rio accessor + + ds = xr.open_zarr(str(output_path)) + da = ds['elevation'] + + # Attach CRS from spatial_ref variable + if 'spatial_ref' in ds: + crs_wkt = ds['spatial_ref'].attrs.get('crs_wkt', '') + if crs_wkt: + da = da.rio.write_crs(crs_wkt) + + return da + + def _merge_clip_reproject(tile_paths, bounds, crs, output_path): """Merge tile arrays, clip to bounds, optionally reproject, and save.""" try: @@ -203,7 +364,13 @@ def _merge_clip_reproject(tile_paths, bounds, crs, output_path): if crs is not None: merged = merged.rio.reproject(crs) - merged.rio.to_raster(str(output_path)) + # Dispatch by output format + output_path = Path(output_path) + if output_path.suffix == '.zarr': + _save_zarr(merged, output_path) + else: + merged.rio.to_raster(str(output_path)) + return merged @@ -215,12 +382,17 @@ def fetch_dem(bounds, output_path, source="copernicus", crs=None, cache_dir=None bounds : tuple of float (west, south, east, north) in WGS84 degrees. output_path : str or Path - Where to save the final merged/clipped/reprojected GeoTIFF. - If the file already exists, loads and returns it directly. + Where to save the final merged/clipped/reprojected DEM. + Use ``.zarr`` for a chunked, CF-encoded zarr store (int16 + + scale_factor=0.1, Blosc zstd compression) or ``.tif`` for + GeoTIFF. If the path already exists, loads and returns it + directly. source : str ``'copernicus'`` for Copernicus GLO-30 (30 m), ``'srtm'`` for - USGS 1-arc-second (~30 m), or ``'usgs_1m'`` for USGS 3DEP - 1-meter lidar DEM (US coverage only, ~30 MB per 10 km tile). + USGS 1-arc-second (~30 m), ``'usgs_10m'`` for USGS 3DEP + 1/3-arc-second (~10 m, US coverage), or ``'usgs_1m'`` for + USGS 3DEP 1-meter lidar DEM (US coverage only, ~30 MB per + 10 km tile). crs : str, optional Target CRS for reprojection (e.g. ``'EPSG:32620'``). ``None`` keeps the native CRS. @@ -242,7 +414,11 @@ def fetch_dem(bounds, output_path, source="copernicus", crs=None, cache_dir=None output_path = Path(output_path) - if output_path.exists(): + # Cache hit: zarr stores are directories, tif files are regular files + if output_path.suffix == '.zarr' and output_path.is_dir(): + print(f"Using cached DEM: {output_path.name}") + return _load_zarr(output_path) + elif output_path.suffix != '.zarr' and output_path.exists(): print(f"Using cached DEM: {output_path.name}") return rxr.open_rasterio(str(output_path), masked=True).squeeze() @@ -259,12 +435,15 @@ def fetch_dem(bounds, output_path, source="copernicus", crs=None, cache_dir=None elif source == "copernicus": tiles = _compute_copernicus_tiles(bounds) ext_prefix = "" + elif source == "usgs_10m": + tiles = _compute_usgs_10m_tiles(bounds) + ext_prefix = "USGS_13_" elif source == "usgs_1m": tiles = _query_usgs_1m_tiles(bounds) ext_prefix = "" else: raise ValueError( - f"Unknown source {source!r}; use 'copernicus', 'srtm', or 'usgs_1m'" + f"Unknown source {source!r}; use 'copernicus', 'srtm', 'usgs_10m', or 'usgs_1m'" ) print(f"Downloading {len(tiles)} {source} tile(s)...") @@ -446,14 +625,81 @@ def _feature_in_bounds(feature, west, south, east, north): return False -def fetch_buildings(bounds, cache_path=None, crs=None, cache_dir=None): - """Download Microsoft Global Building Footprints for a bounding box. +def _fetch_buildings_overture(bounds, cache_path=None, crs=None): + """Fetch building footprints from Overture Maps via DuckDB.""" + print("Querying Overture Maps buildings...") + df = _query_overture(bounds, 'buildings', 'building', + ['height', 'num_floors', 'names']) + + import pandas as pd + + features = [] + for _, row in df.iterrows(): + geom = json.loads(row['geometry_json']) + + # Height: prefer explicit height, fall back to num_floors * 3m + height = row.get('height') + if pd.isna(height): + nf = row.get('num_floors') + if not pd.isna(nf): + height = float(nf) * 3.0 + else: + height = -1.0 + + features.append({ + "type": "Feature", + "geometry": geom, + "properties": { + "height": float(height), + "confidence": 1.0, + }, + }) + + print(f" Found {len(features)} buildings from Overture Maps") + + # Reproject if requested + if crs is not None and features: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = gdf.to_crs(crs) + print(f" Reprojected to {crs}") + geojson = json.loads(gdf.to_json()) + else: + geojson = {"type": "FeatureCollection", "features": features} + + # Cache result + if cache_path is not None: + cache_path = Path(cache_path) + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(geojson, f) + print(f" Cached to {cache_path}") + + return geojson + + +def fetch_buildings(bounds, cache_path=None, crs=None, cache_dir=None, + source='microsoft'): + """Download building footprints for a bounding box. - Uses the `dataset-links.csv - `_ index to - find level-9 quadkey partitions that overlap *bounds*, downloads the - compressed GeoJSONL files, filters features to the bounding box, and - returns a standard GeoJSON FeatureCollection. + Supports two data sources: + + - ``'microsoft'`` (default) — Microsoft Global ML Building Footprints. + Uses the `dataset-links.csv + `_ index to + find level-9 quadkey partitions that overlap *bounds*, downloads the + compressed GeoJSONL files, filters features to the bounding box, and + returns a standard GeoJSON FeatureCollection. + - ``'overture'`` — Overture Maps Foundation buildings. Queries the + Overture GeoParquet dataset on S3 via DuckDB (requires ``duckdb``). + Provides deduplicated footprints with richer attributes (height, + number of floors) aggregated from OSM, Microsoft, Meta, and Esri. Parameters ---------- @@ -469,6 +715,9 @@ def fetch_buildings(bounds, cache_path=None, crs=None, cache_dir=None): cache_dir : str or Path, optional Directory for caching the dataset-links.csv index and downloaded partition files. Defaults to ``~/.cache/rtxpy/buildings``. + Only used with ``source='microsoft'``. + source : str + Data source: ``'microsoft'`` (default) or ``'overture'``. Returns ------- @@ -484,8 +733,6 @@ def fetch_buildings(bounds, cache_path=None, crs=None, cache_dir=None): ... crs='EPSG:32620') >>> dem.rtx.place_geojson(bldgs, height=8.0) """ - import json - if cache_path is not None: cache_path = Path(cache_path) if cache_path.exists(): @@ -493,6 +740,15 @@ def fetch_buildings(bounds, cache_path=None, crs=None, cache_dir=None): with open(cache_path) as f: return json.load(f) + source = source.lower() + if source == 'overture': + return _fetch_buildings_overture(bounds, cache_path=cache_path, + crs=crs) + elif source != 'microsoft': + raise ValueError( + f"Unknown source {source!r}; use 'microsoft' or 'overture'" + ) + if cache_dir is None: cache_dir = Path.home() / ".cache" / "rtxpy" / "buildings" else: @@ -618,8 +874,98 @@ def fetch_buildings(bounds, cache_path=None, crs=None, cache_dir=None): ] -def fetch_roads(bounds, road_type="all", cache_path=None, crs=None): - """Download road data from OpenStreetMap for a bounding box. +def _fetch_roads_overture(bounds, road_type="all", cache_path=None, crs=None): + """Fetch road data from Overture Maps via DuckDB.""" + print("Querying Overture Maps roads...") + df = _query_overture(bounds, 'transportation', 'segment', + ['class', 'subtype', 'names', 'road_surface']) + + # Filter to road subtypes only (exclude rail, water) + if 'subtype' in df.columns: + df = df[df['subtype'] == 'road'] + + # Filter by road_type using Overture class values + road_type = road_type.lower() + if road_type == 'major': + allowed = _OVERTURE_MAJOR_CLASSES + elif road_type == 'minor': + allowed = _OVERTURE_MINOR_CLASSES + elif road_type == 'all': + allowed = _OVERTURE_MAJOR_CLASSES | _OVERTURE_MINOR_CLASSES + else: + raise ValueError( + f"Unknown road_type {road_type!r}; use 'major', 'minor', or 'all'" + ) + + if 'class' in df.columns: + df = df[df['class'].isin(allowed)] + + import pandas as pd + + features = [] + for _, row in df.iterrows(): + geom = json.loads(row['geometry_json']) + + # Extract name from Overture names struct (may be dict or None) + names = row.get('names') + name = None + if isinstance(names, dict): + name = names.get('primary', None) + + # Map Overture class to OSM-style highway tag for compatibility + road_class = row.get('class', '') + if pd.isna(road_class): + road_class = None + + features.append({ + "type": "Feature", + "geometry": geom, + "properties": { + "name": name, + "highway": road_class if road_class else None, + }, + }) + + print(f" Found {len(features)} road segments from Overture Maps") + + # Reproject if requested + if crs is not None and features: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = gdf.to_crs(crs) + print(f" Reprojected to {crs}") + geojson = json.loads(gdf.to_json()) + else: + geojson = {"type": "FeatureCollection", "features": features} + + # Cache result + if cache_path is not None: + cache_path = Path(cache_path) + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(geojson, f) + print(f" Cached to {cache_path}") + + return geojson + + +def fetch_roads(bounds, road_type="all", cache_path=None, crs=None, + source='osm'): + """Download road data for a bounding box. + + Supports two data sources: + + - ``'osm'`` (default) — OpenStreetMap via osmnx. + - ``'overture'`` — Overture Maps Foundation transportation data. + Queries the Overture GeoParquet dataset on S3 via DuckDB (requires + ``duckdb``). Provides deduplicated road segments with + classification and surface attributes. Parameters ---------- @@ -637,6 +983,8 @@ def fetch_roads(bounds, road_type="all", cache_path=None, crs=None): crs : str, optional Target CRS for reprojection (e.g. ``'EPSG:32620'``). ``None`` keeps WGS84. + source : str + Data source: ``'osm'`` (default) or ``'overture'``. Returns ------- @@ -650,6 +998,15 @@ def fetch_roads(bounds, road_type="all", cache_path=None, crs=None): ... road_type='major', crs='EPSG:5070') >>> dem.rtx.place_geojson(roads, height=3.0, label_field='name') """ + source = source.lower() + if source == 'overture': + return _fetch_roads_overture(bounds, road_type=road_type, + cache_path=cache_path, crs=crs) + elif source != 'osm': + raise ValueError( + f"Unknown source {source!r}; use 'osm' or 'overture'" + ) + road_type = road_type.lower() if road_type == "major": values = _MAJOR_ROAD_VALUES @@ -669,9 +1026,109 @@ def fetch_roads(bounds, road_type="all", cache_path=None, crs=None): # OSM tags for water features _WATERWAY_VALUES = ["river", "stream", "canal", "drain", "ditch"] +# Overture water subtype groupings +_OVERTURE_WATERWAY_SUBTYPES = {'river', 'canal', 'stream'} +_OVERTURE_WATERWAY_CLASSES = {'drain', 'ditch'} +_OVERTURE_WATERBODY_SUBTYPES = {'lake', 'pond', 'reservoir', 'ocean'} + + +def _fetch_water_overture(bounds, water_type="all", cache_path=None, crs=None): + """Fetch water features from Overture Maps via DuckDB.""" + print("Querying Overture Maps water...") + df = _query_overture(bounds, 'base', 'water', + ['subtype', 'class', 'names']) + + import pandas as pd + + water_type = water_type.lower() + + # Filter by water_type + if water_type == 'waterway': + mask = (df['subtype'].isin(_OVERTURE_WATERWAY_SUBTYPES) | + df['class'].isin(_OVERTURE_WATERWAY_CLASSES)) + df = df[mask] + elif water_type == 'waterbody': + df = df[df['subtype'].isin(_OVERTURE_WATERBODY_SUBTYPES)] + elif water_type != 'all': + raise ValueError( + f"Unknown water_type {water_type!r}; " + "use 'waterway', 'waterbody', or 'all'" + ) + + features = [] + for _, row in df.iterrows(): + geom = json.loads(row['geometry_json']) + + # Extract name from Overture names struct + names = row.get('names') + name = None + if isinstance(names, dict): + name = names.get('primary', None) + + # Map Overture subtype/class → OSM-style properties for place_water() + subtype = row.get('subtype', '') + cls = row.get('class', '') + if pd.isna(subtype): + subtype = '' + if pd.isna(cls): + cls = '' + + props = {"name": name} + if subtype in ('river', 'canal'): + props['waterway'] = subtype + elif subtype == 'stream' or cls in ('drain', 'ditch'): + props['waterway'] = cls if cls in ('drain', 'ditch') else 'stream' + elif subtype in ('lake', 'pond', 'reservoir', 'ocean'): + props['natural'] = 'water' + else: + # Fallback: treat as minor waterway + props['waterway'] = 'stream' + + features.append({ + "type": "Feature", + "geometry": geom, + "properties": props, + }) + + print(f" Found {len(features)} water features from Overture Maps") + + # Reproject if requested + if crs is not None and features: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = gdf.to_crs(crs) + print(f" Reprojected to {crs}") + geojson = json.loads(gdf.to_json()) + else: + geojson = {"type": "FeatureCollection", "features": features} + + # Cache result + if cache_path is not None: + cache_path = Path(cache_path) + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(geojson, f) + print(f" Cached to {cache_path}") + + return geojson + -def fetch_water(bounds, water_type="all", cache_path=None, crs=None): - """Download water / waterway features from OpenStreetMap. +def fetch_water(bounds, water_type="all", cache_path=None, crs=None, + source='osm'): + """Download water / waterway features. + + Supports two data sources: + + - ``'osm'`` (default) — OpenStreetMap via osmnx. + - ``'overture'`` — Overture Maps Foundation water data. + Queries the Overture GeoParquet dataset on S3 via DuckDB (requires + ``duckdb``). Parameters ---------- @@ -681,8 +1138,7 @@ def fetch_water(bounds, water_type="all", cache_path=None, crs=None): Which features to include: - ``'waterway'`` — linear features: rivers, streams, canals, etc. - - ``'waterbody'`` — area features: lakes, reservoirs, ponds - (``natural=water``). + - ``'waterbody'`` — area features: lakes, reservoirs, ponds. - ``'all'`` (default) — both waterways and waterbodies. cache_path : str or Path, optional Path to cache the result as GeoJSON. If the file already @@ -690,6 +1146,8 @@ def fetch_water(bounds, water_type="all", cache_path=None, crs=None): crs : str, optional Target CRS for reprojection (e.g. ``'EPSG:32620'``). ``None`` keeps WGS84. + source : str + Data source: ``'osm'`` (default) or ``'overture'``. Returns ------- @@ -703,6 +1161,22 @@ def fetch_water(bounds, water_type="all", cache_path=None, crs=None): ... water_type='waterway', crs='EPSG:32620') >>> dem.rtx.place_geojson(rivers, height=2.0, label_field='name') """ + if cache_path is not None: + cache_path = Path(cache_path) + if cache_path.exists(): + print(f"Using cached water data: {cache_path.name}") + with open(cache_path) as f: + return json.load(f) + + source = source.lower() + if source == 'overture': + return _fetch_water_overture(bounds, water_type=water_type, + cache_path=cache_path, crs=crs) + elif source != 'osm': + raise ValueError( + f"Unknown source {source!r}; use 'osm' or 'overture'" + ) + water_type = water_type.lower() if water_type == "waterway": tags = {"waterway": _WATERWAY_VALUES} @@ -720,50 +1194,547 @@ def fetch_water(bounds, water_type="all", cache_path=None, crs=None): # --------------------------------------------------------------------------- -# Wind data (Open-Meteo) +# Overture Maps: places, infrastructure, land use # --------------------------------------------------------------------------- -_OPEN_METEO_URL = "https://api.open-meteo.com/v1/forecast" - -def fetch_wind(bounds, grid_size=20): - """Fetch current wind data from Open-Meteo for a bounding box. +def fetch_places(bounds, category=None, cache_path=None, crs=None): + """Download point-of-interest data from Overture Maps. - Queries the Open-Meteo forecast API for 10 m wind speed and - direction on a regular lat/lon grid, then decomposes into U/V - components suitable for particle animation. + Returns Point geometries from the Overture ``places/place`` dataset. Parameters ---------- bounds : tuple of float (west, south, east, north) in WGS84 degrees. - grid_size : int - Number of grid points along each axis (default 20). - Total API points = grid_size². Open-Meteo allows up to - ~1 000 points per request. + category : str or list of str, optional + Filter on Overture ``categories.primary`` (e.g. ``'eat_and_drink'``, + ``'education'``, ``['hospital', 'school']``). ``None`` returns all. + cache_path : str or Path, optional + Path to cache the result as GeoJSON. If the file already + exists, loads and returns it directly. + crs : str, optional + Target CRS for reprojection. ``None`` keeps WGS84. Returns ------- dict - ``'u'`` : ndarray (ny, nx) — east–west wind component (m/s). - ``'v'`` : ndarray (ny, nx) — north–south wind component (m/s). - ``'speed'`` : ndarray (ny, nx) — wind speed (m/s). - ``'direction'`` : ndarray (ny, nx) — meteorological direction - (degrees, where wind is coming *from*). - ``'lats'`` : ndarray (ny,) — latitude values. - ``'lons'`` : ndarray (nx,) — longitude values. + GeoJSON FeatureCollection with Point geometries. Each feature + has ``name``, ``category``, and ``confidence`` properties. Examples -------- - >>> from rtxpy import fetch_wind - >>> wind = fetch_wind((-43.42, -23.08, -43.10, -22.84)) - >>> wind['u'].shape - (20, 20) + >>> from rtxpy import fetch_places + >>> restaurants = fetch_places((-61.55, 10.62, -61.48, 10.69), + ... category='eat_and_drink') """ - try: - import requests - except ImportError: - raise ImportError( + if cache_path is not None: + cache_path = Path(cache_path) + if cache_path.exists(): + print(f"Using cached places data: {cache_path.name}") + with open(cache_path) as f: + return json.load(f) + + print("Querying Overture Maps places...") + df = _query_overture(bounds, 'places', 'place', + ['names', 'categories', 'confidence']) + + import pandas as pd + + # Filter by category + if category is not None: + if isinstance(category, str): + category = [category] + category_set = set(category) + + def _matches_category(cats): + if isinstance(cats, dict): + return cats.get('primary', '') in category_set + return False + + df = df[df['categories'].apply(_matches_category)] + + features = [] + for _, row in df.iterrows(): + geom = json.loads(row['geometry_json']) + + names = row.get('names') + name = None + if isinstance(names, dict): + name = names.get('primary', None) + + cats = row.get('categories') + cat = None + if isinstance(cats, dict): + cat = cats.get('primary', None) + + conf = row.get('confidence') + if pd.isna(conf): + conf = -1.0 + + features.append({ + "type": "Feature", + "geometry": geom, + "properties": { + "name": name, + "category": cat, + "confidence": float(conf), + }, + }) + + print(f" Found {len(features)} places from Overture Maps") + + # Reproject if requested + if crs is not None and features: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = gdf.to_crs(crs) + print(f" Reprojected to {crs}") + geojson = json.loads(gdf.to_json()) + else: + geojson = {"type": "FeatureCollection", "features": features} + + # Cache result + if cache_path is not None: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(geojson, f) + print(f" Cached to {cache_path}") + + return geojson + + +def fetch_infrastructure(bounds, infra_type='all', cache_path=None, crs=None): + """Download infrastructure features from Overture Maps. + + Returns features from the Overture ``base/infrastructure`` dataset. + + Parameters + ---------- + bounds : tuple of float + (west, south, east, north) in WGS84 degrees. + infra_type : str + Filter on Overture ``subtype``: + + - ``'communication'`` — cell towers, antennas + - ``'power'`` — power lines, substations + - ``'bridge'`` — bridges + - ``'tower'`` — towers + - ``'transit'`` — transit stations + - ``'airport'`` — airports, runways, helipads + - ``'all'`` (default) — everything + cache_path : str or Path, optional + Path to cache the result as GeoJSON. + crs : str, optional + Target CRS for reprojection. ``None`` keeps WGS84. + + Returns + ------- + dict + GeoJSON FeatureCollection. Each feature has ``name``, + ``subtype``, ``class``, and ``height`` (metres, -1 if unknown) + properties. + + Examples + -------- + >>> from rtxpy import fetch_infrastructure + >>> towers = fetch_infrastructure((-61.55, 10.62, -61.48, 10.69), + ... infra_type='communication') + """ + if cache_path is not None: + cache_path = Path(cache_path) + if cache_path.exists(): + print(f"Using cached infrastructure data: {cache_path.name}") + with open(cache_path) as f: + return json.load(f) + + print("Querying Overture Maps infrastructure...") + df = _query_overture(bounds, 'base', 'infrastructure', + ['subtype', 'class', 'names', 'height']) + + import pandas as pd + + infra_type = infra_type.lower() + valid_types = {'communication', 'power', 'bridge', 'tower', + 'transit', 'airport'} + if infra_type != 'all': + if infra_type not in valid_types: + raise ValueError( + f"Unknown infra_type {infra_type!r}; use one of " + f"{sorted(valid_types)} or 'all'" + ) + if 'subtype' in df.columns: + df = df[df['subtype'] == infra_type] + + features = [] + for _, row in df.iterrows(): + geom = json.loads(row['geometry_json']) + + names = row.get('names') + name = None + if isinstance(names, dict): + name = names.get('primary', None) + + subtype = row.get('subtype', None) + if pd.isna(subtype): + subtype = None + cls = row.get('class', None) + if pd.isna(cls): + cls = None + + height = row.get('height') + if pd.isna(height): + height = -1.0 + + features.append({ + "type": "Feature", + "geometry": geom, + "properties": { + "name": name, + "subtype": subtype, + "class": cls, + "height": float(height), + }, + }) + + print(f" Found {len(features)} infrastructure features from Overture Maps") + + # Reproject if requested + if crs is not None and features: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = gdf.to_crs(crs) + print(f" Reprojected to {crs}") + geojson = json.loads(gdf.to_json()) + else: + geojson = {"type": "FeatureCollection", "features": features} + + # Cache result + if cache_path is not None: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(geojson, f) + print(f" Cached to {cache_path}") + + return geojson + + +def fetch_land_use(bounds, land_type='all', cache_path=None, crs=None): + """Download land use/land cover features from Overture Maps. + + Returns features from the Overture ``base/land_use`` dataset. + + Parameters + ---------- + bounds : tuple of float + (west, south, east, north) in WGS84 degrees. + land_type : str + Filter on Overture ``subtype``: + + - ``'residential'``, ``'park'``, ``'agriculture'``, + ``'education'``, ``'military'``, ``'protected'``, + ``'developed'``, ``'recreation'`` + - ``'all'`` (default) — everything + cache_path : str or Path, optional + Path to cache the result as GeoJSON. + crs : str, optional + Target CRS for reprojection. ``None`` keeps WGS84. + + Returns + ------- + dict + GeoJSON FeatureCollection. Each feature has ``name``, + ``subtype``, and ``class`` properties. + + Examples + -------- + >>> from rtxpy import fetch_land_use + >>> parks = fetch_land_use((-61.55, 10.62, -61.48, 10.69), + ... land_type='park') + """ + if cache_path is not None: + cache_path = Path(cache_path) + if cache_path.exists(): + print(f"Using cached land use data: {cache_path.name}") + with open(cache_path) as f: + return json.load(f) + + print("Querying Overture Maps land use...") + df = _query_overture(bounds, 'base', 'land_use', + ['subtype', 'class', 'names']) + + import pandas as pd + + land_type = land_type.lower() + valid_types = {'residential', 'park', 'agriculture', 'education', + 'military', 'protected', 'developed', 'recreation'} + if land_type != 'all': + if land_type not in valid_types: + raise ValueError( + f"Unknown land_type {land_type!r}; use one of " + f"{sorted(valid_types)} or 'all'" + ) + if 'subtype' in df.columns: + df = df[df['subtype'] == land_type] + + features = [] + for _, row in df.iterrows(): + geom = json.loads(row['geometry_json']) + + names = row.get('names') + name = None + if isinstance(names, dict): + name = names.get('primary', None) + + subtype = row.get('subtype', None) + if pd.isna(subtype): + subtype = None + cls = row.get('class', None) + if pd.isna(cls): + cls = None + + features.append({ + "type": "Feature", + "geometry": geom, + "properties": { + "name": name, + "subtype": subtype, + "class": cls, + }, + }) + + print(f" Found {len(features)} land use features from Overture Maps") + + # Reproject if requested + if crs is not None and features: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = gdf.to_crs(crs) + print(f" Reprojected to {crs}") + geojson = json.loads(gdf.to_json()) + else: + geojson = {"type": "FeatureCollection", "features": features} + + # Cache result + if cache_path is not None: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(geojson, f) + print(f" Cached to {cache_path}") + + return geojson + + +# --------------------------------------------------------------------------- +# NYC Open Data: restaurant inspection grades +# --------------------------------------------------------------------------- + +_NYC_RESTAURANT_URL = "https://data.cityofnewyork.us/resource/43nn-pn8j.json" + + +def fetch_restaurant_grades(bounds, cache_path=None, crs=None): + """Download NYC restaurant health inspection grades. + + Queries the DOHMH Restaurant Inspection Results dataset on NYC Open + Data and returns one Point per restaurant with its most recent + letter grade. + + Parameters + ---------- + bounds : tuple of float + (west, south, east, north) in WGS84 degrees. + cache_path : str or Path, optional + Path to cache the result as GeoJSON. If the file already + exists, loads and returns it directly. + crs : str, optional + Target CRS for reprojection. ``None`` keeps WGS84. + + Returns + ------- + dict + GeoJSON FeatureCollection with Point geometries. Each feature + has ``name``, ``cuisine``, ``grade`` (A/B/C), and ``score`` + (lower is better) properties. + + Examples + -------- + >>> from rtxpy import fetch_restaurant_grades + >>> grades = fetch_restaurant_grades((-74.02, 40.70, -73.97, 40.75)) + >>> len(grades['features']) + 1234 + """ + try: + import requests + except ImportError: + raise ImportError( + "requests is required for fetch_restaurant_grades(). " + "Install with: pip install requests" + ) + + if cache_path is not None: + cache_path = Path(cache_path) + if cache_path.exists(): + print(f"Using cached restaurant grades: {cache_path.name}") + with open(cache_path) as f: + return json.load(f) + + west, south, east, north = bounds + print("Fetching NYC restaurant inspection grades...") + + # Paginated fetch — SODA API caps at 50000 rows per request + all_rows = [] + offset = 0 + page_size = 50000 + while True: + params = { + "$select": ("camis, dba, cuisine_description, " + "latitude, longitude, score, grade, grade_date"), + "$where": (f"latitude between {south} and {north} " + f"AND longitude between {west} and {east} " + f"AND grade IS NOT NULL AND latitude > 0"), + "$order": "camis, grade_date DESC", + "$limit": page_size, + "$offset": offset, + } + resp = requests.get(_NYC_RESTAURANT_URL, params=params, timeout=60) + resp.raise_for_status() + rows = resp.json() + all_rows.extend(rows) + if len(rows) < page_size: + break + offset += page_size + + print(f" Received {len(all_rows)} inspection rows") + + # Deduplicate: keep first row per camis (latest grade_date due to ordering) + seen = set() + features = [] + for row in all_rows: + camis = row.get('camis') + if camis in seen: + continue + seen.add(camis) + + lat = float(row.get('latitude', 0)) + lon = float(row.get('longitude', 0)) + if lat == 0 or lon == 0: + continue + + grade = row.get('grade', '') + if grade not in ('A', 'B', 'C'): + continue + + try: + score = int(row.get('score', -1)) + except (ValueError, TypeError): + score = -1 + + features.append({ + "type": "Feature", + "geometry": { + "type": "Point", + "coordinates": [lon, lat], + }, + "properties": { + "name": row.get('dba', ''), + "cuisine": row.get('cuisine_description', ''), + "grade": grade, + "score": score, + }, + }) + + print(f" {len(features)} unique restaurants with grades A/B/C") + + # Reproject if requested + if crs is not None and features: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + gdf = gpd.GeoDataFrame.from_features(features, crs="EPSG:4326") + gdf = gdf.to_crs(crs) + print(f" Reprojected to {crs}") + geojson = json.loads(gdf.to_json()) + else: + geojson = {"type": "FeatureCollection", "features": features} + + # Cache result + if cache_path is not None: + cache_path = Path(cache_path) + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump(geojson, f) + print(f" Cached to {cache_path}") + + return geojson + + +# --------------------------------------------------------------------------- +# Wind data (Open-Meteo) +# --------------------------------------------------------------------------- + +_OPEN_METEO_URL = "https://api.open-meteo.com/v1/forecast" + + +def fetch_wind(bounds, grid_size=20): + """Fetch current wind data from Open-Meteo for a bounding box. + + Queries the Open-Meteo forecast API for 10 m wind speed and + direction on a regular lat/lon grid, then decomposes into U/V + components suitable for particle animation. + + Parameters + ---------- + bounds : tuple of float + (west, south, east, north) in WGS84 degrees. + grid_size : int + Number of grid points along each axis (default 20). + Total API points = grid_size². Open-Meteo allows up to + ~1 000 points per request. + + Returns + ------- + dict + ``'u'`` : ndarray (ny, nx) — east–west wind component (m/s). + ``'v'`` : ndarray (ny, nx) — north–south wind component (m/s). + ``'speed'`` : ndarray (ny, nx) — wind speed (m/s). + ``'direction'`` : ndarray (ny, nx) — meteorological direction + (degrees, where wind is coming *from*). + ``'lats'`` : ndarray (ny,) — latitude values. + ``'lons'`` : ndarray (nx,) — longitude values. + + Examples + -------- + >>> from rtxpy import fetch_wind + >>> wind = fetch_wind((-43.42, -23.08, -43.10, -22.84)) + >>> wind['u'].shape + (20, 20) + """ + try: + import requests + except ImportError: + raise ImportError( "requests is required for fetch_wind(). " "Install it with: pip install requests" ) @@ -1067,3 +2038,524 @@ def fetch_firms(bounds, date_span="24h", region=None, cache_path=None, print(f" Cached to {cache_path}") return geojson + + +# --------------------------------------------------------------------------- +# GTFS Transit Feeds +# --------------------------------------------------------------------------- + +_GTFS_ROUTE_TYPE_NAMES = { + 0: 'tram', 1: 'subway', 2: 'rail', 3: 'bus', 4: 'ferry', + 5: 'tram', 6: 'gondola', 7: 'funicular', + # Extended route types (hundreds) + 100: 'rail', 200: 'rail', 400: 'subway', 700: 'bus', 900: 'tram', + 1000: 'ferry', 1100: 'bus', 1300: 'gondola', 1400: 'funicular', +} + + +def _gtfs_route_type_name(route_type): + """Map GTFS route_type int to a human-readable category name.""" + rt = int(route_type) + if rt in _GTFS_ROUTE_TYPE_NAMES: + return _GTFS_ROUTE_TYPE_NAMES[rt] + # Extended types (>= 100): use hundreds group + if rt >= 100: + hundreds = (rt // 100) * 100 + return _GTFS_ROUTE_TYPE_NAMES.get(hundreds, 'other') + return 'other' + + +def _discover_gtfs_feeds(bounds, cache_dir): + """Query the Mobility Database CSV catalog to find GTFS feeds overlapping bounds. + + Returns list of dicts with keys: feed_id, provider, feed_url, bbox. + """ + import requests + import csv + import io + + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + catalog_path = cache_dir / 'mobility_database_catalogs.csv' + + # Download catalog if not cached (refresh weekly) + import time + if not catalog_path.exists() or (time.time() - catalog_path.stat().st_mtime > 7 * 86400): + print(" Downloading Mobility Database catalog...") + url = "https://bit.ly/catalogs-csv" + resp = requests.get(url, timeout=60, allow_redirects=True) + resp.raise_for_status() + catalog_path.write_bytes(resp.content) + print(f" Catalog cached to {catalog_path}") + + # Parse catalog and find overlapping feeds + west, south, east, north = bounds + matches = [] + with open(catalog_path, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row in reader: + # Only GTFS (not GTFS-RT) + data_type = row.get('data_type', '') + if data_type != 'gtfs': + continue + status = row.get('status', '') + if status not in ('', 'active'): + continue + # Check bounding box overlap + try: + feed_west = float(row.get('location.bounding_box.minimum_longitude', '')) + feed_south = float(row.get('location.bounding_box.minimum_latitude', '')) + feed_east = float(row.get('location.bounding_box.maximum_longitude', '')) + feed_north = float(row.get('location.bounding_box.maximum_latitude', '')) + except (ValueError, TypeError): + continue + # Check overlap + if feed_east < west or feed_west > east or feed_north < south or feed_south > north: + continue + feed_url = row.get('urls.latest', '') or row.get('urls.direct_download', '') + if not feed_url: + continue + # Compute IoU (intersection over union) for ranking + # This prefers feeds whose bbox tightly matches the query + ow = max(0, min(east, feed_east) - max(west, feed_west)) + oh = max(0, min(north, feed_north) - max(south, feed_south)) + intersection = ow * oh + query_area = (east - west) * (north - south) + feed_area = max(1e-12, (feed_east - feed_west) * (feed_north - feed_south)) + union = query_area + feed_area - intersection + iou = intersection / union if union > 0 else 0 + provider = row.get('provider', 'Unknown') + feed_id = row.get('mdb_source_id', '') + matches.append({ + 'feed_id': feed_id, + 'provider': provider, + 'feed_url': feed_url, + 'bbox': (feed_west, feed_south, feed_east, feed_north), + 'overlap': intersection, + 'iou': iou, + }) + + # Sort by IoU (best spatial match first) + matches.sort(key=lambda m: m['iou'], reverse=True) + return matches + + +def _parse_gtfs_zip(zip_path_or_bytes, bounds, route_types=None, + include_stops=True): + """Parse a GTFS ZIP file and return routes/stops as GeoJSON. + + Parameters + ---------- + zip_path_or_bytes : str, Path, or bytes + Path to GTFS ZIP or raw bytes. + bounds : tuple + (west, south, east, north) for spatial filtering. + route_types : list of int, optional + Filter to these GTFS route_type values. + include_stops : bool + Whether to also extract stops. + + Returns + ------- + dict + ``{'routes': FeatureCollection, 'stops': FeatureCollection, + 'metadata': {...}}`` + """ + import zipfile + import io + + try: + import pandas as pd + except ImportError: + raise ImportError( + "pandas is required for GTFS parsing. " + "Install with: pip install pandas" + ) + + west, south, east, north = bounds + + # Open ZIP + if isinstance(zip_path_or_bytes, (str, Path)): + zf = zipfile.ZipFile(zip_path_or_bytes, 'r') + else: + zf = zipfile.ZipFile(io.BytesIO(zip_path_or_bytes), 'r') + + zip_names = set(zf.namelist()) + + def _read_csv(name): + if name not in zip_names: + return None + with zf.open(name) as f: + return pd.read_csv(f, dtype=str, keep_default_na=False) + + routes_df = _read_csv('routes.txt') + trips_df = _read_csv('trips.txt') + shapes_df = _read_csv('shapes.txt') + stops_df = _read_csv('stops.txt') + stop_times_df = _read_csv('stop_times.txt') + + zf.close() + + if routes_df is None: + raise ValueError("GTFS ZIP missing routes.txt") + + # Normalise route_type to int + if 'route_type' in routes_df.columns: + routes_df['route_type'] = pd.to_numeric(routes_df['route_type'], + errors='coerce').fillna(3).astype(int) + + # Filter route types if requested + if route_types is not None: + routes_df = routes_df[routes_df['route_type'].isin(route_types)] + + # Build route info lookup + route_info = {} + for _, r in routes_df.iterrows(): + rid = r.get('route_id', '') + route_info[rid] = { + 'route_type': int(r.get('route_type', 3)), + 'route_color': r.get('route_color', ''), + 'route_short_name': r.get('route_short_name', ''), + 'route_long_name': r.get('route_long_name', ''), + } + + # --- Build route LineStrings --- + route_features = [] + + if shapes_df is not None and len(shapes_df) > 0: + # Join shapes -> trips -> routes + shapes_df['shape_pt_lat'] = pd.to_numeric(shapes_df['shape_pt_lat'], + errors='coerce') + shapes_df['shape_pt_lon'] = pd.to_numeric(shapes_df['shape_pt_lon'], + errors='coerce') + shapes_df['shape_pt_sequence'] = pd.to_numeric( + shapes_df['shape_pt_sequence'], errors='coerce') + shapes_df = shapes_df.dropna(subset=['shape_pt_lat', 'shape_pt_lon', + 'shape_pt_sequence']) + + # Get shape_id -> route_id mapping via trips + shape_route = {} + if trips_df is not None: + for _, t in trips_df.iterrows(): + sid = t.get('shape_id', '') + rid = t.get('route_id', '') + if sid and rid and rid in route_info: + if sid not in shape_route: + shape_route[sid] = rid + + # Build LineStrings from shapes + for shape_id, group in shapes_df.groupby('shape_id'): + group = group.sort_values('shape_pt_sequence') + coords = list(zip(group['shape_pt_lon'].values, + group['shape_pt_lat'].values)) + if len(coords) < 2: + continue + + # Spatial filter: check if any point in bounds + lons = group['shape_pt_lon'].values + lats = group['shape_pt_lat'].values + if lons.max() < west or lons.min() > east: + continue + if lats.max() < south or lats.min() > north: + continue + + rid = shape_route.get(shape_id, '') + props = dict(route_info.get(rid, {})) + props['route_id'] = rid + props['shape_id'] = shape_id + props.setdefault('route_type', 3) + + route_features.append({ + 'type': 'Feature', + 'geometry': {'type': 'LineString', 'coordinates': coords}, + 'properties': props, + }) + + elif stop_times_df is not None and stops_df is not None: + # Fallback: build shapes from stop sequences + print(" No shapes.txt -- building routes from stop_times.txt") + stops_df['stop_lat'] = pd.to_numeric(stops_df['stop_lat'], + errors='coerce') + stops_df['stop_lon'] = pd.to_numeric(stops_df['stop_lon'], + errors='coerce') + stop_coords = {} + for _, s in stops_df.iterrows(): + sid = s.get('stop_id', '') + lat = s['stop_lat'] + lon = s['stop_lon'] + if pd.notna(lat) and pd.notna(lon): + stop_coords[sid] = (lon, lat) + + if trips_df is not None: + trip_route = {} + for _, t in trips_df.iterrows(): + trip_route[t.get('trip_id', '')] = t.get('route_id', '') + + stop_times_df['stop_sequence'] = pd.to_numeric( + stop_times_df['stop_sequence'], errors='coerce') + + # Group by trip, build linestring per trip, deduplicate by route + seen_routes = set() + for trip_id, group in stop_times_df.groupby('trip_id'): + rid = trip_route.get(trip_id, '') + if rid in seen_routes or rid not in route_info: + continue + seen_routes.add(rid) + group = group.sort_values('stop_sequence') + coords = [] + for _, st in group.iterrows(): + c = stop_coords.get(st.get('stop_id', '')) + if c: + coords.append(c) + if len(coords) < 2: + continue + # Spatial filter + lons = [c[0] for c in coords] + lats = [c[1] for c in coords] + if max(lons) < west or min(lons) > east: + continue + if max(lats) < south or min(lats) > north: + continue + props = dict(route_info.get(rid, {})) + props['shape_id'] = f'trip_{trip_id}' + route_features.append({ + 'type': 'Feature', + 'geometry': {'type': 'LineString', 'coordinates': coords}, + 'properties': props, + }) + + # --- Build stop Points --- + stop_features = [] + if include_stops and stops_df is not None: + stops_df['stop_lat'] = pd.to_numeric( + stops_df.get('stop_lat', pd.Series(dtype=float)), errors='coerce') + stops_df['stop_lon'] = pd.to_numeric( + stops_df.get('stop_lon', pd.Series(dtype=float)), errors='coerce') + stops_df = stops_df.dropna(subset=['stop_lat', 'stop_lon']) + + # Determine which route types and colours serve each stop + stop_route_types = {} + stop_route_colors = {} # stop_id -> set of route_color hex strings + if stop_times_df is not None and trips_df is not None: + trip_route_info = {} # trip_id -> (route_type, route_color) + for _, t in trips_df.iterrows(): + rid = t.get('route_id', '') + info = route_info.get(rid) + if info: + trip_route_info[t.get('trip_id', '')] = ( + info['route_type'], info.get('route_color', '')) + for _, st in stop_times_df.iterrows(): + sid = st.get('stop_id', '') + tid = st.get('trip_id', '') + ri = trip_route_info.get(tid) + if ri is not None: + stop_route_types.setdefault(sid, set()).add(ri[0]) + rc = ri[1].strip().lstrip('#') + if len(rc) == 6: + stop_route_colors.setdefault(sid, set()).add(rc.upper()) + + for _, s in stops_df.iterrows(): + lat, lon = s['stop_lat'], s['stop_lon'] + if lon < west or lon > east or lat < south or lat > north: + continue + sid = s.get('stop_id', '') + rts = sorted(stop_route_types.get(sid, set())) + # Filter stops to only those serving requested route types + if route_types is not None and rts: + if not any(rt in route_types for rt in rts): + continue + rcs = sorted(stop_route_colors.get(sid, set())) + stop_features.append({ + 'type': 'Feature', + 'geometry': {'type': 'Point', 'coordinates': [lon, lat]}, + 'properties': { + 'stop_name': s.get('stop_name', ''), + 'stop_id': sid, + 'route_types': rts, + 'route_colors': rcs, + }, + }) + + # Metadata + rt_counts = {} + for f in route_features: + rt = f['properties'].get('route_type', 3) + name = _gtfs_route_type_name(rt) + rt_counts[name] = rt_counts.get(name, 0) + 1 + + metadata = { + 'n_routes': len(route_features), + 'n_stops': len(stop_features), + 'route_type_counts': rt_counts, + } + + return { + 'routes': {'type': 'FeatureCollection', 'features': route_features}, + 'stops': {'type': 'FeatureCollection', 'features': stop_features}, + 'metadata': metadata, + } + + +def fetch_gtfs(bounds, source='auto', feed_url=None, gtfs_path=None, + route_types=None, cache_path=None, crs=None, + include_stops=True, realtime_url=None): + """Download and parse GTFS transit data for a bounding box. + + Discovers a GTFS feed from the Mobility Database, or uses a + user-provided feed URL or local ZIP file. Returns route shapes + as LineStrings and stops as Points. + + Parameters + ---------- + bounds : tuple of float + (west, south, east, north) in WGS84 degrees. + source : str + ``'auto'`` (default) discovers a feed via the Mobility Database + catalog. Ignored when ``feed_url`` or ``gtfs_path`` is given. + feed_url : str, optional + Direct URL to a GTFS ZIP file. + gtfs_path : str or Path, optional + Path to a local GTFS ZIP file. + route_types : list of int, optional + Filter to specific GTFS route types (0=tram, 1=subway, 2=rail, + 3=bus, 4=ferry, etc.). ``None`` returns all. + cache_path : str or Path, optional + Path to cache the parsed result as JSON. If the file already + exists, loads and returns it directly. + crs : str, optional + Target CRS for reprojection (e.g. ``'EPSG:32618'``). + ``None`` keeps WGS84. + include_stops : bool + Whether to include stop points. Default ``True``. + realtime_url : str, optional + URL to a GTFS-Realtime VehiclePositions protobuf feed. + Stored in the returned metadata for use by ``explore()``. + + Returns + ------- + dict + ``{'routes': FeatureCollection, 'stops': FeatureCollection, + 'metadata': {...}}``. + + Examples + -------- + >>> from rtxpy import fetch_gtfs + >>> gtfs = fetch_gtfs((-74.05, 40.68, -73.90, 40.82)) + >>> dem.rtx.place_gtfs(gtfs) + >>> dem.rtx.explore() + """ + try: + import requests + except ImportError: + raise ImportError( + "requests is required for fetch_gtfs(). " + "Install with: pip install requests" + ) + + # Check cache + if cache_path is not None: + cache_path = Path(cache_path) + if cache_path.exists(): + print(f"Using cached GTFS data: {cache_path.name}") + with open(cache_path) as f: + return json.load(f) + + cache_dir = Path.home() / '.cache' / 'rtxpy' / 'gtfs' + cache_dir.mkdir(parents=True, exist_ok=True) + + # Determine ZIP source + zip_bytes = None + feed_name = None + + if gtfs_path is not None: + # Local file + gtfs_path = Path(gtfs_path) + print(f"Loading GTFS from {gtfs_path}") + feed_name = gtfs_path.stem + result = _parse_gtfs_zip(gtfs_path, bounds, + route_types=route_types, + include_stops=include_stops) + + elif feed_url is not None: + # Direct URL + print(f"Downloading GTFS feed: {feed_url}") + resp = requests.get(feed_url, timeout=120) + resp.raise_for_status() + zip_bytes = resp.content + feed_name = feed_url.split('/')[-1].replace('.zip', '') + # Cache the ZIP + zip_cache = cache_dir / f'{feed_name}.zip' + zip_cache.write_bytes(zip_bytes) + result = _parse_gtfs_zip(zip_bytes, bounds, + route_types=route_types, + include_stops=include_stops) + + else: + # Auto-discover from Mobility Database + print(f"Discovering GTFS feeds for bounds {bounds}...") + feeds = _discover_gtfs_feeds(bounds, cache_dir) + if not feeds: + print(" No GTFS feeds found for this area.") + return { + 'routes': {'type': 'FeatureCollection', 'features': []}, + 'stops': {'type': 'FeatureCollection', 'features': []}, + 'metadata': {'n_routes': 0, 'n_stops': 0, + 'route_type_counts': {}}, + } + + best = feeds[0] + feed_name = best['provider'] + print(f" Best match: {best['provider']} (id={best['feed_id']})") + print(f" Downloading {best['feed_url']}...") + + resp = requests.get(best['feed_url'], timeout=120) + resp.raise_for_status() + zip_bytes = resp.content + + # Cache the ZIP + safe_name = re.sub(r'[^\w\-.]', '_', best['provider']) + zip_cache = cache_dir / f'{safe_name}_{best["feed_id"]}.zip' + zip_cache.write_bytes(zip_bytes) + print(f" ZIP cached to {zip_cache}") + + result = _parse_gtfs_zip(zip_bytes, bounds, + route_types=route_types, + include_stops=include_stops) + + # Add feed name and realtime URL to metadata + result['metadata']['feed_name'] = feed_name or '' + if realtime_url: + result['metadata']['realtime_url'] = realtime_url + + meta = result['metadata'] + print(f" {meta['n_routes']} route shapes, {meta['n_stops']} stops") + if meta['route_type_counts']: + parts = [f"{v} {k}" for k, v in sorted(meta['route_type_counts'].items())] + print(f" Types: {', '.join(parts)}") + + # Reproject if requested + if crs is not None: + try: + import geopandas as gpd + except ImportError: + raise ImportError( + "geopandas is required for CRS reprojection. " + "Install with: pip install geopandas" + ) + for key in ('routes', 'stops'): + fc = result[key] + if fc['features']: + gdf = gpd.GeoDataFrame.from_features(fc['features'], + crs="EPSG:4326") + gdf = gdf.to_crs(crs) + result[key] = json.loads(gdf.to_json()) + print(f" Reprojected to {crs}") + + # Cache result + if cache_path is not None: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, 'w') as f: + json.dump(result, f) + print(f" Cached to {cache_path}") + + return result From 70c4acd5094c91bf5b852e7f635d78abd6a500ed Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 18 Feb 2026 17:17:51 -0800 Subject: [PATCH 4/5] way too big of a commit --- .github/workflows/gpu-test.yml | 10 +- INSTALL.md | 202 ++++ README.md | 228 ++--- conda-recipe/README.md | 2 +- conda-recipe/bld.bat | 104 +- conda-recipe/build.sh | 12 +- conda-recipe/conda_build_config.yaml | 5 + conda-recipe/meta.yaml | 28 +- cuda/common.h | 11 + cuda/kernel.cu | 283 +++++- examples/capetown.py | 154 +-- examples/generate_playground_gif.py | 3 - examples/guanajuato.py | 162 +-- examples/los_angeles.py | 162 +-- examples/playground.py | 102 +- examples/rio.py | 138 +-- examples/trinidad.py | 157 +-- pyproject.toml | 8 +- rtxpy/accessor.py | 8 +- rtxpy/analysis/_common.py | 94 +- rtxpy/analysis/hillshade.py | 15 +- rtxpy/analysis/render.py | 1381 +++++++++++++++++++++++--- rtxpy/analysis/slope_aspect.py | 24 +- rtxpy/analysis/viewshed.py | 9 +- rtxpy/engine.py | 261 ++++- rtxpy/kernel.ptx | 715 ++++++++++++- rtxpy/mesh_store.py | 314 ++++++ rtxpy/notebook.py | 407 ++++++++ rtxpy/quickstart.py | 88 +- rtxpy/rtx.py | 1270 ++++++++++++++++++++++- rtxpy/tiles.py | 3 +- rtxpy/tour.py | 403 ++++++++ 32 files changed, 5460 insertions(+), 1303 deletions(-) create mode 100644 INSTALL.md create mode 100644 rtxpy/mesh_store.py create mode 100644 rtxpy/notebook.py create mode 100644 rtxpy/tour.py diff --git a/.github/workflows/gpu-test.yml b/.github/workflows/gpu-test.yml index 37aeb53..55c1751 100644 --- a/.github/workflows/gpu-test.yml +++ b/.github/workflows/gpu-test.yml @@ -124,17 +124,11 @@ jobs: echo "=== PTX compiled successfully ===" head -15 rtxpy/kernel.ptx - - name: Install otk-pyoptix from source + - name: Install pyoptix-contrib shell: bash -el {0} run: | echo "Using OptiX from: ${OptiX_INSTALL_DIR}" - - # Clone and install otk-pyoptix - git clone --depth 1 https://github.com/NVIDIA/otk-pyoptix.git /tmp/otk-pyoptix - cd /tmp/otk-pyoptix/optix - - # Install with OptiX path set - pip install . + pip install pyoptix-contrib - name: Install rtxpy with CUDA dependencies shell: bash -el {0} diff --git a/INSTALL.md b/INSTALL.md new file mode 100644 index 0000000..84d97cf --- /dev/null +++ b/INSTALL.md @@ -0,0 +1,202 @@ +# Installation Guide + +RTXpy requires an NVIDIA GPU and CUDA drivers. There are three ways to install it, listed from easiest to most flexible. + +## Prerequisites + +- **NVIDIA GPU**: Maxwell architecture or newer (GTX 900+ / RTX series) +- **NVIDIA driver**: 455.28+ (Linux) or 456.71+ (Windows) +- **CUDA**: 12.x or newer +- **OS**: Linux (x86_64) or Windows 10/11 + +Verify your GPU and driver: + +```bash +nvidia-smi +``` + +## Option 1: Conda (recommended) + +The makepath Anaconda channel provides pre-built packages for Python 3.10–3.13 +with the OptiX bindings bundled — no SDK download required. + +### Core install + +```bash +conda create -n rtxpy python=3.12 -y +conda activate rtxpy +conda install -c makepath -c conda-forge rtxpy +``` + +This installs rtxpy with its core dependencies (numpy, numba, cupy, zarr, +pyoptix-contrib). Verify it works: + +```python +from rtxpy import RTX +r = RTX() # should succeed without errors +``` + +### Full install (examples + interactive viewer) + +To run the examples (e.g. `examples/trinidad.py`, `examples/playground.py`) +you need the analysis, viewer, and data-fetching dependencies: + +```bash +conda install -c conda-forge \ + xarray rioxarray xarray-spatial \ + pyproj pillow pyglfw moderngl scipy \ + "duckdb<1.4" requests matplotlib +``` + +> **Package name gotchas on conda-forge:** +> - Python GLFW bindings are `pyglfw`, not `glfw` (that's the C library only) +> - xrspatial is `xarray-spatial`, not `xrspatial` +> - Use `duckdb<1.4` — versions 1.4+ have a regression with Overture Maps queries + +### Windows + +Windows builds use the system CUDA Toolkit rather than conda CUDA packages: + +```bash +conda create -n rtxpy python=3.12 -y +conda activate rtxpy +conda install -c makepath -c conda-forge rtxpy +conda install -c conda-forge xarray rioxarray xarray-spatial ^ + pyproj pillow pyglfw moderngl scipy "duckdb<1.4" requests matplotlib +``` + +Ensure the CUDA Toolkit 12.x is installed and `nvcc` is on your PATH. + +## Option 2: Pip + Conda hybrid (development) + +Use this when you want an editable install from the repo with the latest code. +Conda provides GPU packages (cupy, numba) that pip can't build, while pip +handles everything else. + +### Step 1: GPU foundation via conda + +```bash +conda create -n rtxpy-dev python=3.12 -y +conda activate rtxpy-dev +conda install -c conda-forge cupy numba zarr +``` + +### Step 2: OptiX Python bindings + +Install the OptiX SDK headers (needed for the build) and the Python bindings: + +```bash +# Get the OptiX SDK headers (no NVIDIA account required) +git clone --depth 1 https://github.com/NVIDIA/optix-dev.git /tmp/optix-dev + +# Build and install the Python bindings from PyPI +CMAKE_PREFIX_PATH=/tmp/optix-dev \ + pip install pyoptix-contrib +``` + +> **Warning:** There is a *different* `optix` package on PyPI (optical system +> design library). Do **not** run `pip install optix` — it is unrelated to +> NVIDIA OptiX. The correct package is `pyoptix-contrib`. + +### Step 3: Install rtxpy + +```bash +# From the repo root (editable install with all extras) +pip install -e ".[all]" + +# Additional deps for the example scripts +pip install rioxarray xarray-spatial "duckdb<1.4" matplotlib requests +``` + +Verify: + +```bash +python -c "from rtxpy import RTX; RTX(); print('OK')" +``` + +## Option 3: Build the conda package locally + +Build from the conda recipe in the repo. Useful for creating packages for +internal distribution or testing recipe changes. + +```bash +# Install conda-build if needed +conda install -n base conda-build + +# Build for a specific Python version +conda-build conda-recipe --python 3.12 -c conda-forge --no-test + +# Install the locally built package +conda create -n rtxpy-local python=3.12 -y +conda activate rtxpy-local +conda install -c local -c conda-forge rtxpy +``` + +The build script automatically clones the OptiX SDK headers and installs +pyoptix-contrib during the build process. + +## Running the examples + +Once installed, try the interactive examples: + +```bash +# Crater Lake (smaller, good for first run) +python examples/playground.py + +# Trinidad & Tobago coastal resilience analysis +python examples/trinidad.py + +# Los Angeles +python examples/los_angeles.py +``` + +Examples download terrain and vector data on first run (cached for subsequent +runs). Press `H` in the viewer for keyboard controls. + +## Troubleshooting + +### `ModuleNotFoundError: No module named 'optix'` + +The NVIDIA OptiX Python bindings are missing. If using conda from makepath, +they should be bundled. If using pip, see [Step 2](#step-2-optix-python-bindings) +above. + +### `Could NOT find OptiX (missing: OptiX_ROOT_DIR)` + +When building pyoptix-contrib from source, set the path to the OptiX SDK headers: + +```bash +CMAKE_PREFIX_PATH=/path/to/optix-dev pip install ... +``` + +### `No module named 'glfw'` (after conda install glfw) + +The conda-forge `glfw` package is the C library. Install the Python bindings: + +```bash +conda install -c conda-forge pyglfw +``` + +### DuckDB errors when fetching Overture Maps data + +DuckDB 1.4.x has a regression with S3/httpfs. Pin to an older version: + +```bash +conda install -c conda-forge "duckdb<1.4" +# or +pip install "duckdb<1.4" +``` + +### `ImportError` for rtxpy picks up local directory + +When running from the repo root, Python may import the local `rtxpy/` directory +instead of the installed package. Either `cd` to a different directory or use +an editable install (`pip install -e .`). + +### cupy / numba installation via pip fails + +These packages have complex native dependencies. Install them via conda: + +```bash +conda install -c conda-forge cupy numba +``` diff --git a/README.md b/README.md index 7f55ef8..a8750fe 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # RTXpy -Ray tracing using CUDA, accessible from Python. +GPU-accelerated terrain analysis for the xarray ecosystem. Compute hillshade, viewshed, slope — get DataArrays back. Build a Dataset, then explore it interactively in 3D. Built-in data fetching (DEM, buildings, roads, water, fire, wind) makes it easy to go from a bounding box to a full scene. ![Crater Lake Viewshed Demo](examples/images/playground_demo.gif) @@ -8,189 +8,163 @@ Ray tracing using CUDA, accessible from Python. ## Quick Start -```python -import rtxpy # registers the .rtx xarray accessor -import rioxarray - -# Load a GeoTIFF DEM as an xarray DataArray -dem = rioxarray.open_rasterio('elevation.tif').squeeze() +Fetch terrain, analyze, and explore — all from a bounding box: -# Move data to GPU +```python +from rtxpy import fetch_dem +import rtxpy + +# Download 30m terrain (cached after first run) +dem = fetch_dem( + bounds=(-122.3, 42.8, -121.9, 43.0), + output_path='crater_lake.zarr', + source='copernicus', +) dem = dem.rtx.to_cupy() -# Compute hillshade with ray-traced shadows +# Analysis results are standard xarray DataArrays hillshade = dem.rtx.hillshade(shadows=True) - -# Compute viewshed from an observer location (pixel coordinates) viewshed = dem.rtx.viewshed(x=500, y=300, observer_elev=2) -# Launch interactive 3D terrain explorer +# Interactive 3D terrain exploration dem.rtx.explore() ``` -## Prerequisites - -- NVIDIA GPU with RTX support (Maxwell architecture or newer) -- NVIDIA driver version: - - 456.71 or newer for Windows - - 455.28 or newer for Linux -- OptiX SDK 7.6+ (set `OptiX_INSTALL_DIR` environment variable) -- CUDA 12.x+ - -## Installation - -I included some extra deps. here like rioxarray so the examples can load geotiffs +Build a Dataset with multiple layers, then explore them together: -**Note:** The conda-forge version is currently outdated so please use makepath conda channel...will fix soon. +```python +import xarray as xr +from xrspatial import slope, aspect +from rtxpy import fetch_dem, fetch_buildings, fetch_roads -### Linux w/ conda -```bash -conda install -c conda-forge cupy rioxarray matplotlib requests jupyter makepath::rtxpy -``` +bounds = (-122.3, 42.8, -121.9, 43.0) +dem = fetch_dem(bounds, 'terrain.zarr', source='srtm', crs='EPSG:5070') +dem = dem.rtx.to_cupy() -### Windows w/ conda -```bash -conda install -c conda-forge cupy rioxarray matplotlib requests jupyter nvidia::cudatoolkit makepath::rtxpy -``` +ds = xr.Dataset({ + 'elevation': dem, + 'slope': slope(dem), + 'aspect': aspect(dem), +}) -## Build from Source +# Fetch and place vector features +roads = fetch_roads(bounds, crs='EPSG:5070') +ds.rtx.place_roads(roads, z='elevation') -First, install the OptiX Python bindings (otk-pyoptix): +bldgs = fetch_buildings(bounds, source='overture', crs='EPSG:5070') +ds.rtx.place_buildings(bldgs, z='elevation') -```bash -export OptiX_INSTALL_DIR=/path/to/OptiX-SDK -pip install otk-pyoptix +# G cycles layers, N toggles geometry, U drapes satellite tiles +ds.rtx.explore(z='elevation', mesh_type='voxel') ``` -Then install rtxpy: - -```bash -pip install rtxpy -``` +## Prerequisites -## Installation from source +- **NVIDIA GPU**: Maxwell architecture or newer (GTX 900+ / RTX series) +- **NVIDIA driver**: 455.28+ (Linux) or 456.71+ (Windows) +- **CUDA**: 12.x or newer -To install RTXpy from source: +See [INSTALL.md](INSTALL.md) for detailed instructions and troubleshooting. -```bash -export OptiX_INSTALL_DIR=/path/to/OptiX-SDK -pip install otk-pyoptix -pip install -ve . -``` +## Installation -To run tests: +### Conda (recommended) ```bash -pip install -ve .[tests] -pytest -v rtxpy/tests +conda create -n rtxpy python=3.12 -y +conda activate rtxpy +conda install -c makepath -c conda-forge rtxpy + +# Additional deps for examples and interactive viewer +conda install -c conda-forge \ + xarray rioxarray xarray-spatial \ + pyproj pillow pyglfw moderngl scipy \ + "duckdb<1.4" requests matplotlib ``` -## Building kernel.ptx from source - -If you need to rebuild the PTX kernel (e.g., for a different GPU architecture or OptiX version): +### Pip + Conda hybrid (from source) ```bash -# Detect your GPU's compute capability (e.g., 75 for Turing, 86 for Ampere) -GPU_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | tr -d '.') - -# Compile for your GPU architecture -nvcc -ptx -o rtxpy/kernel.ptx cuda/kernel.cu \ - -arch=sm_${GPU_ARCH} \ - -I/path/to/OptiX-SDK/include \ - -I cuda \ - --use_fast_math +# GPU foundation via conda +conda create -n rtxpy-dev python=3.12 -y +conda activate rtxpy-dev +conda install -c conda-forge cupy numba zarr + +# OptiX SDK headers (needed for pyoptix-contrib build and PTX compilation) +git clone --depth 1 https://github.com/NVIDIA/optix-dev.git /tmp/optix-dev +CMAKE_PREFIX_PATH=/tmp/optix-dev \ + pip install pyoptix-contrib + +# Install rtxpy (editable) +pip install -e ".[all]" ``` -The CUDA source files are in the `cuda/` directory. - -## Building with Conda - -The easiest way to build rtxpy with all dependencies is using the included conda recipe: +### Development ```bash -# Install conda-build if not already installed -conda install conda-build - -# Build the package (auto-detects GPU architecture) -conda build conda-recipe - -# Or specify GPU architecture explicitly -GPU_ARCH=86 conda build conda-recipe # For RTX 30xx/A100 - -# Install the built package -conda install --use-local rtxpy +# After completing the pip + conda hybrid setup above: +pip install -e ".[tests]" +pytest -v rtxpy/tests ``` -The conda build automatically: -1. Clones OptiX SDK headers from NVIDIA/optix-dev -2. Detects your GPU architecture (or uses `GPU_ARCH` env var) -3. Compiles the PTX kernel for your GPU -4. Builds and installs otk-pyoptix -5. Installs rtxpy +## Features -You can also specify the OptiX version: -```bash -OPTIX_VERSION=7.7.0 conda build conda-recipe # Requires driver 530.41+ -OPTIX_VERSION=8.0.0 conda build conda-recipe # Requires driver 535+ -``` - -See `conda-recipe/README.md` for detailed documentation, GPU architecture reference, and troubleshooting. +- **Analysis arrays** — `hillshade()`, `slope()`, `aspect()`, `viewshed()` return xarray DataArrays that fit into your existing Dataset +- **Data fetching** — `fetch_dem()`, `fetch_buildings()`, `fetch_roads()`, `fetch_water()`, `fetch_wind()`, `fetch_firms()` — go from a bounding box to real data with automatic caching +- **3D feature placement** — extrude buildings, drape roads, scatter custom meshes on terrain +- **Interactive viewer** — `explore()` renders your Dataset in 3D with keyboard/mouse controls, satellite tiles, wind particles, and real-time viewshed +- **Rendering** — perspective camera with shadows, fog, AO, depth-of-field, colormaps for static images and GIF animations +- **Mesh I/O** — load GLB/OBJ/STL, save/load zarr scenes, export STL -## WSL2 Support +## Documentation -To get OptiX working on WSL2, follow the instructions from the NVIDIA forums: -https://forums.developer.nvidia.com/t/problem-running-optix-7-6-in-wsl/239355/8 +- **[Getting Started](docs/getting-started.md)** — installation, prerequisites, first example, how the accessor works +- **[User Guide](docs/user-guide.md)** — task-oriented workflows for analysis, placement, rendering, and the interactive viewer +- **[API Reference](docs/api-reference.md)** — complete method signatures, parameters, and return values +- **[Examples](docs/examples.md)** — annotated walkthrough, quick recipes, and example scripts -Summary: -1. Install WSL 2 and enable CUDA -2. Download and extract the Linux display driver (e.g., `NVIDIA-Linux-x86_64-590.44.01.run`) -3. Extract with `./NVIDIA-Linux-x86_64-XXX.XX.run -x` -4. Copy the following files to `C:/Windows/System32/lxss/lib`: - - `libnvoptix.so.XXX.00` (rename to `libnvoptix.so.1`) - - `libnvidia-rtcore.so.XXX.00` (keep original name) - - `libnvidia-ptxjitcompiler.so.XXX.00` (rename to `libnvidia-ptxjitcompiler.so.1`) -5. Add `/usr/lib/wsl/lib` to your `LD_LIBRARY_PATH` -6. Reset WSL cache with `wsl --shutdown` from PowerShell +## Low-Level API -## Usage +For custom ray tracing without the xarray accessor: ```python import numpy as np from rtxpy import RTX -# Create RTX instance rtx = RTX() -# Define geometry (vertices and triangle indices) verts = np.float32([0,0,0, 1,0,0, 0,1,0, 1,1,0]) triangles = np.int32([0,1,2, 2,1,3]) - -# Build acceleration structure rtx.build(0, verts, triangles) -# Define rays: [ox, oy, oz, tmin, dx, dy, dz, tmax] rays = np.float32([0.33, 0.33, 100, 0, 0, 0, -1, 1000]) hits = np.float32([0, 0, 0, 0]) - -# Trace rays rtx.trace(rays, hits, 1) -# hits contains: [t, nx, ny, nz] -# t = distance to hit point (-1 if miss) -# nx, ny, nz = surface normal at hit point print(hits) # [100.0, 0.0, 0.0, 1.0] ``` -For GPU-resident data, use CuPy arrays for better performance: +## Building the PTX Kernel -```python -import cupy +```bash +GPU_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | tr -d '.') +nvcc -ptx -o rtxpy/kernel.ptx cuda/kernel.cu \ + -arch=sm_${GPU_ARCH} \ + -I/path/to/OptiX-SDK/include \ + -I cuda \ + --use_fast_math +``` -verts = cupy.float32([0,0,0, 1,0,0, 0,1,0, 1,1,0]) -triangles = cupy.int32([0,1,2, 2,1,3]) -rays = cupy.float32([0.33, 0.33, 100, 0, 0, 0, -1, 1000]) -hits = cupy.float32([0, 0, 0, 0]) +## Building with Conda -rtx.build(0, verts, triangles) -rtx.trace(rays, hits, 1) +```bash +conda install conda-build +conda build conda-recipe +conda install --use-local rtxpy ``` + +Auto-detects GPU architecture, downloads OptiX headers, compiles PTX, and installs everything. Override with `GPU_ARCH=86` or `OPTIX_VERSION=8.0.0`. See `conda-recipe/README.md` for details. + +## WSL2 Support + +See [Getting Started — WSL2 Setup](docs/getting-started.md#wsl2-setup) for instructions on getting OptiX working on WSL2. diff --git a/conda-recipe/README.md b/conda-recipe/README.md index 2858b34..b6f7a57 100644 --- a/conda-recipe/README.md +++ b/conda-recipe/README.md @@ -71,7 +71,7 @@ conda install --use-local rtxpy 1. **Clones OptiX SDK headers** from NVIDIA/optix-dev (v7.7.0 by default) 2. **Detects GPU architecture** or uses the specified `GPU_ARCH` 3. **Compiles kernel.cu to PTX** for the target architecture -4. **Installs otk-pyoptix** from NVIDIA's repository +4. **Installs pyoptix-contrib** (OptiX Python bindings) 5. **Installs rtxpy** with the compiled PTX kernel ## GPU Architecture Reference diff --git a/conda-recipe/bld.bat b/conda-recipe/bld.bat index aeaf523..3f23139 100644 --- a/conda-recipe/bld.bat +++ b/conda-recipe/bld.bat @@ -76,111 +76,17 @@ echo PTX compiled successfully to rtxpy\kernel.ptx echo. :: --------------------------------------------------------------------------- -:: Step 3: Install otk-pyoptix from source +:: Step 3: Install pyoptix-contrib :: --------------------------------------------------------------------------- -echo [3/4] Installing otk-pyoptix... +echo [3/4] Installing pyoptix-contrib... echo ---------------------------------------- -set OTK_PYOPTIX_DIR=%SRC_DIR%\otk-pyoptix - -echo Cloning otk-pyoptix repository... -git clone --depth 1 https://github.com/NVIDIA/otk-pyoptix.git "%OTK_PYOPTIX_DIR%" -if errorlevel 1 ( - echo ERROR: Failed to clone otk-pyoptix - exit /b 1 -) - -:: Verify cmake is available (installed via conda) -where cmake >nul 2>&1 -if errorlevel 1 ( - echo ERROR: cmake not found. Ensure cmake is in build requirements. - exit /b 1 -) -echo Found cmake at: -where cmake - -:: Verify C++ compiler is available (conda-build should set up VS environment) -where cl >nul 2>&1 -if errorlevel 1 ( - echo. - echo ERROR: C++ compiler ^(cl.exe^) not found. - echo. - echo Please ensure Visual Studio Build Tools are installed and activated. - echo You can install them from: https://visualstudio.microsoft.com/visual-cpp-build-tools/ - echo. - echo If already installed, run this build from a "Developer Command Prompt" - echo or run vcvars64.bat before building. - echo. - exit /b 1 -) -echo Found C++ compiler at: -where cl - -:: Pre-clone pybind11 without submodules to avoid FetchContent submodule update failures -echo Pre-cloning pybind11 to avoid submodule issues... -set "PYBIND11_DIR=%SRC_DIR%\pybind11-src" -git clone --depth 1 --branch v2.13.6 https://github.com/pybind/pybind11.git "%PYBIND11_DIR%" -if errorlevel 1 ( - echo ERROR: Failed to clone pybind11 - exit /b 1 -) - -:: Tell CMake to use our pre-cloned pybind11 instead of fetching -set "FETCHCONTENT_SOURCE_DIR_PYBIND11=%PYBIND11_DIR%" -echo Using pre-cloned pybind11 at %PYBIND11_DIR% - -pushd "%OTK_PYOPTIX_DIR%\optix" - -:: Patch CMakeLists.txt to use our pre-cloned pybind11 and skip submodule updates -echo Patching CMakeLists.txt to use local pybind11... - -:: Convert backslashes to forward slashes for CMake -set "PYBIND11_DIR_CMAKE=%PYBIND11_DIR:\=/%" - -:: Prepend the FETCHCONTENT_SOURCE_DIR_PYBIND11 setting to CMakeLists.txt -( - echo set^(FETCHCONTENT_SOURCE_DIR_PYBIND11 "!PYBIND11_DIR_CMAKE!" CACHE PATH "pybind11 source" FORCE^) - type CMakeLists.txt -) > "%SRC_DIR%\CMakeLists_new.txt" -move /y "%SRC_DIR%\CMakeLists_new.txt" CMakeLists.txt >nul - -echo Patched CMakeLists.txt - first 2 lines: -powershell -Command "Get-Content CMakeLists.txt -Head 2" - -:: Set OptiX path for cmake/pip build process (exactly like run_gpu_test.bat) -set "OPTIX_PATH=%OptiX_INSTALL_DIR%" -set "CMAKE_PREFIX_PATH=%OptiX_INSTALL_DIR%;%CMAKE_PREFIX_PATH%" - -:: Clear conda-build injected CMAKE variables that break the build -set CMAKE_GENERATOR= -set CMAKE_GENERATOR_PLATFORM= -set CMAKE_GENERATOR_TOOLSET= - -:: Pre-install build dependencies so we can use --no-build-isolation -echo Installing build dependencies... -"%PYTHON%" -m pip install setuptools wheel - -echo Building with OptiX_INSTALL_DIR=%OptiX_INSTALL_DIR% -echo FETCHCONTENT_SOURCE_DIR_PYBIND11=!FETCHCONTENT_SOURCE_DIR_PYBIND11! - -:: Pass pybind11 source dir to CMake via CMAKE_ARGS (used by scikit-build and setuptools) -set "CMAKE_ARGS=-DFETCHCONTENT_SOURCE_DIR_PYBIND11=!PYBIND11_DIR!" - -:: Use --no-build-isolation so environment variables are visible to CMake -"%PYTHON%" -m pip install . -v --no-build-isolation +"%PYTHON%" -m pip install pyoptix-contrib --no-build-isolation if errorlevel 1 ( - echo. - echo ERROR: Failed to install otk-pyoptix - echo. - echo If the error mentions OptiX not found, try setting manually: - echo set OptiX_INSTALL_DIR=%OptiX_INSTALL_DIR% - echo set OPTIX_PATH=%OptiX_INSTALL_DIR% - echo. - popd + echo ERROR: Failed to install pyoptix-contrib exit /b 1 ) -popd -echo otk-pyoptix installed successfully +echo pyoptix-contrib installed successfully echo. :: --------------------------------------------------------------------------- diff --git a/conda-recipe/build.sh b/conda-recipe/build.sh index 4448fd8..14875fb 100755 --- a/conda-recipe/build.sh +++ b/conda-recipe/build.sh @@ -52,16 +52,10 @@ echo "PTX compiled successfully:" head -15 "${SRC_DIR}/rtxpy/kernel.ptx" # --------------------------------------------------------------------------- -# Step 3: Install otk-pyoptix from source +# Step 3: Install pyoptix-contrib # --------------------------------------------------------------------------- -echo "=== Installing otk-pyoptix ===" -OTK_PYOPTIX_DIR="${SRC_DIR}/otk-pyoptix" - -git clone --depth 1 https://github.com/NVIDIA/otk-pyoptix.git "${OTK_PYOPTIX_DIR}" -cd "${OTK_PYOPTIX_DIR}/optix" - -# Install otk-pyoptix -${PYTHON} -m pip install . --no-deps --no-build-isolation -vv +echo "=== Installing pyoptix-contrib ===" +${PYTHON} -m pip install pyoptix-contrib --no-deps --no-build-isolation -vv # --------------------------------------------------------------------------- # Step 4: Install rtxpy diff --git a/conda-recipe/conda_build_config.yaml b/conda-recipe/conda_build_config.yaml index 71c5741..1b12c29 100644 --- a/conda-recipe/conda_build_config.yaml +++ b/conda-recipe/conda_build_config.yaml @@ -9,6 +9,11 @@ python: - "3.12" - "3.13" +# CUDA major versions to build variants for +cuda_version: + - "12" + - "13" + # NumPy version (needed to avoid conda-build warning) numpy: - "1.26" diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 753fc98..11cdf40 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -1,5 +1,5 @@ {% set name = "rtxpy" %} -{% set version = "0.0.5" %} +{% set version = "0.0.6" %} package: name: {{ name|lower }} @@ -22,9 +22,9 @@ requirements: - {{ compiler('cxx') }} # [linux] - conda-forge::cmake - conda-forge::git # [linux] - - cuda-nvcc >=12 # [linux] - - cuda-cudart-dev >=12 # [linux] - - cuda-nvrtc-dev >=12 # [linux] + - cuda-nvcc {{ cuda_version }}.* # [linux] + - cuda-cudart-dev {{ cuda_version }}.* # [linux] + - cuda-nvrtc-dev {{ cuda_version }}.* # [linux] # Windows uses system CUDA Toolkit and Visual Studio Build Tools host: @@ -35,9 +35,9 @@ requirements: - conda-forge::cmake - conda-forge::git # [win] # Linux CUDA packages - - cuda-version >=12 # [linux] - - cuda-cudart-dev >=12 # [linux] - - cuda-nvrtc-dev >=12 # [linux] + - cuda-version {{ cuda_version }}.* # [linux] + - cuda-cudart-dev {{ cuda_version }}.* # [linux] + - cuda-nvrtc-dev {{ cuda_version }}.* # [linux] run: - python >=3.10 @@ -45,6 +45,18 @@ requirements: - numpy >=2.0,<3 # [py>=313] - numba >=0.56 - cupy >=12.0 + - zarr >=2.0 + - xarray + - rioxarray + - xarray-spatial + - pyproj + - pillow + - pyglfw + - moderngl + - scipy + - duckdb <1.4 + - requests + - matplotlib - cuda-version >=12 # [linux] - __cuda # [linux] @@ -64,7 +76,7 @@ about: summary: Ray tracing using CUDA accessible from Python description: | RTXpy provides GPU-accelerated ray-triangle intersection using - NVIDIA's OptiX ray tracing engine via the otk-pyoptix Python bindings. + NVIDIA's OptiX ray tracing engine via the pyoptix-contrib Python bindings. dev_url: https://github.com/makepath/rtxpy extra: diff --git a/cuda/common.h b/cuda/common.h index ea02225..348d57c 100644 --- a/cuda/common.h +++ b/cuda/common.h @@ -55,4 +55,15 @@ struct Params Hit* hits; int* primitive_ids; // Optional: triangle index per ray (-1 for miss) int* instance_ids; // Optional: geometry/instance index per ray (-1 for miss) + unsigned int ray_flags; // OptixRayFlags (e.g. CULL_BACK_FACING, TERMINATE_ON_FIRST_HIT) + // --- heightfield fields (offset 48) --- + float* heightfield_data; // device pointer to H×W float32 elevation array + int hf_width; // W (columns) + int hf_height; // H (rows) + float hf_spacing_x; // world-space pixel spacing X + float hf_spacing_y; // world-space pixel spacing Y + float hf_ve; // vertical exaggeration + int hf_tile_size; // tile dimension (e.g. 32) + int hf_num_tiles_x; // number of tiles in X direction + int _pad0; // padding to 88 bytes }; diff --git a/cuda/kernel.cu b/cuda/kernel.cu index daa5555..fbd6a83 100644 --- a/cuda/kernel.cu +++ b/cuda/kernel.cu @@ -26,6 +26,7 @@ extern "C" __global__ void __raygen__main() unsigned int t, nx, ny, nz, prim_id, inst_id; Ray ray = params.rays[linear_idx]; optixTrace( + OPTIX_PAYLOAD_TYPE_ID_0, params.handle, ray.origin, ray.dir, @@ -33,7 +34,7 @@ extern "C" __global__ void __raygen__main() ray.tmax, 0.0f, OptixVisibilityMask( 1 ), - OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES, + params.ray_flags, RAY_TYPE_RADIANCE, RAY_TYPE_COUNT, RAY_TYPE_RADIANCE, @@ -94,17 +95,27 @@ __device__ float3 cross(const float3& a, const float3& b) { extern "C" __global__ void __closesthit__chit() { const unsigned int t = optixGetRayTmax(); - - OptixTraversableHandle gas = optixGetGASTraversableHandle(); unsigned int primIdx = optixGetPrimitiveIndex(); - unsigned int sbtIdx = optixGetSbtGASIndex(); - float time = optixGetRayTime(); - float3 data[3]; - optixGetTriangleVertexData(gas, primIdx, sbtIdx, time, data); - float3 AB = data[1] - data[0]; - float3 AC = data[2] - data[0]; - float3 n = normalize(cross(AB, AC)); + float3 n; + if (optixIsTriangleHit()) { + float3 data[3]; +#if OPTIX_VERSION >= 90100 + // Parameterless overload: works for both regular and cluster GAS + optixGetTriangleVertexData(data); +#else + OptixTraversableHandle gas = optixGetGASTraversableHandle(); + unsigned int sbtIdx = optixGetSbtGASIndex(); + float time = optixGetRayTime(); + optixGetTriangleVertexData(gas, primIdx, sbtIdx, time, data); +#endif + float3 AB = data[1] - data[0]; + float3 AC = data[2] - data[0]; + n = normalize(cross(AB, AC)); + } else { + // Round curve tube: use face-up normal for terrain roads/rivers + n = make_float3(0.0f, 0.0f, 1.0f); + } // Set the hit data optixSetPayload_0(float_as_int(t)); @@ -114,3 +125,255 @@ extern "C" __global__ void __closesthit__chit() optixSetPayload_4(primIdx); // primitive/triangle index optixSetPayload_5(optixGetInstanceId()); // instance/geometry index } + + +// --------------------------------------------------------------------------- +// Heightfield custom intersection program +// --------------------------------------------------------------------------- + +// Helper: Moller-Trumbore ray-triangle intersection (single-sided) +__device__ bool ray_triangle( + const float3& orig, const float3& dir, + const float3& v0, const float3& v1, const float3& v2, + float& t_out) +{ + const float3 e1 = v1 - v0; + const float3 e2 = v2 - v0; + const float3 pvec = cross(dir, e2); + const float det = dot(e1, pvec); + if (det < 1e-8f) return false; // back-face or parallel + const float inv_det = 1.0f / det; + const float3 tvec = orig - v0; + const float u = dot(tvec, pvec) * inv_det; + if (u < 0.0f || u > 1.0f) return false; + const float3 qvec = cross(tvec, e1); + const float v = dot(dir, qvec) * inv_det; + if (v < 0.0f || u + v > 1.0f) return false; + t_out = dot(e2, qvec) * inv_det; + return t_out > 0.0f; +} + +// Bit-level NaN check — survives --use_fast_math (isnan gets optimized away) +__device__ __forceinline__ bool is_nan_safe(float f) +{ + unsigned int bits = __float_as_uint(f); + return ((bits & 0x7F800000u) == 0x7F800000u) && ((bits & 0x007FFFFFu) != 0u); +} + +// Fetch heightfield elevation with VE applied, NaN → 0. +// NaN cells become flat z=0 terrain; the render kernel detects them by +// checking the original elevation array and applies an ocean water shader. +__device__ __forceinline__ float hf_z(int r, int c) +{ + float z = params.heightfield_data[r * params.hf_width + c]; + if (is_nan_safe(z)) z = 0.0f; + return z * params.hf_ve; +} + +extern "C" __global__ void __intersection__heightfield() +{ + const int prim_idx = optixGetPrimitiveIndex(); + const float3 ray_o = optixGetObjectRayOrigin(); + const float3 ray_d = optixGetObjectRayDirection(); + const float ray_tmin = optixGetRayTmin(); + const float ray_tmax = optixGetRayTmax(); + + const float sx = params.hf_spacing_x; + const float sy = params.hf_spacing_y; + const int tile_size = params.hf_tile_size; + const int W = params.hf_width; + const int H = params.hf_height; + + // Tile grid coordinates from primitive index + const int tile_col = prim_idx % params.hf_num_tiles_x; + const int tile_row = prim_idx / params.hf_num_tiles_x; + + // Cell range for this tile (in grid cells, clamped to DEM extent) + const int cell_col0 = tile_col * tile_size; + const int cell_row0 = tile_row * tile_size; + const int cell_col1 = min(cell_col0 + tile_size, W - 1); + const int cell_row1 = min(cell_row0 + tile_size, H - 1); + + if (cell_col0 >= cell_col1 || cell_row0 >= cell_row1) + return; + + // Tile AABB in world space + const float tile_x0 = cell_col0 * sx; + const float tile_y0 = cell_row0 * sy; + const float tile_x1 = cell_col1 * sx; + const float tile_y1 = cell_row1 * sy; + + // Compute ray entry into tile AABB (XY only, Z handled per cell) + float t_enter = ray_tmin; + float t_exit = ray_tmax; + + // Clamp ray to tile XY bounds + if (fabsf(ray_d.x) > 1e-8f) { + float t0 = (tile_x0 - ray_o.x) / ray_d.x; + float t1 = (tile_x1 - ray_o.x) / ray_d.x; + if (t0 > t1) { float tmp = t0; t0 = t1; t1 = tmp; } + t_enter = fmaxf(t_enter, t0); + t_exit = fminf(t_exit, t1); + } else { + if (ray_o.x < tile_x0 || ray_o.x > tile_x1) return; + } + if (fabsf(ray_d.y) > 1e-8f) { + float t0 = (tile_y0 - ray_o.y) / ray_d.y; + float t1 = (tile_y1 - ray_o.y) / ray_d.y; + if (t0 > t1) { float tmp = t0; t0 = t1; t1 = tmp; } + t_enter = fmaxf(t_enter, t0); + t_exit = fminf(t_exit, t1); + } else { + if (ray_o.y < tile_y0 || ray_o.y > tile_y1) return; + } + + if (t_enter > t_exit) return; + + // Entry point in world space + float3 p = make_float3( + ray_o.x + ray_d.x * t_enter, + ray_o.y + ray_d.y * t_enter, + ray_o.z + ray_d.z * t_enter); + + // Convert to grid coordinates (fractional) + float gx = p.x / sx; + float gy = p.y / sy; + + // Current cell + int cx = (int)floorf(gx); + int cy = (int)floorf(gy); + cx = max(cx, cell_col0); + cx = min(cx, cell_col1 - 1); + cy = max(cy, cell_row0); + cy = min(cy, cell_row1 - 1); + + // DDA step direction + int step_x = (ray_d.x >= 0.0f) ? 1 : -1; + int step_y = (ray_d.y >= 0.0f) ? 1 : -1; + + // DDA t-deltas (world-space t per grid cell) + float dt_x = (fabsf(ray_d.x) > 1e-8f) ? fabsf(sx / ray_d.x) : 1e30f; + float dt_y = (fabsf(ray_d.y) > 1e-8f) ? fabsf(sy / ray_d.y) : 1e30f; + + // Next cell boundary t values + float next_t_x, next_t_y; + if (fabsf(ray_d.x) > 1e-8f) { + float boundary_x = (ray_d.x >= 0.0f) ? (cx + 1) * sx : cx * sx; + next_t_x = (boundary_x - ray_o.x) / ray_d.x; + } else { + next_t_x = 1e30f; + } + if (fabsf(ray_d.y) > 1e-8f) { + float boundary_y = (ray_d.y >= 0.0f) ? (cy + 1) * sy : cy * sy; + next_t_y = (boundary_y - ray_o.y) / ray_d.y; + } else { + next_t_y = 1e30f; + } + + // DDA loop through cells within this tile + float best_t = ray_tmax; + float best_nx = 0.0f, best_ny = 0.0f; + bool found = false; + + for (int iter = 0; iter < tile_size * tile_size * 2 + 4; iter++) { + if (cx < cell_col0 || cx >= cell_col1 || + cy < cell_row0 || cy >= cell_row1) + break; + + // Grid cell corners (row=cy, col=cx) + // v00 = (cx, cy), v10 = (cx+1, cy), v01 = (cx, cy+1), v11 = (cx+1, cy+1) + const float z00 = hf_z(cy, cx ); + const float z10 = hf_z(cy, cx + 1); + const float z01 = hf_z(cy + 1, cx ); + const float z11 = hf_z(cy + 1, cx + 1); + + const float3 v00 = make_float3(cx * sx, cy * sy, z00); + const float3 v10 = make_float3((cx + 1) * sx, cy * sy, z10); + const float3 v01 = make_float3(cx * sx, (cy + 1) * sy, z01); + const float3 v11 = make_float3((cx + 1) * sx, (cy + 1) * sy, z11); + + // Test two triangles per cell (same winding as triangulate_terrain): + // Triangle 0: v00, v10, v01 (lower-left) + // Triangle 1: v10, v11, v01 (upper-right) + float t_hit; + if (ray_triangle(ray_o, ray_d, v00, v10, v01, t_hit)) { + if (t_hit >= ray_tmin && t_hit < best_t) { + best_t = t_hit; + found = true; + // Compute bilinear normal at hit point + float3 hp = make_float3( + ray_o.x + ray_d.x * t_hit, + ray_o.y + ray_d.y * t_hit, + ray_o.z + ray_d.z * t_hit); + float u = (hp.x / sx) - cx; + float v = (hp.y / sy) - cy; + u = fmaxf(0.0f, fminf(1.0f, u)); + v = fmaxf(0.0f, fminf(1.0f, v)); + float dz_dx = ((1.0f - v) * (z10 - z00) + v * (z11 - z01)) / sx; + float dz_dy = ((1.0f - u) * (z01 - z00) + u * (z11 - z10)) / sy; + float3 n = normalize(make_float3(-dz_dx, -dz_dy, 1.0f)); + best_nx = n.x; + best_ny = n.y; + } + } + if (ray_triangle(ray_o, ray_d, v10, v11, v01, t_hit)) { + if (t_hit >= ray_tmin && t_hit < best_t) { + best_t = t_hit; + found = true; + float3 hp = make_float3( + ray_o.x + ray_d.x * t_hit, + ray_o.y + ray_d.y * t_hit, + ray_o.z + ray_d.z * t_hit); + float u = (hp.x / sx) - cx; + float v = (hp.y / sy) - cy; + u = fmaxf(0.0f, fminf(1.0f, u)); + v = fmaxf(0.0f, fminf(1.0f, v)); + float dz_dx = ((1.0f - v) * (z10 - z00) + v * (z11 - z01)) / sx; + float dz_dy = ((1.0f - u) * (z01 - z00) + u * (z11 - z10)) / sy; + float3 n = normalize(make_float3(-dz_dx, -dz_dy, 1.0f)); + best_nx = n.x; + best_ny = n.y; + } + } + + // Early exit: if we found a hit in this cell, the DDA guarantees + // we won't find a closer one in later cells (front-to-back order) + if (found) break; + + // Step to next cell + if (next_t_x < next_t_y) { + cx += step_x; + next_t_x += dt_x; + } else { + cy += step_y; + next_t_y += dt_y; + } + } + + if (found) { + // Pack normal components as attributes + unsigned int a0 = float_as_int(best_nx); + unsigned int a1 = float_as_int(best_ny); + optixReportIntersection(best_t, 0, a0, a1); + } +} + + +extern "C" __global__ void __closesthit__heightfield() +{ + const float t = optixGetRayTmax(); + + // Reconstruct normal from attributes packed by IS program + const float nx = int_as_float(optixGetAttribute_0()); + const float ny = int_as_float(optixGetAttribute_1()); + float nz_sq = 1.0f - nx * nx - ny * ny; + if (nz_sq < 0.0f) nz_sq = 0.0f; + const float nz = sqrtf(nz_sq); + + optixSetPayload_0(float_as_int(t)); + optixSetPayload_1(float_as_int(nx)); + optixSetPayload_2(float_as_int(ny)); + optixSetPayload_3(float_as_int(nz)); + optixSetPayload_4(optixGetPrimitiveIndex()); + optixSetPayload_5(optixGetInstanceId()); +} diff --git a/examples/capetown.py b/examples/capetown.py index df1afcf..9d23fd8 100644 --- a/examples/capetown.py +++ b/examples/capetown.py @@ -1,145 +1,9 @@ -"""Interactive playground for Cape Town, South Africa. - -Explore the terrain of Cape Town using GPU-accelerated ray tracing. -Elevation data is sourced from the Copernicus GLO-30 DEM (30 m). - -Builds an xr.Dataset with elevation, slope, aspect, and quantile layers. -Press G to cycle between layers. Satellite tiles are draped on the terrain -automatically — press U to toggle tile overlay on/off. - -Requirements: - pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow -""" - -import warnings - -import numpy as np -import xarray as xr - -from xrspatial import slope, aspect, quantile -from pathlib import Path - -from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms -import rtxpy - -# Cape Town bounding box (lon_min, lat_min, lon_max, lat_max) -BOUNDS = (18.3, -34.2, 18.7, -33.8) -CACHE = Path(__file__).parent - - -def load_terrain(): - """Load Cape Town terrain data, downloading if necessary.""" - terrain = fetch_dem( - bounds=BOUNDS, - output_path=CACHE / "capetown_dem.tif", - source='copernicus', - crs='EPSG:32734', # UTM zone 34S - ) - - # Scale down elevation for visualization (optional) - terrain.data = terrain.data * 0.025 - - # Ensure contiguous array before GPU transfer - terrain.data = np.ascontiguousarray(terrain.data) - - # Get stats before GPU transfer (nanmin/nanmax to skip NaN ocean pixels) - elev_min = float(np.nanmin(terrain.data)) - elev_max = float(np.nanmax(terrain.data)) - - # Convert to cupy for GPU processing using the accessor - terrain = terrain.rtx.to_cupy() - - print(f"Terrain loaded: {terrain.shape}, elevation range: " - f"{elev_min:.0f}m to {elev_max:.0f}m (scaled)") - - return terrain - - -if __name__ == "__main__": - terrain = load_terrain() - - # Build Dataset with derived layers - print("Building Dataset with terrain analysis layers...") - ds = xr.Dataset({ - 'elevation': terrain.rename(None), - 'slope': slope(terrain), - 'aspect': aspect(terrain), - 'quantile': quantile(terrain), - }) - print(ds) - - # Drape satellite tiles on terrain (reprojected to match DEM CRS) - print("Loading satellite tiles...") - ds.rtx.place_tiles('satellite', z='elevation') - - # --- Microsoft Global Building Footprints -------------------------------- - try: - bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "capetown_buildings.geojson") - info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, - mesh_cache=CACHE / "capetown_buildings_mesh.npz") - print(f"Placed {info['geometries']} building geometries") - except Exception as e: - print(f"Skipping buildings: {e}") - - # --- OpenStreetMap roads ------------------------------------------------ - try: - for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), - ('minor', 'road_minor', (0.55, 0.55, 0.55))]: - data = fetch_roads(bounds=BOUNDS, road_type=rt, - cache_path=CACHE / f"capetown_roads_{rt}.geojson") - info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, - mesh_cache=CACHE / f"capetown_roads_{rt}_mesh.npz") - print(f"Placed {info['geometries']} {rt} road geometries") - except Exception as e: - print(f"Skipping roads: {e}") - - # --- OpenStreetMap water features --------------------------------------- - try: - water_data = fetch_water(bounds=BOUNDS, water_type='all', - cache_path=CACHE / "capetown_water.geojson") - results = ds.rtx.place_water(water_data, z='elevation', - mesh_cache_prefix=CACHE / "capetown_water") - for cat, info in results.items(): - print(f"Placed {info['geometries']} {cat} water features") - except Exception as e: - print(f"Skipping water: {e}") - - # --- NASA FIRMS fire detections (last 7 days) --------------------------- - try: - fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', - cache_path=CACHE / "capetown_fires.geojson", - crs='EPSG:32734') - if fire_data.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * 0.025, - geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), - extrude=True, merge=True, - ) - print(f"Placed {fire_info['geometries']} fire detection footprints") - else: - print("No fire detections in the last 7 days") - except Exception as e: - print(f"Skipping fire layer: {e}") - - # --- Wind data -------------------------------------------------------- - wind = None - try: - from rtxpy import fetch_wind - wind = fetch_wind(BOUNDS, grid_size=15) - except Exception as e: - print(f"Skipping wind: {e}") - - print("\nLaunching explore (press G to cycle layers, Shift+W for wind)...\n") - ds.rtx.explore( - z='elevation', - width=2048, - height=1600, - render_scale=0.5, - color_stretch='cbrt', - subsample=4, - wind_data=wind, - ) - - print("Done") +"""Cape Town — GPU-accelerated terrain exploration.""" +from rtxpy import quickstart + +quickstart( + name='capetown', + bounds=(18.3, -34.2, 18.7, -33.8), + crs='EPSG:32734', + features=['buildings', 'roads', 'water', 'fire'], +) diff --git a/examples/generate_playground_gif.py b/examples/generate_playground_gif.py index e5d04df..051f497 100644 --- a/examples/generate_playground_gif.py +++ b/examples/generate_playground_gif.py @@ -33,9 +33,6 @@ def load_terrain(): crop = 20 terrain = terrain[crop:-crop, crop:-crop] - # Scale down elevation for visualization - terrain.data = terrain.data * 0.2 - # Ensure contiguous array before GPU transfer terrain.data = np.ascontiguousarray(terrain.data) diff --git a/examples/guanajuato.py b/examples/guanajuato.py index be06db2..561a82d 100644 --- a/examples/guanajuato.py +++ b/examples/guanajuato.py @@ -1,153 +1,9 @@ -"""Interactive playground for the Guanajuato-San Miguel de Allende highlands. - -Explore the terrain of central Mexico's Bajio region using GPU-accelerated -ray tracing. The area covers the Sierra de Santa Rosa northwest of Guanajuato -city, the colonial town of San Miguel de Allende to the east, and the -rugged canyon country in between. - -Elevation data is sourced from the Copernicus GLO-30 DEM (30 m). - -Builds an xr.Dataset with elevation, slope, aspect, and quantile layers. -Press G to cycle between layers. Satellite tiles are draped on the terrain -automatically — press U to toggle tile overlay on/off. - -Requirements: - pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow -""" - -import warnings - -import numpy as np -import xarray as xr - -from xrspatial import slope, aspect, quantile -from pathlib import Path - -# Import rtxpy to register the .rtx accessor -from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms -import rtxpy - -BOUNDS = (-101.50, 20.70, -100.50, 21.30) -CRS = 'EPSG:32614' -CACHE = Path(__file__).parent - - -def load_terrain(): - """Load Guanajuato terrain data, downloading if necessary.""" - terrain = fetch_dem( - bounds=BOUNDS, - output_path=CACHE / "guanajuato_dem.tif", - source='copernicus', - crs=CRS, - ) - - # Mask nodata / water pixels - terrain = terrain.where(terrain > 0) - - # Scale down elevation for visualization - terrain.data = terrain.data * 0.025 - - # Ensure contiguous array before GPU transfer - terrain.data = np.ascontiguousarray(terrain.data) - - # Get stats before GPU transfer - elev_min = float(np.nanmin(terrain.data)) - elev_max = float(np.nanmax(terrain.data)) - - # Convert to cupy for GPU processing using the accessor - terrain = terrain.rtx.to_cupy() - - print(f"Terrain loaded: {terrain.shape}, elevation range: " - f"{elev_min:.0f}m to {elev_max:.0f}m (scaled)") - - return terrain - - -if __name__ == "__main__": - terrain = load_terrain() - - # Build Dataset with derived layers - print("Building Dataset with terrain analysis layers...") - ds = xr.Dataset({ - 'elevation': terrain.rename(None), - 'slope': slope(terrain), - 'aspect': aspect(terrain), - 'quantile': quantile(terrain), - }) - print(ds) - - # Drape satellite tiles on terrain (reprojected to match DEM CRS) - print("Loading satellite tiles...") - ds.rtx.place_tiles('satellite', z='elevation') - - # --- Microsoft Global Building Footprints -------------------------------- - try: - bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "guanajuato_buildings.geojson") - info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, - mesh_cache=CACHE / "guanajuato_buildings_mesh.npz") - print(f"Placed {info['geometries']} building geometries") - except Exception as e: - print(f"Skipping buildings: {e}") - - # --- OpenStreetMap roads ------------------------------------------------ - try: - for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), - ('minor', 'road_minor', (0.55, 0.55, 0.55))]: - data = fetch_roads(bounds=BOUNDS, road_type=rt, - cache_path=CACHE / f"guanajuato_roads_{rt}.geojson") - info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, - mesh_cache=CACHE / f"guanajuato_roads_{rt}_mesh.npz") - print(f"Placed {info['geometries']} {rt} road geometries") - except Exception as e: - print(f"Skipping roads: {e}") - - # --- OpenStreetMap water features --------------------------------------- - try: - water_data = fetch_water(bounds=BOUNDS, water_type='all', - cache_path=CACHE / "guanajuato_water.geojson") - results = ds.rtx.place_water(water_data, z='elevation', - mesh_cache_prefix=CACHE / "guanajuato_water") - for cat, info in results.items(): - print(f"Placed {info['geometries']} {cat} water features") - except Exception as e: - print(f"Skipping water: {e}") - - # --- NASA FIRMS fire detections (last 7 days) --------------------------- - try: - fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', - cache_path=CACHE / "guanajuato_fires.geojson", - crs=CRS) - if fire_data.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * 0.025, - geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), - extrude=True, merge=True, - ) - print(f"Placed {fire_info['geometries']} fire detection footprints") - else: - print("No fire detections in the last 7 days") - except Exception as e: - print(f"Skipping fire layer: {e}") - - # --- Wind data -------------------------------------------------------- - wind = None - try: - from rtxpy import fetch_wind - wind = fetch_wind(BOUNDS, grid_size=15) - except Exception as e: - print(f"Skipping wind: {e}") - - print("\nLaunching explore (press G to cycle layers, Shift+W for wind)...\n") - ds.rtx.explore( - z='elevation', - width=2048, - height=1600, - render_scale=0.5, - color_stretch='cbrt', - subsample=4, - wind_data=wind, - ) - - print("Done") +"""Guanajuato — GPU-accelerated terrain exploration.""" +from rtxpy import quickstart + +quickstart( + name='guanajuato', + bounds=(-101.50, 20.70, -100.50, 21.30), + crs='EPSG:32614', + features=['buildings', 'roads', 'water', 'fire'], +) diff --git a/examples/los_angeles.py b/examples/los_angeles.py index ac18ff0..4e5cf6a 100644 --- a/examples/los_angeles.py +++ b/examples/los_angeles.py @@ -1,152 +1,10 @@ -"""Interactive playground for Los Angeles, California. - -Explore the terrain of Los Angeles using GPU-accelerated ray tracing. -The area covers downtown LA, Echo Park, Silver Lake, Elysian Park, -Griffith Park, the Hollywood Hills, and the Hollywood Sign. - -Elevation data is sourced from USGS 3DEP 1-meter lidar DEM. - -Builds an xr.Dataset with elevation, slope, aspect, and quantile layers. -Press G to cycle between layers. Satellite tiles are draped on the terrain -automatically — press U to toggle tile overlay on/off. - -Requirements: - pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow -""" - -import warnings - -import numpy as np -import xarray as xr - -from xrspatial import slope, aspect, quantile -from pathlib import Path - -from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms -import rtxpy - -# Los Angeles bounding box (WGS84) -# Focused area covering DTLA, Echo Park, Silver Lake, Griffith Park, -# Hollywood Hills, and the Hollywood Sign (~8 km x 8 km at 1 m resolution). -BOUNDS = (-118.32, 34.04, -118.23, 34.12) -CRS = 'EPSG:32611' # UTM zone 11N -CACHE = Path(__file__).parent - - -def load_terrain(): - """Load Los Angeles terrain data, downloading if necessary.""" - terrain = fetch_dem( - bounds=BOUNDS, - output_path=CACHE / "los_angeles_dem.tif", - source='usgs_1m', - crs=CRS, - ) - - # Scale elevation for visualization (1 m pixels need less reduction - # than 30 m Copernicus data to keep a similar visual slope ratio) - terrain.data = terrain.data * 0.5 - - # Ensure contiguous array before GPU transfer - terrain.data = np.ascontiguousarray(terrain.data) - - # Get stats before GPU transfer - elev_min = float(np.nanmin(terrain.data)) - elev_max = float(np.nanmax(terrain.data)) - - # Convert to cupy for GPU processing using the accessor - terrain = terrain.rtx.to_cupy() - - print(f"Terrain loaded: {terrain.shape}, elevation range: " - f"{elev_min:.0f}m to {elev_max:.0f}m (scaled)") - - return terrain - - -if __name__ == "__main__": - terrain = load_terrain() - - # Build Dataset with derived layers - print("Building Dataset with terrain analysis layers...") - ds = xr.Dataset({ - 'elevation': terrain.rename(None), - 'slope': slope(terrain), - 'aspect': aspect(terrain), - 'quantile': quantile(terrain), - }) - print(ds) - - # Drape satellite tiles on terrain (reprojected to match DEM CRS) - print("Loading satellite tiles...") - ds.rtx.place_tiles('satellite', z='elevation', zoom=15) - - # --- Microsoft Global Building Footprints -------------------------------- - try: - bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "los_angeles_buildings.geojson") - info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.5, - mesh_cache=CACHE / "los_angeles_buildings_mesh.npz") - print(f"Placed {info['geometries']} building geometries") - except Exception as e: - print(f"Skipping buildings: {e}") - - # --- OpenStreetMap roads ------------------------------------------------ - try: - for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), - ('minor', 'road_minor', (0.55, 0.55, 0.55))]: - data = fetch_roads(bounds=BOUNDS, road_type=rt, - cache_path=CACHE / f"los_angeles_roads_{rt}.geojson") - info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, - mesh_cache=CACHE / f"los_angeles_roads_{rt}_mesh.npz") - print(f"Placed {info['geometries']} {rt} road geometries") - except Exception as e: - print(f"Skipping roads: {e}") - - # --- OpenStreetMap water features --------------------------------------- - try: - water_data = fetch_water(bounds=BOUNDS, water_type='all', - cache_path=CACHE / "los_angeles_water.geojson") - results = ds.rtx.place_water(water_data, z='elevation', - mesh_cache_prefix=CACHE / "los_angeles_water") - for cat, info in results.items(): - print(f"Placed {info['geometries']} {cat} water features") - except Exception as e: - print(f"Skipping water: {e}") - - # --- NASA FIRMS fire detections (last 7 days) --------------------------- - try: - fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', - cache_path=CACHE / "los_angeles_fires.geojson", - crs=CRS) - if fire_data.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * 0.5, - geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), - extrude=True, merge=True, - ) - print(f"Placed {fire_info['geometries']} fire detection footprints") - else: - print("No fire detections in the last 7 days") - except Exception as e: - print(f"Skipping fire layer: {e}") - - # --- Wind data -------------------------------------------------------- - wind = None - try: - from rtxpy import fetch_wind - wind = fetch_wind(BOUNDS, grid_size=15) - except Exception as e: - print(f"Skipping wind: {e}") - - print("\nLaunching explore (press G to cycle layers, Shift+W for wind)...\n") - ds.rtx.explore( - z='elevation', - width=2048, - height=1600, - render_scale=0.5, - color_stretch='cbrt', - subsample=4, - wind_data=wind, - ) - - print("Done") +"""Los Angeles — GPU-accelerated terrain exploration.""" +from rtxpy import quickstart + +quickstart( + name='los_angeles', + bounds=(-118.52, 33.85, -117.25, 34.23), + crs='EPSG:32611', + source='usgs_10m', + features=['buildings', 'roads', 'water', 'fire'], +) diff --git a/examples/playground.py b/examples/playground.py index 87e2c5a..76101a7 100644 --- a/examples/playground.py +++ b/examples/playground.py @@ -12,7 +12,7 @@ to cycle and jump between geometry layers. Requirements: - pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow osmnx + pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow osmnx duckdb """ import warnings @@ -24,7 +24,7 @@ from xrspatial import slope, aspect, quantile # Import rtxpy to register the .rtx accessor -from rtxpy import fetch_dem, fetch_roads, fetch_water +from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water import rtxpy BOUNDS = (-122.3, 42.8, -121.9, 43.0) @@ -36,7 +36,7 @@ def load_terrain(): """Load Crater Lake terrain data, downloading if necessary.""" terrain = fetch_dem( bounds=BOUNDS, - output_path=CACHE / "crater_lake_national_park.tif", + output_path=CACHE / "crater_lake_national_park.zarr", source='srtm', crs=CRS, ) @@ -44,9 +44,6 @@ def load_terrain(): # Subsample for faster interactive performance terrain = terrain[::2, ::2] - # Scale down elevation for visualization (optional) - terrain.data = terrain.data * 0.1 - # Ensure contiguous array before GPU transfer terrain.data = np.ascontiguousarray(terrain.data) @@ -58,7 +55,7 @@ def load_terrain(): terrain = terrain.rtx.to_cupy() print(f"Terrain loaded: {terrain.shape}, elevation range: " - f"{elev_min:.0f}m to {elev_max:.0f}m (scaled)") + f"{elev_min:.0f}m to {elev_max:.0f}m") return terrain @@ -166,38 +163,65 @@ def load_terrain(): ], } - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - info = ds.rtx.place_geojson( - crater_lake_geojson, - z='elevation', - height=15.0, - label_field='name', - geometry_id='landmark', - ) - print(f"Placed {info['geometries']} GeoJSON geometries: " - f"{', '.join(info['geometry_ids'])}") - - # --- OpenStreetMap roads ------------------------------------------------ - try: - roads_data = fetch_roads(bounds=BOUNDS, road_type='all', crs=CRS, - cache_path=CACHE / "crater_lake_roads.geojson") - info = ds.rtx.place_roads(roads_data, z='elevation', geometry_id='roads', - height=5, mesh_cache=CACHE / "crater_lake_roads_mesh.npz") - print(f"Placed {info['geometries']} road geometries") - except Exception as e: - print(f"Skipping roads: {e}") + ZARR = CACHE / "crater_lake_national_park.zarr" - # --- OpenStreetMap water features --------------------------------------- + # Ensure meshes exist in zarr (first run places + saves them) try: - water_data = fetch_water(bounds=BOUNDS, water_type='all', crs=CRS, - cache_path=CACHE / "crater_lake_water.geojson") - results = ds.rtx.place_water(water_data, z='elevation', - mesh_cache_prefix=CACHE / "crater_lake_water") - for cat, info in results.items(): - print(f"Placed {info['geometries']} {cat} water features") - except Exception as e: - print(f"Skipping water: {e}") + import zarr as _zarr + _zarr.open(str(ZARR), mode='r', use_consolidated=False)['meshes'] + except (KeyError, FileNotFoundError): + # First run — place features from sources and bake into zarr + # --- GeoJSON landmarks and outlines ---------------------------------- + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="place_geojson called before") + info = ds.rtx.place_geojson( + crater_lake_geojson, + z='elevation', + height=15.0, + label_field='name', + geometry_id='landmark', + ) + print(f"Placed {info['geometries']} GeoJSON geometries: " + f"{', '.join(info['geometry_ids'])}") + + # --- Overture Maps buildings ----------------------------------------- + try: + bldgs_data = fetch_buildings(bounds=BOUNDS, source='overture', crs=CRS, + cache_path=CACHE / "crater_lake_buildings_overture.geojson") + if bldgs_data['features']: + info = ds.rtx.place_buildings(bldgs_data, z='elevation') + print(f"Placed {info['geometries']} Overture buildings") + except ImportError: + print("Skipping Overture buildings (pip install duckdb)") + except Exception as e: + print(f"Skipping Overture buildings: {e}") + + # --- Overture Maps roads -------------------------------------------- + try: + roads_data = fetch_roads(bounds=BOUNDS, road_type='all', + source='overture', crs=CRS, + cache_path=CACHE / "crater_lake_roads_overture.geojson") + if roads_data['features']: + info = ds.rtx.place_roads(roads_data, z='elevation', + geometry_id='roads', height=5) + print(f"Placed {info['geometries']} Overture road geometries") + except ImportError: + print("Skipping Overture roads (pip install duckdb)") + except Exception as e: + print(f"Skipping Overture roads: {e}") + + # --- Water features (Overture Maps) --------------------------------- + try: + water_data = fetch_water(bounds=BOUNDS, water_type='all', crs=CRS, + cache_path=CACHE / "crater_lake_water.geojson") + results = ds.rtx.place_water(water_data, z='elevation') + for cat, info in results.items(): + print(f"Placed {info['geometries']} {cat} water features") + except Exception as e: + print(f"Skipping water: {e}") + + # Save all placed meshes into the DEM zarr + ds.rtx.save_meshes(ZARR, z='elevation') # --- Wind data -------------------------------------------------------- wind = None @@ -208,18 +232,20 @@ def load_terrain(): # so they cover the field instead of clumping wind['n_particles'] = 15000 wind['max_age'] = 120 - wind['speed_mult'] = 80.0 + wind['speed_mult'] = 400.0 except Exception as e: print(f"Skipping wind: {e}") print("\nLaunching explore (press G to cycle layers, Shift+W for wind)...\n") ds.rtx.explore( z='elevation', + scene_zarr=ZARR, mesh_type='voxel', width=1024, height=768, render_scale=0.5, wind_data=wind, + repl=True, ) print("Done") diff --git a/examples/rio.py b/examples/rio.py index a301c1c..c901cda 100644 --- a/examples/rio.py +++ b/examples/rio.py @@ -1,129 +1,9 @@ -"""Interactive playground for Rio de Janeiro, Brazil. - -Explore the terrain of Rio de Janeiro using GPU-accelerated ray tracing. -Elevation data is sourced from the Copernicus GLO-30 DEM (30 m). - -Builds an xr.Dataset with elevation, slope, aspect, and quantile layers. -Press G to cycle between layers. Satellite tiles are draped on the terrain -automatically — press U to toggle tile overlay on/off. - -Requirements: - pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow -""" - -import warnings - -import numpy as np -import xarray as xr - -from xrspatial import slope, aspect, quantile -from pathlib import Path - -from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water -import rtxpy - -# Rio de Janeiro bounding box (WGS84) -# Covers the city from Barra da Tijuca in the west to Ilha do Governador -# in the east, including Sugarloaf, Corcovado, and Tijuca Forest. -BOUNDS = (-43.42, -23.08, -43.10, -22.84) -CRS = 'EPSG:32723' # UTM zone 23S -CACHE = Path(__file__).parent - - -def load_terrain(): - """Load Rio de Janeiro terrain data, downloading if necessary.""" - terrain = fetch_dem( - bounds=BOUNDS, - output_path=CACHE / "rio_dem.tif", - source='copernicus', - crs=CRS, - ) - - # Scale down elevation for visualization (optional) - terrain.data = terrain.data * 0.025 - - # Ensure contiguous array before GPU transfer - terrain.data = np.ascontiguousarray(terrain.data) - - # Get stats before GPU transfer - elev_min = float(np.nanmin(terrain.data)) - elev_max = float(np.nanmax(terrain.data)) - - # Convert to cupy for GPU processing using the accessor - terrain = terrain.rtx.to_cupy() - - print(f"Terrain loaded: {terrain.shape}, elevation range: " - f"{elev_min:.0f}m to {elev_max:.0f}m (scaled)") - - return terrain - - -if __name__ == "__main__": - terrain = load_terrain() - - # Build Dataset with derived layers - print("Building Dataset with terrain analysis layers...") - ds = xr.Dataset({ - 'elevation': terrain.rename(None), - 'slope': slope(terrain), - 'aspect': aspect(terrain), - 'quantile': quantile(terrain), - }) - print(ds) - - # Drape satellite tiles on terrain (reprojected to match DEM CRS) - print("Loading satellite tiles...") - ds.rtx.place_tiles('satellite', z='elevation') - - # --- Microsoft Global Building Footprints -------------------------------- - try: - bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "rio_buildings.geojson") - info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, - mesh_cache=CACHE / "rio_buildings_mesh.npz") - print(f"Placed {info['geometries']} building geometries") - except Exception as e: - print(f"Skipping buildings: {e}") - - # --- OpenStreetMap roads ------------------------------------------------ - try: - for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), - ('minor', 'road_minor', (0.55, 0.55, 0.55))]: - data = fetch_roads(bounds=BOUNDS, road_type=rt, - cache_path=CACHE / f"rio_roads_{rt}.geojson") - info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, - mesh_cache=CACHE / f"rio_roads_{rt}_mesh.npz") - print(f"Placed {info['geometries']} {rt} road geometries") - except Exception as e: - print(f"Skipping roads: {e}") - - # --- OpenStreetMap water features --------------------------------------- - try: - water_data = fetch_water(bounds=BOUNDS, water_type='all', - cache_path=CACHE / "rio_water.geojson") - results = ds.rtx.place_water(water_data, z='elevation', - mesh_cache_prefix=CACHE / "rio_water") - for cat, info in results.items(): - print(f"Placed {info['geometries']} {cat} water features") - except Exception as e: - print(f"Skipping water: {e}") - - # --- Wind data -------------------------------------------------------- - wind = None - try: - from rtxpy import fetch_wind - wind = fetch_wind(BOUNDS, grid_size=15) - except Exception as e: - print(f"Skipping wind: {e}") - - print("\nLaunching explore (press G to cycle layers, Shift+W for wind)...\n") - ds.rtx.explore( - z='elevation', - width=2048, - height=1600, - render_scale=0.5, - color_stretch='cbrt', - subsample=4, - wind_data=wind, - ) - - print("Done") +"""Rio de Janeiro — GPU-accelerated terrain exploration.""" +from rtxpy import quickstart + +quickstart( + name='rio', + bounds=(-43.42, -23.08, -43.10, -22.84), + crs='EPSG:32723', + features=['buildings', 'roads', 'water'], +) diff --git a/examples/trinidad.py b/examples/trinidad.py index 0441cfd..18d72df 100644 --- a/examples/trinidad.py +++ b/examples/trinidad.py @@ -1,145 +1,12 @@ -"""Interactive playground for Trinidad and Tobago. - -Explore the terrain of Trinidad and Tobago using GPU-accelerated ray tracing. -Elevation data is sourced from the Copernicus GLO-30 DEM (30 m). - -Builds an xr.Dataset with elevation, slope, aspect, and quantile layers. -Press G to cycle between layers. Satellite tiles are draped on the terrain -automatically — press U to toggle tile overlay on/off. - -Requirements: - pip install rtxpy[all] matplotlib xarray rioxarray requests pyproj Pillow -""" - -import warnings - -import numpy as np -import xarray as xr - -from xrspatial import slope, aspect, quantile -from pathlib import Path - -from rtxpy import fetch_dem, fetch_buildings, fetch_roads, fetch_water, fetch_firms -import rtxpy - -BOUNDS = (-61.95, 10.04, -60.44, 11.40) -CRS = 'EPSG:32620' -CACHE = Path(__file__).parent - - -def load_terrain(): - """Load Trinidad & Tobago terrain data, downloading if necessary.""" - terrain = fetch_dem( - bounds=BOUNDS, - output_path=CACHE / "trinidad_tobago_dem.tif", - source='copernicus', - crs=CRS, - ) - - # Scale down elevation for visualization (optional) - terrain.data = terrain.data * 0.025 - - # Ensure contiguous array before GPU transfer - terrain.data = np.ascontiguousarray(terrain.data) - - # Get stats before GPU transfer (nanmin/nanmax to skip NaN ocean pixels) - elev_min = float(np.nanmin(terrain.data)) - elev_max = float(np.nanmax(terrain.data)) - - # Convert to cupy for GPU processing using the accessor - terrain = terrain.rtx.to_cupy() - - print(f"Terrain loaded: {terrain.shape}, elevation range: " - f"{elev_min:.0f}m to {elev_max:.0f}m (scaled)") - - return terrain - - -if __name__ == "__main__": - terrain = load_terrain() - - # Build Dataset with derived layers - print("Building Dataset with terrain analysis layers...") - ds = xr.Dataset({ - 'elevation': terrain.rename(None), - 'slope': slope(terrain), - 'aspect': aspect(terrain), - 'quantile': quantile(terrain), - }) - print(ds) - - # Drape satellite tiles on terrain (reprojected to match DEM CRS) - print("Loading satellite tiles...") - ds.rtx.place_tiles('satellite', z='elevation') - - # --- Microsoft Global Building Footprints -------------------------------- - try: - bldg_data = fetch_buildings(bounds=BOUNDS, cache_path=CACHE / "trinidad_buildings.geojson") - info = ds.rtx.place_buildings(bldg_data, z='elevation', elev_scale=0.025, - mesh_cache=CACHE / "trinidad_buildings_mesh.npz") - print(f"Placed {info['geometries']} building geometries") - except Exception as e: - print(f"Skipping buildings: {e}") - - # --- OpenStreetMap roads ------------------------------------------------ - try: - for rt, gid, clr in [('major', 'road_major', (0.10, 0.10, 0.10)), - ('minor', 'road_minor', (0.55, 0.55, 0.55))]: - data = fetch_roads(bounds=BOUNDS, road_type=rt, - cache_path=CACHE / f"trinidad_roads_{rt}.geojson") - info = ds.rtx.place_roads(data, z='elevation', geometry_id=gid, color=clr, - mesh_cache=CACHE / f"trinidad_roads_{rt}_mesh.npz") - print(f"Placed {info['geometries']} {rt} road geometries") - except Exception as e: - print(f"Skipping roads: {e}") - - # --- OpenStreetMap water features --------------------------------------- - try: - water_data = fetch_water(bounds=BOUNDS, water_type='all', - cache_path=CACHE / "trinidad_water.geojson") - results = ds.rtx.place_water(water_data, z='elevation', - mesh_cache_prefix=CACHE / "trinidad_water") - for cat, info in results.items(): - print(f"Placed {info['geometries']} {cat} water features") - except Exception as e: - print(f"Skipping water: {e}") - - # --- NASA FIRMS fire detections (last 7 days) --------------------------- - try: - fire_data = fetch_firms(bounds=BOUNDS, date_span='7d', - cache_path=CACHE / "trinidad_fires.geojson", - crs=CRS) - if fire_data.get('features'): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="place_geojson called before") - fire_info = ds.rtx.place_geojson( - fire_data, z='elevation', height=20 * 0.025, - geometry_id='fire', color=(1.0, 0.25, 0.0, 3.0), - extrude=True, merge=True, - ) - print(f"Placed {fire_info['geometries']} fire detection footprints") - else: - print("No fire detections in the last 7 days") - except Exception as e: - print(f"Skipping fire layer: {e}") - - # --- Wind data -------------------------------------------------------- - wind = None - try: - from rtxpy import fetch_wind - wind = fetch_wind(BOUNDS, grid_size=15) - except Exception as e: - print(f"Skipping wind: {e}") - - print("\nLaunching explore (press G to cycle layers, Shift+W for wind)...\n") - ds.rtx.explore( - z='elevation', - width=2048, - height=1600, - render_scale=0.5, - color_stretch='cbrt', - subsample=4, - wind_data=wind, - ) - - print("Done") +"""Trinidad and Tobago — GPU-accelerated terrain exploration.""" +from rtxpy import quickstart + +quickstart( + name='trinidad', + bounds=(-61.95, 10.04, -60.44, 11.40), + crs='EPSG:32620', + features=['buildings', 'roads', 'water', 'fire', 'places', + 'infrastructure', 'land_use'], + ao_samples=1, + denoise=True, +) diff --git a/pyproject.toml b/pyproject.toml index 873451c..086fb3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,17 +29,19 @@ classifiers = [ dependencies = [ "numpy>=1.21", "numba>=0.56", + "zarr>=2.0", + "pyoptix-contrib", ] [project.optional-dependencies] # GPU dependencies - cupy provides CUDA context management # Install cupy via: conda install -c conda-forge cupy -# Note: otk-pyoptix must be installed separately from NVIDIA -# See: https://github.com/NVIDIA/otk-pyoptix tests = ["pytest"] analysis = ["xarray"] tiles = ["Pillow", "pyproj"] -all = ["xarray", "scipy", "Pillow", "pyproj"] +viewer = ["glfw", "moderngl", "Pillow"] +notebook = ["ipywidgets>=8.0", "ipyevents>=2.0", "Pillow"] +all = ["xarray", "scipy", "Pillow", "pyproj", "glfw", "moderngl", "ipywidgets>=8.0", "ipyevents>=2.0"] [project.urls] Homepage = "https://github.com/makepath/rtxpy" diff --git a/rtxpy/accessor.py b/rtxpy/accessor.py index b711d91..bc88b2f 100644 --- a/rtxpy/accessor.py +++ b/rtxpy/accessor.py @@ -1853,8 +1853,9 @@ def triangulate(self, geometry_id='terrain', scale=1.0, self._pixel_spacing_y = pixel_spacing_y self._terrain_mesh_type = 'tin' - # Add to scene - self._rtx.add_geometry(geometry_id, vertices, indices) + # Add to scene (pass grid dims for cluster-accelerated BVH) + self._rtx.add_geometry(geometry_id, vertices, indices, + grid_dims=(H, W)) return vertices, indices @@ -2663,6 +2664,7 @@ def explore(self, z, width=800, height=600, render_scale=0.5, start_position=None, look_at=None, key_repeat_interval=0.05, pixel_spacing_x=None, pixel_spacing_y=None, mesh_type='heightfield', color_stretch='linear', title=None, + subtitle=None, legend=None, subsample=1, wind_data=None, gtfs_data=None, scene_zarr=None, ao_samples=0, gi_bounces=1, denoise=False, repl=False, tour=None): @@ -2759,6 +2761,8 @@ def explore(self, z, width=800, height=600, render_scale=0.5, overlay_layers=overlay_layers, color_stretch=color_stretch, title=title, + subtitle=subtitle, + legend=legend, tile_service=getattr(self, '_tile_service', None), geometry_colors_builder=geometry_colors_builder, baked_meshes=terrain_da.rtx._baked_meshes if terrain_da.rtx._baked_meshes else None, diff --git a/rtxpy/analysis/_common.py b/rtxpy/analysis/_common.py index d25cf8a..9bb72a9 100644 --- a/rtxpy/analysis/_common.py +++ b/rtxpy/analysis/_common.py @@ -14,6 +14,48 @@ import cupy +def _compute_pixel_spacing(da): + """Derive real-world pixel spacing from a DataArray's x/y coordinates. + + Parameters + ---------- + da : xarray.DataArray + Raster with 'x' and 'y' coordinate arrays. + + Returns + ------- + tuple of float + (pixel_spacing_x, pixel_spacing_y) in the CRS's linear units. + Falls back to (1.0, 1.0) when coordinates are missing, too short, + or the CRS uses geographic (degree) units. + """ + try: + x = da.coords['x'].values + y = da.coords['y'].values + except (KeyError, AttributeError): + return (1.0, 1.0) + + if len(x) < 2 or len(y) < 2: + return (1.0, 1.0) + + # Guard against geographic CRS (degrees) — spacing in degrees is not + # meaningful as a metric distance, so fall back to pixel coords. + try: + crs = da.rio.crs + if crs is not None and crs.is_geographic: + return (1.0, 1.0) + except Exception: + pass # rioxarray not available or no CRS set — proceed with diffs + + psx = float(abs(x[1] - x[0])) + psy = float(abs(y[1] - y[0])) + + if psx == 0 or psy == 0: + return (1.0, 1.0) + + return (psx, psy) + + @cuda.jit def _generate_primary_rays_kernel(data, x_coords, y_coords, H, W): """GPU kernel for generating orthographic camera rays looking straight down. @@ -42,7 +84,8 @@ def _generate_primary_rays_kernel(data, x_coords, y_coords, H, W): data[i, j, 7] = np.inf # t_max -def generate_primary_rays(rays, x_coords, y_coords, H, W): +def generate_primary_rays(rays, x_coords, y_coords, H, W, + pixel_spacing_x=1.0, pixel_spacing_y=1.0): """Generate orthographic camera rays for terrain intersection. Parameters @@ -57,6 +100,10 @@ def generate_primary_rays(rays, x_coords, y_coords, H, W): Height of the raster. W : int Width of the raster. + pixel_spacing_x : float, optional + World-space spacing per pixel in X. Default 1.0. + pixel_spacing_y : float, optional + World-space spacing per pixel in Y. Default 1.0. Returns ------- @@ -80,17 +127,25 @@ def generate_primary_rays(rays, x_coords, y_coords, H, W): if not y_coords.flags.writeable: y_coords = y_coords.copy() _generate_primary_rays_kernel[griddim, blockdim](rays, x_coords, y_coords, H, W) + + # Scale ray origins from pixel space to world space + if pixel_spacing_x != 1.0 or pixel_spacing_y != 1.0: + rays[:, :, 0] *= pixel_spacing_x + rays[:, :, 1] *= pixel_spacing_y + return 0 -def prepare_mesh(raster, rtx=None, mesh_type='tin'): +def prepare_mesh(raster, rtx=None, mesh_type='heightfield', + pixel_spacing_x=1.0, pixel_spacing_y=1.0): """Prepare a triangle mesh from raster data and build the RTX acceleration structure. This function handles the common pattern of: 1. Creating or reusing an RTX instance 2. Checking if the mesh needs rebuilding (via hash comparison) 3. Triangulating or voxelating the terrain - 4. Building the GAS (Geometry Acceleration Structure) + 4. Scaling X/Y to world coordinates using pixel_spacing + 5. Building the GAS (Geometry Acceleration Structure) Parameters ---------- @@ -100,6 +155,10 @@ def prepare_mesh(raster, rtx=None, mesh_type='tin'): Existing RTX instance to reuse. If None, a new instance is created. mesh_type : str, optional Mesh generation method: 'tin' or 'voxel'. Default is 'tin'. + pixel_spacing_x : float, optional + World-space spacing per pixel in X. Default 1.0. + pixel_spacing_y : float, optional + World-space spacing per pixel in Y. Default 1.0. Returns ------- @@ -111,7 +170,7 @@ def prepare_mesh(raster, rtx=None, mesh_type='tin'): ValueError If mesh generation or GAS building fails. """ - valid_types = ('tin', 'voxel') + valid_types = ('tin', 'voxel', 'heightfield') if mesh_type not in valid_types: raise ValueError( f"Invalid mesh_type '{mesh_type}'. Must be one of: {valid_types}" @@ -122,8 +181,26 @@ def prepare_mesh(raster, rtx=None, mesh_type='tin'): H, W = raster.shape - # Include mesh_type in hash so switching types triggers rebuild - datahash = np.uint64(hash(str(raster.data.get()) + mesh_type) % (1 << 64)) + if mesh_type == 'heightfield': + # Heightfield path: upload raw elevation grid, no triangle mesh + terrain_data = raster.data + if hasattr(terrain_data, 'get'): + elev_np = terrain_data.get().astype(np.float32) + else: + elev_np = np.asarray(terrain_data, dtype=np.float32) + + res = rtx.add_heightfield_geometry( + 'terrain', elev_np, H, W, + spacing_x=pixel_spacing_x, + spacing_y=pixel_spacing_y, + ) + if res: + raise ValueError(f"Failed to build heightfield GAS. Error code: {res}") + return rtx + + # Include mesh_type and pixel_spacing in hash so changes trigger rebuild + hash_str = str(raster.data.get()) + mesh_type + f'{pixel_spacing_x},{pixel_spacing_y}' + datahash = np.uint64(hash(hash_str) % (1 << 64)) optixhash = np.uint64(rtx.getHash()) if optixhash != datahash: @@ -146,6 +223,11 @@ def prepare_mesh(raster, rtx=None, mesh_type='tin'): if res: raise ValueError(f"Failed to generate mesh from terrain. Error code: {res}") + # Scale vertex X/Y from pixel indices to world coordinates + if pixel_spacing_x != 1.0 or pixel_spacing_y != 1.0: + verts[0::3] *= pixel_spacing_x + verts[1::3] *= pixel_spacing_y + res = rtx.build(datahash, verts, triangles) if res: raise ValueError(f"OptiX failed to build GAS with error code: {res}") diff --git a/rtxpy/analysis/hillshade.py b/rtxpy/analysis/hillshade.py index bffa410..a65657e 100644 --- a/rtxpy/analysis/hillshade.py +++ b/rtxpy/analysis/hillshade.py @@ -10,7 +10,7 @@ from typing import Optional from .._cuda_utils import calc_dims, add, mul, dot, float3, make_float3, invert -from ._common import generate_primary_rays, prepare_mesh +from ._common import generate_primary_rays, prepare_mesh, _compute_pixel_spacing from ..rtx import RTX, has_cupy if has_cupy: @@ -140,7 +140,8 @@ def _shade_lambert(hits, normals, output, H, W, sunDir, castShadows): return 0 -def _hillshade_rt(raster, optix, shadows, azimuth, angle_altitude, name): +def _hillshade_rt(raster, optix, shadows, azimuth, angle_altitude, name, + pixel_spacing_x=1.0, pixel_spacing_y=1.0): """Internal function to perform hillshade ray tracing.""" xr = _lazy_import_xarray() @@ -156,7 +157,9 @@ def _hillshade_rt(raster, optix, shadows, azimuth, angle_altitude, name): y_coords = cupy.array(raster.indexes.get('y').values) x_coords = cupy.array(raster.indexes.get('x').values) - generate_primary_rays(d_rays, x_coords, y_coords, H, W) + generate_primary_rays(d_rays, x_coords, y_coords, H, W, + pixel_spacing_x=pixel_spacing_x, + pixel_spacing_y=pixel_spacing_y) device = cupy.cuda.Device(0) device.synchronize() optix.trace(d_rays, d_hits, W * H) @@ -264,8 +267,10 @@ def hillshade(raster, "Additional overhead will be incurred from CPU-GPU transfers." ) - optix = prepare_mesh(raster, rtx) + psx, psy = _compute_pixel_spacing(raster) + optix = prepare_mesh(raster, rtx, pixel_spacing_x=psx, pixel_spacing_y=psy) return _hillshade_rt( raster, optix, shadows=shadows, azimuth=azimuth, - angle_altitude=angle_altitude, name=name + angle_altitude=angle_altitude, name=name, + pixel_spacing_x=psx, pixel_spacing_y=psy, ) diff --git a/rtxpy/analysis/render.py b/rtxpy/analysis/render.py index 8194a05..10f5400 100644 --- a/rtxpy/analysis/render.py +++ b/rtxpy/analysis/render.py @@ -152,25 +152,129 @@ def _normalize(v): return v +@cuda.jit(device=True) +def _compute_physical_sky(ray_dx, ray_dy, ray_dz, sun_dir): + """Compute physical sky color from a ray direction and sun position. + + Returns (r, g, b) tuple in linear HDR space (may exceed 1.0 near sun). + """ + # Elevation angle of ray (0 = horizon, 1 = zenith, negative = below) + ray_elev = ray_dz # z component of unit direction = sin(elevation) + if ray_elev < 0.0: + ray_elev = 0.0 + + # Zenith -> horizon interpolation with quadratic falloff + horizon_blend = 1.0 - ray_elev + horizon_blend = horizon_blend * horizon_blend + + # Sun glow near the sun direction + sun_dot = ray_dx * sun_dir[0] + ray_dy * sun_dir[1] + ray_dz * sun_dir[2] + if sun_dot < 0.0: + sun_dot = 0.0 + + # Broad halo around sun + sun_glow = sun_dot * sun_dot + sun_glow = sun_glow * sun_glow # ^4 + sun_glow = sun_glow * 0.4 + + # Sun altitude affects brightness and warmth + sun_elev = sun_dir[2] + if sun_elev < 0.0: + sun_elev = 0.0 + brightness = 0.5 + 0.5 * sun_elev + + # Zenith: deep blue + zen_r = 0.15 * brightness + zen_g = 0.25 * brightness + zen_b = 0.55 * brightness + + # Horizon: pale warm white (warmer at low sun angles) + warmth = 1.0 - sun_elev + hor_r = (0.70 + 0.20 * warmth) * brightness + hor_g = (0.75 + 0.05 * warmth) * brightness + hor_b = (0.85 - 0.15 * warmth) * brightness + + # Blend zenith -> horizon + sun glow + sr = zen_r * (1.0 - horizon_blend) + hor_r * horizon_blend + sun_glow * 1.0 + sg = zen_g * (1.0 - horizon_blend) + hor_g * horizon_blend + sun_glow * 0.9 + sb = zen_b * (1.0 - horizon_blend) + hor_b * horizon_blend + sun_glow * 0.6 + + return sr, sg, sb + + @cuda.jit -def _generate_perspective_rays_kernel(rays, width, height, camera_pos, forward, right, up, fov_scale): +def _generate_perspective_rays_kernel(rays, width, height, camera_pos, forward, right, up, + fov_scale, jitter_seed, aperture, focal_distance): """GPU kernel to generate perspective camera rays. Uses pinhole camera model: ray_dir = forward + u*right + v*up where u and v are in normalized device coordinates scaled by FOV. + When jitter_seed > 0, adds sub-pixel random offset for anti-aliasing. + When aperture > 0, applies thin-lens depth-of-field (requires jitter_seed > 0). """ px, py = cuda.grid(2) if px < width and py < height: - # Convert pixel to normalized device coordinates (-1 to 1) + idx = py * width + px aspect = width / height - u = (2.0 * (px + 0.5) / width - 1.0) * aspect * fov_scale - v = (1.0 - 2.0 * (py + 0.5) / height) * fov_scale - # Compute ray direction + if jitter_seed > 0: + # Hash-based RNG (same pattern as AO kernel) + h = np.uint32(idx * np.uint32(1337) + jitter_seed) + h = (h ^ (h >> np.uint32(16))) * np.uint32(2654435761) + h = (h ^ (h >> np.uint32(16))) * np.uint32(2246822519) + h = h ^ (h >> np.uint32(16)) + jx = float(h & np.uint32(0xFFFF)) / 65535.0 - 0.5 # [-0.5, 0.5] + h = (h * np.uint32(1103515245) + np.uint32(12345)) + h = h ^ (h >> np.uint32(16)) + jy = float(h & np.uint32(0xFFFF)) / 65535.0 - 0.5 + else: + jx = 0.0 + jy = 0.0 + + u = (2.0 * (px + 0.5 + jx) / width - 1.0) * aspect * fov_scale + v = (1.0 - 2.0 * (py + 0.5 + jy) / height) * fov_scale + + # Compute ray direction (unnormalized) dir_x = forward[0] + u * right[0] + v * up[0] dir_y = forward[1] + u * right[1] + v * up[1] dir_z = forward[2] + u * right[2] + v * up[2] + # Origin defaults to camera position + ox = camera_pos[0] + oy = camera_pos[1] + oz = camera_pos[2] + + # Thin-lens depth of field + if aperture > 0.0 and jitter_seed > 0: + # Focal point on the focal plane (perpendicular to forward) + fp_x = camera_pos[0] + focal_distance * dir_x + fp_y = camera_pos[1] + focal_distance * dir_y + fp_z = camera_pos[2] + focal_distance * dir_z + + # Two more random numbers for lens disk sampling + h = (h * np.uint32(1103515245) + np.uint32(12345)) + h = h ^ (h >> np.uint32(16)) + lr1 = float(h & np.uint32(0xFFFF)) / 65535.0 + h = (h * np.uint32(1103515245) + np.uint32(12345)) + h = h ^ (h >> np.uint32(16)) + lr2 = float(h & np.uint32(0xFFFF)) / 65535.0 + + # Uniform disk sampling + lens_r = aperture * math.sqrt(lr1) + lens_phi = 2.0 * math.pi * lr2 + lens_dx = lens_r * math.cos(lens_phi) + lens_dy = lens_r * math.sin(lens_phi) + + # Offset origin on lens disk (in camera right/up plane) + ox += lens_dx * right[0] + lens_dy * up[0] + oy += lens_dx * right[1] + lens_dy * up[1] + oz += lens_dx * right[2] + lens_dy * up[2] + + # New direction: from offset origin to focal point + dir_x = fp_x - ox + dir_y = fp_y - oy + dir_z = fp_z - oz + # Normalize direction length = math.sqrt(dir_x * dir_x + dir_y * dir_y + dir_z * dir_z) dir_x /= length @@ -178,10 +282,9 @@ def _generate_perspective_rays_kernel(rays, width, height, camera_pos, forward, dir_z /= length # Store ray (origin + direction) - idx = py * width + px - rays[idx, 0] = camera_pos[0] - rays[idx, 1] = camera_pos[1] - rays[idx, 2] = camera_pos[2] + rays[idx, 0] = ox + rays[idx, 1] = oy + rays[idx, 2] = oz rays[idx, 3] = 1e-3 # t_min rays[idx, 4] = dir_x rays[idx, 5] = dir_y @@ -189,7 +292,8 @@ def _generate_perspective_rays_kernel(rays, width, height, camera_pos, forward, rays[idx, 7] = np.inf # t_max -def _generate_perspective_rays(rays, width, height, camera_pos, forward, right, up, fov): +def _generate_perspective_rays(rays, width, height, camera_pos, forward, right, up, fov, + jitter_seed=np.uint32(0), aperture=0.0, focal_distance=0.0): """Generate perspective camera rays. Parameters @@ -210,6 +314,13 @@ def _generate_perspective_rays(rays, width, height, camera_pos, forward, right, Camera up vector (3,). fov : float Vertical field of view in degrees. + jitter_seed : np.uint32, optional + When > 0, adds sub-pixel jitter for anti-aliasing. Default is 0 (no jitter). + aperture : float, optional + Lens aperture radius for depth of field. 0 disables DOF. Default is 0.0. + focal_distance : float, optional + Distance to the focal plane. Objects at this distance are sharp. + Default is 0.0 (no DOF). """ fov_scale = math.tan(math.radians(fov) / 2.0) @@ -219,13 +330,19 @@ def _generate_perspective_rays(rays, width, height, camera_pos, forward, right, blockspergrid = (blockspergrid_x, blockspergrid_y) _generate_perspective_rays_kernel[blockspergrid, threadsperblock]( - rays, width, height, camera_pos, forward, right, up, fov_scale + rays, width, height, camera_pos, forward, right, up, fov_scale, + jitter_seed, np.float32(aperture), np.float32(focal_distance) ) @cuda.jit -def _generate_shadow_rays_from_hits_kernel(shadow_rays, primary_rays, hits, num_rays, sun_dir): - """GPU kernel to generate shadow rays from primary hit points toward the sun.""" +def _generate_shadow_rays_from_hits_kernel(shadow_rays, primary_rays, hits, num_rays, sun_dir, + sun_angle_rad, shadow_seed): + """GPU kernel to generate shadow rays from primary hit points toward the sun. + + When shadow_seed > 0 and sun_angle_rad > 0, jitters the shadow ray direction + within a cone around sun_dir for soft shadows (finite-size light source). + """ idx = cuda.grid(1) if idx < num_rays: t = hits[idx, 0] @@ -262,13 +379,65 @@ def _generate_shadow_rays_from_hits_kernel(shadow_rays, primary_rays, hits, num_ origin_y = hit_y + ny * offset origin_z = hit_z + nz * offset + # Compute shadow direction (possibly jittered for soft shadows) + if shadow_seed > 0 and sun_angle_rad > 0.0: + # Hash-based RNG (same pattern as AO kernel) + h = np.uint32(idx * np.uint32(2719) + shadow_seed) + h = (h ^ (h >> np.uint32(16))) * np.uint32(2654435761) + h = (h ^ (h >> np.uint32(16))) * np.uint32(2246822519) + h = h ^ (h >> np.uint32(16)) + r1 = float(h & np.uint32(0xFFFF)) / 65535.0 + h = (h * np.uint32(1103515245) + np.uint32(12345)) + h = h ^ (h >> np.uint32(16)) + r2 = float(h & np.uint32(0xFFFF)) / 65535.0 + + # Uniform disk -> cone deflection + cone_r = sun_angle_rad * math.sqrt(r1) + cone_phi = 2.0 * math.pi * r2 + dx_local = cone_r * math.cos(cone_phi) + dy_local = cone_r * math.sin(cone_phi) + + # Build tangent frame from sun_dir + sx = sun_dir[0] + sy = sun_dir[1] + sz = sun_dir[2] + if abs(sx) < 0.9: + tx = 0.0 + ty = -sz + tz = sy + else: + tx = sz + ty = 0.0 + tz = -sx + t_len = math.sqrt(tx * tx + ty * ty + tz * tz) + if t_len > 1e-8: + tx /= t_len + ty /= t_len + tz /= t_len + bx = sy * tz - sz * ty + by = sz * tx - sx * tz + bz = sx * ty - sy * tx + + # Perturbed direction + sdx = sx + dx_local * tx + dy_local * bx + sdy = sy + dx_local * ty + dy_local * by + sdz = sz + dx_local * tz + dy_local * bz + s_len = math.sqrt(sdx * sdx + sdy * sdy + sdz * sdz) + sdx /= s_len + sdy /= s_len + sdz /= s_len + else: + sdx = sun_dir[0] + sdy = sun_dir[1] + sdz = sun_dir[2] + shadow_rays[idx, 0] = origin_x shadow_rays[idx, 1] = origin_y shadow_rays[idx, 2] = origin_z shadow_rays[idx, 3] = 1e-3 # t_min - shadow_rays[idx, 4] = sun_dir[0] - shadow_rays[idx, 5] = sun_dir[1] - shadow_rays[idx, 6] = sun_dir[2] + shadow_rays[idx, 4] = sdx + shadow_rays[idx, 5] = sdy + shadow_rays[idx, 6] = sdz shadow_rays[idx, 7] = np.inf # t_max else: # No hit - shadow ray should not trace @@ -282,19 +451,367 @@ def _generate_shadow_rays_from_hits_kernel(shadow_rays, primary_rays, hits, num_ shadow_rays[idx, 7] = 0 # t_max = 0 means no trace -def _generate_shadow_rays_from_hits(shadow_rays, primary_rays, hits, num_rays, sun_dir): +def _generate_shadow_rays_from_hits(shadow_rays, primary_rays, hits, num_rays, sun_dir, + sun_angle_rad=0.0, shadow_seed=np.uint32(0)): """Generate shadow rays from primary ray hit points toward the sun.""" threadsperblock = 256 blockspergrid = (num_rays + threadsperblock - 1) // threadsperblock _generate_shadow_rays_from_hits_kernel[blockspergrid, threadsperblock]( - shadow_rays, primary_rays, hits, num_rays, sun_dir + shadow_rays, primary_rays, hits, num_rays, sun_dir, sun_angle_rad, shadow_seed + ) + + +@cuda.jit +def _generate_ao_rays_kernel(ao_rays, primary_rays, hits, num_rays, ao_radius, seed): + """GPU kernel to generate ambient occlusion rays from primary hit points. + + For each pixel with a hit, generates a cosine-weighted random direction + on the hemisphere around the surface normal, with tmax limited to ao_radius. + Uses a simple hash-based RNG seeded per pixel. + """ + idx = cuda.grid(1) + if idx < num_rays: + t = hits[idx, 0] + + if t > 0: + # Get normal at hit point + nx = hits[idx, 1] + ny = hits[idx, 2] + nz = hits[idx, 3] + + # Flip normal if facing away from ray + ray_dx = primary_rays[idx, 4] + ray_dy = primary_rays[idx, 5] + ray_dz = primary_rays[idx, 6] + + dot_nd = nx * ray_dx + ny * ray_dy + nz * ray_dz + if dot_nd > 0: + nx = -nx + ny = -ny + nz = -nz + + # Compute hit point + ox = primary_rays[idx, 0] + oy = primary_rays[idx, 1] + oz = primary_rays[idx, 2] + + hit_x = ox + t * ray_dx + hit_y = oy + t * ray_dy + hit_z = oz + t * ray_dz + + # Offset along normal to avoid self-intersection + offset = 1e-3 + origin_x = hit_x + nx * offset + origin_y = hit_y + ny * offset + origin_z = hit_z + nz * offset + + # Hash-based RNG: two uniform randoms from pixel index + seed + h = np.uint32(idx * np.uint32(1337) + seed) + h = (h ^ (h >> np.uint32(16))) * np.uint32(2654435761) + h = (h ^ (h >> np.uint32(16))) * np.uint32(2246822519) + h = h ^ (h >> np.uint32(16)) + r1 = float(h & np.uint32(0xFFFF)) / 65535.0 + + h = (h * np.uint32(1103515245) + np.uint32(12345)) + h = h ^ (h >> np.uint32(16)) + r2 = float(h & np.uint32(0xFFFF)) / 65535.0 + + # Cosine-weighted hemisphere sample in local coords + # r1 = cos^2(theta), so cos_theta = sqrt(r1) + cos_theta = math.sqrt(r1) + sin_theta = math.sqrt(1.0 - r1) + phi = 2.0 * math.pi * r2 + local_x = sin_theta * math.cos(phi) + local_y = sin_theta * math.sin(phi) + local_z = cos_theta + + # Build tangent frame from normal + # Choose a vector not parallel to normal + if abs(nx) < 0.9: + tx = 0.0 + ty = -nz + tz = ny + else: + tx = nz + ty = 0.0 + tz = -nx + # Normalize tangent + t_len = math.sqrt(tx * tx + ty * ty + tz * tz) + if t_len > 1e-8: + tx /= t_len + ty /= t_len + tz /= t_len + + # Bitangent = normal x tangent + bx = ny * tz - nz * ty + by = nz * tx - nx * tz + bz = nx * ty - ny * tx + + # Transform local -> world + dir_x = local_x * tx + local_y * bx + local_z * nx + dir_y = local_x * ty + local_y * by + local_z * ny + dir_z = local_x * tz + local_y * bz + local_z * nz + + # Normalize direction (should be unit already but be safe) + d_len = math.sqrt(dir_x * dir_x + dir_y * dir_y + dir_z * dir_z) + if d_len > 1e-8: + dir_x /= d_len + dir_y /= d_len + dir_z /= d_len + + ao_rays[idx, 0] = origin_x + ao_rays[idx, 1] = origin_y + ao_rays[idx, 2] = origin_z + ao_rays[idx, 3] = 1e-3 # t_min + ao_rays[idx, 4] = dir_x + ao_rays[idx, 5] = dir_y + ao_rays[idx, 6] = dir_z + ao_rays[idx, 7] = ao_radius # t_max + else: + # No hit - AO ray should not trace + ao_rays[idx, 0] = 0 + ao_rays[idx, 1] = 0 + ao_rays[idx, 2] = 0 + ao_rays[idx, 3] = 0 + ao_rays[idx, 4] = 0 + ao_rays[idx, 5] = 0 + ao_rays[idx, 6] = 1 + ao_rays[idx, 7] = 0 # t_max = 0 means no trace + + +def _generate_ao_rays(ao_rays, primary_rays, hits, num_rays, ao_radius, seed): + """Generate ambient occlusion rays from primary ray hit points.""" + threadsperblock = 256 + blockspergrid = (num_rays + threadsperblock - 1) // threadsperblock + + _generate_ao_rays_kernel[blockspergrid, threadsperblock]( + ao_rays, primary_rays, hits, num_rays, ao_radius, np.uint32(seed) + ) + + +@cuda.jit +def _accumulate_ao_kernel(ao_factor, ao_hits, num_rays, ao_samples): + """GPU kernel to accumulate AO results: subtract 1/ao_samples for each hit.""" + idx = cuda.grid(1) + if idx < num_rays: + t = ao_hits[idx, 0] + if t > 0: + ao_factor[idx] -= 1.0 / ao_samples + + +def _accumulate_ao(ao_factor, ao_hits, num_rays, ao_samples): + """Accumulate AO hit results into the ao_factor buffer.""" + threadsperblock = 256 + blockspergrid = (num_rays + threadsperblock - 1) // threadsperblock + + _accumulate_ao_kernel[blockspergrid, threadsperblock]( + ao_factor, ao_hits, num_rays, ao_samples + ) + + +@cuda.jit +def _accumulate_gi_kernel(gi_color, ao_rays, ao_hits, num_rays, ao_samples, + vertical_exaggeration, elev_min, elev_range, + color_lut, color_stretch, gi_throughput): + """GPU kernel to accumulate multi-bounce diffuse GI from AO hit points. + + For each AO ray that hit a surface, looks up the surface color at the hit + point (via elevation -> colormap LUT) and accumulates it into gi_color, + weighted by the current path throughput. Then updates throughput by + multiplying with the hit surface albedo (Lambertian BRDF). + + Bounce 0 with throughput=(1,1,1) is identical to single-bounce GI. + Subsequent bounces are naturally attenuated by accumulated albedo product. + """ + idx = cuda.grid(1) + if idx < num_rays: + t = ao_hits[idx, 0] + if t > 0: + # Hit position Z -> elevation + hit_z = ao_rays[idx, 2] + t * ao_rays[idx, 6] + elevation = hit_z / vertical_exaggeration + + # Normalize elevation to [0, 1] for colormap lookup + if elev_range > 0: + elev_norm = (elevation - elev_min) / elev_range + else: + elev_norm = 0.5 + + if elev_norm < 0.0: + elev_norm = 0.0 + elif elev_norm > 1.0: + elev_norm = 1.0 + + # Apply nonlinear stretch: 0=linear, 1=cbrt, 2=log, 3=sqrt + if color_stretch == 1: + elev_norm = math.pow(elev_norm, 1.0 / 3.0) + elif color_stretch == 2: + elev_norm = math.log(1.0 + elev_norm * 9.0) / math.log(10.0) + elif color_stretch == 3: + elev_norm = math.sqrt(elev_norm) + + # Color lookup + lut_idx = int(elev_norm * 255) + if lut_idx > 255: + lut_idx = 255 + if lut_idx < 0: + lut_idx = 0 + + hit_r = color_lut[lut_idx, 0] + hit_g = color_lut[lut_idx, 1] + hit_b = color_lut[lut_idx, 2] + + # Accumulate weighted by path throughput + gi_color[idx, 0] += gi_throughput[idx, 0] * hit_r / ao_samples + gi_color[idx, 1] += gi_throughput[idx, 1] * hit_g / ao_samples + gi_color[idx, 2] += gi_throughput[idx, 2] * hit_b / ao_samples + + # Update throughput: Lambertian BRDF = albedo/pi, + # cosine-weighted pdf = cos/pi -> cancel to albedo + gi_throughput[idx, 0] *= hit_r + gi_throughput[idx, 1] *= hit_g + gi_throughput[idx, 2] *= hit_b + + +def _accumulate_gi(gi_color, ao_rays, ao_hits, num_rays, ao_samples, + vertical_exaggeration, elev_min, elev_range, + color_lut, color_stretch, gi_throughput): + """Accumulate multi-bounce diffuse GI from AO hit points.""" + threadsperblock = 256 + blockspergrid = (num_rays + threadsperblock - 1) // threadsperblock + + _accumulate_gi_kernel[blockspergrid, threadsperblock]( + gi_color, ao_rays, ao_hits, num_rays, ao_samples, + np.float32(vertical_exaggeration), np.float32(elev_min), + np.float32(elev_range), color_lut, np.int32(color_stretch), + gi_throughput + ) + + +@cuda.jit +def _generate_reflection_rays_kernel(reflection_rays, primary_rays, primary_hits, + instance_ids, geometry_colors, num_rays, + elevation_data, pixel_spacing_x, pixel_spacing_y): + """GPU kernel to generate reflection rays for water surfaces. + + For water pixels (geometry_colors alpha >= 2.0) and NaN ocean terrain, + computes mirror reflection direction R = D - 2(D·N)N from the primary hit + point. Non-water pixels get t_max = 0 (no trace). + """ + idx = cuda.grid(1) + if idx < num_rays: + t = primary_hits[idx, 0] + + # Default: no trace + is_water = False + is_ocean = False + if t > 0: + inst_id = instance_ids[idx] + if inst_id >= 0 and inst_id < geometry_colors.shape[0]: + gc_alpha = geometry_colors[inst_id, 3] + if gc_alpha >= 2.0: + is_water = True + + # Check if terrain hit is NaN ocean + if not is_water: + ray_dx = primary_rays[idx, 4] + ray_dy = primary_rays[idx, 5] + ray_dz = primary_rays[idx, 6] + ox = primary_rays[idx, 0] + oy = primary_rays[idx, 1] + hit_x = ox + t * ray_dx + hit_y = oy + t * ray_dy + ey = int(hit_y / pixel_spacing_y + 0.5) + ex = int(hit_x / pixel_spacing_x + 0.5) + if 0 <= ey < elevation_data.shape[0] and 0 <= ex < elevation_data.shape[1]: + if math.isnan(elevation_data[ey, ex]): + is_water = True + is_ocean = True + + if is_water: + # Get normal at hit point + nx = primary_hits[idx, 1] + ny = primary_hits[idx, 2] + nz = primary_hits[idx, 3] + + # Ocean: force flat water normal + if is_ocean: + nx = 0.0 + ny = 0.0 + nz = 1.0 + + # Flip normal if facing away from ray + ray_dx = primary_rays[idx, 4] + ray_dy = primary_rays[idx, 5] + ray_dz = primary_rays[idx, 6] + + dot_nd = nx * ray_dx + ny * ray_dy + nz * ray_dz + if dot_nd > 0: + nx = -nx + ny = -ny + nz = -nz + dot_nd = -dot_nd + + # Compute hit point + ox = primary_rays[idx, 0] + oy = primary_rays[idx, 1] + oz = primary_rays[idx, 2] + + hit_x = ox + t * ray_dx + hit_y = oy + t * ray_dy + hit_z = oz + t * ray_dz + + # Reflection direction: R = D - 2(D·N)N + ref_dx = ray_dx - 2.0 * dot_nd * nx + ref_dy = ray_dy - 2.0 * dot_nd * ny + ref_dz = ray_dz - 2.0 * dot_nd * nz + + # Normalize + r_len = math.sqrt(ref_dx * ref_dx + ref_dy * ref_dy + ref_dz * ref_dz) + if r_len > 1e-8: + ref_dx /= r_len + ref_dy /= r_len + ref_dz /= r_len + + # Offset origin along normal to avoid self-intersection + offset = 1e-2 + reflection_rays[idx, 0] = hit_x + nx * offset + reflection_rays[idx, 1] = hit_y + ny * offset + reflection_rays[idx, 2] = hit_z + nz * offset + reflection_rays[idx, 3] = 1e-3 # t_min + reflection_rays[idx, 4] = ref_dx + reflection_rays[idx, 5] = ref_dy + reflection_rays[idx, 6] = ref_dz + reflection_rays[idx, 7] = np.inf # t_max + else: + # Not water — no trace needed + reflection_rays[idx, 0] = 0.0 + reflection_rays[idx, 1] = 0.0 + reflection_rays[idx, 2] = 0.0 + reflection_rays[idx, 3] = 0.0 + reflection_rays[idx, 4] = 0.0 + reflection_rays[idx, 5] = 0.0 + reflection_rays[idx, 6] = 1.0 + reflection_rays[idx, 7] = 0.0 # t_max = 0 -> no trace + + +def _generate_reflection_rays(reflection_rays, primary_rays, primary_hits, + instance_ids, geometry_colors, num_rays, + elevation_data, pixel_spacing_x, pixel_spacing_y): + """Generate reflection rays for water surfaces and NaN ocean terrain.""" + threadsperblock = 256 + blockspergrid = (num_rays + threadsperblock - 1) // threadsperblock + _generate_reflection_rays_kernel[blockspergrid, threadsperblock]( + reflection_rays, primary_rays, primary_hits, + instance_ids, geometry_colors, num_rays, + elevation_data, np.float32(pixel_spacing_x), np.float32(pixel_spacing_y) ) @cuda.jit def _shade_terrain_kernel( - output, primary_rays, primary_hits, shadow_hits, + output, albedo_out, primary_rays, primary_hits, shadow_hits, elevation_data, color_lut, num_rays, width, height, sun_dir, ambient, cast_shadows, fog_density, fog_color_r, fog_color_g, fog_color_b, @@ -306,7 +823,9 @@ def _shade_terrain_kernel( color_stretch, rgb_texture, overlay_data, overlay_alpha, overlay_min, overlay_range, - instance_ids, geometry_colors + instance_ids, geometry_colors, + ao_factor, gi_color, gi_intensity, + reflection_hits, reflection_rays ): """GPU kernel for terrain shading with lighting, shadows, fog, colormapping, and viewshed.""" idx = cuda.grid(1) @@ -373,90 +892,113 @@ def _shade_terrain_kernel( elev_h = elevation_data.shape[0] elev_w = elevation_data.shape[1] - # RGB texture mode: real texture has shape > 1, dummy is (1,1,3) - tex_h = rgb_texture.shape[0] - tex_w = rgb_texture.shape[1] - - if tex_h > 1: - # Sample RGB directly from tile texture - if elev_y >= 0 and elev_y < tex_h and elev_x >= 0 and elev_x < tex_w: - base_r = rgb_texture[elev_y, elev_x, 0] - base_g = rgb_texture[elev_y, elev_x, 1] - base_b = rgb_texture[elev_y, elev_x, 2] - else: - base_r = 0.3 - base_g = 0.3 - base_b = 0.3 + # Check for NaN ocean terrain + is_nan_ocean = False + if 0 <= elev_y < elev_h and 0 <= elev_x < elev_w: + if math.isnan(elevation_data[elev_y, elev_x]): + is_nan_ocean = True + + if is_nan_ocean: + # Ocean water — deep blue base, flat normal, water shader + is_water = True + water_specular = 0.12 + base_r = 0.06 + base_g = 0.12 + base_b = 0.22 + nx = 0.0 + ny = 0.0 + nz = 1.0 else: - if elev_y >= 0 and elev_y < elev_h and elev_x >= 0 and elev_x < elev_w: - elevation = elevation_data[elev_y, elev_x] + # RGB texture mode: real texture has shape > 1, dummy is (1,1,3) + tex_h = rgb_texture.shape[0] + tex_w = rgb_texture.shape[1] + + if tex_h > 1: + # Sample RGB directly from tile texture + if elev_y >= 0 and elev_y < tex_h and elev_x >= 0 and elev_x < tex_w: + base_r = rgb_texture[elev_y, elev_x, 0] + base_g = rgb_texture[elev_y, elev_x, 1] + base_b = rgb_texture[elev_y, elev_x, 2] + else: + base_r = 0.3 + base_g = 0.3 + base_b = 0.3 else: - elevation = hit_z + if elev_y >= 0 and elev_y < elev_h and elev_x >= 0 and elev_x < elev_w: + elevation = elevation_data[elev_y, elev_x] + else: + elevation = hit_z - # Normalize elevation to [0, 1] for colormap lookup - if elev_range > 0: - elev_norm = (elevation - elev_min) / elev_range - else: - elev_norm = 0.5 - - if elev_norm < 0: - elev_norm = 0.0 - elif elev_norm > 1: - elev_norm = 1.0 - - # Apply nonlinear stretch: 0=linear, 1=cbrt, 2=log, 3=sqrt - if color_stretch == 1: - elev_norm = math.pow(elev_norm, 1.0 / 3.0) - elif color_stretch == 2: - elev_norm = math.log(1.0 + elev_norm * 9.0) / math.log(10.0) - elif color_stretch == 3: - elev_norm = math.sqrt(elev_norm) - - # Color lookup - lut_idx = int(elev_norm * 255) - if lut_idx > 255: - lut_idx = 255 - if lut_idx < 0: - lut_idx = 0 - - base_r = color_lut[lut_idx, 0] - base_g = color_lut[lut_idx, 1] - base_b = color_lut[lut_idx, 2] - - # Overlay blending: transparent scalar layer on top of base - ov_h = overlay_data.shape[0] - ov_w = overlay_data.shape[1] - if ov_h > 1 and overlay_alpha > 0.0: - if elev_y >= 0 and elev_y < ov_h and elev_x >= 0 and elev_x < ov_w: - ov_val = overlay_data[elev_y, elev_x] - if not math.isnan(ov_val): - if overlay_range > 0: - ov_norm = (ov_val - overlay_min) / overlay_range - else: - ov_norm = 0.5 - if ov_norm < 0: - ov_norm = 0.0 - elif ov_norm > 1: - ov_norm = 1.0 - # Apply same color stretch - if color_stretch == 1: - ov_norm = math.pow(ov_norm, 1.0 / 3.0) - elif color_stretch == 2: - ov_norm = math.log(1.0 + ov_norm * 9.0) / math.log(10.0) - elif color_stretch == 3: - ov_norm = math.sqrt(ov_norm) - ov_idx = int(ov_norm * 255) - if ov_idx > 255: - ov_idx = 255 - if ov_idx < 0: - ov_idx = 0 - ov_r = color_lut[ov_idx, 0] - ov_g = color_lut[ov_idx, 1] - ov_b = color_lut[ov_idx, 2] - a = overlay_alpha - base_r = base_r * (1.0 - a) + ov_r * a - base_g = base_g * (1.0 - a) + ov_g * a - base_b = base_b * (1.0 - a) + ov_b * a + # Normalize elevation to [0, 1] for colormap lookup + if elev_range > 0: + elev_norm = (elevation - elev_min) / elev_range + else: + elev_norm = 0.5 + + if elev_norm < 0: + elev_norm = 0.0 + elif elev_norm > 1: + elev_norm = 1.0 + + # Apply nonlinear stretch: 0=linear, 1=cbrt, 2=log, 3=sqrt + if color_stretch == 1: + elev_norm = math.pow(elev_norm, 1.0 / 3.0) + elif color_stretch == 2: + elev_norm = math.log(1.0 + elev_norm * 9.0) / math.log(10.0) + elif color_stretch == 3: + elev_norm = math.sqrt(elev_norm) + + # Color lookup + lut_idx = int(elev_norm * 255) + if lut_idx > 255: + lut_idx = 255 + if lut_idx < 0: + lut_idx = 0 + + base_r = color_lut[lut_idx, 0] + base_g = color_lut[lut_idx, 1] + base_b = color_lut[lut_idx, 2] + + # Overlay blending: transparent scalar layer on top of base + ov_h = overlay_data.shape[0] + ov_w = overlay_data.shape[1] + if ov_h > 1 and overlay_alpha > 0.0: + if elev_y >= 0 and elev_y < ov_h and elev_x >= 0 and elev_x < ov_w: + ov_val = overlay_data[elev_y, elev_x] + if not math.isnan(ov_val): + if overlay_range > 0: + ov_norm = (ov_val - overlay_min) / overlay_range + else: + ov_norm = 0.5 + if ov_norm < 0: + ov_norm = 0.0 + elif ov_norm > 1: + ov_norm = 1.0 + # Apply same color stretch + if color_stretch == 1: + ov_norm = math.pow(ov_norm, 1.0 / 3.0) + elif color_stretch == 2: + ov_norm = math.log(1.0 + ov_norm * 9.0) / math.log(10.0) + elif color_stretch == 3: + ov_norm = math.sqrt(ov_norm) + ov_idx = int(ov_norm * 255) + if ov_idx > 255: + ov_idx = 255 + if ov_idx < 0: + ov_idx = 0 + ov_r = color_lut[ov_idx, 0] + ov_g = color_lut[ov_idx, 1] + ov_b = color_lut[ov_idx, 2] + a = overlay_alpha + base_r = base_r * (1.0 - a) + ov_r * a + base_g = base_g * (1.0 - a) + ov_g * a + base_b = base_b * (1.0 - a) + ov_b * a + + # Write albedo (material color before lighting) for denoiser guide + if albedo_out.shape[0] > 1: + albedo_out[py, px, 0] = base_r + albedo_out[py, px, 1] = base_g + albedo_out[py, px, 2] = base_b # Lambertian shading cos_theta = nx * sun_dir[0] + ny * sun_dir[1] + nz * sun_dir[2] @@ -473,21 +1015,117 @@ def _shade_terrain_kernel( # Final lighting diffuse = cos_theta * shadow_factor lighting = ambient + (1.0 - ambient) * diffuse + # Apply ambient occlusion + lighting *= ao_factor[idx] # Emissive glow: raise the lighting floor if emissive > 0.0: if lighting < emissive: lighting = emissive - color_r = base_r * lighting - color_g = base_g * lighting - color_b = base_b * lighting + color_r = base_r * lighting + base_r * gi_color[idx, 0] * gi_intensity + color_g = base_g * lighting + base_g * gi_color[idx, 1] * gi_intensity + color_b = base_b * lighting + base_b * gi_color[idx, 2] * gi_intensity - # Water shader: specular highlight + Fresnel rim + # Water shader: reflections + specular highlight + Fresnel if is_water: - # Blinn-Phong specular: H = normalize(L + V) + # Procedural wave normals for shimmer + wx = hit_x * 0.5 + wy = hit_y * 0.5 + h1 = (math.sin(wx * 1.1 + wy * 0.7) * 0.4 + + math.sin(wx * 2.3 - wy * 1.9) * 0.3) + h2 = (math.sin(wx * 0.8 - wy * 1.3) * 0.4 + + math.sin(wx * 1.7 + wy * 2.1) * 0.3) + wave_strength = 0.015 + nx += h1 * wave_strength + ny += h2 * wave_strength + n_len = math.sqrt(nx * nx + ny * ny + nz * nz) + nx /= n_len + ny /= n_len + nz /= n_len + + # View direction vx = -ray_dx vy = -ray_dy vz = -ray_dz + + # Fresnel: more reflective at grazing angles + n_dot_v = abs(nx * vx + ny * vy + nz * vz) + fresnel = 0.3 + 0.7 * (1.0 - n_dot_v) + + # Compute reflection color from traced reflection rays + refl_t = reflection_hits[idx, 0] + if refl_t > 0: + # Reflection hit terrain — shade with simple colormap + diffuse + refl_hx = reflection_rays[idx, 0] + refl_t * reflection_rays[idx, 4] + refl_hy = reflection_rays[idx, 1] + refl_t * reflection_rays[idx, 5] + refl_hz = reflection_rays[idx, 2] + refl_t * reflection_rays[idx, 6] + + # Look up elevation at reflected hit point + refl_ey = int(refl_hy / pixel_spacing_y + 0.5) + refl_ex = int(refl_hx / pixel_spacing_x + 0.5) + elev_h = elevation_data.shape[0] + elev_w = elevation_data.shape[1] + + # Check for RGB texture first + tex_h = rgb_texture.shape[0] + tex_w = rgb_texture.shape[1] + + if tex_h > 1 and refl_ey >= 0 and refl_ey < tex_h and refl_ex >= 0 and refl_ex < tex_w: + refl_r = rgb_texture[refl_ey, refl_ex, 0] + refl_g = rgb_texture[refl_ey, refl_ex, 1] + refl_b = rgb_texture[refl_ey, refl_ex, 2] + elif refl_ey >= 0 and refl_ey < elev_h and refl_ex >= 0 and refl_ex < elev_w: + refl_elev = elevation_data[refl_ey, refl_ex] + if elev_range > 0: + refl_norm = (refl_elev - elev_min) / elev_range + else: + refl_norm = 0.5 + if refl_norm < 0.0: + refl_norm = 0.0 + elif refl_norm > 1.0: + refl_norm = 1.0 + refl_lut = int(refl_norm * 255) + if refl_lut > 255: + refl_lut = 255 + if refl_lut < 0: + refl_lut = 0 + refl_r = color_lut[refl_lut, 0] + refl_g = color_lut[refl_lut, 1] + refl_b = color_lut[refl_lut, 2] + else: + refl_r = 0.3 + refl_g = 0.3 + refl_b = 0.3 + + # Simple diffuse lighting on reflected surface + refl_nx = reflection_hits[idx, 1] + refl_ny = reflection_hits[idx, 2] + refl_nz = reflection_hits[idx, 3] + refl_cos = refl_nx * sun_dir[0] + refl_ny * sun_dir[1] + refl_nz * sun_dir[2] + if refl_cos < 0.0: + refl_cos = -refl_cos + refl_light = ambient + (1.0 - ambient) * refl_cos + refl_r *= refl_light + refl_g *= refl_light + refl_b *= refl_light + else: + # Reflection miss -> sky + ref_dx = reflection_rays[idx, 4] + ref_dy = reflection_rays[idx, 5] + ref_dz = reflection_rays[idx, 6] + if sky_color_r < 0: + refl_r, refl_g, refl_b = _compute_physical_sky(ref_dx, ref_dy, ref_dz, sun_dir) + else: + refl_r = sky_color_r + refl_g = sky_color_g + refl_b = sky_color_b + + # Blend base water color with reflection using Fresnel + color_r = color_r * (1.0 - fresnel) + refl_r * fresnel + color_g = color_g * (1.0 - fresnel) + refl_g * fresnel + color_b = color_b * (1.0 - fresnel) + refl_b * fresnel + + # Blinn-Phong specular: H = normalize(L + V) hx = sun_dir[0] + vx hy = sun_dir[1] + vy hz = sun_dir[2] + vz @@ -508,14 +1146,10 @@ def _shade_terrain_kernel( spec = spec * spec # ^64 spec *= water_specular * shadow_factor - # Fresnel-like darkening at steep view angles - n_dot_v = abs(nx * vx + ny * vy + nz * vz) - fresnel = 0.3 + 0.7 * (1.0 - n_dot_v) - - # Darken base color at steep angles, add specular - color_r = color_r * (0.7 + 0.3 * fresnel) + spec - color_g = color_g * (0.7 + 0.3 * fresnel) + spec - color_b = color_b * (0.7 + 0.3 * fresnel) + spec * 0.9 + # Add specular on top + color_r += spec + color_g += spec + color_b += spec * 0.9 # Observer marker removed — drone mesh is placed as scene geometry @@ -540,14 +1174,6 @@ def _shade_terrain_kernel( color_g = color_g * (1.0 - alpha) + 0.9 * alpha color_b = color_b * (1.0 - alpha) + 0.85 * alpha - # Clamp - if color_r > 1.0: - color_r = 1.0 - if color_g > 1.0: - color_g = 1.0 - if color_b > 1.0: - color_b = 1.0 - # Fog if fog_density > 0: fog_amount = 1.0 - math.exp(-fog_density * t) @@ -561,18 +1187,276 @@ def _shade_terrain_kernel( if alpha_channel: output[py, px, 3] = 1.0 else: + # Miss - black albedo (sky has no material) + if albedo_out.shape[0] > 1: + albedo_out[py, px, 0] = 0.0 + albedo_out[py, px, 1] = 0.0 + albedo_out[py, px, 2] = 0.0 # Miss - sky color - output[py, px, 0] = sky_color_r - output[py, px, 1] = sky_color_g - output[py, px, 2] = sky_color_b + if sky_color_r < 0: + # Physical sky via shared device function + ray_dx = primary_rays[idx, 4] + ray_dy = primary_rays[idx, 5] + ray_dz = primary_rays[idx, 6] + + sr, sg, sb = _compute_physical_sky(ray_dx, ray_dy, ray_dz, sun_dir) + if sr > 1.0: + sr = 1.0 + if sg > 1.0: + sg = 1.0 + if sb > 1.0: + sb = 1.0 + + output[py, px, 0] = sr + output[py, px, 1] = sg + output[py, px, 2] = sb + else: + output[py, px, 0] = sky_color_r + output[py, px, 1] = sky_color_g + output[py, px, 2] = sky_color_b if alpha_channel: output[py, px, 3] = 0.0 +@cuda.jit +def _tone_map_aces_kernel(output, height, width, num_channels): + """Apply ACES filmic tone mapping in-place (Stephen Hill approximation).""" + idx = cuda.grid(1) + if idx < height * width: + py = idx // width + px = idx % width + for c in range(num_channels): + x = output[py, px, c] + # ACES filmic: (x*(2.51*x+0.03)) / (x*(2.43*x+0.59)+0.14) + output[py, px, c] = (x * (2.51 * x + 0.03)) / (x * (2.43 * x + 0.59) + 0.14) + + +def _tone_map_aces(output): + """Apply ACES filmic tone mapping to GPU output buffer in-place.""" + height, width, num_channels = output.shape + num_pixels = height * width + threadsperblock = 256 + blockspergrid = (num_pixels + threadsperblock - 1) // threadsperblock + _tone_map_aces_kernel[blockspergrid, threadsperblock](output, height, width, num_channels) + + +@cuda.jit +def _edge_outline_kernel(output, instance_ids, height, width, + edge_strength, edge_r, edge_g, edge_b): + """Darken pixels at boundaries between different instance_ids.""" + idx = cuda.grid(1) + if idx >= height * width: + return + my_id = instance_ids[idx] + if my_id < 0: + return + py = idx // width + px = idx % width + is_edge = False + # Check 4 cardinal neighbors + if py > 0 and instance_ids[idx - width] != my_id: + is_edge = True + elif py < height - 1 and instance_ids[idx + width] != my_id: + is_edge = True + elif px > 0 and instance_ids[idx - 1] != my_id: + is_edge = True + elif px < width - 1 and instance_ids[idx + 1] != my_id: + is_edge = True + if is_edge: + inv = 1.0 - edge_strength + output[py, px, 0] = output[py, px, 0] * inv + edge_r * edge_strength + output[py, px, 1] = output[py, px, 1] * inv + edge_g * edge_strength + output[py, px, 2] = output[py, px, 2] * inv + edge_b * edge_strength + + +def _edge_outline(output, instance_ids, edge_strength=0.6, + edge_color=(0.05, 0.05, 0.05)): + """Apply screen-space edge detection on instance_id boundaries.""" + height, width, _ = output.shape + num_pixels = height * width + threadsperblock = 256 + blockspergrid = (num_pixels + threadsperblock - 1) // threadsperblock + _edge_outline_kernel[blockspergrid, threadsperblock]( + output, instance_ids, height, width, + edge_strength, *edge_color) + + +@cuda.jit +def _compute_flow_kernel(flow_out, primary_rays, primary_hits, + width, height, + prev_pos, prev_forward, prev_right, prev_up, + aspect, fov_scale): + """Compute per-pixel screen-space motion vectors by reprojecting hits + through the previous frame's camera.""" + idx = cuda.grid(1) + if idx >= width * height: + return + py = idx // width + px = idx % width + t = primary_hits[idx, 0] + if t <= 0: + flow_out[py, px, 0] = 0.0 + flow_out[py, px, 1] = 0.0 + return + # 3D hit point from current ray + hx = primary_rays[idx, 0] + t * primary_rays[idx, 4] + hy = primary_rays[idx, 1] + t * primary_rays[idx, 5] + hz = primary_rays[idx, 2] + t * primary_rays[idx, 6] + # Reproject through previous camera + ox = hx - prev_pos[0] + oy = hy - prev_pos[1] + oz = hz - prev_pos[2] + depth = ox * prev_forward[0] + oy * prev_forward[1] + oz * prev_forward[2] + if depth <= 1e-6: + flow_out[py, px, 0] = 0.0 + flow_out[py, px, 1] = 0.0 + return + u = (ox * prev_right[0] + oy * prev_right[1] + oz * prev_right[2]) / (depth * aspect * fov_scale) + v = (ox * prev_up[0] + oy * prev_up[1] + oz * prev_up[2]) / (depth * fov_scale) + prev_px = (u + 1.0) * width / 2.0 - 0.5 + prev_py = (1.0 - v) * height / 2.0 - 0.5 + flow_out[py, px, 0] = prev_px - px + flow_out[py, px, 1] = prev_py - py + + +def compute_flow(flow_out, primary_rays, primary_hits, width, height, + prev_pos, prev_forward, prev_right, prev_up, + aspect, fov_scale): + """Compute screen-space flow vectors for temporal denoising. + + Parameters + ---------- + flow_out : cupy.ndarray + (height, width, 2) float32 output — per-pixel (dx, dy) in pixels. + primary_rays, primary_hits : cupy.ndarray + Ray buffers from current frame. + prev_pos, prev_forward, prev_right, prev_up : cupy.ndarray + Previous frame camera basis vectors (device arrays, shape (3,)). + aspect : float + Aspect ratio (width / height). + fov_scale : float + tan(fov_radians / 2). + """ + num_rays = width * height + threadsperblock = 256 + blockspergrid = (num_rays + threadsperblock - 1) // threadsperblock + _compute_flow_kernel[blockspergrid, threadsperblock]( + flow_out, primary_rays, primary_hits, width, height, + prev_pos, prev_forward, prev_right, prev_up, + aspect, fov_scale + ) + + +@cuda.jit +def _bloom_threshold_kernel(bright, output, height, width, threshold): + """Extract pixels brighter than threshold into a separate buffer.""" + idx = cuda.grid(1) + if idx < height * width: + py = idx // width + px = idx % width + r = output[py, px, 0] + g = output[py, px, 1] + b = output[py, px, 2] + lum = 0.2126 * r + 0.7152 * g + 0.0722 * b + if lum > threshold: + scale = (lum - threshold) / lum + bright[py, px, 0] = r * scale + bright[py, px, 1] = g * scale + bright[py, px, 2] = b * scale + else: + bright[py, px, 0] = 0.0 + bright[py, px, 1] = 0.0 + bright[py, px, 2] = 0.0 + + +@cuda.jit +def _bloom_blur_kernel(dst, src, height, width, radius, horizontal): + """Separable Gaussian blur (approximate with linear weights).""" + idx = cuda.grid(1) + if idx < height * width: + py = idx // width + px = idx % width + + acc_r = 0.0 + acc_g = 0.0 + acc_b = 0.0 + weight_sum = 0.0 + + for i in range(-radius, radius + 1): + if horizontal: + sx = px + i + sy = py + else: + sx = px + sy = py + i + + if sx >= 0 and sx < width and sy >= 0 and sy < height: + # Gaussian weight: exp(-0.5 * (i/sigma)^2), sigma ≈ radius/2.5 + sigma = float(radius) / 2.5 + w = math.exp(-0.5 * (float(i) / sigma) * (float(i) / sigma)) + acc_r += src[sy, sx, 0] * w + acc_g += src[sy, sx, 1] * w + acc_b += src[sy, sx, 2] * w + weight_sum += w + + if weight_sum > 0: + dst[py, px, 0] = acc_r / weight_sum + dst[py, px, 1] = acc_g / weight_sum + dst[py, px, 2] = acc_b / weight_sum + + +@cuda.jit +def _bloom_composite_kernel(output, bloom, height, width, intensity): + """Additively blend bloom buffer into output.""" + idx = cuda.grid(1) + if idx < height * width: + py = idx // width + px = idx % width + output[py, px, 0] += bloom[py, px, 0] * intensity + output[py, px, 1] += bloom[py, px, 1] * intensity + output[py, px, 2] += bloom[py, px, 2] * intensity + + +def _bloom(output, temp, scratch, threshold=0.7, radius=12, intensity=0.35): + """Apply bloom post-process: threshold -> blur -> composite.""" + height, width = output.shape[0], output.shape[1] + num_pixels = height * width + threadsperblock = 256 + blockspergrid = (num_pixels + threadsperblock - 1) // threadsperblock + + # Extract bright pixels into temp + _bloom_threshold_kernel[blockspergrid, threadsperblock]( + temp, output, height, width, np.float32(threshold) + ) + + # Horizontal blur: temp -> scratch + _bloom_blur_kernel[blockspergrid, threadsperblock]( + scratch, temp, height, width, np.int32(radius), True + ) + + # Vertical blur: scratch -> temp + _bloom_blur_kernel[blockspergrid, threadsperblock]( + temp, scratch, height, width, np.int32(radius), False + ) + + # Composite: add bloom back into output + _bloom_composite_kernel[blockspergrid, threadsperblock]( + output, temp, height, width, np.float32(intensity) + ) + + # Lazy singletons for dummy GPU arrays (avoid per-frame allocations) _DUMMY_1x1 = None _DUMMY_1x1x3 = None _DUMMY_1x4 = None +_DUMMY_AO_ONES = None # (num_rays,) all-ones buffer for disabled AO +_DUMMY_AO_SIZE = 0 +_DUMMY_GI_COLOR = None # (num_rays, 3) all-zero for no GI +_DUMMY_GI_SIZE = 0 +_DUMMY_REFL_HITS = None # (num_rays, 4) all-zero for no reflections +_DUMMY_REFL_RAYS = None +_DUMMY_REFL_SIZE = 0 +_DUMMY_ALBEDO = None # (1, 1, 3) placeholder when albedo not captured def _shade_terrain( @@ -585,11 +1469,14 @@ def _shade_terrain( observer_x=-1e30, observer_y=-1e30, pixel_spacing_x=1.0, pixel_spacing_y=1.0, color_stretch=0, - sky_color=(0.0, 0.0, 0.0), + sky_color=(-1.0, 0.0, 0.0), rgb_texture=None, overlay_data=None, overlay_alpha=0.5, overlay_min=0.0, overlay_range=1.0, instance_ids=None, geometry_colors=None, + ao_factor=None, gi_color=None, gi_intensity=2.0, + reflection_hits=None, reflection_rays=None, + albedo_out=None, ): """Apply terrain shading with all effects.""" threadsperblock = 256 @@ -624,8 +1511,41 @@ def _shade_terrain( _DUMMY_1x4 = cupy.zeros((1, 4), dtype=np.float32) geometry_colors = _DUMMY_1x4 + # Handle AO factor - cached all-ones when disabled + global _DUMMY_AO_ONES, _DUMMY_AO_SIZE + if ao_factor is None: + if _DUMMY_AO_ONES is None or _DUMMY_AO_SIZE != num_rays: + _DUMMY_AO_ONES = cupy.ones(num_rays, dtype=np.float32) + _DUMMY_AO_SIZE = num_rays + ao_factor = _DUMMY_AO_ONES + + # Handle GI color - cached all-zeros when disabled + global _DUMMY_GI_COLOR, _DUMMY_GI_SIZE + if gi_color is None: + if _DUMMY_GI_COLOR is None or _DUMMY_GI_SIZE != num_rays: + _DUMMY_GI_COLOR = cupy.zeros((num_rays, 3), dtype=np.float32) + _DUMMY_GI_SIZE = num_rays + gi_color = _DUMMY_GI_COLOR + + # Handle reflection buffers - dummy zero arrays when no reflections + global _DUMMY_REFL_HITS, _DUMMY_REFL_RAYS, _DUMMY_REFL_SIZE + if reflection_hits is None: + if _DUMMY_REFL_HITS is None or _DUMMY_REFL_SIZE != num_rays: + _DUMMY_REFL_HITS = cupy.zeros((num_rays, 4), dtype=np.float32) + _DUMMY_REFL_RAYS = cupy.zeros((num_rays, 8), dtype=np.float32) + _DUMMY_REFL_SIZE = num_rays + reflection_hits = _DUMMY_REFL_HITS + reflection_rays = _DUMMY_REFL_RAYS + + # Handle albedo output - dummy (1,1,3) when not capturing + global _DUMMY_ALBEDO + if albedo_out is None: + if _DUMMY_ALBEDO is None: + _DUMMY_ALBEDO = cupy.zeros((1, 1, 3), dtype=np.float32) + albedo_out = _DUMMY_ALBEDO + _shade_terrain_kernel[blockspergrid, threadsperblock]( - output, primary_rays, primary_hits, shadow_hits, + output, albedo_out, primary_rays, primary_hits, shadow_hits, elevation_data, color_lut, num_rays, width, height, sun_dir, ambient, cast_shadows, fog_density, fog_color[0], fog_color[1], fog_color[2], @@ -637,7 +1557,9 @@ def _shade_terrain( color_stretch, rgb_texture, overlay_data, overlay_alpha, overlay_min, overlay_range, - instance_ids, geometry_colors + instance_ids, geometry_colors, + ao_factor, gi_color, np.float32(gi_intensity), + reflection_hits, reflection_rays ) @@ -674,24 +1596,55 @@ def __init__(self): self.shadow_rays = None self.shadow_hits = None self.output = None + self.albedo = None self.instance_ids = None - - def get(self, width, height, shadows, alpha, need_instance_ids): + self.ao_rays = None + self.ao_hits = None + self.gi_color = None + self.gi_throughput = None + self.reflection_rays = None + self.reflection_hits = None + self.bloom_temp = None + self.bloom_scratch = None + + def get(self, width, height, shadows, alpha, need_instance_ids, ao=False): num_rays = width * height num_channels = 4 if alpha else 3 - key = (width, height, shadows, alpha) + key = (width, height, shadows, alpha, ao) if key != self._key: self.primary_rays = cupy.empty((num_rays, 8), dtype=np.float32) self.primary_hits = cupy.empty((num_rays, 4), dtype=np.float32) self.shadow_rays = cupy.empty((num_rays, 8), dtype=np.float32) self.shadow_hits = cupy.empty((num_rays, 4), dtype=np.float32) self.output = cupy.zeros((height, width, num_channels), dtype=np.float32) + self.albedo = cupy.zeros((height, width, 3), dtype=np.float32) self.instance_ids = cupy.full(num_rays, -1, dtype=cupy.int32) + if ao: + self.ao_rays = cupy.empty((num_rays, 8), dtype=np.float32) + self.ao_hits = cupy.empty((num_rays, 4), dtype=np.float32) + self.gi_color = cupy.zeros((num_rays, 3), dtype=np.float32) + self.gi_throughput = cupy.ones((num_rays, 3), dtype=np.float32) + else: + self.ao_rays = None + self.ao_hits = None + self.gi_color = None + self.gi_throughput = None + if need_instance_ids: + self.reflection_rays = cupy.empty((num_rays, 8), dtype=np.float32) + self.reflection_hits = cupy.empty((num_rays, 4), dtype=np.float32) + else: + self.reflection_rays = None + self.reflection_hits = None + self.bloom_temp = cupy.zeros((height, width, 3), dtype=np.float32) + self.bloom_scratch = cupy.zeros((height, width, 3), dtype=np.float32) self._key = key else: self.output.fill(0) + self.albedo.fill(0) if need_instance_ids: self.instance_ids.fill(-1) + if self.gi_color is not None: + self.gi_color.fill(0) return self @@ -725,14 +1678,30 @@ def render( observer_position: Optional[Tuple[float, float]] = None, pixel_spacing_x: float = 1.0, pixel_spacing_y: float = 1.0, - mesh_type: str = 'tin', + mesh_type: str = 'heightfield', color_data=None, color_stretch: str = 'linear', - sky_color: Tuple[float, float, float] = (0.0, 0.0, 0.0), + sky_color: Optional[Tuple[float, float, float]] = None, rgb_texture=None, overlay_data=None, overlay_alpha: float = 0.5, geometry_colors=None, + ao_samples: int = 0, + ao_radius: Optional[float] = None, + ao_seed: int = 0, + gi_intensity: float = 2.0, + gi_bounces: int = 1, + frame_seed: int = 0, + sun_angle: float = 0.0, + aperture: float = 0.0, + focal_distance: float = 0.0, + tone_map: bool = True, + bloom: bool = True, + denoise: bool = False, + edge_lines: bool = True, + edge_strength: float = 0.6, + edge_color: Tuple[float, float, float] = (0.05, 0.05, 0.05), + _return_gpu: bool = False, ) -> np.ndarray: """Render terrain with a perspective camera for movie-quality visualization. @@ -835,8 +1804,8 @@ def render( elev_range_orig = elev_max_orig - elev_min_orig if vertical_exaggeration is None: - # Auto-compute: scale so relief is ~20% of horizontal extent - horizontal_extent = max(H, W) + # Auto-compute: scale so relief is ~20% of horizontal extent (in world units) + horizontal_extent = max(H * pixel_spacing_y, W * pixel_spacing_x) if elev_range_orig > 0: vertical_exaggeration = (horizontal_extent * 0.2) / elev_range_orig else: @@ -856,10 +1825,14 @@ def render( # Create a temporary raster with scaled elevations scaled_raster = raster.copy(data=scaled_elevation) # Don't reuse rtx when scaling - need fresh mesh - optix = prepare_mesh(scaled_raster, rtx=None, mesh_type=mesh_type) + optix = prepare_mesh(scaled_raster, rtx=None, mesh_type=mesh_type, + pixel_spacing_x=pixel_spacing_x, + pixel_spacing_y=pixel_spacing_y) else: scaled_raster = raster - optix = prepare_mesh(raster, rtx, mesh_type=mesh_type) + optix = prepare_mesh(raster, rtx, mesh_type=mesh_type, + pixel_spacing_x=pixel_spacing_x, + pixel_spacing_y=pixel_spacing_y) # Scale camera position and look_at z coordinates scaled_camera_position = ( @@ -912,23 +1885,41 @@ def render( elev_max = float(cupy.nanmax(colormap_data)) elev_range = elev_max - elev_min + # Color stretch mode: string -> int for CUDA kernel (needed early for GI in AO loop) + _stretch_modes = {'linear': 0, 'cbrt': 1, 'log': 2, 'sqrt': 3} + stretch_int = _stretch_modes.get(color_stretch, 0) + + # Detect NaN ocean terrain (needs reflection buffers even without geometry) + has_nan_ocean = bool(cupy.any(cupy.isnan(elevation_data))) + # Allocate (or reuse) buffers bufs = _render_buffers.get(width, height, shadows, alpha, - geometry_colors is not None) + geometry_colors is not None or has_nan_ocean, + ao=ao_samples > 0) d_primary_rays = bufs.primary_rays d_primary_hits = bufs.primary_hits d_shadow_rays = bufs.shadow_rays d_shadow_hits = bufs.shadow_hits d_output = bufs.output - device = cupy.cuda.Device(0) + # Compute derived seeds for AA and soft shadows from frame_seed + jitter_seed = np.uint32(frame_seed * 3 + 1) if frame_seed > 0 else np.uint32(0) + shadow_seed = np.uint32(frame_seed * 3 + 2) if frame_seed > 0 else np.uint32(0) + sun_angle_rad = math.radians(sun_angle) if sun_angle > 0 else 0.0 + + # Auto-compute focal distance from camera-to-lookat if not specified + if aperture > 0 and focal_distance <= 0: + dx = scaled_look_at[0] - scaled_camera_position[0] + dy = scaled_look_at[1] - scaled_camera_position[1] + dz = scaled_look_at[2] - scaled_camera_position[2] + focal_distance = math.sqrt(dx * dx + dy * dy + dz * dz) # Step 1: Generate perspective rays _generate_perspective_rays( d_primary_rays, width, height, - d_camera_pos, d_forward, d_right, d_up, fov + d_camera_pos, d_forward, d_right, d_up, fov, + jitter_seed=jitter_seed, aperture=aperture, focal_distance=focal_distance ) - device.synchronize() # Step 2: Trace primary rays (with instance_ids if geometry_colors provided) d_instance_ids = bufs.instance_ids @@ -940,14 +1931,80 @@ def render( # Step 3: Generate and trace shadow rays (if enabled) if shadows: _generate_shadow_rays_from_hits( - d_shadow_rays, d_primary_rays, d_primary_hits, num_rays, d_sun_dir + d_shadow_rays, d_primary_rays, d_primary_hits, num_rays, d_sun_dir, + sun_angle_rad=sun_angle_rad, shadow_seed=shadow_seed ) - device.synchronize() - optix.trace(d_shadow_rays, d_shadow_hits, num_rays) + optix.trace(d_shadow_rays, d_shadow_hits, num_rays, + ray_flags=RTX.RAY_FLAG_OCCLUSION) else: # Fill shadow hits with -1 (no shadow) d_shadow_hits.fill(-1) + # Step 3b: Ambient occlusion pass + d_ao_factor = None + if ao_samples > 0: + d_ao_rays = bufs.ao_rays + d_ao_hits = bufs.ao_hits + + # Auto-compute AO radius from scene extent if not specified + if ao_radius is None: + H_raster, W_raster = raster.shape + diagonal = math.sqrt((H_raster * pixel_spacing_y) ** 2 + + (W_raster * pixel_spacing_x) ** 2) + ao_radius = diagonal * 0.05 + + d_ao_factor = cupy.ones(num_rays, dtype=np.float32) + d_gi_color = bufs.gi_color + d_gi_throughput = bufs.gi_throughput + + for s in range(ao_samples): + sample_seed = ao_seed * ao_samples + s + d_gi_throughput.fill(1.0) # reset per-sample path throughput + + # Bounce 0 (existing flow) + _generate_ao_rays(d_ao_rays, d_primary_rays, d_primary_hits, + num_rays, ao_radius, sample_seed) + optix.trace(d_ao_rays, d_ao_hits, num_rays, + ray_flags=RTX.RAY_FLAG_OCCLUSION) + _accumulate_ao(d_ao_factor, d_ao_hits, num_rays, ao_samples) + _accumulate_gi(d_gi_color, d_ao_rays, d_ao_hits, num_rays, + ao_samples, vertical_exaggeration, + elev_min, elev_range, d_color_lut, stretch_int, + d_gi_throughput) + + # Additional bounces + for bounce in range(1, gi_bounces): + bounce_seed = sample_seed * 7919 + bounce * 6271 + # In-place: new AO rays from previous hit points + _generate_ao_rays(d_ao_rays, d_ao_rays, d_ao_hits, + num_rays, ao_radius, bounce_seed) + optix.trace(d_ao_rays, d_ao_hits, num_rays, + ray_flags=RTX.RAY_FLAG_OCCLUSION) + _accumulate_gi(d_gi_color, d_ao_rays, d_ao_hits, num_rays, + ao_samples, vertical_exaggeration, + elev_min, elev_range, d_color_lut, stretch_int, + d_gi_throughput) + + # Step 3c: Reflection rays for water surfaces and NaN ocean + d_reflection_hits = None + d_reflection_rays = None + if bufs.reflection_rays is not None: + d_reflection_rays = bufs.reflection_rays + d_reflection_hits = bufs.reflection_hits + # Use real geometry_colors or dummy for the kernel + gc = geometry_colors + if gc is None: + global _DUMMY_1x4 + if _DUMMY_1x4 is None: + _DUMMY_1x4 = cupy.zeros((1, 4), dtype=np.float32) + gc = _DUMMY_1x4 + _generate_reflection_rays( + d_reflection_rays, d_primary_rays, d_primary_hits, + d_instance_ids, gc, num_rays, + colormap_data, pixel_spacing_x, pixel_spacing_y + ) + optix.trace(d_reflection_rays, d_reflection_hits, num_rays) + # Prepare viewshed data if provided d_viewshed = None if viewshed_data is not None: @@ -966,10 +2023,6 @@ def render( obs_x = float(observer_position[0]) if observer_position else -1e30 obs_y = float(observer_position[1]) if observer_position else -1e30 - # Color stretch mode: string -> int for CUDA kernel - _stretch_modes = {'linear': 0, 'cbrt': 1, 'log': 2, 'sqrt': 3} - stretch_int = _stretch_modes.get(color_stretch, 0) - # Prepare overlay data for transparent blending d_overlay = None ov_min = 0.0 @@ -994,13 +2047,41 @@ def render( obs_x, obs_y, pixel_spacing_x, pixel_spacing_y, stretch_int, - sky_color=sky_color, + sky_color=(-1.0, 0.0, 0.0) if sky_color is None else sky_color, rgb_texture=rgb_texture, overlay_data=d_overlay, overlay_alpha=overlay_alpha, overlay_min=ov_min, overlay_range=ov_range, instance_ids=d_instance_ids, geometry_colors=geometry_colors, + ao_factor=d_ao_factor, + gi_color=bufs.gi_color if ao_samples > 0 else None, + gi_intensity=gi_intensity, + reflection_hits=d_reflection_hits, reflection_rays=d_reflection_rays, + albedo_out=bufs.albedo, ) - device.synchronize() + + # AI denoiser (after shading, before bloom/tone mapping) + if denoise: + from ..rtx import denoise as _denoise + d_normals = d_primary_hits.reshape(height, width, 4)[:, :, 1:4].copy() + _denoise(d_output, d_normals, width, height, right, cam_up, forward, + albedo=bufs.albedo) + + # Edge outlines on placed geometry (after denoise, before bloom) + if edge_lines and geometry_colors is not None: + _edge_outline(d_output, d_instance_ids, edge_strength, edge_color) + + # Bloom post-process (before tone mapping so ACES compresses bloom gracefully) + if bloom: + _bloom(d_output, bufs.bloom_temp, bufs.bloom_scratch) + + # Tone mapping (ACES filmic curve) + if tone_map: + _tone_map_aces(d_output) + + cupy.cuda.Stream.null.synchronize() + + if _return_gpu: + return d_output # Transfer to CPU output = cupy.asnumpy(d_output) diff --git a/rtxpy/analysis/slope_aspect.py b/rtxpy/analysis/slope_aspect.py index 8ab492b..ac3ce31 100644 --- a/rtxpy/analysis/slope_aspect.py +++ b/rtxpy/analysis/slope_aspect.py @@ -10,7 +10,7 @@ import numpy as np from .._cuda_utils import calc_dims -from ._common import generate_primary_rays, prepare_mesh +from ._common import generate_primary_rays, prepare_mesh, _compute_pixel_spacing from ..rtx import RTX, has_cupy if has_cupy: @@ -92,7 +92,7 @@ def _calc_aspect_kernel(hits, output, H, W): output[i, j] = np.float32(angle) -def _slope_rt(raster, optix): +def _slope_rt(raster, optix, pixel_spacing_x=1.0, pixel_spacing_y=1.0): """Internal: trace primary rays and compute slope.""" xr = _lazy_import_xarray() @@ -105,7 +105,9 @@ def _slope_rt(raster, optix): y_coords = cupy.array(raster.indexes.get('y').values) x_coords = cupy.array(raster.indexes.get('x').values) - generate_primary_rays(d_rays, x_coords, y_coords, H, W) + generate_primary_rays(d_rays, x_coords, y_coords, H, W, + pixel_spacing_x=pixel_spacing_x, + pixel_spacing_y=pixel_spacing_y) cupy.cuda.Device(0).synchronize() optix.trace(d_rays, d_hits, W * H) @@ -134,7 +136,7 @@ def _slope_rt(raster, optix): ) -def _aspect_rt(raster, optix): +def _aspect_rt(raster, optix, pixel_spacing_x=1.0, pixel_spacing_y=1.0): """Internal: trace primary rays and compute aspect.""" xr = _lazy_import_xarray() @@ -147,7 +149,9 @@ def _aspect_rt(raster, optix): y_coords = cupy.array(raster.indexes.get('y').values) x_coords = cupy.array(raster.indexes.get('x').values) - generate_primary_rays(d_rays, x_coords, y_coords, H, W) + generate_primary_rays(d_rays, x_coords, y_coords, H, W, + pixel_spacing_x=pixel_spacing_x, + pixel_spacing_y=pixel_spacing_y) cupy.cuda.Device(0).synchronize() optix.trace(d_rays, d_hits, W * H) @@ -204,8 +208,9 @@ def slope(raster, rtx: RTX = None): "Additional overhead will be incurred from CPU-GPU transfers." ) - optix = prepare_mesh(raster, rtx) - return _slope_rt(raster, optix) + psx, psy = _compute_pixel_spacing(raster) + optix = prepare_mesh(raster, rtx, pixel_spacing_x=psx, pixel_spacing_y=psy) + return _slope_rt(raster, optix, pixel_spacing_x=psx, pixel_spacing_y=psy) def aspect(raster, rtx: RTX = None): @@ -240,5 +245,6 @@ def aspect(raster, rtx: RTX = None): "Additional overhead will be incurred from CPU-GPU transfers." ) - optix = prepare_mesh(raster, rtx) - return _aspect_rt(raster, optix) + psx, psy = _compute_pixel_spacing(raster) + optix = prepare_mesh(raster, rtx, pixel_spacing_x=psx, pixel_spacing_y=psy) + return _aspect_rt(raster, optix, pixel_spacing_x=psx, pixel_spacing_y=psy) diff --git a/rtxpy/analysis/viewshed.py b/rtxpy/analysis/viewshed.py index 32f7acf..2a6a1a8 100644 --- a/rtxpy/analysis/viewshed.py +++ b/rtxpy/analysis/viewshed.py @@ -11,7 +11,7 @@ from typing import Union from .._cuda_utils import calc_dims, add, diff, mul, dot, float3, make_float3, invert -from ._common import generate_primary_rays, prepare_mesh +from ._common import generate_primary_rays, prepare_mesh, _compute_pixel_spacing from ..rtx import RTX, has_cupy if has_cupy: @@ -264,11 +264,14 @@ def viewshed(raster, if not isinstance(raster.data, cupy.ndarray): raise ValueError("raster.data must be a cupy array") + psx, psy = _compute_pixel_spacing(raster) + # If an RTX with existing geometries is provided (multi-GAS scene), # use it directly so viewshed rays are occluded by all scene geometry. # Only build a terrain-only mesh when no RTX is given. if rtx is not None and rtx.get_geometry_count() > 0: optix = rtx else: - optix = prepare_mesh(raster, rtx) - return _viewshed_rt(raster, optix, x, y, observer_elev, target_elev) + optix = prepare_mesh(raster, rtx, pixel_spacing_x=psx, pixel_spacing_y=psy) + return _viewshed_rt(raster, optix, x, y, observer_elev, target_elev, + pixel_spacing_x=psx, pixel_spacing_y=psy) diff --git a/rtxpy/engine.py b/rtxpy/engine.py index 9efb197..4ce3ba4 100644 --- a/rtxpy/engine.py +++ b/rtxpy/engine.py @@ -976,7 +976,7 @@ class InteractiveViewer: - +/=: Increase speed - -: Decrease speed - G: Cycle terrain color (elevation → overlays) - - U: Cycle basemap (none → satellite → osm) + - U/Shift+U: Cycle basemap forward/backward (none → satellite → osm) - N: Cycle geometry layer (none → all → groups) - P: Jump to previous geometry in current group - ,/.: Decrease/increase overlay alpha (transparency) @@ -1016,6 +1016,8 @@ def __init__(self, raster, width: int = 800, height: int = 600, mesh_type: str = 'heightfield', overlay_layers: dict = None, title: str = None, + subtitle: str = None, + legend: dict = None, subsample: int = 1): """ Initialize the interactive viewer. @@ -1256,6 +1258,12 @@ def __init__(self, raster, width: int = 800, height: int = 600, # Help text cache (pre-rendered RGBA numpy array via PIL) self._help_text_rgba = None + # Title / subtitle overlay (pre-rendered RGBA numpy array via PIL) + self._subtitle = subtitle + self._legend_config = legend + self._title_overlay_rgba = None + self._legend_rgba = None + # FIRMS fire layer state self._accessor = None # RTX accessor for place_geojson self._firms_loaded = False # Whether fire data has been fetched @@ -1459,7 +1467,8 @@ def __init__(self, raster, width: int = 800, height: int = 600, verts.copy(), idxs.copy(), terrain_np.copy(), ) - rtx.add_geometry('terrain', verts, idxs) + rtx.add_geometry('terrain', verts, idxs, + grid_dims=(H, W)) def _get_front(self): """Get the forward direction vector.""" @@ -1756,7 +1765,9 @@ def _rebuild_at_resolution(self, factor): # 4. Replace terrain geometry (add_geometry overwrites existing key # in-place, preserving dict insertion order and instance IDs) if self.rtx is not None: - self.rtx.add_geometry('terrain', vertices, indices) + gd = (H, W) if self.mesh_type != 'voxel' else None + self.rtx.add_geometry('terrain', vertices, indices, + grid_dims=gd) self.elev_min = float(np.nanmin(terrain_np)) * ve self.elev_max = float(np.nanmax(terrain_np)) * ve @@ -1999,7 +2010,9 @@ def _rebuild_vertical_exaggeration(self, ve): # Replace terrain geometry (preserves dict insertion order) if self.rtx is not None: - self.rtx.add_geometry('terrain', vertices, indices) + gd = (H, W) if self.mesh_type != 'voxel' else None + self.rtx.add_geometry('terrain', vertices, indices, + grid_dims=gd) # Update elevation stats (scaled) self.elev_min = float(np.nanmin(terrain_np)) * ve @@ -3244,9 +3257,11 @@ def _handle_key_press(self, raw_key, key): self._rebuild_vertical_exaggeration(self.vertical_exaggeration) print(f"Mesh type: {self.mesh_type}") - # Basemap cycling: U = cycle none → satellite → osm → none + # Basemap cycling: U = cycle forward, Shift+U = cycle backward elif key == 'u': self._cycle_basemap() + elif key == 'U': + self._cycle_basemap(reverse=True) # Overlay alpha: , = decrease, . = increase elif key == ',': @@ -3537,7 +3552,9 @@ def _check_terrain_reload(self): # Replace terrain geometry if self.rtx is not None: - self.rtx.add_geometry('terrain', vertices, indices) + gd = (H, W) if self.mesh_type != 'voxel' else None + self.rtx.add_geometry('terrain', vertices, indices, + grid_dims=gd) # Reposition camera in new window self.position = np.array([ @@ -3695,12 +3712,13 @@ def _cycle_terrain_layer(self): self._update_frame() - def _cycle_basemap(self): + def _cycle_basemap(self, reverse=False): """Cycle basemap: none → satellite → osm → none. Auto-creates XYZTileService on-the-fly if needed. """ - self._basemap_idx = (self._basemap_idx + 1) % len(self._basemap_options) + step = -1 if reverse else 1 + self._basemap_idx = (self._basemap_idx + step) % len(self._basemap_options) provider = self._basemap_options[self._basemap_idx] if provider == 'none': @@ -4757,6 +4775,8 @@ def _render_frame(self): sun_angle=1.5, aperture=dof_aperture, focal_distance=dof_focal, + edge_strength=0.2, + edge_color=(0.15, 0.13, 0.10), bloom=not defer_post, tone_map=not defer_post, _return_gpu=True, @@ -4925,6 +4945,8 @@ def _composite_overlays(self): (self._wind_enabled and self._wind_particles is not None) or (self._gtfs_rt_enabled and self._gtfs_rt_vehicles is not None) or self.show_minimap + or self._title_overlay_rgba is not None + or self._legend_rgba is not None or self.show_help ) if needs_overlay: @@ -4943,6 +4965,13 @@ def _composite_overlays(self): # Minimap overlay self._blit_minimap_on_frame(img) + # Title overlay (hidden when help is shown — they overlap top-left) + if not self.show_help: + self._blit_title_on_frame(img) + + # Legend overlay (always visible) + self._blit_legend_on_frame(img) + # Help text overlay if self.show_help and self._help_text_rgba is not None: self._blit_help_on_frame(img) @@ -5046,6 +5075,212 @@ def _handle_mouse_motion(self, xpos, ypos): self._update_frame() + # ------------------------------------------------------------------ + # Title & legend overlays + # ------------------------------------------------------------------ + + def _render_title_overlay(self): + """Pre-render title + subtitle to an RGBA numpy array using PIL. + + Called once at startup; cached in ``self._title_overlay_rgba``. + Skipped when no title is set. + """ + if not self._title or self._title == 'rtxpy': + return + try: + from PIL import Image, ImageDraw, ImageFont + + bold_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" + sans_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" + try: + font_title = ImageFont.truetype(bold_path, 22) + font_sub = ImageFont.truetype(sans_path, 11) + except (OSError, IOError): + font_title = ImageFont.load_default() + font_sub = font_title + + pad_x, pad_y = 14, 10 + corner_r = 10 + bg_color = (15, 18, 24, 200) + + # Measure title + title_bbox = font_title.getbbox(self._title) + title_w = title_bbox[2] - title_bbox[0] + title_h = title_bbox[3] - title_bbox[1] + + # Measure subtitle + sub_h = 0 + sub_w = 0 + if self._subtitle: + sub_bbox = font_sub.getbbox(self._subtitle) + sub_w = sub_bbox[2] - sub_bbox[0] + sub_h = sub_bbox[3] - sub_bbox[1] + 4 # 4px gap + + img_w = pad_x * 2 + max(title_w, sub_w) + img_h = pad_y * 2 + title_h + sub_h + + img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.rounded_rectangle( + [0, 0, img_w - 1, img_h - 1], + radius=corner_r, fill=bg_color, + ) + draw.rounded_rectangle( + [0, 0, img_w - 1, img_h - 1], + radius=corner_r, outline=(60, 70, 90, 140), width=1, + ) + + # Title + draw.text((pad_x, pad_y), self._title, + fill=(255, 255, 255, 255), font=font_title) + + # Subtitle + if self._subtitle: + draw.text((pad_x, pad_y + title_h + 4), self._subtitle, + fill=(170, 180, 195, 220), font=font_sub) + + self._title_overlay_rgba = np.array(img, dtype=np.float32) / 255.0 + except ImportError: + pass + + def _render_legend_overlay(self): + """Pre-render legend to an RGBA numpy array using PIL. + + Called once at startup; cached in ``self._legend_rgba``. + Expects ``self._legend_config`` to be a dict with 'title' and + 'entries' keys. + """ + if not self._legend_config: + return + entries = self._legend_config.get('entries', []) + if not entries: + return + try: + from PIL import Image, ImageDraw, ImageFont + + bold_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" + sans_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf" + try: + font_header = ImageFont.truetype(bold_path, 13) + font_label = ImageFont.truetype(sans_path, 12) + except (OSError, IOError): + font_header = ImageFont.load_default() + font_label = font_header + + pad_x, pad_y = 14, 10 + corner_r = 10 + swatch_size = 12 + swatch_gap = 8 # gap between swatch and label + line_h = 18 + bg_color = (15, 18, 24, 200) + header_color = (180, 210, 255, 255) + label_color = (210, 215, 225, 220) + + legend_title = self._legend_config.get('title', '') + + # Measure widths + max_label_w = 0 + for label, _color in entries: + bbox = font_label.getbbox(label) + max_label_w = max(max_label_w, bbox[2] - bbox[0]) + + content_w = swatch_size + swatch_gap + max_label_w + if legend_title: + hdr_bbox = font_header.getbbox(legend_title) + content_w = max(content_w, hdr_bbox[2] - hdr_bbox[0]) + + img_w = pad_x * 2 + content_w + header_h = 0 + if legend_title: + header_h = 13 + 8 # font size + gap below header + img_h = pad_y * 2 + header_h + len(entries) * line_h + + img = Image.new('RGBA', (img_w, img_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + draw.rounded_rectangle( + [0, 0, img_w - 1, img_h - 1], + radius=corner_r, fill=bg_color, + ) + draw.rounded_rectangle( + [0, 0, img_w - 1, img_h - 1], + radius=corner_r, outline=(60, 70, 90, 140), width=1, + ) + + y = pad_y + if legend_title: + draw.text((pad_x, y), legend_title, + fill=header_color, font=font_header) + y += header_h + + for label, color_rgb in entries: + # Colour swatch + r8 = int(min(1.0, color_rgb[0]) * 255) + g8 = int(min(1.0, color_rgb[1]) * 255) + b8 = int(min(1.0, color_rgb[2]) * 255) + sx = pad_x + sy = y + (line_h - swatch_size) // 2 + draw.rectangle( + [sx, sy, sx + swatch_size - 1, sy + swatch_size - 1], + fill=(r8, g8, b8, 255), + ) + # Label + draw.text((pad_x + swatch_size + swatch_gap, y + 1), + label, fill=label_color, font=font_label) + y += line_h + + self._legend_rgba = np.array(img, dtype=np.float32) / 255.0 + except ImportError: + pass + + def _blit_title_on_frame(self, img): + """Alpha-composite cached title overlay onto the rendered frame. + + Parameters + ---------- + img : ndarray, shape (H, W, 3), float32 0-1 + Rendered frame. Modified in-place. + """ + if self._title_overlay_rgba is None: + return + ov = self._title_overlay_rgba + oh, ow = ov.shape[:2] + fh, fw = img.shape[:2] + margin = 12 + bh = min(oh, fh - margin) + bw = min(ow, fw - margin) + if bh <= 0 or bw <= 0: + return + alpha = ov[:bh, :bw, 3:4] + rgb = ov[:bh, :bw, :3] + region = img[margin:margin+bh, margin:margin+bw] + region[:] = region * (1 - alpha) + rgb * alpha + + def _blit_legend_on_frame(self, img): + """Alpha-composite cached legend overlay onto the rendered frame. + + Parameters + ---------- + img : ndarray, shape (H, W, 3), float32 0-1 + Rendered frame. Modified in-place. + """ + if self._legend_rgba is None: + return + ov = self._legend_rgba + oh, ow = ov.shape[:2] + fh, fw = img.shape[:2] + margin = 12 + y0 = fh - margin - oh + if y0 < 0: + y0 = 0 + bh = min(oh, fh - y0) + bw = min(ow, fw - margin) + if bh <= 0 or bw <= 0: + return + alpha = ov[:bh, :bw, 3:4] + rgb = ov[:bh, :bw, :3] + region = img[y0:y0+bh, margin:margin+bw] + region[:] = region * (1 - alpha) + rgb * alpha + def _render_help_text(self): """Pre-render help text to an RGBA numpy array using PIL. @@ -5065,7 +5300,7 @@ def _render_help_text(self): ]), ("TERRAIN", [ ("G", "Cycle terrain layer"), - ("U", "Cycle basemap"), + ("U / Shift+U", "Cycle basemap fwd / back"), ("C", "Cycle colormap"), ("Y", "Cycle color stretch"), (", / .", "Overlay alpha"), @@ -5348,6 +5583,8 @@ def run(self, start_position: Optional[Tuple[float, float, float]] = None, # --- Pre-render help text overlay --- self._render_help_text() + self._render_title_overlay() + self._render_legend_overlay() # --- Initialize minimap --- self._compute_minimap_background() @@ -5539,6 +5776,8 @@ def explore(raster, width: int = 800, height: int = 600, overlay_layers: dict = None, color_stretch: str = 'linear', title: str = None, + subtitle: str = None, + legend: dict = None, tile_service=None, geometry_colors_builder=None, baked_meshes=None, @@ -5628,7 +5867,7 @@ def explore(raster, width: int = 800, height: int = 600, - +/=: Increase speed - -: Decrease speed - G: Cycle terrain color (elevation → overlays) - - U: Cycle basemap (none → satellite → osm) + - U/Shift+U: Cycle basemap forward/backward (none → satellite → osm) - N: Cycle geometry layer (none → all → groups) - P: Jump to previous geometry in current group - ,/.: Decrease/increase overlay alpha (transparency) @@ -5685,6 +5924,8 @@ def explore(raster, width: int = 800, height: int = 600, mesh_type=mesh_type, overlay_layers=overlay_layers, title=title, + subtitle=subtitle, + legend=legend, subsample=subsample, ) viewer._geometry_colors_builder = geometry_colors_builder diff --git a/rtxpy/kernel.ptx b/rtxpy/kernel.ptx index 63d728e..5e0e610 100644 --- a/rtxpy/kernel.ptx +++ b/rtxpy/kernel.ptx @@ -11,7 +11,7 @@ .address_size 64 // .globl __raygen__main -.const .align 8 .b8 params[40]; +.const .align 8 .b8 params[88]; .visible .entry __raygen__main() { @@ -52,13 +52,13 @@ ld.global.f32 %f6, [%rd8+24]; ld.global.f32 %f8, [%rd8+28]; ld.const.u64 %rd4, [params]; + ld.const.u32 %r72, [params+40]; mov.f32 %f9, 0f00000000; - mov.u32 %r72, 16; mov.u32 %r74, 1; mov.u32 %r76, 6; mov.u32 %r108, 0; // begin inline asm - call(%r38,%r39,%r40,%r41,%r42,%r43,%r44,%r45,%r46,%r47,%r48,%r49,%r50,%r51,%r52,%r53,%r54,%r55,%r56,%r57,%r58,%r59,%r60,%r61,%r62,%r63,%r64,%r65,%r66,%r67,%r68,%r69),_optix_trace_typed_32,(%r108,%rd4,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%r74,%r72,%r108,%r74,%r108,%r76,%r111,%r112,%r113,%r114,%r115,%r116,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108); + call(%r38,%r39,%r40,%r41,%r42,%r43,%r44,%r45,%r46,%r47,%r48,%r49,%r50,%r51,%r52,%r53,%r54,%r55,%r56,%r57,%r58,%r59,%r60,%r61,%r62,%r63,%r64,%r65,%r66,%r67,%r68,%r69),_optix_trace_typed_32,(%r74,%rd4,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%r74,%r72,%r108,%r74,%r108,%r76,%r111,%r112,%r113,%r114,%r115,%r116,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108,%r108); // end inline asm ld.const.u64 %rd9, [params+16]; cvta.to.global.u64 %rd10, %rd9; @@ -130,85 +130,692 @@ $L__BB0_4: // .globl __closesthit__chit .visible .entry __closesthit__chit() { - .reg .f32 %f<38>; - .reg .b32 %r<19>; - .reg .b64 %rd<3>; + .reg .pred %p<3>; + .reg .f32 %f<36>; + .reg .b32 %r<31>; // begin inline asm - call (%f1), _optix_get_ray_tmax, (); + call (%f2), _optix_get_ray_tmax, (); // end inline asm - cvt.rzi.u32.f32 %r18, %f1; // begin inline asm - call (%rd1), _optix_get_gas_traversable_handle, (); + call (%r8), _optix_read_primitive_idx, (); // end inline asm // begin inline asm - call (%r1), _optix_read_primitive_idx, (); + call (%r9), _optix_get_hit_kind, (); // end inline asm + setp.eq.s32 %p1, %r9, 254; + @%p1 bra $L__BB2_2; + // begin inline asm - call (%r2), _optix_read_sbt_gas_idx, (); + call (%r10), _optix_get_hit_kind, (); // end inline asm + setp.ne.s32 %p2, %r10, 255; + mov.u32 %r30, 1065353216; + mov.u32 %r28, 0; + mov.u32 %r29, %r28; + @%p2 bra $L__BB2_3; + +$L__BB2_2: // begin inline asm - call (%f2), _optix_get_ray_time, (); + call (%f3, %f4, %f5, %f6, %f7, %f8, %f9, %f10, %f11), _optix_get_triangle_vertex_data_current_hit, (); // end inline asm + sub.ftz.f32 %f12, %f6, %f3; + sub.ftz.f32 %f13, %f7, %f4; + sub.ftz.f32 %f14, %f8, %f5; + sub.ftz.f32 %f15, %f9, %f3; + sub.ftz.f32 %f16, %f10, %f4; + sub.ftz.f32 %f17, %f11, %f5; + mul.ftz.f32 %f18, %f13, %f17; + mul.ftz.f32 %f19, %f14, %f16; + sub.ftz.f32 %f20, %f18, %f19; + mul.ftz.f32 %f21, %f12, %f17; + mul.ftz.f32 %f22, %f14, %f15; + sub.ftz.f32 %f23, %f21, %f22; + mul.ftz.f32 %f24, %f12, %f16; + mul.ftz.f32 %f25, %f13, %f15; + sub.ftz.f32 %f26, %f24, %f25; + mul.ftz.f32 %f27, %f23, %f23; + fma.rn.ftz.f32 %f28, %f20, %f20, %f27; + fma.rn.ftz.f32 %f29, %f26, %f26, %f28; + rsqrt.approx.ftz.f32 %f30, %f29; + mul.ftz.f32 %f31, %f30, %f20; + mul.ftz.f32 %f32, %f23, %f30; + neg.ftz.f32 %f33, %f32; + mul.ftz.f32 %f34, %f30, %f26; + mov.b32 %r28, %f31; + mov.b32 %r29, %f33; + mov.b32 %r30, %f34; + +$L__BB2_3: + cvt.rzi.ftz.u32.f32 %r27, %f2; + cvt.rn.f32.u32 %f35, %r27; + mov.b32 %r15, %f35; + mov.u32 %r14, 0; // begin inline asm - call (%f3, %f4, %f5, %f6, %f7, %f8, %f9, %f10, %f11), _optix_get_triangle_vertex_data, (%rd1, %r1, %r2, %f2); + call _optix_set_payload, (%r14, %r15); // end inline asm - sub.f32 %f13, %f6, %f3; - sub.f32 %f14, %f7, %f4; - sub.f32 %f15, %f8, %f5; - sub.f32 %f16, %f9, %f3; - sub.f32 %f17, %f10, %f4; - sub.f32 %f18, %f11, %f5; - mul.f32 %f19, %f14, %f18; - mul.f32 %f20, %f15, %f17; - sub.f32 %f21, %f19, %f20; - mul.f32 %f22, %f13, %f18; - mul.f32 %f23, %f15, %f16; - sub.f32 %f24, %f22, %f23; - mul.f32 %f25, %f13, %f17; - mul.f32 %f26, %f14, %f16; - sub.f32 %f27, %f25, %f26; - mul.f32 %f28, %f24, %f24; - fma.rn.f32 %f29, %f21, %f21, %f28; - fma.rn.f32 %f30, %f27, %f27, %f29; - sqrt.rn.f32 %f31, %f30; - rcp.rn.f32 %f32, %f31; - mul.f32 %f33, %f32, %f21; - mul.f32 %f34, %f24, %f32; - neg.f32 %f35, %f34; - mul.f32 %f36, %f32, %f27; - cvt.rn.f32.u32 %f37, %r18; - mov.b32 %r6, %f37; - mov.u32 %r5, 0; + mov.u32 %r16, 1; // begin inline asm - call _optix_set_payload, (%r5, %r6); + call _optix_set_payload, (%r16, %r28); // end inline asm - mov.b32 %r8, %f33; - mov.u32 %r7, 1; + mov.u32 %r18, 2; // begin inline asm - call _optix_set_payload, (%r7, %r8); + call _optix_set_payload, (%r18, %r29); + // end inline asm + mov.u32 %r20, 3; + // begin inline asm + call _optix_set_payload, (%r20, %r30); + // end inline asm + mov.u32 %r22, 4; + // begin inline asm + call _optix_set_payload, (%r22, %r8); + // end inline asm + // begin inline asm + call (%r24), _optix_read_instance_id, (); + // end inline asm + mov.u32 %r25, 5; + // begin inline asm + call _optix_set_payload, (%r25, %r24); + // end inline asm + ret; + +} + // .globl __intersection__heightfield +.visible .entry __intersection__heightfield() +{ + .reg .pred %p<76>; + .reg .b16 %rs<27>; + .reg .f32 %f<325>; + .reg .b32 %r<74>; + .reg .b64 %rd<7>; + + + // begin inline asm + call (%r21), _optix_read_primitive_idx, (); + // end inline asm + // begin inline asm + call (%f106), _optix_get_object_ray_origin_x, (); + // end inline asm + // begin inline asm + call (%f107), _optix_get_object_ray_origin_y, (); // end inline asm - mov.b32 %r10, %f35; - mov.u32 %r9, 2; + // begin inline asm + call (%f108), _optix_get_object_ray_origin_z, (); + // end inline asm + // begin inline asm + call (%f109), _optix_get_object_ray_direction_x, (); + // end inline asm + // begin inline asm + call (%f110), _optix_get_object_ray_direction_y, (); + // end inline asm + // begin inline asm + call (%f111), _optix_get_object_ray_direction_z, (); + // end inline asm + // begin inline asm + call (%f112), _optix_get_ray_tmin, (); + // end inline asm + // begin inline asm + call (%f322), _optix_get_ray_tmax, (); + // end inline asm + ld.const.v2.f32 {%f114, %f115}, [params+64]; + ld.const.v2.u32 {%r22, %r23}, [params+56]; + ld.const.u32 %r25, [params+80]; + div.s32 %r26, %r21, %r25; + mul.lo.s32 %r27, %r26, %r25; + sub.s32 %r28, %r21, %r27; + ld.const.u32 %r2, [params+76]; + mul.lo.s32 %r3, %r28, %r2; + mul.lo.s32 %r4, %r26, %r2; + add.s32 %r29, %r3, %r2; + add.s32 %r30, %r22, -1; + min.s32 %r5, %r29, %r30; + add.s32 %r31, %r4, %r2; + add.s32 %r32, %r23, -1; + min.s32 %r6, %r31, %r32; + setp.ge.s32 %p2, %r3, %r5; + setp.ge.s32 %p3, %r4, %r6; + or.pred %p4, %p2, %p3; + @%p4 bra $L__BB3_40; + + cvt.rn.f32.s32 %f116, %r3; + mul.ftz.f32 %f11, %f114, %f116; + cvt.rn.f32.s32 %f117, %r4; + mul.ftz.f32 %f12, %f115, %f117; + cvt.rn.f32.s32 %f118, %r5; + mul.ftz.f32 %f13, %f114, %f118; + cvt.rn.f32.s32 %f119, %r6; + mul.ftz.f32 %f14, %f115, %f119; + abs.ftz.f32 %f15, %f109; + setp.gt.ftz.f32 %p5, %f15, 0f322BCC77; + @%p5 bra $L__BB3_3; + bra.uni $L__BB3_2; + +$L__BB3_3: + sub.ftz.f32 %f120, %f11, %f106; + div.approx.ftz.f32 %f121, %f120, %f109; + sub.ftz.f32 %f122, %f13, %f106; + div.approx.ftz.f32 %f123, %f122, %f109; + setp.gt.ftz.f32 %p9, %f121, %f123; + selp.f32 %f124, %f123, %f121, %p9; + selp.f32 %f125, %f121, %f123, %p9; + max.ftz.f32 %f298, %f112, %f124; + min.ftz.f32 %f299, %f322, %f125; + bra.uni $L__BB3_4; + +$L__BB3_2: + setp.lt.ftz.f32 %p6, %f106, %f11; + setp.gt.ftz.f32 %p7, %f106, %f13; + or.pred %p8, %p6, %p7; + mov.f32 %f298, %f112; + mov.f32 %f299, %f322; + @%p8 bra $L__BB3_40; + +$L__BB3_4: + abs.ftz.f32 %f20, %f110; + setp.gt.ftz.f32 %p10, %f20, 0f322BCC77; + @%p10 bra $L__BB3_6; + bra.uni $L__BB3_5; + +$L__BB3_6: + sub.ftz.f32 %f126, %f12, %f107; + div.approx.ftz.f32 %f127, %f126, %f110; + sub.ftz.f32 %f128, %f14, %f107; + div.approx.ftz.f32 %f129, %f128, %f110; + setp.gt.ftz.f32 %p14, %f127, %f129; + selp.f32 %f130, %f129, %f127, %p14; + selp.f32 %f131, %f127, %f129, %p14; + max.ftz.f32 %f298, %f298, %f130; + min.ftz.f32 %f299, %f299, %f131; + bra.uni $L__BB3_7; + +$L__BB3_5: + setp.lt.ftz.f32 %p11, %f107, %f12; + setp.gt.ftz.f32 %p12, %f107, %f14; + or.pred %p13, %p11, %p12; + @%p13 bra $L__BB3_40; + +$L__BB3_7: + setp.gt.ftz.f32 %p15, %f298, %f299; + @%p15 bra $L__BB3_40; + + setp.leu.ftz.f32 %p16, %f15, 0f322BCC77; + fma.rn.ftz.f32 %f133, %f109, %f298, %f106; + fma.rn.ftz.f32 %f134, %f110, %f298, %f107; + div.approx.ftz.f32 %f135, %f133, %f114; + div.approx.ftz.f32 %f136, %f134, %f115; + cvt.rmi.ftz.f32.f32 %f137, %f135; + cvt.rzi.ftz.s32.f32 %r33, %f137; + cvt.rmi.ftz.f32.f32 %f138, %f136; + cvt.rzi.ftz.s32.f32 %r34, %f138; + max.s32 %r35, %r33, %r3; + add.s32 %r36, %r5, -1; + min.s32 %r71, %r35, %r36; + max.s32 %r37, %r34, %r4; + add.s32 %r38, %r6, -1; + min.s32 %r70, %r37, %r38; + mov.f32 %f303, 0f7149F2CA; + mov.f32 %f302, %f303; + @%p16 bra $L__BB3_10; + + div.approx.ftz.f32 %f139, %f114, %f109; + abs.ftz.f32 %f302, %f139; + +$L__BB3_10: + setp.leu.ftz.f32 %p17, %f20, 0f322BCC77; + @%p17 bra $L__BB3_12; + + div.approx.ftz.f32 %f141, %f115, %f110; + abs.ftz.f32 %f303, %f141; + +$L__BB3_12: + mov.f32 %f309, 0f7149F2CA; + mov.f32 %f310, %f309; + @%p16 bra $L__BB3_14; + + setp.ge.ftz.f32 %p19, %f109, 0f00000000; + selp.u32 %r39, 1, 0, %p19; + add.s32 %r40, %r71, %r39; + cvt.rn.f32.s32 %f143, %r40; + mul.ftz.f32 %f144, %f114, %f143; + sub.ftz.f32 %f145, %f144, %f106; + div.approx.ftz.f32 %f310, %f145, %f109; + +$L__BB3_14: + @%p17 bra $L__BB3_16; + + setp.ge.ftz.f32 %p21, %f110, 0f00000000; + selp.u32 %r41, 1, 0, %p21; + add.s32 %r42, %r70, %r41; + cvt.rn.f32.s32 %f147, %r42; + mul.ftz.f32 %f148, %f115, %f147; + sub.ftz.f32 %f149, %f148, %f107; + div.approx.ftz.f32 %f309, %f149, %f110; + +$L__BB3_16: + mul.lo.s32 %r43, %r2, %r2; + shl.b32 %r9, %r43, 1; + setp.lt.s32 %p23, %r9, -3; + mov.pred %p22, -1; + mov.f32 %f323, 0f00000000; + mov.f32 %f324, %f323; + mov.pred %p75, %p22; + @%p23 bra $L__BB3_38; + + ld.const.u64 %rd2, [params+48]; + cvta.to.global.u64 %rd1, %rd2; + ld.const.f32 %f33, [params+72]; + add.s32 %r10, %r9, 3; + setp.ge.ftz.f32 %p24, %f110, 0f00000000; + mov.f32 %f324, 0f00000000; + selp.b32 %r11, 1, -1, %p24; + setp.ge.ftz.f32 %p25, %f109, 0f00000000; + selp.b32 %r12, 1, -1, %p25; + mov.u32 %r69, 0; + mov.f32 %f323, %f324; + +$L__BB3_18: + setp.ge.s32 %p27, %r71, %r5; + setp.lt.s32 %p28, %r71, %r3; + or.pred %p29, %p28, %p27; + setp.lt.s32 %p30, %r70, %r4; + or.pred %p31, %p29, %p30; + setp.ge.s32 %p32, %r70, %r6; + or.pred %p33, %p32, %p31; + mov.pred %p75, %p22; + @%p33 bra $L__BB3_38; + + mul.lo.s32 %r45, %r22, %r70; + add.s32 %r46, %r45, %r71; + mul.wide.s32 %rd3, %r46, 4; + add.s64 %rd4, %rd1, %rd3; + ld.global.f32 %f155, [%rd4]; + mov.b32 %r47, %f155; + and.b32 %r48, %r47, 2139095040; + setp.ne.s32 %p34, %r48, 2139095040; + and.b32 %r49, %r47, 8388607; + setp.eq.s32 %p35, %r49, 0; + or.pred %p36, %p35, %p34; + selp.f32 %f156, %f155, 0f00000000, %p36; + mul.ftz.f32 %f40, %f33, %f156; + ld.global.f32 %f157, [%rd4+4]; + mov.b32 %r50, %f157; + and.b32 %r51, %r50, 2139095040; + setp.ne.s32 %p37, %r51, 2139095040; + and.b32 %r52, %r50, 8388607; + setp.eq.s32 %p38, %r52, 0; + or.pred %p39, %p38, %p37; + selp.f32 %f158, %f157, 0f00000000, %p39; + mul.ftz.f32 %f41, %f33, %f158; + add.s32 %r53, %r45, %r22; + add.s32 %r54, %r53, %r71; + mul.wide.s32 %rd5, %r54, 4; + add.s64 %rd6, %rd1, %rd5; + ld.global.f32 %f159, [%rd6]; + mov.b32 %r55, %f159; + and.b32 %r56, %r55, 2139095040; + setp.ne.s32 %p40, %r56, 2139095040; + and.b32 %r57, %r55, 8388607; + setp.eq.s32 %p41, %r57, 0; + or.pred %p42, %p41, %p40; + selp.f32 %f160, %f159, 0f00000000, %p42; + mul.ftz.f32 %f42, %f33, %f160; + ld.global.f32 %f161, [%rd6+4]; + mov.b32 %r58, %f161; + and.b32 %r59, %r58, 2139095040; + setp.ne.s32 %p43, %r59, 2139095040; + and.b32 %r60, %r58, 8388607; + setp.eq.s32 %p44, %r60, 0; + or.pred %p45, %p44, %p43; + selp.f32 %f162, %f161, 0f00000000, %p45; + mul.ftz.f32 %f43, %f33, %f162; + cvt.rn.f32.s32 %f44, %r71; + mul.ftz.f32 %f45, %f114, %f44; + cvt.rn.f32.s32 %f46, %r70; + mul.ftz.f32 %f47, %f115, %f46; + add.s32 %r61, %r71, 1; + cvt.rn.f32.s32 %f163, %r61; + mul.ftz.f32 %f48, %f114, %f163; + add.s32 %r62, %r70, 1; + cvt.rn.f32.s32 %f164, %r62; + mul.ftz.f32 %f165, %f115, %f164; + sub.ftz.f32 %f49, %f48, %f45; + sub.ftz.f32 %f50, %f47, %f47; + sub.ftz.f32 %f51, %f41, %f40; + sub.ftz.f32 %f52, %f45, %f45; + sub.ftz.f32 %f53, %f165, %f47; + sub.ftz.f32 %f54, %f42, %f40; + mul.ftz.f32 %f166, %f110, %f54; + mul.ftz.f32 %f55, %f111, %f53; + sub.ftz.f32 %f56, %f166, %f55; + mul.ftz.f32 %f167, %f109, %f54; + mul.ftz.f32 %f168, %f111, %f52; + sub.ftz.f32 %f169, %f167, %f168; + mul.ftz.f32 %f57, %f109, %f53; + mul.ftz.f32 %f170, %f110, %f52; + sub.ftz.f32 %f58, %f57, %f170; + neg.ftz.f32 %f59, %f169; + mul.ftz.f32 %f171, %f49, %f56; + mul.ftz.f32 %f172, %f50, %f169; + sub.ftz.f32 %f173, %f171, %f172; + fma.rn.ftz.f32 %f60, %f58, %f51, %f173; + setp.lt.ftz.f32 %p46, %f60, 0f322BCC77; + mov.u16 %rs23, 0; + @%p46 bra $L__BB3_23; + + mov.u16 %rs23, 0; + cvt.rn.f32.s32 %f279, %r70; + mul.ftz.f32 %f278, %f115, %f279; + cvt.rn.f32.s32 %f277, %r71; + mul.ftz.f32 %f276, %f114, %f277; + rcp.approx.ftz.f32 %f61, %f60; + sub.ftz.f32 %f62, %f106, %f276; + sub.ftz.f32 %f63, %f107, %f278; + mul.ftz.f32 %f174, %f63, %f59; + fma.rn.ftz.f32 %f175, %f62, %f56, %f174; + sub.ftz.f32 %f64, %f108, %f40; + fma.rn.ftz.f32 %f176, %f58, %f64, %f175; + mul.ftz.f32 %f65, %f176, %f61; + setp.lt.ftz.f32 %p47, %f65, 0f00000000; + setp.gt.ftz.f32 %p48, %f65, 0f3F800000; + or.pred %p49, %p47, %p48; + @%p49 bra $L__BB3_23; + + mov.u16 %rs23, 0; + cvt.rn.f32.s32 %f282, %r70; + mul.ftz.f32 %f281, %f115, %f282; + sub.ftz.f32 %f280, %f281, %f281; + mul.ftz.f32 %f177, %f63, %f51; + mul.ftz.f32 %f178, %f280, %f64; + sub.ftz.f32 %f66, %f177, %f178; + mul.ftz.f32 %f179, %f49, %f64; + mul.ftz.f32 %f180, %f62, %f51; + sub.ftz.f32 %f181, %f180, %f179; + mul.ftz.f32 %f182, %f63, %f49; + mul.ftz.f32 %f183, %f62, %f280; + sub.ftz.f32 %f67, %f183, %f182; + neg.ftz.f32 %f68, %f181; + mul.ftz.f32 %f184, %f109, %f66; + mul.ftz.f32 %f185, %f110, %f181; + sub.ftz.f32 %f186, %f184, %f185; + fma.rn.ftz.f32 %f187, %f111, %f67, %f186; + mul.ftz.f32 %f188, %f187, %f61; + setp.lt.ftz.f32 %p50, %f188, 0f00000000; + add.ftz.f32 %f189, %f65, %f188; + setp.gt.ftz.f32 %p51, %f189, 0f3F800000; + or.pred %p52, %p50, %p51; + @%p52 bra $L__BB3_23; + + cvt.rn.f32.s32 %f285, %r71; + mul.ftz.f32 %f284, %f114, %f285; + sub.ftz.f32 %f283, %f284, %f284; + mul.ftz.f32 %f190, %f53, %f68; + fma.rn.ftz.f32 %f191, %f283, %f66, %f190; + fma.rn.ftz.f32 %f192, %f67, %f54, %f191; + mul.ftz.f32 %f316, %f192, %f61; + setp.gt.ftz.f32 %p53, %f316, 0f00000000; + selp.u16 %rs23, 1, 0, %p53; + +$L__BB3_23: + mov.u16 %rs24, 0; + setp.eq.s16 %p54, %rs23, 0; + @%p54 bra $L__BB3_26; + + mov.u16 %rs24, 0; + setp.ltu.ftz.f32 %p55, %f316, %f112; + setp.geu.ftz.f32 %p56, %f316, %f322; + or.pred %p57, %p55, %p56; + @%p57 bra $L__BB3_26; + + cvt.rn.f32.s32 %f287, %r70; + cvt.rn.f32.s32 %f286, %r71; + fma.rn.ftz.f32 %f193, %f109, %f316, %f106; + fma.rn.ftz.f32 %f194, %f110, %f316, %f107; + div.approx.ftz.f32 %f195, %f193, %f114; + sub.ftz.f32 %f196, %f195, %f286; + div.approx.ftz.f32 %f197, %f194, %f115; + sub.ftz.f32 %f198, %f197, %f287; + mov.f32 %f199, 0f3F800000; + min.ftz.f32 %f200, %f199, %f196; + mov.f32 %f201, 0f00000000; + max.ftz.f32 %f202, %f201, %f200; + min.ftz.f32 %f203, %f199, %f198; + max.ftz.f32 %f204, %f201, %f203; + sub.ftz.f32 %f205, %f199, %f204; + mul.ftz.f32 %f206, %f51, %f205; + sub.ftz.f32 %f207, %f43, %f42; + fma.rn.ftz.f32 %f208, %f207, %f204, %f206; + div.approx.ftz.f32 %f209, %f208, %f114; + sub.ftz.f32 %f210, %f199, %f202; + mul.ftz.f32 %f211, %f54, %f210; + sub.ftz.f32 %f212, %f43, %f41; + fma.rn.ftz.f32 %f213, %f212, %f202, %f211; + div.approx.ftz.f32 %f214, %f213, %f115; + mul.ftz.f32 %f215, %f214, %f214; + fma.rn.ftz.f32 %f216, %f209, %f209, %f215; + add.ftz.f32 %f217, %f216, 0f3F800000; + rsqrt.approx.ftz.f32 %f218, %f217; + mul.ftz.f32 %f219, %f209, %f218; + neg.ftz.f32 %f323, %f219; + mul.ftz.f32 %f220, %f214, %f218; + neg.ftz.f32 %f324, %f220; + mov.u16 %rs24, 1; + mov.f32 %f322, %f316; + +$L__BB3_26: + add.s32 %r67, %r71, 1; + cvt.rn.f32.s32 %f291, %r67; + mul.ftz.f32 %f290, %f114, %f291; + cvt.rn.f32.s32 %f289, %r71; + mul.ftz.f32 %f288, %f114, %f289; + sub.ftz.f32 %f76, %f42, %f41; + mul.ftz.f32 %f221, %f110, %f76; + sub.ftz.f32 %f77, %f221, %f55; + mul.ftz.f32 %f222, %f109, %f76; + sub.ftz.f32 %f78, %f288, %f290; + mul.ftz.f32 %f223, %f111, %f78; + sub.ftz.f32 %f224, %f222, %f223; + mul.ftz.f32 %f225, %f110, %f78; + sub.ftz.f32 %f79, %f57, %f225; + neg.ftz.f32 %f80, %f224; + sub.ftz.f32 %f81, %f290, %f290; + mul.ftz.f32 %f226, %f81, %f77; + mul.ftz.f32 %f227, %f53, %f224; + sub.ftz.f32 %f228, %f226, %f227; + sub.ftz.f32 %f82, %f43, %f41; + fma.rn.ftz.f32 %f83, %f79, %f82, %f228; + setp.lt.ftz.f32 %p58, %f83, 0f322BCC77; + mov.u16 %rs25, 0; + @%p58 bra $L__BB3_30; + + mov.u16 %rs25, 0; + add.s32 %r68, %r71, 1; + cvt.rn.f32.s32 %f295, %r68; + mul.ftz.f32 %f294, %f114, %f295; + cvt.rn.f32.s32 %f293, %r70; + mul.ftz.f32 %f292, %f115, %f293; + rcp.approx.ftz.f32 %f84, %f83; + sub.ftz.f32 %f85, %f106, %f294; + sub.ftz.f32 %f86, %f107, %f292; + mul.ftz.f32 %f229, %f86, %f80; + fma.rn.ftz.f32 %f230, %f85, %f77, %f229; + sub.ftz.f32 %f87, %f108, %f41; + fma.rn.ftz.f32 %f231, %f79, %f87, %f230; + mul.ftz.f32 %f88, %f231, %f84; + setp.lt.ftz.f32 %p59, %f88, 0f00000000; + setp.gt.ftz.f32 %p60, %f88, 0f3F800000; + or.pred %p61, %p59, %p60; + @%p61 bra $L__BB3_30; + + mov.u16 %rs25, 0; + mul.ftz.f32 %f232, %f86, %f82; + mul.ftz.f32 %f233, %f53, %f87; + sub.ftz.f32 %f89, %f232, %f233; + mul.ftz.f32 %f234, %f81, %f87; + mul.ftz.f32 %f235, %f85, %f82; + sub.ftz.f32 %f236, %f235, %f234; + mul.ftz.f32 %f237, %f86, %f81; + mul.ftz.f32 %f238, %f85, %f53; + sub.ftz.f32 %f90, %f238, %f237; + neg.ftz.f32 %f91, %f236; + mul.ftz.f32 %f239, %f109, %f89; + mul.ftz.f32 %f240, %f110, %f236; + sub.ftz.f32 %f241, %f239, %f240; + fma.rn.ftz.f32 %f242, %f111, %f90, %f241; + mul.ftz.f32 %f243, %f84, %f242; + setp.lt.ftz.f32 %p62, %f243, 0f00000000; + add.ftz.f32 %f244, %f88, %f243; + setp.gt.ftz.f32 %p63, %f244, 0f3F800000; + or.pred %p64, %p62, %p63; + @%p64 bra $L__BB3_30; + + mul.ftz.f32 %f245, %f53, %f91; + fma.rn.ftz.f32 %f246, %f78, %f89, %f245; + fma.rn.ftz.f32 %f247, %f90, %f76, %f246; + mul.ftz.f32 %f316, %f84, %f247; + setp.gt.ftz.f32 %p65, %f316, 0f00000000; + selp.u16 %rs25, 1, 0, %p65; + +$L__BB3_30: + setp.eq.s16 %p66, %rs25, 0; + @%p66 bra $L__BB3_33; + + setp.ltu.ftz.f32 %p67, %f316, %f112; + setp.geu.ftz.f32 %p68, %f316, %f322; + or.pred %p69, %p67, %p68; + @%p69 bra $L__BB3_33; + + cvt.rn.f32.s32 %f297, %r70; + cvt.rn.f32.s32 %f296, %r71; + fma.rn.ftz.f32 %f248, %f109, %f316, %f106; + fma.rn.ftz.f32 %f249, %f110, %f316, %f107; + div.approx.ftz.f32 %f250, %f248, %f114; + sub.ftz.f32 %f251, %f250, %f296; + div.approx.ftz.f32 %f252, %f249, %f115; + sub.ftz.f32 %f253, %f252, %f297; + mov.f32 %f254, 0f3F800000; + min.ftz.f32 %f255, %f254, %f251; + mov.f32 %f256, 0f00000000; + max.ftz.f32 %f257, %f256, %f255; + min.ftz.f32 %f258, %f254, %f253; + max.ftz.f32 %f259, %f256, %f258; + sub.ftz.f32 %f260, %f254, %f259; + mul.ftz.f32 %f261, %f51, %f260; + sub.ftz.f32 %f262, %f43, %f42; + fma.rn.ftz.f32 %f263, %f262, %f259, %f261; + div.approx.ftz.f32 %f264, %f263, %f114; + sub.ftz.f32 %f265, %f254, %f257; + mul.ftz.f32 %f266, %f54, %f265; + fma.rn.ftz.f32 %f267, %f82, %f257, %f266; + div.approx.ftz.f32 %f268, %f267, %f115; + mul.ftz.f32 %f269, %f268, %f268; + fma.rn.ftz.f32 %f270, %f264, %f264, %f269; + add.ftz.f32 %f271, %f270, 0f3F800000; + rsqrt.approx.ftz.f32 %f272, %f271; + mul.ftz.f32 %f273, %f264, %f272; + neg.ftz.f32 %f323, %f273; + mul.ftz.f32 %f274, %f268, %f272; + neg.ftz.f32 %f324, %f274; + mov.u16 %rs24, 1; + mov.f32 %f322, %f316; + +$L__BB3_33: + setp.ne.s16 %p71, %rs24, 0; + mov.pred %p75, 0; + @%p71 bra $L__BB3_38; + + setp.lt.ftz.f32 %p72, %f310, %f309; + @%p72 bra $L__BB3_36; + bra.uni $L__BB3_35; + +$L__BB3_36: + add.s32 %r71, %r71, %r12; + add.ftz.f32 %f310, %f302, %f310; + bra.uni $L__BB3_37; + +$L__BB3_35: + add.s32 %r70, %r70, %r11; + add.ftz.f32 %f309, %f303, %f309; + +$L__BB3_37: + add.s32 %r20, %r69, 1; + setp.lt.s32 %p74, %r69, %r10; + mov.u32 %r69, %r20; + mov.pred %p75, %p22; + @%p74 bra $L__BB3_18; + +$L__BB3_38: + @%p75 bra $L__BB3_40; + + mov.b32 %r65, %f323; + mov.b32 %r66, %f324; + mov.u32 %r64, 0; + // begin inline asm + call (%r63), _optix_report_intersection_2, (%f322, %r64, %r65, %r66); + // end inline asm + +$L__BB3_40: + ret; + +} + // .globl __closesthit__heightfield +.visible .entry __closesthit__heightfield() +{ + .reg .pred %p<2>; + .reg .f32 %f<11>; + .reg .b32 %r<17>; + + + // begin inline asm + call (%f1), _optix_get_ray_tmax, (); + // end inline asm + // begin inline asm + call (%r1), _optix_get_attribute_0, (); + // end inline asm + mov.b32 %f2, %r1; + // begin inline asm + call (%r2), _optix_get_attribute_1, (); + // end inline asm + mov.b32 %f3, %r2; + mul.ftz.f32 %f4, %f2, %f2; + mov.f32 %f5, 0f3F800000; + sub.ftz.f32 %f6, %f5, %f4; + mul.ftz.f32 %f7, %f3, %f3; + sub.ftz.f32 %f8, %f6, %f7; + setp.lt.ftz.f32 %p1, %f8, 0f00000000; + selp.f32 %f9, 0f00000000, %f8, %p1; + sqrt.approx.ftz.f32 %f10, %f9; + mov.b32 %r4, %f1; + mov.u32 %r3, 0; + // begin inline asm + call _optix_set_payload, (%r3, %r4); + // end inline asm + mov.u32 %r5, 1; + // begin inline asm + call _optix_set_payload, (%r5, %r1); + // end inline asm + mov.u32 %r7, 2; + // begin inline asm + call _optix_set_payload, (%r7, %r2); + // end inline asm + mov.b32 %r10, %f10; + mov.u32 %r9, 3; // begin inline asm call _optix_set_payload, (%r9, %r10); // end inline asm - mov.b32 %r12, %f36; - mov.u32 %r11, 3; // begin inline asm - call _optix_set_payload, (%r11, %r12); + call (%r11), _optix_read_primitive_idx, (); // end inline asm - mov.u32 %r13, 4; + mov.u32 %r12, 4; // begin inline asm - call _optix_set_payload, (%r13, %r1); + call _optix_set_payload, (%r12, %r11); // end inline asm // begin inline asm - call (%r15), _optix_read_instance_id, (); + call (%r14), _optix_read_instance_id, (); // end inline asm - mov.u32 %r16, 5; + mov.u32 %r15, 5; // begin inline asm - call _optix_set_payload, (%r16, %r15); + call _optix_set_payload, (%r15, %r14); // end inline asm ret; diff --git a/rtxpy/mesh_store.py b/rtxpy/mesh_store.py new file mode 100644 index 0000000..d2fd772 --- /dev/null +++ b/rtxpy/mesh_store.py @@ -0,0 +1,314 @@ +"""Zarr-based mesh storage spatially partitioned to match DEM chunks. + +Stores mesh data (vertices + indices) inside a zarr store alongside the DEM, +partitioned into spatial chunks that align with the elevation grid chunks. +This makes the zarr a self-contained scene file and enables chunk-level mesh +loading — load only the meshes for the terrain region you're viewing. + +Supports both triangle meshes and curve geometries (B-spline tubes used for +roads and water features). + +Zarr layout:: + + scene.zarr/ + elevation/... # existing DEM + meshes/ # mesh group + .zattrs → {pixel_spacing, elevation_shape, elevation_chunks} + building/ + .zattrs → {color: [0.6, 0.6, 0.6, 1.0]} + 0_0/vertices, indices # chunk (row=0, col=0) + ... + road/ + .zattrs → {color: ..., type: "curve"} + 0_0/vertices, widths, indices + ... +""" + +import numpy as np +import zarr + + +def chunks_for_pixel_window(yi0, yi1, xi0, xi1, chunk_h, chunk_w): + """Map a pixel-coordinate window to overlapping chunk indices. + + Parameters + ---------- + yi0, yi1 : int + Row pixel range [yi0, yi1). + xi0, xi1 : int + Column pixel range [xi0, xi1). + chunk_h, chunk_w : int + Chunk size in pixels (rows, cols). + + Returns + ------- + list of (int, int) + List of (chunk_row, chunk_col) tuples that overlap the window. + """ + cr0 = max(yi0 // chunk_h, 0) + cr1 = (yi1 - 1) // chunk_h # inclusive + cc0 = max(xi0 // chunk_w, 0) + cc1 = (xi1 - 1) // chunk_w # inclusive + return [(cr, cc) for cr in range(cr0, cr1 + 1) + for cc in range(cc0, cc1 + 1)] + + +def save_meshes_to_zarr(zarr_path, meshes, colors, pixel_spacing, + elevation_shape, elevation_chunks, + curves=None): + """Save mesh geometries into a zarr store, spatially partitioned by chunk. + + Parameters + ---------- + zarr_path : str or Path + Path to an existing zarr store (opened in ``r+`` mode). + meshes : dict + ``{geometry_id: (vertices_flat, indices_flat)}`` where vertices is + float32 (x, y, z, x, y, z, ...) and indices is int32. + colors : dict + ``{geometry_id: (r, g, b) or (r, g, b, a)}``. + pixel_spacing : tuple of float + ``(pixel_spacing_x, pixel_spacing_y)`` — world-units per pixel. + elevation_shape : tuple of int + ``(H, W)`` of the DEM grid. + elevation_chunks : tuple of int + ``(chunk_h, chunk_w)`` of the DEM chunking. + curves : dict, optional + ``{geometry_id: (vertices_flat, widths_flat, indices_flat)}`` for + B-spline curve tube geometries (roads, water features). + """ + store = zarr.open(str(zarr_path), mode='r+') + + # Create or overwrite meshes group + mg = store.create_group('meshes', overwrite=True) + mg.attrs['pixel_spacing'] = list(pixel_spacing) + mg.attrs['elevation_shape'] = list(elevation_shape) + mg.attrs['elevation_chunks'] = list(elevation_chunks) + + psx, psy = pixel_spacing + chunk_h, chunk_w = elevation_chunks + + for gid, (verts, indices) in meshes.items(): + verts = np.asarray(verts, dtype=np.float32) + indices = np.asarray(indices, dtype=np.int32) + + # Color attribute + gg = mg.create_group(gid) + c = colors.get(gid, (0.6, 0.6, 0.6)) + gg.attrs['color'] = list(c) + + if len(indices) == 0: + continue + + # Compute triangle centroids in pixel coords + n_tris = len(indices) // 3 + tri_idx = indices.reshape(n_tris, 3) + verts_xyz = verts.reshape(-1, 3) + + # Centroid of each triangle + cx = (verts_xyz[tri_idx[:, 0], 0] + + verts_xyz[tri_idx[:, 1], 0] + + verts_xyz[tri_idx[:, 2], 0]) / 3.0 + cy = (verts_xyz[tri_idx[:, 0], 1] + + verts_xyz[tri_idx[:, 1], 1] + + verts_xyz[tri_idx[:, 2], 1]) / 3.0 + + # Convert world coords → pixel → chunk indices + px = cx / psx + py = cy / psy + chunk_rows = np.clip(py.astype(np.int64) // chunk_h, 0, + (elevation_shape[0] - 1) // chunk_h) + chunk_cols = np.clip(px.astype(np.int64) // chunk_w, 0, + (elevation_shape[1] - 1) // chunk_w) + + # Unique chunks + chunk_keys = chunk_rows * 100000 + chunk_cols # perfect hash for sane sizes + unique_keys = np.unique(chunk_keys) + + for uk in unique_keys: + cr = int(uk // 100000) + cc = int(uk % 100000) + mask = chunk_keys == uk + + # Extract triangles for this chunk + local_tri_idx = tri_idx[mask] # (n_local, 3) + + # Remap vertex indices — only keep referenced vertices + used_verts = np.unique(local_tri_idx.ravel()) + remap = np.empty(len(verts_xyz), dtype=np.int32) + remap[used_verts] = np.arange(len(used_verts), dtype=np.int32) + local_verts = verts_xyz[used_verts].ravel().astype(np.float32) + local_indices = remap[local_tri_idx.ravel()].astype(np.int32) + + chunk_grp = gg.create_group(f'{cr}_{cc}') + chunk_grp.create_array( + 'vertices', data=local_verts, + chunks=(len(local_verts),), + ) + chunk_grp.create_array( + 'indices', data=local_indices, + chunks=(len(local_indices),), + ) + + # --- Curve geometries (roads, water) --- + if curves: + for gid, (verts, widths, indices) in curves.items(): + verts = np.asarray(verts, dtype=np.float32) + widths = np.asarray(widths, dtype=np.float32) + indices = np.asarray(indices, dtype=np.int32) + + gg = mg.create_group(gid) + c = colors.get(gid, (0.6, 0.6, 0.6)) + gg.attrs['color'] = list(c) + gg.attrs['type'] = 'curve' + + if len(indices) == 0: + continue + + # Partition curve segments by control point centroid + verts_xyz = verts.reshape(-1, 3) + # Each segment starts at indices[i], uses 3 consecutive control + # points (quadratic B-spline). Use the middle control point as + # the spatial key. + mid_idx = np.clip(indices + 1, 0, len(verts_xyz) - 1) + cx = verts_xyz[mid_idx, 0] + cy = verts_xyz[mid_idx, 1] + + px = cx / psx + py = cy / psy + chunk_rows = np.clip(py.astype(np.int64) // chunk_h, 0, + (elevation_shape[0] - 1) // chunk_h) + chunk_cols = np.clip(px.astype(np.int64) // chunk_w, 0, + (elevation_shape[1] - 1) // chunk_w) + + chunk_keys = chunk_rows * 100000 + chunk_cols + unique_keys = np.unique(chunk_keys) + + for uk in unique_keys: + cr = int(uk // 100000) + cc = int(uk % 100000) + mask = chunk_keys == uk + + local_seg_idx = indices[mask] + # Each segment uses 3 control points: [i, i+1, i+2] + used = set() + for si in local_seg_idx: + used.update(range(si, min(si + 3, len(verts_xyz)))) + used_verts = np.array(sorted(used), dtype=np.int32) + remap = np.empty(len(verts_xyz), dtype=np.int32) + remap[used_verts] = np.arange(len(used_verts), dtype=np.int32) + + local_verts = verts_xyz[used_verts].ravel().astype(np.float32) + local_widths = widths[used_verts].astype(np.float32) + local_indices = remap[local_seg_idx].astype(np.int32) + + chunk_grp = gg.create_group(f'{cr}_{cc}') + chunk_grp.create_array( + 'vertices', data=local_verts, + chunks=(len(local_verts),), + ) + chunk_grp.create_array( + 'widths', data=local_widths, + chunks=(len(local_widths),), + ) + chunk_grp.create_array( + 'indices', data=local_indices, + chunks=(len(local_indices),), + ) + + total = len(meshes) + (len(curves) if curves else 0) + print(f"Saved {total} mesh geometries to {zarr_path}/meshes/") + + +def load_meshes_from_zarr(zarr_path, chunks=None): + """Load mesh geometries from a zarr store. + + Parameters + ---------- + zarr_path : str or Path + Path to the zarr store. + chunks : list of (int, int) or None + Specific chunk indices ``[(row, col), ...]`` to load. ``None`` loads + all chunks (full scene). + + Returns + ------- + meshes : dict + ``{geometry_id: (vertices_flat, indices_flat)}`` for triangle meshes. + colors : dict + ``{geometry_id: tuple}``. + meta : dict + ``{pixel_spacing, elevation_shape, elevation_chunks}`` from attrs. + curves : dict + ``{geometry_id: (vertices_flat, widths_flat, indices_flat)}`` for + curve geometries. Empty dict if none stored. + """ + store = zarr.open(str(zarr_path), mode='r', use_consolidated=False) + mg = store['meshes'] + + meta = { + 'pixel_spacing': tuple(mg.attrs['pixel_spacing']), + 'elevation_shape': tuple(mg.attrs['elevation_shape']), + 'elevation_chunks': tuple(mg.attrs['elevation_chunks']), + } + + # Build set of allowed chunk keys for filtering + chunk_set = None + if chunks is not None: + chunk_set = {f'{cr}_{cc}' for cr, cc in chunks} + + meshes = {} + curves = {} + colors = {} + + for gid in mg: + gg = mg[gid] + if not hasattr(gg, 'attrs'): + continue + colors[gid] = tuple(gg.attrs.get('color', (0.6, 0.6, 0.6))) + is_curve = gg.attrs.get('type', '') == 'curve' + + all_verts = [] + all_widths = [] + all_indices = [] + vert_offset = 0 + + # Iterate sub-groups (chunk keys like "0_0", "0_1", ...) + for key in sorted(gg): + if key in ('vertices', 'indices', 'widths'): + continue # skip if somehow present at geometry level + if chunk_set is not None and key not in chunk_set: + continue + + cg = gg[key] + v = np.array(cg['vertices'], dtype=np.float32) + idx = np.array(cg['indices'], dtype=np.int32) + + # Offset indices by accumulated vertex count + idx = idx + vert_offset + all_verts.append(v) + all_indices.append(idx) + vert_offset += len(v) // 3 + + if is_curve and 'widths' in cg: + all_widths.append(np.array(cg['widths'], dtype=np.float32)) + + if all_verts: + cat_verts = np.concatenate(all_verts) + cat_indices = np.concatenate(all_indices) + if is_curve and all_widths: + curves[gid] = (cat_verts, + np.concatenate(all_widths), + cat_indices) + else: + meshes[gid] = (cat_verts, cat_indices) + else: + if is_curve: + curves[gid] = (np.empty(0, dtype=np.float32), + np.empty(0, dtype=np.float32), + np.empty(0, dtype=np.int32)) + else: + meshes[gid] = (np.empty(0, dtype=np.float32), + np.empty(0, dtype=np.int32)) + + return meshes, colors, meta, curves diff --git a/rtxpy/notebook.py b/rtxpy/notebook.py new file mode 100644 index 0000000..c19cb1b --- /dev/null +++ b/rtxpy/notebook.py @@ -0,0 +1,407 @@ +"""Jupyter notebook integration for the interactive terrain viewer. + +Provides ``JupyterViewer``, a subclass of ``InteractiveViewer`` that +renders frames into an ``ipywidgets.Image`` widget with mouse/keyboard +input via ``ipyevents``. All GPU rendering, camera logic, overlays, +and keyboard shortcuts are inherited unchanged. +""" + +import io +import queue +import threading +import time +from typing import Optional, Tuple + +import numpy as np + +from .engine import InteractiveViewer + + +# --------------------------------------------------------------------------- +# Jupyter environment detection +# --------------------------------------------------------------------------- + +def _detect_jupyter() -> bool: + """Return True if running inside a Jupyter kernel.""" + try: + from IPython import get_ipython + ip = get_ipython() + if ip is None: + return False + return ip.__class__.__module__.startswith('ipykernel') + except ImportError: + return False + + +# --------------------------------------------------------------------------- +# Browser key → rtxpy key mapping +# --------------------------------------------------------------------------- + +_BROWSER_SPECIAL_KEYS = { + 'ArrowUp': 'up', + 'ArrowDown': 'down', + 'ArrowLeft': 'left', + 'ArrowRight': 'right', + 'PageUp': 'pageup', + 'PageDown': 'pagedown', + 'Escape': 'escape', + 'Equal': '=', + 'Minus': '-', + 'Comma': ',', + 'Period': '.', + 'BracketLeft': '[', + 'BracketRight': ']', + 'Semicolon': ';', + 'Quote': "'", +} + +# Shift variants for special keys (mirrors _glfw_to_key) +_SHIFT_SPECIAL = { + '=': '+', + '-': '_', + ';': ':', + "'": '"', +} + + +def _map_browser_key(event: dict) -> Tuple[str, str]: + """Convert a browser keyboard event to (raw_key, key_lower). + + Returns ('', '') for unmapped keys. + """ + key = event.get('key', '') + code = event.get('code', '') + shift = event.get('shiftKey', False) + + # Special keys (arrows, page up/down, etc.) + if code in _BROWSER_SPECIAL_KEYS: + raw = _BROWSER_SPECIAL_KEYS[code] + if shift and raw in _SHIFT_SPECIAL: + raw = _SHIFT_SPECIAL[raw] + return raw, raw.lower() + + # Letter keys + if len(key) == 1 and key.isalpha(): + lower = key.lower() + if shift: + return lower.upper(), lower + return lower, lower + + # Digit keys + if len(key) == 1 and key.isdigit(): + return key, key + + # Direct single-char keys (+, -, etc.) + if len(key) == 1: + return key, key.lower() + + return '', '' + + +# --------------------------------------------------------------------------- +# JavaScript for keyboard/scroll isolation +# --------------------------------------------------------------------------- + +# Injected into the notebook output to prevent keyboard events from +# reaching the notebook's own shortcut handlers (cell navigation, +# command-mode shortcuts like A/B/C/D/H/M/X, arrow scrolling, etc.). +# Also prevents scroll-wheel from scrolling the notebook page. +# +# When the widget is focused (blue border), all keyboard and wheel +# events are captured exclusively by the viewer. +_KEYBOARD_CAPTURE_JS = """ + +""" + + +# --------------------------------------------------------------------------- +# JupyterViewer +# --------------------------------------------------------------------------- + +class JupyterViewer(InteractiveViewer): + """Interactive terrain viewer for Jupyter notebooks. + + Inherits all rendering, camera, and input logic from + ``InteractiveViewer``. Overrides ``run()`` to display frames in + an ``ipywidgets.Image`` widget instead of a GLFW window. + """ + + def _render_help_text(self): + """Render help text, scaling to fit the render resolution.""" + super()._render_help_text() + if self._help_text_rgba is None: + return + ht = self._help_text_rgba + hh, hw = ht.shape[:2] + max_w = self.render_width - 16 + max_h = self.render_height - 16 + if hw > max_w or hh > max_h: + scale = min(max_w / hw, max_h / hh) + new_w = max(1, int(hw * scale)) + new_h = max(1, int(hh * scale)) + from PIL import Image + img = Image.fromarray((ht * 255).astype(np.uint8), 'RGBA') + img = img.resize((new_w, new_h), Image.LANCZOS) + self._help_text_rgba = np.array(img, dtype=np.float32) / 255.0 + + def run(self, start_position: Optional[Tuple[float, float, float]] = None, + look_at: Optional[Tuple[float, float, float]] = None): + """Start the viewer and return an interactive widget. + + Parameters + ---------- + start_position : tuple, optional + Starting camera position (x, y, z). + look_at : tuple, optional + Initial look-at point. + + Returns + ------- + ipywidgets.Image + The widget displaying rendered frames. The widget has a + ``_viewer`` attribute pointing back to this viewer and a + ``stop()`` method to shut down the render thread. + """ + import ipywidgets as widgets + from ipyevents import Event + from IPython.display import display, HTML + + H, W = self.terrain_shape + world_W = W * self.pixel_spacing_x + world_H = H * self.pixel_spacing_y + world_diag = np.sqrt(world_W**2 + world_H**2) + + if self.move_speed is None: + self.move_speed = world_diag * 0.01 + + if start_position is None: + start_position = ( + world_W / 2, + world_H * 1.05, + self.elev_max + world_diag * 0.08, + ) + + self.position = np.array(start_position, dtype=np.float32) + + if look_at is not None: + direction = np.array(look_at) - self.position + direction = direction / (np.linalg.norm(direction) + 1e-8) + self.yaw = np.degrees(np.arctan2(direction[1], direction[0])) + self.pitch = np.degrees(np.arcsin(np.clip(direction[2], -1, 1))) + else: + center = np.array([world_W / 2, world_H / 2, self.elev_mean]) + direction = center - self.position + direction = direction / (np.linalg.norm(direction) + 1e-8) + self.yaw = np.degrees(np.arctan2(direction[1], direction[0])) + self.pitch = np.degrees(np.arcsin(np.clip(direction[2], -1, 1))) + + # Pre-render overlays (help text scales to fit render resolution) + self._render_help_text() + self._compute_minimap_background() + + # --- Widget setup --- + self._widget = widgets.Image( + format='jpeg', + width=self.width, + height=self.height, + ) + self._widget.add_class('rtxpy-viewer') + self._widget._viewer = self + self._widget.stop = self.stop + + # --- Input event handling --- + self._input_queue = queue.Queue(maxsize=200) + + event_handler = Event( + source=self._widget, + watched_events=[ + 'keydown', 'keyup', + 'mousedown', 'mouseup', 'mousemove', + 'wheel', + ], + prevent_default_action=True, + wait=8, + ) + event_handler.on_dom_event(self._handle_dom_event) + self._event_handler = event_handler + + # --- State --- + self.running = True + self._display_frame = None + self._frame_dirty = False + self._render_needed = True + self._fps_counter = 0 + self._fps_last_time = time.monotonic() + self._last_tick_time = time.monotonic() + + # Render first frame synchronously so widget isn't blank + self._tick() + self._push_frame() + + # --- Background render thread --- + self._render_thread = threading.Thread( + target=self._jupyter_render_loop, + daemon=True, + name='rtxpy-jupyter-render', + ) + self._render_thread.start() + + print(f"rtxpy Jupyter viewer started ({self.width}x{self.height})") + print("Click the image to focus (blue border), then use keyboard/mouse.") + print("Press H for controls. Call widget.stop() to exit.") + + # Display widget + keyboard isolation JavaScript + display(self._widget) + display(HTML(_KEYBOARD_CAPTURE_JS)) + return self._widget + + def stop(self): + """Stop the render thread and release resources.""" + self.running = False + if hasattr(self, '_render_thread') and self._render_thread.is_alive(): + self._render_thread.join(timeout=2.0) + if self._tile_service is not None: + self._tile_service.shutdown() + + # --- DOM event handler (runs in Jupyter comm thread) --- + + def _handle_dom_event(self, event): + """Queue a browser DOM event for processing by the render thread.""" + try: + self._input_queue.put_nowait(event) + except queue.Full: + pass # drop event if queue is full + + # --- Render loop (background thread) --- + + def _jupyter_render_loop(self): + """Background thread: process input, tick, push frames.""" + target_period = 1.0 / 15 # ~15 FPS display rate + + while self.running: + loop_start = time.monotonic() + + # Drain input queue + while True: + try: + event = self._input_queue.get_nowait() + except queue.Empty: + break + self._dispatch_dom_event(event) + + # Process REPL command queue + while True: + try: + cmd = self._command_queue.get_nowait() + except queue.Empty: + break + try: + cmd(self) + except Exception: + import traceback + traceback.print_exc() + self._render_needed = True + + # Tick (movement, rendering) + self._tick() + + # Push frame to widget if dirty + if self._frame_dirty and self._display_frame is not None: + self._push_frame() + self._frame_dirty = False + + # Sleep to maintain target frame rate + elapsed = time.monotonic() - loop_start + sleep_time = target_period - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + elif not self._held_keys and not self._mouse_dragging: + time.sleep(0.008) + + def _push_frame(self): + """Encode the current display frame as JPEG and update the widget.""" + frame = self._display_frame + if frame is None: + return + + from PIL import Image + + img_uint8 = (np.clip(frame, 0, 1) * 255).astype(np.uint8) + buf = io.BytesIO() + Image.fromarray(img_uint8).save(buf, format='JPEG', quality=85) + self._widget.value = buf.getvalue() + + # --- DOM event dispatch --- + + def _dispatch_dom_event(self, event): + """Route a browser DOM event to the appropriate handler.""" + etype = event.get('type', '') + + if etype == 'keydown': + raw_key, key_lower = _map_browser_key(event) + if raw_key: + self._handle_key_press(raw_key, key_lower) + + elif etype == 'keyup': + raw_key, key_lower = _map_browser_key(event) + if key_lower: + self._handle_key_release(key_lower) + + elif etype == 'mousedown': + button = event.get('button', 0) + x = event.get('offsetX', 0) + y = event.get('offsetY', 0) + self._handle_mouse_press(button, x, y) + + elif etype == 'mouseup': + button = event.get('button', 0) + self._handle_mouse_release(button) + + elif etype == 'mousemove': + x = event.get('offsetX', 0) + y = event.get('offsetY', 0) + self._handle_mouse_motion(x, y) + + elif etype == 'wheel': + dy = event.get('deltaY', 0) + # Browser deltaY: positive = scroll down + if dy != 0: + self._handle_scroll(-1 if dy > 0 else 1) diff --git a/rtxpy/quickstart.py b/rtxpy/quickstart.py index fca71a2..088129e 100644 --- a/rtxpy/quickstart.py +++ b/rtxpy/quickstart.py @@ -98,15 +98,19 @@ def quickstart( cacheable = {k: v for k, v in feat.items() if k not in _TEMPORAL} temporal = {k: v for k, v in feat.items() if k in _TEMPORAL} - # Check for mesh cache in zarr + # Check for mesh cache in zarr and which features it contains has_cache = False + cached_features = set() if cacheable: try: import zarr as _zarr store = _zarr.open(str(zarr_path), mode='r', use_consolidated=False) - has_cache = ('meshes' in store - and len(list(store['meshes'])) > 0) + if 'meshes' in store and len(list(store['meshes'])) > 0: + has_cache = True + # Read which feature keys were stored with this cache + cached_features = set( + store['meshes'].attrs.get('feature_keys', [])) del store except Exception: pass @@ -116,10 +120,40 @@ def quickstart( message="place_geojson called before") if has_cache: ds.rtx.load_meshes(zarr_path) + # Place any cacheable features missing from the cache. + # If the cache pre-dates feature_keys tracking, we only + # know which features were loaded by checking the scene's + # geometry IDs against known feature prefixes. + if cached_features: + missing = {k: v for k, v in cacheable.items() + if k not in cached_features} + else: + # Legacy cache without feature_keys metadata — + # check which features actually have geometries loaded + loaded_gids = set(ds.rtx._get_terrain_da( + list(ds.data_vars)[0]).rtx._rtx.list_geometries()) + missing = {k: v for k, v in cacheable.items() + if not _has_geometry_for_feature(k, loaded_gids)} + if missing: + names = ', '.join(missing) + print(f"Placing features not in cache: {names}") + _place_features(ds, missing, name, bounds, crs, cache_dir) + try: + ds.rtx.save_meshes(zarr_path) + _save_feature_keys(zarr_path, cacheable.keys()) + except Exception as e: + print(f"Could not update mesh cache: {e}") + elif not cached_features: + # Legacy cache is complete — stamp it with feature_keys + try: + _save_feature_keys(zarr_path, cacheable.keys()) + except Exception: + pass elif cacheable: _place_features(ds, cacheable, name, bounds, crs, cache_dir) try: ds.rtx.save_meshes(zarr_path) + _save_feature_keys(zarr_path, cacheable.keys()) except Exception as e: print(f"Could not save mesh cache: {e}") @@ -127,6 +161,25 @@ def quickstart( if temporal: _place_features(ds, temporal, name, bounds, crs, cache_dir) + # -- gtfs realtime -------------------------------------------------------- + # When loading from cache, _place_gtfs() is never called, so _gtfs_data + # is not set. Fetch the GTFS data directly for the realtime overlay. + gtfs_data = ds.attrs.pop('_gtfs_data', None) + if gtfs_data is None and 'gtfs' in feat: + gtfs_opts = feat['gtfs'] + try: + from .remote_data import fetch_gtfs as _fetch_gtfs + gtfs_data = _fetch_gtfs( + bounds=bounds, + feed_url=gtfs_opts.get('feed_url'), + gtfs_path=gtfs_opts.get('gtfs_path'), + route_types=gtfs_opts.get('route_types'), + cache_path=cache_dir / f"{name}_gtfs.json", + crs=crs, + realtime_url=gtfs_opts.get('realtime_url')) + except Exception as e: + print(f"Skipping GTFS realtime: {e}") + # -- wind ----------------------------------------------------------------- wind_data = None if wind: @@ -143,9 +196,6 @@ def quickstart( ) defaults.update(explore_kwargs) - # -- gtfs realtime ---------------------------------------------------------- - gtfs_data = ds.attrs.pop('_gtfs_data', None) - print("\nLaunching explore...\n") ds.rtx.explore(z='elevation', scene_zarr=zarr_path, wind_data=wind_data, gtfs_data=gtfs_data, **defaults) @@ -155,6 +205,32 @@ def quickstart( # Internal helpers # --------------------------------------------------------------------------- +def _has_geometry_for_feature(feature_key, loaded_gids): + """Check whether any geometry ID in the scene matches *feature_key*.""" + _PREFIXES = { + 'buildings': ('building_',), + 'roads': ('road_major', 'road_minor'), + 'water': ('water_',), + 'places': ('places',), + 'infrastructure': ('infrastructure',), + 'land_use': ('land_use',), + 'restaurant_grades': ('grade_a', 'grade_b', 'grade_c'), + 'gtfs': ('gtfs_',), + } + prefixes = _PREFIXES.get(feature_key, (feature_key,)) + return any( + any(gid == p or gid.startswith(p) for p in prefixes) + for gid in loaded_gids + ) + + +def _save_feature_keys(zarr_path, keys): + """Store feature keys in the zarr meshes group attributes.""" + import zarr as _zarr + store = _zarr.open(str(zarr_path), mode='r+', use_consolidated=False) + store['meshes'].attrs['feature_keys'] = sorted(keys) + + def _parse_features(features): """Normalize *features* to ``{key: {opts}}`` dict.""" if features is None: diff --git a/rtxpy/rtx.py b/rtxpy/rtx.py index b04f0d2..c56da70 100644 --- a/rtxpy/rtx.py +++ b/rtxpy/rtx.py @@ -2,7 +2,7 @@ RTXpy - Ray tracing using NVIDIA OptiX, accessible from Python. This module provides GPU-accelerated ray-triangle intersection using -NVIDIA's OptiX ray tracing engine via the otk-pyoptix Python bindings. +NVIDIA's OptiX ray tracing engine via the pyoptix-contrib Python bindings. """ import os @@ -14,6 +14,7 @@ # CRITICAL: cupy must be imported before optix for proper CUDA context sharing import cupy has_cupy = True +cupy.cuda.set_pinned_memory_allocator(cupy.cuda.PinnedMemoryPool().malloc) import optix @@ -39,6 +40,8 @@ class _GASEntry: visible: bool = True num_vertices: int = 0 num_triangles: int = 0 + is_curve: bool = False # True for round curve tube GAS + is_heightfield: bool = False # True for heightfield custom primitive GAS # ----------------------------------------------------------------------------- @@ -62,15 +65,46 @@ def __init__(self): self.raygen_pg = None self.miss_pg = None self.hit_pg = None + self.curve_hit_pg = None # Hit group for round curve tubes + self.curve_module = None # Built-in IS module for curves + self.heightfield_hit_pg = None # Hit group for heightfield custom primitives self.sbt = None # Device memory for params (shared, overwritten before each trace) self.d_params = None + # Capability / version info (populated during _init_optix) + self.capabilities = None + + # Denoiser state + self.denoiser = None + self.d_denoiser_state = None + self.d_denoiser_scratch = None + self.d_denoiser_normals = None + self.d_denoiser_output = None + self.d_denoiser_albedo = None + self.d_denoiser_flow = None + self._denoiser_temporal = False + self.denoiser_width = 0 + self.denoiser_height = 0 + self._denoiser_failed = False + self.initialized = False def cleanup(self): """Release all OptiX and CUDA resources.""" + # Destroy denoiser + if self.denoiser is not None: + self.denoiser.destroy() + self.denoiser = None + self.d_denoiser_state = None + self.d_denoiser_scratch = None + self.d_denoiser_normals = None + self.d_denoiser_output = None + self.denoiser_width = 0 + self.denoiser_height = 0 + self._denoiser_failed = False + # Reset device tracking self.device_id = None @@ -80,6 +114,9 @@ def cleanup(self): # OptiX objects are automatically cleaned up by Python GC self.sbt = None self.pipeline = None + self.heightfield_hit_pg = None + self.curve_hit_pg = None + self.curve_module = None self.hit_pg = None self.miss_pg = None self.raygen_pg = None @@ -120,6 +157,16 @@ def __init__(self): self.instances_buffer = None self.single_gas_mode = True # False when multi-GAS active + # Heightfield state + self.heightfield_data = None # GPU buffer (cupy) for elevation array + self.hf_width = 0 + self.hf_height = 0 + self.hf_spacing_x = 0.0 + self.hf_spacing_y = 0.0 + self.hf_ve = 1.0 + self.hf_tile_size = 32 + self.hf_num_tiles_x = 0 + # Device buffers for CPU->GPU transfers (per-instance) self.d_rays = None self.d_rays_size = 0 @@ -140,6 +187,16 @@ def clear(self): self.gas_buffer = None self.current_hash = 0xFFFFFFFFFFFFFFFF + # Clear heightfield state + self.heightfield_data = None + self.hf_width = 0 + self.hf_height = 0 + self.hf_spacing_x = 0.0 + self.hf_spacing_y = 0.0 + self.hf_ve = 1.0 + self.hf_tile_size = 32 + self.hf_num_tiles_x = 0 + # Reset to single-GAS mode self.single_gas_mode = True @@ -251,6 +308,121 @@ def get_current_device() -> Optional[int]: return _state.device_id if _state.initialized else None +def _detect_capabilities(context) -> dict: + """Detect OptiX and hardware capabilities after context creation.""" + optix_version = optix.version() # (major, minor, micro) + + rtcore_version = context.getProperty( + optix.DEVICE_PROPERTY_RTCORE_VERSION + ) + + # CUDA driver version (e.g. 12080 → 12.8) + driver_version_int = cupy.cuda.runtime.driverGetVersion() + cuda_major = driver_version_int // 1000 + cuda_minor = (driver_version_int % 1000) // 10 + + # GPU compute capability + dev = cupy.cuda.Device() + props = cupy.cuda.runtime.getDeviceProperties(dev.id) + cc_major = props['major'] + cc_minor = props['minor'] + + # NVIDIA driver version from nvidia-smi (best-effort) + nvidia_driver = 'unknown' + try: + import subprocess + result = subprocess.run( + ['nvidia-smi', '--query-gpu=driver_version', + '--format=csv,noheader', '-i', str(dev.id)], + capture_output=True, text=True, timeout=5, + ) + if result.returncode == 0: + nvidia_driver = result.stdout.strip().split('\n')[0] + except Exception: + pass + + # Feature flags — prefer runtime device property queries (OptiX 9.1+) + optix_major = optix_version[0] if isinstance(optix_version, tuple) else 0 + has_optix9 = optix_major >= 9 + is_blackwell = cc_major >= 10 # sm_100+ = Blackwell + + # Query runtime cluster/coopvec support when OptiX 9+ + has_clusters = False + has_cooperative_vectors = False + cluster_limits = {} + if has_optix9 and hasattr(optix, 'DEVICE_PROPERTY_CLUSTER_ACCEL'): + try: + cluster_flags = context.getProperty( + optix.DEVICE_PROPERTY_CLUSTER_ACCEL) + has_clusters = bool( + cluster_flags + & int(optix.DEVICE_PROPERTY_CLUSTER_ACCEL_FLAG_STANDARD)) + if has_clusters: + cluster_limits = { + 'max_cluster_vertices': context.getProperty( + optix.DEVICE_PROPERTY_LIMIT_MAX_CLUSTER_VERTICES), + 'max_cluster_triangles': context.getProperty( + optix.DEVICE_PROPERTY_LIMIT_MAX_CLUSTER_TRIANGLES), + 'max_clusters_per_gas': context.getProperty( + optix.DEVICE_PROPERTY_LIMIT_MAX_CLUSTERS_PER_GAS), + } + except Exception: + pass + if has_optix9 and hasattr(optix, 'DEVICE_PROPERTY_COOP_VEC'): + try: + coop_flags = context.getProperty(optix.DEVICE_PROPERTY_COOP_VEC) + has_cooperative_vectors = bool( + coop_flags + & int(optix.DEVICE_PROPERTY_COOP_VEC_FLAG_STANDARD)) + except Exception: + pass + + return { + 'optix_version': optix_version, + 'optix_version_str': '.'.join(str(x) for x in optix_version) + if isinstance(optix_version, tuple) else str(optix_version), + 'rtcore_version': rtcore_version, + 'cuda_driver': f'{cuda_major}.{cuda_minor}', + 'nvidia_driver': nvidia_driver, + 'compute_capability': (cc_major, cc_minor), + 'gpu_name': (props['name'].decode('utf-8') + if isinstance(props['name'], bytes) else props['name']), + # Feature flags (runtime-detected) + 'has_clusters': has_clusters, + 'has_cooperative_vectors': has_cooperative_vectors, + 'has_hw_linear_curves': is_blackwell, + 'has_rocaps_curves': has_optix9, + 'has_round_quadratic_bspline': True, # always (OptiX 7.4+) + # Cluster limits (only populated when has_clusters=True) + **cluster_limits, + } + + +def get_capabilities() -> Optional[dict]: + """ + Get OptiX and hardware capability information. + + Returns None if RTX has not been initialized yet. Otherwise returns + a dict with keys: + + - ``optix_version``: tuple (major, minor, micro) + - ``optix_version_str``: e.g. ``'7.7.0'`` + - ``rtcore_version``: RT Core generation (e.g. 20 = 2nd gen) + - ``cuda_driver``: CUDA runtime driver version string + - ``nvidia_driver``: NVIDIA display driver version string + - ``compute_capability``: tuple (major, minor) + - ``gpu_name``: GPU device name string + - ``has_clusters``: OptiX 9+ cluster/mega-geometry BVH + - ``has_cooperative_vectors``: OptiX 9+ Tensor Core access (Blackwell) + - ``has_hw_linear_curves``: hardware-accelerated linear curves (Blackwell) + - ``has_rocaps_curves``: software Rocaps curve intersector (OptiX 9+) + - ``has_round_quadratic_bspline``: round curve tubes (always True) + """ + if not _state.initialized or _state.capabilities is None: + return None + return dict(_state.capabilities) + + # ----------------------------------------------------------------------------- # PTX loading # ----------------------------------------------------------------------------- @@ -328,23 +500,49 @@ def _init_optix(device: Optional[int] = None): ) ) + # Detect capabilities now that context exists + _state.capabilities = _detect_capabilities(_state.context) + caps = _state.capabilities + print(f"OptiX {caps['optix_version_str']} | " + f"RT Core {caps['rtcore_version']} | " + f"{caps['gpu_name']} (sm_{caps['compute_capability'][0]}" + f"{caps['compute_capability'][1]}) | " + f"Driver {caps['nvidia_driver']}") + # Load PTX and create module ptx_data = _load_ptx_file("kernel.ptx") + # Payload semantics: raygen reads after trace, CH+MS write, AH/IS unused + _sem = (int(optix.PAYLOAD_SEMANTICS_TRACE_CALLER_READ) + | int(optix.PAYLOAD_SEMANTICS_CH_WRITE) + | int(optix.PAYLOAD_SEMANTICS_MS_WRITE)) + payload_type = optix.PayloadType(payloadSemantics=[_sem] * 6) + module_options = optix.ModuleCompileOptions( maxRegisterCount=optix.COMPILE_DEFAULT_MAX_REGISTER_COUNT, optLevel=optix.COMPILE_OPTIMIZATION_DEFAULT, debugLevel=optix.COMPILE_DEBUG_LEVEL_MINIMAL, + payloadTypes=[payload_type], ) - pipeline_options = optix.PipelineCompileOptions( + _pco_kwargs = dict( usesMotionBlur=False, traversableGraphFlags=optix.TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY, numPayloadValues=6, # t, nx, ny, nz, primitive_id, instance_id numAttributeValues=2, exceptionFlags=optix.EXCEPTION_FLAG_NONE, pipelineLaunchParamsVariableName="params", + usesPrimitiveTypeFlags=( + optix.PRIMITIVE_TYPE_FLAGS_TRIANGLE + | optix.PRIMITIVE_TYPE_FLAGS_CUSTOM + | (optix.PRIMITIVE_TYPE_FLAGS_ROUND_QUADRATIC_BSPLINE_ROCAPS + if _state.capabilities.get('has_rocaps_curves') + else optix.PRIMITIVE_TYPE_FLAGS_ROUND_QUADRATIC_BSPLINE) + ), ) + if _state.capabilities.get('has_clusters'): + _pco_kwargs['allowClusteredGeometry'] = 1 + pipeline_options = optix.PipelineCompileOptions(**_pco_kwargs) _state.module, log = _state.context.moduleCreate( module_options, @@ -375,7 +573,7 @@ def _init_optix(device: Optional[int] = None): ) _state.miss_pg = _state.miss_pg[0] - # Hit group (closest hit only) + # Hit group (closest hit only — triangles) hit_desc = optix.ProgramGroupDesc() hit_desc.hitgroupModuleCH = _state.module hit_desc.hitgroupEntryFunctionNameCH = "__closesthit__chit" @@ -385,12 +583,52 @@ def _init_optix(device: Optional[int] = None): ) _state.hit_pg = _state.hit_pg[0] + # Built-in IS module for curves (Rocaps if available, else standard) + _curve_prim_type = ( + optix.PRIMITIVE_TYPE_ROUND_QUADRATIC_BSPLINE_ROCAPS + if _state.capabilities.get('has_rocaps_curves') + else optix.PRIMITIVE_TYPE_ROUND_QUADRATIC_BSPLINE + ) + _curve_is_options = optix.BuiltinISOptions( + builtinISModuleType=_curve_prim_type, + usesMotionBlur=False, + ) + _state.curve_module = _state.context.builtinISModuleGet( + module_options, + pipeline_options, + _curve_is_options, + ) + + # Hit group for curves (same closest-hit, built-in IS for intersection) + curve_hit_desc = optix.ProgramGroupDesc() + curve_hit_desc.hitgroupModuleCH = _state.module + curve_hit_desc.hitgroupEntryFunctionNameCH = "__closesthit__chit" + curve_hit_desc.hitgroupModuleIS = _state.curve_module + _state.curve_hit_pg, log = _state.context.programGroupCreate( + [curve_hit_desc], + pg_options, + ) + _state.curve_hit_pg = _state.curve_hit_pg[0] + + # Hit group for heightfield custom primitives (custom IS + dedicated CH) + hf_hit_desc = optix.ProgramGroupDesc() + hf_hit_desc.hitgroupModuleCH = _state.module + hf_hit_desc.hitgroupEntryFunctionNameCH = "__closesthit__heightfield" + hf_hit_desc.hitgroupModuleIS = _state.module + hf_hit_desc.hitgroupEntryFunctionNameIS = "__intersection__heightfield" + _state.heightfield_hit_pg, log = _state.context.programGroupCreate( + [hf_hit_desc], + pg_options, + ) + _state.heightfield_hit_pg = _state.heightfield_hit_pg[0] + # Create pipeline link_options = optix.PipelineLinkOptions( maxTraceDepth=1, ) - program_groups = [_state.raygen_pg, _state.miss_pg, _state.hit_pg] + program_groups = [_state.raygen_pg, _state.miss_pg, _state.hit_pg, + _state.curve_hit_pg, _state.heightfield_hit_pg] _state.pipeline = _state.context.pipelineCreate( pipeline_options, link_options, @@ -420,8 +658,8 @@ def _init_optix(device: Optional[int] = None): # Create shader binding table _create_sbt() - # Allocate params buffer (24 bytes: handle(8) + rays_ptr(8) + hits_ptr(8)) - _state.d_params = cupy.zeros(40, dtype=cupy.uint8) # 5 pointers * 8 bytes + # Allocate params buffer: 48 (existing) + 40 (heightfield fields) = 88 + _state.d_params = cupy.zeros(88, dtype=cupy.uint8) _state.initialized = True atexit.register(_cleanup_at_exit) @@ -445,10 +683,19 @@ def _create_sbt(): optix.sbtRecordPackHeader(_state.miss_pg, miss_record) d_miss = cupy.array(np.frombuffer(miss_record, dtype=np.uint8)) - # Pack hit group record + # Pack hit group records: [0] = triangles, [1] = curves, [2] = heightfield hit_record = bytearray(header_size) optix.sbtRecordPackHeader(_state.hit_pg, hit_record) - d_hit = cupy.array(np.frombuffer(hit_record, dtype=np.uint8)) + + curve_hit_record = bytearray(header_size) + optix.sbtRecordPackHeader(_state.curve_hit_pg, curve_hit_record) + + hf_hit_record = bytearray(header_size) + optix.sbtRecordPackHeader(_state.heightfield_hit_pg, hf_hit_record) + + # Concatenate all hit records into a single buffer + hit_all = bytearray(hit_record) + bytearray(curve_hit_record) + bytearray(hf_hit_record) + d_hit = cupy.array(np.frombuffer(hit_all, dtype=np.uint8)) _state.sbt = optix.ShaderBindingTable( raygenRecord=d_raygen.data.ptr, @@ -457,7 +704,7 @@ def _create_sbt(): missRecordCount=1, hitgroupRecordBase=d_hit.data.ptr, hitgroupRecordStrideInBytes=header_size, - hitgroupRecordCount=1, + hitgroupRecordCount=3, ) # Keep references to prevent garbage collection @@ -567,6 +814,549 @@ def _build_gas_for_geometry(vertices, indices): return gas_handle, gas_buffer +def _build_gas_clustered(vertices, indices, grid_H, grid_W): + """ + Build a GAS via OptiX 9 Cluster Acceleration Structures (CLAS). + + The terrain grid is partitioned into spatial blocks of up to BLOCK×BLOCK + cells. Each block becomes one CLAS (cluster). All clusters are then + assembled into a single GAS. + + Args: + vertices: Vertex buffer (Nx3 float32, flattened) — already on GPU or host. + indices: Index buffer (Mx3 int32, flattened). + grid_H: Number of vertex rows in the terrain grid. + grid_W: Number of vertex columns in the terrain grid. + + Returns: + Tuple of (gas_handle, gas_buffer) or (0, None) on error. + """ + global _state + + if not _state.initialized: + _init_optix() + + d_vertices = (vertices if isinstance(vertices, cupy.ndarray) + else cupy.asarray(vertices, dtype=cupy.float32)) + d_indices = (indices if isinstance(indices, cupy.ndarray) + else cupy.asarray(indices, dtype=cupy.int32)) + + num_vertices = d_vertices.size // 3 + num_triangles = d_indices.size // 3 + if num_vertices == 0 or num_triangles == 0: + return 0, None + + # -- Partition grid into spatial blocks -------------------------------- + max_cluster_v = min( + _state.capabilities.get('max_cluster_vertices', 256), 256) + max_cluster_t = min( + _state.capabilities.get('max_cluster_triangles', 256), 256) + + # Largest BLOCK so that (BLOCK+1)^2 ≤ max_verts AND 2*BLOCK^2 ≤ max_tris + import math + block_v = int(math.isqrt(max_cluster_v)) - 1 # vertices side + block_t = int(math.isqrt(max_cluster_t // 2)) # triangles side + BLOCK = max(1, min(block_v, block_t)) # cells per side + + cell_rows = grid_H - 1 + cell_cols = grid_W - 1 + blocks_r = (cell_rows + BLOCK - 1) // BLOCK + blocks_c = (cell_cols + BLOCK - 1) // BLOCK + num_clusters = blocks_r * blocks_c + + if num_clusters == 0: + return _build_gas_for_geometry(vertices, indices) + + # -- Build per-cluster Args structs on host, then upload --------------- + # TrianglesArgs layout (72 bytes — see optix_types.h): + # 0: clusterId u32 + # 4: clusterFlags u32 + # 8: packed bitfield u32 (triCount:9|vertCount:9|truncBits:6|idxFmt:4|ommFmt:4) + # 12: basePrimitiveInfo u32 (sbtIndex:24|reserved:5|primFlags:3) + # 16: indexStride u16 + # 18: vertexStride u16 + # 20: primInfoStride u16 + # 22: ommIdxStride u16 + # 24: indexBuffer u64 + # 32: vertexBuffer u64 + # 40: primitiveInfoBuffer u64 + # 48: opacityMicromapArray u64 + # 56: opacityMicromapIdxBuf u64 + # 64: instBBoxLimit u64 + ARGS_SIZE = 72 + + # We'll build small index sub-buffers per cluster (re-indexed) + # and collect them all into one large GPU buffer. + args_host = np.zeros(num_clusters * ARGS_SIZE, dtype=np.uint8) + h_indices = (d_indices.get() if isinstance(d_indices, cupy.ndarray) + else np.asarray(d_indices, dtype=np.int32)) + h_indices = h_indices.reshape(-1, 3) + + # Per-cluster re-indexed index arrays (host) + cluster_index_arrays = [] + max_tri_per_cluster = 0 + max_vert_per_cluster = 0 + + for br in range(blocks_r): + for bc in range(blocks_c): + cid = br * blocks_c + bc + r0 = br * BLOCK + c0 = bc * BLOCK + r1 = min(r0 + BLOCK, cell_rows) + c1 = min(c0 + BLOCK, cell_cols) + bH = r1 - r0 # cells in this block + bW = c1 - c0 + + # Vertex range for this block: rows [r0..r1] × cols [c0..c1] + v_min = r0 * grid_W + c0 + v_rows = bH + 1 + v_cols = bW + 1 + v_count = v_rows * v_cols + + # Gather triangle indices for this block + tri_list = [] + for lr in range(bH): + for lc in range(bW): + gr = r0 + lr + gc = c0 + lc + tri_idx = (gr * cell_cols + gc) * 2 + tri_list.append(tri_idx) + tri_list.append(tri_idx + 1) + tri_count = len(tri_list) + + # Re-index: map global vertex ids → local [0..v_count) + local_indices = np.empty(tri_count * 3, dtype=np.int32) + for i, ti in enumerate(tri_list): + for k in range(3): + gv = h_indices[ti, k] + # Map from global (row*W+col) to local block coords + gv_row = gv // grid_W - r0 + gv_col = gv % grid_W - c0 + local_indices[i * 3 + k] = gv_row * v_cols + gv_col + + cluster_index_arrays.append(local_indices) + max_tri_per_cluster = max(max_tri_per_cluster, tri_count) + max_vert_per_cluster = max(max_vert_per_cluster, v_count) + + # Pack the Args struct for this cluster + off = cid * ARGS_SIZE + # clusterId + struct.pack_into('= c1 or r0 >= r1: + # Degenerate tile — set a zero-volume AABB + base = tile_idx * 6 + aabbs[base:base + 6] = 0.0 + continue + + # Extract elevation tile (include +1 for cell corners) + tile_elev = elev_np[r0:r1 + 1, c0:c1 + 1] + valid = tile_elev[~np.isnan(tile_elev)] + if valid.size == 0: + z_min = 0.0 + z_max = 0.0 + else: + z_min = float(valid.min()) + z_max = float(valid.max()) + + z_min *= ve + z_max *= ve + + base = tile_idx * 6 + aabbs[base + 0] = c0 * spacing_x # min_x + aabbs[base + 1] = r0 * spacing_y # min_y + aabbs[base + 2] = z_min - eps # min_z + aabbs[base + 3] = c1 * spacing_x # max_x + aabbs[base + 4] = r1 * spacing_y # max_y + aabbs[base + 5] = z_max + eps # max_z + + d_aabbs = cupy.asarray(aabbs) + + # Build custom primitive GAS + build_input = optix.BuildInputCustomPrimitiveArray() + build_input.aabbBuffers = [d_aabbs.data.ptr] + build_input.numPrimitives = num_tiles + build_input.strideInBytes = 24 # 6 floats + build_input.flags = [optix.GEOMETRY_FLAG_DISABLE_ANYHIT] + build_input.numSbtRecords = 1 + + accel_options = optix.AccelBuildOptions( + buildFlags=optix.BUILD_FLAG_ALLOW_COMPACTION, + operation=optix.BUILD_OPERATION_BUILD, + ) + + buffer_sizes = _state.context.accelComputeMemoryUsage( + [accel_options], + [build_input], + ) + + d_temp = cupy.zeros(buffer_sizes.tempSizeInBytes, dtype=cupy.uint8) + gas_buffer = cupy.zeros(buffer_sizes.outputSizeInBytes, dtype=cupy.uint8) + compacted_size_buffer = cupy.zeros(1, dtype=cupy.uint64) + + gas_handle = _state.context.accelBuild( + 0, + [accel_options], + [build_input], + d_temp.data.ptr, + buffer_sizes.tempSizeInBytes, + gas_buffer.data.ptr, + buffer_sizes.outputSizeInBytes, + [optix.AccelEmitDesc(compacted_size_buffer.data.ptr, + optix.PROPERTY_TYPE_COMPACTED_SIZE)], + ) + + cupy.cuda.Stream.null.synchronize() + + compacted_size = int(compacted_size_buffer[0]) + if compacted_size < gas_buffer.nbytes: + compacted_buffer = cupy.zeros(compacted_size, dtype=cupy.uint8) + gas_handle = _state.context.accelCompact( + 0, + gas_handle, + compacted_buffer.data.ptr, + compacted_size, + ) + gas_buffer = compacted_buffer + + return gas_handle, gas_buffer, d_elevation, num_tiles_x, num_tiles_y + + def _build_ias(geom_state: _GeometryState): """ Build an Instance Acceleration Structure (IAS) from all GAS entries. @@ -613,8 +1403,14 @@ def _build_ias(geom_state: _GeometryState): # Pack instanceId (4 bytes) struct.pack_into('I', instances_data, offset + 48, i) - # Pack sbtOffset (4 bytes) - all use same hit group (SBT index 0) - struct.pack_into('I', instances_data, offset + 52, 0) + # Pack sbtOffset (4 bytes) - 0 for triangles, 1 for curves, 2 for heightfield + if entry.is_heightfield: + sbt_offset = 2 + elif entry.is_curve: + sbt_offset = 1 + else: + sbt_offset = 0 + struct.pack_into('I', instances_data, offset + 52, sbt_offset) # Pack visibilityMask (4 bytes) - 0xFF = visible, 0x00 = hidden mask = 0xFF if entry.visible else 0x00 @@ -795,7 +1591,7 @@ def _build_accel(geom_state: _GeometryState, hash_value: int, vertices, indices) # ----------------------------------------------------------------------------- def _trace_rays(geom_state: _GeometryState, rays, hits, num_rays: int, - primitive_ids=None, instance_ids=None) -> int: + primitive_ids=None, instance_ids=None, ray_flags=None) -> int: """ Trace rays against the acceleration structure in the given geometry state. @@ -812,6 +1608,9 @@ def _trace_rays(geom_state: _GeometryState, rays, hits, num_rays: int, instance_ids: Optional output buffer (Nx1 int32) for geometry/instance indices. -1 indicates a miss. Useful in multi-GAS mode to identify which geometry was hit. + ray_flags: Optional OptiX ray flags (unsigned int). Default is + OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES (0x10). + Use RTX.RAY_FLAG_OCCLUSION for shadow/AO queries. Returns: 0 on success, non-zero on error @@ -895,14 +1694,50 @@ def _trace_rays(geom_state: _GeometryState, rays, hits, num_rays: int, inst_ids_on_host = True d_inst_ids_ptr = d_inst_ids.data.ptr - # Pack params: handle(8) + rays_ptr(8) + hits_ptr(8) + prim_ids_ptr(8) + inst_ids_ptr(8) + # Default ray flags: cull back-facing triangles + if ray_flags is None: + ray_flags = 0x10 # OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES + + # Pack params: 48 bytes (existing) + 48 bytes (heightfield) = 96 bytes + # Heightfield fields: ptr(8) + width(4) + height(4) + sx(4) + sy(4) + ve(4) + tile(4) + ntx(4) + pad(4) = 40+8 pad? + # Actually: Q(8) + ii(8) + ff(8) + f(4) + ii(8) + i_pad(4) = 40 → need pad to 48 + hf_data_ptr = 0 + hf_w = 0 + hf_h = 0 + hf_sx = 0.0 + hf_sy = 0.0 + hf_ve = 1.0 + hf_tile = 0 + hf_ntx = 0 + + if geom_state.heightfield_data is not None: + hf_data_ptr = geom_state.heightfield_data.data.ptr + hf_w = geom_state.hf_width + hf_h = geom_state.hf_height + hf_sx = geom_state.hf_spacing_x + hf_sy = geom_state.hf_spacing_y + hf_ve = geom_state.hf_ve + hf_tile = geom_state.hf_tile_size + hf_ntx = geom_state.hf_num_tiles_x + params_data = struct.pack( - 'QQQQQ', - trace_handle, - d_rays.data.ptr, - d_hits.data.ptr, - d_prim_ids_ptr, - d_inst_ids_ptr, + 'QQQQQIIQiifffiii', + trace_handle, # 8 + d_rays.data.ptr, # 8 + d_hits.data.ptr, # 8 + d_prim_ids_ptr, # 8 + d_inst_ids_ptr, # 8 + ray_flags, # 4 + 0, # 4 padding (existing) + hf_data_ptr, # 8 + hf_w, # 4 + hf_h, # 4 + hf_sx, # 4 + hf_sy, # 4 + hf_ve, # 4 + hf_tile, # 4 + hf_ntx, # 4 + 0, # 4 padding ) _state.d_params[:] = cupy.frombuffer(np.frombuffer(params_data, dtype=np.uint8), dtype=cupy.uint8) @@ -911,7 +1746,7 @@ def _trace_rays(geom_state: _GeometryState, rays, hits, num_rays: int, _state.pipeline, 0, # stream _state.d_params.data.ptr, - 40, # sizeof(Params): 5 pointers * 8 bytes + 88, # sizeof(Params) _state.sbt, num_rays, # width 1, # height @@ -1014,7 +1849,15 @@ def getHash(self) -> int: """ return self._geom_state.current_hash - def trace(self, rays, hits, numRays: int, primitive_ids=None, instance_ids=None) -> int: + # OptiX ray flag constants + RAY_FLAG_NONE = 0x00 + RAY_FLAG_CULL_BACK_FACING = 0x10 # OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES + RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04 # OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT + # Combined flag for shadow/AO occlusion queries (early out + backface cull) + RAY_FLAG_OCCLUSION = 0x10 | 0x04 + + def trace(self, rays, hits, numRays: int, primitive_ids=None, instance_ids=None, + ray_flags=None) -> int: """ Trace rays against the current acceleration structure. @@ -1032,18 +1875,23 @@ def trace(self, rays, hits, numRays: int, primitive_ids=None, instance_ids=None) instance_ids: Optional output buffer (numRays x int32) for geometry/instance indices. Will contain the instance ID of the hit geometry, or -1 for misses. Useful in multi-GAS mode to identify which geometry was hit. + ray_flags: Optional OptiX ray flags (unsigned int). Default is + RAY_FLAG_CULL_BACK_FACING. Use RAY_FLAG_OCCLUSION for + shadow/AO queries to enable early termination. Returns: 0 on success, non-zero on error """ - return _trace_rays(self._geom_state, rays, hits, numRays, primitive_ids, instance_ids) + return _trace_rays(self._geom_state, rays, hits, numRays, primitive_ids, instance_ids, + ray_flags=ray_flags) # ------------------------------------------------------------------------- # Multi-GAS API # ------------------------------------------------------------------------- def add_geometry(self, geometry_id: str, vertices, indices, - transform: Optional[List[float]] = None) -> int: + transform: Optional[List[float]] = None, + grid_dims: Optional[tuple] = None) -> int: """ Add a geometry (GAS) to the scene with an optional transform. @@ -1057,6 +1905,9 @@ def add_geometry(self, geometry_id: str, vertices, indices, transform: Optional 12-float list representing a 3x4 row-major affine transform matrix. Defaults to identity. Format: [Xx, Xy, Xz, Tx, Yx, Yy, Yz, Ty, Zx, Zy, Zz, Tz] + grid_dims: Optional (H, W) grid dimensions for cluster-accelerated + builds. When provided and OptiX 9+ clusters are + available, uses the CLAS pipeline for faster BVH builds. Returns: 0 on success, non-zero on error @@ -1089,7 +1940,16 @@ def add_geometry(self, geometry_id: str, vertices, indices, return 0 # Build the GAS for this geometry - gas_handle, gas_buffer = _build_gas_for_geometry(vertices, indices) + use_clusters = ( + grid_dims is not None + and _state.capabilities + and _state.capabilities.get('has_clusters') + ) + if use_clusters: + gas_handle, gas_buffer = _build_gas_clustered( + vertices, indices, grid_dims[0], grid_dims[1]) + else: + gas_handle, gas_buffer = _build_gas_for_geometry(vertices, indices) if gas_handle == 0: return -1 @@ -1106,8 +1966,9 @@ def add_geometry(self, geometry_id: str, vertices, indices, return -1 # Compute vertex/triangle counts from input arrays - num_vertices = len(np.asarray(vertices).ravel()) // 3 - num_triangles = len(np.asarray(indices).ravel()) // 3 + num_vertices = len(vertices_for_hash.ravel()) // 3 + indices_np = indices.get() if isinstance(indices, cupy.ndarray) else np.asarray(indices) + num_triangles = len(indices_np.ravel()) // 3 # Create or update the GAS entry self._geom_state.gas_entries[geometry_id] = _GASEntry( @@ -1125,6 +1986,172 @@ def add_geometry(self, geometry_id: str, vertices, indices, return 0 + def add_curve_geometry(self, geometry_id: str, vertices, widths, + indices, + transform: Optional[List[float]] = None) -> int: + """ + Add round quadratic B-spline curve tubes to the scene. + + This enables multi-GAS mode. Curve GAS entries use a separate + hit group with the built-in curve IS module. + + Args: + geometry_id: Unique identifier for this geometry + vertices: Control point positions (N*3 float32, flattened) + widths: Per-control-point radii (N float32) + indices: Segment start indices (M int32, one per segment) + transform: Optional 12-float 3x4 row-major affine transform. + + Returns: + 0 on success, non-zero on error + """ + global _state + + if not _state.initialized: + _init_optix() + + # Switch to multi-GAS mode if currently in single-GAS mode + if self._geom_state.single_gas_mode: + self._geom_state.gas_handle = 0 + self._geom_state.gas_buffer = None + self._geom_state.current_hash = 0xFFFFFFFFFFFFFFFF + self._geom_state.single_gas_mode = False + + # Compute hash to skip GAS rebuild when vertices haven't changed + if isinstance(vertices, cupy.ndarray): + vertices_for_hash = vertices.get() + else: + vertices_for_hash = np.asarray(vertices) + vertices_hash = hash(vertices_for_hash.tobytes()) + + existing = self._geom_state.gas_entries.get(geometry_id) + if existing is not None and existing.vertices_hash == vertices_hash: + if transform is not None: + existing.transform = list(transform) + self._geom_state.ias_dirty = True + return 0 + + # Compute segment count from indices + indices_np = indices.get() if isinstance(indices, cupy.ndarray) else np.asarray(indices) + num_segments = len(indices_np) + + gas_handle, gas_buffer = _build_gas_for_curves( + vertices, widths, indices, num_segments) + if gas_handle == 0: + return -1 + + if transform is None: + transform = [ + 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + ] + else: + transform = list(transform) + if len(transform) != 12: + return -1 + + num_vertices = len(vertices_for_hash.ravel()) // 3 + + self._geom_state.gas_entries[geometry_id] = _GASEntry( + gas_id=geometry_id, + gas_handle=gas_handle, + gas_buffer=gas_buffer, + vertices_hash=vertices_hash, + transform=transform, + num_vertices=num_vertices, + num_triangles=0, + is_curve=True, + ) + + self._geom_state.ias_dirty = True + return 0 + + def add_heightfield_geometry(self, geometry_id: str, elevation, + H: int, W: int, + spacing_x: float, spacing_y: float, + ve: float = 1.0, + tile_size: int = 32) -> int: + """ + Add a heightfield terrain as a custom-primitive GAS. + + The terrain is represented as a set of tiled AABBs. A custom + intersection program ray-marches through the grid at trace time, + never materializing an explicit triangle mesh. This dramatically + reduces GPU memory for large terrains and provides smooth bilinear + normals. + + Args: + geometry_id: Unique identifier (typically 'terrain'). + elevation: 2-D array (H, W) of float32 elevation values (numpy or cupy). + H: Number of rows. + W: Number of columns. + spacing_x: World-space pixel spacing in X. + spacing_y: World-space pixel spacing in Y. + ve: Vertical exaggeration. Default 1.0. + tile_size: Tile dimension for AABB grouping. Default 32. + + Returns: + 0 on success, non-zero on error. + """ + global _state + + if not _state.initialized: + _init_optix() + + # Switch to multi-GAS mode + if self._geom_state.single_gas_mode: + self._geom_state.gas_handle = 0 + self._geom_state.gas_buffer = None + self._geom_state.current_hash = 0xFFFFFFFFFFFFFFFF + self._geom_state.single_gas_mode = False + + # Get elevation as numpy + if hasattr(elevation, 'get'): + elev_np = elevation.get() + else: + elev_np = np.asarray(elevation, dtype=np.float32) + + gas_handle, gas_buffer, d_elevation, num_tiles_x, num_tiles_y = \ + _build_gas_for_heightfield(elev_np, H, W, spacing_x, spacing_y, ve, tile_size) + + if gas_handle == 0: + return -1 + + # Store heightfield metadata on geometry state for params packing + self._geom_state.heightfield_data = d_elevation + self._geom_state.hf_width = W + self._geom_state.hf_height = H + self._geom_state.hf_spacing_x = spacing_x + self._geom_state.hf_spacing_y = spacing_y + self._geom_state.hf_ve = ve + self._geom_state.hf_tile_size = tile_size + self._geom_state.hf_num_tiles_x = num_tiles_x + + # Identity transform + transform = [ + 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + ] + + # Compute hash for cache invalidation + vertices_hash = hash(elev_np.tobytes()) + + self._geom_state.gas_entries[geometry_id] = _GASEntry( + gas_id=geometry_id, + gas_handle=gas_handle, + gas_buffer=gas_buffer, + vertices_hash=vertices_hash, + transform=transform, + num_vertices=0, + num_triangles=0, + is_heightfield=True, + ) + + self._geom_state.ias_dirty = True + return 0 + def remove_geometry(self, geometry_id: str) -> int: """ Remove a geometry from the scene. @@ -1300,3 +2327,194 @@ def memory_usage(self) -> dict: 'ray_buffers_bytes': ray_buffers_bytes, 'total_bytes': total_bytes, } + + +# ----------------------------------------------------------------------------- +# OptiX AI Denoiser +# ----------------------------------------------------------------------------- + +def _ensure_denoiser(width, height, temporal=False): + """Create or reconfigure the OptiX AI denoiser for the given dimensions. + + Parameters + ---------- + temporal : bool + If True, use DENOISER_MODEL_KIND_TEMPORAL (requires flow vectors). + If False, use DENOISER_MODEL_KIND_HDR (spatial only). + + Returns True if the denoiser is ready, False if unavailable. + """ + global _state + + if _state._denoiser_failed: + return False + + if not _state.initialized: + _init_optix() + + # Recreate denoiser if mode changed or not yet created + need_create = _state.denoiser is None + if not need_create and _state._denoiser_temporal != temporal: + _state.denoiser = None + _state.denoiser_width = 0 + _state.denoiser_height = 0 + need_create = True + + if need_create: + opts = optix.DenoiserOptions() + opts.guideNormal = 1 + opts.guideAlbedo = 1 + model = (optix.DENOISER_MODEL_KIND_TEMPORAL if temporal + else optix.DENOISER_MODEL_KIND_HDR) + try: + _state.denoiser = _state.context.denoiserCreate(model, opts) + except RuntimeError: + import warnings + warnings.warn( + "OptiX AI Denoiser unavailable (missing nvoptix.bin " + "weights file). Denoising will be skipped.", + RuntimeWarning) + _state._denoiser_failed = True + return False + _state._denoiser_temporal = temporal + + if _state.denoiser_width != width or _state.denoiser_height != height: + sizes = _state.denoiser.computeMemoryResources(width, height) + _state.d_denoiser_state = cupy.empty( + sizes.stateSizeInBytes, dtype=cupy.uint8) + _state.d_denoiser_scratch = cupy.empty( + sizes.withoutOverlapScratchSizeInBytes, dtype=cupy.uint8) + _state.denoiser.setup( + 0, # stream + width, height, + _state.d_denoiser_state.data.ptr, sizes.stateSizeInBytes, + _state.d_denoiser_scratch.data.ptr, + sizes.withoutOverlapScratchSizeInBytes) + _state.d_denoiser_normals = cupy.empty( + (height, width, 3), dtype=cupy.float32) + _state.d_denoiser_output = cupy.empty( + (height, width, 3), dtype=cupy.float32) + _state.d_denoiser_albedo = cupy.empty( + (height, width, 3), dtype=cupy.float32) + if temporal: + _state.d_denoiser_flow = cupy.zeros( + (height, width, 3), dtype=cupy.float32) + _state.denoiser_width = width + _state.denoiser_height = height + + return True + + +def denoise(d_color, d_normals, width, height, cam_right, cam_up, cam_forward, + albedo=None, flow=None): + """Apply the OptiX AI Denoiser to a noisy HDR image. + + Parameters + ---------- + d_color : cupy.ndarray + (height, width, 3) float32 HDR color buffer. Modified in-place + with denoised result. + d_normals : cupy.ndarray + (height, width, 3) float32 world-space hit normals. + width, height : int + Image dimensions. + cam_right, cam_up, cam_forward : array-like + Camera basis vectors (3,) for transforming normals to camera space. + albedo : cupy.ndarray, optional + (height, width, 3) float32 albedo guide (material color before lighting). + flow : cupy.ndarray, optional + (height, width, 2) float32 screen-space motion vectors (pixels). + If provided, temporal denoising is used. + """ + global _state + temporal = flow is not None + if not _ensure_denoiser(width, height, temporal=temporal): + return + + # Transform world-space normals to camera space via matrix multiply. + # Column matrix: columns = right, up, forward. + d_basis = cupy.asarray( + np.stack([np.asarray(cam_right, dtype=np.float32), + np.asarray(cam_up, dtype=np.float32), + np.asarray(cam_forward, dtype=np.float32)], axis=1), + dtype=cupy.float32) # (3, 3) + flat_normals = d_normals.reshape(-1, 3) + _state.d_denoiser_normals.reshape(-1, 3)[:] = flat_normals @ d_basis + + row_stride_3 = width * 3 * 4 # 3 float32 × 4 bytes + pixel_stride_3 = 3 * 4 + + color_image = optix.Image2D() + color_image.data = d_color.data.ptr + color_image.width = width + color_image.height = height + color_image.rowStrideInBytes = row_stride_3 + color_image.pixelStrideInBytes = pixel_stride_3 + color_image.format = optix.PIXEL_FORMAT_FLOAT3 + + output_image = optix.Image2D() + output_image.data = _state.d_denoiser_output.data.ptr + output_image.width = width + output_image.height = height + output_image.rowStrideInBytes = row_stride_3 + output_image.pixelStrideInBytes = pixel_stride_3 + output_image.format = optix.PIXEL_FORMAT_FLOAT3 + + normal_image = optix.Image2D() + normal_image.data = _state.d_denoiser_normals.data.ptr + normal_image.width = width + normal_image.height = height + normal_image.rowStrideInBytes = row_stride_3 + normal_image.pixelStrideInBytes = pixel_stride_3 + normal_image.format = optix.PIXEL_FORMAT_FLOAT3 + + layer = optix.DenoiserLayer() + layer.input = color_image + layer.output = output_image + + guide = optix.DenoiserGuideLayer() + guide.normal = normal_image + + # Albedo guide + if albedo is not None: + _state.d_denoiser_albedo[:] = albedo + albedo_image = optix.Image2D() + albedo_image.data = _state.d_denoiser_albedo.data.ptr + albedo_image.width = width + albedo_image.height = height + albedo_image.rowStrideInBytes = row_stride_3 + albedo_image.pixelStrideInBytes = pixel_stride_3 + albedo_image.format = optix.PIXEL_FORMAT_FLOAT3 + guide.albedo = albedo_image + + # Flow guide (temporal denoising) + if temporal: + # Flow is (H, W, 2) — copy into padded (H, W, 3) buffer for FLOAT3 format + _state.d_denoiser_flow[:, :, :2] = flow + flow_image = optix.Image2D() + flow_image.data = _state.d_denoiser_flow.data.ptr + flow_image.width = width + flow_image.height = height + flow_image.rowStrideInBytes = row_stride_3 + flow_image.pixelStrideInBytes = pixel_stride_3 + flow_image.format = optix.PIXEL_FORMAT_FLOAT3 + guide.flow = flow_image + + params = optix.DenoiserParams() + params.blendFactor = 0.0 + + _state.denoiser.invoke( + 0, # stream + params, + _state.d_denoiser_state.data.ptr, + _state.d_denoiser_state.nbytes, + guide, + layer, + 1, # numLayers + 0, 0, # inputOffsetX, inputOffsetY + _state.d_denoiser_scratch.data.ptr, + _state.d_denoiser_scratch.nbytes, + ) + + # Copy denoised result back into the input buffer + d_color[:] = _state.d_denoiser_output diff --git a/rtxpy/tiles.py b/rtxpy/tiles.py index 18c04c3..58caf1a 100644 --- a/rtxpy/tiles.py +++ b/rtxpy/tiles.py @@ -1,6 +1,6 @@ """XYZ map tile fetching and compositing service. -Downloads map tiles (satellite, street map, topo, etc.) from XYZ tile servers +Downloads map tiles (satellite, street map, etc.) from XYZ tile servers and composites them into an RGB texture that matches the terrain grid dimensions. Each tile is reprojected to the raster's native CRS so the imagery aligns with the terrain regardless of projection. Tiles stream in the background via a @@ -29,7 +29,6 @@ TILE_PROVIDERS = { 'osm': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'satellite': 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}', - 'topo': 'https://tile.opentopomap.org/{z}/{x}/{y}.png', } diff --git a/rtxpy/tour.py b/rtxpy/tour.py new file mode 100644 index 0000000..c871832 --- /dev/null +++ b/rtxpy/tour.py @@ -0,0 +1,403 @@ +"""Tour playback engine for the explore() interactive viewer. + +Defines keyframe-based camera tours with smooth interpolation and +optional frame recording for video assembly. + +Typical usage from the REPL:: + + tour = [ + {'time': 0, 'position': [100, 200, 50], 'yaw': 90, 'pitch': -20}, + {'time': 5, 'position': [300, 200, 80], 'yaw': 120, 'pitch': -30}, + {'time': 10, 'position': [300, 400, 60], 'yaw': 180, 'pitch': -25}, + ] + v.tour(tour) + v.tour(tour, record=True, output_dir='frames/') +""" + +import time +from pathlib import Path + +import numpy as np + + +# --------------------------------------------------------------------------- +# Easing functions +# --------------------------------------------------------------------------- + +def ease_linear(t): + """Linear interpolation (no easing).""" + return t + + +def ease_in_out(t): + """Smoothstep — gentle acceleration and deceleration.""" + return t * t * (3 - 2 * t) + + +def ease_in(t): + """Quadratic ease-in — slow start.""" + return t * t + + +def ease_out(t): + """Quadratic ease-out — slow finish.""" + return 1 - (1 - t) * (1 - t) + + +_EASING = { + 'linear': ease_linear, + 'ease_in_out': ease_in_out, + 'ease_in': ease_in, + 'ease_out': ease_out, +} + + +# --------------------------------------------------------------------------- +# Interpolation helpers +# --------------------------------------------------------------------------- + +def _lerp(a, b, t): + """Linear interpolation between scalars or arrays.""" + return a + (b - a) * t + + +def _lerp_angle(a, b, t): + """Interpolate angles (degrees) via the shortest arc.""" + diff = (b - a) % 360 + if diff > 180: + diff -= 360 + return a + diff * t + + +# --------------------------------------------------------------------------- +# Camera state capture +# --------------------------------------------------------------------------- + +def mark_camera(proxy): + """Capture the current camera state as a keyframe dict (without time). + + Returns a dict with ``position``, ``yaw``, ``pitch``, and ``fov``. + """ + return { + 'position': proxy.position.tolist(), + 'yaw': float(proxy.yaw), + 'pitch': float(proxy.pitch), + 'fov': proxy.run(lambda v: v.fov), + } + + +# --------------------------------------------------------------------------- +# Tour playback +# --------------------------------------------------------------------------- + +def play_tour(proxy, keyframes, fps=30, record=False, output_dir='.', + loop=False): + """Play a camera tour through the viewer. + + Parameters + ---------- + proxy : ViewerProxy + The ``v`` handle from ``explore(repl=True)``. + keyframes : list of dict + Each dict may contain: + + - ``time`` (float, required) — seconds from tour start. + - ``position`` (list[3]) — camera position ``[x, y, z]``. + - ``yaw``, ``pitch``, ``fov`` (float) — camera orientation/FOV. + - ``layer`` (str) — switch terrain layer. + - ``colormap`` (str) — switch colormap. + - ``geometry`` (str) — show only this geometry group. + - ``shadows`` (bool) — toggle shadows. + - ``screenshot`` (bool) — take a screenshot at this keyframe. + - ``ease`` (str) — easing function for interpolation arriving + at this keyframe (default ``'ease_in_out'``). + + Camera fields are interpolated between keyframes. Action + fields trigger once when the keyframe time is crossed. + fps : int + Target playback framerate (default 30). + record : bool + If True, save a frame after each interpolated step. + output_dir : str or Path + Directory for recorded frames (``frame_0001.png``, ...). + loop : bool + If True, repeat the tour indefinitely until the viewer closes. + """ + if not keyframes: + print("Tour: no keyframes") + return + + keyframes = sorted(keyframes, key=lambda k: k['time']) + duration = keyframes[-1]['time'] + dt = 1.0 / fps + + if record: + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + + cam_fields = ('position', 'yaw', 'pitch', 'fov') + + loop_label = " (looping)" if loop else "" + print(f"Tour: {len(keyframes)} keyframes, {duration:.1f}s " + f"@ {fps} fps{' (recording)' if record else ''}{loop_label}") + + frame_num = 0 + lap = 0 + + while True: + # Reset action triggers each lap + action_fired = [False] * len(keyframes) + t_start = time.monotonic() + t_tour = 0.0 + + while t_tour <= duration + 1e-9: + # Check if viewer is still alive + if not proxy._viewer.running: + if frame_num and record: + print(f"Tour stopped. {frame_num} frames saved " + f"to {output_dir}") + return + + # --- Interpolate camera state --- + cam_state = {} + for field in cam_fields: + defined = [(kf['time'], kf[field], + kf.get('ease', 'ease_in_out')) + for kf in keyframes if field in kf] + if not defined: + continue + + if t_tour <= defined[0][0]: + cam_state[field] = defined[0][1] + continue + if t_tour >= defined[-1][0]: + cam_state[field] = defined[-1][1] + continue + + for i in range(len(defined) - 1): + t0, v0, _ = defined[i] + t1, v1, ease_name = defined[i + 1] + if t0 <= t_tour <= t1: + raw_t = ((t_tour - t0) / (t1 - t0) + if t1 > t0 else 1.0) + ease_fn = _EASING.get(ease_name, ease_in_out) + t_eased = ease_fn(raw_t) + if field == 'position': + v0 = np.asarray(v0, dtype=np.float64) + v1 = np.asarray(v1, dtype=np.float64) + cam_state[field] = _lerp(v0, v1, t_eased) + elif field == 'yaw': + cam_state[field] = _lerp_angle( + v0, v1, t_eased) + else: + cam_state[field] = _lerp(v0, v1, t_eased) + break + + # Apply camera state on the render thread + if cam_state: + snapshot = dict(cam_state) + + def _apply(v, s=snapshot): + if 'position' in s: + v.position[:] = s['position'] + if 'yaw' in s: + v.yaw = s['yaw'] + if 'pitch' in s: + v.pitch = s['pitch'] + if 'fov' in s: + v.fov = s['fov'] + v._update_frame() + + proxy.run(_apply) + + # --- Fire action triggers --- + for i, kf in enumerate(keyframes): + if action_fired[i]: + continue + if t_tour >= kf['time']: + action_fired[i] = True + if 'layer' in kf: + proxy.show_layer(kf['layer']) + if 'colormap' in kf: + proxy.set_colormap(kf['colormap']) + if 'geometry' in kf: + proxy.show_geometry(kf['geometry']) + if 'shadows' in kf: + proxy.shadows = kf['shadows'] + if kf.get('screenshot'): + proxy.screenshot() + + # --- Record frame --- + if record: + frame_num += 1 + fname = out / f"frame_{frame_num:05d}.png" + + def _save(v, path=str(fname)): + from PIL import Image + frame = v._pinned_frame + if frame is not None: + rgb = np.clip(frame[:, :, :3] * 255, 0, 255 + ).astype(np.uint8) + img = Image.fromarray(rgb) + img.save(path) + + proxy.run(_save) + + # --- Timing --- + t_tour += dt + t_elapsed = time.monotonic() - t_start + sleep_time = t_tour - t_elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + lap += 1 + if not loop: + break + + print(f"Tour complete. {frame_num} frames" + + (f" saved to {output_dir}" if record else "")) + + +# --------------------------------------------------------------------------- +# Interpolation helper (shared by camera tour and observer tour) +# --------------------------------------------------------------------------- + +def _interpolate_fields(keyframes, t_tour, fields): + """Interpolate keyframe fields at time *t_tour*. + + Returns a dict of interpolated values for each field found in + *keyframes*. + """ + state = {} + for field in fields: + defined = [(kf['time'], kf[field], + kf.get('ease', 'ease_in_out')) + for kf in keyframes if field in kf] + if not defined: + continue + + if t_tour <= defined[0][0]: + state[field] = defined[0][1] + continue + if t_tour >= defined[-1][0]: + state[field] = defined[-1][1] + continue + + for i in range(len(defined) - 1): + t0, v0, _ = defined[i] + t1, v1, ease_name = defined[i + 1] + if t0 <= t_tour <= t1: + raw_t = ((t_tour - t0) / (t1 - t0) + if t1 > t0 else 1.0) + ease_fn = _EASING.get(ease_name, ease_in_out) + t_eased = ease_fn(raw_t) + if field == 'position': + v0 = np.asarray(v0, dtype=np.float64) + v1 = np.asarray(v1, dtype=np.float64) + state[field] = _lerp(v0, v1, t_eased) + elif field == 'yaw': + state[field] = _lerp_angle(v0, v1, t_eased) + else: + state[field] = _lerp(v0, v1, t_eased) + break + return state + + +# --------------------------------------------------------------------------- +# Observer tour playback +# --------------------------------------------------------------------------- + +def play_observer_tour(proxy, slot, keyframes, fps=30, loop=False): + """Animate an observer drone along a keyframe path. + + Runs in a daemon thread. The observer's ``tour_stop`` event is + checked each tick for cooperative cancellation. + + Parameters + ---------- + proxy : ViewerProxy + The ``v`` handle from ``explore(repl=True)``. + slot : int + Observer slot (1-8). + keyframes : list of dict + Each dict may contain ``time``, ``position`` (x,y,z list), + ``yaw``, ``pitch``, ``observer_elev``. Only ``position`` + and ``time`` are required. + fps : int + Target playback framerate. + loop : bool + Repeat indefinitely until stopped. + """ + if not keyframes: + print(f"Observer {slot} tour: no keyframes") + return + + keyframes = sorted(keyframes, key=lambda k: k['time']) + duration = keyframes[-1]['time'] + dt = 1.0 / fps + + tour_fields = ('position', 'yaw', 'pitch', 'observer_elev') + + loop_label = " (looping)" if loop else "" + print(f"Observer {slot} tour: {len(keyframes)} keyframes, " + f"{duration:.1f}s @ {fps} fps{loop_label}") + + while True: + t_start = time.monotonic() + t_tour = 0.0 + + while t_tour <= duration + 1e-9: + # Check cancellation + obs = proxy._viewer._observers.get(slot) + if obs is None or obs.tour_stop.is_set(): + print(f"Observer {slot} tour stopped") + return + if not proxy._viewer.running: + return + + state = _interpolate_fields(keyframes, t_tour, tour_fields) + + if state: + snapshot = dict(state) + snap_slot = slot + + def _apply(v, s=snapshot, sl=snap_slot): + o = v._observers.get(sl) + if o is None: + return + if 'position' in s: + pos = s['position'] + o.position = (float(pos[0]), float(pos[1])) + if len(pos) > 2: + terrain_z = v._get_terrain_z(pos[0], pos[1]) + o.observer_elev = max(0.0, + float(pos[2]) - terrain_z) + if 'observer_elev' in s: + o.observer_elev = float(s['observer_elev']) + if 'yaw' in s: + o.yaw = float(s['yaw']) + if 'pitch' in s: + o.pitch = float(s['pitch']) + v._update_observer_drone_for(o) + # If this observer is in FPV and active, camera follows + if (o.drone_mode == 'fpv' + and v._active_observer == sl): + ox, oy = o.position + tz = v._get_terrain_z(ox, oy) + v.position[:] = [ox, oy, tz + o.observer_elev] + v.yaw = o.yaw + v.pitch = o.pitch + v._render_needed = True + + proxy._submit_fire_and_forget(_apply) + + t_tour += dt + t_elapsed = time.monotonic() - t_start + sleep_time = t_tour - t_elapsed + if sleep_time > 0: + time.sleep(sleep_time) + + if not loop: + break + + print(f"Observer {slot} tour complete") From c66b35c8d85bf5129c86dcc0c1574199829d8daf Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 18 Feb 2026 17:28:10 -0800 Subject: [PATCH 5/5] fixed terminal bug --- rtxpy/engine.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/rtxpy/engine.py b/rtxpy/engine.py index 4ce3ba4..c643d50 100644 --- a/rtxpy/engine.py +++ b/rtxpy/engine.py @@ -5536,6 +5536,15 @@ def run(self, start_position: Optional[Tuple[float, float, float]] = None, self.yaw = np.degrees(np.arctan2(direction[1], direction[0])) self.pitch = np.degrees(np.arcsin(np.clip(direction[2], -1, 1))) + # Save terminal state before GLFW (it can alter termios) + import sys + _saved_termios = None + try: + import termios + _saved_termios = termios.tcgetattr(sys.stdin.fileno()) + except (ImportError, termios.error, ValueError): + pass + # --- GLFW window creation --- if not glfw.init(): raise RuntimeError("Failed to initialise GLFW") @@ -5750,8 +5759,13 @@ def _run_repl(): glfw.destroy_window(window) glfw.terminate() self._glfw_window = None - # Reset terminal state (GLFW can hide cursor / alter termios) - import sys + # Restore terminal state (GLFW can disable echo / alter termios) + if _saved_termios is not None: + try: + import termios + termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, _saved_termios) + except (ImportError, termios.error, ValueError): + pass sys.stdout.write('\033[?25h') # show cursor sys.stdout.flush()