pasp.grammar

  1import pathlib, enum, math, collections.abc, os
  2import lark, lark.reconstruct, clingo, numpy
  3from numpy import ascontiguousarray as contiguous
  4from .program import ProbFact, Query, VarQuery, ProbRule, Program, CredalFact, unique_fact, \
  5  Semantics, Data
  6from .program import AnnotatedDisjunction, NeuralRule, NeuralAD, unique_pgrule_id
  7
  8def read(*files: str, G: lark.Lark = None, from_str: bool = False, start = "plp") -> lark.Tree:
  9  "Read all `files` and parse them with grammar `G`, returning a single `lark.Tree`."
 10  if G is None:
 11    try:
 12      with open(pathlib.Path(__file__).resolve().parent.joinpath("grammar.lark"), "r") as f:
 13        G = lark.Lark(f, start = start)
 14    except Exception as ex:
 15      raise ex
 16  T = None
 17  if from_str:
 18    try: return G.parse("\n".join(files))
 19    except Exception as ex: raise ex
 20  for fname in files:
 21    # For now, dump files entirely into memory (this isn't too much of a problem, since PLPs are
 22    # usually small). In the future, consider streaming batches of text instead for large files.
 23    try:
 24      with open(fname, "r") as f:
 25        text = f.read()
 26        if T is None: T = G.parse(text)
 27        else:
 28          U = G.parse(text)
 29          T.children.extend(u for u in U.children if u not in T.children)
 30    except Exception as ex:
 31      raise ex
 32  assert T is not None, "No file read."
 33  return T
 34
 35def getnths(X: iter, i: int) -> iter: return (x[i] for x in X)
 36def find(X: iter, v, d = None):
 37  try:
 38    i = X.index(v)
 39    return i, X[i]
 40  except ValueError: return -1, d
 41def push(L: list, X: iter):
 42  "Pushes X to L. If X is a list, then push all elements of X to L instead of the list object itself."
 43  if isinstance(X, list): L.extend(X)
 44  else: L.append(X)
 45
 46def lit2atom(x: str) -> str: return x[4:] if x[:4] == "not " else x
 47
 48class PreparsingTransformer(lark.Transformer):
 49  def __init__(self):
 50    super().__init__()
 51    self.consts = {}
 52    self.includes = set()
 53  def __default__(self, _, __, ___): return lark.visitors.Discard
 54  def SEMANTICS_OPT_LOGIC(self, O): return str(O)
 55  def SEMANTICS_OPT_PROB(self, _): return lark.visitors.Discard
 56  def semantics(self, S): return S[0] if len(S) > 0 else lark.visitors.Discard
 57  def WORD(self, W): return str(W)
 58  def ID(self, I): return int(I)
 59  def constdef(self, C):
 60    self.consts[C[0]] = C[1]
 61    return lark.visitors.Discard
 62  "Verify which logic semantic should be used and record constant definitions."
 63  def plp(self, S):
 64    return S[0] if len(S) > 0 else None, self.consts, self.includes
 65  def include(self, P):
 66    self.includes.update(map(lambda x: os.path.abspath(str(x)), P))
 67    return lark.visitors.Discard
 68
 69class StableTransformer(lark.Transformer):
 70  class Pack(tuple):
 71    @staticmethod
 72    def __new__(cls, tp: str, r: str = None, v = None, sc: dict = {}):
 73      return super(StableTransformer.Pack, cls).__new__(cls, (tp, str(v) if r is None else r, \
 74                                                              r if v is None else v, sc))
 75    def __str__(self): return self[1]
 76    def __repr__(self): return f"<{self[0]}: {self.__str__()}>"
 77
 78  def __init__(self, _, consts: dict = {}, scope: dict = None):
 79    super().__init__()
 80    self.sem = Semantics.STABLE
 81    self.torch_scope = {} if scope is None else scope
 82    self.n_prules = 0
 83    self.consts = consts
 84    self.varquery_id = 0
 85
 86  @staticmethod
 87  def pack(t: str, rep: str = None, val = None, scope: dict = {}) -> tuple[str, str, str, dict]:
 88    return StableTransformer.Pack(t, rep, val, scope)
 89
 90  @staticmethod
 91  def join_scope(A: list) -> dict: return dict((y, None) for S in A for y in S[3])
 92
 93  @staticmethod
 94  def find_data_pred(D: dict, body: list, which: str, name: str) -> list:
 95    t = None
 96    for d in body:
 97      if d[2][1] in D:
 98        t = D[d[2][1]]
 99        break
100    if t is None: raise ValueError(f"Neural {which} {name} must contain a data predicate!")
101    return t
102
103  @staticmethod
104  def check_data(D: list):
105    "Checks if all data have same first dimension size."
106    n, m = -1, -1
107    for X in D.values():
108      if n < 0: n = X[0].test.shape[0]
109      if (X[0].train is not None) and (m < 0): m = X[0].train.shape[0]
110      for x in X:
111        if x.test.shape[0] != n:
112          raise ValueError("Test data must have same number of instances!")
113        if (x.train is not None) and (x.train.shape[0] != m):
114          raise ValueError("Train data must have same number of instances!")
115
116  @staticmethod
117  def cont_head_sym(name: str, T: list, O: list, V: list = None):
118    g, S = None, None
119    if V is None:
120      if O is None: S = [f"{name}({t.arg})" for t in T]
121      else: S = [f"{name}({t.arg}, {o})" for t in T for o in O]
122    else:
123      if O is None: S = [f"{name}({t.arg}, {v})" for t in T for v in V]
124      else: S = [f"{name}({t.arg}, {v}, {o})" for t in T for o in O for v in V]
125    g = (clingo.parse_term(s)._rep for s in S)
126    return contiguous(tuple(g), dtype=numpy.uint64), S
127
128  @staticmethod
129  def register_nrule(TNR: list, NR: list, D: list):
130    for name, inp, O, net, body, rep, learnable, params in TNR:
131      t = StableTransformer.find_data_pred(D, body, "rule", name)
132      # Ground rules.
133      H, _ = StableTransformer.cont_head_sym(name, t, O)
134      B, S = None, None
135      if len(body) > 1:
136        # B and S do not depend on the number of outcomes |O|, only on |t| and |body|.
137        body_no_data = [b for b in body if b[2][1] != t[0].name]
138        B = contiguous(tuple(clingo.parse_term(f"{b[2][1]}({t[i].arg})" if len(b[3]) > 0 \
139                                                else lit2atom(b[1]))._rep for i in range(len(t)) \
140                              for b in body_no_data), dtype = numpy.uint64)
141        S = contiguous(tuple(b[2][0] for i in range(len(t)) for b in body_no_data), dtype = bool)
142      NR.append(NeuralRule(H, B, S, name, net, rep, t, learnable, params, O))
143
144  @staticmethod
145  def register_nad(TNA: list, NA: list, D: list):
146    for name, inp, vals, O, net, body, rep, learnable, params in TNA:
147      t = StableTransformer.find_data_pred(D, body, "AD", name)
148      # Ground rules.
149      V = list(vals.keys())
150      H, H_s = StableTransformer.cont_head_sym(name, t, O, V)
151      B, S = None, None
152      if len(body) > 1:
153        body_no_data = [b for b in body if b[2][1] != t[0].name]
154        # B and S do not depend on the number of values |V| or outcomes |O|, only on |t| and |body|.
155        B = contiguous(tuple(clingo.parse_term(f"{b[2][1]}({t[i].arg})" if len(b[3]) > 0 \
156                                                else lit2atom(b[1]))._rep for i in range(len(t)) \
157                              for b in body_no_data), dtype = numpy.uint64)
158        S = contiguous(tuple(b[2][0] for i in range(len(t)) for b in body_no_data), dtype = bool)
159      NA.append(NeuralAD(H, B, S, name, V, net, rep, t, learnable, params, O, H_s))
160
161  def __default__(self, _, __, ___): return lark.visitors.Discard
162
163  # Components which are directly translated to clingo.
164  def CMP_OP(self, o): return self.pack("CMP_OP", str(o))
165  def aggr(self, A): return self.pack("aggr", "".join(str(x) for x in A))
166  def raggr(self, A): return self.pack("raggr", "".join(str(x) for x in A))
167  def caggr(self, A): return self.pack("caggr", "".join(str(x) for x in A))
168
169  # Terminals.
170  def UND(self, u): return self.pack("UND", str(u))
171  def WORD(self, c): return self.pack("WORD", str(c))
172  def NEG(self, n): return self.pack("NEG", str(n))
173  def VAR(self, v):
174    x = str(v); X = {v: None}
175    return self.pack("VAR", x, scope = X)
176  def ID(self, i): return self.pack("ID", val = int(i))
177  def OP(self, o): return self.pack("OP", str(o))
178  def REAL(self, r): return self.pack("REAL", val = float(r))
179  def frac(self, f): return self.pack("frac", val = f[0][2]/f[1][2])
180  def prob(self, p): return self.pack("prob", val = p[0][2])
181  def SHARED(self, s): return self.pack("SHARED",  str(s))
182  def LEARN(self, l): return self.pack("LEARN", str(l))
183  def CONST(self, c): return self.pack("CONST", str(c))
184  def BOOL(self, b): return self.pack("BOOL", v := b.lower(), v != "false")
185  def NULL(self, n): return self.pack("NULL", None, None)
186
187  # Path.
188  def path(self, p): return self.pack(p[0].type, p[0].value)
189
190  # Set.
191  def set(self, S):
192    if S[0][0] == "interval":
193      a, b = S[0][2]
194      if isinstance(a, str):
195        if a not in self.consts: raise KeyError(f"Constant {a} is undefined!")
196        a = self.consts[a]
197      if isinstance(b, str):
198        if b not in self.consts: raise KeyError(f"Constant {b} is undefined!")
199        b = self.consts[b]
200      M = dict((str(i), None) for i in range(a, b+1))
201    else:
202      M = dict((x[1], None) for x in S)
203      if len(M) != len(S): raise ValueError("set must contain only unique constants!")
204    return self.pack("set", f"{{{','.join(x for x in M.keys())}}}", M)
205
206  # Intervals.
207  def interval(self, I): return self.pack("interval", f"{I[0][2]}..{I[1][2]}", (I[0][2], I[1][2]))
208
209  # Predicates.
210  def pred(self, P, replace_semicolons = False):
211    name = P[0][1]
212    rep = f"{name}({', '.join(getnths(P[1:], 1))})"
213    return self.pack("pred", rep.replace(";", ",") if replace_semicolons else rep, name, self.join_scope(P))
214  def grpred(self, P): return self.pred(P)
215  def query_pred(self, P): return self.pred(P, replace_semicolons = True)
216
217  # Literals.
218  def lit(self, P):
219    s = P[0][0] != "NEG"
220    return self.pack("lit", " ".join(getnths(P, 1)), (s, P[0][2] if s else P[1][2]), self.join_scope(P))
221  def grlit(self, P): return self.lit(P)
222
223  # Binary operations.
224  def bop(self, B) -> str: return self.pack("bop", " ".join(getnths(B, 1)))
225
226  # Facts.
227  def fact(self, F):
228    f = f"{''.join(getnths(F, 1))}"
229    # Facts are always grounded.
230    return self.pack("fact", f + ".", f)
231  def pfact(self, PF):
232    p, f = PF[0][2], PF[1][1]
233    return self.pack("pfact", "", ProbFact(p, f))
234  def cfact(self, CF):
235    l, u, f = CF[0][2], CF[1][2], CF[2][1]
236    return self.pack("cfact", "", CredalFact(l, u, f))
237  def lpfact(self, PF):
238    if PF[0][0] == "prob": p, f = PF[0][2], PF[1][1]
239    else: p, f = 0.5, PF[0][1]
240    return self.pack("pfact", "", ProbFact(p, f, learnable = True))
241
242  # Heads.
243  def head(self, H): return self.pack("head", ", ".join(getnths(H, 1)), H, self.join_scope(H))
244  def ohead(self, H): return self.pack("head", H[0][1], H[0][2], H[0][3])
245  # Body.
246  def body(self, B): return self.pack("body", ", ".join(getnths(B, 1)), B, self.join_scope(B))
247
248  # Rules.
249  def rule(self, R): return self.pack("rule", " :- ".join(getnths(R, 1)) + ".")
250  def prule(self, R):
251    l = "LEARN" in getnths(R, 0)
252    s = "SHARED" in getnths(R, 0)
253    h, b = R[-2], R[-1]
254    o = f"{h[1]} :- {b[1]}"
255    p = R[0][2] if R[0][0] == "prob" else 0.5
256    S = self.join_scope(R)
257    if len(S) == 0:
258      pr = ProbRule(p, o, is_prop = True, learnable = l)
259      self.n_prules += 1
260      return self.pack("prule", pr.prop_f, pr)
261    # Invariant: len(b) > 0, otherwise the rule is unsafe.
262    name = h[2]
263    # hscope is guaranteed to be ordered by Python dict's definition.
264    hscope = h[3]
265    body_preds = [x for x in b[2] if x[0] != "bop"]
266    h_s = ", ".join(hscope) + ", " if len(hscope) > 0 else ""
267    b_s = ", ".join(map(lambda x: f"1, {x[1]}" if x[2][0] else f"0, {x[1][4:]}", body_preds))
268    # If parameters are shared, then we require a special ID.
269    upr = -1 if not (s and l) else unique_pgrule_id()
270    # The number of body arguments is twice as we need to store the sugoal's sign and symbol.
271    rid = self.n_prules; self.n_prules += 1
272    u = f"{name}(@unify({rid}, {name}, {int(l)}, {upr}, {len(hscope)}, {2*len(body_preds)}, {h_s}{b_s})) :- {b[1]}."
273    return self.pack("prule", "", ProbRule(p, o, is_prop = False, unify = u, learnable = l,
274                                           sharing = s))
275
276  # Annotated disjunction head.
277  def ad_head(self, H):
278    P, F = [], []
279    for i in range(0, len(H), 2):
280      P.append(H[i][2])
281      F.append(H[i+1][1])
282    return self.pack("ad_head", F, P, self.join_scope(H))
283  # Learnable annotated disjunction head.
284  def lad_head(self, H: list):
285    P, F = [], []
286    i, o, j = 0, 0, 0
287    last = None
288    while i < len(H):
289      a = H[i]
290      if a[0] == "prob":
291        P.append(a[2])
292        F.append(H[i+1][1])
293        i += 2
294      else:
295        P.append(-1)
296        F.append(a[1])
297        i += 1; o += 1
298        last = j
299      j += 1
300    if o > 0:
301      P_s = sum(P)+o
302      # If probs were not explicitly given, assume maximum uncertainty and set to uniform.
303      s = round((1.0-P_s)/o, ndigits = 15)
304      ts = P_s+s*(o-1)
305      for i, p in enumerate(P):
306        if i == last: P[i] = 1.0-ts
307        elif p < 0: P[i] = s
308    return self.pack("lad_head", F, P, self.join_scope(H))
309  # Annotated disjunctions.
310  def ad(self, AD):
311    P, F, learnable = AD[0][2], AD[0][1], AD[0][0] == "lad_head"
312    if not math.isclose(s := sum(P), 1.0):
313      P.append(1-s)
314      F.append(unique_fact())
315    return self.pack("ad", "", AnnotatedDisjunction(P, F, learnable), AD[0][3])
316  def adr(self, AD):
317    raise NotImplementedError
318
319  def py_func_args(self, A): return "args", A[0][2]
320  def py_func_kwargs(self, A): return "kwargs", (A[0][2], A[1][2])
321  def py_func_call(self, A):
322    args = [a for k, a in A[1:] if k == "args"]
323    kwargs = dict([a for k, a in A[1:] if k == "kwargs"])
324    f = A[0][2]
325    if f not in self.torch_scope:
326      raise ValueError(f"No data definition {f} found! Either define it in a Python "
327                       "block or specify a file or URL to read from.")
328    return self.pack("py_func_call", "", self.torch_scope[f](*args, **kwargs))
329
330  def _data2tensor(self, D):
331    tp, path_or_data = D[0][0], D[0][2]
332    import pandas, numpy
333    # Is an external file or URL.
334    if tp != "py_func_call": data = pandas.read_csv(path_or_data, dtype = numpy.float32)
335    else: # is data.
336      try:
337        import torch
338      except ModuleNotFoundError:
339        raise ModuleNotFoundError("PyTorch not found! PyTorch must be installed for neural rules "
340                                  "and neural ADs.")
341      if not issubclass(type(path_or_data), torch.Tensor): path_or_data = torch.tensor(path_or_data)
342    return path_or_data
343
344  # Test set special predicate.
345  def test(self, T): return self.pack("test", "", self._data2tensor(T))
346  # Train set special predicate.
347  def train(self, R): return self.pack("train", "", self._data2tensor(R))
348
349  # Data special predicate.
350  def data(self, D):
351    name, arg = D[0][1], D[1][1]
352    test, train = D[2][2], D[3][2] if len(D) > 3 else None
353    return self.pack("data", f"{name}({arg}).", Data(name, arg, test, train))
354
355  # Python block.
356  def python(self, T):
357    exec("import torch\n\n" + T[0].value, self.torch_scope)
358    return self.pack("python", "")
359
360  # Local hubconf repo.
361  def LOCAL_NET(self, L): return self.pack("LOCAL_NET", str(L))
362  # GitHub hubconf repo.
363  def GITHUB(self, H): return self.pack("GITHUB", str(H))
364  # Python function.
365  def PY_FUNC(self, P): return self.pack("PY_FUNC", str(P))
366
367  # Hub network.
368  def hub(self, H):
369    # Function name or entrypoint.
370    func = H[0][1]
371    # Network is coming from a Torch block.
372    if len(H) == 1:
373      if func not in self.torch_scope:
374        raise ValueError(f"No network definition {func} found! Either define it in a Python"
375                         "block or specify a PyTorch Hub model (local or from GitHub).")
376      N = self.torch_scope[func]()
377      rep = f"@{func}"
378    # Network is coming from PyTorch Hub.
379    else:
380      try:
381        import torch
382      except ModuleNotFoundError:
383        raise ModuleNotFoundError("PyTorch not found! PyTorch must be installed for neural rules "
384                                  "and neural ADs.")
385      path, source = H[1][1], "github" if H[1][0] == "GITHUB" else "local"
386      N = torch.hub.load(path, func, source = source, trust_repo = "check")
387      rep = f"@{func} on \"{path}\" at \"{source}\""
388    return self.pack("hub", "", (N, rep))
389
390  # Optimizer parameters.
391  def params(self, P):
392    return self.pack("params", "", {P[i][1]: v[2] if isinstance((v := P[i+1]), self.Pack) else str(v) for i in range(0, len(P), 2)})
393
394  # Neural rule.
395  def nrule(self, A):
396    learnable = A[0][0] == "LEARN"
397    name = A[1][1]
398    inp = A[2][1]
399    offset = 3
400    outcomes = None
401    # Has more than one outcome within the neural network.
402    if A[offset][0] == "set":
403      outcomes = list(A[offset][2].keys())
404      offset += 1
405    net, hub_repr = A[offset][2]
406    if A[offset+1][0] == "params":
407      params = A[offset+1][2]
408      body = A[offset+2:]
409    else:
410      params = {}
411      body = A[offset+1:]
412    scope = self.join_scope(A)
413
414    if len(scope) != 1:  raise ValueError(f"Neural rule {name} is not grounded!")
415    if inp not in scope: raise ValueError(f"Neural rule {name} is unsafe!")
416
417    rep = f"{A[0][1]}::{name}({inp}{'' if outcomes is None else f'; {A[offset-1][1]}'}) as {hub_repr} :- {', '.join(getnths(body, 1))}."
418    return self.pack("nrule", "", (name, inp, outcomes, net, body, rep, learnable, params))
419
420  # Neural annotated disjunction.
421  def nad(self, A):
422    learnable = A[0][0] == "LEARN"
423    name = A[1][1]
424    inp = A[2][1]
425    vals = A[3][2]
426    outcomes = None
427    offset = 4
428    if A[offset][0] == "set":
429      outcomes = list(A[offset][2].keys())
430      offset += 1
431    net, hub_repr = A[offset][2]
432    if A[offset+1][0] == "params":
433      params = A[offset+1][2]
434      body = A[offset+2:]
435    else:
436      params = {}
437      body = A[offset+1:]
438    scope = self.join_scope(A)
439
440    if len(scope) != 1:  raise ValueError(f"Neural annotated disjunction {name} is not grounded!")
441    if inp not in scope: raise ValueError(f"Neural annotated disjunction {name} is unsafe!")
442
443    rep = f"{A[0][1]}::{name}({inp}, {A[3][1]}{'' if outcomes is None else f'; {A[offset-1][1]}'}) as {hub_repr} :- {', '.join(getnths(body, 1))}."
444    return self.pack("nad", "", (name, inp, vals, outcomes, net, body, rep, learnable, params))
445
446  # Constraint.
447  def constraint(self, C): return self.pack("constraint", f":- {C[0][1]}.")
448
449  # Query elements.
450  def qelement(self, E):
451    return self.pack("qelement", " ".join(getnths(E, 1)), scope = self.join_scope(E))
452  # Interpretations.
453  def interp(self, I):
454    return self.pack("interp", "", getnths(I, 1), scope = self.join_scope(I))
455  # Queries.
456  def query(self, Q):
457    Sc = self.join_scope(Q)
458    if len(Sc) > 0:
459      P = self.pack("varquery", "", VarQuery(self.varquery_id, list(Q[0][2]), \
460                                             list(Q[1][2]) if len(Q) > 1 else [], \
461                                             semantics = self.sem))
462      self.varquery_id += 1
463      return P
464    return self.pack("query", "", Query(Q[0][2], Q[1][2] if len(Q) > 1 else [], semantics = self.sem))
465
466  # Constant definition.
467  def constdef(self, C): return self.pack("constdef", f"#const {C[0][1]} = {C[1][1]}.")
468
469  @staticmethod
470  def path2obs(path: str):
471    import pandas, numpy
472    data = pandas.read_csv(path, dtype = int)
473    return lambda: (data.values, data.columns.values.tolist())
474
475  # Learning directive.
476  def learn(self, L):
477    A = {str(L[i]): str(v) if isinstance(v := L[i+1], lark.Token) else v[2] for i in range(1, len(L), 2)}
478    data = self.torch_scope[L[0][1]] if L[0][0] == "PY_FUNC" else StableTransformer.path2obs(L[0][1])
479    return self.pack("directive", "", ("learn", data, A))
480
481  # Include directive.
482  def include(self, F): return lark.visitors.Discard
483
484  def exact_inf(self, I): return ("inference", "exact", tuple())
485  def aseo_inf(self, I): return ("inference", "aseo", (I[0][2],))
486  def inference(self, I): return self.pack("directive", "", I[0])
487
488  # Semantics directive and options.
489  def SEMANTICS_OPT_LOGIC(self, _): return lark.visitors.Discard
490  def SEMANTICS_OPT_PROB(self, O): return str(O)
491  def semantics(self, S):
492    return self.pack("directive", "", ("psemantics", {"psemantics": S[0]})) if len(S) > 0 else \
493      lark.visitors.Discard
494
495  # Probabilistic Logic Program.
496  def plp(self, C) -> Program:
497    # Logic Program.
498    P  = []
499    # Probabilistic Facts.
500    PF = []
501    # Probabilistic Rules.
502    PR = []
503    # Queries.
504    Q  = []
505    # VarQueries.
506    VQ = []
507    # Credal Facts.
508    CF = []
509    # Annotated Disjunction.
510    AD = []
511    # Neural arguments and data.
512    TNR, TNA = [], []
513    D = {}
514    # Actual neural rules and neural ADs.
515    NR, NA = [], []
516    # Directives.
517    directives = {"inference": ("exact", tuple())}
518    # Mapping.
519    M = {"pfact": PF, "prule": PR, "query": Q, "varquery": VQ, "cfact": CF, "ad": AD, "nrule": TNR,
520         "nad": TNA}
521    for t, L, O, _ in C:
522      if len(L) > 0: push(P, L)
523      if t in M: push(M[t], O)
524      if t == "data":
525        if O.name in D: D[O.name].append(O)
526        else: D[O.name] = [O]
527      if t == "directive": directives[O[0]] = tup if len(tup := O[1:]) > 1 else tup[0]
528    # Deal with ungrounded probabilistic rules.
529    for r in PR:
530      if r.is_prop: PF.append(r.prop_pf)
531    self.check_data(D)
532    self.register_nrule(TNR, NR, D)
533    self.register_nad(TNA, NA, D)
534    return Program("\n".join(P), PF, PR, Q, VQ, CF, AD, NR, NA, semantics = self.sem, \
535                   directives = directives)
536
537class PartialTransformer(StableTransformer):
538  def __init__(self, sem: str, consts: dict = {}):
539    super().__init__(sem, consts)
540    self.PT = set()
541    if sem == "lstable":
542      self.sem = Semantics.LSTABLE
543    elif sem == "smproblog":
544      self.sem = Semantics.SMPROBLOG
545    else:
546      self.sem = Semantics.PARTIAL
547    self.o_tree = None
548
549  @staticmethod
550  def has_binop(x: str): return ("=" in x) or ("<" in x) or (">" in x)
551
552  def fact(self, F):
553    T = super().fact(F)
554    self.PT.add(T[2])
555    return T
556
557  def pfact(self, PF):
558    T = super().pfact(PF)
559    self.PT.add(T[2].f)
560    return T
561
562  def cfact(self, CF):
563    T = super().cfact(CF)
564    self.PT.add(T[2].f)
565    return T
566
567  def rule(self, R):
568    b1 = ", ".join(map(lambda x: x[1] if x[2][0] else f"not _{x[1][4:]}", R[1][2]))
569    b2 = ", ".join(map(lambda x: f"_{x[1]}" if x[2][0] or PartialTransformer.has_binop(x) else x[1], R[1][2]))
570    h1, h2 = R[0][1], ", ".join(map(lambda x: f"_{x[1]}", R[0][2]))
571    for h in R[0][2]: self.PT.add(h[1])
572    # for x in r[1][3]:
573      # if not PartialTransformer.has_binop(x): self.PT.add(x[4:] if x[:4] == "not " else x)
574    return self.pack("rule", [f"{h1} :- {b1}.", f"{h2} :- {b2}."])
575
576  def prule(self, R):
577    l = "LEARN" in getnths(R, 0)
578    s = "SHARED" in getnths(R, 0)
579    p = R[0][2] if R[0][0] == "prob" else 0.5
580    h, b = R[-2], R[-1]
581    tr_negs = lambda x: x[1] if x[2][0] else f"not _{x[1][4:]}"
582    tr_pos  = lambda x: f"_{x[1]}" if x[2][0] or PartialTransformer.has_binop(x) else x[1]
583    b1 = ", ".join(map(tr_negs, b[2]))
584    b2 = ", ".join(map(tr_pos, b[2]))
585    o1, o2 = f"{h[1]} :- {b1}", f"_{h[1]} :- {b2}"
586    self.PT.add(h[1])
587    uid = unique_fact()
588    S = self.join_scope(R)
589    if len(S) == 0:
590      pr1, pr2 = ProbRule(p, o1, ufact = uid, learnable = l), ProbRule(p, o2, ufact = uid)
591      self.n_prules += 2
592      return self.pack("prule", [pr1.prop_f, pr2.prop_f], [pr1, pr2])
593    # Invariant: len(b) > 0, otherwise the rule is unsafe.
594    name = h[2]
595    hscope = h[3]
596    body_preds = [x for x in b[2] if x[0] != "bop"]
597    h_s = ", ".join(hscope) + ", " if len(hscope) > 0 else ""
598    b1_s = ", ".join(map(lambda x: f"1, {x[1]}" if x[2][0] else f"0, _{x[1][4:]}", body_preds))
599    # If parameters are shared, then we require a special ID.
600    upr = -1 if not(s and l) else unique_pgrule_id()
601    # Let the grounder deal with the _f rule.
602    rid = self.n_prules; self.n_prules += 1
603    u1 = f"{name}(@unify({rid}, {name}, {int(l)}, {upr}, {len(hscope)}, {2*len(body_preds)}, {h_s}{b1_s})) :- {b1}."
604    return self.pack("prule", "", ProbRule(p, o1, is_prop = False, unify = u1, learnable = l))
605
606  def plp(self, C: list[tuple]) -> Program:
607    # Logic Program.
608    P  = []
609    # Probabilistic Facts.
610    PF = []
611    # Probabilistic Rules.
612    PR = []
613    # Queries.
614    Q  = []
615    # Variable queries.
616    VQ = []
617    # Credal Facts.
618    CF = []
619    # Annotated Disjunction.
620    AD = []
621    # Neural arguments and data.
622    TNR, TNA = [], []
623    D = {}
624    # Neural rules and ADs.
625    NR, NA = [], []
626    # Directives.
627    directives = {"inference": ("exact", tuple())}
628    # Mapping.
629    M = {"pfact": PF, "prule": PR, "query": Q, "varquery": VQ, "cfact": CF, "ad": AD, "nrule": TNR,
630         "nad": TNA}
631    for t, L, O, _ in C:
632      if len(L) > 0: push(P, L)
633      if t in M: push(M[t], O)
634      if t == "prule" and isinstance(O, collections.abc.Iterable) and O[0].is_prop:
635        PF.append(O[0].prop_pf)
636      if t == "directive": directives[O[0]] = tup if len(tup := O[1:]) > 1 else tup[0]
637    P.extend(f"_{x} :- {x}." for x in self.PT)
638    self.check_data(D)
639    self.register_nrule(TNR, NR, D)
640    self.register_nad(TNA, NA, D)
641    return Program("\n".join(P), PF, PR, Q, VQ, CF, AD, NR, NA, semantics = self.sem, \
642                   stable_p = self.stable_p, directives = directives)
643
644  def transform(self, tree):
645    self.o_tree = tree
646    self.stable_p = StableTransformer(self.sem).transform(tree)
647    return super().transform(tree)
648
649def _flatten_includes(*files: str, G: lark.Lark = None, from_str: bool = False) -> tuple:
650  transf = PreparsingTransformer()
651  T = read(*files, G=G, from_str=from_str)
652  sem, consts, includes = transf.transform(T)
653  _files = set() if from_str else set(map(os.path.abspath, set(files)))
654  to_parse = includes.difference(_files)
655  _files.update(includes)
656  while len(to_parse) > 0:
657    _T = read(*to_parse, G=G, from_str=False)
658    _sem, _consts, includes = transf.transform(_T)
659    if _sem is not None: sem = _sem
660    consts.update(_consts)
661    T.children.extend(u for u in _T.children if u not in T.children)
662    to_parse = includes.difference(_files)
663    _files.update(includes)
664  return sem, consts, T
665
666def parse(*files: str, G: lark.Lark = None, from_str: bool = False, semantics: str = "stable") -> Program:
667  """Either parses `streams` as blocks of text containing the PLP when `from_str = True`, or
668  interprets `streams` as filenames to be read and parsed into a `Program`."""
669  if semantics not in parse.trans_map:
670    raise ValueError("semantics not supported (must either be 'stable', 'partial' or 'lstable')!")
671  sem, consts,  T = _flatten_includes(*files, G=G, from_str=from_str)
672  if sem is not None: semantics = sem
673  return parse.trans_map[semantics](semantics, consts).transform(T)
674parse.trans_map = {}
675parse.trans_map["stable"] = StableTransformer
676parse.trans_map["lstable"] = PartialTransformer
677parse.trans_map["partial"] = PartialTransformer
678parse.trans_map["smproblog"] = PartialTransformer
def read( *files: str, G: lark.lark.Lark = None, from_str: bool = False, start='plp') -> lark.tree.Tree:
 9def read(*files: str, G: lark.Lark = None, from_str: bool = False, start = "plp") -> lark.Tree:
10  "Read all `files` and parse them with grammar `G`, returning a single `lark.Tree`."
11  if G is None:
12    try:
13      with open(pathlib.Path(__file__).resolve().parent.joinpath("grammar.lark"), "r") as f:
14        G = lark.Lark(f, start = start)
15    except Exception as ex:
16      raise ex
17  T = None
18  if from_str:
19    try: return G.parse("\n".join(files))
20    except Exception as ex: raise ex
21  for fname in files:
22    # For now, dump files entirely into memory (this isn't too much of a problem, since PLPs are
23    # usually small). In the future, consider streaming batches of text instead for large files.
24    try:
25      with open(fname, "r") as f:
26        text = f.read()
27        if T is None: T = G.parse(text)
28        else:
29          U = G.parse(text)
30          T.children.extend(u for u in U.children if u not in T.children)
31    except Exception as ex:
32      raise ex
33  assert T is not None, "No file read."
34  return T

Read all files and parse them with grammar G, returning a single lark.Tree.

def getnths(X: <built-in function iter>, i: int) -> <built-in function iter>:
36def getnths(X: iter, i: int) -> iter: return (x[i] for x in X)
def find(X: <built-in function iter>, v, d=None):
37def find(X: iter, v, d = None):
38  try:
39    i = X.index(v)
40    return i, X[i]
41  except ValueError: return -1, d
def push(L: list, X: <built-in function iter>):
42def push(L: list, X: iter):
43  "Pushes X to L. If X is a list, then push all elements of X to L instead of the list object itself."
44  if isinstance(X, list): L.extend(X)
45  else: L.append(X)

Pushes X to L. If X is a list, then push all elements of X to L instead of the list object itself.

def lit2atom(x: str) -> str:
47def lit2atom(x: str) -> str: return x[4:] if x[:4] == "not " else x
class PreparsingTransformer(lark.visitors._Decoratable, abc.ABC, typing.Generic[~_Leaf_T, ~_Return_T]):
49class PreparsingTransformer(lark.Transformer):
50  def __init__(self):
51    super().__init__()
52    self.consts = {}
53    self.includes = set()
54  def __default__(self, _, __, ___): return lark.visitors.Discard
55  def SEMANTICS_OPT_LOGIC(self, O): return str(O)
56  def SEMANTICS_OPT_PROB(self, _): return lark.visitors.Discard
57  def semantics(self, S): return S[0] if len(S) > 0 else lark.visitors.Discard
58  def WORD(self, W): return str(W)
59  def ID(self, I): return int(I)
60  def constdef(self, C):
61    self.consts[C[0]] = C[1]
62    return lark.visitors.Discard
63  "Verify which logic semantic should be used and record constant definitions."
64  def plp(self, S):
65    return S[0] if len(S) > 0 else None, self.consts, self.includes
66  def include(self, P):
67    self.includes.update(map(lambda x: os.path.abspath(str(x)), P))
68    return lark.visitors.Discard

Transformers work bottom-up (or depth-first), starting with visiting the leaves and working their way up until ending at the root of the tree.

For each node visited, the transformer will call the appropriate method (callbacks), according to the node's data, and use the returned value to replace the node, thereby creating a new tree structure.

Transformers can be used to implement map & reduce patterns. Because nodes are reduced from leaf to root, at any point the callbacks may assume the children have already been transformed (if applicable).

If the transformer cannot find a method with the right name, it will instead call __default__, which by default creates a copy of the node.

To discard a node, return Discard (lark.visitors.Discard).

Transformer can do anything Visitor can do, but because it reconstructs the tree, it is slightly less efficient.

A transformer without methods essentially performs a non-memoized partial deepcopy.

All these classes implement the transformer interface:

  • Transformer - Recursively transforms the tree. This is the one you probably want.
  • Transformer_InPlace - Non-recursive. Changes the tree in-place instead of returning new instances
  • Transformer_InPlaceRecursive - Recursive. Changes the tree in-place instead of returning new instances

Parameters: visit_tokens (bool, optional): Should the transformer visit tokens in addition to rules. Setting this to False is slightly faster. Defaults to True. (For processing ignored tokens, use the lexer_callbacks options)

consts
includes
def SEMANTICS_OPT_LOGIC(self, O):
55  def SEMANTICS_OPT_LOGIC(self, O): return str(O)
def SEMANTICS_OPT_PROB(self, _):
56  def SEMANTICS_OPT_PROB(self, _): return lark.visitors.Discard
def semantics(self, S):
57  def semantics(self, S): return S[0] if len(S) > 0 else lark.visitors.Discard
def WORD(self, W):
58  def WORD(self, W): return str(W)
def ID(self, I):
59  def ID(self, I): return int(I)
def constdef(self, C):
60  def constdef(self, C):
61    self.consts[C[0]] = C[1]
62    return lark.visitors.Discard
def plp(self, S):
64  def plp(self, S):
65    return S[0] if len(S) > 0 else None, self.consts, self.includes
def include(self, P):
66  def include(self, P):
67    self.includes.update(map(lambda x: os.path.abspath(str(x)), P))
68    return lark.visitors.Discard
class StableTransformer(lark.visitors._Decoratable, abc.ABC, typing.Generic[~_Leaf_T, ~_Return_T]):
 70class StableTransformer(lark.Transformer):
 71  class Pack(tuple):
 72    @staticmethod
 73    def __new__(cls, tp: str, r: str = None, v = None, sc: dict = {}):
 74      return super(StableTransformer.Pack, cls).__new__(cls, (tp, str(v) if r is None else r, \
 75                                                              r if v is None else v, sc))
 76    def __str__(self): return self[1]
 77    def __repr__(self): return f"<{self[0]}: {self.__str__()}>"
 78
 79  def __init__(self, _, consts: dict = {}, scope: dict = None):
 80    super().__init__()
 81    self.sem = Semantics.STABLE
 82    self.torch_scope = {} if scope is None else scope
 83    self.n_prules = 0
 84    self.consts = consts
 85    self.varquery_id = 0
 86
 87  @staticmethod
 88  def pack(t: str, rep: str = None, val = None, scope: dict = {}) -> tuple[str, str, str, dict]:
 89    return StableTransformer.Pack(t, rep, val, scope)
 90
 91  @staticmethod
 92  def join_scope(A: list) -> dict: return dict((y, None) for S in A for y in S[3])
 93
 94  @staticmethod
 95  def find_data_pred(D: dict, body: list, which: str, name: str) -> list:
 96    t = None
 97    for d in body:
 98      if d[2][1] in D:
 99        t = D[d[2][1]]
100        break
101    if t is None: raise ValueError(f"Neural {which} {name} must contain a data predicate!")
102    return t
103
104  @staticmethod
105  def check_data(D: list):
106    "Checks if all data have same first dimension size."
107    n, m = -1, -1
108    for X in D.values():
109      if n < 0: n = X[0].test.shape[0]
110      if (X[0].train is not None) and (m < 0): m = X[0].train.shape[0]
111      for x in X:
112        if x.test.shape[0] != n:
113          raise ValueError("Test data must have same number of instances!")
114        if (x.train is not None) and (x.train.shape[0] != m):
115          raise ValueError("Train data must have same number of instances!")
116
117  @staticmethod
118  def cont_head_sym(name: str, T: list, O: list, V: list = None):
119    g, S = None, None
120    if V is None:
121      if O is None: S = [f"{name}({t.arg})" for t in T]
122      else: S = [f"{name}({t.arg}, {o})" for t in T for o in O]
123    else:
124      if O is None: S = [f"{name}({t.arg}, {v})" for t in T for v in V]
125      else: S = [f"{name}({t.arg}, {v}, {o})" for t in T for o in O for v in V]
126    g = (clingo.parse_term(s)._rep for s in S)
127    return contiguous(tuple(g), dtype=numpy.uint64), S
128
129  @staticmethod
130  def register_nrule(TNR: list, NR: list, D: list):
131    for name, inp, O, net, body, rep, learnable, params in TNR:
132      t = StableTransformer.find_data_pred(D, body, "rule", name)
133      # Ground rules.
134      H, _ = StableTransformer.cont_head_sym(name, t, O)
135      B, S = None, None
136      if len(body) > 1:
137        # B and S do not depend on the number of outcomes |O|, only on |t| and |body|.
138        body_no_data = [b for b in body if b[2][1] != t[0].name]
139        B = contiguous(tuple(clingo.parse_term(f"{b[2][1]}({t[i].arg})" if len(b[3]) > 0 \
140                                                else lit2atom(b[1]))._rep for i in range(len(t)) \
141                              for b in body_no_data), dtype = numpy.uint64)
142        S = contiguous(tuple(b[2][0] for i in range(len(t)) for b in body_no_data), dtype = bool)
143      NR.append(NeuralRule(H, B, S, name, net, rep, t, learnable, params, O))
144
145  @staticmethod
146  def register_nad(TNA: list, NA: list, D: list):
147    for name, inp, vals, O, net, body, rep, learnable, params in TNA:
148      t = StableTransformer.find_data_pred(D, body, "AD", name)
149      # Ground rules.
150      V = list(vals.keys())
151      H, H_s = StableTransformer.cont_head_sym(name, t, O, V)
152      B, S = None, None
153      if len(body) > 1:
154        body_no_data = [b for b in body if b[2][1] != t[0].name]
155        # B and S do not depend on the number of values |V| or outcomes |O|, only on |t| and |body|.
156        B = contiguous(tuple(clingo.parse_term(f"{b[2][1]}({t[i].arg})" if len(b[3]) > 0 \
157                                                else lit2atom(b[1]))._rep for i in range(len(t)) \
158                              for b in body_no_data), dtype = numpy.uint64)
159        S = contiguous(tuple(b[2][0] for i in range(len(t)) for b in body_no_data), dtype = bool)
160      NA.append(NeuralAD(H, B, S, name, V, net, rep, t, learnable, params, O, H_s))
161
162  def __default__(self, _, __, ___): return lark.visitors.Discard
163
164  # Components which are directly translated to clingo.
165  def CMP_OP(self, o): return self.pack("CMP_OP", str(o))
166  def aggr(self, A): return self.pack("aggr", "".join(str(x) for x in A))
167  def raggr(self, A): return self.pack("raggr", "".join(str(x) for x in A))
168  def caggr(self, A): return self.pack("caggr", "".join(str(x) for x in A))
169
170  # Terminals.
171  def UND(self, u): return self.pack("UND", str(u))
172  def WORD(self, c): return self.pack("WORD", str(c))
173  def NEG(self, n): return self.pack("NEG", str(n))
174  def VAR(self, v):
175    x = str(v); X = {v: None}
176    return self.pack("VAR", x, scope = X)
177  def ID(self, i): return self.pack("ID", val = int(i))
178  def OP(self, o): return self.pack("OP", str(o))
179  def REAL(self, r): return self.pack("REAL", val = float(r))
180  def frac(self, f): return self.pack("frac", val = f[0][2]/f[1][2])
181  def prob(self, p): return self.pack("prob", val = p[0][2])
182  def SHARED(self, s): return self.pack("SHARED",  str(s))
183  def LEARN(self, l): return self.pack("LEARN", str(l))
184  def CONST(self, c): return self.pack("CONST", str(c))
185  def BOOL(self, b): return self.pack("BOOL", v := b.lower(), v != "false")
186  def NULL(self, n): return self.pack("NULL", None, None)
187
188  # Path.
189  def path(self, p): return self.pack(p[0].type, p[0].value)
190
191  # Set.
192  def set(self, S):
193    if S[0][0] == "interval":
194      a, b = S[0][2]
195      if isinstance(a, str):
196        if a not in self.consts: raise KeyError(f"Constant {a} is undefined!")
197        a = self.consts[a]
198      if isinstance(b, str):
199        if b not in self.consts: raise KeyError(f"Constant {b} is undefined!")
200        b = self.consts[b]
201      M = dict((str(i), None) for i in range(a, b+1))
202    else:
203      M = dict((x[1], None) for x in S)
204      if len(M) != len(S): raise ValueError("set must contain only unique constants!")
205    return self.pack("set", f"{{{','.join(x for x in M.keys())}}}", M)
206
207  # Intervals.
208  def interval(self, I): return self.pack("interval", f"{I[0][2]}..{I[1][2]}", (I[0][2], I[1][2]))
209
210  # Predicates.
211  def pred(self, P, replace_semicolons = False):
212    name = P[0][1]
213    rep = f"{name}({', '.join(getnths(P[1:], 1))})"
214    return self.pack("pred", rep.replace(";", ",") if replace_semicolons else rep, name, self.join_scope(P))
215  def grpred(self, P): return self.pred(P)
216  def query_pred(self, P): return self.pred(P, replace_semicolons = True)
217
218  # Literals.
219  def lit(self, P):
220    s = P[0][0] != "NEG"
221    return self.pack("lit", " ".join(getnths(P, 1)), (s, P[0][2] if s else P[1][2]), self.join_scope(P))
222  def grlit(self, P): return self.lit(P)
223
224  # Binary operations.
225  def bop(self, B) -> str: return self.pack("bop", " ".join(getnths(B, 1)))
226
227  # Facts.
228  def fact(self, F):
229    f = f"{''.join(getnths(F, 1))}"
230    # Facts are always grounded.
231    return self.pack("fact", f + ".", f)
232  def pfact(self, PF):
233    p, f = PF[0][2], PF[1][1]
234    return self.pack("pfact", "", ProbFact(p, f))
235  def cfact(self, CF):
236    l, u, f = CF[0][2], CF[1][2], CF[2][1]
237    return self.pack("cfact", "", CredalFact(l, u, f))
238  def lpfact(self, PF):
239    if PF[0][0] == "prob": p, f = PF[0][2], PF[1][1]
240    else: p, f = 0.5, PF[0][1]
241    return self.pack("pfact", "", ProbFact(p, f, learnable = True))
242
243  # Heads.
244  def head(self, H): return self.pack("head", ", ".join(getnths(H, 1)), H, self.join_scope(H))
245  def ohead(self, H): return self.pack("head", H[0][1], H[0][2], H[0][3])
246  # Body.
247  def body(self, B): return self.pack("body", ", ".join(getnths(B, 1)), B, self.join_scope(B))
248
249  # Rules.
250  def rule(self, R): return self.pack("rule", " :- ".join(getnths(R, 1)) + ".")
251  def prule(self, R):
252    l = "LEARN" in getnths(R, 0)
253    s = "SHARED" in getnths(R, 0)
254    h, b = R[-2], R[-1]
255    o = f"{h[1]} :- {b[1]}"
256    p = R[0][2] if R[0][0] == "prob" else 0.5
257    S = self.join_scope(R)
258    if len(S) == 0:
259      pr = ProbRule(p, o, is_prop = True, learnable = l)
260      self.n_prules += 1
261      return self.pack("prule", pr.prop_f, pr)
262    # Invariant: len(b) > 0, otherwise the rule is unsafe.
263    name = h[2]
264    # hscope is guaranteed to be ordered by Python dict's definition.
265    hscope = h[3]
266    body_preds = [x for x in b[2] if x[0] != "bop"]
267    h_s = ", ".join(hscope) + ", " if len(hscope) > 0 else ""
268    b_s = ", ".join(map(lambda x: f"1, {x[1]}" if x[2][0] else f"0, {x[1][4:]}", body_preds))
269    # If parameters are shared, then we require a special ID.
270    upr = -1 if not (s and l) else unique_pgrule_id()
271    # The number of body arguments is twice as we need to store the sugoal's sign and symbol.
272    rid = self.n_prules; self.n_prules += 1
273    u = f"{name}(@unify({rid}, {name}, {int(l)}, {upr}, {len(hscope)}, {2*len(body_preds)}, {h_s}{b_s})) :- {b[1]}."
274    return self.pack("prule", "", ProbRule(p, o, is_prop = False, unify = u, learnable = l,
275                                           sharing = s))
276
277  # Annotated disjunction head.
278  def ad_head(self, H):
279    P, F = [], []
280    for i in range(0, len(H), 2):
281      P.append(H[i][2])
282      F.append(H[i+1][1])
283    return self.pack("ad_head", F, P, self.join_scope(H))
284  # Learnable annotated disjunction head.
285  def lad_head(self, H: list):
286    P, F = [], []
287    i, o, j = 0, 0, 0
288    last = None
289    while i < len(H):
290      a = H[i]
291      if a[0] == "prob":
292        P.append(a[2])
293        F.append(H[i+1][1])
294        i += 2
295      else:
296        P.append(-1)
297        F.append(a[1])
298        i += 1; o += 1
299        last = j
300      j += 1
301    if o > 0:
302      P_s = sum(P)+o
303      # If probs were not explicitly given, assume maximum uncertainty and set to uniform.
304      s = round((1.0-P_s)/o, ndigits = 15)
305      ts = P_s+s*(o-1)
306      for i, p in enumerate(P):
307        if i == last: P[i] = 1.0-ts
308        elif p < 0: P[i] = s
309    return self.pack("lad_head", F, P, self.join_scope(H))
310  # Annotated disjunctions.
311  def ad(self, AD):
312    P, F, learnable = AD[0][2], AD[0][1], AD[0][0] == "lad_head"
313    if not math.isclose(s := sum(P), 1.0):
314      P.append(1-s)
315      F.append(unique_fact())
316    return self.pack("ad", "", AnnotatedDisjunction(P, F, learnable), AD[0][3])
317  def adr(self, AD):
318    raise NotImplementedError
319
320  def py_func_args(self, A): return "args", A[0][2]
321  def py_func_kwargs(self, A): return "kwargs", (A[0][2], A[1][2])
322  def py_func_call(self, A):
323    args = [a for k, a in A[1:] if k == "args"]
324    kwargs = dict([a for k, a in A[1:] if k == "kwargs"])
325    f = A[0][2]
326    if f not in self.torch_scope:
327      raise ValueError(f"No data definition {f} found! Either define it in a Python "
328                       "block or specify a file or URL to read from.")
329    return self.pack("py_func_call", "", self.torch_scope[f](*args, **kwargs))
330
331  def _data2tensor(self, D):
332    tp, path_or_data = D[0][0], D[0][2]
333    import pandas, numpy
334    # Is an external file or URL.
335    if tp != "py_func_call": data = pandas.read_csv(path_or_data, dtype = numpy.float32)
336    else: # is data.
337      try:
338        import torch
339      except ModuleNotFoundError:
340        raise ModuleNotFoundError("PyTorch not found! PyTorch must be installed for neural rules "
341                                  "and neural ADs.")
342      if not issubclass(type(path_or_data), torch.Tensor): path_or_data = torch.tensor(path_or_data)
343    return path_or_data
344
345  # Test set special predicate.
346  def test(self, T): return self.pack("test", "", self._data2tensor(T))
347  # Train set special predicate.
348  def train(self, R): return self.pack("train", "", self._data2tensor(R))
349
350  # Data special predicate.
351  def data(self, D):
352    name, arg = D[0][1], D[1][1]
353    test, train = D[2][2], D[3][2] if len(D) > 3 else None
354    return self.pack("data", f"{name}({arg}).", Data(name, arg, test, train))
355
356  # Python block.
357  def python(self, T):
358    exec("import torch\n\n" + T[0].value, self.torch_scope)
359    return self.pack("python", "")
360
361  # Local hubconf repo.
362  def LOCAL_NET(self, L): return self.pack("LOCAL_NET", str(L))
363  # GitHub hubconf repo.
364  def GITHUB(self, H): return self.pack("GITHUB", str(H))
365  # Python function.
366  def PY_FUNC(self, P): return self.pack("PY_FUNC", str(P))
367
368  # Hub network.
369  def hub(self, H):
370    # Function name or entrypoint.
371    func = H[0][1]
372    # Network is coming from a Torch block.
373    if len(H) == 1:
374      if func not in self.torch_scope:
375        raise ValueError(f"No network definition {func} found! Either define it in a Python"
376                         "block or specify a PyTorch Hub model (local or from GitHub).")
377      N = self.torch_scope[func]()
378      rep = f"@{func}"
379    # Network is coming from PyTorch Hub.
380    else:
381      try:
382        import torch
383      except ModuleNotFoundError:
384        raise ModuleNotFoundError("PyTorch not found! PyTorch must be installed for neural rules "
385                                  "and neural ADs.")
386      path, source = H[1][1], "github" if H[1][0] == "GITHUB" else "local"
387      N = torch.hub.load(path, func, source = source, trust_repo = "check")
388      rep = f"@{func} on \"{path}\" at \"{source}\""
389    return self.pack("hub", "", (N, rep))
390
391  # Optimizer parameters.
392  def params(self, P):
393    return self.pack("params", "", {P[i][1]: v[2] if isinstance((v := P[i+1]), self.Pack) else str(v) for i in range(0, len(P), 2)})
394
395  # Neural rule.
396  def nrule(self, A):
397    learnable = A[0][0] == "LEARN"
398    name = A[1][1]
399    inp = A[2][1]
400    offset = 3
401    outcomes = None
402    # Has more than one outcome within the neural network.
403    if A[offset][0] == "set":
404      outcomes = list(A[offset][2].keys())
405      offset += 1
406    net, hub_repr = A[offset][2]
407    if A[offset+1][0] == "params":
408      params = A[offset+1][2]
409      body = A[offset+2:]
410    else:
411      params = {}
412      body = A[offset+1:]
413    scope = self.join_scope(A)
414
415    if len(scope) != 1:  raise ValueError(f"Neural rule {name} is not grounded!")
416    if inp not in scope: raise ValueError(f"Neural rule {name} is unsafe!")
417
418    rep = f"{A[0][1]}::{name}({inp}{'' if outcomes is None else f'; {A[offset-1][1]}'}) as {hub_repr} :- {', '.join(getnths(body, 1))}."
419    return self.pack("nrule", "", (name, inp, outcomes, net, body, rep, learnable, params))
420
421  # Neural annotated disjunction.
422  def nad(self, A):
423    learnable = A[0][0] == "LEARN"
424    name = A[1][1]
425    inp = A[2][1]
426    vals = A[3][2]
427    outcomes = None
428    offset = 4
429    if A[offset][0] == "set":
430      outcomes = list(A[offset][2].keys())
431      offset += 1
432    net, hub_repr = A[offset][2]
433    if A[offset+1][0] == "params":
434      params = A[offset+1][2]
435      body = A[offset+2:]
436    else:
437      params = {}
438      body = A[offset+1:]
439    scope = self.join_scope(A)
440
441    if len(scope) != 1:  raise ValueError(f"Neural annotated disjunction {name} is not grounded!")
442    if inp not in scope: raise ValueError(f"Neural annotated disjunction {name} is unsafe!")
443
444    rep = f"{A[0][1]}::{name}({inp}, {A[3][1]}{'' if outcomes is None else f'; {A[offset-1][1]}'}) as {hub_repr} :- {', '.join(getnths(body, 1))}."
445    return self.pack("nad", "", (name, inp, vals, outcomes, net, body, rep, learnable, params))
446
447  # Constraint.
448  def constraint(self, C): return self.pack("constraint", f":- {C[0][1]}.")
449
450  # Query elements.
451  def qelement(self, E):
452    return self.pack("qelement", " ".join(getnths(E, 1)), scope = self.join_scope(E))
453  # Interpretations.
454  def interp(self, I):
455    return self.pack("interp", "", getnths(I, 1), scope = self.join_scope(I))
456  # Queries.
457  def query(self, Q):
458    Sc = self.join_scope(Q)
459    if len(Sc) > 0:
460      P = self.pack("varquery", "", VarQuery(self.varquery_id, list(Q[0][2]), \
461                                             list(Q[1][2]) if len(Q) > 1 else [], \
462                                             semantics = self.sem))
463      self.varquery_id += 1
464      return P
465    return self.pack("query", "", Query(Q[0][2], Q[1][2] if len(Q) > 1 else [], semantics = self.sem))
466
467  # Constant definition.
468  def constdef(self, C): return self.pack("constdef", f"#const {C[0][1]} = {C[1][1]}.")
469
470  @staticmethod
471  def path2obs(path: str):
472    import pandas, numpy
473    data = pandas.read_csv(path, dtype = int)
474    return lambda: (data.values, data.columns.values.tolist())
475
476  # Learning directive.
477  def learn(self, L):
478    A = {str(L[i]): str(v) if isinstance(v := L[i+1], lark.Token) else v[2] for i in range(1, len(L), 2)}
479    data = self.torch_scope[L[0][1]] if L[0][0] == "PY_FUNC" else StableTransformer.path2obs(L[0][1])
480    return self.pack("directive", "", ("learn", data, A))
481
482  # Include directive.
483  def include(self, F): return lark.visitors.Discard
484
485  def exact_inf(self, I): return ("inference", "exact", tuple())
486  def aseo_inf(self, I): return ("inference", "aseo", (I[0][2],))
487  def inference(self, I): return self.pack("directive", "", I[0])
488
489  # Semantics directive and options.
490  def SEMANTICS_OPT_LOGIC(self, _): return lark.visitors.Discard
491  def SEMANTICS_OPT_PROB(self, O): return str(O)
492  def semantics(self, S):
493    return self.pack("directive", "", ("psemantics", {"psemantics": S[0]})) if len(S) > 0 else \
494      lark.visitors.Discard
495
496  # Probabilistic Logic Program.
497  def plp(self, C) -> Program:
498    # Logic Program.
499    P  = []
500    # Probabilistic Facts.
501    PF = []
502    # Probabilistic Rules.
503    PR = []
504    # Queries.
505    Q  = []
506    # VarQueries.
507    VQ = []
508    # Credal Facts.
509    CF = []
510    # Annotated Disjunction.
511    AD = []
512    # Neural arguments and data.
513    TNR, TNA = [], []
514    D = {}
515    # Actual neural rules and neural ADs.
516    NR, NA = [], []
517    # Directives.
518    directives = {"inference": ("exact", tuple())}
519    # Mapping.
520    M = {"pfact": PF, "prule": PR, "query": Q, "varquery": VQ, "cfact": CF, "ad": AD, "nrule": TNR,
521         "nad": TNA}
522    for t, L, O, _ in C:
523      if len(L) > 0: push(P, L)
524      if t in M: push(M[t], O)
525      if t == "data":
526        if O.name in D: D[O.name].append(O)
527        else: D[O.name] = [O]
528      if t == "directive": directives[O[0]] = tup if len(tup := O[1:]) > 1 else tup[0]
529    # Deal with ungrounded probabilistic rules.
530    for r in PR:
531      if r.is_prop: PF.append(r.prop_pf)
532    self.check_data(D)
533    self.register_nrule(TNR, NR, D)
534    self.register_nad(TNA, NA, D)
535    return Program("\n".join(P), PF, PR, Q, VQ, CF, AD, NR, NA, semantics = self.sem, \
536                   directives = directives)

Transformers work bottom-up (or depth-first), starting with visiting the leaves and working their way up until ending at the root of the tree.

For each node visited, the transformer will call the appropriate method (callbacks), according to the node's data, and use the returned value to replace the node, thereby creating a new tree structure.

Transformers can be used to implement map & reduce patterns. Because nodes are reduced from leaf to root, at any point the callbacks may assume the children have already been transformed (if applicable).

If the transformer cannot find a method with the right name, it will instead call __default__, which by default creates a copy of the node.

To discard a node, return Discard (lark.visitors.Discard).

Transformer can do anything Visitor can do, but because it reconstructs the tree, it is slightly less efficient.

A transformer without methods essentially performs a non-memoized partial deepcopy.

All these classes implement the transformer interface:

  • Transformer - Recursively transforms the tree. This is the one you probably want.
  • Transformer_InPlace - Non-recursive. Changes the tree in-place instead of returning new instances
  • Transformer_InPlaceRecursive - Recursive. Changes the tree in-place instead of returning new instances

Parameters: visit_tokens (bool, optional): Should the transformer visit tokens in addition to rules. Setting this to False is slightly faster. Defaults to True. (For processing ignored tokens, use the lexer_callbacks options)

StableTransformer(_, consts: dict = {}, scope: dict = None)
79  def __init__(self, _, consts: dict = {}, scope: dict = None):
80    super().__init__()
81    self.sem = Semantics.STABLE
82    self.torch_scope = {} if scope is None else scope
83    self.n_prules = 0
84    self.consts = consts
85    self.varquery_id = 0
sem
torch_scope
n_prules
consts
varquery_id
@staticmethod
def pack( t: str, rep: str = None, val=None, scope: dict = {}) -> tuple[str, str, str, dict]:
87  @staticmethod
88  def pack(t: str, rep: str = None, val = None, scope: dict = {}) -> tuple[str, str, str, dict]:
89    return StableTransformer.Pack(t, rep, val, scope)
@staticmethod
def join_scope(A: list) -> dict:
91  @staticmethod
92  def join_scope(A: list) -> dict: return dict((y, None) for S in A for y in S[3])
@staticmethod
def find_data_pred(D: dict, body: list, which: str, name: str) -> list:
 94  @staticmethod
 95  def find_data_pred(D: dict, body: list, which: str, name: str) -> list:
 96    t = None
 97    for d in body:
 98      if d[2][1] in D:
 99        t = D[d[2][1]]
100        break
101    if t is None: raise ValueError(f"Neural {which} {name} must contain a data predicate!")
102    return t
@staticmethod
def check_data(D: list):
104  @staticmethod
105  def check_data(D: list):
106    "Checks if all data have same first dimension size."
107    n, m = -1, -1
108    for X in D.values():
109      if n < 0: n = X[0].test.shape[0]
110      if (X[0].train is not None) and (m < 0): m = X[0].train.shape[0]
111      for x in X:
112        if x.test.shape[0] != n:
113          raise ValueError("Test data must have same number of instances!")
114        if (x.train is not None) and (x.train.shape[0] != m):
115          raise ValueError("Train data must have same number of instances!")

Checks if all data have same first dimension size.

@staticmethod
def cont_head_sym(name: str, T: list, O: list, V: list = None):
117  @staticmethod
118  def cont_head_sym(name: str, T: list, O: list, V: list = None):
119    g, S = None, None
120    if V is None:
121      if O is None: S = [f"{name}({t.arg})" for t in T]
122      else: S = [f"{name}({t.arg}, {o})" for t in T for o in O]
123    else:
124      if O is None: S = [f"{name}({t.arg}, {v})" for t in T for v in V]
125      else: S = [f"{name}({t.arg}, {v}, {o})" for t in T for o in O for v in V]
126    g = (clingo.parse_term(s)._rep for s in S)
127    return contiguous(tuple(g), dtype=numpy.uint64), S
@staticmethod
def register_nrule(TNR: list, NR: list, D: list):
129  @staticmethod
130  def register_nrule(TNR: list, NR: list, D: list):
131    for name, inp, O, net, body, rep, learnable, params in TNR:
132      t = StableTransformer.find_data_pred(D, body, "rule", name)
133      # Ground rules.
134      H, _ = StableTransformer.cont_head_sym(name, t, O)
135      B, S = None, None
136      if len(body) > 1:
137        # B and S do not depend on the number of outcomes |O|, only on |t| and |body|.
138        body_no_data = [b for b in body if b[2][1] != t[0].name]
139        B = contiguous(tuple(clingo.parse_term(f"{b[2][1]}({t[i].arg})" if len(b[3]) > 0 \
140                                                else lit2atom(b[1]))._rep for i in range(len(t)) \
141                              for b in body_no_data), dtype = numpy.uint64)
142        S = contiguous(tuple(b[2][0] for i in range(len(t)) for b in body_no_data), dtype = bool)
143      NR.append(NeuralRule(H, B, S, name, net, rep, t, learnable, params, O))
@staticmethod
def register_nad(TNA: list, NA: list, D: list):
145  @staticmethod
146  def register_nad(TNA: list, NA: list, D: list):
147    for name, inp, vals, O, net, body, rep, learnable, params in TNA:
148      t = StableTransformer.find_data_pred(D, body, "AD", name)
149      # Ground rules.
150      V = list(vals.keys())
151      H, H_s = StableTransformer.cont_head_sym(name, t, O, V)
152      B, S = None, None
153      if len(body) > 1:
154        body_no_data = [b for b in body if b[2][1] != t[0].name]
155        # B and S do not depend on the number of values |V| or outcomes |O|, only on |t| and |body|.
156        B = contiguous(tuple(clingo.parse_term(f"{b[2][1]}({t[i].arg})" if len(b[3]) > 0 \
157                                                else lit2atom(b[1]))._rep for i in range(len(t)) \
158                              for b in body_no_data), dtype = numpy.uint64)
159        S = contiguous(tuple(b[2][0] for i in range(len(t)) for b in body_no_data), dtype = bool)
160      NA.append(NeuralAD(H, B, S, name, V, net, rep, t, learnable, params, O, H_s))
def CMP_OP(self, o):
165  def CMP_OP(self, o): return self.pack("CMP_OP", str(o))
def aggr(self, A):
166  def aggr(self, A): return self.pack("aggr", "".join(str(x) for x in A))
def raggr(self, A):
167  def raggr(self, A): return self.pack("raggr", "".join(str(x) for x in A))
def caggr(self, A):
168  def caggr(self, A): return self.pack("caggr", "".join(str(x) for x in A))
def UND(self, u):
171  def UND(self, u): return self.pack("UND", str(u))
def WORD(self, c):
172  def WORD(self, c): return self.pack("WORD", str(c))
def NEG(self, n):
173  def NEG(self, n): return self.pack("NEG", str(n))
def VAR(self, v):
174  def VAR(self, v):
175    x = str(v); X = {v: None}
176    return self.pack("VAR", x, scope = X)
def ID(self, i):
177  def ID(self, i): return self.pack("ID", val = int(i))
def OP(self, o):
178  def OP(self, o): return self.pack("OP", str(o))
def REAL(self, r):
179  def REAL(self, r): return self.pack("REAL", val = float(r))
def frac(self, f):
180  def frac(self, f): return self.pack("frac", val = f[0][2]/f[1][2])
def prob(self, p):
181  def prob(self, p): return self.pack("prob", val = p[0][2])
def SHARED(self, s):
182  def SHARED(self, s): return self.pack("SHARED",  str(s))
def LEARN(self, l):
183  def LEARN(self, l): return self.pack("LEARN", str(l))
def CONST(self, c):
184  def CONST(self, c): return self.pack("CONST", str(c))
def BOOL(self, b):
185  def BOOL(self, b): return self.pack("BOOL", v := b.lower(), v != "false")
def NULL(self, n):
186  def NULL(self, n): return self.pack("NULL", None, None)
def path(self, p):
189  def path(self, p): return self.pack(p[0].type, p[0].value)
def set(self, S):
192  def set(self, S):
193    if S[0][0] == "interval":
194      a, b = S[0][2]
195      if isinstance(a, str):
196        if a not in self.consts: raise KeyError(f"Constant {a} is undefined!")
197        a = self.consts[a]
198      if isinstance(b, str):
199        if b not in self.consts: raise KeyError(f"Constant {b} is undefined!")
200        b = self.consts[b]
201      M = dict((str(i), None) for i in range(a, b+1))
202    else:
203      M = dict((x[1], None) for x in S)
204      if len(M) != len(S): raise ValueError("set must contain only unique constants!")
205    return self.pack("set", f"{{{','.join(x for x in M.keys())}}}", M)
def interval(self, I):
208  def interval(self, I): return self.pack("interval", f"{I[0][2]}..{I[1][2]}", (I[0][2], I[1][2]))
def pred(self, P, replace_semicolons=False):
211  def pred(self, P, replace_semicolons = False):
212    name = P[0][1]
213    rep = f"{name}({', '.join(getnths(P[1:], 1))})"
214    return self.pack("pred", rep.replace(";", ",") if replace_semicolons else rep, name, self.join_scope(P))
def grpred(self, P):
215  def grpred(self, P): return self.pred(P)
def query_pred(self, P):
216  def query_pred(self, P): return self.pred(P, replace_semicolons = True)
def lit(self, P):
219  def lit(self, P):
220    s = P[0][0] != "NEG"
221    return self.pack("lit", " ".join(getnths(P, 1)), (s, P[0][2] if s else P[1][2]), self.join_scope(P))
def grlit(self, P):
222  def grlit(self, P): return self.lit(P)
def bop(self, B) -> str:
225  def bop(self, B) -> str: return self.pack("bop", " ".join(getnths(B, 1)))
def fact(self, F):
228  def fact(self, F):
229    f = f"{''.join(getnths(F, 1))}"
230    # Facts are always grounded.
231    return self.pack("fact", f + ".", f)
def pfact(self, PF):
232  def pfact(self, PF):
233    p, f = PF[0][2], PF[1][1]
234    return self.pack("pfact", "", ProbFact(p, f))
def cfact(self, CF):
235  def cfact(self, CF):
236    l, u, f = CF[0][2], CF[1][2], CF[2][1]
237    return self.pack("cfact", "", CredalFact(l, u, f))
def lpfact(self, PF):
238  def lpfact(self, PF):
239    if PF[0][0] == "prob": p, f = PF[0][2], PF[1][1]
240    else: p, f = 0.5, PF[0][1]
241    return self.pack("pfact", "", ProbFact(p, f, learnable = True))
def head(self, H):
244  def head(self, H): return self.pack("head", ", ".join(getnths(H, 1)), H, self.join_scope(H))
def ohead(self, H):
245  def ohead(self, H): return self.pack("head", H[0][1], H[0][2], H[0][3])
def body(self, B):
247  def body(self, B): return self.pack("body", ", ".join(getnths(B, 1)), B, self.join_scope(B))
def rule(self, R):
250  def rule(self, R): return self.pack("rule", " :- ".join(getnths(R, 1)) + ".")
def prule(self, R):
251  def prule(self, R):
252    l = "LEARN" in getnths(R, 0)
253    s = "SHARED" in getnths(R, 0)
254    h, b = R[-2], R[-1]
255    o = f"{h[1]} :- {b[1]}"
256    p = R[0][2] if R[0][0] == "prob" else 0.5
257    S = self.join_scope(R)
258    if len(S) == 0:
259      pr = ProbRule(p, o, is_prop = True, learnable = l)
260      self.n_prules += 1
261      return self.pack("prule", pr.prop_f, pr)
262    # Invariant: len(b) > 0, otherwise the rule is unsafe.
263    name = h[2]
264    # hscope is guaranteed to be ordered by Python dict's definition.
265    hscope = h[3]
266    body_preds = [x for x in b[2] if x[0] != "bop"]
267    h_s = ", ".join(hscope) + ", " if len(hscope) > 0 else ""
268    b_s = ", ".join(map(lambda x: f"1, {x[1]}" if x[2][0] else f"0, {x[1][4:]}", body_preds))
269    # If parameters are shared, then we require a special ID.
270    upr = -1 if not (s and l) else unique_pgrule_id()
271    # The number of body arguments is twice as we need to store the sugoal's sign and symbol.
272    rid = self.n_prules; self.n_prules += 1
273    u = f"{name}(@unify({rid}, {name}, {int(l)}, {upr}, {len(hscope)}, {2*len(body_preds)}, {h_s}{b_s})) :- {b[1]}."
274    return self.pack("prule", "", ProbRule(p, o, is_prop = False, unify = u, learnable = l,
275                                           sharing = s))
def ad_head(self, H):
278  def ad_head(self, H):
279    P, F = [], []
280    for i in range(0, len(H), 2):
281      P.append(H[i][2])
282      F.append(H[i+1][1])
283    return self.pack("ad_head", F, P, self.join_scope(H))
def lad_head(self, H: list):
285  def lad_head(self, H: list):
286    P, F = [], []
287    i, o, j = 0, 0, 0
288    last = None
289    while i < len(H):
290      a = H[i]
291      if a[0] == "prob":
292        P.append(a[2])
293        F.append(H[i+1][1])
294        i += 2
295      else:
296        P.append(-1)
297        F.append(a[1])
298        i += 1; o += 1
299        last = j
300      j += 1
301    if o > 0:
302      P_s = sum(P)+o
303      # If probs were not explicitly given, assume maximum uncertainty and set to uniform.
304      s = round((1.0-P_s)/o, ndigits = 15)
305      ts = P_s+s*(o-1)
306      for i, p in enumerate(P):
307        if i == last: P[i] = 1.0-ts
308        elif p < 0: P[i] = s
309    return self.pack("lad_head", F, P, self.join_scope(H))
def ad(self, AD):
311  def ad(self, AD):
312    P, F, learnable = AD[0][2], AD[0][1], AD[0][0] == "lad_head"
313    if not math.isclose(s := sum(P), 1.0):
314      P.append(1-s)
315      F.append(unique_fact())
316    return self.pack("ad", "", AnnotatedDisjunction(P, F, learnable), AD[0][3])
def adr(self, AD):
317  def adr(self, AD):
318    raise NotImplementedError
def py_func_args(self, A):
320  def py_func_args(self, A): return "args", A[0][2]
def py_func_kwargs(self, A):
321  def py_func_kwargs(self, A): return "kwargs", (A[0][2], A[1][2])
def py_func_call(self, A):
322  def py_func_call(self, A):
323    args = [a for k, a in A[1:] if k == "args"]
324    kwargs = dict([a for k, a in A[1:] if k == "kwargs"])
325    f = A[0][2]
326    if f not in self.torch_scope:
327      raise ValueError(f"No data definition {f} found! Either define it in a Python "
328                       "block or specify a file or URL to read from.")
329    return self.pack("py_func_call", "", self.torch_scope[f](*args, **kwargs))
def test(self, T):
346  def test(self, T): return self.pack("test", "", self._data2tensor(T))
def train(self, R):
348  def train(self, R): return self.pack("train", "", self._data2tensor(R))
def data(self, D):
351  def data(self, D):
352    name, arg = D[0][1], D[1][1]
353    test, train = D[2][2], D[3][2] if len(D) > 3 else None
354    return self.pack("data", f"{name}({arg}).", Data(name, arg, test, train))
def python(self, T):
357  def python(self, T):
358    exec("import torch\n\n" + T[0].value, self.torch_scope)
359    return self.pack("python", "")
def LOCAL_NET(self, L):
362  def LOCAL_NET(self, L): return self.pack("LOCAL_NET", str(L))
def GITHUB(self, H):
364  def GITHUB(self, H): return self.pack("GITHUB", str(H))
def PY_FUNC(self, P):
366  def PY_FUNC(self, P): return self.pack("PY_FUNC", str(P))
def hub(self, H):
369  def hub(self, H):
370    # Function name or entrypoint.
371    func = H[0][1]
372    # Network is coming from a Torch block.
373    if len(H) == 1:
374      if func not in self.torch_scope:
375        raise ValueError(f"No network definition {func} found! Either define it in a Python"
376                         "block or specify a PyTorch Hub model (local or from GitHub).")
377      N = self.torch_scope[func]()
378      rep = f"@{func}"
379    # Network is coming from PyTorch Hub.
380    else:
381      try:
382        import torch
383      except ModuleNotFoundError:
384        raise ModuleNotFoundError("PyTorch not found! PyTorch must be installed for neural rules "
385                                  "and neural ADs.")
386      path, source = H[1][1], "github" if H[1][0] == "GITHUB" else "local"
387      N = torch.hub.load(path, func, source = source, trust_repo = "check")
388      rep = f"@{func} on \"{path}\" at \"{source}\""
389    return self.pack("hub", "", (N, rep))
def params(self, P):
392  def params(self, P):
393    return self.pack("params", "", {P[i][1]: v[2] if isinstance((v := P[i+1]), self.Pack) else str(v) for i in range(0, len(P), 2)})
def nrule(self, A):
396  def nrule(self, A):
397    learnable = A[0][0] == "LEARN"
398    name = A[1][1]
399    inp = A[2][1]
400    offset = 3
401    outcomes = None
402    # Has more than one outcome within the neural network.
403    if A[offset][0] == "set":
404      outcomes = list(A[offset][2].keys())
405      offset += 1
406    net, hub_repr = A[offset][2]
407    if A[offset+1][0] == "params":
408      params = A[offset+1][2]
409      body = A[offset+2:]
410    else:
411      params = {}
412      body = A[offset+1:]
413    scope = self.join_scope(A)
414
415    if len(scope) != 1:  raise ValueError(f"Neural rule {name} is not grounded!")
416    if inp not in scope: raise ValueError(f"Neural rule {name} is unsafe!")
417
418    rep = f"{A[0][1]}::{name}({inp}{'' if outcomes is None else f'; {A[offset-1][1]}'}) as {hub_repr} :- {', '.join(getnths(body, 1))}."
419    return self.pack("nrule", "", (name, inp, outcomes, net, body, rep, learnable, params))
def nad(self, A):
422  def nad(self, A):
423    learnable = A[0][0] == "LEARN"
424    name = A[1][1]
425    inp = A[2][1]
426    vals = A[3][2]
427    outcomes = None
428    offset = 4
429    if A[offset][0] == "set":
430      outcomes = list(A[offset][2].keys())
431      offset += 1
432    net, hub_repr = A[offset][2]
433    if A[offset+1][0] == "params":
434      params = A[offset+1][2]
435      body = A[offset+2:]
436    else:
437      params = {}
438      body = A[offset+1:]
439    scope = self.join_scope(A)
440
441    if len(scope) != 1:  raise ValueError(f"Neural annotated disjunction {name} is not grounded!")
442    if inp not in scope: raise ValueError(f"Neural annotated disjunction {name} is unsafe!")
443
444    rep = f"{A[0][1]}::{name}({inp}, {A[3][1]}{'' if outcomes is None else f'; {A[offset-1][1]}'}) as {hub_repr} :- {', '.join(getnths(body, 1))}."
445    return self.pack("nad", "", (name, inp, vals, outcomes, net, body, rep, learnable, params))
def constraint(self, C):
448  def constraint(self, C): return self.pack("constraint", f":- {C[0][1]}.")
def qelement(self, E):
451  def qelement(self, E):
452    return self.pack("qelement", " ".join(getnths(E, 1)), scope = self.join_scope(E))
def interp(self, I):
454  def interp(self, I):
455    return self.pack("interp", "", getnths(I, 1), scope = self.join_scope(I))
def query(self, Q):
457  def query(self, Q):
458    Sc = self.join_scope(Q)
459    if len(Sc) > 0:
460      P = self.pack("varquery", "", VarQuery(self.varquery_id, list(Q[0][2]), \
461                                             list(Q[1][2]) if len(Q) > 1 else [], \
462                                             semantics = self.sem))
463      self.varquery_id += 1
464      return P
465    return self.pack("query", "", Query(Q[0][2], Q[1][2] if len(Q) > 1 else [], semantics = self.sem))
def constdef(self, C):
468  def constdef(self, C): return self.pack("constdef", f"#const {C[0][1]} = {C[1][1]}.")
@staticmethod
def path2obs(path: str):
470  @staticmethod
471  def path2obs(path: str):
472    import pandas, numpy
473    data = pandas.read_csv(path, dtype = int)
474    return lambda: (data.values, data.columns.values.tolist())
def learn(self, L):
477  def learn(self, L):
478    A = {str(L[i]): str(v) if isinstance(v := L[i+1], lark.Token) else v[2] for i in range(1, len(L), 2)}
479    data = self.torch_scope[L[0][1]] if L[0][0] == "PY_FUNC" else StableTransformer.path2obs(L[0][1])
480    return self.pack("directive", "", ("learn", data, A))
def include(self, F):
483  def include(self, F): return lark.visitors.Discard
def exact_inf(self, I):
485  def exact_inf(self, I): return ("inference", "exact", tuple())
def aseo_inf(self, I):
486  def aseo_inf(self, I): return ("inference", "aseo", (I[0][2],))
def inference(self, I):
487  def inference(self, I): return self.pack("directive", "", I[0])
def SEMANTICS_OPT_LOGIC(self, _):
490  def SEMANTICS_OPT_LOGIC(self, _): return lark.visitors.Discard
def SEMANTICS_OPT_PROB(self, O):
491  def SEMANTICS_OPT_PROB(self, O): return str(O)
def semantics(self, S):
492  def semantics(self, S):
493    return self.pack("directive", "", ("psemantics", {"psemantics": S[0]})) if len(S) > 0 else \
494      lark.visitors.Discard
def plp(self, C) -> pasp.program.Program:
497  def plp(self, C) -> Program:
498    # Logic Program.
499    P  = []
500    # Probabilistic Facts.
501    PF = []
502    # Probabilistic Rules.
503    PR = []
504    # Queries.
505    Q  = []
506    # VarQueries.
507    VQ = []
508    # Credal Facts.
509    CF = []
510    # Annotated Disjunction.
511    AD = []
512    # Neural arguments and data.
513    TNR, TNA = [], []
514    D = {}
515    # Actual neural rules and neural ADs.
516    NR, NA = [], []
517    # Directives.
518    directives = {"inference": ("exact", tuple())}
519    # Mapping.
520    M = {"pfact": PF, "prule": PR, "query": Q, "varquery": VQ, "cfact": CF, "ad": AD, "nrule": TNR,
521         "nad": TNA}
522    for t, L, O, _ in C:
523      if len(L) > 0: push(P, L)
524      if t in M: push(M[t], O)
525      if t == "data":
526        if O.name in D: D[O.name].append(O)
527        else: D[O.name] = [O]
528      if t == "directive": directives[O[0]] = tup if len(tup := O[1:]) > 1 else tup[0]
529    # Deal with ungrounded probabilistic rules.
530    for r in PR:
531      if r.is_prop: PF.append(r.prop_pf)
532    self.check_data(D)
533    self.register_nrule(TNR, NR, D)
534    self.register_nad(TNA, NA, D)
535    return Program("\n".join(P), PF, PR, Q, VQ, CF, AD, NR, NA, semantics = self.sem, \
536                   directives = directives)
class StableTransformer.Pack(builtins.tuple):
71  class Pack(tuple):
72    @staticmethod
73    def __new__(cls, tp: str, r: str = None, v = None, sc: dict = {}):
74      return super(StableTransformer.Pack, cls).__new__(cls, (tp, str(v) if r is None else r, \
75                                                              r if v is None else v, sc))
76    def __str__(self): return self[1]
77    def __repr__(self): return f"<{self[0]}: {self.__str__()}>"

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

class PartialTransformer(lark.visitors._Decoratable, abc.ABC, typing.Generic[~_Leaf_T, ~_Return_T]):
538class PartialTransformer(StableTransformer):
539  def __init__(self, sem: str, consts: dict = {}):
540    super().__init__(sem, consts)
541    self.PT = set()
542    if sem == "lstable":
543      self.sem = Semantics.LSTABLE
544    elif sem == "smproblog":
545      self.sem = Semantics.SMPROBLOG
546    else:
547      self.sem = Semantics.PARTIAL
548    self.o_tree = None
549
550  @staticmethod
551  def has_binop(x: str): return ("=" in x) or ("<" in x) or (">" in x)
552
553  def fact(self, F):
554    T = super().fact(F)
555    self.PT.add(T[2])
556    return T
557
558  def pfact(self, PF):
559    T = super().pfact(PF)
560    self.PT.add(T[2].f)
561    return T
562
563  def cfact(self, CF):
564    T = super().cfact(CF)
565    self.PT.add(T[2].f)
566    return T
567
568  def rule(self, R):
569    b1 = ", ".join(map(lambda x: x[1] if x[2][0] else f"not _{x[1][4:]}", R[1][2]))
570    b2 = ", ".join(map(lambda x: f"_{x[1]}" if x[2][0] or PartialTransformer.has_binop(x) else x[1], R[1][2]))
571    h1, h2 = R[0][1], ", ".join(map(lambda x: f"_{x[1]}", R[0][2]))
572    for h in R[0][2]: self.PT.add(h[1])
573    # for x in r[1][3]:
574      # if not PartialTransformer.has_binop(x): self.PT.add(x[4:] if x[:4] == "not " else x)
575    return self.pack("rule", [f"{h1} :- {b1}.", f"{h2} :- {b2}."])
576
577  def prule(self, R):
578    l = "LEARN" in getnths(R, 0)
579    s = "SHARED" in getnths(R, 0)
580    p = R[0][2] if R[0][0] == "prob" else 0.5
581    h, b = R[-2], R[-1]
582    tr_negs = lambda x: x[1] if x[2][0] else f"not _{x[1][4:]}"
583    tr_pos  = lambda x: f"_{x[1]}" if x[2][0] or PartialTransformer.has_binop(x) else x[1]
584    b1 = ", ".join(map(tr_negs, b[2]))
585    b2 = ", ".join(map(tr_pos, b[2]))
586    o1, o2 = f"{h[1]} :- {b1}", f"_{h[1]} :- {b2}"
587    self.PT.add(h[1])
588    uid = unique_fact()
589    S = self.join_scope(R)
590    if len(S) == 0:
591      pr1, pr2 = ProbRule(p, o1, ufact = uid, learnable = l), ProbRule(p, o2, ufact = uid)
592      self.n_prules += 2
593      return self.pack("prule", [pr1.prop_f, pr2.prop_f], [pr1, pr2])
594    # Invariant: len(b) > 0, otherwise the rule is unsafe.
595    name = h[2]
596    hscope = h[3]
597    body_preds = [x for x in b[2] if x[0] != "bop"]
598    h_s = ", ".join(hscope) + ", " if len(hscope) > 0 else ""
599    b1_s = ", ".join(map(lambda x: f"1, {x[1]}" if x[2][0] else f"0, _{x[1][4:]}", body_preds))
600    # If parameters are shared, then we require a special ID.
601    upr = -1 if not(s and l) else unique_pgrule_id()
602    # Let the grounder deal with the _f rule.
603    rid = self.n_prules; self.n_prules += 1
604    u1 = f"{name}(@unify({rid}, {name}, {int(l)}, {upr}, {len(hscope)}, {2*len(body_preds)}, {h_s}{b1_s})) :- {b1}."
605    return self.pack("prule", "", ProbRule(p, o1, is_prop = False, unify = u1, learnable = l))
606
607  def plp(self, C: list[tuple]) -> Program:
608    # Logic Program.
609    P  = []
610    # Probabilistic Facts.
611    PF = []
612    # Probabilistic Rules.
613    PR = []
614    # Queries.
615    Q  = []
616    # Variable queries.
617    VQ = []
618    # Credal Facts.
619    CF = []
620    # Annotated Disjunction.
621    AD = []
622    # Neural arguments and data.
623    TNR, TNA = [], []
624    D = {}
625    # Neural rules and ADs.
626    NR, NA = [], []
627    # Directives.
628    directives = {"inference": ("exact", tuple())}
629    # Mapping.
630    M = {"pfact": PF, "prule": PR, "query": Q, "varquery": VQ, "cfact": CF, "ad": AD, "nrule": TNR,
631         "nad": TNA}
632    for t, L, O, _ in C:
633      if len(L) > 0: push(P, L)
634      if t in M: push(M[t], O)
635      if t == "prule" and isinstance(O, collections.abc.Iterable) and O[0].is_prop:
636        PF.append(O[0].prop_pf)
637      if t == "directive": directives[O[0]] = tup if len(tup := O[1:]) > 1 else tup[0]
638    P.extend(f"_{x} :- {x}." for x in self.PT)
639    self.check_data(D)
640    self.register_nrule(TNR, NR, D)
641    self.register_nad(TNA, NA, D)
642    return Program("\n".join(P), PF, PR, Q, VQ, CF, AD, NR, NA, semantics = self.sem, \
643                   stable_p = self.stable_p, directives = directives)
644
645  def transform(self, tree):
646    self.o_tree = tree
647    self.stable_p = StableTransformer(self.sem).transform(tree)
648    return super().transform(tree)

Transformers work bottom-up (or depth-first), starting with visiting the leaves and working their way up until ending at the root of the tree.

For each node visited, the transformer will call the appropriate method (callbacks), according to the node's data, and use the returned value to replace the node, thereby creating a new tree structure.

Transformers can be used to implement map & reduce patterns. Because nodes are reduced from leaf to root, at any point the callbacks may assume the children have already been transformed (if applicable).

If the transformer cannot find a method with the right name, it will instead call __default__, which by default creates a copy of the node.

To discard a node, return Discard (lark.visitors.Discard).

Transformer can do anything Visitor can do, but because it reconstructs the tree, it is slightly less efficient.

A transformer without methods essentially performs a non-memoized partial deepcopy.

All these classes implement the transformer interface:

  • Transformer - Recursively transforms the tree. This is the one you probably want.
  • Transformer_InPlace - Non-recursive. Changes the tree in-place instead of returning new instances
  • Transformer_InPlaceRecursive - Recursive. Changes the tree in-place instead of returning new instances

Parameters: visit_tokens (bool, optional): Should the transformer visit tokens in addition to rules. Setting this to False is slightly faster. Defaults to True. (For processing ignored tokens, use the lexer_callbacks options)

PartialTransformer(sem: str, consts: dict = {})
539  def __init__(self, sem: str, consts: dict = {}):
540    super().__init__(sem, consts)
541    self.PT = set()
542    if sem == "lstable":
543      self.sem = Semantics.LSTABLE
544    elif sem == "smproblog":
545      self.sem = Semantics.SMPROBLOG
546    else:
547      self.sem = Semantics.PARTIAL
548    self.o_tree = None
PT
o_tree
@staticmethod
def has_binop(x: str):
550  @staticmethod
551  def has_binop(x: str): return ("=" in x) or ("<" in x) or (">" in x)
def fact(self, F):
553  def fact(self, F):
554    T = super().fact(F)
555    self.PT.add(T[2])
556    return T
def pfact(self, PF):
558  def pfact(self, PF):
559    T = super().pfact(PF)
560    self.PT.add(T[2].f)
561    return T
def cfact(self, CF):
563  def cfact(self, CF):
564    T = super().cfact(CF)
565    self.PT.add(T[2].f)
566    return T
def rule(self, R):
568  def rule(self, R):
569    b1 = ", ".join(map(lambda x: x[1] if x[2][0] else f"not _{x[1][4:]}", R[1][2]))
570    b2 = ", ".join(map(lambda x: f"_{x[1]}" if x[2][0] or PartialTransformer.has_binop(x) else x[1], R[1][2]))
571    h1, h2 = R[0][1], ", ".join(map(lambda x: f"_{x[1]}", R[0][2]))
572    for h in R[0][2]: self.PT.add(h[1])
573    # for x in r[1][3]:
574      # if not PartialTransformer.has_binop(x): self.PT.add(x[4:] if x[:4] == "not " else x)
575    return self.pack("rule", [f"{h1} :- {b1}.", f"{h2} :- {b2}."])
def prule(self, R):
577  def prule(self, R):
578    l = "LEARN" in getnths(R, 0)
579    s = "SHARED" in getnths(R, 0)
580    p = R[0][2] if R[0][0] == "prob" else 0.5
581    h, b = R[-2], R[-1]
582    tr_negs = lambda x: x[1] if x[2][0] else f"not _{x[1][4:]}"
583    tr_pos  = lambda x: f"_{x[1]}" if x[2][0] or PartialTransformer.has_binop(x) else x[1]
584    b1 = ", ".join(map(tr_negs, b[2]))
585    b2 = ", ".join(map(tr_pos, b[2]))
586    o1, o2 = f"{h[1]} :- {b1}", f"_{h[1]} :- {b2}"
587    self.PT.add(h[1])
588    uid = unique_fact()
589    S = self.join_scope(R)
590    if len(S) == 0:
591      pr1, pr2 = ProbRule(p, o1, ufact = uid, learnable = l), ProbRule(p, o2, ufact = uid)
592      self.n_prules += 2
593      return self.pack("prule", [pr1.prop_f, pr2.prop_f], [pr1, pr2])
594    # Invariant: len(b) > 0, otherwise the rule is unsafe.
595    name = h[2]
596    hscope = h[3]
597    body_preds = [x for x in b[2] if x[0] != "bop"]
598    h_s = ", ".join(hscope) + ", " if len(hscope) > 0 else ""
599    b1_s = ", ".join(map(lambda x: f"1, {x[1]}" if x[2][0] else f"0, _{x[1][4:]}", body_preds))
600    # If parameters are shared, then we require a special ID.
601    upr = -1 if not(s and l) else unique_pgrule_id()
602    # Let the grounder deal with the _f rule.
603    rid = self.n_prules; self.n_prules += 1
604    u1 = f"{name}(@unify({rid}, {name}, {int(l)}, {upr}, {len(hscope)}, {2*len(body_preds)}, {h_s}{b1_s})) :- {b1}."
605    return self.pack("prule", "", ProbRule(p, o1, is_prop = False, unify = u1, learnable = l))
def plp(self, C: list[tuple]) -> pasp.program.Program:
607  def plp(self, C: list[tuple]) -> Program:
608    # Logic Program.
609    P  = []
610    # Probabilistic Facts.
611    PF = []
612    # Probabilistic Rules.
613    PR = []
614    # Queries.
615    Q  = []
616    # Variable queries.
617    VQ = []
618    # Credal Facts.
619    CF = []
620    # Annotated Disjunction.
621    AD = []
622    # Neural arguments and data.
623    TNR, TNA = [], []
624    D = {}
625    # Neural rules and ADs.
626    NR, NA = [], []
627    # Directives.
628    directives = {"inference": ("exact", tuple())}
629    # Mapping.
630    M = {"pfact": PF, "prule": PR, "query": Q, "varquery": VQ, "cfact": CF, "ad": AD, "nrule": TNR,
631         "nad": TNA}
632    for t, L, O, _ in C:
633      if len(L) > 0: push(P, L)
634      if t in M: push(M[t], O)
635      if t == "prule" and isinstance(O, collections.abc.Iterable) and O[0].is_prop:
636        PF.append(O[0].prop_pf)
637      if t == "directive": directives[O[0]] = tup if len(tup := O[1:]) > 1 else tup[0]
638    P.extend(f"_{x} :- {x}." for x in self.PT)
639    self.check_data(D)
640    self.register_nrule(TNR, NR, D)
641    self.register_nad(TNA, NA, D)
642    return Program("\n".join(P), PF, PR, Q, VQ, CF, AD, NR, NA, semantics = self.sem, \
643                   stable_p = self.stable_p, directives = directives)
def transform(self, tree):
645  def transform(self, tree):
646    self.o_tree = tree
647    self.stable_p = StableTransformer(self.sem).transform(tree)
648    return super().transform(tree)

Transform the given tree, and return the final result

def parse( *files: str, G: lark.lark.Lark = None, from_str: bool = False, semantics: str = 'stable') -> pasp.program.Program:
667def parse(*files: str, G: lark.Lark = None, from_str: bool = False, semantics: str = "stable") -> Program:
668  """Either parses `streams` as blocks of text containing the PLP when `from_str = True`, or
669  interprets `streams` as filenames to be read and parsed into a `Program`."""
670  if semantics not in parse.trans_map:
671    raise ValueError("semantics not supported (must either be 'stable', 'partial' or 'lstable')!")
672  sem, consts,  T = _flatten_includes(*files, G=G, from_str=from_str)
673  if sem is not None: semantics = sem
674  return parse.trans_map[semantics](semantics, consts).transform(T)

Either parses streams as blocks of text containing the PLP when from_str = True, or interprets streams as filenames to be read and parsed into a Program.