|
11 | 11 | NonCommutativeLogicalCriterionCombination,
|
12 | 12 | )
|
13 | 13 | from execution_engine.omop.criterion.combination.temporal import (
|
| 14 | + FixedWindowTemporalIndicatorCombination, |
| 15 | + PersonalWindowTemporalIndicatorCombination, |
14 | 16 | TemporalIndicatorCombination,
|
15 | 17 | )
|
16 | 18 |
|
@@ -226,6 +228,7 @@ def criterion_attr(attr: str) -> str | None:
|
226 | 228 | node_data["data"]["start_time"] = node.start_time
|
227 | 229 | node_data["data"]["end_time"] = node.end_time
|
228 | 230 | node_data["data"]["interval_type"] = node.interval_type
|
| 231 | + node_data["data"]["interval_criterion"] = repr(node.interval_criterion) |
229 | 232 |
|
230 | 233 | if hasattr(node, "count_min"):
|
231 | 234 | node_data["data"]["count_min"] = node.count_min
|
@@ -387,136 +390,101 @@ def conjunction_from_combination(
|
387 | 390 | # The problem is that we need a non-criterion sink node of the intervention and population in order
|
388 | 391 | # to store the results to the database without the criterion_id (as the result of the whole
|
389 | 392 | # intervention or population of this population/intervention pair).
|
390 |
| - assert ( |
| 393 | + |
| 394 | + if ( |
391 | 395 | comb.operator.operator
|
392 |
| - == LogicalCriterionCombination.Operator.AND |
393 |
| - ), ( |
394 |
| - f"Invalid operator {str(comb.operator)} for root node. " |
395 |
| - f"Expected {LogicalCriterionCombination.Operator.AND}" |
396 |
| - ) |
| 396 | + != LogicalCriterionCombination.Operator.AND |
| 397 | + ): |
| 398 | + raise AssertionError( |
| 399 | + f"Invalid operator {comb.operator} for root node. Expected AND." |
| 400 | + ) |
397 | 401 | return logic.NonSimplifiableAnd
|
398 |
| - elif isinstance(comb, NonCommutativeLogicalCriterionCombination): |
| 402 | + |
| 403 | + # Handle non-commutative combinations. |
| 404 | + if isinstance(comb, NonCommutativeLogicalCriterionCombination): |
399 | 405 | return logic.ConditionalFilter
|
400 |
| - elif comb.operator.operator == LogicalCriterionCombination.Operator.NOT: |
401 |
| - return logic.Not |
402 |
| - elif comb.operator.operator == LogicalCriterionCombination.Operator.AND: |
403 |
| - return logic.And |
404 |
| - elif comb.operator.operator == LogicalCriterionCombination.Operator.OR: |
405 |
| - return logic.Or |
406 |
| - elif ( |
407 |
| - comb.operator.operator |
408 |
| - == LogicalCriterionCombination.Operator.ALL_OR_NONE |
409 |
| - ): |
410 |
| - return logic.AllOrNone |
411 |
| - elif ( |
412 |
| - comb.operator.operator |
413 |
| - == LogicalCriterionCombination.Operator.AT_LEAST |
414 |
| - ): |
415 |
| - if comb.operator.threshold is None: |
416 |
| - raise ValueError( |
417 |
| - f"Threshold must be set for operator {comb.operator.operator}" |
418 |
| - ) |
419 |
| - return lambda *args, category: logic.MinCount( |
420 |
| - *args, threshold=comb.operator.threshold, category=category # type: ignore |
421 |
| - ) |
422 |
| - elif ( |
423 |
| - comb.operator.operator |
424 |
| - == LogicalCriterionCombination.Operator.AT_MOST |
425 |
| - ): |
426 |
| - if comb.operator.threshold is None: |
427 |
| - raise ValueError( |
428 |
| - f"Threshold must be set for operator {comb.operator.operator}" |
429 |
| - ) |
430 |
| - return lambda *args, category: logic.MaxCount( |
431 |
| - *args, threshold=comb.operator.threshold, category=category # type: ignore |
432 |
| - ) |
433 |
| - elif ( |
434 |
| - comb.operator.operator |
435 |
| - == LogicalCriterionCombination.Operator.EXACTLY |
436 |
| - ): |
| 406 | + |
| 407 | + op = comb.operator.operator |
| 408 | + |
| 409 | + # Mapping of simple logical operators. |
| 410 | + simple_ops = { |
| 411 | + LogicalCriterionCombination.Operator.NOT: logic.Not, |
| 412 | + LogicalCriterionCombination.Operator.AND: logic.And, |
| 413 | + LogicalCriterionCombination.Operator.OR: logic.Or, |
| 414 | + LogicalCriterionCombination.Operator.ALL_OR_NONE: logic.AllOrNone, |
| 415 | + } |
| 416 | + if op in simple_ops: |
| 417 | + return simple_ops[op] |
| 418 | + |
| 419 | + # Mapping of count-based operators. |
| 420 | + count_ops = { |
| 421 | + LogicalCriterionCombination.Operator.AT_LEAST: logic.MinCount, |
| 422 | + LogicalCriterionCombination.Operator.AT_MOST: logic.MaxCount, |
| 423 | + LogicalCriterionCombination.Operator.EXACTLY: logic.ExactCount, |
| 424 | + } |
| 425 | + if op in count_ops: |
437 | 426 | if comb.operator.threshold is None:
|
438 | 427 | raise ValueError(
|
439 | 428 | f"Threshold must be set for operator {comb.operator.operator}"
|
440 | 429 | )
|
441 |
| - return lambda *args, category: logic.ExactCount( |
442 |
| - *args, threshold=comb.operator.threshold, category=category # type: ignore |
443 |
| - ) |
444 |
| - else: |
445 |
| - raise NotImplementedError( |
446 |
| - f'Operator "{str(comb.operator)}" not implemented' |
| 430 | + return lambda *args, category: count_ops[op]( |
| 431 | + *args, threshold=comb.operator.threshold, category=category |
447 | 432 | )
|
448 | 433 |
|
| 434 | + raise NotImplementedError(f'Operator "{comb.operator}" not implemented') |
| 435 | + |
| 436 | + ################################################################################### |
449 | 437 | elif isinstance(comb, TemporalIndicatorCombination):
|
450 | 438 |
|
451 |
| - tcomb: TemporalIndicatorCombination = comb |
452 |
| - interval_criterion: logic.Expr | logic.Symbol | None |
| 439 | + interval_criterion: logic.BaseExpr | None = None |
| 440 | + start_time = None |
| 441 | + end_time = None |
| 442 | + interval_type = None |
453 | 443 |
|
454 |
| - if isinstance(tcomb.interval_criterion, CriterionCombination): |
455 |
| - interval_criterion = _traverse(tcomb.interval_criterion) |
456 |
| - elif isinstance(tcomb.interval_criterion, Criterion): |
457 |
| - interval_criterion = logic.Symbol(tcomb.interval_criterion) |
458 |
| - elif tcomb.interval_criterion is None: |
459 |
| - interval_criterion = None |
460 |
| - else: |
461 |
| - raise ValueError( |
462 |
| - f"Invalid interval criterion type: {type(tcomb.interval_criterion)}" |
463 |
| - ) |
| 444 | + if isinstance(comb, PersonalWindowTemporalIndicatorCombination): |
464 | 445 |
|
465 |
| - if ( |
466 |
| - tcomb.operator.operator |
467 |
| - == TemporalIndicatorCombination.Operator.AT_LEAST |
468 |
| - ): |
469 |
| - if tcomb.operator.threshold is None: |
| 446 | + if isinstance(comb.interval_criterion, CriterionCombination): |
| 447 | + interval_criterion = _traverse(comb.interval_criterion) |
| 448 | + elif isinstance(comb.interval_criterion, Criterion): |
| 449 | + interval_criterion = logic.Symbol(comb.interval_criterion) |
| 450 | + else: |
470 | 451 | raise ValueError(
|
471 |
| - f"Threshold must be set for operator {tcomb.operator.operator}" |
| 452 | + f"Invalid interval criterion type: {type(comb.interval_criterion)}" |
472 | 453 | )
|
473 |
| - return lambda *args, category: logic.TemporalMinCount( |
474 |
| - *args, |
475 |
| - threshold=tcomb.operator.threshold, |
476 |
| - category=category, |
477 |
| - start_time=tcomb.start_time, |
478 |
| - end_time=tcomb.end_time, |
479 |
| - interval_type=tcomb.interval_type, # type: ignore |
480 |
| - interval_criterion=interval_criterion, |
481 |
| - ) |
482 |
| - elif ( |
483 |
| - tcomb.operator.operator |
484 |
| - == TemporalIndicatorCombination.Operator.AT_MOST |
485 |
| - ): |
486 |
| - if tcomb.operator.threshold is None: |
487 |
| - raise ValueError( |
488 |
| - f"Threshold must be set for operator {tcomb.operator.operator}" |
489 |
| - ) |
490 |
| - return lambda *args, category: logic.TemporalMaxCount( |
491 |
| - *args, |
492 |
| - threshold=tcomb.operator.threshold, |
493 |
| - category=category, # type: ignore |
494 |
| - start_time=tcomb.start_time, |
495 |
| - end_time=tcomb.end_time, |
496 |
| - interval_type=tcomb.interval_type, # type: ignore |
497 |
| - interval_criterion=interval_criterion, |
498 |
| - ) |
499 |
| - elif ( |
500 |
| - tcomb.operator.operator |
501 |
| - == TemporalIndicatorCombination.Operator.EXACTLY |
502 |
| - ): |
503 |
| - if tcomb.operator.threshold is None: |
504 |
| - raise ValueError( |
505 |
| - f"Threshold must be set for operator {tcomb.operator.operator}" |
506 |
| - ) |
507 |
| - return lambda *args, category: logic.TemporalExactCount( |
508 |
| - *args, |
509 |
| - threshold=tcomb.operator.threshold, |
510 |
| - category=category, |
511 |
| - start_time=tcomb.start_time, |
512 |
| - end_time=tcomb.end_time, |
513 |
| - interval_type=tcomb.interval_type, |
514 |
| - interval_criterion=interval_criterion, |
| 454 | + |
| 455 | + elif isinstance(comb, FixedWindowTemporalIndicatorCombination): |
| 456 | + start_time = comb.start_time |
| 457 | + end_time = comb.end_time |
| 458 | + interval_type = comb.interval_type |
| 459 | + |
| 460 | + # Ensure a threshold is set. |
| 461 | + if comb.operator.threshold is None: |
| 462 | + raise ValueError( |
| 463 | + f"Threshold must be set for operator {comb.operator.operator}" |
515 | 464 | )
|
516 |
| - else: |
| 465 | + |
| 466 | + # Map the operator to the corresponding logic function. |
| 467 | + op_map = { |
| 468 | + TemporalIndicatorCombination.Operator.AT_LEAST: logic.TemporalMinCount, |
| 469 | + TemporalIndicatorCombination.Operator.AT_MOST: logic.TemporalMaxCount, |
| 470 | + TemporalIndicatorCombination.Operator.EXACTLY: logic.TemporalExactCount, |
| 471 | + } |
| 472 | + op_func = op_map.get(comb.operator.operator, None) |
| 473 | + if op_func is None: |
517 | 474 | raise NotImplementedError(
|
518 |
| - f'Operator "{str(tcomb.operator)}" not implemented' |
| 475 | + f'Operator "{str(comb.operator)}" not implemented' |
519 | 476 | )
|
| 477 | + |
| 478 | + return lambda *args, category: op_func( |
| 479 | + *args, |
| 480 | + threshold=comb.operator.threshold, |
| 481 | + category=category, |
| 482 | + start_time=start_time, |
| 483 | + end_time=end_time, |
| 484 | + interval_type=interval_type, |
| 485 | + interval_criterion=interval_criterion, |
| 486 | + ) |
| 487 | + |
520 | 488 | else:
|
521 | 489 | raise ValueError(f"Invalid combination type: {type(comb)}")
|
522 | 490 |
|
|
0 commit comments