Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,36 @@ object LicenseManagerFactory extends LazyLogging {
_strategy.get().getOrElse(new NopRefreshStrategy())

/** Replace the cached strategy. Initializes the new strategy before caching. Use for license
* upgrade/downgrade at runtime. Uses CAS loop to ensure atomic replacement.
* upgrade/downgrade at runtime. Shuts down the old strategy before replacing.
*/
def setStrategy(strategy: LicenseRefreshStrategy): Unit = {
strategy.initialize()
var old: Option[LicenseRefreshStrategy] = None
var updated = false
while (!updated) {
val current = _strategy.get()
updated = _strategy.compareAndSet(current, Some(strategy))
if (_strategy.compareAndSet(current, Some(strategy))) {
old = current
updated = true
}
}
old.foreach(_.shutdown())
}

/** Shutdown the cached strategy's background resources (if any) and clear the cache. Called
* during process shutdown to stop the refresh scheduler.
*/
def shutdown(): Unit = {
var old: Option[LicenseRefreshStrategy] = None
var updated = false
while (!updated) {
val current = _strategy.get()
if (_strategy.compareAndSet(current, None)) {
old = current
updated = true
}
}
old.foreach(_.shutdown())
}

/** Reset cached strategy (for testing). Uses CAS to ensure atomic clear. */
Expand All @@ -73,43 +94,50 @@ object LicenseManagerFactory extends LazyLogging {
else Some(LicenseMode.Driver)
}

/** Resolve strategy via SPI, initialize it, and cache it. Uses CAS to ensure only one strategy is
* created in concurrent scenarios.
/** Resolve strategy via SPI, initialize it, and cache it. Synchronized to prevent concurrent
* creation of duplicate strategies (which would leak resources like Akka schedulers).
*/
private def resolveStrategy(config: Config): LicenseRefreshStrategy =
_strategy.get() match {
case Some(s) => s
case None =>
val mode = resolveMode(config)
val loader = ServiceLoader.load(classOf[LicenseManagerSpi])
val spis = loader.iterator().asScala.toSeq.sortBy(_.priority)
val strategy = spis.headOption
.map { spi =>
try {
val s = spi.createStrategy(config, mode)
s.initialize()
logger.info(
s"License strategy initialized: ${s.getClass.getSimpleName} " +
s"(mode=${mode.getOrElse("default")}, type=${s.licenseManager.licenseType})"
)
s
} catch {
case e: Exception =>
logger.error(
s"Failed to create license strategy from ${spi.getClass.getName}: ${e.getMessage}",
e
)
val fallback = new NopRefreshStrategy()
fallback.initialize()
fallback
}
}
.getOrElse {
val fallback = new NopRefreshStrategy()
fallback.initialize()
fallback
synchronized {
// Double-check after acquiring lock
_strategy.get() match {
case Some(s) => s
case None =>
val mode = resolveMode(config)
val loader = ServiceLoader.load(classOf[LicenseManagerSpi])
val spis = loader.iterator().asScala.toSeq.sortBy(_.priority)
val strategy = spis.headOption
.map { spi =>
try {
val s = spi.createStrategy(config, mode)
s.initialize()
logger.info(
s"License strategy initialized: ${s.getClass.getSimpleName} " +
s"(mode=${mode.getOrElse("default")}, type=${s.licenseManager.licenseType})"
)
s
} catch {
case e: Exception =>
logger.error(
s"Failed to create license strategy from ${spi.getClass.getName}: ${e.getMessage}",
e
)
val fallback = new NopRefreshStrategy()
fallback.initialize()
fallback
}
}
.getOrElse {
val fallback = new NopRefreshStrategy()
fallback.initialize()
fallback
}
_strategy.set(Some(strategy))
strategy
}
_strategy.compareAndSet(None, Some(strategy))
_strategy.get().getOrElse(strategy)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,12 @@ trait LicenseRefreshStrategy {

/** Access the current LicenseManager for SHOW LICENSE and feature/quota checks. */
def licenseManager: LicenseManager

/** Telemetry collector for this strategy. Extensions update counters via this reference. Default
* returns `TelemetryCollector.Noop` (zero counters, never updated).
*/
def telemetryCollector: TelemetryCollector = TelemetryCollector.Noop

/** Shutdown background resources (scheduler, etc.). Default is no-op. */
def shutdown(): Unit = ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2025 SOFTNETWORK
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package app.softnetwork.elastic.licensing

import java.util.concurrent.atomic.AtomicLong

/** Snapshot of runtime telemetry for inclusion in refresh requests. All counters are cumulative
* since server startup (stateless -- reset on restart).
*
* Backend note: the backend receives cumulative totals per instance_id. To compute per-interval
* deltas, the backend stores the previous snapshot and subtracts. When queries_total drops below
* the previous value for the same instance_id, this indicates a process restart -- start a new
* session, don't compute negative delta.
*/
case class TelemetryData(
queriesTotal: Long = 0,
joinsTotal: Long = 0,
mvsActive: Int = 0,
clustersConnected: Int = 0
)

/** Mutable telemetry collector with atomic counters.
*
* Accessible (read-only snapshot) via `LicenseRefreshStrategy.telemetryCollector.collect()`.
* Extensions (e.g. CoreDqlExtension) update counters via increment/set methods through
* `LicenseManagerFactory.currentStrategy.telemetryCollector`.
*
* Thread-safe: counters use `AtomicLong`, gauges use `@volatile`.
*/
class TelemetryCollector {

private val _queriesTotal = new AtomicLong(0L)
private val _joinsTotal = new AtomicLong(0L)
@volatile private var _mvsActive: Int = 0
@volatile private var _clustersConnected: Int = 0

// --- Write methods (called by extensions) ---

def incrementQueries(): Unit = { val _ = _queriesTotal.incrementAndGet() }

def incrementJoins(): Unit = { val _ = _joinsTotal.incrementAndGet() }

def setMvsActive(count: Int): Unit = { _mvsActive = count }

def setClustersConnected(count: Int): Unit = { _clustersConnected = count }

// --- Read method (called by AutoRefreshStrategy.doScheduleRefresh) ---

def collect(): TelemetryData = TelemetryData(
queriesTotal = _queriesTotal.get(),
joinsTotal = _joinsTotal.get(),
mvsActive = _mvsActive,
clustersConnected = _clustersConnected
)
}

object TelemetryCollector {

/** Default collector returning zero-valued data. Used when telemetry is disabled or no runtime
* wires a real collector. Write methods are no-ops to prevent accidental mutation of the shared
* singleton.
*/
val Noop: TelemetryCollector = new TelemetryCollector {
override def incrementQueries(): Unit = ()
override def incrementJoins(): Unit = ()
override def setMvsActive(count: Int): Unit = ()
override def setClustersConnected(count: Int): Unit = ()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2025 SOFTNETWORK
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package app.softnetwork.elastic.licensing

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class TelemetryCollectorSpec extends AnyFlatSpec with Matchers {

// --- Noop collector ---

"TelemetryCollector.Noop" should "return zero-valued TelemetryData" in {
val data = TelemetryCollector.Noop.collect()
data shouldBe TelemetryData(0, 0, 0, 0)
}

// --- incrementQueries ---

"TelemetryCollector" should "increment queries counter atomically" in {
val collector = new TelemetryCollector
collector.incrementQueries()
collector.incrementQueries()
collector.incrementQueries()
collector.collect().queriesTotal shouldBe 3
}

// --- incrementJoins ---

it should "increment joins counter atomically" in {
val collector = new TelemetryCollector
collector.incrementJoins()
collector.incrementJoins()
collector.collect().joinsTotal shouldBe 2
}

// --- setMvsActive ---

it should "set MVs active gauge" in {
val collector = new TelemetryCollector
collector.setMvsActive(5)
collector.collect().mvsActive shouldBe 5
collector.setMvsActive(3)
collector.collect().mvsActive shouldBe 3
}

// --- setClustersConnected ---

it should "set clusters connected gauge" in {
val collector = new TelemetryCollector
collector.setClustersConnected(2)
collector.collect().clustersConnected shouldBe 2
}

// --- collect returns consistent snapshot ---

it should "return a consistent snapshot combining all counters" in {
val collector = new TelemetryCollector
collector.incrementQueries()
collector.incrementQueries()
collector.incrementJoins()
collector.setMvsActive(4)
collector.setClustersConnected(1)

val data = collector.collect()
data.queriesTotal shouldBe 2
data.joinsTotal shouldBe 1
data.mvsActive shouldBe 4
data.clustersConnected shouldBe 1
}

// --- concurrent access ---

it should "handle concurrent increments safely" in {
val collector = new TelemetryCollector
val threads = (1 to 10).map { _ =>
new Thread(() => {
(1 to 1000).foreach { _ =>
collector.incrementQueries()
collector.incrementJoins()
}
})
}
threads.foreach(_.start())
threads.foreach(_.join())

val data = collector.collect()
data.queriesTotal shouldBe 10000
data.joinsTotal shouldBe 10000
}
}
Loading