pasp.wlearn

 1import numpy as np
 2
 3def learn(P, D: np.ndarray, A: np.ndarray = None, niters: int = 30, alg: str = "em",
 4          lr: float = 0.001, batch: int = None, smoothing: float = 1e-4, lstable_sat: bool = True,
 5          display: str = "loglikelihood"):
 6  if alg == "fixpoint": alg = "em"
 7  # If batch is not given, set batch to the size of the dataset.
 8  if batch is None: batch = len(D)
 9  # Prepare training tensors.
10  for N in P.NR: N.prepare_train(batch)
11  for N in P.NA: N.prepare_train(batch)
12  # Check if data dimensions all match.
13  def assert_dims(N, e: int):
14    if N.learnable and ((o := len(N.data[0].train)) != e):
15      raise ValueError(f"Training data dimensions do not match!\n  Expected {n}, got {o}.")
16  n = len(D)
17  for N in P.NR: assert_dims(N, n)
18  for N in P.NA: assert_dims(N, n)
19
20  # Check if D is an np.ndarray or list.
21  if not (issubclass(t := type(D), list) or issubclass(t, np.ndarray)):
22    raise TypeError(f"Expected dataset of type list or numpy.ndarray, got {t}!")
23
24  # Batch mode.
25  if A is None:
26    if type(D) is np.ndarray: data = D if np.issubdtype(D.dtype, bytes) else D.astype(bytes)
27    else: data = D
28    from learn import learn_batch as clearn_batch
29    P.train()
30    clearn_batch(P, data, niters = niters, alg = alg, lr = lr, batch = batch,
31                 lstable_sat = lstable_sat, display = display, smoothing = smoothing)
32    P.eval()
33    return
34
35  # Non-batch mode.
36  if type(A) is not np.ndarray: atoms = np.array(A, dtype = bytes)
37  else: atoms = A if np.issubdtype(A.dtype, bytes) else A.astype(bytes)
38  if type(D) is not np.ndarray: data = np.array(D, dtype = np.uint8)
39  else: data = D if np.issubdtype(D.dtype, np.uint8) else D.astype(np.uint8)
40
41  obs, obs_counts = np.unique(data, axis = 0, return_counts = True)
42  from learn import learn as clearn
43  P.train()
44  clearn(P, obs, obs_counts, atoms, niters = niters, alg = alg, lr = lr, lstable_sat = lstable_sat)
45  P.eval()
def learn( P, D: numpy.ndarray, A: numpy.ndarray = None, niters: int = 30, alg: str = 'em', lr: float = 0.001, batch: int = None, smoothing: float = 0.0001, lstable_sat: bool = True, display: str = 'loglikelihood'):
 4def learn(P, D: np.ndarray, A: np.ndarray = None, niters: int = 30, alg: str = "em",
 5          lr: float = 0.001, batch: int = None, smoothing: float = 1e-4, lstable_sat: bool = True,
 6          display: str = "loglikelihood"):
 7  if alg == "fixpoint": alg = "em"
 8  # If batch is not given, set batch to the size of the dataset.
 9  if batch is None: batch = len(D)
10  # Prepare training tensors.
11  for N in P.NR: N.prepare_train(batch)
12  for N in P.NA: N.prepare_train(batch)
13  # Check if data dimensions all match.
14  def assert_dims(N, e: int):
15    if N.learnable and ((o := len(N.data[0].train)) != e):
16      raise ValueError(f"Training data dimensions do not match!\n  Expected {n}, got {o}.")
17  n = len(D)
18  for N in P.NR: assert_dims(N, n)
19  for N in P.NA: assert_dims(N, n)
20
21  # Check if D is an np.ndarray or list.
22  if not (issubclass(t := type(D), list) or issubclass(t, np.ndarray)):
23    raise TypeError(f"Expected dataset of type list or numpy.ndarray, got {t}!")
24
25  # Batch mode.
26  if A is None:
27    if type(D) is np.ndarray: data = D if np.issubdtype(D.dtype, bytes) else D.astype(bytes)
28    else: data = D
29    from learn import learn_batch as clearn_batch
30    P.train()
31    clearn_batch(P, data, niters = niters, alg = alg, lr = lr, batch = batch,
32                 lstable_sat = lstable_sat, display = display, smoothing = smoothing)
33    P.eval()
34    return
35
36  # Non-batch mode.
37  if type(A) is not np.ndarray: atoms = np.array(A, dtype = bytes)
38  else: atoms = A if np.issubdtype(A.dtype, bytes) else A.astype(bytes)
39  if type(D) is not np.ndarray: data = np.array(D, dtype = np.uint8)
40  else: data = D if np.issubdtype(D.dtype, np.uint8) else D.astype(np.uint8)
41
42  obs, obs_counts = np.unique(data, axis = 0, return_counts = True)
43  from learn import learn as clearn
44  P.train()
45  clearn(P, obs, obs_counts, atoms, niters = niters, alg = alg, lr = lr, lstable_sat = lstable_sat)
46  P.eval()