diff --git a/LICENSE b/LICENSE index 07a2ad8da4d4..7dc7ec742c4c 100644 --- a/LICENSE +++ b/LICENSE @@ -360,3 +360,13 @@ Project page: https://github.com/SalesforceAIResearch/uni2ts License: https://github.com/SalesforceAIResearch/uni2ts/blob/main/LICENSE.txt -------------------------------------------------------------------------------- + +The following files include code modified from PatchTST project. + +./iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/* + +PatchTST is open source software licensed under the Apache License 2.0 +Project page: https://github.com/ibm-research/patchtst +License: https://github.com/ibm-research/patchtst/blob/main/LICENSE + +-------------------------------------------------------------------------------- diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index bf758a083d46..d54b2d6f947d 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -60,35 +60,36 @@ public class AINodeTestUtils { new AbstractMap.SimpleEntry<>( "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "toto", new FakeModelInfo("toto", "toto", "builtin", "active"))) + "toto", new FakeModelInfo("toto", "toto", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "patchtst_fm", new FakeModelInfo("patchtst_fm", "patchtst_fm", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; static { - Map tmp = - Stream.of( - new AbstractMap.SimpleEntry<>( - "arima", new FakeModelInfo("arima", "sktime", "builtin", "active")), - new AbstractMap.SimpleEntry<>( - "holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")), - new AbstractMap.SimpleEntry<>( - "exponential_smoothing", - new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")), - new AbstractMap.SimpleEntry<>( - "naive_forecaster", - new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")), - new AbstractMap.SimpleEntry<>( - "stl_forecaster", - new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")), - new AbstractMap.SimpleEntry<>( - "gaussian_hmm", - new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")), - new AbstractMap.SimpleEntry<>( - "gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")), - new AbstractMap.SimpleEntry<>( - "stray", new FakeModelInfo("stray", "sktime", "builtin", "active"))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + Map tmp = Stream.of( + new AbstractMap.SimpleEntry<>( + "arima", new FakeModelInfo("arima", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "exponential_smoothing", + new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "naive_forecaster", + new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "stl_forecaster", + new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "gaussian_hmm", + new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "stray", new FakeModelInfo("stray", "sktime", "builtin", "active"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); tmp.putAll(BUILTIN_LTSM_MAP); BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp); } @@ -117,36 +118,35 @@ public static void concurrentInference( AtomicBoolean allPass = new AtomicBoolean(true); Thread[] threads = new Thread[threadCnt]; for (int i = 0; i < threadCnt; i++) { - threads[i] = - new Thread( - () -> { - try { - for (int j = 0; j < loop; j++) { - try (ResultSet resultSet = statement.executeQuery(sql)) { - int outputCnt = 0; - while (resultSet.next()) { - outputCnt++; - } - if (expectedOutputLength != outputCnt) { - allPass.set(false); - fail( - "Output count mismatch for SQL: " - + sql - + ". Expected: " - + expectedOutputLength - + ", but got: " - + outputCnt); - } - } catch (SQLException e) { - allPass.set(false); - fail(e.getMessage()); - } + threads[i] = new Thread( + () -> { + try { + for (int j = 0; j < loop; j++) { + try (ResultSet resultSet = statement.executeQuery(sql)) { + int outputCnt = 0; + while (resultSet.next()) { + outputCnt++; + } + if (expectedOutputLength != outputCnt) { + allPass.set(false); + fail( + "Output count mismatch for SQL: " + + sql + + ". Expected: " + + expectedOutputLength + + ", but got: " + + outputCnt); } - } catch (Exception e) { + } catch (SQLException e) { allPass.set(false); fail(e.getMessage()); } - }); + } + } catch (Exception e) { + allPass.set(false); + fail(e.getMessage()); + } + }); threads[i].start(); } for (Thread thread : threads) { @@ -164,8 +164,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); for (int retry = 0; retry < 200; retry++) { Set foundDevices = new HashSet<>(); - try (final ResultSet resultSet = - statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { while (resultSet.next()) { String deviceId = resultSet.getString("DeviceId"); String loadedModelId = resultSet.getString("ModelId"); @@ -193,8 +192,7 @@ public static void checkModelNotOnSpecifiedDevice( LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices); for (int retry = 0; retry < 50; retry++) { Set foundDevices = new HashSet<>(); - try (final ResultSet resultSet = - statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { while (resultSet.next()) { String deviceId = resultSet.getString("DeviceId"); String loadedModelId = resultSet.getString("ModelId"); @@ -215,16 +213,18 @@ public static void checkModelNotOnSpecifiedDevice( fail("Model " + modelId + " is still loaded on device " + device); } - private static final String[] WRITE_SQL_IN_TREE = - new String[] { - "CREATE DATABASE root.AI", - "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", - }; + private static final String[] WRITE_SQL_IN_TREE = new String[] { + "CREATE DATABASE root.AI", + "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", + }; - /** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */ + /** + * Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of + * data in tree. + */ public static void prepareDataInTree() throws SQLException { prepareData(WRITE_SQL_IN_TREE); try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); @@ -238,7 +238,10 @@ public static void prepareDataInTree() throws SQLException { } } - /** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */ + /** + * Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data + * in table. + */ public static void prepareDataInTable() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 642986c42d21..b9aee4fa5f16 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -30,6 +30,7 @@ def __init__( model_type: str = "", pipeline_cls: str = "", repo_id: str = "", + download_weights: bool = True, auto_map: Optional[Dict] = None, hub_mixin_cls: Optional[str] = None, transformers_registered: bool = False, @@ -40,9 +41,12 @@ def __init__( self.state = state self.pipeline_cls = pipeline_cls self.repo_id = repo_id - self.auto_map = auto_map + self.download_weights = download_weights + self.auto_map = auto_map # If exists, indicates it's a Transformers model self.hub_mixin_cls = hub_mixin_cls - self.transformers_registered = transformers_registered + self.transformers_registered = ( + transformers_registered # Internal flag: whether registered to Transformers + ) def __repr__(self): return ( @@ -173,4 +177,18 @@ def __repr__(self): }, transformers_registered=True, ), + "patchtst_fm": ModelInfo( + model_id="patchtst_fm", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="patchtst_fm", + pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline", + repo_id="ibm-research/patchtst-fm-r1", + download_weights=False, + auto_map={ + "AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig", + "AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction", + }, + transformers_registered=True, + ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index b7799df67a58..0cfb1d4cf41c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -151,21 +151,28 @@ def _process_builtin_model_directory(self, model_dir: str, model_id: str): def _download_model_if_necessary() -> bool: """Returns: True if the model is existed or downloaded successfully, False otherwise.""" - repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id + model_info = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] + repo_id = model_info.repo_id weights_path = os.path.join(model_dir, MODEL_SAFETENSORS) config_path = os.path.join(model_dir, CONFIG_JSON) - if not os.path.exists(weights_path): - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_SAFETENSORS, - local_dir=model_dir, - ) - except Exception as e: - logger.error( - f"Failed to download model weights from HuggingFace: {e}" - ) - return False + + if getattr(model_info, "download_weights", True): + if not os.path.exists(weights_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_SAFETENSORS, + local_dir=model_dir, + ) + except Exception as e: + logger.error( + f"Failed to download model weights from HuggingFace: {e}" + ) + return False + else: + logger.info( + f"Skipping weight download for {model_id} due to configuration." + ) if not os.path.exists(config_path): try: hf_hub_download( diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py new file mode 100644 index 000000000000..b74423fa902c --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/basic.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/basic.py new file mode 100644 index 000000000000..0e1a79d6c1de --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/basic.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +from typing import Optional, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def make_attn_mask(query_pad: torch.Tensor, key_pad: torch.Tensor) -> torch.Tensor: + """ + Build an additive attention mask of shape (B, Q, K) from + query/key padding masks. + + Args: + query_pad: (B, Q) bool or 0/1 tensor. 1/True = padded query position. + key_pad: (B, K) bool or 0/1 tensor. 1/True = padded key position. + + Returns: + attn_mask: (B, Q, K) float tensor, where masked positions are -inf + and valid positions are 0.0 (for use with SDPA). + """ + # Ensure boolean + q_pad = query_pad.bool() # (B, Q) + k_pad = key_pad.bool() # (B, K) + + # A position (q, k) is invalid if *either* the query or key is padded + # Shape: (B, Q, K) + pad = q_pad.unsqueeze(-1) | k_pad.unsqueeze(-2) + + # Build float mask with -inf on padded positions, 0 elsewhere + attn_mask = torch.zeros_like(pad, dtype=torch.float32) + attn_mask.masked_fill_(pad, float("-inf")) + + return attn_mask + + +class MLP(nn.Module): + def __init__( + self, + in_dim, + out_dim, + hidden_dim=256, + num_hidden_layers=1, + dropout=0, + norm=False, + activation=nn.GELU(approximate="tanh"), + output_activation=nn.Identity(), + norm_layer=nn.LayerNorm, + ): + super().__init__() + layers = [] + layers.append(nn.Linear(in_dim, hidden_dim)) + # layers.append(norm_layer(hidden_dim) if norm else nn.Identity()) + layers.append(activation) + for _ in range(num_hidden_layers - 1): + layers.append(nn.Dropout(dropout)) + layers.append(norm_layer(hidden_dim) if norm else nn.Identity()) + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(activation) + layers.append(nn.Dropout(dropout)) + layers.append(norm_layer(hidden_dim) if norm else nn.Identity()) + layers.append(nn.Linear(hidden_dim, out_dim)) + layers.append(output_activation) + self.layers = nn.Sequential(*layers) + # self.init_weights() + + def forward(self, x): + return self.layers(x) + + +class SwiGLU(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=384, dropout=0): + super().__init__() + hidden_dim = round(hidden_dim * 2 / 3) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(in_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim, out_dim) + self.activation = nn.SiLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.activation(self.fc2(x)) + return self.dropout(self.fc3(x)) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, x: torch.Tensor, attn_mask: torch.Tensor | None = None + ) -> torch.Tensor: + if x.ndim == 3: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) # (B, num_heads, N, head_dim) + q, k = self.q_norm(q), self.k_norm(k) + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask, + ) + x = x.transpose(1, 2).reshape(B, N, C) + elif x.ndim == 4: + B, M, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, M, N, 3, self.num_heads, self.head_dim) + .permute(3, 0, 4, 1, 2, 5) + ) + q, k, v = qkv.unbind(0) # (B, num_heads, M, N, head_dim) + q, k = self.q_norm(q), self.k_norm(k) + # print('q', q.shape, 'k', k.shape, 'v', v.shape, 'attn_mask', attn_mask.shape if attn_mask is not None else "None") + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask.unsqueeze(1) if attn_mask is not None else None, + ) + x = x.permute(0, 2, 3, 1, 4).reshape(B, M, N, C) + else: + raise ValueError(f"Unsupported input dimension: {x.ndim}") + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + q_dim: int, # dim of x + kv_dim: Optional[int] = None, # dim of m (defaults to q_dim) + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + kv_dim = kv_dim if kv_dim is not None else q_dim + assert q_dim % num_heads == 0, "q_dim must be divisible by num_heads" + + self.num_heads = num_heads + self.head_dim = q_dim // num_heads + + self.q = nn.Linear(q_dim, q_dim, bias=qkv_bias) + self.kv = nn.Linear( + kv_dim, 2 * q_dim, bias=qkv_bias + ) # produce k and v in the SAME head dim as q + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(q_dim, q_dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x: torch.Tensor, # (B, Nq, q_dim) + m: torch.Tensor, # (B, Nk, kv_dim) + attn_mask: Optional[ + torch.Tensor + ] = None, # broadcastable to (B, num_heads, Nq, Nk) or (Nq, Nk) + is_causal: bool = False, + ) -> torch.Tensor: + if x.ndim == 3: + B, Nq, Cq = x.shape + _, Nk, _ = m.shape + q = ( + self.q(x) + .reshape(B, Nq, self.num_heads, self.head_dim) + .permute(0, 2, 1, 3) + ) # (B, H, Nq, Hd) + kv = ( + self.kv(m) + .reshape(B, Nk, 2, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv.unbind(0) # (B, H, Nk, Hd) + q, k = self.q_norm(q), self.k_norm(k) + x = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0.0, + is_causal=is_causal, + ) # (B, H, Nq, Hd) + x = x.transpose(1, 2).reshape(B, Nq, Cq) # back to (B, Nq, q_dim) + elif x.ndim == 4: + B, M, Nq, Cq = x.shape + _, Nk, _ = m.shape + q = ( + self.q(x) + .reshape(B, M, Nq, self.num_heads, self.head_dim) + .permute(0, 3, 1, 2, 4) + ) # (B, H, M, Nq, Hd) + kv = ( + self.kv(m) + .reshape(B, Nk, 2, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv.unbind(0) # (B, H, Nk, Hd) + q, k = self.q_norm(q), self.k_norm(k) + x = F.scaled_dot_product_attention( + q, + k.unsqueeze(2), + v.unsqueeze(2), + attn_mask=attn_mask.unsqueeze(1) if attn_mask is not None else None, + dropout_p=self.attn_drop.p if self.training else 0.0, + is_causal=is_causal, + ) # (B, H, M, Nq, Hd) + x = x.permute(0, 2, 3, 1, 4).reshape(B, M, Nq, Cq) + else: + raise ValueError(f"Unsupported input dimension: {x.ndim}") + x = self.proj_drop(self.proj(x)) + return x + + +class TransformerBlock(nn.Module): + """ + A standard Transformer block. + """ + + def __init__( + self, + d_model, + num_heads, + mlp_ratio=4.0, + dropout=0.1, + norm_first=True, + norm_layer=nn.LayerNorm, + mlp_type="mlp", + ): + super().__init__() + self.norm_first = norm_first + self.norm1 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + d_model, num_heads, qkv_bias=True, attn_drop=dropout, proj_drop=dropout + ) + self.norm2 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + if mlp_type == "swiglu": + self.mlp = SwiGLU( + d_model, d_model, hidden_dim=int(mlp_ratio * d_model), dropout=dropout + ) + elif mlp_type == "mlp": + self.mlp = MLP( + in_dim=d_model, + out_dim=d_model, + hidden_dim=int(mlp_ratio * d_model), + dropout=dropout, + ) + else: + raise ValueError(f"Unsupported MLP type: {mlp_type}") + self.dropout = nn.Dropout(dropout) + + def forward(self, x, attn_mask=None): + if self.norm_first: + x = x + self.attn(self.norm1(x), attn_mask) + x = x + self.dropout(self.mlp(self.norm2(x))) + else: + x = self.norm1(x + self.attn(x, attn_mask)) + x = self.norm2(x + self.dropout(self.mlp(x))) + return x + + +class TransformerBlockCrossAttention(nn.Module): + def __init__( + self, + d_model, + num_heads, + d_cond=None, + mlp_ratio=4.0, + dropout=0.1, + norm_first=True, + norm_layer=nn.LayerNorm, + mlp_type="mlp", + ): + super().__init__() + d_cond = d_cond if d_cond is not None else d_model + self.norm_first = norm_first + self.norm1 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + self.attn = CrossAttention( + d_model, + d_cond, + num_heads, + qkv_bias=True, + attn_drop=dropout, + proj_drop=dropout, + ) + self.norm2 = norm_layer(d_model, elementwise_affine=True, eps=1e-6) + if mlp_type == "swiglu": + self.mlp = SwiGLU( + d_model, d_model, hidden_dim=int(mlp_ratio * d_model), dropout=dropout + ) + elif mlp_type == "mlp": + self.mlp = MLP( + in_dim=d_model, + out_dim=d_model, + hidden_dim=int(mlp_ratio * d_model), + dropout=dropout, + ) + else: + raise ValueError(f"Unsupported MLP type: {mlp_type}") + self.dropout = nn.Dropout(dropout) + + def forward(self, x, m, attn_mask=None): + if self.norm_first: + x = x + self.attn(self.norm1(x), m, attn_mask) + x = x + self.dropout(self.mlp(self.norm2(x))) + else: + x = self.norm1(x + self.attn(x, m, attn_mask)) + x = self.norm2(x + self.dropout(self.mlp(x))) + return x diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/configuration_patchtst_fm.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/configuration_patchtst_fm.py new file mode 100644 index 000000000000..d1d5be2bbb12 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/configuration_patchtst_fm.py @@ -0,0 +1,77 @@ +# Copyright contributors to the TSFM project +# + + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +"""PatchTST-FM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +PATCHTSTFM_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class PatchTSTFMConfig(PretrainedConfig): + model_type = "patchtst_fm" + attribute_map = { + "hidden_size": "d_model", + "num_hidden_layers": "n_layer", + } + + # has_no_defaults_at_init = True + def __init__( + self, + context_length: int = 8192, + prediction_length: int = 64, + d_patch: int = 16, + d_model: int = 384, + n_head: int = 6, + n_layer: int = 6, + norm_first: bool = True, + pretrain_mask_ratio: float = 0.4, + pretrain_mask_cont: int = 8, + num_quantile: int = 99, + **kwargs, + ): + self.context_length = context_length + self.prediction_length = prediction_length + self.d_patch = d_patch + self.n_patch = int(context_length // d_patch) + self.d_model = d_model + self.n_head = n_head + self.n_layer = n_layer + self.norm_first = norm_first + self.pretrain_mask_ratio = pretrain_mask_ratio + self.pretrain_mask_cont = pretrain_mask_cont + self.num_quantile = num_quantile + + if num_quantile % 9 == 0: + quantiles = [ + i / (self.num_quantile + 1) for i in range(1, self.num_quantile + 1) + ] + else: + quantiles = [ + i / (self.num_quantile - 1) for i in range(1, self.num_quantile - 1) + ] + quantiles = [0.01] + quantiles + [0.99] + self.quantile_levels = quantiles + super().__init__(**kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/modeling_patchtst_fm.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/modeling_patchtst_fm.py new file mode 100644 index 000000000000..934b1b4337ad --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/modeling_patchtst_fm.py @@ -0,0 +1,488 @@ +# Copyright contributors to the TSFM project +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +"""PatchTST-FM model implementation""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from .basic import ( + TransformerBlock, + make_attn_mask, +) +from .configuration_patchtst_fm import PatchTSTFMConfig +from .normalization import RevIN +from .tools import count_parameters + +logger = logging.get_logger(__name__) + + +class LearnedPositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000, type="add"): + super().__init__() + self.embedding = nn.Embedding(max_len, d_model) + self.type = type + + def forward(self, x): + positions = torch.arange(x.size(-2), device=x.device).unsqueeze(0) + pe = self.embedding(positions) + if x.ndim == 4: + pe = pe.unsqueeze(1) + if self.type == "add": + return x + pe + elif self.type == "mul": + return x * pe + else: + raise ValueError(f"Invalid type: {self.type}") + + +class ResidualBlock(nn.Module): + def __init__(self, d_in, d_out, d_hidden): + super().__init__() + + self.layer1 = nn.Linear(d_in, d_hidden) + self.layer2 = nn.Linear(d_hidden, d_out) + self.residual = nn.Linear(d_in, d_out) + self.activation = nn.Sigmoid() + + def forward(self, x): + return self.layer2(self.activation(self.layer1(x))) + self.residual(x) + + +class PatchTSTFMPreTrainedModel(PreTrainedModel): + # Weight initialization + config_class = PatchTSTFMConfig + base_model_prefix = "model" + main_input_name = "inputs" + supports_gradient_checkpointing = False + + +@dataclass +class PatchTSTFMModelOutput(ModelOutput): + loss_mask: torch.Tensor = None + normed_target: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + quantile_predictions: torch.FloatTensor = None + + +@dataclass +class PatchTSTFMPretrainingOutput(ModelOutput): + loss: torch.Tensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + quanitle_predictions: torch.Tensor = None + + +@dataclass +class PatchTSTFMPredictionOutput(ModelOutput): + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + quantile_predictions: torch.Tensor = None + + +class PatchTSTFMModel(PatchTSTFMPreTrainedModel): + def __init__(self, config: PatchTSTFMConfig): + super().__init__(config) + self.config = config + self.quantile_levels = config.quantile_levels + self.pos_embed = LearnedPositionalEmbedding( + d_model=config.d_model, max_len=config.n_patch, type="add" + ) + assert ( + config.d_model % config.n_head == 0 + ), "[QuantileDecoder] d_model must be divisible by n_head" + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + config.d_model, + config.n_head, + mlp_ratio=4.0, + norm_first=True, + dropout=0.1, + ) + for _ in range(config.n_layer) + ] + ) + self.in_layer = ResidualBlock( + config.d_patch * 2, config.d_model, config.d_model + ) + self.out_layer = ResidualBlock( + config.d_model, config.d_patch * (config.num_quantile + 1), config.d_model + ) + + self.norm_fn = RevIN(dim=-1, std_min=1e-5, use_sinh=True) + + def model_summary(self): + s = "" + model_name = "PatchTST-FM" + s += f"{'=' * 5:<10} {model_name} {'=' * 5:>9}\n" + s += f"{'Transformer:':<20} {count_parameters(self.blocks)[0] / 1e6:>8.2f}M\n" + s += f"{'=' * 30}\n" + p = count_parameters(self) + s += f"{'Trainable:':<20} {p[1] / 1e6:>8.2f}M\n" + s += f"{'Frozen:':<20} {p[2] / 1e6:>8.2f}M\n" + s += f"{'Total:':<20} {p[0] / 1e6:>8.2f}M\n" + s += f"{'=' * 30}\n" + return s + + def forward( + self, + inputs: torch.Tensor, + pred_mask: torch.Tensor, + miss_mask: torch.Tensor, + pad_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + # **kwargs, + ) -> PatchTSTFMPretrainingOutput: + x = inputs.to(self.device) + pad_mask = pad_mask.to(self.device).bool() + pred_mask = pred_mask.to(self.device).bool() + miss_mask = miss_mask.to(self.device).bool() + if x.ndim > 2: + x = rearrange(x, "B N T -> (B N) T") + pad_mask = rearrange(pad_mask, "B N T -> (B N) T") + pred_mask = rearrange(pred_mask, "B N T -> (B N) T") + miss_mask = rearrange(miss_mask, "B N T -> (B N) T") + + B, T = x.shape + ts_mask = pred_mask | pad_mask | miss_mask + + x_target = self.norm_fn.fit_transform(x, mask=pred_mask | pad_mask | miss_mask) + x_input = torch.where(ts_mask, torch.zeros_like(x_target), x_target) + + x_patch = x_input.reshape(B, self.config.n_patch, self.config.d_patch) + mask_patch = ts_mask.reshape(B, self.config.n_patch, self.config.d_patch) + pad_patch_mask = ( + pad_mask.reshape(B, self.config.n_patch, self.config.d_patch) + .float() + .mean(dim=-1) + .gt(0.9) + ) + + q_pred, q_raw = self.decode( + x=x_patch, mask=mask_patch.float(), t_pad_mask=pad_patch_mask + ) + q_pred = q_pred.permute(0, 2, 3, 1) + + B, N, D, Q = q_pred.shape + q_pred = q_pred.reshape(B, N * D, Q) + + if output_hidden_states: + hidden_states = q_raw.reshape(B, N * D, Q) + else: + hidden_states = None + + # return here q_pred, loss_mask, and x_target + return PatchTSTFMModelOutput( + normed_target=x_target, + quantile_predictions=q_pred, + loss_mask=(pred_mask & ~pad_mask & ~miss_mask).float(), + hidden_states=hidden_states, + ) + + def decode(self, x, mask, t_pad_mask=None): + B, N, D = x.shape + # x = self.in_layer(torch.cat([x, t, 1 - mask], dim=-1)) + x = self.in_layer(torch.cat([x, 1 - mask], dim=-1)) + pad_attn_mask = make_attn_mask(t_pad_mask, t_pad_mask).unsqueeze(1) + + x = self.pos_embed(x) + for block in self.blocks: + x = block(x, pad_attn_mask) + x = self.out_layer(x) + q_raw = x.reshape( + B, N, self.config.num_quantile + 1, self.config.d_patch + ).permute(0, 2, 1, 3) + q = q_raw[:, 0, :, :].unsqueeze(1) + torch.cumsum( + F.softplus(q_raw[:, 1:, :, :]) / self.config.num_quantile, dim=1 + ) + return q, q_raw + + +class PatchTSTFMForPretraining(PatchTSTFMPreTrainedModel): + def __init__(self, config: PatchTSTFMConfig): + super().__init__(config) + + self.config = config + self.backbone = PatchTSTFMModel(config) + + # move all out_layer items here + + def forward( + self, + inputs: torch.Tensor, + pred_mask: torch.Tensor, + miss_mask: torch.Tensor, + pad_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ) -> PatchTSTFMPretrainingOutput: + # move quantile logic here + + model_outputs = self.backbone( + inputs, + pred_mask=pred_mask, + miss_mask=miss_mask, + pad_mask=pad_mask, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + q_pred = model_outputs.quantile_predictions + x_target = model_outputs.normed_target + loss_mask = model_outputs.loss_mask + + if return_loss: + x_target = x_target.unsqueeze(-1) + quantiles = torch.tensor( + self.backbone.quantile_levels, device=x_target.device + ).view(1, 1, -1) + loss = 2 * torch.abs( + (x_target - q_pred) * ((x_target <= q_pred).float() - quantiles) + ) + loss = loss * loss_mask.unsqueeze(-1) + loss = loss.sum(dim=1) / torch.clamp( + loss_mask.sum(dim=1, keepdim=True), min=1 + ) + loss = loss.sum(dim=-1).mean() / math.sqrt(self.config.num_quantile) + else: + loss = None + + x_pred = q_pred.permute(0, 2, 1) + x_pred = self.backbone.norm_fn.inverse_transform(x_pred) + + return PatchTSTFMPretrainingOutput( + quantile_predictions=x_pred, + loss=loss, + hidden_states=model_outputs.hidden_states, + ) + + +class PatchTSTFMForPrediction(PatchTSTFMPreTrainedModel): + def __init__(self, config: PatchTSTFMConfig): + super().__init__(config) + + self.config = config + self.backbone = PatchTSTFMModel(config) + + def model_summary(self) -> str: + return self.backbone.model_summary() + + def forward( + self, + inputs: List[torch.Tensor] | torch.Tensor, + prediction_length: Optional[int] = None, + quantile_levels: Optional[List[float]] = None, + output_hidden_states: Optional[bool] = False, + return_loss: bool = True, + return_dict: Optional[bool] = None, + ): + forecast_len = ( + prediction_length if prediction_length else self.config.prediction_length + ) + + cl = self.config.context_length + ul = -1 + logger.info( + f"Context Len: {cl} | Forecast Len: {forecast_len} ", + ) + cl = [cl] * len(inputs) + fl = [ + max( + forecast_len, + ul, + self.config.d_patch * max(self.config.pretrain_mask_cont, 2), + ) + ] * len(inputs) + forecast_samples, hidden_states = self.forecast_single_step( + inputs, fl, context_len=cl, output_hidden_states=output_hidden_states + ) + forecast_samples = torch.stack(forecast_samples, dim=0)[:, :, :forecast_len] + + if quantile_levels is not None: + quantile_indices = [ + self.backbone.quantile_levels.index(q) for q in quantile_levels + ] + forecast_samples = forecast_samples[:, quantile_indices, :] + return PatchTSTFMPredictionOutput( + quantile_predictions=forecast_samples, hidden_states=hidden_states + ) + + def forecast_single_step( + self, + x: List[torch.Tensor], + forecast_len: List[int], + context_len: List[int], + output_hidden_states: Optional[bool] = False, + ): + """ + x: list of torch.Tensor of time series, can be of different lengths + """ + + inputs = [] + pad_mask = [] + pred_mask = [] + miss_mask = [] + ts_ends = [] + time_index = [] + sample_lengths = [] + + for x_i, c_i, f_i in zip(x, context_len, forecast_len): + c_i = min(x_i.shape[0] + f_i, c_i) + s_i = c_i - f_i + x_in = x_i[-s_i:] + pad_mask_i = torch.zeros_like(x_in) + miss_mask_i = torch.zeros_like(x_in) + x_in = torch.nan_to_num(x_in, nan=x_in.nanmean().item()) + pred_mask_i = torch.cat([torch.zeros_like(x_in), torch.ones(f_i)], dim=-1) + miss_mask_i = torch.cat([miss_mask_i, torch.zeros(f_i)], dim=-1) + pad_mask_i = torch.cat([pad_mask_i, torch.zeros(f_i)], dim=-1) + x_in = torch.cat([x_in, torch.ones(f_i) * x_in.nanmean().item()], dim=-1) + time_index_i = ( + torch.arange( + self.config.context_length - x_in.shape[-1] + 1, + self.config.context_length + 1, + ).float() + / self.config.context_length + ) + sample_len = x_in.shape[-1] + if sample_len == self.config.context_length: + inputs.append(x_in) + pred_mask.append(pred_mask_i) + pad_mask.append(pad_mask_i) + miss_mask.append(miss_mask_i) + time_index.append(time_index_i) + ts_ends.append(torch.tensor([0, sample_len]).float()) + sample_lengths.append(sample_len) + elif sample_len < self.config.context_length: # padding + left_pad = self.config.context_length - sample_len + inputs.append( + F.pad( + x_in, + (left_pad, 0), + mode="constant", + value=x_in.nanmean().item(), + ) + ) + pred_mask.append( + F.pad(pred_mask_i, (left_pad, 0), mode="constant", value=0.0) + ) + pad_mask.append( + F.pad(pad_mask_i, (left_pad, 0), mode="constant", value=1.0) + ) + miss_mask.append( + F.pad(miss_mask_i, (left_pad, 0), mode="constant", value=0.0) + ) + time_index.append( + F.pad(time_index_i, (left_pad, 0), mode="constant", value=-1) + ) + ts_ends.append(torch.tensor([left_pad, left_pad + sample_len]).float()) + sample_lengths.append(sample_len) + else: # subsample + inputs.append( + F.interpolate( + x_in.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + pred_mask.append( + F.interpolate( + pred_mask_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + pad_mask.append( + F.interpolate( + pad_mask_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + miss_mask.append( + F.interpolate( + miss_mask_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + time_index.append( + F.interpolate( + time_index_i.view(1, 1, -1), + size=self.config.context_length, + mode="nearest", + ).squeeze() + ) + ts_ends.append(torch.tensor([0, self.config.context_length]).float()) + sample_lengths.append(sample_len) + + inputs = torch.stack(inputs, dim=0) + pred_mask = torch.stack(pred_mask, dim=0) + pad_mask = torch.stack(pad_mask, dim=0) + miss_mask = torch.stack(miss_mask, dim=0) + time_index = torch.stack(time_index, dim=0) + ts_ends = torch.stack(ts_ends, dim=0) + + precision = ( + torch.bfloat16 + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + device = ( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.mps.is_available() else "cpu" + ) + + with torch.autocast(device_type=device, dtype=precision, enabled=True): + model_output = self.backbone( + inputs=inputs, + pred_mask=pred_mask, + miss_mask=miss_mask, + pad_mask=pad_mask, + return_loss=False, + output_hidden_states=output_hidden_states, + ) + outputs = model_output.quantile_predictions + + outputs = outputs.permute(0, 2, 1) + outputs = self.backbone.norm_fn.inverse_transform(outputs) + + x_preds = [] + for i in range(outputs.shape[0]): + if sample_lengths[i] <= self.config.context_length: + x_pred = outputs[i][:, int(ts_ends[i][0]) : int(ts_ends[i][1])] + else: + x_pred = F.interpolate( + outputs[i].unsqueeze(1), size=sample_lengths[i], mode="linear" + ).squeeze(1) + x_preds.append(x_pred[:, -forecast_len[i] :]) + return x_preds, model_output.hidden_states diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/normalization.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/normalization.py new file mode 100644 index 000000000000..f9e6354a8521 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/normalization.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import torch +import torch.nn as nn + + +class RevIN(nn.Module): + def __init__(self, dim=-1, std_min=1e-5, max_val=100, use_sinh=False): + super().__init__() + self.dim = dim + self.std_min = std_min + self.max_val = max_val + self.use_sinh = use_sinh + + def fit_transform(self, x, mask=None): + with torch.autocast(device_type="cuda", enabled=False): + self._get_statistics(x, mask) + return self.transform(x) + + def transform(self, x): + with torch.autocast(device_type="cuda", enabled=False): + x = (x - self.mean) / self.std + if self.use_sinh: + x = torch.asinh(x) + return x + + def inverse_transform(self, x): + with torch.autocast(device_type="cuda", enabled=False): + if self.use_sinh: + x = torch.sinh(x) + if x.ndim != self.mean.ndim: + x = x * self.std.unsqueeze(1) + self.mean.unsqueeze(1) + else: + x = x * self.std + self.mean + + return x + + def get_statistics(self): + return self.mean, self.std + + def _get_statistics(self, x, mask=None): + if mask is None: + self.mean = x.mean(dim=self.dim, keepdim=True) + std = x.std(dim=self.dim, keepdim=True) + self.std = torch.where(std > self.std_min, std, torch.ones_like(std)) + else: + mask = mask.bool() + unmask = (~mask).float() + count = unmask.sum(dim=self.dim, keepdim=True).clamp( + min=1 + ) # avoid division by zero + x_mean = (x * unmask).sum(dim=self.dim, keepdim=True) / count + x_std = (((x - x_mean) * unmask) ** 2).sum( + dim=self.dim, keepdim=True + ) / count + x_std = x_std.sqrt() + x_std = torch.where(x_std > self.std_min, x_std, torch.ones_like(x_std)) + self.mean = x_mean + self.std = x_std + + +class CausalRevIN(nn.Module): + def __init__(self, dim=-1, std_min=1e-5, max_val=100): + """ + Causal RevIN implementation to enable parallel predictions during training of FlowState + + :param eps: a value added for numerical stability + :param with_missing (bool): whether contiguous patch masking (CPM) is used or not, interpreting nans as missing values + """ + super().__init__() + self.dim = dim + self.std_min = std_min + self.max_val = max_val + + def fit_transform(self, x, mask=None): + self._get_statistics(x, mask) + return self.transform(x) + + def transform(self, x): + return torch.clamp( + (x - self.mean) / self.std, min=-self.max_val, max=self.max_val + ) + + def inverse_transform(self, x): + if x.ndim == 2: + return x * self.std + self.mean + elif x.ndim == 3: + return x * self.std.unsqueeze(-1) + self.mean.unsqueeze(-1) + else: + raise ValueError(f"Invalid input dimension: {x.shape}") + + def get_statistics(self): + return self.mean, self.std + + def _get_statistics(self, x, mask=None): + if mask is not None: + n = torch.cumsum(1 - mask.float(), dim=1) + n = torch.where(n == 0, 1.0, n) + else: + n = torch.arange(1, x.shape[1] + 1, device=x.device) + self.mean = (torch.cumsum(x, dim=1) / n).detach() + mask = 1 - mask.float() if mask is not None else 1 + self.std = torch.sqrt( + torch.cumsum(((x - self.mean) * mask) ** 2, 1) / n + ).detach() + self.std = torch.where( + self.std > self.std_min, self.std, torch.ones_like(self.std) + ) + + def set_statistics(self, mean, std): + self.mean = mean + self.std = std diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py new file mode 100644 index 000000000000..4c988c870d7b --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/pipeline_patchtst_fm.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline + + +class PatchTSTFMPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, **model_kwargs) + + def preprocess(self, inputs, **infer_kwargs): + inputs = super().preprocess(inputs, **infer_kwargs) + for idx, item in enumerate(inputs): + # Model expects float32 + target_tensor = item["targets"].to(torch.float32) + + # Expand 1D tensor [length] to [batch=1, length] + if target_tensor.ndim == 1: + target_tensor = target_tensor.unsqueeze(0) + + item["targets"] = target_tensor + return inputs + + def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: + """ + Run the PatchTST-FM-R1 forward pass for each input in the batch. + The model expects a list of 1D tensors (one per variate) and returns + a PatchTSTFMPredictionOutput with a `quantile_predictions` attribute. + """ + forecasts = [] + for item in inputs: + targets = item["targets"] + pred_length = infer_kwargs.get("output_length", 96) + # Move to device and convert [n_variates, length] → list of 1D tensors + # as required by PatchTSTFMForPrediction.forward() + tensor = targets.to(self.device) + tensor_list = [tensor[i] for i in range(tensor.shape[0])] + with torch.no_grad(): + output = self.model(inputs=tensor_list, prediction_length=pred_length) + forecasts.append(output.quantile_predictions) + return forecasts + + def postprocess( + self, outputs: list[torch.Tensor], **infer_kwargs + ) -> list[torch.Tensor]: + """ + The IBM Model returns quantiles [batch, variates, prediction_length, quantiles]. + We reduce this to [variates, prediction_length] by taking the median or mean. + """ + final_outputs = [] + for output in outputs: + # Remove batch dimension if it is just a single batch + if output.ndim == 4: + output = output.squeeze(0) + + # Average out the quantiles to get a point forecast + point_forecast = output.mean(dim=-1) + final_outputs.append(point_forecast) + + return super().postprocess(final_outputs, **infer_kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/tools.py b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/tools.py new file mode 100644 index 000000000000..9e16c7cb444e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/tools.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import os +import random +import time +from datetime import datetime + +import numpy as np +import pandas as pd +import torch + + +def seed_everything(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def count_parameters(model): + total_params = sum(p.numel() for p in model.parameters()) + grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + no_grad_params = total_params - grad_params + return total_params, grad_params, no_grad_params + + +def to_hms(seconds): + h = int(seconds // 3600) + m = int((seconds % 3600) // 60) + s = int(seconds % 60) + return f"{h:02d}:{m:02d}:{s:02d}" + + +def hms_to_seconds(hms): + h, m, s = map(int, hms.split(":")) + return h * 3600 + m * 60 + s + + +def compute_remaining_time(start_time, current_step, max_steps): + current_time = time.time() + elapsed_time = current_time - start_time + remaining_steps = max_steps - current_step + remaining_time = elapsed_time * remaining_steps / current_step + second_per_step = elapsed_time / current_step + return ( + f"{to_hms(elapsed_time)}<{to_hms(remaining_time)} ({second_per_step:.2f}s/step)" + ) + + +class Timer: + def __init__(self, start_step: int = 0, max_step: int = 0): + self.start_time = time.time() + self.last_time = time.time() + self.last_step = start_step + self.max_step = max_step + + def __call__(self, step): + current_time = time.time() + delta_time = current_time - self.last_time + delta_step = max(step - self.last_step, 1) + remaining_step = self.max_step - step + remaining_time = delta_time * remaining_step / delta_step + elapsed_time = current_time - self.start_time + second_per_step = delta_time / delta_step + self.last_time = current_time + self.last_step = step + return f"{to_hms(elapsed_time)}<{to_hms(remaining_time)} ({second_per_step:.2f}s/step)" + + +class StandardScaler: + """ + A numpy implementation of StandardScaler that mimics sklearn's StandardScaler. + Standardizes features by removing the mean and scaling to unit variance. + """ + + def __init__(self, with_mean=True, with_std=True): + self.with_mean = with_mean + self.with_std = with_std + self.mean_ = None + self.scale_ = None + self.var_ = None + self.n_samples_seen_ = 0 + + def fit(self, X): + """ + Compute the mean and std to be used for later scaling. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data used to compute the mean and standard deviation. + + Returns: + -------- + self : object + Returns self. + """ + X = np.array(X, dtype=np.float64) + + if self.with_mean: + self.mean_ = np.mean(X, axis=0) + else: + self.mean_ = np.zeros(X.shape[1], dtype=np.float64) + + if self.with_std: + self.var_ = np.var(X, axis=0) + self.scale_ = np.sqrt(self.var_) + # Handle zeros in scale + self.scale_ = np.where(self.scale_ == 0, 1.0, self.scale_) + else: + self.var_ = np.ones(X.shape[1], dtype=np.float64) + self.scale_ = np.ones(X.shape[1], dtype=np.float64) + + self.n_samples_seen_ = X.shape[0] + + return self + + def transform(self, X): + """ + Perform standardization by centering and scaling. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data to standardize. + + Returns: + -------- + X_scaled : array-like, shape [n_samples, n_features] + Standardized data. + """ + X = np.array(X, dtype=np.float64) + + if self.with_mean: + X = X - self.mean_ + + if self.with_std: + X = X / self.scale_ + + return X + + def fit_transform(self, X): + """ + Fit to data, then transform it. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data to be transformed. + + Returns: + -------- + X_scaled : array-like, shape [n_samples, n_features] + Standardized data. + """ + return self.fit(X).transform(X) + + def inverse_transform(self, X): + """ + Scale back the data to the original representation. + + Parameters: + ----------- + X : array-like, shape [n_samples, n_features] + The data to inverse transform. + + Returns: + -------- + X_orig : array-like, shape [n_samples, n_features] + Data in original scale. + """ + X = np.array(X, dtype=np.float64) + + if self.with_std: + X = X * self.scale_ + + if self.with_mean: + X = X + self.mean_ + + return X + + +class CSVLogger: + """ + Simple CSV logger that stores training metrics and figure paths. + """ + + def __init__(self, log_dir: str): + """ + Initialize CSV logger. + + Parameters: + ----------- + log_dir : str + Directory to save CSV log files and figures + """ + + self.log_dir = f"{log_dir}/run_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}" + os.makedirs(self.log_dir, exist_ok=True) + + # Initialize dataframe to store scalar metrics + self.scalar_data = [] + self.scalar_file = os.path.join(self.log_dir, "scalars.csv") + + self.fig_dir = os.path.join(self.log_dir, "figures") + os.makedirs(self.fig_dir, exist_ok=True) + os.makedirs(f"{self.fig_dir}/TRAIN") + os.makedirs(f"{self.fig_dir}/VAL") + + def log_scalar(self, tag: str, value: float, step: int): + self.scalar_data.append( + {"timestamp": datetime.now(), "step": step, "tag": tag, "value": value} + ) + + def save(self): + # Save to CSV + df = pd.DataFrame(self.scalar_data) + df.to_csv(self.scalar_file, index=False) + + def log_figure(self, tag: str, figure, step: int): + # Create figures subdirectory + fig_path = os.path.join(self.fig_dir, f"{tag}_{step}.png") + figure.savefig(fig_path)