From e319cfcb8e9450b84c2f9a0227370e91e7589d20 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 3 Mar 2026 11:57:58 -0500 Subject: [PATCH] Add the ability to use the Prism parser Previously, this relied on Ripper exclusively. This commit adds the ability to use the Prism parser instead. This enables power_assert to handle slightly more syntax, and to not get confused on lines that are lacking column information and have multiple of the same token. It is also faster, when examining runs times of the ident extraction methods. I have left Ripper as the default parser, as that would require more of a discussion going forward. But switching would yield the above benefits, as well as allowing this to run on other Rubies besides just CRuby. --- lib/power_assert/context.rb | 2 +- lib/power_assert/parser.rb | 308 ++++++++++++++++++++++-------------- test/test_helper.rb | 23 ++- 3 files changed, 204 insertions(+), 129 deletions(-) diff --git a/lib/power_assert/context.rb b/lib/power_assert/context.rb index 8d68196..cddc442 100644 --- a/lib/power_assert/context.rb +++ b/lib/power_assert/context.rb @@ -30,7 +30,7 @@ def initialize(assertion_proc_or_source, assertion_method, source_binding) line ||= File.open(path) {|fp| fp.each_line.drop(lineno - 1).first } end if line - @parser = Parser.new(line, path, lineno, @assertion_proc.binding, assertion_method.to_s, @assertion_proc) + @parser = Parser::RipperParser.new(line, path, lineno, @assertion_proc.binding, assertion_method.to_s, @assertion_proc) end end end diff --git a/lib/power_assert/parser.rb b/lib/power_assert/parser.rb index c8ad988..99b1697 100644 --- a/lib/power_assert/parser.rb +++ b/lib/power_assert/parser.rb @@ -1,9 +1,10 @@ -require 'ripper' - module PowerAssert class Parser Ident = Struct.new(:type, :name, :column) + class Branch < Array + end + attr_reader :line, :path, :lineno, :binding def initialize(line, path, lineno, binding, assertion_method_name = nil, assertion_proc = nil) @@ -18,7 +19,7 @@ def initialize(line, path, lineno, binding, assertion_method_name = nil, asserti end def idents - @idents ||= extract_idents(Ripper.sexp(@line_for_parsing)) + @idents ||= [] end def call_paths @@ -52,134 +53,197 @@ def slice_expression(str) str end - class Branch < Array - end + class RipperParser < Parser + def idents + @idents ||= + begin + require 'ripper' + extract_idents(Ripper.sexp(@line_for_parsing)) + end + end - AND_OR_OPS = %i(and or && ||) - - # - # Returns idents as graph structure. - # - # +--c--b--+ - # extract_idents(Ripper.sexp('a&.b(c).d')) #=> a--+ +--d - # +--------+ - # - def extract_idents(sexp) - case sexp - in [:arg_paren | :assoc_splat | :fcall | :hash | :method_add_block | :string_literal | :return, s, *] - extract_idents(s) - in [:assign | :massign, _, s] - extract_idents(s) - in [:opassign, _, [_, op_name, [_, op_column]], s] - extract_idents(s) + [Ident[:method, op_name.sub(/=\z/, ''), op_column]] - in [:dyna_symbol, [Symbol, *] => s] - # s can be [:string_content, [..]] while parsing an expression like { "a": 1 } - extract_idents(s) - in [:dyna_symbol, ss] - ss.flat_map {|s| extract_idents(s) } - in [:assoclist_from_args | :bare_assoc_hash | :paren | :string_embexpr | :regexp_literal | :xstring_literal, ss, *] - ss.flat_map {|s| extract_idents(s) } - in [:command, s0, s1] - [s1, s0].flat_map {|s| extract_idents(s) } - in [:assoc_new | :dot2 | :dot3 | :string_content, *ss] - ss.flat_map {|s| extract_idents(s) } - in [:unary, mid, s] - handle_columnless_ident([], mid, extract_idents(s)) - in [:binary, s0, op, s1] if AND_OR_OPS.include?(op) - extract_idents(s0) + [Branch[extract_idents(s1), []]] - in [:binary, s0, op, s1] - handle_columnless_ident(extract_idents(s0), op, extract_idents(s1)) - in [:call, recv, [op_sym, op_name, _], method] - with_safe_op = ((op_sym == :@op and op_name == '&.') or op_sym == :"&.") - if method == :call - handle_columnless_ident(extract_idents(recv), :call, [], with_safe_op) + private + + AND_OR_OPS = %i(and or && ||) + + # + # Returns idents as graph structure. + # + # +--c--b--+ + # extract_idents(Ripper.sexp('a&.b(c).d')) #=> a--+ +--d + # +--------+ + # + def extract_idents(sexp) + case sexp + in [:arg_paren | :assoc_splat | :fcall | :hash | :method_add_block | :string_literal | :return, s, *] + extract_idents(s) + in [:assign | :massign, _, s] + extract_idents(s) + in [:opassign, _, [_, op_name, [_, op_column]], s] + extract_idents(s) + [Ident[:method, op_name.sub(/=\z/, ''), op_column]] + in [:dyna_symbol, [Symbol, *] => s] + # s can be [:string_content, [..]] while parsing an expression like { "a": 1 } + extract_idents(s) + in [:dyna_symbol, ss] + ss.flat_map {|s| extract_idents(s) } + in [:assoclist_from_args | :bare_assoc_hash | :paren | :string_embexpr | :regexp_literal | :xstring_literal, ss, *] + ss.flat_map {|s| extract_idents(s) } + in [:command, s0, s1] + [s1, s0].flat_map {|s| extract_idents(s) } + in [:assoc_new | :dot2 | :dot3 | :string_content, *ss] + ss.flat_map {|s| extract_idents(s) } + in [:unary, mid, s] + handle_columnless_ident([], mid, extract_idents(s)) + in [:binary, s0, op, s1] if AND_OR_OPS.include?(op) + extract_idents(s0) + [Branch[extract_idents(s1), []]] + in [:binary, s0, op, s1] + handle_columnless_ident(extract_idents(s0), op, extract_idents(s1)) + in [:call, recv, [op_sym, op_name, _], method] + with_safe_op = ((op_sym == :@op and op_name == '&.') or op_sym == :"&.") + if method == :call + handle_columnless_ident(extract_idents(recv), :call, [], with_safe_op) + else + extract_idents(recv) + (with_safe_op ? [Branch[extract_idents(method), []]] : extract_idents(method)) + end + in [:array, ss] + ss ? ss.flat_map {|s| extract_idents(s) } : [] + in [:command_call, s0, _, s1, s2] + [s0, s2, s1].flat_map {|s| extract_idents(s) } + in [:aref, s0, s1] + handle_columnless_ident(extract_idents(s0), :[], extract_idents(s1)) + in [:method_add_arg, s0, s1] + case extract_idents(s0) + in [] + # idents(s0) may be empty(e.g. ->{}.()) + extract_idents(s1) + in [*is0, Branch[is1, []]] + # Safe navigation operator is used. See :call clause also. + is0 + [Branch[extract_idents(s1) + is1, []]] + in [*is, i] + is + extract_idents(s1) + [i] + end + in [:args_add_block, [:args_add_star, ss0, *ss1], _] + (ss0 + ss1).flat_map {|s| extract_idents(s) } + in [:args_add_block, ss, _] + ss.flat_map {|s| extract_idents(s) } + in [:vcall, [:@ident, name, [_, column]]] + [Ident[@proc_local_variables.include?(name) ? :ref : :method, name, column]] + in [:vcall, _] + [] + in [:program, [[:method_add_block, [:method_add_arg, [:fcall, [:@ident | :@const, ^@assertion_method_name, _]], _], [:brace_block | :do_block, _, ss]]]] + ss.flat_map {|s| extract_idents(s) } + in [:program, [s, *]] + extract_idents(s) + in [:ifop, s0, s1, s2] + [*extract_idents(s0), Branch[extract_idents(s1), extract_idents(s2)]] + in [:if | :unless, s0, ss0, [_, ss1]] + [*extract_idents(s0), Branch[ss0.flat_map {|s| extract_idents(s) }, ss1.flat_map {|s| extract_idents(s) }]] + in [:if | :unless, s0, ss0, _] + [*extract_idents(s0), Branch[ss0.flat_map {|s| extract_idents(s) }, []]] + in [:if_mod | :unless_mod, s0, s1] + [*extract_idents(s0), Branch[extract_idents(s1), []]] + in [:var_ref | :var_field, [:@kw, 'self', [_, column]]] + [Ident[:ref, 'self', column]] + in [:var_ref | :var_field, [:@ident | :@const | :@cvar | :@ivar | :@gvar, ref_name, [_, column]]] + [Ident[:ref, ref_name, column]] + in [:var_ref | :var_field, _] + [] + in [:@ident | :@const | :@op, method_name, [_, column]] + [Ident[:method, method_name, column]] else - extract_idents(recv) + (with_safe_op ? [Branch[extract_idents(method), []]] : extract_idents(method)) + [] end - in [:array, ss] - ss ? ss.flat_map {|s| extract_idents(s) } : [] - in [:command_call, s0, _, s1, s2] - [s0, s2, s1].flat_map {|s| extract_idents(s) } - in [:aref, s0, s1] - handle_columnless_ident(extract_idents(s0), :[], extract_idents(s1)) - in [:method_add_arg, s0, s1] - case extract_idents(s0) - in [] - # idents(s0) may be empty(e.g. ->{}.()) - extract_idents(s1) - in [*is0, Branch[is1, []]] - # Safe navigation operator is used. See :call clause also. - is0 + [Branch[extract_idents(s1) + is1, []]] - in [*is, i] - is + extract_idents(s1) + [i] + end + + def str_indices(str, re, offset, limit) + idx = str.index(re, offset) + if idx and idx <= limit + [idx, *str_indices(str, re, idx + 1, limit)] + else + [] end - in [:args_add_block, [:args_add_star, ss0, *ss1], _] - (ss0 + ss1).flat_map {|s| extract_idents(s) } - in [:args_add_block, ss, _] - ss.flat_map {|s| extract_idents(s) } - in [:vcall, [:@ident, name, [_, column]]] - [Ident[@proc_local_variables.include?(name) ? :ref : :method, name, column]] - in [:vcall, _] - [] - in [:program, [[:method_add_block, [:method_add_arg, [:fcall, [:@ident | :@const, ^@assertion_method_name, _]], _], [:brace_block | :do_block, _, ss]]]] - ss.flat_map {|s| extract_idents(s) } - in [:program, [s, *]] - extract_idents(s) - in [:ifop, s0, s1, s2] - [*extract_idents(s0), Branch[extract_idents(s1), extract_idents(s2)]] - in [:if | :unless, s0, ss0, [_, ss1]] - [*extract_idents(s0), Branch[ss0.flat_map {|s| extract_idents(s) }, ss1.flat_map {|s| extract_idents(s) }]] - in [:if | :unless, s0, ss0, _] - [*extract_idents(s0), Branch[ss0.flat_map {|s| extract_idents(s) }, []]] - in [:if_mod | :unless_mod, s0, s1] - [*extract_idents(s0), Branch[extract_idents(s1), []]] - in [:var_ref | :var_field, [:@kw, 'self', [_, column]]] - [Ident[:ref, 'self', column]] - in [:var_ref | :var_field, [:@ident | :@const | :@cvar | :@ivar | :@gvar, ref_name, [_, column]]] - [Ident[:ref, ref_name, column]] - in [:var_ref | :var_field, _] - [] - in [:@ident | :@const | :@op, method_name, [_, column]] - [Ident[:method, method_name, column]] - else - [] end - end - def str_indices(str, re, offset, limit) - idx = str.index(re, offset) - if idx and idx <= limit - [idx, *str_indices(str, re, idx + 1, limit)] - else - [] + MID2SRCTXT = { + :[] => '[', + :+@ => '+', + :-@ => '-', + :call => '(' + } + + def handle_columnless_ident(left_idents, mid, right_idents, with_safe_op = false) + left_max = left_idents.flatten.max_by(&:column) + right_min = right_idents.flatten.min_by(&:column) + bg = left_max ? left_max.column + left_max.name.length : 0 + ed = right_min ? right_min.column - 1 : @line_for_parsing.length - 1 + mname = mid.to_s + srctxt = MID2SRCTXT[mid] || mname + re = / + #{'\b' if /\A\w/ =~ srctxt} + #{Regexp.escape(srctxt)} + #{'\b' if /\w\z/ =~ srctxt} + /x + indices = str_indices(@line_for_parsing, re, bg, ed) + if indices.length == 1 or !(right_idents.empty? and left_idents.empty?) + ident = Ident[:method, mname, right_idents.empty? ? indices.first : indices.last] + left_idents + right_idents + (with_safe_op ? [Branch[[ident], []]] : [ident]) + else + left_idents + right_idents + end end end - MID2SRCTXT = { - :[] => '[', - :+@ => '+', - :-@ => '-', - :call => '(' - } - - def handle_columnless_ident(left_idents, mid, right_idents, with_safe_op = false) - left_max = left_idents.flatten.max_by(&:column) - right_min = right_idents.flatten.min_by(&:column) - bg = left_max ? left_max.column + left_max.name.length : 0 - ed = right_min ? right_min.column - 1 : @line_for_parsing.length - 1 - mname = mid.to_s - srctxt = MID2SRCTXT[mid] || mname - re = / - #{'\b' if /\A\w/ =~ srctxt} - #{Regexp.escape(srctxt)} - #{'\b' if /\w\z/ =~ srctxt} - /x - indices = str_indices(@line_for_parsing, re, bg, ed) - if indices.length == 1 or !(right_idents.empty? and left_idents.empty?) - ident = Ident[:method, mname, right_idents.empty? ? indices.first : indices.last] - left_idents + right_idents + (with_safe_op ? [Branch[[ident], []]] : [ident]) - else - left_idents + right_idents + class PrismParser < Parser + def idents + @idents ||= + begin + require 'prism' + result = Prism.parse(@line_for_parsing, scopes: [@proc_local_variables.map(&:to_sym)]) + result.success? ? extract_idents(result.value) : [] + end + end + + private + + def extract_idents(node) + case node&.type + when nil + [] + when :program_node + if (statement = node.statements.body.first).is_a?(Prism::CallNode) && (statement.name == @assertion_method_name&.to_sym) && (block = statement.block).is_a?(Prism::BlockNode) && (body = block.body).is_a?(Prism::StatementsNode) + extract_idents(body) + else + extract_idents(statement) + end + when :call_node + message_ident = + if (node.name == :[] && !node.call_operator_loc) + Ident[:method, "[]", node.opening_loc.start_column] + elsif (message_loc = node.message_loc) + Ident[:method, node.name.name, message_loc.start_column] + else + Ident[:method, "call", node.opening_loc.start_column] + end + + if node.safe_navigation? + extract_idents(node.receiver) + [Branch[extract_idents(node.arguments) + [message_ident], []]] + else + extract_idents(node.receiver) + extract_idents(node.arguments) + [message_ident] + end + when :call_operator_write_node, :class_variable_operator_write_node, :constant_operator_write_node, :constant_path_operator_write_node, :global_variable_operator_write_node, :index_operator_write_node, :instance_variable_operator_write_node, :local_variable_operator_write_node + binary_operator_loc = node.binary_operator_loc.chop + extract_idents(node.value) + [Ident[:method, binary_operator_loc.slice, binary_operator_loc.start_column]] + when :if_node + extract_idents(node.predicate) + [Branch[extract_idents(node.statements), extract_idents(node.subsequent)]] + when :unless_node + extract_idents(node.predicate) + [Branch[extract_idents(node.statements), extract_idents(node.else_clause)]] + when :and_node, :or_node + extract_idents(node.left) + [Branch[extract_idents(node.right), []]] + when :class_variable_read_node, :constant_read_node, :global_variable_read_node, :instance_variable_read_node, :local_variable_read_node, :self_node + [Ident[:ref, node.slice, node.start_column]] + else + node.compact_child_nodes.flat_map { |child| extract_idents(child) } + end end end diff --git a/test/test_helper.rb b/test/test_helper.rb index 5cc39ef..a81c0c7 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -14,7 +14,6 @@ require 'test/unit' require 'power_assert' -require 'ripper' module PowerAssertTestHelper class << self @@ -32,12 +31,24 @@ def t(msg='', &blk) private + PrismParser = ::PowerAssert.const_get(:Parser)::PrismParser + RipperParser = ::PowerAssert.const_get(:Parser)::RipperParser + PARSER_CLASSES = RUBY_VERSION >= '3.3.0' ? [PrismParser, RipperParser] : [RipperParser] + def _test_parser((expected_idents, expected_paths, source)) - parser = ::PowerAssert.const_get(:Parser).new(source, '', 1, -> { var = nil; -> { var } }.().binding, 'assertion_message') - idents = parser.idents - assert_equal expected_idents, map_recursive(idents, &:to_a), source - if expected_paths - assert_equal expected_paths, map_recursive(parser.call_paths, &:name), source + PARSER_CLASSES.each do |parser_class| + parser = parser_class.new(source, '', 1, -> { var = nil; -> { var } }.().binding, 'assertion_message') + idents = parser.idents + + if expected_idents.empty? && parser_class == PrismParser + # Allow PrismParser to handle more syntax than RipperParser. + else + assert_equal expected_idents, map_recursive(idents, &:to_a), source + end + + if expected_paths + assert_equal expected_paths, map_recursive(parser.call_paths, &:name), source + end end end