-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathidentify_lora.py
More file actions
129 lines (108 loc) · 4.78 KB
/
identify_lora.py
File metadata and controls
129 lines (108 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import torch
import argparse
from safetensors.torch import load_file
import json
import re
def identify_lora_type(model_path):
"""
Identify the type of LoRA/LyCORIS model by examining its structure.
Args:
model_path: Path to the LoRA model file (.safetensors or .pt)
Returns:
str: The identified LoRA type and additional information
"""
print(f"Analyzing model: {model_path}")
# Check file extension
is_safetensors = model_path.endswith('.safetensors')
try:
# Load the model
if is_safetensors:
state_dict = load_file(model_path)
# Try to load metadata for safetensors
metadata = None
try:
with open(model_path, 'rb') as f:
# Read the header to extract metadata
header_size = int.from_bytes(f.read(8), byteorder='little')
metadata_bytes = f.read(header_size)
metadata = json.loads(metadata_bytes)
print("\nMetadata found:")
for key, value in metadata.items():
print(f" {key}: {value}")
except Exception as e:
print(f"Could not extract metadata: {e}")
else:
state_dict = torch.load(model_path, map_location='cpu')
except Exception as e:
return f"Error loading model: {e}"
# Collect keys for analysis
keys = list(state_dict.keys())
if not keys:
return "Empty model state dict"
print(f"\nFound {len(keys)} keys in the state dict")
# Print some sample keys for reference
print("\nSample keys (up to 5):")
for key in keys[:5]:
print(f" {key}: {state_dict[key].shape}")
# Common patterns to identify different LoRA types
lora_patterns = {
"standard_lora": (r"\.lora_[AB]\.weight$", "Standard LoRA"),
"loha": (r"\.hada_[wt]1.*|\.hada_[wt]2.*", "LoHa (LoCon Hada)"),
"lokr": (r"\.lokr_[wt]1.*|\.lokr_[wt]2.*", "LoKr"),
"ia3": (r"\.ia3_.*", "IA³"),
"dylora": (r"\.dya_|\.dyb_", "DyLoRA"),
"dora": (r"\.d_", "DoRA (Weight-Decomposed LoRA)"),
"glora": (r"\.gloraw_a|\.gloraw_b", "GLoRA"),
"norms": (r"\.lora_magnitude|\.lora_norm1|\.lora_norm2", "LoRA with norms"),
"diag": (r"\.diag_", "Diagonal adapters"),
"full": (r"\.full_", "Full adapters"),
"ema": (r"\.ema_", "EMA entries"),
}
# Check for specific patterns in keys
found_types = {}
for pattern_key, (pattern, name) in lora_patterns.items():
matches = [key for key in keys if re.search(pattern, key)]
if matches:
found_types[pattern_key] = (name, len(matches))
# Output detailed information
if found_types:
result = "Identified LoRA types:\n"
for pattern_key, (name, count) in found_types.items():
result += f"- {name}: {count} matching entries\n"
# Additional checks for specific types
if "dora" in found_types:
# Check for DoRA's dual structure (both LoRA and delta weights)
lora_count = sum(1 for key in keys if re.search(r"\.lora_[AB]\.weight$", key))
d_count = sum(1 for key in keys if re.search(r"\.d_", key))
result += f" DoRA structure: {lora_count} LoRA weights, {d_count} delta weights\n"
# Look for network_alpha in metadata
if metadata and "ss_network_alpha" in metadata:
result += f"Network alpha: {metadata['ss_network_alpha']}\n"
# Look for network dimensions
if metadata and "ss_network_dim" in metadata:
result += f"Network dimension: {metadata['ss_network_dim']}\n"
return result.strip()
else:
# Check for other patterns suggesting a LoRA-like structure
lora_like = any("lora" in key.lower() for key in keys)
adapter_like = any("adapter" in key.lower() for key in keys)
if lora_like:
return "Appears to be a LoRA-like model with non-standard naming"
elif adapter_like:
return "Appears to be an adapter model with non-standard naming"
else:
return "Unknown model type - not recognized as a standard LoRA or LyCORIS variant"
def main():
parser = argparse.ArgumentParser(description="Identify LoRA and LyCORIS model types")
parser.add_argument("model_path", help="Path to the model file (.safetensors or .pt)")
args = parser.parse_args()
if not os.path.exists(args.model_path):
print(f"Error: Model file {args.model_path} not found")
return
result = identify_lora_type(args.model_path)
print("\n" + "="*40)
print(result)
print("="*40)
if __name__ == "__main__":
main()